From 9d0f821f77d8ab73580be826f4e91c1822a93e47 Mon Sep 17 00:00:00 2001 From: Christiaan Meijer Date: Mon, 4 Dec 2023 12:03:51 +0100 Subject: [PATCH] parameterize rise time series test synthetic data introduce failing test --- tests/methods/test_rise_timeseries.py | 47 ++++++++++----------------- 1 file changed, 17 insertions(+), 30 deletions(-) diff --git a/tests/methods/test_rise_timeseries.py b/tests/methods/test_rise_timeseries.py index 0b6e4402..4691970b 100644 --- a/tests/methods/test_rise_timeseries.py +++ b/tests/methods/test_rise_timeseries.py @@ -1,5 +1,5 @@ import numpy as np -import pandas +import pytest import dianna from dianna.methods.rise_timeseries import RISETimeseries from tests.methods.time_series_test_case import average_temperature_timeseries_with_1_cold_and_1_hot_day @@ -24,42 +24,34 @@ def test_rise_timeseries_correct_output_shape(): assert heatmaps.shape == (len(labels), *input_data.shape) -def test_rise_timeseries_with_expert_model_for_correct_max_and_min(): +@pytest.mark.parametrize('series_length', [ + 10, + 3, +]) +def test_rise_timeseries_with_expert_model_for_correct_max_and_min( + series_length): """Test if RISE highlights the correct areas for this artificial example.""" hot_day_index = 1 cold_day_index = 2 - series_length = 4 temperature_timeseries = average_temperature_timeseries_with_1_cold_and_1_hot_day( cold_day_index, hot_day_index, series_length=series_length) - # summer_explanation, winter_explanation = dianna.explain_timeseries( - # # run_expert_model, - # run_continuous_expert_model, - # timeseries_data=temperature_timeseries, - # method='rise', - # labels=[0, 1], - # p_keep=0.1, - # n_masks=50, - # feature_res=series_length, - # mask_type=input_train_mean) - explainer = RISETimeseries(n_masks=100000, p_keep=0.5, feature_res=series_length) summer_explanation, winter_explanation = explainer.explain( run_expert_model_3_step, temperature_timeseries, labels=[0, 1]) - print('\n') - print(pandas.DataFrame(temperature_timeseries)) - dot = (explainer.masks[:, :, 0].T * explainer.predictions[:, 0]).T - log = pandas.DataFrame( - np.column_stack((explainer.masks[:, :, 0], explainer.masked[:, :, 0], - explainer.predictions, dot)), - columns=[f'm{i}' for i in range(len(summer_explanation))] + - [f'd{i}' for i in range(len(summer_explanation))] + ['P(S)', 'P(W)'] + - [f'S{i}' for i in range(len(summer_explanation))]) - print(log) - print(log.sum()) + _visualize_explainer_output(hot_day_index, cold_day_index, + summer_explanation, winter_explanation) + assert np.argmin(winter_explanation) == hot_day_index + assert np.argmax(summer_explanation) == hot_day_index + assert np.argmin(summer_explanation) == cold_day_index + assert np.argmax(winter_explanation) == cold_day_index + + +def _visualize_explainer_output(hot_day_index, cold_day_index, + summer_explanation, winter_explanation): print('\n') length = len(summer_explanation) margin = ' ' @@ -74,11 +66,6 @@ def test_rise_timeseries_with_expert_model_for_correct_max_and_min(): _print_series('summer', summer_explanation) _print_series('winter', winter_explanation) - assert np.argmin(winter_explanation) == hot_day_index - assert np.argmax(summer_explanation) == hot_day_index - assert np.argmin(summer_explanation) == cold_day_index - assert np.argmax(winter_explanation) == cold_day_index - def _print_series(title, series): mini = np.min(series)