Skip to content

Commit

Permalink
feat: Table.count_row_if (#788)
Browse files Browse the repository at this point in the history
Closes #786 

### Summary of Changes

Add a new method `Table.count_row_if` to count the rows in a table that
satisfy a given predicate.
  • Loading branch information
lars-reimann authored May 17, 2024
1 parent 1c3ea59 commit 4137131
Show file tree
Hide file tree
Showing 13 changed files with 218 additions and 42 deletions.
27 changes: 26 additions & 1 deletion src/safeds/_validation/_check_columns_are_numeric.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,35 @@
if TYPE_CHECKING:
from collections.abc import Container

from safeds.data.tabular.containers import Table
from safeds.data.tabular.containers import Column, Table
from safeds.data.tabular.typing import Schema


def _check_column_is_numeric(
column: Column,
*,
operation: str = "do a numeric operation",
) -> None:
"""
Check if the column is numeric and raise an error if it is not.
Parameters
----------
column:
The column to check.
operation:
The operation that is performed on the column. This is used in the error message.
Raises
------
ColumnTypeError
If the column is not numeric.
"""
if not column.type.is_numeric:
message = _build_error_message([column.name], operation)
raise ColumnTypeError(message)


def _check_columns_are_numeric(
table_or_schema: Table | Schema,
column_names: str | list[str],
Expand Down
31 changes: 13 additions & 18 deletions src/safeds/data/tabular/containers/_column.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,13 @@
from typing import TYPE_CHECKING, Any, Literal, TypeVar, overload

from safeds._utils import _structural_hash
from safeds._validation._check_columns_are_numeric import _check_column_is_numeric
from safeds.data.tabular.plotting import ColumnPlotter
from safeds.data.tabular.typing._polars_data_type import _PolarsDataType
from safeds.exceptions import (
ColumnLengthMismatchError,
IndexOutOfBoundsError,
MissingValuesColumnError,
NonNumericColumnError,
)

from ._lazy_cell import _LazyCell
Expand Down Expand Up @@ -223,7 +223,7 @@ def get_distinct_values(

def get_value(self, index: int) -> T_co:
"""
Return the column value at specified index.
Return the column value at specified index. Equivalent to the `[]` operator (indexed access).
Nonnegative indices are counted from the beginning (starting at 0), negative indices from the end (starting at
-1).
Expand All @@ -249,6 +249,9 @@ def get_value(self, index: int) -> T_co:
>>> column = Column("test", [1, 2, 3])
>>> column.get_value(1)
2
>>> column[1]
2
"""
if index < -self.row_count or index >= self.row_count:
raise IndexOutOfBoundsError(index)
Expand Down Expand Up @@ -434,7 +437,7 @@ def count_if(
"""
Return how many values in the column satisfy the predicate.
The predicate can return one of three values:
The predicate can return one of three results:
* True, if the value satisfies the predicate.
* False, if the value does not satisfy the predicate.
Expand All @@ -458,11 +461,6 @@ def count_if(
count:
The number of values in the column that satisfy the predicate.
Raises
------
TypeError
If the predicate does not return a boolean cell.
Examples
--------
>>> from safeds.data.tabular.containers import Column
Expand Down Expand Up @@ -764,8 +762,9 @@ def correlation_with(self, other: Column) -> float:
"""
import polars as pl

if not self.is_numeric or not other.is_numeric:
raise NonNumericColumnError("") # TODO: Add column names to error message
_check_column_is_numeric(self, operation="calculate the correlation")
_check_column_is_numeric(other, operation="calculate the correlation")

if self.row_count != other.row_count:
raise ColumnLengthMismatchError("") # TODO: Add column names to error message
if self.missing_value_count() > 0 or other.missing_value_count() > 0:
Expand Down Expand Up @@ -881,8 +880,7 @@ def mean(self) -> T_co:
>>> column.mean()
2.0
"""
if not self.is_numeric:
raise NonNumericColumnError("") # TODO: Add column name to error message
_check_column_is_numeric(self, operation="calculate the mean")

return self._series.mean()

Expand Down Expand Up @@ -910,8 +908,7 @@ def median(self) -> T_co:
>>> column.median()
2.0
"""
if not self.is_numeric:
raise NonNumericColumnError("") # TODO: Add column name to error message
_check_column_is_numeric(self, operation="calculate the median")

return self._series.median()

Expand Down Expand Up @@ -1087,8 +1084,7 @@ def standard_deviation(self) -> float:
>>> column.standard_deviation()
1.0
"""
if not self.is_numeric:
raise NonNumericColumnError("") # TODO: Add column name to error message
_check_column_is_numeric(self, operation="calculate the standard deviation")

return self._series.std()

Expand Down Expand Up @@ -1116,8 +1112,7 @@ def variance(self) -> float:
>>> column.variance()
1.0
"""
if not self.is_numeric:
raise NonNumericColumnError("") # TODO: Add column name to error message
_check_column_is_numeric(self, operation="calculate the variance")

return self._series.var()

Expand Down
25 changes: 24 additions & 1 deletion src/safeds/data/tabular/containers/_row.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ def schema(self) -> Schema:
@abstractmethod
def get_value(self, name: str) -> Cell:
"""
Get the value of the specified column.
Get the value of the specified column. This is equivalent to using the `[]` operator (indexed access).
Parameters
----------
Expand All @@ -84,6 +84,29 @@ def get_value(self, name: str) -> Cell:
------
ColumnNotFoundError
If the column name does not exist.
Examples
--------
>>> from safeds.data.tabular.containers import Table
>>> table = Table({"col1": [1, 2], "col2": [3, 4]})
>>> table.remove_rows(lambda row: row.get_value("col1") == 1)
+------+------+
| col1 | col2 |
| --- | --- |
| i64 | i64 |
+=============+
| 2 | 4 |
+------+------+
>>> table.remove_rows(lambda row: row["col1"] == 1)
+------+------+
| col1 | col2 |
| --- | --- |
| i64 | i64 |
+=============+
| 2 | 4 |
+------+------+
"""

@abstractmethod
Expand Down
69 changes: 68 additions & 1 deletion src/safeds/data/tabular/containers/_table.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from __future__ import annotations

from typing import TYPE_CHECKING, Any, Literal
from typing import TYPE_CHECKING, Any, Literal, overload

from safeds._config import _get_device, _init_default_device
from safeds._config._polars import _get_polars_config
Expand Down Expand Up @@ -1008,6 +1008,73 @@ def transform_column(
# Row operations
# ------------------------------------------------------------------------------------------------------------------

@overload
def count_row_if(
self,
predicate: Callable[[Row], Cell[bool | None]],
*,
ignore_unknown: Literal[True] = ...,
) -> int: ...

@overload
def count_row_if(
self,
predicate: Callable[[Row], Cell[bool | None]],
*,
ignore_unknown: bool,
) -> int | None: ...

def count_row_if(
self,
predicate: Callable[[Row], Cell[bool | None]],
*,
ignore_unknown: bool = True,
) -> int | None:
"""
Return how many rows in the table satisfy the predicate.
The predicate can return one of three results:
* True, if the row satisfies the predicate.
* False, if the row does not satisfy the predicate.
* None, if the truthiness of the predicate is unknown, e.g. due to missing values.
By default, cases where the truthiness of the predicate is unknown are ignored and this method returns how
often the predicate returns True.
You can instead enable Kleene logic by setting `ignore_unknown=False`. In this case, this method returns None if
the predicate returns None at least once. Otherwise, it still returns how often the predicate returns True.
Parameters
----------
predicate:
The predicate to apply to each row.
ignore_unknown:
Whether to ignore cases where the truthiness of the predicate is unknown.
Returns
-------
count:
The number of rows in the table that satisfy the predicate.
Examples
--------
>>> from safeds.data.tabular.containers import Table
>>> table = Table({"col1": [1, 2, 3], "col2": [1, 3, 3]})
>>> table.count_row_if(lambda row: row["col1"] == row["col2"])
2
>>> table.count_row_if(lambda row: row["col1"] > row["col2"])
0
"""
expression = predicate(_LazyVectorizedRow(self))._polars_expression
series = self._lazy_frame.select(expression.alias("count")).collect().get_column("count")

if ignore_unknown or series.null_count() == 0:
return series.sum()
else:
return None

# TODO: Rethink group_rows/group_rows_by_column. They should not return a dict.

def remove_duplicate_rows(self) -> Table:
Expand Down
12 changes: 5 additions & 7 deletions src/safeds/data/tabular/plotting/_column_plotter.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from typing import TYPE_CHECKING

from safeds._utils import _figure_to_image
from safeds.exceptions import NonNumericColumnError
from safeds._validation._check_columns_are_numeric import _check_column_is_numeric

if TYPE_CHECKING:
from safeds.data.image.containers import Image
Expand Down Expand Up @@ -49,9 +49,8 @@ def box_plot(self) -> Image:
>>> column = Column("test", [1, 2, 3])
>>> boxplot = column.plot.box_plot()
"""
if self._column.row_count > 0 and not self._column.is_numeric:
# TODO better error message
raise NonNumericColumnError(f"{self._column.name} is of type {self._column.type}.")
if self._column.row_count > 0:
_check_column_is_numeric(self._column, operation="create a box plot")

import matplotlib.pyplot as plt

Expand Down Expand Up @@ -115,9 +114,8 @@ def lag_plot(self, lag: int) -> Image:
>>> column = Column("values", [1, 2, 3, 4])
>>> image = column.plot.lag_plot(2)
"""
if self._column.row_count > 0 and not self._column.is_numeric:
# TODO better error message
raise NonNumericColumnError("This time series target contains non-numerical columns.")
if self._column.row_count > 0:
_check_column_is_numeric(self._column, operation="create a lag plot")

import matplotlib.pyplot as plt

Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,10 @@
import pytest
from safeds.data.tabular.containers import Column
from safeds.exceptions import ColumnLengthMismatchError, MissingValuesColumnError, NonNumericColumnError
from safeds.exceptions import (
ColumnLengthMismatchError,
ColumnTypeError,
MissingValuesColumnError,
)


@pytest.mark.parametrize(
Expand Down Expand Up @@ -38,7 +42,7 @@ def test_should_return_correlation_between_two_columns(values1: list, values2: l
def test_should_raise_if_columns_are_not_numeric(values1: list, values2: list) -> None:
column1 = Column("A", values1)
column2 = Column("B", values2)
with pytest.raises(NonNumericColumnError):
with pytest.raises(ColumnTypeError):
column1.correlation_with(column2)


Expand Down
4 changes: 2 additions & 2 deletions tests/safeds/data/tabular/containers/_column/test_mean.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import pytest
from safeds.data.tabular.containers import Column
from safeds.exceptions import NonNumericColumnError
from safeds.exceptions import ColumnTypeError


@pytest.mark.parametrize(
Expand Down Expand Up @@ -36,5 +36,5 @@ def test_should_return_mean_value(values: list, expected: int) -> None:
)
def test_should_raise_if_column_is_not_numeric(values: list) -> None:
column = Column("col", values)
with pytest.raises(NonNumericColumnError):
with pytest.raises(ColumnTypeError):
column.mean()
4 changes: 2 additions & 2 deletions tests/safeds/data/tabular/containers/_column/test_median.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import pytest
from safeds.data.tabular.containers import Column
from safeds.exceptions import NonNumericColumnError
from safeds.exceptions import ColumnTypeError


@pytest.mark.parametrize(
Expand Down Expand Up @@ -36,5 +36,5 @@ def test_should_return_median_value(values: list, expected: int) -> None:
)
def test_should_raise_if_column_is_not_numeric(values: list) -> None:
column = Column("A", values)
with pytest.raises(NonNumericColumnError):
with pytest.raises(ColumnTypeError):
column.median()
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import pytest
from safeds.data.tabular.containers import Column
from safeds.exceptions import NonNumericColumnError
from safeds.exceptions import ColumnTypeError
from syrupy import SnapshotAssertion


Expand All @@ -24,5 +24,5 @@ def test_should_match_snapshot(column: Column, snapshot_png_image: SnapshotAsser

def test_should_raise_if_column_contains_non_numerical_values() -> None:
column = Column("a", ["A", "B", "C"])
with pytest.raises(NonNumericColumnError):
with pytest.raises(ColumnTypeError):
column.plot.box_plot()
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import pytest
from safeds.data.tabular.containers import Column
from safeds.exceptions import NonNumericColumnError
from safeds.exceptions import ColumnTypeError
from syrupy import SnapshotAssertion


Expand All @@ -24,5 +24,5 @@ def test_should_match_snapshot(column: Column, snapshot_png_image: SnapshotAsser

def test_should_raise_if_column_contains_non_numerical_values() -> None:
column = Column("a", ["A", "B", "C"])
with pytest.raises(NonNumericColumnError):
with pytest.raises(ColumnTypeError):
column.plot.lag_plot(1)
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import pytest
from safeds.data.tabular.containers import Column
from safeds.exceptions import NonNumericColumnError
from safeds.exceptions import ColumnTypeError


@pytest.mark.parametrize(
Expand Down Expand Up @@ -34,5 +34,5 @@ def test_should_return_standard_deviation(values: list, expected: int) -> None:
)
def test_should_raise_if_column_is_not_numeric(values: list) -> None:
column = Column("A", values)
with pytest.raises(NonNumericColumnError):
with pytest.raises(ColumnTypeError):
column.standard_deviation()
Loading

0 comments on commit 4137131

Please sign in to comment.