Skip to content

Commit

Permalink
feat: stabilize Row class (#980)
Browse files Browse the repository at this point in the history
Closes partially #977

### Summary of Changes

- Narrow types of parameters and results for better type checking.
- Improve tests and documentation.
  • Loading branch information
lars-reimann authored Jan 13, 2025
1 parent db85617 commit ca1ce3d
Show file tree
Hide file tree
Showing 111 changed files with 793 additions and 401 deletions.
2 changes: 1 addition & 1 deletion src/safeds/data/tabular/containers/_lazy_vectorized_row.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ class _LazyVectorizedRow(Row):
up operations on the row.
Moreover, accessing a column only builds an expression that will be evaluated when needed. This is useful when later
operations remove more rows or columns, so we don't do unnecessary work upfront.
operations remove rows or columns, so we don't do unnecessary work upfront.
"""

# ------------------------------------------------------------------------------------------------------------------
Expand Down
35 changes: 18 additions & 17 deletions src/safeds/data/tabular/containers/_row.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,27 +2,29 @@

from abc import ABC, abstractmethod
from collections.abc import Iterator, Mapping
from typing import TYPE_CHECKING, Any
from typing import TYPE_CHECKING

from ._cell import Cell

if TYPE_CHECKING:
from safeds.data.tabular.typing import ColumnType, Schema

from ._cell import Cell


class Row(ABC, Mapping[str, Any]):
class Row(ABC, Mapping[str, Cell]):
"""
A one-dimensional collection of named, heterogeneous values.
This class cannot be instantiated directly. It is only used for arguments of callbacks.
You only need to interact with this class in callbacks passed to higher-order functions.
"""

# ------------------------------------------------------------------------------------------------------------------
# Dunder methods
# ------------------------------------------------------------------------------------------------------------------

def __contains__(self, name: Any) -> bool:
return self.has_column(name)
def __contains__(self, key: object, /) -> bool:
if not isinstance(key, str):
return False
return self.has_column(key)

@abstractmethod
def __eq__(self, other: object) -> bool: ...
Expand All @@ -33,7 +35,7 @@ def __getitem__(self, name: str) -> Cell:
@abstractmethod
def __hash__(self) -> int: ...

def __iter__(self) -> Iterator[Any]:
def __iter__(self) -> Iterator[str]:
return iter(self.column_names)

def __len__(self) -> int:
Expand All @@ -48,18 +50,18 @@ def __sizeof__(self) -> int: ...

@property
@abstractmethod
def column_names(self) -> list[str]:
"""The names of the columns in the row."""
def column_count(self) -> int:
"""The number of columns."""

@property
@abstractmethod
def column_count(self) -> int:
"""The number of columns in the row."""
def column_names(self) -> list[str]:
"""The names of the columns."""

@property
@abstractmethod
def schema(self) -> Schema:
"""The schema of the row."""
"""The schema, which is a mapping from column names to their types."""

# ------------------------------------------------------------------------------------------------------------------
# Column operations
Expand Down Expand Up @@ -98,7 +100,6 @@ def get_cell(self, name: str) -> Cell:
| 2 | 4 |
+------+------+
>>> table.remove_rows(lambda row: row["col1"] == 1)
+------+------+
| col1 | col2 |
Expand All @@ -112,7 +113,7 @@ def get_cell(self, name: str) -> Cell:
@abstractmethod
def get_column_type(self, name: str) -> ColumnType:
"""
Get the type of the specified column.
Get the type of a column. This is equivalent to using the `[]` operator (indexed access).
Parameters
----------
Expand All @@ -127,13 +128,13 @@ def get_column_type(self, name: str) -> ColumnType:
Raises
------
ColumnNotFoundError
If the column name does not exist.
If the column does not exist.
"""

@abstractmethod
def has_column(self, name: str) -> bool:
"""
Check if the row has a column with the specified name.
Check if the row has a column with a specific name. This is equivalent to using the `in` operator.
Parameters
----------
Expand Down
24 changes: 12 additions & 12 deletions src/safeds/data/tabular/containers/_table.py
Original file line number Diff line number Diff line change
Expand Up @@ -393,41 +393,41 @@ def _data_frame(self) -> pl.DataFrame:
return self.__data_frame_cache

@property
def column_names(self) -> list[str]:
def column_count(self) -> int:
"""
The names of the columns in the table.
The number of columns.
**Note:** This operation must compute the schema of the table, which can be expensive.
Examples
--------
>>> from safeds.data.tabular.containers import Table
>>> table = Table({"a": [1, 2, 3], "b": [4, 5, 6]})
>>> table.column_names
['a', 'b']
>>> table.column_count
2
"""
return self.schema.column_names
return len(self.column_names)

@property
def column_count(self) -> int:
def column_names(self) -> list[str]:
"""
The number of columns in the table.
The names of the columns in the table.
**Note:** This operation must compute the schema of the table, which can be expensive.
Examples
--------
>>> from safeds.data.tabular.containers import Table
>>> table = Table({"a": [1, 2, 3], "b": [4, 5, 6]})
>>> table.column_count
2
>>> table.column_names
['a', 'b']
"""
return len(self.column_names)
return self.schema.column_names

@property
def row_count(self) -> int:
"""
The number of rows in the table.
The number of rows.
**Note:** This operation must fully load the data into memory, which can be expensive.
Expand Down Expand Up @@ -458,7 +458,7 @@ def plot(self) -> TablePlotter:
@property
def schema(self) -> Schema:
"""
The schema of the table.
The schema, which is a mapping from column names to their types.
Examples
--------
Expand Down
33 changes: 24 additions & 9 deletions src/safeds/data/tabular/typing/_schema.py
Original file line number Diff line number Diff line change
@@ -1,22 +1,20 @@
from __future__ import annotations

import sys
from collections.abc import Iterator, Mapping
from typing import TYPE_CHECKING

from safeds._utils import _structural_hash
from safeds._validation import _check_columns_exist

from ._column_type import ColumnType
from ._polars_column_type import _PolarsColumnType

if TYPE_CHECKING:
from collections.abc import Mapping

import polars as pl

from ._column_type import ColumnType


class Schema:
class Schema(Mapping[str, ColumnType]):
"""The schema of a row or table."""

# ------------------------------------------------------------------------------------------------------------------
Expand All @@ -41,16 +39,30 @@ def __init__(self, schema: Mapping[str, ColumnType]) -> None:
check_dtypes=False,
)

