diff --git a/molpipeline/estimators/chemprop/component_wrapper.py b/molpipeline/estimators/chemprop/component_wrapper.py index 00571d2a..f7182aa2 100644 --- a/molpipeline/estimators/chemprop/component_wrapper.py +++ b/molpipeline/estimators/chemprop/component_wrapper.py @@ -10,17 +10,7 @@ from chemprop.nn.agg import MeanAggregation as _MeanAggregation from chemprop.nn.agg import SumAggregation as _SumAggregation from chemprop.nn.ffn import MLP -from chemprop.nn.loss import ( - BCELoss, - BinaryDirichletLoss, - CrossEntropyLoss, - EvidentialLoss, - LossFunction, - MSELoss, - MulticlassDirichletLoss, - MVELoss, - SIDLoss, -) +from chemprop.nn.loss import LossFunction from chemprop.nn.message_passing import BondMessagePassing as _BondMessagePassing from chemprop.nn.message_passing import MessagePassing from chemprop.nn.metrics import ( @@ -48,6 +38,17 @@ from sklearn.base import BaseEstimator from torch import Tensor, nn +from molpipeline.estimators.chemprop.loss_wrapper import ( + BCELoss, + BinaryDirichletLoss, + CrossEntropyLoss, + EvidentialLoss, + MSELoss, + MulticlassDirichletLoss, + MVELoss, + SIDLoss, +) + # pylint: disable=too-many-ancestors, too-many-instance-attributes class BondMessagePassing(_BondMessagePassing, BaseEstimator): diff --git a/molpipeline/estimators/chemprop/loss_wrapper.py b/molpipeline/estimators/chemprop/loss_wrapper.py new file mode 100644 index 00000000..11ecde9c --- /dev/null +++ b/molpipeline/estimators/chemprop/loss_wrapper.py @@ -0,0 +1,104 @@ +"""Wrapper for Chemprop loss functions.""" + +from typing import Any + +import torch +from chemprop.nn.loss import BCELoss as _BCELoss +from chemprop.nn.loss import BinaryDirichletLoss as _BinaryDirichletLoss +from chemprop.nn.loss import CrossEntropyLoss as _CrossEntropyLoss +from chemprop.nn.loss import EvidentialLoss as _EvidentialLoss +from chemprop.nn.loss import LossFunction as _LossFunction +from chemprop.nn.loss import MSELoss as _MSELoss +from chemprop.nn.loss import MulticlassDirichletLoss as _MulticlassDirichletLoss +from chemprop.nn.loss import MVELoss as _MVELoss +from chemprop.nn.loss import SIDLoss as _SIDLoss +from numpy.typing import ArrayLike + + +class LossFunctionParamMixin: + """Mixin for loss functions to get and set parameters.""" + + _original_task_weights: ArrayLike + + def __init__(self: _LossFunction, task_weights: ArrayLike) -> None: + """Initialize the loss function. + + Parameters + ---------- + task_weights : ArrayLike + The weights for each task. + + """ + super().__init__(task_weights=task_weights) # type: ignore + self._original_task_weights = task_weights + + # pylint: disable=unused-argument + def get_params(self: _LossFunction, deep: bool = True) -> dict[str, Any]: + """Get the parameters of the loss function. + + Parameters + ---------- + deep : bool, optional + Not used, only present to match the sklearn API. + + Returns + ------- + dict[str, Any] + The parameters of the loss function. + """ + return {"task_weights": self._original_task_weights} + + def set_params(self: _LossFunction, **params: Any) -> _LossFunction: + """Set the parameters of the loss function. + + Parameters + ---------- + **params : Any + The parameters to set. + + Returns + ------- + Self + The loss function with the new parameters. + """ + task_weights = params.pop("task_weights", None) + if task_weights is not None: + self._original_task_weights = task_weights + state_dict = self.state_dict() + state_dict["task_weights"] = torch.as_tensor( + task_weights, dtype=torch.float + ).view(1, -1) + self.load_state_dict(state_dict) + return self + + +class BCELoss(LossFunctionParamMixin, _BCELoss): + """Binary cross-entropy loss function.""" + + +class BinaryDirichletLoss(LossFunctionParamMixin, _BinaryDirichletLoss): + """Binary Dirichlet loss function.""" + + +class CrossEntropyLoss(LossFunctionParamMixin, _CrossEntropyLoss): + """Cross-entropy loss function.""" + + +class EvidentialLoss(LossFunctionParamMixin, _EvidentialLoss): + """Evidential loss function.""" + + +class MSELoss(LossFunctionParamMixin, _MSELoss): + """Mean squared error loss function.""" + + +class MulticlassDirichletLoss(LossFunctionParamMixin, _MulticlassDirichletLoss): + """Multiclass Dirichlet loss function.""" + + +class MVELoss(LossFunctionParamMixin, _MVELoss): + """Mean value entropy loss function.""" + + +class SIDLoss(LossFunctionParamMixin, _SIDLoss): + """SID loss function.""" diff --git a/molpipeline/estimators/chemprop/neural_fingerprint.py b/molpipeline/estimators/chemprop/neural_fingerprint.py index 5f04826f..cda5e9a9 100644 --- a/molpipeline/estimators/chemprop/neural_fingerprint.py +++ b/molpipeline/estimators/chemprop/neural_fingerprint.py @@ -1,6 +1,6 @@ """Wrap Chemprop in a sklearn like transformer returning the neural fingerprint as a numpy array.""" -from typing import Self, Sequence +from typing import Any, Self, Sequence import numpy as np import numpy.typing as npt @@ -21,6 +21,7 @@ def __init__( batch_size: int = 64, n_jobs: int = 1, disable_fitting: bool = False, + **kwargs: Any, ) -> None: """Initialize the chemprop neural fingerprint model. @@ -36,6 +37,8 @@ def __init__( The number of jobs to use. disable_fitting : bool, optional (default=False) Whether to allow fitting or set to fixed encoding. + **kwargs: Any + Parameters for components of the model. """ # pylint: disable=duplicate-code super().__init__( @@ -43,6 +46,7 @@ def __init__( lightning_trainer=lightning_trainer, batch_size=batch_size, n_jobs=n_jobs, + **kwargs, ) self.disable_fitting = disable_fitting diff --git a/molpipeline/utils/json_operations.py b/molpipeline/utils/json_operations.py index 2a7098a3..45fdeb62 100644 --- a/molpipeline/utils/json_operations.py +++ b/molpipeline/utils/json_operations.py @@ -8,6 +8,7 @@ from typing import Any from molpipeline.pipeline import Pipeline +from molpipeline.utils.json_operations_torch import tensor_to_json __all__ = [ "builtin_to_json", @@ -258,6 +259,10 @@ def recursive_to_json(obj: Any) -> Any: for key, value in model_params.items(): object_dict[key] = recursive_to_json(value) else: + obj_dict, success = tensor_to_json(obj) + # Either not a tensor or torch is not available + if success: + return obj_dict # If the object is not a sklearn model, a warning is raised # as it might not be possible to recreate the object. warnings.warn( diff --git a/molpipeline/utils/json_operations_torch.py b/molpipeline/utils/json_operations_torch.py new file mode 100644 index 00000000..347fa06f --- /dev/null +++ b/molpipeline/utils/json_operations_torch.py @@ -0,0 +1,60 @@ +"""Functions for serializing and deserializing PyTorch models.""" + +from typing import TypeVar + +try: + import torch + + TORCH_AVAILABLE = True +except ImportError: + TORCH_AVAILABLE = False +from typing import Any, Literal + +_T = TypeVar("_T") + +if TORCH_AVAILABLE: + + def tensor_to_json( + obj: _T, + ) -> tuple[dict[str, Any], Literal[True]] | tuple[_T, Literal[False]]: + """Recursively convert a PyTorch model to a JSON-serializable object. + + Parameters + ---------- + obj : object + The object to convert. + + Returns + ------- + object + The JSON-serializable object. + """ + if isinstance(obj, torch.Tensor): + object_dict: dict[str, Any] = { + "__name__": obj.__class__.__name__, + "__module__": obj.__class__.__module__, + "__init__": True, + } + else: + return obj, False + object_dict["data"] = obj.tolist() + return object_dict, True + +else: + + def tensor_to_json( + obj: _T, + ) -> tuple[dict[str, Any], Literal[True]] | tuple[_T, Literal[False]]: + """Recursively convert a PyTorch model to a JSON-serializable object. + + Parameters + ---------- + obj : object + The object to convert. + + Returns + ------- + object + The JSON-serializable object. + """ + return obj, False diff --git a/test_extras/test_chemprop/chemprop_test_utils/__init__.py b/test_extras/test_chemprop/chemprop_test_utils/__init__.py new file mode 100644 index 00000000..337f6c1b --- /dev/null +++ b/test_extras/test_chemprop/chemprop_test_utils/__init__.py @@ -0,0 +1 @@ +"""Functions repeatedly used in tests for Chemprop models.""" diff --git a/test_extras/test_chemprop/chemprop_test_utils/compare_models.py b/test_extras/test_chemprop/chemprop_test_utils/compare_models.py new file mode 100644 index 00000000..5718439c --- /dev/null +++ b/test_extras/test_chemprop/chemprop_test_utils/compare_models.py @@ -0,0 +1,54 @@ +"""Functions for comparing chemprop models.""" + +from typing import Sequence +from unittest import TestCase + +import torch +from chemprop.nn.loss import LossFunction +from lightning.pytorch.accelerators import Accelerator +from lightning.pytorch.profilers.base import PassThroughProfiler +from sklearn.base import BaseEstimator +from torch import nn + + +def compare_params( + test_case: TestCase, model_a: BaseEstimator, model_b: BaseEstimator +) -> None: + """Compare the parameters of two models. + + Parameters + ---------- + test_case : TestCase + The test case for which to raise the assertion. + model_a : BaseEstimator + The first model. + model_b : BaseEstimator + The second model. + """ + model_a_params = model_a.get_params(deep=True) + model_b_params = model_b.get_params(deep=True) + test_case.assertSetEqual(set(model_a_params.keys()), set(model_b_params.keys())) + for param_name, param_a in model_a_params.items(): + param_b = model_b_params[param_name] + test_case.assertEqual(param_a.__class__, param_b.__class__) + if hasattr(param_a, "get_params"): + test_case.assertTrue(hasattr(param_b, "get_params")) + test_case.assertNotEqual(id(param_a), id(param_b)) + elif isinstance(param_a, LossFunction): + test_case.assertEqual( + param_a.state_dict()["task_weights"], + param_b.state_dict()["task_weights"], + ) + test_case.assertEqual(type(param_a), type(param_b)) + elif isinstance(param_a, (nn.Identity, Accelerator, PassThroughProfiler)): + test_case.assertEqual(type(param_a), type(param_b)) + elif isinstance(param_a, torch.Tensor): + test_case.assertTrue( + torch.equal(param_a, param_b), f"Test failed for {param_name}" + ) + elif param_name == "lightning_trainer__callbacks": + test_case.assertIsInstance(param_b, Sequence) + for i, callback in enumerate(param_a): + test_case.assertIsInstance(callback, type(param_b[i])) + else: + test_case.assertEqual(param_a, param_b, f"Test failed for {param_name}") diff --git a/test_extras/test_chemprop/chemprop_test_utils/constant_vars.py b/test_extras/test_chemprop/chemprop_test_utils/constant_vars.py new file mode 100644 index 00000000..89663866 --- /dev/null +++ b/test_extras/test_chemprop/chemprop_test_utils/constant_vars.py @@ -0,0 +1,96 @@ +"""Variables that are used in multiple tests.""" + +from chemprop.nn import BCELoss +from torch import Tensor, nn + +from molpipeline.estimators.chemprop.component_wrapper import ( + MPNN, + BinaryClassificationFFN, + BondMessagePassing, + SumAggregation, +) + +# These are model parameters which are copied by value, but are too complex to check for equality. +# Thus, for these model parameters, only the type is checked. +NO_IDENTITY_CHECK = [ + "model__agg", + "model__message_passing", + "model", + "model__predictor", + "model__predictor__criterion", + "model__predictor__output_transform", +] + +# Default parameters for the Chemprop model. + +DEFAULT_PARAMS = { + "batch_size": 64, + "lightning_trainer": None, + "lightning_trainer__enable_checkpointing": False, + "lightning_trainer__enable_model_summary": False, + "lightning_trainer__max_epochs": 500, + "lightning_trainer__accelerator": "cpu", + "lightning_trainer__default_root_dir": None, + "lightning_trainer__limit_predict_batches": 1.0, + "lightning_trainer__detect_anomaly": False, + "lightning_trainer__reload_dataloaders_every_n_epochs": 0, + "lightning_trainer__precision": "32-true", + "lightning_trainer__min_steps": None, + "lightning_trainer__max_time": None, + "lightning_trainer__limit_train_batches": 1.0, + "lightning_trainer__strategy": "auto", + "lightning_trainer__gradient_clip_algorithm": None, + "lightning_trainer__log_every_n_steps": 50, + "lightning_trainer__limit_val_batches": 1.0, + "lightning_trainer__gradient_clip_val": None, + "lightning_trainer__overfit_batches": 0.0, + "lightning_trainer__num_nodes": 1, + "lightning_trainer__use_distributed_sampler": True, + "lightning_trainer__check_val_every_n_epoch": 1, + "lightning_trainer__benchmark": False, + "lightning_trainer__inference_mode": True, + "lightning_trainer__limit_test_batches": 1.0, + "lightning_trainer__fast_dev_run": False, + "lightning_trainer__logger": None, + "lightning_trainer__max_steps": -1, + "lightning_trainer__num_sanity_val_steps": 2, + "lightning_trainer__devices": "auto", + "lightning_trainer__min_epochs": None, + "lightning_trainer__val_check_interval": 1.0, + "lightning_trainer__barebones": False, + "lightning_trainer__accumulate_grad_batches": 1, + "lightning_trainer__deterministic": False, + "lightning_trainer__enable_progress_bar": True, + "model": MPNN, + "model__agg__dim": 0, + "model__agg": SumAggregation, + "model__batch_norm": True, + "model__final_lr": 0.0001, + "model__init_lr": 0.0001, + "model__max_lr": 0.001, + "model__message_passing__activation": "relu", + "model__message_passing__bias": False, + "model__message_passing__d_e": 14, + "model__message_passing__d_h": 300, + "model__message_passing__d_v": 72, + "model__message_passing__d_vd": None, + "model__message_passing__depth": 3, + "model__message_passing__dropout_rate": 0.0, + "model__message_passing__undirected": False, + "model__message_passing": BondMessagePassing, + "model__metric_list": None, + "model__predictor__activation": "relu", + "model__warmup_epochs": 2, + "model__predictor": BinaryClassificationFFN, + "model__predictor__criterion": BCELoss, + "model__predictor__criterion__task_weights": Tensor([1.0]), + "model__predictor__dropout": 0, + "model__predictor__hidden_dim": 300, + "model__predictor__input_dim": 300, + "model__predictor__n_layers": 1, + "model__predictor__n_tasks": 1, + "model__predictor__output_transform": nn.Identity, + "model__predictor__task_weights": Tensor([1.0]), + "model__predictor__threshold": None, + "n_jobs": 1, +} diff --git a/test_extras/test_chemprop/chemprop_test_utils/default_models.py b/test_extras/test_chemprop/chemprop_test_utils/default_models.py new file mode 100644 index 00000000..9f72db33 --- /dev/null +++ b/test_extras/test_chemprop/chemprop_test_utils/default_models.py @@ -0,0 +1,54 @@ +"""Functions for creating default chemprop models.""" + +from molpipeline.estimators.chemprop import ChempropModel, ChempropNeuralFP +from molpipeline.estimators.chemprop.component_wrapper import ( + MPNN, + BinaryClassificationFFN, + BondMessagePassing, + SumAggregation, +) + + +def get_binary_classification_mpnn() -> MPNN: + """Get a Chemprop model for binary classification. + + Returns + ------- + ChempropModel + The Chemprop model. + """ + binary_clf_ffn = BinaryClassificationFFN() + aggregate = SumAggregation() + bond_message_passing = BondMessagePassing() + mpnn = MPNN( + message_passing=bond_message_passing, + agg=aggregate, + predictor=binary_clf_ffn, + ) + return mpnn + + +def get_neural_fp_encoder() -> ChempropNeuralFP: + """Get the Chemprop model. + + Returns + ------- + ChempropNeuralFP + The Chemprop model. + """ + mpnn = get_binary_classification_mpnn() + chemprop_model = ChempropNeuralFP(model=mpnn, lightning_trainer__accelerator="cpu") + return chemprop_model + + +def get_chemprop_model_binary_classification_mpnn() -> ChempropModel: + """Get the Chemprop model. + + Returns + ------- + ChempropModel + The Chemprop model. + """ + mpnn = get_binary_classification_mpnn() + chemprop_model = ChempropModel(model=mpnn, lightning_trainer__accelerator="cpu") + return chemprop_model diff --git a/test_extras/test_chemprop/test_models.py b/test_extras/test_chemprop/test_models.py index 639b8c4d..66c1c94e 100644 --- a/test_extras/test_chemprop/test_models.py +++ b/test_extras/test_chemprop/test_models.py @@ -2,22 +2,16 @@ import logging import unittest -from typing import Iterable, Sequence +from typing import Iterable import torch -from chemprop.nn.loss import BCELoss, LossFunction, MSELoss -from lightning.pytorch.accelerators import Accelerator -from lightning.pytorch.profilers.base import PassThroughProfiler +from chemprop.nn.loss import MSELoss from sklearn.base import clone -from torch import Tensor, nn from molpipeline.estimators.chemprop.component_wrapper import ( MPNN, - BinaryClassificationFFN, - BondMessagePassing, MeanAggregation, RegressionFFN, - SumAggregation, ) from molpipeline.estimators.chemprop.models import ( ChempropClassifier, @@ -27,108 +21,14 @@ from molpipeline.estimators.chemprop.neural_fingerprint import ChempropNeuralFP from molpipeline.utils.json_operations import recursive_from_json, recursive_to_json -logging.getLogger("lightning.pytorch.utilities.rank_zero").setLevel(logging.WARNING) - - -def get_model() -> ChempropModel: - """Get the Chemprop model. - - Returns - ------- - ChempropModel - The Chemprop model. - """ - binary_clf_ffn = BinaryClassificationFFN() - aggregate = SumAggregation() - bond_message_passing = BondMessagePassing() - mpnn = MPNN( - message_passing=bond_message_passing, - agg=aggregate, - predictor=binary_clf_ffn, - ) - chemprop_model = ChempropModel(model=mpnn, lightning_trainer__accelerator="cpu") - return chemprop_model - - -DEFAULT_PARAMS = { - "batch_size": 64, - "lightning_trainer": None, - "lightning_trainer__enable_checkpointing": False, - "lightning_trainer__enable_model_summary": False, - "lightning_trainer__max_epochs": 500, - "lightning_trainer__accelerator": "cpu", - "lightning_trainer__default_root_dir": None, - "lightning_trainer__limit_predict_batches": 1.0, - "lightning_trainer__detect_anomaly": False, - "lightning_trainer__reload_dataloaders_every_n_epochs": 0, - "lightning_trainer__precision": "32-true", - "lightning_trainer__min_steps": None, - "lightning_trainer__max_time": None, - "lightning_trainer__limit_train_batches": 1.0, - "lightning_trainer__strategy": "auto", - "lightning_trainer__gradient_clip_algorithm": None, - "lightning_trainer__log_every_n_steps": 50, - "lightning_trainer__limit_val_batches": 1.0, - "lightning_trainer__gradient_clip_val": None, - "lightning_trainer__overfit_batches": 0.0, - "lightning_trainer__num_nodes": 1, - "lightning_trainer__use_distributed_sampler": True, - "lightning_trainer__check_val_every_n_epoch": 1, - "lightning_trainer__benchmark": False, - "lightning_trainer__inference_mode": True, - "lightning_trainer__limit_test_batches": 1.0, - "lightning_trainer__fast_dev_run": False, - "lightning_trainer__logger": None, - "lightning_trainer__max_steps": -1, - "lightning_trainer__num_sanity_val_steps": 2, - "lightning_trainer__devices": "auto", - "lightning_trainer__min_epochs": None, - "lightning_trainer__val_check_interval": 1.0, - "lightning_trainer__barebones": False, - "lightning_trainer__accumulate_grad_batches": 1, - "lightning_trainer__deterministic": False, - "lightning_trainer__enable_progress_bar": True, - "model": MPNN, - "model__agg__dim": 0, - "model__agg": SumAggregation, - "model__batch_norm": True, - "model__final_lr": 0.0001, - "model__init_lr": 0.0001, - "model__max_lr": 0.001, - "model__message_passing__activation": "relu", - "model__message_passing__bias": False, - "model__message_passing__d_e": 14, - "model__message_passing__d_h": 300, - "model__message_passing__d_v": 72, - "model__message_passing__d_vd": None, - "model__message_passing__depth": 3, - "model__message_passing__dropout_rate": 0.0, - "model__message_passing__undirected": False, - "model__message_passing": BondMessagePassing, - "model__metric_list": None, - "model__predictor__activation": "relu", - "model__warmup_epochs": 2, - "model__predictor": BinaryClassificationFFN, - "model__predictor__criterion": BCELoss, - "model__predictor__dropout": 0, - "model__predictor__hidden_dim": 300, - "model__predictor__input_dim": 300, - "model__predictor__n_layers": 1, - "model__predictor__n_tasks": 1, - "model__predictor__output_transform": nn.Identity, - "model__predictor__task_weights": Tensor([1.0]), - "model__predictor__threshold": None, - "n_jobs": 1, -} +# pylint: disable=relative-beyond-top-level +from .chemprop_test_utils.compare_models import compare_params +from .chemprop_test_utils.constant_vars import DEFAULT_PARAMS, NO_IDENTITY_CHECK +from .chemprop_test_utils.default_models import ( + get_chemprop_model_binary_classification_mpnn, +) -NO_IDENTITY_CHECK = [ - "model__agg", - "model__message_passing", - "model", - "model__predictor", - "model__predictor__criterion", - "model__predictor__output_transform", -] +logging.getLogger("lightning.pytorch.utilities.rank_zero").setLevel(logging.WARNING) class TestChempropModel(unittest.TestCase): @@ -136,7 +36,7 @@ class TestChempropModel(unittest.TestCase): def test_get_params(self) -> None: """Test the get_params and set_params methods.""" - chemprop_model = get_model() + chemprop_model = get_chemprop_model_binary_classification_mpnn() orig_params = chemprop_model.get_params(deep=True) expected_params = dict(DEFAULT_PARAMS) # Shallow copy @@ -174,33 +74,14 @@ def test_get_params(self) -> None: def test_clone(self) -> None: """Test the clone method.""" - chemprop_model = get_model() + chemprop_model = get_chemprop_model_binary_classification_mpnn() cloned_model = clone(chemprop_model) self.assertIsInstance(cloned_model, ChempropModel) - cloned_model_params = cloned_model.get_params(deep=True) - for param_name, param in chemprop_model.get_params(deep=True).items(): - cloned_param = cloned_model_params[param_name] - if hasattr(param, "get_params"): - self.assertEqual(param.__class__, cloned_param.__class__) - self.assertNotEqual(id(param), id(cloned_param)) - elif isinstance(param, LossFunction): - self.assertEqual( - param.state_dict()["task_weights"], - cloned_param.state_dict()["task_weights"], - ) - self.assertEqual(type(param), type(cloned_param)) - elif isinstance(param, (nn.Identity, Accelerator, PassThroughProfiler)): - self.assertEqual(type(param), type(cloned_param)) - elif param_name == "lightning_trainer__callbacks": - self.assertIsInstance(cloned_param, Sequence) - for i, callback in enumerate(param): - self.assertIsInstance(callback, type(cloned_param[i])) - else: - self.assertEqual(param, cloned_param, f"Test failed for {param_name}") + compare_params(self, chemprop_model, cloned_model) def test_classifier_methods(self) -> None: """Test the classifier methods.""" - chemprop_model = get_model() + chemprop_model = get_chemprop_model_binary_classification_mpnn() # pylint: disable=protected-access self.assertTrue(chemprop_model._is_binary_classifier()) self.assertFalse(chemprop_model._is_multiclass_classifier()) @@ -209,7 +90,7 @@ def test_classifier_methods(self) -> None: def test_neural_fp(self) -> None: """Test the to_encoder method.""" - chemprop_model = get_model() + chemprop_model = get_chemprop_model_binary_classification_mpnn() neural_fp = chemprop_model.to_encoder() self.assertIsInstance(neural_fp, ChempropNeuralFP) self.assertIsInstance(neural_fp.model, MPNN) @@ -219,7 +100,7 @@ def test_neural_fp(self) -> None: def test_json_serialization(self) -> None: """Test the to_json and from_json methods.""" - chemprop_model = get_model() + chemprop_model = get_chemprop_model_binary_classification_mpnn() chemprop_json = recursive_to_json(chemprop_model) chemprop_model_copy = recursive_from_json(chemprop_json) param_dict = chemprop_model_copy.get_params(deep=True) diff --git a/test_extras/test_chemprop/test_neural_fingerprint.py b/test_extras/test_chemprop/test_neural_fingerprint.py new file mode 100644 index 00000000..3a4f13b3 --- /dev/null +++ b/test_extras/test_chemprop/test_neural_fingerprint.py @@ -0,0 +1,33 @@ +"""Test Chemprop neural fingerprint.""" + +import logging +import unittest + +from sklearn.base import clone + +from molpipeline.estimators.chemprop.neural_fingerprint import ChempropNeuralFP +from molpipeline.utils.json_operations import recursive_from_json, recursive_to_json + +# pylint: disable=relative-beyond-top-level +from .chemprop_test_utils.compare_models import compare_params +from .chemprop_test_utils.default_models import get_neural_fp_encoder + +logging.getLogger("lightning.pytorch.utilities.rank_zero").setLevel(logging.WARNING) + + +class TestChempropNeuralFingerprint(unittest.TestCase): + """Test the Chemprop model.""" + + def test_clone(self) -> None: + """Test the clone method.""" + chemprop_fp_encoder = get_neural_fp_encoder() + cloned_encoder = clone(chemprop_fp_encoder) + self.assertIsInstance(cloned_encoder, ChempropNeuralFP) + compare_params(self, chemprop_fp_encoder, cloned_encoder) + + def test_json_serialization(self) -> None: + """Test the to_json and from_json methods.""" + chemprop_fp_encoder = get_neural_fp_encoder() + chemprop_json = recursive_to_json(chemprop_fp_encoder) + chemprop_encoder_copy = recursive_from_json(chemprop_json) + compare_params(self, chemprop_fp_encoder, chemprop_encoder_copy)