From 6494ba337e71214fc0f29c676285c94643933ecf Mon Sep 17 00:00:00 2001 From: Mike Knepper Date: Fri, 12 May 2023 16:28:25 -0500 Subject: [PATCH] Go back to table_name_mappings as a list to preserve order --- src/gretel_trainer/relational/backup.py | 2 +- src/gretel_trainer/relational/json.py | 28 ++++++++++++++++--------- tests/relational/test_backup.py | 8 +++---- 3 files changed, 23 insertions(+), 15 deletions(-) diff --git a/src/gretel_trainer/relational/backup.py b/src/gretel_trainer/relational/backup.py index 3ff9e216..04eaf81b 100644 --- a/src/gretel_trainer/relational/backup.py +++ b/src/gretel_trainer/relational/backup.py @@ -35,7 +35,7 @@ class BackupRelationalJson: original_table_name: str original_primary_key: list[str] original_columns: list[str] - table_name_mappings: dict[str, str] + table_name_mappings: list[tuple[str, str]] invented_table_names: list[str] diff --git a/src/gretel_trainer/relational/json.py b/src/gretel_trainer/relational/json.py index 740bde70..14022fae 100644 --- a/src/gretel_trainer/relational/json.py +++ b/src/gretel_trainer/relational/json.py @@ -154,7 +154,7 @@ def __init__( original_primary_key: list[str], original_columns: list[str], original_data: Optional[pd.DataFrame], - table_name_mappings: dict[str, str], + table_name_mappings: list[tuple[str, str]], ): self.original_table_name = original_table_name self.original_primary_key = original_primary_key @@ -169,7 +169,7 @@ def ingest( tables = _normalize_json([(table_name, df.copy())], []) # If we created additional tables (from JSON lists) or added columns (from JSON dicts) if len(tables) > 1 or len(tables[0][1].columns) > len(df.columns): - mappings = {name: sanitize_str(name) for name, _ in tables} + mappings = [(name, sanitize_str(name)) for name, _ in tables] rel_json = RelationalJson( original_table_name=table_name, original_primary_key=primary_key, @@ -182,17 +182,25 @@ def ingest( @property def root_table_name(self) -> str: - return self.table_name_mappings[self.original_table_name] + return self._mapping_dict[self.original_table_name] @property def table_names(self) -> list[str]: - return list(self.table_name_mappings.values()) + # We need to keep the order intact for restoring + return [m[1] for m in self.table_name_mappings] + + def get_sanitized_name(self, t: str) -> str: + return self._mapping_dict[t] + + @property + def _mapping_dict(self) -> dict[str, str]: + return dict(self.table_name_mappings) @property def inverse_table_name_mappings(self) -> dict[str, str]: # Keys are sanitized, model-friendly names # Values are "provenance" names (a^b>c) or the original table name - return {value: key for key, value in self.table_name_mappings.items()} + return {value: key for key, value in self._mapping_dict.items()} def restore( self, tables: dict[str, pd.DataFrame], rel_data: _RelationalData @@ -280,7 +288,7 @@ def _generate_commands( Returns lists of keyword arguments designed to be passed to a RelationalData instance's _add_single_table and add_foreign_key methods """ - tables = [(rel_json.table_name_mappings[name], df) for name, df in tables] + tables = [(rel_json.get_sanitized_name(name), df) for name, df in tables] non_empty_tables = [t for t in tables if not t[1].empty] _add_single_table = [] @@ -293,9 +301,9 @@ def _generate_commands( table_pk = [PRIMARY_KEY_COLUMN] table_df.index.rename(PRIMARY_KEY_COLUMN, inplace=True) table_df.reset_index(inplace=True) - invented_root_table_name = rel_json.table_name_mappings[ + invented_root_table_name = rel_json.get_sanitized_name( rel_json.original_table_name - ] + ) metadata = InventedTableMetadata( invented_root_table_name=invented_root_table_name, original_table_name=rel_json.original_table_name, @@ -311,9 +319,9 @@ def _generate_commands( for table_name, table_df in non_empty_tables: for column in get_id_columns(table_df): - referred_table = rel_json.table_name_mappings[ + referred_table = rel_json.get_sanitized_name( get_parent_table_name_from_child_id_column(column) - ] + ) add_foreign_key.append( { "table": table_name, diff --git a/tests/relational/test_backup.py b/tests/relational/test_backup.py index a6b31ae3..7182b065 100644 --- a/tests/relational/test_backup.py +++ b/tests/relational/test_backup.py @@ -80,10 +80,10 @@ def test_backup_relational_data_with_json(documents): original_table_name="purchases", original_primary_key=["id"], original_columns=["id", "user_id", "data"], - table_name_mappings={ - "purchases": "purchases-sfx", - "purchases^data>years": "purchases-data-years-sfx", - }, + table_name_mappings=[ + ("purchases", "purchases-sfx"), + ("purchases^data>years", "purchases-data-years-sfx"), + ], invented_table_names=["purchases-sfx", "purchases-data-years-sfx"], ), },