Skip to content

Commit

Permalink
adds the TabPFNExplainer and closes #301
Browse files Browse the repository at this point in the history
  • Loading branch information
mmschlk committed Jan 14, 2025
1 parent f5f08a1 commit 01d827e
Show file tree
Hide file tree
Showing 14 changed files with 400 additions and 41 deletions.
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -21,3 +21,4 @@ xgboost==2.1.3
numpy==1.26.4
requests==2.32.3
lightgbm==4.5.0
tabpfn==2.0.3
6 changes: 4 additions & 2 deletions shapiq/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,14 +39,14 @@
from .datasets import load_adult_census, load_bike_sharing, load_california_housing

# explainer classes
from .explainer import Explainer, TabularExplainer, TreeExplainer
from .explainer import Explainer, TabPFNExplainer, TabularExplainer, TreeExplainer

# exact computer classes
from .game_theory.exact import ExactComputer

# game classes
# imputer classes
from .games import BaselineImputer, ConditionalImputer, Game, MarginalImputer
from .games import BaselineImputer, ConditionalImputer, Game, MarginalImputer, TabPFNImputer

# base classes
from .interaction_values import InteractionValues
Expand Down Expand Up @@ -97,10 +97,12 @@
"Explainer",
"TabularExplainer",
"TreeExplainer",
"TabPFNExplainer",
# imputers
"MarginalImputer",
"BaselineImputer",
"ConditionalImputer",
"TabPFNImputer",
# plots
"network_plot",
"stacked_bar_plot",
Expand Down
3 changes: 2 additions & 1 deletion shapiq/explainer/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
"""Explainer objects, including TreeSHAP-IQ."""

from ._base import Explainer
from .tabpfn import TabPFNExplainer
from .tabular import TabularExplainer
from .tree import TreeExplainer

__all__ = ["Explainer", "TabularExplainer", "TreeExplainer"]
__all__ = ["Explainer", "TabularExplainer", "TreeExplainer", "TabPFNExplainer"]
50 changes: 36 additions & 14 deletions shapiq/explainer/_base.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
"""The base Explainer classes for the shapiq package."""

from typing import Optional
from warnings import warn

import numpy as np

Expand Down Expand Up @@ -32,24 +33,14 @@ def __init__(
) -> None:

