From 3ec0041669ffe97640f96db345f3f43720d5c3f7 Mon Sep 17 00:00:00 2001 From: Alexander <47296670+Marsmaennchen221@users.noreply.github.com> Date: Tue, 18 Apr 2023 11:01:22 +0200 Subject: [PATCH] feat: `OneHotEncoder.inverse_transform` now maintains the column order from the original table (#195) Closes #109. ### Summary of Changes `OneHotEncoder.inverse_transform` now maintains the column order from the original table (#109) Fixed bug with `OneHotEncoder.inverse_transform` to not work if not all columns were fitted New feature columns in `OneHotEncoder` will now be inserted where the combined columns were in the original table --------- Co-authored-by: megalinter-bot <129584137+megalinter-bot@users.noreply.github.com> Co-authored-by: Lars Reimann --- .../transformation/_one_hot_encoder.py | 52 ++++++++-- .../transformation/test_one_hot_encoder.py | 98 +++++++++++++++++-- 2 files changed, 133 insertions(+), 17 deletions(-) diff --git a/src/safeds/data/tabular/transformation/_one_hot_encoder.py b/src/safeds/data/tabular/transformation/_one_hot_encoder.py index a8a7af4c0..3c84ce5c0 100644 --- a/src/safeds/data/tabular/transformation/_one_hot_encoder.py +++ b/src/safeds/data/tabular/transformation/_one_hot_encoder.py @@ -15,7 +15,7 @@ class OneHotEncoder(InvertibleTableTransformer): def __init__(self) -> None: self._wrapped_transformer: sk_OneHotEncoder | None = None - self._column_names: list[str] | None = None + self._column_names: dict[str, list[str]] | None = None # noinspection PyProtectedMember def fit(self, table: Table, column_names: list[str] | None = None) -> OneHotEncoder: @@ -49,7 +49,10 @@ def fit(self, table: Table, column_names: list[str] | None = None) -> OneHotEnco result = OneHotEncoder() result._wrapped_transformer = wrapped_transformer - result._column_names = column_names + result._column_names = { + column: [f"{column}_{element}" for element in table.get_column(column).get_unique_values()] + for column in column_names + } return result @@ -78,19 +81,33 @@ def transform(self, table: Table) -> Table: raise TransformerNotFittedError # Input table does not contain all columns used to fit the transformer - missing_columns = set(self._column_names) - set(table.get_column_names()) + missing_columns = set(self._column_names.keys()) - set(table.get_column_names()) if len(missing_columns) > 0: raise UnknownColumnNameError(list(missing_columns)) original = table._data.copy() original.columns = table.schema.get_column_names() - one_hot_encoded = pd.DataFrame(self._wrapped_transformer.transform(original[self._column_names]).toarray()) + one_hot_encoded = pd.DataFrame( + self._wrapped_transformer.transform(original[self._column_names.keys()]).toarray(), + ) one_hot_encoded.columns = self._wrapped_transformer.get_feature_names_out() - unchanged = original.drop(self._column_names, axis=1) + unchanged = original.drop(self._column_names.keys(), axis=1) + + res = Table(pd.concat([unchanged, one_hot_encoded], axis=1)) + column_names = [] + + for name in table.get_column_names(): + if name not in self._column_names.keys(): + column_names.append(name) + else: + column_names.extend( + [f_name for f_name in self._wrapped_transformer.get_feature_names_out() if f_name.startswith(name)], + ) + res = res.sort_columns(lambda col1, col2: column_names.index(col1.name) - column_names.index(col2.name)) - return Table(pd.concat([unchanged, one_hot_encoded], axis=1)) + return res # noinspection PyProtectedMember def inverse_transform(self, transformed_table: Table) -> Table: @@ -120,12 +137,29 @@ def inverse_transform(self, transformed_table: Table) -> Table: data.columns = transformed_table.get_column_names() decoded = pd.DataFrame( - self._wrapped_transformer.inverse_transform(transformed_table._data), - columns=self._column_names, + self._wrapped_transformer.inverse_transform( + transformed_table.keep_only_columns(self._wrapped_transformer.get_feature_names_out())._data, + ), + columns=list(self._column_names.keys()), ) unchanged = data.drop(self._wrapped_transformer.get_feature_names_out(), axis=1) - return Table(pd.concat([unchanged, decoded], axis=1)) + res = Table(pd.concat([unchanged, decoded], axis=1)) + column_names = [ + name + if name not in [value for value_list in list(self._column_names.values()) for value in value_list] + else list(self._column_names.keys())[ + [ + list(self._column_names.values()).index(value) + for value in list(self._column_names.values()) + if name in value + ][0] + ] + for name in transformed_table.get_column_names() + ] + res = res.sort_columns(lambda col1, col2: column_names.index(col1.name) - column_names.index(col2.name)) + + return res def is_fitted(self) -> bool: """ diff --git a/tests/safeds/data/tabular/transformation/test_one_hot_encoder.py b/tests/safeds/data/tabular/transformation/test_one_hot_encoder.py index cf98feeb2..a57f3e739 100644 --- a/tests/safeds/data/tabular/transformation/test_one_hot_encoder.py +++ b/tests/safeds/data/tabular/transformation/test_one_hot_encoder.py @@ -107,14 +107,34 @@ class TestFitAndTransform: ["col1"], Table.from_dict( { + "col1_a": [1.0, 0.0, 0.0, 0.0], + "col1_b": [0.0, 1.0, 1.0, 0.0], + "col1_c": [0.0, 0.0, 0.0, 1.0], "col2": ["a", "b", "b", "c"], + }, + ), + ), + ( + Table.from_dict( + { + "col1": ["a", "b", "b", "c"], + "col2": ["a", "b", "b", "c"], + }, + ), + ["col1", "col2"], + Table.from_dict( + { "col1_a": [1.0, 0.0, 0.0, 0.0], "col1_b": [0.0, 1.0, 1.0, 0.0], "col1_c": [0.0, 0.0, 0.0, 1.0], + "col2_a": [1.0, 0.0, 0.0, 0.0], + "col2_b": [0.0, 1.0, 1.0, 0.0], + "col2_c": [0.0, 0.0, 0.0, 1.0], }, ), ), ], + ids=["all columns", "one column", "multiple columns"], ) def test_should_return_transformed_table( self, @@ -144,19 +164,81 @@ def test_should_not_change_original_table(self) -> None: class TestInverseTransform: @pytest.mark.parametrize( - "table", + ("table_to_fit", "column_names", "table_to_transform"), [ - Table.from_dict( - { - "col1": ["a", "b", "b", "c"], - }, + ( + Table.from_dict( + { + "a": [1.0, 0.0, 0.0, 0.0], + "b": ["a", "b", "b", "c"], + "c": [0.0, 0.0, 0.0, 1.0], + }, + ), + ["b"], + Table.from_dict( + { + "a": [1.0, 0.0, 0.0, 0.0], + "b": ["a", "b", "b", "c"], + "c": [0.0, 0.0, 0.0, 1.0], + }, + ), ), + ( + Table.from_dict( + { + "a": [1.0, 0.0, 0.0, 0.0], + "b": ["a", "b", "b", "c"], + "c": [0.0, 0.0, 0.0, 1.0], + }, + ), + ["b"], + Table.from_dict( + { + "c": [0.0, 0.0, 0.0, 1.0], + "b": ["a", "b", "b", "c"], + "a": [1.0, 0.0, 0.0, 0.0], + }, + ), + ), + ( + Table.from_dict( + { + "a": [1.0, 0.0, 0.0, 0.0], + "b": ["a", "b", "b", "c"], + "bb": ["a", "b", "b", "c"], + }, + ), + ["b", "bb"], + Table.from_dict( + { + "a": [1.0, 0.0, 0.0, 0.0], + "b": ["a", "b", "b", "c"], + "bb": ["a", "b", "b", "c"], + }, + ), + ), + ], + ids=[ + "same table to fit and transform", + "different tables to fit and transform", + "one column name is a prefix of another column name", ], ) - def test_should_return_original_table(self, table: Table) -> None: - transformer = OneHotEncoder().fit(table) + def test_should_return_original_table( + self, + table_to_fit: Table, + column_names: list[str], + table_to_transform: Table, + ) -> None: + transformer = OneHotEncoder().fit(table_to_fit, column_names) + + result = transformer.inverse_transform(transformer.transform(table_to_transform)) - assert transformer.inverse_transform(transformer.transform(table)) == table + # This checks whether the columns are in the same order + assert result.get_column_names() == table_to_transform.get_column_names() + # This is subsumed by the next assertion, but we get a better error message + assert result.schema == table_to_transform.schema + assert result == table_to_transform def test_should_not_change_transformed_table(self) -> None: table = Table.from_dict(