From f623c1fdb720f3e1772b5247f4f230255da3f9b5 Mon Sep 17 00:00:00 2001 From: Lars Reimann Date: Tue, 4 Apr 2023 13:24:19 +0200 Subject: [PATCH] feat: class to store image data --- src/safeds/data/image/__init__.py | 1 + src/safeds/data/image/containers/__init__.py | 7 + src/safeds/data/image/containers/_image.py | 152 ++++++++++++++++++ src/safeds/data/image/typing/__init__.py | 7 + src/safeds/data/image/typing/_image_format.py | 7 + src/safeds/data/tabular/containers/_column.py | 53 ++++-- src/safeds/data/tabular/containers/_row.py | 2 +- src/safeds/data/tabular/containers/_table.py | 38 +++-- .../data/tabular/containers/_tagged_table.py | 2 +- tests/resources/image/white_square.jpg | Bin 0 -> 4083 bytes tests/resources/image/white_square.png | Bin 0 -> 1798 bytes tests/safeds/data/image/__init__.py | 0 .../safeds/data/image/containers/__init__.py | 0 .../data/image/containers/test_image.py | 124 ++++++++++++++ tests/safeds/data/image/typing/__init__.py | 0 .../data/image/typing/test_image_format.py | 16 ++ 16 files changed, 382 insertions(+), 27 deletions(-) create mode 100644 src/safeds/data/image/__init__.py create mode 100644 src/safeds/data/image/containers/__init__.py create mode 100644 src/safeds/data/image/containers/_image.py create mode 100644 src/safeds/data/image/typing/__init__.py create mode 100644 src/safeds/data/image/typing/_image_format.py create mode 100644 tests/resources/image/white_square.jpg create mode 100644 tests/resources/image/white_square.png create mode 100644 tests/safeds/data/image/__init__.py create mode 100644 tests/safeds/data/image/containers/__init__.py create mode 100644 tests/safeds/data/image/containers/test_image.py create mode 100644 tests/safeds/data/image/typing/__init__.py create mode 100644 tests/safeds/data/image/typing/test_image_format.py diff --git a/src/safeds/data/image/__init__.py b/src/safeds/data/image/__init__.py new file mode 100644 index 000000000..88627d20e --- /dev/null +++ b/src/safeds/data/image/__init__.py @@ -0,0 +1 @@ +"""Work with image data.""" diff --git a/src/safeds/data/image/containers/__init__.py b/src/safeds/data/image/containers/__init__.py new file mode 100644 index 000000000..0ac99abda --- /dev/null +++ b/src/safeds/data/image/containers/__init__.py @@ -0,0 +1,7 @@ +"""Classes that can store image data.""" + +from ._image import Image + +__all__ = [ + 'Image', +] diff --git a/src/safeds/data/image/containers/_image.py b/src/safeds/data/image/containers/_image.py new file mode 100644 index 000000000..6b85747c6 --- /dev/null +++ b/src/safeds/data/image/containers/_image.py @@ -0,0 +1,152 @@ +from __future__ import annotations + +import io +from pathlib import Path +from typing import BinaryIO + +from PIL.Image import open as open_image, Image as PillowImage + +from safeds.data.image.typing import ImageFormat + + +class Image: + """ + A container for image data. + + Parameters + ---------- + data : BinaryIO + The image data as bytes. + """ + + @staticmethod + def from_jpeg_file(path: str) -> Image: + """ + Create an image from a JPEG file. + + Parameters + ---------- + path : str + The path to the JPEG file. + + Returns + ------- + image : Image + The image. + """ + return Image( + data=Path(path).open("rb"), + format_=ImageFormat.JPEG, + ) + + @staticmethod + def from_png_file(path: str) -> Image: + """ + Create an image from a PNG file. + + Parameters + ---------- + path : str + The path to the PNG file. + + Returns + ------- + image : Image + The image. + """ + return Image( + data=Path(path).open("rb"), + format_=ImageFormat.PNG, + ) + + def __init__(self, data: BinaryIO, format_: ImageFormat): + data.seek(0) + + self._image: PillowImage = open_image(data, formats=[format_.value]) + self._format: ImageFormat = format_ + + # ------------------------------------------------------------------------------------------------------------------ + # Properties + # ------------------------------------------------------------------------------------------------------------------ + + @property + def format(self) -> ImageFormat: + """ + Get the image format. + + Returns + ------- + format : ImageFormat + The image format. + """ + return self._format + + # ------------------------------------------------------------------------------------------------------------------ + # Conversion + # ------------------------------------------------------------------------------------------------------------------ + + def to_jpeg_file(self, path: str) -> None: + """ + Save the image as a JPEG file. + + Parameters + ---------- + path : str + The path to the JPEG file. + """ + Path(path).parent.mkdir(parents=True, exist_ok=True) + self._image.save(path, format="jpeg") + + def to_png_file(self, path: str) -> None: + """ + Save the image as a PNG file. + + Parameters + ---------- + path : str + The path to the PNG file. + """ + Path(path).parent.mkdir(parents=True, exist_ok=True) + self._image.save(path, format="png") + + # ------------------------------------------------------------------------------------------------------------------ + # IPython integration + # ------------------------------------------------------------------------------------------------------------------ + + def _repr_jpeg_(self) -> bytes | None: + """ + Return a JPEG image as bytes. + + If the image is not a JPEG, return None. + + Returns + ------- + jpeg : bytes + The image as JPEG. + """ + if self._format != ImageFormat.JPEG: + return None + + buffer = io.BytesIO() + self._image.save(buffer, format="jpeg") + buffer.seek(0) + return buffer.read() + + def _repr_png_(self) -> bytes | None: + """ + Return a PNG image as bytes. + + If the image is not a PNG, return None. + + Returns + ------- + png : bytes + The image as PNG. + """ + if self._format != ImageFormat.PNG: + return None + + buffer = io.BytesIO() + self._image.save(buffer, format="png") + buffer.seek(0) + return buffer.read() diff --git a/src/safeds/data/image/typing/__init__.py b/src/safeds/data/image/typing/__init__.py new file mode 100644 index 000000000..5f68a11a4 --- /dev/null +++ b/src/safeds/data/image/typing/__init__.py @@ -0,0 +1,7 @@ +"""Types used to distinguish different image formats.""" + +from ._image_format import ImageFormat + +__all__ = [ + 'ImageFormat', +] diff --git a/src/safeds/data/image/typing/_image_format.py b/src/safeds/data/image/typing/_image_format.py new file mode 100644 index 000000000..f137c152a --- /dev/null +++ b/src/safeds/data/image/typing/_image_format.py @@ -0,0 +1,7 @@ +from enum import Enum + + +class ImageFormat(Enum): + """Images formats supported by us.""" + JPEG = 'jpeg' + PNG = 'png' diff --git a/src/safeds/data/tabular/containers/_column.py b/src/safeds/data/tabular/containers/_column.py index 03675223e..bcd9ae151 100644 --- a/src/safeds/data/tabular/containers/_column.py +++ b/src/safeds/data/tabular/containers/_column.py @@ -1,5 +1,6 @@ from __future__ import annotations +import io from numbers import Number from typing import TYPE_CHECKING, Any @@ -9,6 +10,8 @@ import seaborn as sns from IPython.core.display_functions import DisplayHandle, display +from safeds.data.image.containers import Image +from safeds.data.image.typing import ImageFormat from safeds.data.tabular.typing import ColumnType from safeds.exceptions import ( ColumnLengthMismatchError, @@ -470,7 +473,7 @@ def variance(self) -> float: # Plotting # ------------------------------------------------------------------------------------------------------------------ - def boxplot(self) -> None: + def boxplot(self) -> Image: """ Plot this column in a boxplot. This function can only plot real numerical data. @@ -487,13 +490,22 @@ def boxplot(self) -> None: "The column contains complex data. Boxplots cannot plot the imaginary part of complex " "data. Please provide a Column with only real numbers", ) + + fig = plt.figure() ax = sns.boxplot(data=self._data) ax.set(xlabel=self.name) plt.tight_layout() - plt.show() - def histogram(self) -> None: + buffer = io.BytesIO() + fig.savefig(buffer, format='png') + plt.close() # Prevents the figure from being displayed directly + buffer.seek(0) + return Image(buffer, ImageFormat.PNG) + + def histogram(self) -> Image: """Plot a column in a histogram.""" + + fig = plt.figure() ax = sns.histplot(data=self._data) ax.set_xticks(ax.get_xticks()) ax.set(xlabel=self.name) @@ -503,23 +515,17 @@ def histogram(self) -> None: horizontalalignment="right", ) # rotate the labels of the x Axis to prevent the chance of overlapping of the labels plt.tight_layout() - plt.show() + + buffer = io.BytesIO() + fig.savefig(buffer, format='png') + plt.close() # Prevents the figure from being displayed directly + buffer.seek(0) + return Image(buffer, ImageFormat.PNG) # ------------------------------------------------------------------------------------------------------------------ - # Other + # IPython integration # ------------------------------------------------------------------------------------------------------------------ - def _count_missing_values(self) -> int: - """ - Return the number of null values in the column. - - Returns - ------- - count : int - The number of null values. - """ - return self._data.isna().sum() - def _ipython_display_(self) -> DisplayHandle: """ Return a display object for the column to be used in Jupyter Notebooks. @@ -534,3 +540,18 @@ def _ipython_display_(self) -> DisplayHandle: with pd.option_context("display.max_rows", tmp.shape[0], "display.max_columns", tmp.shape[1]): return display(tmp) + + # ------------------------------------------------------------------------------------------------------------------ + # Other + # ------------------------------------------------------------------------------------------------------------------ + + def _count_missing_values(self) -> int: + """ + Return the number of null values in the column. + + Returns + ------- + count : int + The number of null values. + """ + return self._data.isna().sum() diff --git a/src/safeds/data/tabular/containers/_row.py b/src/safeds/data/tabular/containers/_row.py index f913f003b..2d954dc60 100644 --- a/src/safeds/data/tabular/containers/_row.py +++ b/src/safeds/data/tabular/containers/_row.py @@ -165,7 +165,7 @@ def count(self) -> int: return len(self._data) # ------------------------------------------------------------------------------------------------------------------ - # Other + # IPython integration # ------------------------------------------------------------------------------------------------------------------ def _ipython_display_(self) -> DisplayHandle: diff --git a/src/safeds/data/tabular/containers/_table.py b/src/safeds/data/tabular/containers/_table.py index cfd093cb1..dc7bd83f1 100644 --- a/src/safeds/data/tabular/containers/_table.py +++ b/src/safeds/data/tabular/containers/_table.py @@ -1,6 +1,7 @@ from __future__ import annotations import functools +import io from pathlib import Path from typing import TYPE_CHECKING, Any @@ -12,6 +13,7 @@ from pandas import DataFrame, Series from scipy import stats +from safeds.data.image.containers import Image from safeds.data.tabular.typing import ColumnType, Schema from safeds.exceptions import ( ColumnLengthMismatchError, @@ -23,9 +25,9 @@ SchemaMismatchError, UnknownColumnNameError, ) - from ._column import Column from ._row import Row +from ...image.typing import ImageFormat if TYPE_CHECKING: from collections.abc import Callable, Iterable @@ -820,7 +822,7 @@ def slice_rows( def sort_columns( self, comparator: Callable[[Column, Column], int] = lambda col1, col2: (col1.name > col2.name) - - (col1.name < col2.name), + - (col1.name < col2.name), ) -> Table: """ Sort the columns of a `Table` with the given comparator and return a new `Table`. @@ -942,10 +944,11 @@ def transform_column(self, name: str, transformer: Callable[[Row], Any]) -> Tabl # Plotting # ------------------------------------------------------------------------------------------------------------------ - def correlation_heatmap(self) -> None: + def correlation_heatmap(self) -> Image: """Plot a correlation heatmap for all numerical columns of this `Table`.""" only_numerical = self.remove_columns_with_non_numerical_values() + fig = plt.figure() sns.heatmap( data=only_numerical._data.corr(), vmin=-1, @@ -955,9 +958,14 @@ def correlation_heatmap(self) -> None: cmap="vlag", ) plt.tight_layout() - plt.show() - def lineplot(self, x_column_name: str, y_column_name: str) -> None: + buffer = io.BytesIO() + fig.savefig(buffer, format='png') + plt.close() # Prevents the figure from being displayed directly + buffer.seek(0) + return Image(buffer, format_=ImageFormat.PNG) + + def lineplot(self, x_column_name: str, y_column_name: str) -> Image: """ Plot two columns against each other in a lineplot. @@ -981,6 +989,7 @@ def lineplot(self, x_column_name: str, y_column_name: str) -> None: if not self.has_column(y_column_name): raise UnknownColumnNameError([y_column_name]) + fig = plt.figure() ax = sns.lineplot( data=self._data, x=self._schema._get_column_index_by_name(x_column_name), @@ -994,9 +1003,14 @@ def lineplot(self, x_column_name: str, y_column_name: str) -> None: horizontalalignment="right", ) # rotate the labels of the x Axis to prevent the chance of overlapping of the labels plt.tight_layout() - plt.show() - def scatterplot(self, x_column_name: str, y_column_name: str) -> None: + buffer = io.BytesIO() + fig.savefig(buffer, format='png') + plt.close() # Prevents the figure from being displayed directly + buffer.seek(0) + return Image(buffer, format_=ImageFormat.PNG) + + def scatterplot(self, x_column_name: str, y_column_name: str) -> Image: """ Plot two columns against each other in a scatterplot. @@ -1017,6 +1031,7 @@ def scatterplot(self, x_column_name: str, y_column_name: str) -> None: if not self.has_column(y_column_name): raise UnknownColumnNameError([y_column_name]) + fig = plt.figure() ax = sns.scatterplot( data=self._data, x=self._schema._get_column_index_by_name(x_column_name), @@ -1030,7 +1045,12 @@ def scatterplot(self, x_column_name: str, y_column_name: str) -> None: horizontalalignment="right", ) # rotate the labels of the x Axis to prevent the chance of overlapping of the labels plt.tight_layout() - plt.show() + + buffer = io.BytesIO() + fig.savefig(buffer, format='png') + plt.close() # Prevents the figure from being displayed directly + buffer.seek(0) + return Image(buffer, format_=ImageFormat.PNG) # ------------------------------------------------------------------------------------------------------------------ # Conversion @@ -1093,7 +1113,7 @@ def to_rows(self) -> list[Row]: return [Row(series_row, self._schema) for (_, series_row) in self._data.iterrows()] # ------------------------------------------------------------------------------------------------------------------ - # Other + # IPython integration # ------------------------------------------------------------------------------------------------------------------ def _ipython_display_(self) -> DisplayHandle: diff --git a/src/safeds/data/tabular/containers/_tagged_table.py b/src/safeds/data/tabular/containers/_tagged_table.py index 8c750b8ed..959f59856 100644 --- a/src/safeds/data/tabular/containers/_tagged_table.py +++ b/src/safeds/data/tabular/containers/_tagged_table.py @@ -73,7 +73,7 @@ def target(self) -> Column: return self._target # ------------------------------------------------------------------------------------------------------------------ - # Other + # IPython integration # ------------------------------------------------------------------------------------------------------------------ def _ipython_display_(self) -> DisplayHandle: diff --git a/tests/resources/image/white_square.jpg b/tests/resources/image/white_square.jpg new file mode 100644 index 0000000000000000000000000000000000000000..16b86aae8a2e41074df1a558aa02b7e62c7d73f8 GIT binary patch literal 4083 zcmeHKU2NM_6#ksFTG=|S5qscgNR~Gcg2r~7wlvn#YE9EE(=w_wh;1Mt$8i#?iEV6m zNs0tKAngT-Ng(m_LV*W(<^djfU=q9{Au$03gcvV?1VRFVkYF}5YST5=HqRFI&2>Fh78V!y<1Mw-a2kfM@vTPP^3+yZ2!pGH87&EXXyUJ@ zg>(5;K4ZHEKP%3PvOwN$xuO=9SGvCRD=l=BY&M%>Gb!3`RZ65%DM?l&MTwI|+-qCD z){0x+(O!gQ!_!^U@lD&}SwyX5Z}@3}-rPy4;q=4qy7r!-Q?q^BtJ!;)_oqX_dQP?8`+D7w+Hb>@NlWkr@dmU+8uR@%E-vbrHlwQowBRMvd$KaKm#SD9%i7pjiE6B#&U?7Hnz zPLy|2;0JszlUa4`idm-#uW)jS&t)@niPYR&ToDs&cSCX2Q>N$JZo6AAYsTYUR5$rh zPgT>|-lqk}waXj2Q5L$qXIK2D<{BreBw|`PQK^`g>9=_*7_({_F0b&4D2p>a`)19c zC_7d(B$z}mmRN;#a*%eW7tRe;W6zvQoWcHoWq}`!kpd$HMhc7+7%4DP;Qyz<;IwU6 zbckxwSvvR$)2qD`D4!H%%rE8gV|TODJEpGQ_3ZX@czf)Jr*C=b$3K>y48Dz&>G<^6 zYT?0^+?g!3B0!{A*BjoL3^Xj?W#_4L=P&U4zN5Zk3;^?Q=vg7&pidobra$@7 z80hL%l1fN%lS{z(*Fo?r?eV_nf$Nup;Lq2B;QH&-*nI_Du?GJ37BK%C@s|hOu~&gZ z&j9ayFyKlr0`Fb|ZrgS=R}0_6y3I(1dedLt1P-1DZv7b8|42tOnq%{L<1?W61Z~wv zz{{@y_mKUeM}sd(Bi7~K%wR!o#y|88t{^r+rO8FP7$Pw)8smZwsd465@EZBfTfg(7to?Q{!n+ZNKphA_&@ zRs&27g(0p`7z)KO<*5Y%QVskxV?@W!4x$L{HKMLO8(|BZK!CP1aKO<}>hv_!dh}kj zus~%HN$4O11PT#DL53$qBP#nP;eTlxM-iDK)EZH@&5oE@4j_409!8)AD#XBfePo$= z^NcZE8PR|sgh?EaMx$6%hq0U=*XZ?noY3N0tr{}ae2ft&QO)p~l8b~7GvGZO9TsSo zK_s74H5(C(C<<|8R#_-K=bYhZqKA3FMJkMIFalpl4hL9)l-1;Pu(^)L}LNnIppHg4V-3j@3i@N9(hfZ}?<2$^y& z;Ay0Z1C#*F@G+uVLaS2~x$xuE56$4vK>@=wz|;kkw?ebF7aJLX0mf*@Re&P5Byqe!g0Z zX=JW+Z={_TLC|fcLxAC7jTzCUSjD3;`D6!dEa*OABHb%#4N-pJqP;>Od3y_B{7d$w zy9q|5;(mjn;caSVVX`OpJO+a!2zIiIHvHqIRzJUlMXPjhB!FK zo>{HF4o77Z4XYuk3KWX(-A=gS&gV@>a`Ghpw)u>PX)Qr@irOH;2+giy6^7a@F}^fdyhTZc!SdXWjoTvY*}rZcYZPJ&+@R9{xAYWSUofv@naV_!~!{`BRk=YP%~E&U-p(9&_L zb9LvT+e1Fpt4G(2zx0in!CpT7W_bp=w)sJw?$SnO%X41m1oiK|Z}D{_7oVQ|Jl>yu zvUN1OIJex=o^{ameY5xVJ+{{FwyvJQ#57y=_it@|@BVW4t(J8c#6Ojn(;E-1xIKQA zICt>;L|Vp~n#mF6{(*hlkJohX3jDPDqbtYLhnvQ2cSO5q+o-F5s6XrNdkrZk)}N|0 zU8_Fh@7Q@@Y;8YQ6-hn2^J>nN`nAI8WByMLkAtt_3u$@SQo$<}Sf=zS#iNDwaC3## iT;xpa@;sQHE>xsWrA)c+*FGh^zO2Qi<`bKCH2(*v`e$DN literal 0 HcmV?d00001 diff --git a/tests/safeds/data/image/__init__.py b/tests/safeds/data/image/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/safeds/data/image/containers/__init__.py b/tests/safeds/data/image/containers/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/safeds/data/image/containers/test_image.py b/tests/safeds/data/image/containers/test_image.py new file mode 100644 index 000000000..5f2b08590 --- /dev/null +++ b/tests/safeds/data/image/containers/test_image.py @@ -0,0 +1,124 @@ +from pathlib import Path +from tempfile import NamedTemporaryFile + +import pytest + +from helpers import resolve_resource_path +from safeds.data.image.containers import Image +from safeds.data.image.typing import ImageFormat + + +class TestFromJpegFile: + @pytest.mark.parametrize( + "path", + [ + "image/white_square.jpg" + ], + ) + def test_should_load_jpeg_file(self, path: str) -> None: + Image.from_jpeg_file(resolve_resource_path(path)) + + @pytest.mark.parametrize( + "path", + [ + "image/missing_file.jpg" + ], + ) + def test_should_raise_if_file_not_found(self, path: str) -> None: + with pytest.raises(FileNotFoundError): + Image.from_jpeg_file(resolve_resource_path(path)) + + +class TestFromPngFile: + @pytest.mark.parametrize( + "path", + [ + "image/white_square.png" + ], + ) + def test_should_load_png_file(self, path: str) -> None: + Image.from_png_file(resolve_resource_path(path)) + + @pytest.mark.parametrize( + "path", + [ + "image/missing_file.png" + ], + ) + def test_should_raise_if_file_not_found(self, path: str) -> None: + with pytest.raises(FileNotFoundError): + Image.from_png_file(resolve_resource_path(path)) + + +class TestFormat: + @pytest.mark.parametrize( + ("image", "format_"), + [ + (Image.from_jpeg_file(resolve_resource_path("image/white_square.jpg")), ImageFormat.JPEG), + (Image.from_png_file(resolve_resource_path("image/white_square.png")), ImageFormat.PNG) + ], + ) + def test_should_return_correct_format(self, image: Image, format_: ImageFormat) -> None: + assert image.format == format_ + + +class TestToJpegFile: + @pytest.mark.parametrize( + "path", + [ + "image/white_square.jpg" + ], + ) + def test_should_save_jpeg_file(self, path: str) -> None: + image = Image.from_jpeg_file(resolve_resource_path(path)) + + with NamedTemporaryFile() as tmp_file: + tmp_file.close() + with Path(tmp_file.name).open("wb") as tmp_write_file: + image.to_jpeg_file(tmp_write_file.name) + with Path(tmp_file.name).open("rb") as tmp_read_file: + image_read_back = Image.from_jpeg_file(tmp_read_file.name) + + assert image._image.tobytes() == image_read_back._image.tobytes() + + +class TestToPngFile: + @pytest.mark.parametrize( + "path", + [ + "image/white_square.png" + ], + ) + def test_should_save_png_file(self, path: str) -> None: + image = Image.from_png_file(resolve_resource_path(path)) + + with NamedTemporaryFile() as tmp_file: + tmp_file.close() + with Path(tmp_file.name).open("wb") as tmp_write_file: + image.to_png_file(tmp_write_file.name) + with Path(tmp_file.name).open("rb") as tmp_read_file: + image_read_back = Image.from_png_file(tmp_read_file.name) + + assert image._image.tobytes() == image_read_back._image.tobytes() + + +class TestReprJpeg: + @pytest.mark.parametrize( + "image", + [ + Image.from_png_file(resolve_resource_path("image/white_square.png")) + ], + ) + def test_should_return_none_if_image_is_not_jpeg(self, image: Image) -> None: + assert image._repr_jpeg_() is None + + +class TestReprPng: + @pytest.mark.parametrize( + "image", + [ + Image.from_jpeg_file(resolve_resource_path("image/white_square.jpg")) + ], + ) + def test_should_return_none_if_image_is_not_png(self, image: Image) -> None: + assert image._repr_png_() is None diff --git a/tests/safeds/data/image/typing/__init__.py b/tests/safeds/data/image/typing/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/safeds/data/image/typing/test_image_format.py b/tests/safeds/data/image/typing/test_image_format.py new file mode 100644 index 000000000..dae2fa217 --- /dev/null +++ b/tests/safeds/data/image/typing/test_image_format.py @@ -0,0 +1,16 @@ +import pytest + +from safeds.data.image.typing import ImageFormat + + +class TestValue: + + @pytest.mark.parametrize( + "image_format, expected_value", + [ + (ImageFormat.JPEG, "jpeg"), + (ImageFormat.PNG, "png"), + ], + ) + def test_should_return_correct_value(self, image_format: ImageFormat, expected_value: str) -> None: + assert image_format.value == expected_value