-
Notifications
You must be signed in to change notification settings - Fork 1
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* Change kwargs argument name, unify 'layer' argument name passed to explainers * Fix docstrings of create_explainer function * Add Saliency algorithm * Add DeepLIFT algorithm * Add DeepLIFT SHAP algorithm * Add Deconvolution algorithm * Add Input x Gradient algorithm * Add Layer Conductance algorithm * Add standardization of matrices, generate figure with respect to presence of negative attributions * Revert visualization error handling due to #12 PR * Mark model argument type with TODO * Remove incorrect standardization method * Bump patch version
- Loading branch information
1 parent
8a26adf
commit 4b54012
Showing
18 changed files
with
571 additions
and
40 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,69 @@ | ||
"""File with Conductance algorithm explainer classes.""" | ||
|
||
from typing import Optional | ||
|
||
import torch | ||
from captum.attr import LayerConductance | ||
|
||
from autoxai.explainer.base_explainer import CVExplainer | ||
|
||
|
||
class LayerConductanceCVExplainer(CVExplainer): | ||
"""Layer Conductance algorithm explainer.""" | ||
|
||
def create_explainer(self, **kwargs) -> LayerConductance: | ||
"""Create explainer object. | ||
Raises: | ||
RuntimeError: When passed arguments are invalid. | ||
Returns: | ||
Explainer object. | ||
""" | ||
model: Optional[torch.nn.Module] = kwargs.get("model", None) | ||
layer: Optional[torch.nn.Module] = kwargs.get("layer", None) | ||
if model is None or layer is None: | ||
raise RuntimeError( | ||
f"Missing or `None` arguments `model` or `layer` passed: {kwargs}" | ||
) | ||
|
||
conductance = LayerConductance(forward_func=model, layer=layer) | ||
|
||
return conductance | ||
|
||
def calculate_features( | ||
self, | ||
model: torch.nn.Module, | ||
input_data: torch.Tensor, | ||
pred_label_idx: int, | ||
**kwargs, | ||
) -> torch.Tensor: | ||
"""Generate features image with Layer Conductance algorithm explainer. | ||
Args: | ||
model: Any DNN model You want to use. | ||
input_data: Input image. | ||
pred_label_idx: Predicted label. | ||
Returns: | ||
Features matrix. | ||
""" | ||
layer: Optional[torch.nn.Module] = kwargs.get("layer", None) | ||
|
||
conductance = self.create_explainer(model=model, layer=layer) | ||
attributions = conductance.attribute( | ||
input_data, | ||
baselines=torch.rand( # pylint: disable = (no-member) | ||
1, | ||
input_data.shape[1], | ||
input_data.shape[2], | ||
input_data.shape[3], | ||
), | ||
target=pred_label_idx, | ||
) | ||
if attributions.shape[0] == 0: | ||
raise RuntimeError( | ||
"Error occured during attribution calculation. " | ||
+ "Make sure You are applying this method to CNN network.", | ||
) | ||
return attributions |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,78 @@ | ||
"""File with Deconvolution algorithm explainer classes.""" | ||
|
||
from abc import abstractmethod | ||
from typing import Optional, Union | ||
|
||
import torch | ||
from captum.attr import Deconvolution, NeuronDeconvolution | ||
|
||
from autoxai.explainer.base_explainer import CVExplainer | ||
from autoxai.explainer.model_utils import modify_modules | ||
|
||
|
||
class BaseDeconvolutionCVExplainer(CVExplainer): | ||
"""Base Deconvolution algorithm explainer.""" | ||
|
||
@abstractmethod | ||
def create_explainer(self, **kwargs) -> Union[Deconvolution, NeuronDeconvolution]: | ||
"""Create explainer object. | ||
Raises: | ||
RuntimeError: When passed arguments are invalid. | ||
Returns: | ||
Explainer object. | ||
""" | ||
|
||
def calculate_features( | ||
self, | ||
model: torch.nn.Module, | ||
input_data: torch.Tensor, | ||
pred_label_idx: int, | ||
**kwargs, | ||
) -> torch.Tensor: | ||
"""Generate features image with Deconvolution algorithm explainer. | ||
Args: | ||
model: Any DNN model You want to use. | ||
input_data: Input image. | ||
pred_label_idx: Predicted label. | ||
Returns: | ||
Features matrix. | ||
""" | ||
layer: Optional[torch.nn.Module] = kwargs.get("layer", None) | ||
|
||
deconv = self.create_explainer(model=model, layer=layer) | ||
attributions = deconv.attribute( | ||
input_data, | ||
target=pred_label_idx, | ||
) | ||
if attributions.shape[0] == 0: | ||
raise RuntimeError( | ||
"Error occured during attribution calculation. " | ||
+ "Make sure You are applying this method to CNN network.", | ||
) | ||
return attributions | ||
|
||
|
||
class DeconvolutionCVExplainer(BaseDeconvolutionCVExplainer): | ||
"""Base Deconvolution algorithm explainer.""" | ||
|
||
def create_explainer(self, **kwargs) -> Union[Deconvolution, NeuronDeconvolution]: | ||
"""Create explainer object. | ||
Raises: | ||
RuntimeError: When passed arguments are invalid. | ||
Returns: | ||
Explainer object. | ||
""" | ||
model: Optional[torch.nn.Module] = kwargs.get("model", None) | ||
if model is None: | ||
raise RuntimeError(f"Missing or `None` argument `model` passed: {kwargs}") | ||
|
||
model = modify_modules(model=model) | ||
deconv = Deconvolution(model=model) | ||
|
||
return deconv |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,102 @@ | ||
"""File with DeepLIFT algorithm explainer classes.""" | ||
|
||
from abc import abstractmethod | ||
from typing import Optional, Union | ||
|
||
import torch | ||
from captum.attr import DeepLift, LayerDeepLift | ||
|
||
from autoxai.explainer.base_explainer import CVExplainer | ||
from autoxai.explainer.model_utils import modify_modules | ||
|
||
|
||
class BaseDeepLIFTCVExplainer(CVExplainer): | ||
"""Base DeepLIFT algorithm explainer.""" | ||
|
||
@abstractmethod | ||
def create_explainer(self, **kwargs) -> Union[DeepLift, LayerDeepLift]: | ||
"""Create explainer object. | ||
Raises: | ||
RuntimeError: When passed arguments are invalid. | ||
Returns: | ||
Explainer object. | ||
""" | ||
|
||
def calculate_features( | ||
self, | ||
model: torch.nn.Module, | ||
input_data: torch.Tensor, | ||
pred_label_idx: int, | ||
**kwargs, | ||
) -> torch.Tensor: | ||
"""Generate features image with DeepLIFT algorithm explainer. | ||
Args: | ||
model: Any DNN model You want to use. | ||
input_data: Input image. | ||
pred_label_idx: Predicted label. | ||
Returns: | ||
Features matrix. | ||
""" | ||
layer: Optional[torch.nn.Module] = kwargs.get("layer", None) | ||
|
||
deeplift = self.create_explainer(model=model, layer=layer) | ||
|
||
attributions = deeplift.attribute( | ||
input_data, | ||
target=pred_label_idx, | ||
) | ||
if attributions.shape[0] == 0: | ||
raise RuntimeError( | ||
"Error occured during attribution calculation. " | ||
+ "Make sure You are applying this method to CNN network.", | ||
) | ||
return attributions | ||
|
||
|
||
class DeepLIFTCVExplainer(BaseDeepLIFTCVExplainer): | ||
"""DeepLIFTC algorithm explainer.""" | ||
|
||
def create_explainer(self, **kwargs) -> Union[DeepLift, LayerDeepLift]: | ||
"""Create explainer object. | ||
Raises: | ||
RuntimeError: When passed arguments are invalid. | ||
Returns: | ||
Explainer object. | ||
""" | ||
model: Optional[torch.nn.Module] = kwargs.get("model", None) | ||
if model is None: | ||
raise RuntimeError(f"Missing or `None` argument `model` passed: {kwargs}") | ||
|
||
model = modify_modules(model) | ||
|
||
return DeepLift(model=model) | ||
|
||
|
||
class LayerDeepLIFTCVExplainer(BaseDeepLIFTCVExplainer): | ||
"""Layer DeepLIFT algorithm explainer.""" | ||
|
||
def create_explainer(self, **kwargs) -> Union[DeepLift, LayerDeepLift]: | ||
"""Create explainer object. | ||
Raises: | ||
RuntimeError: When passed arguments are invalid. | ||
Returns: | ||
Explainer object. | ||
""" | ||
model: Optional[torch.nn.Module] = kwargs.get("model", None) | ||
layer: Optional[torch.nn.Module] = kwargs.get("layer", None) | ||
if model is None or layer is None: | ||
raise RuntimeError( | ||
f"Missing or `None` arguments `model` or `layer` passed: {kwargs}" | ||
) | ||
|
||
model = modify_modules(model) | ||
|
||
return LayerDeepLift(model=model, layer=layer) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,109 @@ | ||
"""File with DeepLIFT SHAP algorithm explainer classes.""" | ||
|
||
from abc import abstractmethod | ||
from typing import Optional, Union | ||
|
||
import torch | ||
from captum.attr import DeepLiftShap, LayerDeepLiftShap | ||
|
||
from autoxai.explainer.base_explainer import CVExplainer | ||
from autoxai.explainer.model_utils import modify_modules | ||
|
||
|
||
class BaseDeepLIFTSHAPCVExplainer(CVExplainer): | ||
"""Base DeepLIFT SHAP algorithm explainer.""" | ||
|
||
@abstractmethod | ||
def create_explainer(self, **kwargs) -> Union[DeepLiftShap, LayerDeepLiftShap]: | ||
"""Create explainer object. | ||
Raises: | ||
RuntimeError: When passed arguments are invalid. | ||
Returns: | ||
Explainer object. | ||
""" | ||
|
||
def calculate_features( | ||
self, | ||
model: torch.nn.Module, | ||
input_data: torch.Tensor, | ||
pred_label_idx: int, | ||
**kwargs, | ||
) -> torch.Tensor: | ||
"""Generate features image with DeepLIFT SHAP algorithm explainer. | ||
Args: | ||
model: Any DNN model You want to use. | ||
input_data: Input image. | ||
pred_label_idx: Predicted label. | ||
Returns: | ||
Features matrix. | ||
""" | ||
layer: Optional[torch.nn.Module] = kwargs.get("layer", None) | ||
number_of_samples: int = kwargs.get("number_of_samples", 100) | ||
|
||
deeplift = self.create_explainer(model=model, layer=layer) | ||
baselines = torch.randn( # pylint: disable = (no-member) | ||
number_of_samples, | ||
input_data.shape[1], | ||
input_data.shape[2], | ||
input_data.shape[3], | ||
) | ||
attributions = deeplift.attribute( | ||
input_data, | ||
target=pred_label_idx, | ||
baselines=baselines, | ||
) | ||
if attributions.shape[0] == 0: | ||
raise RuntimeError( | ||
"Error occured during attribution calculation. " | ||
+ "Make sure You are applying this method to CNN network.", | ||
) | ||
return attributions | ||
|
||
|
||
class DeepLIFTSHAPCVExplainer(BaseDeepLIFTSHAPCVExplainer): | ||
"""DeepLIFTC SHAP algorithm explainer.""" | ||
|
||
def create_explainer(self, **kwargs) -> Union[DeepLiftShap, LayerDeepLiftShap]: | ||
"""Create explainer object. | ||
Raises: | ||
RuntimeError: When passed arguments are invalid. | ||
Returns: | ||
Explainer object. | ||
""" | ||
model: Optional[torch.nn.Module] = kwargs.get("model", None) | ||
if model is None: | ||
raise RuntimeError(f"Missing or `None` argument `model` passed: {kwargs}") | ||
|
||
model = modify_modules(model) | ||
|
||
return DeepLiftShap(model=model) | ||
|
||
|
||
class LayerDeepLIFTSHAPCVExplainer(BaseDeepLIFTSHAPCVExplainer): | ||
"""Layer DeepLIFT SHAP algorithm explainer.""" | ||
|
||
def create_explainer(self, **kwargs) -> Union[DeepLiftShap, LayerDeepLiftShap]: | ||
"""Create explainer object. | ||
Raises: | ||
RuntimeError: When passed arguments are invalid. | ||
Returns: | ||
Explainer object. | ||
""" | ||
model: Optional[torch.nn.Module] = kwargs.get("model", None) | ||
layer: Optional[torch.nn.Module] = kwargs.get("layer", None) | ||
if model is None or layer is None: | ||
raise RuntimeError( | ||
f"Missing or `None` arguments `model` or `layer` passed: {kwargs}" | ||
) | ||
|
||
model = modify_modules(model) | ||
|
||
return LayerDeepLiftShap(model=model, layer=layer) |
Oops, something went wrong.