Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Add StandardScaler transformer #316

Merged
merged 35 commits into from
Jun 7, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
35 commits
Select commit Hold shift + click to select a range
9cc0e35
added RangeScaler class
sibre28 May 19, 2023
f4233ac
Refactor naming scheme to camel case
zzril May 19, 2023
13d861c
added __init__ and imports
sibre28 May 19, 2023
267f0c6
Merge remote-tracking branch 'origin/141-normalize-table' into 141-no…
sibre28 May 19, 2023
2fee42f
added all methods
sibre28 May 19, 2023
f096b73
added ValueError
sibre28 May 19, 2023
7fdef9d
updated imports and __init__ method fields
sibre28 May 19, 2023
94d647a
Merge branch 'main' into 141-normalize-table
sibre28 May 19, 2023
1e3f11c
merged main
sibre28 May 19, 2023
d1a38d0
added docstring to RangeScaler
sibre28 May 19, 2023
fa0e5a5
Add message ot ValueError in init
zzril May 19, 2023
8593b85
Updated Docstrings and imports
sibre28 May 19, 2023
724bffd
added test file
sibre28 May 19, 2023
14480ec
Add test for ValueError in init
zzril May 19, 2023
0e73126
style: apply automated linter fixes
megalinter-bot May 19, 2023
0873bbb
changed existing tests, copied from labelEncoder to fit rangeScaler
sibre28 May 19, 2023
21cc982
Merge remote-tracking branch 'origin/141-normalize-table' into 141-no…
sibre28 May 19, 2023
c932119
Fix test for should not change case
zzril May 19, 2023
f2f9d21
Add template for StandardScaler
zzril May 19, 2023
26c6a56
Fix init method
zzril May 19, 2023
abba5e4
Add template for tests
zzril May 19, 2023
ad248ac
Drop unneccessary constructor test
zzril May 19, 2023
8d29399
Start adding proper testcases
zzril May 19, 2023
1924673
Merge branch 'main' into 142-standardize-table
zzril May 19, 2023
4f3a5af
Start refactoring tests
zzril May 19, 2023
719f3d7
Fix tests
zzril May 26, 2023
fa5ad0b
Merge branch 'main' into 142-standardize-table
lars-reimann May 26, 2023
f514b48
Remove warnings
zzril May 26, 2023
7f41501
changed approximation method
sibre28 May 26, 2023
4e01f8e
Merge branch 'main' into 142-standardize-table
lars-reimann May 26, 2023
cf6abba
fix: linter issue
lars-reimann May 26, 2023
2a670cb
test: move new method to another file since it's unrelated to test re…
lars-reimann May 26, 2023
dfd4fc6
Merge branch 'main' into 142-standardize-table
zzril Jun 1, 2023
c59a9b0
Merge branch 'main' into 142-standardize-table
zzril Jun 2, 2023
30dc13a
Merge branch 'main' into 142-standardize-table
zzril Jun 6, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions src/safeds/data/tabular/transformation/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from ._label_encoder import LabelEncoder
from ._one_hot_encoder import OneHotEncoder
from ._range_scaler import RangeScaler
from ._standard_scaler import StandardScaler
from ._table_transformer import InvertibleTableTransformer, TableTransformer

__all__ = [
Expand All @@ -13,4 +14,5 @@
"InvertibleTableTransformer",
"TableTransformer",
"RangeScaler",
"StandardScaler",
]
180 changes: 180 additions & 0 deletions src/safeds/data/tabular/transformation/_standard_scaler.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,180 @@
from __future__ import annotations

from sklearn.preprocessing import StandardScaler as sk_StandardScaler

from safeds.data.tabular.containers import Table
from safeds.data.tabular.transformation._table_transformer import InvertibleTableTransformer
from safeds.exceptions import TransformerNotFittedError, UnknownColumnNameError


class StandardScaler(InvertibleTableTransformer):
"""The StandardScaler transforms column values by scaling each value to a given range."""

def __init__(self) -> None:
self._column_names: list[str] | None = None
self._wrapped_transformer: sk_StandardScaler | None = None

def fit(self, table: Table, column_names: list[str] | None) -> StandardScaler:
"""
Learn a transformation for a set of columns in a table.

This transformer is not modified.

Parameters
----------
table : Table
The table used to fit the transformer.
column_names : Optional[list[str]]
The list of columns from the table used to fit the transformer. If `None`, all columns are used.

Returns
-------
fitted_transformer : TableTransformer
The fitted transformer.
"""
if column_names is None:
column_names = table.column_names
else:
missing_columns = set(column_names) - set(table.column_names)
if len(missing_columns) > 0:
raise UnknownColumnNameError(list(missing_columns))

