From 039c7e960c6493cff2887d82dbb7d8e3a47daa5b Mon Sep 17 00:00:00 2001 From: "Christian W. Feldmann" <128160984+c-w-feldmann@users.noreply.github.com> Date: Tue, 29 Oct 2024 10:27:44 +0100 Subject: [PATCH] 102 bug remove irrelevant params fails for int vaules (#103) * add test to reproduce error * cover dicts which are not param-dicts --- molpipeline/utils/comparison.py | 2 + tests/test_utils/test_comparison.py | 20 +++++++--- tests/utils/default_models.py | 61 +++++++++++++++++++++++++++-- 3 files changed, 75 insertions(+), 8 deletions(-) diff --git a/molpipeline/utils/comparison.py b/molpipeline/utils/comparison.py index 941f5a0e..b75b1b96 100644 --- a/molpipeline/utils/comparison.py +++ b/molpipeline/utils/comparison.py @@ -32,6 +32,8 @@ def remove_irrelevant_params(params: _T) -> _T: if isinstance(params, dict): params_new = {} for key, value in params.items(): + if not isinstance(key, str): + continue if key.split("__")[-1] in irrelevant_params: continue params_new[key] = remove_irrelevant_params(value) diff --git a/tests/test_utils/test_comparison.py b/tests/test_utils/test_comparison.py index b2b5d707..636faf72 100644 --- a/tests/test_utils/test_comparison.py +++ b/tests/test_utils/test_comparison.py @@ -1,9 +1,14 @@ """Test the comparison functions.""" +from typing import Callable from unittest import TestCase +from molpipeline import Pipeline from molpipeline.utils.comparison import check_pipelines_equivalent -from tests.utils.default_models import get_morgan_physchem_rf_pipeline +from tests.utils.default_models import ( + get_morgan_physchem_rf_pipeline, + get_standardization_pipeline, +) class TestComparison(TestCase): @@ -11,10 +16,15 @@ class TestComparison(TestCase): def test_are_equal(self) -> None: """Test if two equivalent pipelines are detected as such.""" - - pipeline_a = get_morgan_physchem_rf_pipeline() - pipeline_b = get_morgan_physchem_rf_pipeline() - self.assertTrue(check_pipelines_equivalent(pipeline_a, pipeline_b)) + # Test standardization pipelines + pipline_method_list: list[Callable[[int], Pipeline]] = [ + get_standardization_pipeline, + get_morgan_physchem_rf_pipeline, + ] + for pipeline_method in pipline_method_list: + pipeline_a = pipeline_method(1) + pipeline_b = pipeline_method(1) + self.assertTrue(check_pipelines_equivalent(pipeline_a, pipeline_b)) def test_are_not_equal(self) -> None: """Test if two different pipelines are detected as such.""" diff --git a/tests/utils/default_models.py b/tests/utils/default_models.py index a82fe40d..457ef3ae 100644 --- a/tests/utils/default_models.py +++ b/tests/utils/default_models.py @@ -9,13 +9,30 @@ MolToConcatenatedVector, MolToMorganFP, MolToRDKitPhysChem, + MolToSmiles, ) +from molpipeline.mol2mol import ( + EmptyMoleculeFilter, + FragmentDeduplicator, + MetalDisconnector, + MixtureFilter, + SaltRemover, + StereoRemover, + TautomerCanonicalizer, + Uncharger, +) +from molpipeline.mol2mol.filter import ElementFilter from molpipeline.post_prediction import PostPredictionWrapper -def get_morgan_physchem_rf_pipeline() -> Pipeline: +def get_morgan_physchem_rf_pipeline(n_jobs: int = 1) -> Pipeline: """Get a pipeline combining Morgan fingerprints and physicochemical properties with a RandomForestClassifier. + Parameters + ---------- + n_jobs: int, default=-1 + Number of parallel jobs to use. + Returns ------- Pipeline @@ -35,7 +52,7 @@ def get_morgan_physchem_rf_pipeline() -> Pipeline: ), ), ("error_filter", error_filter), - ("rf", RandomForestClassifier()), + ("rf", RandomForestClassifier(n_jobs=n_jobs)), ( "filter_reinserter", PostPredictionWrapper( @@ -43,6 +60,44 @@ def get_morgan_physchem_rf_pipeline() -> Pipeline: ), ), ], - n_jobs=1, + n_jobs=n_jobs, ) return pipeline + + +def get_standardization_pipeline(n_jobs: int = 1) -> Pipeline: + """Get the standardization pipeline. + + Parameters + ---------- + n_jobs: int, optional (default=-1) + The number of jobs to use for standardization. + In case of -1, all available CPUs are used. + + Returns + ------- + Pipeline + The standardization pipeline. + """ + error_filter = ErrorFilter(filter_everything=True) + # Set up pipeline + standardization_pipeline = Pipeline( + [ + ("smi2mol", SmilesToMol()), + ("metal_disconnector", MetalDisconnector()), + ("salt_remover", SaltRemover()), + ("element_filter", ElementFilter()), + ("uncharge1", Uncharger()), + ("canonical_tautomer", TautomerCanonicalizer()), + ("uncharge2", Uncharger()), + ("stereo_remover", StereoRemover()), + ("fragment_deduplicator", FragmentDeduplicator()), + ("mixture_remover", MixtureFilter()), + ("empty_molecule_remover", EmptyMoleculeFilter()), + ("mol2smi", MolToSmiles()), + ("error_filter", error_filter), + ("error_replacer", FilterReinserter.from_error_filter(error_filter, None)), + ], + n_jobs=n_jobs, + ) + return standardization_pipeline