Skip to content

Commit

Permalink
perf: Faster plot_histograms and more reliable plots (#659)
Browse files Browse the repository at this point in the history
## Fixed problem that histogram get plottet a lot faster 

### Summary of Changes

- rewrote the entire funtion
- reduce numerical bins to a given parameter for more performance
(default 10)
- Bins are now also sorted
- sometimes specific columns had strange plots, these are also now fixed
- swaped the images for the test with new ones

---------

Co-authored-by: Lars Reimann <[email protected]>
  • Loading branch information
SamanHushi and lars-reimann authored May 1, 2024
1 parent 6013eb2 commit b5f0a12
Show file tree
Hide file tree
Showing 5 changed files with 54 additions and 17 deletions.
63 changes: 47 additions & 16 deletions src/safeds/data/tabular/containers/_table.py
Original file line number Diff line number Diff line change
Expand Up @@ -2177,13 +2177,18 @@ def plot_boxplots(self) -> Image:
buffer.seek(0)
return Image.from_bytes(buffer.read())

def plot_histograms(self) -> Image:
def plot_histograms(self, *, number_of_bins : int = 10) -> Image:
"""
Plot a histogram for every column.
Parameters
----------
number_of_bins:
The number of bins to use in the histogram. Default is 10.
Returns
-------
plot: Image
plot:
The plot as an image.
Examples
Expand All @@ -2193,26 +2198,52 @@ def plot_histograms(self) -> Image:
>>> image = table.plot_histograms()
"""
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import seaborn as sns

col_wrap = min(self.number_of_columns, 3)
n_cols = min(3, self.number_of_columns)
n_rows = 1 + (self.number_of_columns - 1) // n_cols

data = pd.melt(self._data.map(lambda value: str(value)), value_vars=self.column_names)
grid = sns.FacetGrid(data=data, col="variable", col_wrap=col_wrap, sharex=False, sharey=False)
grid.map(sns.histplot, "value")
grid.set_xlabels("")
grid.set_ylabels("")
grid.set_titles("{col_name}")
for axes in grid.axes.flat:
axes.set_xticks(axes.get_xticks())
axes.set_xticklabels(axes.get_xticklabels(), rotation=45, horizontalalignment="right")
grid.tight_layout()
fig = grid.fig
one_col = n_cols == 1 and n_rows == 1
fig, axs = plt.subplots(n_rows, n_cols, tight_layout=True, figsize=(n_cols * 3, n_rows * 3))

col_names = self.column_names
for col_name, ax in zip(col_names, axs.flatten() if not one_col else [axs]):
np_col = np.array(self.get_column(col_name))
bins = min(number_of_bins, len(pd.unique(np_col)))

ax.set_title(col_name)
ax.set_xlabel("")
ax.set_ylabel("")

if self.get_column(col_name).type.is_numeric():
np_col = np_col[~np.isnan(np_col)]

if bins < len(pd.unique(np_col)):
min_val = np.min(np_col)
max_val = np.max(np_col)
hist, bin_edges = np.histogram(self.get_column(col_name), bins, range=(min_val, max_val))

bars = np.array([])
for i in range(len(hist)):
bars = np.append(bars, f'{round(bin_edges[i], 2)}-{round(bin_edges[i+1], 2)}')

ax.bar(bars, hist, edgecolor='black')
ax.set_xticks(np.arange(len(hist)), bars, rotation=45, horizontalalignment="right")
continue

np_col = np_col.astype(str)
unique_values = np.unique(np_col)
hist = np.array([np.sum(np_col == value) for value in unique_values])
ax.bar(unique_values, hist, edgecolor='black')
ax.set_xticks(np.arange(len(unique_values)), unique_values, rotation=45, horizontalalignment="right")

for i in range(len(col_names), n_rows * n_cols):
fig.delaxes(axs.flatten()[i]) # Remove empty subplots

buffer = io.BytesIO()
fig.savefig(buffer, format="png")
plt.close()
plt.close() # Prevents the figure from being displayed directly
buffer.seek(0)
return Image.from_bytes(buffer.read())

Expand Down
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,14 @@
"D": [1.0, 2.1, 2.1, 2.1, 2.1, 3.0, 3.0],
},
),
Table(
{
"A": [3.8, 1.8, 3.2, 2.2, 1.0, 2.4, 3.5, 3.9, 1.9, 4.0, 1.4, 4.2, 4.5, 4.5, 1.4, 2.5, 2.8, 2.8, 1.9, 4.3],
"B": ["a", "b", "b", "c", "d", "f", "a", "f", "e", "a", "b", "b", "k", "j", "b", "i", "h", "g", "g", "a"],
}
),
],
ids=["one column", "four columns"],
ids=["one column", "four columns", "two columns with compressed visualization"],
)
def test_should_match_snapshot(table: Table, snapshot_png_image: SnapshotAssertion) -> None:
histograms = table.plot_histograms()
Expand Down

0 comments on commit b5f0a12

Please sign in to comment.