Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: added moving average plot #551

Closed
wants to merge 15 commits into from
Closed
5 changes: 5 additions & 0 deletions src/safeds/_utils/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
"""Tools to work internally with plots."""

from ._plotting import _create_image_for_plot

__all__ = ["_create_image_for_plot"]
13 changes: 13 additions & 0 deletions src/safeds/_utils/_plotting.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
import io

import matplotlib.pyplot as plt

from safeds.data.image.containers import Image


def _create_image_for_plot(fig: plt.Figure) -> Image:
buffer = io.BytesIO()
fig.savefig(buffer, format="png")
plt.close() # Prevents the figure from being displayed directly
buffer.seek(0)
return Image.from_bytes(buffer.read())
40 changes: 12 additions & 28 deletions src/safeds/data/tabular/containers/_table.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,7 @@
from __future__ import annotations

import sys
import functools
import io
import sys
import warnings
from pathlib import Path
from typing import TYPE_CHECKING, Any, TypeVar
Expand All @@ -16,7 +15,6 @@
from pandas import DataFrame
from scipy import stats

from safeds.data.image.containers import Image
from safeds.data.tabular.typing import ColumnType, Schema
from safeds.exceptions import (
ColumnLengthMismatchError,
Expand All @@ -34,6 +32,7 @@
if TYPE_CHECKING:
from collections.abc import Callable, Mapping, Sequence

from safeds.data.image.containers import Image
from safeds.data.tabular.transformation import InvertibleTableTransformer, TableTransformer

from ._tagged_table import TaggedTable
Expand Down Expand Up @@ -1933,12 +1932,9 @@ def plot_correlation_heatmap(self) -> Image:
cmap="vlag",
)
plt.tight_layout()
from safeds._utils._plotting import _create_image_for_plot

buffer = io.BytesIO()
fig.savefig(buffer, format="png")
plt.close() # Prevents the figure from being displayed directly
buffer.seek(0)
return Image.from_bytes(buffer.read())
return _create_image_for_plot(fig)

def plot_lineplot(self, x_column_name: str, y_column_name: str) -> Image:
"""
Expand Down Expand Up @@ -1994,12 +1990,9 @@ def plot_lineplot(self, x_column_name: str, y_column_name: str) -> Image:
horizontalalignment="right",
) # rotate the labels of the x Axis to prevent the chance of overlapping of the labels
plt.tight_layout()
from safeds._utils._plotting import _create_image_for_plot

buffer = io.BytesIO()
fig.savefig(buffer, format="png")
plt.close() # Prevents the figure from being displayed directly
buffer.seek(0)
return Image.from_bytes(buffer.read())
return _create_image_for_plot(fig)

def plot_scatterplot(self, x_column_name: str, y_column_name: str) -> Image:
"""
Expand Down Expand Up @@ -2052,12 +2045,9 @@ def plot_scatterplot(self, x_column_name: str, y_column_name: str) -> Image:
horizontalalignment="right",
) # rotate the labels of the x Axis to prevent the chance of overlapping of the labels
plt.tight_layout()
from safeds._utils._plotting import _create_image_for_plot

buffer = io.BytesIO()
fig.savefig(buffer, format="png")
plt.close() # Prevents the figure from being displayed directly
buffer.seek(0)
return Image.from_bytes(buffer.read())
return _create_image_for_plot(fig)

def plot_boxplots(self) -> Image:
"""
Expand Down Expand Up @@ -2099,12 +2089,9 @@ def plot_boxplots(self) -> Image:
axes.set_xticks([])
plt.tight_layout()
fig = grid.fig
from safeds._utils._plotting import _create_image_for_plot

buffer = io.BytesIO()
fig.savefig(buffer, format="png")
plt.close() # Prevents the figure from being displayed directly
buffer.seek(0)
return Image.from_bytes(buffer.read())
return _create_image_for_plot(fig)

def plot_histograms(self) -> Image:
"""
Expand Down Expand Up @@ -2134,12 +2121,9 @@ def plot_histograms(self) -> Image:
axes.set_xticklabels(axes.get_xticklabels(), rotation=45, horizontalalignment="right")
grid.tight_layout()
fig = grid.fig
from safeds._utils._plotting import _create_image_for_plot

