Skip to content

Commit

Permalink
Refactored - addresed some comments.
Browse files Browse the repository at this point in the history
  • Loading branch information
RobertSamoilescu committed Apr 19, 2023
1 parent c64f96c commit 94dc01a
Show file tree
Hide file tree
Showing 2 changed files with 57 additions and 26 deletions.
77 changes: 55 additions & 22 deletions alibi/explainers/similarity/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,15 @@
from typing import TYPE_CHECKING, Any, Callable, List, Optional, Tuple, Union

import numpy as np
from alibi.api.interfaces import Explainer
from alibi.explainers.similarity.backends import _select_backend
from alibi.utils.frameworks import Framework, has_pytorch, has_tensorflow
from alibi.utils.missing_optional_dependency import import_optional
from tqdm import tqdm
from typing_extensions import Literal

from alibi.api.interfaces import Explainer
from alibi.explainers.similarity.backends import _select_backend
from alibi.utils.frameworks import Framework
_TfTensor = import_optional('tensorflow', ['Tensor'])
_PtTensor = import_optional('torch', ['Tensor'])

if TYPE_CHECKING:
import tensorflow
Expand Down Expand Up @@ -90,13 +93,31 @@ def fit(self,
grads = []
X: Union[np.ndarray, List[Any]]
for X, Y in tqdm(zip(self.X_train, self.Y_train), disable=not self.verbose):
X = X[None] if isinstance(self.X_train, np.ndarray) else [X] # type: ignore[call-overload]
grad_X_train = self._compute_grad(X, Y[None])
grad_X_train = self._compute_grad(self._format(X), Y[None])
grads.append(grad_X_train[None])

self.grad_X_train = np.concatenate(grads, axis=0)
return self

@staticmethod
def _is_tensor(x: Any):
"""Checks if an obejct is a tensor."""
if has_tensorflow and isinstance(x, _TfTensor):
return True
if has_pytorch and isinstance(x, _PtTensor):
return True
if isinstance(x, np.ndarray):
return True
return False

@staticmethod
def _format(x: 'Union[np.ndarray, tensorflow.Tensor, torch.Tensor, Any]'
) -> 'Union[np.ndarray, tensorflow.Tensor, torch.Tensor, List[Any]]':
"""Adds batch dimension."""
if BaseSimilarityExplainer._is_tensor(x):
return x[None]
return [x]

def _verify_fit(self) -> None:
"""Verify that the explainer has been fitted.
Expand All @@ -105,14 +126,15 @@ def _verify_fit(self) -> None:
ValueError
If the explainer has not been fitted.
"""

if not hasattr(self, 'X_train') or not hasattr(self, 'Y_train'):
raise ValueError('Training data not set. Call `fit` and pass training data first.')

def _match_shape_to_data(self,
data: 'Union[np.ndarray, tensorflow.Tensor, torch.Tensor]',
target_type: Literal['X', 'Y']) -> 'Union[np.ndarray, tensorflow.Tensor, torch.Tensor]':
"""Verify the shape of `data` against the shape of the training data.
data: 'Union[np.ndarray, tensorflow.Tensor, torch.Tensor, Any, List[Any]]',
target_type: Literal['X', 'Y']
) -> 'Union[np.ndarray, tensorflow.Tensor, torch.Tensor, List[Any]]':
"""
Verify the shape of `data` against the shape of the training data.
Used to ensure input is correct shape for gradient methods implemented in the backends. `data` will be the
features or label of the instance being explained. If the `data` is not a batch, reshape to be a single batch
Expand All @@ -134,18 +156,30 @@ def _match_shape_to_data(self,
If the shape of `data` does not match the shape of the training data, or fit has not been called prior to
calling this method.
"""
if hasattr(data, 'shape'):
target_shape = getattr(self, f'{target_type}_dims')
if data.shape == target_shape:
data = data[None]
if data.shape[1:] != target_shape:
raise ValueError((f'Input `{target_type}` has shape {data.shape[1:]}'
f' but training data has shape {target_shape}'))
elif not isinstance(data, list):
data = [data]

if self._is_tensor(data):
return self._match_shape_to_data_tensor(data, target_type)
return self._match_shape_to_data_any(data)

def _match_shape_to_data_tensor(self,
data: 'Union[np.ndarray, tensorflow.Tensor, torch.Tensor]',
target_type: Literal['X', 'Y']
) -> 'Union[np.ndarray, tensorflow.Tensor, torch.Tensor]':
""" Verify the shape of `data` against the shape of the training data for tensor like data."""
target_shape = getattr(self, f'{target_type}_dims')
if data.shape == target_shape:
data = data[None]
if data.shape[1:] != target_shape:
raise ValueError((f'Input `{target_type}` has shape {data.shape[1:]}'
f' but training data has shape {target_shape}'))
return data

@staticmethod
def _match_shape_to_data_any(data: Union[Any, List[Any]]) -> list:
""" Ensures that any other data type is a list."""
if isinstance(data, list):
return data
return [data]

def _compute_adhoc_similarity(self, grad_X: np.ndarray) -> np.ndarray:
"""
Computes the similarity between the gradients of the test instances and all the training instances. The method
Expand All @@ -159,16 +193,15 @@ def _compute_adhoc_similarity(self, grad_X: np.ndarray) -> np.ndarray:
scores = np.zeros((len(grad_X), len(self.X_train)))
X: Union[np.ndarray, List[Any]]
for i, (X, Y) in tqdm(enumerate(zip(self.X_train, self.Y_train)), disable=not self.verbose):
X = X[None] if hasattr(self.X_train, 'shape') else [X] # type: ignore[call-overload]
grad_X_train = self._compute_grad(X, Y[None])
grad_X_train = self._compute_grad(self._format(X), Y[None])
scores[:, i] = self.sim_fn(grad_X, grad_X_train[None])[:, 0]
return scores

def _compute_grad(self,
X: 'Union[np.ndarray, tensorflow.Tensor, torch.Tensor, List[Any]]',
Y: 'Union[np.ndarray, tensorflow.Tensor, torch.Tensor]') \
-> np.ndarray:
""" Computes predictor parameter gradients and returns a flattened `numpy` array."""
"""Computes predictor parameter gradients and returns a flattened `numpy` array."""
X = self.backend.to_tensor(X) if isinstance(X, np.ndarray) else X
Y = self.backend.to_tensor(Y) if isinstance(Y, np.ndarray) else Y
return self.backend.get_grads(self.predictor, X, Y, self.loss_fn)
Expand Down
6 changes: 2 additions & 4 deletions alibi/explainers/similarity/grad.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,14 +10,13 @@
Union)

import numpy as np
from typing_extensions import Literal

from alibi.api.defaults import DEFAULT_DATA_SIM, DEFAULT_META_SIM
from alibi.api.interfaces import Explainer, Explanation
from alibi.explainers.similarity.base import BaseSimilarityExplainer
from alibi.explainers.similarity.metrics import asym_dot, cos, dot
from alibi.utils import _get_options_string
from alibi.utils.frameworks import Framework
from typing_extensions import Literal

if TYPE_CHECKING:
import tensorflow
Expand Down Expand Up @@ -250,8 +249,7 @@ def explain(
X, Y = self._preprocess_args(X, Y)
test_grads = []
for x, y in zip(X, Y):
x = x[None] if hasattr(x, 'shape') else [x]
test_grads.append(self._compute_grad(x, y[None])[None])
test_grads.append(self._compute_grad(self._format(x), y[None])[None])
grads_X_test = np.concatenate(np.array(test_grads), axis=0)
if not self.precompute_grads:
scores = self._compute_adhoc_similarity(grads_X_test)
Expand Down

0 comments on commit 94dc01a

Please sign in to comment.