diff --git a/src/safeds/data/tabular/plotting/_column_plotter.py b/src/safeds/data/tabular/plotting/_column_plotter.py index f47dc8003..1ba568525 100644 --- a/src/safeds/data/tabular/plotting/_column_plotter.py +++ b/src/safeds/data/tabular/plotting/_column_plotter.py @@ -1,6 +1,6 @@ from __future__ import annotations -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Literal from safeds._utils import _figure_to_image from safeds._validation._check_columns_are_numeric import _check_column_is_numeric @@ -29,10 +29,15 @@ class ColumnPlotter: def __init__(self, column: Column): self._column: Column = column - def box_plot(self) -> Image: + def box_plot(self, *, theme: Literal["dark", "light"] = "light") -> Image: """ Create a box plot for the values in the column. This is only possible for numeric columns. + Parameters + ---------- + theme: + The color theme of the plot. Default is "light". + Returns ------- plot: @@ -54,20 +59,75 @@ def box_plot(self) -> Image: import matplotlib.pyplot as plt - fig, ax = plt.subplots() - ax.boxplot( - self._column._series.drop_nulls(), - patch_artist=True, - ) - - ax.set(title=self._column.name) - ax.set_xticks([]) - ax.yaxis.grid(visible=True) - fig.tight_layout() - - return _figure_to_image(fig) - - def histogram(self, *, max_bin_count: int = 10) -> Image: + def _set_boxplot_colors(box: dict, theme: str) -> None: + if theme == "dark": + for median in box["medians"]: + median.set(color="orange", linewidth=1.5) + + for box_part in box["boxes"]: + box_part.set(color="white", linewidth=1.5, facecolor="cyan") + + for whisker in box["whiskers"]: + whisker.set(color="white", linewidth=1.5) + + for cap in box["caps"]: + cap.set(color="white", linewidth=1.5) + + for flier in box["fliers"]: + flier.set(marker="o", color="white", alpha=0.5) + else: + for median in box["medians"]: + median.set(color="orange", linewidth=1.5) + + for box_part in box["boxes"]: + box_part.set(color="black", linewidth=1.5, facecolor="blue") + + for whisker in box["whiskers"]: + whisker.set(color="black", linewidth=1.5) + + for cap in box["caps"]: + cap.set(color="black", linewidth=1.5) + + for flier in box["fliers"]: + flier.set(marker="o", color="black", alpha=0.5) + + style = "dark_background" if theme == "dark" else "default" + with plt.style.context(style): + if theme == "dark": + plt.rcParams.update( + { + "text.color": "white", + "axes.labelcolor": "white", + "axes.edgecolor": "white", + "xtick.color": "white", + "ytick.color": "white", + "grid.color": "gray", + "grid.linewidth": 0.5, + }, + ) + else: + plt.rcParams.update( + { + "grid.linewidth": 0.5, + }, + ) + + fig, ax = plt.subplots() + box = ax.boxplot( + self._column._series.drop_nulls(), + patch_artist=True, + ) + + _set_boxplot_colors(box, theme) + + ax.set(title=self._column.name) + ax.set_xticks([]) + ax.yaxis.grid(visible=True) + fig.tight_layout() + + return _figure_to_image(fig) + + def histogram(self, *, max_bin_count: int = 10, theme: Literal["dark", "light"] = "light") -> Image: """ Create a histogram for the values in the column. @@ -75,6 +135,8 @@ def histogram(self, *, max_bin_count: int = 10) -> Image: ---------- max_bin_count: The maximum number of bins to use in the histogram. Default is 10. + theme: + The color theme of the plot. Default is "light". Returns ------- @@ -87,9 +149,24 @@ def histogram(self, *, max_bin_count: int = 10) -> Image: >>> column = Column("test", [1, 2, 3]) >>> histogram = column.plot.histogram() """ - return self._column.to_table().plot.histograms(max_bin_count=max_bin_count) + import matplotlib.pyplot as plt - def lag_plot(self, lag: int) -> Image: + style = "dark_background" if theme == "dark" else "default" + with plt.style.context(style): + if theme == "dark": + plt.rcParams.update( + { + "text.color": "white", + "axes.labelcolor": "white", + "axes.edgecolor": "white", + "xtick.color": "white", + "ytick.color": "white", + }, + ) + + return self._column.to_table().plot.histograms(max_bin_count=max_bin_count) + + def lag_plot(self, lag: int, *, theme: Literal["dark", "light"] = "light") -> Image: """ Create a lag plot for the values in the column. @@ -97,6 +174,8 @@ def lag_plot(self, lag: int) -> Image: ---------- lag: The amount of lag. + theme: + The color theme of the plot. Default is "light". Returns ------- @@ -114,21 +193,34 @@ def lag_plot(self, lag: int) -> Image: >>> column = Column("values", [1, 2, 3, 4]) >>> image = column.plot.lag_plot(2) """ - if self._column.row_count > 0: - _check_column_is_numeric(self._column, operation="create a lag plot") - import matplotlib.pyplot as plt - fig, ax = plt.subplots() - series = self._column._series - ax.scatter( - x=series.slice(0, max(len(self._column) - lag, 0)), - y=series.slice(lag), - ) - ax.set( - xlabel="y(t)", - ylabel=f"y(t + {lag})", - ) - fig.tight_layout() - - return _figure_to_image(fig) + style = "dark_background" if theme == "dark" else "default" + with plt.style.context(style): + if theme == "dark": + plt.rcParams.update( + { + "text.color": "white", + "axes.labelcolor": "white", + "axes.edgecolor": "white", + "xtick.color": "white", + "ytick.color": "white", + }, + ) + + if self._column.row_count > 0: + _check_column_is_numeric(self._column, operation="create a lag plot") + + fig, ax = plt.subplots() + series = self._column._series + ax.scatter( + x=series.slice(0, max(len(self._column) - lag, 0)), + y=series.slice(lag), + ) + ax.set( + xlabel="y(t)", + ylabel=f"y(t + {lag})", + ) + fig.tight_layout() + + return _figure_to_image(fig) diff --git a/src/safeds/data/tabular/plotting/_table_plotter.py b/src/safeds/data/tabular/plotting/_table_plotter.py index e7bf679cb..6294a0827 100644 --- a/src/safeds/data/tabular/plotting/_table_plotter.py +++ b/src/safeds/data/tabular/plotting/_table_plotter.py @@ -1,7 +1,7 @@ from __future__ import annotations import warnings -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Literal from safeds._utils import _figure_to_image from safeds._validation import _check_bounds, _check_columns_exist, _ClosedBound @@ -34,10 +34,15 @@ class TablePlotter: def __init__(self, table: Table): self._table: Table = table - def box_plots(self) -> Image: + def box_plots(self, *, theme: Literal["dark", "light"] = "light") -> Image: """ Create a box plot for every numerical column. + Parameters + ---------- + theme: + The color theme of the plot. Default is "light". + Returns ------- plot: @@ -61,53 +66,83 @@ def box_plots(self) -> Image: import matplotlib.pyplot as plt - columns = numerical_table.to_columns() - columns = [column._series.drop_nulls() for column in columns] - max_width = 3 - number_of_columns = len(columns) if len(columns) <= max_width else max_width - number_of_rows = ceil(len(columns) / number_of_columns) - - fig, axs = plt.subplots(nrows=number_of_rows, ncols=number_of_columns) - line = 0 - for i, column in enumerate(columns): - if i % number_of_columns == 0 and i != 0: - line += 1 - - if number_of_columns == 1: - axs.boxplot( - column, - patch_artist=True, - labels=[numerical_table.column_names[i]], + style = "dark_background" if theme == "dark" else "default" + with plt.style.context(style): + if theme == "dark": + plt.rcParams.update( + { + "text.color": "white", + "axes.labelcolor": "white", + "axes.edgecolor": "white", + "xtick.color": "white", + "ytick.color": "white", + }, ) - break - - if number_of_rows == 1: - axs[i].boxplot( - column, - patch_artist=True, - labels=[numerical_table.column_names[i]], - ) - - else: - axs[line, i % number_of_columns].boxplot( - column, - patch_artist=True, - labels=[numerical_table.column_names[i]], - ) - - # removes unused ax indices, so there wont be empty plots - last_filled_ax_index = len(columns) % number_of_columns - for i in range(last_filled_ax_index, number_of_columns): - if number_of_rows != 1 and last_filled_ax_index != 0: - fig.delaxes(axs[number_of_rows - 1, i]) + columns = numerical_table.to_columns() + columns = [column._series.drop_nulls() for column in columns] + max_width = 3 + number_of_columns = len(columns) if len(columns) <= max_width else max_width + number_of_rows = ceil(len(columns) / number_of_columns) + + fig, axs = plt.subplots(nrows=number_of_rows, ncols=number_of_columns) + line = 0 + for i, column in enumerate(columns): + if i % number_of_columns == 0 and i != 0: + line += 1 + + if number_of_columns == 1: + axs.boxplot( + column, + patch_artist=True, + labels=[numerical_table.column_names[i]], + ) + break + + if number_of_rows == 1: + axs[i].boxplot( + column, + patch_artist=True, + labels=[numerical_table.column_names[i]], + ) + + else: + axs[line, i % number_of_columns].boxplot( + column, + patch_artist=True, + labels=[numerical_table.column_names[i]], + ) + + # removes unused ax indices, so there wont be empty plots + last_filled_ax_index = len(columns) % number_of_columns + for i in range(last_filled_ax_index, number_of_columns): + if number_of_rows != 1 and last_filled_ax_index != 0: + fig.delaxes(axs[number_of_rows - 1, i]) - fig.tight_layout() - return _figure_to_image(fig) + fig.tight_layout() - def correlation_heatmap(self) -> Image: + style = "dark_background" if theme == "dark" else "default" + with plt.style.context(style): + if theme == "dark": + plt.rcParams.update( + { + "text.color": "white", + "axes.labelcolor": "white", + "axes.edgecolor": "white", + "xtick.color": "white", + "ytick.color": "white", + }, + ) + return _figure_to_image(fig) + + def correlation_heatmap(self, *, theme: Literal["dark", "light"] = "light") -> Image: """ Plot a correlation heatmap for all numerical columns of this `Table`. + Parameters + ---------- + theme: + The color theme of the plot. Default is "light". + Returns ------- plot: @@ -119,44 +154,59 @@ def correlation_heatmap(self) -> Image: >>> table = Table({"temperature": [10, 15, 20, 25, 30], "sales": [54, 74, 90, 206, 210]}) >>> image = table.plot.correlation_heatmap() """ - # TODO: implement using matplotlib and polars - # https://stackoverflow.com/questions/33282368/plotting-a-2d-heatmap import matplotlib.pyplot as plt import numpy as np - only_numerical = self._table.remove_non_numeric_columns()._data_frame.fill_null(0) + style = "dark_background" if theme == "dark" else "default" + with plt.style.context(style): + if theme == "dark": + plt.rcParams.update( + { + "text.color": "white", + "axes.labelcolor": "white", + "axes.edgecolor": "white", + "xtick.color": "white", + "ytick.color": "white", + }, + ) - if self._table.row_count == 0: - warnings.warn( - "An empty table has been used. A correlation heatmap on an empty table will show nothing.", - stacklevel=2, - ) + only_numerical = self._table.remove_non_numeric_columns()._data_frame.fill_null(0) - with warnings.catch_warnings(): - warnings.filterwarnings( - "ignore", - message=( - "Attempting to set identical low and high (xlims|ylims) makes transformation singular;" - " automatically expanding." - ), - ) + if self._table.row_count == 0: + warnings.warn( + "An empty table has been used. A correlation heatmap on an empty table will show nothing.", + stacklevel=2, + ) - fig, ax = plt.subplots() - heatmap = plt.imshow( - only_numerical.corr().to_numpy(), - vmin=-1, - vmax=1, - cmap="coolwarm", - ) - ax.set_xticks(np.arange(len(only_numerical.columns)), rotation="vertical", labels=only_numerical.columns) - ax.set_yticks(np.arange(len(only_numerical.columns)), labels=only_numerical.columns) - fig.colorbar(heatmap) + with warnings.catch_warnings(): + warnings.filterwarnings( + "ignore", + message=( + "Attempting to set identical low and high (xlims|ylims) makes transformation singular;" + " automatically expanding." + ), + ) + + fig, ax = plt.subplots() + heatmap = plt.imshow( + only_numerical.corr().to_numpy(), + vmin=-1, + vmax=1, + cmap="coolwarm", + ) + ax.set_xticks( + np.arange(len(only_numerical.columns)), + rotation="vertical", + labels=only_numerical.columns, + ) + ax.set_yticks(np.arange(len(only_numerical.columns)), labels=only_numerical.columns) + fig.colorbar(heatmap) - plt.tight_layout() + plt.tight_layout() - return _figure_to_image(fig) + return _figure_to_image(fig) - def histograms(self, *, max_bin_count: int = 10) -> Image: + def histograms(self, *, max_bin_count: int = 10, theme: Literal["dark", "light"] = "light") -> Image: """ Plot a histogram for every column. @@ -164,6 +214,8 @@ def histograms(self, *, max_bin_count: int = 10) -> Image: ---------- max_bin_count: The maximum number of bins to use in the histogram. Default is 10. + theme: + The color theme of the plot. Default is "light". Returns ------- @@ -177,55 +229,80 @@ def histograms(self, *, max_bin_count: int = 10) -> Image: >>> image = table.plot.histograms() """ import matplotlib.pyplot as plt - import polars as pl - - n_cols = min(3, self._table.column_count) - n_rows = 1 + (self._table.column_count - 1) // n_cols - if n_cols == 1 and n_rows == 1: - fig, axs = plt.subplots(1, 1, tight_layout=True) - one_col = True - else: - fig, axs = plt.subplots(n_rows, n_cols, tight_layout=True, figsize=(n_cols * 3, n_rows * 3)) - one_col = False - - col_names = self._table.column_names - for col_name, ax in zip(col_names, axs.flatten() if not one_col else [axs], strict=False): - column = self._table.get_column(col_name) - distinct_values = column.get_distinct_values() - - ax.set_title(col_name) - ax.set_xlabel("") - ax.set_ylabel("") - - if column.is_numeric and len(distinct_values) > max_bin_count: - min_val = (column.min() or 0) - 1e-6 # Otherwise the minimum value is not included in the first bin - max_val = column.max() or 0 - bin_count = min(max_bin_count, len(distinct_values)) - bins = [ - *(pl.Series(range(bin_count + 1)) / bin_count * (max_val - min_val) + min_val), - ] - - bars = [f"{round((bins[i] + bins[i + 1]) / 2, 2)}" for i in range(len(bins) - 1)] - hist = column._series.hist(bins=bins).slice(1, length=max_bin_count).get_column("count").to_numpy() - - ax.bar(bars, hist, edgecolor="black") - ax.set_xticks(range(len(hist)), bars, rotation=45, horizontalalignment="right") - else: - value_counts = ( - column._series.drop_nulls().value_counts().sort(column.name).slice(0, length=max_bin_count) + style = "dark_background" if theme == "dark" else "default" + with plt.style.context(style): + if theme == "dark": + plt.rcParams.update( + { + "text.color": "white", + "axes.labelcolor": "white", + "axes.edgecolor": "white", + "xtick.color": "white", + "ytick.color": "white", + }, ) - distinct_values = value_counts.get_column(column.name).cast(pl.String).to_numpy() - hist = value_counts.get_column("count").to_numpy() - ax.bar(distinct_values, hist, edgecolor="black") - ax.set_xticks(range(len(distinct_values)), distinct_values, rotation=45, horizontalalignment="right") + import polars as pl - for i in range(len(col_names), n_rows * n_cols): - fig.delaxes(axs.flatten()[i]) # Remove empty subplots + n_cols = min(3, self._table.column_count) + n_rows = 1 + (self._table.column_count - 1) // n_cols - return _figure_to_image(fig) + if n_cols == 1 and n_rows == 1: + fig, axs = plt.subplots(1, 1, tight_layout=True) + one_col = True + else: + fig, axs = plt.subplots(n_rows, n_cols, tight_layout=True, figsize=(n_cols * 3, n_rows * 3)) + one_col = False + + col_names = self._table.column_names + for col_name, ax in zip(col_names, axs.flatten() if not one_col else [axs], strict=False): + column = self._table.get_column(col_name) + distinct_values = column.get_distinct_values() + + ax.set_title(col_name) + ax.set_xlabel("") + ax.set_ylabel("") + + if column.is_numeric and len(distinct_values) > max_bin_count: + min_val = (column.min() or 0) - 1e-6 # Otherwise the minimum value is not included in the first bin + max_val = column.max() or 0 + bin_count = min(max_bin_count, len(distinct_values)) + bins = [ + *(pl.Series(range(bin_count + 1)) / bin_count * (max_val - min_val) + min_val), + ] + + bars = [f"{round((bins[i] + bins[i + 1]) / 2, 2)}" for i in range(len(bins) - 1)] + hist = column._series.hist(bins=bins).slice(1, length=max_bin_count).get_column("count").to_numpy() + + ax.bar(bars, hist, edgecolor="black") + ax.set_xticks(range(len(hist)), bars, rotation=45, horizontalalignment="right") + else: + value_counts = ( + column._series.drop_nulls().value_counts().sort(column.name).slice(0, length=max_bin_count) + ) + distinct_values = value_counts.get_column(column.name).cast(pl.String).to_numpy() + hist = value_counts.get_column("count").to_numpy() + ax.bar(distinct_values, hist, edgecolor="black") + ax.set_xticks( + range(len(distinct_values)), + distinct_values, + rotation=45, + horizontalalignment="right", + ) + + for i in range(len(col_names), n_rows * n_cols): + fig.delaxes(axs.flatten()[i]) # Remove empty subplots + + return _figure_to_image(fig) - def line_plot(self, x_name: str, y_names: list[str], show_confidence_interval: bool = True) -> Image: + def line_plot( + self, + x_name: str, + y_names: list[str], + *, + show_confidence_interval: bool = True, + theme: Literal["dark", "light"] = "light", + ) -> Image: """ Create a line plot for two columns in the table. @@ -237,6 +314,8 @@ def line_plot(self, x_name: str, y_names: list[str], show_confidence_interval: b The name(s) of the column(s) to be plotted on the y-axis. show_confidence_interval: If the confidence interval is shown, per default True. + theme: + The color theme of the plot. Default is "light". Returns ------- @@ -264,57 +343,70 @@ def line_plot(self, x_name: str, y_names: list[str], show_confidence_interval: b _plot_validation(self._table, x_name, y_names) import matplotlib.pyplot as plt - import polars as pl - - agg_list = [] - for name in y_names: - agg_list.append(pl.col(name).mean().alias(f"{name}_mean")) - agg_list.append(pl.count(name).alias(f"{name}_count")) - agg_list.append(pl.std(name, ddof=0).alias(f"{name}_std")) - grouped = self._table._lazy_frame.sort(x_name).group_by(x_name).agg(agg_list).collect() - - x = grouped.get_column(x_name) - y_s = [] - confidence_intervals = [] - for name in y_names: - y_s.append(grouped.get_column(name + "_mean")) - confidence_intervals.append( - 1.96 * grouped.get_column(name + "_std") / grouped.get_column(name + "_count").sqrt(), - ) - fig, ax = plt.subplots() - for name, y in zip(y_names, y_s, strict=False): - ax.plot(x, y, label=name) - - if show_confidence_interval: - for y, conf in zip(y_s, confidence_intervals, strict=False): - ax.fill_between( - x, - y - conf, - y + conf, - color="lightblue", - alpha=0.15, + style = "dark_background" if theme == "dark" else "default" + with plt.style.context(style): + if theme == "dark": + plt.rcParams.update( + { + "text.color": "white", + "axes.labelcolor": "white", + "axes.edgecolor": "white", + "xtick.color": "white", + "ytick.color": "white", + }, ) - if len(y_names) > 1: - name = "values" - else: - name = y_names[0] - ax.set( - xlabel=x_name, - ylabel=name, - ) - ax.legend() - ax.set_xticks(ax.get_xticks()) - ax.set_xticklabels( - ax.get_xticklabels(), - rotation=45, - horizontalalignment="right", - ) # rotate the labels of the x Axis to prevent the chance of overlapping of the labels - fig.tight_layout() - - return _figure_to_image(fig) - - def scatter_plot(self, x_name: str, y_names: list[str]) -> Image: + import polars as pl + + agg_list = [] + for name in y_names: + agg_list.append(pl.col(name).mean().alias(f"{name}_mean")) + agg_list.append(pl.count(name).alias(f"{name}_count")) + agg_list.append(pl.std(name, ddof=0).alias(f"{name}_std")) + grouped = self._table._lazy_frame.sort(x_name).group_by(x_name).agg(agg_list).collect() + + x = grouped.get_column(x_name) + y_s = [] + confidence_intervals = [] + for name in y_names: + y_s.append(grouped.get_column(name + "_mean")) + confidence_intervals.append( + 1.96 * grouped.get_column(name + "_std") / grouped.get_column(name + "_count").sqrt(), + ) + + fig, ax = plt.subplots() + for name, y in zip(y_names, y_s, strict=False): + ax.plot(x, y, label=name) + + if show_confidence_interval: + for y, conf in zip(y_s, confidence_intervals, strict=False): + ax.fill_between( + x, + y - conf, + y + conf, + color="lightblue", + alpha=0.15, + ) + if len(y_names) > 1: + name = "values" + else: + name = y_names[0] + ax.set( + xlabel=x_name, + ylabel=name, + ) + ax.legend() + ax.set_xticks(ax.get_xticks()) + ax.set_xticklabels( + ax.get_xticklabels(), + rotation=45, + horizontalalignment="right", + ) # rotate the labels of the x Axis to prevent the chance of overlapping of the labels + fig.tight_layout() + + return _figure_to_image(fig) + + def scatter_plot(self, x_name: str, y_names: list[str], *, theme: Literal["dark", "light"] = "light") -> Image: """ Create a scatter plot for two columns in the table. @@ -324,6 +416,8 @@ def scatter_plot(self, x_name: str, y_names: list[str]) -> Image: The name of the column to be plotted on the x-axis. y_names: The name(s) of the column(s) to be plotted on the y-axis. + theme: + The color theme of the plot. Default is "light". Returns ------- @@ -352,36 +446,55 @@ def scatter_plot(self, x_name: str, y_names: list[str]) -> Image: import matplotlib.pyplot as plt - fig, ax = plt.subplots() - for y_name in y_names: - ax.scatter( - x=self._table.get_column(x_name)._series, - y=self._table.get_column(y_name)._series, - s=64, # marker size - linewidth=1, - edgecolor="white", - label=y_name, + style = "dark_background" if theme == "dark" else "default" + with plt.style.context(style): + if theme == "dark": + plt.rcParams.update( + { + "text.color": "white", + "axes.labelcolor": "white", + "axes.edgecolor": "white", + "xtick.color": "white", + "ytick.color": "white", + }, + ) + fig, ax = plt.subplots() + for y_name in y_names: + ax.scatter( + x=self._table.get_column(x_name)._series, + y=self._table.get_column(y_name)._series, + s=64, # marker size + linewidth=1, + edgecolor="white", + label=y_name, + ) + if len(y_names) > 1: + name = "values" + else: + name = y_names[0] + ax.set( + xlabel=x_name, + ylabel=name, ) - if len(y_names) > 1: - name = "values" - else: - name = y_names[0] - ax.set( - xlabel=x_name, - ylabel=name, - ) - ax.legend() - ax.set_xticks(ax.get_xticks()) - ax.set_xticklabels( - ax.get_xticklabels(), - rotation=45, - horizontalalignment="right", - ) # rotate the labels of the x Axis to prevent the chance of overlapping of the labels - fig.tight_layout() - - return _figure_to_image(fig) - - def moving_average_plot(self, x_name: str, y_name: str, window_size: int) -> Image: + ax.legend() + ax.set_xticks(ax.get_xticks()) + ax.set_xticklabels( + ax.get_xticklabels(), + rotation=45, + horizontalalignment="right", + ) # rotate the labels of the x Axis to prevent the chance of overlapping of the labels + fig.tight_layout() + + return _figure_to_image(fig) + + def moving_average_plot( + self, + x_name: str, + y_name: str, + window_size: int, + *, + theme: Literal["dark", "light"] = "light", + ) -> Image: """ Create a moving average plot for the y column and plot it by the x column in the table. @@ -391,6 +504,8 @@ def moving_average_plot(self, x_name: str, y_name: str, window_size: int) -> Ima The name of the column to be plotted on the x-axis. y_name: The name of the column to be plotted on the y-axis. + theme: + The color theme of the plot. Default is "light". Returns ------- @@ -416,45 +531,58 @@ def moving_average_plot(self, x_name: str, y_name: str, window_size: int) -> Ima >>> image = table.plot.moving_average_plot("a", "b", window_size = 2) """ import matplotlib.pyplot as plt - import numpy as np - import polars as pl - _plot_validation(self._table, x_name, [y_name]) - for name in [x_name, y_name]: - if self._table.get_column(name).missing_value_count() >= 1: - raise ValueError( - f"there are missing values in column '{name}', use transformation to fill missing values " - f"or drop the missing values. For a moving average no missing values are allowed.", + style = "dark_background" if theme == "dark" else "default" + with plt.style.context(style): + if theme == "dark": + plt.rcParams.update( + { + "text.color": "white", + "axes.labelcolor": "white", + "axes.edgecolor": "white", + "xtick.color": "white", + "ytick.color": "white", + }, ) + import numpy as np + import polars as pl + + _plot_validation(self._table, x_name, [y_name]) + for name in [x_name, y_name]: + if self._table.get_column(name).missing_value_count() >= 1: + raise ValueError( + f"there are missing values in column '{name}', use transformation to fill missing values " + f"or drop the missing values. For a moving average no missing values are allowed.", + ) + + # Calculate the moving average + mean_col = pl.col(y_name).mean().alias(y_name) + grouped = self._table._lazy_frame.sort(x_name).group_by(x_name).agg(mean_col).collect() + data = grouped + moving_average = data.select([pl.col(y_name).rolling_mean(window_size).alias("moving_average")]) + # set up the arrays for plotting + y_data_with_nan = moving_average["moving_average"].to_numpy() + nan_mask = ~np.isnan(y_data_with_nan) + y_data = y_data_with_nan[nan_mask] + x_data = data[x_name].to_numpy()[nan_mask] + fig, ax = plt.subplots() + ax.plot(x_data, y_data, label="moving average") + ax.set( + xlabel=x_name, + ylabel=y_name, + ) + ax.legend() + if self._table.get_column(x_name).is_temporal: + ax.set_xticks(x_data) # Set x-ticks to the x data points + ax.set_xticks(ax.get_xticks()) + ax.set_xticklabels( + ax.get_xticklabels(), + rotation=45, + horizontalalignment="right", + ) # rotate the labels of the x Axis to prevent the chance of overlapping of the labels + fig.tight_layout() - # Calculate the moving average - mean_col = pl.col(y_name).mean().alias(y_name) - grouped = self._table._lazy_frame.sort(x_name).group_by(x_name).agg(mean_col).collect() - data = grouped - moving_average = data.select([pl.col(y_name).rolling_mean(window_size).alias("moving_average")]) - # set up the arrays for plotting - y_data_with_nan = moving_average["moving_average"].to_numpy() - nan_mask = ~np.isnan(y_data_with_nan) - y_data = y_data_with_nan[nan_mask] - x_data = data[x_name].to_numpy()[nan_mask] - fig, ax = plt.subplots() - ax.plot(x_data, y_data, label="moving average") - ax.set( - xlabel=x_name, - ylabel=y_name, - ) - ax.legend() - if self._table.get_column(x_name).is_temporal: - ax.set_xticks(x_data) # Set x-ticks to the x data points - ax.set_xticks(ax.get_xticks()) - ax.set_xticklabels( - ax.get_xticklabels(), - rotation=45, - horizontalalignment="right", - ) # rotate the labels of the x Axis to prevent the chance of overlapping of the labels - fig.tight_layout() - - return _figure_to_image(fig) + return _figure_to_image(fig) def histogram_2d( self, diff --git a/tests/safeds/data/tabular/containers/_column/__snapshots__/test_plot_box_plot/test_should_match_snapshot[empty].png b/tests/safeds/data/tabular/containers/_column/__snapshots__/test_plot_box_plot/test_should_match_snapshot[empty].png index f1f0e93bc..5451c1ddf 100644 Binary files a/tests/safeds/data/tabular/containers/_column/__snapshots__/test_plot_box_plot/test_should_match_snapshot[empty].png and b/tests/safeds/data/tabular/containers/_column/__snapshots__/test_plot_box_plot/test_should_match_snapshot[empty].png differ diff --git a/tests/safeds/data/tabular/containers/_column/__snapshots__/test_plot_box_plot/test_should_match_snapshot[multiple rows].png b/tests/safeds/data/tabular/containers/_column/__snapshots__/test_plot_box_plot/test_should_match_snapshot[multiple rows].png index ef17a72d5..ede1f8073 100644 Binary files a/tests/safeds/data/tabular/containers/_column/__snapshots__/test_plot_box_plot/test_should_match_snapshot[multiple rows].png and b/tests/safeds/data/tabular/containers/_column/__snapshots__/test_plot_box_plot/test_should_match_snapshot[multiple rows].png differ diff --git a/tests/safeds/data/tabular/containers/_column/__snapshots__/test_plot_box_plot/test_should_match_snapshot[one row].png b/tests/safeds/data/tabular/containers/_column/__snapshots__/test_plot_box_plot/test_should_match_snapshot[one row].png index 82e8c9187..5e0f824ba 100644 Binary files a/tests/safeds/data/tabular/containers/_column/__snapshots__/test_plot_box_plot/test_should_match_snapshot[one row].png and b/tests/safeds/data/tabular/containers/_column/__snapshots__/test_plot_box_plot/test_should_match_snapshot[one row].png differ diff --git a/tests/safeds/data/tabular/containers/_column/__snapshots__/test_plot_box_plot/test_should_match_snapshot_dark[empty].png b/tests/safeds/data/tabular/containers/_column/__snapshots__/test_plot_box_plot/test_should_match_snapshot_dark[empty].png new file mode 100644 index 000000000..a6303ecc5 Binary files /dev/null and b/tests/safeds/data/tabular/containers/_column/__snapshots__/test_plot_box_plot/test_should_match_snapshot_dark[empty].png differ diff --git a/tests/safeds/data/tabular/containers/_column/__snapshots__/test_plot_box_plot/test_should_match_snapshot_dark[multiple rows].png b/tests/safeds/data/tabular/containers/_column/__snapshots__/test_plot_box_plot/test_should_match_snapshot_dark[multiple rows].png new file mode 100644 index 000000000..1b5c5db21 Binary files /dev/null and b/tests/safeds/data/tabular/containers/_column/__snapshots__/test_plot_box_plot/test_should_match_snapshot_dark[multiple rows].png differ diff --git a/tests/safeds/data/tabular/containers/_column/__snapshots__/test_plot_box_plot/test_should_match_snapshot_dark[one row].png b/tests/safeds/data/tabular/containers/_column/__snapshots__/test_plot_box_plot/test_should_match_snapshot_dark[one row].png new file mode 100644 index 000000000..6de7d77ae Binary files /dev/null and b/tests/safeds/data/tabular/containers/_column/__snapshots__/test_plot_box_plot/test_should_match_snapshot_dark[one row].png differ diff --git a/tests/safeds/data/tabular/containers/_column/__snapshots__/test_plot_histogram/test_should_match_snapshot_dark[empty].png b/tests/safeds/data/tabular/containers/_column/__snapshots__/test_plot_histogram/test_should_match_snapshot_dark[empty].png new file mode 100644 index 000000000..3958a4aeb Binary files /dev/null and b/tests/safeds/data/tabular/containers/_column/__snapshots__/test_plot_histogram/test_should_match_snapshot_dark[empty].png differ diff --git a/tests/safeds/data/tabular/containers/_column/__snapshots__/test_plot_histogram/test_should_match_snapshot_dark[multiple rows (numeric)].png b/tests/safeds/data/tabular/containers/_column/__snapshots__/test_plot_histogram/test_should_match_snapshot_dark[multiple rows (numeric)].png new file mode 100644 index 000000000..2a8a53eef Binary files /dev/null and b/tests/safeds/data/tabular/containers/_column/__snapshots__/test_plot_histogram/test_should_match_snapshot_dark[multiple rows (numeric)].png differ diff --git a/tests/safeds/data/tabular/containers/_column/__snapshots__/test_plot_histogram/test_should_match_snapshot_dark[non-numeric].png b/tests/safeds/data/tabular/containers/_column/__snapshots__/test_plot_histogram/test_should_match_snapshot_dark[non-numeric].png new file mode 100644 index 000000000..52676b5df Binary files /dev/null and b/tests/safeds/data/tabular/containers/_column/__snapshots__/test_plot_histogram/test_should_match_snapshot_dark[non-numeric].png differ diff --git a/tests/safeds/data/tabular/containers/_column/__snapshots__/test_plot_histogram/test_should_match_snapshot_dark[one row (numeric)].png b/tests/safeds/data/tabular/containers/_column/__snapshots__/test_plot_histogram/test_should_match_snapshot_dark[one row (numeric)].png new file mode 100644 index 000000000..4f91625ab Binary files /dev/null and b/tests/safeds/data/tabular/containers/_column/__snapshots__/test_plot_histogram/test_should_match_snapshot_dark[one row (numeric)].png differ diff --git a/tests/safeds/data/tabular/containers/_column/__snapshots__/test_plot_lag_plot/test_should_match_snapshot_dark[empty].png b/tests/safeds/data/tabular/containers/_column/__snapshots__/test_plot_lag_plot/test_should_match_snapshot_dark[empty].png new file mode 100644 index 000000000..b69bb6e78 Binary files /dev/null and b/tests/safeds/data/tabular/containers/_column/__snapshots__/test_plot_lag_plot/test_should_match_snapshot_dark[empty].png differ diff --git a/tests/safeds/data/tabular/containers/_column/__snapshots__/test_plot_lag_plot/test_should_match_snapshot_dark[multiple rows].png b/tests/safeds/data/tabular/containers/_column/__snapshots__/test_plot_lag_plot/test_should_match_snapshot_dark[multiple rows].png new file mode 100644 index 000000000..15a45f475 Binary files /dev/null and b/tests/safeds/data/tabular/containers/_column/__snapshots__/test_plot_lag_plot/test_should_match_snapshot_dark[multiple rows].png differ diff --git a/tests/safeds/data/tabular/containers/_column/__snapshots__/test_plot_lag_plot/test_should_match_snapshot_dark[one row].png b/tests/safeds/data/tabular/containers/_column/__snapshots__/test_plot_lag_plot/test_should_match_snapshot_dark[one row].png new file mode 100644 index 000000000..b69bb6e78 Binary files /dev/null and b/tests/safeds/data/tabular/containers/_column/__snapshots__/test_plot_lag_plot/test_should_match_snapshot_dark[one row].png differ diff --git a/tests/safeds/data/tabular/containers/_column/test_plot_box_plot.py b/tests/safeds/data/tabular/containers/_column/test_plot_box_plot.py index fbfbd21ad..089e4f7fe 100644 --- a/tests/safeds/data/tabular/containers/_column/test_plot_box_plot.py +++ b/tests/safeds/data/tabular/containers/_column/test_plot_box_plot.py @@ -26,3 +26,21 @@ def test_should_raise_if_column_contains_non_numerical_values() -> None: column = Column("a", ["A", "B", "C"]) with pytest.raises(ColumnTypeError): column.plot.box_plot() + + +@pytest.mark.parametrize( + "column", + [ + Column("a", []), + Column("a", [0]), + Column("a", [0, 1]), + ], + ids=[ + "empty", + "one row", + "multiple rows", + ], +) +def test_should_match_snapshot_dark(column: Column, snapshot_png_image: SnapshotAssertion) -> None: + box_plot = column.plot.box_plot(theme="dark") + assert box_plot == snapshot_png_image diff --git a/tests/safeds/data/tabular/containers/_column/test_plot_histogram.py b/tests/safeds/data/tabular/containers/_column/test_plot_histogram.py index 1f6984144..032fe2155 100644 --- a/tests/safeds/data/tabular/containers/_column/test_plot_histogram.py +++ b/tests/safeds/data/tabular/containers/_column/test_plot_histogram.py @@ -21,3 +21,23 @@ def test_should_match_snapshot(column: Column, snapshot_png_image: SnapshotAssertion) -> None: histogram = column.plot.histogram() assert histogram == snapshot_png_image + + +@pytest.mark.parametrize( + "column", + [ + Column("a", []), + Column("a", [0]), + Column("a", [0, 1]), + Column("a", ["A", "B", "C"]), + ], + ids=[ + "empty", + "one row (numeric)", + "multiple rows (numeric)", + "non-numeric", + ], +) +def test_should_match_snapshot_dark(column: Column, snapshot_png_image: SnapshotAssertion) -> None: + histogram = column.plot.histogram(theme="dark") + assert histogram == snapshot_png_image diff --git a/tests/safeds/data/tabular/containers/_column/test_plot_lag_plot.py b/tests/safeds/data/tabular/containers/_column/test_plot_lag_plot.py index aa5abbfa0..429aa3f97 100644 --- a/tests/safeds/data/tabular/containers/_column/test_plot_lag_plot.py +++ b/tests/safeds/data/tabular/containers/_column/test_plot_lag_plot.py @@ -26,3 +26,21 @@ def test_should_raise_if_column_contains_non_numerical_values() -> None: column = Column("a", ["A", "B", "C"]) with pytest.raises(ColumnTypeError): column.plot.lag_plot(1) + + +@pytest.mark.parametrize( + "column", + [ + Column("a", []), + Column("a", [0]), + Column("a", [0, 1]), + ], + ids=[ + "empty", + "one row", + "multiple rows", + ], +) +def test_should_match_snapshot_dark(column: Column, snapshot_png_image: SnapshotAssertion) -> None: + lag_plot = column.plot.lag_plot(1, theme="dark") + assert lag_plot == snapshot_png_image diff --git a/tests/safeds/data/tabular/plotting/__snapshots__/test_moving_average_plot/test_should_match_snapshot_dark[date grouped].png b/tests/safeds/data/tabular/plotting/__snapshots__/test_moving_average_plot/test_should_match_snapshot_dark[date grouped].png new file mode 100644 index 000000000..34418ac99 Binary files /dev/null and b/tests/safeds/data/tabular/plotting/__snapshots__/test_moving_average_plot/test_should_match_snapshot_dark[date grouped].png differ diff --git a/tests/safeds/data/tabular/plotting/__snapshots__/test_moving_average_plot/test_should_match_snapshot_dark[date].png b/tests/safeds/data/tabular/plotting/__snapshots__/test_moving_average_plot/test_should_match_snapshot_dark[date].png new file mode 100644 index 000000000..8cc81001d Binary files /dev/null and b/tests/safeds/data/tabular/plotting/__snapshots__/test_moving_average_plot/test_should_match_snapshot_dark[date].png differ diff --git a/tests/safeds/data/tabular/plotting/__snapshots__/test_moving_average_plot/test_should_match_snapshot_dark[numerical].png b/tests/safeds/data/tabular/plotting/__snapshots__/test_moving_average_plot/test_should_match_snapshot_dark[numerical].png new file mode 100644 index 000000000..dbbce1eb2 Binary files /dev/null and b/tests/safeds/data/tabular/plotting/__snapshots__/test_moving_average_plot/test_should_match_snapshot_dark[numerical].png differ diff --git a/tests/safeds/data/tabular/plotting/__snapshots__/test_plot_boxplots/test_should_match_snapshot_dark[four columns (all numeric)].png b/tests/safeds/data/tabular/plotting/__snapshots__/test_plot_boxplots/test_should_match_snapshot_dark[four columns (all numeric)].png new file mode 100644 index 000000000..0c93ca8af Binary files /dev/null and b/tests/safeds/data/tabular/plotting/__snapshots__/test_plot_boxplots/test_should_match_snapshot_dark[four columns (all numeric)].png differ diff --git a/tests/safeds/data/tabular/plotting/__snapshots__/test_plot_boxplots/test_should_match_snapshot_dark[four columns (some non-numeric)].png b/tests/safeds/data/tabular/plotting/__snapshots__/test_plot_boxplots/test_should_match_snapshot_dark[four columns (some non-numeric)].png new file mode 100644 index 000000000..3016746e8 Binary files /dev/null and b/tests/safeds/data/tabular/plotting/__snapshots__/test_plot_boxplots/test_should_match_snapshot_dark[four columns (some non-numeric)].png differ diff --git a/tests/safeds/data/tabular/plotting/__snapshots__/test_plot_boxplots/test_should_match_snapshot_dark[one column].png b/tests/safeds/data/tabular/plotting/__snapshots__/test_plot_boxplots/test_should_match_snapshot_dark[one column].png new file mode 100644 index 000000000..be46ce227 Binary files /dev/null and b/tests/safeds/data/tabular/plotting/__snapshots__/test_plot_boxplots/test_should_match_snapshot_dark[one column].png differ diff --git a/tests/safeds/data/tabular/plotting/__snapshots__/test_plot_correlation_heatmap/test_should_match_snapshot_dark[normal].png b/tests/safeds/data/tabular/plotting/__snapshots__/test_plot_correlation_heatmap/test_should_match_snapshot_dark[normal].png new file mode 100644 index 000000000..7957cc10d Binary files /dev/null and b/tests/safeds/data/tabular/plotting/__snapshots__/test_plot_correlation_heatmap/test_should_match_snapshot_dark[normal].png differ diff --git a/tests/safeds/data/tabular/plotting/__snapshots__/test_plot_histograms/test_should_match_snapshot_dark[four columns].png b/tests/safeds/data/tabular/plotting/__snapshots__/test_plot_histograms/test_should_match_snapshot_dark[four columns].png new file mode 100644 index 000000000..dc45552d3 Binary files /dev/null and b/tests/safeds/data/tabular/plotting/__snapshots__/test_plot_histograms/test_should_match_snapshot_dark[four columns].png differ diff --git a/tests/safeds/data/tabular/plotting/__snapshots__/test_plot_histograms/test_should_match_snapshot_dark[one column].png b/tests/safeds/data/tabular/plotting/__snapshots__/test_plot_histograms/test_should_match_snapshot_dark[one column].png new file mode 100644 index 000000000..751e163ee Binary files /dev/null and b/tests/safeds/data/tabular/plotting/__snapshots__/test_plot_histograms/test_should_match_snapshot_dark[one column].png differ diff --git a/tests/safeds/data/tabular/plotting/__snapshots__/test_plot_histograms/test_should_match_snapshot_dark[two columns with compressed visualization].png b/tests/safeds/data/tabular/plotting/__snapshots__/test_plot_histograms/test_should_match_snapshot_dark[two columns with compressed visualization].png new file mode 100644 index 000000000..2d0f23832 Binary files /dev/null and b/tests/safeds/data/tabular/plotting/__snapshots__/test_plot_histograms/test_should_match_snapshot_dark[two columns with compressed visualization].png differ diff --git a/tests/safeds/data/tabular/plotting/__snapshots__/test_plot_lineplot/test_should_match_snapshot_dark[functional multiple columns].png b/tests/safeds/data/tabular/plotting/__snapshots__/test_plot_lineplot/test_should_match_snapshot_dark[functional multiple columns].png new file mode 100644 index 000000000..d7580a7f9 Binary files /dev/null and b/tests/safeds/data/tabular/plotting/__snapshots__/test_plot_lineplot/test_should_match_snapshot_dark[functional multiple columns].png differ diff --git a/tests/safeds/data/tabular/plotting/__snapshots__/test_plot_lineplot/test_should_match_snapshot_dark[functional].png b/tests/safeds/data/tabular/plotting/__snapshots__/test_plot_lineplot/test_should_match_snapshot_dark[functional].png new file mode 100644 index 000000000..4f42ab4a2 Binary files /dev/null and b/tests/safeds/data/tabular/plotting/__snapshots__/test_plot_lineplot/test_should_match_snapshot_dark[functional].png differ diff --git a/tests/safeds/data/tabular/plotting/__snapshots__/test_plot_lineplot/test_should_match_snapshot_dark[sorted grouped multiple columns].png b/tests/safeds/data/tabular/plotting/__snapshots__/test_plot_lineplot/test_should_match_snapshot_dark[sorted grouped multiple columns].png new file mode 100644 index 000000000..c18a53839 Binary files /dev/null and b/tests/safeds/data/tabular/plotting/__snapshots__/test_plot_lineplot/test_should_match_snapshot_dark[sorted grouped multiple columns].png differ diff --git a/tests/safeds/data/tabular/plotting/__snapshots__/test_plot_lineplot/test_should_match_snapshot_dark[sorted grouped].png b/tests/safeds/data/tabular/plotting/__snapshots__/test_plot_lineplot/test_should_match_snapshot_dark[sorted grouped].png new file mode 100644 index 000000000..78f2b4a5c Binary files /dev/null and b/tests/safeds/data/tabular/plotting/__snapshots__/test_plot_lineplot/test_should_match_snapshot_dark[sorted grouped].png differ diff --git a/tests/safeds/data/tabular/plotting/__snapshots__/test_plot_lineplot/test_should_match_snapshot_dark[unsorted grouped multiple columns].png b/tests/safeds/data/tabular/plotting/__snapshots__/test_plot_lineplot/test_should_match_snapshot_dark[unsorted grouped multiple columns].png new file mode 100644 index 000000000..94dfdf588 Binary files /dev/null and b/tests/safeds/data/tabular/plotting/__snapshots__/test_plot_lineplot/test_should_match_snapshot_dark[unsorted grouped multiple columns].png differ diff --git a/tests/safeds/data/tabular/plotting/__snapshots__/test_plot_lineplot/test_should_match_snapshot_dark[unsorted grouped].png b/tests/safeds/data/tabular/plotting/__snapshots__/test_plot_lineplot/test_should_match_snapshot_dark[unsorted grouped].png new file mode 100644 index 000000000..51b62c352 Binary files /dev/null and b/tests/safeds/data/tabular/plotting/__snapshots__/test_plot_lineplot/test_should_match_snapshot_dark[unsorted grouped].png differ diff --git a/tests/safeds/data/tabular/plotting/__snapshots__/test_plot_scatterplot/test_should_match_snapshot_dark[functional].png b/tests/safeds/data/tabular/plotting/__snapshots__/test_plot_scatterplot/test_should_match_snapshot_dark[functional].png new file mode 100644 index 000000000..e03ffc868 Binary files /dev/null and b/tests/safeds/data/tabular/plotting/__snapshots__/test_plot_scatterplot/test_should_match_snapshot_dark[functional].png differ diff --git a/tests/safeds/data/tabular/plotting/__snapshots__/test_plot_scatterplot/test_should_match_snapshot_dark[multiple].png b/tests/safeds/data/tabular/plotting/__snapshots__/test_plot_scatterplot/test_should_match_snapshot_dark[multiple].png new file mode 100644 index 000000000..51cb83530 Binary files /dev/null and b/tests/safeds/data/tabular/plotting/__snapshots__/test_plot_scatterplot/test_should_match_snapshot_dark[multiple].png differ diff --git a/tests/safeds/data/tabular/plotting/__snapshots__/test_plot_scatterplot/test_should_match_snapshot_dark[overlapping].png b/tests/safeds/data/tabular/plotting/__snapshots__/test_plot_scatterplot/test_should_match_snapshot_dark[overlapping].png new file mode 100644 index 000000000..96ed6018a Binary files /dev/null and b/tests/safeds/data/tabular/plotting/__snapshots__/test_plot_scatterplot/test_should_match_snapshot_dark[overlapping].png differ diff --git a/tests/safeds/data/tabular/plotting/test_moving_average_plot.py b/tests/safeds/data/tabular/plotting/test_moving_average_plot.py index a92c2edc2..eed04e567 100644 --- a/tests/safeds/data/tabular/plotting/test_moving_average_plot.py +++ b/tests/safeds/data/tabular/plotting/test_moving_average_plot.py @@ -10,7 +10,6 @@ ("table", "x_name", "y_name", "window_size"), [ (Table({"A": [1, 2, 3], "B": [2, 4, 7]}), "A", "B", 2), - # (Table({"A": [1, 1, 2, 2, 3, 3, 4, 4, 5, 5], "B": [2, 4, 6, 8, 10, 12, 14, 16, 18, 20]}), "A", "B", 2), ( Table( { @@ -59,6 +58,59 @@ def test_should_match_snapshot( assert line_plot == snapshot_png_image +@pytest.mark.parametrize( + ("table", "x_name", "y_name", "window_size"), + [ + (Table({"A": [1, 2, 3], "B": [2, 4, 7]}), "A", "B", 2), + # (Table({"A": [1, 1, 2, 2, 3, 3, 4, 4, 5, 5], "B": [2, 4, 6, 8, 10, 12, 14, 16, 18, 20]}), "A", "B", 2), + ( + Table( + { + "time": [ + datetime.date(2022, 1, 10), + datetime.date(2022, 1, 10), + datetime.date(2022, 1, 11), + datetime.date(2022, 1, 11), + datetime.date(2022, 1, 12), + datetime.date(2022, 1, 12), + ], + "A": [10, 5, 20, 2, 1, 1], + }, + ), + "time", + "A", + 2, + ), + ( + Table( + { + "time": [ + datetime.date(2022, 1, 9), + datetime.date(2022, 1, 10), + datetime.date(2022, 1, 11), + datetime.date(2022, 1, 12), + ], + "A": [10, 5, 20, 2], + }, + ), + "time", + "A", + 2, + ), + ], + ids=["numerical", "date grouped", "date"], +) +def test_should_match_snapshot_dark( + table: Table, + x_name: str, + y_name: str, + window_size: int, + snapshot_png_image: SnapshotAssertion, +) -> None: + line_plot = table.plot.moving_average_plot(x_name, y_name, window_size, theme="dark") + assert line_plot == snapshot_png_image + + @pytest.mark.parametrize( ("x", "y"), [ diff --git a/tests/safeds/data/tabular/plotting/test_plot_boxplots.py b/tests/safeds/data/tabular/plotting/test_plot_boxplots.py index f9f5b9bf5..1867efe75 100644 --- a/tests/safeds/data/tabular/plotting/test_plot_boxplots.py +++ b/tests/safeds/data/tabular/plotting/test_plot_boxplots.py @@ -33,3 +33,17 @@ def test_should_raise_if_column_contains_non_numerical_values() -> None: def test_should_fail_on_empty_table() -> None: with pytest.raises(NonNumericColumnError): Table().plot.box_plots() + + +@pytest.mark.parametrize( + "table", + [ + Table({"A": [1, 2, 3]}), + Table({"A": [1, 2, 3], "B": ["A", "A", "Bla"], "C": [True, True, False], "D": [1.0, 2.1, 4.5]}), + Table({"A": [1, 2, 3], "B": [1.0, 2.1, 4.5], "C": [1, 2, 3], "D": [1.0, 2.1, 4.5]}), + ], + ids=["one column", "four columns (some non-numeric)", "four columns (all numeric)"], +) +def test_should_match_snapshot_dark(table: Table, snapshot_png_image: SnapshotAssertion) -> None: + boxplots = table.plot.box_plots(theme="dark") + assert boxplots == snapshot_png_image diff --git a/tests/safeds/data/tabular/plotting/test_plot_correlation_heatmap.py b/tests/safeds/data/tabular/plotting/test_plot_correlation_heatmap.py index deb55c9f1..f2f04db96 100644 --- a/tests/safeds/data/tabular/plotting/test_plot_correlation_heatmap.py +++ b/tests/safeds/data/tabular/plotting/test_plot_correlation_heatmap.py @@ -22,3 +22,15 @@ def test_should_match_snapshot(table: Table, snapshot_png_image: SnapshotAsserti # match=r"An empty table has been used. A correlation heatmap on an empty table will show nothing.", # ): # Table().plot.correlation_heatmap() + + +@pytest.mark.parametrize( + "table", + [ + Table({"A": [1, 2, 3.5], "B": [0.2, 4, 77]}), + ], + ids=["normal"], +) +def test_should_match_snapshot_dark(table: Table, snapshot_png_image: SnapshotAssertion) -> None: + correlation_heatmap = table.plot.correlation_heatmap(theme="dark") + assert correlation_heatmap == snapshot_png_image diff --git a/tests/safeds/data/tabular/plotting/test_plot_histograms.py b/tests/safeds/data/tabular/plotting/test_plot_histograms.py index 2b50b775f..0e6f10a93 100644 --- a/tests/safeds/data/tabular/plotting/test_plot_histograms.py +++ b/tests/safeds/data/tabular/plotting/test_plot_histograms.py @@ -74,3 +74,71 @@ def test_should_match_snapshot(table: Table, snapshot_png_image: SnapshotAsserti def test_should_fail_on_empty_table() -> None: with pytest.raises(ZeroDivisionError): Table().plot.histograms() + + +@pytest.mark.parametrize( + "table", + [ + Table({"A": [1, 2, 3]}), + Table( + { + "A": [1, 2, 3, 3, 2, 4, 2], + "B": ["a", "b", "b", "b", "b", "b", "a"], + "C": [True, True, False, True, False, None, True], + "D": [1.0, 2.1, 2.1, 2.1, 2.1, 3.0, 3.0], + }, + ), + Table( + { + "A": [ + 3.8, + 1.8, + 3.2, + 2.2, + 1.0, + 2.4, + 3.5, + 3.9, + 1.9, + 4.0, + 1.4, + 4.2, + 4.5, + 4.5, + 1.4, + 2.5, + 2.8, + 2.8, + 1.9, + 4.3, + ], + "B": [ + "a", + "b", + "b", + "c", + "d", + "f", + "a", + "f", + "e", + "a", + "b", + "b", + "k", + "j", + "b", + "i", + "h", + "g", + "g", + "a", + ], + }, + ), + ], + ids=["one column", "four columns", "two columns with compressed visualization"], +) +def test_should_match_snapshot_dark(table: Table, snapshot_png_image: SnapshotAssertion) -> None: + histograms = table.plot.histograms(theme="dark") + assert histograms == snapshot_png_image diff --git a/tests/safeds/data/tabular/plotting/test_plot_lineplot.py b/tests/safeds/data/tabular/plotting/test_plot_lineplot.py index 5015abf4c..7ca2a2cbf 100644 --- a/tests/safeds/data/tabular/plotting/test_plot_lineplot.py +++ b/tests/safeds/data/tabular/plotting/test_plot_lineplot.py @@ -70,3 +70,34 @@ def test_should_raise_if_column_does_not_exist_error_message(x: str, y: str) -> table = Table({"A": [1, 2, 3], "B": [2, 4, 7]}) with pytest.raises(ColumnNotFoundError): table.plot.line_plot(x, [y]) + + +@pytest.mark.parametrize( + ("table", "x_name", "y_names"), + [ + (Table({"A": [1, 2, 3], "B": [2, 4, 7]}), "A", ["B"]), + (Table({"A": [1, 1, 2, 2], "B": [2, 4, 6, 8]}), "A", ["B"]), + (Table({"A": [2, 1, 3, 3, 1, 2], "B": [6, 2, 5, 5, 4, 8]}), "A", ["B"]), + (Table({"A": [1, 2, 3], "B": [2, 4, 7], "C": [1, 3, 5]}), "A", ["B", "C"]), + (Table({"A": [1, 1, 2, 2], "B": [2, 4, 6, 8], "C": [1, 3, 5, 6]}), "A", ["B", "C"]), + (Table({"A": [2, 1, 3, 3, 1, 2], "B": [6, 2, 5, 5, 4, 8], "C": [9, 7, 5, 3, 2, 1]}), "A", ["B", "C"]), + ], + ids=[ + "functional", + "sorted grouped", + "unsorted grouped", + "functional multiple columns", + "sorted grouped multiple columns", + "unsorted grouped multiple columns", + ], +) +def test_should_match_snapshot_dark( + table: Table, + x_name: str, + y_names: list[str], + snapshot_png_image: SnapshotAssertion, +) -> None: + skip_if_os([os_mac]) + + line_plot = table.plot.line_plot(x_name, y_names, theme="dark") + assert line_plot == snapshot_png_image diff --git a/tests/safeds/data/tabular/plotting/test_plot_scatterplot.py b/tests/safeds/data/tabular/plotting/test_plot_scatterplot.py index 33d0170b5..4a1452274 100644 --- a/tests/safeds/data/tabular/plotting/test_plot_scatterplot.py +++ b/tests/safeds/data/tabular/plotting/test_plot_scatterplot.py @@ -42,6 +42,44 @@ def test_should_match_snapshot( assert scatterplot == snapshot_png_image +@pytest.mark.parametrize( + ("table", "x_name", "y_names"), + [ + (Table({"A": [1, 2, 3], "B": [2, 4, 7]}), "A", ["B"]), + ( + Table( + { + "A": [1, 0.99, 0.99, 2], + "B": [1, 0.99, 1.01, 2], + }, + ), + "A", + ["B"], + ), + ( + Table( + {"A": [1, 0.99, 0.99, 2], "B": [1, 0.99, 1.01, 2], "C": [2, 2.99, 2.01, 3]}, + ), + "A", + ["B", "C"], + ), + ], + ids=[ + "functional", + "overlapping", + "multiple", + ], +) +def test_should_match_snapshot_dark( + table: Table, + x_name: str, + y_names: list[str], + snapshot_png_image: SnapshotAssertion, +) -> None: + scatterplot = table.plot.scatter_plot(x_name, y_names, theme="dark") + assert scatterplot == snapshot_png_image + + @pytest.mark.parametrize( ("table", "col1", "col2"), [