Skip to content

Commit

Permalink
Add functionality to compare pipelines (#95)
Browse files Browse the repository at this point in the history
* Add functionality to compare pipelines
* Add tests to verify functionalities
  • Loading branch information
c-w-feldmann authored Oct 7, 2024
1 parent 21b1f10 commit ebd9887
Show file tree
Hide file tree
Showing 3 changed files with 184 additions and 0 deletions.
101 changes: 101 additions & 0 deletions molpipeline/utils/comparison.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,101 @@
"""Functions for comparing pipelines."""

from typing import Any, TypeVar

from molpipeline import Pipeline
from molpipeline.utils.json_operations import recursive_to_json

_T = TypeVar("_T", list[Any], tuple[Any, ...], set[Any], dict[Any, Any], Any)


def remove_irrelevant_params(params: _T) -> _T:
"""Remove irrelevant parameters from a dictionary.
Parameters
----------
params : TypeVar
Parameters to remove irrelevant parameters from.
Returns
-------
TypeVar
Parameters without irrelevant parameters.
"""
if isinstance(params, list):
return [remove_irrelevant_params(val) for val in params]
if isinstance(params, tuple):
return tuple(remove_irrelevant_params(val) for val in params)
if isinstance(params, set):
return {remove_irrelevant_params(val) for val in params}

irrelevant_params = ["n_jobs", "uuid", "error_filter_id"]
if isinstance(params, dict):
params_new = {}
for key, value in params.items():
if key.split("__")[-1] in irrelevant_params:
continue
params_new[key] = remove_irrelevant_params(value)
return params_new
return params


def compare_recursive( # pylint: disable=too-many-return-statements
value_a: Any, value_b: Any
) -> bool:
"""Compare two values recursively.
Parameters
----------
value_a : Any
First value to compare.
value_b : Any
Second value to compare.
Returns
-------
bool
True if the values are the same, False otherwise.
"""
if value_a.__class__ != value_b.__class__:
return False

if isinstance(value_a, dict):
if set(value_a.keys()) != set(value_b.keys()):
return False
for key in value_a:
if not compare_recursive(value_a[key], value_b[key]):
return False
return True

if isinstance(value_a, (list, tuple)):
if len(value_a) != len(value_b):
return False
for val_a, val_b in zip(value_a, value_b):
if not compare_recursive(val_a, val_b):
return False
return True
return value_a == value_b


def check_pipelines_equivalent(pipeline_a: Pipeline, pipeline_b: Pipeline) -> bool:
"""Check if two pipelines are the same.
Parameters
----------
pipeline_a : Pipeline
Pipeline to compare.
pipeline_b : Pipeline
Pipeline to compare.
Returns
-------
bool
True if the pipelines are the same, False otherwise.
"""
if not isinstance(pipeline_a, Pipeline) or not isinstance(pipeline_b, Pipeline):
raise ValueError("Both inputs should be of type Pipeline.")
pipeline_json_a = recursive_to_json(pipeline_a)
pipeline_json_a = remove_irrelevant_params(pipeline_json_a)
pipeline_json_b = recursive_to_json(pipeline_b)
pipeline_json_b = remove_irrelevant_params(pipeline_json_b)
return compare_recursive(pipeline_json_a, pipeline_json_b)
35 changes: 35 additions & 0 deletions tests/test_utils/test_comparison.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
"""Test the comparison functions."""

from unittest import TestCase

from molpipeline.utils.comparison import check_pipelines_equivalent
from tests.utils.default_models import get_morgan_physchem_rf_pipeline


class TestComparison(TestCase):
"""Test if functional equivalent pipelines are detected as such."""

def test_are_equal(self) -> None:
"""Test if two equivalent pipelines are detected as such."""

pipeline_a = get_morgan_physchem_rf_pipeline()
pipeline_b = get_morgan_physchem_rf_pipeline()
self.assertTrue(check_pipelines_equivalent(pipeline_a, pipeline_b))

def test_are_not_equal(self) -> None:
"""Test if two different pipelines are detected as such."""
# Test changed parameters
pipeline_a = get_morgan_physchem_rf_pipeline()
pipeline_b = get_morgan_physchem_rf_pipeline()
pipeline_b.set_params(mol2fp__morgan__n_bits=1024)
self.assertFalse(check_pipelines_equivalent(pipeline_a, pipeline_b))

# Test changed steps
pipeline_b = get_morgan_physchem_rf_pipeline()
last_step = pipeline_b.steps[-1]
pipeline_b.steps = pipeline_b.steps[:-1]
self.assertFalse(check_pipelines_equivalent(pipeline_a, pipeline_b))

# Test if adding the step back makes the pipelines equivalent
pipeline_b.steps.append(last_step)
self.assertTrue(check_pipelines_equivalent(pipeline_a, pipeline_b))
48 changes: 48 additions & 0 deletions tests/utils/default_models.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
"""This module contains the default models used for testing molpipeline functions and classes."""

from sklearn.ensemble import RandomForestClassifier

from molpipeline import Pipeline
from molpipeline.any2mol import SmilesToMol
from molpipeline.error_handling import ErrorFilter, FilterReinserter
from molpipeline.mol2any import (
MolToConcatenatedVector,
MolToMorganFP,
MolToRDKitPhysChem,
)
from molpipeline.post_prediction import PostPredictionWrapper


def get_morgan_physchem_rf_pipeline() -> Pipeline:
"""Get a pipeline combining Morgan fingerprints and physicochemical properties with a RandomForestClassifier.
Returns
-------
Pipeline
A pipeline combining Morgan fingerprints and physicochemical properties with a RandomForestClassifier.
"""
error_filter = ErrorFilter(filter_everything=True)
pipeline = Pipeline(
[
("smi2mol", SmilesToMol()),
(
"mol2fp",
MolToConcatenatedVector(
[
("morgan", MolToMorganFP(n_bits=2048)),
("physchem", MolToRDKitPhysChem()),
]
),
),
("error_filter", error_filter),
("rf", RandomForestClassifier()),
(
"filter_reinserter",
PostPredictionWrapper(
FilterReinserter.from_error_filter(error_filter, None)
),
),
],
n_jobs=1,
)
return pipeline

0 comments on commit ebd9887

Please sign in to comment.