Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

31 add kwargs to chempropneuralfp init #32

Merged
merged 20 commits into from
Jun 21, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 12 additions & 11 deletions molpipeline/estimators/chemprop/component_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,17 +10,7 @@
from chemprop.nn.agg import MeanAggregation as _MeanAggregation
from chemprop.nn.agg import SumAggregation as _SumAggregation
from chemprop.nn.ffn import MLP
from chemprop.nn.loss import (
BCELoss,
BinaryDirichletLoss,
CrossEntropyLoss,
EvidentialLoss,
LossFunction,
MSELoss,
MulticlassDirichletLoss,
MVELoss,
SIDLoss,
)
from chemprop.nn.loss import LossFunction
from chemprop.nn.message_passing import BondMessagePassing as _BondMessagePassing
from chemprop.nn.message_passing import MessagePassing
from chemprop.nn.metrics import (
Expand Down Expand Up @@ -48,6 +38,17 @@
from sklearn.base import BaseEstimator
from torch import Tensor, nn

from molpipeline.estimators.chemprop.loss_wrapper import (
BCELoss,
BinaryDirichletLoss,
CrossEntropyLoss,
EvidentialLoss,
MSELoss,
MulticlassDirichletLoss,
MVELoss,
SIDLoss,
)


# pylint: disable=too-many-ancestors, too-many-instance-attributes
class BondMessagePassing(_BondMessagePassing, BaseEstimator):
Expand Down
104 changes: 104 additions & 0 deletions molpipeline/estimators/chemprop/loss_wrapper.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,104 @@
"""Wrapper for Chemprop loss functions."""

from typing import Any

import torch
from chemprop.nn.loss import BCELoss as _BCELoss
from chemprop.nn.loss import BinaryDirichletLoss as _BinaryDirichletLoss
from chemprop.nn.loss import CrossEntropyLoss as _CrossEntropyLoss
from chemprop.nn.loss import EvidentialLoss as _EvidentialLoss
from chemprop.nn.loss import LossFunction as _LossFunction
from chemprop.nn.loss import MSELoss as _MSELoss
from chemprop.nn.loss import MulticlassDirichletLoss as _MulticlassDirichletLoss
from chemprop.nn.loss import MVELoss as _MVELoss
from chemprop.nn.loss import SIDLoss as _SIDLoss
from numpy.typing import ArrayLike


class LossFunctionParamMixin:
"""Mixin for loss functions to get and set parameters."""

_original_task_weights: ArrayLike

def __init__(self: _LossFunction, task_weights: ArrayLike) -> None:
"""Initialize the loss function.

Parameters
----------
task_weights : ArrayLike
The weights for each task.

"""
super().__init__(task_weights=task_weights) # type: ignore
self._original_task_weights = task_weights

# pylint: disable=unused-argument
def get_params(self: _LossFunction, deep: bool = True) -> dict[str, Any]:
"""Get the parameters of the loss function.

Parameters
----------
deep : bool, optional
Not used, only present to match the sklearn API.

Returns
-------
dict[str, Any]
The parameters of the loss function.
"""
return {"task_weights": self._original_task_weights}

def set_params(self: _LossFunction, **params: Any) -> _LossFunction:
"""Set the parameters of the loss function.

Parameters
----------
**params : Any
The parameters to set.

Returns
-------
Self
The loss function with the new parameters.
"""
task_weights = params.pop("task_weights", None)
if task_weights is not None:
self._original_task_weights = task_weights
state_dict = self.state_dict()
state_dict["task_weights"] = torch.as_tensor(
task_weights, dtype=torch.float
).view(1, -1)
self.load_state_dict(state_dict)
return self


class BCELoss(LossFunctionParamMixin, _BCELoss):
"""Binary cross-entropy loss function."""


class BinaryDirichletLoss(LossFunctionParamMixin, _BinaryDirichletLoss):
"""Binary Dirichlet loss function."""


class CrossEntropyLoss(LossFunctionParamMixin, _CrossEntropyLoss):
"""Cross-entropy loss function."""


class EvidentialLoss(LossFunctionParamMixin, _EvidentialLoss):
"""Evidential loss function."""


class MSELoss(LossFunctionParamMixin, _MSELoss):
"""Mean squared error loss function."""


class MulticlassDirichletLoss(LossFunctionParamMixin, _MulticlassDirichletLoss):
"""Multiclass Dirichlet loss function."""


class MVELoss(LossFunctionParamMixin, _MVELoss):
"""Mean value entropy loss function."""


class SIDLoss(LossFunctionParamMixin, _SIDLoss):
"""SID loss function."""
6 changes: 5 additions & 1 deletion molpipeline/estimators/chemprop/neural_fingerprint.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
"""Wrap Chemprop in a sklearn like transformer returning the neural fingerprint as a numpy array."""

from typing import Self, Sequence
from typing import Any, Self, Sequence

import numpy as np
import numpy.typing as npt
Expand All @@ -21,6 +21,7 @@ def __init__(
batch_size: int = 64,
n_jobs: int = 1,
disable_fitting: bool = False,
**kwargs: Any,
) -> None:
"""Initialize the chemprop neural fingerprint model.

Expand All @@ -36,13 +37,16 @@ def __init__(
The number of jobs to use.
disable_fitting : bool, optional (default=False)
Whether to allow fitting or set to fixed encoding.
**kwargs: Any
Parameters for components of the model.
"""
# pylint: disable=duplicate-code
super().__init__(
model=model,
lightning_trainer=lightning_trainer,
batch_size=batch_size,
n_jobs=n_jobs,
**kwargs,
)
self.disable_fitting = disable_fitting

Expand Down
5 changes: 5 additions & 0 deletions molpipeline/utils/json_operations.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from typing import Any

from molpipeline.pipeline import Pipeline
from molpipeline.utils.json_operations_torch import tensor_to_json

__all__ = [
"builtin_to_json",
Expand Down Expand Up @@ -258,6 +259,10 @@ def recursive_to_json(obj: Any) -> Any:
for key, value in model_params.items():
object_dict[key] = recursive_to_json(value)
else:
obj_dict, success = tensor_to_json(obj)
# Either not a tensor or torch is not available
if success:
return obj_dict
# If the object is not a sklearn model, a warning is raised
# as it might not be possible to recreate the object.
warnings.warn(
Expand Down
60 changes: 60 additions & 0 deletions molpipeline/utils/json_operations_torch.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
"""Functions for serializing and deserializing PyTorch models."""

from typing import TypeVar

try:
import torch

TORCH_AVAILABLE = True
except ImportError:
TORCH_AVAILABLE = False
from typing import Any, Literal

_T = TypeVar("_T")

if TORCH_AVAILABLE:

def tensor_to_json(
obj: _T,
) -> tuple[dict[str, Any], Literal[True]] | tuple[_T, Literal[False]]:
"""Recursively convert a PyTorch model to a JSON-serializable object.

Parameters
----------
obj : object
The object to convert.

Returns
-------
object
The JSON-serializable object.
"""
if isinstance(obj, torch.Tensor):
object_dict: dict[str, Any] = {
"__name__": obj.__class__.__name__,
"__module__": obj.__class__.__module__,
"__init__": True,
}
else:
return obj, False
object_dict["data"] = obj.tolist()
return object_dict, True

else:

def tensor_to_json(
obj: _T,
) -> tuple[dict[str, Any], Literal[True]] | tuple[_T, Literal[False]]:
"""Recursively convert a PyTorch model to a JSON-serializable object.

Parameters
----------
obj : object
The object to convert.

Returns
-------
object
The JSON-serializable object.
"""
return obj, False
1 change: 1 addition & 0 deletions test_extras/test_chemprop/chemprop_test_utils/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
"""Functions repeatedly used in tests for Chemprop models."""
54 changes: 54 additions & 0 deletions test_extras/test_chemprop/chemprop_test_utils/compare_models.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
"""Functions for comparing chemprop models."""

from typing import Sequence
from unittest import TestCase

import torch
from chemprop.nn.loss import LossFunction
from lightning.pytorch.accelerators import Accelerator
from lightning.pytorch.profilers.base import PassThroughProfiler
from sklearn.base import BaseEstimator
from torch import nn


def compare_params(
test_case: TestCase, model_a: BaseEstimator, model_b: BaseEstimator
) -> None:
"""Compare the parameters of two models.

Parameters
----------
test_case : TestCase
The test case for which to raise the assertion.
model_a : BaseEstimator
The first model.
model_b : BaseEstimator
The second model.
"""
model_a_params = model_a.get_params(deep=True)
model_b_params = model_b.get_params(deep=True)
test_case.assertSetEqual(set(model_a_params.keys()), set(model_b_params.keys()))
for param_name, param_a in model_a_params.items():
param_b = model_b_params[param_name]
test_case.assertEqual(param_a.__class__, param_b.__class__)
if hasattr(param_a, "get_params"):
test_case.assertTrue(hasattr(param_b, "get_params"))
test_case.assertNotEqual(id(param_a), id(param_b))
elif isinstance(param_a, LossFunction):
test_case.assertEqual(
param_a.state_dict()["task_weights"],
param_b.state_dict()["task_weights"],
)
test_case.assertEqual(type(param_a), type(param_b))
elif isinstance(param_a, (nn.Identity, Accelerator, PassThroughProfiler)):
test_case.assertEqual(type(param_a), type(param_b))
elif isinstance(param_a, torch.Tensor):
test_case.assertTrue(
torch.equal(param_a, param_b), f"Test failed for {param_name}"
)
elif param_name == "lightning_trainer__callbacks":
test_case.assertIsInstance(param_b, Sequence)
for i, callback in enumerate(param_a):
test_case.assertIsInstance(callback, type(param_b[i]))
else:
test_case.assertEqual(param_a, param_b, f"Test failed for {param_name}")
Loading
Loading