Skip to content

Commit

Permalink
Add more algorithms (#8)
Browse files Browse the repository at this point in the history
* Change kwargs argument name, unify 'layer' argument name passed to explainers

* Fix docstrings of create_explainer function

* Add Saliency algorithm

* Add DeepLIFT algorithm

* Add DeepLIFT SHAP algorithm

* Add Deconvolution algorithm

* Add Input x Gradient algorithm

* Add Layer Conductance algorithm

* Add standardization of matrices, generate figure with respect to presence of negative attributions

* Revert visualization error handling due to #12 PR

* Mark model argument type with TODO

* Remove incorrect standardization method

* Bump patch version
  • Loading branch information
adamwawrzynski authored Dec 28, 2022
1 parent 8a26adf commit 4b54012
Show file tree
Hide file tree
Showing 18 changed files with 571 additions and 40 deletions.
2 changes: 1 addition & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ repos:
types_or: [ python, pyi ]
- id: pylint
name: pylint
entry: poetry run pylint --min-similarity-lines 6
entry: poetry run pylint --min-similarity-lines 8
language: system
types: [ python ]
require_serial: true
Expand Down
3 changes: 1 addition & 2 deletions autoxai/explainer/base_explainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ class CVExplainer(ABC):
def calculate_features(
self,
model: torch.nn.Module,
input_data: torch.Tensor,
input_data: torch.Tensor, # TODO: add more generic way of passing model inputs # pylint: disable = (fixme)
pred_label_idx: int,
**kwargs,
) -> torch.Tensor:
Expand Down Expand Up @@ -94,5 +94,4 @@ def visualize(
show_colorbar=True,
use_pyplot=False,
)

return figure
69 changes: 69 additions & 0 deletions autoxai/explainer/conductance.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
"""File with Conductance algorithm explainer classes."""

from typing import Optional

import torch
from captum.attr import LayerConductance

from autoxai.explainer.base_explainer import CVExplainer


class LayerConductanceCVExplainer(CVExplainer):
"""Layer Conductance algorithm explainer."""

def create_explainer(self, **kwargs) -> LayerConductance:
"""Create explainer object.
Raises:
RuntimeError: When passed arguments are invalid.
Returns:
Explainer object.
"""
model: Optional[torch.nn.Module] = kwargs.get("model", None)
layer: Optional[torch.nn.Module] = kwargs.get("layer", None)
if model is None or layer is None:
raise RuntimeError(
f"Missing or `None` arguments `model` or `layer` passed: {kwargs}"
)

conductance = LayerConductance(forward_func=model, layer=layer)

return conductance

def calculate_features(
self,
model: torch.nn.Module,
input_data: torch.Tensor,
pred_label_idx: int,
**kwargs,
) -> torch.Tensor:
"""Generate features image with Layer Conductance algorithm explainer.
Args:
model: Any DNN model You want to use.
input_data: Input image.
pred_label_idx: Predicted label.
Returns:
Features matrix.
"""
layer: Optional[torch.nn.Module] = kwargs.get("layer", None)

conductance = self.create_explainer(model=model, layer=layer)
attributions = conductance.attribute(
input_data,
baselines=torch.rand( # pylint: disable = (no-member)
1,
input_data.shape[1],
input_data.shape[2],
input_data.shape[3],
),
target=pred_label_idx,
)
if attributions.shape[0] == 0:
raise RuntimeError(
"Error occured during attribution calculation. "
+ "Make sure You are applying this method to CNN network.",
)
return attributions
78 changes: 78 additions & 0 deletions autoxai/explainer/deconv.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
"""File with Deconvolution algorithm explainer classes."""

from abc import abstractmethod
from typing import Optional, Union

import torch
from captum.attr import Deconvolution, NeuronDeconvolution

from autoxai.explainer.base_explainer import CVExplainer
from autoxai.explainer.model_utils import modify_modules


class BaseDeconvolutionCVExplainer(CVExplainer):
"""Base Deconvolution algorithm explainer."""

@abstractmethod
def create_explainer(self, **kwargs) -> Union[Deconvolution, NeuronDeconvolution]:
"""Create explainer object.
Raises:
RuntimeError: When passed arguments are invalid.
Returns:
Explainer object.
"""

def calculate_features(
self,
model: torch.nn.Module,
input_data: torch.Tensor,
pred_label_idx: int,
**kwargs,
) -> torch.Tensor:
"""Generate features image with Deconvolution algorithm explainer.
Args:
model: Any DNN model You want to use.
input_data: Input image.
pred_label_idx: Predicted label.
Returns:
Features matrix.
"""
layer: Optional[torch.nn.Module] = kwargs.get("layer", None)

deconv = self.create_explainer(model=model, layer=layer)
attributions = deconv.attribute(
input_data,
target=pred_label_idx,
)
if attributions.shape[0] == 0:
raise RuntimeError(
"Error occured during attribution calculation. "
+ "Make sure You are applying this method to CNN network.",
)
return attributions


class DeconvolutionCVExplainer(BaseDeconvolutionCVExplainer):
"""Base Deconvolution algorithm explainer."""

def create_explainer(self, **kwargs) -> Union[Deconvolution, NeuronDeconvolution]:
"""Create explainer object.
Raises:
RuntimeError: When passed arguments are invalid.
Returns:
Explainer object.
"""
model: Optional[torch.nn.Module] = kwargs.get("model", None)
if model is None:
raise RuntimeError(f"Missing or `None` argument `model` passed: {kwargs}")

model = modify_modules(model=model)
deconv = Deconvolution(model=model)

return deconv
102 changes: 102 additions & 0 deletions autoxai/explainer/deeplift.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,102 @@
"""File with DeepLIFT algorithm explainer classes."""

from abc import abstractmethod
from typing import Optional, Union

import torch
from captum.attr import DeepLift, LayerDeepLift

from autoxai.explainer.base_explainer import CVExplainer
from autoxai.explainer.model_utils import modify_modules


class BaseDeepLIFTCVExplainer(CVExplainer):
"""Base DeepLIFT algorithm explainer."""

@abstractmethod
def create_explainer(self, **kwargs) -> Union[DeepLift, LayerDeepLift]:
"""Create explainer object.
Raises:
RuntimeError: When passed arguments are invalid.
Returns:
Explainer object.
"""

def calculate_features(
self,
model: torch.nn.Module,
input_data: torch.Tensor,
pred_label_idx: int,
**kwargs,
) -> torch.Tensor:
"""Generate features image with DeepLIFT algorithm explainer.
Args:
model: Any DNN model You want to use.
input_data: Input image.
pred_label_idx: Predicted label.
Returns:
Features matrix.
"""
layer: Optional[torch.nn.Module] = kwargs.get("layer", None)

deeplift = self.create_explainer(model=model, layer=layer)

attributions = deeplift.attribute(
input_data,
target=pred_label_idx,
)
if attributions.shape[0] == 0:
raise RuntimeError(
"Error occured during attribution calculation. "
+ "Make sure You are applying this method to CNN network.",
)
return attributions


class DeepLIFTCVExplainer(BaseDeepLIFTCVExplainer):
"""DeepLIFTC algorithm explainer."""

def create_explainer(self, **kwargs) -> Union[DeepLift, LayerDeepLift]:
"""Create explainer object.
Raises:
RuntimeError: When passed arguments are invalid.
Returns:
Explainer object.
"""
model: Optional[torch.nn.Module] = kwargs.get("model", None)
if model is None:
raise RuntimeError(f"Missing or `None` argument `model` passed: {kwargs}")

model = modify_modules(model)

return DeepLift(model=model)


class LayerDeepLIFTCVExplainer(BaseDeepLIFTCVExplainer):
"""Layer DeepLIFT algorithm explainer."""

def create_explainer(self, **kwargs) -> Union[DeepLift, LayerDeepLift]:
"""Create explainer object.
Raises:
RuntimeError: When passed arguments are invalid.
Returns:
Explainer object.
"""
model: Optional[torch.nn.Module] = kwargs.get("model", None)
layer: Optional[torch.nn.Module] = kwargs.get("layer", None)
if model is None or layer is None:
raise RuntimeError(
f"Missing or `None` arguments `model` or `layer` passed: {kwargs}"
)

model = modify_modules(model)

return LayerDeepLift(model=model, layer=layer)
109 changes: 109 additions & 0 deletions autoxai/explainer/deeplift_shap.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,109 @@
"""File with DeepLIFT SHAP algorithm explainer classes."""

from abc import abstractmethod
from typing import Optional, Union

import torch
from captum.attr import DeepLiftShap, LayerDeepLiftShap

from autoxai.explainer.base_explainer import CVExplainer
from autoxai.explainer.model_utils import modify_modules


class BaseDeepLIFTSHAPCVExplainer(CVExplainer):
"""Base DeepLIFT SHAP algorithm explainer."""

@abstractmethod
def create_explainer(self, **kwargs) -> Union[DeepLiftShap, LayerDeepLiftShap]:
"""Create explainer object.
Raises:
RuntimeError: When passed arguments are invalid.
Returns:
Explainer object.
"""

def calculate_features(
self,
model: torch.nn.Module,
input_data: torch.Tensor,
pred_label_idx: int,
**kwargs,
) -> torch.Tensor:
"""Generate features image with DeepLIFT SHAP algorithm explainer.
Args:
model: Any DNN model You want to use.
input_data: Input image.
pred_label_idx: Predicted label.
Returns:
Features matrix.
"""
layer: Optional[torch.nn.Module] = kwargs.get("layer", None)
number_of_samples: int = kwargs.get("number_of_samples", 100)

deeplift = self.create_explainer(model=model, layer=layer)
baselines = torch.randn( # pylint: disable = (no-member)
number_of_samples,
input_data.shape[1],
input_data.shape[2],
input_data.shape[3],
)
attributions = deeplift.attribute(
input_data,
target=pred_label_idx,
baselines=baselines,
)
if attributions.shape[0] == 0:
raise RuntimeError(
"Error occured during attribution calculation. "
+ "Make sure You are applying this method to CNN network.",
)
return attributions


class DeepLIFTSHAPCVExplainer(BaseDeepLIFTSHAPCVExplainer):
"""DeepLIFTC SHAP algorithm explainer."""

def create_explainer(self, **kwargs) -> Union[DeepLiftShap, LayerDeepLiftShap]:
"""Create explainer object.
Raises:
RuntimeError: When passed arguments are invalid.
Returns:
Explainer object.
"""
model: Optional[torch.nn.Module] = kwargs.get("model", None)
if model is None:
raise RuntimeError(f"Missing or `None` argument `model` passed: {kwargs}")

model = modify_modules(model)

return DeepLiftShap(model=model)


class LayerDeepLIFTSHAPCVExplainer(BaseDeepLIFTSHAPCVExplainer):
"""Layer DeepLIFT SHAP algorithm explainer."""

def create_explainer(self, **kwargs) -> Union[DeepLiftShap, LayerDeepLiftShap]:
"""Create explainer object.
Raises:
RuntimeError: When passed arguments are invalid.
Returns:
Explainer object.
"""
model: Optional[torch.nn.Module] = kwargs.get("model", None)
layer: Optional[torch.nn.Module] = kwargs.get("layer", None)
if model is None or layer is None:
raise RuntimeError(
f"Missing or `None` arguments `model` or `layer` passed: {kwargs}"
)

model = modify_modules(model)

return LayerDeepLiftShap(model=model, layer=layer)
Loading

0 comments on commit 4b54012

Please sign in to comment.