Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update update_transformers validation #563

Merged
merged 5 commits into from
Oct 13, 2022
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions rdt/errors.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,3 +7,7 @@ class NotFittedError(Exception):

class Error(Exception):
"""Error to raise when ``HyperTransformer`` produces a controlled error message."""


class InvalidSdtypeForTransformerError(Exception):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we should change this error name to match the ones we created in the doc. You won't need to catch this error anymore in the integration tests if you do the other change

"""Error to raise when the sdtype is not supported by the transformer."""
14 changes: 5 additions & 9 deletions rdt/hyper_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
import pandas as pd
import yaml

from rdt.errors import Error, NotFittedError
from rdt.errors import Error, InvalidSdtypeForTransformerError, NotFittedError
from rdt.transformers import (
BaseTransformer, get_class_by_transformer_name, get_default_transformer,
get_transformer_instance, get_transformers_by_type)
Expand Down Expand Up @@ -442,21 +442,17 @@ def update_transformers(self, column_name_to_transformer):
self._validate_update_columns(update_columns)
self._validate_transformers(column_name_to_transformer)

incompatible_sdtypes = []
for column_name, transformer in column_name_to_transformer.items():
if transformer is not None:
current_sdtype = self.field_sdtypes.get(column_name)
if current_sdtype and current_sdtype not in transformer.get_supported_sdtypes():
incompatible_sdtypes.append(column_name)
raise InvalidSdtypeForTransformerError(
f"Column '{column_name}' is a {current_sdtype} column, which is "
f"incompatible with the '{type(transformer).__name__}' transformer."
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🤔 Should we add a get_name method to the BaseTransformer and use that whenever we want to get the __name__ ?

)

self.field_transformers[column_name] = transformer

if incompatible_sdtypes:
warnings.warn(
"Some transformers you've assigned are not compatible with the sdtypes. "
f"Use 'update_sdtypes' to update: {incompatible_sdtypes}"
)

self._modified_config = True

def remove_transformers(self, column_names):
Expand Down
7 changes: 6 additions & 1 deletion tests/integration/test_transformers.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import pytest

from rdt import HyperTransformer
from rdt.errors import InvalidSdtypeForTransformerError
from rdt.performance.datasets import BaseDatasetGenerator
from rdt.transformers import BaseTransformer

Expand Down Expand Up @@ -255,7 +256,11 @@ def _test_transformer_with_hypertransformer(transformer_class, input_data, steps
}

hypertransformer.detect_initial_config(input_data)
Copy link
Contributor

@amontanez24 amontanez24 Oct 13, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

you can just delete this line now. You only need to detect if you don't set it

hypertransformer.update_transformers(field_transformers)
try:
hypertransformer.update_transformers(field_transformers)
except InvalidSdtypeForTransformerError:
pass
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think instead we could use the set_config method and avoid having this error get raised altogether

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It will still run into an error, since we have a version of it validating set_config as well:

if mismatched_columns:
raise Error(
"Some transformers you've assigned are not compatible with the sdtypes. "
f'Please change the following columns: {mismatched_columns}'
)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

When we set the config can't we force the sdtype to match?


hypertransformer.fit(input_data)

transformed = hypertransformer.transform(input_data)
Expand Down
17 changes: 8 additions & 9 deletions tests/unit/test_hyper_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
import pytest

from rdt import HyperTransformer
from rdt.errors import Error, NotFittedError
from rdt.errors import Error, InvalidSdtypeForTransformerError, NotFittedError
from rdt.transformers import (
AnonymizedFaker, BinaryEncoder, FloatFormatter, FrequencyEncoder, GaussianNormalizer,
LabelEncoder, OneHotEncoder, RegexGenerator, UnixTimestampEncoder)
Expand Down Expand Up @@ -2484,16 +2484,15 @@ def test_update_transformers_missmatch_sdtypes(self, mock_warnings):
'my_column': transformer
}

# Run
instance.update_transformers(column_name_to_transformer)

# Assert
expected_call = (
"Some transformers you've assigned are not compatible with the sdtypes. "
f"Use 'update_sdtypes' to update: {'my_column'}"
# Run and Assert
err_msg = re.escape(
"Column 'my_column' is a categorical column, which is incompatible "
"with the 'BinaryEncoder' transformer."
)
with pytest.raises(InvalidSdtypeForTransformerError, match=err_msg):
instance.update_transformers(column_name_to_transformer)

assert mock_warnings.called_once_with(expected_call)
assert mock_warnings.called_once_with(err_msg)
instance._validate_transformers.assert_called_once_with(column_name_to_transformer)

def test_update_transformers_transformer_is_none(self):
Expand Down