diff --git a/posebusters/__init__.py b/posebusters/__init__.py index 2385bfc..fad5467 100644 --- a/posebusters/__init__.py +++ b/posebusters/__init__.py @@ -24,4 +24,4 @@ "check_volume_overlap", ] -__version__ = "0.2.7" +__version__ = "0.2.8" diff --git a/posebusters/modules/rmsd.py b/posebusters/modules/rmsd.py index c9eb368..24e0d99 100644 --- a/posebusters/modules/rmsd.py +++ b/posebusters/modules/rmsd.py @@ -25,7 +25,9 @@ tautomer_enumerator.SetRemoveSp3Stereo(True) -def check_rmsd(mol_pred: Mol, mol_true: Mol, rmsd_threshold: float = 2.0) -> dict[str, dict[str, bool | float]]: +def check_rmsd( + mol_pred: Mol, mol_true: Mol, rmsd_threshold: float = 2.0, heavy_only: bool = True, choose_by: str = "rmsd" +) -> dict[str, dict[str, bool | float]]: """Calculate RMSD and related metrics between predicted molecule and closest ground truth molecule. Args: @@ -33,6 +35,8 @@ def check_rmsd(mol_pred: Mol, mol_true: Mol, rmsd_threshold: float = 2.0) -> dic mol_true: Ground truth molecule (crystal ligand) with at least one conformer. If multiple conformers are present, the lowest RMSD will be reported. rmsd_threshold: Threshold in angstrom for reporting whether RMSD is within threshold. Defaults to 2.0. + heavy_only: Whether to only consider heavy atoms for RMSD calculation. Defaults to True. + choose_by: Metric to choose which mol_true conformation to compare to. Defaults to "rmsd". Returns: PoseBusters results dictionary. @@ -43,16 +47,34 @@ def check_rmsd(mol_pred: Mol, mol_true: Mol, rmsd_threshold: float = 2.0) -> dic assert num_conf > 0, "Crystal ligand needs at least one conformer." assert mol_pred.GetNumConformers() == 1, "Docked ligand should only have one conformer." - rmsds = [robust_rmsd(mol_true, mol_pred, conf_id_probe=i) for i in range(num_conf)] - kabsch_rmsd = [robust_rmsd(mol_true, mol_pred, conf_id_probe=i, kabsch=True) for i in range(num_conf)] - - i = np.argmin(np.nan_to_num(rmsds, nan=np.inf)) + rmsds = [robust_rmsd(mol_true, mol_pred, conf_id_probe=i, heavy_only=heavy_only) for i in range(num_conf)] + kabsch_rmsds = [ + robust_rmsd(mol_true, mol_pred, conf_id_probe=i, kabsch=True, heavy_only=heavy_only) for i in range(num_conf) + ] + intercentroids = [ + intercentroid(mol_true, mol_pred, conf_id_probe=i, heavy_only=heavy_only) for i in range(num_conf) + ] + + if choose_by == "rmsd": + i = np.argmin(np.nan_to_num(rmsds, nan=np.inf)) + elif choose_by == "kabsch_rmsd": + i = np.argmin(np.nan_to_num(kabsch_rmsds, nan=np.inf)) + elif choose_by == "centroid_distance": + i = np.argmin(np.nan_to_num(intercentroids, nan=np.inf)) + else: + raise ValueError(f"Invalid value {choose_by} for choose_by. Use 'rmsd', 'kabsch_rmsd', 'centroid_distance'.") rmsd = rmsds[i] - kabsch_rmsd = kabsch_rmsd[i] + kabsch_rmsd = kabsch_rmsds[i] + centroid_dist = intercentroids[i] rmsd_within_threshold = rmsd <= rmsd_threshold - results = {"rmsd": rmsd, "kabsch_rmsd": kabsch_rmsd, "rmsd_within_threshold": rmsd_within_threshold} + results = { + "rmsd": rmsd, + "kabsch_rmsd": kabsch_rmsd, + "centroid_distance": centroid_dist, + "rmsd_within_threshold": rmsd_within_threshold, + } return {"results": results} @@ -136,3 +158,17 @@ def _rmsd(mol_probe: Mol, mol_ref: Mol, conf_id_probe: int, conf_id_ref: int, ka if kabsch is True: return GetBestRMS(prbMol=mol_probe, refMol=mol_ref, prbId=conf_id_probe, refId=conf_id_ref, **params) return CalcRMS(prbMol=mol_probe, refMol=mol_ref, prbId=conf_id_probe, refId=conf_id_ref, **params) + + +def intercentroid( + mol_probe: Mol, mol_ref: Mol, conf_id_probe: int = -1, conf_id_ref: int = -1, heavy_only: bool = True +) -> float: + """Distance between centroids of two molecules.""" + if heavy_only: + mol_probe = RemoveHs(mol_probe) + mol_ref = RemoveHs(mol_ref) + + centroid_probe = mol_probe.GetConformer(conf_id_probe).GetPositions().mean(axis=0) + centroid_ref = mol_ref.GetConformer(conf_id_ref).GetPositions().mean(axis=0) + + return float(np.linalg.norm(centroid_probe - centroid_ref))