diff --git a/molpipeline/abstract_pipeline_elements/mol2mol/__init__.py b/molpipeline/abstract_pipeline_elements/mol2mol/__init__.py new file mode 100644 index 00000000..15d7b5a4 --- /dev/null +++ b/molpipeline/abstract_pipeline_elements/mol2mol/__init__.py @@ -0,0 +1,8 @@ +"""Initialize the module for abstract mol2mol elements.""" + +from molpipeline.abstract_pipeline_elements.mol2mol.filter import ( + BaseKeepMatchesFilter, + BasePatternsFilter, +) + +__all__ = ["BasePatternsFilter", "BaseKeepMatchesFilter"] diff --git a/molpipeline/abstract_pipeline_elements/mol2mol/filter.py b/molpipeline/abstract_pipeline_elements/mol2mol/filter.py new file mode 100644 index 00000000..b19cdb5b --- /dev/null +++ b/molpipeline/abstract_pipeline_elements/mol2mol/filter.py @@ -0,0 +1,359 @@ +"""Abstract classes for filters.""" + +import abc +from typing import Any, Literal, Mapping, Optional, Sequence, TypeAlias, Union + +try: + from typing import Self # type: ignore[attr-defined] +except ImportError: + from typing_extensions import Self + +from molpipeline.abstract_pipeline_elements.core import ( + InvalidInstance, + MolToMolPipelineElement, + OptionalMol, + RDKitMol, +) +from molpipeline.utils.molpipeline_types import ( + FloatCountRange, + IntCountRange, + IntOrIntCountRange, +) +from molpipeline.utils.value_conversions import count_value_to_tuple + +# possible mode types for a KeepMatchesFilter: +# - "any" means one match is enough +# - "all" means all elements must be matched +FilterModeType: TypeAlias = Literal["any", "all"] + + +def _within_boundaries( + lower_bound: Optional[float], upper_bound: Optional[float], property_value: float +) -> bool: + """Check if a value is within the specified boundaries. + + Boundaries given as None are ignored. + + Parameters + ---------- + lower_bound: Optional[float] + Lower boundary. + upper_bound: Optional[float] + Upper boundary. + property_value: float + Property value to check. + + Returns + ------- + bool + True if the value is within the boundaries, else False. + """ + if lower_bound is not None and property_value < lower_bound: + return False + if upper_bound is not None and property_value > upper_bound: + return False + return True + + +class BaseKeepMatchesFilter(MolToMolPipelineElement, abc.ABC): + """Filter to keep or remove molecules based on patterns. + + Notes + ----- + There are four possible scenarios: + - mode = "any" & keep_matches = True: Needs to match at least one filter element. + - mode = "any" & keep_matches = False: Must not match any filter element. + - mode = "all" & keep_matches = True: Needs to match all filter elements. + - mode = "all" & keep_matches = False: Must not match all filter elements. + """ + + keep_matches: bool + mode: FilterModeType + + def __init__( + self, + filter_elements: Union[ + Mapping[Any, Union[FloatCountRange, IntCountRange, IntOrIntCountRange]], + Sequence[Any], + ], + keep_matches: bool = True, + mode: FilterModeType = "any", + name: Optional[str] = None, + n_jobs: int = 1, + uuid: Optional[str] = None, + ) -> None: + """Initialize BasePatternsFilter. + + Parameters + ---------- + filter_elements: Union[Mapping[Any, Union[FloatCountRange, IntCountRange, IntOrIntCountRange]], Sequence[Any]] + List of filter elements. Typically can be a list of patterns or a dictionary with patterns as keys and + an int for exact count or a tuple of minimum and maximum. + NOTE: for each child class, the type of filter_elements must be specified by the filter_elements setter. + keep_matches: bool, optional (default: True) + If True, molecules containing the specified patterns are kept, else removed. + mode: FilterModeType, optional (default: "any") + If "any", at least one of the specified patterns must be present in the molecule. + If "all", all of the specified patterns must be present in the molecule. + name: Optional[str], optional (default: None) + Name of the pipeline element. + n_jobs: int, optional (default: 1) + Number of parallel jobs to use. + uuid: str, optional (default: None) + Unique identifier of the pipeline element. + """ + super().__init__(name=name, n_jobs=n_jobs, uuid=uuid) + self.filter_elements = filter_elements # type: ignore + self.keep_matches = keep_matches + self.mode = mode + + @property + @abc.abstractmethod + def filter_elements( + self, + ) -> Mapping[Any, FloatCountRange]: + """Get filter elements as dict.""" + + @filter_elements.setter + @abc.abstractmethod + def filter_elements( + self, + filter_elements: Union[Mapping[Any, FloatCountRange], Sequence[Any]], + ) -> None: + """Set filter elements as dict. + + Parameters + ---------- + filter_elements: Union[Mapping[Any, FloatCountRange], Sequence[Any]] + List of filter elements. + """ + + def set_params(self, **parameters: Any) -> Self: + """Set parameters of BaseKeepMatchesFilter. + + Parameters + ---------- + parameters: Any + Parameters to set. + + Returns + ------- + Self + Self. + """ + parameter_copy = dict(parameters) + if "keep_matches" in parameter_copy: + self.keep_matches = parameter_copy.pop("keep_matches") + if "mode" in parameter_copy: + self.mode = parameter_copy.pop("mode") + if "filter_elements" in parameter_copy: + self.filter_elements = parameter_copy.pop("filter_elements") + super().set_params(**parameter_copy) + return self + + def get_params(self, deep: bool = True) -> dict[str, Any]: + """Get parameters of PatternFilter. + + Parameters + ---------- + deep: bool, optional (default: True) + If True, return the parameters of all subobjects that are PipelineElements. + + Returns + ------- + dict[str, Any] + Parameters of BaseKeepMatchesFilter. + """ + params = super().get_params(deep=deep) + params["keep_matches"] = self.keep_matches + params["mode"] = self.mode + params["filter_elements"] = self.filter_elements + return params + + def pretransform_single(self, value: RDKitMol) -> OptionalMol: + """Invalidate or validate molecule based on specified filter. + + There are four possible scenarios: + - mode = "any" & keep_matches = True: Needs to match at least one filter element. + - mode = "any" & keep_matches = False: Must not match any filter element. + - mode = "all" & keep_matches = True: Needs to match all filter elements. + - mode = "all" & keep_matches = False: Must not match all filter elements. + + Parameters + ---------- + value: RDKitMol + Molecule to check. + + Returns + ------- + OptionalMol + Molecule that matches defined filter elements, else InvalidInstance. + """ + for filter_element, (lower_limit, upper_limit) in self.filter_elements.items(): + property_value = self._calculate_single_element_value(filter_element, value) + if _within_boundaries(lower_limit, upper_limit, property_value): + # For "any" mode we can return early if a match is found + if self.mode == "any": + if not self.keep_matches: + value = InvalidInstance( + self.uuid, + f"Molecule contains forbidden filter element {filter_element}.", + self.name, + ) + return value + else: + # For "all" mode we can return early if a match is not found + if self.mode == "all": + if self.keep_matches: + value = InvalidInstance( + self.uuid, + f"Molecule does not contain required filter element {filter_element}.", + self.name, + ) + return value + + # If this point is reached, no or all patterns were found + # If mode is "any", finishing the loop means no match was found + if self.mode == "any": + if self.keep_matches: + value = InvalidInstance( + self.uuid, + "Molecule does not match any of the required filter elements.", + self.name, + ) + # else: No match with forbidden filter elements was found, return original molecule + return value + + if self.mode == "all": + if not self.keep_matches: + value = InvalidInstance( + self.uuid, + "Molecule matches all forbidden filter elements.", + self.name, + ) + # else: All required filter elements were found, return original molecule + return value + + raise ValueError(f"Invalid mode: {self.mode}") + + @abc.abstractmethod + def _calculate_single_element_value( + self, filter_element: Any, value: RDKitMol + ) -> float: + """Calculate the value of a single match. + + Parameters + ---------- + filter_element: Any + Match case to calculate. + value: RDKitMol + Molecule to calculate the match for. + + Returns + ------- + float + Value of the match. + """ + + +class BasePatternsFilter(BaseKeepMatchesFilter, abc.ABC): + """Filter to keep or remove molecules based on patterns. + + Attributes + ---------- + filter_elements: Union[Sequence[str], Mapping[str, IntOrIntCountRange]] + List of patterns to allow in molecules. + Alternatively, a dictionary can be passed with patterns as keys + and an int for exact count or a tuple of minimum and maximum. + [...] + + Notes + ----- + There are four possible scenarios: + - mode = "any" & keep_matches = True: Needs to match at least one filter element. + - mode = "any" & keep_matches = False: Must not match any filter element. + - mode = "all" & keep_matches = True: Needs to match all filter elements. + - mode = "all" & keep_matches = False: Must not match all filter elements. + """ + + _filter_elements: Mapping[str, IntCountRange] + + @property + def filter_elements(self) -> Mapping[str, IntCountRange]: + """Get allowed filter elements (patterns) as dict.""" + return self._filter_elements + + @filter_elements.setter + def filter_elements( + self, + patterns: Union[list[str], Mapping[str, IntOrIntCountRange]], + ) -> None: + """Set allowed filter elements (patterns) as dict. + + Parameters + ---------- + patterns: Union[list[str], Mapping[str, IntOrIntCountRange]] + List of patterns. + """ + if isinstance(patterns, (list, set)): + self._filter_elements = {pat: (1, None) for pat in patterns} + else: + self._filter_elements = { + pat: count_value_to_tuple(count) for pat, count in patterns.items() + } + self.patterns_mol_dict = list(self._filter_elements.keys()) # type: ignore + + @property + def patterns_mol_dict(self) -> Mapping[str, RDKitMol]: + """Get patterns as dict with RDKitMol objects.""" + return self._patterns_mol_dict + + @patterns_mol_dict.setter + def patterns_mol_dict(self, patterns: Sequence[str]) -> None: + """Set patterns as dict with RDKitMol objects. + + Parameters + ---------- + patterns: Sequence[str] + List of patterns. + """ + self._patterns_mol_dict = {pat: self._pattern_to_mol(pat) for pat in patterns} + failed_patterns = [ + pat for pat, mol in self._patterns_mol_dict.items() if not mol + ] + if failed_patterns: + raise ValueError("Invalid pattern(s): " + ", ".join(failed_patterns)) + + @abc.abstractmethod + def _pattern_to_mol(self, pattern: str) -> RDKitMol: + """Convert pattern to Rdkitmol object. + + Parameters + ---------- + pattern: str + Pattern to convert. + + Returns + ------- + RDKitMol + RDKitMol object of the pattern. + """ + + def _calculate_single_element_value( + self, filter_element: Any, value: RDKitMol + ) -> int: + """Calculate a single match count for a molecule. + + Parameters + ---------- + filter_element: Any + smarts to calculate match count for. + value: RDKitMol + Molecule to calculate smarts match count for. + + Returns + ------- + int + smarts match count value. + """ + return len(value.GetSubstructMatches(self.patterns_mol_dict[filter_element])) diff --git a/molpipeline/mol2mol/__init__.py b/molpipeline/mol2mol/__init__.py index 356e4a63..7f6ed1ae 100644 --- a/molpipeline/mol2mol/__init__.py +++ b/molpipeline/mol2mol/__init__.py @@ -1,10 +1,14 @@ """Init the module for mol2mol pipeline elements.""" from molpipeline.mol2mol.filter import ( + ComplexFilter, ElementFilter, EmptyMoleculeFilter, InorganicsFilter, MixtureFilter, + RDKitDescriptorsFilter, + SmartsFilter, + SmilesFilter, ) from molpipeline.mol2mol.reaction import MolToMolReaction from molpipeline.mol2mol.scaffolds import MakeScaffoldGeneric, MurckoScaffold @@ -41,4 +45,8 @@ "SolventRemover", "Uncharger", "InorganicsFilter", + "SmartsFilter", + "SmilesFilter", + "RDKitDescriptorsFilter", + "ComplexFilter", ) diff --git a/molpipeline/mol2mol/filter.py b/molpipeline/mol2mol/filter.py index f0b4f24a..894758ba 100644 --- a/molpipeline/mol2mol/filter.py +++ b/molpipeline/mol2mol/filter.py @@ -2,24 +2,48 @@ from __future__ import annotations -from typing import Any, Optional +from collections import Counter +from typing import Any, Mapping, Optional, Sequence, Union try: from typing import Self # type: ignore[attr-defined] except ImportError: from typing_extensions import Self +from loguru import logger from rdkit import Chem +from rdkit.Chem import Descriptors from molpipeline.abstract_pipeline_elements.core import InvalidInstance from molpipeline.abstract_pipeline_elements.core import ( MolToMolPipelineElement as _MolToMolPipelineElement, ) -from molpipeline.utils.molpipeline_types import OptionalMol, RDKitMol +from molpipeline.abstract_pipeline_elements.mol2mol import ( + BaseKeepMatchesFilter as _BaseKeepMatchesFilter, +) +from molpipeline.abstract_pipeline_elements.mol2mol import ( + BasePatternsFilter as _BasePatternsFilter, +) +from molpipeline.abstract_pipeline_elements.mol2mol.filter import ( + FilterModeType, + _within_boundaries, +) +from molpipeline.utils.molpipeline_types import ( + FloatCountRange, + IntCountRange, + IntOrIntCountRange, + OptionalMol, + RDKitMol, +) +from molpipeline.utils.value_conversions import count_value_to_tuple class ElementFilter(_MolToMolPipelineElement): - """ElementFilter which removes molecules containing chemical elements other than specified.""" + """ElementFilter which removes molecules containing chemical elements other than specified. + + Molecular elements are filtered based on their atomic number. + The filter can be configured to allow only specific elements and/or a specific number of atoms of each element. + """ DEFAULT_ALLOWED_ELEMENT_NUMBERS = [ 1, @@ -39,7 +63,10 @@ class ElementFilter(_MolToMolPipelineElement): def __init__( self, - allowed_element_numbers: Optional[list[int]] = None, + allowed_element_numbers: Optional[ + Union[list[int], dict[int, IntOrIntCountRange]] + ] = None, + add_hydrogens: bool = True, name: str = "ElementFilter", n_jobs: int = 1, uuid: Optional[str] = None, @@ -48,9 +75,12 @@ def __init__( Parameters ---------- - allowed_element_numbers: list[int] + allowed_element_numbers: Optional[Union[list[int], dict[int, IntOrIntCountRange]]] List of atomic numbers of elements to allowed in molecules. Per default allowed elements are: H, B, C, N, O, F, Si, P, S, Cl, Se, Br, I. + Alternatively, a dictionary can be passed with atomic numbers as keys and an int for exact count or a tuple of minimum and maximum + add_hydrogens: bool, optional (default: True) + If True, in case Hydrogens are in allowed_element_list, add hydrogens to the molecule before filtering. name: str, optional (default: "ElementFilterPipe") Name of the pipeline element. n_jobs: int, optional (default: 1) @@ -59,12 +89,65 @@ def __init__( Unique identifier of the pipeline element. """ super().__init__(name=name, n_jobs=n_jobs, uuid=uuid) + self.allowed_element_numbers = allowed_element_numbers # type: ignore + self.add_hydrogens = add_hydrogens + + @property + def add_hydrogens(self) -> bool: + """Get add_hydrogens.""" + return self._add_hydrogens + + @add_hydrogens.setter + def add_hydrogens(self, add_hydrogens: bool) -> None: + """Set add_hydrogens. + + Parameters + ---------- + add_hydrogens: bool + If True, in case Hydrogens are in allowed_element_list, add hydrogens to the molecule before filtering. + """ + self._add_hydrogens = add_hydrogens + if self.add_hydrogens and 1 in self.allowed_element_numbers: + self.process_hydrogens = True + else: + if 1 in self.allowed_element_numbers: + logger.warning( + "Hydrogens are included in allowed_element_numbers, but add_hydrogens is set to False. " + "Thus hydrogens are NOT added before filtering. You might receive unexpected results." + ) + self.process_hydrogens = False + + @property + def allowed_element_numbers(self) -> dict[int, IntCountRange]: + """Get allowed element numbers as dict.""" + return self._allowed_element_numbers + + @allowed_element_numbers.setter + def allowed_element_numbers( + self, + allowed_element_numbers: Optional[ + Union[list[int], dict[int, IntOrIntCountRange]] + ], + ) -> None: + """Set allowed element numbers as dict. + + Parameters + ---------- + allowed_element_numbers: Optional[Union[list[int], dict[int, IntOrIntCountRange]] + List of atomic numbers of elements to allowed in molecules. + """ + self._allowed_element_numbers: dict[int, IntCountRange] if allowed_element_numbers is None: allowed_element_numbers = self.DEFAULT_ALLOWED_ELEMENT_NUMBERS - if not isinstance(allowed_element_numbers, set): - self.allowed_element_numbers = set(allowed_element_numbers) + if isinstance(allowed_element_numbers, (list, set)): + self._allowed_element_numbers = { + atom_number: (0, None) for atom_number in allowed_element_numbers + } else: - self.allowed_element_numbers = allowed_element_numbers + self._allowed_element_numbers = { + atom_number: count_value_to_tuple(count) + for atom_number, count in allowed_element_numbers.items() + } def get_params(self, deep: bool = True) -> dict[str, Any]: """Get parameters of ElementFilter. @@ -82,10 +165,12 @@ def get_params(self, deep: bool = True) -> dict[str, Any]: params = super().get_params(deep=deep) if deep: params["allowed_element_numbers"] = { - int(atom) for atom in self.allowed_element_numbers + atom_number: (count_tuple[0], count_tuple[1]) + for atom_number, count_tuple in self.allowed_element_numbers.items() } else: params["allowed_element_numbers"] = self.allowed_element_numbers + params["add_hydrogens"] = self.add_hydrogens return params def set_params(self, **parameters: Any) -> Self: @@ -104,6 +189,8 @@ def set_params(self, **parameters: Any) -> Self: parameter_copy = dict(parameters) if "allowed_element_numbers" in parameter_copy: self.allowed_element_numbers = parameter_copy.pop("allowed_element_numbers") + if "add_hydrogens" in parameter_copy: + self.add_hydrogens = parameter_copy.pop("add_hydrogens") super().set_params(**parameter_copy) return self @@ -120,38 +207,307 @@ def pretransform_single(self, value: RDKitMol) -> OptionalMol: OptionalMol Molecule if it contains only allowed elements, else InvalidInstance. """ - unique_elements = set(atom.GetAtomicNum() for atom in value.GetAtoms()) - if not unique_elements.issubset(self.allowed_element_numbers): - forbidden_elements = unique_elements - self.allowed_element_numbers + to_process_value = Chem.AddHs(value) if self.process_hydrogens else value + elements_list = [atom.GetAtomicNum() for atom in to_process_value.GetAtoms()] + elements_counter = Counter(elements_list) + if any( + element not in self.allowed_element_numbers for element in elements_counter + ): return InvalidInstance( - self.uuid, - f"Molecule contains following forbidden elements: {forbidden_elements}", - self.name, + self.uuid, "Molecule contains forbidden chemical element.", self.name ) + for element, (lower_limit, upper_limit) in self.allowed_element_numbers.items(): + count = elements_counter[element] + if not _within_boundaries(lower_limit, upper_limit, count): + return InvalidInstance( + self.uuid, + f"Molecule contains forbidden number of element {element}.", + self.name, + ) return value -class MixtureFilter(_MolToMolPipelineElement): - """MolToMol which removes molecules composed of multiple fragments.""" +class SmartsFilter(_BasePatternsFilter): + """Filter to keep or remove molecules based on SMARTS patterns. + + Notes + ----- + There are four possible scenarios: + - mode = "any" & keep_matches = True: Needs to match at least one filter element. + - mode = "any" & keep_matches = False: Must not match any filter element. + - mode = "all" & keep_matches = True: Needs to match all filter elements. + - mode = "all" & keep_matches = False: Must not match all filter elements. + """ + + def _pattern_to_mol(self, pattern: str) -> RDKitMol: + """Convert SMARTS pattern to RDKit molecule. + + Parameters + ---------- + pattern: str + SMARTS pattern to convert. + + Returns + ------- + RDKitMol + RDKit molecule. + """ + return Chem.MolFromSmarts(pattern) + + +class SmilesFilter(_BasePatternsFilter): + """Filter to keep or remove molecules based on SMILES patterns. + + In contrast to the SMARTSFilter, which also can match SMILES, the SmilesFilter + sanitizes the molecules and, e.g. checks kekulized bonds for aromaticity and + then sets it to aromatic while the SmartsFilter detects alternating single and + double bonds. + + Notes + ----- + There are four possible scenarios: + - mode = "any" & keep_matches = True: Needs to match at least one filter element. + - mode = "any" & keep_matches = False: Must not match any filter element. + - mode = "all" & keep_matches = True: Needs to match all filter elements. + - mode = "all" & keep_matches = False: Must not match all filter elements. + """ + + def _pattern_to_mol(self, pattern: str) -> RDKitMol: + """Convert SMILES pattern to RDKit molecule. - def __int__( + Parameters + ---------- + pattern: str + SMILES pattern to convert. + + Returns + ------- + RDKitMol + RDKit molecule. + """ + return Chem.MolFromSmiles(pattern) + + +class ComplexFilter(_BaseKeepMatchesFilter): + """Filter to keep or remove molecules based on multiple filter elements. + + Attributes + ---------- + pipeline_filter_elements: Sequence[tuple[str, _MolToMolPipelineElement]] + pairs of unique names and MolToMol elements to use as filters. + [...] + + Notes + ----- + There are four possible scenarios: + - mode = "any" & keep_matches = True: Needs to match at least one filter element. + - mode = "any" & keep_matches = False: Must not match any filter element. + - mode = "all" & keep_matches = True: Needs to match all filter elements. + - mode = "all" & keep_matches = False: Must not match all filter elements. + """ + + _filter_elements: Mapping[str, tuple[int, Optional[int]]] + + def __init__( self, - name: str = "MixtureFilter", + pipeline_filter_elements: Sequence[tuple[str, _MolToMolPipelineElement]], + keep_matches: bool = True, + mode: FilterModeType = "any", + name: str | None = None, n_jobs: int = 1, - uuid: Optional[str] = None, + uuid: str | None = None, ) -> None: - """Initialize MixtureFilter. + """Initialize ComplexFilter. Parameters ---------- - name: str, optional (default: "MixtureFilterPipe") + pipeline_filter_elements: Sequence[tuple[str, _MolToMolPipelineElement]] + Filter elements to use. + keep_matches: bool, optional (default: True) + If True, keep matches, else remove matches. + mode: FilterModeType, optional (default: "any") + Mode to filter by. + name: str, optional (default: None) Name of the pipeline element. n_jobs: int, optional (default: 1) Number of parallel jobs to use. uuid: str, optional (default: None) Unique identifier of the pipeline element. """ - super().__init__(name=name, n_jobs=n_jobs, uuid=uuid) + self.pipeline_filter_elements = pipeline_filter_elements + super().__init__( + filter_elements=pipeline_filter_elements, + keep_matches=keep_matches, + mode=mode, + name=name, + n_jobs=n_jobs, + uuid=uuid, + ) + + def get_params(self, deep: bool = True) -> dict[str, Any]: + """Get parameters of ComplexFilter. + + Parameters + ---------- + deep: bool, optional (default: True) + If True, return the parameters of all subobjects that are PipelineElements. + + Returns + ------- + dict[str, Any] + Parameters of ComplexFilter. + """ + params = super().get_params(deep) + params.pop("filter_elements") + params["pipeline_filter_elements"] = self.pipeline_filter_elements + if deep: + for name, element in self.pipeline_filter_elements: + deep_items = element.get_params().items() + params.update( + ("pipeline_filter_elements" + "__" + name + "__" + key, val) + for key, val in deep_items + ) + return params + + def set_params(self, **parameters: Any) -> Self: + """Set parameters of ComplexFilter. + + Parameters + ---------- + parameters: Any + Parameters to set. + + Returns + ------- + Self + Self. + """ + parameter_copy = dict(parameters) + if "pipeline_filter_elements" in parameter_copy: + self.pipeline_filter_elements = parameter_copy.pop( + "pipeline_filter_elements" + ) + self.filter_elements = self.pipeline_filter_elements # type: ignore + for key in parameters: + if key.startswith("pipeline_filter_elements__"): + value = parameter_copy.pop(key) + element_name, element_key = key.split("__")[1:] + for name, element in self.pipeline_filter_elements: + if name == element_name: + element.set_params(**{element_key: value}) + super().set_params(**parameter_copy) + return self + + @property + def filter_elements( + self, + ) -> Mapping[str, tuple[int, Optional[int]]]: + """Get filter elements.""" + return self._filter_elements + + @filter_elements.setter + def filter_elements( + self, + filter_elements: Sequence[tuple[str, _MolToMolPipelineElement]], + ) -> None: + """Set filter elements. + + Parameters + ---------- + filter_elements: Sequence[tuple[str, _MolToMolPipelineElement]] + Filter elements to set. + """ + self.filter_elements_dict = dict(filter_elements) + if not len(self.filter_elements_dict) == len(filter_elements): + raise ValueError("Filter elements names need to be unique.") + self._filter_elements = { + element_name: (1, None) for element_name, _element in filter_elements + } + + def _calculate_single_element_value( + self, filter_element: Any, value: RDKitMol + ) -> int: + """Calculate a single filter match for a molecule. + + Parameters + ---------- + filter_element: Any + MolToMol Filter to calculate. + value: RDKitMol + Molecule to calculate filter match for. + + Returns + ------- + int + Filter match. + """ + mol = self.filter_elements_dict[filter_element].pretransform_single(value) + if isinstance(mol, InvalidInstance): + return 0 + return 1 + + +class RDKitDescriptorsFilter(_BaseKeepMatchesFilter): + """Filter to keep or remove molecules based on RDKit descriptors. + + Attributes + ---------- + filter_elements: dict[str, FloatCountRange] + Dictionary of RDKit descriptors to filter by. + The value must be a tuple of minimum and maximum. If None, no limit is set. + [...] + + Notes + ----- + There are four possible scenarios: + - mode = "any" & keep_matches = True: Needs to match at least one filter element. + - mode = "any" & keep_matches = False: Must not match any filter element. + - mode = "all" & keep_matches = True: Needs to match all filter elements. + - mode = "all" & keep_matches = False: Must not match all filter elements. + """ + + @property + def filter_elements(self) -> dict[str, FloatCountRange]: + """Get allowed descriptors as dict.""" + return self._filter_elements + + @filter_elements.setter + def filter_elements(self, descriptors: dict[str, FloatCountRange]) -> None: + """Set allowed descriptors as dict. + + Parameters + ---------- + descriptors: dict[str, FloatCountRange] + Dictionary of RDKit descriptors to filter by. + """ + if not all(hasattr(Descriptors, descriptor) for descriptor in descriptors): + raise ValueError( + "You are trying to use an invalid descriptor. Use RDKit Descriptors module." + ) + self._filter_elements = descriptors + + def _calculate_single_element_value( + self, filter_element: Any, value: RDKitMol + ) -> float: + """Calculate a single descriptor value for a molecule. + + Parameters + ---------- + filter_element: Any + Descriptor to calculate. + value: RDKitMol + Molecule to calculate descriptor for. + + Returns + ------- + float + Descriptor value. + """ + return getattr(Descriptors, filter_element)(value) + + +class MixtureFilter(_MolToMolPipelineElement): + """MolToMol which removes molecules composed of multiple fragments.""" def pretransform_single(self, value: RDKitMol) -> OptionalMol: """Invalidate molecule containing multiple fragments. @@ -180,25 +536,6 @@ def pretransform_single(self, value: RDKitMol) -> OptionalMol: class EmptyMoleculeFilter(_MolToMolPipelineElement): """EmptyMoleculeFilter which removes empty molecules.""" - def __init__( - self, - name: str = "EmptyMoleculeFilter", - n_jobs: int = 1, - uuid: Optional[str] = None, - ) -> None: - """Initialize EmptyMoleculeFilter. - - Parameters - ---------- - name: str, optional (default: "EmptyMoleculeFilterPipe") - Name of the pipeline element. - n_jobs: int, optional (default: 1) - Number of parallel jobs to use. - uuid: str, optional (default: None) - Unique identifier of the pipeline element. - """ - super().__init__(name=name, n_jobs=n_jobs, uuid=uuid) - def pretransform_single(self, value: RDKitMol) -> OptionalMol: """Invalidate empty molecule. @@ -223,25 +560,6 @@ class InorganicsFilter(_MolToMolPipelineElement): CARBON_INORGANICS = ["O=C=O", "[C-]#[O+]"] # CO2 and CO are not organic CARBON_INORGANICS_MAX_ATOMS = 3 - def __init__( - self, - name: str = "InorganicsFilter", - n_jobs: int = 1, - uuid: Optional[str] = None, - ) -> None: - """Initialize InorganicsFilter. - - Parameters - ---------- - name: str, optional (default: "InorganicsFilter") - Name of the pipeline element. - n_jobs: int, optional (default: 1) - Number of parallel jobs to use. - uuid: str, optional (default: None) - Unique identifier of the pipeline element. - """ - super().__init__(name=name, n_jobs=n_jobs, uuid=uuid) - def pretransform_single(self, value: RDKitMol) -> OptionalMol: """Invalidate molecules not containing a carbon atom. diff --git a/molpipeline/utils/molpipeline_types.py b/molpipeline/utils/molpipeline_types.py index ff59fdec..e17b48b3 100644 --- a/molpipeline/utils/molpipeline_types.py +++ b/molpipeline/utils/molpipeline_types.py @@ -3,7 +3,17 @@ from __future__ import annotations from numbers import Number -from typing import Any, List, Literal, Optional, Protocol, Tuple, TypeVar, Union +from typing import ( + Any, + List, + Literal, + Optional, + Protocol, + Tuple, + TypeAlias, + TypeVar, + Union, +) try: from typing import Self # type: ignore[attr-defined] @@ -47,6 +57,16 @@ TypeConserverdIterable = TypeVar("TypeConserverdIterable", List[_T], npt.NDArray[_T]) +FloatCountRange: TypeAlias = tuple[Optional[float], Optional[float]] + +IntCountRange: TypeAlias = tuple[Optional[int], Optional[int]] + +# IntOrIntCountRange for Typing of count ranges +# - a single int for an exact value match +# - a range given as a tuple with a lower and upper bound +# - both limits are optional +IntOrIntCountRange: TypeAlias = Union[int, IntCountRange] + class AnySklearnEstimator(Protocol): """Protocol for sklearn estimators.""" diff --git a/molpipeline/utils/value_conversions.py b/molpipeline/utils/value_conversions.py new file mode 100644 index 00000000..4206e97f --- /dev/null +++ b/molpipeline/utils/value_conversions.py @@ -0,0 +1,28 @@ +"""Module for utilities converting values.""" + +from typing import Sequence + +from molpipeline.utils.molpipeline_types import IntCountRange, IntOrIntCountRange + + +def count_value_to_tuple(count: IntOrIntCountRange) -> IntCountRange: + """Convert a count value to a tuple. + + Parameters + ---------- + count: Union[int, IntCountRange] + Count value. Can be a single int or a tuple of two values. + + Returns + ------- + IntCountRange + Tuple of count values. + """ + if isinstance(count, int): + return count, count + if isinstance(count, Sequence): + count_tuple = tuple(count) + if len(count_tuple) != 2: + raise ValueError(f"Expected a sequence of length 2, got: {count_tuple}") + return count_tuple + raise TypeError(f"Got unexpected type: {type(count)}") diff --git a/tests/test_elements/test_mol2mol/test_mol2mol_filter.py b/tests/test_elements/test_mol2mol/test_mol2mol_filter.py index 9a5572d3..8a0d5c0c 100644 --- a/tests/test_elements/test_mol2mol/test_mol2mol_filter.py +++ b/tests/test_elements/test_mol2mol/test_mol2mol_filter.py @@ -5,7 +5,17 @@ from molpipeline import ErrorFilter, FilterReinserter, Pipeline from molpipeline.any2mol import SmilesToMol from molpipeline.mol2any import MolToSmiles -from molpipeline.mol2mol import ElementFilter, InorganicsFilter, MixtureFilter +from molpipeline.mol2mol import ( + ComplexFilter, + ElementFilter, + InorganicsFilter, + MixtureFilter, + RDKitDescriptorsFilter, + SmartsFilter, + SmilesFilter, +) +from molpipeline.utils.json_operations import recursive_from_json, recursive_to_json +from molpipeline.utils.molpipeline_types import FloatCountRange, IntOrIntCountRange # pylint: disable=duplicate-code # test case molecules are allowed to be duplicated SMILES_ANTIMONY = "[SbH6+3]" @@ -14,68 +24,343 @@ SMILES_CL_BR = "NC(Cl)(Br)C(=O)O" SMILES_METAL_AU = "OC[C@H]1OC(S[Au])[C@H](O)[C@@H](O)[C@@H]1O" +SMILES_LIST = [ + SMILES_ANTIMONY, + SMILES_BENZENE, + SMILES_CHLOROBENZENE, + SMILES_METAL_AU, + SMILES_CL_BR, +] -class MolFilterTest(unittest.TestCase): - """Unittest for MolFilter, which invalidate molecules based on criteria defined in the respective filter.""" + +class ElementFilterTest(unittest.TestCase): + """Unittest for Elementiflter.""" def test_element_filter(self) -> None: - """Test if molecules are filtered correctly by allowed chemical elements. + """Test if molecules are filtered correctly by allowed chemical elements.""" + default_atoms_dict = { + 1: (0, None), + 5: (0, None), + 6: (0, None), + 7: (0, None), + 8: (0, None), + 9: (0, None), + 14: (0, None), + 15: (0, None), + 16: (0, None), + 17: (0, None), + 34: (0, None), + 35: (0, None), + 53: (0, None), + } + + element_filter = ElementFilter() + self.assertEqual(element_filter.allowed_element_numbers, default_atoms_dict) + pipeline = Pipeline( + [ + ("Smiles2Mol", SmilesToMol()), + ("ElementFilter", element_filter), + ("Mol2Smiles", MolToSmiles()), + ("ErrorFilter", ErrorFilter()), + ], + ) + + test_params_list_with_results = [ + { + "params": {}, + "result": [SMILES_BENZENE, SMILES_CHLOROBENZENE, SMILES_CL_BR], + }, + { + "params": { + "ElementFilter__allowed_element_numbers": { + 6: 6, + 1: (5, 6), + 17: (0, 1), + } + }, + "result": [SMILES_BENZENE, SMILES_CHLOROBENZENE], + }, + {"params": {"ElementFilter__add_hydrogens": False}, "result": []}, + ] + + for test_params in test_params_list_with_results: + pipeline.set_params(**test_params["params"]) + filtered_smiles = pipeline.fit_transform(SMILES_LIST) + self.assertEqual(filtered_smiles, test_params["result"]) + + +class ComplexFilterTest(unittest.TestCase): + """Unittest for ComplexFilter.""" + + @staticmethod + def _create_pipeline() -> Pipeline: + """Create a pipeline with a complex filter. Returns ------- - None + Pipeline + Pipeline with a complex filter. """ - smiles2mol = SmilesToMol() - default_atoms = { - 1, - 5, - 6, - 7, - 8, - 9, - 14, - 15, - 16, - 17, - 34, - 35, - 53, - } + element_filter_1 = ElementFilter({6: 6, 1: 6}) + element_filter_2 = ElementFilter({6: 6, 1: 5, 17: 1}) - element_filter = ElementFilter() - self.assertEqual(element_filter.allowed_element_numbers, default_atoms) - mol2smiles = MolToSmiles() - error_filter = ErrorFilter.from_element_list( - [smiles2mol, element_filter, mol2smiles] + multi_element_filter = ComplexFilter( + ( + ("element_filter_1", element_filter_1), + ("element_filter_2", element_filter_2), + ) ) + pipeline = Pipeline( [ - ("Smiles2Mol", smiles2mol), - ("ElementFilter", element_filter), - ("Mol2Smiles", mol2smiles), - ("ErrorFilter", error_filter), + ("Smiles2Mol", SmilesToMol()), + ("MultiElementFilter", multi_element_filter), + ("Mol2Smiles", MolToSmiles()), + ("ErrorFilter", ErrorFilter()), ], ) - filtered_smiles = pipeline.fit_transform( - [ - SMILES_ANTIMONY, - SMILES_BENZENE, - SMILES_CHLOROBENZENE, - SMILES_METAL_AU, - SMILES_CL_BR, + return pipeline + + def test_complex_filter(self) -> None: + """Test if molecules are filtered correctly by allowed chemical elements.""" + pipeline = ComplexFilterTest._create_pipeline() + + test_params_list_with_results = [ + { + "params": {}, + "result": [SMILES_BENZENE, SMILES_CHLOROBENZENE], + }, + { + "params": {"MultiElementFilter__mode": "all"}, + "result": [], + }, + { + "params": { + "MultiElementFilter__mode": "any", + "MultiElementFilter__pipeline_filter_elements__element_filter_1__add_hydrogens": False, + }, + "result": [SMILES_CHLOROBENZENE], + }, + ] + + for test_params in test_params_list_with_results: + pipeline.set_params(**test_params["params"]) + filtered_smiles = pipeline.fit_transform(SMILES_LIST) + self.assertEqual(filtered_smiles, test_params["result"]) + + def test_json_serialization(self) -> None: + """Test if complex filter can be serialized and deserialized.""" + pipeline = ComplexFilterTest._create_pipeline() + json_object = recursive_to_json(pipeline) + newpipeline = recursive_from_json(json_object) + self.assertEqual(json_object, recursive_to_json(newpipeline)) + + pipeline_result = pipeline.fit_transform(SMILES_LIST) + newpipeline_result = newpipeline.fit_transform(SMILES_LIST) + self.assertEqual(pipeline_result, newpipeline_result) + + def test_complex_filter_non_unique_names(self) -> None: + """Test if molecules are filtered correctly by allowed chemical elements.""" + element_filter_1 = ElementFilter({6: 6, 1: 6}) + element_filter_2 = ElementFilter({6: 6, 1: 5, 17: 1}) + + with self.assertRaises(ValueError): + ComplexFilter( + (("filter_1", element_filter_1), ("filter_1", element_filter_2)) + ) + + +class SmartsSmilesFilterTest(unittest.TestCase): + """Unittest for SmartsFilter and SmilesFilter.""" + + def test_smarts_smiles_filter(self) -> None: + """Test if molecules are filtered correctly by allowed SMARTS patterns.""" + smarts_pats: dict[str, IntOrIntCountRange] = { + "c": (4, None), + "Cl": 1, + } + smarts_filter = SmartsFilter(smarts_pats) + + smiles_pats: dict[str, IntOrIntCountRange] = { + "c1ccccc1": (1, None), + "Cl": 1, + } + smiles_filter = SmilesFilter(smiles_pats) + + for filter_ in [smarts_filter, smiles_filter]: + new_input_as_list = list(filter_.filter_elements.keys()) + pipeline = Pipeline( + [ + ("Smiles2Mol", SmilesToMol()), + ("SmartsFilter", filter_), + ("Mol2Smiles", MolToSmiles()), + ("ErrorFilter", ErrorFilter()), + ], + ) + + test_params_list_with_results = [ + { + "params": {}, + "result": [SMILES_BENZENE, SMILES_CHLOROBENZENE, SMILES_CL_BR], + }, + { + "params": {"SmartsFilter__keep_matches": False}, + "result": [SMILES_ANTIMONY, SMILES_METAL_AU], + }, + { + "params": { + "SmartsFilter__mode": "all", + "SmartsFilter__keep_matches": True, + }, + "result": [SMILES_CHLOROBENZENE], + }, + { + "params": { + "SmartsFilter__keep_matches": True, + "SmartsFilter__filter_elements": ["I"], + }, + "result": [], + }, + { + "params": { + "SmartsFilter__keep_matches": False, + "SmartsFilter__mode": "any", + "SmartsFilter__filter_elements": new_input_as_list, + }, + "result": [SMILES_ANTIMONY, SMILES_METAL_AU], + }, ] + + for test_params in test_params_list_with_results: + pipeline.set_params(**test_params["params"]) + filtered_smiles = pipeline.fit_transform(SMILES_LIST) + self.assertEqual(filtered_smiles, test_params["result"]) + + def test_smarts_smiles_filter_wrong_pattern(self) -> None: + """Test if molecules are filtered correctly by allowed SMARTS patterns.""" + smarts_pats: dict[str, IntOrIntCountRange] = { + "cIOnk": (4, None), + "cC": 1, + } + with self.assertRaises(ValueError): + SmartsFilter(smarts_pats) + + smiles_pats: dict[str, IntOrIntCountRange] = { + "cC": (1, None), + "Cl": 1, + } + with self.assertRaises(ValueError): + SmilesFilter(smiles_pats) + + def test_smarts_filter_parallel(self) -> None: + """Test if molecules are filtered correctly by allowed SMARTS patterns in parallel.""" + smarts_pats: dict[str, IntOrIntCountRange] = { + "c": (4, None), + "Cl": 1, + "cc": (1, None), + "ccc": (1, None), + "cccc": (1, None), + "ccccc": (1, None), + "cccccc": (1, None), + "c1ccccc1": (1, None), + "cCl": 1, + } + smarts_filter = SmartsFilter(smarts_pats, mode="all", n_jobs=2) + pipeline = Pipeline( + [ + ("Smiles2Mol", SmilesToMol()), + ("SmartsFilter", smarts_filter), + ("Mol2Smiles", MolToSmiles()), + ("ErrorFilter", ErrorFilter()), + ], ) - self.assertEqual( - filtered_smiles, [SMILES_BENZENE, SMILES_CHLOROBENZENE, SMILES_CL_BR] + filtered_smiles = pipeline.fit_transform(SMILES_LIST) + self.assertEqual(filtered_smiles, [SMILES_CHLOROBENZENE]) + + +class RDKitDescriptorsFilterTest(unittest.TestCase): + """Unittest for RDKitDescriptorsFilter.""" + + def test_descriptor_filter(self) -> None: + """Test if molecules are filtered correctly by allowed descriptors.""" + descriptors: dict[str, FloatCountRange] = { + "MolWt": (None, 190), + "NumHAcceptors": (2, 10), + } + + descriptor_filter = RDKitDescriptorsFilter(descriptors) + + pipeline = Pipeline( + [ + ("Smiles2Mol", SmilesToMol()), + ("DescriptorsFilter", descriptor_filter), + ("Mol2Smiles", MolToSmiles()), + ("ErrorFilter", ErrorFilter()), + ], ) + test_params_list_with_results = [ + {"params": {}, "result": SMILES_LIST}, + {"params": {"DescriptorsFilter__mode": "all"}, "result": [SMILES_CL_BR]}, + { + "params": {"DescriptorsFilter__keep_matches": False}, + "result": [ + SMILES_ANTIMONY, + SMILES_BENZENE, + SMILES_CHLOROBENZENE, + SMILES_METAL_AU, + ], + }, + {"params": {"DescriptorsFilter__mode": "any"}, "result": []}, + { + "params": { + "DescriptorsFilter__keep_matches": True, + "DescriptorsFilter__filter_elements": {"NumHAcceptors": (2.00, 4)}, + }, + "result": [SMILES_CL_BR], + }, + { + "params": { + "DescriptorsFilter__filter_elements": {"NumHAcceptors": (1.99, 4)} + }, + "result": [SMILES_CL_BR], + }, + { + "params": { + "DescriptorsFilter__filter_elements": {"NumHAcceptors": (2.01, 4)} + }, + "result": [], + }, + { + "params": { + "DescriptorsFilter__filter_elements": {"NumHAcceptors": (1, 2.00)} + }, + "result": [SMILES_CL_BR], + }, + { + "params": { + "DescriptorsFilter__filter_elements": {"NumHAcceptors": (1, 2.01)} + }, + "result": [SMILES_CL_BR], + }, + { + "params": { + "DescriptorsFilter__filter_elements": {"NumHAcceptors": (1, 1.99)} + }, + "result": [], + }, + ] - def test_invalidate_mixtures(self) -> None: - """Test if mixtures are correctly invalidated. + for test_params in test_params_list_with_results: + pipeline.set_params(**test_params["params"]) + filtered_smiles = pipeline.fit_transform(SMILES_LIST) + self.assertEqual(filtered_smiles, test_params["result"]) - Returns - ------- - None - """ + +class MixtureFilterTest(unittest.TestCase): + """Unittest for MixtureFilter.""" + + def test_invalidate_mixtures(self) -> None: + """Test if mixtures are correctly invalidated.""" mol_list = ["CCC.CC.C", "c1ccccc1.[Na+].[Cl-]", "c1ccccc1"] expected_invalidated_mol_list = [None, None, "c1ccccc1"] @@ -97,13 +382,12 @@ def test_invalidate_mixtures(self) -> None: mols_processed = pipeline.fit_transform(mol_list) self.assertEqual(expected_invalidated_mol_list, mols_processed) - def test_inorganic_filter(self) -> None: - """Test if molecules are filtered correctly by allowed chemical elements. - Returns - ------- - None - """ +class InorganicsFilterTest(unittest.TestCase): + """Unittest for InorganicsFilter.""" + + def test_inorganic_filter(self) -> None: + """Test if molecules are filtered correctly by allowed chemical elements.""" smiles2mol = SmilesToMol() inorganics_filter = InorganicsFilter() mol2smiles = MolToSmiles() @@ -118,15 +402,7 @@ def test_inorganic_filter(self) -> None: ("ErrorFilter", error_filter), ], ) - filtered_smiles = pipeline.fit_transform( - [ - SMILES_ANTIMONY, - SMILES_BENZENE, - SMILES_CHLOROBENZENE, - SMILES_METAL_AU, - SMILES_CL_BR, - ] - ) + filtered_smiles = pipeline.fit_transform(SMILES_LIST) self.assertEqual( filtered_smiles, [SMILES_BENZENE, SMILES_CHLOROBENZENE, SMILES_METAL_AU, SMILES_CL_BR],