Skip to content

Commit

Permalink
final review
Browse files Browse the repository at this point in the history
  • Loading branch information
frederik-sandfort1 committed Oct 7, 2024
1 parent d345191 commit 08d58f4
Show file tree
Hide file tree
Showing 2 changed files with 186 additions and 19 deletions.
116 changes: 107 additions & 9 deletions molpipeline/mol2mol/filter.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,10 @@
from molpipeline.abstract_pipeline_elements.mol2mol import (
BasePatternsFilter as _BasePatternsFilter,
)
from molpipeline.abstract_pipeline_elements.mol2mol.filter import _within_boundaries
from molpipeline.abstract_pipeline_elements.mol2mol.filter import (
FilterModeType,
_within_boundaries,
)
from molpipeline.utils.molpipeline_types import (
FloatCountRange,
IntCountRange,
Expand Down Expand Up @@ -290,8 +293,8 @@ class ComplexFilter(_BaseKeepMatchesFilter):
Attributes
----------
filter_elements: Sequence[_MolToMolPipelineElement]
MolToMol elements to use as filters.
pipeline_filter_elements: Sequence[tuple[str, _MolToMolPipelineElement]]
pairs of unique names and MolToMol elements to use as filters.
[...]
Notes
Expand All @@ -303,28 +306,123 @@ class ComplexFilter(_BaseKeepMatchesFilter):
- mode = "all" & keep_matches = False: Must not match all filter elements.
"""

_filter_elements: Mapping[_MolToMolPipelineElement, tuple[int, Optional[int]]]
_filter_elements: Mapping[str, tuple[int, Optional[int]]]

def __init__(
self,
pipeline_filter_elements: Sequence[tuple[str, _MolToMolPipelineElement]],
keep_matches: bool = True,
mode: FilterModeType = "any",
name: str | None = None,
n_jobs: int = 1,
uuid: str | None = None,
) -> None:
"""Initialize ComplexFilter.
Parameters
----------
pipeline_filter_elements: Sequence[tuple[str, _MolToMolPipelineElement]]
Filter elements to use.
keep_matches: bool, optional (default: True)
If True, keep matches, else remove matches.
mode: FilterModeType, optional (default: "any")
Mode to filter by.
name: str, optional (default: None)
Name of the pipeline element.
n_jobs: int, optional (default: 1)
Number of parallel jobs to use.
uuid: str, optional (default: None)
Unique identifier of the pipeline element.
"""
self.pipeline_filter_elements = pipeline_filter_elements
super().__init__(
filter_elements=pipeline_filter_elements,
keep_matches=keep_matches,
mode=mode,
name=name,
n_jobs=n_jobs,
uuid=uuid,
)

def get_params(self, deep: bool = True) -> dict[str, Any]:
"""Get parameters of ComplexFilter.
Parameters
----------
deep: bool, optional (default: True)
If True, return the parameters of all subobjects that are PipelineElements.
Returns
-------
dict[str, Any]
Parameters of ComplexFilter.
"""
params = super().get_params(deep)
params.pop("filter_elements")
params["pipeline_filter_elements"] = self.pipeline_filter_elements
if deep:
for name, element in self.pipeline_filter_elements:
deep_items = element.get_params().items()
params.update(
("pipeline_filter_elements" + "__" + name + "__" + key, val)
for key, val in deep_items
)
return params

def set_params(self, **parameters: Any) -> Self:
"""Set parameters of ComplexFilter.
Parameters
----------
parameters: Any
Parameters to set.
Returns
-------
Self
Self.
"""
parameter_copy = dict(parameters)
if "pipeline_filter_elements" in parameter_copy:
self.pipeline_filter_elements = parameter_copy.pop(
"pipeline_filter_elements"
)
self.filter_elements = self.pipeline_filter_elements # type: ignore
for key in parameters:
if key.startswith("pipeline_filter_elements__"):
value = parameter_copy.pop(key)
element_name, element_key = key.split("__")[1:]
for name, element in self.pipeline_filter_elements:
if name == element_name:
element.set_params(**{element_key: value})
super().set_params(**parameter_copy)
return self

@property
def filter_elements(
self,
) -> Mapping[_MolToMolPipelineElement, tuple[int, Optional[int]]]:
) -> Mapping[str, tuple[int, Optional[int]]]:
"""Get filter elements."""
return self._filter_elements

@filter_elements.setter
def filter_elements(
self,
filter_elements: Sequence[_MolToMolPipelineElement],
filter_elements: Sequence[tuple[str, _MolToMolPipelineElement]],
) -> None:
"""Set filter elements.
Parameters
----------
filter_elements: dict[_MolToMolPipelineElement, tuple[int, Optional[int]]]
filter_elements: Sequence[tuple[str, _MolToMolPipelineElement]]
Filter elements to set.
"""
self._filter_elements = {element: (1, None) for element in filter_elements}
self.filter_elements_dict = dict(filter_elements)
if not len(self.filter_elements_dict) == len(filter_elements):
raise ValueError("Filter elements names need to be unique.")
self._filter_elements = {
element_name: (1, None) for element_name, _element in filter_elements
}

def _calculate_single_element_value(
self, filter_element: Any, value: RDKitMol
Expand All @@ -343,7 +441,7 @@ def _calculate_single_element_value(
int
Filter match.
"""
mol = filter_element.pretransform_single(value)
mol = self.filter_elements_dict[filter_element].pretransform_single(value)
if isinstance(mol, InvalidInstance):
return 0
return 1
Expand Down
89 changes: 79 additions & 10 deletions tests/test_elements/test_mol2mol/test_mol2mol_filter.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
SmartsFilter,
SmilesFilter,
)
from molpipeline.utils.json_operations import recursive_from_json, recursive_to_json
from molpipeline.utils.molpipeline_types import FloatCountRange, IntOrIntCountRange

