Skip to content

Commit

Permalink
Search trajectory for outliers (#381)
Browse files Browse the repository at this point in the history
* apply

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* add fail check, fix non initialized issue

* rename to `AllowedMoleculeFilter`

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* pass `frames` and rename to `AllowedStructuresFilter`

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
PythonFZ and pre-commit-ci[bot] authored Dec 17, 2024
1 parent 536a54b commit 9a7519e
Show file tree
Hide file tree
Showing 3 changed files with 63 additions and 0 deletions.
2 changes: 2 additions & 0 deletions ipsuite/__init__.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ from . import base

# Analysis
from .analysis import (
AllowedStructuresFilter,
AnalyseDensity,
AnalyseGlobalForceSensitivity,
AnalyseSingleForceSensitivity,
Expand Down Expand Up @@ -170,6 +171,7 @@ __all__ = [
"TemperatureCheck",
"AnalyseDensity",
"CollectMDSteps",
"AllowedStructuresFilter",
# Calculators
"CP2KSinglePoint",
"ASEGeoOpt",
Expand Down
2 changes: 2 additions & 0 deletions ipsuite/analysis/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
TemperatureCheck,
ThresholdCheck,
)
from ipsuite.analysis.molecules import AllowedStructuresFilter
from ipsuite.analysis.sensitivity import (
AnalyseGlobalForceSensitivity,
AnalyseSingleForceSensitivity,
Expand Down Expand Up @@ -63,4 +64,5 @@
"EnergyUncertaintyHistogram",
"AnalyseDensity",
"CollectMDSteps",
"AllowedStructuresFilter",
]
59 changes: 59 additions & 0 deletions ipsuite/analysis/molecules.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
import ase
import rdkit2ase
import tqdm
import zntrack

from ipsuite import base
from ipsuite.geometry import BarycenterMapping


class AllowedStructuresFilter(base.IPSNode):
"""Search a given dataset for outliers.
Iterates all structures in the dataset, uses covalent radii to determine
the atoms in each molecule, and checks if the molecule is allowed.
Attributes
----------
data : list[ase.Atoms]
The dataset to search.
molecules : list[ase.Atoms], optional
The molecules that are allowed.
smiles : list[str], optional
The SMILES strings of the allowed molecules.
"""

data: list[ase.Atoms] = zntrack.deps()
molecules: list[ase.Atoms] = zntrack.deps(default_factory=list)
smiles: list[str] = zntrack.params(default_factory=list)
fail: bool = zntrack.params(False)

outliers: list[int] = zntrack.outs()

def run(self):
molecules = self.molecules + [rdkit2ase.smiles2atoms(s) for s in self.smiles]
mapping = BarycenterMapping()
self.outliers = []
for idx, atoms in enumerate(tqdm.tqdm(self.data)):
_, mols = mapping.forward_mapping(atoms)
for mol in mols:
# check if the atomic numbers are the same
if sorted(mol.get_atomic_numbers()) in [
sorted(m.get_atomic_numbers()) for m in molecules
]:
continue
if self.fail:
raise ValueError(f"Outlier found at index {idx} for molecule {mol}")
else:
print(f"Outlier found at index {idx} for molecule {mol}")
self.outliers.append(idx)

@property
def excluded_frames(self) -> list[ase.Atoms]:
return [self.data[idx] for idx in self.outliers]

@property
def frames(self) -> list[ase.Atoms]:
return [
self.data[idx] for idx in range(len(self.data)) if idx not in self.outliers
]

0 comments on commit 9a7519e

Please sign in to comment.