From b95574a607a8ad2612340bfdcba93b33c4c5c7a1 Mon Sep 17 00:00:00 2001 From: Felipe Alex Hofmann Date: Fri, 10 Feb 2023 12:14:55 -0800 Subject: [PATCH] Make private attributes public in the metadata (#1246) * _columns -> columns * _primary_key -> primary_key * _alternate_keys -> alternate_keys * _sequence_key -> sequence_key * _sequence_index -> sequence_index * _tables -> tables * _relationships -> relationships * Fix keys * Fix rebase --- sdv/constraints/base.py | 2 +- sdv/constraints/tabular.py | 32 ++-- sdv/data_processing/data_processor.py | 14 +- sdv/metadata/multi_table.py | 86 ++++----- sdv/metadata/single_table.py | 84 ++++---- sdv/multi_table/base.py | 10 +- sdv/multi_table/hma.py | 12 +- sdv/sequential/par.py | 6 +- sdv/single_table/base.py | 22 +-- sdv/single_table/copulagan.py | 4 +- sdv/single_table/copulas.py | 6 +- sdv/single_table/utils.py | 4 +- .../integration/metadata/test_multi_table.py | 4 +- .../integration/metadata/test_single_table.py | 20 +- tests/integration/multi_table/test_hma.py | 4 +- tests/unit/constraints/test_tabular.py | 100 +++++----- .../data_processing/test_data_processor.py | 12 +- tests/unit/metadata/test_multi_table.py | 168 ++++++++-------- tests/unit/metadata/test_single_table.py | 180 +++++++++--------- tests/unit/multi_table/test_base.py | 6 +- tests/unit/multi_table/test_hma.py | 12 +- tests/unit/sequential/test_par.py | 2 +- tests/unit/single_table/test_base.py | 12 +- tests/unit/single_table/test_copulagan.py | 2 +- tests/unit/single_table/test_copulas.py | 2 +- tests/unit/single_table/test_utils.py | 2 +- 26 files changed, 404 insertions(+), 404 deletions(-) diff --git a/sdv/constraints/base.py b/sdv/constraints/base.py index bca03bc7d..82e50927d 100644 --- a/sdv/constraints/base.py +++ b/sdv/constraints/base.py @@ -139,7 +139,7 @@ def _validate_metadata_columns(cls, metadata, **kwargs): else: column_names = kwargs.get('column_names') - missing_columns = set(column_names) - set(metadata._columns) - {None} + missing_columns = set(column_names) - set(metadata.columns) - {None} if missing_columns: article = 'An' if cls.__name__ == 'Inequality' else 'A' raise ConstraintMetadataError( diff --git a/sdv/constraints/tabular.py b/sdv/constraints/tabular.py index 1d2a3b01b..e56cc7925 100644 --- a/sdv/constraints/tabular.py +++ b/sdv/constraints/tabular.py @@ -351,8 +351,8 @@ def _validate_metadata_columns(cls, metadata, **kwargs): def _validate_metadata_specific_to_constraint(metadata, **kwargs): high = kwargs.get('high_column_name') low = kwargs.get('low_column_name') - high_sdtype = metadata._columns.get(high, {}).get('sdtype') - low_sdtype = metadata._columns.get(low, {}).get('sdtype') + high_sdtype = metadata.columns.get(high, {}).get('sdtype') + low_sdtype = metadata.columns.get(low, {}).get('sdtype') both_datetime = high_sdtype == low_sdtype == 'datetime' both_numerical = high_sdtype == low_sdtype == 'numerical' if not (both_datetime or both_numerical) and not (high is None or low is None): @@ -520,14 +520,14 @@ def _validate_inputs(cls, **kwargs): @staticmethod def _validate_metadata_specific_to_constraint(metadata, **kwargs): column_name = kwargs.get('column_name') - sdtype = metadata._columns.get(column_name, {}).get('sdtype') + sdtype = metadata.columns.get(column_name, {}).get('sdtype') value = kwargs.get('value') if sdtype == 'numerical': if not isinstance(value, (int, float)): raise ConstraintMetadataError("'value' must be an int or float.") elif sdtype == 'datetime': - datetime_format = metadata._columns.get(column_name).get('datetime_format') + datetime_format = metadata.columns.get(column_name).get('datetime_format') matches_format = matches_datetime_format(value, datetime_format) if not matches_format: raise ConstraintMetadataError( @@ -685,7 +685,7 @@ class Positive(ScalarInequality): @staticmethod def _validate_metadata_specific_to_constraint(metadata, **kwargs): column_name = kwargs.get('column_name') - sdtype = metadata._columns.get(column_name, {}).get('sdtype') + sdtype = metadata.columns.get(column_name, {}).get('sdtype') if sdtype != 'numerical': raise ConstraintMetadataError( f'A Positive constraint is being applied to an invalid column ' @@ -714,7 +714,7 @@ class Negative(ScalarInequality): @staticmethod def _validate_metadata_specific_to_constraint(metadata, **kwargs): column_name = kwargs.get('column_name') - sdtype = metadata._columns.get(column_name, {}).get('sdtype') + sdtype = metadata.columns.get(column_name, {}).get('sdtype') if sdtype != 'numerical': raise ConstraintMetadataError( f'A Negative constraint is being applied to an invalid column ' @@ -759,9 +759,9 @@ def _validate_metadata_specific_to_constraint(metadata, **kwargs): high = kwargs.get('high_column_name') low = kwargs.get('low_column_name') middle = kwargs.get('middle_column_name') - high_sdtype = metadata._columns.get(high, {}).get('sdtype') - low_sdtype = metadata._columns.get(low, {}).get('sdtype') - middle_sdtype = metadata._columns.get(middle, {}).get('sdtype') + high_sdtype = metadata.columns.get(high, {}).get('sdtype') + low_sdtype = metadata.columns.get(low, {}).get('sdtype') + middle_sdtype = metadata.columns.get(middle, {}).get('sdtype') all_datetime = high_sdtype == low_sdtype == middle_sdtype == 'datetime' all_numerical = high_sdtype == low_sdtype == middle_sdtype == 'numerical' if not (all_datetime or all_numerical) and \ @@ -942,12 +942,12 @@ def _validate_init_inputs(low_value, high_value): @staticmethod def _validate_metadata_specific_to_constraint(metadata, **kwargs): column_name = kwargs.get('column_name') - if column_name not in metadata._columns: + if column_name not in metadata.columns: raise ConstraintMetadataError( f'A ScalarRange constraint is being applied to invalid column names ' f'({column_name}). The columns must exist in the table.' ) - sdtype = metadata._columns.get(column_name).get('sdtype') + sdtype = metadata.columns.get(column_name).get('sdtype') high_value = kwargs.get('high_value') low_value = kwargs.get('low_value') if sdtype == 'numerical': @@ -957,7 +957,7 @@ def _validate_metadata_specific_to_constraint(metadata, **kwargs): ) elif sdtype == 'datetime': - datetime_format = metadata._columns.get(column_name, {}).get('datetime_format') + datetime_format = metadata.columns.get(column_name, {}).get('datetime_format') high_matches_format = matches_datetime_format(high_value, datetime_format) low_matches_format = matches_datetime_format(low_value, datetime_format) if not (low_matches_format and high_matches_format): @@ -1275,12 +1275,12 @@ def __init__(self, column_names): def _validate_metadata_specific_to_constraint(metadata, **kwargs): column_names = kwargs.get('column_names') keys = set() - if isinstance(metadata._primary_key, tuple): - keys.update(metadata._primary_key) + if isinstance(metadata.primary_key, tuple): + keys.update(metadata.primary_key) else: - keys.add(metadata._primary_key) + keys.add(metadata.primary_key) - for key in metadata._alternate_keys: + for key in metadata.alternate_keys: if isinstance(key, tuple): keys.update(key) else: diff --git a/sdv/data_processing/data_processor.py b/sdv/data_processing/data_processor.py index 7f64a2007..419ad4e1f 100644 --- a/sdv/data_processing/data_processor.py +++ b/sdv/data_processing/data_processor.py @@ -103,9 +103,9 @@ def __init__(self, metadata, enforce_rounding=True, enforce_min_max_values=True, self._dtypes = None self.fitted = False self.formatters = {} - self._primary_key = self.metadata._primary_key + self._primary_key = self.metadata.primary_key self._prepared_for_fitting = False - self._keys = deepcopy(self.metadata._alternate_keys) + self._keys = deepcopy(self.metadata.alternate_keys) self._keys_generators = {} if self._primary_key: self._keys.append(self._primary_key) @@ -147,7 +147,7 @@ def get_sdtypes(self, primary_keys=False): Dictionary that contains the column names and ``sdtypes``. """ sdtypes = {} - for name, column_metadata in self.metadata._columns.items(): + for name, column_metadata in self.metadata.columns.items(): sdtype = column_metadata['sdtype'] if primary_keys or (name not in self._keys): @@ -444,7 +444,7 @@ def _create_config(self, data, columns_created_by_constraints): transformers = {} for column in set(data.columns) - columns_created_by_constraints: - column_metadata = self.metadata._columns.get(column) + column_metadata = self.metadata.columns.get(column) sdtype = column_metadata.get('sdtype') pii = column_metadata.get('pii', sdtype not in self._DEFAULT_TRANSFORMERS_BY_SDTYPE) sdtypes[column] = 'pii' if pii else sdtype @@ -517,7 +517,7 @@ def _fit_hyper_transformer(self, data): def _fit_formatters(self, data): """Fit ``NumericalFormatter`` and ``DatetimeFormatter`` for each column in the data.""" for column_name in data: - column_metadata = self.metadata._columns.get(column_name) + column_metadata = self.metadata.columns.get(column_name) sdtype = column_metadata.get('sdtype') if sdtype == 'numerical' and column_name != self._primary_key: representation = column_metadata.get('computer_representation', 'Float') @@ -741,7 +741,7 @@ def reverse_transform(self, data, reset_keys=False): sampled_columns = list(reversed_data.columns) missing_columns = [ column - for column in self.metadata._columns.keys() - set(sampled_columns + self._keys) + for column in self.metadata.columns.keys() - set(sampled_columns + self._keys) if self._hyper_transformer.field_transformers.get(column) ] if missing_columns: @@ -760,7 +760,7 @@ def reverse_transform(self, data, reset_keys=False): # And alternate keys. Thats the reason of ensuring that the metadata column is within # The sampled columns. sampled_columns = [ - column for column in self.metadata._columns.keys() + column for column in self.metadata.columns.keys() if column in sampled_columns ] for column_name in sampled_columns: diff --git a/sdv/metadata/multi_table.py b/sdv/metadata/multi_table.py index dcbbf2734..904b3e0a9 100644 --- a/sdv/metadata/multi_table.py +++ b/sdv/metadata/multi_table.py @@ -22,14 +22,14 @@ class MultiTableMetadata: METADATA_SPEC_VERSION = 'MULTI_TABLE_V1' def __init__(self): - self._tables = {} - self._relationships = [] + self.tables = {} + self.relationships = [] def _validate_missing_relationship_keys(self, parent_table_name, parent_primary_key, child_table_name, child_foreign_key): - parent_table = self._tables.get(parent_table_name) - child_table = self._tables.get(child_table_name) - if parent_table._primary_key is None: + parent_table = self.tables.get(parent_table_name) + child_table = self.tables.get(child_table_name) + if parent_table.primary_key is None: raise InvalidMetadataError( f"The parent table '{parent_table_name}' does not have a primary key set. " "Please use 'set_primary_key' in order to set one." @@ -37,7 +37,7 @@ def _validate_missing_relationship_keys(self, parent_table_name, parent_primary_ missing_keys = set() parent_primary_key = cast_to_iterable(parent_primary_key) - table_primary_keys = set(cast_to_iterable(parent_table._primary_key)) + table_primary_keys = set(cast_to_iterable(parent_table.primary_key)) for key in parent_primary_key: if key not in table_primary_keys: missing_keys.add(key) @@ -49,7 +49,7 @@ def _validate_missing_relationship_keys(self, parent_table_name, parent_primary_ ) for key in set(cast_to_iterable(child_foreign_key)): - if key not in child_table._columns: + if key not in child_table.columns: missing_keys.add(key) if missing_keys: @@ -83,8 +83,8 @@ def _validate_relationship_key_length(parent_table_name, parent_primary_key, def _validate_relationship_sdtypes(self, parent_table_name, parent_primary_key, child_table_name, child_foreign_key): - parent_table_columns = self._tables.get(parent_table_name)._columns - child_table_columns = self._tables.get(child_table_name)._columns + parent_table_columns = self.tables.get(parent_table_name).columns + child_table_columns = self.tables.get(child_table_name).columns parent_primary_key = cast_to_iterable(parent_primary_key) child_foreign_key = cast_to_iterable(child_foreign_key) for pk, fk in zip(parent_primary_key, child_foreign_key): @@ -119,7 +119,7 @@ def _validate_circular_relationships(self, parent, children=None, def _validate_child_map_circular_relationship(self, child_map): errors = [] - for table_name in self._tables.keys(): + for table_name in self.tables.keys(): self._validate_circular_relationships(table_name, child_map=child_map, errors=errors) if errors: @@ -129,7 +129,7 @@ def _validate_child_map_circular_relationship(self, child_map): ) def _validate_foreign_child_key(self, child_table_name, child_foreign_key): - child_primary_key = cast_to_iterable(self._tables[child_table_name]._primary_key) + child_primary_key = cast_to_iterable(self.tables[child_table_name].primary_key) child_foreign_key = cast_to_iterable(child_foreign_key) if set(child_foreign_key).intersection(set(child_primary_key)): raise InvalidMetadataError( @@ -137,7 +137,7 @@ def _validate_foreign_child_key(self, child_table_name, child_foreign_key): def _validate_relationship_does_not_exist(self, parent_table_name, parent_primary_key, child_table_name, child_foreign_key): - for relationship in self._relationships: + for relationship in self.relationships: already_exists = ( relationship['parent_table_name'] == parent_table_name and relationship['parent_primary_key'] == parent_primary_key and @@ -150,7 +150,7 @@ def _validate_relationship_does_not_exist(self, parent_table_name, parent_primar def _validate_relationship(self, parent_table_name, child_table_name, parent_primary_key, child_foreign_key): self._validate_no_missing_tables_in_relationship( - parent_table_name, child_table_name, self._tables.keys()) + parent_table_name, child_table_name, self.tables.keys()) self._validate_missing_relationship_keys( parent_table_name, @@ -176,7 +176,7 @@ def _validate_relationship(self, parent_table_name, child_table_name, def _get_child_map(self): child_map = defaultdict(set) - for relation in self._relationships: + for relation in self.relationships: parent_name = relation['parent_table_name'] child_name = relation['child_table_name'] child_map[parent_name].add(child_name) @@ -223,7 +223,7 @@ def add_relationship(self, parent_table_name, child_table_name, ) self._validate_child_map_circular_relationship(child_map) - self._relationships.append({ + self.relationships.append({ 'parent_table_name': parent_table_name, 'child_table_name': child_table_name, 'parent_primary_key': deepcopy(parent_primary_key), @@ -231,7 +231,7 @@ def add_relationship(self, parent_table_name, child_table_name, }) def _validate_table_exists(self, table_name): - if table_name not in self._tables: + if table_name not in self.tables: raise InvalidMetadataError(f"Unknown table name ('{table_name}').") def add_column(self, table_name, column_name, **kwargs): @@ -253,7 +253,7 @@ def add_column(self, table_name, column_name, **kwargs): - ``InvalidMetadataError`` if the table doesn't exist in the ``MultiTableMetadata``. """ self._validate_table_exists(table_name) - table = self._tables.get(table_name) + table = self.tables.get(table_name) table.add_column(column_name, **kwargs) def update_column(self, table_name, column_name, **kwargs): @@ -275,7 +275,7 @@ def update_column(self, table_name, column_name, **kwargs): - ``InvalidMetadataError`` if the table doesn't exist in the ``MultiTableMetadata``. """ self._validate_table_exists(table_name) - table = self._tables.get(table_name) + table = self.tables.get(table_name) table.update_column(column_name, **kwargs) def add_constraint(self, table_name, constraint_name, **kwargs): @@ -290,11 +290,11 @@ def add_constraint(self, table_name, constraint_name, **kwargs): Any other arguments the constraint requires. """ self._validate_table_exists(table_name) - table = self._tables.get(table_name) + table = self.tables.get(table_name) table.add_constraint(constraint_name, **kwargs) def _validate_table_not_detected(self, table_name): - if table_name in self._tables: + if table_name in self.tables: raise InvalidMetadataError( f"Metadata for table '{table_name}' already exists. Specify a new table name or " 'create a new MultiTableMetadata object for other data sources.' @@ -319,7 +319,7 @@ def detect_table_from_dataframe(self, table_name, data): self._validate_table_not_detected(table_name) table = SingleTableMetadata() table._detect_columns(data) - self._tables[table_name] = table + self.tables[table_name] = table self._log_detected_table(table) def detect_table_from_csv(self, table_name, filepath): @@ -335,7 +335,7 @@ def detect_table_from_csv(self, table_name, filepath): table = SingleTableMetadata() data = table._load_data_from_csv(filepath) table._detect_columns(data) - self._tables[table_name] = table + self.tables[table_name] = table self._log_detected_table(table) def set_primary_key(self, table_name, column_name): @@ -348,7 +348,7 @@ def set_primary_key(self, table_name, column_name): Name (or tuple of names) of the primary key column(s). """ self._validate_table_exists(table_name) - self._tables[table_name].set_primary_key(column_name) + self.tables[table_name].set_primary_key(column_name) def set_sequence_key(self, table_name, column_name): """Set the sequence key of a table. @@ -361,7 +361,7 @@ def set_sequence_key(self, table_name, column_name): """ self._validate_table_exists(table_name) warnings.warn('Sequential modeling is not yet supported on SDV Multi Table models.') - self._tables[table_name].set_sequence_key(column_name) + self.tables[table_name].set_sequence_key(column_name) def add_alternate_keys(self, table_name, column_names): """Set the alternate keys of a table. @@ -373,7 +373,7 @@ def add_alternate_keys(self, table_name, column_names): List of names (or tuple of names) of the alternate key columns. """ self._validate_table_exists(table_name) - self._tables[table_name].add_alternate_keys(column_names) + self.tables[table_name].add_alternate_keys(column_names) def set_sequence_index(self, table_name, column_name): """Set the sequence index of a table. @@ -386,11 +386,11 @@ def set_sequence_index(self, table_name, column_name): """ self._validate_table_exists(table_name) warnings.warn('Sequential modeling is not yet supported on SDV Multi Table models.') - self._tables[table_name].set_sequence_index(column_name) + self.tables[table_name].set_sequence_index(column_name) def _validate_single_table(self, errors): - for table_name, table in self._tables.items(): - if len(table._columns) == 0: + for table_name, table in self.tables.items(): + if len(table.columns) == 0: error_message = ( f"Table '{table_name}' has 0 columns. Use 'add_column' to specify its columns." ) @@ -405,7 +405,7 @@ def _validate_single_table(self, errors): errors.append(error) def _validate_all_tables_connected(self, parent_map, child_map): - nodes = list(self._tables.keys()) + nodes = list(self.tables.keys()) queue = [nodes[0]] connected = {table_name: False for table_name in nodes} @@ -441,7 +441,7 @@ def _append_relationships_errors(self, errors, method, *args, **kwargs): def _get_parent_map(self): parent_map = defaultdict(set) - for relation in self._relationships: + for relation in self.relationships: parent_name = relation['parent_table_name'] child_name = relation['child_table_name'] parent_map[child_name].add(parent_name) @@ -456,7 +456,7 @@ def validate(self): """ errors = [] self._validate_single_table(errors) - for relation in self._relationships: + for relation in self.relationships: self._append_relationships_errors(errors, self._validate_relationship, **relation) parent_map = self._get_parent_map() @@ -487,13 +487,13 @@ def add_table(self, table_name): "Invalid table name (''). The table name must be a non-empty string." ) - if table_name in self._tables: + if table_name in self.tables: raise InvalidMetadataError( f"Cannot add a table named '{table_name}' because it already exists in the " 'metadata. Please choose a different name.' ) - self._tables[table_name] = SingleTableMetadata() + self.tables[table_name] = SingleTableMetadata() def visualize(self, show_table_details=True, show_relationship_labels=True, output_filepath=None): @@ -516,22 +516,22 @@ def visualize(self, show_table_details=True, show_relationship_labels=True, nodes = {} edges = [] if show_table_details: - for table_name, table_meta in self._tables.items(): - column_dict = table_meta._columns.items() + for table_name, table_meta in self.tables.items(): + column_dict = table_meta.columns.items() columns = [f"{name} : {meta.get('sdtype')}" for name, meta in column_dict] nodes[table_name] = { 'columns': r'\l'.join(columns), - 'primary_key': f'Primary key: {table_meta._primary_key}' + 'primary_key': f'Primary key: {table_meta.primary_key}' } else: - nodes = {table_name: None for table_name in self._tables} + nodes = {table_name: None for table_name in self.tables} - for relationship in self._relationships: + for relationship in self.relationships: parent = relationship.get('parent_table_name') child = relationship.get('child_table_name') foreign_key = relationship.get('child_foreign_key') - primary_key = self._tables.get(parent)._primary_key + primary_key = self.tables.get(parent).primary_key edge_label = f' {foreign_key} → {primary_key}' if show_relationship_labels else '' edges.append((parent, child, edge_label)) @@ -559,12 +559,12 @@ def visualize(self, show_table_details=True, show_relationship_labels=True, def to_dict(self): """Return a python ``dict`` representation of the ``MultiTableMetadata``.""" metadata = {'tables': {}, 'relationships': []} - for table_name, single_table_metadata in self._tables.items(): + for table_name, single_table_metadata in self.tables.items(): table_dict = single_table_metadata.to_dict() table_dict.pop('METADATA_SPEC_VERSION', None) metadata['tables'][table_name] = table_dict - metadata['relationships'] = deepcopy(self._relationships) + metadata['relationships'] = deepcopy(self.relationships) metadata['METADATA_SPEC_VERSION'] = self.METADATA_SPEC_VERSION return metadata @@ -576,10 +576,10 @@ def _set_metadata_dict(self, metadata): Python dictionary representing a ``MultiTableMetadata`` object. """ for table_name, table_dict in metadata.get('tables', {}).items(): - self._tables[table_name] = SingleTableMetadata._load_from_dict(table_dict) + self.tables[table_name] = SingleTableMetadata._load_from_dict(table_dict) for relationship in metadata.get('relationships', []): - self._relationships.append(relationship) + self.relationships.append(relationship) @classmethod def _load_from_dict(cls, metadata): diff --git a/sdv/metadata/single_table.py b/sdv/metadata/single_table.py index a73e1daaf..a92b74acd 100644 --- a/sdv/metadata/single_table.py +++ b/sdv/metadata/single_table.py @@ -117,11 +117,11 @@ def _validate_pii(column_name, **kwargs): ) def __init__(self): - self._columns = {} - self._primary_key = None - self._alternate_keys = [] - self._sequence_key = None - self._sequence_index = None + self.columns = {} + self.primary_key = None + self.alternate_keys = [] + self.sequence_key = None + self.sequence_index = None self._version = self.METADATA_SPEC_VERSION def _validate_unexpected_kwargs(self, column_name, sdtype, **kwargs): @@ -173,7 +173,7 @@ def add_column(self, column_name, **kwargs): - ``InvalidMetadataError`` if the ``pii`` value is not ``True`` or ``False`` when present. """ - if column_name in self._columns: + if column_name in self.columns: raise InvalidMetadataError( f"Column name '{column_name}' already exists. Use 'update_column' " 'to update an existing column.' @@ -189,10 +189,10 @@ def add_column(self, column_name, **kwargs): pii = column_kwargs.get('pii', True) column_kwargs['pii'] = pii - self._columns[column_name] = column_kwargs + self.columns[column_name] = column_kwargs def _validate_column_exists(self, column_name): - if column_name not in self._columns: + if column_name not in self.columns: raise InvalidMetadataError( f"Column name ('{column_name}') does not exist in the table. " "Use 'add_column' to add new column." @@ -221,17 +221,17 @@ def update_column(self, column_name, **kwargs): if 'sdtype' in kwargs: sdtype = kwargs.pop('sdtype') else: - sdtype = self._columns[column_name]['sdtype'] + sdtype = self.columns[column_name]['sdtype'] _kwargs['sdtype'] = sdtype self._validate_column(column_name, sdtype, **kwargs) - self._columns[column_name] = _kwargs + self.columns[column_name] = _kwargs def to_dict(self): """Return a python ``dict`` representation of the ``SingleTableMetadata``.""" metadata = {} for key in self._KEYS: - value = getattr(self, f'_{key}') if key != 'METADATA_SPEC_VERSION' else self._version + value = getattr(self, f'{key}') if key != 'METADATA_SPEC_VERSION' else self._version if value: metadata[key] = value @@ -241,7 +241,7 @@ def _detect_columns(self, data): for field in data: clean_data = data[field].dropna() kind = clean_data.infer_objects().dtype.kind - self._columns[field] = {'sdtype': self._DTYPES_TO_SDTYPES.get(kind, 'categorical')} + self.columns[field] = {'sdtype': self._DTYPES_TO_SDTYPES.get(kind, 'categorical')} def detect_from_dataframe(self, data): """Detect the metadata from a ``pd.DataFrame`` object. @@ -252,7 +252,7 @@ def detect_from_dataframe(self, data): data (pandas.DataFrame): ``pandas.DataFrame`` to detect the metadata from. """ - if self._columns: + if self.columns: raise InvalidMetadataError( 'Metadata already exists. Create a new ``SingleTableMetadata`` ' 'object to detect from other data sources.' @@ -275,7 +275,7 @@ def detect_from_csv(self, filepath, pandas_kwargs=None): A python dictionary of with string and value accepted by ``pandas.read_csv`` function. Defaults to ``None``. """ - if self._columns: + if self.columns: raise InvalidMetadataError( 'Metadata already exists. Create a new ``SingleTableMetadata`` ' 'object to detect from other data sources.' @@ -294,7 +294,7 @@ def _validate_keys_sdtype(self, keys, key_type): """Validate that no key is of type 'categorical'.""" bad_sdtypes = ('boolean', 'categorical') categorical_keys = sorted( - {key for key in keys if self._columns[key]['sdtype'] in bad_sdtypes} + {key for key in keys if self.columns[key]['sdtype'] in bad_sdtypes} ) if categorical_keys: raise InvalidMetadataError( @@ -310,7 +310,7 @@ def _validate_key(self, column_name, key_type): f"'{key_type}_key' must be a string or tuple of strings.") keys = {column_name} if isinstance(column_name, str) else set(column_name) - invalid_ids = keys - set(self._columns) + invalid_ids = keys - set(self.columns) if invalid_ids: raise InvalidMetadataError( f'Unknown {key_type} key values {invalid_ids}.' @@ -327,20 +327,20 @@ def set_primary_key(self, column_name): Name (or tuple of names) of the primary key column(s). """ self._validate_key(column_name, 'primary') - if column_name in self._alternate_keys: + if column_name in self.alternate_keys: warnings.warn( f'{column_name} is currently set as an alternate key and will be removed from ' 'that list.' ) - self._alternate_keys.remove(column_name) + self.alternate_keys.remove(column_name) - if self._primary_key is not None: + if self.primary_key is not None: warnings.warn( - f'There is an existing primary key {self._primary_key}.' + f'There is an existing primary key {self.primary_key}.' ' This key will be removed.' ) - self._primary_key = column_name + self.primary_key = column_name def set_sequence_key(self, column_name): """Set the metadata sequence key. @@ -350,13 +350,13 @@ def set_sequence_key(self, column_name): Name (or tuple of names) of the sequence key column(s). """ self._validate_key(column_name, 'sequence') - if self._sequence_key is not None: + if self.sequence_key is not None: warnings.warn( - f'There is an existing sequence key {self._sequence_key}.' + f'There is an existing sequence key {self.sequence_key}.' ' This key will be removed.' ) - self._sequence_key = column_name + self.sequence_key = column_name def _validate_alternate_keys(self, column_names): if not isinstance(column_names, list) or \ @@ -369,16 +369,16 @@ def _validate_alternate_keys(self, column_names): for column_name in column_names: keys.update({column_name} if isinstance(column_name, str) else set(column_name)) - invalid_ids = keys - set(self._columns) + invalid_ids = keys - set(self.columns) if invalid_ids: raise InvalidMetadataError( f'Unknown alternate key values {invalid_ids}.' ' Keys should be columns that exist in the table.' ) - if self._primary_key in column_names: + if self.primary_key in column_names: raise InvalidMetadataError( - f"Invalid alternate key '{self._primary_key}'. The key is " + f"Invalid alternate key '{self.primary_key}'. The key is " 'already specified as a primary key.' ) @@ -393,23 +393,23 @@ def add_alternate_keys(self, column_names): """ self._validate_alternate_keys(column_names) for column in column_names: - if column in self._alternate_keys: + if column in self.alternate_keys: warnings.warn(f'{column} is already an alternate key.') else: - self._alternate_keys.append(column) + self.alternate_keys.append(column) def _validate_sequence_index(self, column_name): if not isinstance(column_name, str): raise InvalidMetadataError("'sequence_index' must be a string.") - if column_name not in self._columns: + if column_name not in self.columns: column_name = {column_name} raise InvalidMetadataError( f'Unknown sequence index value {column_name}.' ' Keys should be columns that exist in the table.' ) - sdtype = self._columns[column_name].get('sdtype') + sdtype = self.columns[column_name].get('sdtype') if sdtype not in ['datetime', 'numerical']: raise InvalidMetadataError( "The sequence_index must be of type 'datetime' or 'numerical'.") @@ -422,14 +422,14 @@ def set_sequence_index(self, column_name): Name of the sequence index column. """ self._validate_sequence_index(column_name) - self._sequence_index = column_name + self.sequence_index = column_name def _validate_sequence_index_not_in_sequence_key(self): """Check that ``_sequence_index`` and ``_sequence_key`` don't overlap.""" - seq_key = self._sequence_key + seq_key = self.sequence_key sequence_key = set(cast_to_iterable(seq_key)) - if self._sequence_index in sequence_key or seq_key is None: - index = {self._sequence_index} + if self.sequence_index in sequence_key or seq_key is None: + index = {self.sequence_index} raise InvalidMetadataError( f"'sequence_index' and 'sequence_key' have the same value {index}." ' These columns must be different.' @@ -450,16 +450,16 @@ def validate(self): """ errors = [] # Validate keys - self._append_error(errors, self._validate_key, self._primary_key, 'primary') - self._append_error(errors, self._validate_key, self._sequence_key, 'sequence') - if self._sequence_index: - self._append_error(errors, self._validate_sequence_index, self._sequence_index) + self._append_error(errors, self._validate_key, self.primary_key, 'primary') + self._append_error(errors, self._validate_key, self.sequence_key, 'sequence') + if self.sequence_index: + self._append_error(errors, self._validate_sequence_index, self.sequence_index) self._append_error(errors, self._validate_sequence_index_not_in_sequence_key) - self._append_error(errors, self._validate_alternate_keys, self._alternate_keys) + self._append_error(errors, self._validate_alternate_keys, self.alternate_keys) # Validate columns - for column, kwargs in self._columns.items(): + for column, kwargs in self.columns.items(): self._append_error(errors, self._validate_column, column, **kwargs) if errors: @@ -499,7 +499,7 @@ def _load_from_dict(cls, metadata): for key in instance._KEYS: value = deepcopy(metadata.get(key)) if value: - setattr(instance, f'_{key}', value) + setattr(instance, f'{key}', value) return instance diff --git a/sdv/multi_table/base.py b/sdv/multi_table/base.py index 3a14d30a3..de8f3e202 100644 --- a/sdv/multi_table/base.py +++ b/sdv/multi_table/base.py @@ -44,7 +44,7 @@ def _set_temp_numpy_seed(self): np.random.set_state(initial_state) def _initialize_models(self): - for table_name, table_metadata in self.metadata._tables.items(): + for table_name, table_metadata in self.metadata.tables.items(): synthesizer_parameters = self._table_parameters.get(table_name, {}) self._table_synthesizers[table_name] = self._synthesizer( metadata=table_metadata, @@ -101,7 +101,7 @@ def set_table_parameters(self, table_name, table_parameters): the table's synthesizer. """ self._table_synthesizers[table_name] = self._synthesizer( - metadata=self.metadata._tables[table_name], + metadata=self.metadata.tables[table_name], **table_parameters ) self._table_parameters[table_name].update(deepcopy(table_parameters)) @@ -112,7 +112,7 @@ def get_metadata(self): def _get_all_foreign_keys(self, table_name): foreign_keys = [] - for relation in self.metadata._relationships: + for relation in self.metadata.relationships: if table_name == relation['child_table_name']: foreign_keys.append(deepcopy(relation['child_foreign_key'])) @@ -121,7 +121,7 @@ def _get_all_foreign_keys(self, table_name): def _validate_foreign_keys(self, data): error_msg = None errors = [] - for relation in self.metadata._relationships: + for relation in self.metadata.relationships: child_table = data.get(relation['child_table_name']) parent_table = data.get(relation['parent_table_name']) @@ -169,7 +169,7 @@ def validate(self, data): * values of a column don't satisfy their sdtype """ errors = [] - missing_tables = set(self.metadata._tables) - set(data) + missing_tables = set(self.metadata.tables) - set(data) if missing_tables: errors.append(f'The provided data is missing the tables {missing_tables}.') diff --git a/sdv/multi_table/hma.py b/sdv/multi_table/hma.py index dd2bd6ce9..f653714c8 100644 --- a/sdv/multi_table/hma.py +++ b/sdv/multi_table/hma.py @@ -30,7 +30,7 @@ def __init__(self, metadata, synthesizer_kwargs=None): self._table_sizes = {} self._max_child_rows = {} self._modeled_tables = [] - for table_name in self.metadata._tables: + for table_name in self.metadata.tables: self.set_table_parameters(table_name, self._synthesizer_kwargs) def _get_extension(self, child_name, child_table, foreign_key): @@ -88,7 +88,7 @@ def _get_extension(self, child_name, child_table, foreign_key): def _get_foreign_keys(self, table_name, child_name): foreign_keys = [] - for relation in self.metadata._relationships: + for relation in self.metadata.relationships: if table_name == relation['parent_table_name'] and\ child_name == relation['child_table_name']: foreign_keys.append(deepcopy(relation['child_foreign_key'])) @@ -314,7 +314,7 @@ def _sample_rows(self, synthesizer, table_name, num_rows=None): def _get_child_synthesizer(self, parent_row, table_name, foreign_key): parameters = self._extract_parameters(parent_row, table_name, foreign_key) - table_meta = self.metadata._tables[table_name] + table_meta = self.metadata.tables[table_name] synthesizer = self._synthesizer(table_meta, **self._synthesizer_kwargs) synthesizer._set_parameters(parameters) @@ -340,7 +340,7 @@ def _sample_child_rows(self, table_name, parent_name, parent_row, sampled_data): table_rows = self._sample_rows(synthesizer, table_name) if len(table_rows): - parent_key = self.metadata._tables[parent_name]._primary_key + parent_key = self.metadata.tables[parent_name].primary_key table_rows[foreign_key] = parent_row[parent_key] previous = sampled_data.get(table_name) @@ -473,7 +473,7 @@ def _find_parent_ids(self, table_name, parent_name, foreign_key, sampled_data): parent_model = self._table_synthesizers[parent_name] parent_rows = self._sample_rows(parent_model, parent_name, num_parent_rows) - primary_key = self.metadata._tables[parent_name]._primary_key + primary_key = self.metadata.tables[parent_name].primary_key parent_rows = parent_rows.set_index(primary_key) num_rows = parent_rows[f'__{table_name}__{foreign_key}__num_rows'].fillna(0).clip(0) @@ -527,7 +527,7 @@ def _sample(self, scale=1.0): A ``NotFittedError`` is raised when the ``SDV`` instance has not been fitted yet. """ sampled_data = {} - for table in self.metadata._tables: + for table in self.metadata.tables: if not self.metadata._get_parent_map().get(table): self._sample_table(table, scale=scale, sampled_data=sampled_data) diff --git a/sdv/sequential/par.py b/sdv/sequential/par.py index 2b35b5d49..4127f9ba8 100644 --- a/sdv/sequential/par.py +++ b/sdv/sequential/par.py @@ -71,7 +71,7 @@ def _get_context_metadata(self): context_columns += self._sequence_key for column in context_columns: - context_columns_dict[column] = self.metadata._columns[column] + context_columns_dict[column] = self.metadata.columns[column] context_metadata_dict = {'columns': context_columns_dict} return SingleTableMetadata._load_from_dict(context_metadata_dict) @@ -85,7 +85,7 @@ def __init__(self, metadata, enforce_min_max_values=True, enforce_rounding=False enforce_rounding=enforce_rounding, ) - sequence_key = self.metadata._sequence_key + sequence_key = self.metadata.sequence_key self._sequence_key = list(cast_to_iterable(sequence_key)) if sequence_key else None if context_columns and not self._sequence_key: raise SynthesizerInputError( @@ -93,7 +93,7 @@ def __init__(self, metadata, enforce_min_max_values=True, enforce_rounding=False "model 'context_columns' in this case." ) - self._sequence_index = self.metadata._sequence_index + self._sequence_index = self.metadata.sequence_index self.context_columns = context_columns or [] self.segment_size = segment_size self._model_kwargs = { diff --git a/sdv/single_table/base.py b/sdv/single_table/base.py index d195bc282..d7509eaa3 100644 --- a/sdv/single_table/base.py +++ b/sdv/single_table/base.py @@ -77,7 +77,7 @@ def __init__(self, metadata, enforce_min_max_values=True, enforce_rounding=True) def _validate_metadata_matches_data(self, columns): errors = [] - metadata_columns = self.metadata._columns or [] + metadata_columns = self.metadata.columns or [] missing_data_columns = set(columns).difference(metadata_columns) if missing_data_columns: errors.append( @@ -94,18 +94,18 @@ def _validate_metadata_matches_data(self, columns): raise InvalidDataError(errors) def _get_primary_and_alternate_keys(self): - keys = set(self.metadata._alternate_keys) - if self.metadata._primary_key: - keys.update({self.metadata._primary_key}) + keys = set(self.metadata.alternate_keys) + if self.metadata.primary_key: + keys.update({self.metadata.primary_key}) return keys def _get_set_of_sequence_keys(self): - if isinstance(self.metadata._sequence_key, tuple): - return set(self.metadata._sequence_key) + if isinstance(self.metadata.sequence_key, tuple): + return set(self.metadata.sequence_key) - if isinstance(self.metadata._sequence_key, str): - return {self.metadata._sequence_key} + if isinstance(self.metadata.sequence_key, str): + return {self.metadata.sequence_key} return set() @@ -150,7 +150,7 @@ def _validate_sdtype(self, sdtype, column, validation): def _validate_column(self, column): """Validate values of the column satisfy its sdtype properties.""" errors = [] - sdtype = self.metadata._columns[column.name]['sdtype'] + sdtype = self.metadata.columns[column.name]['sdtype'] # boolean values must be True/False, None or missing values # int/str are not allowed @@ -242,7 +242,7 @@ def _warn_for_update_transformers(self, column_name_to_transformer): Dict mapping column names to transformers to be used for that column. """ for column in column_name_to_transformer: - sdtype = self.metadata._columns[column]['sdtype'] + sdtype = self.metadata.columns[column]['sdtype'] if sdtype in {'categorical', 'boolean'}: warnings.warn( f"Replacing the default transformer for column '{column}' " @@ -354,7 +354,7 @@ def get_transformers(self): field_transformers = { column_name: field_transformers.get(column_name) - for column_name in self.metadata._columns + for column_name in self.metadata.columns } return field_transformers diff --git a/sdv/single_table/copulagan.py b/sdv/single_table/copulagan.py index f98aec237..062929e1e 100644 --- a/sdv/single_table/copulagan.py +++ b/sdv/single_table/copulagan.py @@ -140,7 +140,7 @@ def __init__(self, metadata, enforce_min_max_values=True, enforce_rounding=True, cuda=cuda, ) - validate_numerical_distributions(numerical_distributions, self.metadata._columns) + validate_numerical_distributions(numerical_distributions, self.metadata.columns) self.numerical_distributions = numerical_distributions or {} self.default_distribution = default_distribution or 'beta' @@ -153,7 +153,7 @@ def __init__(self, metadata, enforce_min_max_values=True, enforce_rounding=True, } def _create_gaussian_normalizer_config(self, processed_data): - columns = self.metadata._columns + columns = self.metadata.columns transformers = {} sdtypes = {} for column in processed_data.columns: diff --git a/sdv/single_table/copulas.py b/sdv/single_table/copulas.py index f7d1aa4ad..b8825b458 100644 --- a/sdv/single_table/copulas.py +++ b/sdv/single_table/copulas.py @@ -95,7 +95,7 @@ def __init__(self, metadata, enforce_min_max_values=True, enforce_rounding=True, enforce_min_max_values=enforce_min_max_values, enforce_rounding=enforce_rounding, ) - validate_numerical_distributions(numerical_distributions, self.metadata._columns) + validate_numerical_distributions(numerical_distributions, self.metadata.columns) self.numerical_distributions = numerical_distributions or {} self.default_distribution = default_distribution or 'beta' @@ -139,7 +139,7 @@ def _warn_for_update_transformers(self, column_name_to_transformer): Dict mapping column names to transformers to be used for that column. """ for column, transformer in column_name_to_transformer.items(): - sdtype = self.metadata._columns[column]['sdtype'] + sdtype = self.metadata.columns[column]['sdtype'] if sdtype == 'categorical' and isinstance(transformer, OneHotEncoder): warnings.warn( f"Using a OneHotEncoder transformer for column '{column}' " @@ -166,7 +166,7 @@ def _sample(self, num_rows, conditions=None): def _get_valid_columns_from_metadata(self, columns): valid_columns = [] for column in columns: - for valid_column in self.metadata._columns: + for valid_column in self.metadata.columns: if column.startswith(valid_column): valid_columns.append(column) break diff --git a/sdv/single_table/utils.py b/sdv/single_table/utils.py index 081dca685..69b36a720 100644 --- a/sdv/single_table/utils.py +++ b/sdv/single_table/utils.py @@ -29,8 +29,8 @@ def detect_discrete_columns(metadata, data): discrete_columns = [] for column in data.columns: - if column in metadata._columns: - if metadata._columns[column]['sdtype'] not in ['numerical', 'datetime']: + if column in metadata.columns: + if metadata.columns[column]['sdtype'] not in ['numerical', 'datetime']: discrete_columns.append(column) else: diff --git a/tests/integration/metadata/test_multi_table.py b/tests/integration/metadata/test_multi_table.py index bf12771a4..dd5cb5877 100644 --- a/tests/integration/metadata/test_multi_table.py +++ b/tests/integration/metadata/test_multi_table.py @@ -22,8 +22,8 @@ def test_multi_table_metadata(): 'relationships': [], 'METADATA_SPEC_VERSION': 'MULTI_TABLE_V1' } - assert instance._tables == {} - assert instance._relationships == [] + assert instance.tables == {} + assert instance.relationships == [] def test_upgrade_metadata(): diff --git a/tests/integration/metadata/test_single_table.py b/tests/integration/metadata/test_single_table.py index 8f26fda5d..0ee9465bb 100644 --- a/tests/integration/metadata/test_single_table.py +++ b/tests/integration/metadata/test_single_table.py @@ -24,12 +24,12 @@ def test_single_table_metadata(): assert result == { 'METADATA_SPEC_VERSION': 'SINGLE_TABLE_V1' } - assert instance._columns == {} + assert instance.columns == {} assert instance._version == 'SINGLE_TABLE_V1' - assert instance._primary_key is None - assert instance._sequence_key is None - assert instance._alternate_keys == [] - assert instance._sequence_index is None + assert instance.primary_key is None + assert instance.sequence_key is None + assert instance.alternate_keys == [] + assert instance.sequence_index is None def test_validate(): @@ -54,7 +54,7 @@ def test_validate_errors(): """Test ``SingleTableMetadata.validate`` raises the correct errors.""" # Setup instance = SingleTableMetadata() - instance._columns = { + instance.columns = { 'col1': {'sdtype': 'numerical'}, 'col2': {'sdtype': 'numerical'}, 'col4': {'sdtype': 'categorical', 'invalid1': 'value'}, @@ -65,10 +65,10 @@ def test_validate_errors(): 'col9': {'sdtype': 'datetime', 'datetime_format': '%1-%Y-%m-%d-%'}, 'col10': {'sdtype': 'text', 'regex_format': '[A-{6}'}, } - instance._primary_key = 10 - instance._alternate_keys = 'col1' - instance._sequence_key = ('col3', 'col1') - instance._sequence_index = 'col3' + instance.primary_key = 10 + instance.alternate_keys = 'col1' + instance.sequence_key = ('col3', 'col1') + instance.sequence_index = 'col3' err_msg = re.escape( 'The following errors were found in the metadata:' diff --git a/tests/integration/multi_table/test_hma.py b/tests/integration/multi_table/test_hma.py index bd3d8654b..a0d3f17aa 100644 --- a/tests/integration/multi_table/test_hma.py +++ b/tests/integration/multi_table/test_hma.py @@ -53,8 +53,8 @@ def test_hma_reset_sampling(tmpdir): sdtype='ssn', ) data['characters']['ssn'] = [faker.lexify() for _ in range(len(data['characters']))] - for table in metadata._tables.values(): - table._alternate_keys = [] + for table in metadata.tables.values(): + table.alternate_keys = [] hmasynthesizer = HMASynthesizer(metadata) diff --git a/tests/unit/constraints/test_tabular.py b/tests/unit/constraints/test_tabular.py index 189c432ee..c8930a3b6 100644 --- a/tests/unit/constraints/test_tabular.py +++ b/tests/unit/constraints/test_tabular.py @@ -163,7 +163,7 @@ def test__validate_metadata_columns(self): # Setup constraint_class = create_custom_constraint_class(sorted, sorted, sorted) metadata = Mock() - metadata._columns = {'a': 1, 'b': 2} + metadata.columns = {'a': 1, 'b': 2} # Run constraint_class._validate_metadata_columns(metadata, column_names=['a', 'b']) @@ -185,7 +185,7 @@ def test__validate_metadata_columns_raises_error(self): # Setup constraint_class = create_custom_constraint_class(sorted, sorted, sorted) metadata = Mock() - metadata._columns = {'a': 1, 'b': 2} + metadata.columns = {'a': 1, 'b': 2} # Run error_message = re.escape( @@ -504,7 +504,7 @@ def test__validate_metadata_columns(self): """ # Setup metadata = Mock() - metadata._columns = {'a': 1, 'b': 2} + metadata.columns = {'a': 1, 'b': 2} # Run FixedCombinations._validate_metadata_columns(metadata, column_names=['a', 'b']) @@ -522,7 +522,7 @@ def test__validate_metadata_columns_raises_error(self): """ # Setup metadata = Mock() - metadata._columns = {'a': 1, 'b': 2} + metadata.columns = {'a': 1, 'b': 2} # Run error_message = re.escape( @@ -942,7 +942,7 @@ def test__validate_metadata_columns(self): """ # Setup metadata = Mock() - metadata._columns = {'a': 1, 'b': 2} + metadata.columns = {'a': 1, 'b': 2} # Run Inequality._validate_metadata_columns(metadata, low_column_name='a', high_column_name='b') @@ -960,7 +960,7 @@ def test__validate_metadata_columns_raises_error(self): """ # Setup metadata = Mock() - metadata._columns = {'a': 1, 'b': 2} + metadata.columns = {'a': 1, 'b': 2} # Run error_message = re.escape( @@ -982,7 +982,7 @@ def test__validate_metadata_specific_to_constraint_datetime(self): """ # Setup metadata = Mock() - metadata._columns = {'a': {'sdtype': 'datetime'}, 'b': {'sdtype': 'datetime'}} + metadata.columns = {'a': {'sdtype': 'datetime'}, 'b': {'sdtype': 'datetime'}} # Run Inequality._validate_metadata_specific_to_constraint( @@ -1002,7 +1002,7 @@ def test__validate_metadata_specific_to_constraint_datetime_error(self): """ # Setup metadata = Mock() - metadata._columns = {'a': {'sdtype': 'datetime'}, 'b': {'sdtype': 'categorical'}} + metadata.columns = {'a': {'sdtype': 'datetime'}, 'b': {'sdtype': 'categorical'}} # Run error_message = re.escape( @@ -1027,7 +1027,7 @@ def test__validate_metadata_specific_to_constraint_numerical(self): """ # Setup metadata = Mock() - metadata._columns = {'a': {'sdtype': 'numerical'}, 'b': {'sdtype': 'numerical'}} + metadata.columns = {'a': {'sdtype': 'numerical'}, 'b': {'sdtype': 'numerical'}} # Run Inequality._validate_metadata_specific_to_constraint( @@ -1047,7 +1047,7 @@ def test__validate_metadata_specific_to_constraint_numerical_error(self): """ # Setup metadata = Mock() - metadata._columns = {'a': {'sdtype': 'numerical'}, 'b': {'sdtype': 'categorical'}} + metadata.columns = {'a': {'sdtype': 'numerical'}, 'b': {'sdtype': 'categorical'}} # Run error_message = re.escape( @@ -1642,7 +1642,7 @@ def test__validate_metadata_columns(self): """ # Setup metadata = Mock() - metadata._columns = {'a': 1, 'b': 2} + metadata.columns = {'a': 1, 'b': 2} # Run ScalarInequality._validate_metadata_columns(metadata, column_name='a') @@ -1660,7 +1660,7 @@ def test__validate_metadata_columns_raises_error(self): """ # Setup metadata = Mock() - metadata._columns = {'a': 1, 'b': 2} + metadata.columns = {'a': 1, 'b': 2} # Run error_message = re.escape( @@ -1684,7 +1684,7 @@ def test__validate_metadata_specific_to_constraint_numerical(self): """ # Setup metadata = Mock() - metadata._columns = {'a': {'sdtype': 'numerical'}} + metadata.columns = {'a': {'sdtype': 'numerical'}} # Run ScalarInequality._validate_metadata_specific_to_constraint( @@ -1708,7 +1708,7 @@ def test__validate_metadata_specific_to_constraint_numerical_error(self): """ # Setup metadata = Mock() - metadata._columns = {'a': {'sdtype': 'numerical'}} + metadata.columns = {'a': {'sdtype': 'numerical'}} # Run error_message = "'value' must be an int or float." @@ -1736,7 +1736,7 @@ def test__validate_metadata_specific_to_constraint_datetime(self, datetime_forma """ # Setup metadata = Mock() - metadata._columns = {'a': {'sdtype': 'datetime', 'datetime_format': 'm/d/y'}} + metadata.columns = {'a': {'sdtype': 'datetime', 'datetime_format': 'm/d/y'}} datetime_format_mock.return_value = True # Run @@ -1766,7 +1766,7 @@ def test__validate_metadata_specific_to_constraint_datetime_error(self, datetime """ # Setup metadata = Mock() - metadata._columns = {'a': {'sdtype': 'datetime', 'datetime_format': 'm/d/y'}} + metadata.columns = {'a': {'sdtype': 'datetime', 'datetime_format': 'm/d/y'}} datetime_format_mock.return_value = False # Run @@ -1796,7 +1796,7 @@ def test__validate_metadata_specific_to_constraint_bad_type(self): """ # Setup metadata = Mock() - metadata._columns = {'a': {'sdtype': 'categorical'}} + metadata.columns = {'a': {'sdtype': 'categorical'}} # Run error_message = ( @@ -2352,7 +2352,7 @@ def test__validate_metadata_columns(self): """ # Setup metadata = Mock() - metadata._columns = {'a': 1, 'b': 2} + metadata.columns = {'a': 1, 'b': 2} # Run Positive._validate_metadata_columns(metadata, column_name='a') @@ -2370,7 +2370,7 @@ def test__validate_metadata_columns_raises_error(self): """ # Setup metadata = Mock() - metadata._columns = {'a': 1, 'b': 2} + metadata.columns = {'a': 1, 'b': 2} # Run error_message = re.escape( @@ -2391,7 +2391,7 @@ def test__validate_metadata_specific_to_constraint(self): """ # Setup metadata = Mock() - metadata._columns = { + metadata.columns = { 'a': {'sdtype': 'numerical'} } @@ -2415,7 +2415,7 @@ def test__validate_metadata_specific_to_constraint_error(self): """ # Setup metadata = Mock() - metadata._columns = { + metadata.columns = { 'a': {'sdtype': 'datetime'} } @@ -2493,7 +2493,7 @@ def test__validate_metadata_columns(self): """ # Setup metadata = Mock() - metadata._columns = {'a': 1, 'b': 2} + metadata.columns = {'a': 1, 'b': 2} # Run Negative._validate_metadata_columns(metadata, column_name='a') @@ -2511,7 +2511,7 @@ def test__validate_metadata_columns_raises_error(self): """ # Setup metadata = Mock() - metadata._columns = {'a': 1, 'b': 2} + metadata.columns = {'a': 1, 'b': 2} # Run error_message = re.escape( @@ -2532,7 +2532,7 @@ def test__validate_metadata_specific_to_constraint(self): """ # Setup metadata = Mock() - metadata._columns = { + metadata.columns = { 'a': {'sdtype': 'numerical'} } @@ -2556,7 +2556,7 @@ def test__validate_metadata_specific_to_constraint_error(self): """ # Setup metadata = Mock() - metadata._columns = { + metadata.columns = { 'a': {'sdtype': 'datetime'} } @@ -2637,7 +2637,7 @@ def test__validate_metadata_columns(self): """ # Setup metadata = Mock() - metadata._columns = {'a': 1, 'b': 2, 'c': 3} + metadata.columns = {'a': 1, 'b': 2, 'c': 3} # Run Range._validate_metadata_columns( @@ -2660,7 +2660,7 @@ def test__validate_metadata_columns_raises_error(self): """ # Setup metadata = Mock() - metadata._columns = {'a': 1, 'b': 2} + metadata.columns = {'a': 1, 'b': 2} # Run error_message = re.escape( @@ -2686,7 +2686,7 @@ def test__validate_metadata_specific_to_constraint_datetime(self): """ # Setup metadata = Mock() - metadata._columns = { + metadata.columns = { 'a': {'sdtype': 'datetime'}, 'b': {'sdtype': 'datetime'}, 'c': {'sdtype': 'datetime'} @@ -2712,7 +2712,7 @@ def test__validate_metadata_specific_to_constraint_datetime_error(self): """ # Setup metadata = Mock() - metadata._columns = { + metadata.columns = { 'a': {'sdtype': 'datetime'}, 'b': {'sdtype': 'datetime'}, 'c': {'sdtype': 'numerical'} @@ -2742,7 +2742,7 @@ def test__validate_metadata_specific_to_constraint_numerical(self): """ # Setup metadata = Mock() - metadata._columns = { + metadata.columns = { 'a': {'sdtype': 'numerical'}, 'b': {'sdtype': 'numerical'}, 'c': {'sdtype': 'numerical'} @@ -2768,7 +2768,7 @@ def test__validate_metadata_specific_to_constraint_numerical_error(self): """ # Setup metadata = Mock() - metadata._columns = { + metadata.columns = { 'a': {'sdtype': 'numerical'}, 'b': {'sdtype': 'numerical'}, 'c': {'sdtype': 'datetime'} @@ -3240,7 +3240,7 @@ def test__validate_metadata_columns(self): """ # Setup metadata = Mock() - metadata._columns = {'a': 1, 'b': 2} + metadata.columns = {'a': 1, 'b': 2} # Run ScalarRange._validate_metadata_columns(metadata, column_name='a') @@ -3258,7 +3258,7 @@ def test__validate_metadata_columns_raises_error(self): """ # Setup metadata = Mock() - metadata._columns = {'a': 1, 'b': 2} + metadata.columns = {'a': 1, 'b': 2} # Run error_message = re.escape( @@ -3282,7 +3282,7 @@ def test__validate_metadata_specific_to_constraint_numerical(self): """ # Setup metadata = Mock() - metadata._columns = {'a': {'sdtype': 'numerical'}} + metadata.columns = {'a': {'sdtype': 'numerical'}} # Run ScalarRange._validate_metadata_specific_to_constraint( @@ -3306,7 +3306,7 @@ def test__validate_metadata_specific_to_constraint_numerical_high_not_numerical_ """ # Setup metadata = Mock() - metadata._columns = {'a': {'sdtype': 'numerical'}} + metadata.columns = {'a': {'sdtype': 'numerical'}} # Run error_message = "Both 'high_value' and 'low_value' must be ints or floats" @@ -3332,7 +3332,7 @@ def test__validate_metadata_specific_to_constraint_numerical_low_not_numerical_e """ # Setup metadata = Mock() - metadata._columns = {'a': {'sdtype': 'numerical'}} + metadata.columns = {'a': {'sdtype': 'numerical'}} # Run error_message = "Both 'high_value' and 'low_value' must be ints or floats" @@ -3361,7 +3361,7 @@ def test__validate_metadata_specific_to_constraint_datetime(self, datetime_forma """ # Setup metadata = Mock() - metadata._columns = {'a': {'sdtype': 'datetime', 'datetime_format': 'm/d/y'}} + metadata.columns = {'a': {'sdtype': 'datetime', 'datetime_format': 'm/d/y'}} datetime_format_mock.return_value = True # Run @@ -3393,7 +3393,7 @@ def test__validate_metadata_specific_to_constraint_high_datetime_error( """ # Setup metadata = Mock() - metadata._columns = {'a': {'sdtype': 'datetime', 'datetime_format': 'm/d/y'}} + metadata.columns = {'a': {'sdtype': 'datetime', 'datetime_format': 'm/d/y'}} datetime_format_mock.side_effect = [False, True] # Run @@ -3429,7 +3429,7 @@ def test__validate_metadata_specific_to_constraint_low_datetime_error( """ # Setup metadata = Mock() - metadata._columns = {'a': {'sdtype': 'datetime', 'datetime_format': 'm/d/y'}} + metadata.columns = {'a': {'sdtype': 'datetime', 'datetime_format': 'm/d/y'}} datetime_format_mock.side_effect = [True, False] # Run @@ -3461,7 +3461,7 @@ def test__validate_metadata_specific_to_constraint_bad_type(self): """ # Setup metadata = Mock() - metadata._columns = {'a': {'sdtype': 'categorical'}} + metadata.columns = {'a': {'sdtype': 'categorical'}} # Run error_message = ( @@ -3946,7 +3946,7 @@ def test__validate_metadata_columns(self): """ # Setup metadata = Mock() - metadata._columns = {'a': 1, 'b': 2} + metadata.columns = {'a': 1, 'b': 2} # Run OneHotEncoding._validate_metadata_columns(metadata, column_names=['a', 'b']) @@ -3964,7 +3964,7 @@ def test__validate_metadata_columns_raises_error(self): """ # Setup metadata = Mock() - metadata._columns = {'a': 1, 'b': 2} + metadata.columns = {'a': 1, 'b': 2} # Run error_message = re.escape( @@ -4069,7 +4069,7 @@ def test__validate_metadata_columns(self): """ # Setup metadata = Mock() - metadata._columns = {'a': 1, 'b': 2} + metadata.columns = {'a': 1, 'b': 2} # Run Unique._validate_metadata_columns(metadata, column_names=['a', 'b']) @@ -4087,7 +4087,7 @@ def test__validate_metadata_columns_raises_error(self): """ # Setup metadata = Mock() - metadata._columns = {'a': 1, 'b': 2} + metadata.columns = {'a': 1, 'b': 2} # Run error_message = re.escape( @@ -4109,8 +4109,8 @@ def test__validate_metadata_specific_to_constraint(self): """ # Setup metadata = Mock() - metadata._primary_key = 'a' - metadata._alternate_keys = ['b', 'c'] + metadata.primary_key = 'a' + metadata.alternate_keys = ['b', 'c'] # Run Unique._validate_metadata_specific_to_constraint( @@ -4130,8 +4130,8 @@ def test__validate_metadata_specific_to_constraint_error(self): """ # Setup metadata = Mock() - metadata._primary_key = 'a' - metadata._alternate_keys = [('b', 'c'), 'd'] + metadata.primary_key = 'a' + metadata.alternate_keys = [('b', 'c'), 'd'] # Run error_message = re.escape( @@ -4432,7 +4432,7 @@ def test__validate_metadata_columns(self): """ # Setup metadata = Mock() - metadata._columns = {'a': 1, 'b': 2} + metadata.columns = {'a': 1, 'b': 2} # Run FixedIncrements._validate_metadata_columns(metadata, column_name='a') @@ -4450,7 +4450,7 @@ def test__validate_metadata_columns_raises_error(self): """ # Setup metadata = Mock() - metadata._columns = {'a': 1, 'b': 2} + metadata.columns = {'a': 1, 'b': 2} # Run error_message = re.escape( diff --git a/tests/unit/data_processing/test_data_processor.py b/tests/unit/data_processing/test_data_processor.py index 04390b779..779bc226a 100644 --- a/tests/unit/data_processing/test_data_processor.py +++ b/tests/unit/data_processing/test_data_processor.py @@ -111,7 +111,7 @@ def test___init___without_mocks(self): # Assert assert isinstance(instance.metadata, SingleTableMetadata) - assert instance.metadata._columns == {'col': {'sdtype': 'numerical'}} + assert instance.metadata.columns == {'col': {'sdtype': 'numerical'}} def test_filter_valid(self): """Test that we are calling the ``filter_valid`` of each constraint over the data.""" @@ -1137,10 +1137,10 @@ def test__create_config(self): dp.create_key_transformer = Mock() dp.create_anonymized_transformer.return_value = 'AnonymizedFaker' dp.create_key_transformer.return_value = 'RegexGenerator' - dp.metadata._primary_key = 'id' + dp.metadata.primary_key = 'id' dp._primary_key = 'id' dp._keys = ['id'] - dp.metadata._columns = { + dp.metadata.columns = { 'int': {'sdtype': 'numerical'}, 'float': {'sdtype': 'numerical'}, 'bool': {'sdtype': 'boolean'}, @@ -1914,7 +1914,7 @@ def test_reverse_transform(self): dp = DataProcessor(SingleTableMetadata()) dp.fitted = True dp.metadata = Mock() - dp.metadata._columns = {'a': None, 'b': None, 'c': None, 'd': None} + dp.metadata.columns = {'a': None, 'b': None, 'c': None, 'd': None} data = pd.DataFrame({ 'a': [1, 2, 3], 'b': [True, True, False], @@ -1980,7 +1980,7 @@ def test_reverse_transform_hyper_transformer_errors(self, log_mock): dp = DataProcessor(SingleTableMetadata(), table_name='table_name') dp.fitted = True dp.metadata = Mock() - dp.metadata._columns = {'a': None, 'b': None, 'c': None} + dp.metadata.columns = {'a': None, 'b': None, 'c': None} data = pd.DataFrame({ 'a': [1, 2, 3], 'b': [True, True, False], @@ -2060,7 +2060,7 @@ def test_reverse_transform_integer_rounding(self): dp._constraints_to_reverse = [] dp._dtypes = {'bar': 'int'} dp.metadata = Mock() - dp.metadata._columns = {'bar': None} + dp.metadata.columns = {'bar': None} # Run output = dp.reverse_transform(data) diff --git a/tests/unit/metadata/test_multi_table.py b/tests/unit/metadata/test_multi_table.py index 0b3555a0a..dc3d25a8b 100644 --- a/tests/unit/metadata/test_multi_table.py +++ b/tests/unit/metadata/test_multi_table.py @@ -83,8 +83,8 @@ def test___init__(self): instance = MultiTableMetadata() # Assert - assert instance._tables == {} - assert instance._relationships == [] + assert instance.tables == {} + assert instance.relationships == [] def test__validate_missing_relationship_keys_foreign_key(self): """Test the ``_validate_missing_relationship_keys`` method of ``MultiTableMetadata``. @@ -111,8 +111,8 @@ def test__validate_missing_relationship_keys_foreign_key(self): """ # Setup parent_table = Mock() - parent_table._primary_key = 'id' - parent_table._columns = { + parent_table.primary_key = 'id' + parent_table.columns = { 'id': {'sdtype': 'numerical'}, 'session': {'sdtype': 'numerical'}, 'transactions': {'sdtype': 'numerical'}, @@ -121,8 +121,8 @@ def test__validate_missing_relationship_keys_foreign_key(self): parent_primary_key = 'id' child_table = Mock() - child_table._primary_key = 'session_id' - child_table._columns = { + child_table.primary_key = 'session_id' + child_table.columns = { 'user_id': {'sdtype': 'numerical'}, 'session_id': {'sdtype': 'numerical'}, 'transactions': {'sdtype': 'numerical'}, @@ -131,7 +131,7 @@ def test__validate_missing_relationship_keys_foreign_key(self): child_foreign_key = 'id' instance = Mock() - instance._tables = { + instance.tables = { 'users': parent_table, 'sessions': child_table, } @@ -154,7 +154,7 @@ def test__validate_missing_relationship_keys_primary_key(self): """Test the ``_validate_missing_relationship_keys`` method of ``MultiTableMetadata``. Test that when the provided ``child_foreign_key`` key is not in the - ``parent_table._columns``, this raises an error. + ``parent_table.columns``, this raises an error. Setup: - Create ``parent_table``. @@ -168,14 +168,14 @@ def test__validate_missing_relationship_keys_primary_key(self): - ``parent_table_name`` string. - ``parent_primary_key`` a string that is the parent primary key. - ``child_table_name`` string. - - ``child_foreign_key`` a string that is not in the ``parent_table._columns``. + - ``child_foreign_key`` a string that is not in the ``parent_table.columns``. Side Effects: - Raises ``InvalidMetadataError`` stating that primary key is unknown. """ # Setup parent_table = Mock() - parent_table._primary_key = 'users_id' + parent_table.primary_key = 'users_id' parent_table_name = 'users' parent_primary_key = 'primary_key' child_table_name = 'sessions' @@ -279,8 +279,8 @@ def test__validate_relationship_sdtype(self): """ # Setup parent_table = Mock() - parent_table._primary_key = 'id' - parent_table._columns = { + parent_table.primary_key = 'id' + parent_table.columns = { 'id': {'sdtype': 'numerical'}, 'user_name': {'sdtype': 'categorical'}, 'transactions': {'sdtype': 'numerical'}, @@ -289,8 +289,8 @@ def test__validate_relationship_sdtype(self): parent_primary_key = ['id', 'user_name'] child_table = Mock() - child_table._primary_key = 'session_id' - child_table._columns = { + child_table.primary_key = 'session_id' + child_table.columns = { 'user_id': {'sdtype': 'numerical'}, 'session_id': {'sdtype': 'numerical'}, 'timestamp': {'sdtype': 'datetime'}, @@ -299,7 +299,7 @@ def test__validate_relationship_sdtype(self): child_foreign_key = ['user_id', 'timestamp'] instance = Mock() - instance._tables = { + instance.tables = { 'users': parent_table, 'sessions': child_table, } @@ -322,7 +322,7 @@ def test__validate_relationship_does_not_exist(self): """Test the method raises an error if an existing relationship is added.""" # Setup metadata = MultiTableMetadata() - metadata._relationships = [ + metadata.relationships = [ { 'parent_table_name': 'users', 'child_table_name': 'sessions', @@ -402,7 +402,7 @@ def test__validate_child_map_circular_relationship(self): # Setup instance = MultiTableMetadata() parent_table = Mock() - instance._tables = { + instance.tables = { 'users': parent_table, 'sessions': Mock(), 'transactions': Mock() @@ -447,26 +447,26 @@ def test__validate_relationship(self, instance = MultiTableMetadata() parent_table = Mock() - parent_table._primary_key = 'id' - parent_table._columns = { + parent_table.primary_key = 'id' + parent_table.columns = { 'id': {'sdtype': 'numerical'}, 'session': {'sdtype': 'numerical'}, 'transactions': {'sdtype': 'numerical'}, } child_table = Mock() - child_table._primary_key = 'session_id' - child_table._columns = { + child_table.primary_key = 'session_id' + child_table.columns = { 'user_id': {'sdtype': 'numerical'}, 'session_id': {'sdtype': 'numerical'}, 'transactions': {'sdtype': 'numerical'}, } - instance._tables = { + instance.tables = { 'users': parent_table, 'sessions': child_table, } - instance._relationships = [ + instance.relationships = [ { 'parent_table_name': 'users', 'child_table_name': 'sessions', @@ -483,7 +483,7 @@ def test__validate_relationship(self, # Assert mock_validate_no_missing_tables_in_relationship.assert_called_once_with( - 'users', 'sessions', instance._tables.keys()) + 'users', 'sessions', instance.tables.keys()) instance._validate_missing_relationship_keys.assert_called_once_with( 'users', 'id', 'sessions', 'user_id') mock_validate_relationship_key_length.assert_called_once_with( @@ -495,13 +495,13 @@ def test_add_relationship(self): """Test the ``add_relationship`` method of ``MultiTableMetadata``. Test that when passing a valid ``relationship`` this is being added to the - ``instance._relationships``. + ``instance.relationships``. Setup: - Instance of ``MultiTableMetadata``. - Mock of ``parent_table`` simulating a ``SingleTableMetadata``. - Mock of ``child_table`` simulating a ``SingleTableMetadata``. - - Add those to ``instance._tables``. + - Add those to ``instance.tables``. Mock: - Mock ``validate_child_map_circular_relationship``. @@ -513,29 +513,29 @@ def test_add_relationship(self): - ``child_foreign_key`` string representing the ``foreing_Key`` of the child table. Side Effects: - - ``instance._relationships`` has been updated. + - ``instance.relationships`` has been updated. """ # Setup instance = MultiTableMetadata() instance._validate_child_map_circular_relationship = Mock() instance._validate_relationship_does_not_exist = Mock() parent_table = Mock() - parent_table._primary_key = 'id' - parent_table._columns = { + parent_table.primary_key = 'id' + parent_table.columns = { 'id': {'sdtype': 'numerical'}, 'session': {'sdtype': 'numerical'}, 'transactions': {'sdtype': 'numerical'}, } child_table = Mock() - child_table._primary_key = 'session_id' - child_table._columns = { + child_table.primary_key = 'session_id' + child_table.columns = { 'user_id': {'sdtype': 'numerical'}, 'session_id': {'sdtype': 'numerical'}, 'transactions': {'sdtype': 'numerical'}, } - instance._tables = { + instance.tables = { 'users': parent_table, 'sessions': child_table, } @@ -544,7 +544,7 @@ def test_add_relationship(self): instance.add_relationship('users', 'sessions', 'id', 'user_id') # Assert - instance._relationships == [ + instance.relationships == [ { 'parent_table_name': 'users', 'child_table_name': 'sessions', @@ -580,7 +580,7 @@ def test_add_relationship_child_key_is_primary_key(self): def test__validate_single_table(self): """Test ``_validate_single_table``. - Test that ``_validate_single_table`` iterates over the ``self._tables`` items and + Test that ``_validate_single_table`` iterates over the ``self.tables`` items and calls their ``validate()`` method, catches the error if raised and parses it to ``MultiTableMetadata`` error message. @@ -605,8 +605,8 @@ def test__validate_single_table(self): instance = Mock() users_mock = Mock() - users_mock._columns = {} - instance._tables = { + users_mock.columns = {} + instance.tables = { 'accounts': table_accounts, 'users': users_mock } @@ -624,7 +624,7 @@ def test__validate_single_table(self): "Table 'users' has 0 columns. Use 'add_column' to specify its columns." ) assert errors == ['\n', expected_error_msg, empty_table_error_message] - instance._tables['users'].validate.assert_called_once() + instance.tables['users'].validate.assert_called_once() def test__validate_all_tables_connected_connected(self): """Test ``_validate_all_tables_connected``. @@ -646,7 +646,7 @@ def test__validate_all_tables_connected_connected(self): """ # Setup instance = Mock() - instance._tables = { + instance.tables = { 'users': Mock(), 'sessions': Mock(), 'transactions': Mock(), @@ -700,7 +700,7 @@ def test__validate_all_tables_connected_not_connected(self): """ # Setup instance = Mock() - instance._tables = { + instance.tables = { 'users': Mock(), 'sessions': Mock(), 'transactions': Mock(), @@ -755,7 +755,7 @@ def test__validate_all_tables_connected_multiple_not_connected(self): """ # Setup instance = Mock() - instance._tables = { + instance.tables = { 'users': Mock(), 'sessions': Mock(), 'transactions': Mock(), @@ -815,11 +815,11 @@ def test_validate_raises_errors(self): """ # Setup instance = self.get_metadata() - instance._tables['users']._primary_key = None - instance._tables['transactions']._columns['session_id']['sdtype'] = 'datetime' - instance._tables['payments']._columns['date']['sdtype'] = 'text' - instance._tables['payments']._columns['date']['regex_format'] = '[A-z{' - instance._relationships.pop(-1) + instance.tables['users'].primary_key = None + instance.tables['transactions'].columns['session_id']['sdtype'] = 'datetime' + instance.tables['payments'].columns['date']['sdtype'] = 'text' + instance.tables['payments'].columns['date']['regex_format'] = '[A-z{' + instance.relationships.pop(-1) # Run error_msg = re.escape( @@ -853,7 +853,7 @@ def test_validate_child_key_is_primary_key(self): metadata.detect_table_from_dataframe('table2', table) metadata.set_primary_key('table2', 'pk') - metadata._relationships = [ + metadata.relationships = [ { 'parent_table_name': 'table', 'parent_primary_key': 'pk', @@ -873,7 +873,7 @@ def test_validate_child_key_is_primary_key(self): @patch('sdv.metadata.multi_table.SingleTableMetadata') def test_add_table(self, table_metadata_mock): - """Test that the method adds the table name to ``instance._tables``.""" + """Test that the method adds the table name to ``instance.tables``.""" # Setup instance = MultiTableMetadata() @@ -881,7 +881,7 @@ def test_add_table(self, table_metadata_mock): instance.add_table('users') # Assert - assert instance._tables == {'users': table_metadata_mock.return_value} + assert instance.tables == {'users': table_metadata_mock.return_value} def test_add_table_empty_string(self): """Test that the method raises an error if the table name is an empty string.""" @@ -911,7 +911,7 @@ def test_add_table_table_already_exists(self): """Test that the method raises an error if the table already exists.""" # Setup instance = MultiTableMetadata() - instance._tables = {'users': Mock()} + instance.tables = {'users': Mock()} # Run and Assert error_message = re.escape( @@ -926,9 +926,9 @@ def test_to_dict(self): Setup: - Instance of ``MultiTableMetadata``. - - Add mocked values to ``instance._tables`` and ``instance._relationships``. + - Add mocked values to ``instance.tables`` and ``instance.relationships``. Mock: - - Mock ``SingleTableMetadata`` like object to ``instance._tables``. + - Mock ``SingleTableMetadata`` like object to ``instance.tables``. Output: - A dict representation containing ``tables`` and ``relationships`` has to be returned @@ -949,11 +949,11 @@ def test_to_dict(self): 'name': {'sdtype': 'text'}, } instance = MultiTableMetadata() - instance._tables = { + instance.tables = { 'accounts': table_accounts, 'branches': table_branches } - instance._relationships = [ + instance.relationships = [ { 'parent_table_name': 'accounts', 'parent_primary_key': 'id', @@ -1004,7 +1004,7 @@ def test__set_metadata(self, mock_singletablemetadata): - Mock ``SingleTableMetadata`` from ``sdv.metadata.multi_table`` Side Effects: - - ``instance`` now contains ``instance._tables`` and ``instance._relationships``. + - ``instance`` now contains ``instance.tables`` and ``instance.relationships``. - ``SingleTableMetadata._load_from_dict`` has been called. """ # Setup @@ -1045,12 +1045,12 @@ def test__set_metadata(self, mock_singletablemetadata): instance._set_metadata_dict(multitable_metadata) # Assert - assert instance._tables == { + assert instance.tables == { 'accounts': single_table_accounts, 'branches': single_table_branches } - assert instance._relationships == [ + assert instance.relationships == [ { 'parent_table_name': 'accounts', 'parent_primary_key': 'id', @@ -1073,7 +1073,7 @@ def test__load_from_dict(self, mock_singletablemetadata): - Mock ``SingleTableMetadata`` from ``sdv.metadata.multi_table`` Output: - - ``instance`` that contains ``instance._tables`` and ``instance._relationships``. + - ``instance`` that contains ``instance.tables`` and ``instance.relationships``. Side Effects: - ``SingleTableMetadata._load_from_dict`` has been called. @@ -1114,12 +1114,12 @@ def test__load_from_dict(self, mock_singletablemetadata): instance = MultiTableMetadata._load_from_dict(multitable_metadata) # Assert - assert instance._tables == { + assert instance.tables == { 'accounts': single_table_accounts, 'branches': single_table_branches } - assert instance._relationships == [ + assert instance.relationships == [ { 'parent_table_name': 'accounts', 'parent_primary_key': 'id', @@ -1312,7 +1312,7 @@ def test_add_column(self): # Setup metadata = MultiTableMetadata() table = Mock() - metadata._tables = {'table': table} + metadata.tables = {'table': table} # Run metadata.add_column('table', 'column', sdtype='numerical', pii=False) @@ -1360,7 +1360,7 @@ def test_update_column(self): # Setup metadata = MultiTableMetadata() table = Mock() - metadata._tables = {'table': table} + metadata.tables = {'table': table} # Run metadata.update_column('table', 'column', sdtype='numerical', pii=False) @@ -1401,7 +1401,7 @@ def test_detect_table_from_csv(self, single_table_mock, log_mock): - Mock the ``SingleTableMetadata`` class and the print function. Assert: - - Table should be added to ``self._tables``. + - Table should be added to ``self.tables``. """ # Setup metadata = MultiTableMetadata() @@ -1417,7 +1417,7 @@ def test_detect_table_from_csv(self, single_table_mock, log_mock): # Assert single_table_mock.return_value._load_data_from_csv.assert_called_once_with('path.csv') single_table_mock.return_value._detect_columns.assert_called_once_with(fake_data) - assert metadata._tables == {'table': single_table_mock.return_value} + assert metadata.tables == {'table': single_table_mock.return_value} expected_log_calls = call( 'Detected metadata:\n' @@ -1448,7 +1448,7 @@ def test_detect_table_from_csv_table_already_exists(self): """ # Setup metadata = MultiTableMetadata() - metadata._tables = {'table': Mock()} + metadata.tables = {'table': Mock()} # Run error_message = ( @@ -1470,7 +1470,7 @@ def test_detect_table_from_dataframe(self, single_table_mock, log_mock): - Mock the ``SingleTableMetadata`` class and print function. Assert: - - Table should be added to ``self._tables``. + - Table should be added to ``self.tables``. """ # Setup metadata = MultiTableMetadata() @@ -1484,7 +1484,7 @@ def test_detect_table_from_dataframe(self, single_table_mock, log_mock): # Assert single_table_mock.return_value._detect_columns.assert_called_once_with(data) - assert metadata._tables == {'table': single_table_mock.return_value} + assert metadata.tables == {'table': single_table_mock.return_value} expected_log_calls = call( 'Detected metadata:\n' @@ -1515,7 +1515,7 @@ def test_detect_table_from_dataframe_table_already_exists(self): """ # Setup metadata = MultiTableMetadata() - metadata._tables = {'table': Mock()} + metadata.tables = {'table': Mock()} # Run error_message = ( @@ -1539,7 +1539,7 @@ def test__validate_table_exists(self): """ # Setup metadata = MultiTableMetadata() - metadata._tables = {'table1': 'val', 'table2': 'val'} + metadata.tables = {'table1': 'val', 'table2': 'val'} # Run metadata._validate_table_exists('table1') @@ -1565,7 +1565,7 @@ def test_set_primary_key(self): """ # Setup metadata = MultiTableMetadata() - metadata._tables = {'table1': Mock(), 'table2': 'val'} + metadata.tables = {'table1': Mock(), 'table2': 'val'} metadata._validate_table_exists = Mock() # Run @@ -1573,7 +1573,7 @@ def test_set_primary_key(self): # Assert metadata._validate_table_exists.assert_called_once_with('table1') - metadata._tables['table1'].set_primary_key.assert_called_once_with('col') + metadata.tables['table1'].set_primary_key.assert_called_once_with('col') def test_set_sequence_key(self): """Test ``set_sequence_key``. @@ -1591,7 +1591,7 @@ def test_set_sequence_key(self): """ # Setup metadata = MultiTableMetadata() - metadata._tables = {'table1': Mock(), 'table2': 'val'} + metadata.tables = {'table1': Mock(), 'table2': 'val'} metadata._validate_table_exists = Mock() # Run @@ -1601,7 +1601,7 @@ def test_set_sequence_key(self): # Assert metadata._validate_table_exists.assert_called_once_with('table1') - metadata._tables['table1'].set_sequence_key.assert_called_once_with('col') + metadata.tables['table1'].set_sequence_key.assert_called_once_with('col') def test_add_alternate_keys(self): """Test ``add_alternate_keys``. @@ -1619,7 +1619,7 @@ def test_add_alternate_keys(self): """ # Setup metadata = MultiTableMetadata() - metadata._tables = {'table1': Mock(), 'table2': 'val'} + metadata.tables = {'table1': Mock(), 'table2': 'val'} metadata._validate_table_exists = Mock() # Run @@ -1627,7 +1627,7 @@ def test_add_alternate_keys(self): # Assert metadata._validate_table_exists.assert_called_once_with('table1') - metadata._tables['table1'].add_alternate_keys.assert_called_once_with(['col1', 'col2']) + metadata.tables['table1'].add_alternate_keys.assert_called_once_with(['col1', 'col2']) def test_set_sequence_index(self): """Test ``set_sequence_index``. @@ -1645,7 +1645,7 @@ def test_set_sequence_index(self): """ # Setup metadata = MultiTableMetadata() - metadata._tables = {'table1': Mock(), 'table2': 'val'} + metadata.tables = {'table1': Mock(), 'table2': 'val'} metadata._validate_table_exists = Mock() # Run @@ -1655,7 +1655,7 @@ def test_set_sequence_index(self): # Assert metadata._validate_table_exists.assert_called_once_with('table1') - metadata._tables['table1'].set_sequence_index.assert_called_once_with('col') + metadata.tables['table1'].set_sequence_index.assert_called_once_with('col') def test_add_constraint(self): """Test the ``add_constraint`` method. @@ -1676,7 +1676,7 @@ def test_add_constraint(self): # Setup metadata = MultiTableMetadata() table = Mock() - metadata._tables = {'table': table} + metadata.tables = {'table': table} # Run metadata.add_constraint('table', 'Inequality', low_column_name='a', high_column_name='b') @@ -1778,13 +1778,13 @@ def test_load_from_json(self, mock_json, mock_path, mock_open): instance = MultiTableMetadata.load_from_json('filepath.json') # Asserts - assert list(instance._tables.keys()) == ['table1'] - assert instance._tables['table1']._columns == {'animals': {'type': 'categorical'}} - assert instance._tables['table1']._primary_key == 'animals' - assert instance._tables['table1']._sequence_key is None - assert instance._tables['table1']._alternate_keys == [] - assert instance._tables['table1']._sequence_index is None - assert instance._tables['table1']._version == 'SINGLE_TABLE_V1' + assert list(instance.tables.keys()) == ['table1'] + assert instance.tables['table1'].columns == {'animals': {'type': 'categorical'}} + assert instance.tables['table1'].primary_key == 'animals' + assert instance.tables['table1'].sequence_key is None + assert instance.tables['table1'].alternate_keys == [] + assert instance.tables['table1'].sequence_index is None + assert instance.tables['table1']._version == 'SINGLE_TABLE_V1' @patch('sdv.metadata.utils.Path') def test_save_to_json_file_exists(self, mock_path): diff --git a/tests/unit/metadata/test_single_table.py b/tests/unit/metadata/test_single_table.py index c4989b6ad..c269779bb 100644 --- a/tests/unit/metadata/test_single_table.py +++ b/tests/unit/metadata/test_single_table.py @@ -74,11 +74,11 @@ def test___init__(self): instance = SingleTableMetadata() # Assert - assert instance._columns == {} - assert instance._primary_key is None - assert instance._sequence_key is None - assert instance._alternate_keys == [] - assert instance._sequence_index is None + assert instance.columns == {} + assert instance.primary_key is None + assert instance.sequence_key is None + assert instance.alternate_keys == [] + assert instance.sequence_index is None assert instance._version == 'SINGLE_TABLE_V1' def test__validate_numerical_default_and_invalid(self): @@ -247,11 +247,11 @@ def test__validate_column_exists(self): - Column name. Side Effects: - - ``InvalidMetadataError`` when the column is not in the ``instance._columns``. + - ``InvalidMetadataError`` when the column is not in the ``instance.columns``. """ # Setup instance = SingleTableMetadata() - instance._columns = { + instance.columns = { 'name': {'sdtype': 'categorical'}, 'age': {'sdtype': 'numerical'}, 'start_date': {'sdtype': 'datetime'}, @@ -504,25 +504,25 @@ def test_update_column_sdtype(self): """Test that ``update_column`` updates the sdtype and keyword args for the given column.""" # Setup instance = SingleTableMetadata() - instance._columns = {'a': {'sdtype': 'numerical'}} + instance.columns = {'a': {'sdtype': 'numerical'}} # Run instance.update_column('a', sdtype='categorical', order_by='alphabetical') # Assert - assert instance._columns == {'a': {'sdtype': 'categorical', 'order_by': 'alphabetical'}} + assert instance.columns == {'a': {'sdtype': 'categorical', 'order_by': 'alphabetical'}} def test_update_column_add_extra_value(self): """Test that ``update_column`` updates only the keyword args for the given column.""" # Setup instance = SingleTableMetadata() - instance._columns = {'a': {'sdtype': 'numerical'}} + instance.columns = {'a': {'sdtype': 'numerical'}} # Run instance.update_column('a', computer_representation='Int64') # Assert - assert instance._columns == { + assert instance.columns == { 'a': { 'sdtype': 'numerical', 'computer_representation': 'Int64' @@ -533,7 +533,7 @@ def test_add_column_column_name_in_columns(self): """Test ``add_column`` method. Test that when calling ``add_column`` with a column that is already in - ``instance._columns`` raises an ``InvalidMetadataError`` stating to use the + ``instance.columns`` raises an ``InvalidMetadataError`` stating to use the ``update_column`` instead. Setup: @@ -541,14 +541,14 @@ def test_add_column_column_name_in_columns(self): - ``_columns`` with some values. Input: - - A column name that is already in ``instance._columns``. + - A column name that is already in ``instance.columns``. Side Effects: - ``InvalidMetadataError`` is being raised stating that the column exists. """ # Setup instance = SingleTableMetadata() - instance._columns = {'age': {'sdtype': 'numerical'}} + instance.columns = {'age': {'sdtype': 'numerical'}} # Run / Assert error_msg = re.escape( @@ -600,7 +600,7 @@ def test_add_column(self): """Test ``add_column`` method. Test that when calling ``add_column`` method with a ``sdtype`` and the proper ``kwargs`` - this is being added to the ``instance._columns``. + this is being added to the ``instance.columns``. Setup: - Instance of ``SingleTableMetadata``. @@ -610,7 +610,7 @@ def test_add_column(self): - An ``sdtype``. Side Effects: - - ``instance._columns[column_name]`` now exists. + - ``instance.columns[column_name]`` now exists. """ # Setup instance = SingleTableMetadata() @@ -619,7 +619,7 @@ def test_add_column(self): instance.add_column('age', sdtype='numerical', computer_representation='Int8') # Assert - assert instance._columns['age'] == { + assert instance.columns['age'] == { 'sdtype': 'numerical', 'computer_representation': 'Int8' } @@ -640,7 +640,7 @@ def test_add_column_other_sdtype(self): instance.add_column('number', sdtype='phone_number') # Assert - assert instance._columns['number'] == {'sdtype': 'phone_number', 'pii': True} + assert instance.columns['number'] == {'sdtype': 'phone_number', 'pii': True} @patch('sdv.metadata.single_table.SingleTableMetadata._validate_column') @patch('sdv.metadata.single_table.SingleTableMetadata._validate_column_exists') @@ -664,13 +664,13 @@ def test_upate_column_sdtype_in_kwargs(self, """ # Setup instance = SingleTableMetadata() - instance._columns = {'age': {'sdtype': 'numerical'}} + instance.columns = {'age': {'sdtype': 'numerical'}} # Run instance.update_column('age', sdtype='categorical', order_by='numerical_value') # Assert - assert instance._columns['age'] == { + assert instance.columns['age'] == { 'sdtype': 'categorical', 'order_by': 'numerical_value' } @@ -699,13 +699,13 @@ def test_upate_column_no_sdtype(self, mock__validate_column_exists, mock__valida """ # Setup instance = SingleTableMetadata() - instance._columns = {'age': {'sdtype': 'numerical'}} + instance.columns = {'age': {'sdtype': 'numerical'}} # Run instance.update_column('age', computer_representation='Float') # Assert - assert instance._columns['age'] == { + assert instance.columns['age'] == { 'sdtype': 'numerical', 'computer_representation': 'Float' } @@ -721,14 +721,14 @@ def test_detect_from_dataframe_raises_error(self): Setup: - instance of ``SingleTableMetadata``. - - Add some value to ``instance._columns``. + - Add some value to ``instance.columns``. Side Effects: Raises an ``InvalidMetadataError`` stating that ``metadata`` already exists. """ # Setup instance = SingleTableMetadata() - instance._columns = {'column': {'sdtype': 'categorical'}} + instance.columns = {'column': {'sdtype': 'categorical'}} # Run / Assert err_msg = ( @@ -754,7 +754,7 @@ def test_detect_from_dataframe(self, mock_log): - ``pandas.DataFrame`` with multiple data types. Side Effects: - - ``instance._columns`` has been updated with the expected ``sdtypes``. + - ``instance.columns`` has been updated with the expected ``sdtypes``. - A message is being printed. """ # Setup @@ -771,7 +771,7 @@ def test_detect_from_dataframe(self, mock_log): instance.detect_from_dataframe(data) # Assert - assert instance._columns == { + assert instance.columns == { 'categorical': {'sdtype': 'categorical'}, 'date': {'sdtype': 'datetime'}, 'int': {'sdtype': 'numerical'}, @@ -793,14 +793,14 @@ def test_detect_from_csv_raises_error(self): Setup: - instance of ``SingleTableMetadata``. - - Add some value to ``instance._columns``. + - Add some value to ``instance.columns``. Side Effects: Raises an ``InvalidMetadataError`` stating that ``metadata`` already exists. """ # Setup instance = SingleTableMetadata() - instance._columns = {'column': {'sdtype': 'categorical'}} + instance.columns = {'column': {'sdtype': 'categorical'}} # Run / Assert err_msg = ( @@ -827,7 +827,7 @@ def test_detect_from_csv(self, mock_log): - String that represents the ``path`` to the ``csv`` file. Side Effects: - - ``instance._columns`` has been updated with the expected ``sdtypes``. + - ``instance.columns`` has been updated with the expected ``sdtypes``. - A message is being printed. """ # Setup @@ -847,7 +847,7 @@ def test_detect_from_csv(self, mock_log): instance.detect_from_csv(filepath) # Assert - assert instance._columns == { + assert instance.columns == { 'categorical': {'sdtype': 'categorical'}, 'date': {'sdtype': 'categorical'}, 'int': {'sdtype': 'numerical'}, @@ -877,7 +877,7 @@ def test_detect_from_csv_with_kwargs(self, mock_log): - String that represents the ``path`` to the ``csv`` file. Side Effects: - - ``instance._columns`` has been updated with the expected ``sdtypes``. + - ``instance.columns`` has been updated with the expected ``sdtypes``. - one of the columns must be datetime - A message is being printed. """ @@ -898,7 +898,7 @@ def test_detect_from_csv_with_kwargs(self, mock_log): instance.detect_from_csv(filepath, pandas_kwargs={'parse_dates': ['date']}) # Assert - assert instance._columns == { + assert instance.columns == { 'categorical': {'sdtype': 'categorical'}, 'date': {'sdtype': 'datetime'}, 'int': {'sdtype': 'numerical'}, @@ -1017,7 +1017,7 @@ def test_set_primary_key_validation_columns(self): """ # Setup instance = SingleTableMetadata() - instance._columns = {'a', 'd'} + instance.columns = {'a', 'd'} err_msg = ( "Unknown primary key values {'b'}." @@ -1054,25 +1054,25 @@ def test_set_primary_key(self): """Test that ``set_primary_key`` sets the ``_primary_key`` value.""" # Setup instance = SingleTableMetadata() - instance._columns = {'column': {'sdtype': 'numerical'}} + instance.columns = {'column': {'sdtype': 'numerical'}} # Run instance.set_primary_key('column') # Assert - assert instance._primary_key == 'column' + assert instance.primary_key == 'column' def test_set_primary_key_tuple(self): """Test that ``set_primary_key`` sets the ``_primary_key`` value for tuples.""" # Setup instance = SingleTableMetadata() - instance._columns = {'col1': {'sdtype': 'numerical'}, 'col2': {'sdtype': 'numerical'}} + instance.columns = {'col1': {'sdtype': 'numerical'}, 'col2': {'sdtype': 'numerical'}} # Run instance.set_primary_key(('col1', 'col2')) # Assert - assert instance._primary_key == ('col1', 'col2') + assert instance.primary_key == ('col1', 'col2') @patch('sdv.metadata.single_table.warnings') def test_set_primary_key_already_exists_warning(self, warning_mock): @@ -1089,8 +1089,8 @@ def test_set_primary_key_already_exists_warning(self, warning_mock): """ # Setup instance = SingleTableMetadata() - instance._columns = {'column1': {'sdtype': 'numerical'}} - instance._primary_key = 'column0' + instance.columns = {'column1': {'sdtype': 'numerical'}} + instance.primary_key = 'column0' # Run instance.set_primary_key('column1') @@ -1098,20 +1098,20 @@ def test_set_primary_key_already_exists_warning(self, warning_mock): # Assert warning_msg = "There is an existing primary key 'column0'. This key will be removed." assert warning_mock.warn.called_once_with(warning_msg) - assert instance._primary_key == 'column1' + assert instance.primary_key == 'column1' @patch('sdv.metadata.single_table.warnings') def test_set_primary_key_in_alternate_keys_warning(self, warning_mock): - """Test that ``set_primary_key`` raises a warning the key is in ``self._alternate_keys``. + """Test that ``set_primary_key`` raises a warning the key is in ``self.alternate_keys``. Setup: - Set the ``self._alternate_keys`` list to contain the key being added. + Set the ``self.alternate_keys`` list to contain the key being added. """ # Setup instance = SingleTableMetadata() - instance._columns = {'column1': {'sdtype': 'numerical'}} - instance._primary_key = 'column0' - instance._alternate_keys = ['column1', 'column2'] + instance.columns = {'column1': {'sdtype': 'numerical'}} + instance.primary_key = 'column0' + instance.alternate_keys = ['column1', 'column2'] # Run instance.set_primary_key('column1') @@ -1121,8 +1121,8 @@ def test_set_primary_key_in_alternate_keys_warning(self, warning_mock): 'column1 is currently set as an alternate key and will be removed from that list.' ) assert warning_mock.warn.called_once_with(warning_msg) - assert instance._primary_key == 'column1' - assert instance._alternate_keys == ['column2'] + assert instance.primary_key == 'column1' + assert instance.alternate_keys == ['column2'] def test_set_sequence_key_validation_dtype(self): """Test that ``set_sequence_key`` crashes for invalid arguments. @@ -1155,7 +1155,7 @@ def test_set_sequence_key_validation_columns(self): """ # Setup instance = SingleTableMetadata() - instance._columns = {'a', 'd'} + instance.columns = {'a', 'd'} err_msg = ( "Unknown sequence key values {'b'}." @@ -1192,25 +1192,25 @@ def test_set_sequence_key(self): """Test that ``set_sequence_key`` sets the ``_sequence_key`` value.""" # Setup instance = SingleTableMetadata() - instance._columns = {'column': {'sdtype': 'numerical'}} + instance.columns = {'column': {'sdtype': 'numerical'}} # Run instance.set_sequence_key('column') # Assert - assert instance._sequence_key == 'column' + assert instance.sequence_key == 'column' def test_set_sequence_key_tuple(self): """Test that ``set_sequence_key`` sets ``_sequence_key`` for tuples.""" # Setup instance = SingleTableMetadata() - instance._columns = {'col1': {'sdtype': 'numerical'}, 'col2': {'sdtype': 'numerical'}} + instance.columns = {'col1': {'sdtype': 'numerical'}, 'col2': {'sdtype': 'numerical'}} # Run instance.set_sequence_key(('col1', 'col2')) # Assert - assert instance._sequence_key == ('col1', 'col2') + assert instance.sequence_key == ('col1', 'col2') @patch('sdv.metadata.single_table.warnings') def test_set_sequence_key_warning(self, warning_mock): @@ -1227,8 +1227,8 @@ def test_set_sequence_key_warning(self, warning_mock): """ # Setup instance = SingleTableMetadata() - instance._columns = {'column1': {'sdtype': 'numerical'}} - instance._sequence_key = 'column0' + instance.columns = {'column1': {'sdtype': 'numerical'}} + instance.sequence_key = 'column0' # Run instance.set_sequence_key('column1') @@ -1236,7 +1236,7 @@ def test_set_sequence_key_warning(self, warning_mock): # Assert warning_msg = "There is an existing sequence key 'column0'. This key will be removed." assert warning_mock.warn.called_once_with(warning_msg) - assert instance._sequence_key == 'column1' + assert instance.sequence_key == 'column1' def test_add_alternate_keys_validation_dtype(self): """Test that ``add_alternate_keys`` crashes for invalid arguments. @@ -1269,7 +1269,7 @@ def test_add_alternate_keys_validation_columns(self): """ # Setup instance = SingleTableMetadata() - instance._columns = {'abc', '213', '312'} + instance.columns = {'abc', '213', '312'} err_msg = ( "Unknown alternate key values {'123'}." @@ -1311,8 +1311,8 @@ def test_add_alternate_keys_validation_primary_key(self): """ # Setup instance = SingleTableMetadata() - instance._columns = {'column1': {'sdtype': 'numerical'}} - instance._primary_key = 'column1' + instance.columns = {'column1': {'sdtype': 'numerical'}} + instance.primary_key = 'column1' err_msg = re.escape( "Invalid alternate key 'column1'. The key is already specified as a primary key." @@ -1325,7 +1325,7 @@ def test_add_alternate_keys(self): """Test that ``add_alternate_keys`` adds the columns to the ``_alternate_keys``.""" # Setup instance = SingleTableMetadata() - instance._columns = { + instance.columns = { 'column1': {'sdtype': 'numerical'}, 'column2': {'sdtype': 'numerical'}, 'column3': {'sdtype': 'numerical'} @@ -1335,25 +1335,25 @@ def test_add_alternate_keys(self): instance.add_alternate_keys(['column1', ('column2', 'column3')]) # Assert - assert instance._alternate_keys == ['column1', ('column2', 'column3')] + assert instance.alternate_keys == ['column1', ('column2', 'column3')] @patch('sdv.metadata.single_table.warnings') def test_add_alternate_keys_duplicate(self, warnings_mock): """Test that the method does not add columns that are already in ``_alternate_keys``.""" # Setup instance = SingleTableMetadata() - instance._columns = { + instance.columns = { 'column1': {'sdtype': 'numerical'}, 'column2': {'sdtype': 'numerical'}, 'column3': {'sdtype': 'numerical'} } - instance._alternate_keys = ['column3'] + instance.alternate_keys = ['column3'] # Run instance.add_alternate_keys(['column1', 'column2', 'column3']) # Assert - assert instance._alternate_keys == ['column3', 'column1', 'column2'] + assert instance.alternate_keys == ['column3', 'column1', 'column2'] message = 'column3 is already an alternate key.' warnings_mock.warn.assert_called_once_with(message) @@ -1388,7 +1388,7 @@ def test_set_sequence_index_validation_columns(self): """ # Setup instance = SingleTableMetadata() - instance._columns = {'a', 'd'} + instance.columns = {'a', 'd'} err_msg = ( "Unknown sequence index value {'column'}." @@ -1402,7 +1402,7 @@ def test_set_sequence_index_column_not_numerical_or_datetime(self): """Test that the method errors if the column is not numerical or datetime.""" # Setup instance = SingleTableMetadata() - instance._columns = { + instance.columns = { 'a': {'sdtype': 'numerical'}, 'd': {'sdtype': 'categorical'} } @@ -1416,20 +1416,20 @@ def test_set_sequence_index(self): """Test that ``set_sequence_index`` sets the ``_sequence_index`` value.""" # Setup instance = SingleTableMetadata() - instance._columns = {'column': {'sdtype': 'numerical'}} + instance.columns = {'column': {'sdtype': 'numerical'}} # Run instance.set_sequence_index('column') # Assert - assert instance._sequence_index == 'column' + assert instance.sequence_index == 'column' def test_validate_sequence_index_not_in_sequence_key(self): """Test the ``_validate_sequence_index_not_in_sequence_key`` method.""" # Setup instance = SingleTableMetadata() - instance._sequence_key = ('abc', 'def') - instance._sequence_index = 'abc' + instance.sequence_key = ('abc', 'def') + instance.sequence_index = 'abc' err_msg = ( "'sequence_index' and 'sequence_key' have the same value {'abc'}." @@ -1453,11 +1453,11 @@ def test_validate(self): """ # Setup instance = SingleTableMetadata() - instance._columns = {'col1': {'sdtype': 'numerical'}, 'col2': {'sdtype': 'numerical'}} - instance._primary_key = 'col1' - instance._alternate_keys = ['col2'] - instance._sequence_key = 'col1' - instance._sequence_index = 'col2' + instance.columns = {'col1': {'sdtype': 'numerical'}, 'col2': {'sdtype': 'numerical'}} + instance.primary_key = 'col1' + instance.alternate_keys = ['col2'] + instance.sequence_key = 'col1' + instance.sequence_index = 'col2' instance._validate_key = Mock() instance._validate_alternate_keys = Mock() instance._validate_sequence_index = Mock() @@ -1475,20 +1475,20 @@ def test_validate(self): # Assert instance._validate_key.assert_has_calls( - [call(instance._primary_key, 'primary'), call(instance._sequence_key, 'sequence')] + [call(instance.primary_key, 'primary'), call(instance.sequence_key, 'sequence')] ) instance._validate_column.assert_has_calls( [call('col1', sdtype='numerical'), call('col2', sdtype='numerical')] ) - instance._validate_alternate_keys.assert_called_once_with(instance._alternate_keys) - instance._validate_sequence_index.assert_called_once_with(instance._sequence_index) + instance._validate_alternate_keys.assert_called_once_with(instance.alternate_keys) + instance._validate_sequence_index.assert_called_once_with(instance.sequence_index) instance._validate_sequence_index_not_in_sequence_key.assert_called_once() def test_to_dict(self): """Test the ``to_dict`` method from ``SingleTableMetadata``. Setup: - - Instance of ``SingleTableMetadata`` and modify the ``instance._columns`` to ensure + - Instance of ``SingleTableMetadata`` and modify the ``instance.columns`` to ensure that ``to_dict`` works properly. Output: - A dictionary representation of the ``instance`` that does not modify the @@ -1496,7 +1496,7 @@ def test_to_dict(self): """ # Setup instance = SingleTableMetadata() - instance._columns['my_column'] = 'value' + instance.columns['my_column'] = 'value' # Run result = instance.to_dict() @@ -1509,7 +1509,7 @@ def test_to_dict(self): # Ensure that the output object does not alterate the inside object result['columns']['my_column'] = 1 - assert instance._columns['my_column'] == 'value' + assert instance.columns['my_column'] == 'value' def test__load_from_dict(self): """Test that ``_load_from_dict`` returns a instance with the ``dict`` updated objects.""" @@ -1527,11 +1527,11 @@ def test__load_from_dict(self): instance = SingleTableMetadata._load_from_dict(my_metadata) # Assert - assert instance._columns == {'my_column': 'value'} - assert instance._primary_key == 'pk' - assert instance._sequence_key is None - assert instance._alternate_keys == [] - assert instance._sequence_index is None + assert instance.columns == {'my_column': 'value'} + assert instance.primary_key == 'pk' + assert instance.sequence_key is None + assert instance.alternate_keys == [] + assert instance.sequence_index is None assert instance._version == 'SINGLE_TABLE_V1' @patch('sdv.metadata.utils.Path') @@ -1641,11 +1641,11 @@ def test_load_from_json(self, mock_json, mock_path, mock_open): instance = SingleTableMetadata.load_from_json('filepath.json') # Assert - assert instance._columns == {'animals': {'type': 'categorical'}} - assert instance._primary_key == 'animals' - assert instance._sequence_key is None - assert instance._alternate_keys == [] - assert instance._sequence_index is None + assert instance.columns == {'animals': {'type': 'categorical'}} + assert instance.primary_key == 'animals' + assert instance.sequence_key is None + assert instance.alternate_keys == [] + assert instance.sequence_index is None assert instance._version == 'SINGLE_TABLE_V1' @patch('sdv.metadata.utils.Path') diff --git a/tests/unit/multi_table/test_base.py b/tests/unit/multi_table/test_base.py index 8aed387e5..1916049d8 100644 --- a/tests/unit/multi_table/test_base.py +++ b/tests/unit/multi_table/test_base.py @@ -44,9 +44,9 @@ def test__initialize_models(self): 'upravna_enota': instance._synthesizer.return_value } instance._synthesizer.assert_has_calls([ - call(metadata=instance.metadata._tables['nesreca'], default_distribution='gamma'), - call(metadata=instance.metadata._tables['oseba']), - call(metadata=instance.metadata._tables['upravna_enota']) + call(metadata=instance.metadata.tables['nesreca'], default_distribution='gamma'), + call(metadata=instance.metadata.tables['oseba']), + call(metadata=instance.metadata.tables['upravna_enota']) ]) def test___init__(self): diff --git a/tests/unit/multi_table/test_hma.py b/tests/unit/multi_table/test_hma.py index 27430f030..a1f80e0ca 100644 --- a/tests/unit/multi_table/test_hma.py +++ b/tests/unit/multi_table/test_hma.py @@ -404,7 +404,7 @@ def test__get_child_synthesizer(self): table_name = 'users' foreign_key = 'session_id' table_meta = Mock() - instance.metadata._tables = {'users': table_meta} + instance.metadata.tables = {'users': table_meta} instance._synthesizer_kwargs = {'a': 1} # Run @@ -439,8 +439,8 @@ def test__sample_child_rows(self): metadata = Mock() sessions_meta = Mock() users_meta = Mock() - users_meta._primary_key.return_value = 'user_id' - metadata._tables = { + users_meta.primary_key.return_value = 'user_id' + metadata.tables = { 'users': users_meta, 'sessions': sessions_meta } @@ -489,8 +489,8 @@ def test__sample_child_rows_with_sampled_data(self): metadata = Mock() sessions_meta = Mock() users_meta = Mock() - users_meta._primary_key.return_value = 'user_id' - metadata._tables = { + users_meta.primary_key.return_value = 'user_id' + metadata.tables = { 'users': users_meta, 'sessions': sessions_meta } @@ -627,7 +627,7 @@ def test__sample(self): 'sessions': ['users'], 'transactions': ['sessions'] } - instance.metadata._tables = { + instance.metadata.tables = { 'users': Mock(), 'sessions': Mock(), 'transactions': Mock(), diff --git a/tests/unit/sequential/test_par.py b/tests/unit/sequential/test_par.py index db4948c0a..a3299b086 100644 --- a/tests/unit/sequential/test_par.py +++ b/tests/unit/sequential/test_par.py @@ -74,7 +74,7 @@ def test___init__(self): assert isinstance(synthesizer._data_processor, DataProcessor) assert synthesizer._data_processor.metadata == metadata assert isinstance(synthesizer._context_synthesizer, GaussianCopulaSynthesizer) - assert synthesizer._context_synthesizer.metadata._columns == { + assert synthesizer._context_synthesizer.metadata.columns == { 'gender': {'sdtype': 'categorical'}, 'name': {'sdtype': 'text'} } diff --git a/tests/unit/single_table/test_base.py b/tests/unit/single_table/test_base.py index 598aa3189..891e5fe6b 100644 --- a/tests/unit/single_table/test_base.py +++ b/tests/unit/single_table/test_base.py @@ -146,7 +146,7 @@ def test_get_transformers(self): 'name': 'FrequencyEncoder', 'salary': 'FloatFormatter' } - instance.metadata._columns = { + instance.metadata.columns = { 'salary': {'sdtype': 'numerical'}, 'name': {'sdtype': 'categorical'} } @@ -1765,11 +1765,11 @@ def test_load(self): # Assert assert isinstance(loaded_instance, BaseSingleTableSynthesizer) - assert instance.metadata._columns == {} - assert instance.metadata._primary_key is None - assert instance.metadata._alternate_keys == [] - assert instance.metadata._sequence_key is None - assert instance.metadata._sequence_index is None + assert instance.metadata.columns == {} + assert instance.metadata.primary_key is None + assert instance.metadata.alternate_keys == [] + assert instance.metadata.sequence_key is None + assert instance.metadata.sequence_index is None assert instance.metadata._version == 'SINGLE_TABLE_V1' def test_load_custom_constraint_classes(self): diff --git a/tests/unit/single_table/test_copulagan.py b/tests/unit/single_table/test_copulagan.py index e3df9d3f4..940c5ab9d 100644 --- a/tests/unit/single_table/test_copulagan.py +++ b/tests/unit/single_table/test_copulagan.py @@ -188,7 +188,7 @@ def test__create_gaussian_normalizer_config(self, mock_rdt): # Setup numerical_distributions = {'age': 'gamma'} metadata = SingleTableMetadata() - metadata._columns = { + metadata.columns = { 'name': { 'sdtype': 'categorical', }, diff --git a/tests/unit/single_table/test_copulas.py b/tests/unit/single_table/test_copulas.py index 713a04317..c44595afc 100644 --- a/tests/unit/single_table/test_copulas.py +++ b/tests/unit/single_table/test_copulas.py @@ -449,7 +449,7 @@ def test__get_valid_columns_from_metadata(self): """Test that it returns a list with columns that are from the metadata.""" # Seutp instance = Mock() - instance.metadata._columns = { + instance.metadata.columns = { 'a_value': object(), 'n_value': object(), 'b_value': object() diff --git a/tests/unit/single_table/test_utils.py b/tests/unit/single_table/test_utils.py index 4f8d204af..619d907f8 100644 --- a/tests/unit/single_table/test_utils.py +++ b/tests/unit/single_table/test_utils.py @@ -15,7 +15,7 @@ def test_detect_discrete_columns(): """Test that the detect discrete columns returns a list columns that are not continuum.""" # Setup metadata = SingleTableMetadata() - metadata._columns = { + metadata.columns = { 'name': { 'sdtype': 'categorical', },