Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

72 make chemprop multiclass classification model #73

Merged
merged 25 commits into from
Sep 5, 2024
Merged
Show file tree
Hide file tree
Changes from 17 commits
Commits
Show all changes
25 commits
Select commit Hold shift + click to select a range
737db7f
add multi class classifier
JenniferHem Aug 26, 2024
8692bc2
use input check to prevent confusing message by torch of the class la…
JenniferHem Aug 26, 2024
b560d18
remove get_params
JenniferHem Aug 26, 2024
b3d0af8
make n classes non-optional
JenniferHem Aug 26, 2024
fe40e4a
black
JenniferHem Aug 26, 2024
70e5928
ignore loghtning logs
JenniferHem Aug 27, 2024
3de60fd
add test for multiclass
JenniferHem Aug 27, 2024
3003493
mock data for test
JenniferHem Aug 27, 2024
1b84120
remove random write csv
JenniferHem Aug 27, 2024
3dbcff8
add test for full coverage of multiclass chemprop
JenniferHem Aug 28, 2024
dd0ebbe
add missing parameters for docsig
JenniferHem Aug 28, 2024
d744d4d
code review requests
JenniferHem Aug 28, 2024
e404579
Adapt Eror message
JenniferHem Aug 28, 2024
2b2d687
check classifier in init
JenniferHem Aug 28, 2024
f87d68b
docstring adaptations
JenniferHem Aug 28, 2024
7faedc1
fix docstings and naming in tests
JenniferHem Aug 28, 2024
4834614
split instace check from validation
JenniferHem Aug 30, 2024
261d7db
add test for set_params and initialize Multiclass FFN properlky
JenniferHem Sep 3, 2024
4e84111
raise attribute error if wrong model.predictor is passed
JenniferHem Sep 3, 2024
c5810fc
test multiclass setter and getter
JenniferHem Sep 3, 2024
a064100
pass correct tasks
JenniferHem Sep 3, 2024
33e1202
black
JenniferHem Sep 3, 2024
4a117d5
docsig and pydocstyle
JenniferHem Sep 3, 2024
fdb1d31
lint: docstrings and tests
JenniferHem Sep 3, 2024
eef3d22
missing space
JenniferHem Sep 4, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -3,4 +3,5 @@ __pycache__
molpipeline.egg-info/
lib/
build/
lightning_logs/
c-w-feldmann marked this conversation as resolved.
Show resolved Hide resolved

127 changes: 127 additions & 0 deletions molpipeline/estimators/chemprop/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -348,3 +348,130 @@ def __init__(
n_jobs=n_jobs,
**kwargs,
)


class ChempropMulticlassClassifier(ChempropModel):
"""Chemprop model with default parameters for multiclass classification tasks."""

def __init__(
self,
n_classes: int,
model: MPNN | None = None,
lightning_trainer: pl.Trainer | None = None,
batch_size: int = 64,
n_jobs: int = 1,
**kwargs: Any,
) -> None:
"""Initialize the chemprop multiclass model.

Parameters
----------
n_classes : int
The number of classes for the classifier.
model : MPNN | None, optional
The chemprop model to wrap. If None, a default model will be used.
lightning_trainer : pl.Trainer, optional
The lightning trainer to use, by default None
batch_size : int, optional (default=64)
The batch size to use.
n_jobs : int, optional (default=1)
The number of jobs to use.
kwargs : Any
Parameters set using `set_params`.
Can be used to modify components of the model.
"""
if model is None:
bond_encoder = BondMessagePassing()
agg = SumAggregation()
predictor = MulticlassClassificationFFN(n_classes=n_classes)
model = MPNN(message_passing=bond_encoder, agg=agg, predictor=predictor)
self.n_classes = n_classes
super().__init__(
model=model,
lightning_trainer=lightning_trainer,
batch_size=batch_size,
n_jobs=n_jobs,
**kwargs,
)
self._is_valid_multiclass_classifier()

def set_params(self, **params: Any) -> Self:
"""Set the parameters of the model and check if it is a multiclass classifier.

Parameters
----------
**params
The parameters to set.

Returns
-------
Self
The model with the new parameters.
"""
super().set_params(**params)
JochenSiegWork marked this conversation as resolved.
Show resolved Hide resolved
if not self._is_valid_multiclass_classifier():
raise ValueError(
"The model's predictor or the number of classes are invalid. Use a multiclass predictor and more than 2 classes."
)
return self

def fit(
self,
X: MoleculeDataset,
y: Sequence[int | float] | npt.NDArray[np.int_ | np.float64],
) -> Self:
"""Fit the model to the data.

Parameters
----------
X : MoleculeDataset
The input data.
y : Sequence[int | float] | npt.NDArray[np.int_ | np.float64]
The target data.

Returns
-------
Self
The fitted model.
"""
self._check_correct_input(y)
JochenSiegWork marked this conversation as resolved.
Show resolved Hide resolved
return super().fit(X, y)

def _check_correct_input(
self, y: Sequence[int | float] | npt.NDArray[np.int_ | np.float64]
) -> None:
"""Check if the input for the multi-class classifier is correct.

Parameters
----------
y : Sequence[int | float] | npt.NDArray[np.int_ | np.float64]
Indended classes for the dataset

Raises
------
ValueError
If the classes found in y are not matching n_classes or if the class labels do not start from 0 to n_classes-1.
"""
unique_y = np.unique(y)
log = []
if self.n_classes != len(unique_y):
log.append(
f"Given number of classes in init (n_classes) does not match the number of unique classes (found {unique_y}) in the target data."
)
if sorted(unique_y) != list(range(self.n_classes)):
JenniferHem marked this conversation as resolved.
Show resolved Hide resolved
err = f"Classes need to be in the range from 0 to {self.n_classes-1}. Found {unique_y}. Please correct the input data accordingly."
log.append(err)
if log:
raise ValueError("\n".join(log))

