Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update update_transformers validation #563

Merged
merged 5 commits into from
Oct 13, 2022
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions rdt/errors.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,3 +7,7 @@ class NotFittedError(Exception):

class Error(Exception):
"""Error to raise when ``HyperTransformer`` produces a controlled error message."""


class TransformerInputError(Exception):
"""Error to raise when ``HyperTransformer`` receives incorrect input."""
16 changes: 6 additions & 10 deletions rdt/hyper_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
import pandas as pd
import yaml

from rdt.errors import Error, NotFittedError
from rdt.errors import Error, NotFittedError, TransformerInputError
from rdt.transformers import (
BaseTransformer, get_class_by_transformer_name, get_default_transformer,
get_transformer_instance, get_transformers_by_type)
Expand Down Expand Up @@ -442,21 +442,17 @@ def update_transformers(self, column_name_to_transformer):
self._validate_update_columns(update_columns)
self._validate_transformers(column_name_to_transformer)

incompatible_sdtypes = []
for column_name, transformer in column_name_to_transformer.items():
if transformer is not None:
current_sdtype = self.field_sdtypes.get(column_name)
if current_sdtype and current_sdtype not in transformer.get_supported_sdtypes():
incompatible_sdtypes.append(column_name)
raise TransformerInputError(
f"Column '{column_name}' is a {current_sdtype} column, which is "
f"incompatible with the '{transformer.get_name()}' transformer."
)

self.field_transformers[column_name] = transformer

if incompatible_sdtypes:
warnings.warn(
"Some transformers you've assigned are not compatible with the sdtypes. "
f"Use 'update_sdtypes' to update: {incompatible_sdtypes}"
)

self._modified_config = True

def remove_transformers(self, column_names):
Expand Down Expand Up @@ -603,7 +599,7 @@ def _get_transformer_tree_yaml(self):
"""
modified_tree = deepcopy(self._transformers_tree)
for field in modified_tree:
class_name = modified_tree[field]['transformer'].__class__.__name__
class_name = modified_tree[field]['transformer'].__class__.get_name()
modified_tree[field]['transformer'] = class_name

return yaml.safe_dump(dict(modified_tree))
Expand Down
4 changes: 2 additions & 2 deletions rdt/performance/performance.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ def evaluate_transformer_performance(transformer, dataset_generator, verbose=Fal
pandas.DataFrame:
The performance test results.
"""
transformer_args = TRANSFORMER_ARGS.get(transformer.__name__, {})
transformer_args = TRANSFORMER_ARGS.get(transformer.get_name(), {})
transformer_instance = transformer(**transformer_args)

