diff --git a/molpipeline/mol2mol/filter.py b/molpipeline/mol2mol/filter.py index 5902eee7..894758ba 100644 --- a/molpipeline/mol2mol/filter.py +++ b/molpipeline/mol2mol/filter.py @@ -24,7 +24,10 @@ from molpipeline.abstract_pipeline_elements.mol2mol import ( BasePatternsFilter as _BasePatternsFilter, ) -from molpipeline.abstract_pipeline_elements.mol2mol.filter import _within_boundaries +from molpipeline.abstract_pipeline_elements.mol2mol.filter import ( + FilterModeType, + _within_boundaries, +) from molpipeline.utils.molpipeline_types import ( FloatCountRange, IntCountRange, @@ -290,8 +293,8 @@ class ComplexFilter(_BaseKeepMatchesFilter): Attributes ---------- - filter_elements: Sequence[_MolToMolPipelineElement] - MolToMol elements to use as filters. + pipeline_filter_elements: Sequence[tuple[str, _MolToMolPipelineElement]] + pairs of unique names and MolToMol elements to use as filters. [...] Notes @@ -303,28 +306,123 @@ class ComplexFilter(_BaseKeepMatchesFilter): - mode = "all" & keep_matches = False: Must not match all filter elements. """ - _filter_elements: Mapping[_MolToMolPipelineElement, tuple[int, Optional[int]]] + _filter_elements: Mapping[str, tuple[int, Optional[int]]] + + def __init__( + self, + pipeline_filter_elements: Sequence[tuple[str, _MolToMolPipelineElement]], + keep_matches: bool = True, + mode: FilterModeType = "any", + name: str | None = None, + n_jobs: int = 1, + uuid: str | None = None, + ) -> None: + """Initialize ComplexFilter. + + Parameters + ---------- + pipeline_filter_elements: Sequence[tuple[str, _MolToMolPipelineElement]] + Filter elements to use. + keep_matches: bool, optional (default: True) + If True, keep matches, else remove matches. + mode: FilterModeType, optional (default: "any") + Mode to filter by. + name: str, optional (default: None) + Name of the pipeline element. + n_jobs: int, optional (default: 1) + Number of parallel jobs to use. + uuid: str, optional (default: None) + Unique identifier of the pipeline element. + """ + self.pipeline_filter_elements = pipeline_filter_elements + super().__init__( + filter_elements=pipeline_filter_elements, + keep_matches=keep_matches, + mode=mode, + name=name, + n_jobs=n_jobs, + uuid=uuid, + ) + + def get_params(self, deep: bool = True) -> dict[str, Any]: + """Get parameters of ComplexFilter. + + Parameters + ---------- + deep: bool, optional (default: True) + If True, return the parameters of all subobjects that are PipelineElements. + + Returns + ------- + dict[str, Any] + Parameters of ComplexFilter. + """ + params = super().get_params(deep) + params.pop("filter_elements") + params["pipeline_filter_elements"] = self.pipeline_filter_elements + if deep: + for name, element in self.pipeline_filter_elements: + deep_items = element.get_params().items() + params.update( + ("pipeline_filter_elements" + "__" + name + "__" + key, val) + for key, val in deep_items + ) + return params + + def set_params(self, **parameters: Any) -> Self: + """Set parameters of ComplexFilter. + + Parameters + ---------- + parameters: Any + Parameters to set. + + Returns + ------- + Self + Self. + """ + parameter_copy = dict(parameters) + if "pipeline_filter_elements" in parameter_copy: + self.pipeline_filter_elements = parameter_copy.pop( + "pipeline_filter_elements" + ) + self.filter_elements = self.pipeline_filter_elements # type: ignore + for key in parameters: + if key.startswith("pipeline_filter_elements__"): + value = parameter_copy.pop(key) + element_name, element_key = key.split("__")[1:] + for name, element in self.pipeline_filter_elements: + if name == element_name: + element.set_params(**{element_key: value}) + super().set_params(**parameter_copy) + return self @property def filter_elements( self, - ) -> Mapping[_MolToMolPipelineElement, tuple[int, Optional[int]]]: + ) -> Mapping[str, tuple[int, Optional[int]]]: """Get filter elements.""" return self._filter_elements @filter_elements.setter def filter_elements( self, - filter_elements: Sequence[_MolToMolPipelineElement], + filter_elements: Sequence[tuple[str, _MolToMolPipelineElement]], ) -> None: """Set filter elements. Parameters ---------- - filter_elements: dict[_MolToMolPipelineElement, tuple[int, Optional[int]]] + filter_elements: Sequence[tuple[str, _MolToMolPipelineElement]] Filter elements to set. """ - self._filter_elements = {element: (1, None) for element in filter_elements} + self.filter_elements_dict = dict(filter_elements) + if not len(self.filter_elements_dict) == len(filter_elements): + raise ValueError("Filter elements names need to be unique.") + self._filter_elements = { + element_name: (1, None) for element_name, _element in filter_elements + } def _calculate_single_element_value( self, filter_element: Any, value: RDKitMol @@ -343,7 +441,7 @@ def _calculate_single_element_value( int Filter match. """ - mol = filter_element.pretransform_single(value) + mol = self.filter_elements_dict[filter_element].pretransform_single(value) if isinstance(mol, InvalidInstance): return 0 return 1 diff --git a/tests/test_elements/test_mol2mol/test_mol2mol_filter.py b/tests/test_elements/test_mol2mol/test_mol2mol_filter.py index ca0f9dd5..2c100b62 100644 --- a/tests/test_elements/test_mol2mol/test_mol2mol_filter.py +++ b/tests/test_elements/test_mol2mol/test_mol2mol_filter.py @@ -14,6 +14,7 @@ SmartsFilter, SmilesFilter, ) +from molpipeline.utils.json_operations import recursive_from_json, recursive_to_json from molpipeline.utils.molpipeline_types import FloatCountRange, IntOrIntCountRange # pylint: disable=duplicate-code # test case molecules are allowed to be duplicated @@ -32,8 +33,8 @@ ] -class MolFilterTest(unittest.TestCase): - """Unittest for MolFilter, which invalidate molecules based on criteria defined in the respective filter.""" +class ElementFilterTest(unittest.TestCase): + """Unittest for Elementiflter.""" def test_element_filter(self) -> None: """Test if molecules are filtered correctly by allowed chemical elements.""" @@ -87,12 +88,22 @@ def test_element_filter(self) -> None: filtered_smiles = pipeline.fit_transform(SMILES_LIST) self.assertEqual(filtered_smiles, test_params["result"]) - def test_complex_filter(self) -> None: - """Test if molecules are filtered correctly by allowed chemical elements.""" + +class ComplexFilterTest(unittest.TestCase): + """Unittest for ComplexFilter.""" + + @staticmethod + def _create_pipeline(): + """Create a pipeline with a complex filter.""" element_filter_1 = ElementFilter({6: 6, 1: 6}) element_filter_2 = ElementFilter({6: 6, 1: 5, 17: 1}) - multi_element_filter = ComplexFilter((element_filter_1, element_filter_2)) + multi_element_filter = ComplexFilter( + ( + ("element_filter_1", element_filter_1), + ("element_filter_2", element_filter_2), + ) + ) pipeline = Pipeline( [ @@ -102,13 +113,59 @@ def test_complex_filter(self) -> None: ("ErrorFilter", ErrorFilter()), ], ) + return pipeline + + def test_complex_filter(self) -> None: + """Test if molecules are filtered correctly by allowed chemical elements.""" + pipeline = ComplexFilterTest._create_pipeline() + + test_params_list_with_results = [ + { + "params": {}, + "result": [SMILES_BENZENE, SMILES_CHLOROBENZENE], + }, + { + "params": {"MultiElementFilter__mode": "all"}, + "result": [], + }, + { + "params": { + "MultiElementFilter__mode": "any", + "MultiElementFilter__pipeline_filter_elements__element_filter_1__add_hydrogens": False, + }, + "result": [SMILES_CHLOROBENZENE], + }, + ] + + for test_params in test_params_list_with_results: + pipeline.set_params(**test_params["params"]) + filtered_smiles = pipeline.fit_transform(SMILES_LIST) + self.assertEqual(filtered_smiles, test_params["result"]) + + def test_json_serialization(self) -> None: + """Test if complex filter can be serialized and deserialized.""" + pipeline = ComplexFilterTest._create_pipeline() + json_object = recursive_to_json(pipeline) + newpipeline = recursive_from_json(json_object) + self.assertEqual(json_object, recursive_to_json(newpipeline)) + + pipeline_result = pipeline.fit_transform(SMILES_LIST) + newpipeline_result = newpipeline.fit_transform(SMILES_LIST) + self.assertEqual(pipeline_result, newpipeline_result) + + def test_complex_filter_non_unique_names(self) -> None: + """Test if molecules are filtered correctly by allowed chemical elements.""" + element_filter_1 = ElementFilter({6: 6, 1: 6}) + element_filter_2 = ElementFilter({6: 6, 1: 5, 17: 1}) + + with self.assertRaises(ValueError): + ComplexFilter( + (("filter_1", element_filter_1), ("filter_1", element_filter_2)) + ) - filtered_smiles = pipeline.fit_transform(SMILES_LIST) - self.assertEqual(filtered_smiles, [SMILES_BENZENE, SMILES_CHLOROBENZENE]) - pipeline.set_params(MultiElementFilter__mode="all") - filtered_smiles_2 = pipeline.fit_transform(SMILES_LIST) - self.assertEqual(filtered_smiles_2, []) +class SmartsSmilesFilterTest(unittest.TestCase): + """Unittest for SmartsFilter and SmilesFilter.""" def test_smarts_smiles_filter(self) -> None: """Test if molecules are filtered correctly by allowed SMARTS patterns.""" @@ -214,6 +271,10 @@ def test_smarts_filter_parallel(self) -> None: filtered_smiles = pipeline.fit_transform(SMILES_LIST) self.assertEqual(filtered_smiles, [SMILES_CHLOROBENZENE]) + +class RDKitDescriptorsFilterTest(unittest.TestCase): + """Unittest for RDKitDescriptorsFilter.""" + def test_descriptor_filter(self) -> None: """Test if molecules are filtered correctly by allowed descriptors.""" descriptors: dict[str, FloatCountRange] = { @@ -288,6 +349,10 @@ def test_descriptor_filter(self) -> None: filtered_smiles = pipeline.fit_transform(SMILES_LIST) self.assertEqual(filtered_smiles, test_params["result"]) + +class MixtureFilterTest(unittest.TestCase): + """Unittest for MixtureFilter.""" + def test_invalidate_mixtures(self) -> None: """Test if mixtures are correctly invalidated.""" mol_list = ["CCC.CC.C", "c1ccccc1.[Na+].[Cl-]", "c1ccccc1"] @@ -311,6 +376,10 @@ def test_invalidate_mixtures(self) -> None: mols_processed = pipeline.fit_transform(mol_list) self.assertEqual(expected_invalidated_mol_list, mols_processed) + +class InorganicsFilterTest(unittest.TestCase): + """Unittest for InorganicsFilter.""" + def test_inorganic_filter(self) -> None: """Test if molecules are filtered correctly by allowed chemical elements.""" smiles2mol = SmilesToMol()