self._model_class = print_class(model)
self._predict_function, self._model_type = get_predict_function_and_model_type(
self._shapiq_predict_function, self._model_type = get_predict_function_and_model_type(
model, self._model_class, class_index
)
self.model = model

if data is not None:
if not isinstance(data, np.ndarray):
raise TypeError("`data` must be a NumPy array.")
try:
pred = self.predict(data)
if isinstance(pred, np.ndarray):
if len(pred.shape) > 1:
raise ValueError()
else:
raise ValueError()
except Exception as e:
print(f"Error: The `data` provided is not compatible with the model. {e}")
pass
if self._model_type != "tabpfn":
self._validate_data(data)
self.data = data

# not super()
Expand All @@ -59,6 +50,37 @@ def __init__(
self.__class__ = _explainer
_explainer.__init__(self, model=model, data=data, class_index=class_index, **kwargs)

def _validate_data(self, data: np.ndarray, raise_error: bool = False) -> None:
"""Validate the data for compatibility with the model.
Args:
data: A 2-dimensional matrix of inputs to be explained.
raise_error: Whether to raise an error if the data is not compatible with the model or
only print a warning. Defaults to ``False``.
Raises:
TypeError: If the data is not a NumPy array.
"""
message = "The `data` and the model must be compatible."
if not isinstance(data, np.ndarray):
message += " The `data` must be a NumPy array."
raise TypeError(message)
try:
# TODO (mmschlk): This can take a long time for large datasets and slow models
pred = self.predict(data)
if isinstance(pred, np.ndarray):
if len(pred.shape) > 1:
message += " The model's prediction must be a 1-dimensional array."
raise ValueError()
else:
message += " The model's prediction must be a NumPy array."
raise ValueError()
except Exception as e:
if raise_error:
raise ValueError(message) from e
else:
warn(message)

def explain(self, x: np.ndarray) -> InteractionValues:
"""Explain the model's prediction in terms of interaction values.
Expand Down Expand Up @@ -104,4 +126,4 @@ def predict(self, x: np.ndarray) -> np.ndarray:
Args:
x: An instance/point/sample/observation to be explained.
"""
return self._predict_function(self.model, x)
return self._shapiq_predict_function(self.model, x)
66 changes: 61 additions & 5 deletions shapiq/explainer/tabpfn.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

from ..approximator._base import Approximator
from .tabular import TabularExplainer
from .utils import ModelType, get_predict_function_and_model_type


class TabPFNExplainer(TabularExplainer):
Expand All @@ -21,6 +22,41 @@ class TabPFNExplainer(TabularExplainer):
Args:
model: Either a TabPFNClassifier or TabPFNRegressor model to be explained.
data: The background data to use for the explainer as a 2-dimensional array with shape
``(n_samples, n_features)``. This data is used to contextualize the model on.
labels: The labels for the background data as a 1-dimensional array with shape
``(n_samples,)``. This data is used to contextualize the model on.
index: The index to explain the model with. Defaults to ``"k-SII"`` which computes the
k-Shapley Interaction Index. If ``max_order`` is set to 1, this corresponds to the
Shapley value (``index="SV"``). Options are:
- ``"SV"``: Shapley value
- ``"k-SII"``: k-Shapley Interaction Index
- ``"FSII"``: Faithful Shapley Interaction Index
- ``"STII"``: Shapley Taylor Interaction Index
- ``"SII"``: Shapley Interaction Index
x_test: An optional test data set to compute the model's empty prediction (average
prediction) on. If no test data and ``empty_prediction`` is set to ``None`` the last
20% of the background data is used as test data and the remaining 80% as training data
for contextualization. Defaults to ``None``.
empty_prediction: Optional value for the model's average prediction on an empty data point
(all features missing). If provided, overrides parameters in ``x_test``. and skips the
computation of the empty prediction. Defaults to ``None``.
class_index: The class index of the model to explain. Defaults to ``None``, which will set
the class index to ``1`` per default for classification models and is ignored for
regression models.
approximator: The approximator to use for calculating the Shapley values or Shapley
interactions. Can be a string or an instance of an approximator. Defaults to ``"auto"``.
verbose: Whether to show a progress bar during the computation. Defaults to ``False``.
Note that verbosity can slow down the computation for large datasets.
References:
.. [1] Rundel, D., Kobialka, J., von Crailsheim, C., Feurer, M., Nagler, T., Rügamer, D. (2024). Interpretable Machine Learning for TabPFN. In: Longo, L., Lapuschkin, S., Seifert, C. (eds) Explainable Artificial Intelligence. xAI 2024. Communications in Computer and Information Science, vol 2154. Springer, Cham. https://doi.org/10.1007/978-3-031-63797-1_23
.. [2] Hollmann, N., Müller, S., Purucker, L. et al. Accurate predictions on small data with a tabular foundation model. Nature 637, 319–326 (2025). https://doi.org/10.1038/s41586-024-08328-6
Expand All @@ -29,24 +65,44 @@ class TabPFNExplainer(TabularExplainer):
def __init__(
self,
*,
model,
x_train,
y_train,
model: ModelType,
data: np.ndarray,
labels: np.ndarray,
index: str = "k-SII",
max_order: int = 2,
x_test: Optional[np.ndarray] = None,
empty_prediction: Optional[float] = None,
class_index: Optional[int] = None,
approximator: Union[str, Approximator] = "auto",
index: str = "k-SII",
max_order: int = 2,
verbose: bool = False,
):
from ..games.imputer.tabpfn_imputer import TabPFNImputer

_predict_function, _ = get_predict_function_and_model_type(model, class_index=class_index)
model._shapiq_predict_function = _predict_function

# check that data and labels have the same number of samples
if data.shape[0] != labels.shape[0]:
raise ValueError(
f"The number of samples in `data` and `labels` must be equal (got data.shape= "
f"{data.shape} and labels.shape={labels.shape})."
)
n_samples = data.shape[0]
x_train = data
y_train = labels

if x_test is None and empty_prediction is None:
sections = [int(0.8 * n_samples)]
x_train, x_test = np.split(data, sections)
y_train, _ = np.split(labels, sections)

imputer = TabPFNImputer(
model=model,
x_train=x_train,
y_train=y_train,
x_test=x_test,
empty_prediction=empty_prediction,
verbose=verbose,
)

super().__init__(
Expand Down
13 changes: 13 additions & 0 deletions shapiq/explainer/tabular.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import warnings
from typing import Optional, Union
from warnings import warn

import numpy as np

Expand Down Expand Up @@ -105,6 +106,7 @@ def __init__(
index: str = "k-SII",
max_order: int = 2,
random_state: Optional[int] = None,
verbose: bool = False,
**kwargs,
) -> None:
from shapiq.games.imputer import (
Expand All @@ -119,6 +121,16 @@ def __init__(

super().__init__(model, data, class_index)

# get class for self
class_name = self.__class__.__name__
if self._model_type == "tabpfn" and class_name == "TabularExplainer":
warn(
"You are using a TabPFN model with the ``shapiq.TabularExplainer`` directly. This "
"is not recommended as it uses missing value imputation and not contextualization. "
"Consider using the ``shapiq.TabPFNExplainer`` instead. For more information see "
"the documentation and the example notebooks."
)

self._random_state = random_state
if imputer == "marginal":
self._imputer = MarginalImputer(
Expand Down Expand Up @@ -146,6 +158,7 @@ def __init__(
f"object."
)
self._n_features: int = self.data.shape[1]
self._imputer.verbose = verbose # set the verbose flag for the imputer

self.index = index
self._max_order: int = max_order
Expand Down
35 changes: 33 additions & 2 deletions shapiq/explainer/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,17 +21,42 @@ def get_explainers() -> dict[str, Any]:
Returns:
A dictionary of all available explainer classes.
"""
from shapiq.explainer.tabpfn import TabPFNExplainer
from shapiq.explainer.tabular import TabularExplainer
from shapiq.explainer.tree.explainer import TreeExplainer

return {"tabular": TabularExplainer, "tree": TreeExplainer}
return {"tabular": TabularExplainer, "tree": TreeExplainer, "tabpfn": TabPFNExplainer}


def get_predict_function_and_model_type(
model: ModelType, model_class: str, class_index: Optional[int] = None
model: ModelType,
model_class: Optional[str] = None,
class_index: Optional[int] = None,
) -> tuple[Callable[[ModelType, np.ndarray], np.ndarray], str]:
"""Get the predict function and model type for a given model.
The prediction function is used in the explainer to predict the model's output for a given data
point. The function has the following signature: ``predict_function(model, data)``.
Args:
model: The model to explain. Can be any model object or callable function. We try to infer
the model type from the model object.
model_class: The class of the model. as a string. If not provided, it will be inferred from
the model object.
class_index: The class index of the model to explain. Defaults to ``None``, which will set
the class index to ``1`` per default for classification models and is ignored for
regression models.
Returns:
A tuple of the predict function and the model type.
"""
from . import tree

if model_class is None:
model_class = print_class(model)

_model_type = "tabular" # default
_predict_function = None

Expand Down Expand Up @@ -96,6 +121,12 @@ def get_predict_function_and_model_type(
_model_type = "tabular"
_predict_function = predict_tensorflow

if model_class in [
"tabpfn.classifier.TabPFNClassifier",
"tabpfn.regressor.TabPFNRegressor",
]:
_model_type = "tabpfn"

# default extraction (sklearn api)
if _predict_function is None and hasattr(model, "predict_proba"):
_predict_function = predict_proba
Expand Down
4 changes: 2 additions & 2 deletions shapiq/games/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,8 @@

# from . import benchmark # not imported here to avoid circular imports and long import times
from .base import Game
from .imputer import BaselineImputer, ConditionalImputer, MarginalImputer
from .imputer import BaselineImputer, ConditionalImputer, MarginalImputer, TabPFNImputer

__all__ = ["Game", "MarginalImputer", "ConditionalImputer", "BaselineImputer"]
__all__ = ["Game", "MarginalImputer", "ConditionalImputer", "BaselineImputer", "TabPFNImputer"]

# Path: shapiq/games/__init__.py
18 changes: 15 additions & 3 deletions shapiq/games/imputer/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,15 +15,23 @@ class Imputer(Game):
Args:
model: The model to explain as a callable function expecting a data points as input and
returning the model's predictions.
data: The background data to use for the explainer as a 2-dimensional array
with shape ``(n_samples, n_features)``.
x: The explanation point to use the imputer on either as a 2-dimensional array with
shape ``(1, n_features)`` or as a vector with shape ``(n_features,)``.
sample_size: The number of samples to draw from the background data. Defaults to ``100`` but
can is usually overwritten in the subclasses.
categorical_features: A list of indices of the categorical features in the background data.
random_state: The random state to use for sampling. Defaults to ``None``.
verbose: A flag to enable verbose imputation, which will print a progress bar for model
evaluation. Note that this can slow down the imputation process. Defaults to ``False``.
Attributes:
n_features: The number of features in the data (equals the number of players in the game).
data: The background data to use for the imputer.
Expand All @@ -45,11 +53,15 @@ def __init__(
sample_size: int = 100,
categorical_features: list[int] = None,
random_state: Optional[int] = None,
verbose: bool = False,
) -> None:
if callable(model) and not hasattr(model, "_predict_function"):
self._predict_function = utils.predict_callable
else: # shapiq.Explainer adds a predict function to the model to make it callable
self._predict_function = model._predict_function
# shapiq.Explainer adds a _shapiq_predict_function to the model to make it callable
elif hasattr(model, "_shapiq_predict_function"):
self._predict_function = model._shapiq_predict_function
else:
raise ValueError("The model must be callable or have a predict function.")
self.model = model
# check if data is a vector
if data.ndim == 1:
Expand All @@ -69,7 +81,7 @@ def __init__(

# init the game
# developer note: the normalization_value needs to be set in the subclass
super().__init__(n_players=self.n_features, normalize=False)
super().__init__(n_players=self.n_features, normalize=False, verbose=verbose)

@property
def x(self) -> Optional[np.ndarray]:
Expand Down
Loading

0 comments on commit 01d827e

Please sign in to comment.