Skip to content

Commit

Permalink
add shapiq backend to ShapDP
Browse files Browse the repository at this point in the history
  • Loading branch information
hbaniecki committed Feb 14, 2025
1 parent f3a7c37 commit 74b18bd
Show file tree
Hide file tree
Showing 6 changed files with 41 additions and 11 deletions.
36 changes: 30 additions & 6 deletions effector/global_effect_shap.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from effector.global_effect import GlobalEffectBase
import numpy as np
import shap
import scipy
import shapiq
import effector.axis_partitioning as ap
import effector.utils as utils
from scipy.interpolate import interp1d
Expand All @@ -21,6 +21,7 @@ def __init__(
feature_names: Optional[List[str]] = None,
target_name: Optional[str] = None,
shap_values: Optional[np.ndarray] = None,
backend: str = "shap",
):
"""
Constructor of the SHAPDependence class.
Expand Down Expand Up @@ -70,7 +71,7 @@ def __init__(
nof_instances: maximum number of instances to be used for SHAP estimation.
- use "all", for using all instances.
- use `"all"`, for using all instances.
- use an `int`, for using `nof_instances` instances.
avg_output: The average output of the model.
Expand All @@ -92,8 +93,14 @@ def __init__(
- if shap values are already computed, they can be passed here
- if `None`, the SHAP values will be computed using the `shap` package
backend: Package to compute SHAP values
- use `"shap"` for the `shap` package (default)
- use `"shapiq"` for the `shapiq` package
"""
self.shap_values = shap_values if shap_values is not None else None
self.backend = backend
super(ShapDP, self).__init__(
"SHAP DP",
data,
Expand All @@ -112,13 +119,25 @@ def _fit_feature(
binning_method: Union[str, ap.Greedy, ap.Fixed] = "greedy",
centering: typing.Union[bool, str] = False,
points_for_centering: int = 30,
budget: int = 512
) -> typing.Dict:

data = self.data
if self.shap_values is None:
shap_explainer = shap.Explainer(self.model, data)
shap_explainer_fitted = shap_explainer(data)
self.shap_values = shap_explainer_fitted.values
if self.backend == "shap":
# by default shap uses 'permutation' with 100 background samples where
# either max_evals = 500 or max_evals = "auto" := 10 * 2 * (nfeatures + 1)
# (or 'exact' when <=10 features)
explainer = shap.Explainer(self.model, data)
explanation = explainer(data, max_evals=budget)
self.shap_values = explanation.values
elif self.backend == "shapiq":
# by default shapiq uses 'kernelshap' with 100 background samples where budget := 2 ** nfeatures
explainer = shapiq.Explainer(self.model, data, index="SV", max_order=1, approximator="permutation", imputer="marginal")
explanations = explainer.explain_X(data, budget=budget)
self.shap_values = np.stack([ex.get_n_order_values(1) for ex in explanations])
else:
raise ValueError("`backend` should be either 'shap' or 'shapiq'")

# extract x and y
yy = self.shap_values[:, feature]
Expand Down Expand Up @@ -183,6 +202,7 @@ def fit(
centering: Union[bool, str] = True,
points_for_centering: Union[int, str] = 30,
binning_method: Union[str, ap.Greedy, ap.Fixed] = "greedy",
budget: int = 512
) -> None:
"""Fit the SHAP Dependence Plot to the data.
Expand All @@ -205,6 +225,10 @@ def fit(
- If set to `all`, all the dataset points will be used.
budget: Budget to use for the approximation. Defaults to 512.
- Increasing the budget improves the approximation at the cost of slower computation.
- Decrease the budget for faster computation at the cost of approximation error.
Notes:
SHAP values are by default centered, i.e., $\sum_{i=1}^N \hat{\phi}_j(x_j^i) = 0$. This does not mean that the SHAP _curve_ is centered around zero; this happens only if the $s$-th feature of the dataset instances, i.e., the set $\{x_s^i\}_{i=1}^N$ is uniformly distributed along the $s$-th axis. So, use:
Expand All @@ -224,7 +248,7 @@ def fit(
# new implementation
for s in features:
self.feature_effect["feature_" + str(s)] = self._fit_feature(
s, binning_method, centering, points_for_centering,
s, binning_method, centering, points_for_centering, budget
)
self.is_fitted[s] = True
self.fit_args["feature_" + str(s)] = {
Expand Down
1 change: 1 addition & 0 deletions requirements-dev.txt
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ matplotlib
scipy
pytest
shap
shapiq
build
mkdocs
mkdocstrings
Expand Down
1 change: 1 addition & 0 deletions requirements-test.txt
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ matplotlib
scipy
pytest
shap
shapiq
build
mkdocs
mkdocstrings
Expand Down
3 changes: 2 additions & 1 deletion tests/test_functional_gam.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,8 @@ def test_gam():
{"method": effector.PDP, "kwargs": {}},
{"method": effector.DerPDP, "kwargs": {"model_jac": None}},
{"method": effector.DerPDP, "kwargs": {"model_jac": model_jac}},
{"method": effector.ShapDP, "kwargs": {}},
{"method": effector.ShapDP, "kwargs": {"backend": "shap"}},
{"method": effector.ShapDP, "kwargs": {"backend": "shapiq"}},
{"method": effector.ALE, "kwargs": {}},
{"method": effector.RHALE, "kwargs": {"model_jac": None}},
{"method": effector.RHALE, "kwargs": {"model_jac": model_jac}}
Expand Down
3 changes: 2 additions & 1 deletion tests/test_functional_linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,8 @@ def test_linear():
{"method": effector.PDP, "init_kwargs": {}},
{"method": effector.DerPDP, "init_kwargs": {}},
{"method": effector.DerPDP, "init_kwargs": {"model_jac": model_jac}},
{"method": effector.ShapDP, "init_kwargs": {}},
{"method": effector.ShapDP, "init_kwargs": {"backend": "shap"}},
{"method": effector.ShapDP, "init_kwargs": {"backend": "shapiq"}},
{"method": effector.ALE, "init_kwargs": {}},
{"method": effector.RHALE, "init_kwargs": {}},
{"method": effector.RHALE, "init_kwargs": {"model_jac": model_jac}}
Expand Down
8 changes: 5 additions & 3 deletions tests/test_regional_methods.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ def model_jac(x):
return y

xs = np.linspace(-1, 1, T)
methods = ["pdp", "d-pdp", "ale", "rhale", "shap"]
methods = ["pdp", "d-pdp", "ale", "rhale", "shap", "shapiq"]
for method in methods:
if method == "pdp":
reg_eff = effector.RegionalPDP(data, model, nof_instances=1000)
Expand All @@ -40,8 +40,10 @@ def model_jac(x):
reg_eff = effector.RegionalALE(data, model, nof_instances=1000)
elif method == "rhale":
reg_eff = effector.RegionalRHALE(data, model, model_jac, nof_instances=1000)
else:
reg_eff = effector.RegionalShapDP(data, model, nof_instances=1000)
elif method == "shap":
reg_eff = effector.RegionalShapDP(data, model, nof_instances=1000, backend="shap")
elif method == "shapiq":
reg_eff = effector.RegionalShapDP(data, model, nof_instances=1000, backend="shapiq")

reg_eff.fit(0)

Expand Down

0 comments on commit 74b18bd

Please sign in to comment.