buffer = io.BytesIO()
fig.savefig(buffer, format="png")
plt.close()
buffer.seek(0)
return Image.from_bytes(buffer.read())
return _create_image_for_plot(fig)

# ------------------------------------------------------------------------------------------------------------------
# Conversion
Expand Down
75 changes: 64 additions & 11 deletions src/safeds/data/tabular/containers/_time_series.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,10 @@
from __future__ import annotations

import io
import sys
from typing import TYPE_CHECKING

import matplotlib.pyplot as plt
import pandas as pd

from safeds.data.image.containers import Image
from safeds.data.tabular.containers import Column, Row, Table, TaggedTable
from safeds.exceptions import (
ColumnIsTargetError,
Expand All @@ -21,6 +18,8 @@
from collections.abc import Callable, Mapping, Sequence
from typing import Any

from safeds.data.image.containers import Image


class TimeSeries(TaggedTable):

Expand Down Expand Up @@ -871,7 +870,7 @@ def transform_column(self, name: str, transformer: Callable[[Row], Any]) -> Time
time_name=self.time.name,
)

def plot_lagplot(self, lag: int) -> Image:
def plot_lag_plot(self, lag: int) -> Image:
"""
Plot a lagplot for the target column.

Expand All @@ -894,15 +893,69 @@ def plot_lagplot(self, lag: int) -> Image:
--------
>>> from safeds.data.tabular.containers import TimeSeries
>>> table = TimeSeries({"time":[1, 2], "target": [3, 4], "feature":[2,2]}, target_name= "target", time_name="time", feature_names=["feature"], )
>>> image = table.plot_lagplot(lag = 1)
>>> image = table.plot_lag_plot(lag = 1)

"""
if not self.target.type.is_numeric():
raise NonNumericColumnError("This time series target contains non-numerical columns.")
ax = pd.plotting.lag_plot(self.target._data, lag=lag)
fig = ax.figure
buffer = io.BytesIO()
fig.savefig(buffer, format="png")
plt.close() # Prevents the figure from being displayed directly
buffer.seek(0)
return Image.from_bytes(buffer.read())
from safeds._utils._plotting import _create_image_for_plot

return _create_image_for_plot(ax.figure)

def plot_moving_average(
self,
window_size: int,
column_name: str | None = None,
) -> Image:
"""
Plot the moving average for the target column.

Parameters
----------
window_size:
The size of the windows, which the average gets calculated for

column_name:
The name of the column which will be used to calculate the moving average, if None the target column will be taken

Returns
-------
plot:
The moving avereage plot and the normal plot as an image.

Raises
------
NonNumericColumnError
If the time series targets contains non-numerical values.

UnknownColumnNameError
If the time series doesn't contain the given column name

Examples
--------
>>> from safeds.data.tabular.containers import TimeSeries
>>> table = TimeSeries({"time":[1, 2], "target": [3, 4], "feature":[2,2]}, target_name= "target", time_name="time", feature_names=["feature"], )
>>> image = table.plot_moving_average(window_size = 2)

"""
if column_name is None or column_name == self.target.name:
series = self.target._data
column_name = self.target.name
else:
if column_name not in self.column_names:
raise UnknownColumnNameError([column_name])
series = self._data[column_name]
if not self.get_column(column_name).type.is_numeric():
raise NonNumericColumnError("This time series plotted column contains non-numerical columns.")

# create moving average series
series_mvg = series.rolling(window_size).mean()

# plot both series and put them together
ax_temp = series_mvg.plot()
ax = series.plot(ax=ax_temp)
ax.legend(labels=["moving_average", column_name])
from safeds._utils._plotting import _create_image_for_plot

