Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Included ALE plot tests. #816

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
206 changes: 203 additions & 3 deletions alibi/explainers/tests/test_ale.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,16 @@
import pytest
from pytest_lazyfixture import lazy_fixture
import re
from copy import deepcopy

import matplotlib.pyplot as plt
import numpy as np
import pytest
from numpy.testing import assert_allclose
from alibi.explainers.ale import ale_num, adaptive_grid, get_quantiles, minimum_satisfied
from pytest_lazyfixture import lazy_fixture

from alibi.api.defaults import DEFAULT_DATA_ALE, DEFAULT_META_ALE
from alibi.api.interfaces import Explanation
from alibi.explainers.ale import (_plot_one_ale_num, adaptive_grid, ale_num,
get_quantiles, minimum_satisfied, plot_ale)


@pytest.mark.parametrize('min_bin_points', [1, 4, 10])
Expand Down Expand Up @@ -212,3 +219,196 @@ def test_grid_points_stress(num_bins, perc_bins, size_data, outside_grid):
selected_bins = np.unique(selected_bins)
expected_q = np.array([grid[selected_bins[0]]] + [grid[b + 1] for b in selected_bins])
np.testing.assert_allclose(q, expected_q)


@pytest.fixture(scope='module')
def explanation():
meta = deepcopy(DEFAULT_META_ALE)
data = deepcopy(DEFAULT_DATA_ALE)

params = {
'check_feature_resolution': True,
'low_resolution_threshold': 10,
'extrapolate_constant': True,
'extrapolate_constant_perc': 10.0,
'extrapolate_constant_min': 0.1,
'min_bin_points': 4
}

meta.update(name='ALE', params=params)

data['ale_values'] = [
np.array([[0.19286464, -0.19286464],
[0.19154015, -0.19154015],
[0.10176856, -0.10176856],
[-0.27236635, 0.27236635],
[-0.61915417, 0.61915417]]),

np.array([[0.02562475, -0.02562475],
[0.01918759, -0.01918759],
[0.01825016, -0.01825016],
[-0.0422308, 0.0422308],
[-0.060851, 0.060851]])
]
data['constant_value'] = 0.5
data['ale0'] = [
np.array([-0.19286464, 0.19286464]),
np.array([-0.02562475, 0.02562475])
]
data['feature_values'] = [
np.array([-1.1492519, -1.13838482, -0.4544552, 1.00972385, 2.66536596]),
np.array([-1.16729904, -0.96888455, -0.94992442, 0.72328773, 0.98771278])
]
data['feature_names'] = np.array(['f_0', 'f_1'], dtype='<U3')
data['target_names'] = np.array(['c_0', 'c_1'], dtype='<U3')
data['feature_deciles'] = [
np.array([-1.1492519, -1.14490507, -1.14055824, -1.00159889, -0.72802705, -0.4544552, 0.13121642, 0.71688804,
1.34085227, 2.00310912, 2.66536596]),
np.array([-1.16729904, -1.08793324, -1.00856744, -0.96509252, -0.95750847, -0.94992442, -0.28063956, 0.3886453,
0.77617274, 0.88194276, 0.98771278])
]
return Explanation(meta=meta, data=data)


@pytest.mark.parametrize('constant', [False, True])
def test__plot_one_ale_num(explanation, constant):
""" Test the `_plot_one_ale_num` function. """
feature = 0
targets = [0, 1]

fig, ax = plt.subplots()
ax = _plot_one_ale_num(exp=explanation,
feature=feature,
targets=targets,
constant=constant,
ax=ax,
legend=True,
line_kw={'label': None})

x1, y1 = ax.lines[1].get_xydata().T
x2, y2 = ax.lines[2].get_xydata().T

assert np.allclose(x1, explanation.data['feature_values'][feature])
assert np.allclose(x2, explanation.data['feature_values'][feature])

expected_ale_values = explanation.data['ale_values'][feature] + constant * explanation.data['constant_value']
assert np.allclose(y1, expected_ale_values[:, targets[0]])
assert np.allclose(y2, expected_ale_values[:, targets[1]])

assert ax.get_legend().texts[0].get_text() == f'c_{targets[0]}'
assert ax.get_legend().texts[1].get_text() == f'c_{targets[1]}'
assert ax.get_xlabel() == f'f_{feature}'

# extract deciles form the plot
segments = ax.collections[0].get_segments()
deciles = [segment[0][0] for segment in segments]
assert np.allclose(deciles, explanation.data['feature_deciles'][feature][1:])


@pytest.mark.parametrize('feats', [['f5']])
def test_plot_ale_features_error(feats, explanation):
""" Test if an error is raised when the name of the feature does not exist. """
with pytest.raises(ValueError) as err:
plot_ale(exp=explanation, features=feats)
assert f"Feature name {feats[0]} does not exist." == str(err.value)