def __contains__(self, key: object, /) -> bool:
if not isinstance(key, str):
return False
return self.has_column(key)

def __eq__(self, other: object) -> bool:
if not isinstance(other, Schema):
return NotImplemented
if self is other:
return True
return self._schema == other._schema

def __getitem__(self, key: str, /) -> ColumnType:
return self.get_column_type(key)

def __hash__(self) -> int:
return _structural_hash(tuple(self._schema.keys()), [str(type_) for type_ in self._schema.values()])

def __iter__(self) -> Iterator[str]:
return iter(self._schema.keys())

def __len__(self) -> int:
return self.column_count

def __repr__(self) -> str:
return f"Schema({self!s})"

Expand Down Expand Up @@ -108,7 +120,7 @@ def column_names(self) -> list[str]:

def get_column_type(self, name: str) -> ColumnType:
"""
Get the type of a column.
Get the type of a column. This is equivalent to using the `[]` operator (indexed access).
Parameters
----------
Expand All @@ -131,14 +143,17 @@ def get_column_type(self, name: str) -> ColumnType:
>>> schema = Schema({"a": ColumnType.int64(), "b": ColumnType.float32()})
>>> schema.get_column_type("a")
int64
>>> schema["b"]
float32
"""
_check_columns_exist(self, name)

return _PolarsColumnType(self._schema[name])

def has_column(self, name: str) -> bool:
"""
Check if the table has a column with a specific name.
Check if the schema has a column with a specific name. This is equivalent to using the `in` operator.
Parameters
----------
Expand All @@ -148,7 +163,7 @@ def has_column(self, name: str) -> bool:
Returns
-------
has_column:
Whether the table has a column with the specified name.
Whether the schema has a column with the specified name.
Examples
--------
Expand All @@ -157,7 +172,7 @@ def has_column(self, name: str) -> bool:
>>> schema.has_column("a")
True
>>> schema.has_column("c")
>>> "c" in schema
False
"""
return name in self._schema
Expand Down
65 changes: 46 additions & 19 deletions tests/helpers/_assertions.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from polars.testing import assert_frame_equal

from safeds.data.labeled.containers import TabularDataset
from safeds.data.tabular.containers import Cell, Column, Table
from safeds.data.tabular.containers import Cell, Column, Row, Table


def assert_tables_are_equal(
Expand Down Expand Up @@ -62,44 +62,71 @@ def assert_that_tabular_datasets_are_equal(table1: TabularDataset, table2: Tabul


def assert_cell_operation_works(
input_value: Any,
value: Any,
transformer: Callable[[Cell], Cell],
expected_value: Any,
expected: Any,
) -> None:
"""
Assert that a cell operation works as expected.
Parameters
----------
input_value:
value:
The value in the input cell.
transformer:
The transformer to apply to the cells.
expected_value:
expected:
The expected value of the transformed cell.
"""
column = Column("A", [input_value])
column = Column("A", [value])
transformed_column = column.transform(transformer)
assert transformed_column == Column("A", [expected_value]), f"Expected: {expected_value}\nGot: {transformed_column}"
actual = transformed_column[0]
assert actual == expected


def assert_row_operation_works(
input_value: Any,
transformer: Callable[[Table], Table],
expected_value: Any,
table: Table,
computer: Callable[[Row], Cell],
expected: list[Any],
) -> None:
"""
Assert that a row operation works as expected.
Parameters
----------
input_value:
The value in the input row.
transformer:
The transformer to apply to the rows.
expected_value:
The expected value of the transformed row.
table:
The input table.
computer:
The function that computes the new column.
expected:
The expected values of the computed column.
"""
column_name = _find_free_column_name(table, "computed")

new_table = table.add_computed_column(column_name, computer)
actual = list(new_table.get_column(column_name))
assert actual == expected


def _find_free_column_name(table: Table, prefix: str) -> str:
"""
table = Table(input_value)
transformed_table = transformer(table)
assert transformed_table == Table(expected_value), f"Expected: {expected_value}\nGot: {transformed_table}"
Find a free column name in the table.
Parameters
----------
table:
The table to search for a free column name.
prefix:
The prefix to use for the column name.
Returns
-------
free_name:
A free column name.
"""
column_name = prefix

while column_name in table.column_names:
column_name += "_"

return column_name
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
# serializer version: 1
# name: TestContract.test_should_return_same_hash_in_different_processes[empty]
1789859531466043636
# ---
# name: TestContract.test_should_return_same_hash_in_different_processes[no rows]
585695607399955642
# ---
# name: TestContract.test_should_return_same_hash_in_different_processes[with data]
909875695937937648
# ---
Loading

0 comments on commit ca1ce3d

Please sign in to comment.