Skip to content

Commit

Permalink
Christian comments 1
Browse files Browse the repository at this point in the history
  • Loading branch information
JochenSiegWork committed Dec 4, 2024
1 parent 964213b commit 52387f5
Show file tree
Hide file tree
Showing 15 changed files with 51 additions and 39 deletions.
Original file line number Diff line number Diff line change
@@ -1,19 +1,23 @@
"""Explainability module for the molpipeline package."""

from molpipeline.explainability.explainer import SHAPTreeExplainer
from molpipeline.explainability.explanation import (
from molpipeline.experimental.explainability.explainer import (
SHAPKernelExplainer,
SHAPTreeExplainer,
)
from molpipeline.experimental.explainability.explanation import (
SHAPFeatureAndAtomExplanation,
SHAPFeatureExplanation,
)
from molpipeline.explainability.visualization.visualization import (
from molpipeline.experimental.explainability.visualization.visualization import (
structure_heatmap,
structure_heatmap_shap,
)

__all__ = [
"SHAPTreeExplainer",
"SHAPKernelExplainer",
"SHAPFeatureExplanation",
"SHAPFeatureAndAtomExplanation",
"SHAPTreeExplainer",
"structure_heatmap",
"structure_heatmap_shap",
]
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@

from molpipeline import Pipeline
from molpipeline.abstract_pipeline_elements.core import OptionalMol
from molpipeline.explainability.explanation import (
from molpipeline.experimental.explainability.explanation import (
AtomExplanationMixin,
BondExplanationMixin,
FeatureExplanationMixin,
Expand All @@ -24,12 +24,13 @@
SHAPFeatureAndAtomExplanation,
SHAPFeatureExplanation,
)
from molpipeline.explainability.fingerprint_utils import fingerprint_shap_to_atomweights
from molpipeline.experimental.explainability.fingerprint_utils import (
fingerprint_shap_to_atomweights,
)
from molpipeline.mol2any import MolToMorganFP
from molpipeline.utils.subpipeline import SubpipelineExtractor, get_model_from_pipeline


# pylint: disable=C0103,W0613
def _to_dense(
feature_matrix: npt.NDArray[Any] | spmatrix,
) -> npt.NDArray[Any]:
Expand Down Expand Up @@ -170,13 +171,13 @@ def _convert_shap_feature_weights_to_atom_weights(
]


# pylint: disable=R0903
class AbstractSHAPExplainer(abc.ABC):
"""Abstract class for SHAP explainer objects."""

# pylint: disable=C0103,W0613
@abc.abstractmethod
def explain(self, X: Any, **kwargs: Any) -> _SHAPExplainer_return_type_:
def explain(
self, X: Any, **kwargs: Any
) -> _SHAPExplainer_return_type_: # pylint: disable=invalid-name,unused-argument
"""Explain the predictions for the input data.
Parameters
Expand All @@ -193,7 +194,6 @@ def explain(self, X: Any, **kwargs: Any) -> _SHAPExplainer_return_type_:
"""


# pylint: disable=R0903
class SHAPExplainerAdapter(AbstractSHAPExplainer, abc.ABC):
"""Adapter for SHAP explainer wrappers for handling molecules and pipelines."""

Expand Down Expand Up @@ -269,9 +269,10 @@ def _prediction_is_valid(prediction: Any) -> bool:

return True

# pylint: disable=C0103,W0613
@override
def explain(self, X: Any, **kwargs: Any) -> _SHAPExplainer_return_type_:
def explain(
self, X: Any, **kwargs: Any
) -> _SHAPExplainer_return_type_: # pylint: disable=invalid-name,unused-argument
"""Explain the predictions for the input data.
If the calculation of the SHAP values for an input sample fails, the explanation will be invalid.
Expand Down
File renamed without changes.
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,7 @@
import numpy.typing as npt


# pylint: disable=too-few-public-methods
class GaussFunctor2D:
class GaussFunctor2D: # pylint: disable=too-few-public-methods
"""2D Gaussian functor."""

def __init__(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -83,11 +83,10 @@ def get_color_map_from_input(
# read user definer color scheme as ColorMap
if color is None:
coolwarm = (
(0.017, 0.50, 0.850, 0.5),
(1.0, 1.0, 1.0, 0.5),
(1.0, 0.25, 0.0, 0.5),
(1.0, 1.0, 1.0, 0.5),
(0.017, 0.50, 0.850, 0.5),
)
coolwarm = (coolwarm[2], coolwarm[1], coolwarm[0])
color = coolwarm
if isinstance(color, Colormap):
color_map = color
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,20 +20,23 @@
from rdkit.Chem.Draw import rdMolDraw2D

from molpipeline.abstract_pipeline_elements.core import RDKitMol
from molpipeline.explainability.explanation import SHAPFeatureAndAtomExplanation
from molpipeline.explainability.visualization.gauss import GaussFunctor2D
from molpipeline.explainability.visualization.heatmaps import (
from molpipeline.experimental.explainability import (
SHAPFeatureAndAtomExplanation,
)

from molpipeline.experimental.explainability.visualization.gauss import GaussFunctor2D
from molpipeline.experimental.explainability.visualization.heatmaps import (
ValueGrid,
color_canvas,
get_color_normalizer_from_data,
)
from molpipeline.explainability.visualization.utils import (
RGBAtuple,
from molpipeline.experimental.explainability.visualization.utils import (
get_color_map_from_input,
get_mol_lims,
pad,
plt_to_pil,
to_png,
get_mol_lims,
pad,
RGBAtuple,
)


Expand Down Expand Up @@ -129,15 +132,14 @@ def _add_gaussians_for_atoms(
return v_map


# pylint: disable=too-many-locals
def _add_gaussians_for_bonds(
mol: Chem.Mol,
conf: Chem.Conformer,
v_map: ValueGrid,
bond_weights: npt.NDArray[np.float64],
bond_width: float,
bond_length: float,
) -> ValueGrid:
) -> ValueGrid: # pylint: disable=too-many-locals
"""Add Gauss-functions centered at bonds to the grid.
Parameters
Expand Down Expand Up @@ -475,7 +477,7 @@ def structure_heatmap_shap( # pylint: disable=too-many-branches
f"$Prediction = {explanation.prediction[-1]:.2f}$ ="
"\n"
"\n"
f" $expected \ value={explanation.expected_value[-1]:.2f}$ + " # noqa: W605 # pylint: disable=W1401
f" $expected \ value={explanation.expected_value[-1]:.2f}$ + " # noqa: W605 # pylint: disable=anomalous-backslash-in-string
f"$features_{{present}}= {sum_present_shap:.2f}$ + "
f"$features_{{absent}}={sum_absent_shap:.2f}$"
)
Expand Down
File renamed without changes.
Original file line number Diff line number Diff line change
Expand Up @@ -18,20 +18,23 @@
from molpipeline import ErrorFilter, FilterReinserter, Pipeline, PostPredictionWrapper
from molpipeline.abstract_pipeline_elements.core import RDKitMol
from molpipeline.any2mol import SmilesToMol
from molpipeline.explainability.explainer import SHAPKernelExplainer, SHAPTreeExplainer
from molpipeline.explainability.explanation import (
AtomExplanationMixin,
from molpipeline.experimental.explainability import (
SHAPFeatureAndAtomExplanation,
SHAPFeatureExplanation,
SHAPTreeExplainer,
SHAPKernelExplainer,
)
from molpipeline.experimental.explainability.explanation import AtomExplanationMixin
from molpipeline.mol2any import (
MolToConcatenatedVector,
MolToMorganFP,
MolToRDKitPhysChem,
)
from molpipeline.mol2mol import SaltRemover
from molpipeline.utils.subpipeline import SubpipelineExtractor
from tests.test_explainability.utils import construct_kernel_shap_kwargs
from tests.test_experimental.test_explainability.utils import (
construct_kernel_shap_kwargs,
)

TEST_SMILES = ["CC", "CCO", "COC", "c1ccccc1(N)", "CCC(-O)O", "CCCN"]
CONTAINS_OX = [0, 1, 1, 0, 1, 0]
Expand Down Expand Up @@ -171,7 +174,9 @@ def test_explanations_fingerprint_pipeline( # pylint: disable=too-many-locals
]
explainer_estimators = [tree_estimators + other_estimators, tree_estimators]

for estimators, explainer_type in zip(explainer_estimators, explainer_types):
for estimators, explainer_type in zip(
explainer_estimators, explainer_types, strict=True
):

# test explanations with different estimators
for estimator in estimators:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,15 +8,15 @@
from rdkit.Chem import Draw

from molpipeline import Pipeline
from molpipeline.explainability import (
from molpipeline.experimental.explainability import (
SHAPFeatureAndAtomExplanation,
SHAPFeatureExplanation,
SHAPTreeExplainer,
)
from molpipeline.explainability.visualization.visualization import (
from molpipeline.experimental.explainability.visualization.visualization import (
make_sum_of_gaussians_grid,
)
from tests.test_explainability.test_visualization.test_visualization import (
from tests.test_experimental.test_explainability.test_visualization.test_visualization import (
_get_test_morgan_rf_pipeline,
)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,16 +9,18 @@

from molpipeline import Pipeline
from molpipeline.any2mol import SmilesToMol
from molpipeline.explainability import (
from molpipeline.experimental.explainability import (
SHAPFeatureAndAtomExplanation,
SHAPFeatureExplanation,
SHAPKernelExplainer,
SHAPTreeExplainer,
structure_heatmap,
structure_heatmap_shap,
)
from molpipeline.explainability.explainer import SHAPKernelExplainer
from molpipeline.mol2any import MolToMorganFP
from tests.test_explainability.utils import construct_kernel_shap_kwargs
from tests.test_experimental.test_explainability.utils import (
construct_kernel_shap_kwargs,
)

TEST_SMILES = ["CC", "CCO", "COC", "c1ccccc1(N)", "CCC(-O)O", "CCCN"]
CONTAINS_OX = [0, 1, 1, 0, 1, 0] # classification labels
Expand Down
File renamed without changes.

0 comments on commit 52387f5

Please sign in to comment.