Source code for utils.parallel.framework

from functools import partial
from multiprocessing import Pool

from ignite.utils import setup_logger


[docs]class Parallelization: """Parallelization framework for running functions in parallel. you will need to prepare the inputs and the function to run. Arguments: max_workers (int): maximum number of parallel workers (default: ``8``) It is a simple framework that uses the `multiprocessing.Pool` module to run functions in parallel. You need to prepare the inputs as a list of lists, where each inner list contains the arguments for a single function call. Example: .. code-block:: python # Import the Parallelization class. from utils.parallel.framework import Parallelization # Create a Parallelization object and prepare the inputs. p_job = Parallelization(max_workers=4) p_job.prepare(test=True) # Print out the inputs in parallel. p_job.run(func=p_job.test_func) """ def __init__(self, max_workers=8): self.max_workers = max_workers def prepare(self, test=False): if test: self.test_prepare() else: self.inputs = [] def set_function(self, func, **kwargs): self.func = partial(func, **kwargs) def run(self, func=None): assert self.inputs, "No inputs to run the function." if func is None: func = self.func assert func, "No function to run." with Pool(self.max_workers) as pool: pool.starmap(func, zip(*self.inputs)) def test_run(self): self.run(self.test_func) def test_prepare(self): a = [1, 2, 3] b = ["a", "b", "c"] c = [99, 98, 97] self.inputs = [a, b, c] def test_func(self, a, b, c): print(f"{a} {b} {c} done\n")
[docs]class TrajHandlerPreprocess(Parallelization): """Parallelization framework for preprocessing trajectory data. You will need to prepare the inputs and the function to run. Arguments: max_workers (int): maximum number of parallel workers (default: ``8``) logger_path (str): path to save the logger file (default: ``None``). If ``None``, no logger will be created. task_name (str): name of the task for the logger (default: ``"core-dataprep"``) Example: .. code-block:: python # Import the TrajHandlerPreprocess class. from utils.parallel.framework import TrajHandlerPreprocess # Create a TrajHandlerPreprocess object and prepare the inputs. p_job = TrajHandlerPreprocess(max_workers=4, logger_path="preprocess.log") p_job.prepare( traj_handler=traj_handler, root_path="output/p_data", filename="traj_data", frames_list=[0, 1, 2, 3, 4], ) # Print out the inputs in parallel. p_job.run(func=preprocess_workflow) .. note:: The source code of the ``preprocess_workflow`` function can be found in :func:`utils.datasets.general.preprocess_workflow`. """ def __init__( self, max_workers=8, logger_path: str = None, task_name: str = "core-dataprep" ): super().__init__(max_workers) self.logger = ( setup_logger(name=task_name, filepath=logger_path) if logger_path else None )
[docs] def prepare( self, traj_handler, config=None, **kwargs, ): """Prepare the inputs for the preprocess workflow. Args: traj_handler: MDAnalysis Universe object config: configuration object. If ``None``, the ``root_path`` and ``filename`` are required in the kwargs (optional keyword arguments). root_path (str, optional keyword argument): root path for the output files. If ``None``, the root path is required in the config file (``output_p_data_folderpath``). filename (str, optional keyword argument): filename for the output files. If ``None``, the filename is required in the config file (``p_filename``). frames_list (list, optional keyword argument): list of frames to process (e.g. ``[1, 2, 3]``). If ``None``, all frames will be processed. index_path (str, optional keyword argument): path to the index file. If not provided, the one found in the config file (``output_index_path``) will be used. Otherwise, it will be generated by ``[root_path]/[filename]_index.txt``. """ # setup and check root_path if kwargs.get("root_path"): root_path = kwargs["root_path"] elif hasattr(config, "output_p_data_folderpath"): root_path = config.output_p_data_folderpath else: root_path = None assert root_path is not None, ( "`root_path` or `config.output_p_data_folderpath` is required (not `None`)." ) # setup filename if kwargs.get("filename"): filename = kwargs["filename"] elif hasattr(config, "p_filename"): filename = config.p_filename else: filename = None assert filename is not None, ( "`filename` or `config.p_filename` is required (not `None`)." ) # setup frames_list if kwargs.get("frames_list"): frames_list = kwargs["frames_list"] elif hasattr(config, "frames_list"): frames_list = config.frames_list else: frames_list = None # check it is existance of output_index_path in config if kwargs.get("index_path"): index_path = kwargs["index_path"] elif ( hasattr(config, "output_index_path") and config.output_index_path is not None ): index_path = config.output_index_path else: index_path = f"{root_path}/{filename}_index.txt" # setup frames frame_length = len(traj_handler.universe.trajectory) frames = ( [i for i in range(frame_length)] if frames_list is None else frames_list ) assert max(frames) < frame_length, ( f"The provided frames list is out of range ({max(frames)} > " f"{frame_length - 1})." ) # initialize the inputs self.inputs_dict = { "traj_handlers": [], "pdb_filepaths": [], "ply_filepaths": [], "h5_filepaths": [], "frames": frames, "index_paths": [], } # prepare the inputs for frame in frames: # common object self.inputs_dict["traj_handlers"].append(traj_handler) self.inputs_dict["index_paths"].append(index_path) # each file path self.inputs_dict["pdb_filepaths"].append( f"{root_path}/{filename}_{frame}.pdb" ) self.inputs_dict["ply_filepaths"].append( f"{root_path}/{filename}_{frame}.ply" ) self.inputs_dict["h5_filepaths"].append( f"{root_path}/{filename}_{frame}.h5" ) # follow the input in the order of the function # utils.datasets.general.preprocess_workflow self.inputs = [ self.inputs_dict["traj_handlers"], self.inputs_dict["pdb_filepaths"], self.inputs_dict["ply_filepaths"], self.inputs_dict["h5_filepaths"], self.inputs_dict["frames"], self.inputs_dict["index_paths"], ]
def set_function(self, func, **kwargs): if self.logger: return super().set_function(func, logger=self.logger, **kwargs) else: return super().set_function(func, **kwargs)
[docs]class TrajHandlerPrediction(TrajHandlerPreprocess): """Parallelization framework for prediction workflow. You will need to prepare the inputs and the function to run. Arguments: max_workers (int): maximum number of parallel workers (default: ``8``) logger_path (str): path to save the logger file (default: ``None``). If ``None``, no logger will be created. task_name (str): name of the task for the logger (default: ``"core-predict"``) Example: .. code-block:: python # Import the TrajHandlerPrediction class. from utils.parallel.framework import TrajHandlerPrediction # Create a TrajHandlerPrediction object and prepare the inputs. p_job = TrajHandlerPrediction(max_workers=4, logger_path="predict.log") p_job.prepare( traj_handler=traj_handler, root_path="output/p_data", filename="traj_data", frames_list=[0, 1, 2, 3, 4], ) # Setup the function. p_job.set_function( func=add_prediction_to_ply, model_path="model/best_model.pt", ) # Run the prediction in parallel. p_job.run() .. note:: The source code of the ``add_prediction_to_ply`` function can be found in :func:`utils.datasets.general.add_prediction_to_ply`. """ def __init__( self, max_workers=8, logger_path: str = None, task_name: str = "core-predict", ): super().__init__(max_workers, logger_path=logger_path, task_name=task_name)
[docs] def prepare(self, traj_handler, config=None, **kwargs): # modify the config if hasattr(config, "input_p_data_folderpath"): config["output_p_data_folderpath"] = config.input_p_data_folderpath if hasattr(config, "input_p_filename"): config["p_filename"] = config.input_p_filename # setup default values super().prepare( traj_handler, config=config, **kwargs, ) self.inputs_dict["json_filepaths"] = [ each.replace(".ply", ".json") for each in self.inputs_dict["ply_filepaths"] ] self.inputs = [ self.inputs_dict["traj_handlers"], # traj_handlers self.inputs_dict["ply_filepaths"], # ply_filepaths self.inputs_dict["h5_filepaths"], # h5_filepaths self.inputs_dict["json_filepaths"], # json_filepaths self.inputs_dict["frames"], # frames ]
[docs]class TrajHandlerVisualization(TrajHandlerPreprocess): """Parallelization framework for visualization workflow. You will need to prepare the inputs and the function to run. Arguments: max_workers (int): maximum number of parallel workers (default: ``8``) logger_path (str): path to save the logger file (default: ``None``). If ``None``, no logger will be created. task_name (str): name of the task for the logger (default: ``"core-vis"``) Example: .. code-block:: python # Import the TrajHandlerVisualization class. from utils.parallel.framework import TrajHandlerVisualization # Create a TrajHandlerVisualization object and prepare the inputs. p_job = TrajHandlerVisualization(max_workers=4) p_job.prepare( traj_handler=traj_handler, root_path="output/p_data", filename="traj_data", frames_list=[0, 1, 2, 3, 4], ) # Setup the function. p_job.set_function(func=generate_pse, pymol_path="path/to/pymol") # Run the visualization in parallel. p_job.run() .. note:: The source code of the ``generate_pse`` function can be found in :func:`utils.pymol_scripts.vis_pdb_ply.generate_pse`. """ def __init__( self, max_workers=8, logger_path: str = None, task_name: str = "core-vis", ): super().__init__(max_workers, logger_path=logger_path, task_name=task_name)
[docs] def prepare(self, traj_handler, config=None, **kwargs): # modify the config if hasattr(config, "input_p_data_folderpath"): config["output_p_data_folderpath"] = config.input_p_data_folderpath if hasattr(config, "input_p_filename"): config["p_filename"] = config.input_p_filename # setup default values super().prepare( traj_handler, config=config, **kwargs, ) self.inputs_dict["pse_filepaths"] = [ each.replace(".ply", ".pse") for each in self.inputs_dict["ply_filepaths"] ] self.inputs = [ self.inputs_dict["pdb_filepaths"], # pdb_filepaths self.inputs_dict["ply_filepaths"], # ply_filepaths self.inputs_dict["pse_filepaths"], # pse_filepaths ]