Skip to content

Commit

Permalink
HyperTransformer API name changes
Browse files Browse the repository at this point in the history
  • Loading branch information
amontanez24 committed Oct 21, 2021
1 parent 08d9f28 commit dfdabc0
Show file tree
Hide file tree
Showing 4 changed files with 45 additions and 45 deletions.
34 changes: 17 additions & 17 deletions rdt/hyper_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,10 @@ class HyperTransformer:
names. Keys can also specify transformers for fields derived by other transformers.
This can be done by concatenating the name of the original field to the output name
using ``.`` as a separator (eg. {field_name}.{transformer_output_name}).
field_types (dict or None):
field_data_types (dict or None):
Dict mapping field names to their data types. If not provided, the data type is
inferred using the column's Pandas ``dtype``.
data_type_transformers (dict or None):
default_data_type_transformers (dict or None):
Dict used to overwrite the default transformer for a data type. The keys are
data types and the values are Transformers or Transformer instances.
copy (bool):
Expand All @@ -40,11 +40,11 @@ class HyperTransformer:
Create a ``HyperTransformer`` passing a dict mapping fields to data types.
>>> field_types = {
>>> field_data_types = {
... 'a': 'categorical',
... 'b': 'numerical
... }
>>> ht = HyperTransformer(field_types=field_types)
>>> ht = HyperTransformer(field_data_types=field_data_types)
Create a ``HyperTransformer`` passing a ``field_transformers`` dict.
(Note: The transformers used in this example may not exist and are just used
Expand All @@ -58,11 +58,11 @@ class HyperTransformer:
>>> ht = HyperTransformer(field_transformers=field_transformers)
Create a ``HyperTransformer`` passing a dict mapping data types to transformers.
>>> data_type_transformers = {
>>> default_data_type_transformers = {
... 'categorical': LabelEncodingTransformer(),
... 'numerical': NumericalTransformer()
... }
>>> ht = HyperTransformer(data_type_transformers=data_type_transformers)
>>> ht = HyperTransformer(default_data_type_transformers=default_data_type_transformers)
"""

_DTYPES_TO_DATA_TYPES = {
Expand Down Expand Up @@ -102,7 +102,7 @@ def _subset(input_list, other_list, not_in=False):

def _create_multi_column_fields(self):
multi_column_fields = {}
for field in list(self.field_types) + list(self.field_transformers):
for field in list(self.field_data_types) + list(self.field_transformers):
if isinstance(field, tuple):
for column in field:
multi_column_fields[column] = field
Expand All @@ -117,11 +117,11 @@ def _validate_field_transformers(self):

self._add_field_to_set(field, self._specified_fields)

def __init__(self, copy=True, field_types=None, data_type_transformers=None,
def __init__(self, copy=True, field_data_types=None, default_data_type_transformers=None,
field_transformers=None, transform_output_types=None):
self.copy = copy
self.field_types = field_types or {}
self.data_type_transformers = data_type_transformers or {}
self.field_data_types = field_data_types or {}
self.default_data_type_transformers = default_data_type_transformers or {}
self.field_transformers = field_transformers or {}
self._specified_fields = set()
self._validate_field_transformers()
Expand All @@ -137,15 +137,15 @@ def _field_in_data(field, data):
all_columns_in_data = isinstance(field, tuple) and all(col in data for col in field)
return field in data or all_columns_in_data

def _update_field_types(self, data):
def _update_field_data_types(self, data):
# get set of provided fields including multi-column fields
provided_fields = set()
for field in self.field_types.keys():
for field in self.field_data_types.keys():
self._add_field_to_set(field, provided_fields)

for field in data:
if field not in provided_fields:
self.field_types[field] = self._DTYPES_TO_DATA_TYPES[data[field].dtype.kind]
self.field_data_types[field] = self._DTYPES_TO_DATA_TYPES[data[field].dtype.kind]

def _get_next_transformer(self, output_field, output_type, next_transformers):
next_transformer = None
Expand Down Expand Up @@ -215,17 +215,17 @@ def fit(self, data):
Data to fit the transformers to.
"""
self._input_columns = list(data.columns)
self._update_field_types(data)
self._update_field_data_types(data)

# Loop through field_transformers that are first level
for field in self.field_transformers:
if self._field_in_data(field, data):
data = self._fit_field_transformer(data, field, self.field_transformers[field])

for (field, data_type) in self.field_types.items():
for (field, data_type) in self.field_data_types.items():
if not self._field_in_set(field, self._fitted_fields):
if data_type in self.data_type_transformers:
transformer = self.data_type_transformers[data_type]
if data_type in self.default_data_type_transformers:
transformer = self.default_data_type_transformers[data_type]
else:
transformer = get_default_transformer(data_type)

Expand Down
6 changes: 3 additions & 3 deletions tests/integration/test_hyper_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -195,7 +195,7 @@ def test_hypertransformer_default_inputs():
transformers to use for each field.
Setup:
- `data_type_transformers` will be set to use the `CategoricalTransformer`
- `default_data_type_transformers` will be set to use the `CategoricalTransformer`
for categorical data types so that the output is predictable.
Input:
Expand All @@ -218,7 +218,7 @@ def test_hypertransformer_default_inputs():
expected_transformed = get_transformed_data(index)
expected_reversed = get_reversed(index)

ht = HyperTransformer(data_type_transformers={'categorical': CategoricalTransformer})
ht = HyperTransformer(default_data_type_transformers={'categorical': CategoricalTransformer})
ht.fit(data)
transformed = ht.transform(data)

Expand Down Expand Up @@ -397,7 +397,7 @@ def test_with_unfitted_columns():
"""HyperTransform should be able to transform even if there are unseen columns in data."""
data = get_input_data_without_nan()

ht = HyperTransformer(data_type_transformers={'categorical': CategoricalTransformer})
ht = HyperTransformer(default_data_type_transformers={'categorical': CategoricalTransformer})
ht.fit(data)

new_data = get_input_data_without_nan()
Expand Down
2 changes: 1 addition & 1 deletion tests/quality/test_quality.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ def get_transformer_regression_scores(data, data_type, dataset_name, transformer
target = numerical_transformer.fit_transform(target, list(target.columns))
target = format_array(target)
for transformer in transformers:
ht = HyperTransformer(data_type_transformers={data_type: transformer})
ht = HyperTransformer(default_data_type_transformers={data_type: transformer})
ht.fit(features)
transformed_features = ht.transform(features).to_numpy()
score = get_regression_score(transformed_features, target)
Expand Down
48 changes: 24 additions & 24 deletions tests/unit/test_hyper_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,21 +56,21 @@ def test___init__(self, validation_mock, multi_column_mock):
# Asserts
assert ht.copy is True
assert ht.field_transformers == {}
assert ht.data_type_transformers == {}
assert ht.field_types == {}
assert ht.default_data_type_transformers == {}
assert ht.field_data_types == {}
multi_column_mock.assert_called_once()
validation_mock.assert_called_once()

def test__create_multi_column_fields(self):
"""Test the ``_create_multi_column_fields`` method.
This tests that the method goes through both the ``field_transformers``
dict and ``field_types`` dict to find multi_column fields and map
dict and ``field_data_types`` dict to find multi_column fields and map
each column to its corresponding tuple.
Setup:
- instance.field_transformers will be populated with multi-column fields
- instance.field_types will be populated with multi-column fields
- instance.field_data_types will be populated with multi-column fields
Output:
- A dict mapping each column name that is part of a multi-column
Expand All @@ -84,7 +84,7 @@ def test__create_multi_column_fields(self):
('c', 'd'): DatetimeTransformer,
'e': NumericalTransformer
}
ht.field_types = {
ht.field_data_types = {
'f': 'categorical',
('g', 'h'): 'datetime'
}
Expand All @@ -111,7 +111,7 @@ def test__get_next_transformer_field_transformer(self):
Setup:
- field_transformers is given a transformer for the
output field.
- data_type_transformers will be given a different transformer
- default_data_type_transformers will be given a different transformer
for the output type of the output field.
Input:
Expand All @@ -126,7 +126,7 @@ def test__get_next_transformer_field_transformer(self):
transformer = NumericalTransformer()
ht = HyperTransformer(
field_transformers={'a.out': transformer},
data_type_transformers={'numerical': GaussianCopulaTransformer()}
default_data_type_transformers={'numerical': GaussianCopulaTransformer()}
)

# Run
Expand All @@ -144,7 +144,7 @@ def test__get_next_transformer_final_output_type(self):
is returned.
Setup:
- data_type_transformers will be given a transformer
- default_data_type_transformers will be given a transformer
for the output type of the output field.
Input:
Expand All @@ -157,7 +157,7 @@ def test__get_next_transformer_final_output_type(self):
"""
# Setup
ht = HyperTransformer(
data_type_transformers={'numerical': GaussianCopulaTransformer()}
default_data_type_transformers={'numerical': GaussianCopulaTransformer()}
)

# Run
Expand All @@ -176,7 +176,7 @@ def test__get_next_transformer_next_transformers(self):
field, then it is used.
Setup:
- data_type_transformers will be given a transformer
- default_data_type_transformers will be given a transformer
for the output type of the output field.
Input:
Expand All @@ -191,7 +191,7 @@ def test__get_next_transformer_next_transformers(self):
# Setup
transformer = CategoricalTransformer()
ht = HyperTransformer(
data_type_transformers={'categorical': OneHotEncodingTransformer()}
default_data_type_transformers={'categorical': OneHotEncodingTransformer()}
)
next_transformers = {'a.out': transformer}

Expand Down Expand Up @@ -225,7 +225,7 @@ def test__get_next_transformer_default_transformer(self, mock):
transformer = CategoricalTransformer(fuzzy=True)
mock.return_value = transformer
ht = HyperTransformer(
data_type_transformers={'categorical': OneHotEncodingTransformer()}
default_data_type_transformers={'categorical': OneHotEncodingTransformer()}
)

# Run
Expand All @@ -235,15 +235,15 @@ def test__get_next_transformer_default_transformer(self, mock):
assert isinstance(next_transformer, CategoricalTransformer)
assert next_transformer.fuzzy is True

def test__update_field_types(self):
"""Test the ``_update_field_types`` method.
def test__update_field_data_types(self):
"""Test the ``_update_field_data_types`` method.
This tests that if any field types are missing in the
provided field_types dict, that the rest of the values
provided field_data_types dict, that the rest of the values
are filled in using the data types for the dtype.
Setup:
- field_types will only define a few of the fields.
- field_data_types will only define a few of the fields.
Input:
- A DataFrame of various types.
Expand All @@ -253,7 +253,7 @@ def test__update_field_types(self):
the data.
"""
# Setup
ht = HyperTransformer(field_types={'a': 'numerical', 'b': 'categorical'})
ht = HyperTransformer(field_data_types={'a': 'numerical', 'b': 'categorical'})
data = pd.DataFrame({
'a': [1, 2, 3],
'b': ['category1', 'category2', 'category3'],
Expand All @@ -262,11 +262,11 @@ def test__update_field_types(self):
})

# Run
ht._update_field_types(data)
ht._update_field_data_types(data)

# Assert
expected = {'a': 'numerical', 'b': 'categorical', 'c': 'boolean', 'd': 'float'}
assert ht.field_types == expected
assert ht.field_data_types == expected

@patch('rdt.hyper_transformer.load_transformer')
def test__fit_field_transformer(self, load_transformer_mock):
Expand Down Expand Up @@ -558,9 +558,9 @@ def test_fit(self, get_default_transformer_mock):
"""Test the ``fit`` method.
Tests that the ``fit`` method loops through the fields in ``field_transformers``
and ``field_types`` that are in the data. It should try to find a transformer
in ``data_type_transformers`` and then use the default if it doesn't find one
when looping through ``field_types``. It should then call ``_fit_field_transformer``
and ``field_data_types`` that are in the data. It should try to find a transformer
in ``default_data_type_transformers`` and then use the default if it doesn't find one
when looping through ``field_data_types``. It should then call ``_fit_field_transformer``
with the correct arguments.
Setup:
Expand Down Expand Up @@ -589,14 +589,14 @@ def test_fit(self, get_default_transformer_mock):
'float': float_transformer,
'integer.out': int_out_transformer
}
data_type_transformers = {
default_data_type_transformers = {
'boolean': bool_transformer,
'categorical': categorical_transformer
}
get_default_transformer_mock.return_value = datetime_transformer
ht = HyperTransformer(
field_transformers=field_transformers,
data_type_transformers=data_type_transformers
default_data_type_transformers=default_data_type_transformers
)
ht._fit_field_transformer = Mock()
ht._fit_field_transformer.return_value = data
Expand Down

0 comments on commit dfdabc0

Please sign in to comment.