wrapped_transformer = sk_StandardScaler()
wrapped_transformer.fit(table._data[column_names])

result = StandardScaler()
result._wrapped_transformer = wrapped_transformer
result._column_names = column_names

return result

def transform(self, table: Table) -> Table:
"""
Apply the learned transformation to a table.

The table is not modified.

Parameters
----------
table : Table
The table to which the learned transformation is applied.

Returns
-------
transformed_table : Table
The transformed table.

Raises
------
TransformerNotFittedError
If the transformer has not been fitted yet.
"""
# Transformer has not been fitted yet
if self._wrapped_transformer is None or self._column_names is None:
raise TransformerNotFittedError

# Input table does not contain all columns used to fit the transformer
missing_columns = set(self._column_names) - set(table.column_names)
if len(missing_columns) > 0:
raise UnknownColumnNameError(list(missing_columns))

data = table._data.copy()
data.columns = table.column_names
data[self._column_names] = self._wrapped_transformer.transform(data[self._column_names])
return Table._from_pandas_dataframe(data)

def inverse_transform(self, transformed_table: Table) -> Table:
"""
Undo the learned transformation.

The table is not modified.

Parameters
----------
transformed_table : Table
The table to be transformed back to the original version.

Returns
-------
table : Table
The original table.

Raises
------
TransformerNotFittedError
If the transformer has not been fitted yet.
"""
# Transformer has not been fitted yet
if self._wrapped_transformer is None or self._column_names is None:
raise TransformerNotFittedError

data = transformed_table._data.copy()
data.columns = transformed_table.column_names
data[self._column_names] = self._wrapped_transformer.inverse_transform(data[self._column_names])
return Table._from_pandas_dataframe(data)

def is_fitted(self) -> bool:
"""
Check if the transformer is fitted.

Returns
-------
is_fitted : bool
Whether the transformer is fitted.
"""
return self._wrapped_transformer is not None

def get_names_of_added_columns(self) -> list[str]:
"""
Get the names of all new columns that have been added by the StandardScaler.

Returns
-------
added_columns : list[str]
A list of names of the added columns, ordered as they will appear in the table.

Raises
------
TransformerNotFittedError
If the transformer has not been fitted yet.
"""
if not self.is_fitted():
raise TransformerNotFittedError
return []

# (Must implement abstract method, cannot instantiate class otherwise.)
def get_names_of_changed_columns(self) -> list[str]:
"""
Get the names of all columns that may have been changed by the StandardScaler.

Returns
-------
changed_columns : list[str]
The list of (potentially) changed column names, as passed to fit.

Raises
------
TransformerNotFittedError
If the transformer has not been fitted yet.
"""
if self._column_names is None:
raise TransformerNotFittedError
return self._column_names

def get_names_of_removed_columns(self) -> list[str]:
"""
Get the names of all columns that have been removed by the StandardScaler.

Returns
-------
removed_columns : list[str]
A list of names of the removed columns, ordered as they appear in the table the StandardScaler was fitted on.

Raises
------
TransformerNotFittedError
If the transformer has not been fitted yet.
"""
if not self.is_fitted():
raise TransformerNotFittedError
return []
3 changes: 2 additions & 1 deletion tests/helpers/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from ._assertions import assert_that_tables_are_close
from ._resources import resolve_resource_path

__all__ = ["resolve_resource_path"]
__all__ = ["assert_that_tables_are_close", "resolve_resource_path"]
24 changes: 24 additions & 0 deletions tests/helpers/_assertions.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
import pytest
from safeds.data.tabular.containers import Table


def assert_that_tables_are_close(table1: Table, table2: Table) -> None:
"""
Assert that two tables are almost equal.

Parameters
----------
table1: Table
The first table.
table2: Table
The table to compare the first table to.
"""
assert table1.schema == table2.schema
for column_name in table1.column_names:
assert table1.get_column(column_name).type == table2.get_column(column_name).type
assert table1.get_column(column_name).type.is_numeric()
assert table2.get_column(column_name).type.is_numeric()
for i in range(table1.number_of_rows):
entry_1 = table1.get_column(column_name).get_value(i)
entry_2 = table2.get_column(column_name).get_value(i)
assert entry_1 == pytest.approx(entry_2)
Loading