Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add KNearestNeighborsImputer #864

Merged
merged 32 commits into from
Jul 2, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
32 commits
Select commit Hold shift + click to select a range
54ec649
KNN imputer implemenmted
SamanHushi Jun 21, 2024
ba42201
modified __init__
SamanHushi Jun 21, 2024
e80a651
added tests and change a bit
SamanHushi Jun 21, 2024
36640fd
more and better test
SamanHushi Jun 21, 2024
6450e3d
removed typechecking for init
SamanHushi Jun 21, 2024
83f7d92
end of day
SamanHushi Jun 21, 2024
a8e9fd9
wrote all tests and everything working accordingly
LIEeOoNn Jun 25, 2024
aa69dce
renamed a test and removed a wrong todo
LIEeOoNn Jun 28, 2024
8dfb8c4
style: apply automated linter fixes
megalinter-bot Jun 28, 2024
d4375b7
style: apply automated linter fixes
megalinter-bot Jun 28, 2024
054fba2
how should we test the __hash__ function?
LIEeOoNn Jun 28, 2024
1c1bfa4
style: apply automated linter fixes
megalinter-bot Jun 28, 2024
959a250
style: apply automated linter fixes
megalinter-bot Jun 28, 2024
f866aff
removed unreachable code
SamanHushi Jun 28, 2024
18a8e62
Merge branch 'main' into 743-add-knearestneighborsimputer
SamanHushi Jun 28, 2024
d568a1f
added missing word in Knn discription
SamanHushi Jun 28, 2024
b704c03
adjusted tests
SamanHushi Jun 28, 2024
1412a12
Update src/safeds/data/tabular/transformation/_k_nearest_neighbors_im…
SamanHushi Jun 28, 2024
f6f1974
Update src/safeds/data/tabular/transformation/_k_nearest_neighbors_im…
SamanHushi Jun 28, 2024
9e68c09
Update src/safeds/data/tabular/transformation/_k_nearest_neighbors_im…
SamanHushi Jun 28, 2024
7a5a454
Update src/safeds/data/tabular/transformation/_k_nearest_neighbors_im…
SamanHushi Jun 28, 2024
6f5d4ed
Update tests/safeds/data/tabular/transformation/test_k_nearest_neighb…
SamanHushi Jun 28, 2024
e436937
added '_check_bounds' implementation
SamanHushi Jun 28, 2024
ef72b77
added neighbor_count to all tests
SamanHushi Jun 28, 2024
c4caca7
should have 100% conver now and hashing implemented like in SimpleImp…
LIEeOoNn Jul 1, 2024
21f3d0c
style: apply automated linter fixes
megalinter-bot Jul 1, 2024
64dc4a8
Merge branch 'main' into 743-add-knearestneighborsimputer
lars-reimann Jul 1, 2024
55d4a71
added property value_to_replace changed nan into fit and the import also
LIEeOoNn Jul 2, 2024
3b670ca
removed the import of nan into the if statement
LIEeOoNn Jul 2, 2024
3fc7a62
style: apply automated linter fixes
megalinter-bot Jul 2, 2024
489d329
now using var: value_to_replace for correct usage_
LIEeOoNn Jul 2, 2024
c63833a
style: apply automated linter fixes
megalinter-bot Jul 2, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions src/safeds/data/tabular/transformation/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
if TYPE_CHECKING:
from ._discretizer import Discretizer
from ._invertible_table_transformer import InvertibleTableTransformer
from ._k_nearest_neighbors_imputer import KNearestNeighborsImputer
from ._label_encoder import LabelEncoder
from ._one_hot_encoder import OneHotEncoder
from ._range_scaler import RangeScaler
Expand All @@ -27,6 +28,7 @@
"SimpleImputer": "._simple_imputer:SimpleImputer",
"StandardScaler": "._standard_scaler:StandardScaler",
"TableTransformer": "._table_transformer:TableTransformer",
"KNearestNeighborsImputer": "._k_nearest_neighbors_imputer:KNearestNeighborsImputer",
},
)

Expand All @@ -40,4 +42,5 @@
"SimpleImputer",
"StandardScaler",
"TableTransformer",
"KNearestNeighborsImputer",
]
Original file line number Diff line number Diff line change
@@ -0,0 +1,167 @@
from __future__ import annotations

