Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

31 add kwargs to chempropneuralfp init #32

Merged
merged 20 commits into from
Jun 21, 2024
Merged
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion molpipeline/estimators/chemprop/neural_fingerprint.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
"""Wrap Chemprop in a sklearn like transformer returning the neural fingerprint as a numpy array."""

from typing import Self, Sequence
from typing import Any, Self, Sequence

import numpy as np
import numpy.typing as npt
Expand All @@ -21,6 +21,7 @@ def __init__(
batch_size: int = 64,
n_jobs: int = 1,
disable_fitting: bool = False,
**kwargs: Any,
) -> None:
"""Initialize the chemprop neural fingerprint model.

Expand All @@ -36,13 +37,16 @@ def __init__(
The number of jobs to use.
disable_fitting : bool, optional (default=False)
Whether to allow fitting or set to fixed encoding.
**kwargs: Any
Parameters for components of the model.
"""
# pylint: disable=duplicate-code
super().__init__(
model=model,
lightning_trainer=lightning_trainer,
batch_size=batch_size,
n_jobs=n_jobs,
**kwargs,
)
self.disable_fitting = disable_fitting

Expand Down
1 change: 1 addition & 0 deletions test_extras/test_chemprop/chemprop_test_utils/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
"""Functions repeatedly used in tests for Chemprop models."""
49 changes: 49 additions & 0 deletions test_extras/test_chemprop/chemprop_test_utils/compare_models.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
"""Functions for comparing chemprop models."""

from typing import Sequence
from unittest import TestCase

from chemprop.nn.loss import LossFunction
from lightning.pytorch.accelerators import Accelerator
from lightning.pytorch.profilers.base import PassThroughProfiler
from sklearn.base import BaseEstimator
from torch import nn


def compare_params(
test_case: TestCase, model_a: BaseEstimator, model_b: BaseEstimator
) -> None:
"""Compare the parameters of two models.

Parameters
----------
test_case : TestCase
The test case for which to raise the assertion.
model_a : BaseEstimator
The first model.
model_b : BaseEstimator
The second model.
"""
model_a_params = model_a.get_params(deep=True)
model_b_params = model_b.get_params(deep=True)
test_case.assertSetEqual(set(model_a_params.keys()), set(model_b_params.keys()))
for param_name, param_a in model_a_params.items():
param_b = model_b_params[param_name]
test_case.assertEqual(param_a.__class__, param_b.__class__)
if hasattr(param_a, "get_params"):
test_case.assertTrue(hasattr(param_b, "get_params"))
test_case.assertNotEqual(id(param_a), id(param_b))
elif isinstance(param_a, LossFunction):
test_case.assertEqual(
param_a.state_dict()["task_weights"],
param_b.state_dict()["task_weights"],
)
test_case.assertEqual(type(param_a), type(param_b))
elif isinstance(param_a, (nn.Identity, Accelerator, PassThroughProfiler)):
test_case.assertEqual(type(param_a), type(param_b))
elif param_name == "lightning_trainer__callbacks":
test_case.assertIsInstance(param_b, Sequence)
for i, callback in enumerate(param_a):
test_case.assertIsInstance(callback, type(param_b[i]))
else:
test_case.assertEqual(param_a, param_b, f"Test failed for {param_name}")
10 changes: 10 additions & 0 deletions test_extras/test_chemprop/chemprop_test_utils/constant_vars.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
"""Variables that are used in multiple tests."""

NO_IDENTITY_CHECK = [
"model__agg",
"model__message_passing",
"model",
"model__predictor",
"model__predictor__criterion",
"model__predictor__output_transform",
]
27 changes: 27 additions & 0 deletions test_extras/test_chemprop/chemprop_test_utils/default_models.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
"""Functions for creating default chemprop models."""

from molpipeline.estimators.chemprop.component_wrapper import (
MPNN,
BinaryClassificationFFN,
BondMessagePassing,
SumAggregation,
)


def get_classification_mpnn() -> MPNN:
"""Get a Chemprop model for classification.

Returns
-------
ChempropModel
The Chemprop model.
"""
binary_clf_ffn = BinaryClassificationFFN()
aggregate = SumAggregation()
bond_message_passing = BondMessagePassing()
mpnn = MPNN(
message_passing=bond_message_passing,
agg=aggregate,
predictor=binary_clf_ffn,
)
return mpnn
50 changes: 9 additions & 41 deletions test_extras/test_chemprop/test_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,10 @@

import logging
import unittest
from typing import Iterable, Sequence
from typing import Iterable

import torch
from chemprop.nn.loss import BCELoss, LossFunction, MSELoss
from lightning.pytorch.accelerators import Accelerator
from lightning.pytorch.profilers.base import PassThroughProfiler
from chemprop.nn.loss import BCELoss, MSELoss
from sklearn.base import clone
from torch import Tensor, nn

Expand All @@ -27,6 +25,11 @@
from molpipeline.estimators.chemprop.neural_fingerprint import ChempropNeuralFP
from molpipeline.utils.json_operations import recursive_from_json, recursive_to_json