@pytest.mark.parametrize('feats', [[0], [0, 1], ['f_0'], ['f_0', 'f_1'], 'all'])
def test_plot_ale_features(feats, explanation, mocker):
""" Test if `plot_ale` returns the expected number of plots given by the number of features. """
m = mocker.patch('alibi.explainers.ale._plot_one_ale_num')
axes = plot_ale(exp=explanation, features=feats).ravel()

if feats == 'all':
expected_features = list(range(len(explanation.data['feature_names'])))
else:
expected_features = [int(re.findall(r'\d+', f)[0]) if isinstance(f, str) else f for f in feats]

call_features = [kwargs['feature'] for _, kwargs in m.call_args_list]
assert np.allclose(call_features, expected_features)
assert len(axes) == len(expected_features)


@pytest.mark.parametrize('targets', [['c_5']])
def test_plot_ale_targets_error(targets, explanation):
""" Test if an error is raised when the name of the target does not exist. """
with pytest.raises(ValueError) as err:
plot_ale(exp=explanation, targets=targets)
assert f"Target name {targets[0]} does not exist." == str(err.value)


@pytest.mark.parametrize('targets', [[0], [0, 1], ['c_0'], ['c_0', 'c_1'], 'all'])
def test_plot_ale_targets(targets, explanation, mocker):
""" Test if `plot_ale` plots all the given targets. """
m = mocker.patch('alibi.explainers.ale._plot_one_ale_num')
plot_ale(exp=explanation, targets=targets).ravel()

_, kwargs = m.call_args
call_targets = kwargs['targets']

if targets == 'all':
expected_targets = list(range(len(explanation.data['target_names'])))
else:
expected_targets = [int(re.findall(r'\d+', t)[0]) if isinstance(t, str) else t for t in targets]

assert np.allclose(call_targets, expected_targets)


@pytest.mark.parametrize('n_cols', [1, 2, 3])
def test_plot_ale_n_cols(n_cols, explanation, mocker):
""" Test if the number of plot columns matches the expectation. """
mocker.patch('alibi.explainers.ale._plot_one_ale_num')
axes = plot_ale(exp=explanation, features=[0, 0, 0], n_cols=n_cols)
assert axes.shape[-1] == n_cols


def test_plot_ale_sharey_all(explanation):
""" Test if all axes have the same y limits when ``sharey='all'``. """
axes = plot_ale(exp=explanation, features=[0, 1], n_cols=1, sharey='all').ravel()
assert len(set([ax.get_ylim() for ax in axes])) == 1


@pytest.mark.parametrize('n_cols', [1, 2])
def test_plot_ale_sharey_row(n_cols, explanation):
""" Test if all axes on the same rows have the same y limits and axes on different rows have different y limits
when ``sharey=row``. """
axes = plot_ale(exp=explanation, features=[0, 1], n_cols=n_cols, sharey='row')

if n_cols == 1:
# different rows should have different y-limits
assert axes[0, 0].get_ylim() != axes[1, 0].get_ylim()
else:
# same row should have the same y-limits
assert axes[0, 0].get_ylim() == axes[0, 1].get_ylim()


@pytest.mark.parametrize('n_cols', [1, 2])
def test_plot_ale_sharey_none(n_cols, explanation):
""" Test if all axes have different y limits when ``sharey=None``. """
axes = plot_ale(exp=explanation, features=[0, 1], n_cols=n_cols, sharey=None).ravel()
assert axes[0].get_ylim() != axes[1].get_ylim()


def test_plot_ale_axes_error(explanation):
""" Test if an error is raised when the number of provided axes is less that the number of features. """
import matplotlib.pyplot as plt
fig, axes = plt.subplots(nrows=1, ncols=2)
feats = [0, 0, 0]

with pytest.raises(ValueError) as err:
plot_ale(exp=explanation, features=feats, ax=axes)
assert f"Expected ax to have {len(features)} axes, got {axes.size}" == str(err.value)


@pytest.mark.parametrize('label', [None, ['target_1', 'target_2']])
def test_plot_ale_legend(label, explanation):
""" Test if the legend is displayed only for the first ax object with the expected text. """
axes = plot_ale(exp=explanation, line_kw={'label': label}).ravel()
assert axes[0].get_legend() is not None
assert all([ax.get_legend() is None for ax in axes[1:]])

texts = [text.get_text() for text in axes[0].get_legend().get_texts()]
if label is None:
assert texts == explanation.data['target_names'].tolist()
else:
assert texts == label
1 change: 1 addition & 0 deletions requirements/dev.txt
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ pytest-xdist>=1.28.0, <3.0.0 # for distributed testing, currently unused (see se
pytest-lazy-fixture>=0.6.3, <0.7.0
pytest-custom_exit_code>=0.3.0 # for notebook tests
pytest-timeout>=1.4.2, <3.0.0 # for notebook tests
pytest-mock>=3.10.0, <4.0.0
jupytext>=1.12.0, <2.0.0 # for notebook tests
ipykernel>=5.1.0, <7.0.0 # for notebook tests
nbconvert>=6.0.7, <8.0.0 # for notebook tests
Expand Down