Skip to content

Commit

Permalink
feat: make Column a subclass of Sequence
Browse files Browse the repository at this point in the history
  • Loading branch information
lars-reimann committed Apr 22, 2023
1 parent ad1cac5 commit 5bcc5bd
Show file tree
Hide file tree
Showing 5 changed files with 96 additions and 68 deletions.
77 changes: 68 additions & 9 deletions src/safeds/data/tabular/containers/_column.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
from __future__ import annotations

import io
from collections.abc import Sequence
from numbers import Number
from typing import TYPE_CHECKING, Any
from typing import TYPE_CHECKING, Any, TypeVar

import matplotlib.pyplot as plt
import numpy as np
Expand All @@ -20,32 +21,90 @@
from safeds.data.tabular.typing import ColumnType

if TYPE_CHECKING:
from collections.abc import Callable, Iterable, Iterator
from collections.abc import Callable, Iterator

_T = TypeVar("_T")

class Column:

class Column(Sequence[_T]):
"""
A column is a named collection of values.
Parameters
----------
name : str
The name of the column.
data : Iterable
data : Sequence
The data.
type_ : Optional[ColumnType]
The type of the column. If not specified, the type will be inferred from the data.
Examples
--------
>>> from safeds.data.tabular.containers import Column
>>> column = Column("test", [1, 2, 3])
"""

# ------------------------------------------------------------------------------------------------------------------
# Creation
# ------------------------------------------------------------------------------------------------------------------

@staticmethod
def _from_pandas_series(data: pd.Series, type_: ColumnType | None = None) -> Column:
"""
Create a column from a `pandas.Series`.
Parameters
----------
data : pd.Series
The data.
type_ : ColumnType | None
The type. If None, the schema is inferred from the data.
Returns
-------
column : Column
The created column.
Examples
--------
>>> import pandas as pd
>>> from safeds.data.tabular.containers import Row
>>> row = Column._from_pandas_series(pd.Series([1, 2, 3], name="test"))
"""
result = object.__new__(Column)
result._name = data.name
result._data = data
# noinspection PyProtectedMember
result._type = type_ if type_ is not None else ColumnType._from_numpy_data_type(data.dtype)

return result

# ------------------------------------------------------------------------------------------------------------------
# Dunder methods
# ------------------------------------------------------------------------------------------------------------------

def __init__(self, name: str, data: Iterable, type_: ColumnType | None = None) -> None:
def __init__(self, name: str, data: Sequence) -> None:
"""
Create a column.
Parameters
----------
name
The name of the column.
data
The data.
Examples
--------
>>> from safeds.data.tabular.containers import Column
>>> column = Column("test", [1, 2, 3])
"""
self._name: str = name
self._data: pd.Series = data if isinstance(data, pd.Series) else pd.Series(data)
# noinspection PyProtectedMember
self._type: ColumnType = type_ if type_ is not None else ColumnType._from_numpy_data_type(self._data.dtype)
self._type: ColumnType = ColumnType._from_numpy_data_type(self._data.dtype)

def __contains__(self, item: Any) -> bool:
return item in self._data

def __eq__(self, other: object) -> bool:
if not isinstance(other, Column):
Expand Down Expand Up @@ -236,7 +295,7 @@ def rename(self, new_name: str) -> Column:
column : Column
A new column with the new name.
"""
return Column(new_name, self._data, self._type)
return Column._from_pandas_series(self._data.rename(new_name), self._type)

# ------------------------------------------------------------------------------------------------------------------
# Statistics
Expand Down
18 changes: 8 additions & 10 deletions src/safeds/data/tabular/containers/_table.py
Original file line number Diff line number Diff line change
Expand Up @@ -338,15 +338,13 @@ def get_column(self, column_name: str) -> Column:
UnknownColumnNameError
If the specified target column name does not exist.
"""
if self._schema.has_column(column_name):
output_column = Column(
column_name,
self._data.iloc[:, [self._schema._get_column_index(column_name)]].squeeze(),
self._schema.get_column_type(column_name),
)
return output_column
if not self.has_column(column_name):
raise UnknownColumnNameError([column_name])

raise UnknownColumnNameError([column_name])
return Column._from_pandas_series(
self._data[column_name],
self.get_column_type(column_name),
)

def has_column(self, column_name: str) -> bool:
"""
Expand Down Expand Up @@ -866,7 +864,7 @@ def slice_rows(
def sort_columns(
self,
comparator: Callable[[Column, Column], int] = lambda col1, col2: (col1.name > col2.name)
- (col1.name < col2.name),
- (col1.name < col2.name),
) -> Table:
"""
Sort the columns of a `Table` with the given comparator and return a new `Table`.
Expand Down Expand Up @@ -980,7 +978,7 @@ def transform_column(self, name: str, transformer: Callable[[Row], Any]) -> Tabl
"""
if self.has_column(name):
items: list = [transformer(item) for item in self.to_rows()]
result: Column = Column(name, pd.Series(items))
result: Column = Column._from_pandas_series(pd.Series(items, name=name))
return self.replace_column(name, result)
raise UnknownColumnNameError([name])

Expand Down
28 changes: 0 additions & 28 deletions src/safeds/data/tabular/typing/_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -239,31 +239,3 @@ def _repr_markdown_(self) -> str:
lines = (f"| {name} | {type_} |" for name, type_ in self._schema.items())
joined = "\n".join(lines)
return f"| Column Name | Column Type |\n| --- | --- |\n{joined}"

# ------------------------------------------------------------------------------------------------------------------
# Other
# ------------------------------------------------------------------------------------------------------------------

def _get_column_index(self, column_name: str) -> int:
"""
Return the index of the column with specified column name.
Parameters
----------
column_name : str
The name of the column.
Returns
-------
index : int
The index of the column.
Raises
------
ColumnNameError
If the specified column name does not exist.
"""
if not self.has_column(column_name):
raise UnknownColumnNameError([column_name])

return list(self._schema.keys()).index(column_name)
20 changes: 20 additions & 0 deletions tests/safeds/data/tabular/containers/test_column.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,28 @@
from typing import Any

import pytest
import regex as re
from safeds.data.tabular.containers import Column


class TestContains:
@pytest.mark.parametrize(
("column", "value", "expected"),
[
(Column("a", []), 1, False),
(Column("a", [1, 2, 3]), 1, True),
(Column("a", [1, 2, 3]), 4, False),
],
ids=[
"empty",
"value exists",
"value does not exist",
],
)
def test_should_check_whether_the_value_exists(self, column: Column, value: Any, expected: bool) -> None:
assert (value in column) == expected


class TestToHtml:
@pytest.mark.parametrize(
"column",
Expand Down
21 changes: 0 additions & 21 deletions tests/safeds/data/tabular/typing/test_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -212,27 +212,6 @@ def test_should_return_column_names(self, schema: Schema, expected: list[str]) -
assert schema.column_names == expected


class TestGetColumnIndex:
@pytest.mark.parametrize(
("schema", "column_name", "expected"),
[
(Schema({"A": Integer()}), "A", 0),
(Schema({"A": Integer(), "B": RealNumber()}), "B", 1),
],
ids=[
"single column",
"multiple columns",
],
)
def test_should_return_column_index(self, schema: Schema, column_name: str, expected: int) -> None:
assert schema._get_column_index(column_name) == expected

def test_should_raise_if_column_does_not_exist(self) -> None:
schema = Schema({"A": Integer()})
with pytest.raises(UnknownColumnNameError):
schema._get_column_index("B")


class TestToDict:
@pytest.mark.parametrize(
("schema", "expected"),
Expand Down

0 comments on commit 5bcc5bd

Please sign in to comment.