import json
import warnings
from pathlib import Path
from typing import Literal
import MDAnalysis as mda
import numpy as np
import pandas as pd
import pymesh
from MDAnalysis.analysis import align
from natsort import natsorted
from ..ppseg.dataset import get_mask_col_idx, load_h5, preprocess_h5, write_h5
from ..ppseg.fragment import fragment_idx_label_dict
from ..ppseg.inference import inference
from ..ppseg.myo.default_config import LIGAND_FRAG_INFO_PATH
from ..ppseg.visualization.visualization import fragmentation_from_universe
from ..ppseg.voxelization import map_voxel_to_xyz, site_voxelization
from ..thirdparty.deepdrug3d.build_grid import read_aux_file
from ..thirdparty.deepdrug3d.write_aux_file import write_aux_file
from .feature_handler import generate_masif_features
_TYPES = Literal["complex", "protein", "ligand"]
[docs]def check_standard_names(u: mda.Universe):
"""Check the resnames and atom names for the topology.
Arguments:
u: The input MDAnalysis Universe.
Returns:
None
"""
standard_resnames = {
"ALA",
"GLY",
"SER",
"THR",
"LEU",
"ILE",
"VAL",
"ASN",
"GLN",
"ARG",
"HIS",
"TRP",
"PHE",
"TYR",
"GLU",
"ASP",
"LYS",
"PRO",
"CYS",
"MET",
}
top_res = set(u.select_atoms("protein").residues.resnames)
unique_atomnames = set(u.select_atoms("protein").atoms.names)
assert top_res.issubset(standard_resnames), (
f"Unexpected residues found: {top_res - standard_resnames}.\n"
"Try to use convert_to_standard_names:\n"
" from utils.datasets.traj_handler import convert_to_standard_names"
" convert_to_standard_names(PDB_path)\n"
"Then, you can use the converted PDB file as the topology for your trajectory."
)
assert not any(name.startswith(("1", "2", "3")) for name in unique_atomnames), (
"Unexpected atom names found (starting with 1, 2, or 3)"
f"{[name for name in unique_atomnames if name.startswith(('1', '2', '3'))]}.\n"
"Try to use convert_to_standard_names:\n"
" from utils.datasets.traj_handler import convert_to_standard_names\n"
" convert_to_standard_names(PDB_path)\n"
"Then, you can use the converted PDB file as the topology for your trajectory."
)
[docs]def convert_to_standard_names(PDB_path: str | Path):
"""Convert the resnames and atom names in the PDB file to standard ones.
Arguments:
PDB_path: The path to the input PDB file.
Returns:
None. The converted PDB file will be saved to the same directory with the
suffix "_converted.pdb".
"""
u = mda.Universe(PDB_path)
original_resnames = set(u.residues.resnames)
original_atomnames = set(u.atoms.names)
# Define a mapping from non-standard to standard residue names
resname_mapping = {
# alanine
"ALAD": "ALA",
"DALA": "ALA",
"NALA": "ALA",
"CALA": "ALA",
# arginine
"ARGN": "ARG",
"NARG": "ARG",
"CARG": "ARG",
# asparagine
"NASN": "ASN",
"CASN": "ASN",
"CASF": "ASN",
"ASF": "ASN",
"ASN1": "ASN",
# aspartic acid
"ASH": "ASP",
"ASPH": "ASP",
"ASPP": "ASP",
"NASP": "ASP",
"CASP": "ASP",
# cysteine
"CYS1": "CYS",
"CYS2": "CYS",
"CYSH": "CYS",
"CYM": "CYS",
"CYN": "CYS",
"CYX": "CYS",
"NCYS": "CYS",
"CCYS": "CYS",
"NCYX": "CYS",
"CCYX": "CYS",
# gluatmic acid
"GLUH": "GLU",
"GLUP": "GLU",
"GLH": "GLU",
"NGLU": "GLU",
"CGLU": "GLU",
# glutamine
"NGLN": "GLN",
"CGLN": "GLN",
# glycine
"NGLY": "GLY",
"CGLY": "GLY",
# histidine
"HISD": "HIS",
"HISE": "HIS",
"HIS1": "HIS",
"HIS2": "HIS",
"HISA": "HIS",
"HISB": "HIS",
"HISH": "HIS",
"HID": "HIS",
"HIE": "HIS",
"HIP": "HIS",
"HSD": "HIS",
"HSE": "HIS",
"HSP": "HIS",
"NHID": "HIS",
"NHIE": "HIS",
"NHIP": "HIS",
"CHID": "HIS",
"CHIE": "HIS",
"CHIP": "HIS",
# isoleucine
"NILE": "ILE",
"CILE": "ILE",
# leucine
"NLEU": "LEU",
"CLEU": "LEU",
# lysine
"LYN": "LYS",
"LSN": "LYS",
"LYSH": "LYS",
"NLYS": "LYS",
"CLYS": "LYS",
# methionine
"NMET": "MET",
"CMET": "MET",
# phenylalanine
"NPHE": "PHE",
"CPHE": "PHE",
# proline
"NPRO": "PRO",
"CPRO": "PRO",
# serine
"NSER": "SER",
"CSER": "SER",
# threonine
"NTHR": "THR",
"CTHR": "THR",
# tryptophan
"NTRP": "TRP",
"CTRP": "TRP",
# tyrosine
"NTYR": "TYR",
"CTYR": "TYR",
# valine
"NVAL": "VAL",
"CVAL": "VAL",
}
# Convert residue names
converted_resnames = []
for resid in u.residues:
if resid.resname in resname_mapping:
converted_resnames.append(resid.resname)
resid.resname = resname_mapping[resid.resname]
# Convert atom names (e.g., 1HG2 to HG21)
converted_atomnames = []
for atom in u.atoms:
if atom.name.startswith(("1", "2", "3")) and len(atom.name) > 1:
converted_atomnames.append(atom.name)
atom.name = atom.name[1:] + atom.name[0]
# Save the converted PDB file
converted_PDB_path = str(PDB_path).replace(".pdb", "_converted.pdb")
u.atoms.write(converted_PDB_path)
print(f"Converted PDB saved to: {converted_PDB_path}")
print(
f"Original residue names: {original_resnames};\n"
f"Converted residue names: {set(converted_resnames)}"
)
print(
f"Original atom names: {original_atomnames};\n"
f"Converted atom names: {set(converted_atomnames)}"
)
[docs]def get_ligand_around_resids(
u: mda.Universe,
ligand_name: str,
ligand_aa_dist: int,
aa_existence_time: float = 0.50,
with_segid: bool = False,
) -> list:
"""Get the residues around the ligand over a trajectory.
Arguments:
u (required): The input MDAnalysis Universe, representing the molecular system.
ligand_name (required): The name of the ligand to find residues around.
ligand_aa_dist: The distance (Å) from the ligand to consider residues as nearby.
aa_existence_time: The fraction of the trajectory during which a residue must
be present near the ligand to be included (default = 50%).
with_segid: If ``True``, includes the segment ID (segid) in the returned
residues.
Returns:
list: A list of the residues (include only resids, if `with_segid`
is ``False``; include segids and resids, if `with_segid` is ``True``)
around the ligand.
"""
u.trajectory[0] # set to frame 0
frame_traj = []
around_segresids_traj = []
around_u = u.select_atoms(
f"protein and (around {ligand_aa_dist} resname {ligand_name})", updating=True
)
# get the residues around the ligand
for ts in u.trajectory:
frame_traj.append(ts.frame)
# update the around_u
around_segresids = np.unique(
around_u.segids + "_" + around_u.resids.astype(str)
).tolist()
around_segresids_traj.append(around_segresids)
# create a list to know all residues appearing in the trajectory
all_avail_segresids = []
for each in around_segresids_traj:
all_avail_segresids.extend(each)
all_avail_segresids = np.unique(np.array(all_avail_segresids))
# intialise the one-hot encoding dictionary of the residues around the ligand
one_hot_segresids_traj = {"template": {}}
for each_segresid in all_avail_segresids:
one_hot_segresids_traj["template"][each_segresid] = 0
# create the one-hot encoding of the residues around the ligand
for frame_num, each_frame in zip(frame_traj, around_segresids_traj):
one_hot_segresids_traj[frame_num] = one_hot_segresids_traj["template"].copy()
for each_segresid in each_frame:
one_hot_segresids_traj[frame_num][each_segresid] = 1
one_hot_segresids_traj_df = pd.DataFrame(one_hot_segresids_traj).T
one_hot_segresids_traj_df.index.name = "Frame"
one_hot_segresids_traj_df.drop("template", axis=0, inplace=True)
# filter the residues that are present in at least 50% of the trajectory
filter = (
one_hot_segresids_traj_df.sum(axis=0) / len(one_hot_segresids_traj_df.index)
) >= aa_existence_time
around_u_segresids = filter.index[filter.values].tolist()
if with_segid:
return around_u_segresids
else:
around_u_resids = [
str(each.split("_")[1]) for each in around_u_segresids if "_" in each
]
return around_u_resids
[docs]def get_resname_with_resid(u: mda.Universe, resids: list) -> list:
"""Get the residue name with the residue ID.
Arguments:
u: The input MDAnalysis Universe.
resids: The residue IDs to get the residue names.
Returns:
list: A list of the residue names with the residue IDs.
Example:
.. code-block:: python
from MDAnalysis import Universe
u = Universe("example.pdb")
resids = [1, 2, 3]
resname_with_resid = get_resname_with_resid(u, resids)
print(resname_with_resid)
# Output:
>>> ['ALA1', 'ARG2', 'GLU3']
"""
resname_with_resid = []
for each_resid in resids:
sele_string = f"resid {each_resid}"
sele_resid = u.select_atoms(sele_string)
sele_aa = mda.lib.util.convert_aa_code(sele_resid.residues[0].resname)
resname_with_resid.append(f"{sele_aa}{each_resid}")
return resname_with_resid
# handling the complex trajectory
[docs]class TrajectoryHandler:
"""Trajectory handler for protein-ligand complex / protein-only trajectory.
This class is used to handle the trajectory of a protein-ligand complex or
protein-only trajectory. It provides methods to read the trajectory, get the
residues at the pocket, get the pocket center, write the structure, features,
labels, interest region, and voxelised data. It also provides methods to
preprocess the data and write the auxiliary files.
Arguments:
top_path (required): the path to the topology file
(.pdb, .gro, ... [MDAnalysis compatible])
trajectory_path (required): the path to the trajectory file
(.trr, .xtc, ... [MDAnalysis compatible])
ligand_name (optional, recommended to provide): the name of the ligand
radius_of_interest (optional): the radius (Å)
to consider the interest region (default: ``16.0``)
spacing (optional): the spacing (Å) between the grid points
(default = ``0.5`` due to the sampling theorem
from the mesh spacing 1Å)
distance_cutoff (optional): the surface points will be labelled only
if the distance of the point to the ligand's heavy atoms within
this distance cutoff. (default: ``5.0`` Å)
warning_check (optional): if ``True``, the warnings will be shown
(default = ``True``)
Returns:
`self.top_path` was set to the top_path
`self.trajectory_path` was set to the trajectory_path
`self.ligand_name` was set to the ligand_name
`self.universe` was set to the MDAnalysis Universe object
`self.warning_check` was set to the warning_check
.. note::
Functions:
high-level functions (can use self.variables and self.functions):
low-level functions (can only use self.functions):
"""
def __init__(
self,
top_path: str | Path,
trajectory_path: str | Path = None,
ligand_name: str = None,
radius_of_interest: float = 16.0,
spacing: float = 0.5,
distance_cutoff: float = 5.0,
warning_check: bool = True,
):
"""Initialize the TrajectoryHandler with the given parameters.
Args:
top_path (str | Path): Path to the topology file.
trajectory_path (str | Path, optional): Path to the trajectory file.
Defaults to None.
ligand_name (str, optional): Name of the ligand. Defaults to None.
radius_of_interest (float, optional): Radius to consider the interest
region. Defaults to ``16.0``.
spacing (float, optional): Spacing for the grid points. Defaults to ``0.5``.
distance_cutoff (float, optional): Distance cutoff for labeling surface
points. Defaults to ``5.0``.
warning_check (bool, optional): Whether to show warnings. Defaults to
``True``.
"""
self.top_path = top_path
self.trajectory_path = trajectory_path
self.ligand_name = ligand_name
self.universe = (
mda.Universe(top_path, trajectory_path)
if trajectory_path is not None
else mda.Universe(top_path)
)
self.warning_check = warning_check
check_standard_names(self.universe)
# optional
self.set_config(
radius_of_interest=radius_of_interest,
spacing=spacing,
distance_cutoff=distance_cutoff,
)
# check the ligand
if self.warning_check:
self._check_ligand()
else:
warnings.filterwarnings("ignore")
[docs] def set_config(
self,
radius_of_interest: float = None,
spacing: float = None,
distance_cutoff: float = None,
):
"""Set the configuration (radius of interest, spacing, distance_cutoff), and
detect whether the trajectory has multiple segids.
Arguments:
radius_of_interest: float (recommend = 16), the radius (Å) to
consider the interest region
spacing: float (recommend = 0.5), the spacing (Å) to
consider the interest region
distance_cutoff: float (Å) (recommend = 5), the surface points will be
labelled only if the distance of the point to the ligand's heavy atoms
within this distance cutoff.
Returns:
`self.radius_of_interest` was set to the radius_of_interest
`self.spacing` was set to the spacing
`self.distance_cutoff` was set to the distance_cutoff
"""
if radius_of_interest is not None:
self.radius_of_interest = radius_of_interest
if spacing is not None:
self.spacing = spacing
if distance_cutoff is not None:
self.distance_cutoff = distance_cutoff
if len(np.unique(self.universe.residues.segids)) > 1:
self.multi_segids = True
else:
self.multi_segids = False
[docs] def get_frame(self, frame_number: int):
"""Get the frame of the trajectory.
Arguments:
frame_number: int, the frame number to get
Returns:
`self.universe.trajectory` was set to the frame_number
"""
self.universe.trajectory[frame_number]
[docs] def get_residues_at_pocket_by_center(self, pocket_center: list = None):
"""Get the residues at the pocket by the pocket center and the radius
(`self.radius_of_interest`).
Arguments:
pocket_center: list, the pocket center. The default is `None`,
which will attempt to use the pocket center stored in the trajectory
handler.
Returns:
`self.residues_at_pocket` was set to the residues at the pocket
`self.residues_at_pocket_str` was set to the residues at the pocket
in string format
"""
if pocket_center is None:
pocket_center = self.pocket_center
pocket_residues = self.universe.select_atoms(
f"protein and byres point {self.__list2str(pocket_center)} "
f"{self.distance_cutoff}",
)
if self.multi_segids:
self.residues_at_pocket = [
f"{segid}_{resid}"
for segid, resid in zip(
pocket_residues.residues.segids,
pocket_residues.residues.resids,
)
]
else:
self.residues_at_pocket = [
str(each) for each in pocket_residues.residues.resids
]
self.residues_at_pocket_str = self.__list2str(self.residues_at_pocket)
[docs] def get_residues_at_pocket(
self, ligand_aa_dist: int = 5, aa_existence_time: float = 0.5
):
"""[Require ligand name] Get the resnames of the anchored residues
at the pocket over a trajectory.
Arguments:
ligand_aa_dist: int, the distance (Å) from the ligand to
consider the residues.
aa_existence_time: float, the fraction of the trajectory that
the residue should be present to be considered.
Returns:
`self.residues_at_pocket` was set to the residues at the pocket
`self.residues_at_pocket_str` was set to the residues at the pocket
in string format
"""
self.ligand_aa_dist = ligand_aa_dist
self.aa_existence_time = aa_existence_time
# sanity check
if self.ligand_name is None:
raise ValueError("ligand_name is not provided.")
# get the residues at the pocket
self.residues_at_pocket = get_ligand_around_resids(
u=self.universe,
ligand_name=self.ligand_name,
ligand_aa_dist=ligand_aa_dist,
aa_existence_time=aa_existence_time,
with_segid=True if self.multi_segids else False,
)
self.residues_at_pocket_str = self.__list2str(self.residues_at_pocket)
[docs] def get_pocket_center(self, frame: int = 0):
"""[Require ligand name] Get the pocket center at a specific frame
(default = 0).
Arguments:
frame: int, the frame number to get the pocket center
Returns:
self.pocket_center was set to the pocket center
self.pocket_center_str was set to the pocket center in string format
.. note::
Deprecated the deepdrug3d version to calculate the pocket center.
Instead, use mdanalysis to calculate the center of geometry.
"""
# requirement: get_pocket_residues
if not hasattr(self, "pocket_residues"):
self.get_pocket_residues()
self.get_frame(frame)
self.pocket_center = self.pocket_residues.center_of_geometry().tolist()
self.pocket_center_str = self.__list2str(self.pocket_center)
# reset the universe
self.get_frame(0)
[docs] def get_protein(self):
"""Get the protein from the trajectory by MDAnalysis selection.
Returns:
`self.protein` was set to the protein
"""
self.protein = self.universe.select_atoms("protein", updating=True)
[docs] def get_ligand(self):
"""[Require ligand name] Get the ligand from the trajectory
by MDAnalysis selection.
Returns:
`self.ligand` was set to the ligand
"""
self.ligand = self.universe.select_atoms(
f"resname {self.ligand_name}", updating=True
)
[docs] def get_complex(self):
"""[Require ligand name] Get the complex from the trajectory
(inlcuding, protein, ligand, protein + ligand) by MDAnalysis selection.
Returns:
`self.ligand` was set to the ligand
`self.protein` was set to the protein
`self.complex` was set to the complex (protein + ligand)
"""
self.get_protein()
self.get_ligand()
self.complex = self.universe.select_atoms(
f"protein or resname {self.ligand_name}", updating=True
)
[docs] def get_pocket_residues(self):
"""[Require `self.residues_at_pocket`] Get the residues at the pocket from
the trajectory by `MDAnalysis` selection.
Returns:
`self.pocket_residues` (`MDAnlysiis` AtomGroup) was set to the residues
at the pocket
"""
# requirement: get_residues_at_pocket
if not hasattr(self, "residues_at_pocket"):
self.get_residues_at_pocket()
assert self.residues_at_pocket != [], "residues_at_pocket is empty"
pocket_residue_str = self.__resid_for_selection(
self.residues_at_pocket, self.multi_segids
)
self.pocket_residues = self.universe.select_atoms(
f"protein and ({pocket_residue_str})", updating=True
)
[docs] def read_pocket_from_string(
self,
residues_at_pocket_str: str = None,
pocket_center_str: str = None,
):
"""Read the residues at the pocket and the pocket center from strings.
Arguments:
residues_at_pocket_str: str, the residues at the pocket in string format
(default: None)
pocket_center_str: str, the pocket center in string format
(default: None)
Returns:
`self.residues_at_pocket` was set to the residues at the pocket
`self.residues_at_pocket_str` was set to the residues at the pocket
in string format
`self.pocket_center` was set to the pocket center
`self.pocket_center_str` was set to the pocket center in string format
"""
if residues_at_pocket_str:
self.residues_at_pocket_str = residues_at_pocket_str.strip()
self.residues_at_pocket = [
each for each in residues_at_pocket_str.split(" ") if each != ""
]
if pocket_center_str:
self.pocket_center_str = pocket_center_str.strip()
pocket_center = [
float(each) for each in self.pocket_center_str.split(" ") if each != ""
]
assert len(pocket_center) == 3, (
"The pocket center is not providely correctly. "
"Please provide the pocket center in the format 'x y z'."
)
self.pocket_center = pocket_center
self.get_pocket_residues()
[docs] def read_pocket_aux_file(self, aux_file_path: str | Path):
"""Read the residues at the pocket and the pocket center from an auxiliary file.
Arguments:
aux_file_path: str, the path to the auxiliary file
Returns:
`self.residues_at_pocket` was set to the residues at the pocket
`self.residues_at_pocket_str` was set to the residues at the pocket
in string format
`self.pocket_center` was set to the pocket center
`self.pocket_center_str` was set to the pocket center in string format
"""
self.residues_at_pocket, content = read_aux_file(aux_file_path)
assert content[0].replace(" ", "") != "" or content[1].replace(" ", "") != "", (
"Both residues IDs and pocket center are not provided."
)
# residues in the pocket
if self.residues_at_pocket == []:
# if the residues at the pocket are not provided
self.get_residues_at_pocket_by_center(
pocket_center=[
float(each) for each in content[1].split(" ") if each != ""
]
)
else:
self.residues_at_pocket_str = content[0]
if content[1].replace(" ", "") == "":
# if pocket center is not provided, calculate it
self.get_pocket_center()
else:
self.pocket_center_str = content[1]
self.pocket_center = [
float(each) for each in self.pocket_center_str.split(" ") if each != ""
]
assert len(self.pocket_center) == 3, (
"The pocket center is not providely correctly. "
"Please provide the pocket center in the format 'x y z'."
)
self.get_pocket_residues()
[docs] def read_fragment_aux_file(self, aux_file_path: str | Path = None):
"""Read the fragments from an auxiliary file.
Arguments:
aux_file_path: str, the path to the auxiliary file
(format: json), if not provided, use the default example file
Returns:
`self.labels_info` was set to the fragments
Example:
.. code-block:: python
from ProBiSEnSe.utils.datasets.traj_handler import TrajectoryHandler
traj_handler = TrajectoryHandler(
top_path="example.pdb",
trajectory_path="example.xtc",
ligand_name="LIG",
)
aux_file_path = "example.json"
traj_handler.read_fragment_aux_file(aux_file_path)
print(traj_handler.label_fragment_info)
# Output:
>>> {
"0": {
"name": "out of the threshold"
},
"1": {
"name": "fragment 1",
"fragments_idx": [0, 1, 2, 3, 30, 45, 52]
},
"2": {
"name": "fragment 2",
"fragments_idx":
[4, 5, 6, 7, 26, 27, 29, 31, 32, 43, 44, 46, 47, 50]
}
"""
if aux_file_path is None:
aux_file_path = LIGAND_FRAG_INFO_PATH
with open(aux_file_path) as f:
self.label_fragment_info = json.load(f)
self.fragidx_label_dict = fragment_idx_label_dict(
labels_info=self.label_fragment_info
)
[docs] def align_traj_to_pocket(
self,
reference: mda.Universe | mda.AtomGroup | int = None,
select_Hs: bool = False,
update_pocket_center: bool = True,
):
"""Use the pocket resids to align the trajectory to the pocket.
Requirement: `residues_at_pocket`.
Arguments:
reference: MDAnalysis Universe object, AtomGroup object, or int,
the reference to align the trajectory
select_Hs: bool, if ``True``, the H atoms will be selected
update_pocket_center: bool, if ``True``, the pocket center will be
updated after the alignment (default: ``True``)
Returns:
Align the trajectory to the pocket. See `self.universe`,
it will be updated.
"""
around_u_resids = self.__resid_for_selection(
resid_list=self.residues_at_pocket,
with_segid=self.multi_segids,
)
# ignore the H atoms when aligning the structure
if not select_Hs:
# exclude hydrogens (some H name like 1HD1)
around_u_resids = (
f"{around_u_resids} and "
"(not ((name *H* and not name N* and not name O*) or (type H)))"
)
# align the trajectory to the pocket
self.get_frame(0) # set to frame 0
if reference is None:
alignment = align.AlignTraj(
mobile=self.universe,
reference=self.universe,
select=around_u_resids,
ref_frame=0,
in_memory=True,
)
elif isinstance(reference, int):
alignment = align.AlignTraj(
mobile=self.universe,
reference=self.universe,
select=around_u_resids,
ref_frame=reference,
in_memory=True,
)
else:
alignment = align.AlignTraj(
mobile=self.universe,
reference=reference,
select=around_u_resids,
in_memory=True,
)
alignment.run()
# update the pocket center
if update_pocket_center:
self.get_pocket_center()
if self.warning_check:
print(
"The pocket center has been updated to "
f"{self.pocket_center}"
" after the alignment."
)
[docs] def preprocess_workflow(
self,
pdb_path: str | Path,
ply_path: str | Path,
h5_path: str | Path,
frame: int = 0,
with_label: bool = True,
):
"""Preprocessing workflow for a frame, including writing the structure,
features, labels, interest region, and voxelised data.
Arguments:
pdb_path: str, the path to the PDB file
ply_path: str, the path to the PLY file
h5_path: str, the path to the H5 file
frame: int, the frame number to get the features
with_label: bool, if ``True``, the labels will be included in the H5 file
Returns:
Save the PDB file in `pdb_path`
Save the PLY file with the MASIF features in `ply_path`
Save the H5 file in `h5_path` with ['raw'] or ['raw' and 'label']
(if `with_label` is ``True``)
.. note::
- `raw`: the voxelised data of the features
- `label: the voxelised data of the labels (if `with_label` is ``True``)
"""
# check
assert self.pocket_center is not None, "pocket_center is not provided"
# write the structure
self.write_structure(pdb_path=pdb_path, frame=frame)
if self.warning_check:
print(f"Writing the PDB file: {frame} completed")
# write the features
self.write_features_to_ply(pdb_path=pdb_path, ply_path=ply_path, frame=frame)
self.add_interest_region_to_ply(ply_path=ply_path)
if with_label:
self.add_labels_to_ply(
ply_path=ply_path,
ref_ligand_frame=frame,
)
if self.warning_check:
print(
"Writing the PLY file with the features"
f"{' and labels' if with_label else ''}: {frame} completed"
)
# write the h5 file
self.write_voxelised_data_to_h5(
ply_path=ply_path,
h5_path=h5_path,
with_label=with_label,
)
if self.warning_check:
print(f"Writing the h5 file: {frame} completed")
[docs] def write_pocket_aux_file(self, aux_file_path: str | Path):
"""Write the residues at the pocket and the pocket center to an auxiliary file.
Arguments:
aux_file_path: str, the path to the auxiliary file
Returns:
Save the auxiliary file in `aux_file_path`
"""
# requirement: get_residues_at_pocket
if not hasattr(self, "residues_at_pocket"):
self.get_residues_at_pocket()
if not hasattr(self, "pocket_center"):
self.get_pocket_center()
write_aux_file(
aux_filepath=aux_file_path,
binding_residue_ids=self.residues_at_pocket_str,
binding_site_center=self.pocket_center_str,
)
[docs] def write_trajectory(
self,
traj_path: str | Path,
start_frame: int = 0,
end_frame: int = None,
structure_type: _TYPES = None,
step: int = 1,
):
"""Write the trajectory as a traj file.
Arguments:
traj_path: str, the path to the trajectory file
start_frame: int, the frame number to start
end_frame: int, the frame number to end. If ``None``,
it will be the total number of frames.
structure_type: str, the type of structure to write
(complex, protein, ligand). If ``None``, it will be all atoms.
step: int, the step to write the frames
Returns:
Save the trajectory file in `traj_path`.
"""
# initialisation
if end_frame is None:
end_frame = len(self.universe.trajectory)
# check
assert end_frame <= len(self.universe.trajectory), (
"end_frame should be less than the total number of frames."
)
if structure_type is not None:
assert hasattr(self, structure_type), (
f"{structure_type} is not provided. Please run `get_complex()` first."
)
# write the trajectory
atoms_to_save = (
self.universe.select_atoms("all", updating=True)
if structure_type is None
else getattr(self, structure_type)
)
with mda.Writer(traj_path, atoms_to_save.n_atoms) as w:
for ts in self.universe.trajectory[start_frame:end_frame:step]:
w.write(atoms_to_save)
# reset the universe
self.get_frame(0)
[docs] def write_structure(
self,
pdb_path: str | Path,
frame: int,
structure_type: _TYPES = "protein",
fragmentation: bool = False,
):
"""Write the structure as a PDB file for a specific frame.
Arguments:
pdb_path: str, the path to the PDB file
frame: int, the frame number to get the structure
structure_type: str, the type of structure to write
(complex, protein, ligand)
fragmentation: bool, if ``True``, the structure will be fragmented
Returns:
Save the PDB file in `pdb_path`
"""
# requirement: get_complex
if not hasattr(self, structure_type):
self.get_complex()
self.get_frame(frame)
if fragmentation and structure_type != "protein":
self._fragment_universe()
with mda.Writer(pdb_path, getattr(self, structure_type).n_atoms) as W:
W.write(getattr(self, structure_type))
# reset the universe
self.get_frame(0)
[docs] def write_features_to_ply(
self,
pdb_path: str | Path,
ply_path: str | Path,
frame: int = None,
):
"""Write the MASIF features to a PLY file. If the PDB file does
not exist, it will be created from the trajectory.
Arguments:
pdb_path: str, the path to the PDB file
ply_path: str, the path to the PLY file
frame: int, the frame number to get the features
Returns:
Save the PLY file with the MASIF features in `ply_path`
"""
# check
if not Path(pdb_path).is_file() and frame is not None:
if self.warning_check:
print(
f"The PDB file does not exist. Creating the PDB file from "
f"the trajectory (frame {frame})..."
)
self.write_structure(pdb_path=pdb_path, frame=frame)
elif not Path(pdb_path).is_file() and frame is None:
raise ValueError(
"The PDB file does not exist. If providing the frame number, "
"an PDB file will be created from the trajectory."
)
else:
if self.warning_check:
print(f"Using the PDB file: {pdb_path}")
# workflow of generating masif features
generate_masif_features(pdb_path, ply_path)
[docs] def write_voxelised_data_to_h5(
self,
ply_path: str | Path,
h5_path: str | Path,
with_label: bool = True,
):
"""Write the surface vertices into voxelised data and save in an H5 file.
Arguments:
ply_path: str, the path to the input PLY file
h5_path: str, the path to the outpu H5 file
with_label: bool, if ``True``, the labels will be included in the H5 file
Returns:
Save the H5 file in `h5_path` with ['raw'] or ['raw' and 'label']
(if `with_label` is ``True``)
.. note::
- `raw`: the voxelised data of the features
- `label`: the voxelised data of the labels \
(if `with_label` is ``True``)
"""
# check
assert Path(ply_path).is_file(), "The PLY file does not exist."
assert str(h5_path).endswith(".h5"), (
"The H5 file should have the extension '.h5'."
)
self._check_config()
# read mesh file
regular_mesh = pymesh.load_mesh(ply_path)
# voxelisation
voxel_features, voxel_labels = self._voxelisation(
mesh=regular_mesh, with_label=with_label
)
# write h5 file
if with_label:
write_h5(data=voxel_features, h5_filename=h5_path, label=voxel_labels)
else:
write_h5(data=voxel_features, h5_filename=h5_path, label=None)
[docs] def add_labels_to_ply(
self,
ply_path: str | Path,
ref_ligand_frame: int,
ply_path_output: str | Path = None,
):
"""Add the labels to a PLY file.
Arguments:
ply_path: str, the path to the PLY file
ref_ligand_frame: int, the frame number to get the reference
ligand for the surface.
ply_path_output: str, the path to the output PLY file
Returns:
Save the PLY file with the labels in `ply_path_output`.
If the `ply_path_output` is not provided, otherwise in `ply_path`.
"""
# check
if not hasattr(self, "ligand"):
self.get_complex()
assert self.distance_cutoff is not None, "distance_cutoff is not provided"
assert hasattr(self, "label_fragment_info"), (
"label_fragment_info is not provided, please read the "
"fragment auxiliary file first (using `read_fragment_aux_file`)."
)
# set output path
if ply_path_output is None:
ply_path_output = ply_path
# load the mesh
regular_mesh = pymesh.load_mesh(ply_path)
mesh_vertices_coords = regular_mesh.vertices
# load ligand atoms
ligand_heavy_atoms = self.ligand.select_atoms("(not name H*)", updating=True)
self.get_frame(ref_ligand_frame)
ligand_coords = self.ligand.atoms.positions
ligand_atoms_idxs = self.ligand.atoms.indices # get the ligand atoms indices
ligand_heavy_atoms_idxs = [
int(np.where(ligand_atoms_idxs == each)[0])
for each in ligand_heavy_atoms.atoms.indices
]
# generate the labels
labels = self.__generate_label_for_each_vertex_with_fragment(
mesh_coords=mesh_vertices_coords,
ligand_coords=ligand_coords,
ligand_heavy_atoms_idxs=ligand_heavy_atoms_idxs,
fragidx_label_dict=self.fragidx_label_dict,
distance_cutoff=self.distance_cutoff,
)
# add attribute and save mesh
regular_mesh.add_attribute("vertex_label")
regular_mesh.set_attribute("vertex_label", labels)
self.__save_mesh(ply_path_output, regular_mesh)
# reset the universe
self.get_frame(0)
[docs] def add_interest_region_to_ply(
self,
ply_path: str | Path,
ply_path_output: str | Path = None,
):
"""Add the interest region to a PLY file.
Arguments:
ply_path: str, the path to the input PLY file.
ply_path_output: str, the path to the output PLY file. If not provided,
the input PLY file will be overwritten.
Returns:
Save the PLY file with the interest region in `ply_path_output`.
If the `ply_path_output` is not provided, otherwise in `ply_path`.
"""
# set output path
if ply_path_output is None:
ply_path_output = ply_path
# load the mesh
regular_mesh = pymesh.load_mesh(ply_path)
mesh_vertices_coords = regular_mesh.vertices
# compute the distance between each vertex and the pocket center
distances = np.linalg.norm(mesh_vertices_coords - self.pocket_center, axis=1)
interest_vertices_bool = np.array(
[1 if each else 0 for each in (distances <= self.radius_of_interest)]
)
# add attribute and save mesh
regular_mesh.add_attribute("vertex_interest")
regular_mesh.set_attribute("vertex_interest", interest_vertices_bool)
self.__save_mesh(ply_path, regular_mesh)
def add_prediction_to_ply( # noqa: D102
self,
ply_path: str | Path,
h5_path: str | Path,
model,
ply_path_output: str | Path = None,
device: str = "cpu",
):
# set output path
if ply_path_output is None:
ply_path_output = ply_path
# load the mesh
regular_mesh = pymesh.load_mesh(ply_path)
# prediction
_, _, _, pred, probs, mask = self.__predict(h5_path, model, device)
# map the voxel to the xyz
preds_on_vertex, probs_on_vertex = self._map_voxel_to_vertices(
voxel_pred=pred,
voxel_probs=probs,
voxel_mask=mask,
regular_mesh=regular_mesh,
)
# add attribute and save mesh
regular_mesh.add_attribute("vertex_pred")
regular_mesh.set_attribute("vertex_pred", preds_on_vertex)
regular_mesh.add_attribute("vertex_predprobs")
regular_mesh.set_attribute("vertex_predprobs", probs_on_vertex)
self.__save_mesh(ply_path_output, regular_mesh)
def _check_config(self):
if self.radius_of_interest is None:
raise ValueError("radius_of_interest is not provided")
elif self.spacing is None:
raise ValueError("spacing is not provided")
else:
assert self.radius_of_interest / self.spacing < 100, (
"The r/spacing should be less than 100 to avoid computer freezing."
)
def _check_ligand(self):
if self.ligand_name is None:
print("ligand_name is not provided")
else:
if (
self.ligand_name
not in self.universe.select_atoms("not protein").residues.resnames
):
raise ValueError(f"{self.ligand_name} is not in the trajectory")
else:
print(f"{self.ligand_name} is in the trajectory")
def _get_attribute(self):
return {
each_object: type(getattr(self, each_object))
for each_object in self.__dict__.keys()
}
def _voxelisation(self, mesh: pymesh.Mesh, with_label: bool = True):
# check
if not hasattr(self, "pocket_center"):
self.get_pocket_center()
# normlaise with the pocket center
mesh_vertices = mesh.vertices - self.pocket_center
# get the features from the mesh file
apbs_charge = mesh.get_attribute("vertex_charge")
hbond = mesh.get_attribute("vertex_hbond")
hphob = mesh.get_attribute("vertex_hphob")
occupancy = mesh.get_attribute("vertex_interest")
features = np.column_stack((apbs_charge, hbond, hphob, occupancy))
channel = features.shape[1]
# get the label from the mesh file
label = mesh.get_attribute("vertex_label") if with_label else None
# merge the vertices, features, and labels (for voxelisation mapping)
mesh_data = (
np.column_stack((mesh_vertices, features, label))
if with_label
else np.column_stack((mesh_vertices, features))
)
interest_vetices_mask = (occupancy == 1).astype(bool)
mesh_data = mesh_data[interest_vetices_mask]
# voxelisation
log_info = True if self.warning_check else False
grid_voxel = site_voxelization(
site=mesh_data,
r=self.radius_of_interest,
spacing=self.spacing,
shape=False,
pass_voxel_coord=False,
log_info=log_info,
)
return grid_voxel[0:channel, :, :, :], (
grid_voxel[channel, :, :, :] if with_label else None
)
def _fragment_universe(self):
mol_f_idx_keys = natsorted(self.label_fragment_info.keys())
mol_f_idx_keys.remove("0") # ignore the background label
mol_f_idx = [
self.label_fragment_info[each]["fragments_idx"] for each in mol_f_idx_keys
]
self.universe = fragmentation_from_universe(
universe=self.universe,
ligand_name=self.ligand_name,
mol_f_idx=mol_f_idx,
using_type="fragment_idx",
)
return None
def _map_voxel_to_vertices(
self,
voxel_pred: np.array,
voxel_probs: np.array,
voxel_mask: np.array,
regular_mesh: pymesh.Mesh,
):
# map voxel to xyz
voxel_pred_probs = np.vstack([voxel_pred, voxel_probs])
voxel_xyz_pred_probs = map_voxel_to_xyz(
data=voxel_pred_probs,
r=self.radius_of_interest,
spacing=self.spacing,
pocket_center=self.pocket_center,
mask=voxel_mask,
)
voxel_xyz = voxel_xyz_pred_probs[:, 0:3]
voxel_pred = voxel_xyz_pred_probs[:, 3]
voxel_probs = voxel_xyz_pred_probs[:, 4]
# map to vertices
mesh_vertices_coords = regular_mesh.vertices
vertex_interest = regular_mesh.get_attribute("vertex_interest")
preds_on_vertex = []
probs_on_vertex = []
for idx, each_interest in enumerate(vertex_interest):
if each_interest == 1:
vertex_xyz = mesh_vertices_coords[idx]
# find the cloest point in test_xyz
dist = np.linalg.norm(voxel_xyz - vertex_xyz, axis=1)
min_idx = np.argmin(dist)
preds_on_vertex.append(voxel_pred[min_idx])
probs_on_vertex.append(voxel_probs[min_idx])
else:
preds_on_vertex.append(0)
probs_on_vertex.append(0)
preds_on_vertex = np.array(preds_on_vertex)
probs_on_vertex = np.array(probs_on_vertex, dtype=float)
return preds_on_vertex, probs_on_vertex
# basic functions (without using self.variables)
@staticmethod
def __list2str(list_input: list, sep: str = " "):
return sep.join([str(each) for each in list_input])
@staticmethod
def __resid_for_selection(resid_list: list, with_segid: bool = False):
if with_segid:
return " or ".join(
[
f"(segid {each.split('_')[0]} and resid {each.split('_')[1]})"
for each in resid_list
]
)
else:
return " or ".join([f"resid {each}" for each in resid_list])
@staticmethod
def __save_mesh(ply_path: str | Path, mesh: pymesh.Mesh):
pymesh.save_mesh(
ply_path,
mesh,
*[
each
for each in mesh.get_attribute_names()
if each
not in ["vertex_x", "vertex_y", "vertex_z", "face_vertex_indices"]
],
use_float=True,
ascii=True,
)
@staticmethod
def __generate_label_for_each_vertex_with_fragment(
mesh_coords: np.array,
ligand_coords: np.array,
ligand_heavy_atoms_idxs: list[int],
fragidx_label_dict: dict,
distance_cutoff: float = 5.0,
):
labels = []
for vertex in mesh_coords:
# compute the distance between each vertex and the heavy atoms to
# make sure it is in the threshold
distances = np.linalg.norm(ligand_coords - vertex, axis=1)
distances_heavy_atoms = distances[ligand_heavy_atoms_idxs]
if np.min(distances_heavy_atoms) <= distance_cutoff:
min_distance_idx = ligand_heavy_atoms_idxs[
np.argmin(distances_heavy_atoms)
]
# In case of multiple occurrences of the minimum values,
# the indices corresponding to the first occurrence are returned.
labels.append(fragidx_label_dict[min_distance_idx])
else:
labels.append(0)
return np.array(labels, dtype=float)
@staticmethod
def __predict(
h5_path: str | Path,
model,
device: str = "cpu",
):
"""Predict the labels from the H5 file.
Arguments:
h5_path: str, the path to the H5 file
model: the model to predict
device: str, the device to run the model
Returns:
tuple:
- `data`: the input data
- `_label_if_exist`: the labels if exist else return None
- `outputs`: the outputs from the model
- `pred`: the predictions
- `probs`: the probabilities
- `mask`: the mask
"""
# load data
data, _label_if_exist = load_h5(h5_path)
mask_col_idx = get_mask_col_idx(data)
data, _label_if_exist = preprocess_h5(
data, _label_if_exist, except_channels=[mask_col_idx]
)
# data: [C, H, W, D]
# inference (prediction)
outputs, pred, probs, mask = inference(model, data, device, return_mask=True)
# outputs: [C, H, W, D]; predict, probs, mask: [H, W, D]
pred = np.expand_dims(pred, axis=0)
probs = np.expand_dims(probs, axis=0)
# pred, probs: [1, H, W, D]
return data, _label_if_exist, outputs, pred, probs, mask