Skip to content

Commit

Permalink
Make private attributes public in the metadata (#1246)
Browse files Browse the repository at this point in the history
* _columns -> columns

* _primary_key -> primary_key

* _alternate_keys -> alternate_keys

* _sequence_key -> sequence_key

* _sequence_index -> sequence_index

* _tables -> tables

* _relationships -> relationships

* Fix keys

* Fix rebase
  • Loading branch information
fealho authored and R-Palazzo committed Feb 13, 2023
1 parent a8af857 commit b95574a
Show file tree
Hide file tree
Showing 26 changed files with 404 additions and 404 deletions.
2 changes: 1 addition & 1 deletion sdv/constraints/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,7 +139,7 @@ def _validate_metadata_columns(cls, metadata, **kwargs):
else:
column_names = kwargs.get('column_names')

missing_columns = set(column_names) - set(metadata._columns) - {None}
missing_columns = set(column_names) - set(metadata.columns) - {None}
if missing_columns:
article = 'An' if cls.__name__ == 'Inequality' else 'A'
raise ConstraintMetadataError(
Expand Down
32 changes: 16 additions & 16 deletions sdv/constraints/tabular.py
Original file line number Diff line number Diff line change
Expand Up @@ -351,8 +351,8 @@ def _validate_metadata_columns(cls, metadata, **kwargs):
def _validate_metadata_specific_to_constraint(metadata, **kwargs):
high = kwargs.get('high_column_name')
low = kwargs.get('low_column_name')
high_sdtype = metadata._columns.get(high, {}).get('sdtype')
low_sdtype = metadata._columns.get(low, {}).get('sdtype')
high_sdtype = metadata.columns.get(high, {}).get('sdtype')
low_sdtype = metadata.columns.get(low, {}).get('sdtype')
both_datetime = high_sdtype == low_sdtype == 'datetime'
both_numerical = high_sdtype == low_sdtype == 'numerical'
if not (both_datetime or both_numerical) and not (high is None or low is None):
Expand Down Expand Up @@ -520,14 +520,14 @@ def _validate_inputs(cls, **kwargs):
@staticmethod
def _validate_metadata_specific_to_constraint(metadata, **kwargs):
column_name = kwargs.get('column_name')
sdtype = metadata._columns.get(column_name, {}).get('sdtype')
sdtype = metadata.columns.get(column_name, {}).get('sdtype')
value = kwargs.get('value')
if sdtype == 'numerical':
if not isinstance(value, (int, float)):
raise ConstraintMetadataError("'value' must be an int or float.")

elif sdtype == 'datetime':
datetime_format = metadata._columns.get(column_name).get('datetime_format')
datetime_format = metadata.columns.get(column_name).get('datetime_format')
matches_format = matches_datetime_format(value, datetime_format)
if not matches_format:
raise ConstraintMetadataError(
Expand Down Expand Up @@ -685,7 +685,7 @@ class Positive(ScalarInequality):
@staticmethod
def _validate_metadata_specific_to_constraint(metadata, **kwargs):
column_name = kwargs.get('column_name')
sdtype = metadata._columns.get(column_name, {}).get('sdtype')
sdtype = metadata.columns.get(column_name, {}).get('sdtype')
if sdtype != 'numerical':
raise ConstraintMetadataError(
f'A Positive constraint is being applied to an invalid column '
Expand Down Expand Up @@ -714,7 +714,7 @@ class Negative(ScalarInequality):
@staticmethod
def _validate_metadata_specific_to_constraint(metadata, **kwargs):
column_name = kwargs.get('column_name')
sdtype = metadata._columns.get(column_name, {}).get('sdtype')
sdtype = metadata.columns.get(column_name, {}).get('sdtype')
if sdtype != 'numerical':
raise ConstraintMetadataError(
f'A Negative constraint is being applied to an invalid column '
Expand Down Expand Up @@ -759,9 +759,9 @@ def _validate_metadata_specific_to_constraint(metadata, **kwargs):
high = kwargs.get('high_column_name')
low = kwargs.get('low_column_name')
middle = kwargs.get('middle_column_name')
high_sdtype = metadata._columns.get(high, {}).get('sdtype')
low_sdtype = metadata._columns.get(low, {}).get('sdtype')
middle_sdtype = metadata._columns.get(middle, {}).get('sdtype')
high_sdtype = metadata.columns.get(high, {}).get('sdtype')
low_sdtype = metadata.columns.get(low, {}).get('sdtype')
middle_sdtype = metadata.columns.get(middle, {}).get('sdtype')
all_datetime = high_sdtype == low_sdtype == middle_sdtype == 'datetime'
all_numerical = high_sdtype == low_sdtype == middle_sdtype == 'numerical'
if not (all_datetime or all_numerical) and \
Expand Down Expand Up @@ -942,12 +942,12 @@ def _validate_init_inputs(low_value, high_value):
@staticmethod
def _validate_metadata_specific_to_constraint(metadata, **kwargs):
column_name = kwargs.get('column_name')
if column_name not in metadata._columns:
if column_name not in metadata.columns:
raise ConstraintMetadataError(
f'A ScalarRange constraint is being applied to invalid column names '
f'({column_name}). The columns must exist in the table.'
)
sdtype = metadata._columns.get(column_name).get('sdtype')
sdtype = metadata.columns.get(column_name).get('sdtype')
high_value = kwargs.get('high_value')
low_value = kwargs.get('low_value')
if sdtype == 'numerical':
Expand All @@ -957,7 +957,7 @@ def _validate_metadata_specific_to_constraint(metadata, **kwargs):
)

elif sdtype == 'datetime':
datetime_format = metadata._columns.get(column_name, {}).get('datetime_format')
datetime_format = metadata.columns.get(column_name, {}).get('datetime_format')
high_matches_format = matches_datetime_format(high_value, datetime_format)
low_matches_format = matches_datetime_format(low_value, datetime_format)
if not (low_matches_format and high_matches_format):
Expand Down Expand Up @@ -1275,12 +1275,12 @@ def __init__(self, column_names):
def _validate_metadata_specific_to_constraint(metadata, **kwargs):
column_names = kwargs.get('column_names')
keys = set()
if isinstance(metadata._primary_key, tuple):
keys.update(metadata._primary_key)
if isinstance(metadata.primary_key, tuple):
keys.update(metadata.primary_key)
else:
keys.add(metadata._primary_key)
keys.add(metadata.primary_key)

for key in metadata._alternate_keys:
for key in metadata.alternate_keys:
if isinstance(key, tuple):
keys.update(key)
else:
Expand Down
14 changes: 7 additions & 7 deletions sdv/data_processing/data_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,9 +103,9 @@ def __init__(self, metadata, enforce_rounding=True, enforce_min_max_values=True,
self._dtypes = None
self.fitted = False
self.formatters = {}
self._primary_key = self.metadata._primary_key
self._primary_key = self.metadata.primary_key
self._prepared_for_fitting = False
self._keys = deepcopy(self.metadata._alternate_keys)
self._keys = deepcopy(self.metadata.alternate_keys)
self._keys_generators = {}
if self._primary_key:
self._keys.append(self._primary_key)
Expand Down Expand Up @@ -147,7 +147,7 @@ def get_sdtypes(self, primary_keys=False):
Dictionary that contains the column names and ``sdtypes``.
"""
sdtypes = {}
for name, column_metadata in self.metadata._columns.items():
for name, column_metadata in self.metadata.columns.items():
sdtype = column_metadata['sdtype']

if primary_keys or (name not in self._keys):
Expand Down Expand Up @@ -444,7 +444,7 @@ def _create_config(self, data, columns_created_by_constraints):
transformers = {}

for column in set(data.columns) - columns_created_by_constraints:
column_metadata = self.metadata._columns.get(column)
column_metadata = self.metadata.columns.get(column)
sdtype = column_metadata.get('sdtype')
pii = column_metadata.get('pii', sdtype not in self._DEFAULT_TRANSFORMERS_BY_SDTYPE)
sdtypes[column] = 'pii' if pii else sdtype
Expand Down Expand Up @@ -517,7 +517,7 @@ def _fit_hyper_transformer(self, data):
def _fit_formatters(self, data):
"""Fit ``NumericalFormatter`` and ``DatetimeFormatter`` for each column in the data."""
for column_name in data:
column_metadata = self.metadata._columns.get(column_name)
column_metadata = self.metadata.columns.get(column_name)
sdtype = column_metadata.get('sdtype')
if sdtype == 'numerical' and column_name != self._primary_key:
representation = column_metadata.get('computer_representation', 'Float')
Expand Down Expand Up @@ -741,7 +741,7 @@ def reverse_transform(self, data, reset_keys=False):
sampled_columns = list(reversed_data.columns)
missing_columns = [
column
for column in self.metadata._columns.keys() - set(sampled_columns + self._keys)
for column in self.metadata.columns.keys() - set(sampled_columns + self._keys)
if self._hyper_transformer.field_transformers.get(column)
]
if missing_columns:
Expand All @@ -760,7 +760,7 @@ def reverse_transform(self, data, reset_keys=False):
# And alternate keys. Thats the reason of ensuring that the metadata column is within
# The sampled columns.
sampled_columns = [
column for column in self.metadata._columns.keys()
column for column in self.metadata.columns.keys()
if column in sampled_columns
]
for column_name in sampled_columns:
Expand Down
Loading

0 comments on commit b95574a

Please sign in to comment.