From 994fc3e23fa5987404a65f4988245e67c9b3f2be Mon Sep 17 00:00:00 2001 From: Jochen Sieg Date: Mon, 25 Nov 2024 14:41:19 +0100 Subject: [PATCH] linitng --- tests/test_explainability/test_shap_explainers.py | 11 ++++++----- .../test_visualization/test_gaussian_grid.py | 8 +++++--- 2 files changed, 11 insertions(+), 8 deletions(-) diff --git a/tests/test_explainability/test_shap_explainers.py b/tests/test_explainability/test_shap_explainers.py index c9905903..36027763 100644 --- a/tests/test_explainability/test_shap_explainers.py +++ b/tests/test_explainability/test_shap_explainers.py @@ -76,7 +76,9 @@ def _construct_kernel_shap_kwargs( dict[str, Any] The kwargs for SHAPKernelExplainer """ - featurization_subpipeline = get_featurization_subpipeline(pipeline) + featurization_subpipeline = get_featurization_subpipeline( + pipeline, raise_not_found=True + ) data_transformed = featurization_subpipeline.transform(data) if scipy.sparse.issparse(data_transformed): data_transformed = data_transformed.toarray() @@ -151,8 +153,8 @@ def _test_valid_explanation( # SVC seems to be handled differently by SHAP. It returns only a one dimensional # feature array for binary classification. self.assertTrue( - (1,), explanation.prediction.shape - ) # type: ignore[union-attr] + (1,), explanation.prediction.shape # type: ignore[union-attr] + ) self.assertEqual( (nof_features,), explanation.feature_weights.shape # type: ignore[union-attr] ) @@ -165,7 +167,7 @@ def _test_valid_explanation( raise ValueError("Error in unittest. Unsupported estimator.") if issubclass(type(explainer), AtomExplanationMixin): - self.assertIsInstance(explanation.atom_weights, np.ndarray) + self.assertIsInstance(explanation.atom_weights, np.ndarray) # type: ignore[union-attr] self.assertEqual( explanation.atom_weights.shape, # type: ignore[union-attr] (explanation.molecule.GetNumAtoms(),), # type: ignore[union-attr] @@ -215,7 +217,6 @@ def test_explanations_fingerprint_pipeline(self) -> None: explanations = explainer.explain(TEST_SMILES) self.assertEqual(len(explanations), len(TEST_SMILES)) - self.assertTrue(explainer.has_atom_weights_) self.assertTrue( issubclass(explainer.return_element_type_, AtomExplanationMixin) ) diff --git a/tests/test_explainability/test_visualization/test_gaussian_grid.py b/tests/test_explainability/test_visualization/test_gaussian_grid.py index 40cdd20a..5ab84472 100644 --- a/tests/test_explainability/test_visualization/test_gaussian_grid.py +++ b/tests/test_explainability/test_visualization/test_gaussian_grid.py @@ -37,10 +37,12 @@ class TestSumOfGaussiansGrid(unittest.TestCase): @classmethod def setUpClass(cls) -> None: """Set up the tests.""" - cls.test_pipeline = _get_test_morgan_rf_pipeline() + cls.test_pipeline: Pipeline = _get_test_morgan_rf_pipeline() cls.test_pipeline.fit(TEST_SMILES, CONTAINS_OX) - cls.test_explainer = SHAPTreeExplainer(cls.test_pipeline) - cls.test_explanations = cls.test_explainer.explain(TEST_SMILES) + cls.test_explainer: SHAPTreeExplainer = SHAPTreeExplainer(cls.test_pipeline) + cls.test_explanations: list[SHAPFeatureAndAtomExplanation] = ( + cls.test_explainer.explain(TEST_SMILES) + ) def test_grid_with_shap_atom_weights(self) -> None: """Test grid with SHAP atom weights."""