Skip to content

Commit

Permalink
Error handling sanitize bug (#85)
Browse files Browse the repository at this point in the history
* fix molsanitize exception error catching

* linting

* isort on other stuff
  • Loading branch information
frederik-sandfort1 authored and JochenSiegWork committed Sep 20, 2024
1 parent 18a2543 commit 91152e1
Show file tree
Hide file tree
Showing 5 changed files with 51 additions and 9 deletions.
6 changes: 2 additions & 4 deletions molpipeline/estimators/chemprop/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,7 @@

try:
from chemprop.data import MoleculeDataset, build_dataloader
from chemprop.nn.predictors import (
BinaryClassificationFFNBase,
)
from chemprop.nn.predictors import BinaryClassificationFFNBase
from lightning import pytorch as pl
except ImportError as error:
logger.error(
Expand All @@ -31,9 +29,9 @@
MPNN,
BinaryClassificationFFN,
BondMessagePassing,
MulticlassClassificationFFN,
RegressionFFN,
SumAggregation,
MulticlassClassificationFFN,
)
from molpipeline.estimators.chemprop.neural_fingerprint import ChempropNeuralFP

Expand Down
2 changes: 1 addition & 1 deletion molpipeline/pipeline/_molpipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -349,7 +349,7 @@ def transform_single(self, input_value: Any) -> Any:
elif isinstance(p_element, FilterReinserter):
iter_value = p_element.transform_single(iter_value)
except MolSanitizeException as err:
return InvalidInstance(
iter_value = InvalidInstance(
p_element.uuid,
f"RDKit MolSanitizeException: {err.args}",
p_element.name,
Expand Down
2 changes: 1 addition & 1 deletion test_extras/test_chemprop/test_chemprop_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,8 @@
from molpipeline.estimators.chemprop.models import (
ChempropClassifier,
ChempropModel,
ChempropRegressor,
ChempropMulticlassClassifier,
ChempropRegressor,
)
from molpipeline.mol2any.mol2chemprop import MolToChemprop
from molpipeline.pipeline import Pipeline
Expand Down
4 changes: 2 additions & 2 deletions test_extras/test_chemprop/test_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,10 +28,10 @@
# pylint: disable=relative-beyond-top-level
from test_extras.test_chemprop.chemprop_test_utils.compare_models import compare_params
from test_extras.test_chemprop.chemprop_test_utils.constant_vars import (
NO_IDENTITY_CHECK,
DEFAULT_SET_PARAMS,
DEFAULT_BINARY_CLASSIFICATION_PARAMS,
DEFAULT_MULTICLASS_CLASSIFICATION_PARAMS,
DEFAULT_SET_PARAMS,
NO_IDENTITY_CHECK,
)
from test_extras.test_chemprop.chemprop_test_utils.default_models import (
get_chemprop_model_binary_classification_mpnn,
Expand Down
46 changes: 45 additions & 1 deletion tests/test_elements/test_error_handling.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,16 @@
from typing import Any

import numpy as np
from rdkit import RDLogger
from rdkit import Chem, RDLogger
from rdkit.Chem.rdchem import MolSanitizeException
from sklearn.base import clone

from molpipeline import ErrorFilter, FilterReinserter, Pipeline, PostPredictionWrapper
from molpipeline.abstract_pipeline_elements.core import MolToMolPipelineElement
from molpipeline.any2mol import SmilesToMol
from molpipeline.any2mol.auto2mol import AutoToMol
from molpipeline.mol2any import MolToMorganFP, MolToRDKitPhysChem, MolToSmiles
from molpipeline.utils.molpipeline_types import OptionalMol, RDKitMol
from tests.utils.mock_element import MockTransformingPipelineElement

rdlog = RDLogger.logger()
Expand Down Expand Up @@ -247,3 +251,43 @@ def test_replace_mixed_datatypes_expected_failures(self) -> None:
self.assertRaises(ValueError, pipeline.fit, test_values)
self.assertRaises(ValueError, pipeline.transform, test_values)
self.assertRaises(ValueError, pipeline2.fit_transform, test_values)

def test_molsanitize_error(self) -> None:
"""Test if MolSanitizeException is caught and catched by ErrorFilter."""

class DummyMolSanitizeExc(MolToMolPipelineElement):
"""MolToMolPipelineElement with dummy molsanitize exception."""

def pretransform_single(self, value: RDKitMol) -> OptionalMol:
"""Dummy Mol.
Parameters
----------
value: RDKitMol
Molecule.
Returns
-------
OptionalMol
Molecule.
"""
if Chem.MolToSmiles(value) == "c1ccccc1":
raise MolSanitizeException("This is a dummy exception.")
return value

pipeline = Pipeline(
[
("autotosmiles", AutoToMol()),
("atomneutralizer", DummyMolSanitizeExc()),
("moltosmiles", MolToSmiles()),
("errorfilter", error_filter := ErrorFilter()),
(
"filterreinserter",
FilterReinserter.from_error_filter(error_filter, None),
),
],
n_jobs=-1,
)

result = pipeline.transform(["c1ccccc1", "CCCCCCC", "c1cc"])
self.assertEqual(result, [None, "CCCCCCC", None])

0 comments on commit 91152e1

Please sign in to comment.