Skip to content

Commit

Permalink
Merge pull request #684 from dianna-ai/split_tests_modularize
Browse files Browse the repository at this point in the history
676 Split tests for modularized code
  • Loading branch information
Yang authored Jan 10, 2024
2 parents e5c9930 + 6b7df83 commit 25e6f9d
Show file tree
Hide file tree
Showing 7 changed files with 213 additions and 189 deletions.
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
"""Unit tests for KernelSHAP image."""
from unittest import TestCase
import numpy as np
from dianna.methods.kernelshap_image import KERNELSHAPImage


class ShapOnImages(TestCase):
"""Suite of Kernelshap tests for the image case."""

def test_shap_segment_image(self):
"""Test if the segmentation of images are correct given some data."""
input_data = np.random.random((28, 28, 1))
Expand Down Expand Up @@ -41,7 +43,10 @@ def test_shap_mask_image(self):
sigma,
)
masked_image = explainer._mask_image(
np.zeros((1, n_segments)), segments_slic, input_data, background,
np.zeros((1, n_segments)),
segments_slic,
input_data,
background,
)
# check if all points are masked
assert np.array_equal(masked_image[0], np.zeros(input_data.shape))
Expand Down
67 changes: 67 additions & 0 deletions tests/methods/test_lime_image.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
"""Unit tests for LIME image."""
from unittest import TestCase
import numpy as np
import dianna
from dianna.methods.lime_image import LIMEImage
from tests.methods.test_onnx_runner import generate_data
from tests.utils import run_model


class LimeOnImages(TestCase):
"""Suite of Lime tests for the image case."""

@staticmethod
def test_lime_function():
"""Test if lime runs and outputs are correct given some data and a model function."""
input_data = np.random.random((224, 224, 3))
heatmap_expected = np.load('tests/test_data/heatmap_lime_function.npy')
labels = [1]

explainer = LIMEImage(random_state=42)
heatmap = explainer.explain(run_model,
input_data,
labels,
num_samples=100)

assert heatmap[0].shape == input_data.shape[:2]
assert np.allclose(heatmap, heatmap_expected, atol=1e-5)

@staticmethod
def test_lime_filename():
"""Test if lime runs and outputs are correct given some data and a model file."""
model_filename = 'tests/test_data/mnist_model.onnx'
input_data = generate_data(batch_size=1)[0].astype(np.float32)
axis_labels = ('channels', 'y', 'x')
labels = [1]

heatmap = dianna.explain_image(model_filename,
input_data,
method='LIME',
labels=labels,
random_state=42,
axis_labels=axis_labels)

heatmap_expected = np.load('tests/test_data/heatmap_lime_filename.npy')
assert heatmap[0].shape == input_data[0].shape
assert np.allclose(heatmap, heatmap_expected, atol=1e-5)

@staticmethod
def test_lime_values():
"""Test if get_explanation_values function works correctly."""
input_data = np.random.random((224, 224, 3))
heatmap_expected = np.load('tests/test_data/heatmap_lime_values.npy')
labels = [1]

explainer = LIMEImage(random_state=42)
heatmap = explainer.explain(run_model,
input_data,
labels,
return_masks=False,
num_samples=100)

assert heatmap[0].shape == input_data.shape[:2]
assert np.allclose(heatmap, heatmap_expected, atol=1e-5)

def setUp(self) -> None:
"""Set seed."""
np.random.seed(42)
86 changes: 13 additions & 73 deletions tests/test_lime.py → tests/methods/test_lime_text.py
Original file line number Diff line number Diff line change
@@ -1,73 +1,10 @@
"""Unit tests for LIME text."""
from unittest import TestCase
import numpy as np
import pytest
import dianna
import dianna.visualization
from dianna.methods.lime_image import LIMEImage
from tests.test_onnx_runner import generate_data
from tests.utils import assert_explanation_satisfies_expectations
from tests.utils import load_movie_review_model
from tests.utils import run_model


class LimeOnImages(TestCase):
"""Suite of Lime tests for the image case."""

