Skip to content

Commit

Permalink
parameterize rise time series test synthetic data
Browse files Browse the repository at this point in the history
introduce failing test
  • Loading branch information
cwmeijer committed Dec 4, 2023
1 parent 4b29c1f commit 9d0f821
Showing 1 changed file with 17 additions and 30 deletions.
47 changes: 17 additions & 30 deletions tests/methods/test_rise_timeseries.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import numpy as np
import pandas
import pytest
import dianna
from dianna.methods.rise_timeseries import RISETimeseries
from tests.methods.time_series_test_case import average_temperature_timeseries_with_1_cold_and_1_hot_day
Expand All @@ -24,42 +24,34 @@ def test_rise_timeseries_correct_output_shape():
assert heatmaps.shape == (len(labels), *input_data.shape)


def test_rise_timeseries_with_expert_model_for_correct_max_and_min():
@pytest.mark.parametrize('series_length', [
10,
3,
])
def test_rise_timeseries_with_expert_model_for_correct_max_and_min(
series_length):
"""Test if RISE highlights the correct areas for this artificial example."""
hot_day_index = 1
cold_day_index = 2
series_length = 4
temperature_timeseries = average_temperature_timeseries_with_1_cold_and_1_hot_day(
cold_day_index, hot_day_index, series_length=series_length)

# summer_explanation, winter_explanation = dianna.explain_timeseries(
# # run_expert_model,
# run_continuous_expert_model,
# timeseries_data=temperature_timeseries,
# method='rise',
# labels=[0, 1],
# p_keep=0.1,
# n_masks=50,
# feature_res=series_length,
# mask_type=input_train_mean)

explainer = RISETimeseries(n_masks=100000,
p_keep=0.5,
feature_res=series_length)
summer_explanation, winter_explanation = explainer.explain(
run_expert_model_3_step, temperature_timeseries, labels=[0, 1])
print('\n')
print(pandas.DataFrame(temperature_timeseries))
dot = (explainer.masks[:, :, 0].T * explainer.predictions[:, 0]).T
log = pandas.DataFrame(
np.column_stack((explainer.masks[:, :, 0], explainer.masked[:, :, 0],
explainer.predictions, dot)),
columns=[f'm{i}' for i in range(len(summer_explanation))] +
[f'd{i}' for i in range(len(summer_explanation))] + ['P(S)', 'P(W)'] +
[f'S{i}' for i in range(len(summer_explanation))])
print(log)
print(log.sum())

_visualize_explainer_output(hot_day_index, cold_day_index,
summer_explanation, winter_explanation)
assert np.argmin(winter_explanation) == hot_day_index
assert np.argmax(summer_explanation) == hot_day_index
assert np.argmin(summer_explanation) == cold_day_index
assert np.argmax(winter_explanation) == cold_day_index


def _visualize_explainer_output(hot_day_index, cold_day_index,
summer_explanation, winter_explanation):
print('\n')
length = len(summer_explanation)
margin = ' '
Expand All @@ -74,11 +66,6 @@ def test_rise_timeseries_with_expert_model_for_correct_max_and_min():
_print_series('summer', summer_explanation)
_print_series('winter', winter_explanation)

assert np.argmin(winter_explanation) == hot_day_index
assert np.argmax(summer_explanation) == hot_day_index
assert np.argmin(summer_explanation) == cold_day_index
assert np.argmax(winter_explanation) == cold_day_index


def _print_series(title, series):
mini = np.min(series)
Expand Down

0 comments on commit 9d0f821

Please sign in to comment.