Skip to content

Commit

Permalink
include check for failed patterns in init
Browse files Browse the repository at this point in the history
  • Loading branch information
frederik-sandfort1 committed Oct 7, 2024
1 parent 1f8dc1c commit d345191
Show file tree
Hide file tree
Showing 2 changed files with 21 additions and 0 deletions.
5 changes: 5 additions & 0 deletions molpipeline/abstract_pipeline_elements/mol2mol/filter.py
Original file line number Diff line number Diff line change
Expand Up @@ -318,6 +318,11 @@ def patterns_mol_dict(self, patterns: Sequence[str]) -> None:
List of patterns.
"""
self._patterns_mol_dict = {pat: self._pattern_to_mol(pat) for pat in patterns}
failed_patterns = [
pat for pat, mol in self._patterns_mol_dict.items() if not mol
]
if failed_patterns:
raise ValueError("Invalid pattern(s): " + ", ".join(failed_patterns))

@abc.abstractmethod
def _pattern_to_mol(self, pattern: str) -> RDKitMol:
Expand Down
16 changes: 16 additions & 0 deletions tests/test_elements/test_mol2mol/test_mol2mol_filter.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,6 +173,22 @@ def test_smarts_smiles_filter(self) -> None:
filtered_smiles = pipeline.fit_transform(SMILES_LIST)
self.assertEqual(filtered_smiles, test_params["result"])

def test_smarts_smiles_filter_wrong_pattern(self) -> None:
"""Test if molecules are filtered correctly by allowed SMARTS patterns."""
smarts_pats: dict[str, IntOrIntCountRange] = {
"cIOnk": (4, None),
"cC": 1,
}
with self.assertRaises(ValueError):
SmartsFilter(smarts_pats)

smiles_pats: dict[str, IntOrIntCountRange] = {
"cC": (1, None),
"Cl": 1,
}
with self.assertRaises(ValueError):
SmilesFilter(smiles_pats)

def test_smarts_filter_parallel(self) -> None:
"""Test if molecules are filtered correctly by allowed SMARTS patterns in parallel."""
smarts_pats: dict[str, IntOrIntCountRange] = {
Expand Down

0 comments on commit d345191

Please sign in to comment.