from functools import partial
from multiprocessing import Pool
from ignite.utils import setup_logger
[docs]class Parallelization:
"""Parallelization framework for running functions in parallel.
you will need to prepare the inputs and the function to run.
Arguments:
max_workers (int): maximum number of parallel workers (default: ``8``)
It is a simple framework that uses the `multiprocessing.Pool` module to
run functions in parallel. You need to prepare the inputs as a list of lists,
where each inner list contains the arguments for a single function call.
Example:
.. code-block:: python
# Import the Parallelization class.
from utils.parallel.framework import Parallelization
# Create a Parallelization object and prepare the inputs.
p_job = Parallelization(max_workers=4)
p_job.prepare(test=True)
# Print out the inputs in parallel.
p_job.run(func=p_job.test_func)
"""
def __init__(self, max_workers=8):
self.max_workers = max_workers
def prepare(self, test=False):
if test:
self.test_prepare()
else:
self.inputs = []
def set_function(self, func, **kwargs):
self.func = partial(func, **kwargs)
def run(self, func=None):
assert self.inputs, "No inputs to run the function."
if func is None:
func = self.func
assert func, "No function to run."
with Pool(self.max_workers) as pool:
pool.starmap(func, zip(*self.inputs))
def test_run(self):
self.run(self.test_func)
def test_prepare(self):
a = [1, 2, 3]
b = ["a", "b", "c"]
c = [99, 98, 97]
self.inputs = [a, b, c]
def test_func(self, a, b, c):
print(f"{a} {b} {c} done\n")
[docs]class TrajHandlerPreprocess(Parallelization):
"""Parallelization framework for preprocessing trajectory data.
You will need to prepare the inputs and the function to run.
Arguments:
max_workers (int): maximum number of parallel workers (default: ``8``)
logger_path (str): path to save the logger file (default: ``None``). If
``None``, no logger will be created.
task_name (str): name of the task for the logger (default: ``"core-dataprep"``)
Example:
.. code-block:: python
# Import the TrajHandlerPreprocess class.
from utils.parallel.framework import TrajHandlerPreprocess
# Create a TrajHandlerPreprocess object and prepare the inputs.
p_job = TrajHandlerPreprocess(max_workers=4, logger_path="preprocess.log")
p_job.prepare(
traj_handler=traj_handler,
root_path="output/p_data",
filename="traj_data",
frames_list=[0, 1, 2, 3, 4],
)
# Print out the inputs in parallel.
p_job.run(func=preprocess_workflow)
.. note::
The source code of the ``preprocess_workflow`` function can be found in
:func:`utils.datasets.general.preprocess_workflow`.
"""
def __init__(
self, max_workers=8, logger_path: str = None, task_name: str = "core-dataprep"
):
super().__init__(max_workers)
self.logger = (
setup_logger(name=task_name, filepath=logger_path) if logger_path else None
)
[docs] def prepare(
self,
traj_handler,
config=None,
**kwargs,
):
"""Prepare the inputs for the preprocess workflow.
Args:
traj_handler: MDAnalysis Universe object
config: configuration object.
If ``None``, the ``root_path`` and ``filename`` are required in the
kwargs (optional keyword arguments).
root_path (str, optional keyword argument): root path for the output files.
If ``None``, the root path is required in the config file
(``output_p_data_folderpath``).
filename (str, optional keyword argument): filename for the output files.
If ``None``, the filename is required in the config file
(``p_filename``).
frames_list (list, optional keyword argument): list of frames to process
(e.g. ``[1, 2, 3]``). If ``None``, all frames will be processed.
index_path (str, optional keyword argument): path to the index file.
If not provided, the one found in the config file
(``output_index_path``) will be used. Otherwise, it will be
generated by ``[root_path]/[filename]_index.txt``.
"""
# setup and check root_path
if kwargs.get("root_path"):
root_path = kwargs["root_path"]
elif hasattr(config, "output_p_data_folderpath"):
root_path = config.output_p_data_folderpath
else:
root_path = None
assert root_path is not None, (
"`root_path` or `config.output_p_data_folderpath` is required (not `None`)."
)
# setup filename
if kwargs.get("filename"):
filename = kwargs["filename"]
elif hasattr(config, "p_filename"):
filename = config.p_filename
else:
filename = None
assert filename is not None, (
"`filename` or `config.p_filename` is required (not `None`)."
)
# setup frames_list
if kwargs.get("frames_list"):
frames_list = kwargs["frames_list"]
elif hasattr(config, "frames_list"):
frames_list = config.frames_list
else:
frames_list = None
# check it is existance of output_index_path in config
if kwargs.get("index_path"):
index_path = kwargs["index_path"]
elif (
hasattr(config, "output_index_path")
and config.output_index_path is not None
):
index_path = config.output_index_path
else:
index_path = f"{root_path}/{filename}_index.txt"
# setup frames
frame_length = len(traj_handler.universe.trajectory)
frames = (
[i for i in range(frame_length)] if frames_list is None else frames_list
)
assert max(frames) < frame_length, (
f"The provided frames list is out of range ({max(frames)} > "
f"{frame_length - 1})."
)
# initialize the inputs
self.inputs_dict = {
"traj_handlers": [],
"pdb_filepaths": [],
"ply_filepaths": [],
"h5_filepaths": [],
"frames": frames,
"index_paths": [],
}
# prepare the inputs
for frame in frames:
# common object
self.inputs_dict["traj_handlers"].append(traj_handler)
self.inputs_dict["index_paths"].append(index_path)
# each file path
self.inputs_dict["pdb_filepaths"].append(
f"{root_path}/{filename}_{frame}.pdb"
)
self.inputs_dict["ply_filepaths"].append(
f"{root_path}/{filename}_{frame}.ply"
)
self.inputs_dict["h5_filepaths"].append(
f"{root_path}/{filename}_{frame}.h5"
)
# follow the input in the order of the function
# utils.datasets.general.preprocess_workflow
self.inputs = [
self.inputs_dict["traj_handlers"],
self.inputs_dict["pdb_filepaths"],
self.inputs_dict["ply_filepaths"],
self.inputs_dict["h5_filepaths"],
self.inputs_dict["frames"],
self.inputs_dict["index_paths"],
]
def set_function(self, func, **kwargs):
if self.logger:
return super().set_function(func, logger=self.logger, **kwargs)
else:
return super().set_function(func, **kwargs)
[docs]class TrajHandlerPrediction(TrajHandlerPreprocess):
"""Parallelization framework for prediction workflow.
You will need to prepare the inputs and the function to run.
Arguments:
max_workers (int): maximum number of parallel workers (default: ``8``)
logger_path (str): path to save the logger file (default: ``None``). If
``None``, no logger will be created.
task_name (str): name of the task for the logger (default: ``"core-predict"``)
Example:
.. code-block:: python
# Import the TrajHandlerPrediction class.
from utils.parallel.framework import TrajHandlerPrediction
# Create a TrajHandlerPrediction object and prepare the inputs.
p_job = TrajHandlerPrediction(max_workers=4, logger_path="predict.log")
p_job.prepare(
traj_handler=traj_handler,
root_path="output/p_data",
filename="traj_data",
frames_list=[0, 1, 2, 3, 4],
)
# Setup the function.
p_job.set_function(
func=add_prediction_to_ply,
model_path="model/best_model.pt",
)
# Run the prediction in parallel.
p_job.run()
.. note::
The source code of the ``add_prediction_to_ply`` function can be found in
:func:`utils.datasets.general.add_prediction_to_ply`.
"""
def __init__(
self,
max_workers=8,
logger_path: str = None,
task_name: str = "core-predict",
):
super().__init__(max_workers, logger_path=logger_path, task_name=task_name)
[docs] def prepare(self, traj_handler, config=None, **kwargs):
# modify the config
if hasattr(config, "input_p_data_folderpath"):
config["output_p_data_folderpath"] = config.input_p_data_folderpath
if hasattr(config, "input_p_filename"):
config["p_filename"] = config.input_p_filename
# setup default values
super().prepare(
traj_handler,
config=config,
**kwargs,
)
self.inputs_dict["json_filepaths"] = [
each.replace(".ply", ".json") for each in self.inputs_dict["ply_filepaths"]
]
self.inputs = [
self.inputs_dict["traj_handlers"], # traj_handlers
self.inputs_dict["ply_filepaths"], # ply_filepaths
self.inputs_dict["h5_filepaths"], # h5_filepaths
self.inputs_dict["json_filepaths"], # json_filepaths
self.inputs_dict["frames"], # frames
]
[docs]class TrajHandlerVisualization(TrajHandlerPreprocess):
"""Parallelization framework for visualization workflow.
You will need to prepare the inputs and the function to run.
Arguments:
max_workers (int): maximum number of parallel workers (default: ``8``)
logger_path (str): path to save the logger file (default: ``None``). If
``None``, no logger will be created.
task_name (str): name of the task for the logger (default: ``"core-vis"``)
Example:
.. code-block:: python
# Import the TrajHandlerVisualization class.
from utils.parallel.framework import TrajHandlerVisualization
# Create a TrajHandlerVisualization object and prepare the inputs.
p_job = TrajHandlerVisualization(max_workers=4)
p_job.prepare(
traj_handler=traj_handler,
root_path="output/p_data",
filename="traj_data",
frames_list=[0, 1, 2, 3, 4],
)
# Setup the function.
p_job.set_function(func=generate_pse, pymol_path="path/to/pymol")
# Run the visualization in parallel.
p_job.run()
.. note::
The source code of the ``generate_pse`` function can be found in
:func:`utils.pymol_scripts.vis_pdb_ply.generate_pse`.
"""
def __init__(
self,
max_workers=8,
logger_path: str = None,
task_name: str = "core-vis",
):
super().__init__(max_workers, logger_path=logger_path, task_name=task_name)
[docs] def prepare(self, traj_handler, config=None, **kwargs):
# modify the config
if hasattr(config, "input_p_data_folderpath"):
config["output_p_data_folderpath"] = config.input_p_data_folderpath
if hasattr(config, "input_p_filename"):
config["p_filename"] = config.input_p_filename
# setup default values
super().prepare(
traj_handler,
config=config,
**kwargs,
)
self.inputs_dict["pse_filepaths"] = [
each.replace(".ply", ".pse") for each in self.inputs_dict["ply_filepaths"]
]
self.inputs = [
self.inputs_dict["pdb_filepaths"], # pdb_filepaths
self.inputs_dict["ply_filepaths"], # ply_filepaths
self.inputs_dict["pse_filepaths"], # pse_filepaths
]