From fa27bfd518cd396a64a3fa78003a06e5d2c9c13d Mon Sep 17 00:00:00 2001 From: frederik-sandfort1 Date: Mon, 19 Aug 2024 17:19:31 +0200 Subject: [PATCH 01/25] remove unnecessary inits and refactor --- .../abstract_pipeline_elements/core.py | 71 +------- molpipeline/any2mol/bin2mol.py | 21 --- molpipeline/any2mol/smiles2mol.py | 20 --- molpipeline/mol2any/mol2bin.py | 25 --- molpipeline/mol2any/mol2bool.py | 24 +-- molpipeline/mol2any/mol2inchi.py | 23 --- .../mol2any/mol2maccs_key_fingerprint.py | 29 ---- molpipeline/mol2any/mol2smiles.py | 25 --- molpipeline/mol2mol/scaffolds.py | 48 ----- molpipeline/mol2mol/standardization.py | 164 ------------------ 10 files changed, 9 insertions(+), 441 deletions(-) diff --git a/molpipeline/abstract_pipeline_elements/core.py b/molpipeline/abstract_pipeline_elements/core.py index 7fe430a9..62f888cf 100644 --- a/molpipeline/abstract_pipeline_elements/core.py +++ b/molpipeline/abstract_pipeline_elements/core.py @@ -91,13 +91,13 @@ def __repr__(self) -> str: class ABCPipelineElement(abc.ABC): """Ancestor of all PipelineElements.""" - name: str + name: Optional[str] _requires_fitting: bool = False uuid: str def __init__( self, - name: str = "ABCPipelineElement", + name: Optional[str] = None, n_jobs: int = 1, uuid: Optional[str] = None, ) -> None: @@ -105,13 +105,15 @@ def __init__( Parameters ---------- - name: str + name: Optional[str], optional (default=None) Name of PipelineElement n_jobs: int Number of cores used for processing. uuid: Optional[str] Unique identifier of the PipelineElement. """ + if name is None: + name = self.__class__.__name__ self.name = name self.n_jobs = n_jobs if uuid is None: @@ -334,11 +336,11 @@ class TransformingPipelineElement(ABCPipelineElement): _input_type: str _output_type: str - name: str + name: Optional[str] def __init__( self, - name: str = "ABCPipelineElement", + name: Optional[str] = None, n_jobs: int = 1, uuid: Optional[str] = None, ) -> None: @@ -346,7 +348,7 @@ def __init__( Parameters ---------- - name: str + name: Optional[str], optional (default=None) Name of PipelineElement n_jobs: int Number of cores used for processing. @@ -616,25 +618,6 @@ class MolToMolPipelineElement(TransformingPipelineElement, abc.ABC): _input_type = "RDKitMol" _output_type = "RDKitMol" - def __init__( - self, - name: str = "MolToMolPipelineElement", - n_jobs: int = 1, - uuid: Optional[str] = None, - ) -> None: - """Initialize MolToMolPipelineElement. - - Parameters - ---------- - name: str - Name of the PipelineElement. - n_jobs: int - Number of cores used for processing. - uuid: Optional[str] - Unique identifier of the PipelineElement. - """ - super().__init__(name=name, n_jobs=n_jobs, uuid=uuid) - def transform(self, values: list[OptionalMol]) -> list[OptionalMol]: """Transform list of molecules to list of molecules. @@ -700,25 +683,6 @@ class AnyToMolPipelineElement(TransformingPipelineElement, abc.ABC): _output_type = "RDKitMol" - def __init__( - self, - name: str = "AnyToMolPipelineElement", - n_jobs: int = 1, - uuid: Optional[str] = None, - ) -> None: - """Initialize AnyToMolPipelineElement. - - Parameters - ---------- - name: str - Name of the PipelineElement. - n_jobs: int - Number of cores used for processing. - uuid: Optional[str] - Unique identifier of the PipelineElement. - """ - super().__init__(name=name, n_jobs=n_jobs, uuid=uuid) - def transform(self, values: Any) -> list[OptionalMol]: """Transform list of instances to list of molecules. @@ -756,25 +720,6 @@ class MolToAnyPipelineElement(TransformingPipelineElement, abc.ABC): _input_type = "RDKitMol" - def __init__( - self, - name: str = "MolToAnyPipelineElement", - n_jobs: int = 1, - uuid: Optional[str] = None, - ) -> None: - """Initialize MolToAnyPipelineElement. - - Parameters - ---------- - name: str - Name of the PipelineElement. - n_jobs: int - Number of cores used for processing. - uuid: Optional[str] - Unique identifier of the PipelineElement. - """ - super().__init__(name=name, n_jobs=n_jobs, uuid=uuid) - @abc.abstractmethod def pretransform_single(self, value: RDKitMol) -> Any: """Transform the molecule, but skip parameters learned during fitting. diff --git a/molpipeline/any2mol/bin2mol.py b/molpipeline/any2mol/bin2mol.py index 9e2c94f9..90c70cbf 100644 --- a/molpipeline/any2mol/bin2mol.py +++ b/molpipeline/any2mol/bin2mol.py @@ -2,8 +2,6 @@ from __future__ import annotations -from typing import Optional - from rdkit import Chem from molpipeline.abstract_pipeline_elements.core import ( @@ -16,25 +14,6 @@ class BinaryToMol(AnyToMolPipelineElement): """Transforms binary string representation to RDKit Mol objects.""" - def __init__( - self, - name: str = "bin2mol", - n_jobs: int = 1, - uuid: Optional[str] = None, - ) -> None: - """Initialize BinaryToMol. - - Parameters - ---------- - name: str, optional (default="bin2mol") - Name of PipelineElement. - n_jobs: int, optional (default=1) - Number of cores used. - uuid: str | None, optional (default=None) - UUID of the pipeline element. If None, a random UUID is generated. - """ - super().__init__(name=name, n_jobs=n_jobs, uuid=uuid) - def pretransform_single(self, value: str) -> OptionalMol: """Transform binary string to molecule. diff --git a/molpipeline/any2mol/smiles2mol.py b/molpipeline/any2mol/smiles2mol.py index 40a0ac80..b28a5fb9 100644 --- a/molpipeline/any2mol/smiles2mol.py +++ b/molpipeline/any2mol/smiles2mol.py @@ -2,8 +2,6 @@ from __future__ import annotations -from typing import Optional - from rdkit import Chem from molpipeline.abstract_pipeline_elements.any2mol.string2mol import ( @@ -16,24 +14,6 @@ class SmilesToMol(_StringToMolPipelineElement): """Transforms Smiles to RDKit Mol objects.""" - def __init__( - self, - name: str = "smiles2mol", - n_jobs: int = 1, - uuid: Optional[str] = None, - ) -> None: - """Initialize SmilesToMol. - - Parameters - ---------- - name: str, optional (default="smiles2mol") - Name of PipelineElement. - n_jobs: int, optional (default=1) - Number of cores used. - uuid: str | None, optional (default=None) - UUID of the pipeline element. If None, a random UUID is generated. - """ - super().__init__(name=name, n_jobs=n_jobs, uuid=uuid) def pretransform_single(self, value: str) -> OptionalMol: """Transform Smiles string to molecule. diff --git a/molpipeline/mol2any/mol2bin.py b/molpipeline/mol2any/mol2bin.py index 61edb9d1..3f14b8b7 100644 --- a/molpipeline/mol2any/mol2bin.py +++ b/molpipeline/mol2any/mol2bin.py @@ -1,7 +1,5 @@ """Converter element for molecules to binary string representation.""" -from typing import Optional - from rdkit import Chem from molpipeline.abstract_pipeline_elements.core import MolToAnyPipelineElement @@ -10,29 +8,6 @@ class MolToBinary(MolToAnyPipelineElement): """PipelineElement to transform a molecule to a binary.""" - def __init__( - self, - name: str = "Mol2Binary", - n_jobs: int = 1, - uuid: Optional[str] = None, - ) -> None: - """Initialize MolToBinaryPipelineElement. - - Parameters - ---------- - name: str, optional (default="Mol2Binary") - name of PipelineElement - n_jobs: int, optional (default=1) - number of jobs to use for parallelization - uuid: Optional[str], optional (default=None) - uuid of PipelineElement, by default None - - Returns - ------- - None - """ - super().__init__(name=name, n_jobs=n_jobs, uuid=uuid) - def pretransform_single(self, value: Chem.Mol) -> str: """Transform a molecule to a binary string. diff --git a/molpipeline/mol2any/mol2bool.py b/molpipeline/mol2any/mol2bool.py index e372fcb1..b7e7ef0e 100644 --- a/molpipeline/mol2any/mol2bool.py +++ b/molpipeline/mol2any/mol2bool.py @@ -3,36 +3,14 @@ from typing import Any from molpipeline.abstract_pipeline_elements.core import ( - MolToAnyPipelineElement, InvalidInstance, + MolToAnyPipelineElement, ) class MolToBool(MolToAnyPipelineElement): """Element to generate a bool array from input.""" - def __init__( - self, - name: str = "Mol2Bool", - n_jobs: int = 1, - uuid: str | None = None, - ) -> None: - """Initialize MolToBinaryPipelineElement. - - Parameters - ---------- - name: str, optional (default="Mol2Bool") - name of PipelineElement - n_jobs: int, optional (default=1) - number of jobs to use for parallelization - uuid: Optional[str], optional (default=None) - uuid of PipelineElement, by default None - - Returns - ------- - None - """ - super().__init__(name=name, n_jobs=n_jobs, uuid=uuid) def pretransform_single(self, value: Any) -> bool: """Transform a value to a bool representation. diff --git a/molpipeline/mol2any/mol2inchi.py b/molpipeline/mol2any/mol2inchi.py index d0174f9a..20dcf0e9 100644 --- a/molpipeline/mol2any/mol2inchi.py +++ b/molpipeline/mol2any/mol2inchi.py @@ -15,29 +15,6 @@ class MolToInchi(_MolToStringPipelineElement): """PipelineElement to transform a molecule to an INCHI string.""" - def __init__( - self, - name: str = "Mol2Inchi", - n_jobs: int = 1, - uuid: Optional[str] = None, - ) -> None: - """Initialize MolToInchiPipelineElement. - - Parameters - ---------- - name: str - name of PipelineElement - n_jobs: int - number of jobs to use for parallelization - uuid: Optional[str], optional - uuid of PipelineElement, by default None - - Returns - ------- - None - """ - super().__init__(name=name, n_jobs=n_jobs, uuid=uuid) - def pretransform_single(self, value: RDKitMol) -> str: """Transform a molecule to a INCHI-key string. diff --git a/molpipeline/mol2any/mol2maccs_key_fingerprint.py b/molpipeline/mol2any/mol2maccs_key_fingerprint.py index 701b5ef9..9b70773e 100644 --- a/molpipeline/mol2any/mol2maccs_key_fingerprint.py +++ b/molpipeline/mol2any/mol2maccs_key_fingerprint.py @@ -1,7 +1,5 @@ """Implementation of MACCS key fingerprint.""" -from typing import Literal - import numpy as np from numpy import typing as npt from rdkit.Chem import MACCSkeys @@ -24,33 +22,6 @@ class MolToMACCSFP(MolToFingerprintPipelineElement): _n_bits = 167 # MACCS keys have 166 bits + 1 bit for an all-zero vector (bit 0) - def __init__( - self, - return_as: Literal["sparse", "dense", "explicit_bit_vect"] = "sparse", - name: str = "MolToMACCS", - n_jobs: int = 1, - uuid: str | None = None, - ) -> None: - """Initialize MolToMACCS. - - Parameters - ---------- - return_as: Literal["sparse", "dense", "explicit_bit_vect"], optional (default="sparse") - Type of output. When "sparse" the fingerprints will be returned as a - scipy.sparse.csr_matrix holding a sparse representation of the bit vectors. - With "dense" a numpy matrix will be returned. - With "explicit_bit_vect" the fingerprints will be returned as a list of RDKit's - rdkit.DataStructs.cDataStructs.ExplicitBitVect. - name: str, optional (default="MolToMACCS") - Name of PipelineElement - n_jobs: int, optional (default=1) - Number of cores to use. - uuid: str | None, optional (default=None) - UUID of the PipelineElement. - - """ - super().__init__(return_as=return_as, name=name, n_jobs=n_jobs, uuid=uuid) - def pretransform_single( self, value: RDKitMol ) -> dict[int, int] | npt.NDArray[np.int_] | ExplicitBitVect: diff --git a/molpipeline/mol2any/mol2smiles.py b/molpipeline/mol2any/mol2smiles.py index 3638e711..756beb0e 100644 --- a/molpipeline/mol2any/mol2smiles.py +++ b/molpipeline/mol2any/mol2smiles.py @@ -2,8 +2,6 @@ from __future__ import annotations -from typing import Optional - from rdkit import Chem from molpipeline.abstract_pipeline_elements.mol2any.mol2string import ( @@ -14,29 +12,6 @@ class MolToSmiles(_MolToStringPipelineElement): """PipelineElement to transform a molecule to a SMILES string.""" - def __init__( - self, - name: str = "Mol2Smiles", - n_jobs: int = 1, - uuid: Optional[str] = None, - ): - """Initialize MolToSmilesPipelineElement. - - Parameters - ---------- - name: str - name of PipelineElement - n_jobs: int - number of jobs to use for parallelization - uuid: Optional[str], optional - uuid of PipelineElement, by default None - - Returns - ------- - None - """ - super().__init__(name=name, n_jobs=n_jobs, uuid=uuid) - def pretransform_single(self, value: Chem.Mol) -> str: """Transform a molecule to a SMILES string. diff --git a/molpipeline/mol2mol/scaffolds.py b/molpipeline/mol2mol/scaffolds.py index 08f4674a..00cb1768 100644 --- a/molpipeline/mol2mol/scaffolds.py +++ b/molpipeline/mol2mol/scaffolds.py @@ -2,8 +2,6 @@ from __future__ import annotations -from typing import Optional - from rdkit.Chem.Scaffolds import MurckoScaffold as RDKIT_MurckoScaffold from molpipeline.abstract_pipeline_elements.core import ( @@ -18,29 +16,6 @@ class MurckoScaffold(_MolToMolPipelineElement): The Murcko-scaffold is composed of all rings and the linker atoms between them. """ - def __init__( - self, - name: str = "MurckoScaffold", - n_jobs: int = 1, - uuid: Optional[str] = None, - ) -> None: - """Initialize MurckoScaffold. - - Parameters - ---------- - name: str - Name of pipeline element. - n_jobs: int - Number of jobs to use for parallelization. - uuid: Optional[str] - UUID of pipeline element. - - Returns - ------- - None - """ - super().__init__(name=name, n_jobs=n_jobs, uuid=uuid) - def pretransform_single(self, value: RDKitMol) -> OptionalMol: """Extract Murco-scaffold of molecule. @@ -63,29 +38,6 @@ class MakeScaffoldGeneric(_MolToMolPipelineElement): Done to make scaffolds less speciffic. """ - def __init__( - self, - name: str = "MakeScaffoldGeneric", - n_jobs: int = 1, - uuid: Optional[str] = None, - ) -> None: - """Initialize MakeScaffoldGeneric. - - Parameters - ---------- - name: str - Name of pipeline element. - n_jobs: int - Number of jobs to use for parallelization. - uuid: Optional[str] - UUID of pipeline element. - - Returns - ------- - None - """ - super().__init__(name=name, n_jobs=n_jobs, uuid=uuid) - def pretransform_single(self, value: RDKitMol) -> OptionalMol: """Set all atoms to carbon and all bonds to single bond and return mol object. diff --git a/molpipeline/mol2mol/standardization.py b/molpipeline/mol2mol/standardization.py index 336e6231..4399ed49 100644 --- a/molpipeline/mol2mol/standardization.py +++ b/molpipeline/mol2mol/standardization.py @@ -46,25 +46,6 @@ class TautomerCanonicalizer(_MolToMolPipelineElement): """MolToMolPipelineElement which canonicalizes tautomers of a molecule.""" - def __init__( - self, - name: str = "TautomerCanonicalizer", - n_jobs: int = 1, - uuid: Optional[str] = None, - ) -> None: - """Initialize TautomerCanonicalizer. - - Parameters - ---------- - name: str, optional (default="TautomerCanonicalizer") - Name of PipelineElement - n_jobs: int, optional (default=1) - Number of jobs to use for parallelization - uuid: Optional[str], optional (default=None) - UUID of PipelineElement. If None, a random UUID is generated. - """ - super().__init__(name=name, n_jobs=n_jobs, uuid=uuid) - def pretransform_single(self, value: RDKitMol) -> OptionalMol: """Canonicalize tautomers of molecule. @@ -93,25 +74,6 @@ class ChargeParentExtractor(_MolToMolPipelineElement): The charge-parent is the largest fragment after neutralization. """ - def __init__( - self, - name: str = "ChargeParentExtractor", - n_jobs: int = 1, - uuid: Optional[str] = None, - ) -> None: - """Initialize ChargeParentExtractor. - - Parameters - ---------- - name: str - Name of PipelineElement - n_jobs: int - Number of jobs to use for parallelization - uuid: Optional[str], optional - uuid of PipelineElement, by default None - """ - super().__init__(name=name, n_jobs=n_jobs, uuid=uuid) - def pretransform_single(self, value: RDKitMol) -> OptionalMol: """Return charge-parent of molecule, which is the largest fragment after neutralization. @@ -200,24 +162,6 @@ def pretransform_single(self, value: RDKitMol) -> OptionalMol: class LargestFragmentChooser(_MolToMolPipelineElement): """MolToMolPipelineElement which returns the largest fragment of a molecule.""" - def __init__( - self, - name: str = "LargestFragmentChooser", - n_jobs: int = 1, - uuid: Optional[str] = None, - ) -> None: - """Initialize LargestFragmentChooser. - - Parameters - ---------- - name: str, optional (default="LargestFragmentChooser") - Name of PipelineElement. - n_jobs: int, optional (default=1) - Number of jobs to use for parallelization. - uuid: Optional[str], optional (default=None) - UUID of PipelineElement. If None, a random UUID is generated. - """ - super().__init__(name=name, n_jobs=n_jobs, uuid=uuid) def pretransform_single(self, value: RDKitMol) -> OptionalMol: """Return largest fragment of molecule. @@ -238,24 +182,6 @@ def pretransform_single(self, value: RDKitMol) -> OptionalMol: class MetalDisconnector(_MolToMolPipelineElement): """MolToMolPipelineElement which removes bonds between organic compounds and metals.""" - def __init__( - self, - name: str = "MetalDisconnector", - n_jobs: int = 1, - uuid: Optional[str] = None, - ) -> None: - """Initialize MetalDisconnector. - - Parameters - ---------- - name: str - Name of PipelineElement - n_jobs: int - Number of jobs to use for parallelization - uuid: Optional[str], optional - uuid of PipelineElement, by default None - """ - super().__init__(name=name, n_jobs=n_jobs, uuid=uuid) def pretransform_single(self, value: RDKitMol) -> OptionalMol: """Cleave bonds with metals. @@ -280,24 +206,6 @@ def pretransform_single(self, value: RDKitMol) -> OptionalMol: class IsotopeRemover(_MolToMolPipelineElement): """MolToMolPipelineElement which removes isotope information of atoms in a molecule.""" - def __init__( - self, - name: str = "IsotopeRemover", - n_jobs: int = 1, - uuid: Optional[str] = None, - ) -> None: - """Initialize IsotopeRemover. - - Parameters - ---------- - name: str, optional (default="IsotopeRemover") - Name of PipelineElement. - n_jobs: int, optional (default=1) - Number of jobs to use for parallelization. - uuid: Optional[str], optional (default=None) - UUID of PipelineElement. If None, a random UUID is generated. - """ - super().__init__(name=name, n_jobs=n_jobs, uuid=uuid) def pretransform_single(self, value: RDKitMol) -> OptionalMol: """Remove isotope information of each atom. @@ -323,24 +231,6 @@ def pretransform_single(self, value: RDKitMol) -> OptionalMol: class ExplicitHydrogenRemover(_MolToMolPipelineElement): """MolToMolPipelineElement which removes explicit hydrogen atoms from a molecule.""" - def __init__( - self, - name: str = "ExplicitHydrogenRemover", - n_jobs: int = 1, - uuid: Optional[str] = None, - ) -> None: - """Initialize ExplicitHydrogenRemover. - - Parameters - ---------- - name: str, optional (default="ExplicitHydrogenRemover") - Name of PipelineElement - n_jobs: int, optional (default=1) - Number of jobs to use for parallelization - uuid: Optional[str], optional (default=None) - UUID of PipelineElement. If None, a random UUID is generated. - """ - super().__init__(name=name, n_jobs=n_jobs, uuid=uuid) def pretransform_single(self, value: RDKitMol) -> OptionalMol: """Remove explicit hydrogen atoms. @@ -361,24 +251,6 @@ def pretransform_single(self, value: RDKitMol) -> OptionalMol: class StereoRemover(_MolToMolPipelineElement): """MolToMolPipelineElement which removes stereo-information from the molecule.""" - def __init__( - self, - name: str = "StereoRemover", - n_jobs: int = 1, - uuid: Optional[str] = None, - ) -> None: - """Initialize StereoRemover. - - Parameters - ---------- - name: str - Name of PipelineElement - n_jobs: int - Number of jobs to use for parallelization - uuid: Optional[str], optional - uuid of PipelineElement, by default None - """ - super().__init__(name=name, n_jobs=n_jobs, uuid=uuid) def pretransform_single(self, value: RDKitMol) -> OptionalMol: """Remove stereo-information in molecule. @@ -401,24 +273,6 @@ def pretransform_single(self, value: RDKitMol) -> OptionalMol: class SaltRemover(_MolToMolPipelineElement): """MolToMolPipelineElement which removes metal ions from molecule.""" - def __init__( - self, - name: str = "SaltRemover", - n_jobs: int = 1, - uuid: Optional[str] = None, - ) -> None: - """Initialize SaltRemover. - - Parameters - ---------- - name: str - Name of PipelineElement - n_jobs: int - Number of jobs to use for parallelization - uuid: Optional[str], optional - uuid of PipelineElement, by default None - """ - super().__init__(name=name, n_jobs=n_jobs, uuid=uuid) def pretransform_single(self, value: RDKitMol) -> OptionalMol: """Remove metal ions. @@ -605,24 +459,6 @@ def pretransform_single(self, value: RDKitMol) -> OptionalMol: class Uncharger(_MolToMolPipelineElement): """MolToMolPipelineElement which removes charges in a molecule, if possible.""" - def __init__( - self, - name: str = "Uncharger", - n_jobs: int = 1, - uuid: Optional[str] = None, - ) -> None: - """Initialize Uncharger. - - Parameters - ---------- - name: str, optional (default="Uncharger") - Name of PipelineElement. - n_jobs: int, optional (default=1) - Number of jobs to use for parallelization. - uuid: str | None, optional (default=None) - UUID of the pipeline element. If None, a random UUID is generated. - """ - super().__init__(name=name, n_jobs=n_jobs, uuid=uuid) def pretransform_single(self, value: RDKitMol) -> OptionalMol: """Remove charges of molecule. From b706268b3ee51afc39a33525a3635711f507b672 Mon Sep 17 00:00:00 2001 From: frederik-sandfort1 Date: Tue, 20 Aug 2024 15:18:38 +0200 Subject: [PATCH 02/25] include smarts filter, smiles filter, descriptors filter --- .../mol2mol/__init__.py | 5 + .../mol2mol/filter.py | 129 ++++++ molpipeline/any2mol/smiles2mol.py | 1 - molpipeline/mol2any/mol2bool.py | 1 - molpipeline/mol2mol/__init__.py | 6 + molpipeline/mol2mol/filter.py | 416 +++++++++++++++--- molpipeline/mol2mol/standardization.py | 7 - .../test_mol2mol/test_mol2mol_filter.py | 188 +++++--- 8 files changed, 625 insertions(+), 128 deletions(-) create mode 100644 molpipeline/abstract_pipeline_elements/mol2mol/__init__.py create mode 100644 molpipeline/abstract_pipeline_elements/mol2mol/filter.py diff --git a/molpipeline/abstract_pipeline_elements/mol2mol/__init__.py b/molpipeline/abstract_pipeline_elements/mol2mol/__init__.py new file mode 100644 index 00000000..eb352dba --- /dev/null +++ b/molpipeline/abstract_pipeline_elements/mol2mol/__init__.py @@ -0,0 +1,5 @@ +"""Initialize the module for abstract mol2mol elements.""" + +from molpipeline.abstract_pipeline_elements.mol2mol.filter import BasePatternsFilter + +__all__ = ["BasePatternsFilter"] diff --git a/molpipeline/abstract_pipeline_elements/mol2mol/filter.py b/molpipeline/abstract_pipeline_elements/mol2mol/filter.py new file mode 100644 index 00000000..191b3ddd --- /dev/null +++ b/molpipeline/abstract_pipeline_elements/mol2mol/filter.py @@ -0,0 +1,129 @@ +"""Abstract classes for filters.""" + +import abc +from typing import Any, Literal, Optional, Union + +try: + from typing import Self # type: ignore[attr-defined] +except ImportError: + from typing_extensions import Self + +from molpipeline.abstract_pipeline_elements.core import MolToMolPipelineElement + + +class BasePatternsFilter(MolToMolPipelineElement, abc.ABC): + """Filter to keep or remove molecules based on patterns.""" + + def __init__( + self, + patterns: Union[ + list[str], dict[str, Union[int, tuple[Optional[int], Optional[int]]]] + ], + keep: bool = True, + mode: Literal["any", "all"] = "any", + name: Optional[str] = None, + n_jobs: int = 1, + uuid: Optional[str] = None, + ) -> None: + """Initialize BasePatternsFilter. + + Parameters + ---------- + patterns: Union[list[str], dict[str, Union[int, tuple[Optional[int], Optional[int]]]]] + List of patterns to allow in molecules. + Alternatively, a dictionary can be passed with patterns as keys + and an int for exact count or a tuple of minimum and maximum. + keep: bool, optional (default: True) + If True, molecules containing the specified patterns are kept, else removed. + mode: Literal["any", "all"], optional (default: "any") + If "any", at least one of the specified patterns must be present in the molecule. + If "all", all of the specified patterns must be present in the molecule. + name: Optional[str], optional (default: None) + Name of the pipeline element. + n_jobs: int, optional (default: 1) + Number of parallel jobs to use. + uuid: str, optional (default: None) + Unique identifier of the pipeline element. + """ + super().__init__(name=name, n_jobs=n_jobs, uuid=uuid) + self.patterns = patterns # type: ignore + self.keep = keep + self.mode = mode + + @property + def patterns(self) -> dict[str, tuple[Optional[int], Optional[int]]]: + """Get allowed patterns as dict.""" + return self._patterns + + @patterns.setter + def patterns( + self, + patterns: Union[ + list[str], dict[str, Union[int, tuple[Optional[int], Optional[int]]]] + ], + ) -> None: + """Set allowed patterns as dict. + + Parameters + ---------- + patterns: Union[list[str], dict[str, Union[int, tuple[Optional[int], Optional[int]]]]] + List of patterns. + """ + self._patterns: dict[str, tuple[Optional[int], Optional[int]]] + if isinstance(patterns, list) or isinstance(patterns, set): + self._patterns = {pat: (1, None) for pat in patterns} + else: + self._patterns = {} + for pat, count in patterns.items(): + if isinstance(count, int): + self._patterns[pat] = (count, count) + else: + self._patterns[pat] = count + + def get_params(self, deep: bool = True) -> dict[str, Any]: + """Get parameters of PatternFilter. + + Parameters + ---------- + deep: bool, optional (default: True) + If True, return the parameters of all subobjects that are PipelineElements. + + Returns + ------- + dict[str, Any] + Parameters of PatternFilter. + """ + params = super().get_params(deep=deep) + if deep: + params["patterns"] = { + pat: (count_tuple[0], count_tuple[1]) + for pat, count_tuple in self.patterns.items() + } + else: + params["patterns"] = self.patterns + params["keep"] = self.keep + params["mode"] = self.mode + return params + + def set_params(self, **parameters: Any) -> Self: + """Set parameters of PatternFilter. + + Parameters + ---------- + parameters: Any + Parameters to set. + + Returns + ------- + Self + Self. + """ + parameter_copy = dict(parameters) + if "patterns" in parameter_copy: + self.patterns = parameter_copy.pop("patterns") + if "keep" in parameter_copy: + self.keep = parameter_copy.pop("keep") + if "mode" in parameter_copy: + self.mode = parameter_copy.pop("mode") + super().set_params(**parameter_copy) + return self diff --git a/molpipeline/any2mol/smiles2mol.py b/molpipeline/any2mol/smiles2mol.py index b28a5fb9..79db23bd 100644 --- a/molpipeline/any2mol/smiles2mol.py +++ b/molpipeline/any2mol/smiles2mol.py @@ -14,7 +14,6 @@ class SmilesToMol(_StringToMolPipelineElement): """Transforms Smiles to RDKit Mol objects.""" - def pretransform_single(self, value: str) -> OptionalMol: """Transform Smiles string to molecule. diff --git a/molpipeline/mol2any/mol2bool.py b/molpipeline/mol2any/mol2bool.py index b7e7ef0e..55d473f9 100644 --- a/molpipeline/mol2any/mol2bool.py +++ b/molpipeline/mol2any/mol2bool.py @@ -11,7 +11,6 @@ class MolToBool(MolToAnyPipelineElement): """Element to generate a bool array from input.""" - def pretransform_single(self, value: Any) -> bool: """Transform a value to a bool representation. diff --git a/molpipeline/mol2mol/__init__.py b/molpipeline/mol2mol/__init__.py index 356e4a63..6114de1b 100644 --- a/molpipeline/mol2mol/__init__.py +++ b/molpipeline/mol2mol/__init__.py @@ -1,10 +1,13 @@ """Init the module for mol2mol pipeline elements.""" from molpipeline.mol2mol.filter import ( + DescriptorsFilter, ElementFilter, EmptyMoleculeFilter, InorganicsFilter, MixtureFilter, + SmartsFilter, + SmilesFilter, ) from molpipeline.mol2mol.reaction import MolToMolReaction from molpipeline.mol2mol.scaffolds import MakeScaffoldGeneric, MurckoScaffold @@ -41,4 +44,7 @@ "SolventRemover", "Uncharger", "InorganicsFilter", + "SmartsFilter", + "SmilesFilter", + "DescriptorsFilter", ) diff --git a/molpipeline/mol2mol/filter.py b/molpipeline/mol2mol/filter.py index 2b384ebf..c40b55c8 100644 --- a/molpipeline/mol2mol/filter.py +++ b/molpipeline/mol2mol/filter.py @@ -2,7 +2,7 @@ from __future__ import annotations -from typing import Any, Optional +from typing import Any, Literal, Optional, TypeVar, Union try: from typing import Self # type: ignore[attr-defined] @@ -10,16 +10,36 @@ from typing_extensions import Self from rdkit import Chem +from rdkit.Chem import Descriptors, FilterCatalog from molpipeline.abstract_pipeline_elements.core import InvalidInstance from molpipeline.abstract_pipeline_elements.core import ( MolToMolPipelineElement as _MolToMolPipelineElement, ) +from molpipeline.abstract_pipeline_elements.mol2mol import ( + BasePatternsFilter as _BasePatternsFilter, +) from molpipeline.utils.molpipeline_types import OptionalMol, RDKitMol +_T = TypeVar("_T") + + +def _list_to_dict_with_counts(elements_list: list[_T]) -> dict[_T, int]: + counts_dict: dict[_T, int] = {} + for element in elements_list: + if element in counts_dict: + counts_dict[element] += 1 + else: + counts_dict[element] = 1 + return counts_dict + class ElementFilter(_MolToMolPipelineElement): - """ElementFilter which removes molecules containing chemical elements other than specified.""" + """ElementFilter which removes molecules containing chemical elements other than specified. + + Molecular elements are filtered based on their atomic number. + The filter can be configured to allow only specific elements and/or a specific number of atoms of each element. + """ DEFAULT_ALLOWED_ELEMENT_NUMBERS = [ 1, @@ -39,7 +59,9 @@ class ElementFilter(_MolToMolPipelineElement): def __init__( self, - allowed_element_numbers: Optional[list[int]] = None, + allowed_element_numbers: Optional[ + Union[list[int], dict[int, Union[int, tuple[Optional[int], Optional[int]]]]] + ] = None, name: str = "ElementFilter", n_jobs: int = 1, uuid: Optional[str] = None, @@ -48,9 +70,10 @@ def __init__( Parameters ---------- - allowed_element_numbers: list[int] + allowed_element_numbers: Optional[Union[list[int], dict[int, Union[int, tuple[Optional[int], Optional[int]]]]]] List of atomic numbers of elements to allowed in molecules. Per default allowed elements are: H, B, C, N, O, F, Si, P, S, Cl, Se, Br, I. + Alternatively, a dictionary can be passed with atomic numbers as keys and an int for exact count or a tuple of minimum and maximum name: str, optional (default: "ElementFilterPipe") Name of the pipeline element. n_jobs: int, optional (default: 1) @@ -59,12 +82,43 @@ def __init__( Unique identifier of the pipeline element. """ super().__init__(name=name, n_jobs=n_jobs, uuid=uuid) + self.allowed_element_numbers = allowed_element_numbers # type: ignore + + @property + def allowed_element_numbers(self) -> dict[int, tuple[Optional[int], Optional[int]]]: + """Get allowed element numbers as dict.""" + return self._allowed_element_numbers + + @allowed_element_numbers.setter + def allowed_element_numbers( + self, + allowed_element_numbers: Optional[ + Union[list[int], dict[int, Union[int, tuple[Optional[int], Optional[int]]]]] + ], + ) -> None: + """Set allowed element numbers as dict. + + Parameters + ---------- + allowed_element_numbers: Optional[Union[list[int], dict[int, Union[int, tuple[Optional[int], Optional[int]]]]] + List of atomic numbers of elements to allowed in molecules. + """ + self._allowed_element_numbers: dict[int, tuple[Optional[int], Optional[int]]] if allowed_element_numbers is None: allowed_element_numbers = self.DEFAULT_ALLOWED_ELEMENT_NUMBERS - if not isinstance(allowed_element_numbers, set): - self.allowed_element_numbers = set(allowed_element_numbers) + if isinstance(allowed_element_numbers, list) or isinstance( + allowed_element_numbers, set + ): + self._allowed_element_numbers = { + atom_number: (1, None) for atom_number in allowed_element_numbers + } else: - self.allowed_element_numbers = allowed_element_numbers + self._allowed_element_numbers = {} + for atom_number, count in allowed_element_numbers.items(): + if isinstance(count, int): + self._allowed_element_numbers[atom_number] = (count, count) + else: + self._allowed_element_numbers[atom_number] = count def get_params(self, deep: bool = True) -> dict[str, Any]: """Get parameters of ElementFilter. @@ -82,7 +136,8 @@ def get_params(self, deep: bool = True) -> dict[str, Any]: params = super().get_params(deep=deep) if deep: params["allowed_element_numbers"] = { - int(atom) for atom in self.allowed_element_numbers + atom_number: (count_tuple[0], count_tuple[1]) + for atom_number, count_tuple in self.allowed_element_numbers.items() } else: params["allowed_element_numbers"] = self.allowed_element_numbers @@ -120,31 +175,196 @@ def pretransform_single(self, value: RDKitMol) -> OptionalMol: OptionalMol Molecule if it contains only allowed elements, else InvalidInstance. """ - unique_elements = set(atom.GetAtomicNum() for atom in value.GetAtoms()) - if not unique_elements.issubset(self.allowed_element_numbers): - forbidden_elements = unique_elements - self.allowed_element_numbers - return InvalidInstance( - self.uuid, - f"Molecule contains following forbidden elements: {forbidden_elements}", - self.name, - ) + elements_list = [atom.GetAtomicNum() for atom in value.GetAtoms()] + elements_count_dict = _list_to_dict_with_counts(elements_list) + for element, count in elements_count_dict.items(): + min_count, max_count = self.allowed_element_numbers[element] + if (min_count is not None and count < min_count) or ( + max_count is not None and count > max_count + ): + return InvalidInstance( + self.uuid, + f"Molecule contains forbidden number of element {element}.", + self.name, + ) return value -class MixtureFilter(_MolToMolPipelineElement): - """MolToMol which removes molecules composed of multiple fragments.""" +class SmartsFilter(_BasePatternsFilter): + """Filter to keep or remove molecules based on SMARTS patterns.""" + + @property + def smarts_filter(self) -> FilterCatalog.FilterCatalog: + """Get the SMARTS filter.""" + smarts_matcher_list = [ + FilterCatalog.SmartsMatcher(smarts, smarts) + for i, smarts in enumerate(self.patterns) + ] + rdkit_filter = FilterCatalog.FilterCatalog() + for smarts_matcher in smarts_matcher_list: + if not smarts_matcher.IsValid(): + raise ValueError(f"Invalid SMARTS: {smarts_matcher.GetPattern()}") + entry = FilterCatalog.FilterCatalogEntry( + smarts_matcher.GetName(), smarts_matcher + ) + rdkit_filter.AddEntry(entry) + return rdkit_filter + + def pretransform_single(self, value: RDKitMol) -> OptionalMol: + """Invalidate or validate molecule matching any or all of the specified SMARTS patterns. + + Parameters + ---------- + value: RDKitMol + Molecule to check. + + Returns + ------- + OptionalMol + Molecule that matches defined smarts filter, else InvalidInstance. + """ + match_counts = 0 + for smarts_match in self.smarts_filter.GetMatches(value): + match_smarts = smarts_match.GetDescription() + all_matches = value.GetSubstructMatches(Chem.MolFromSmarts(match_smarts)) + min_count, max_count = self.patterns[match_smarts] + if (min_count is None or len(all_matches) >= min_count) and ( + max_count is None or len(all_matches) <= max_count + ): + if self.mode == "any": + return ( + value + if self.keep + else InvalidInstance( + self.uuid, + f"Molecule contains forbidden SMARTS pattern {match_smarts}.", + self.name, + ) + ) + else: + match_counts += 1 + if self.mode == "any": + return ( + value + if not self.keep + else InvalidInstance( + self.uuid, + "Molecule does not match any of the SmartsFilter patterns.", + self.name, + ) + ) + else: + if match_counts == len(self.patterns): + return ( + value + if self.keep + else InvalidInstance( + self.uuid, + "Molecule matches one of the SmartsFilter patterns.", + self.name, + ) + ) + else: + return ( + value + if not self.keep + else InvalidInstance( + self.uuid, + "Molecule does not match all of the SmartsFilter patterns.", + self.name, + ) + ) + + +class SmilesFilter(_BasePatternsFilter): + """Filter to keep or remove molecules based on SMILES patterns.""" + + def pretransform_single(self, value: RDKitMol) -> OptionalMol: + """Invalidate or validate molecule matching any or all of the specified SMILES patterns. + + Parameters + ---------- + value: RDKitMol + Molecule to check. + + Returns + ------- + OptionalMol + Molecule that matches defined smiles filter, else InvalidInstance. + """ + for pattern, (min_count, max_count) in self.patterns.items(): + all_matches = value.GetSubstructMatches(Chem.MolFromSmiles(pattern)) + if (min_count is None or len(all_matches) >= min_count) and ( + max_count is None or len(all_matches) <= max_count + ): + if self.mode == "any": + return ( + value + if self.keep + else InvalidInstance( + self.uuid, + f"Molecule contains forbidden SMILES pattern {pattern}.", + self.name, + ) + ) + else: + if self.mode == "all": + return ( + value + if not self.keep + else InvalidInstance( + self.uuid, + "Molecule does not match all required patterns.", + self.name, + ) + ) + if self.mode == "any": + return ( + value + if not self.keep + else InvalidInstance( + self.uuid, + "Molecule does not match any of the SmilesFilter patterns.", + self.name, + ) + ) + else: + return ( + value + if self.keep + else InvalidInstance( + self.uuid, + "Molecule does not match all of the SmilesFilter patterns.", + self.name, + ) + ) + - def __int__( +class DescriptorsFilter(_MolToMolPipelineElement): + """Filter to keep or remove molecules based on RDKit descriptors.""" + + def __init__( self, - name: str = "MixtureFilter", + descriptors: dict[str, tuple[Optional[float], Optional[float]]], + keep: bool = True, + mode: Literal["any", "all"] = "any", + name: Optional[str] = None, n_jobs: int = 1, uuid: Optional[str] = None, ) -> None: - """Initialize MixtureFilter. + """Initialize DescriptorsFilter. Parameters ---------- - name: str, optional (default: "MixtureFilterPipe") + descriptors: dict[str, tuple[Optional[float], Optional[float]]] + Dictionary of RDKit descriptors to filter by. + The value must be a tuple of minimum and maximum. If None, no limit is set. + keep: bool, optional (default: True) + If True, molecules containing the specified descriptors are kept, else removed. + mode: Literal["any", "all"], optional (default: "any") + If "any", at least one of the specified descriptors must be present in the molecule. + If "all", all of the specified descriptors must be present in the molecule. + name: Optional[str], optional (default: None) Name of the pipeline element. n_jobs: int, optional (default: 1) Number of parallel jobs to use. @@ -152,6 +372,120 @@ def __int__( Unique identifier of the pipeline element. """ super().__init__(name=name, n_jobs=n_jobs, uuid=uuid) + self.descriptors = descriptors + self.keep = keep + self.mode = mode + + @property + def descriptors(self) -> dict[str, tuple[Optional[float], Optional[float]]]: + """Get allowed descriptors as dict.""" + return self._descriptors + + @descriptors.setter + def descriptors( + self, descriptors: dict[str, tuple[Optional[float], Optional[float]]] + ) -> None: + """Set allowed descriptors as dict. + + Parameters + ---------- + descriptors: dict[str, tuple[Optional[float], Optional[float]]] + Dictionary of RDKit descriptors to filter by. + """ + self._descriptors = descriptors + if not all(hasattr(Descriptors, descriptor) for descriptor in descriptors): + raise ValueError( + "You are trying to use an invalid descriptor. Use RDKit Descriptors module." + ) + + def get_params(self, deep: bool = True) -> dict[str, Any]: + """Get parameters of DescriptorFilter. + + Parameters + ---------- + deep: bool, optional (default: True) + If True, return the parameters of all subobjects that are PipelineElements. + + Returns + ------- + dict[str, Any] + Parameters of DescriptorFilter. + """ + params = super().get_params(deep=deep) + if deep: + params["descriptors"] = { + descriptor: (count_tuple[0], count_tuple[1]) + for descriptor, count_tuple in self.descriptors.items() + } + else: + params["descriptors"] = self.descriptors + params["keep"] = self.keep + params["mode"] = self.mode + return params + + def pretransform_single(self, value: RDKitMol) -> OptionalMol: + """Invalidate or validate molecule based on specified RDKit descriptors. + + Parameters + ---------- + value: RDKitMol + Molecule to check. + + Returns + ------- + OptionalMol + Molecule that matches defined descriptors filter, else InvalidInstance. + """ + for descriptor, (min_count, max_count) in self.descriptors.items(): + descriptor_value = getattr(Descriptors, descriptor)(value) + if (min_count is None or descriptor_value >= min_count) and ( + max_count is None or descriptor_value <= max_count + ): + if self.mode == "any": + return ( + value + if self.keep + else InvalidInstance( + self.uuid, + f"Molecule contains forbidden descriptor {descriptor}.", + self.name, + ) + ) + else: + if self.mode == "all": + return ( + value + if not self.keep + else InvalidInstance( + self.uuid, + "Molecule does not match all required descriptors.", + self.name, + ) + ) + if self.mode == "any": + return ( + value + if not self.keep + else InvalidInstance( + self.uuid, + "Molecule does not match any of the DescriptorsFilter descriptors.", + self.name, + ) + ) + else: + return ( + value + if self.keep + else InvalidInstance( + self.uuid, + "Molecule does not match all of the DescriptorsFilter descriptors.", + self.name, + ) + ) + + +class MixtureFilter(_MolToMolPipelineElement): + """MolToMol which removes molecules composed of multiple fragments.""" def pretransform_single(self, value: RDKitMol) -> OptionalMol: """Invalidate molecule containing multiple fragments. @@ -180,25 +514,6 @@ def pretransform_single(self, value: RDKitMol) -> OptionalMol: class EmptyMoleculeFilter(_MolToMolPipelineElement): """EmptyMoleculeFilter which removes empty molecules.""" - def __init__( - self, - name: str = "EmptyMoleculeFilter", - n_jobs: int = 1, - uuid: Optional[str] = None, - ) -> None: - """Initialize EmptyMoleculeFilter. - - Parameters - ---------- - name: str, optional (default: "EmptyMoleculeFilterPipe") - Name of the pipeline element. - n_jobs: int, optional (default: 1) - Number of parallel jobs to use. - uuid: str, optional (default: None) - Unique identifier of the pipeline element. - """ - super().__init__(name=name, n_jobs=n_jobs, uuid=uuid) - def pretransform_single(self, value: RDKitMol) -> OptionalMol: """Invalidate empty molecule. @@ -223,25 +538,6 @@ class InorganicsFilter(_MolToMolPipelineElement): CARBON_INORGANICS = ["O=C=O", "[C-]#[O+]"] # CO2 and CO are not organic CARBON_INORGANICS_MAX_ATOMS = 3 - def __init__( - self, - name: str = "InorganicsFilter", - n_jobs: int = 1, - uuid: Optional[str] = None, - ) -> None: - """Initialize InorganicsFilter. - - Parameters - ---------- - name: str, optional (default: "InorganicsFilter") - Name of the pipeline element. - n_jobs: int, optional (default: 1) - Number of parallel jobs to use. - uuid: str, optional (default: None) - Unique identifier of the pipeline element. - """ - super().__init__(name=name, n_jobs=n_jobs, uuid=uuid) - def pretransform_single(self, value: RDKitMol) -> OptionalMol: """Invalidate molecules not containing a carbon atom. diff --git a/molpipeline/mol2mol/standardization.py b/molpipeline/mol2mol/standardization.py index 4399ed49..24ad7136 100644 --- a/molpipeline/mol2mol/standardization.py +++ b/molpipeline/mol2mol/standardization.py @@ -162,7 +162,6 @@ def pretransform_single(self, value: RDKitMol) -> OptionalMol: class LargestFragmentChooser(_MolToMolPipelineElement): """MolToMolPipelineElement which returns the largest fragment of a molecule.""" - def pretransform_single(self, value: RDKitMol) -> OptionalMol: """Return largest fragment of molecule. @@ -182,7 +181,6 @@ def pretransform_single(self, value: RDKitMol) -> OptionalMol: class MetalDisconnector(_MolToMolPipelineElement): """MolToMolPipelineElement which removes bonds between organic compounds and metals.""" - def pretransform_single(self, value: RDKitMol) -> OptionalMol: """Cleave bonds with metals. @@ -206,7 +204,6 @@ def pretransform_single(self, value: RDKitMol) -> OptionalMol: class IsotopeRemover(_MolToMolPipelineElement): """MolToMolPipelineElement which removes isotope information of atoms in a molecule.""" - def pretransform_single(self, value: RDKitMol) -> OptionalMol: """Remove isotope information of each atom. @@ -231,7 +228,6 @@ def pretransform_single(self, value: RDKitMol) -> OptionalMol: class ExplicitHydrogenRemover(_MolToMolPipelineElement): """MolToMolPipelineElement which removes explicit hydrogen atoms from a molecule.""" - def pretransform_single(self, value: RDKitMol) -> OptionalMol: """Remove explicit hydrogen atoms. @@ -251,7 +247,6 @@ def pretransform_single(self, value: RDKitMol) -> OptionalMol: class StereoRemover(_MolToMolPipelineElement): """MolToMolPipelineElement which removes stereo-information from the molecule.""" - def pretransform_single(self, value: RDKitMol) -> OptionalMol: """Remove stereo-information in molecule. @@ -273,7 +268,6 @@ def pretransform_single(self, value: RDKitMol) -> OptionalMol: class SaltRemover(_MolToMolPipelineElement): """MolToMolPipelineElement which removes metal ions from molecule.""" - def pretransform_single(self, value: RDKitMol) -> OptionalMol: """Remove metal ions. @@ -459,7 +453,6 @@ def pretransform_single(self, value: RDKitMol) -> OptionalMol: class Uncharger(_MolToMolPipelineElement): """MolToMolPipelineElement which removes charges in a molecule, if possible.""" - def pretransform_single(self, value: RDKitMol) -> OptionalMol: """Remove charges of molecule. diff --git a/tests/test_elements/test_mol2mol/test_mol2mol_filter.py b/tests/test_elements/test_mol2mol/test_mol2mol_filter.py index 9a5572d3..ad79ea40 100644 --- a/tests/test_elements/test_mol2mol/test_mol2mol_filter.py +++ b/tests/test_elements/test_mol2mol/test_mol2mol_filter.py @@ -1,11 +1,19 @@ """Test MolFilter, which invalidate molecules based on criteria defined in the respective filter.""" import unittest +from typing import Optional, Union from molpipeline import ErrorFilter, FilterReinserter, Pipeline from molpipeline.any2mol import SmilesToMol from molpipeline.mol2any import MolToSmiles -from molpipeline.mol2mol import ElementFilter, InorganicsFilter, MixtureFilter +from molpipeline.mol2mol import ( + DescriptorsFilter, + ElementFilter, + InorganicsFilter, + MixtureFilter, + SmartsFilter, + SmilesFilter, +) # pylint: disable=duplicate-code # test case molecules are allowed to be duplicated SMILES_ANTIMONY = "[SbH6+3]" @@ -14,68 +22,143 @@ SMILES_CL_BR = "NC(Cl)(Br)C(=O)O" SMILES_METAL_AU = "OC[C@H]1OC(S[Au])[C@H](O)[C@@H](O)[C@@H]1O" +SMILES_LIST = [ + SMILES_ANTIMONY, + SMILES_BENZENE, + SMILES_CHLOROBENZENE, + SMILES_METAL_AU, + SMILES_CL_BR, +] + class MolFilterTest(unittest.TestCase): """Unittest for MolFilter, which invalidate molecules based on criteria defined in the respective filter.""" def test_element_filter(self) -> None: - """Test if molecules are filtered correctly by allowed chemical elements. - - Returns - ------- - None - """ - smiles2mol = SmilesToMol() - default_atoms = { - 1, - 5, - 6, - 7, - 8, - 9, - 14, - 15, - 16, - 17, - 34, - 35, - 53, + """Test if molecules are filtered correctly by allowed chemical elements.""" + default_atoms_dict = { + 1: (None, None), + 5: (None, None), + 6: (None, None), + 7: (None, None), + 8: (None, None), + 9: (None, None), + 14: (None, None), + 15: (None, None), + 16: (None, None), + 17: (None, None), + 34: (None, None), + 35: (None, None), + 53: (None, None), } element_filter = ElementFilter() - self.assertEqual(element_filter.allowed_element_numbers, default_atoms) - mol2smiles = MolToSmiles() - error_filter = ErrorFilter.from_element_list( - [smiles2mol, element_filter, mol2smiles] - ) + self.assertEqual(element_filter.allowed_element_numbers, default_atoms_dict) pipeline = Pipeline( [ - ("Smiles2Mol", smiles2mol), + ("Smiles2Mol", SmilesToMol()), ("ElementFilter", element_filter), - ("Mol2Smiles", mol2smiles), - ("ErrorFilter", error_filter), + ("Mol2Smiles", MolToSmiles()), + ("ErrorFilter", ErrorFilter()), ], ) - filtered_smiles = pipeline.fit_transform( + filtered_smiles = pipeline.fit_transform(SMILES_LIST) + self.assertEqual( + filtered_smiles, [SMILES_BENZENE, SMILES_CHLOROBENZENE, SMILES_CL_BR] + ) + pipeline.set_params( + ElementFilter__allowed_element_numbers={6: 6, 1: (5, 6), 17: 1} + ) + filtered_smiles_2 = pipeline.fit_transform(SMILES_LIST) + self.assertEqual(filtered_smiles_2, [SMILES_BENZENE, SMILES_CHLOROBENZENE]) + + def test_smarts_smiles_filter(self) -> None: + """Test if molecules are filtered correctly by allowed SMARTS patterns.""" + smarts_pats: dict[str, Union[int, tuple[Optional[int], Optional[int]]]] = { + "c": (4, None), + "Cl": 1, + } + smarts_filter = SmartsFilter(smarts_pats) + + smiles_pats: dict[str, Union[int, tuple[Optional[int], Optional[int]]]] = { + "c1ccccc1": (1, None), + "Cl": 1, + } + smiles_filter = SmilesFilter(smiles_pats) + + for filter_ in [smarts_filter, smiles_filter]: + new_input_as_list = list(filter_.patterns.keys()) + pipeline = Pipeline( + [ + ("Smiles2Mol", SmilesToMol()), + ("SmartsFilter", filter_), + ("Mol2Smiles", MolToSmiles()), + ("ErrorFilter", ErrorFilter()), + ], + ) + filtered_smiles = pipeline.fit_transform(SMILES_LIST) + self.assertEqual( + filtered_smiles, [SMILES_BENZENE, SMILES_CHLOROBENZENE, SMILES_CL_BR] + ) + + pipeline.set_params(SmartsFilter__keep=False) + filtered_smiles_2 = pipeline.fit_transform(SMILES_LIST) + self.assertEqual(filtered_smiles_2, [SMILES_ANTIMONY, SMILES_METAL_AU]) + + pipeline.set_params(SmartsFilter__mode="all", SmartsFilter__keep=True) + filtered_smiles_3 = pipeline.fit_transform(SMILES_LIST) + self.assertEqual(filtered_smiles_3, [SMILES_CHLOROBENZENE]) + + pipeline.set_params(SmartsFilter__keep=True, SmartsFilter__patterns=["I"]) + filtered_smiles_4 = pipeline.fit_transform(SMILES_LIST) + self.assertEqual(filtered_smiles_4, []) + + pipeline.set_params( + SmartsFilter__keep=False, + SmartsFilter__mode="any", + SmartsFilter__patterns=new_input_as_list, + ) + filtered_smiles_5 = pipeline.fit_transform(SMILES_LIST) + self.assertEqual(filtered_smiles_5, [SMILES_ANTIMONY, SMILES_METAL_AU]) + + def test_descriptor_filter(self) -> None: + """Test if molecules are filtered correctly by allowed descriptors.""" + descriptors: dict[str, tuple[Optional[float], Optional[float]]] = { + "MolWt": (None, 190), + "NumHAcceptors": (2, 10), + } + + descriptor_filter = DescriptorsFilter(descriptors) + + pipeline = Pipeline( [ - SMILES_ANTIMONY, - SMILES_BENZENE, - SMILES_CHLOROBENZENE, - SMILES_METAL_AU, - SMILES_CL_BR, - ] + ("Smiles2Mol", SmilesToMol()), + ("DescriptorsFilter", descriptor_filter), + ("Mol2Smiles", MolToSmiles()), + ("ErrorFilter", ErrorFilter()), + ], ) + filtered_smiles = pipeline.fit_transform(SMILES_LIST) + self.assertEqual(filtered_smiles, SMILES_LIST) + + pipeline.set_params(DescriptorsFilter__mode="all") + filtered_smiles_2 = pipeline.fit_transform(SMILES_LIST) + self.assertEqual(filtered_smiles_2, [SMILES_CL_BR]) + + pipeline.set_params(DescriptorsFilter__keep=False) + filtered_smiles_3 = pipeline.fit_transform(SMILES_LIST) + # why is this not self.assertEqual( - filtered_smiles, [SMILES_BENZENE, SMILES_CHLOROBENZENE, SMILES_CL_BR] + filtered_smiles_3, + [SMILES_ANTIMONY, SMILES_BENZENE, SMILES_CHLOROBENZENE, SMILES_METAL_AU], ) - def test_invalidate_mixtures(self) -> None: - """Test if mixtures are correctly invalidated. + pipeline.set_params(DescriptorsFilter__mode="any") + filtered_smiles_4 = pipeline.fit_transform(SMILES_LIST) + self.assertEqual(filtered_smiles_4, []) - Returns - ------- - None - """ + def test_invalidate_mixtures(self) -> None: + """Test if mixtures are correctly invalidated.""" mol_list = ["CCC.CC.C", "c1ccccc1.[Na+].[Cl-]", "c1ccccc1"] expected_invalidated_mol_list = [None, None, "c1ccccc1"] @@ -98,12 +181,7 @@ def test_invalidate_mixtures(self) -> None: self.assertEqual(expected_invalidated_mol_list, mols_processed) def test_inorganic_filter(self) -> None: - """Test if molecules are filtered correctly by allowed chemical elements. - - Returns - ------- - None - """ + """Test if molecules are filtered correctly by allowed chemical elements.""" smiles2mol = SmilesToMol() inorganics_filter = InorganicsFilter() mol2smiles = MolToSmiles() @@ -118,15 +196,7 @@ def test_inorganic_filter(self) -> None: ("ErrorFilter", error_filter), ], ) - filtered_smiles = pipeline.fit_transform( - [ - SMILES_ANTIMONY, - SMILES_BENZENE, - SMILES_CHLOROBENZENE, - SMILES_METAL_AU, - SMILES_CL_BR, - ] - ) + filtered_smiles = pipeline.fit_transform(SMILES_LIST) self.assertEqual( filtered_smiles, [SMILES_BENZENE, SMILES_CHLOROBENZENE, SMILES_METAL_AU, SMILES_CL_BR], From 476d65ac444afd6c58da8728789f1f66d8f5e7f4 Mon Sep 17 00:00:00 2001 From: frederik-sandfort1 Date: Tue, 20 Aug 2024 15:34:12 +0200 Subject: [PATCH 03/25] Fix wrong typing that caused thousands of type ignores --- .../abstract_pipeline_elements/core.py | 8 +++---- .../mol2any/mol2bitvector.py | 16 +++++++------- .../mol2any/mol2floatvector.py | 6 ++--- molpipeline/any2mol/sdf2mol.py | 6 ++--- molpipeline/error_handling.py | 8 +++---- .../mol2any/mol2concatinated_vector.py | 6 ++--- molpipeline/mol2any/mol2morgan_fingerprint.py | 6 ++--- molpipeline/mol2any/mol2net_charge.py | 6 ++--- molpipeline/mol2any/mol2path_fingerprint.py | 22 +++++++++---------- molpipeline/mol2any/mol2rdkit_phys_chem.py | 4 ++-- molpipeline/mol2mol/filter.py | 18 ++++++++++++--- molpipeline/mol2mol/reaction.py | 8 +++---- molpipeline/mol2mol/standardization.py | 6 ++--- molpipeline/pipeline/_molpipeline.py | 8 +++---- tests/utils/mock_element.py | 4 ++-- 15 files changed, 72 insertions(+), 60 deletions(-) diff --git a/molpipeline/abstract_pipeline_elements/core.py b/molpipeline/abstract_pipeline_elements/core.py index 62f888cf..1c1e7aeb 100644 --- a/molpipeline/abstract_pipeline_elements/core.py +++ b/molpipeline/abstract_pipeline_elements/core.py @@ -184,12 +184,12 @@ def get_params(self, deep: bool = True) -> dict[str, Any]: "uuid": self.uuid, } - def set_params(self, **parameters: dict[str, Any]) -> Self: + def set_params(self, **parameters: Any) -> Self: """As the setter function cannot be assessed with super(), this method is implemented for inheritance. Parameters ---------- - parameters: dict[str, Any] + parameters: Any Parameters to be set. Returns @@ -379,12 +379,12 @@ def parameters(self) -> dict[str, Any]: return self.get_params() @parameters.setter - def parameters(self, **parameters: dict[str, Any]) -> None: + def parameters(self, **parameters: Any) -> None: """Set the parameters of the object. Parameters ---------- - parameters: dict[str, Any] + parameters: Any Object parameters as a dictionary. Returns diff --git a/molpipeline/abstract_pipeline_elements/mol2any/mol2bitvector.py b/molpipeline/abstract_pipeline_elements/mol2any/mol2bitvector.py index 66317966..7de2cbab 100644 --- a/molpipeline/abstract_pipeline_elements/mol2any/mol2bitvector.py +++ b/molpipeline/abstract_pipeline_elements/mol2any/mol2bitvector.py @@ -140,12 +140,12 @@ def get_params(self, deep: bool = True) -> dict[str, Any]: return parameters - def set_params(self, **parameters: dict[str, Any]) -> Self: + def set_params(self, **parameters: Any) -> Self: """Set object parameters relevant for copying the class. Parameters ---------- - parameters: dict[str, Any] + parameters: Any Dictionary of parameter names and values. Returns @@ -160,7 +160,7 @@ def set_params(self, **parameters: dict[str, Any]) -> Self: raise ValueError( f"return_as has to be one of {get_args(OutputDatatype)}! (Received: {return_as})" ) - self._return_as = return_as # type: ignore + self._return_as = return_as super().set_params(**parameter_dict_copy) return self @@ -300,12 +300,12 @@ def get_params(self, deep: bool = True) -> dict[str, Any]: return parameters - def set_params(self, **parameters: dict[str, Any]) -> Self: + def set_params(self, **parameters: Any) -> Self: """Set object parameters relevant for copying the class. Parameters ---------- - parameters: dict[str, Any] + parameters: Any Dictionary of parameter names and values. Returns @@ -398,12 +398,12 @@ def get_params(self, deep: bool = True) -> dict[str, Any]: parameters.pop("fill_value", None) return parameters - def set_params(self, **parameters: dict[str, Any]) -> Self: + def set_params(self, **parameters: Any) -> Self: """Set parameters. Parameters ---------- - parameters: dict[str, Any] + parameters: Any Dictionary of parameter names and values. Returns @@ -417,7 +417,7 @@ def set_params(self, **parameters: dict[str, Any]) -> Self: # explicitly check for None, since 0 is a valid value if radius is not None: - self._radius = radius # type: ignore + self._radius = radius # explicitly check for None, since False is a valid value if use_features is not None: self._use_features = bool(use_features) diff --git a/molpipeline/abstract_pipeline_elements/mol2any/mol2floatvector.py b/molpipeline/abstract_pipeline_elements/mol2any/mol2floatvector.py index 3f25e2a4..3c0711c3 100644 --- a/molpipeline/abstract_pipeline_elements/mol2any/mol2floatvector.py +++ b/molpipeline/abstract_pipeline_elements/mol2any/mol2floatvector.py @@ -107,12 +107,12 @@ def get_params(self, deep: bool = True) -> dict[str, Any]: params["standardizer"] = self._standardizer return params - def set_params(self, **parameters: dict[str, Any]) -> Self: + def set_params(self, **parameters: Any) -> Self: """Set parameters. Parameters ---------- - parameters: dict[str, Any] + parameters: Any Dictionary with parameter names and corresponding values. Returns @@ -123,7 +123,7 @@ def set_params(self, **parameters: dict[str, Any]) -> Self: parameter_copy = dict(parameters) standardizer = parameter_copy.pop("standardizer", None) if standardizer is not None: - self._standardizer = standardizer # type: ignore + self._standardizer = standardizer super().set_params(**parameter_copy) return self diff --git a/molpipeline/any2mol/sdf2mol.py b/molpipeline/any2mol/sdf2mol.py index ded4567e..660ed413 100644 --- a/molpipeline/any2mol/sdf2mol.py +++ b/molpipeline/any2mol/sdf2mol.py @@ -71,12 +71,12 @@ def get_params(self, deep: bool = True) -> dict[str, Any]: params["identifier"] = self.identifier return params - def set_params(self, **parameters: dict[str, Any]) -> Self: + def set_params(self, **parameters: Any) -> Self: """Set parameters of the object. Parameters ---------- - parameters: dict[str, Any] + parameters: Any Dictionary containing all parameters defining the object. Returns @@ -86,7 +86,7 @@ def set_params(self, **parameters: dict[str, Any]) -> Self: """ super().set_params(**parameters) if "identifier" in parameters: - self.identifier = parameters["identifier"] # type: ignore + self.identifier = parameters["identifier"] return self def finish(self) -> None: diff --git a/molpipeline/error_handling.py b/molpipeline/error_handling.py index 4ef5efa1..147ef8d7 100644 --- a/molpipeline/error_handling.py +++ b/molpipeline/error_handling.py @@ -127,12 +127,12 @@ def get_params(self, deep: bool = True) -> dict[str, Any]: params["element_ids"] = self.element_ids return params - def set_params(self, **parameters: dict[str, Any]) -> Self: + def set_params(self, **parameters: Any) -> Self: """Set parameters for this element. Parameters ---------- - parameters: dict[str, Any] + parameters: Any Dict of arameters to set. Returns @@ -508,12 +508,12 @@ def get_params(self, deep: bool = True) -> dict[str, Any]: params["fill_value"] = self.fill_value return params - def set_params(self, **parameters: dict[str, Any]) -> Self: + def set_params(self, **parameters: Any) -> Self: """Set parameters for this element. Parameters ---------- - parameters: dict[str, Any] + parameters: Any Parameter dict. Returns diff --git a/molpipeline/mol2any/mol2concatinated_vector.py b/molpipeline/mol2any/mol2concatinated_vector.py index a0a24406..09c2e3db 100644 --- a/molpipeline/mol2any/mol2concatinated_vector.py +++ b/molpipeline/mol2any/mol2concatinated_vector.py @@ -113,12 +113,12 @@ def get_params(self, deep: bool = True) -> dict[str, Any]: return parameters - def set_params(self, **parameters: dict[str, Any]) -> Self: + def set_params(self, **parameters: Any) -> Self: """Set parameters. Parameters ---------- - parameters: dict[str, Any] + parameters: Any Parameters to set. Returns @@ -129,7 +129,7 @@ def set_params(self, **parameters: dict[str, Any]) -> Self: parameter_copy = dict(parameters) element_list = parameter_copy.pop("element_list", None) if element_list is not None: - self._element_list = element_list # type: ignore + self._element_list = element_list step_params: dict[str, dict[str, Any]] = {} step_dict = dict(self._element_list) to_delete_list = [] diff --git a/molpipeline/mol2any/mol2morgan_fingerprint.py b/molpipeline/mol2any/mol2morgan_fingerprint.py index 2f079e38..1c93295d 100644 --- a/molpipeline/mol2any/mol2morgan_fingerprint.py +++ b/molpipeline/mol2any/mol2morgan_fingerprint.py @@ -104,12 +104,12 @@ def get_params(self, deep: bool = True) -> dict[str, Any]: parameters["n_bits"] = self._n_bits return parameters - def set_params(self, **parameters: dict[str, Any]) -> Self: + def set_params(self, **parameters: Any) -> Self: """Set parameters. Parameters ---------- - parameters: dict[str, Any] + parameters: Any Dictionary of parameter names and values. Returns @@ -120,7 +120,7 @@ def set_params(self, **parameters: dict[str, Any]) -> Self: parameter_copy = dict(parameters) n_bits = parameter_copy.pop("n_bits", None) if n_bits is not None: - self._n_bits = n_bits # type: ignore + self._n_bits = n_bits super().set_params(**parameter_copy) return self diff --git a/molpipeline/mol2any/mol2net_charge.py b/molpipeline/mol2any/mol2net_charge.py index 6636f48c..759ec84b 100644 --- a/molpipeline/mol2any/mol2net_charge.py +++ b/molpipeline/mol2any/mol2net_charge.py @@ -145,12 +145,12 @@ def get_params(self, deep: bool = True) -> dict[str, Any]: parent_dict["charge_policy"] = self._charge_method return parent_dict - def set_params(self, **parameters: dict[str, Any]) -> Self: + def set_params(self, **parameters: Any) -> Self: """Set parameters. Parameters ---------- - parameters: dict[str, Any] + parameters: Any Parameters to set Returns @@ -161,6 +161,6 @@ def set_params(self, **parameters: dict[str, Any]) -> Self: parameters_shallow_copy = dict(parameters) charge_policy = parameters_shallow_copy.pop("charge_policy", None) if charge_policy is not None: - self._charge_method = charge_policy # type: ignore + self._charge_method = charge_policy super().set_params(**parameters_shallow_copy) return self diff --git a/molpipeline/mol2any/mol2path_fingerprint.py b/molpipeline/mol2any/mol2path_fingerprint.py index 368e8cb7..38e98d4a 100644 --- a/molpipeline/mol2any/mol2path_fingerprint.py +++ b/molpipeline/mol2any/mol2path_fingerprint.py @@ -152,12 +152,12 @@ def get_params(self, deep: bool = True) -> dict[str, Any]: parameters["n_bits"] = self._n_bits return parameters - def set_params(self, **parameters: dict[str, Any]) -> Self: + def set_params(self, **parameters: Any) -> Self: """Set parameters. Parameters ---------- - parameters: dict[str, Any] + parameters: Any Dictionary of parameter names and values. Returns @@ -168,28 +168,28 @@ def set_params(self, **parameters: dict[str, Any]) -> Self: parameter_copy = dict(parameters) min_path = parameter_copy.pop("min_path", None) if min_path is not None: - self._min_path = min_path # type: ignore + self._min_path = min_path max_path = parameter_copy.pop("max_path", None) if max_path is not None: - self._max_path = max_path # type: ignore + self._max_path = max_path use_hs = parameter_copy.pop("use_hs", None) if use_hs is not None: - self._use_hs = use_hs # type: ignore + self._use_hs = use_hs branched_paths = parameter_copy.pop("branched_paths", None) if branched_paths is not None: - self._branched_paths = branched_paths # type: ignore + self._branched_paths = branched_paths use_bond_order = parameter_copy.pop("use_bond_order", None) if use_bond_order is not None: - self._use_bond_order = use_bond_order # type: ignore + self._use_bond_order = use_bond_order count_simulation = parameter_copy.pop("count_simulation", None) if count_simulation is not None: - self._count_simulation = count_simulation # type: ignore + self._count_simulation = count_simulation count_bounds = parameter_copy.pop("count_bounds", None) if count_bounds is not None: - self._count_bounds = count_bounds # type: ignore + self._count_bounds = count_bounds num_bits_per_feature = parameter_copy.pop("num_bits_per_feature", None) if num_bits_per_feature is not None: - self._num_bits_per_feature = num_bits_per_feature # type: ignore + self._num_bits_per_feature = num_bits_per_feature atom_invariants_generator = parameter_copy.pop( "atom_invariants_generator", None ) @@ -197,7 +197,7 @@ def set_params(self, **parameters: dict[str, Any]) -> Self: self._atom_invariants_generator = atom_invariants_generator n_bits = parameter_copy.pop("n_bits", None) # pylint: disable=duplicate-code if n_bits is not None: - self._n_bits = n_bits # type: ignore + self._n_bits = n_bits super().set_params(**parameter_copy) return self diff --git a/molpipeline/mol2any/mol2rdkit_phys_chem.py b/molpipeline/mol2any/mol2rdkit_phys_chem.py index 968b0d1a..251b6b3c 100644 --- a/molpipeline/mol2any/mol2rdkit_phys_chem.py +++ b/molpipeline/mol2any/mol2rdkit_phys_chem.py @@ -170,12 +170,12 @@ def get_params(self, deep: bool = True) -> dict[str, Any]: parent_dict["log_exceptions"] = self._log_exceptions return parent_dict - def set_params(self, **parameters: dict[str, Any]) -> Self: + def set_params(self, **parameters: Any) -> Self: """Set parameters. Parameters ---------- - parameters: dict[str, Any] + parameters: Any Parameters to set Returns diff --git a/molpipeline/mol2mol/filter.py b/molpipeline/mol2mol/filter.py index c40b55c8..55edde5b 100644 --- a/molpipeline/mol2mol/filter.py +++ b/molpipeline/mol2mol/filter.py @@ -25,6 +25,18 @@ def _list_to_dict_with_counts(elements_list: list[_T]) -> dict[_T, int]: + """Convert list to dictionary with counts of elements. + + Parameters + ---------- + elements_list: list[_T] + List of elements. + + Returns + ------- + dict[_T, int] + Dictionary with elements as keys and counts as values. + """ counts_dict: dict[_T, int] = {} for element in elements_list: if element in counts_dict: @@ -143,12 +155,12 @@ def get_params(self, deep: bool = True) -> dict[str, Any]: params["allowed_element_numbers"] = self.allowed_element_numbers return params - def set_params(self, **parameters: dict[str, Any]) -> Self: + def set_params(self, **parameters: Any) -> Self: """Set parameters of ElementFilter. Parameters ---------- - parameters: dict[str, Any] + parameters: Any Parameters to set. Returns @@ -158,7 +170,7 @@ def set_params(self, **parameters: dict[str, Any]) -> Self: """ parameter_copy = dict(parameters) if "allowed_element_numbers" in parameter_copy: - self.allowed_element_numbers = parameter_copy.pop("allowed_element_numbers") # type: ignore + self.allowed_element_numbers = parameter_copy.pop("allowed_element_numbers") super().set_params(**parameter_copy) return self diff --git a/molpipeline/mol2mol/reaction.py b/molpipeline/mol2mol/reaction.py index aad5bfd7..97b13500 100644 --- a/molpipeline/mol2mol/reaction.py +++ b/molpipeline/mol2mol/reaction.py @@ -91,12 +91,12 @@ def get_params(self, deep: bool = True) -> dict[str, Any]: parameters["handle_multi"] = self.handle_multi return parameters - def set_params(self, **parameters: dict[str, Any]) -> Self: + def set_params(self, **parameters: Any) -> Self: """Set the parameters. Parameters ---------- - parameters: dict[str, Any] + parameters: Any Dictionary containing parameters to be set. Returns @@ -108,9 +108,9 @@ def set_params(self, **parameters: dict[str, Any]) -> Self: if "reaction" in parameters: self.reaction = parameters["reaction"] if "additive_list" in parameters: - self.additive_list = parameters["additive_list"] # type: ignore + self.additive_list = parameters["additive_list"] if "handle_multi" in parameters: - self.handle_multi = parameters["handle_multi"] # type: ignore + self.handle_multi = parameters["handle_multi"] return self @property diff --git a/molpipeline/mol2mol/standardization.py b/molpipeline/mol2mol/standardization.py index 24ad7136..bda92773 100644 --- a/molpipeline/mol2mol/standardization.py +++ b/molpipeline/mol2mol/standardization.py @@ -394,12 +394,12 @@ def get_params(self, deep: bool = True) -> dict[str, Any]: params["solvent_smiles_list"] = self.solvent_smiles_list return params - def set_params(self, **parameters: dict[str, Any]) -> Self: + def set_params(self, **parameters: Any) -> Self: """Set parameters of pipeline element. Parameters ---------- - parameters: dict[str, Any] + parameters: Any Parameters to set. Returns @@ -410,7 +410,7 @@ def set_params(self, **parameters: dict[str, Any]) -> Self: param_copy = dict(parameters) solvent_smiles_list = param_copy.pop("solvent_smiles_list", None) if solvent_smiles_list is not None: - self.solvent_smiles_list = solvent_smiles_list # type: ignore + self.solvent_smiles_list = solvent_smiles_list super().set_params(**param_copy) return self diff --git a/molpipeline/pipeline/_molpipeline.py b/molpipeline/pipeline/_molpipeline.py index 2771b7e3..3ddb7c9b 100644 --- a/molpipeline/pipeline/_molpipeline.py +++ b/molpipeline/pipeline/_molpipeline.py @@ -166,12 +166,12 @@ def get_params(self, deep: bool = True) -> dict[str, Any]: "raise_nones": self.raise_nones, } - def set_params(self, **parameter_dict: dict[str, Any]) -> Self: + def set_params(self, **parameter_dict: Any) -> Self: """Set parameters of the pipeline and pipeline elements. Parameters ---------- - parameter_dict: dict[str, Any] + parameter_dict: Any Dictionary containing the parameter names and corresponding values to be set. Returns @@ -180,9 +180,9 @@ def set_params(self, **parameter_dict: dict[str, Any]) -> Self: MolPipeline object with updated parameters. """ if "element_list" in parameter_dict: - self._element_list = parameter_dict["element_list"] # type: ignore + self._element_list = parameter_dict["element_list"] if "n_jobs" in parameter_dict: - self.n_jobs = int(parameter_dict["n_jobs"]) # type: ignore + self.n_jobs = int(parameter_dict["n_jobs"]) if "name" in parameter_dict: self.name = str(parameter_dict["name"]) if "raise_nones" in parameter_dict: diff --git a/tests/utils/mock_element.py b/tests/utils/mock_element.py index 8b0ba64f..4bf8eee4 100644 --- a/tests/utils/mock_element.py +++ b/tests/utils/mock_element.py @@ -73,12 +73,12 @@ def get_params(self, deep: bool = True) -> dict[str, Any]: params["return_as_numpy_array"] = self.return_as_numpy_array return params - def set_params(self, **parameters: dict[str, Any]) -> Self: + def set_params(self, **parameters: Any) -> Self: """Set parameters of the object. Parameters ---------- - parameters: dict[str, Any] + parameters: Any Dictionary containing all parameters defining the object. Returns From f14b71aba02f10833445401040ec4c07e8aa099c Mon Sep 17 00:00:00 2001 From: frederik-sandfort1 Date: Tue, 20 Aug 2024 15:46:42 +0200 Subject: [PATCH 04/25] linting and fix element number test --- .../mol2mol/filter.py | 2 +- molpipeline/mol2mol/filter.py | 85 +++++++++---------- .../test_mol2mol/test_mol2mol_filter.py | 26 +++--- 3 files changed, 56 insertions(+), 57 deletions(-) diff --git a/molpipeline/abstract_pipeline_elements/mol2mol/filter.py b/molpipeline/abstract_pipeline_elements/mol2mol/filter.py index 191b3ddd..acfbf0c0 100644 --- a/molpipeline/abstract_pipeline_elements/mol2mol/filter.py +++ b/molpipeline/abstract_pipeline_elements/mol2mol/filter.py @@ -70,7 +70,7 @@ def patterns( List of patterns. """ self._patterns: dict[str, tuple[Optional[int], Optional[int]]] - if isinstance(patterns, list) or isinstance(patterns, set): + if isinstance(patterns, (list, set)): self._patterns = {pat: (1, None) for pat in patterns} else: self._patterns = {} diff --git a/molpipeline/mol2mol/filter.py b/molpipeline/mol2mol/filter.py index 55edde5b..95adb85f 100644 --- a/molpipeline/mol2mol/filter.py +++ b/molpipeline/mol2mol/filter.py @@ -118,9 +118,7 @@ def allowed_element_numbers( self._allowed_element_numbers: dict[int, tuple[Optional[int], Optional[int]]] if allowed_element_numbers is None: allowed_element_numbers = self.DEFAULT_ALLOWED_ELEMENT_NUMBERS - if isinstance(allowed_element_numbers, list) or isinstance( - allowed_element_numbers, set - ): + if isinstance(allowed_element_numbers, (list, set)): self._allowed_element_numbers = { atom_number: (1, None) for atom_number in allowed_element_numbers } @@ -190,6 +188,12 @@ def pretransform_single(self, value: RDKitMol) -> OptionalMol: elements_list = [atom.GetAtomicNum() for atom in value.GetAtoms()] elements_count_dict = _list_to_dict_with_counts(elements_list) for element, count in elements_count_dict.items(): + if element not in self.allowed_element_numbers: + return InvalidInstance( + self.uuid, + f"Molecule contains forbidden element {element}.", + self.name, + ) min_count, max_count = self.allowed_element_numbers[element] if (min_count is not None and count < min_count) or ( max_count is not None and count > max_count @@ -253,8 +257,7 @@ def pretransform_single(self, value: RDKitMol) -> OptionalMol: self.name, ) ) - else: - match_counts += 1 + match_counts += 1 if self.mode == "any": return ( value @@ -265,27 +268,25 @@ def pretransform_single(self, value: RDKitMol) -> OptionalMol: self.name, ) ) - else: - if match_counts == len(self.patterns): - return ( - value - if self.keep - else InvalidInstance( - self.uuid, - "Molecule matches one of the SmartsFilter patterns.", - self.name, - ) - ) - else: - return ( - value - if not self.keep - else InvalidInstance( - self.uuid, - "Molecule does not match all of the SmartsFilter patterns.", - self.name, - ) + if match_counts == len(self.patterns): + return ( + value + if self.keep + else InvalidInstance( + self.uuid, + "Molecule matches one of the SmartsFilter patterns.", + self.name, ) + ) + return ( + value + if not self.keep + else InvalidInstance( + self.uuid, + "Molecule does not match all of the SmartsFilter patterns.", + self.name, + ) + ) class SmilesFilter(_BasePatternsFilter): @@ -340,16 +341,15 @@ def pretransform_single(self, value: RDKitMol) -> OptionalMol: self.name, ) ) - else: - return ( - value - if self.keep - else InvalidInstance( - self.uuid, - "Molecule does not match all of the SmilesFilter patterns.", - self.name, - ) + return ( + value + if self.keep + else InvalidInstance( + self.uuid, + "Molecule does not match all of the SmilesFilter patterns.", + self.name, ) + ) class DescriptorsFilter(_MolToMolPipelineElement): @@ -484,16 +484,15 @@ def pretransform_single(self, value: RDKitMol) -> OptionalMol: self.name, ) ) - else: - return ( - value - if self.keep - else InvalidInstance( - self.uuid, - "Molecule does not match all of the DescriptorsFilter descriptors.", - self.name, - ) + return ( + value + if self.keep + else InvalidInstance( + self.uuid, + "Molecule does not match all of the DescriptorsFilter descriptors.", + self.name, ) + ) class MixtureFilter(_MolToMolPipelineElement): diff --git a/tests/test_elements/test_mol2mol/test_mol2mol_filter.py b/tests/test_elements/test_mol2mol/test_mol2mol_filter.py index ad79ea40..8b2ecf5b 100644 --- a/tests/test_elements/test_mol2mol/test_mol2mol_filter.py +++ b/tests/test_elements/test_mol2mol/test_mol2mol_filter.py @@ -37,19 +37,19 @@ class MolFilterTest(unittest.TestCase): def test_element_filter(self) -> None: """Test if molecules are filtered correctly by allowed chemical elements.""" default_atoms_dict = { - 1: (None, None), - 5: (None, None), - 6: (None, None), - 7: (None, None), - 8: (None, None), - 9: (None, None), - 14: (None, None), - 15: (None, None), - 16: (None, None), - 17: (None, None), - 34: (None, None), - 35: (None, None), - 53: (None, None), + 1: (1, None), + 5: (1, None), + 6: (1, None), + 7: (1, None), + 8: (1, None), + 9: (1, None), + 14: (1, None), + 15: (1, None), + 16: (1, None), + 17: (1, None), + 34: (1, None), + 35: (1, None), + 53: (1, None), } element_filter = ElementFilter() From c352144a42d109f7f2c00e13863d137ff40ce03c Mon Sep 17 00:00:00 2001 From: frederik-sandfort1 Date: Wed, 21 Aug 2024 15:48:07 +0200 Subject: [PATCH 05/25] reset name typing --- molpipeline/abstract_pipeline_elements/core.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/molpipeline/abstract_pipeline_elements/core.py b/molpipeline/abstract_pipeline_elements/core.py index 1c1e7aeb..276a9fa5 100644 --- a/molpipeline/abstract_pipeline_elements/core.py +++ b/molpipeline/abstract_pipeline_elements/core.py @@ -91,7 +91,7 @@ def __repr__(self) -> str: class ABCPipelineElement(abc.ABC): """Ancestor of all PipelineElements.""" - name: Optional[str] + name: str _requires_fitting: bool = False uuid: str @@ -336,7 +336,7 @@ class TransformingPipelineElement(ABCPipelineElement): _input_type: str _output_type: str - name: Optional[str] + name: str def __init__( self, From 5c95f8115bb394b9ea2869d17329c0150b6b6be5 Mon Sep 17 00:00:00 2001 From: frederik-sandfort1 Date: Thu, 22 Aug 2024 11:56:59 +0200 Subject: [PATCH 06/25] Christians first review --- .../mol2mol/filter.py | 13 ++-- molpipeline/mol2mol/filter.py | 64 +++++++------------ molpipeline/utils/value_conversions.py | 28 ++++++++ .../test_mol2mol/test_mol2mol_filter.py | 53 +++++++++++---- 4 files changed, 95 insertions(+), 63 deletions(-) create mode 100644 molpipeline/utils/value_conversions.py diff --git a/molpipeline/abstract_pipeline_elements/mol2mol/filter.py b/molpipeline/abstract_pipeline_elements/mol2mol/filter.py index acfbf0c0..5adcac4a 100644 --- a/molpipeline/abstract_pipeline_elements/mol2mol/filter.py +++ b/molpipeline/abstract_pipeline_elements/mol2mol/filter.py @@ -9,11 +9,14 @@ from typing_extensions import Self from molpipeline.abstract_pipeline_elements.core import MolToMolPipelineElement +from molpipeline.utils.value_conversions import count_value_to_tuple class BasePatternsFilter(MolToMolPipelineElement, abc.ABC): """Filter to keep or remove molecules based on patterns.""" + _patterns: dict[str, tuple[Optional[int], Optional[int]]] + def __init__( self, patterns: Union[ @@ -69,16 +72,12 @@ def patterns( patterns: Union[list[str], dict[str, Union[int, tuple[Optional[int], Optional[int]]]]] List of patterns. """ - self._patterns: dict[str, tuple[Optional[int], Optional[int]]] if isinstance(patterns, (list, set)): self._patterns = {pat: (1, None) for pat in patterns} else: - self._patterns = {} - for pat, count in patterns.items(): - if isinstance(count, int): - self._patterns[pat] = (count, count) - else: - self._patterns[pat] = count + self._patterns = { + pat: count_value_to_tuple(count) for pat, count in patterns.items() + } def get_params(self, deep: bool = True) -> dict[str, Any]: """Get parameters of PatternFilter. diff --git a/molpipeline/mol2mol/filter.py b/molpipeline/mol2mol/filter.py index 95adb85f..a2ed202a 100644 --- a/molpipeline/mol2mol/filter.py +++ b/molpipeline/mol2mol/filter.py @@ -2,7 +2,8 @@ from __future__ import annotations -from typing import Any, Literal, Optional, TypeVar, Union +from collections import Counter +from typing import Any, Literal, Optional, Union try: from typing import Self # type: ignore[attr-defined] @@ -20,30 +21,7 @@ BasePatternsFilter as _BasePatternsFilter, ) from molpipeline.utils.molpipeline_types import OptionalMol, RDKitMol - -_T = TypeVar("_T") - - -def _list_to_dict_with_counts(elements_list: list[_T]) -> dict[_T, int]: - """Convert list to dictionary with counts of elements. - - Parameters - ---------- - elements_list: list[_T] - List of elements. - - Returns - ------- - dict[_T, int] - Dictionary with elements as keys and counts as values. - """ - counts_dict: dict[_T, int] = {} - for element in elements_list: - if element in counts_dict: - counts_dict[element] += 1 - else: - counts_dict[element] = 1 - return counts_dict +from molpipeline.utils.value_conversions import count_value_to_tuple class ElementFilter(_MolToMolPipelineElement): @@ -120,15 +98,13 @@ def allowed_element_numbers( allowed_element_numbers = self.DEFAULT_ALLOWED_ELEMENT_NUMBERS if isinstance(allowed_element_numbers, (list, set)): self._allowed_element_numbers = { - atom_number: (1, None) for atom_number in allowed_element_numbers + atom_number: (0, None) for atom_number in allowed_element_numbers } else: - self._allowed_element_numbers = {} - for atom_number, count in allowed_element_numbers.items(): - if isinstance(count, int): - self._allowed_element_numbers[atom_number] = (count, count) - else: - self._allowed_element_numbers[atom_number] = count + self._allowed_element_numbers = { + atom_number: count_value_to_tuple(count) + for atom_number, count in allowed_element_numbers.items() + } def get_params(self, deep: bool = True) -> dict[str, Any]: """Get parameters of ElementFilter. @@ -185,16 +161,20 @@ def pretransform_single(self, value: RDKitMol) -> OptionalMol: OptionalMol Molecule if it contains only allowed elements, else InvalidInstance. """ - elements_list = [atom.GetAtomicNum() for atom in value.GetAtoms()] - elements_count_dict = _list_to_dict_with_counts(elements_list) - for element, count in elements_count_dict.items(): - if element not in self.allowed_element_numbers: - return InvalidInstance( - self.uuid, - f"Molecule contains forbidden element {element}.", - self.name, - ) - min_count, max_count = self.allowed_element_numbers[element] + to_process_value = ( + Chem.AddHs(value) if 1 in self.allowed_element_numbers else value + ) + + elements_list = [atom.GetAtomicNum() for atom in to_process_value.GetAtoms()] + elements_counter = Counter(elements_list) + if any( + element not in self.allowed_element_numbers for element in elements_counter + ): + return InvalidInstance( + self.uuid, "Molecule contains forbidden chemical element.", self.name + ) + for element, (min_count, max_count) in self.allowed_element_numbers.items(): + count = elements_counter[element] if (min_count is not None and count < min_count) or ( max_count is not None and count > max_count ): diff --git a/molpipeline/utils/value_conversions.py b/molpipeline/utils/value_conversions.py new file mode 100644 index 00000000..a348c885 --- /dev/null +++ b/molpipeline/utils/value_conversions.py @@ -0,0 +1,28 @@ +"""Module for utilities converting values.""" + +from typing import Optional, Sequence, Union + + +def count_value_to_tuple( + count: Union[int, tuple[Optional[int], Optional[int]]] +) -> tuple[Optional[int], Optional[int]]: + """Convert a count value to a tuple. + + Parameters + ---------- + count: Union[int, float, tuple[Optional[int], Optional[int]]] + Count value. Can be a single float or int or a tuple of two values. + + Returns + ------- + tuple[Optional[int], Optional[int]] + Tuple of count values. + """ + if isinstance(count, int): + return count, count + if isinstance(count, Sequence): + count_tuple = tuple(count) + if len(count_tuple) != 2: + raise ValueError(f"Expected a sequence of length 2, got: {count_tuple}") + return count_tuple + raise TypeError(f"Got unexpected type: {type(count)}") diff --git a/tests/test_elements/test_mol2mol/test_mol2mol_filter.py b/tests/test_elements/test_mol2mol/test_mol2mol_filter.py index 8b2ecf5b..edf066b3 100644 --- a/tests/test_elements/test_mol2mol/test_mol2mol_filter.py +++ b/tests/test_elements/test_mol2mol/test_mol2mol_filter.py @@ -37,19 +37,19 @@ class MolFilterTest(unittest.TestCase): def test_element_filter(self) -> None: """Test if molecules are filtered correctly by allowed chemical elements.""" default_atoms_dict = { - 1: (1, None), - 5: (1, None), - 6: (1, None), - 7: (1, None), - 8: (1, None), - 9: (1, None), - 14: (1, None), - 15: (1, None), - 16: (1, None), - 17: (1, None), - 34: (1, None), - 35: (1, None), - 53: (1, None), + 1: (0, None), + 5: (0, None), + 6: (0, None), + 7: (0, None), + 8: (0, None), + 9: (0, None), + 14: (0, None), + 15: (0, None), + 16: (0, None), + 17: (0, None), + 34: (0, None), + 35: (0, None), + 53: (0, None), } element_filter = ElementFilter() @@ -67,7 +67,7 @@ def test_element_filter(self) -> None: filtered_smiles, [SMILES_BENZENE, SMILES_CHLOROBENZENE, SMILES_CL_BR] ) pipeline.set_params( - ElementFilter__allowed_element_numbers={6: 6, 1: (5, 6), 17: 1} + ElementFilter__allowed_element_numbers={6: 6, 1: (5, 6), 17: (0, 1)} ) filtered_smiles_2 = pipeline.fit_transform(SMILES_LIST) self.assertEqual(filtered_smiles_2, [SMILES_BENZENE, SMILES_CHLOROBENZENE]) @@ -121,6 +121,31 @@ def test_smarts_smiles_filter(self) -> None: filtered_smiles_5 = pipeline.fit_transform(SMILES_LIST) self.assertEqual(filtered_smiles_5, [SMILES_ANTIMONY, SMILES_METAL_AU]) + def test_smarts_filter_parallel(self) -> None: + """Test if molecules are filtered correctly by allowed SMARTS patterns in parallel.""" + smarts_pats: dict[str, Union[int, tuple[Optional[int], Optional[int]]]] = { + "c": (4, None), + "Cl": 1, + "cc": (1, None), + "ccc": (1, None), + "cccc": (1, None), + "ccccc": (1, None), + "cccccc": (1, None), + "c1ccccc1": (1, None), + "cCl": 1, + } + smarts_filter = SmartsFilter(smarts_pats, mode="all", n_jobs=-1) + pipeline = Pipeline( + [ + ("Smiles2Mol", SmilesToMol()), + ("SmartsFilter", smarts_filter), + ("Mol2Smiles", MolToSmiles()), + ("ErrorFilter", ErrorFilter()), + ], + ) + filtered_smiles = pipeline.fit_transform(SMILES_LIST) + self.assertEqual(filtered_smiles, [SMILES_CHLOROBENZENE]) + def test_descriptor_filter(self) -> None: """Test if molecules are filtered correctly by allowed descriptors.""" descriptors: dict[str, tuple[Optional[float], Optional[float]]] = { From 16088db7c4c6f6b2c9710238896758820ad38091 Mon Sep 17 00:00:00 2001 From: frederik-sandfort1 Date: Thu, 22 Aug 2024 12:12:20 +0200 Subject: [PATCH 07/25] more changes --- .../mol2mol/filter.py | 12 ++-- molpipeline/mol2mol/filter.py | 55 +++++++++++++------ .../test_mol2mol/test_mol2mol_filter.py | 10 ++-- 3 files changed, 50 insertions(+), 27 deletions(-) diff --git a/molpipeline/abstract_pipeline_elements/mol2mol/filter.py b/molpipeline/abstract_pipeline_elements/mol2mol/filter.py index 5adcac4a..b38c3e45 100644 --- a/molpipeline/abstract_pipeline_elements/mol2mol/filter.py +++ b/molpipeline/abstract_pipeline_elements/mol2mol/filter.py @@ -22,7 +22,7 @@ def __init__( patterns: Union[ list[str], dict[str, Union[int, tuple[Optional[int], Optional[int]]]] ], - keep: bool = True, + keep_matches: bool = True, mode: Literal["any", "all"] = "any", name: Optional[str] = None, n_jobs: int = 1, @@ -36,7 +36,7 @@ def __init__( List of patterns to allow in molecules. Alternatively, a dictionary can be passed with patterns as keys and an int for exact count or a tuple of minimum and maximum. - keep: bool, optional (default: True) + keep_matches: bool, optional (default: True) If True, molecules containing the specified patterns are kept, else removed. mode: Literal["any", "all"], optional (default: "any") If "any", at least one of the specified patterns must be present in the molecule. @@ -50,7 +50,7 @@ def __init__( """ super().__init__(name=name, n_jobs=n_jobs, uuid=uuid) self.patterns = patterns # type: ignore - self.keep = keep + self.keep_matches = keep_matches self.mode = mode @property @@ -100,7 +100,7 @@ def get_params(self, deep: bool = True) -> dict[str, Any]: } else: params["patterns"] = self.patterns - params["keep"] = self.keep + params["keep_matches"] = self.keep_matches params["mode"] = self.mode return params @@ -120,8 +120,8 @@ def set_params(self, **parameters: Any) -> Self: parameter_copy = dict(parameters) if "patterns" in parameter_copy: self.patterns = parameter_copy.pop("patterns") - if "keep" in parameter_copy: - self.keep = parameter_copy.pop("keep") + if "keep_matches" in parameter_copy: + self.keep_matches = parameter_copy.pop("keep_matches") if "mode" in parameter_copy: self.mode = parameter_copy.pop("mode") super().set_params(**parameter_copy) diff --git a/molpipeline/mol2mol/filter.py b/molpipeline/mol2mol/filter.py index a2ed202a..f30c34bc 100644 --- a/molpipeline/mol2mol/filter.py +++ b/molpipeline/mol2mol/filter.py @@ -230,7 +230,7 @@ def pretransform_single(self, value: RDKitMol) -> OptionalMol: if self.mode == "any": return ( value - if self.keep + if self.keep_matches else InvalidInstance( self.uuid, f"Molecule contains forbidden SMARTS pattern {match_smarts}.", @@ -241,7 +241,7 @@ def pretransform_single(self, value: RDKitMol) -> OptionalMol: if self.mode == "any": return ( value - if not self.keep + if not self.keep_matches else InvalidInstance( self.uuid, "Molecule does not match any of the SmartsFilter patterns.", @@ -251,7 +251,7 @@ def pretransform_single(self, value: RDKitMol) -> OptionalMol: if match_counts == len(self.patterns): return ( value - if self.keep + if self.keep_matches else InvalidInstance( self.uuid, "Molecule matches one of the SmartsFilter patterns.", @@ -260,7 +260,7 @@ def pretransform_single(self, value: RDKitMol) -> OptionalMol: ) return ( value - if not self.keep + if not self.keep_matches else InvalidInstance( self.uuid, "Molecule does not match all of the SmartsFilter patterns.", @@ -293,7 +293,7 @@ def pretransform_single(self, value: RDKitMol) -> OptionalMol: if self.mode == "any": return ( value - if self.keep + if self.keep_matches else InvalidInstance( self.uuid, f"Molecule contains forbidden SMILES pattern {pattern}.", @@ -304,7 +304,7 @@ def pretransform_single(self, value: RDKitMol) -> OptionalMol: if self.mode == "all": return ( value - if not self.keep + if not self.keep_matches else InvalidInstance( self.uuid, "Molecule does not match all required patterns.", @@ -314,7 +314,7 @@ def pretransform_single(self, value: RDKitMol) -> OptionalMol: if self.mode == "any": return ( value - if not self.keep + if not self.keep_matches else InvalidInstance( self.uuid, "Molecule does not match any of the SmilesFilter patterns.", @@ -323,7 +323,7 @@ def pretransform_single(self, value: RDKitMol) -> OptionalMol: ) return ( value - if self.keep + if self.keep_matches else InvalidInstance( self.uuid, "Molecule does not match all of the SmilesFilter patterns.", @@ -338,7 +338,7 @@ class DescriptorsFilter(_MolToMolPipelineElement): def __init__( self, descriptors: dict[str, tuple[Optional[float], Optional[float]]], - keep: bool = True, + keep_matches: bool = True, mode: Literal["any", "all"] = "any", name: Optional[str] = None, n_jobs: int = 1, @@ -351,7 +351,7 @@ def __init__( descriptors: dict[str, tuple[Optional[float], Optional[float]]] Dictionary of RDKit descriptors to filter by. The value must be a tuple of minimum and maximum. If None, no limit is set. - keep: bool, optional (default: True) + keep_matches: bool, optional (default: True) If True, molecules containing the specified descriptors are kept, else removed. mode: Literal["any", "all"], optional (default: "any") If "any", at least one of the specified descriptors must be present in the molecule. @@ -365,7 +365,7 @@ def __init__( """ super().__init__(name=name, n_jobs=n_jobs, uuid=uuid) self.descriptors = descriptors - self.keep = keep + self.keep_matches = keep_matches self.mode = mode @property @@ -411,10 +411,33 @@ def get_params(self, deep: bool = True) -> dict[str, Any]: } else: params["descriptors"] = self.descriptors - params["keep"] = self.keep + params["keep_matches"] = self.keep_matches params["mode"] = self.mode return params + def set_params(self, **parameters: Any) -> Self: + """Set parameters of PatternFilter. + + Parameters + ---------- + parameters: Any + Parameters to set. + + Returns + ------- + Self + Self. + """ + parameter_copy = dict(parameters) + if "descriptors" in parameter_copy: + self.patterns = parameter_copy.pop("descriptors") + if "keep_matches" in parameter_copy: + self.keep_matches = parameter_copy.pop("keep_matches") + if "mode" in parameter_copy: + self.mode = parameter_copy.pop("mode") + super().set_params(**parameter_copy) + return self + def pretransform_single(self, value: RDKitMol) -> OptionalMol: """Invalidate or validate molecule based on specified RDKit descriptors. @@ -436,7 +459,7 @@ def pretransform_single(self, value: RDKitMol) -> OptionalMol: if self.mode == "any": return ( value - if self.keep + if self.keep_matches else InvalidInstance( self.uuid, f"Molecule contains forbidden descriptor {descriptor}.", @@ -447,7 +470,7 @@ def pretransform_single(self, value: RDKitMol) -> OptionalMol: if self.mode == "all": return ( value - if not self.keep + if not self.keep_matches else InvalidInstance( self.uuid, "Molecule does not match all required descriptors.", @@ -457,7 +480,7 @@ def pretransform_single(self, value: RDKitMol) -> OptionalMol: if self.mode == "any": return ( value - if not self.keep + if not self.keep_matches else InvalidInstance( self.uuid, "Molecule does not match any of the DescriptorsFilter descriptors.", @@ -466,7 +489,7 @@ def pretransform_single(self, value: RDKitMol) -> OptionalMol: ) return ( value - if self.keep + if self.keep_matches else InvalidInstance( self.uuid, "Molecule does not match all of the DescriptorsFilter descriptors.", diff --git a/tests/test_elements/test_mol2mol/test_mol2mol_filter.py b/tests/test_elements/test_mol2mol/test_mol2mol_filter.py index edf066b3..d488cdd8 100644 --- a/tests/test_elements/test_mol2mol/test_mol2mol_filter.py +++ b/tests/test_elements/test_mol2mol/test_mol2mol_filter.py @@ -101,20 +101,20 @@ def test_smarts_smiles_filter(self) -> None: filtered_smiles, [SMILES_BENZENE, SMILES_CHLOROBENZENE, SMILES_CL_BR] ) - pipeline.set_params(SmartsFilter__keep=False) + pipeline.set_params(SmartsFilter__keep_matches=False) filtered_smiles_2 = pipeline.fit_transform(SMILES_LIST) self.assertEqual(filtered_smiles_2, [SMILES_ANTIMONY, SMILES_METAL_AU]) - pipeline.set_params(SmartsFilter__mode="all", SmartsFilter__keep=True) + pipeline.set_params(SmartsFilter__mode="all", SmartsFilter__keep_matches=True) filtered_smiles_3 = pipeline.fit_transform(SMILES_LIST) self.assertEqual(filtered_smiles_3, [SMILES_CHLOROBENZENE]) - pipeline.set_params(SmartsFilter__keep=True, SmartsFilter__patterns=["I"]) + pipeline.set_params(SmartsFilter__keep_matches=True, SmartsFilter__patterns=["I"]) filtered_smiles_4 = pipeline.fit_transform(SMILES_LIST) self.assertEqual(filtered_smiles_4, []) pipeline.set_params( - SmartsFilter__keep=False, + SmartsFilter__keep_matches=False, SmartsFilter__mode="any", SmartsFilter__patterns=new_input_as_list, ) @@ -170,7 +170,7 @@ def test_descriptor_filter(self) -> None: filtered_smiles_2 = pipeline.fit_transform(SMILES_LIST) self.assertEqual(filtered_smiles_2, [SMILES_CL_BR]) - pipeline.set_params(DescriptorsFilter__keep=False) + pipeline.set_params(DescriptorsFilter__keep_matches=False) filtered_smiles_3 = pipeline.fit_transform(SMILES_LIST) # why is this not self.assertEqual( From b2ca26dbb207573f256e260d0705d5413754439c Mon Sep 17 00:00:00 2001 From: frederik-sandfort1 Date: Thu, 22 Aug 2024 12:13:09 +0200 Subject: [PATCH 08/25] linting --- tests/test_elements/test_mol2mol/test_mol2mol_filter.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/tests/test_elements/test_mol2mol/test_mol2mol_filter.py b/tests/test_elements/test_mol2mol/test_mol2mol_filter.py index d488cdd8..1387577d 100644 --- a/tests/test_elements/test_mol2mol/test_mol2mol_filter.py +++ b/tests/test_elements/test_mol2mol/test_mol2mol_filter.py @@ -105,11 +105,15 @@ def test_smarts_smiles_filter(self) -> None: filtered_smiles_2 = pipeline.fit_transform(SMILES_LIST) self.assertEqual(filtered_smiles_2, [SMILES_ANTIMONY, SMILES_METAL_AU]) - pipeline.set_params(SmartsFilter__mode="all", SmartsFilter__keep_matches=True) + pipeline.set_params( + SmartsFilter__mode="all", SmartsFilter__keep_matches=True + ) filtered_smiles_3 = pipeline.fit_transform(SMILES_LIST) self.assertEqual(filtered_smiles_3, [SMILES_CHLOROBENZENE]) - pipeline.set_params(SmartsFilter__keep_matches=True, SmartsFilter__patterns=["I"]) + pipeline.set_params( + SmartsFilter__keep_matches=True, SmartsFilter__patterns=["I"] + ) filtered_smiles_4 = pipeline.fit_transform(SMILES_LIST) self.assertEqual(filtered_smiles_4, []) From 81ffb7cc8ea02cf9d0a166ae7ac6f02017447df8 Mon Sep 17 00:00:00 2001 From: frederik-sandfort1 Date: Thu, 22 Aug 2024 12:27:19 +0200 Subject: [PATCH 09/25] pylint --- .../mol2mol/__init__.py | 7 +- .../mol2mol/filter.py | 88 ++++++++++++++++--- molpipeline/mol2mol/filter.py | 17 ++-- 3 files changed, 90 insertions(+), 22 deletions(-) diff --git a/molpipeline/abstract_pipeline_elements/mol2mol/__init__.py b/molpipeline/abstract_pipeline_elements/mol2mol/__init__.py index eb352dba..15d7b5a4 100644 --- a/molpipeline/abstract_pipeline_elements/mol2mol/__init__.py +++ b/molpipeline/abstract_pipeline_elements/mol2mol/__init__.py @@ -1,5 +1,8 @@ """Initialize the module for abstract mol2mol elements.""" -from molpipeline.abstract_pipeline_elements.mol2mol.filter import BasePatternsFilter +from molpipeline.abstract_pipeline_elements.mol2mol.filter import ( + BaseKeepMatchesFilter, + BasePatternsFilter, +) -__all__ = ["BasePatternsFilter"] +__all__ = ["BasePatternsFilter", "BaseKeepMatchesFilter"] diff --git a/molpipeline/abstract_pipeline_elements/mol2mol/filter.py b/molpipeline/abstract_pipeline_elements/mol2mol/filter.py index b38c3e45..9be3c42c 100644 --- a/molpipeline/abstract_pipeline_elements/mol2mol/filter.py +++ b/molpipeline/abstract_pipeline_elements/mol2mol/filter.py @@ -12,7 +12,81 @@ from molpipeline.utils.value_conversions import count_value_to_tuple -class BasePatternsFilter(MolToMolPipelineElement, abc.ABC): +class BaseKeepMatchesFilter(MolToMolPipelineElement, abc.ABC): + """Filter to keep or remove molecules based on patterns.""" + + keep_matches: bool + mode: Literal["any", "all"] + + def __init__( + self, + keep_matches: bool = True, + mode: Literal["any", "all"] = "any", + name: Optional[str] = None, + n_jobs: int = 1, + uuid: Optional[str] = None, + ) -> None: + """Initialize BasePatternsFilter. + + Parameters + ---------- + keep_matches: bool, optional (default: True) + If True, molecules containing the specified patterns are kept, else removed. + mode: Literal["any", "all"], optional (default: "any") + If "any", at least one of the specified patterns must be present in the molecule. + If "all", all of the specified patterns must be present in the molecule. + name: Optional[str], optional (default: None) + Name of the pipeline element. + n_jobs: int, optional (default: 1) + Number of parallel jobs to use. + uuid: str, optional (default: None) + Unique identifier of the pipeline element. + """ + super().__init__(name=name, n_jobs=n_jobs, uuid=uuid) + self.keep_matches = keep_matches + self.mode = mode + + def set_params(self, **parameters: Any) -> Self: + """Set parameters of BaseKeepMatchesFilter. + + Parameters + ---------- + parameters: Any + Parameters to set. + + Returns + ------- + Self + Self. + """ + parameter_copy = dict(parameters) + if "keep_matches" in parameter_copy: + self.keep_matches = parameter_copy.pop("keep_matches") + if "mode" in parameter_copy: + self.mode = parameter_copy.pop("mode") + super().set_params(**parameter_copy) + return self + + def get_params(self, deep: bool = True) -> dict[str, Any]: + """Get parameters of PatternFilter. + + Parameters + ---------- + deep: bool, optional (default: True) + If True, return the parameters of all subobjects that are PipelineElements. + + Returns + ------- + dict[str, Any] + Parameters of BaseKeepMatchesFilter. + """ + params = super().get_params(deep=deep) + params["keep_matches"] = self.keep_matches + params["mode"] = self.mode + return params + + +class BasePatternsFilter(BaseKeepMatchesFilter, abc.ABC): """Filter to keep or remove molecules based on patterns.""" _patterns: dict[str, tuple[Optional[int], Optional[int]]] @@ -48,10 +122,10 @@ def __init__( uuid: str, optional (default: None) Unique identifier of the pipeline element. """ - super().__init__(name=name, n_jobs=n_jobs, uuid=uuid) + super().__init__( + keep_matches=keep_matches, mode=mode, name=name, n_jobs=n_jobs, uuid=uuid + ) self.patterns = patterns # type: ignore - self.keep_matches = keep_matches - self.mode = mode @property def patterns(self) -> dict[str, tuple[Optional[int], Optional[int]]]: @@ -100,8 +174,6 @@ def get_params(self, deep: bool = True) -> dict[str, Any]: } else: params["patterns"] = self.patterns - params["keep_matches"] = self.keep_matches - params["mode"] = self.mode return params def set_params(self, **parameters: Any) -> Self: @@ -120,9 +192,5 @@ def set_params(self, **parameters: Any) -> Self: parameter_copy = dict(parameters) if "patterns" in parameter_copy: self.patterns = parameter_copy.pop("patterns") - if "keep_matches" in parameter_copy: - self.keep_matches = parameter_copy.pop("keep_matches") - if "mode" in parameter_copy: - self.mode = parameter_copy.pop("mode") super().set_params(**parameter_copy) return self diff --git a/molpipeline/mol2mol/filter.py b/molpipeline/mol2mol/filter.py index f30c34bc..46151aef 100644 --- a/molpipeline/mol2mol/filter.py +++ b/molpipeline/mol2mol/filter.py @@ -17,6 +17,9 @@ from molpipeline.abstract_pipeline_elements.core import ( MolToMolPipelineElement as _MolToMolPipelineElement, ) +from molpipeline.abstract_pipeline_elements.mol2mol import ( + BaseKeepMatchesFilter as _BaseKeepMatchesFilter, +) from molpipeline.abstract_pipeline_elements.mol2mol import ( BasePatternsFilter as _BasePatternsFilter, ) @@ -332,7 +335,7 @@ def pretransform_single(self, value: RDKitMol) -> OptionalMol: ) -class DescriptorsFilter(_MolToMolPipelineElement): +class DescriptorsFilter(_BaseKeepMatchesFilter): """Filter to keep or remove molecules based on RDKit descriptors.""" def __init__( @@ -363,10 +366,10 @@ def __init__( uuid: str, optional (default: None) Unique identifier of the pipeline element. """ - super().__init__(name=name, n_jobs=n_jobs, uuid=uuid) + super().__init__( + keep_matches=keep_matches, mode=mode, name=name, n_jobs=n_jobs, uuid=uuid + ) self.descriptors = descriptors - self.keep_matches = keep_matches - self.mode = mode @property def descriptors(self) -> dict[str, tuple[Optional[float], Optional[float]]]: @@ -411,8 +414,6 @@ def get_params(self, deep: bool = True) -> dict[str, Any]: } else: params["descriptors"] = self.descriptors - params["keep_matches"] = self.keep_matches - params["mode"] = self.mode return params def set_params(self, **parameters: Any) -> Self: @@ -431,10 +432,6 @@ def set_params(self, **parameters: Any) -> Self: parameter_copy = dict(parameters) if "descriptors" in parameter_copy: self.patterns = parameter_copy.pop("descriptors") - if "keep_matches" in parameter_copy: - self.keep_matches = parameter_copy.pop("keep_matches") - if "mode" in parameter_copy: - self.mode = parameter_copy.pop("mode") super().set_params(**parameter_copy) return self From 9fed198bd72db33ddb2ba16c17881038b11147f3 Mon Sep 17 00:00:00 2001 From: "Christian W. Feldmann" <128160984+c-w-feldmann@users.noreply.github.com> Date: Thu, 22 Aug 2024 14:55:24 +0200 Subject: [PATCH 10/25] rewrite filter logic (#71) --- molpipeline/mol2mol/filter.py | 104 ++++++++++++++++++++++------------ 1 file changed, 68 insertions(+), 36 deletions(-) diff --git a/molpipeline/mol2mol/filter.py b/molpipeline/mol2mol/filter.py index 46151aef..e7fbf368 100644 --- a/molpipeline/mol2mol/filter.py +++ b/molpipeline/mol2mol/filter.py @@ -27,6 +27,34 @@ from molpipeline.utils.value_conversions import count_value_to_tuple +def _within_boundaries( + lower_bound: Optional[float], upper_bound: Optional[float], value: float +) -> bool: + """Check if a value is within the specified boundaries. + + Boundaries given as None are ignored. + + Parameters + ---------- + lower_bound: Optional[float] + Lower boundary. + upper_bound: Optional[float] + Upper boundary. + value: float + Value to check. + + Returns + ------- + bool + True if the value is within the boundaries, else False. + """ + if lower_bound is not None and value < lower_bound: + return False + if upper_bound is not None and value > upper_bound: + return False + return True + + class ElementFilter(_MolToMolPipelineElement): """ElementFilter which removes molecules containing chemical elements other than specified. @@ -227,9 +255,7 @@ def pretransform_single(self, value: RDKitMol) -> OptionalMol: match_smarts = smarts_match.GetDescription() all_matches = value.GetSubstructMatches(Chem.MolFromSmarts(match_smarts)) min_count, max_count = self.patterns[match_smarts] - if (min_count is None or len(all_matches) >= min_count) and ( - max_count is None or len(all_matches) <= max_count - ): + if _within_boundaries(min_count, max_count, len(all_matches)): if self.mode == "any": return ( value @@ -290,9 +316,7 @@ def pretransform_single(self, value: RDKitMol) -> OptionalMol: """ for pattern, (min_count, max_count) in self.patterns.items(): all_matches = value.GetSubstructMatches(Chem.MolFromSmiles(pattern)) - if (min_count is None or len(all_matches) >= min_count) and ( - max_count is None or len(all_matches) <= max_count - ): + if _within_boundaries(min_count, max_count, len(all_matches)): if self.mode == "any": return ( value @@ -431,13 +455,19 @@ def set_params(self, **parameters: Any) -> Self: """ parameter_copy = dict(parameters) if "descriptors" in parameter_copy: - self.patterns = parameter_copy.pop("descriptors") + self.descriptors = parameter_copy.pop("descriptors") super().set_params(**parameter_copy) return self def pretransform_single(self, value: RDKitMol) -> OptionalMol: """Invalidate or validate molecule based on specified RDKit descriptors. + There are four possible scenarios: + - Mode = "any" & "keep_matches" = True: Needs to match at least one descriptor. + - Mode = "any" & "keep_matches" = False: Must not match any descriptor. + - Mode = "all" & "keep_matches" = True: Needs to match all descriptors. + - Mode = "all" & "keep_matches" = False: Must not match all descriptors. + Parameters ---------- value: RDKitMol @@ -450,49 +480,51 @@ def pretransform_single(self, value: RDKitMol) -> OptionalMol: """ for descriptor, (min_count, max_count) in self.descriptors.items(): descriptor_value = getattr(Descriptors, descriptor)(value) - if (min_count is None or descriptor_value >= min_count) and ( - max_count is None or descriptor_value <= max_count - ): + + if _within_boundaries(min_count, max_count, descriptor_value): + # For "any" mode we can return early if a match is found if self.mode == "any": - return ( - value - if self.keep_matches - else InvalidInstance( + if not self.keep_matches: + value = InvalidInstance( self.uuid, f"Molecule contains forbidden descriptor {descriptor}.", self.name, ) - ) + return value else: + # For "all" mode we can return early if a match is not found if self.mode == "all": - return ( - value - if not self.keep_matches - else InvalidInstance( + if self.keep_matches: + value = InvalidInstance( self.uuid, - "Molecule does not match all required descriptors.", + f"Molecule does not contain required descriptor {descriptor}.", self.name, ) - ) + return value + + # If this point is reached, no or all patterns were found + # If mode is "any", finishing the loop means no match was found if self.mode == "any": - return ( - value - if not self.keep_matches - else InvalidInstance( + if self.keep_matches: + value = InvalidInstance( self.uuid, - "Molecule does not match any of the DescriptorsFilter descriptors.", + "Molecule does not match any of the required descriptors.", self.name, ) - ) - return ( - value - if self.keep_matches - else InvalidInstance( - self.uuid, - "Molecule does not match all of the DescriptorsFilter descriptors.", - self.name, - ) - ) + # else: No match with forbidden descriptors was found, return original molecule + return value + + if self.mode == "all": + if not self.keep_matches: + value = InvalidInstance( + self.uuid, + "Molecule matches all forbidden descriptors.", + self.name, + ) + # else: All required descriptors were found, return original molecule + return value + + raise ValueError(f"Invalid mode: {self.mode}") class MixtureFilter(_MolToMolPipelineElement): From f49cb7058db5ffc26ac1022e73372fbd791df659 Mon Sep 17 00:00:00 2001 From: frederik-sandfort1 Date: Thu, 22 Aug 2024 17:18:49 +0200 Subject: [PATCH 11/25] Combine filters with one base logic --- .../mol2mol/filter.py | 132 ++++++++- molpipeline/mol2mol/filter.py | 260 ++++-------------- 2 files changed, 179 insertions(+), 213 deletions(-) diff --git a/molpipeline/abstract_pipeline_elements/mol2mol/filter.py b/molpipeline/abstract_pipeline_elements/mol2mol/filter.py index 9be3c42c..f8906dcc 100644 --- a/molpipeline/abstract_pipeline_elements/mol2mol/filter.py +++ b/molpipeline/abstract_pipeline_elements/mol2mol/filter.py @@ -8,10 +8,43 @@ except ImportError: from typing_extensions import Self -from molpipeline.abstract_pipeline_elements.core import MolToMolPipelineElement +from molpipeline.abstract_pipeline_elements.core import ( + InvalidInstance, + MolToMolPipelineElement, + OptionalMol, + RDKitMol, +) from molpipeline.utils.value_conversions import count_value_to_tuple +def _within_boundaries( + lower_bound: Optional[float], upper_bound: Optional[float], value: float +) -> bool: + """Check if a value is within the specified boundaries. + + Boundaries given as None are ignored. + + Parameters + ---------- + lower_bound: Optional[float] + Lower boundary. + upper_bound: Optional[float] + Upper boundary. + value: float + Value to check. + + Returns + ------- + bool + True if the value is within the boundaries, else False. + """ + if lower_bound is not None and value < lower_bound: + return False + if upper_bound is not None and value > upper_bound: + return False + return True + + class BaseKeepMatchesFilter(MolToMolPipelineElement, abc.ABC): """Filter to keep or remove molecules based on patterns.""" @@ -85,6 +118,98 @@ def get_params(self, deep: bool = True) -> dict[str, Any]: params["mode"] = self.mode return params + def pretransform_single(self, value: RDKitMol) -> OptionalMol: + """Invalidate or validate molecule based on specified filter. + + There are four possible scenarios: + - Mode = "any" & "keep_matches" = True: Needs to match at least one filter element. + - Mode = "any" & "keep_matches" = False: Must not match any filter element. + - Mode = "all" & "keep_matches" = True: Needs to match all filter elements. + - Mode = "all" & "keep_matches" = False: Must not match all filter elements. + + Parameters + ---------- + value: RDKitMol + Molecule to check. + + Returns + ------- + OptionalMol + Molecule that matches defined filter elements, else InvalidInstance. + """ + for filter_element, (min_count, max_count) in self.filter_elements.items(): + count = self._calculate_single_element_value(filter_element, value) + if _within_boundaries(min_count, max_count, count): + # For "any" mode we can return early if a match is found + if self.mode == "any": + if not self.keep_matches: + value = InvalidInstance( + self.uuid, + f"Molecule contains forbidden filter element {filter_element}.", + self.name, + ) + return value + else: + # For "all" mode we can return early if a match is not found + if self.mode == "all": + if self.keep_matches: + value = InvalidInstance( + self.uuid, + f"Molecule does not contain required filter element {filter_element}.", + self.name, + ) + return value + + # If this point is reached, no or all patterns were found + # If mode is "any", finishing the loop means no match was found + if self.mode == "any": + if self.keep_matches: + value = InvalidInstance( + self.uuid, + "Molecule does not match any of the required filter elements.", + self.name, + ) + # else: No match with forbidden filter elements was found, return original molecule + return value + + if self.mode == "all": + if not self.keep_matches: + value = InvalidInstance( + self.uuid, + "Molecule matches all forbidden filter elements.", + self.name, + ) + # else: All required filter elements were found, return original molecule + return value + + raise ValueError(f"Invalid mode: {self.mode}") + + @abc.abstractmethod + def _calculate_single_element_value( + self, filter_element: Any, value: RDKitMol + ) -> float: + """Calculate the value of a single match. + + Parameters + ---------- + filter_element: Any + Match case to calculate. + value: RDKitMol + Molecule to calculate the match for. + + Returns + ------- + float + Value of the match. + """ + + @property + @abc.abstractmethod + def filter_elements( + self, + ) -> dict[str, tuple[Optional[Union[float, int]], Optional[Union[float, int]]]]: + """Get filter elements as dict.""" + class BasePatternsFilter(BaseKeepMatchesFilter, abc.ABC): """Filter to keep or remove molecules based on patterns.""" @@ -153,6 +278,11 @@ def patterns( pat: count_value_to_tuple(count) for pat, count in patterns.items() } + @property + def filter_elements(self) -> dict[str, tuple[Optional[int], Optional[int]]]: + """Get filter elements as dict.""" + return self.patterns + def get_params(self, deep: bool = True) -> dict[str, Any]: """Get parameters of PatternFilter. diff --git a/molpipeline/mol2mol/filter.py b/molpipeline/mol2mol/filter.py index e7fbf368..ae230397 100644 --- a/molpipeline/mol2mol/filter.py +++ b/molpipeline/mol2mol/filter.py @@ -11,7 +11,7 @@ from typing_extensions import Self from rdkit import Chem -from rdkit.Chem import Descriptors, FilterCatalog +from rdkit.Chem import Descriptors from molpipeline.abstract_pipeline_elements.core import InvalidInstance from molpipeline.abstract_pipeline_elements.core import ( @@ -27,34 +27,6 @@ from molpipeline.utils.value_conversions import count_value_to_tuple -def _within_boundaries( - lower_bound: Optional[float], upper_bound: Optional[float], value: float -) -> bool: - """Check if a value is within the specified boundaries. - - Boundaries given as None are ignored. - - Parameters - ---------- - lower_bound: Optional[float] - Lower boundary. - upper_bound: Optional[float] - Upper boundary. - value: float - Value to check. - - Returns - ------- - bool - True if the value is within the boundaries, else False. - """ - if lower_bound is not None and value < lower_bound: - return False - if upper_bound is not None and value > upper_bound: - return False - return True - - class ElementFilter(_MolToMolPipelineElement): """ElementFilter which removes molecules containing chemical elements other than specified. @@ -217,146 +189,53 @@ def pretransform_single(self, value: RDKitMol) -> OptionalMol: return value +# should we combine smarts and smiles filter within a single class? option usesmiles? +# should we check the input patterns for valid smarts/smiles? +# should we apply the same logic to ElementFilter? class SmartsFilter(_BasePatternsFilter): """Filter to keep or remove molecules based on SMARTS patterns.""" - @property - def smarts_filter(self) -> FilterCatalog.FilterCatalog: - """Get the SMARTS filter.""" - smarts_matcher_list = [ - FilterCatalog.SmartsMatcher(smarts, smarts) - for i, smarts in enumerate(self.patterns) - ] - rdkit_filter = FilterCatalog.FilterCatalog() - for smarts_matcher in smarts_matcher_list: - if not smarts_matcher.IsValid(): - raise ValueError(f"Invalid SMARTS: {smarts_matcher.GetPattern()}") - entry = FilterCatalog.FilterCatalogEntry( - smarts_matcher.GetName(), smarts_matcher - ) - rdkit_filter.AddEntry(entry) - return rdkit_filter - - def pretransform_single(self, value: RDKitMol) -> OptionalMol: - """Invalidate or validate molecule matching any or all of the specified SMARTS patterns. + def _calculate_single_element_value( + self, filter_element: Any, value: RDKitMol + ) -> float: + """Calculate a single smarts match count for a molecule. Parameters ---------- + filter_element: Any + smarts to calculate match count for. value: RDKitMol - Molecule to check. + Molecule to calculate smarts match count for. Returns ------- - OptionalMol - Molecule that matches defined smarts filter, else InvalidInstance. + float + smarts match count value. """ - match_counts = 0 - for smarts_match in self.smarts_filter.GetMatches(value): - match_smarts = smarts_match.GetDescription() - all_matches = value.GetSubstructMatches(Chem.MolFromSmarts(match_smarts)) - min_count, max_count = self.patterns[match_smarts] - if _within_boundaries(min_count, max_count, len(all_matches)): - if self.mode == "any": - return ( - value - if self.keep_matches - else InvalidInstance( - self.uuid, - f"Molecule contains forbidden SMARTS pattern {match_smarts}.", - self.name, - ) - ) - match_counts += 1 - if self.mode == "any": - return ( - value - if not self.keep_matches - else InvalidInstance( - self.uuid, - "Molecule does not match any of the SmartsFilter patterns.", - self.name, - ) - ) - if match_counts == len(self.patterns): - return ( - value - if self.keep_matches - else InvalidInstance( - self.uuid, - "Molecule matches one of the SmartsFilter patterns.", - self.name, - ) - ) - return ( - value - if not self.keep_matches - else InvalidInstance( - self.uuid, - "Molecule does not match all of the SmartsFilter patterns.", - self.name, - ) - ) + return len(value.GetSubstructMatches(Chem.MolFromSmarts(filter_element))) class SmilesFilter(_BasePatternsFilter): """Filter to keep or remove molecules based on SMILES patterns.""" - def pretransform_single(self, value: RDKitMol) -> OptionalMol: - """Invalidate or validate molecule matching any or all of the specified SMILES patterns. + def _calculate_single_element_value( + self, filter_element: Any, value: RDKitMol + ) -> float: + """Calculate a single smiles match count for a molecule. Parameters ---------- + filter_element: Any + smiles to calculate match count for. value: RDKitMol - Molecule to check. + Molecule to calculate smiles match count for. Returns ------- - OptionalMol - Molecule that matches defined smiles filter, else InvalidInstance. + float + smiles match count value. """ - for pattern, (min_count, max_count) in self.patterns.items(): - all_matches = value.GetSubstructMatches(Chem.MolFromSmiles(pattern)) - if _within_boundaries(min_count, max_count, len(all_matches)): - if self.mode == "any": - return ( - value - if self.keep_matches - else InvalidInstance( - self.uuid, - f"Molecule contains forbidden SMILES pattern {pattern}.", - self.name, - ) - ) - else: - if self.mode == "all": - return ( - value - if not self.keep_matches - else InvalidInstance( - self.uuid, - "Molecule does not match all required patterns.", - self.name, - ) - ) - if self.mode == "any": - return ( - value - if not self.keep_matches - else InvalidInstance( - self.uuid, - "Molecule does not match any of the SmilesFilter patterns.", - self.name, - ) - ) - return ( - value - if self.keep_matches - else InvalidInstance( - self.uuid, - "Molecule does not match all of the SmilesFilter patterns.", - self.name, - ) - ) + return len(value.GetSubstructMatches(Chem.MolFromSmiles(filter_element))) class DescriptorsFilter(_BaseKeepMatchesFilter): @@ -417,6 +296,30 @@ def descriptors( "You are trying to use an invalid descriptor. Use RDKit Descriptors module." ) + @property + def filter_elements(self) -> dict[str, tuple[Optional[float], Optional[float]]]: + """Get filter elements.""" + return self.descriptors + + def _calculate_single_element_value( + self, filter_element: Any, value: RDKitMol + ) -> float: + """Calculate a single descriptor value for a molecule. + + Parameters + ---------- + filter_element: Any + Descriptor to calculate. + value: RDKitMol + Molecule to calculate descriptor for. + + Returns + ------- + float + Descriptor value. + """ + return getattr(Descriptors, filter_element)(value) + def get_params(self, deep: bool = True) -> dict[str, Any]: """Get parameters of DescriptorFilter. @@ -459,73 +362,6 @@ def set_params(self, **parameters: Any) -> Self: super().set_params(**parameter_copy) return self - def pretransform_single(self, value: RDKitMol) -> OptionalMol: - """Invalidate or validate molecule based on specified RDKit descriptors. - - There are four possible scenarios: - - Mode = "any" & "keep_matches" = True: Needs to match at least one descriptor. - - Mode = "any" & "keep_matches" = False: Must not match any descriptor. - - Mode = "all" & "keep_matches" = True: Needs to match all descriptors. - - Mode = "all" & "keep_matches" = False: Must not match all descriptors. - - Parameters - ---------- - value: RDKitMol - Molecule to check. - - Returns - ------- - OptionalMol - Molecule that matches defined descriptors filter, else InvalidInstance. - """ - for descriptor, (min_count, max_count) in self.descriptors.items(): - descriptor_value = getattr(Descriptors, descriptor)(value) - - if _within_boundaries(min_count, max_count, descriptor_value): - # For "any" mode we can return early if a match is found - if self.mode == "any": - if not self.keep_matches: - value = InvalidInstance( - self.uuid, - f"Molecule contains forbidden descriptor {descriptor}.", - self.name, - ) - return value - else: - # For "all" mode we can return early if a match is not found - if self.mode == "all": - if self.keep_matches: - value = InvalidInstance( - self.uuid, - f"Molecule does not contain required descriptor {descriptor}.", - self.name, - ) - return value - - # If this point is reached, no or all patterns were found - # If mode is "any", finishing the loop means no match was found - if self.mode == "any": - if self.keep_matches: - value = InvalidInstance( - self.uuid, - "Molecule does not match any of the required descriptors.", - self.name, - ) - # else: No match with forbidden descriptors was found, return original molecule - return value - - if self.mode == "all": - if not self.keep_matches: - value = InvalidInstance( - self.uuid, - "Molecule matches all forbidden descriptors.", - self.name, - ) - # else: All required descriptors were found, return original molecule - return value - - raise ValueError(f"Invalid mode: {self.mode}") - class MixtureFilter(_MolToMolPipelineElement): """MolToMol which removes molecules composed of multiple fragments.""" From 91feed1ea41f2a2f76923ecd4741d76a8ca65caa Mon Sep 17 00:00:00 2001 From: Christian Feldmann Date: Thu, 22 Aug 2024 19:34:14 +0200 Subject: [PATCH 12/25] change dict to Mapping --- molpipeline/abstract_pipeline_elements/mol2mol/filter.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/molpipeline/abstract_pipeline_elements/mol2mol/filter.py b/molpipeline/abstract_pipeline_elements/mol2mol/filter.py index f8906dcc..cebc9d5d 100644 --- a/molpipeline/abstract_pipeline_elements/mol2mol/filter.py +++ b/molpipeline/abstract_pipeline_elements/mol2mol/filter.py @@ -1,7 +1,7 @@ """Abstract classes for filters.""" import abc -from typing import Any, Literal, Optional, Union +from typing import Any, Literal, Optional, Union, Mapping try: from typing import Self # type: ignore[attr-defined] @@ -207,7 +207,7 @@ def _calculate_single_element_value( @abc.abstractmethod def filter_elements( self, - ) -> dict[str, tuple[Optional[Union[float, int]], Optional[Union[float, int]]]]: + ) -> Mapping[str, tuple[Optional[Union[float, int]], Optional[Union[float, int]]]]: """Get filter elements as dict.""" @@ -279,7 +279,7 @@ def patterns( } @property - def filter_elements(self) -> dict[str, tuple[Optional[int], Optional[int]]]: + def filter_elements(self) -> Mapping[str, tuple[Optional[int], Optional[int]]]: """Get filter elements as dict.""" return self.patterns From 93e6183a8cf3d7e5cd961cb3d634efae993e1412 Mon Sep 17 00:00:00 2001 From: frederik-sandfort1 Date: Mon, 26 Aug 2024 06:56:36 +0200 Subject: [PATCH 13/25] isort --- molpipeline/abstract_pipeline_elements/mol2mol/filter.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/molpipeline/abstract_pipeline_elements/mol2mol/filter.py b/molpipeline/abstract_pipeline_elements/mol2mol/filter.py index cebc9d5d..aeecff1a 100644 --- a/molpipeline/abstract_pipeline_elements/mol2mol/filter.py +++ b/molpipeline/abstract_pipeline_elements/mol2mol/filter.py @@ -1,7 +1,7 @@ """Abstract classes for filters.""" import abc -from typing import Any, Literal, Optional, Union, Mapping +from typing import Any, Literal, Mapping, Optional, Union try: from typing import Self # type: ignore[attr-defined] From cd1831052b278413fd15a0f35880d310b7d2e103 Mon Sep 17 00:00:00 2001 From: frederik-sandfort1 Date: Thu, 12 Sep 2024 06:36:07 +0200 Subject: [PATCH 14/25] Include comments --- .../mol2mol/filter.py | 91 +++++++++++++---- molpipeline/mol2mol/__init__.py | 4 +- molpipeline/mol2mol/filter.py | 97 +++++++++---------- molpipeline/utils/value_conversions.py | 15 ++- .../test_mol2mol/test_mol2mol_filter.py | 44 +++++++-- 5 files changed, 170 insertions(+), 81 deletions(-) diff --git a/molpipeline/abstract_pipeline_elements/mol2mol/filter.py b/molpipeline/abstract_pipeline_elements/mol2mol/filter.py index aeecff1a..b6a92ead 100644 --- a/molpipeline/abstract_pipeline_elements/mol2mol/filter.py +++ b/molpipeline/abstract_pipeline_elements/mol2mol/filter.py @@ -1,7 +1,7 @@ """Abstract classes for filters.""" import abc -from typing import Any, Literal, Mapping, Optional, Union +from typing import Any, Literal, Mapping, Optional, TypeAlias, Union try: from typing import Self # type: ignore[attr-defined] @@ -14,7 +14,15 @@ OptionalMol, RDKitMol, ) -from molpipeline.utils.value_conversions import count_value_to_tuple +from molpipeline.utils.value_conversions import IntCountRange, count_value_to_tuple + + +# possible mode types for a KeepMatchesFilter: +# - "any" means one match is enough +# - "all" means all elements must be matched +FilterModeType: TypeAlias = Literal["any", "all"] + + def _within_boundaries( @@ -46,15 +54,24 @@ def _within_boundaries( class BaseKeepMatchesFilter(MolToMolPipelineElement, abc.ABC): - """Filter to keep or remove molecules based on patterns.""" + """Filter to keep or remove molecules based on patterns. + + Notes + ----- + There are four possible scenarios: + - mode = "any" & keep_matches = True: Needs to match at least one filter element. + - mode = "any" & keep_matches = False: Must not match any filter element. + - mode = "all" & keep_matches = True: Needs to match all filter elements. + - mode = "all" & keep_matches = False: Must not match all filter elements. + """ keep_matches: bool - mode: Literal["any", "all"] + mode: FilterModeType def __init__( self, keep_matches: bool = True, - mode: Literal["any", "all"] = "any", + mode: FilterModeType = "any", name: Optional[str] = None, n_jobs: int = 1, uuid: Optional[str] = None, @@ -65,7 +82,7 @@ def __init__( ---------- keep_matches: bool, optional (default: True) If True, molecules containing the specified patterns are kept, else removed. - mode: Literal["any", "all"], optional (default: "any") + mode: FilterModeType, optional (default: "any") If "any", at least one of the specified patterns must be present in the molecule. If "all", all of the specified patterns must be present in the molecule. name: Optional[str], optional (default: None) @@ -122,10 +139,10 @@ def pretransform_single(self, value: RDKitMol) -> OptionalMol: """Invalidate or validate molecule based on specified filter. There are four possible scenarios: - - Mode = "any" & "keep_matches" = True: Needs to match at least one filter element. - - Mode = "any" & "keep_matches" = False: Must not match any filter element. - - Mode = "all" & "keep_matches" = True: Needs to match all filter elements. - - Mode = "all" & "keep_matches" = False: Must not match all filter elements. + - mode = "any" & keep_matches = True: Needs to match at least one filter element. + - mode = "any" & keep_matches = False: Must not match any filter element. + - mode = "all" & keep_matches = True: Needs to match all filter elements. + - mode = "all" & keep_matches = False: Must not match all filter elements. Parameters ---------- @@ -212,17 +229,25 @@ def filter_elements( class BasePatternsFilter(BaseKeepMatchesFilter, abc.ABC): - """Filter to keep or remove molecules based on patterns.""" + """Filter to keep or remove molecules based on patterns. + + Notes + ----- + There are four possible scenarios: + - mode = "any" & keep_matches = True: Needs to match at least one filter element. + - mode = "any" & keep_matches = False: Must not match any filter element. + - mode = "all" & keep_matches = True: Needs to match all filter elements. + - mode = "all" & keep_matches = False: Must not match all filter elements.""" _patterns: dict[str, tuple[Optional[int], Optional[int]]] def __init__( self, patterns: Union[ - list[str], dict[str, Union[int, tuple[Optional[int], Optional[int]]]] + list[str], dict[str, IntCountRange] ], keep_matches: bool = True, - mode: Literal["any", "all"] = "any", + mode: FilterModeType = "any", name: Optional[str] = None, n_jobs: int = 1, uuid: Optional[str] = None, @@ -231,13 +256,13 @@ def __init__( Parameters ---------- - patterns: Union[list[str], dict[str, Union[int, tuple[Optional[int], Optional[int]]]]] + patterns: Union[list[str], dict[str, CountRange]] List of patterns to allow in molecules. Alternatively, a dictionary can be passed with patterns as keys and an int for exact count or a tuple of minimum and maximum. keep_matches: bool, optional (default: True) If True, molecules containing the specified patterns are kept, else removed. - mode: Literal["any", "all"], optional (default: "any") + mode: FilterModeType, optional (default: "any") If "any", at least one of the specified patterns must be present in the molecule. If "all", all of the specified patterns must be present in the molecule. name: Optional[str], optional (default: None) @@ -261,14 +286,14 @@ def patterns(self) -> dict[str, tuple[Optional[int], Optional[int]]]: def patterns( self, patterns: Union[ - list[str], dict[str, Union[int, tuple[Optional[int], Optional[int]]]] + list[str], dict[str, IntCountRange] ], ) -> None: """Set allowed patterns as dict. Parameters ---------- - patterns: Union[list[str], dict[str, Union[int, tuple[Optional[int], Optional[int]]]]] + patterns: Union[list[str], dict[str, CountRange]] List of patterns. """ if isinstance(patterns, (list, set)): @@ -277,11 +302,43 @@ def patterns( self._patterns = { pat: count_value_to_tuple(count) for pat, count in patterns.items() } + self.patterns_mol_dict = list(self._patterns.keys()) + + @property + def patterns_mol_dict(self) -> Mapping[str, RDKitMol]: + return self._patterns_mol_dict + + @patterns_mol_dict.setter + def patterns_mol_dict(self, patterns: list[str]) -> None: + self._patterns_mol_dict = {pat: self._pattern_to_mol(pat) for pat in patterns} + + @abc.abstractmethod + def _pattern_to_mol(self, pattern: str) -> RDKitMol: + """Function to convert pattern to Rdkitmol object.""" @property def filter_elements(self) -> Mapping[str, tuple[Optional[int], Optional[int]]]: """Get filter elements as dict.""" return self.patterns + + def _calculate_single_element_value( + self, filter_element: Any, value: RDKitMol + ) -> int: + """Calculate a single match count for a molecule. + + Parameters + ---------- + filter_element: Any + smarts to calculate match count for. + value: RDKitMol + Molecule to calculate smarts match count for. + + Returns + ------- + int + smarts match count value. + """ + return len(value.GetSubstructMatches(self.patterns_mol_dict[filter_element])) def get_params(self, deep: bool = True) -> dict[str, Any]: """Get parameters of PatternFilter. diff --git a/molpipeline/mol2mol/__init__.py b/molpipeline/mol2mol/__init__.py index 6114de1b..30c7b3e0 100644 --- a/molpipeline/mol2mol/__init__.py +++ b/molpipeline/mol2mol/__init__.py @@ -1,7 +1,7 @@ """Init the module for mol2mol pipeline elements.""" from molpipeline.mol2mol.filter import ( - DescriptorsFilter, + RDKitDescriptorsFilter, ElementFilter, EmptyMoleculeFilter, InorganicsFilter, @@ -46,5 +46,5 @@ "InorganicsFilter", "SmartsFilter", "SmilesFilter", - "DescriptorsFilter", + "RDKitDescriptorsFilter", ) diff --git a/molpipeline/mol2mol/filter.py b/molpipeline/mol2mol/filter.py index ae230397..ce6f5c25 100644 --- a/molpipeline/mol2mol/filter.py +++ b/molpipeline/mol2mol/filter.py @@ -5,6 +5,8 @@ from collections import Counter from typing import Any, Literal, Optional, Union +from molpipeline.utils.value_conversions import FloatCountRange, IntCountRange + try: from typing import Self # type: ignore[attr-defined] except ImportError: @@ -53,7 +55,7 @@ class ElementFilter(_MolToMolPipelineElement): def __init__( self, allowed_element_numbers: Optional[ - Union[list[int], dict[int, Union[int, tuple[Optional[int], Optional[int]]]]] + Union[list[int], dict[int, IntCountRange]] ] = None, name: str = "ElementFilter", n_jobs: int = 1, @@ -63,7 +65,7 @@ def __init__( Parameters ---------- - allowed_element_numbers: Optional[Union[list[int], dict[int, Union[int, tuple[Optional[int], Optional[int]]]]]] + allowed_element_numbers: Optional[Union[list[int], dict[int, CountRange]]] List of atomic numbers of elements to allowed in molecules. Per default allowed elements are: H, B, C, N, O, F, Si, P, S, Cl, Se, Br, I. Alternatively, a dictionary can be passed with atomic numbers as keys and an int for exact count or a tuple of minimum and maximum @@ -86,14 +88,14 @@ def allowed_element_numbers(self) -> dict[int, tuple[Optional[int], Optional[int def allowed_element_numbers( self, allowed_element_numbers: Optional[ - Union[list[int], dict[int, Union[int, tuple[Optional[int], Optional[int]]]]] + Union[list[int], dict[int, IntCountRange]] ], ) -> None: """Set allowed element numbers as dict. Parameters ---------- - allowed_element_numbers: Optional[Union[list[int], dict[int, Union[int, tuple[Optional[int], Optional[int]]]]] + allowed_element_numbers: Optional[Union[list[int], dict[int, CountRange]] List of atomic numbers of elements to allowed in molecules. """ self._allowed_element_numbers: dict[int, tuple[Optional[int], Optional[int]]] @@ -193,57 +195,52 @@ def pretransform_single(self, value: RDKitMol) -> OptionalMol: # should we check the input patterns for valid smarts/smiles? # should we apply the same logic to ElementFilter? class SmartsFilter(_BasePatternsFilter): - """Filter to keep or remove molecules based on SMARTS patterns.""" - - def _calculate_single_element_value( - self, filter_element: Any, value: RDKitMol - ) -> float: - """Calculate a single smarts match count for a molecule. - - Parameters - ---------- - filter_element: Any - smarts to calculate match count for. - value: RDKitMol - Molecule to calculate smarts match count for. + """Filter to keep or remove molecules based on SMARTS patterns. + + Notes + ----- + There are four possible scenarios: + - mode = "any" & keep_matches = True: Needs to match at least one filter element. + - mode = "any" & keep_matches = False: Must not match any filter element. + - mode = "all" & keep_matches = True: Needs to match all filter elements. + - mode = "all" & keep_matches = False: Must not match all filter elements. + """ - Returns - ------- - float - smarts match count value. - """ - return len(value.GetSubstructMatches(Chem.MolFromSmarts(filter_element))) + def _pattern_to_mol(self, pattern: str) -> RDKitMol: + return Chem.MolFromSmarts(pattern) class SmilesFilter(_BasePatternsFilter): - """Filter to keep or remove molecules based on SMILES patterns.""" - - def _calculate_single_element_value( - self, filter_element: Any, value: RDKitMol - ) -> float: - """Calculate a single smiles match count for a molecule. - - Parameters - ---------- - filter_element: Any - smiles to calculate match count for. - value: RDKitMol - Molecule to calculate smiles match count for. + """Filter to keep or remove molecules based on SMILES patterns. + + Notes + ----- + There are four possible scenarios: + - mode = "any" & keep_matches = True: Needs to match at least one filter element. + - mode = "any" & keep_matches = False: Must not match any filter element. + - mode = "all" & keep_matches = True: Needs to match all filter elements. + - mode = "all" & keep_matches = False: Must not match all filter elements. + """ - Returns - ------- - float - smiles match count value. - """ - return len(value.GetSubstructMatches(Chem.MolFromSmiles(filter_element))) + def _pattern_to_mol(self, pattern: str) -> RDKitMol: + return Chem.MolFromSmiles(pattern) -class DescriptorsFilter(_BaseKeepMatchesFilter): - """Filter to keep or remove molecules based on RDKit descriptors.""" +class RDKitDescriptorsFilter(_BaseKeepMatchesFilter): + """Filter to keep or remove molecules based on RDKit descriptors. + + Notes + ----- + There are four possible scenarios: + - mode = "any" & keep_matches = True: Needs to match at least one filter element. + - mode = "any" & keep_matches = False: Must not match any filter element. + - mode = "all" & keep_matches = True: Needs to match all filter elements. + - mode = "all" & keep_matches = False: Must not match all filter elements. + """ def __init__( self, - descriptors: dict[str, tuple[Optional[float], Optional[float]]], + descriptors: dict[str, FloatCountRange], keep_matches: bool = True, mode: Literal["any", "all"] = "any", name: Optional[str] = None, @@ -254,7 +251,7 @@ def __init__( Parameters ---------- - descriptors: dict[str, tuple[Optional[float], Optional[float]]] + descriptors: dict[str, FloatCountRange] Dictionary of RDKit descriptors to filter by. The value must be a tuple of minimum and maximum. If None, no limit is set. keep_matches: bool, optional (default: True) @@ -275,19 +272,19 @@ def __init__( self.descriptors = descriptors @property - def descriptors(self) -> dict[str, tuple[Optional[float], Optional[float]]]: + def descriptors(self) -> dict[str, FloatCountRange]: """Get allowed descriptors as dict.""" return self._descriptors @descriptors.setter def descriptors( - self, descriptors: dict[str, tuple[Optional[float], Optional[float]]] + self, descriptors: dict[str, FloatCountRange] ) -> None: """Set allowed descriptors as dict. Parameters ---------- - descriptors: dict[str, tuple[Optional[float], Optional[float]]] + descriptors: dict[str, FloatCountRange] Dictionary of RDKit descriptors to filter by. """ self._descriptors = descriptors @@ -297,7 +294,7 @@ def descriptors( ) @property - def filter_elements(self) -> dict[str, tuple[Optional[float], Optional[float]]]: + def filter_elements(self) -> dict[str, FloatCountRange]: """Get filter elements.""" return self.descriptors diff --git a/molpipeline/utils/value_conversions.py b/molpipeline/utils/value_conversions.py index a348c885..711da3df 100644 --- a/molpipeline/utils/value_conversions.py +++ b/molpipeline/utils/value_conversions.py @@ -1,17 +1,24 @@ """Module for utilities converting values.""" -from typing import Optional, Sequence, Union +from typing import Optional, Sequence, TypeAlias, Union +# IntCountRange for Typing of count ranges +# - a single int for an exact value match +# - a range given as a tuple with a lower and upper bound +# - both limits are optional +IntCountRange: TypeAlias = Union[int, tuple[Optional[int], Optional[int]]] + +FloatCountRange: TypeAlias = tuple[Optional[float], Optional[float]] def count_value_to_tuple( - count: Union[int, tuple[Optional[int], Optional[int]]] + count: IntCountRange ) -> tuple[Optional[int], Optional[int]]: """Convert a count value to a tuple. Parameters ---------- - count: Union[int, float, tuple[Optional[int], Optional[int]]] - Count value. Can be a single float or int or a tuple of two values. + count: Union[int, tuple[Optional[int], Optional[int]]] + Count value. Can be a single int or a tuple of two values. Returns ------- diff --git a/tests/test_elements/test_mol2mol/test_mol2mol_filter.py b/tests/test_elements/test_mol2mol/test_mol2mol_filter.py index 1387577d..ee9b63fb 100644 --- a/tests/test_elements/test_mol2mol/test_mol2mol_filter.py +++ b/tests/test_elements/test_mol2mol/test_mol2mol_filter.py @@ -7,13 +7,14 @@ from molpipeline.any2mol import SmilesToMol from molpipeline.mol2any import MolToSmiles from molpipeline.mol2mol import ( - DescriptorsFilter, + RDKitDescriptorsFilter, ElementFilter, InorganicsFilter, MixtureFilter, SmartsFilter, SmilesFilter, ) +from molpipeline.utils.value_conversions import FloatCountRange, IntCountRange # pylint: disable=duplicate-code # test case molecules are allowed to be duplicated SMILES_ANTIMONY = "[SbH6+3]" @@ -74,13 +75,13 @@ def test_element_filter(self) -> None: def test_smarts_smiles_filter(self) -> None: """Test if molecules are filtered correctly by allowed SMARTS patterns.""" - smarts_pats: dict[str, Union[int, tuple[Optional[int], Optional[int]]]] = { + smarts_pats: dict[str, IntCountRange] = { "c": (4, None), "Cl": 1, } smarts_filter = SmartsFilter(smarts_pats) - smiles_pats: dict[str, Union[int, tuple[Optional[int], Optional[int]]]] = { + smiles_pats: dict[str, IntCountRange] = { "c1ccccc1": (1, None), "Cl": 1, } @@ -127,7 +128,7 @@ def test_smarts_smiles_filter(self) -> None: def test_smarts_filter_parallel(self) -> None: """Test if molecules are filtered correctly by allowed SMARTS patterns in parallel.""" - smarts_pats: dict[str, Union[int, tuple[Optional[int], Optional[int]]]] = { + smarts_pats: dict[str, IntCountRange] = { "c": (4, None), "Cl": 1, "cc": (1, None), @@ -138,7 +139,7 @@ def test_smarts_filter_parallel(self) -> None: "c1ccccc1": (1, None), "cCl": 1, } - smarts_filter = SmartsFilter(smarts_pats, mode="all", n_jobs=-1) + smarts_filter = SmartsFilter(smarts_pats, mode="all", n_jobs=2) pipeline = Pipeline( [ ("Smiles2Mol", SmilesToMol()), @@ -152,12 +153,13 @@ def test_smarts_filter_parallel(self) -> None: def test_descriptor_filter(self) -> None: """Test if molecules are filtered correctly by allowed descriptors.""" - descriptors: dict[str, tuple[Optional[float], Optional[float]]] = { + descriptors: dict[str, FloatCountRange] = { "MolWt": (None, 190), "NumHAcceptors": (2, 10), } - descriptor_filter = DescriptorsFilter(descriptors) + + descriptor_filter = RDKitDescriptorsFilter(descriptors) pipeline = Pipeline( [ @@ -176,7 +178,6 @@ def test_descriptor_filter(self) -> None: pipeline.set_params(DescriptorsFilter__keep_matches=False) filtered_smiles_3 = pipeline.fit_transform(SMILES_LIST) - # why is this not self.assertEqual( filtered_smiles_3, [SMILES_ANTIMONY, SMILES_BENZENE, SMILES_CHLOROBENZENE, SMILES_METAL_AU], @@ -186,6 +187,33 @@ def test_descriptor_filter(self) -> None: filtered_smiles_4 = pipeline.fit_transform(SMILES_LIST) self.assertEqual(filtered_smiles_4, []) + pipeline.set_params(DescriptorsFilter__mode = "any", DescriptorsFilter__keep_matches=True) + + pipeline.set_params(DescriptorsFilter__descriptors = { + "NumHAcceptors": (1.99, 4), + }) + result_lower_in_bound = pipeline.fit_transform(SMILES_LIST) + self.assertEqual(result_lower_in_bound, [SMILES_CL_BR]) + + pipeline.set_params(DescriptorsFilter__descriptors = { + "NumHAcceptors": (2.01, 4), + }) + result_lower_out_bound = pipeline.fit_transform(SMILES_LIST) + self.assertEqual(result_lower_out_bound, []) + + pipeline.set_params(DescriptorsFilter__descriptors = { + "NumHAcceptors": (1, 2.01), + }) + result_upper_in_bound = pipeline.fit_transform(SMILES_LIST) + self.assertEqual(result_upper_in_bound, [SMILES_CL_BR]) + + pipeline.set_params(DescriptorsFilter__descriptors = { + "NumHAcceptors": (1, 1.99), + }) + result_upper_out_bound = pipeline.fit_transform(SMILES_LIST) + self.assertEqual(result_upper_out_bound, []) + + def test_invalidate_mixtures(self) -> None: """Test if mixtures are correctly invalidated.""" mol_list = ["CCC.CC.C", "c1ccccc1.[Na+].[Cl-]", "c1ccccc1"] From c0427abca84ed1b94f66928640ea4adfd2402d00 Mon Sep 17 00:00:00 2001 From: frederik-sandfort1 Date: Thu, 12 Sep 2024 06:49:59 +0200 Subject: [PATCH 15/25] linting --- .../mol2mol/filter.py | 51 ++++++++++++------- molpipeline/mol2mol/__init__.py | 2 +- molpipeline/mol2mol/filter.py | 38 ++++++++++---- molpipeline/utils/value_conversions.py | 5 +- .../test_mol2mol/test_mol2mol_filter.py | 41 ++++++++------- 5 files changed, 88 insertions(+), 49 deletions(-) diff --git a/molpipeline/abstract_pipeline_elements/mol2mol/filter.py b/molpipeline/abstract_pipeline_elements/mol2mol/filter.py index b6a92ead..2da2b53b 100644 --- a/molpipeline/abstract_pipeline_elements/mol2mol/filter.py +++ b/molpipeline/abstract_pipeline_elements/mol2mol/filter.py @@ -16,15 +16,12 @@ ) from molpipeline.utils.value_conversions import IntCountRange, count_value_to_tuple - # possible mode types for a KeepMatchesFilter: # - "any" means one match is enough # - "all" means all elements must be matched FilterModeType: TypeAlias = Literal["any", "all"] - - def _within_boundaries( lower_bound: Optional[float], upper_bound: Optional[float], value: float ) -> bool: @@ -55,7 +52,7 @@ def _within_boundaries( class BaseKeepMatchesFilter(MolToMolPipelineElement, abc.ABC): """Filter to keep or remove molecules based on patterns. - + Notes ----- There are four possible scenarios: @@ -230,22 +227,21 @@ def filter_elements( class BasePatternsFilter(BaseKeepMatchesFilter, abc.ABC): """Filter to keep or remove molecules based on patterns. - + Notes ----- There are four possible scenarios: - - mode = "any" & keep_matches = True: Needs to match at least one filter element. - - mode = "any" & keep_matches = False: Must not match any filter element. - - mode = "all" & keep_matches = True: Needs to match all filter elements. - - mode = "all" & keep_matches = False: Must not match all filter elements.""" + - mode = "any" & keep_matches = True: Needs to match at least one filter element. + - mode = "any" & keep_matches = False: Must not match any filter element. + - mode = "all" & keep_matches = True: Needs to match all filter elements. + - mode = "all" & keep_matches = False: Must not match all filter elements. + """ _patterns: dict[str, tuple[Optional[int], Optional[int]]] def __init__( self, - patterns: Union[ - list[str], dict[str, IntCountRange] - ], + patterns: Union[list[str], dict[str, IntCountRange]], keep_matches: bool = True, mode: FilterModeType = "any", name: Optional[str] = None, @@ -285,9 +281,7 @@ def patterns(self) -> dict[str, tuple[Optional[int], Optional[int]]]: @patterns.setter def patterns( self, - patterns: Union[ - list[str], dict[str, IntCountRange] - ], + patterns: Union[list[str], dict[str, IntCountRange]], ) -> None: """Set allowed patterns as dict. @@ -302,25 +296,44 @@ def patterns( self._patterns = { pat: count_value_to_tuple(count) for pat, count in patterns.items() } - self.patterns_mol_dict = list(self._patterns.keys()) + self.patterns_mol_dict = list(self._patterns.keys()) # type: ignore @property def patterns_mol_dict(self) -> Mapping[str, RDKitMol]: + """Get patterns as dict with RDKitMol objects.""" return self._patterns_mol_dict - + @patterns_mol_dict.setter def patterns_mol_dict(self, patterns: list[str]) -> None: + """Set patterns as dict with RDKitMol objects. + + Parameters + ---------- + patterns: list[str] + List of patterns. + """ self._patterns_mol_dict = {pat: self._pattern_to_mol(pat) for pat in patterns} @abc.abstractmethod def _pattern_to_mol(self, pattern: str) -> RDKitMol: - """Function to convert pattern to Rdkitmol object.""" + """Convert pattern to Rdkitmol object. + + Parameters + ---------- + pattern: str + Pattern to convert. + + Returns + ------- + RDKitMol + RDKitMol object of the pattern. + """ @property def filter_elements(self) -> Mapping[str, tuple[Optional[int], Optional[int]]]: """Get filter elements as dict.""" return self.patterns - + def _calculate_single_element_value( self, filter_element: Any, value: RDKitMol ) -> int: diff --git a/molpipeline/mol2mol/__init__.py b/molpipeline/mol2mol/__init__.py index 30c7b3e0..36df054d 100644 --- a/molpipeline/mol2mol/__init__.py +++ b/molpipeline/mol2mol/__init__.py @@ -1,11 +1,11 @@ """Init the module for mol2mol pipeline elements.""" from molpipeline.mol2mol.filter import ( - RDKitDescriptorsFilter, ElementFilter, EmptyMoleculeFilter, InorganicsFilter, MixtureFilter, + RDKitDescriptorsFilter, SmartsFilter, SmilesFilter, ) diff --git a/molpipeline/mol2mol/filter.py b/molpipeline/mol2mol/filter.py index ce6f5c25..3d0b3391 100644 --- a/molpipeline/mol2mol/filter.py +++ b/molpipeline/mol2mol/filter.py @@ -87,9 +87,7 @@ def allowed_element_numbers(self) -> dict[int, tuple[Optional[int], Optional[int @allowed_element_numbers.setter def allowed_element_numbers( self, - allowed_element_numbers: Optional[ - Union[list[int], dict[int, IntCountRange]] - ], + allowed_element_numbers: Optional[Union[list[int], dict[int, IntCountRange]]], ) -> None: """Set allowed element numbers as dict. @@ -196,7 +194,7 @@ def pretransform_single(self, value: RDKitMol) -> OptionalMol: # should we apply the same logic to ElementFilter? class SmartsFilter(_BasePatternsFilter): """Filter to keep or remove molecules based on SMARTS patterns. - + Notes ----- There are four possible scenarios: @@ -207,12 +205,24 @@ class SmartsFilter(_BasePatternsFilter): """ def _pattern_to_mol(self, pattern: str) -> RDKitMol: + """Convert SMARTS pattern to RDKit molecule. + + Parameters + ---------- + pattern: str + SMARTS pattern to convert. + + Returns + ------- + RDKitMol + RDKit molecule. + """ return Chem.MolFromSmarts(pattern) class SmilesFilter(_BasePatternsFilter): """Filter to keep or remove molecules based on SMILES patterns. - + Notes ----- There are four possible scenarios: @@ -223,12 +233,24 @@ class SmilesFilter(_BasePatternsFilter): """ def _pattern_to_mol(self, pattern: str) -> RDKitMol: + """Convert SMILES pattern to RDKit molecule. + + Parameters + ---------- + pattern: str + SMILES pattern to convert. + + Returns + ------- + RDKitMol + RDKit molecule. + """ return Chem.MolFromSmiles(pattern) class RDKitDescriptorsFilter(_BaseKeepMatchesFilter): """Filter to keep or remove molecules based on RDKit descriptors. - + Notes ----- There are four possible scenarios: @@ -277,9 +299,7 @@ def descriptors(self) -> dict[str, FloatCountRange]: return self._descriptors @descriptors.setter - def descriptors( - self, descriptors: dict[str, FloatCountRange] - ) -> None: + def descriptors(self, descriptors: dict[str, FloatCountRange]) -> None: """Set allowed descriptors as dict. Parameters diff --git a/molpipeline/utils/value_conversions.py b/molpipeline/utils/value_conversions.py index 711da3df..fb508276 100644 --- a/molpipeline/utils/value_conversions.py +++ b/molpipeline/utils/value_conversions.py @@ -10,9 +10,8 @@ FloatCountRange: TypeAlias = tuple[Optional[float], Optional[float]] -def count_value_to_tuple( - count: IntCountRange -) -> tuple[Optional[int], Optional[int]]: + +def count_value_to_tuple(count: IntCountRange) -> tuple[Optional[int], Optional[int]]: """Convert a count value to a tuple. Parameters diff --git a/tests/test_elements/test_mol2mol/test_mol2mol_filter.py b/tests/test_elements/test_mol2mol/test_mol2mol_filter.py index ee9b63fb..e08c35a9 100644 --- a/tests/test_elements/test_mol2mol/test_mol2mol_filter.py +++ b/tests/test_elements/test_mol2mol/test_mol2mol_filter.py @@ -1,16 +1,15 @@ """Test MolFilter, which invalidate molecules based on criteria defined in the respective filter.""" import unittest -from typing import Optional, Union from molpipeline import ErrorFilter, FilterReinserter, Pipeline from molpipeline.any2mol import SmilesToMol from molpipeline.mol2any import MolToSmiles from molpipeline.mol2mol import ( - RDKitDescriptorsFilter, ElementFilter, InorganicsFilter, MixtureFilter, + RDKitDescriptorsFilter, SmartsFilter, SmilesFilter, ) @@ -158,7 +157,6 @@ def test_descriptor_filter(self) -> None: "NumHAcceptors": (2, 10), } - descriptor_filter = RDKitDescriptorsFilter(descriptors) pipeline = Pipeline( @@ -187,33 +185,42 @@ def test_descriptor_filter(self) -> None: filtered_smiles_4 = pipeline.fit_transform(SMILES_LIST) self.assertEqual(filtered_smiles_4, []) - pipeline.set_params(DescriptorsFilter__mode = "any", DescriptorsFilter__keep_matches=True) + pipeline.set_params( + DescriptorsFilter__mode="any", DescriptorsFilter__keep_matches=True + ) - pipeline.set_params(DescriptorsFilter__descriptors = { - "NumHAcceptors": (1.99, 4), - }) + pipeline.set_params( + DescriptorsFilter__descriptors={ + "NumHAcceptors": (1.99, 4), + } + ) result_lower_in_bound = pipeline.fit_transform(SMILES_LIST) self.assertEqual(result_lower_in_bound, [SMILES_CL_BR]) - pipeline.set_params(DescriptorsFilter__descriptors = { - "NumHAcceptors": (2.01, 4), - }) + pipeline.set_params( + DescriptorsFilter__descriptors={ + "NumHAcceptors": (2.01, 4), + } + ) result_lower_out_bound = pipeline.fit_transform(SMILES_LIST) self.assertEqual(result_lower_out_bound, []) - pipeline.set_params(DescriptorsFilter__descriptors = { - "NumHAcceptors": (1, 2.01), - }) + pipeline.set_params( + DescriptorsFilter__descriptors={ + "NumHAcceptors": (1, 2.01), + } + ) result_upper_in_bound = pipeline.fit_transform(SMILES_LIST) self.assertEqual(result_upper_in_bound, [SMILES_CL_BR]) - pipeline.set_params(DescriptorsFilter__descriptors = { - "NumHAcceptors": (1, 1.99), - }) + pipeline.set_params( + DescriptorsFilter__descriptors={ + "NumHAcceptors": (1, 1.99), + } + ) result_upper_out_bound = pipeline.fit_transform(SMILES_LIST) self.assertEqual(result_upper_out_bound, []) - def test_invalidate_mixtures(self) -> None: """Test if mixtures are correctly invalidated.""" mol_list = ["CCC.CC.C", "c1ccccc1.[Na+].[Cl-]", "c1ccccc1"] From cfdfd832264be12aba66b4dca185bcd9cb470a74 Mon Sep 17 00:00:00 2001 From: frederik-sandfort1 Date: Thu, 12 Sep 2024 07:32:32 +0200 Subject: [PATCH 16/25] linting and ComplexFilter --- .../mol2mol/filter.py | 8 +- molpipeline/mol2mol/__init__.py | 2 + molpipeline/mol2mol/filter.py | 142 +++++++++++++++++- .../test_mol2mol/test_mol2mol_filter.py | 24 +++ 4 files changed, 167 insertions(+), 9 deletions(-) diff --git a/molpipeline/abstract_pipeline_elements/mol2mol/filter.py b/molpipeline/abstract_pipeline_elements/mol2mol/filter.py index 2da2b53b..3f4efa5d 100644 --- a/molpipeline/abstract_pipeline_elements/mol2mol/filter.py +++ b/molpipeline/abstract_pipeline_elements/mol2mol/filter.py @@ -14,7 +14,11 @@ OptionalMol, RDKitMol, ) -from molpipeline.utils.value_conversions import IntCountRange, count_value_to_tuple +from molpipeline.utils.value_conversions import ( + FloatCountRange, + IntCountRange, + count_value_to_tuple, +) # possible mode types for a KeepMatchesFilter: # - "any" means one match is enough @@ -221,7 +225,7 @@ def _calculate_single_element_value( @abc.abstractmethod def filter_elements( self, - ) -> Mapping[str, tuple[Optional[Union[float, int]], Optional[Union[float, int]]]]: + ) -> Mapping[Any, FloatCountRange]: """Get filter elements as dict.""" diff --git a/molpipeline/mol2mol/__init__.py b/molpipeline/mol2mol/__init__.py index 36df054d..4fa3bd95 100644 --- a/molpipeline/mol2mol/__init__.py +++ b/molpipeline/mol2mol/__init__.py @@ -5,6 +5,7 @@ EmptyMoleculeFilter, InorganicsFilter, MixtureFilter, + ComplexFilter, RDKitDescriptorsFilter, SmartsFilter, SmilesFilter, @@ -47,4 +48,5 @@ "SmartsFilter", "SmilesFilter", "RDKitDescriptorsFilter", + "ComplexFilter", ) diff --git a/molpipeline/mol2mol/filter.py b/molpipeline/mol2mol/filter.py index 3d0b3391..5b8b9b37 100644 --- a/molpipeline/mol2mol/filter.py +++ b/molpipeline/mol2mol/filter.py @@ -5,8 +5,6 @@ from collections import Counter from typing import Any, Literal, Optional, Union -from molpipeline.utils.value_conversions import FloatCountRange, IntCountRange - try: from typing import Self # type: ignore[attr-defined] except ImportError: @@ -26,7 +24,11 @@ BasePatternsFilter as _BasePatternsFilter, ) from molpipeline.utils.molpipeline_types import OptionalMol, RDKitMol -from molpipeline.utils.value_conversions import count_value_to_tuple +from molpipeline.utils.value_conversions import ( + FloatCountRange, + IntCountRange, + count_value_to_tuple, +) class ElementFilter(_MolToMolPipelineElement): @@ -189,9 +191,6 @@ def pretransform_single(self, value: RDKitMol) -> OptionalMol: return value -# should we combine smarts and smiles filter within a single class? option usesmiles? -# should we check the input patterns for valid smarts/smiles? -# should we apply the same logic to ElementFilter? class SmartsFilter(_BasePatternsFilter): """Filter to keep or remove molecules based on SMARTS patterns. @@ -248,6 +247,135 @@ def _pattern_to_mol(self, pattern: str) -> RDKitMol: return Chem.MolFromSmiles(pattern) +class ComplexFilter(_BaseKeepMatchesFilter): + """Filter to keep or remove molecules based on multiple filter elements. + + Notes + ----- + There are four possible scenarios: + - mode = "any" & keep_matches = True: Needs to match at least one filter element. + - mode = "any" & keep_matches = False: Must not match any filter element. + - mode = "all" & keep_matches = True: Needs to match all filter elements. + - mode = "all" & keep_matches = False: Must not match all filter elements. + """ + + def __init__( + self, + filter_elements: tuple[_MolToMolPipelineElement, ...], + keep_matches: bool = True, + mode: Literal["any", "all"] = "any", + name: Optional[str] = None, + n_jobs: int = 1, + uuid: Optional[str] = None, + ) -> None: + """Initialize ComplexFilter. + + Parameters + ---------- + filter_elements: tuple[_MolToMolPipelineElement, ...] + tuple of filter elements. + keep_matches: bool, optional (default: True) + If True, molecules containing the specified patterns are kept, else removed. + mode: Literal["any", "all"], optional (default: "any") + If "any", at least one of the specified patterns must be present in the molecule. + If "all", all of the specified patterns must be present in the molecule. + name: Optional[str], optional (default: None) + Name of the pipeline element. + n_jobs: int, optional (default: 1) + Number of parallel jobs to use. + uuid: str, optional (default: None) + Unique identifier of the pipeline element. + """ + super().__init__( + keep_matches=keep_matches, mode=mode, name=name, n_jobs=n_jobs, uuid=uuid + ) + self.filter_elements = {element: (1, None) for element in filter_elements} + + @property + def filter_elements( + self, + ) -> dict[_MolToMolPipelineElement, tuple[int, Optional[int]]]: + """Get filter elements.""" + return self._filter_elements + + @filter_elements.setter + def filter_elements( + self, filter_elements: dict[_MolToMolPipelineElement, tuple[int, Optional[int]]] + ) -> None: + """Set filter elements. + + Parameters + ---------- + filter_elements: dict[_MolToMolPipelineElement, tuple[int, Optional[int]]] + Filter elements to set. + """ + self._filter_elements = filter_elements + + def _calculate_single_element_value( + self, filter_element: Any, value: RDKitMol + ) -> int: + """Calculate a single filter match for a molecule. + + Parameters + ---------- + filter_element: Any + MolToMol Filter to calculate. + value: RDKitMol + Molecule to calculate filter match for. + + Returns + ------- + int + Filter match. + """ + mol = filter_element.pretransform_single(value) + if isinstance(mol, InvalidInstance): + return 0 + return 1 + + def get_params(self, deep: bool = True) -> dict[str, Any]: + """Get parameters of ComplexFilter. + + Parameters + ---------- + deep: bool, optional (default: True) + If True, return the parameters of all subobjects that are PipelineElements. + + Returns + ------- + dict[str, Any] + Parameters of ComplexFilter. + """ + params = super().get_params(deep=deep) + if deep: + params["filter_elements"] = { + element: (count_tuple[0], count_tuple[1]) + for element, count_tuple in self.filter_elements.items() + } + else: + params["filter_elements"] = self.filter_elements + return params + + def set_params(self, **parameters: Any) -> Self: + """Set parameters of ComplexFilter. + + Parameters + ---------- + parameters: Any + Parameters to set. + + Returns + ------- + Self + Self. + """ + parameter_copy = dict(parameters) + if "filter_elements" in parameter_copy: + self.filter_elements = parameter_copy.pop("filter_elements") + super().set_params(**parameter_copy) + return self + + class RDKitDescriptorsFilter(_BaseKeepMatchesFilter): """Filter to keep or remove molecules based on RDKit descriptors. @@ -361,7 +489,7 @@ def get_params(self, deep: bool = True) -> dict[str, Any]: return params def set_params(self, **parameters: Any) -> Self: - """Set parameters of PatternFilter. + """Set parameters of DescriptorFilter. Parameters ---------- diff --git a/tests/test_elements/test_mol2mol/test_mol2mol_filter.py b/tests/test_elements/test_mol2mol/test_mol2mol_filter.py index e08c35a9..7d274ed3 100644 --- a/tests/test_elements/test_mol2mol/test_mol2mol_filter.py +++ b/tests/test_elements/test_mol2mol/test_mol2mol_filter.py @@ -9,6 +9,7 @@ ElementFilter, InorganicsFilter, MixtureFilter, + ComplexFilter, RDKitDescriptorsFilter, SmartsFilter, SmilesFilter, @@ -72,6 +73,29 @@ def test_element_filter(self) -> None: filtered_smiles_2 = pipeline.fit_transform(SMILES_LIST) self.assertEqual(filtered_smiles_2, [SMILES_BENZENE, SMILES_CHLOROBENZENE]) + def test_multi_element_filter(self) -> None: + """Test if molecules are filtered correctly by allowed chemical elements.""" + element_filter_1 = ElementFilter({6: 6, 1: 6}) + element_filter_2 = ElementFilter({6: 6, 1: 5, 17: 1}) + + multi_element_filter = ComplexFilter((element_filter_1, element_filter_2)) + + pipeline = Pipeline( + [ + ("Smiles2Mol", SmilesToMol()), + ("MultiElementFilter", multi_element_filter), + ("Mol2Smiles", MolToSmiles()), + ("ErrorFilter", ErrorFilter()), + ], + ) + + filtered_smiles = pipeline.fit_transform(SMILES_LIST) + self.assertEqual(filtered_smiles, [SMILES_BENZENE, SMILES_CHLOROBENZENE]) + + pipeline.set_params(MultiElementFilter__mode="all") + filtered_smiles_2 = pipeline.fit_transform(SMILES_LIST) + self.assertEqual(filtered_smiles_2, []) + def test_smarts_smiles_filter(self) -> None: """Test if molecules are filtered correctly by allowed SMARTS patterns.""" smarts_pats: dict[str, IntCountRange] = { From b8436574242005ad9808fd948f90f7addd4a2dd9 Mon Sep 17 00:00:00 2001 From: frederik-sandfort1 Date: Thu, 12 Sep 2024 16:18:01 +0200 Subject: [PATCH 17/25] typing, tests, complex filter naming --- .../mol2mol/filter.py | 15 +++++----- molpipeline/mol2mol/__init__.py | 2 +- molpipeline/mol2mol/filter.py | 15 ++++++---- molpipeline/utils/value_conversions.py | 16 ++++++----- .../test_mol2mol/test_mol2mol_filter.py | 28 +++++++++++++++---- 5 files changed, 49 insertions(+), 27 deletions(-) diff --git a/molpipeline/abstract_pipeline_elements/mol2mol/filter.py b/molpipeline/abstract_pipeline_elements/mol2mol/filter.py index 3f4efa5d..52b64a56 100644 --- a/molpipeline/abstract_pipeline_elements/mol2mol/filter.py +++ b/molpipeline/abstract_pipeline_elements/mol2mol/filter.py @@ -17,6 +17,7 @@ from molpipeline.utils.value_conversions import ( FloatCountRange, IntCountRange, + IntOrIntCountRange, count_value_to_tuple, ) @@ -241,11 +242,11 @@ class BasePatternsFilter(BaseKeepMatchesFilter, abc.ABC): - mode = "all" & keep_matches = False: Must not match all filter elements. """ - _patterns: dict[str, tuple[Optional[int], Optional[int]]] + _patterns: dict[str, IntCountRange] def __init__( self, - patterns: Union[list[str], dict[str, IntCountRange]], + patterns: Union[list[str], dict[str, IntOrIntCountRange]], keep_matches: bool = True, mode: FilterModeType = "any", name: Optional[str] = None, @@ -256,7 +257,7 @@ def __init__( Parameters ---------- - patterns: Union[list[str], dict[str, CountRange]] + patterns: Union[list[str], dict[str, IntOrIntCountRange]] List of patterns to allow in molecules. Alternatively, a dictionary can be passed with patterns as keys and an int for exact count or a tuple of minimum and maximum. @@ -278,20 +279,20 @@ def __init__( self.patterns = patterns # type: ignore @property - def patterns(self) -> dict[str, tuple[Optional[int], Optional[int]]]: + def patterns(self) -> dict[str, IntCountRange]: """Get allowed patterns as dict.""" return self._patterns @patterns.setter def patterns( self, - patterns: Union[list[str], dict[str, IntCountRange]], + patterns: Union[list[str], dict[str, IntOrIntCountRange]], ) -> None: """Set allowed patterns as dict. Parameters ---------- - patterns: Union[list[str], dict[str, CountRange]] + patterns: Union[list[str], dict[str, IntOrIntCountRange]] List of patterns. """ if isinstance(patterns, (list, set)): @@ -334,7 +335,7 @@ def _pattern_to_mol(self, pattern: str) -> RDKitMol: """ @property - def filter_elements(self) -> Mapping[str, tuple[Optional[int], Optional[int]]]: + def filter_elements(self) -> Mapping[str, IntCountRange]: """Get filter elements as dict.""" return self.patterns diff --git a/molpipeline/mol2mol/__init__.py b/molpipeline/mol2mol/__init__.py index 4fa3bd95..7f6ed1ae 100644 --- a/molpipeline/mol2mol/__init__.py +++ b/molpipeline/mol2mol/__init__.py @@ -1,11 +1,11 @@ """Init the module for mol2mol pipeline elements.""" from molpipeline.mol2mol.filter import ( + ComplexFilter, ElementFilter, EmptyMoleculeFilter, InorganicsFilter, MixtureFilter, - ComplexFilter, RDKitDescriptorsFilter, SmartsFilter, SmilesFilter, diff --git a/molpipeline/mol2mol/filter.py b/molpipeline/mol2mol/filter.py index 5b8b9b37..5e46f7e1 100644 --- a/molpipeline/mol2mol/filter.py +++ b/molpipeline/mol2mol/filter.py @@ -27,6 +27,7 @@ from molpipeline.utils.value_conversions import ( FloatCountRange, IntCountRange, + IntOrIntCountRange, count_value_to_tuple, ) @@ -57,7 +58,7 @@ class ElementFilter(_MolToMolPipelineElement): def __init__( self, allowed_element_numbers: Optional[ - Union[list[int], dict[int, IntCountRange]] + Union[list[int], dict[int, IntOrIntCountRange]] ] = None, name: str = "ElementFilter", n_jobs: int = 1, @@ -67,7 +68,7 @@ def __init__( Parameters ---------- - allowed_element_numbers: Optional[Union[list[int], dict[int, CountRange]]] + allowed_element_numbers: Optional[Union[list[int], dict[int, IntOrIntCountRange]]] List of atomic numbers of elements to allowed in molecules. Per default allowed elements are: H, B, C, N, O, F, Si, P, S, Cl, Se, Br, I. Alternatively, a dictionary can be passed with atomic numbers as keys and an int for exact count or a tuple of minimum and maximum @@ -82,23 +83,25 @@ def __init__( self.allowed_element_numbers = allowed_element_numbers # type: ignore @property - def allowed_element_numbers(self) -> dict[int, tuple[Optional[int], Optional[int]]]: + def allowed_element_numbers(self) -> dict[int, IntCountRange]: """Get allowed element numbers as dict.""" return self._allowed_element_numbers @allowed_element_numbers.setter def allowed_element_numbers( self, - allowed_element_numbers: Optional[Union[list[int], dict[int, IntCountRange]]], + allowed_element_numbers: Optional[ + Union[list[int], dict[int, IntOrIntCountRange]] + ], ) -> None: """Set allowed element numbers as dict. Parameters ---------- - allowed_element_numbers: Optional[Union[list[int], dict[int, CountRange]] + allowed_element_numbers: Optional[Union[list[int], dict[int, IntOrIntCountRange]] List of atomic numbers of elements to allowed in molecules. """ - self._allowed_element_numbers: dict[int, tuple[Optional[int], Optional[int]]] + self._allowed_element_numbers: dict[int, IntCountRange] if allowed_element_numbers is None: allowed_element_numbers = self.DEFAULT_ALLOWED_ELEMENT_NUMBERS if isinstance(allowed_element_numbers, (list, set)): diff --git a/molpipeline/utils/value_conversions.py b/molpipeline/utils/value_conversions.py index fb508276..df595a84 100644 --- a/molpipeline/utils/value_conversions.py +++ b/molpipeline/utils/value_conversions.py @@ -2,26 +2,28 @@ from typing import Optional, Sequence, TypeAlias, Union -# IntCountRange for Typing of count ranges +FloatCountRange: TypeAlias = tuple[Optional[float], Optional[float]] + +IntCountRange: TypeAlias = tuple[Optional[int], Optional[int]] + +# IntOrIntCountRange for Typing of count ranges # - a single int for an exact value match # - a range given as a tuple with a lower and upper bound # - both limits are optional -IntCountRange: TypeAlias = Union[int, tuple[Optional[int], Optional[int]]] - -FloatCountRange: TypeAlias = tuple[Optional[float], Optional[float]] +IntOrIntCountRange: TypeAlias = Union[int, IntCountRange] -def count_value_to_tuple(count: IntCountRange) -> tuple[Optional[int], Optional[int]]: +def count_value_to_tuple(count: IntOrIntCountRange) -> IntCountRange: """Convert a count value to a tuple. Parameters ---------- - count: Union[int, tuple[Optional[int], Optional[int]]] + count: Union[int, IntCountRange] Count value. Can be a single int or a tuple of two values. Returns ------- - tuple[Optional[int], Optional[int]] + IntCountRange Tuple of count values. """ if isinstance(count, int): diff --git a/tests/test_elements/test_mol2mol/test_mol2mol_filter.py b/tests/test_elements/test_mol2mol/test_mol2mol_filter.py index 7d274ed3..f4bc8df4 100644 --- a/tests/test_elements/test_mol2mol/test_mol2mol_filter.py +++ b/tests/test_elements/test_mol2mol/test_mol2mol_filter.py @@ -6,15 +6,15 @@ from molpipeline.any2mol import SmilesToMol from molpipeline.mol2any import MolToSmiles from molpipeline.mol2mol import ( + ComplexFilter, ElementFilter, InorganicsFilter, MixtureFilter, - ComplexFilter, RDKitDescriptorsFilter, SmartsFilter, SmilesFilter, ) -from molpipeline.utils.value_conversions import FloatCountRange, IntCountRange +from molpipeline.utils.value_conversions import FloatCountRange, IntOrIntCountRange # pylint: disable=duplicate-code # test case molecules are allowed to be duplicated SMILES_ANTIMONY = "[SbH6+3]" @@ -73,7 +73,7 @@ def test_element_filter(self) -> None: filtered_smiles_2 = pipeline.fit_transform(SMILES_LIST) self.assertEqual(filtered_smiles_2, [SMILES_BENZENE, SMILES_CHLOROBENZENE]) - def test_multi_element_filter(self) -> None: + def test_complex_filter(self) -> None: """Test if molecules are filtered correctly by allowed chemical elements.""" element_filter_1 = ElementFilter({6: 6, 1: 6}) element_filter_2 = ElementFilter({6: 6, 1: 5, 17: 1}) @@ -98,13 +98,13 @@ def test_multi_element_filter(self) -> None: def test_smarts_smiles_filter(self) -> None: """Test if molecules are filtered correctly by allowed SMARTS patterns.""" - smarts_pats: dict[str, IntCountRange] = { + smarts_pats: dict[str, IntOrIntCountRange] = { "c": (4, None), "Cl": 1, } smarts_filter = SmartsFilter(smarts_pats) - smiles_pats: dict[str, IntCountRange] = { + smiles_pats: dict[str, IntOrIntCountRange] = { "c1ccccc1": (1, None), "Cl": 1, } @@ -151,7 +151,7 @@ def test_smarts_smiles_filter(self) -> None: def test_smarts_filter_parallel(self) -> None: """Test if molecules are filtered correctly by allowed SMARTS patterns in parallel.""" - smarts_pats: dict[str, IntCountRange] = { + smarts_pats: dict[str, IntOrIntCountRange] = { "c": (4, None), "Cl": 1, "cc": (1, None), @@ -213,6 +213,14 @@ def test_descriptor_filter(self) -> None: DescriptorsFilter__mode="any", DescriptorsFilter__keep_matches=True ) + pipeline.set_params( + DescriptorsFilter__descriptors={ + "NumHAcceptors": (2.00, 4), + } + ) + result_lower_exact = pipeline.fit_transform(SMILES_LIST) + self.assertEqual(result_lower_exact, [SMILES_CL_BR]) + pipeline.set_params( DescriptorsFilter__descriptors={ "NumHAcceptors": (1.99, 4), @@ -229,6 +237,14 @@ def test_descriptor_filter(self) -> None: result_lower_out_bound = pipeline.fit_transform(SMILES_LIST) self.assertEqual(result_lower_out_bound, []) + pipeline.set_params( + DescriptorsFilter__descriptors={ + "NumHAcceptors": (1, 2.00), + } + ) + result_upper_exact = pipeline.fit_transform(SMILES_LIST) + self.assertEqual(result_upper_exact, [SMILES_CL_BR]) + pipeline.set_params( DescriptorsFilter__descriptors={ "NumHAcceptors": (1, 2.01), From a93344c6dbd0bf151aa182315875013ce6adde54 Mon Sep 17 00:00:00 2001 From: frederik-sandfort1 Date: Thu, 12 Sep 2024 17:26:27 +0200 Subject: [PATCH 18/25] finalize filter refactoring --- .../mol2mol/filter.py | 165 ++++++---------- molpipeline/mol2mol/filter.py | 186 +++--------------- .../test_mol2mol/test_mol2mol_filter.py | 18 +- 3 files changed, 95 insertions(+), 274 deletions(-) diff --git a/molpipeline/abstract_pipeline_elements/mol2mol/filter.py b/molpipeline/abstract_pipeline_elements/mol2mol/filter.py index 52b64a56..951fec8a 100644 --- a/molpipeline/abstract_pipeline_elements/mol2mol/filter.py +++ b/molpipeline/abstract_pipeline_elements/mol2mol/filter.py @@ -1,7 +1,7 @@ """Abstract classes for filters.""" import abc -from typing import Any, Literal, Mapping, Optional, TypeAlias, Union +from typing import Any, Literal, Mapping, Optional, Sequence, TypeAlias, Union try: from typing import Self # type: ignore[attr-defined] @@ -72,6 +72,10 @@ class BaseKeepMatchesFilter(MolToMolPipelineElement, abc.ABC): def __init__( self, + filter_elements: Union[ + Mapping[Any, Union[FloatCountRange, IntCountRange, IntOrIntCountRange]], + Sequence[Any], + ], keep_matches: bool = True, mode: FilterModeType = "any", name: Optional[str] = None, @@ -82,6 +86,10 @@ def __init__( Parameters ---------- + filter_elements: Union[Mapping[Any, Union[FloatCountRange, IntCountRange, IntOrIntCountRange]], Sequence[Any]] + List of filter elements. Typically can be a list of patterns or a dictionary with patterns as keys and + an int for exact count or a tuple of minimum and maximum. + NOTE: for each child class, the type of filter_elements must be specified by the filter_elements setter. keep_matches: bool, optional (default: True) If True, molecules containing the specified patterns are kept, else removed. mode: FilterModeType, optional (default: "any") @@ -95,9 +103,31 @@ def __init__( Unique identifier of the pipeline element. """ super().__init__(name=name, n_jobs=n_jobs, uuid=uuid) + self.filter_elements = filter_elements # type: ignore self.keep_matches = keep_matches self.mode = mode + @property + @abc.abstractmethod + def filter_elements( + self, + ) -> Mapping[Any, FloatCountRange]: + """Get filter elements as dict.""" + + @filter_elements.setter + @abc.abstractmethod + def filter_elements( + self, + filter_elements: Union[Mapping[Any, FloatCountRange], Sequence[Any]], + ) -> None: + """Set filter elements as dict. + + Parameters + ---------- + filter_elements: Union[Mapping[Any, FloatCountRange], Sequence[Any]] + List of filter elements. + """ + def set_params(self, **parameters: Any) -> Self: """Set parameters of BaseKeepMatchesFilter. @@ -116,6 +146,8 @@ def set_params(self, **parameters: Any) -> Self: self.keep_matches = parameter_copy.pop("keep_matches") if "mode" in parameter_copy: self.mode = parameter_copy.pop("mode") + if "filter_elements" in parameter_copy: + self.filter_elements = parameter_copy.pop("filter_elements") super().set_params(**parameter_copy) return self @@ -135,6 +167,13 @@ def get_params(self, deep: bool = True) -> dict[str, Any]: params = super().get_params(deep=deep) params["keep_matches"] = self.keep_matches params["mode"] = self.mode + if deep: + params["filter_elements"] = { + element: (count_tuple[0], count_tuple[1]) + for element, count_tuple in self.filter_elements.items() + } + else: + params["filter_elements"] = self.filter_elements return params def pretransform_single(self, value: RDKitMol) -> OptionalMol: @@ -222,17 +261,18 @@ def _calculate_single_element_value( Value of the match. """ - @property - @abc.abstractmethod - def filter_elements( - self, - ) -> Mapping[Any, FloatCountRange]: - """Get filter elements as dict.""" - class BasePatternsFilter(BaseKeepMatchesFilter, abc.ABC): """Filter to keep or remove molecules based on patterns. + Parameters + ---------- + filter_elements: Union[Sequence[str], Mapping[str, IntOrIntCountRange]] + List of patterns to allow in molecules. + Alternatively, a dictionary can be passed with patterns as keys + and an int for exact count or a tuple of minimum and maximum. + [...] + Notes ----- There are four possible scenarios: @@ -242,66 +282,32 @@ class BasePatternsFilter(BaseKeepMatchesFilter, abc.ABC): - mode = "all" & keep_matches = False: Must not match all filter elements. """ - _patterns: dict[str, IntCountRange] - - def __init__( - self, - patterns: Union[list[str], dict[str, IntOrIntCountRange]], - keep_matches: bool = True, - mode: FilterModeType = "any", - name: Optional[str] = None, - n_jobs: int = 1, - uuid: Optional[str] = None, - ) -> None: - """Initialize BasePatternsFilter. - - Parameters - ---------- - patterns: Union[list[str], dict[str, IntOrIntCountRange]] - List of patterns to allow in molecules. - Alternatively, a dictionary can be passed with patterns as keys - and an int for exact count or a tuple of minimum and maximum. - keep_matches: bool, optional (default: True) - If True, molecules containing the specified patterns are kept, else removed. - mode: FilterModeType, optional (default: "any") - If "any", at least one of the specified patterns must be present in the molecule. - If "all", all of the specified patterns must be present in the molecule. - name: Optional[str], optional (default: None) - Name of the pipeline element. - n_jobs: int, optional (default: 1) - Number of parallel jobs to use. - uuid: str, optional (default: None) - Unique identifier of the pipeline element. - """ - super().__init__( - keep_matches=keep_matches, mode=mode, name=name, n_jobs=n_jobs, uuid=uuid - ) - self.patterns = patterns # type: ignore + _filter_elements: Mapping[str, IntCountRange] @property - def patterns(self) -> dict[str, IntCountRange]: - """Get allowed patterns as dict.""" - return self._patterns + def filter_elements(self) -> Mapping[str, IntCountRange]: + """Get allowed filter elements (patterns) as dict.""" + return self._filter_elements - @patterns.setter - def patterns( + @filter_elements.setter + def filter_elements( self, - patterns: Union[list[str], dict[str, IntOrIntCountRange]], + patterns: Union[list[str], Mapping[str, IntOrIntCountRange]], ) -> None: - """Set allowed patterns as dict. + """Set allowed filter elements (patterns) as dict. Parameters ---------- - patterns: Union[list[str], dict[str, IntOrIntCountRange]] + patterns: Union[list[str], Mapping[str, IntOrIntCountRange]] List of patterns. """ if isinstance(patterns, (list, set)): - self._patterns = {pat: (1, None) for pat in patterns} + self._filter_elements = {pat: (1, None) for pat in patterns} else: - self._patterns = { + self._filter_elements = { pat: count_value_to_tuple(count) for pat, count in patterns.items() } - self.patterns_mol_dict = list(self._patterns.keys()) # type: ignore + self.patterns_mol_dict = list(self._filter_elements.keys()) # type: ignore @property def patterns_mol_dict(self) -> Mapping[str, RDKitMol]: @@ -309,12 +315,12 @@ def patterns_mol_dict(self) -> Mapping[str, RDKitMol]: return self._patterns_mol_dict @patterns_mol_dict.setter - def patterns_mol_dict(self, patterns: list[str]) -> None: + def patterns_mol_dict(self, patterns: Sequence[str]) -> None: """Set patterns as dict with RDKitMol objects. Parameters ---------- - patterns: list[str] + patterns: Sequence[str] List of patterns. """ self._patterns_mol_dict = {pat: self._pattern_to_mol(pat) for pat in patterns} @@ -334,11 +340,6 @@ def _pattern_to_mol(self, pattern: str) -> RDKitMol: RDKitMol object of the pattern. """ - @property - def filter_elements(self) -> Mapping[str, IntCountRange]: - """Get filter elements as dict.""" - return self.patterns - def _calculate_single_element_value( self, filter_element: Any, value: RDKitMol ) -> int: @@ -357,45 +358,3 @@ def _calculate_single_element_value( smarts match count value. """ return len(value.GetSubstructMatches(self.patterns_mol_dict[filter_element])) - - def get_params(self, deep: bool = True) -> dict[str, Any]: - """Get parameters of PatternFilter. - - Parameters - ---------- - deep: bool, optional (default: True) - If True, return the parameters of all subobjects that are PipelineElements. - - Returns - ------- - dict[str, Any] - Parameters of PatternFilter. - """ - params = super().get_params(deep=deep) - if deep: - params["patterns"] = { - pat: (count_tuple[0], count_tuple[1]) - for pat, count_tuple in self.patterns.items() - } - else: - params["patterns"] = self.patterns - return params - - def set_params(self, **parameters: Any) -> Self: - """Set parameters of PatternFilter. - - Parameters - ---------- - parameters: Any - Parameters to set. - - Returns - ------- - Self - Self. - """ - parameter_copy = dict(parameters) - if "patterns" in parameter_copy: - self.patterns = parameter_copy.pop("patterns") - super().set_params(**parameter_copy) - return self diff --git a/molpipeline/mol2mol/filter.py b/molpipeline/mol2mol/filter.py index 5e46f7e1..f16d15d7 100644 --- a/molpipeline/mol2mol/filter.py +++ b/molpipeline/mol2mol/filter.py @@ -3,7 +3,7 @@ from __future__ import annotations from collections import Counter -from typing import Any, Literal, Optional, Union +from typing import Any, Mapping, Optional, Sequence, Union try: from typing import Self # type: ignore[attr-defined] @@ -253,6 +253,12 @@ def _pattern_to_mol(self, pattern: str) -> RDKitMol: class ComplexFilter(_BaseKeepMatchesFilter): """Filter to keep or remove molecules based on multiple filter elements. + Parameters + ---------- + filter_elements: Sequence[_MolToMolPipelineElement] + MolToMol elements to use as filters. + [...] + Notes ----- There are four possible scenarios: @@ -262,48 +268,19 @@ class ComplexFilter(_BaseKeepMatchesFilter): - mode = "all" & keep_matches = False: Must not match all filter elements. """ - def __init__( - self, - filter_elements: tuple[_MolToMolPipelineElement, ...], - keep_matches: bool = True, - mode: Literal["any", "all"] = "any", - name: Optional[str] = None, - n_jobs: int = 1, - uuid: Optional[str] = None, - ) -> None: - """Initialize ComplexFilter. - - Parameters - ---------- - filter_elements: tuple[_MolToMolPipelineElement, ...] - tuple of filter elements. - keep_matches: bool, optional (default: True) - If True, molecules containing the specified patterns are kept, else removed. - mode: Literal["any", "all"], optional (default: "any") - If "any", at least one of the specified patterns must be present in the molecule. - If "all", all of the specified patterns must be present in the molecule. - name: Optional[str], optional (default: None) - Name of the pipeline element. - n_jobs: int, optional (default: 1) - Number of parallel jobs to use. - uuid: str, optional (default: None) - Unique identifier of the pipeline element. - """ - super().__init__( - keep_matches=keep_matches, mode=mode, name=name, n_jobs=n_jobs, uuid=uuid - ) - self.filter_elements = {element: (1, None) for element in filter_elements} + _filter_elements: Mapping[_MolToMolPipelineElement, tuple[int, Optional[int]]] @property def filter_elements( self, - ) -> dict[_MolToMolPipelineElement, tuple[int, Optional[int]]]: + ) -> Mapping[_MolToMolPipelineElement, tuple[int, Optional[int]]]: """Get filter elements.""" return self._filter_elements @filter_elements.setter def filter_elements( - self, filter_elements: dict[_MolToMolPipelineElement, tuple[int, Optional[int]]] + self, + filter_elements: Sequence[_MolToMolPipelineElement], ) -> None: """Set filter elements. @@ -312,7 +289,7 @@ def filter_elements( filter_elements: dict[_MolToMolPipelineElement, tuple[int, Optional[int]]] Filter elements to set. """ - self._filter_elements = filter_elements + self._filter_elements = {element: (1, None) for element in filter_elements} def _calculate_single_element_value( self, filter_element: Any, value: RDKitMol @@ -336,52 +313,17 @@ def _calculate_single_element_value( return 0 return 1 - def get_params(self, deep: bool = True) -> dict[str, Any]: - """Get parameters of ComplexFilter. - - Parameters - ---------- - deep: bool, optional (default: True) - If True, return the parameters of all subobjects that are PipelineElements. - - Returns - ------- - dict[str, Any] - Parameters of ComplexFilter. - """ - params = super().get_params(deep=deep) - if deep: - params["filter_elements"] = { - element: (count_tuple[0], count_tuple[1]) - for element, count_tuple in self.filter_elements.items() - } - else: - params["filter_elements"] = self.filter_elements - return params - - def set_params(self, **parameters: Any) -> Self: - """Set parameters of ComplexFilter. - - Parameters - ---------- - parameters: Any - Parameters to set. - - Returns - ------- - Self - Self. - """ - parameter_copy = dict(parameters) - if "filter_elements" in parameter_copy: - self.filter_elements = parameter_copy.pop("filter_elements") - super().set_params(**parameter_copy) - return self - class RDKitDescriptorsFilter(_BaseKeepMatchesFilter): """Filter to keep or remove molecules based on RDKit descriptors. + Parameters + ---------- + filter_elements: dict[str, FloatCountRange] + Dictionary of RDKit descriptors to filter by. + The value must be a tuple of minimum and maximum. If None, no limit is set. + [...] + Notes ----- There are four possible scenarios: @@ -391,46 +333,13 @@ class RDKitDescriptorsFilter(_BaseKeepMatchesFilter): - mode = "all" & keep_matches = False: Must not match all filter elements. """ - def __init__( - self, - descriptors: dict[str, FloatCountRange], - keep_matches: bool = True, - mode: Literal["any", "all"] = "any", - name: Optional[str] = None, - n_jobs: int = 1, - uuid: Optional[str] = None, - ) -> None: - """Initialize DescriptorsFilter. - - Parameters - ---------- - descriptors: dict[str, FloatCountRange] - Dictionary of RDKit descriptors to filter by. - The value must be a tuple of minimum and maximum. If None, no limit is set. - keep_matches: bool, optional (default: True) - If True, molecules containing the specified descriptors are kept, else removed. - mode: Literal["any", "all"], optional (default: "any") - If "any", at least one of the specified descriptors must be present in the molecule. - If "all", all of the specified descriptors must be present in the molecule. - name: Optional[str], optional (default: None) - Name of the pipeline element. - n_jobs: int, optional (default: 1) - Number of parallel jobs to use. - uuid: str, optional (default: None) - Unique identifier of the pipeline element. - """ - super().__init__( - keep_matches=keep_matches, mode=mode, name=name, n_jobs=n_jobs, uuid=uuid - ) - self.descriptors = descriptors - @property - def descriptors(self) -> dict[str, FloatCountRange]: + def filter_elements(self) -> dict[str, FloatCountRange]: """Get allowed descriptors as dict.""" - return self._descriptors + return self._filter_elements - @descriptors.setter - def descriptors(self, descriptors: dict[str, FloatCountRange]) -> None: + @filter_elements.setter + def filter_elements(self, descriptors: dict[str, FloatCountRange]) -> None: """Set allowed descriptors as dict. Parameters @@ -438,17 +347,12 @@ def descriptors(self, descriptors: dict[str, FloatCountRange]) -> None: descriptors: dict[str, FloatCountRange] Dictionary of RDKit descriptors to filter by. """ - self._descriptors = descriptors + self._filter_elements = descriptors if not all(hasattr(Descriptors, descriptor) for descriptor in descriptors): raise ValueError( "You are trying to use an invalid descriptor. Use RDKit Descriptors module." ) - @property - def filter_elements(self) -> dict[str, FloatCountRange]: - """Get filter elements.""" - return self.descriptors - def _calculate_single_element_value( self, filter_element: Any, value: RDKitMol ) -> float: @@ -468,48 +372,6 @@ def _calculate_single_element_value( """ return getattr(Descriptors, filter_element)(value) - def get_params(self, deep: bool = True) -> dict[str, Any]: - """Get parameters of DescriptorFilter. - - Parameters - ---------- - deep: bool, optional (default: True) - If True, return the parameters of all subobjects that are PipelineElements. - - Returns - ------- - dict[str, Any] - Parameters of DescriptorFilter. - """ - params = super().get_params(deep=deep) - if deep: - params["descriptors"] = { - descriptor: (count_tuple[0], count_tuple[1]) - for descriptor, count_tuple in self.descriptors.items() - } - else: - params["descriptors"] = self.descriptors - return params - - def set_params(self, **parameters: Any) -> Self: - """Set parameters of DescriptorFilter. - - Parameters - ---------- - parameters: Any - Parameters to set. - - Returns - ------- - Self - Self. - """ - parameter_copy = dict(parameters) - if "descriptors" in parameter_copy: - self.descriptors = parameter_copy.pop("descriptors") - super().set_params(**parameter_copy) - return self - class MixtureFilter(_MolToMolPipelineElement): """MolToMol which removes molecules composed of multiple fragments.""" diff --git a/tests/test_elements/test_mol2mol/test_mol2mol_filter.py b/tests/test_elements/test_mol2mol/test_mol2mol_filter.py index f4bc8df4..5e353b7a 100644 --- a/tests/test_elements/test_mol2mol/test_mol2mol_filter.py +++ b/tests/test_elements/test_mol2mol/test_mol2mol_filter.py @@ -111,7 +111,7 @@ def test_smarts_smiles_filter(self) -> None: smiles_filter = SmilesFilter(smiles_pats) for filter_ in [smarts_filter, smiles_filter]: - new_input_as_list = list(filter_.patterns.keys()) + new_input_as_list = list(filter_.filter_elements.keys()) pipeline = Pipeline( [ ("Smiles2Mol", SmilesToMol()), @@ -136,7 +136,7 @@ def test_smarts_smiles_filter(self) -> None: self.assertEqual(filtered_smiles_3, [SMILES_CHLOROBENZENE]) pipeline.set_params( - SmartsFilter__keep_matches=True, SmartsFilter__patterns=["I"] + SmartsFilter__keep_matches=True, SmartsFilter__filter_elements=["I"] ) filtered_smiles_4 = pipeline.fit_transform(SMILES_LIST) self.assertEqual(filtered_smiles_4, []) @@ -144,7 +144,7 @@ def test_smarts_smiles_filter(self) -> None: pipeline.set_params( SmartsFilter__keep_matches=False, SmartsFilter__mode="any", - SmartsFilter__patterns=new_input_as_list, + SmartsFilter__filter_elements=new_input_as_list, ) filtered_smiles_5 = pipeline.fit_transform(SMILES_LIST) self.assertEqual(filtered_smiles_5, [SMILES_ANTIMONY, SMILES_METAL_AU]) @@ -214,7 +214,7 @@ def test_descriptor_filter(self) -> None: ) pipeline.set_params( - DescriptorsFilter__descriptors={ + DescriptorsFilter__filter_elements={ "NumHAcceptors": (2.00, 4), } ) @@ -222,7 +222,7 @@ def test_descriptor_filter(self) -> None: self.assertEqual(result_lower_exact, [SMILES_CL_BR]) pipeline.set_params( - DescriptorsFilter__descriptors={ + DescriptorsFilter__filter_elements={ "NumHAcceptors": (1.99, 4), } ) @@ -230,7 +230,7 @@ def test_descriptor_filter(self) -> None: self.assertEqual(result_lower_in_bound, [SMILES_CL_BR]) pipeline.set_params( - DescriptorsFilter__descriptors={ + DescriptorsFilter__filter_elements={ "NumHAcceptors": (2.01, 4), } ) @@ -238,7 +238,7 @@ def test_descriptor_filter(self) -> None: self.assertEqual(result_lower_out_bound, []) pipeline.set_params( - DescriptorsFilter__descriptors={ + DescriptorsFilter__filter_elements={ "NumHAcceptors": (1, 2.00), } ) @@ -246,7 +246,7 @@ def test_descriptor_filter(self) -> None: self.assertEqual(result_upper_exact, [SMILES_CL_BR]) pipeline.set_params( - DescriptorsFilter__descriptors={ + DescriptorsFilter__filter_elements={ "NumHAcceptors": (1, 2.01), } ) @@ -254,7 +254,7 @@ def test_descriptor_filter(self) -> None: self.assertEqual(result_upper_in_bound, [SMILES_CL_BR]) pipeline.set_params( - DescriptorsFilter__descriptors={ + DescriptorsFilter__filter_elements={ "NumHAcceptors": (1, 1.99), } ) From 47d4d90b6c6c364d47f0603e792a559d1491af6e Mon Sep 17 00:00:00 2001 From: frederik-sandfort1 Date: Tue, 1 Oct 2024 11:25:34 +0200 Subject: [PATCH 19/25] review Christian --- .../mol2mol/filter.py | 30 +-- molpipeline/mol2mol/filter.py | 64 +++-- molpipeline/utils/molpipeline_types.py | 22 +- molpipeline/utils/value_conversions.py | 12 +- .../test_mol2mol/test_mol2mol_filter.py | 221 +++++++++--------- 5 files changed, 201 insertions(+), 148 deletions(-) diff --git a/molpipeline/abstract_pipeline_elements/mol2mol/filter.py b/molpipeline/abstract_pipeline_elements/mol2mol/filter.py index 951fec8a..50caecc5 100644 --- a/molpipeline/abstract_pipeline_elements/mol2mol/filter.py +++ b/molpipeline/abstract_pipeline_elements/mol2mol/filter.py @@ -14,12 +14,12 @@ OptionalMol, RDKitMol, ) -from molpipeline.utils.value_conversions import ( +from molpipeline.utils.molpipeline_types import ( FloatCountRange, IntCountRange, IntOrIntCountRange, - count_value_to_tuple, ) +from molpipeline.utils.value_conversions import count_value_to_tuple # possible mode types for a KeepMatchesFilter: # - "any" means one match is enough @@ -28,7 +28,7 @@ def _within_boundaries( - lower_bound: Optional[float], upper_bound: Optional[float], value: float + lower_bound: Optional[float], upper_bound: Optional[float], property: float ) -> bool: """Check if a value is within the specified boundaries. @@ -40,17 +40,17 @@ def _within_boundaries( Lower boundary. upper_bound: Optional[float] Upper boundary. - value: float - Value to check. + property: float + Property to check. Returns ------- bool True if the value is within the boundaries, else False. """ - if lower_bound is not None and value < lower_bound: + if lower_bound is not None and property < lower_bound: return False - if upper_bound is not None and value > upper_bound: + if upper_bound is not None and property > upper_bound: return False return True @@ -167,13 +167,7 @@ def get_params(self, deep: bool = True) -> dict[str, Any]: params = super().get_params(deep=deep) params["keep_matches"] = self.keep_matches params["mode"] = self.mode - if deep: - params["filter_elements"] = { - element: (count_tuple[0], count_tuple[1]) - for element, count_tuple in self.filter_elements.items() - } - else: - params["filter_elements"] = self.filter_elements + params["filter_elements"] = self.filter_elements return params def pretransform_single(self, value: RDKitMol) -> OptionalMol: @@ -195,9 +189,9 @@ def pretransform_single(self, value: RDKitMol) -> OptionalMol: OptionalMol Molecule that matches defined filter elements, else InvalidInstance. """ - for filter_element, (min_count, max_count) in self.filter_elements.items(): - count = self._calculate_single_element_value(filter_element, value) - if _within_boundaries(min_count, max_count, count): + for filter_element, (lower_limit, upper_limit) in self.filter_elements.items(): + property = self._calculate_single_element_value(filter_element, value) + if _within_boundaries(lower_limit, upper_limit, property): # For "any" mode we can return early if a match is found if self.mode == "any": if not self.keep_matches: @@ -265,7 +259,7 @@ def _calculate_single_element_value( class BasePatternsFilter(BaseKeepMatchesFilter, abc.ABC): """Filter to keep or remove molecules based on patterns. - Parameters + Attributes ---------- filter_elements: Union[Sequence[str], Mapping[str, IntOrIntCountRange]] List of patterns to allow in molecules. diff --git a/molpipeline/mol2mol/filter.py b/molpipeline/mol2mol/filter.py index f16d15d7..6d6b6f8a 100644 --- a/molpipeline/mol2mol/filter.py +++ b/molpipeline/mol2mol/filter.py @@ -5,11 +5,14 @@ from collections import Counter from typing import Any, Mapping, Optional, Sequence, Union +from molpipeline.abstract_pipeline_elements.mol2mol.filter import _within_boundaries + try: from typing import Self # type: ignore[attr-defined] except ImportError: from typing_extensions import Self +from loguru import logger from rdkit import Chem from rdkit.Chem import Descriptors @@ -23,13 +26,14 @@ from molpipeline.abstract_pipeline_elements.mol2mol import ( BasePatternsFilter as _BasePatternsFilter, ) -from molpipeline.utils.molpipeline_types import OptionalMol, RDKitMol -from molpipeline.utils.value_conversions import ( +from molpipeline.utils.molpipeline_types import ( FloatCountRange, IntCountRange, IntOrIntCountRange, - count_value_to_tuple, + OptionalMol, + RDKitMol, ) +from molpipeline.utils.value_conversions import count_value_to_tuple class ElementFilter(_MolToMolPipelineElement): @@ -60,6 +64,7 @@ def __init__( allowed_element_numbers: Optional[ Union[list[int], dict[int, IntOrIntCountRange]] ] = None, + add_hydrogens: bool = True, name: str = "ElementFilter", n_jobs: int = 1, uuid: Optional[str] = None, @@ -72,6 +77,8 @@ def __init__( List of atomic numbers of elements to allowed in molecules. Per default allowed elements are: H, B, C, N, O, F, Si, P, S, Cl, Se, Br, I. Alternatively, a dictionary can be passed with atomic numbers as keys and an int for exact count or a tuple of minimum and maximum + add_hydrogens: bool, optional (default: True) + If True, in case Hydrogens are in allowed_element_list, add hydrogens to the molecule before filtering. name: str, optional (default: "ElementFilterPipe") Name of the pipeline element. n_jobs: int, optional (default: 1) @@ -81,6 +88,32 @@ def __init__( """ super().__init__(name=name, n_jobs=n_jobs, uuid=uuid) self.allowed_element_numbers = allowed_element_numbers # type: ignore + self.add_hydrogens = add_hydrogens + + @property + def add_hydrogens(self) -> bool: + """Get add_hydrogens.""" + return self._add_hydrogens + + @add_hydrogens.setter + def add_hydrogens(self, add_hydrogens: bool) -> None: + """Set add_hydrogens. + + Parameters + ---------- + add_hydrogens: bool + If True, in case Hydrogens are in allowed_element_list, add hydrogens to the molecule before filtering. + """ + self._add_hydrogens = add_hydrogens + if self.add_hydrogens and 1 in self.allowed_element_numbers: + self.process_hydrogens = True + else: + if 1 in self.allowed_element_numbers: + logger.warning( + "Hydrogens are included in allowed_element_numbers, but add_hydrogens is set to False. " + "Thus hydrogens are NOT added before filtering. You might receive unexpected results." + ) + self.process_hydrogens = False @property def allowed_element_numbers(self) -> dict[int, IntCountRange]: @@ -135,6 +168,7 @@ def get_params(self, deep: bool = True) -> dict[str, Any]: } else: params["allowed_element_numbers"] = self.allowed_element_numbers + params["add_hydrogens"] = self.add_hydrogens return params def set_params(self, **parameters: Any) -> Self: @@ -153,6 +187,8 @@ def set_params(self, **parameters: Any) -> Self: parameter_copy = dict(parameters) if "allowed_element_numbers" in parameter_copy: self.allowed_element_numbers = parameter_copy.pop("allowed_element_numbers") + if "add_hydrogens" in parameter_copy: + self.add_hydrogens = parameter_copy.pop("add_hydrogens") super().set_params(**parameter_copy) return self @@ -169,10 +205,7 @@ def pretransform_single(self, value: RDKitMol) -> OptionalMol: OptionalMol Molecule if it contains only allowed elements, else InvalidInstance. """ - to_process_value = ( - Chem.AddHs(value) if 1 in self.allowed_element_numbers else value - ) - + to_process_value = Chem.AddHs(value) if self.process_hydrogens else value elements_list = [atom.GetAtomicNum() for atom in to_process_value.GetAtoms()] elements_counter = Counter(elements_list) if any( @@ -181,11 +214,9 @@ def pretransform_single(self, value: RDKitMol) -> OptionalMol: return InvalidInstance( self.uuid, "Molecule contains forbidden chemical element.", self.name ) - for element, (min_count, max_count) in self.allowed_element_numbers.items(): + for element, (lower_limit, upper_limit) in self.allowed_element_numbers.items(): count = elements_counter[element] - if (min_count is not None and count < min_count) or ( - max_count is not None and count > max_count - ): + if not _within_boundaries(lower_limit, upper_limit, count): return InvalidInstance( self.uuid, f"Molecule contains forbidden number of element {element}.", @@ -225,6 +256,11 @@ def _pattern_to_mol(self, pattern: str) -> RDKitMol: class SmilesFilter(_BasePatternsFilter): """Filter to keep or remove molecules based on SMILES patterns. + In contrast to the SMARTSFilter, which also can match SMILES, the SmilesFilter + sanitizes the molecules and, e.g. checks kekulized bonds for aromaticity and + then sets it to aromatic while the SmartsFilter detects alternating single and + double bonds. + Notes ----- There are four possible scenarios: @@ -253,7 +289,7 @@ def _pattern_to_mol(self, pattern: str) -> RDKitMol: class ComplexFilter(_BaseKeepMatchesFilter): """Filter to keep or remove molecules based on multiple filter elements. - Parameters + Attributes ---------- filter_elements: Sequence[_MolToMolPipelineElement] MolToMol elements to use as filters. @@ -317,7 +353,7 @@ def _calculate_single_element_value( class RDKitDescriptorsFilter(_BaseKeepMatchesFilter): """Filter to keep or remove molecules based on RDKit descriptors. - Parameters + Attributes ---------- filter_elements: dict[str, FloatCountRange] Dictionary of RDKit descriptors to filter by. @@ -347,11 +383,11 @@ def filter_elements(self, descriptors: dict[str, FloatCountRange]) -> None: descriptors: dict[str, FloatCountRange] Dictionary of RDKit descriptors to filter by. """ - self._filter_elements = descriptors if not all(hasattr(Descriptors, descriptor) for descriptor in descriptors): raise ValueError( "You are trying to use an invalid descriptor. Use RDKit Descriptors module." ) + self._filter_elements = descriptors def _calculate_single_element_value( self, filter_element: Any, value: RDKitMol diff --git a/molpipeline/utils/molpipeline_types.py b/molpipeline/utils/molpipeline_types.py index ff59fdec..e17b48b3 100644 --- a/molpipeline/utils/molpipeline_types.py +++ b/molpipeline/utils/molpipeline_types.py @@ -3,7 +3,17 @@ from __future__ import annotations from numbers import Number -from typing import Any, List, Literal, Optional, Protocol, Tuple, TypeVar, Union +from typing import ( + Any, + List, + Literal, + Optional, + Protocol, + Tuple, + TypeAlias, + TypeVar, + Union, +) try: from typing import Self # type: ignore[attr-defined] @@ -47,6 +57,16 @@ TypeConserverdIterable = TypeVar("TypeConserverdIterable", List[_T], npt.NDArray[_T]) +FloatCountRange: TypeAlias = tuple[Optional[float], Optional[float]] + +IntCountRange: TypeAlias = tuple[Optional[int], Optional[int]] + +# IntOrIntCountRange for Typing of count ranges +# - a single int for an exact value match +# - a range given as a tuple with a lower and upper bound +# - both limits are optional +IntOrIntCountRange: TypeAlias = Union[int, IntCountRange] + class AnySklearnEstimator(Protocol): """Protocol for sklearn estimators.""" diff --git a/molpipeline/utils/value_conversions.py b/molpipeline/utils/value_conversions.py index df595a84..4206e97f 100644 --- a/molpipeline/utils/value_conversions.py +++ b/molpipeline/utils/value_conversions.py @@ -1,16 +1,8 @@ """Module for utilities converting values.""" -from typing import Optional, Sequence, TypeAlias, Union +from typing import Sequence -FloatCountRange: TypeAlias = tuple[Optional[float], Optional[float]] - -IntCountRange: TypeAlias = tuple[Optional[int], Optional[int]] - -# IntOrIntCountRange for Typing of count ranges -# - a single int for an exact value match -# - a range given as a tuple with a lower and upper bound -# - both limits are optional -IntOrIntCountRange: TypeAlias = Union[int, IntCountRange] +from molpipeline.utils.molpipeline_types import IntCountRange, IntOrIntCountRange def count_value_to_tuple(count: IntOrIntCountRange) -> IntCountRange: diff --git a/tests/test_elements/test_mol2mol/test_mol2mol_filter.py b/tests/test_elements/test_mol2mol/test_mol2mol_filter.py index 5e353b7a..d814c4e9 100644 --- a/tests/test_elements/test_mol2mol/test_mol2mol_filter.py +++ b/tests/test_elements/test_mol2mol/test_mol2mol_filter.py @@ -14,7 +14,7 @@ SmartsFilter, SmilesFilter, ) -from molpipeline.utils.value_conversions import FloatCountRange, IntOrIntCountRange +from molpipeline.utils.molpipeline_types import FloatCountRange, IntOrIntCountRange # pylint: disable=duplicate-code # test case molecules are allowed to be duplicated SMILES_ANTIMONY = "[SbH6+3]" @@ -63,15 +63,29 @@ def test_element_filter(self) -> None: ("ErrorFilter", ErrorFilter()), ], ) - filtered_smiles = pipeline.fit_transform(SMILES_LIST) - self.assertEqual( - filtered_smiles, [SMILES_BENZENE, SMILES_CHLOROBENZENE, SMILES_CL_BR] - ) - pipeline.set_params( - ElementFilter__allowed_element_numbers={6: 6, 1: (5, 6), 17: (0, 1)} - ) - filtered_smiles_2 = pipeline.fit_transform(SMILES_LIST) - self.assertEqual(filtered_smiles_2, [SMILES_BENZENE, SMILES_CHLOROBENZENE]) + + test_params_list_with_results = [ + { + "params": {}, + "result": [SMILES_BENZENE, SMILES_CHLOROBENZENE, SMILES_CL_BR], + }, + { + "params": { + "ElementFilter__allowed_element_numbers": { + 6: 6, + 1: (5, 6), + 17: (0, 1), + } + }, + "result": [SMILES_BENZENE, SMILES_CHLOROBENZENE], + }, + {"params": {"ElementFilter__add_hydrogens": False}, "result": []}, + ] + + for test_params in test_params_list_with_results: + pipeline.set_params(**test_params["params"]) + filtered_smiles = pipeline.fit_transform(SMILES_LIST) + self.assertEqual(filtered_smiles, test_params["result"]) def test_complex_filter(self) -> None: """Test if molecules are filtered correctly by allowed chemical elements.""" @@ -120,34 +134,44 @@ def test_smarts_smiles_filter(self) -> None: ("ErrorFilter", ErrorFilter()), ], ) - filtered_smiles = pipeline.fit_transform(SMILES_LIST) - self.assertEqual( - filtered_smiles, [SMILES_BENZENE, SMILES_CHLOROBENZENE, SMILES_CL_BR] - ) - - pipeline.set_params(SmartsFilter__keep_matches=False) - filtered_smiles_2 = pipeline.fit_transform(SMILES_LIST) - self.assertEqual(filtered_smiles_2, [SMILES_ANTIMONY, SMILES_METAL_AU]) - pipeline.set_params( - SmartsFilter__mode="all", SmartsFilter__keep_matches=True - ) - filtered_smiles_3 = pipeline.fit_transform(SMILES_LIST) - self.assertEqual(filtered_smiles_3, [SMILES_CHLOROBENZENE]) - - pipeline.set_params( - SmartsFilter__keep_matches=True, SmartsFilter__filter_elements=["I"] - ) - filtered_smiles_4 = pipeline.fit_transform(SMILES_LIST) - self.assertEqual(filtered_smiles_4, []) + test_params_list_with_results = [ + { + "params": {}, + "result": [SMILES_BENZENE, SMILES_CHLOROBENZENE, SMILES_CL_BR], + }, + { + "params": {"SmartsFilter__keep_matches": False}, + "result": [SMILES_ANTIMONY, SMILES_METAL_AU], + }, + { + "params": { + "SmartsFilter__mode": "all", + "SmartsFilter__keep_matches": True, + }, + "result": [SMILES_CHLOROBENZENE], + }, + { + "params": { + "SmartsFilter__keep_matches": True, + "SmartsFilter__filter_elements": ["I"], + }, + "result": [], + }, + { + "params": { + "SmartsFilter__keep_matches": False, + "SmartsFilter__mode": "any", + "SmartsFilter__filter_elements": new_input_as_list, + }, + "result": [SMILES_ANTIMONY, SMILES_METAL_AU], + }, + ] - pipeline.set_params( - SmartsFilter__keep_matches=False, - SmartsFilter__mode="any", - SmartsFilter__filter_elements=new_input_as_list, - ) - filtered_smiles_5 = pipeline.fit_transform(SMILES_LIST) - self.assertEqual(filtered_smiles_5, [SMILES_ANTIMONY, SMILES_METAL_AU]) + for test_params in test_params_list_with_results: + pipeline.set_params(**test_params["params"]) + filtered_smiles = pipeline.fit_transform(SMILES_LIST) + self.assertEqual(filtered_smiles, test_params["result"]) def test_smarts_filter_parallel(self) -> None: """Test if molecules are filtered correctly by allowed SMARTS patterns in parallel.""" @@ -191,75 +215,62 @@ def test_descriptor_filter(self) -> None: ("ErrorFilter", ErrorFilter()), ], ) - filtered_smiles = pipeline.fit_transform(SMILES_LIST) - self.assertEqual(filtered_smiles, SMILES_LIST) - - pipeline.set_params(DescriptorsFilter__mode="all") - filtered_smiles_2 = pipeline.fit_transform(SMILES_LIST) - self.assertEqual(filtered_smiles_2, [SMILES_CL_BR]) - - pipeline.set_params(DescriptorsFilter__keep_matches=False) - filtered_smiles_3 = pipeline.fit_transform(SMILES_LIST) - self.assertEqual( - filtered_smiles_3, - [SMILES_ANTIMONY, SMILES_BENZENE, SMILES_CHLOROBENZENE, SMILES_METAL_AU], - ) - - pipeline.set_params(DescriptorsFilter__mode="any") - filtered_smiles_4 = pipeline.fit_transform(SMILES_LIST) - self.assertEqual(filtered_smiles_4, []) - - pipeline.set_params( - DescriptorsFilter__mode="any", DescriptorsFilter__keep_matches=True - ) - - pipeline.set_params( - DescriptorsFilter__filter_elements={ - "NumHAcceptors": (2.00, 4), - } - ) - result_lower_exact = pipeline.fit_transform(SMILES_LIST) - self.assertEqual(result_lower_exact, [SMILES_CL_BR]) - - pipeline.set_params( - DescriptorsFilter__filter_elements={ - "NumHAcceptors": (1.99, 4), - } - ) - result_lower_in_bound = pipeline.fit_transform(SMILES_LIST) - self.assertEqual(result_lower_in_bound, [SMILES_CL_BR]) - - pipeline.set_params( - DescriptorsFilter__filter_elements={ - "NumHAcceptors": (2.01, 4), - } - ) - result_lower_out_bound = pipeline.fit_transform(SMILES_LIST) - self.assertEqual(result_lower_out_bound, []) - - pipeline.set_params( - DescriptorsFilter__filter_elements={ - "NumHAcceptors": (1, 2.00), - } - ) - result_upper_exact = pipeline.fit_transform(SMILES_LIST) - self.assertEqual(result_upper_exact, [SMILES_CL_BR]) - - pipeline.set_params( - DescriptorsFilter__filter_elements={ - "NumHAcceptors": (1, 2.01), - } - ) - result_upper_in_bound = pipeline.fit_transform(SMILES_LIST) - self.assertEqual(result_upper_in_bound, [SMILES_CL_BR]) - - pipeline.set_params( - DescriptorsFilter__filter_elements={ - "NumHAcceptors": (1, 1.99), - } - ) - result_upper_out_bound = pipeline.fit_transform(SMILES_LIST) - self.assertEqual(result_upper_out_bound, []) + test_params_list_with_results = [ + {"params": {}, "result": SMILES_LIST}, + {"params": {"DescriptorsFilter__mode": "all"}, "result": [SMILES_CL_BR]}, + { + "params": {"DescriptorsFilter__keep_matches": False}, + "result": [ + SMILES_ANTIMONY, + SMILES_BENZENE, + SMILES_CHLOROBENZENE, + SMILES_METAL_AU, + ], + }, + {"params": {"DescriptorsFilter__mode": "any"}, "result": []}, + { + "params": { + "DescriptorsFilter__keep_matches": True, + "DescriptorsFilter__filter_elements": {"NumHAcceptors": (2.00, 4)}, + }, + "result": [SMILES_CL_BR], + }, + { + "params": { + "DescriptorsFilter__filter_elements": {"NumHAcceptors": (1.99, 4)} + }, + "result": [SMILES_CL_BR], + }, + { + "params": { + "DescriptorsFilter__filter_elements": {"NumHAcceptors": (2.01, 4)} + }, + "result": [], + }, + { + "params": { + "DescriptorsFilter__filter_elements": {"NumHAcceptors": (1, 2.00)} + }, + "result": [SMILES_CL_BR], + }, + { + "params": { + "DescriptorsFilter__filter_elements": {"NumHAcceptors": (1, 2.01)} + }, + "result": [SMILES_CL_BR], + }, + { + "params": { + "DescriptorsFilter__filter_elements": {"NumHAcceptors": (1, 1.99)} + }, + "result": [], + }, + ] + + for test_params in test_params_list_with_results: + pipeline.set_params(**test_params["params"]) + filtered_smiles = pipeline.fit_transform(SMILES_LIST) + self.assertEqual(filtered_smiles, test_params["result"]) def test_invalidate_mixtures(self) -> None: """Test if mixtures are correctly invalidated.""" From 1f8dc1c6f6f34b413632f7ea6a8285fd8f67080f Mon Sep 17 00:00:00 2001 From: frederik-sandfort1 Date: Tue, 1 Oct 2024 11:35:23 +0200 Subject: [PATCH 20/25] pylint --- .../abstract_pipeline_elements/mol2mol/filter.py | 14 +++++++------- molpipeline/mol2mol/filter.py | 3 +-- 2 files changed, 8 insertions(+), 9 deletions(-) diff --git a/molpipeline/abstract_pipeline_elements/mol2mol/filter.py b/molpipeline/abstract_pipeline_elements/mol2mol/filter.py index 50caecc5..4142efde 100644 --- a/molpipeline/abstract_pipeline_elements/mol2mol/filter.py +++ b/molpipeline/abstract_pipeline_elements/mol2mol/filter.py @@ -28,7 +28,7 @@ def _within_boundaries( - lower_bound: Optional[float], upper_bound: Optional[float], property: float + lower_bound: Optional[float], upper_bound: Optional[float], property_value: float ) -> bool: """Check if a value is within the specified boundaries. @@ -40,17 +40,17 @@ def _within_boundaries( Lower boundary. upper_bound: Optional[float] Upper boundary. - property: float - Property to check. + property_value: float + Property value to check. Returns ------- bool True if the value is within the boundaries, else False. """ - if lower_bound is not None and property < lower_bound: + if lower_bound is not None and property_value < lower_bound: return False - if upper_bound is not None and property > upper_bound: + if upper_bound is not None and property_value > upper_bound: return False return True @@ -190,8 +190,8 @@ def pretransform_single(self, value: RDKitMol) -> OptionalMol: Molecule that matches defined filter elements, else InvalidInstance. """ for filter_element, (lower_limit, upper_limit) in self.filter_elements.items(): - property = self._calculate_single_element_value(filter_element, value) - if _within_boundaries(lower_limit, upper_limit, property): + property_value = self._calculate_single_element_value(filter_element, value) + if _within_boundaries(lower_limit, upper_limit, property_value): # For "any" mode we can return early if a match is found if self.mode == "any": if not self.keep_matches: diff --git a/molpipeline/mol2mol/filter.py b/molpipeline/mol2mol/filter.py index 6d6b6f8a..5902eee7 100644 --- a/molpipeline/mol2mol/filter.py +++ b/molpipeline/mol2mol/filter.py @@ -5,8 +5,6 @@ from collections import Counter from typing import Any, Mapping, Optional, Sequence, Union -from molpipeline.abstract_pipeline_elements.mol2mol.filter import _within_boundaries - try: from typing import Self # type: ignore[attr-defined] except ImportError: @@ -26,6 +24,7 @@ from molpipeline.abstract_pipeline_elements.mol2mol import ( BasePatternsFilter as _BasePatternsFilter, ) +from molpipeline.abstract_pipeline_elements.mol2mol.filter import _within_boundaries from molpipeline.utils.molpipeline_types import ( FloatCountRange, IntCountRange, From d345191c40a628e2d5bf573acf4bd32c3d6899ef Mon Sep 17 00:00:00 2001 From: frederik-sandfort1 Date: Mon, 7 Oct 2024 14:45:15 +0200 Subject: [PATCH 21/25] include check for failed patterns in init --- .../abstract_pipeline_elements/mol2mol/filter.py | 5 +++++ .../test_mol2mol/test_mol2mol_filter.py | 16 ++++++++++++++++ 2 files changed, 21 insertions(+) diff --git a/molpipeline/abstract_pipeline_elements/mol2mol/filter.py b/molpipeline/abstract_pipeline_elements/mol2mol/filter.py index 4142efde..b19cdb5b 100644 --- a/molpipeline/abstract_pipeline_elements/mol2mol/filter.py +++ b/molpipeline/abstract_pipeline_elements/mol2mol/filter.py @@ -318,6 +318,11 @@ def patterns_mol_dict(self, patterns: Sequence[str]) -> None: List of patterns. """ self._patterns_mol_dict = {pat: self._pattern_to_mol(pat) for pat in patterns} + failed_patterns = [ + pat for pat, mol in self._patterns_mol_dict.items() if not mol + ] + if failed_patterns: + raise ValueError("Invalid pattern(s): " + ", ".join(failed_patterns)) @abc.abstractmethod def _pattern_to_mol(self, pattern: str) -> RDKitMol: diff --git a/tests/test_elements/test_mol2mol/test_mol2mol_filter.py b/tests/test_elements/test_mol2mol/test_mol2mol_filter.py index d814c4e9..ca0f9dd5 100644 --- a/tests/test_elements/test_mol2mol/test_mol2mol_filter.py +++ b/tests/test_elements/test_mol2mol/test_mol2mol_filter.py @@ -173,6 +173,22 @@ def test_smarts_smiles_filter(self) -> None: filtered_smiles = pipeline.fit_transform(SMILES_LIST) self.assertEqual(filtered_smiles, test_params["result"]) + def test_smarts_smiles_filter_wrong_pattern(self) -> None: + """Test if molecules are filtered correctly by allowed SMARTS patterns.""" + smarts_pats: dict[str, IntOrIntCountRange] = { + "cIOnk": (4, None), + "cC": 1, + } + with self.assertRaises(ValueError): + SmartsFilter(smarts_pats) + + smiles_pats: dict[str, IntOrIntCountRange] = { + "cC": (1, None), + "Cl": 1, + } + with self.assertRaises(ValueError): + SmilesFilter(smiles_pats) + def test_smarts_filter_parallel(self) -> None: """Test if molecules are filtered correctly by allowed SMARTS patterns in parallel.""" smarts_pats: dict[str, IntOrIntCountRange] = { From 08d58f4236efeccf544648083a4725397296fe91 Mon Sep 17 00:00:00 2001 From: frederik-sandfort1 Date: Mon, 7 Oct 2024 17:25:20 +0200 Subject: [PATCH 22/25] final review --- molpipeline/mol2mol/filter.py | 116 ++++++++++++++++-- .../test_mol2mol/test_mol2mol_filter.py | 89 ++++++++++++-- 2 files changed, 186 insertions(+), 19 deletions(-) diff --git a/molpipeline/mol2mol/filter.py b/molpipeline/mol2mol/filter.py index 5902eee7..894758ba 100644 --- a/molpipeline/mol2mol/filter.py +++ b/molpipeline/mol2mol/filter.py @@ -24,7 +24,10 @@ from molpipeline.abstract_pipeline_elements.mol2mol import ( BasePatternsFilter as _BasePatternsFilter, ) -from molpipeline.abstract_pipeline_elements.mol2mol.filter import _within_boundaries +from molpipeline.abstract_pipeline_elements.mol2mol.filter import ( + FilterModeType, + _within_boundaries, +) from molpipeline.utils.molpipeline_types import ( FloatCountRange, IntCountRange, @@ -290,8 +293,8 @@ class ComplexFilter(_BaseKeepMatchesFilter): Attributes ---------- - filter_elements: Sequence[_MolToMolPipelineElement] - MolToMol elements to use as filters. + pipeline_filter_elements: Sequence[tuple[str, _MolToMolPipelineElement]] + pairs of unique names and MolToMol elements to use as filters. [...] Notes @@ -303,28 +306,123 @@ class ComplexFilter(_BaseKeepMatchesFilter): - mode = "all" & keep_matches = False: Must not match all filter elements. """ - _filter_elements: Mapping[_MolToMolPipelineElement, tuple[int, Optional[int]]] + _filter_elements: Mapping[str, tuple[int, Optional[int]]] + + def __init__( + self, + pipeline_filter_elements: Sequence[tuple[str, _MolToMolPipelineElement]], + keep_matches: bool = True, + mode: FilterModeType = "any", + name: str | None = None, + n_jobs: int = 1, + uuid: str | None = None, + ) -> None: + """Initialize ComplexFilter. + + Parameters + ---------- + pipeline_filter_elements: Sequence[tuple[str, _MolToMolPipelineElement]] + Filter elements to use. + keep_matches: bool, optional (default: True) + If True, keep matches, else remove matches. + mode: FilterModeType, optional (default: "any") + Mode to filter by. + name: str, optional (default: None) + Name of the pipeline element. + n_jobs: int, optional (default: 1) + Number of parallel jobs to use. + uuid: str, optional (default: None) + Unique identifier of the pipeline element. + """ + self.pipeline_filter_elements = pipeline_filter_elements + super().__init__( + filter_elements=pipeline_filter_elements, + keep_matches=keep_matches, + mode=mode, + name=name, + n_jobs=n_jobs, + uuid=uuid, + ) + + def get_params(self, deep: bool = True) -> dict[str, Any]: + """Get parameters of ComplexFilter. + + Parameters + ---------- + deep: bool, optional (default: True) + If True, return the parameters of all subobjects that are PipelineElements. + + Returns + ------- + dict[str, Any] + Parameters of ComplexFilter. + """ + params = super().get_params(deep) + params.pop("filter_elements") + params["pipeline_filter_elements"] = self.pipeline_filter_elements + if deep: + for name, element in self.pipeline_filter_elements: + deep_items = element.get_params().items() + params.update( + ("pipeline_filter_elements" + "__" + name + "__" + key, val) + for key, val in deep_items + ) + return params + + def set_params(self, **parameters: Any) -> Self: + """Set parameters of ComplexFilter. + + Parameters + ---------- + parameters: Any + Parameters to set. + + Returns + ------- + Self + Self. + """ + parameter_copy = dict(parameters) + if "pipeline_filter_elements" in parameter_copy: + self.pipeline_filter_elements = parameter_copy.pop( + "pipeline_filter_elements" + ) + self.filter_elements = self.pipeline_filter_elements # type: ignore + for key in parameters: + if key.startswith("pipeline_filter_elements__"): + value = parameter_copy.pop(key) + element_name, element_key = key.split("__")[1:] + for name, element in self.pipeline_filter_elements: + if name == element_name: + element.set_params(**{element_key: value}) + super().set_params(**parameter_copy) + return self @property def filter_elements( self, - ) -> Mapping[_MolToMolPipelineElement, tuple[int, Optional[int]]]: + ) -> Mapping[str, tuple[int, Optional[int]]]: """Get filter elements.""" return self._filter_elements @filter_elements.setter def filter_elements( self, - filter_elements: Sequence[_MolToMolPipelineElement], + filter_elements: Sequence[tuple[str, _MolToMolPipelineElement]], ) -> None: """Set filter elements. Parameters ---------- - filter_elements: dict[_MolToMolPipelineElement, tuple[int, Optional[int]]] + filter_elements: Sequence[tuple[str, _MolToMolPipelineElement]] Filter elements to set. """ - self._filter_elements = {element: (1, None) for element in filter_elements} + self.filter_elements_dict = dict(filter_elements) + if not len(self.filter_elements_dict) == len(filter_elements): + raise ValueError("Filter elements names need to be unique.") + self._filter_elements = { + element_name: (1, None) for element_name, _element in filter_elements + } def _calculate_single_element_value( self, filter_element: Any, value: RDKitMol @@ -343,7 +441,7 @@ def _calculate_single_element_value( int Filter match. """ - mol = filter_element.pretransform_single(value) + mol = self.filter_elements_dict[filter_element].pretransform_single(value) if isinstance(mol, InvalidInstance): return 0 return 1 diff --git a/tests/test_elements/test_mol2mol/test_mol2mol_filter.py b/tests/test_elements/test_mol2mol/test_mol2mol_filter.py index ca0f9dd5..2c100b62 100644 --- a/tests/test_elements/test_mol2mol/test_mol2mol_filter.py +++ b/tests/test_elements/test_mol2mol/test_mol2mol_filter.py @@ -14,6 +14,7 @@ SmartsFilter, SmilesFilter, ) +from molpipeline.utils.json_operations import recursive_from_json, recursive_to_json from molpipeline.utils.molpipeline_types import FloatCountRange, IntOrIntCountRange # pylint: disable=duplicate-code # test case molecules are allowed to be duplicated @@ -32,8 +33,8 @@ ] -class MolFilterTest(unittest.TestCase): - """Unittest for MolFilter, which invalidate molecules based on criteria defined in the respective filter.""" +class ElementFilterTest(unittest.TestCase): + """Unittest for Elementiflter.""" def test_element_filter(self) -> None: """Test if molecules are filtered correctly by allowed chemical elements.""" @@ -87,12 +88,22 @@ def test_element_filter(self) -> None: filtered_smiles = pipeline.fit_transform(SMILES_LIST) self.assertEqual(filtered_smiles, test_params["result"]) - def test_complex_filter(self) -> None: - """Test if molecules are filtered correctly by allowed chemical elements.""" + +class ComplexFilterTest(unittest.TestCase): + """Unittest for ComplexFilter.""" + + @staticmethod + def _create_pipeline(): + """Create a pipeline with a complex filter.""" element_filter_1 = ElementFilter({6: 6, 1: 6}) element_filter_2 = ElementFilter({6: 6, 1: 5, 17: 1}) - multi_element_filter = ComplexFilter((element_filter_1, element_filter_2)) + multi_element_filter = ComplexFilter( + ( + ("element_filter_1", element_filter_1), + ("element_filter_2", element_filter_2), + ) + ) pipeline = Pipeline( [ @@ -102,13 +113,59 @@ def test_complex_filter(self) -> None: ("ErrorFilter", ErrorFilter()), ], ) + return pipeline + + def test_complex_filter(self) -> None: + """Test if molecules are filtered correctly by allowed chemical elements.""" + pipeline = ComplexFilterTest._create_pipeline() + + test_params_list_with_results = [ + { + "params": {}, + "result": [SMILES_BENZENE, SMILES_CHLOROBENZENE], + }, + { + "params": {"MultiElementFilter__mode": "all"}, + "result": [], + }, + { + "params": { + "MultiElementFilter__mode": "any", + "MultiElementFilter__pipeline_filter_elements__element_filter_1__add_hydrogens": False, + }, + "result": [SMILES_CHLOROBENZENE], + }, + ] + + for test_params in test_params_list_with_results: + pipeline.set_params(**test_params["params"]) + filtered_smiles = pipeline.fit_transform(SMILES_LIST) + self.assertEqual(filtered_smiles, test_params["result"]) + + def test_json_serialization(self) -> None: + """Test if complex filter can be serialized and deserialized.""" + pipeline = ComplexFilterTest._create_pipeline() + json_object = recursive_to_json(pipeline) + newpipeline = recursive_from_json(json_object) + self.assertEqual(json_object, recursive_to_json(newpipeline)) + + pipeline_result = pipeline.fit_transform(SMILES_LIST) + newpipeline_result = newpipeline.fit_transform(SMILES_LIST) + self.assertEqual(pipeline_result, newpipeline_result) + + def test_complex_filter_non_unique_names(self) -> None: + """Test if molecules are filtered correctly by allowed chemical elements.""" + element_filter_1 = ElementFilter({6: 6, 1: 6}) + element_filter_2 = ElementFilter({6: 6, 1: 5, 17: 1}) + + with self.assertRaises(ValueError): + ComplexFilter( + (("filter_1", element_filter_1), ("filter_1", element_filter_2)) + ) - filtered_smiles = pipeline.fit_transform(SMILES_LIST) - self.assertEqual(filtered_smiles, [SMILES_BENZENE, SMILES_CHLOROBENZENE]) - pipeline.set_params(MultiElementFilter__mode="all") - filtered_smiles_2 = pipeline.fit_transform(SMILES_LIST) - self.assertEqual(filtered_smiles_2, []) +class SmartsSmilesFilterTest(unittest.TestCase): + """Unittest for SmartsFilter and SmilesFilter.""" def test_smarts_smiles_filter(self) -> None: """Test if molecules are filtered correctly by allowed SMARTS patterns.""" @@ -214,6 +271,10 @@ def test_smarts_filter_parallel(self) -> None: filtered_smiles = pipeline.fit_transform(SMILES_LIST) self.assertEqual(filtered_smiles, [SMILES_CHLOROBENZENE]) + +class RDKitDescriptorsFilterTest(unittest.TestCase): + """Unittest for RDKitDescriptorsFilter.""" + def test_descriptor_filter(self) -> None: """Test if molecules are filtered correctly by allowed descriptors.""" descriptors: dict[str, FloatCountRange] = { @@ -288,6 +349,10 @@ def test_descriptor_filter(self) -> None: filtered_smiles = pipeline.fit_transform(SMILES_LIST) self.assertEqual(filtered_smiles, test_params["result"]) + +class MixtureFilterTest(unittest.TestCase): + """Unittest for MixtureFilter.""" + def test_invalidate_mixtures(self) -> None: """Test if mixtures are correctly invalidated.""" mol_list = ["CCC.CC.C", "c1ccccc1.[Na+].[Cl-]", "c1ccccc1"] @@ -311,6 +376,10 @@ def test_invalidate_mixtures(self) -> None: mols_processed = pipeline.fit_transform(mol_list) self.assertEqual(expected_invalidated_mol_list, mols_processed) + +class InorganicsFilterTest(unittest.TestCase): + """Unittest for InorganicsFilter.""" + def test_inorganic_filter(self) -> None: """Test if molecules are filtered correctly by allowed chemical elements.""" smiles2mol = SmilesToMol() From ede7d3914908f055c1438656b95f4354df6a036b Mon Sep 17 00:00:00 2001 From: frederik-sandfort1 Date: Mon, 7 Oct 2024 17:30:34 +0200 Subject: [PATCH 23/25] final linting --- tests/test_elements/test_mol2mol/test_mol2mol_filter.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_elements/test_mol2mol/test_mol2mol_filter.py b/tests/test_elements/test_mol2mol/test_mol2mol_filter.py index 2c100b62..d1fd6e18 100644 --- a/tests/test_elements/test_mol2mol/test_mol2mol_filter.py +++ b/tests/test_elements/test_mol2mol/test_mol2mol_filter.py @@ -93,7 +93,7 @@ class ComplexFilterTest(unittest.TestCase): """Unittest for ComplexFilter.""" @staticmethod - def _create_pipeline(): + def _create_pipeline() -> None: """Create a pipeline with a complex filter.""" element_filter_1 = ElementFilter({6: 6, 1: 6}) element_filter_2 = ElementFilter({6: 6, 1: 5, 17: 1}) From 222675c5b0df186490af8204425cd400adee3cbc Mon Sep 17 00:00:00 2001 From: frederik-sandfort1 Date: Mon, 7 Oct 2024 17:34:59 +0200 Subject: [PATCH 24/25] final final linting --- .../test_elements/test_mol2mol/test_mol2mol_filter.py | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/tests/test_elements/test_mol2mol/test_mol2mol_filter.py b/tests/test_elements/test_mol2mol/test_mol2mol_filter.py index d1fd6e18..3d54b03d 100644 --- a/tests/test_elements/test_mol2mol/test_mol2mol_filter.py +++ b/tests/test_elements/test_mol2mol/test_mol2mol_filter.py @@ -93,8 +93,14 @@ class ComplexFilterTest(unittest.TestCase): """Unittest for ComplexFilter.""" @staticmethod - def _create_pipeline() -> None: - """Create a pipeline with a complex filter.""" + def _create_pipeline() -> Pipeline: + """Create a pipeline with a complex filter. + + Returns + ------- + Pipeline + Pipeline with a complex filter. + """ element_filter_1 = ElementFilter({6: 6, 1: 6}) element_filter_2 = ElementFilter({6: 6, 1: 5, 17: 1}) From 30e57011724292a4095c08bdf1febe9f663ab6ad Mon Sep 17 00:00:00 2001 From: frederik-sandfort1 Date: Mon, 7 Oct 2024 17:41:52 +0200 Subject: [PATCH 25/25] final final final linting --- tests/test_elements/test_mol2mol/test_mol2mol_filter.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_elements/test_mol2mol/test_mol2mol_filter.py b/tests/test_elements/test_mol2mol/test_mol2mol_filter.py index 3d54b03d..8a0d5c0c 100644 --- a/tests/test_elements/test_mol2mol/test_mol2mol_filter.py +++ b/tests/test_elements/test_mol2mol/test_mol2mol_filter.py @@ -95,7 +95,7 @@ class ComplexFilterTest(unittest.TestCase): @staticmethod def _create_pipeline() -> Pipeline: """Create a pipeline with a complex filter. - + Returns ------- Pipeline