diff --git a/ipsuite/__init__.pyi b/ipsuite/__init__.pyi index 20c269fd..0f730805 100644 --- a/ipsuite/__init__.pyi +++ b/ipsuite/__init__.pyi @@ -2,6 +2,7 @@ from . import base # Analysis from .analysis import ( + AllowedStructuresFilter, AnalyseDensity, AnalyseGlobalForceSensitivity, AnalyseSingleForceSensitivity, @@ -170,6 +171,7 @@ __all__ = [ "TemperatureCheck", "AnalyseDensity", "CollectMDSteps", + "AllowedStructuresFilter", # Calculators "CP2KSinglePoint", "ASEGeoOpt", diff --git a/ipsuite/analysis/__init__.py b/ipsuite/analysis/__init__.py index 28a74dd8..cf90ee96 100644 --- a/ipsuite/analysis/__init__.py +++ b/ipsuite/analysis/__init__.py @@ -27,6 +27,7 @@ TemperatureCheck, ThresholdCheck, ) +from ipsuite.analysis.molecules import AllowedStructuresFilter from ipsuite.analysis.sensitivity import ( AnalyseGlobalForceSensitivity, AnalyseSingleForceSensitivity, @@ -63,4 +64,5 @@ "EnergyUncertaintyHistogram", "AnalyseDensity", "CollectMDSteps", + "AllowedStructuresFilter", ] diff --git a/ipsuite/analysis/molecules.py b/ipsuite/analysis/molecules.py new file mode 100644 index 00000000..c9b018ae --- /dev/null +++ b/ipsuite/analysis/molecules.py @@ -0,0 +1,59 @@ +import ase +import rdkit2ase +import tqdm +import zntrack + +from ipsuite import base +from ipsuite.geometry import BarycenterMapping + + +class AllowedStructuresFilter(base.IPSNode): + """Search a given dataset for outliers. + + Iterates all structures in the dataset, uses covalent radii to determine + the atoms in each molecule, and checks if the molecule is allowed. + + Attributes + ---------- + data : list[ase.Atoms] + The dataset to search. + molecules : list[ase.Atoms], optional + The molecules that are allowed. + smiles : list[str], optional + The SMILES strings of the allowed molecules. + """ + + data: list[ase.Atoms] = zntrack.deps() + molecules: list[ase.Atoms] = zntrack.deps(default_factory=list) + smiles: list[str] = zntrack.params(default_factory=list) + fail: bool = zntrack.params(False) + + outliers: list[int] = zntrack.outs() + + def run(self): + molecules = self.molecules + [rdkit2ase.smiles2atoms(s) for s in self.smiles] + mapping = BarycenterMapping() + self.outliers = [] + for idx, atoms in enumerate(tqdm.tqdm(self.data)): + _, mols = mapping.forward_mapping(atoms) + for mol in mols: + # check if the atomic numbers are the same + if sorted(mol.get_atomic_numbers()) in [ + sorted(m.get_atomic_numbers()) for m in molecules + ]: + continue + if self.fail: + raise ValueError(f"Outlier found at index {idx} for molecule {mol}") + else: + print(f"Outlier found at index {idx} for molecule {mol}") + self.outliers.append(idx) + + @property + def excluded_frames(self) -> list[ase.Atoms]: + return [self.data[idx] for idx in self.outliers] + + @property + def frames(self) -> list[ase.Atoms]: + return [ + self.data[idx] for idx in range(len(self.data)) if idx not in self.outliers + ]