diff --git a/molpipeline/abstract_pipeline_elements/mol2mol/filter.py b/molpipeline/abstract_pipeline_elements/mol2mol/filter.py index 4142efde..b19cdb5b 100644 --- a/molpipeline/abstract_pipeline_elements/mol2mol/filter.py +++ b/molpipeline/abstract_pipeline_elements/mol2mol/filter.py @@ -318,6 +318,11 @@ def patterns_mol_dict(self, patterns: Sequence[str]) -> None: List of patterns. """ self._patterns_mol_dict = {pat: self._pattern_to_mol(pat) for pat in patterns} + failed_patterns = [ + pat for pat, mol in self._patterns_mol_dict.items() if not mol + ] + if failed_patterns: + raise ValueError("Invalid pattern(s): " + ", ".join(failed_patterns)) @abc.abstractmethod def _pattern_to_mol(self, pattern: str) -> RDKitMol: diff --git a/tests/test_elements/test_mol2mol/test_mol2mol_filter.py b/tests/test_elements/test_mol2mol/test_mol2mol_filter.py index d814c4e9..ca0f9dd5 100644 --- a/tests/test_elements/test_mol2mol/test_mol2mol_filter.py +++ b/tests/test_elements/test_mol2mol/test_mol2mol_filter.py @@ -173,6 +173,22 @@ def test_smarts_smiles_filter(self) -> None: filtered_smiles = pipeline.fit_transform(SMILES_LIST) self.assertEqual(filtered_smiles, test_params["result"]) + def test_smarts_smiles_filter_wrong_pattern(self) -> None: + """Test if molecules are filtered correctly by allowed SMARTS patterns.""" + smarts_pats: dict[str, IntOrIntCountRange] = { + "cIOnk": (4, None), + "cC": 1, + } + with self.assertRaises(ValueError): + SmartsFilter(smarts_pats) + + smiles_pats: dict[str, IntOrIntCountRange] = { + "cC": (1, None), + "Cl": 1, + } + with self.assertRaises(ValueError): + SmilesFilter(smiles_pats) + def test_smarts_filter_parallel(self) -> None: """Test if molecules are filtered correctly by allowed SMARTS patterns in parallel.""" smarts_pats: dict[str, IntOrIntCountRange] = {