Skip to content

Commit

Permalink
linitng
Browse files Browse the repository at this point in the history
  • Loading branch information
JochenSiegWork committed Nov 25, 2024
1 parent fe50d74 commit 994fc3e
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 8 deletions.
11 changes: 6 additions & 5 deletions tests/test_explainability/test_shap_explainers.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,9 @@ def _construct_kernel_shap_kwargs(
dict[str, Any]
The kwargs for SHAPKernelExplainer
"""
featurization_subpipeline = get_featurization_subpipeline(pipeline)
featurization_subpipeline = get_featurization_subpipeline(
pipeline, raise_not_found=True
)
data_transformed = featurization_subpipeline.transform(data)
if scipy.sparse.issparse(data_transformed):
data_transformed = data_transformed.toarray()
Expand Down Expand Up @@ -151,8 +153,8 @@ def _test_valid_explanation(
# SVC seems to be handled differently by SHAP. It returns only a one dimensional
# feature array for binary classification.
self.assertTrue(
(1,), explanation.prediction.shape
) # type: ignore[union-attr]
(1,), explanation.prediction.shape # type: ignore[union-attr]
)
self.assertEqual(
(nof_features,), explanation.feature_weights.shape # type: ignore[union-attr]
)
Expand All @@ -165,7 +167,7 @@ def _test_valid_explanation(
raise ValueError("Error in unittest. Unsupported estimator.")

if issubclass(type(explainer), AtomExplanationMixin):
self.assertIsInstance(explanation.atom_weights, np.ndarray)
self.assertIsInstance(explanation.atom_weights, np.ndarray) # type: ignore[union-attr]
self.assertEqual(
explanation.atom_weights.shape, # type: ignore[union-attr]
(explanation.molecule.GetNumAtoms(),), # type: ignore[union-attr]
Expand Down Expand Up @@ -215,7 +217,6 @@ def test_explanations_fingerprint_pipeline(self) -> None:
explanations = explainer.explain(TEST_SMILES)
self.assertEqual(len(explanations), len(TEST_SMILES))

self.assertTrue(explainer.has_atom_weights_)
self.assertTrue(
issubclass(explainer.return_element_type_, AtomExplanationMixin)
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,10 +37,12 @@ class TestSumOfGaussiansGrid(unittest.TestCase):
@classmethod
def setUpClass(cls) -> None:
"""Set up the tests."""
cls.test_pipeline = _get_test_morgan_rf_pipeline()
cls.test_pipeline: Pipeline = _get_test_morgan_rf_pipeline()
cls.test_pipeline.fit(TEST_SMILES, CONTAINS_OX)
cls.test_explainer = SHAPTreeExplainer(cls.test_pipeline)
cls.test_explanations = cls.test_explainer.explain(TEST_SMILES)
cls.test_explainer: SHAPTreeExplainer = SHAPTreeExplainer(cls.test_pipeline)
cls.test_explanations: list[SHAPFeatureAndAtomExplanation] = (
cls.test_explainer.explain(TEST_SMILES)
)

def test_grid_with_shap_atom_weights(self) -> None:
"""Test grid with SHAP atom weights."""
Expand Down

0 comments on commit 994fc3e

Please sign in to comment.