Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Added method Table.inverse_transform_table which returns the original table #227

Merged
merged 9 commits into from
Apr 21, 2023
23 changes: 23 additions & 0 deletions src/safeds/data/tabular/containers/_table.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,8 @@
if TYPE_CHECKING:
from collections.abc import Callable, Iterable

from safeds.data.tabular.transformation import InvertibleTableTransformer

from ._tagged_table import TaggedTable


Expand Down Expand Up @@ -991,6 +993,27 @@ def transform_column(self, name: str, transformer: Callable[[Row], Any]) -> Tabl
return self.replace_column(name, result)
raise UnknownColumnNameError([name])

def inverse_transform_table(self, transformer: InvertibleTableTransformer) -> Table:
"""
Use a transformer to inverse the transformation of this table.
Marsmaennchen221 marked this conversation as resolved.
Show resolved Hide resolved

Parameters
----------
transformer : InvertibleTableTransformer
The transformer which is fitted on this table
lars-reimann marked this conversation as resolved.
Show resolved Hide resolved

Returns
-------
table : Table
The original table

Raises
------
TransformerNotFittedError
If the transformer has not been fitted yet.
"""
Marsmaennchen221 marked this conversation as resolved.
Show resolved Hide resolved
return transformer.inverse_transform(self)

# ------------------------------------------------------------------------------------------------------------------
# Plotting
# ------------------------------------------------------------------------------------------------------------------
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,119 @@
import pytest
from safeds.data.tabular.containers import Table
from safeds.data.tabular.exceptions import TransformerNotFittedError
from safeds.data.tabular.transformation import OneHotEncoder


class TestInverseTransformTableOnOneHotEncoder:
@pytest.mark.parametrize(
("table_to_fit", "column_names", "table_to_transform"),
[
(
Table.from_dict(
{
"a": [1.0, 0.0, 0.0, 0.0],
"b": ["a", "b", "b", "c"],
"c": [0.0, 0.0, 0.0, 1.0],
},
),
["b"],
Table.from_dict(
{
"a": [1.0, 0.0, 0.0, 0.0],
"b": ["a", "b", "b", "c"],
"c": [0.0, 0.0, 0.0, 1.0],
},
),
),
(
Table.from_dict(
{
"a": [1.0, 0.0, 0.0, 0.0],
"b": ["a", "b", "b", "c"],
"c": [0.0, 0.0, 0.0, 1.0],
},
),
["b"],
Table.from_dict(
{
"c": [0.0, 0.0, 0.0, 1.0],
"b": ["a", "b", "b", "c"],
"a": [1.0, 0.0, 0.0, 0.0],
},
),
),
(
Table.from_dict(
{
"a": [1.0, 0.0, 0.0, 0.0],
"b": ["a", "b", "b", "c"],
"bb": ["a", "b", "b", "c"],
},
),
["b", "bb"],
Table.from_dict(
{
"a": [1.0, 0.0, 0.0, 0.0],
"b": ["a", "b", "b", "c"],
"bb": ["a", "b", "b", "c"],
},
),
),
],
ids=[
"same table to fit and transform",
"different tables to fit and transform",
"one column name is a prefix of another column name",
],
)
def test_should_return_original_table(
self,
table_to_fit: Table,
column_names: list[str],
table_to_transform: Table,
) -> None:
transformer = OneHotEncoder().fit(table_to_fit, column_names)
transformed_table = transformer.transform(table_to_transform)

result = transformed_table.inverse_transform_table(transformer)

# This checks whether the columns are in the same order
assert result.column_names == table_to_transform.column_names
# This is subsumed by the next assertion, but we get a better error message
assert result.schema == table_to_transform.schema
assert result == table_to_transform

def test_should_not_change_transformed_table(self) -> None:
table = Table.from_dict(
{
"col1": ["a", "b", "b", "c"],
},
)

transformer = OneHotEncoder().fit(table, None)
transformed_table = transformer.transform(table)
transformed_table.inverse_transform_table(transformer)

expected = Table.from_dict(
{
"col1_a": [1.0, 0.0, 0.0, 0.0],
"col1_b": [0.0, 1.0, 1.0, 0.0],
"col1_c": [0.0, 0.0, 0.0, 1.0],
},
)

assert transformed_table == expected

def test_should_raise_if_not_fitted(self) -> None:
table = Table.from_dict(
{
"a": [1.0, 0.0, 0.0, 0.0],
"b": [0.0, 1.0, 1.0, 0.0],
"c": [0.0, 0.0, 0.0, 1.0],
},
)

transformer = OneHotEncoder()

with pytest.raises(TransformerNotFittedError):
table.inverse_transform_table(transformer)