return _create_image_for_plot(ax.figure)
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ def test_should_return_table(snapshot_png: SnapshotAssertion) -> None:
time_name="time",
feature_names=None,
)
lag_plot = table.plot_lagplot(lag=1)
lag_plot = table.plot_lag_plot(lag=1)
assert lag_plot == snapshot_png


Expand All @@ -38,4 +38,4 @@ def test_should_raise_if_column_contains_non_numerical_values() -> None:
r" non-numerical columns."
),
):
table.plot_lagplot(2)
table.plot_lag_plot(2)
Original file line number Diff line number Diff line change
@@ -0,0 +1,109 @@
import pytest
from safeds.data.tabular.containers import TimeSeries
from safeds.exceptions import NonNumericColumnError, UnknownColumnNameError
from syrupy import SnapshotAssertion


def test_should_return_table(snapshot_png: SnapshotAssertion) -> None:
table = TimeSeries(
{
"time": [1, 2, 3, 4, 5, 6, 7, 8, 9, 10],
"feature_1": [1, 2, 3, 4, 5, 6, 7, 8, 9, 10],
"target": [1, 2, 3, 4, 3, 2, 1, 2, 3, 4],
},
target_name="target",
time_name="time",
feature_names=None,
)
moving_average_plot = table.plot_moving_average(window_size=2)
assert moving_average_plot == snapshot_png


def test_optional_parameter(snapshot_png: SnapshotAssertion) -> None:
table = TimeSeries(
{
"time": [1, 2, 3, 4, 5, 6, 7, 8, 9, 10],
"feature_1": [1, 2, 3, 4, 5, 4, 3, 2, 1, 0],
"target": [1, 2, 3, 4, 3, 2, 1, 2, 3, 4],
},
target_name="target",
time_name="time",
feature_names=None,
)
moving_average_plot = table.plot_moving_average(window_size=2, column_name="feature_1")
assert moving_average_plot == snapshot_png


def test_should_raise_if_column_contains_non_numerical_values() -> None:
table = TimeSeries(
{
"time": [1, 2, 3, 4, 5, 6, 7, 8, 9, 10],
"feature_1": [1, 2, 3, 4, 5, 6, 7, 8, 9, 10],
"target": ["1", "2", "3", "4", "5", "6", "7", "8", "9", "10"],
},
target_name="target",
time_name="time",
feature_names=None,
)
with pytest.raises(
NonNumericColumnError,
match=(
r"Tried to do a numerical operation on one or multiple non-numerical columns: \nThis time series plotted"
r" column"
r" contains"
r" non-numerical columns."
),
):
table.plot_moving_average(2)


@pytest.mark.parametrize(
("time_series", "name", "error", "error_msg"),
[
(
TimeSeries(
{
"time": [1, 2, 3, 4, 5, 6, 7, 8, 9, 10],
"feature_1": ["1", "2", "3", "4", "5", "6", "7", "8", "9", "10"],
"target": ["1", "2", "3", "4", "5", "6", "7", "8", "9", "10"],
},
target_name="target",
time_name="time",
feature_names=None,
),
"feature_1",
NonNumericColumnError,
r"Tried to do a numerical operation on one or multiple non-numerical columns: \nThis time series plotted"
r" column"
r" contains"
r" non-numerical columns.",
),
(
TimeSeries(
{
"time": [1, 2, 3, 4, 5, 6, 7, 8, 9, 10],
"feature_1": ["1", "2", "3", "4", "5", "6", "7", "8", "9", "10"],
"target": ["1", "2", "3", "4", "5", "6", "7", "8", "9", "10"],
},
target_name="target",
time_name="time",
feature_names=None,
),
"feature_3",
UnknownColumnNameError,
r"Could not find column\(s\) 'feature_3'.",
),
],
ids=["feature_not_numerical", "feature_does_not_exist"],
)
def test_should_raise_error_optional_parameter(
time_series: TimeSeries,
name: str,
error: type[Exception],
error_msg: str,
) -> None:
with pytest.raises(
error,
match=error_msg,
):
time_series.plot_moving_average(2, column_name=name)