# pylint: disable=relative-beyond-top-level
from .chemprop_test_utils.compare_models import compare_params
from .chemprop_test_utils.constant_vars import NO_IDENTITY_CHECK
from .chemprop_test_utils.default_models import get_classification_mpnn

logging.getLogger("lightning.pytorch.utilities.rank_zero").setLevel(logging.WARNING)


Expand All @@ -38,14 +41,7 @@ def get_model() -> ChempropModel:
ChempropModel
The Chemprop model.
"""
binary_clf_ffn = BinaryClassificationFFN()
aggregate = SumAggregation()
bond_message_passing = BondMessagePassing()
mpnn = MPNN(
message_passing=bond_message_passing,
agg=aggregate,
predictor=binary_clf_ffn,
)
mpnn = get_classification_mpnn()
chemprop_model = ChempropModel(model=mpnn, lightning_trainer__accelerator="cpu")
return chemprop_model

Expand Down Expand Up @@ -121,15 +117,6 @@ def get_model() -> ChempropModel:
"n_jobs": 1,
}

NO_IDENTITY_CHECK = [
"model__agg",
"model__message_passing",
"model",
"model__predictor",
"model__predictor__criterion",
"model__predictor__output_transform",
]


class TestChempropModel(unittest.TestCase):
"""Test the Chemprop model."""
Expand Down Expand Up @@ -177,26 +164,7 @@ def test_clone(self) -> None:
chemprop_model = get_model()
cloned_model = clone(chemprop_model)
self.assertIsInstance(cloned_model, ChempropModel)
cloned_model_params = cloned_model.get_params(deep=True)
for param_name, param in chemprop_model.get_params(deep=True).items():
cloned_param = cloned_model_params[param_name]
if hasattr(param, "get_params"):
self.assertEqual(param.__class__, cloned_param.__class__)
self.assertNotEqual(id(param), id(cloned_param))
elif isinstance(param, LossFunction):
self.assertEqual(
param.state_dict()["task_weights"],
cloned_param.state_dict()["task_weights"],
)
self.assertEqual(type(param), type(cloned_param))
elif isinstance(param, (nn.Identity, Accelerator, PassThroughProfiler)):
self.assertEqual(type(param), type(cloned_param))
elif param_name == "lightning_trainer__callbacks":
self.assertIsInstance(cloned_param, Sequence)
for i, callback in enumerate(param):
self.assertIsInstance(callback, type(cloned_param[i]))
else:
self.assertEqual(param, cloned_param, f"Test failed for {param_name}")
compare_params(self, chemprop_model, cloned_model)

def test_classifier_methods(self) -> None:
"""Test the classifier methods."""
Expand Down
64 changes: 64 additions & 0 deletions test_extras/test_chemprop/test_neural_fingerprint.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
"""Test Chemprop neural fingerprint."""

import logging
import unittest
from typing import Iterable

import torch
from sklearn.base import clone

from molpipeline.estimators.chemprop.neural_fingerprint import ChempropNeuralFP
from molpipeline.utils.json_operations import recursive_from_json, recursive_to_json

# pylint: disable=relative-beyond-top-level
from .chemprop_test_utils.compare_models import compare_params
from .chemprop_test_utils.constant_vars import NO_IDENTITY_CHECK
from .chemprop_test_utils.default_models import get_classification_mpnn

logging.getLogger("lightning.pytorch.utilities.rank_zero").setLevel(logging.WARNING)


def get_neural_fp_encoder() -> ChempropNeuralFP:
"""Get the Chemprop model.

Returns
-------
ChempropNeuralFP
The Chemprop model.
"""
mpnn = get_classification_mpnn()
chemprop_model = ChempropNeuralFP(model=mpnn, lightning_trainer__accelerator="cpu")
return chemprop_model


class TestChempropNeuralFingerprint(unittest.TestCase):
"""Test the Chemprop model."""

def test_clone(self) -> None:
"""Test the clone method."""
chemprop_fp_encoder = get_neural_fp_encoder()
cloned_encoder = clone(chemprop_fp_encoder)
self.assertIsInstance(cloned_encoder, ChempropNeuralFP)
compare_params(self, chemprop_fp_encoder, cloned_encoder)

def test_json_serialization(self) -> None:
"""Test the to_json and from_json methods."""
chemprop_fp_encoder = get_neural_fp_encoder()
chemprop_json = recursive_to_json(chemprop_fp_encoder)
chemprop_encoder_copy = recursive_from_json(chemprop_json)
original_params = chemprop_fp_encoder.get_params(deep=True)
recreated_params = chemprop_encoder_copy.get_params(deep=True)

self.assertSetEqual(set(original_params.keys()), set(recreated_params.keys()))
for param_name, param in original_params.items():
if param_name in NO_IDENTITY_CHECK:
self.assertIsInstance(recreated_params[param_name], type(param))
if isinstance(param, Iterable):
for i, p in enumerate(param):
self.assertIsInstance(recreated_params[param_name][i], type(p))
elif param_name == "model__predictor__task_weights":
self.assertTrue(torch.allclose(param, recreated_params[param_name]))
else:
self.assertEqual(
recreated_params[param_name], param, f"Test failed for {param_name}"
)
Loading