# pylint: disable=duplicate-code # test case molecules are allowed to be duplicated
Expand All @@ -32,8 +33,8 @@
]


class MolFilterTest(unittest.TestCase):
"""Unittest for MolFilter, which invalidate molecules based on criteria defined in the respective filter."""
class ElementFilterTest(unittest.TestCase):
"""Unittest for Elementiflter."""

def test_element_filter(self) -> None:
"""Test if molecules are filtered correctly by allowed chemical elements."""
Expand Down Expand Up @@ -87,12 +88,22 @@ def test_element_filter(self) -> None:
filtered_smiles = pipeline.fit_transform(SMILES_LIST)
self.assertEqual(filtered_smiles, test_params["result"])

def test_complex_filter(self) -> None:
"""Test if molecules are filtered correctly by allowed chemical elements."""

class ComplexFilterTest(unittest.TestCase):
"""Unittest for ComplexFilter."""

@staticmethod
def _create_pipeline():
"""Create a pipeline with a complex filter."""
element_filter_1 = ElementFilter({6: 6, 1: 6})
element_filter_2 = ElementFilter({6: 6, 1: 5, 17: 1})

multi_element_filter = ComplexFilter((element_filter_1, element_filter_2))
multi_element_filter = ComplexFilter(
(
("element_filter_1", element_filter_1),
("element_filter_2", element_filter_2),
)
)

pipeline = Pipeline(
[
Expand All @@ -102,13 +113,59 @@ def test_complex_filter(self) -> None:
("ErrorFilter", ErrorFilter()),
],
)
return pipeline

def test_complex_filter(self) -> None:
"""Test if molecules are filtered correctly by allowed chemical elements."""
pipeline = ComplexFilterTest._create_pipeline()

test_params_list_with_results = [
{
"params": {},
"result": [SMILES_BENZENE, SMILES_CHLOROBENZENE],
},
{
"params": {"MultiElementFilter__mode": "all"},
"result": [],
},
{
"params": {
"MultiElementFilter__mode": "any",
"MultiElementFilter__pipeline_filter_elements__element_filter_1__add_hydrogens": False,
},
"result": [SMILES_CHLOROBENZENE],
},
]

for test_params in test_params_list_with_results:
pipeline.set_params(**test_params["params"])
filtered_smiles = pipeline.fit_transform(SMILES_LIST)
self.assertEqual(filtered_smiles, test_params["result"])

def test_json_serialization(self) -> None:
"""Test if complex filter can be serialized and deserialized."""
pipeline = ComplexFilterTest._create_pipeline()
json_object = recursive_to_json(pipeline)
newpipeline = recursive_from_json(json_object)
self.assertEqual(json_object, recursive_to_json(newpipeline))

pipeline_result = pipeline.fit_transform(SMILES_LIST)
newpipeline_result = newpipeline.fit_transform(SMILES_LIST)
self.assertEqual(pipeline_result, newpipeline_result)

def test_complex_filter_non_unique_names(self) -> None:
"""Test if molecules are filtered correctly by allowed chemical elements."""
element_filter_1 = ElementFilter({6: 6, 1: 6})
element_filter_2 = ElementFilter({6: 6, 1: 5, 17: 1})

with self.assertRaises(ValueError):
ComplexFilter(
(("filter_1", element_filter_1), ("filter_1", element_filter_2))
)

filtered_smiles = pipeline.fit_transform(SMILES_LIST)
self.assertEqual(filtered_smiles, [SMILES_BENZENE, SMILES_CHLOROBENZENE])

pipeline.set_params(MultiElementFilter__mode="all")
filtered_smiles_2 = pipeline.fit_transform(SMILES_LIST)
self.assertEqual(filtered_smiles_2, [])
class SmartsSmilesFilterTest(unittest.TestCase):
"""Unittest for SmartsFilter and SmilesFilter."""

def test_smarts_smiles_filter(self) -> None:
"""Test if molecules are filtered correctly by allowed SMARTS patterns."""
Expand Down Expand Up @@ -214,6 +271,10 @@ def test_smarts_filter_parallel(self) -> None:
filtered_smiles = pipeline.fit_transform(SMILES_LIST)
self.assertEqual(filtered_smiles, [SMILES_CHLOROBENZENE])


class RDKitDescriptorsFilterTest(unittest.TestCase):
"""Unittest for RDKitDescriptorsFilter."""

def test_descriptor_filter(self) -> None:
"""Test if molecules are filtered correctly by allowed descriptors."""
descriptors: dict[str, FloatCountRange] = {
Expand Down Expand Up @@ -288,6 +349,10 @@ def test_descriptor_filter(self) -> None:
filtered_smiles = pipeline.fit_transform(SMILES_LIST)
self.assertEqual(filtered_smiles, test_params["result"])


class MixtureFilterTest(unittest.TestCase):
"""Unittest for MixtureFilter."""

def test_invalidate_mixtures(self) -> None:
"""Test if mixtures are correctly invalidated."""
mol_list = ["CCC.CC.C", "c1ccccc1.[Na+].[Cl-]", "c1ccccc1"]
Expand All @@ -311,6 +376,10 @@ def test_invalidate_mixtures(self) -> None:
mols_processed = pipeline.fit_transform(mol_list)
self.assertEqual(expected_invalidated_mol_list, mols_processed)


class InorganicsFilterTest(unittest.TestCase):
"""Unittest for InorganicsFilter."""

def test_inorganic_filter(self) -> None:
"""Test if molecules are filtered correctly by allowed chemical elements."""
smiles2mol = SmilesToMol()
Expand Down

0 comments on commit 08d58f4

Please sign in to comment.