sizes = _get_dataset_sizes(dataset_generator.SDTYPE)
Expand All @@ -102,7 +102,7 @@ def evaluate_transformer_performance(transformer, dataset_generator, verbose=Fal
performance['Number of fit rows'] = fit_size
performance['Number of transform rows'] = transform_size
performance['Dataset'] = dataset_generator.__name__
performance['Transformer'] = f'{transformer.__module__ }.{transformer.__name__}'
performance['Transformer'] = f'{transformer.__module__ }.{transformer.get_name()}'

out.append(performance)

Expand Down
4 changes: 2 additions & 2 deletions rdt/transformers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ def get_transformer_name(transformer):
The path of the transformer.
"""
if inspect.isclass(transformer):
return transformer.__module__ + '.' + transformer.__name__
return transformer.__module__ + '.' + transformer.get_name()

raise ValueError(f'The transformer {transformer} must be passed as a class.')

Expand Down Expand Up @@ -106,7 +106,7 @@ def get_class_by_transformer_name():
BaseTransformer:
BaseTransformer subclass class object.
"""
return {class_.__name__: class_ for class_ in BaseTransformer.get_subclasses()}
return {class_.get_name(): class_ for class_ in BaseTransformer.get_subclasses()}


def get_transformer_class(transformer):
Expand Down
12 changes: 11 additions & 1 deletion rdt/transformers/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,16 @@ class BaseTransformer:
column_prefix = None
output_columns = None

@classmethod
def get_name(cls):
"""Return transformer name.

Returns:
str:
Transformer name.
"""
return cls.__name__

@classmethod
def get_subclasses(cls):
"""Recursively find subclasses of this Baseline.
Expand Down Expand Up @@ -208,7 +218,7 @@ def __repr__(self):
str:
The name of the transformer followed by any non-default parameters.
"""
class_name = self.__class__.__name__
class_name = self.__class__.get_name()
custom_args = []
args = inspect.getfullargspec(self.__init__)
keys = args.args[1:]
Expand Down
2 changes: 1 addition & 1 deletion rdt/transformers/pii/anonymizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,7 +161,7 @@ def __repr__(self):
str:
The name of the transformer followed by any non-default parameters.
"""
class_name = self.__class__.__name__
class_name = self.__class__.get_name()
custom_args = []
args = inspect.getfullargspec(self.__init__)
keys = args.args[1:]
Expand Down
6 changes: 3 additions & 3 deletions tests/code_style.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ def validate_transformer_addon(transformer):
module_py = True
elif document.match('config.json'):
config_json_exist = True
_validate_config_json(document, transformer.__name__)
_validate_config_json(document, transformer.get_name())

assert init_file_exist, 'Missing __init__.py file within the addon folder.'
assert config_json_exist, 'Missing the config.json file within the addon folder.'
Expand All @@ -85,7 +85,7 @@ def validate_transformer_addon(transformer):

def validate_transformer_importable_from_parent_module(transformer):
"""Validate wheter the transformer can be imported from the parent module."""
name = transformer.__name__
name = transformer.get_name()
module = getattr(transformer, '__module__', '')
module = module.rsplit('.', 1)[0]
imported_transformer = getattr(importlib.import_module(module), name, None)
Expand Down Expand Up @@ -156,7 +156,7 @@ def validate_test_names(transformer):
test_file = get_test_location(transformer)
module = _load_module_from_path(test_file)

test_class = getattr(module, f'Test{transformer.__name__}', None)
test_class = getattr(module, f'Test{transformer.get_name()}', None)
assert test_class is not None, 'The expected test class was not found.'

test_functions = inspect.getmembers(test_class, predicate=inspect.isfunction)
Expand Down
16 changes: 8 additions & 8 deletions tests/contributing.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ def validate_transformer_integration(transformer):
if isinstance(transformer, str):
transformer = get_transformer_class(transformer)

print(f'Validating Integration Tests for transformer {transformer.__name__}\n')
print(f'Validating Integration Tests for transformer {transformer.get_name()}\n')

steps = []
validation_error = None
Expand Down Expand Up @@ -384,14 +384,14 @@ def validate_transformer_quality(transformer):
if isinstance(transformer, str):
transformer = get_transformer_class(transformer)

print(f'Validating Quality Tests for transformer {transformer.__name__}\n')
print(f'Validating Quality Tests for transformer {transformer.get_name()}\n')

input_sdtype = transformer.get_input_sdtype()
test_cases = get_test_cases({input_sdtype})
regression_scores = get_regression_scores(test_cases, get_transformers_by_type())
results = get_results_table(regression_scores)

transformer_results = results[results['transformer_name'] == transformer.__name__]
transformer_results = results[results['transformer_name'] == transformer.get_name()]
transformer_results = transformer_results.drop('transformer_name', axis=1)
transformer_results['Acceptable'] = False
passing_relative_scores = transformer_results['score_relative_to_average'] > TEST_THRESHOLD
Expand Down Expand Up @@ -430,7 +430,7 @@ def validate_transformer_performance(transformer):
if isinstance(transformer, str):
transformer = get_transformer_class(transformer)

print(f'Validating Performance for transformer {transformer.__name__}\n')
print(f'Validating Performance for transformer {transformer.get_name()}\n')

sdtype = transformer.get_input_sdtype()
transformers = get_transformers_by_type().get(sdtype, [])
Expand All @@ -445,8 +445,8 @@ def validate_transformer_performance(transformer):
results = pd.DataFrame({
'Value': performance.to_numpy(),
'Valid': valid,
'transformer': current_transformer.__name__,
'dataset': dataset_generator.__name__,
'transformer': current_transformer.get_name(),
'dataset': dataset_generator.get_name(),
})
results['Evaluation Metric'] = performance.index
total_results = total_results.append(results)
Expand All @@ -456,10 +456,10 @@ def validate_transformer_performance(transformer):
else:
print('ERROR: One or more Performance Tests were NOT successful.')

other_results = total_results[total_results.transformer != transformer.__name__]
other_results = total_results[total_results.transformer != transformer.get_name()]
average = other_results.groupby('Evaluation Metric')['Value'].mean()

total_results = total_results[total_results.transformer == transformer.__name__]
total_results = total_results[total_results.transformer == transformer.get_name()]
final_results = total_results.groupby('Evaluation Metric').agg({
'Value': 'mean',
'Valid': 'any'
Expand Down
10 changes: 9 additions & 1 deletion tests/integration/test_transformers.py
Original file line number Diff line number Diff line change
Expand Up @@ -254,8 +254,16 @@ def _test_transformer_with_hypertransformer(transformer_class, input_data, steps
TEST_COL: transformer_class
}

sdtypes = {}
for field, transformer in field_transformers.items():
sdtypes[field] = transformer.get_supported_sdtypes()[0]

config = {
'sdtypes': sdtypes,
'transformers': field_transformers
}
hypertransformer.detect_initial_config(input_data)
Copy link
Contributor

@amontanez24 amontanez24 Oct 13, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

you can just delete this line now. You only need to detect if you don't set it

hypertransformer.update_transformers(field_transformers)
hypertransformer.set_config(config)
hypertransformer.fit(input_data)

transformed = hypertransformer.transform(input_data)
Expand Down
2 changes: 1 addition & 1 deletion tests/quality/test_quality.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,7 @@ def get_transformer_regression_scores(data, sdtype, dataset_name, transformers,
transformed_features = transformed_features[~nans]
score = get_regression_score(transformed_features, target)
row = pd.Series({
'transformer_name': transformer.__name__,
'transformer_name': transformer.get_name(),
'dataset_name': dataset_name,
'column': column,
'score': score
Expand Down
17 changes: 8 additions & 9 deletions tests/unit/test_hyper_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
import pytest

from rdt import HyperTransformer
from rdt.errors import Error, NotFittedError
from rdt.errors import Error, NotFittedError, TransformerInputError
from rdt.transformers import (
AnonymizedFaker, BinaryEncoder, FloatFormatter, FrequencyEncoder, GaussianNormalizer,
LabelEncoder, OneHotEncoder, RegexGenerator, UnixTimestampEncoder)
Expand Down Expand Up @@ -2484,16 +2484,15 @@ def test_update_transformers_missmatch_sdtypes(self, mock_warnings):
'my_column': transformer
}

# Run
instance.update_transformers(column_name_to_transformer)

# Assert
expected_call = (
"Some transformers you've assigned are not compatible with the sdtypes. "
f"Use 'update_sdtypes' to update: {'my_column'}"
# Run and Assert
err_msg = re.escape(
"Column 'my_column' is a categorical column, which is incompatible "
"with the 'BinaryEncoder' transformer."
)
with pytest.raises(TransformerInputError, match=err_msg):
instance.update_transformers(column_name_to_transformer)

assert mock_warnings.called_once_with(expected_call)
assert mock_warnings.called_once_with(err_msg)
instance._validate_transformers.assert_called_once_with(column_name_to_transformer)

def test_update_transformers_transformer_is_none(self):
Expand Down