diff --git a/molpipeline/utils/comparison.py b/molpipeline/utils/comparison.py new file mode 100644 index 00000000..941f5a0e --- /dev/null +++ b/molpipeline/utils/comparison.py @@ -0,0 +1,101 @@ +"""Functions for comparing pipelines.""" + +from typing import Any, TypeVar + +from molpipeline import Pipeline +from molpipeline.utils.json_operations import recursive_to_json + +_T = TypeVar("_T", list[Any], tuple[Any, ...], set[Any], dict[Any, Any], Any) + + +def remove_irrelevant_params(params: _T) -> _T: + """Remove irrelevant parameters from a dictionary. + + Parameters + ---------- + params : TypeVar + Parameters to remove irrelevant parameters from. + + Returns + ------- + TypeVar + Parameters without irrelevant parameters. + """ + if isinstance(params, list): + return [remove_irrelevant_params(val) for val in params] + if isinstance(params, tuple): + return tuple(remove_irrelevant_params(val) for val in params) + if isinstance(params, set): + return {remove_irrelevant_params(val) for val in params} + + irrelevant_params = ["n_jobs", "uuid", "error_filter_id"] + if isinstance(params, dict): + params_new = {} + for key, value in params.items(): + if key.split("__")[-1] in irrelevant_params: + continue + params_new[key] = remove_irrelevant_params(value) + return params_new + return params + + +def compare_recursive( # pylint: disable=too-many-return-statements + value_a: Any, value_b: Any +) -> bool: + """Compare two values recursively. + + Parameters + ---------- + value_a : Any + First value to compare. + value_b : Any + Second value to compare. + + Returns + ------- + bool + True if the values are the same, False otherwise. + """ + if value_a.__class__ != value_b.__class__: + return False + + if isinstance(value_a, dict): + if set(value_a.keys()) != set(value_b.keys()): + return False + for key in value_a: + if not compare_recursive(value_a[key], value_b[key]): + return False + return True + + if isinstance(value_a, (list, tuple)): + if len(value_a) != len(value_b): + return False + for val_a, val_b in zip(value_a, value_b): + if not compare_recursive(val_a, val_b): + return False + return True + return value_a == value_b + + +def check_pipelines_equivalent(pipeline_a: Pipeline, pipeline_b: Pipeline) -> bool: + """Check if two pipelines are the same. + + Parameters + ---------- + pipeline_a : Pipeline + Pipeline to compare. + pipeline_b : Pipeline + Pipeline to compare. + + Returns + ------- + bool + True if the pipelines are the same, False otherwise. + """ + if not isinstance(pipeline_a, Pipeline) or not isinstance(pipeline_b, Pipeline): + raise ValueError("Both inputs should be of type Pipeline.") + pipeline_json_a = recursive_to_json(pipeline_a) + pipeline_json_a = remove_irrelevant_params(pipeline_json_a) + pipeline_json_b = recursive_to_json(pipeline_b) + pipeline_json_b = remove_irrelevant_params(pipeline_json_b) + return compare_recursive(pipeline_json_a, pipeline_json_b) diff --git a/tests/test_utils/test_comparison.py b/tests/test_utils/test_comparison.py new file mode 100644 index 00000000..b2b5d707 --- /dev/null +++ b/tests/test_utils/test_comparison.py @@ -0,0 +1,35 @@ +"""Test the comparison functions.""" + +from unittest import TestCase + +from molpipeline.utils.comparison import check_pipelines_equivalent +from tests.utils.default_models import get_morgan_physchem_rf_pipeline + + +class TestComparison(TestCase): + """Test if functional equivalent pipelines are detected as such.""" + + def test_are_equal(self) -> None: + """Test if two equivalent pipelines are detected as such.""" + + pipeline_a = get_morgan_physchem_rf_pipeline() + pipeline_b = get_morgan_physchem_rf_pipeline() + self.assertTrue(check_pipelines_equivalent(pipeline_a, pipeline_b)) + + def test_are_not_equal(self) -> None: + """Test if two different pipelines are detected as such.""" + # Test changed parameters + pipeline_a = get_morgan_physchem_rf_pipeline() + pipeline_b = get_morgan_physchem_rf_pipeline() + pipeline_b.set_params(mol2fp__morgan__n_bits=1024) + self.assertFalse(check_pipelines_equivalent(pipeline_a, pipeline_b)) + + # Test changed steps + pipeline_b = get_morgan_physchem_rf_pipeline() + last_step = pipeline_b.steps[-1] + pipeline_b.steps = pipeline_b.steps[:-1] + self.assertFalse(check_pipelines_equivalent(pipeline_a, pipeline_b)) + + # Test if adding the step back makes the pipelines equivalent + pipeline_b.steps.append(last_step) + self.assertTrue(check_pipelines_equivalent(pipeline_a, pipeline_b)) diff --git a/tests/utils/default_models.py b/tests/utils/default_models.py new file mode 100644 index 00000000..a82fe40d --- /dev/null +++ b/tests/utils/default_models.py @@ -0,0 +1,48 @@ +"""This module contains the default models used for testing molpipeline functions and classes.""" + +from sklearn.ensemble import RandomForestClassifier + +from molpipeline import Pipeline +from molpipeline.any2mol import SmilesToMol +from molpipeline.error_handling import ErrorFilter, FilterReinserter +from molpipeline.mol2any import ( + MolToConcatenatedVector, + MolToMorganFP, + MolToRDKitPhysChem, +) +from molpipeline.post_prediction import PostPredictionWrapper + + +def get_morgan_physchem_rf_pipeline() -> Pipeline: + """Get a pipeline combining Morgan fingerprints and physicochemical properties with a RandomForestClassifier. + + Returns + ------- + Pipeline + A pipeline combining Morgan fingerprints and physicochemical properties with a RandomForestClassifier. + """ + error_filter = ErrorFilter(filter_everything=True) + pipeline = Pipeline( + [ + ("smi2mol", SmilesToMol()), + ( + "mol2fp", + MolToConcatenatedVector( + [ + ("morgan", MolToMorganFP(n_bits=2048)), + ("physchem", MolToRDKitPhysChem()), + ] + ), + ), + ("error_filter", error_filter), + ("rf", RandomForestClassifier()), + ( + "filter_reinserter", + PostPredictionWrapper( + FilterReinserter.from_error_filter(error_filter, None) + ), + ), + ], + n_jobs=1, + ) + return pipeline