-
Notifications
You must be signed in to change notification settings - Fork 13
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #684 from dianna-ai/split_tests_modularize
676 Split tests for modularized code
- Loading branch information
Showing
7 changed files
with
213 additions
and
189 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,67 @@ | ||
"""Unit tests for LIME image.""" | ||
from unittest import TestCase | ||
import numpy as np | ||
import dianna | ||
from dianna.methods.lime_image import LIMEImage | ||
from tests.methods.test_onnx_runner import generate_data | ||
from tests.utils import run_model | ||
|
||
|
||
class LimeOnImages(TestCase): | ||
"""Suite of Lime tests for the image case.""" | ||
|
||
@staticmethod | ||
def test_lime_function(): | ||
"""Test if lime runs and outputs are correct given some data and a model function.""" | ||
input_data = np.random.random((224, 224, 3)) | ||
heatmap_expected = np.load('tests/test_data/heatmap_lime_function.npy') | ||
labels = [1] | ||
|
||
explainer = LIMEImage(random_state=42) | ||
heatmap = explainer.explain(run_model, | ||
input_data, | ||
labels, | ||
num_samples=100) | ||
|
||
assert heatmap[0].shape == input_data.shape[:2] | ||
assert np.allclose(heatmap, heatmap_expected, atol=1e-5) | ||
|
||
@staticmethod | ||
def test_lime_filename(): | ||
"""Test if lime runs and outputs are correct given some data and a model file.""" | ||
model_filename = 'tests/test_data/mnist_model.onnx' | ||
input_data = generate_data(batch_size=1)[0].astype(np.float32) | ||
axis_labels = ('channels', 'y', 'x') | ||
labels = [1] | ||
|
||
heatmap = dianna.explain_image(model_filename, | ||
input_data, | ||
method='LIME', | ||
labels=labels, | ||
random_state=42, | ||
axis_labels=axis_labels) | ||
|
||
heatmap_expected = np.load('tests/test_data/heatmap_lime_filename.npy') | ||
assert heatmap[0].shape == input_data[0].shape | ||
assert np.allclose(heatmap, heatmap_expected, atol=1e-5) | ||
|
||
@staticmethod | ||
def test_lime_values(): | ||
"""Test if get_explanation_values function works correctly.""" | ||
input_data = np.random.random((224, 224, 3)) | ||
heatmap_expected = np.load('tests/test_data/heatmap_lime_values.npy') | ||
labels = [1] | ||
|
||
explainer = LIMEImage(random_state=42) | ||
heatmap = explainer.explain(run_model, | ||
input_data, | ||
labels, | ||
return_masks=False, | ||
num_samples=100) | ||
|
||
assert heatmap[0].shape == input_data.shape[:2] | ||
assert np.allclose(heatmap, heatmap_expected, atol=1e-5) | ||
|
||
def setUp(self) -> None: | ||
"""Set seed.""" | ||
np.random.seed(42) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,68 @@ | ||
"""Unit tests for RISE image.""" | ||
from unittest import TestCase | ||
import numpy as np | ||
import dianna | ||
from dianna.methods.rise_image import RISEImage | ||
from dianna.utils import get_function | ||
from tests.methods.test_onnx_runner import generate_data | ||
from tests.utils import get_mnist_1_data | ||
from tests.utils import run_model | ||
|
||
|
||
class RiseOnImages(TestCase): | ||
"""Suite of RISE tests for the image case.""" | ||
|
||
@staticmethod | ||
def test_rise_function(): | ||
"""Test if rise runs and outputs the correct shape given some data and a model function.""" | ||
input_data = np.random.random((224, 224, 3)) | ||
axis_labels = ["y", "x", "channels"] | ||
labels = [1] | ||
heatmaps_expected = np.load( | ||
"tests/test_data/heatmap_rise_function.npy") | ||
|
||
heatmaps = dianna.explain_image( | ||
run_model, | ||
input_data, | ||
"RISE", | ||
labels, | ||
axis_labels=axis_labels, | ||
n_masks=200, | ||
p_keep=0.5, | ||
) | ||
|
||
assert heatmaps[0].shape == input_data.shape[:2] | ||
assert np.allclose(heatmaps, heatmaps_expected, atol=1e-5) | ||
|
||
@staticmethod | ||
def test_rise_filename(): | ||
"""Test if rise runs and outputs the correct shape given some data and a model file.""" | ||
model_filename = "tests/test_data/mnist_model.onnx" | ||
input_data = generate_data(batch_size=1).astype(np.float32)[0] | ||
heatmaps_expected = np.load( | ||
"tests/test_data/heatmap_rise_filename.npy") | ||
labels = [1] | ||
|
||
heatmaps = dianna.explain_image(model_filename, | ||
input_data, | ||
"RISE", | ||
labels, | ||
n_masks=200, | ||
p_keep=0.5) | ||
|
||
assert heatmaps[0].shape == input_data.shape[1:] | ||
print(heatmaps_expected.shape) | ||
assert np.allclose(heatmaps, heatmaps_expected, atol=1e-5) | ||
|
||
@staticmethod | ||
def test_rise_determine_p_keep_for_images(): | ||
"""Tests exact expected p_keep given an image and model.""" | ||
np.random.seed(0) | ||
expected_p_exact_keep = 0.4 | ||
model_filename = "tests/test_data/mnist_model.onnx" | ||
data = get_mnist_1_data().astype(np.float32) | ||
|
||
p_keep = RISEImage()._determine_p_keep(data, | ||
get_function(model_filename)) | ||
|
||
assert np.isclose(p_keep, expected_p_exact_keep) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,57 @@ | ||
"""Unit tests for RISE text.""" | ||
from unittest import TestCase | ||
import numpy as np | ||
import dianna.visualization | ||
from dianna.methods.rise_text import RISEText | ||
from dianna.utils import get_function | ||
from tests.utils import assert_explanation_satisfies_expectations | ||
from tests.utils import load_movie_review_model | ||
|
||
|
||
class RiseOnText(TestCase): | ||
"""Suite of RISE tests for the text case.""" | ||
|
||
def test_rise_text(self): | ||
"""Tests exact expected output given a text and model.""" | ||
review = "such a bad movie" | ||
expected_words = ["such", "a", "bad", "movie"] | ||
expected_word_indices = [0, 1, 2, 3] | ||
expected_positive_scores = [0.30, 0.29, 0.04, 0.25] | ||
|
||
positive_explanation = dianna.explain_text( | ||
self.runner, | ||
review, | ||
tokenizer=self.runner.tokenizer, | ||
labels=(1, 0), | ||
method="RISE", | ||
p_keep=0.5, | ||
)[0] | ||
|
||
assert_explanation_satisfies_expectations( | ||
positive_explanation, | ||
expected_positive_scores, | ||
expected_word_indices, | ||
expected_words, | ||
) | ||
|
||
def test_rise_determine_p_keep_for_text(self): | ||
"""Tests exact expected p_keep given a text and model.""" | ||
expected_p_exact_keep = 0.7 | ||
input_text = "such a bad movie" | ||
runner = get_function(self.runner) | ||
input_tokens = np.asarray(runner.tokenizer.tokenize(input_text)) | ||
|
||
# pylint: disable=protected-access | ||
p_keep = RISEText()._determine_p_keep( | ||
input_tokens, | ||
runner, | ||
runner.tokenizer, | ||
n_masks=100, | ||
batch_size=100, | ||
) | ||
assert np.isclose(p_keep, expected_p_exact_keep) | ||
|
||
def setUp(self) -> None: | ||
"""Set seed and load runner.""" | ||
np.random.seed(0) | ||
self.runner = load_movie_review_model() |
Oops, something went wrong.