Skip to content

Commit

Permalink
102 bug remove irrelevant params fails for int vaules (#103)
Browse files Browse the repository at this point in the history
* add test to reproduce error
* cover dicts which are not param-dicts
  • Loading branch information
c-w-feldmann authored Oct 29, 2024
1 parent 43cb758 commit 039c7e9
Show file tree
Hide file tree
Showing 3 changed files with 75 additions and 8 deletions.
2 changes: 2 additions & 0 deletions molpipeline/utils/comparison.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,8 @@ def remove_irrelevant_params(params: _T) -> _T:
if isinstance(params, dict):
params_new = {}
for key, value in params.items():
if not isinstance(key, str):
continue
if key.split("__")[-1] in irrelevant_params:
continue
params_new[key] = remove_irrelevant_params(value)
Expand Down
20 changes: 15 additions & 5 deletions tests/test_utils/test_comparison.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,30 @@
"""Test the comparison functions."""

from typing import Callable
from unittest import TestCase

from molpipeline import Pipeline
from molpipeline.utils.comparison import check_pipelines_equivalent
from tests.utils.default_models import get_morgan_physchem_rf_pipeline
from tests.utils.default_models import (
get_morgan_physchem_rf_pipeline,
get_standardization_pipeline,
)


class TestComparison(TestCase):
"""Test if functional equivalent pipelines are detected as such."""

def test_are_equal(self) -> None:
"""Test if two equivalent pipelines are detected as such."""

pipeline_a = get_morgan_physchem_rf_pipeline()
pipeline_b = get_morgan_physchem_rf_pipeline()
self.assertTrue(check_pipelines_equivalent(pipeline_a, pipeline_b))
# Test standardization pipelines
pipline_method_list: list[Callable[[int], Pipeline]] = [
get_standardization_pipeline,
get_morgan_physchem_rf_pipeline,
]
for pipeline_method in pipline_method_list:
pipeline_a = pipeline_method(1)
pipeline_b = pipeline_method(1)
self.assertTrue(check_pipelines_equivalent(pipeline_a, pipeline_b))

def test_are_not_equal(self) -> None:
"""Test if two different pipelines are detected as such."""
Expand Down
61 changes: 58 additions & 3 deletions tests/utils/default_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,13 +9,30 @@
MolToConcatenatedVector,
MolToMorganFP,
MolToRDKitPhysChem,
MolToSmiles,
)
from molpipeline.mol2mol import (
EmptyMoleculeFilter,
FragmentDeduplicator,
MetalDisconnector,
MixtureFilter,
SaltRemover,
StereoRemover,
TautomerCanonicalizer,
Uncharger,
)
from molpipeline.mol2mol.filter import ElementFilter
from molpipeline.post_prediction import PostPredictionWrapper


def get_morgan_physchem_rf_pipeline() -> Pipeline:
def get_morgan_physchem_rf_pipeline(n_jobs: int = 1) -> Pipeline:
"""Get a pipeline combining Morgan fingerprints and physicochemical properties with a RandomForestClassifier.
Parameters
----------
n_jobs: int, default=-1
Number of parallel jobs to use.
Returns
-------
Pipeline
Expand All @@ -35,14 +52,52 @@ def get_morgan_physchem_rf_pipeline() -> Pipeline:
),
),
("error_filter", error_filter),
("rf", RandomForestClassifier()),
("rf", RandomForestClassifier(n_jobs=n_jobs)),
(
"filter_reinserter",
PostPredictionWrapper(
FilterReinserter.from_error_filter(error_filter, None)
),
),
],
n_jobs=1,
n_jobs=n_jobs,
)
return pipeline


def get_standardization_pipeline(n_jobs: int = 1) -> Pipeline:
"""Get the standardization pipeline.
Parameters
----------
n_jobs: int, optional (default=-1)
The number of jobs to use for standardization.
In case of -1, all available CPUs are used.
Returns
-------
Pipeline
The standardization pipeline.
"""
error_filter = ErrorFilter(filter_everything=True)
# Set up pipeline
standardization_pipeline = Pipeline(
[
("smi2mol", SmilesToMol()),
("metal_disconnector", MetalDisconnector()),
("salt_remover", SaltRemover()),
("element_filter", ElementFilter()),
("uncharge1", Uncharger()),
("canonical_tautomer", TautomerCanonicalizer()),
("uncharge2", Uncharger()),
("stereo_remover", StereoRemover()),
("fragment_deduplicator", FragmentDeduplicator()),
("mixture_remover", MixtureFilter()),
("empty_molecule_remover", EmptyMoleculeFilter()),
("mol2smi", MolToSmiles()),
("error_filter", error_filter),
("error_replacer", FilterReinserter.from_error_filter(error_filter, None)),
],
n_jobs=n_jobs,
)
return standardization_pipeline

0 comments on commit 039c7e9

Please sign in to comment.