Skip to content

Commit

Permalink
feat: make Column and Row iterable (#55)
Browse files Browse the repository at this point in the history
Closes #47.

### Summary of Changes

* Add `__iter__` method to `Column` and `Row` to iterate over the
values:
  * Iterating over a `Column` returns the values.
* Iterating over a `Row` returns the column names, as specified in the
[documenetation of
`__iter__`](https://docs.python.org/3/reference/datamodel.html#object.__iter__).
* Add `__len__` method to `Column` and `Row` to compute their length.
* Change superclasses of exceptions as needed for
[`__getitem__`](https://docs.python.org/3/reference/datamodel.html#object.__getitem__):
  * Change superclass of `IndexOutOfBoundsError` to `IndexError`.
  * Change superclass of `UnknownColumnNameError` to `KeyError`.

---------

Co-authored-by: lars-reimann <[email protected]>
  • Loading branch information
lars-reimann and lars-reimann authored Mar 21, 2023
1 parent c3fd3b5 commit 74eea1f
Show file tree
Hide file tree
Showing 14 changed files with 102 additions and 34 deletions.
58 changes: 31 additions & 27 deletions src/safeds/data/tabular/_column.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import typing
from numbers import Number
from typing import Any, Callable
from typing import Any, Callable, Iterator

import numpy as np
import pandas as pd
Expand Down Expand Up @@ -34,6 +34,10 @@ def name(self) -> str:
"""
return self._name

@property
def statistics(self) -> ColumnStatistics:
return ColumnStatistics(self)

@property
def type(self) -> ColumnType:
"""
Expand All @@ -46,9 +50,35 @@ def type(self) -> ColumnType:
"""
return self._type

def __eq__(self, other: object) -> bool:
if not isinstance(other, Column):
return NotImplemented
if self is other:
return True
return self._data.equals(other._data) and self.name == other.name

def __getitem__(self, index: int) -> Any:
return self.get_value(index)

def __hash__(self) -> int:
return hash(self._data)

def __iter__(self) -> Iterator[Any]:
return iter(self._data)

def __len__(self) -> int:
return len(self._data)

def __repr__(self) -> str:
tmp = self._data.to_frame()
tmp.columns = [self.name]
return tmp.__repr__()

def __str__(self) -> str:
tmp = self._data.to_frame()
tmp.columns = [self.name]
return tmp.__str__()

def get_value(self, index: int) -> Any:
"""
Return column value at specified index, starting at 0.
Expand All @@ -73,10 +103,6 @@ def get_value(self, index: int) -> Any:

return self._data[index]

@property
def statistics(self) -> ColumnStatistics:
return ColumnStatistics(self)

def count(self) -> int:
"""
Return the number of elements in the column.
Expand Down Expand Up @@ -223,26 +249,6 @@ def get_unique_values(self) -> list[typing.Any]:
"""
return list(self._data.unique())

def __eq__(self, other: object) -> bool:
if not isinstance(other, Column):
return NotImplemented
if self is other:
return True
return self._data.equals(other._data) and self.name == other.name

def __hash__(self) -> int:
return hash(self._data)

def __str__(self) -> str:
tmp = self._data.to_frame()
tmp.columns = [self.name]
return tmp.__str__()

def __repr__(self) -> str:
tmp = self._data.to_frame()
tmp.columns = [self.name]
return tmp.__repr__()

def _ipython_display_(self) -> DisplayHandle:
"""
Return a display object for the column to be used in Jupyter Notebooks.
Expand Down Expand Up @@ -378,7 +384,6 @@ def sum(self) -> float:
return self._column._data.sum()

def variance(self) -> float:

"""
Return the variance of the column. The column has to be numerical.
Expand All @@ -401,7 +406,6 @@ def variance(self) -> float:
return self._column._data.var()

def standard_deviation(self) -> float:

"""
Return the standard deviation of the column. The column has to be numerical.
Expand Down
17 changes: 17 additions & 0 deletions src/safeds/data/tabular/_row.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,12 @@ def __init__(self, data: typing.Iterable, schema: TableSchema):
def __getitem__(self, column_name: str) -> Any:
return self.get_value(column_name)

def __iter__(self) -> typing.Iterator[Any]:
return iter(self.get_column_names())

def __len__(self) -> int:
return len(self._data)

def get_value(self, column_name: str) -> Any:
"""
Return the value of a specified column.
Expand All @@ -34,6 +40,17 @@ def get_value(self, column_name: str) -> Any:
raise UnknownColumnNameError([column_name])
return self._data[self.schema._get_column_index_by_name(column_name)]

def count(self) -> int:
"""
Return the number of columns in this row.
Returns
-------
count : int
The number of columns.
"""
return len(self._data)

def has_column(self, column_name: str) -> bool:
"""
Return whether the row contains a given column.
Expand Down
4 changes: 2 additions & 2 deletions src/safeds/exceptions/_data_exceptions.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
class UnknownColumnNameError(Exception):
class UnknownColumnNameError(KeyError):
"""
Exception raised for trying to access an invalid column name.
Expand Down Expand Up @@ -37,7 +37,7 @@ def __init__(self, column_name: str):
super().__init__(f"Column '{column_name}' already exists.")


class IndexOutOfBoundsError(Exception):
class IndexOutOfBoundsError(IndexError):
"""
Exception raised for trying to access an element by an index that does not exist in the underlying data.
Expand Down
2 changes: 1 addition & 1 deletion tests/safeds/data/tabular/_column/test_column.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ def test_from_columns() -> None:
assert column1._type == column2._type


def negative_test_from_columns() -> None:
def test_from_columns_negative() -> None:
column1 = Column(pd.Series([1, 4]), "A")
column2 = Column(pd.Series(["2", "5"]), "B")

Expand Down
7 changes: 3 additions & 4 deletions tests/safeds/data/tabular/_column/test_count.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
import pandas as pd
from safeds.data.tabular import Table
from safeds.data.tabular import Column


def test_count_valid() -> None:
table = Table(pd.DataFrame(data={"col1": [1, 2, 3, 4, 5], "col2": [2, 3, 4, 5, 6]}))
assert table.get_column("col1").count() == 5
column = Column([1, 2, 3, 4, 5], "col1")
assert column.count() == 5
6 changes: 6 additions & 0 deletions tests/safeds/data/tabular/_column/test_iter.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
from safeds.data.tabular import Column


def test_iter() -> None:
column = Column([0, "1"], "testColumn")
assert list(column) == [0, "1"]
6 changes: 6 additions & 0 deletions tests/safeds/data/tabular/_column/test_len.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
from safeds.data.tabular import Column


def test_count_valid() -> None:
column = Column([1, 2, 3, 4, 5], "col1")
assert len(column) == 5
12 changes: 12 additions & 0 deletions tests/safeds/data/tabular/_row/test_count.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
from safeds.data.tabular import Row
from safeds.data.tabular.typing import IntColumnType, StringColumnType, TableSchema


def test_count() -> None:
row = Row(
[0, "1"],
TableSchema(
{"testColumn1": IntColumnType(), "testColumn2": StringColumnType()}
),
)
assert row.count() == 2
12 changes: 12 additions & 0 deletions tests/safeds/data/tabular/_row/test_iter.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
from safeds.data.tabular import Row
from safeds.data.tabular.typing import IntColumnType, StringColumnType, TableSchema


def test_iter() -> None:
row = Row(
[0, "1"],
TableSchema(
{"testColumn1": IntColumnType(), "testColumn2": StringColumnType()}
),
)
assert list(row) == ["testColumn1", "testColumn2"]
12 changes: 12 additions & 0 deletions tests/safeds/data/tabular/_row/test_len.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
from safeds.data.tabular import Row
from safeds.data.tabular.typing import IntColumnType, StringColumnType, TableSchema


def test_count() -> None:
row = Row(
[0, "1"],
TableSchema(
{"testColumn1": IntColumnType(), "testColumn2": StringColumnType()}
),
)
assert len(row) == 2

0 comments on commit 74eea1f

Please sign in to comment.