@staticmethod
def test_lime_function():
"""Test if lime runs and outputs are correct given some data and a model function."""
input_data = np.random.random((224, 224, 3))
heatmap_expected = np.load('tests/test_data/heatmap_lime_function.npy')
labels = [1]

explainer = LIMEImage(random_state=42)
heatmap = explainer.explain(run_model,
input_data,
labels,
num_samples=100)

assert heatmap[0].shape == input_data.shape[:2]
assert np.allclose(heatmap, heatmap_expected, atol=1e-5)

@staticmethod
def test_lime_filename():
"""Test if lime runs and outputs are correct given some data and a model file."""
model_filename = 'tests/test_data/mnist_model.onnx'
input_data = generate_data(batch_size=1)[0].astype(np.float32)
axis_labels = ('channels', 'y', 'x')
labels = [1]

heatmap = dianna.explain_image(model_filename,
input_data,
method='LIME',
labels=labels,
random_state=42,
axis_labels=axis_labels)

heatmap_expected = np.load('tests/test_data/heatmap_lime_filename.npy')
assert heatmap[0].shape == input_data[0].shape
assert np.allclose(heatmap, heatmap_expected, atol=1e-5)

@staticmethod
def test_lime_values():
"""Test if get_explanation_values function works correctly."""
input_data = np.random.random((224, 224, 3))
heatmap_expected = np.load('tests/test_data/heatmap_lime_values.npy')
labels = [1]

explainer = LIMEImage(random_state=42)
heatmap = explainer.explain(run_model,
input_data,
labels,
return_masks=False,
num_samples=100)

assert heatmap[0].shape == input_data.shape[:2]
assert np.allclose(heatmap, heatmap_expected, atol=1e-5)

def setUp(self) -> None:
"""Set seed."""
np.random.seed(42)


