-
Notifications
You must be signed in to change notification settings - Fork 10
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add functionality to compare pipelines (#95)
* Add functionality to compare pipelines * Add tests to verify functionalities
- Loading branch information
1 parent
21b1f10
commit ebd9887
Showing
3 changed files
with
184 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,101 @@ | ||
"""Functions for comparing pipelines.""" | ||
|
||
from typing import Any, TypeVar | ||
|
||
from molpipeline import Pipeline | ||
from molpipeline.utils.json_operations import recursive_to_json | ||
|
||
_T = TypeVar("_T", list[Any], tuple[Any, ...], set[Any], dict[Any, Any], Any) | ||
|
||
|
||
def remove_irrelevant_params(params: _T) -> _T: | ||
"""Remove irrelevant parameters from a dictionary. | ||
Parameters | ||
---------- | ||
params : TypeVar | ||
Parameters to remove irrelevant parameters from. | ||
Returns | ||
------- | ||
TypeVar | ||
Parameters without irrelevant parameters. | ||
""" | ||
if isinstance(params, list): | ||
return [remove_irrelevant_params(val) for val in params] | ||
if isinstance(params, tuple): | ||
return tuple(remove_irrelevant_params(val) for val in params) | ||
if isinstance(params, set): | ||
return {remove_irrelevant_params(val) for val in params} | ||
|
||
irrelevant_params = ["n_jobs", "uuid", "error_filter_id"] | ||
if isinstance(params, dict): | ||
params_new = {} | ||
for key, value in params.items(): | ||
if key.split("__")[-1] in irrelevant_params: | ||
continue | ||
params_new[key] = remove_irrelevant_params(value) | ||
return params_new | ||
return params | ||
|
||
|
||
def compare_recursive( # pylint: disable=too-many-return-statements | ||
value_a: Any, value_b: Any | ||
) -> bool: | ||
"""Compare two values recursively. | ||
Parameters | ||
---------- | ||
value_a : Any | ||
First value to compare. | ||
value_b : Any | ||
Second value to compare. | ||
Returns | ||
------- | ||
bool | ||
True if the values are the same, False otherwise. | ||
""" | ||
if value_a.__class__ != value_b.__class__: | ||
return False | ||
|
||
if isinstance(value_a, dict): | ||
if set(value_a.keys()) != set(value_b.keys()): | ||
return False | ||
for key in value_a: | ||
if not compare_recursive(value_a[key], value_b[key]): | ||
return False | ||
return True | ||
|
||
if isinstance(value_a, (list, tuple)): | ||
if len(value_a) != len(value_b): | ||
return False | ||
for val_a, val_b in zip(value_a, value_b): | ||
if not compare_recursive(val_a, val_b): | ||
return False | ||
return True | ||
return value_a == value_b | ||
|
||
|
||
def check_pipelines_equivalent(pipeline_a: Pipeline, pipeline_b: Pipeline) -> bool: | ||
"""Check if two pipelines are the same. | ||
Parameters | ||
---------- | ||
pipeline_a : Pipeline | ||
Pipeline to compare. | ||
pipeline_b : Pipeline | ||
Pipeline to compare. | ||
Returns | ||
------- | ||
bool | ||
True if the pipelines are the same, False otherwise. | ||
""" | ||
if not isinstance(pipeline_a, Pipeline) or not isinstance(pipeline_b, Pipeline): | ||
raise ValueError("Both inputs should be of type Pipeline.") | ||
pipeline_json_a = recursive_to_json(pipeline_a) | ||
pipeline_json_a = remove_irrelevant_params(pipeline_json_a) | ||
pipeline_json_b = recursive_to_json(pipeline_b) | ||
pipeline_json_b = remove_irrelevant_params(pipeline_json_b) | ||
return compare_recursive(pipeline_json_a, pipeline_json_b) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,35 @@ | ||
"""Test the comparison functions.""" | ||
|
||
from unittest import TestCase | ||
|
||
from molpipeline.utils.comparison import check_pipelines_equivalent | ||
from tests.utils.default_models import get_morgan_physchem_rf_pipeline | ||
|
||
|
||
class TestComparison(TestCase): | ||
"""Test if functional equivalent pipelines are detected as such.""" | ||
|
||
def test_are_equal(self) -> None: | ||
"""Test if two equivalent pipelines are detected as such.""" | ||
|
||
pipeline_a = get_morgan_physchem_rf_pipeline() | ||
pipeline_b = get_morgan_physchem_rf_pipeline() | ||
self.assertTrue(check_pipelines_equivalent(pipeline_a, pipeline_b)) | ||
|
||
def test_are_not_equal(self) -> None: | ||
"""Test if two different pipelines are detected as such.""" | ||
# Test changed parameters | ||
pipeline_a = get_morgan_physchem_rf_pipeline() | ||
pipeline_b = get_morgan_physchem_rf_pipeline() | ||
pipeline_b.set_params(mol2fp__morgan__n_bits=1024) | ||
self.assertFalse(check_pipelines_equivalent(pipeline_a, pipeline_b)) | ||
|
||
# Test changed steps | ||
pipeline_b = get_morgan_physchem_rf_pipeline() | ||
last_step = pipeline_b.steps[-1] | ||
pipeline_b.steps = pipeline_b.steps[:-1] | ||
self.assertFalse(check_pipelines_equivalent(pipeline_a, pipeline_b)) | ||
|
||
# Test if adding the step back makes the pipelines equivalent | ||
pipeline_b.steps.append(last_step) | ||
self.assertTrue(check_pipelines_equivalent(pipeline_a, pipeline_b)) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,48 @@ | ||
"""This module contains the default models used for testing molpipeline functions and classes.""" | ||
|
||
from sklearn.ensemble import RandomForestClassifier | ||
|
||
from molpipeline import Pipeline | ||
from molpipeline.any2mol import SmilesToMol | ||
from molpipeline.error_handling import ErrorFilter, FilterReinserter | ||
from molpipeline.mol2any import ( | ||
MolToConcatenatedVector, | ||
MolToMorganFP, | ||
MolToRDKitPhysChem, | ||
) | ||
from molpipeline.post_prediction import PostPredictionWrapper | ||
|
||
|
||
def get_morgan_physchem_rf_pipeline() -> Pipeline: | ||
"""Get a pipeline combining Morgan fingerprints and physicochemical properties with a RandomForestClassifier. | ||
Returns | ||
------- | ||
Pipeline | ||
A pipeline combining Morgan fingerprints and physicochemical properties with a RandomForestClassifier. | ||
""" | ||
error_filter = ErrorFilter(filter_everything=True) | ||
pipeline = Pipeline( | ||
[ | ||
("smi2mol", SmilesToMol()), | ||
( | ||
"mol2fp", | ||
MolToConcatenatedVector( | ||
[ | ||
("morgan", MolToMorganFP(n_bits=2048)), | ||
("physchem", MolToRDKitPhysChem()), | ||
] | ||
), | ||
), | ||
("error_filter", error_filter), | ||
("rf", RandomForestClassifier()), | ||
( | ||
"filter_reinserter", | ||
PostPredictionWrapper( | ||
FilterReinserter.from_error_filter(error_filter, None) | ||
), | ||
), | ||
], | ||
n_jobs=1, | ||
) | ||
return pipeline |