From e18b36250ac170e3364106ba1c59649e0b4aff21 Mon Sep 17 00:00:00 2001 From: Lars Reimann Date: Tue, 18 Apr 2023 22:51:15 +0200 Subject: [PATCH] feat: create column types for `polars` data types (#208) Closes partially #196. ### Summary of Changes * Add `polars` * Create `ColumnType` for `polars` data type * Create `Schema` for `polars` data frame --------- Co-authored-by: megalinter-bot <129584137+megalinter-bot@users.noreply.github.com> --- poetry.lock | 100 ++++++- pyproject.toml | 1 + src/safeds/data/tabular/containers/_column.py | 2 +- src/safeds/data/tabular/containers/_row.py | 6 +- src/safeds/data/tabular/containers/_table.py | 4 +- .../data/tabular/typing/_column_type.py | 93 +++++-- src/safeds/data/tabular/typing/_schema.py | 132 +++++---- .../containers/_column/test_boxplot.py | 2 +- .../data/tabular/containers/test_row.py | 2 +- .../data/tabular/typing/_schema/__init__.py | 0 .../tabular/typing/_schema/test__str__.py | 6 - .../_schema/test_get_column_index_by_name.py | 7 - .../typing/_schema/test_get_column_type.py | 8 - .../tabular/typing/_schema/test_has_column.py | 13 - .../typing/_schema/test_table_equals.py | 26 -- .../data/tabular/typing/test_column_type.py | 125 ++++++--- .../safeds/data/tabular/typing/test_schema.py | 253 ++++++++++++++++++ 17 files changed, 600 insertions(+), 180 deletions(-) delete mode 100644 tests/safeds/data/tabular/typing/_schema/__init__.py delete mode 100644 tests/safeds/data/tabular/typing/_schema/test__str__.py delete mode 100644 tests/safeds/data/tabular/typing/_schema/test_get_column_index_by_name.py delete mode 100644 tests/safeds/data/tabular/typing/_schema/test_get_column_type.py delete mode 100644 tests/safeds/data/tabular/typing/_schema/test_has_column.py delete mode 100644 tests/safeds/data/tabular/typing/_schema/test_table_equals.py create mode 100644 tests/safeds/data/tabular/typing/test_schema.py diff --git a/poetry.lock b/poetry.lock index 788b9252f..0f10b0300 100644 --- a/poetry.lock +++ b/poetry.lock @@ -2185,6 +2185,42 @@ files = [ dev = ["pre-commit", "tox"] testing = ["pytest", "pytest-benchmark"] +[[package]] +name = "polars" +version = "0.17.5" +description = "Blazingly fast DataFrame library" +category = "main" +optional = false +python-versions = ">=3.7" +files = [ + {file = "polars-0.17.5-cp37-abi3-macosx_10_7_x86_64.whl", hash = "sha256:3ec088bb68c2b833f1172b85dc1222ae88732ce0ae7de34590dd387204a84b1b"}, + {file = "polars-0.17.5-cp37-abi3-macosx_11_0_arm64.whl", hash = "sha256:1f5389d29d5e5e993a9a361801d54dccd0399cedb72632274b341e27957c631c"}, + {file = "polars-0.17.5-cp37-abi3-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:4518d2a70bf69eaae04437a241f6d81f2d576e4491d8d4b45c95eacb53415616"}, + {file = "polars-0.17.5-cp37-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:6fbe7dc6be495d1805f8252bc6bcfb0372134595bea0ebf7c46db21bad86bf58"}, + {file = "polars-0.17.5-cp37-abi3-win_amd64.whl", hash = "sha256:7f112d6cefb37a32fc723195f0be1f62ec528b5f83905ad2e614bc78585a0313"}, + {file = "polars-0.17.5.tar.gz", hash = "sha256:7db2da068e983312238799ad8263e80544151304aac0bc2e6511f91cb56af54d"}, +] + +[package.dependencies] +pandas = {version = "*", optional = true, markers = "extra == \"pandas\""} +pyarrow = {version = ">=7.0.0", optional = true, markers = "extra == \"pyarrow\" or extra == \"pandas\""} +typing_extensions = {version = ">=4.0.1", markers = "python_version < \"3.11\""} +xlsx2csv = {version = ">=0.8.0", optional = true, markers = "extra == \"xlsx2csv\""} + +[package.extras] +all = ["polars[connectorx,deltalake,fsspec,matplotlib,numpy,pandas,pyarrow,sqlalchemy,timezone,xlsx2csv,xlsxwriter]"] +connectorx = ["connectorx"] +deltalake = ["deltalake (>=0.8.0)"] +fsspec = ["fsspec"] +matplotlib = ["matplotlib"] +numpy = ["numpy (>=1.16.0)"] +pandas = ["pandas", "pyarrow (>=7.0.0)"] +pyarrow = ["pyarrow (>=7.0.0)"] +sqlalchemy = ["pandas", "sqlalchemy"] +timezone = ["backports.zoneinfo", "tzdata"] +xlsx2csv = ["xlsx2csv (>=0.8.0)"] +xlsxwriter = ["xlsxwriter"] + [[package]] name = "prometheus-client" version = "0.16.0" @@ -2269,6 +2305,44 @@ files = [ [package.extras] tests = ["pytest"] +[[package]] +name = "pyarrow" +version = "11.0.0" +description = "Python library for Apache Arrow" +category = "main" +optional = false +python-versions = ">=3.7" +files = [ + {file = "pyarrow-11.0.0-cp310-cp310-macosx_10_14_x86_64.whl", hash = "sha256:40bb42afa1053c35c749befbe72f6429b7b5f45710e85059cdd534553ebcf4f2"}, + {file = "pyarrow-11.0.0-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:7c28b5f248e08dea3b3e0c828b91945f431f4202f1a9fe84d1012a761324e1ba"}, + {file = "pyarrow-11.0.0-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a37bc81f6c9435da3c9c1e767324ac3064ffbe110c4e460660c43e144be4ed85"}, + {file = "pyarrow-11.0.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:ad7c53def8dbbc810282ad308cc46a523ec81e653e60a91c609c2233ae407689"}, + {file = "pyarrow-11.0.0-cp310-cp310-win_amd64.whl", hash = "sha256:25aa11c443b934078bfd60ed63e4e2d42461682b5ac10f67275ea21e60e6042c"}, + {file = "pyarrow-11.0.0-cp311-cp311-macosx_10_14_x86_64.whl", hash = "sha256:e217d001e6389b20a6759392a5ec49d670757af80101ee6b5f2c8ff0172e02ca"}, + {file = "pyarrow-11.0.0-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:ad42bb24fc44c48f74f0d8c72a9af16ba9a01a2ccda5739a517aa860fa7e3d56"}, + {file = "pyarrow-11.0.0-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:2d942c690ff24a08b07cb3df818f542a90e4d359381fbff71b8f2aea5bf58841"}, + {file = "pyarrow-11.0.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:f010ce497ca1b0f17a8243df3048055c0d18dcadbcc70895d5baf8921f753de5"}, + {file = "pyarrow-11.0.0-cp311-cp311-win_amd64.whl", hash = "sha256:2f51dc7ca940fdf17893227edb46b6784d37522ce08d21afc56466898cb213b2"}, + {file = "pyarrow-11.0.0-cp37-cp37m-macosx_10_14_x86_64.whl", hash = "sha256:1cbcfcbb0e74b4d94f0b7dde447b835a01bc1d16510edb8bb7d6224b9bf5bafc"}, + {file = "pyarrow-11.0.0-cp37-cp37m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:aaee8f79d2a120bf3e032d6d64ad20b3af6f56241b0ffc38d201aebfee879d00"}, + {file = "pyarrow-11.0.0-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:410624da0708c37e6a27eba321a72f29d277091c8f8d23f72c92bada4092eb5e"}, + {file = "pyarrow-11.0.0-cp37-cp37m-win_amd64.whl", hash = "sha256:2d53ba72917fdb71e3584ffc23ee4fcc487218f8ff29dd6df3a34c5c48fe8c06"}, + {file = "pyarrow-11.0.0-cp38-cp38-macosx_10_14_x86_64.whl", hash = "sha256:f12932e5a6feb5c58192209af1d2607d488cb1d404fbc038ac12ada60327fa34"}, + {file = "pyarrow-11.0.0-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:41a1451dd895c0b2964b83d91019e46f15b5564c7ecd5dcb812dadd3f05acc97"}, + {file = "pyarrow-11.0.0-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:becc2344be80e5dce4e1b80b7c650d2fc2061b9eb339045035a1baa34d5b8f1c"}, + {file = "pyarrow-11.0.0-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:8f40be0d7381112a398b93c45a7e69f60261e7b0269cc324e9f739ce272f4f70"}, + {file = "pyarrow-11.0.0-cp38-cp38-win_amd64.whl", hash = "sha256:362a7c881b32dc6b0eccf83411a97acba2774c10edcec715ccaab5ebf3bb0835"}, + {file = "pyarrow-11.0.0-cp39-cp39-macosx_10_14_x86_64.whl", hash = "sha256:ccbf29a0dadfcdd97632b4f7cca20a966bb552853ba254e874c66934931b9841"}, + {file = "pyarrow-11.0.0-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:3e99be85973592051e46412accea31828da324531a060bd4585046a74ba45854"}, + {file = "pyarrow-11.0.0-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:69309be84dcc36422574d19c7d3a30a7ea43804f12552356d1ab2a82a713c418"}, + {file = "pyarrow-11.0.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:da93340fbf6f4e2a62815064383605b7ffa3e9eeb320ec839995b1660d69f89b"}, + {file = "pyarrow-11.0.0-cp39-cp39-win_amd64.whl", hash = "sha256:caad867121f182d0d3e1a0d36f197df604655d0b466f1bc9bafa903aa95083e4"}, + {file = "pyarrow-11.0.0.tar.gz", hash = "sha256:5461c57dbdb211a632a48facb9b39bbeb8a7905ec95d768078525283caef5f6d"}, +] + +[package.dependencies] +numpy = ">=1.16.6" + [[package]] name = "pycparser" version = "2.21" @@ -3123,6 +3197,18 @@ files = [ docs = ["myst-parser", "pydata-sphinx-theme", "sphinx"] test = ["argcomplete (>=2.0)", "pre-commit", "pytest", "pytest-mock"] +[[package]] +name = "typing-extensions" +version = "4.5.0" +description = "Backported and Experimental Type Hints for Python 3.7+" +category = "main" +optional = false +python-versions = ">=3.7" +files = [ + {file = "typing_extensions-4.5.0-py3-none-any.whl", hash = "sha256:fb33085c39dd998ac16d1431ebc293a8b3eedd00fd4a32de0ff79002c19511b4"}, + {file = "typing_extensions-4.5.0.tar.gz", hash = "sha256:5cb5f4a79139d699607b3ef622a1dedafa84e115ab0024e0d9c044a9479ca7cb"}, +] + [[package]] name = "tzdata" version = "2023.3" @@ -3273,7 +3359,19 @@ files = [ {file = "widgetsnbextension-4.0.5.tar.gz", hash = "sha256:003f716d930d385be3fd9de42dd9bf008e30053f73bddde235d14fbeaeff19af"}, ] +[[package]] +name = "xlsx2csv" +version = "0.8.1" +description = "xlsx to csv converter" +category = "main" +optional = false +python-versions = "*" +files = [ + {file = "xlsx2csv-0.8.1-py3-none-any.whl", hash = "sha256:6c36c0295d64f231570479e514d6163ce135af3c431a1705b073230bedaef9f2"}, + {file = "xlsx2csv-0.8.1.tar.gz", hash = "sha256:7ecd6d2bc2426f2e432f4fdac12211e1976d3cbb65f9033e1eda65edda2045e3"}, +] + [metadata] lock-version = "2.0" python-versions = "^3.10" -content-hash = "dbd1f2f23c52bb6e883b6aa470410e039ff9bcd7f87fa163a099dfaa27585486" +content-hash = "98f6804ba729bd1e041efcdcecbb06c2704b688a5ffcf580c1ef73c4672a8143" diff --git a/pyproject.toml b/pyproject.toml index 94ff414d0..04898fe2e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -20,6 +20,7 @@ pandas = "^2.0.0" pillow = "^9.5.0" scikit-learn = "^1.2.0" seaborn = "^0.12.2" +polars = {extras = ["pandas", "pyarrow", "xlsx2csv"], version = "^0.17.5"} [tool.poetry.group.dev.dependencies] pytest = "^7.2.1" diff --git a/src/safeds/data/tabular/containers/_column.py b/src/safeds/data/tabular/containers/_column.py index 2a846e03e..2a7f23712 100644 --- a/src/safeds/data/tabular/containers/_column.py +++ b/src/safeds/data/tabular/containers/_column.py @@ -46,7 +46,7 @@ def __init__(self, name: str, data: Iterable, type_: ColumnType | None = None) - self._name: str = name self._data: pd.Series = data if isinstance(data, pd.Series) else pd.Series(data) # noinspection PyProtectedMember - self._type: ColumnType = type_ if type_ is not None else ColumnType._from_numpy_dtype(self._data.dtype) + self._type: ColumnType = type_ if type_ is not None else ColumnType._from_numpy_data_type(self._data.dtype) def __eq__(self, other: object) -> bool: if not isinstance(other, Column): diff --git a/src/safeds/data/tabular/containers/_row.py b/src/safeds/data/tabular/containers/_row.py index 5af250b41..a0049471b 100644 --- a/src/safeds/data/tabular/containers/_row.py +++ b/src/safeds/data/tabular/containers/_row.py @@ -47,7 +47,7 @@ def from_dict(data: dict[str, Any]) -> Row: """ row_frame = pd.DataFrame([data.values()], columns=list(data.keys())) # noinspection PyProtectedMember - return Row(data.values(), Schema._from_dataframe(row_frame)) + return Row(data.values(), Schema._from_pandas_dataframe(row_frame)) # ------------------------------------------------------------------------------------------------------------------ # Dunder methods @@ -65,7 +65,7 @@ def __init__(self, data: Iterable, schema: Schema | None = None): dataframe = self._data.to_frame().T dataframe.columns = column_names # noinspection PyProtectedMember - self._schema = Schema._from_dataframe(dataframe) + self._schema = Schema._from_pandas_dataframe(dataframe) def __eq__(self, other: Any) -> bool: if not isinstance(other, Row): @@ -136,7 +136,7 @@ def get_value(self, column_name: str) -> Any: if not self._schema.has_column(column_name): raise UnknownColumnNameError([column_name]) # noinspection PyProtectedMember - return self._data[self._schema._get_column_index_by_name(column_name)] + return self._data[self._schema._get_column_index(column_name)] def has_column(self, column_name: str) -> bool: """ diff --git a/src/safeds/data/tabular/containers/_table.py b/src/safeds/data/tabular/containers/_table.py index 26e8a1bc1..956ceb7d5 100644 --- a/src/safeds/data/tabular/containers/_table.py +++ b/src/safeds/data/tabular/containers/_table.py @@ -233,7 +233,7 @@ def __init__(self, data: Iterable, schema: Schema | None = None): | [from_rows][safeds.data.tabular.containers._table.Table.from_rows] | Create a table from a list of rows. | """ self._data: pd.DataFrame = data if isinstance(data, pd.DataFrame) else pd.DataFrame(data) - self._schema: Schema = Schema._from_dataframe(self._data) if schema is None else schema + self._schema: Schema = Schema._from_pandas_dataframe(self._data) if schema is None else schema if self._data.empty: self._data = pd.DataFrame(columns=self._schema.get_column_names()) @@ -305,7 +305,7 @@ def get_column(self, column_name: str) -> Column: if self._schema.has_column(column_name): output_column = Column( column_name, - self._data.iloc[:, [self._schema._get_column_index_by_name(column_name)]].squeeze(), + self._data.iloc[:, [self._schema._get_column_index(column_name)]].squeeze(), self._schema.get_type_of_column(column_name), ) return output_column diff --git a/src/safeds/data/tabular/typing/_column_type.py b/src/safeds/data/tabular/typing/_column_type.py index 8b2a3f7a1..670a16f51 100644 --- a/src/safeds/data/tabular/typing/_column_type.py +++ b/src/safeds/data/tabular/typing/_column_type.py @@ -4,6 +4,15 @@ from dataclasses import dataclass from typing import TYPE_CHECKING +from polars import FLOAT_DTYPES as POLARS_FLOAT_DTYPES +from polars import INTEGER_DTYPES as POLARS_INTEGER_DTYPES +from polars import TEMPORAL_DTYPES as POLARS_TEMPORAL_DTYPES +from polars import Boolean as PolarsBoolean +from polars import Decimal as PolarsDecimal +from polars import Object as PolarsObject +from polars import PolarsDataType +from polars import Utf8 as PolarsUtf8 + if TYPE_CHECKING: import numpy as np @@ -11,37 +20,47 @@ class ColumnType(ABC): """Abstract base class for column types.""" - @abstractmethod - def is_nullable(self) -> bool: + @staticmethod + def _from_numpy_data_type(data_type: np.dtype) -> ColumnType: """ - Return whether the given column type is nullable. + Return the column type for a given `numpy` data type. + + Parameters + ---------- + data_type : numpy.dtype + The `numpy` data type. Returns ------- - is_nullable : bool - True if the column is nullable. - """ + column_type : ColumnType + The ColumnType. - @abstractmethod - def is_numeric(self) -> bool: + Raises + ------ + NotImplementedError + If the given data type is not supported. """ - Return whether the given column type is numeric. + if data_type.kind in ("u", "i"): + return Integer() + if data_type.kind == "b": + return Boolean() + if data_type.kind == "f": + return RealNumber() + if data_type.kind in ("S", "U", "O", "M", "m"): + return String() - Returns - ------- - is_numeric : bool - True if the column is numeric. - """ + message = f"Unsupported numpy data type '{data_type}'." + raise NotImplementedError(message) @staticmethod - def _from_numpy_dtype(dtype: np.dtype) -> ColumnType: + def _from_polars_data_type(data_type: PolarsDataType) -> ColumnType: """ - Return the column type for a given numpy dtype. + Return the column type for a given `polars` data type. Parameters ---------- - dtype : numpy.dtype - The numpy dtype. + data_type : PolarsDataType + The `polars` data type. Returns ------- @@ -50,18 +69,42 @@ def _from_numpy_dtype(dtype: np.dtype) -> ColumnType: Raises ------ - TypeError - If the given dtype is not supported. + NotImplementedError + If the given data type is not supported. """ - if dtype.kind in ("u", "i"): + if data_type in POLARS_INTEGER_DTYPES: return Integer() - if dtype.kind == "b": + if data_type is PolarsBoolean: return Boolean() - if dtype.kind == "f": + if data_type in POLARS_FLOAT_DTYPES or data_type is PolarsDecimal: return RealNumber() - if dtype.kind in ("S", "U", "O", "M", "m"): + if data_type is PolarsUtf8 or data_type is PolarsObject or data_type in POLARS_TEMPORAL_DTYPES: return String() - raise TypeError("Unexpected column type") + + message = f"Unsupported polars data type '{data_type}'." + raise NotImplementedError(message) + + @abstractmethod + def is_nullable(self) -> bool: + """ + Return whether the given column type is nullable. + + Returns + ------- + is_nullable : bool + True if the column is nullable. + """ + + @abstractmethod + def is_numeric(self) -> bool: + """ + Return whether the given column type is numeric. + + Returns + ------- + is_numeric : bool + True if the column is numeric. + """ @dataclass diff --git a/src/safeds/data/tabular/typing/_schema.py b/src/safeds/data/tabular/typing/_schema.py index 135d709dd..5837a6ab8 100644 --- a/src/safeds/data/tabular/typing/_schema.py +++ b/src/safeds/data/tabular/typing/_schema.py @@ -8,6 +8,7 @@ if TYPE_CHECKING: import pandas as pd + import polars as pl @dataclass @@ -23,9 +24,83 @@ class Schema: _schema: dict[str, ColumnType] + @staticmethod + def _from_pandas_dataframe(dataframe: pd.DataFrame) -> Schema: + """ + Create a schema from a `pandas.DataFrame`. + + Parameters + ---------- + dataframe : pd.DataFrame + The dataframe. + + Returns + ------- + schema : Schema + The schema. + """ + names = dataframe.columns + # noinspection PyProtectedMember + types = (ColumnType._from_numpy_data_type(data_type) for data_type in dataframe.dtypes) + + return Schema(dict(zip(names, types, strict=True))) + + @staticmethod + def _from_polars_dataframe(dataframe: pl.DataFrame) -> Schema: + """ + Create a schema from a `polars.Dataframe`. + + Parameters + ---------- + dataframe : pl.DataFrame + The dataframe. + + Returns + ------- + schema : Schema + The schema. + """ + names = dataframe.columns + # noinspection PyProtectedMember + types = (ColumnType._from_polars_data_type(data_type) for data_type in dataframe.dtypes) + + return Schema(dict(zip(names, types, strict=True))) + def __init__(self, schema: dict[str, ColumnType]): self._schema = dict(schema) # Defensive copy + def __hash__(self) -> int: + """ + Return a hash value for the schema. + + Returns + ------- + hash : int + The hash value. + """ + column_names = self._schema.keys() + column_types = map(repr, self._schema.values()) + return hash(tuple(zip(column_names, column_types, strict=True))) + + def __str__(self) -> str: + """ + Return a user-friendly string representation of the schema. + + Returns + ------- + string : str + The string representation. + """ + match len(self._schema): + case 0: + return "{}" + case 1: + return str(self._schema) + case _: + lines = (f" {name!r}: {type_}" for name, type_ in self._schema.items()) + joined = ",\n".join(lines) + return f"{{\n{joined}\n}}" + def has_column(self, column_name: str) -> bool: """ Return whether the schema contains a given column. @@ -59,13 +134,13 @@ def get_type_of_column(self, column_name: str) -> ColumnType: Raises ------ ColumnNameError - If the specified target column name does not exist. + If the specified column name does not exist. """ if not self.has_column(column_name): raise UnknownColumnNameError([column_name]) return self._schema[column_name] - def _get_column_index_by_name(self, column_name: str) -> int: + def _get_column_index(self, column_name: str) -> int: """ Return the index of the column with specified column name. @@ -78,30 +153,16 @@ def _get_column_index_by_name(self, column_name: str) -> int: ------- index : int The index of the column. - """ - return list(self._schema.keys()).index(column_name) - - @staticmethod - def _from_dataframe(dataframe: pd.DataFrame) -> Schema: - """ - Construct a TableSchema from a Dataframe. This function is not supposed to be exposed to the user. - - Parameters - ---------- - dataframe : pd.DataFrame - The Dataframe used to construct the TableSchema. - - Returns - ------- - _from_dataframe: Schema - The constructed TableSchema. + Raises + ------ + ColumnNameError + If the specified column name does not exist. """ - names = dataframe.columns - # noinspection PyProtectedMember - types = (ColumnType._from_numpy_dtype(dtype) for dtype in dataframe.dtypes) + if not self.has_column(column_name): + raise UnknownColumnNameError([column_name]) - return Schema(dict(zip(names, types, strict=True))) + return list(self._schema.keys()).index(column_name) def get_column_names(self) -> list[str]: """ @@ -113,28 +174,3 @@ def get_column_names(self) -> list[str]: The column names. """ return list(self._schema.keys()) - - def __str__(self) -> str: - """ - Return a print-string for the TableSchema. - - Returns - ------- - output_string : str - The string. - """ - column_count = len(self._schema) - output_string = f"TableSchema:\nColumn Count: {column_count}\nColumns:\n" - for column_name, data_type in self._schema.items(): - output_string += f" {column_name}: {data_type}\n" - return output_string - - def __repr__(self) -> str: - return self.__str__() - - def __eq__(self, o: object) -> bool: - if not isinstance(o, Schema): - return NotImplemented - if self is o: - return True - return self._schema == o._schema diff --git a/tests/safeds/data/tabular/containers/_column/test_boxplot.py b/tests/safeds/data/tabular/containers/_column/test_boxplot.py index 841cb6891..059ab78ad 100644 --- a/tests/safeds/data/tabular/containers/_column/test_boxplot.py +++ b/tests/safeds/data/tabular/containers/_column/test_boxplot.py @@ -6,7 +6,7 @@ def test_boxplot_complex() -> None: - with pytest.raises(TypeError): # noqa: PT012 + with pytest.raises(NotImplementedError): # noqa: PT012 table = Table.from_dict({"A": [1, 2, complex(1, -2)]}) table.get_column("A").boxplot() diff --git a/tests/safeds/data/tabular/containers/test_row.py b/tests/safeds/data/tabular/containers/test_row.py index fcbf119de..2a71c1c28 100644 --- a/tests/safeds/data/tabular/containers/test_row.py +++ b/tests/safeds/data/tabular/containers/test_row.py @@ -65,7 +65,7 @@ class TestEq: ], ) def test_should_return_whether_two_rows_are_equal(self, row1: Row, row2: Row, expected: bool) -> None: - assert (row1 == row2) == expected + assert (row1.__eq__(row2)) == expected @pytest.mark.parametrize( ("row", "other"), diff --git a/tests/safeds/data/tabular/typing/_schema/__init__.py b/tests/safeds/data/tabular/typing/_schema/__init__.py deleted file mode 100644 index e69de29bb..000000000 diff --git a/tests/safeds/data/tabular/typing/_schema/test__str__.py b/tests/safeds/data/tabular/typing/_schema/test__str__.py deleted file mode 100644 index 9bdf89652..000000000 --- a/tests/safeds/data/tabular/typing/_schema/test__str__.py +++ /dev/null @@ -1,6 +0,0 @@ -from safeds.data.tabular.containers import Table - - -def test__str__() -> None: - table = Table.from_dict({"col1": ["col1_1"], "col2": [1]}) - assert str(table.schema) == "TableSchema:\nColumn Count: 2\nColumns:\n col1: String\n col2: Integer\n" diff --git a/tests/safeds/data/tabular/typing/_schema/test_get_column_index_by_name.py b/tests/safeds/data/tabular/typing/_schema/test_get_column_index_by_name.py deleted file mode 100644 index 3d445e8d0..000000000 --- a/tests/safeds/data/tabular/typing/_schema/test_get_column_index_by_name.py +++ /dev/null @@ -1,7 +0,0 @@ -from safeds.data.tabular.containers import Table - - -def test_get_column_index_by_name() -> None: - table = Table.from_dict({"col1": [1], "col2": [2]}) - assert table.schema._get_column_index_by_name("col1") == 0 - assert table.schema._get_column_index_by_name("col2") == 1 diff --git a/tests/safeds/data/tabular/typing/_schema/test_get_column_type.py b/tests/safeds/data/tabular/typing/_schema/test_get_column_type.py deleted file mode 100644 index f53752e50..000000000 --- a/tests/safeds/data/tabular/typing/_schema/test_get_column_type.py +++ /dev/null @@ -1,8 +0,0 @@ -from safeds.data.tabular.containers import Table -from safeds.data.tabular.typing import Integer - - -def test_get_type_of_column() -> None: - table = Table.from_dict({"A": [1], "B": [2]}) - table_column_type = table.schema.get_type_of_column("A") - assert table_column_type == Integer() diff --git a/tests/safeds/data/tabular/typing/_schema/test_has_column.py b/tests/safeds/data/tabular/typing/_schema/test_has_column.py deleted file mode 100644 index 13d304936..000000000 --- a/tests/safeds/data/tabular/typing/_schema/test_has_column.py +++ /dev/null @@ -1,13 +0,0 @@ -from safeds.data.tabular.containers import Table - - -def test_has_column_true() -> None: - table = Table.from_dict({"A": [1], "B": [2]}) - - assert table.schema.has_column("A") - - -def test_has_column_false() -> None: - table = Table.from_dict({"A": [1], "B": [2]}) - - assert not table.schema.has_column("C") diff --git a/tests/safeds/data/tabular/typing/_schema/test_table_equals.py b/tests/safeds/data/tabular/typing/_schema/test_table_equals.py deleted file mode 100644 index d132dad71..000000000 --- a/tests/safeds/data/tabular/typing/_schema/test_table_equals.py +++ /dev/null @@ -1,26 +0,0 @@ -from safeds.data.tabular.containers import Table -from safeds.data.tabular.typing import Integer, RealNumber, Schema - - -def test_table_equals_valid() -> None: - table = Table.from_dict({"A": [1], "B": [2]}) - schema_expected = Schema( - { - "A": Integer(), - "B": Integer(), - }, - ) - - assert table.schema == schema_expected - - -def test_table_equals_invalid() -> None: - table = Table.from_dict({"A": [1], "B": [2]}) - schema_not_expected = Schema( - { - "A": RealNumber(), - "C": Integer(), - }, - ) - - assert table.schema != schema_not_expected diff --git a/tests/safeds/data/tabular/typing/test_column_type.py b/tests/safeds/data/tabular/typing/test_column_type.py index 05c614b92..a8c0a33c9 100644 --- a/tests/safeds/data/tabular/typing/test_column_type.py +++ b/tests/safeds/data/tabular/typing/test_column_type.py @@ -1,5 +1,14 @@ import numpy as np import pytest +from polars import FLOAT_DTYPES as POLARS_FLOAT_DTYPES +from polars import INTEGER_DTYPES as POLARS_INTEGER_DTYPES +from polars import PolarsDataType +from polars.datatypes import TEMPORAL_DTYPES as POLARS_TEMPORAL_DTYPES +from polars.datatypes import Boolean as PolarsBoolean +from polars.datatypes import Decimal as PolarsDecimal +from polars.datatypes import Object as PolarsObject +from polars.datatypes import Unknown as PolarsUnknown +from polars.datatypes import Utf8 as PolarsUtf8 from safeds.data.tabular.typing import ( Anything, Boolean, @@ -10,7 +19,77 @@ ) -class TestColumnType: +class TestFromNumpyDataType: + # Test cases taken from https://numpy.org/doc/stable/reference/arrays.scalars.html#scalars + @pytest.mark.parametrize( + ("data_type", "expected"), + [ + # Boolean + (np.dtype(np.bool_), Boolean()), + # Number + (np.dtype(np.half), RealNumber()), + (np.dtype(np.single), RealNumber()), + (np.dtype(np.float_), RealNumber()), + (np.dtype(np.longfloat), RealNumber()), + # Int + (np.dtype(np.byte), Integer()), + (np.dtype(np.short), Integer()), + (np.dtype(np.intc), Integer()), + (np.dtype(np.int_), Integer()), + (np.dtype(np.longlong), Integer()), + (np.dtype(np.ubyte), Integer()), + (np.dtype(np.ushort), Integer()), + (np.dtype(np.uintc), Integer()), + (np.dtype(np.uint), Integer()), + (np.dtype(np.ulonglong), Integer()), + # String + (np.dtype(np.str_), String()), + (np.dtype(np.unicode_), String()), + (np.dtype(np.object_), String()), + (np.dtype(np.datetime64), String()), + (np.dtype(np.timedelta64), String()), + ], + ids=repr, + ) + def test_should_create_column_type_from_numpy_data_type(self, data_type: np.dtype, expected: ColumnType) -> None: + assert ColumnType._from_numpy_data_type(data_type) == expected + + def test_should_raise_if_data_type_is_not_supported(self) -> None: + with pytest.raises(NotImplementedError): + ColumnType._from_numpy_data_type(np.dtype(np.void)) + + +class TestFromPolarsDataType: + @pytest.mark.parametrize( + ("data_type", "expected"), + [ + # Boolean + (PolarsBoolean, Boolean()), + # Float + *((data_type, RealNumber()) for data_type in POLARS_FLOAT_DTYPES), + (PolarsDecimal, RealNumber()), + # Int + *((data_type, Integer()) for data_type in POLARS_INTEGER_DTYPES), + # String + (PolarsUtf8, String()), + (PolarsObject, String()), + *((data_type, String()) for data_type in POLARS_TEMPORAL_DTYPES), + ], + ids=repr, + ) + def test_should_create_column_type_from_polars_data_type( + self, + data_type: PolarsDataType, + expected: ColumnType, + ) -> None: + assert ColumnType._from_polars_data_type(data_type) == expected + + def test_should_raise_if_data_type_is_not_supported(self) -> None: + with pytest.raises(NotImplementedError): + ColumnType._from_polars_data_type(PolarsUnknown) + + +class TestRepr: @pytest.mark.parametrize( ("column_type", "expected"), [ @@ -27,9 +106,11 @@ class TestColumnType: ], ids=repr, ) - def test_repr(self, column_type: ColumnType, expected: str) -> None: + def test_should_create_a_printable_representation(self, column_type: ColumnType, expected: str) -> None: assert repr(column_type) == expected + +class TestIsNullable: @pytest.mark.parametrize( ("column_type", "expected"), [ @@ -46,9 +127,11 @@ def test_repr(self, column_type: ColumnType, expected: str) -> None: ], ids=repr, ) - def test_is_nullable(self, column_type: ColumnType, expected: bool) -> None: + def test_should_return_whether_the_column_type_is_nullable(self, column_type: ColumnType, expected: bool) -> None: assert column_type.is_nullable() == expected + +class TestIsNumeric: @pytest.mark.parametrize( ("column_type", "expected"), [ @@ -65,39 +148,5 @@ def test_is_nullable(self, column_type: ColumnType, expected: bool) -> None: ], ids=repr, ) - def test_is_numeric(self, column_type: ColumnType, expected: bool) -> None: + def test_should_return_whether_the_column_type_is_numeric(self, column_type: ColumnType, expected: bool) -> None: assert column_type.is_numeric() == expected - - # Test cases taken from https://numpy.org/doc/stable/reference/arrays.scalars.html#scalars - @pytest.mark.parametrize( - ("dtype", "expected"), - [ - # Boolean - (np.dtype(np.bool_), Boolean()), - # Number - (np.dtype(np.half), RealNumber()), - (np.dtype(np.single), RealNumber()), - (np.dtype(np.float_), RealNumber()), - (np.dtype(np.longfloat), RealNumber()), - # Int - (np.dtype(np.byte), Integer()), - (np.dtype(np.short), Integer()), - (np.dtype(np.intc), Integer()), - (np.dtype(np.int_), Integer()), - (np.dtype(np.longlong), Integer()), - (np.dtype(np.ubyte), Integer()), - (np.dtype(np.ushort), Integer()), - (np.dtype(np.uintc), Integer()), - (np.dtype(np.uint), Integer()), - (np.dtype(np.ulonglong), Integer()), - # String - (np.dtype(np.str_), String()), - (np.dtype(np.unicode_), String()), - (np.dtype(np.object_), String()), - (np.dtype(np.datetime64), String()), - (np.dtype(np.timedelta64), String()), - ], - ids=repr, - ) - def test_from_numpy_dtype(self, dtype: np.dtype, expected: ColumnType) -> None: - assert ColumnType._from_numpy_dtype(dtype) == expected diff --git a/tests/safeds/data/tabular/typing/test_schema.py b/tests/safeds/data/tabular/typing/test_schema.py new file mode 100644 index 000000000..9413c0606 --- /dev/null +++ b/tests/safeds/data/tabular/typing/test_schema.py @@ -0,0 +1,253 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING + +import pandas as pd +import polars as pl +import pytest +from safeds.data.tabular.exceptions import UnknownColumnNameError +from safeds.data.tabular.typing import Boolean, ColumnType, Integer, RealNumber, Schema, String + +if TYPE_CHECKING: + from typing import Any + + +class TestFromPandasDataFrame: + @pytest.mark.parametrize( + ("dataframe", "expected"), + [ + ( + pd.DataFrame({"A": [True, False, True]}), + Schema({"A": Boolean()}), + ), + ( + pd.DataFrame({"A": [1, 2, 3]}), + Schema({"A": Integer()}), + ), + ( + pd.DataFrame({"A": [1.0, 2.0, 3.0]}), + Schema({"A": RealNumber()}), + ), + ( + pd.DataFrame({"A": ["a", "b", "c"]}), + Schema({"A": String()}), + ), + ( + pd.DataFrame({"A": [1, 2, 3], "B": ["a", "b", "c"]}), + Schema({"A": Integer(), "B": String()}), + ), + ], + ids=[ + "integer", + "real number", + "string", + "boolean", + "multiple columns", + ], + ) + def test_should_create_schema_from_pandas_dataframe(self, dataframe: pd.DataFrame, expected: Schema) -> None: + assert Schema._from_pandas_dataframe(dataframe) == expected + + +class TestFromPolarsDataFrame: + @pytest.mark.parametrize( + ("dataframe", "expected"), + [ + ( + pl.DataFrame({"A": [True, False, True]}), + Schema({"A": Boolean()}), + ), + ( + pl.DataFrame({"A": [1, 2, 3]}), + Schema({"A": Integer()}), + ), + ( + pl.DataFrame({"A": [1.0, 2.0, 3.0]}), + Schema({"A": RealNumber()}), + ), + ( + pl.DataFrame({"A": ["a", "b", "c"]}), + Schema({"A": String()}), + ), + ( + pl.DataFrame({"A": [1, 2, 3], "B": ["a", "b", "c"]}), + Schema({"A": Integer(), "B": String()}), + ), + ], + ids=[ + "integer", + "real number", + "string", + "boolean", + "multiple columns", + ], + ) + def test_should_create_schema_from_polars_dataframe(self, dataframe: pl.DataFrame, expected: Schema) -> None: + assert Schema._from_polars_dataframe(dataframe) == expected + + +class TestStr: + @pytest.mark.parametrize( + ("schema", "expected"), + [ + (Schema({}), "{}"), + (Schema({"A": Integer()}), "{'A': Integer}"), + (Schema({"A": Integer(), "B": String()}), "{\n 'A': Integer,\n 'B': String\n}"), + ], + ids=[ + "empty", + "single column", + "multiple columns", + ], + ) + def test_should_create_a_printable_representation(self, schema: Schema, expected: str) -> None: + assert str(schema) == expected + + +class TestEq: + @pytest.mark.parametrize( + ("schema1", "schema2", "expected"), + [ + (Schema({}), Schema({}), True), + (Schema({"col1": Integer()}), Schema({"col1": Integer()}), True), + (Schema({"col1": Integer()}), Schema({"col1": String()}), False), + (Schema({"col1": Integer()}), Schema({"col2": Integer()}), False), + ( + Schema({"col1": Integer(), "col2": String()}), + Schema({"col2": String(), "col1": Integer()}), + True, + ), + ], + ids=[ + "empty", + "same name and type", + "same name but different type", + "different name but same type", + "flipped columns", + ], + ) + def test_should_return_whether_two_schema_are_equal(self, schema1: Schema, schema2: Schema, expected: bool) -> None: + assert (schema1.__eq__(schema2)) == expected + + @pytest.mark.parametrize( + ("schema", "other"), + [ + (Schema({"col1": Integer()}), None), + (Schema({"col1": Integer()}), {"col1": Integer()}), + ], + ) + def test_should_return_not_implemented_if_other_is_not_schema(self, schema: Schema, other: Any) -> None: + assert (schema.__eq__(other)) is NotImplemented + + +class TestHash: + @pytest.mark.parametrize( + ("schema1", "schema2"), + [ + (Schema({}), Schema({})), + (Schema({"col1": Integer()}), Schema({"col1": Integer()})), + ], + ids=[ + "empty", + "one column", + ], + ) + def test_should_return_same_hash_for_equal_schemas(self, schema1: Schema, schema2: Schema) -> None: + assert hash(schema1) == hash(schema2) + + @pytest.mark.parametrize( + ("schema1", "schema2"), + [ + (Schema({"col1": Integer()}), Schema({"col1": String()})), + (Schema({"col1": Integer()}), Schema({"col2": Integer()})), + ], + ids=[ + "same name but different type", + "different name but same type", + ], + ) + def test_should_return_different_hash_for_unequal_schemas(self, schema1: Schema, schema2: Schema) -> None: + assert hash(schema1) != hash(schema2) + + +class TestHasColumn: + @pytest.mark.parametrize( + ("schema", "column_name", "expected"), + [ + (Schema({}), "A", False), + (Schema({"A": Integer()}), "A", True), + (Schema({"A": Integer()}), "B", False), + ], + ids=[ + "empty", + "column exists", + "column does not exist", + ], + ) + def test_should_return_whether_column_exists(self, schema: Schema, column_name: str, expected: bool) -> None: + assert schema.has_column(column_name) == expected + + +class TestGetTypeOfColumn: + @pytest.mark.parametrize( + ("schema", "column_name", "expected"), + [ + (Schema({"A": Integer()}), "A", Integer()), + (Schema({"A": Integer(), "B": String()}), "B", String()), + ], + ids=[ + "one column", + "two columns", + ], + ) + def test_should_return_type_of_existing_column( + self, + schema: Schema, + column_name: str, + expected: ColumnType, + ) -> None: + assert schema.get_type_of_column(column_name) == expected + + def test_should_raise_if_column_does_not_exist(self) -> None: + schema = Schema({"A": Integer()}) + with pytest.raises(UnknownColumnNameError): + schema.get_type_of_column("B") + + +class TestGetColumnNames: + @pytest.mark.parametrize( + ("schema", "expected"), + [ + (Schema({}), []), + (Schema({"A": Integer()}), ["A"]), + (Schema({"A": Integer(), "B": RealNumber()}), ["A", "B"]), + ], + ids=[ + "empty", + "single column", + "multiple columns", + ], + ) + def test_should_return_column_names(self, schema: Schema, expected: list[str]) -> None: + assert schema.get_column_names() == expected + + +class TestGetColumnIndex: + @pytest.mark.parametrize( + ("schema", "column_name", "expected"), + [ + (Schema({"A": Integer()}), "A", 0), + (Schema({"A": Integer(), "B": RealNumber()}), "B", 1), + ], + ids=[ + "single column", + "multiple columns", + ], + ) + def test_should_return_column_index(self, schema: Schema, column_name: str, expected: int) -> None: + assert schema._get_column_index(column_name) == expected + + def test_should_raise_if_column_does_not_exist(self) -> None: + schema = Schema({"A": Integer()}) + with pytest.raises(UnknownColumnNameError): + schema._get_column_index("B")