class LimeOnText(TestCase):
Expand Down Expand Up @@ -165,21 +102,24 @@ def tokenizer():
('UNKWORDZ a bad UNKWORDZ UNKWORDZ!?\'"', 9),
('such UNKWORDZ UNKWORDZ movie "UNKWORDZUNKWORDZ\'UNKWORDZ', 9),
('such a bad UNKWORDZ UNKWORDZ!UNKWORDZ\'UNKWORDZ', 9),
pytest.param('its own self-UNKWORDZ universe.', 7,
pytest.param('its own self-UNKWORDZ universe.',
7,
marks=pytest.mark.xfail(reason='poor handling of -')),
pytest.param('its own UNKWORDZ-contained universe.', 7,
pytest.param('its own UNKWORDZ-contained universe.',
7,
marks=pytest.mark.xfail(reason='poor handling of -')),
pytest.param('Backslashes are UNKWORDZ/cool.', 6,
pytest.param('Backslashes are UNKWORDZ/cool.',
6,
marks=pytest.mark.xfail(reason='/ poor handling of /')),
pytest.param('Backslashes are fun/UNKWORDZ.', 6,
pytest.param('Backslashes are fun/UNKWORDZ.',
6,
marks=pytest.mark.xfail(reason='poor handling of /')),
pytest.param(' ', 0,
marks=pytest.mark.xfail(reason='Repeated whitespaces')),
pytest.param('I like whitespaces.', 4,
pytest.param(
' ', 0, marks=pytest.mark.xfail(reason='Repeated whitespaces')),
pytest.param('I like whitespaces.',
4,
marks=pytest.mark.xfail(reason='Repeated whitespaces')),
])


def test_spacytokenizer_length(text, length, tokenizer):
"""Test that tokenizer returns strings of the correct length."""
tokens = tokenizer.tokenize(text)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,8 @@

def generate_data(batch_size):
"""Generate a batch of random data."""
return np.random.randint(0, 256, size=(batch_size, 1, 28, 28)) # MNIST shape
return np.random.randint(0, 256,
size=(batch_size, 1, 28, 28)) # MNIST shape


def test_onnx_runner():
Expand Down
68 changes: 68 additions & 0 deletions tests/methods/test_rise_image.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
"""Unit tests for RISE image."""
from unittest import TestCase
import numpy as np
import dianna
from dianna.methods.rise_image import RISEImage
from dianna.utils import get_function
from tests.methods.test_onnx_runner import generate_data
from tests.utils import get_mnist_1_data
from tests.utils import run_model


class RiseOnImages(TestCase):
"""Suite of RISE tests for the image case."""

@staticmethod
def test_rise_function():
"""Test if rise runs and outputs the correct shape given some data and a model function."""
input_data = np.random.random((224, 224, 3))
axis_labels = ["y", "x", "channels"]
labels = [1]
heatmaps_expected = np.load(
"tests/test_data/heatmap_rise_function.npy")

heatmaps = dianna.explain_image(
run_model,
input_data,
"RISE",
labels,
axis_labels=axis_labels,
n_masks=200,
p_keep=0.5,
)

assert heatmaps[0].shape == input_data.shape[:2]
assert np.allclose(heatmaps, heatmaps_expected, atol=1e-5)

@staticmethod
def test_rise_filename():
"""Test if rise runs and outputs the correct shape given some data and a model file."""
model_filename = "tests/test_data/mnist_model.onnx"
input_data = generate_data(batch_size=1).astype(np.float32)[0]
heatmaps_expected = np.load(
"tests/test_data/heatmap_rise_filename.npy")
labels = [1]

heatmaps = dianna.explain_image(model_filename,
input_data,
"RISE",
labels,
n_masks=200,
p_keep=0.5)

assert heatmaps[0].shape == input_data.shape[1:]
print(heatmaps_expected.shape)
assert np.allclose(heatmaps, heatmaps_expected, atol=1e-5)

@staticmethod
def test_rise_determine_p_keep_for_images():
"""Tests exact expected p_keep given an image and model."""
np.random.seed(0)
expected_p_exact_keep = 0.4
model_filename = "tests/test_data/mnist_model.onnx"
data = get_mnist_1_data().astype(np.float32)

p_keep = RISEImage()._determine_p_keep(data,
get_function(model_filename))

assert np.isclose(p_keep, expected_p_exact_keep)
57 changes: 57 additions & 0 deletions tests/methods/test_rise_text.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
"""Unit tests for RISE text."""
from unittest import TestCase
import numpy as np
import dianna.visualization
from dianna.methods.rise_text import RISEText
from dianna.utils import get_function
from tests.utils import assert_explanation_satisfies_expectations
from tests.utils import load_movie_review_model


class RiseOnText(TestCase):
"""Suite of RISE tests for the text case."""

def test_rise_text(self):
"""Tests exact expected output given a text and model."""
review = "such a bad movie"
expected_words = ["such", "a", "bad", "movie"]
expected_word_indices = [0, 1, 2, 3]
expected_positive_scores = [0.30, 0.29, 0.04, 0.25]

positive_explanation = dianna.explain_text(
self.runner,
review,
tokenizer=self.runner.tokenizer,
labels=(1, 0),
method="RISE",
p_keep=0.5,
)[0]

assert_explanation_satisfies_expectations(
positive_explanation,
expected_positive_scores,
expected_word_indices,
expected_words,
)

def test_rise_determine_p_keep_for_text(self):
"""Tests exact expected p_keep given a text and model."""
expected_p_exact_keep = 0.7
input_text = "such a bad movie"
runner = get_function(self.runner)
input_tokens = np.asarray(runner.tokenizer.tokenize(input_text))

# pylint: disable=protected-access
p_keep = RISEText()._determine_p_keep(
input_tokens,
runner,
runner.tokenizer,
n_masks=100,
batch_size=100,
)
assert np.isclose(p_keep, expected_p_exact_keep)

def setUp(self) -> None:
"""Set seed and load runner."""
np.random.seed(0)
self.runner = load_movie_review_model()
Loading

0 comments on commit 25e6f9d

Please sign in to comment.