from typing import TYPE_CHECKING

from safeds._utils import _structural_hash
from safeds._validation import _check_bounds, _check_columns_exist, _ClosedBound
from safeds.data.tabular.containers import Table
from safeds.exceptions import TransformerNotFittedError

from ._table_transformer import TableTransformer

if TYPE_CHECKING:
from sklearn.impute import KNNImputer as sk_KNNImputer


class KNearestNeighborsImputer(TableTransformer):
"""
The KNearestNeighborsImputer replaces missing values in given Columns with the mean value of the K-nearest neighbors.

Parameters
----------
neighbor_count:
The number of neighbors to consider when imputing missing values.
column_names:
The list of columns used to impute missing values. If 'None', all columns are used.
value_to_replace:
The placeholder for the missing values. All occurrences of`missing_values` will be imputed.
"""

# ------------------------------------------------------------------------------------------------------------------
# Dunder methods
# ------------------------------------------------------------------------------------------------------------------

def __init__(
self,
neighbor_count: int,
*,
column_names: str | list[str] | None = None,
value_to_replace: float | str | None = None,
) -> None:
super().__init__(column_names)

_check_bounds(name="neighbor_count", actual=neighbor_count, lower_bound=_ClosedBound(1))

# parameter
self._neighbor_count: int = neighbor_count
self._value_to_replace: float | str | None = value_to_replace

# attributes
self._wrapped_transformer: sk_KNNImputer | None = None

def __hash__(self) -> int:
return _structural_hash(
super().__hash__(),
self._neighbor_count,
self._value_to_replace,
# Leave out the internal state for faster hashing
)

# ------------------------------------------------------------------------------------------------------------------
# Properties
# ------------------------------------------------------------------------------------------------------------------

@property
def is_fitted(self) -> bool:
"""Whether the transformer is fitted."""
return self._wrapped_transformer is not None

@property
def neighbor_count(self) -> int:
"""The number of neighbors to consider when imputing missing values."""
return self._neighbor_count

LIEeOoNn marked this conversation as resolved.
Show resolved Hide resolved
@property
def value_to_replace(self) -> float | str | None:
"""The value to replace."""
return self._value_to_replace

# ------------------------------------------------------------------------------------------------------------------
# Learning and transformation
# ------------------------------------------------------------------------------------------------------------------

def fit(self, table: Table) -> KNearestNeighborsImputer:
"""
Learn a transformation for a set of columns in a table.

**Note:** This transformer is not modified.

Parameters
----------
table:
The table used to fit the transformer.

Returns
-------
fitted_transformer:
The fitted transformer.

Raises
------
ColumnNotFoundError
If one of the columns, that should be fitted is not in the table.
"""
from sklearn.impute import KNNImputer as sk_KNNImputer

if table.row_count == 0:
raise ValueError("The KNearestNeighborsImputer cannot be fitted because the table contains 0 rows.")

if self._column_names is None:
column_names = table.column_names
else:
column_names = self._column_names
_check_columns_exist(table, column_names)

value_to_replace = self._value_to_replace

if self._value_to_replace is None:
from numpy import nan

value_to_replace = nan

wrapped_transformer = sk_KNNImputer(n_neighbors=self._neighbor_count, missing_values=value_to_replace)
wrapped_transformer.set_output(transform="polars")
wrapped_transformer.fit(
table.remove_columns_except(column_names)._data_frame,
)

result = KNearestNeighborsImputer(self._neighbor_count, column_names=column_names)
result._wrapped_transformer = wrapped_transformer

return result

def transform(self, table: Table) -> Table:
"""
Apply the learned transformation to a table.

**Note:** The given table is not modified.

Parameters
----------
table:
The table to wich the learned transformation is applied.

Returns
-------
transformed_table:
The transformed table.

Raises
------
TransformerNotFittedError
If the transformer is not fitted.
ColumnNotFoundError
If one of the columns, that should be transformed is not in the table.
"""
if self._column_names is None or self._wrapped_transformer is None:
raise TransformerNotFittedError

_check_columns_exist(table, self._column_names)

new_data = self._wrapped_transformer.transform(
table.remove_columns_except(self._column_names)._data_frame,
)

return Table._from_polars_lazy_frame(
table._lazy_frame.with_columns(new_data),
)
Loading