From e428e08d6a42f99f14a4f4297e9a756dda4abc82 Mon Sep 17 00:00:00 2001 From: Jochen Sieg Date: Tue, 13 Aug 2024 14:29:58 +0200 Subject: [PATCH] mol2any: add mol2bool element - Add new MolToBool element that converts values to a bool array. Can be used for masking. --- molpipeline/mol2any/mol2bool.py | 52 +++++++++++++++++++ .../test_mol2any/test_mol2bool.py | 24 +++++++++ 2 files changed, 76 insertions(+) create mode 100644 molpipeline/mol2any/mol2bool.py create mode 100644 tests/test_elements/test_mol2any/test_mol2bool.py diff --git a/molpipeline/mol2any/mol2bool.py b/molpipeline/mol2any/mol2bool.py new file mode 100644 index 00000000..e372fcb1 --- /dev/null +++ b/molpipeline/mol2any/mol2bool.py @@ -0,0 +1,52 @@ +"""Pipeline elements for converting instances to bool.""" + +from typing import Any + +from molpipeline.abstract_pipeline_elements.core import ( + MolToAnyPipelineElement, + InvalidInstance, +) + + +class MolToBool(MolToAnyPipelineElement): + """Element to generate a bool array from input.""" + + def __init__( + self, + name: str = "Mol2Bool", + n_jobs: int = 1, + uuid: str | None = None, + ) -> None: + """Initialize MolToBinaryPipelineElement. + + Parameters + ---------- + name: str, optional (default="Mol2Bool") + name of PipelineElement + n_jobs: int, optional (default=1) + number of jobs to use for parallelization + uuid: Optional[str], optional (default=None) + uuid of PipelineElement, by default None + + Returns + ------- + None + """ + super().__init__(name=name, n_jobs=n_jobs, uuid=uuid) + + def pretransform_single(self, value: Any) -> bool: + """Transform a value to a bool representation. + + Parameters + ---------- + value: Any + Value to be transformed to bool representation. + + Returns + ------- + str + Binary representation of molecule. + """ + if isinstance(value, InvalidInstance): + return False + return True diff --git a/tests/test_elements/test_mol2any/test_mol2bool.py b/tests/test_elements/test_mol2any/test_mol2bool.py new file mode 100644 index 00000000..5847aec7 --- /dev/null +++ b/tests/test_elements/test_mol2any/test_mol2bool.py @@ -0,0 +1,24 @@ +"""Test mol to bool conversion.""" + +import unittest + +from molpipeline.abstract_pipeline_elements.core import InvalidInstance +from molpipeline.mol2any.mol2bool import MolToBool + + +class TestMolToBool(unittest.TestCase): + """Unittest for MolToBool.""" + + def test_bool_conversion(self) -> None: + """Test if the invalid instances are converted to bool.""" + + mol2bool = MolToBool() + result = mol2bool.transform( + [ + 1, + 2, + InvalidInstance(element_id="test", message="test", element_name="Test"), + 4, + ] + ) + self.assertEqual(result, [True, True, False, True])