Skip to content

Commit

Permalink
add test for tabular plot
Browse files Browse the repository at this point in the history
  • Loading branch information
Yang committed Jan 5, 2024
1 parent 2566e25 commit ca470ff
Showing 1 changed file with 20 additions and 7 deletions.
27 changes: 20 additions & 7 deletions tests/test_visualization.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,29 @@
"""Unit tests for visualization modules."""
from pathlib import Path
import numpy as np
import pytest
from dianna.visualization import plot_tabular
from dianna.visualization import plot_timeseries


def test_plot_tabular(tmpdir):
"""Test plot tabular data."""
x = np.linspace(-5, 5, 3)
y = [f"Feature {i}" for i in range(len(x))]
output_path = Path(tmpdir) / "temp_visualization_test_tabular.png"

plot_tabular(x=x, y=y, show_plot=False, output_filename=output_path)

assert output_path.exists()


def test_plot_timeseries_univariate(tmpdir, random):
"""Test plot univariate time series."""
x = np.linspace(0, 10, 20)
y = np.sin(x)
segments = get_test_segments(data=np.expand_dims(y, 0))

output_path = Path(tmpdir) / 'temp_visualization_test_univariate.png'
output_path = Path(tmpdir) / "temp_visualization_test_univariate.png"

plot_timeseries(x=x,
y=y,
Expand All @@ -26,7 +39,7 @@ def test_plot_timeseries_multivariate(tmpdir, random):
x = np.linspace(start=0, stop=10, num=20)
ys = np.stack((np.sin(x), np.cos(x), np.tan(0.4 * x)))
segments = get_test_segments(data=ys)
output_path = Path(tmpdir) / 'temp_visualization_test_multivariate.png'
output_path = Path(tmpdir) / "temp_visualization_test_multivariate.png"

plot_timeseries(x=x,
y=ys.T,
Expand All @@ -48,13 +61,13 @@ def get_test_segments(data):
for i_segment in range(n_segments):
for i_channel in range(n_channels):
segment = {
'index': i_segment + i_channel * n_segments,
'start': i_segment,
'stop': i_segment + 1,
'weight': data[i_channel, factor * i_segment],
"index": i_segment + i_channel * n_segments,
"start": i_segment,
"stop": i_segment + 1,
"weight": data[i_channel, factor * i_segment],
}
if n_channels > 1:
segment['channel'] = i_channel
segment["channel"] = i_channel
segments.append(segment)

return segments
Expand Down

0 comments on commit ca470ff

Please sign in to comment.