From 5bcc5bdf1ef77bcc9adc602351622581eb3f745a Mon Sep 17 00:00:00 2001 From: Lars Reimann Date: Sat, 22 Apr 2023 21:20:03 +0200 Subject: [PATCH] feat: make `Column` a subclass of `Sequence` --- src/safeds/data/tabular/containers/_column.py | 77 ++++++++++++++++--- src/safeds/data/tabular/containers/_table.py | 18 ++--- src/safeds/data/tabular/typing/_schema.py | 28 ------- .../data/tabular/containers/test_column.py | 20 +++++ .../safeds/data/tabular/typing/test_schema.py | 21 ----- 5 files changed, 96 insertions(+), 68 deletions(-) diff --git a/src/safeds/data/tabular/containers/_column.py b/src/safeds/data/tabular/containers/_column.py index 963841049..2ec2cf795 100644 --- a/src/safeds/data/tabular/containers/_column.py +++ b/src/safeds/data/tabular/containers/_column.py @@ -1,8 +1,9 @@ from __future__ import annotations import io +from collections.abc import Sequence from numbers import Number -from typing import TYPE_CHECKING, Any +from typing import TYPE_CHECKING, Any, TypeVar import matplotlib.pyplot as plt import numpy as np @@ -20,10 +21,12 @@ from safeds.data.tabular.typing import ColumnType if TYPE_CHECKING: - from collections.abc import Callable, Iterable, Iterator + from collections.abc import Callable, Iterator +_T = TypeVar("_T") -class Column: + +class Column(Sequence[_T]): """ A column is a named collection of values. @@ -31,21 +34,77 @@ class Column: ---------- name : str The name of the column. - data : Iterable + data : Sequence The data. - type_ : Optional[ColumnType] - The type of the column. If not specified, the type will be inferred from the data. + + Examples + -------- + >>> from safeds.data.tabular.containers import Column + >>> column = Column("test", [1, 2, 3]) """ + # ------------------------------------------------------------------------------------------------------------------ + # Creation + # ------------------------------------------------------------------------------------------------------------------ + + @staticmethod + def _from_pandas_series(data: pd.Series, type_: ColumnType | None = None) -> Column: + """ + Create a column from a `pandas.Series`. + + Parameters + ---------- + data : pd.Series + The data. + type_ : ColumnType | None + The type. If None, the schema is inferred from the data. + + Returns + ------- + column : Column + The created column. + + Examples + -------- + >>> import pandas as pd + >>> from safeds.data.tabular.containers import Row + >>> row = Column._from_pandas_series(pd.Series([1, 2, 3], name="test")) + """ + result = object.__new__(Column) + result._name = data.name + result._data = data + # noinspection PyProtectedMember + result._type = type_ if type_ is not None else ColumnType._from_numpy_data_type(data.dtype) + + return result + # ------------------------------------------------------------------------------------------------------------------ # Dunder methods # ------------------------------------------------------------------------------------------------------------------ - def __init__(self, name: str, data: Iterable, type_: ColumnType | None = None) -> None: + def __init__(self, name: str, data: Sequence) -> None: + """ + Create a column. + + Parameters + ---------- + name + The name of the column. + data + The data. + + Examples + -------- + >>> from safeds.data.tabular.containers import Column + >>> column = Column("test", [1, 2, 3]) + """ self._name: str = name self._data: pd.Series = data if isinstance(data, pd.Series) else pd.Series(data) # noinspection PyProtectedMember - self._type: ColumnType = type_ if type_ is not None else ColumnType._from_numpy_data_type(self._data.dtype) + self._type: ColumnType = ColumnType._from_numpy_data_type(self._data.dtype) + + def __contains__(self, item: Any) -> bool: + return item in self._data def __eq__(self, other: object) -> bool: if not isinstance(other, Column): @@ -236,7 +295,7 @@ def rename(self, new_name: str) -> Column: column : Column A new column with the new name. """ - return Column(new_name, self._data, self._type) + return Column._from_pandas_series(self._data.rename(new_name), self._type) # ------------------------------------------------------------------------------------------------------------------ # Statistics diff --git a/src/safeds/data/tabular/containers/_table.py b/src/safeds/data/tabular/containers/_table.py index 1f4cab0ee..a18e32831 100644 --- a/src/safeds/data/tabular/containers/_table.py +++ b/src/safeds/data/tabular/containers/_table.py @@ -338,15 +338,13 @@ def get_column(self, column_name: str) -> Column: UnknownColumnNameError If the specified target column name does not exist. """ - if self._schema.has_column(column_name): - output_column = Column( - column_name, - self._data.iloc[:, [self._schema._get_column_index(column_name)]].squeeze(), - self._schema.get_column_type(column_name), - ) - return output_column + if not self.has_column(column_name): + raise UnknownColumnNameError([column_name]) - raise UnknownColumnNameError([column_name]) + return Column._from_pandas_series( + self._data[column_name], + self.get_column_type(column_name), + ) def has_column(self, column_name: str) -> bool: """ @@ -866,7 +864,7 @@ def slice_rows( def sort_columns( self, comparator: Callable[[Column, Column], int] = lambda col1, col2: (col1.name > col2.name) - - (col1.name < col2.name), + - (col1.name < col2.name), ) -> Table: """ Sort the columns of a `Table` with the given comparator and return a new `Table`. @@ -980,7 +978,7 @@ def transform_column(self, name: str, transformer: Callable[[Row], Any]) -> Tabl """ if self.has_column(name): items: list = [transformer(item) for item in self.to_rows()] - result: Column = Column(name, pd.Series(items)) + result: Column = Column._from_pandas_series(pd.Series(items, name=name)) return self.replace_column(name, result) raise UnknownColumnNameError([name]) diff --git a/src/safeds/data/tabular/typing/_schema.py b/src/safeds/data/tabular/typing/_schema.py index 8ab03dacd..c9bd1996c 100644 --- a/src/safeds/data/tabular/typing/_schema.py +++ b/src/safeds/data/tabular/typing/_schema.py @@ -239,31 +239,3 @@ def _repr_markdown_(self) -> str: lines = (f"| {name} | {type_} |" for name, type_ in self._schema.items()) joined = "\n".join(lines) return f"| Column Name | Column Type |\n| --- | --- |\n{joined}" - - # ------------------------------------------------------------------------------------------------------------------ - # Other - # ------------------------------------------------------------------------------------------------------------------ - - def _get_column_index(self, column_name: str) -> int: - """ - Return the index of the column with specified column name. - - Parameters - ---------- - column_name : str - The name of the column. - - Returns - ------- - index : int - The index of the column. - - Raises - ------ - ColumnNameError - If the specified column name does not exist. - """ - if not self.has_column(column_name): - raise UnknownColumnNameError([column_name]) - - return list(self._schema.keys()).index(column_name) diff --git a/tests/safeds/data/tabular/containers/test_column.py b/tests/safeds/data/tabular/containers/test_column.py index 8e97c213d..fee51e1e5 100644 --- a/tests/safeds/data/tabular/containers/test_column.py +++ b/tests/safeds/data/tabular/containers/test_column.py @@ -1,8 +1,28 @@ +from typing import Any + import pytest import regex as re from safeds.data.tabular.containers import Column +class TestContains: + @pytest.mark.parametrize( + ("column", "value", "expected"), + [ + (Column("a", []), 1, False), + (Column("a", [1, 2, 3]), 1, True), + (Column("a", [1, 2, 3]), 4, False), + ], + ids=[ + "empty", + "value exists", + "value does not exist", + ], + ) + def test_should_check_whether_the_value_exists(self, column: Column, value: Any, expected: bool) -> None: + assert (value in column) == expected + + class TestToHtml: @pytest.mark.parametrize( "column", diff --git a/tests/safeds/data/tabular/typing/test_schema.py b/tests/safeds/data/tabular/typing/test_schema.py index ea262131c..7b8bfb636 100644 --- a/tests/safeds/data/tabular/typing/test_schema.py +++ b/tests/safeds/data/tabular/typing/test_schema.py @@ -212,27 +212,6 @@ def test_should_return_column_names(self, schema: Schema, expected: list[str]) - assert schema.column_names == expected -class TestGetColumnIndex: - @pytest.mark.parametrize( - ("schema", "column_name", "expected"), - [ - (Schema({"A": Integer()}), "A", 0), - (Schema({"A": Integer(), "B": RealNumber()}), "B", 1), - ], - ids=[ - "single column", - "multiple columns", - ], - ) - def test_should_return_column_index(self, schema: Schema, column_name: str, expected: int) -> None: - assert schema._get_column_index(column_name) == expected - - def test_should_raise_if_column_does_not_exist(self) -> None: - schema = Schema({"A": Integer()}) - with pytest.raises(UnknownColumnNameError): - schema._get_column_index("B") - - class TestToDict: @pytest.mark.parametrize( ("schema", "expected"),