def _is_valid_multiclass_classifier(self) -> bool:
"""Check if a multiclass classifier is valid. Needs to be of the correct class and have more than 2 classes.

Returns
-------
bool
True if is a valid multiclass classifier, False otherwise.
"""
has_correct_class = self._is_multiclass_classifier()
has_classes = self.n_classes > 2
return has_correct_class and has_classes
82 changes: 81 additions & 1 deletion test_extras/test_chemprop/test_chemprop_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
ChempropClassifier,
ChempropModel,
ChempropRegressor,
ChempropMulticlassClassifier,
)
from molpipeline.mol2any.mol2chemprop import MolToChemprop
from molpipeline.pipeline import Pipeline
Expand Down Expand Up @@ -139,6 +140,40 @@ def get_classification_pipeline() -> Pipeline:
return model_pipeline


def get_multiclass_classification_pipeline(n_classes: int) -> Pipeline:
JenniferHem marked this conversation as resolved.
Show resolved Hide resolved
"""Get the Chemprop model pipeline for multiclass classification.

Parameters
----------
n_classes : int
The number of classes for model initialization.

Returns
-------
Pipeline
The Chemprop model pipeline for multiclass classification.
"""
smiles2mol = SmilesToMol()
mol2chemprop = MolToChemprop()
error_filter = ErrorFilter(filter_everything=True)
filter_reinserter = FilterReinserter.from_error_filter(
error_filter, fill_value=np.nan
)
chemprop_model = ChempropMulticlassClassifier(
n_classes=n_classes, lightning_trainer=DEFAULT_TRAINER
)
model_pipeline = Pipeline(
steps=[
("smiles2mol", smiles2mol),
("mol2chemprop", mol2chemprop),
("error_filter", error_filter),
("model", chemprop_model),
("filter_reinserter", PostPredictionWrapper(filter_reinserter)),
],
)
return model_pipeline


_T = TypeVar("_T")


Expand Down Expand Up @@ -282,7 +317,6 @@ def test_prediction(self) -> None:
molecule_net_bbbp_df = pd.read_csv(
TEST_DATA_DIR / "molecule_net_bbbp.tsv.gz", sep="\t", nrows=100
)
molecule_net_bbbp_df.to_csv("molecule_net_bbbp.tsv.gz", sep="\t", index=False)
classification_model = get_classification_pipeline()
classification_model.fit(
molecule_net_bbbp_df["smiles"].tolist(),
Expand All @@ -306,3 +340,49 @@ def test_prediction(self) -> None:

self.assertEqual(proba.shape, proba_copy.shape)
self.assertTrue(np.allclose(proba[~nan_indices], proba_copy[~nan_indices]))


class TestMulticlassClassificationPipeline(unittest.TestCase):
"""Test the Chemprop model pipeline for multiclass classification."""

def test_prediction(self) -> None:
"""Test the prediction of the multiclass classification model."""

test_data_df = pd.read_csv(
TEST_DATA_DIR / "multiclass_mock.tsv", sep="\t", index_col=False
)
classification_model = get_multiclass_classification_pipeline(n_classes=3)
mols = test_data_df["Molecule"].tolist()
classification_model.fit(
mols,
test_data_df["Label"].to_numpy(),
)
pred = classification_model.predict(mols)
proba = classification_model.predict_proba(mols)
self.assertEqual(len(pred), len(test_data_df))
self.assertEqual(proba.shape[1], 3)
self.assertEqual(proba.shape[0], len(test_data_df))

model_copy = joblib_dump_load(classification_model)
pred_copy = model_copy.predict(mols)
proba_copy = model_copy.predict_proba(mols)

nan_mask = np.isnan(pred)
self.assertListEqual(nan_mask.tolist(), np.isnan(pred_copy).tolist())
self.assertTrue(np.allclose(pred[~nan_mask], pred_copy[~nan_mask]))

self.assertEqual(proba.shape, proba_copy.shape)
JenniferHem marked this conversation as resolved.
Show resolved Hide resolved
self.assertEqual(pred.shape, pred_copy.shape)
self.assertTrue(np.allclose(proba[~nan_mask], proba_copy[~nan_mask]))

with self.assertRaises(ValueError):
classification_model.fit(
mols,
test_data_df["Label"].add(1).to_numpy(),
)
with self.assertRaises(ValueError):
JenniferHem marked this conversation as resolved.
Show resolved Hide resolved
classification_model = get_multiclass_classification_pipeline(n_classes=2)
classification_model.fit(
mols,
test_data_df["Label"].to_numpy(),
JenniferHem marked this conversation as resolved.
Show resolved Hide resolved
)
13 changes: 13 additions & 0 deletions tests/test_data/multiclass_mock.tsv
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
Molecule Label
"CCCCCC" 0
"CCCCCCCO" 1
"CCCC" 0
"CCCN" 2
"CCCCCC" 0
"CCCO" 1
"CCCCC" 0
"CCCCCN" 2
"CC(C)CCC" 0
"CCCCCCO" 1
"CCCCCl" 0
"CCC#N" 2
Loading