Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: added moving average plot #551

Closed
wants to merge 15 commits into from
Closed
65 changes: 63 additions & 2 deletions src/safeds/data/tabular/containers/_time_series.py
Original file line number Diff line number Diff line change
Expand Up @@ -871,7 +871,7 @@ def transform_column(self, name: str, transformer: Callable[[Row], Any]) -> Time
time_name=self.time.name,
)

def plot_lagplot(self, lag: int) -> Image:
def plot_lag_plot(self, lag: int) -> Image:
"""
Plot a lagplot for the target column.

Expand All @@ -894,7 +894,7 @@ def plot_lagplot(self, lag: int) -> Image:
--------
>>> from safeds.data.tabular.containers import TimeSeries
>>> table = TimeSeries({"time":[1, 2], "target": [3, 4], "feature":[2,2]}, target_name= "target", time_name="time", feature_names=["feature"], )
>>> image = table.plot_lagplot(lag = 1)
>>> image = table.plot_lag_plot(lag = 1)

"""
if not self.target.type.is_numeric():
Expand All @@ -906,3 +906,64 @@ def plot_lagplot(self, lag: int) -> Image:
plt.close() # Prevents the figure from being displayed directly
buffer.seek(0)
return Image.from_bytes(buffer.read())

def plot_moving_average(
self,
window_size: int,
feature_name: str | None = None,
Gerhardsa0 marked this conversation as resolved.
Show resolved Hide resolved
) -> Image:
"""
Plot the moving average for the target column.

Parameters
----------
window_size:
The size of the windows, which the average gets calculated for

feature_name:
The name of the column which will be used to calculate the moving average, if None the target column will be taken

Returns
-------
plot:
The moving avereage plot and the normal plot as an image.

Raises
------
NonNumericColumnError
If the time series targets contains non-numerical values.

UnknownColumnNameError
If the time series doesn't contain the given column name

Examples
--------
>>> from safeds.data.tabular.containers import TimeSeries
>>> table = TimeSeries({"time":[1, 2], "target": [3, 4], "feature":[2,2]}, target_name= "target", time_name="time", feature_names=["feature"], )
>>> image = table.plot_moving_average(window_size = 2)

"""
if feature_name is None or feature_name == self.target.name:
series = self.target._data
feature_name = self.target.name
else:
if feature_name not in self.column_names:
raise UnknownColumnNameError([feature_name])
series = self._data[feature_name]
if not self.get_column(feature_name).type.is_numeric():
raise NonNumericColumnError("This time series plotted column contains non-numerical columns.")

# create moving average series
series_mvg = series.rolling(window_size).mean()

# plot both series and put them together
ax_temp = series_mvg.plot()
ax = series.plot(ax=ax_temp)
ax.legend(labels=["moving_average", feature_name])

fig = ax.figure
buffer = io.BytesIO()
fig.savefig(buffer, format="png")
plt.close() # Prevents the figure from being displayed directly
buffer.seek(0)
return Image.from_bytes(buffer.read())
Gerhardsa0 marked this conversation as resolved.
Show resolved Hide resolved
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ def test_should_return_table(snapshot_png: SnapshotAssertion) -> None:
time_name="time",
feature_names=None,
)
lag_plot = table.plot_lagplot(lag=1)
lag_plot = table.plot_lag_plot(lag=1)
assert lag_plot == snapshot_png


Expand All @@ -38,4 +38,4 @@ def test_should_raise_if_column_contains_non_numerical_values() -> None:
r" non-numerical columns."
),
):
table.plot_lagplot(2)
table.plot_lag_plot(2)
Original file line number Diff line number Diff line change
@@ -0,0 +1,109 @@
import pytest
from safeds.data.tabular.containers import TimeSeries
from safeds.exceptions import NonNumericColumnError, UnknownColumnNameError
from syrupy import SnapshotAssertion


def test_should_return_table(snapshot_png: SnapshotAssertion) -> None:
table = TimeSeries(
{
"time": [1, 2, 3, 4, 5, 6, 7, 8, 9, 10],
"feature_1": [1, 2, 3, 4, 5, 6, 7, 8, 9, 10],
"target": [1, 2, 3, 4, 3, 2, 1, 2, 3, 4],
},
target_name="target",
time_name="time",
feature_names=None,
)
moving_average_plot = table.plot_moving_average(window_size=2)
assert moving_average_plot == snapshot_png


def test_optional_parameter(snapshot_png: SnapshotAssertion) -> None:
table = TimeSeries(
{
"time": [1, 2, 3, 4, 5, 6, 7, 8, 9, 10],
"feature_1": [1, 2, 3, 4, 5, 4, 3, 2, 1, 0],
"target": [1, 2, 3, 4, 3, 2, 1, 2, 3, 4],
},
target_name="target",
time_name="time",
feature_names=None,
)
moving_average_plot = table.plot_moving_average(window_size=2, feature_name="feature_1")
assert moving_average_plot == snapshot_png


def test_should_raise_if_column_contains_non_numerical_values() -> None:
table = TimeSeries(
{
"time": [1, 2, 3, 4, 5, 6, 7, 8, 9, 10],
"feature_1": [1, 2, 3, 4, 5, 6, 7, 8, 9, 10],
"target": ["1", "2", "3", "4", "5", "6", "7", "8", "9", "10"],
},
target_name="target",
time_name="time",
feature_names=None,
)
with pytest.raises(
NonNumericColumnError,
match=(
r"Tried to do a numerical operation on one or multiple non-numerical columns: \nThis time series plotted"
r" column"
r" contains"
r" non-numerical columns."
),
):
table.plot_moving_average(2)


@pytest.mark.parametrize(
("time_series", "name", "error", "error_msg"),
[
(
TimeSeries(
{
"time": [1, 2, 3, 4, 5, 6, 7, 8, 9, 10],
"feature_1": ["1", "2", "3", "4", "5", "6", "7", "8", "9", "10"],
"target": ["1", "2", "3", "4", "5", "6", "7", "8", "9", "10"],
},
target_name="target",
time_name="time",
feature_names=None,
),
"feature_1",
NonNumericColumnError,
r"Tried to do a numerical operation on one or multiple non-numerical columns: \nThis time series plotted"
r" column"
r" contains"
r" non-numerical columns.",
),
(
TimeSeries(
{
"time": [1, 2, 3, 4, 5, 6, 7, 8, 9, 10],
"feature_1": ["1", "2", "3", "4", "5", "6", "7", "8", "9", "10"],
"target": ["1", "2", "3", "4", "5", "6", "7", "8", "9", "10"],
},
target_name="target",
time_name="time",
feature_names=None,
),
"feature_3",
UnknownColumnNameError,
r"Could not find column\(s\) 'feature_3'.",
),
],
ids=["feature_not_numerical", "feature_does_not_exist"],
)
def test_should_raise_error_optional_parameter(
time_series: TimeSeries,
name: str,
error: type[Exception],
error_msg: str,
) -> None:
with pytest.raises(
error,
match=error_msg,
):
time_series.plot_moving_average(2, feature_name=name)