diff --git a/molpipeline/any2mol/smiles2mol.py b/molpipeline/any2mol/smiles2mol.py index 0d7c45e6..34d0d31d 100644 --- a/molpipeline/any2mol/smiles2mol.py +++ b/molpipeline/any2mol/smiles2mol.py @@ -2,6 +2,13 @@ from __future__ import annotations +from typing import Any, Optional + +try: + from typing import Self # type: ignore[attr-defined] +except ImportError: + from typing_extensions import Self + from rdkit import Chem from molpipeline.abstract_pipeline_elements.any2mol.string2mol import ( @@ -13,6 +20,42 @@ class SmilesToMol(SimpleStringToMolElement): """Transforms Smiles to RDKit Mol objects.""" + def __init__( + self, + remove_hydrogens: bool = True, + name: str = "smiles2mol", + n_jobs: int = 1, + uuid: Optional[str] = None, + ) -> None: + """Initialize SmilesToMol object. + + Parameters + ---------- + remove_hydrogens: bool + Whether to remove hydrogens from the molecule. + name: str + Name of the object. + n_jobs: int + Number of jobs to run in parallel. + uuid: Optional[str] + UUID of the object. + """ + super().__init__(name=name, n_jobs=n_jobs, uuid=uuid) + self._remove_hydrogens = remove_hydrogens + + def _get_parser_config(self) -> Chem.SmilesParserParams: + """Get parser configuration. + + Returns + ------- + dict[str, Any] + Configuration for the parser. + """ + # set up rdkit smiles parser parameters + parser_params = Chem.SmilesParserParams() + parser_params.removeHs = self._remove_hydrogens + return parser_params + def string_to_mol(self, value: str) -> RDKitMol: """Transform Smiles string to molecule. @@ -26,4 +69,45 @@ def string_to_mol(self, value: str) -> RDKitMol: RDKitMol Rdkit molecule if valid SMILES, else None. """ - return Chem.MolFromSmiles(value) + return Chem.MolFromSmiles(value, self._get_parser_config()) + + def get_params(self, deep: bool = True) -> dict[str, Any]: + """Get parameters for this object. + + Parameters + ---------- + deep: bool + If True, return a deep copy of the parameters. + + Returns + ------- + dict[str, Any] + Dictionary of parameters. + """ + parameters = super().get_params(deep) + if deep: + parameters["remove_hydrogens"] = bool(self._remove_hydrogens) + + else: + parameters["remove_hydrogens"] = self._remove_hydrogens + return parameters + + def set_params(self, **parameters: Any) -> Self: + """Set parameters. + + Parameters + ---------- + parameters: Any + Dictionary of parameter names and values. + + Returns + ------- + Self + SmilesToMol pipeline element with updated parameters. + """ + parameter_copy = dict(parameters) + remove_hydrogens = parameter_copy.pop("remove_hydrogens", None) + if remove_hydrogens is not None: + self._remove_hydrogens = remove_hydrogens + super().set_params(**parameter_copy) + return self diff --git a/tests/test_elements/test_any2mol/test_smiles2mol.py b/tests/test_elements/test_any2mol/test_smiles2mol.py new file mode 100644 index 00000000..97edd6d1 --- /dev/null +++ b/tests/test_elements/test_any2mol/test_smiles2mol.py @@ -0,0 +1,53 @@ +"""Test smiles to mol pipeline element.""" + +import unittest +from typing import Any + +from molpipeline import Pipeline +from molpipeline.any2mol import SmilesToMol + + +class TestSmiles2Mol(unittest.TestCase): + """Test case for testing conversion of SMILES input to molecules.""" + + def test_smiles2mol_explict_hydrogens(self) -> None: + """Test smiles reading with and without explicit smiles.""" + smiles = "C[H]" + + # test: remove explicit Hs + pipeline = Pipeline( + [ + ( + "Smiles2Mol", + SmilesToMol(remove_hydrogens=True), + ), + ] + ) + mols = pipeline.fit_transform([smiles]) + self.assertEqual(len(mols), 1) + self.assertIsNotNone(mols[0]) + self.assertEqual(mols[0].GetNumAtoms(), 1) + + # test: keep explicit Hs + pipeline2 = Pipeline( + [ + ( + "Smiles2Mol", + SmilesToMol(remove_hydrogens=False), + ), + ] + ) + mols2 = pipeline2.fit_transform([smiles]) + self.assertEqual(len(mols2), 1) + self.assertIsNotNone(mols2[0]) + self.assertEqual(mols2[0].GetNumAtoms(), 2) + + def test_getter_setter(self) -> None: + """Test getter and setter methods.""" + smiles2mol = SmilesToMol(remove_hydrogens=False) + self.assertEqual(smiles2mol.get_params()["remove_hydrogens"], False) + params: dict[str, Any] = { + "remove_hydrogens": True, + } + smiles2mol.set_params(**params) + self.assertEqual(smiles2mol.get_params()["remove_hydrogens"], True)