Skip to content

Commit

Permalink
linting
Browse files Browse the repository at this point in the history
  • Loading branch information
JochenSiegWork committed Nov 25, 2024
1 parent 994fc3e commit e0c3d3c
Show file tree
Hide file tree
Showing 3 changed files with 16 additions and 12 deletions.
6 changes: 4 additions & 2 deletions tests/test_explainability/test_shap_explainers.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ def _construct_kernel_shap_kwargs(
featurization_subpipeline = get_featurization_subpipeline(
pipeline, raise_not_found=True
)
data_transformed = featurization_subpipeline.transform(data)
data_transformed = featurization_subpipeline.transform(data) # type: ignore[union-attr]
if scipy.sparse.issparse(data_transformed):
data_transformed = data_transformed.toarray()
return {"data": data_transformed}
Expand Down Expand Up @@ -213,7 +213,9 @@ def test_explanations_fingerprint_pipeline(self) -> None:
pipeline, TEST_SMILES
)

explainer = explainer_type(pipeline, **explainer_kwargs)
explainer: SHAPKernelExplainer | SHAPTreeExplainer = explainer_type(
pipeline, **explainer_kwargs
)
explanations = explainer.explain(TEST_SMILES)
self.assertEqual(len(explanations), len(TEST_SMILES))

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -48,13 +48,13 @@ def test_grid_with_shap_atom_weights(self) -> None:
"""Test grid with SHAP atom weights."""
for explanation in self.test_explanations:
self.assertTrue(explanation.is_valid())
self.assertIsInstance(explanation.atom_weights, np.ndarray)
self.assertIsInstance(explanation.atom_weights, np.ndarray) # type: ignore[union-attr]

mol_copy = Chem.Mol(explanation.molecule)
mol_copy = Draw.PrepareMolForDrawing(mol_copy)
value_grid = make_sum_of_gaussians_grid(
mol_copy,
atom_weights=explanation.atom_weights,
atom_weights=explanation.atom_weights, # type: ignore[union-attr]
atom_width=np.inf,
grid_resolution=[64, 64],
padding=[0.4, 0.4],
Expand Down
18 changes: 10 additions & 8 deletions tests/test_explainability/test_visualization/test_visualization.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,9 +53,9 @@ class TestExplainabilityVisualization(unittest.TestCase):

test_pipeline: ClassVar[Pipeline]
test_tree_explainer: ClassVar[SHAPTreeExplainer]
test_tree_explanations: ClassVar[
list[SHAPFeatureAndAtomExplanation] | list[SHAPFeatureExplanation]
]
test_tree_explanations: ClassVar[list[SHAPFeatureAndAtomExplanation]]
test_kernel_explainer: ClassVar[SHAPKernelExplainer]
test_kernel_explanations: ClassVar[list[SHAPFeatureAndAtomExplanation]]

@classmethod
def setUpClass(cls) -> None:
Expand All @@ -68,15 +68,17 @@ def setUpClass(cls) -> None:
cls.test_tree_explanations = cls.test_tree_explainer.explain(TEST_SMILES)

# kernel explainer
featurization_subpipeline = get_featurization_subpipeline(cls.test_pipeline)
X_data_transformed = featurization_subpipeline.transform(TEST_SMILES)
if scipy.sparse.issparse(X_data_transformed):
featurization_subpipeline = get_featurization_subpipeline(
cls.test_pipeline, raise_not_found=True
)
data_transformed = featurization_subpipeline.transform(TEST_SMILES) # type: ignore[union-attr]
if scipy.sparse.issparse(data_transformed):
# convert sparse matrix to dense array because SHAPKernelExplainer
# does not support sparse matrix as `data` and then explain dense matrices.
# We stick to dense matrices for simplicity.
X_data_transformed = X_data_transformed.toarray()
data_transformed = data_transformed.toarray()
cls.test_kernel_explainer = SHAPKernelExplainer(
cls.test_pipeline, data=X_data_transformed
cls.test_pipeline, data=data_transformed
)
cls.test_kernel_explanations = cls.test_kernel_explainer.explain(TEST_SMILES)

Expand Down

0 comments on commit e0c3d3c

Please sign in to comment.