From 76984aaa42ce5fa4df2b19ed942a96d70878e2bc Mon Sep 17 00:00:00 2001 From: Mike Knepper Date: Thu, 22 Jun 2023 07:50:20 -0500 Subject: [PATCH] Core refactors (#126) Elevates "producer tables" so that all tables are nodes on a graph. Removes `RelationalJson` object; json transforms now stateless. --- src/gretel_trainer/relational/backup.py | 71 ++-- src/gretel_trainer/relational/core.py | 335 +++++++++--------- src/gretel_trainer/relational/json.py | 268 +++++++------- src/gretel_trainer/relational/multi_table.py | 21 +- tests/relational/test_backup.py | 24 +- .../test_relational_data_with_json.py | 20 +- 6 files changed, 364 insertions(+), 375 deletions(-) diff --git a/src/gretel_trainer/relational/backup.py b/src/gretel_trainer/relational/backup.py index 675c70ee..13acc6ab 100644 --- a/src/gretel_trainer/relational/backup.py +++ b/src/gretel_trainer/relational/backup.py @@ -1,16 +1,18 @@ from __future__ import annotations from dataclasses import asdict, dataclass -from typing import Any, Optional +from typing import Any, Optional, Union from gretel_trainer.relational.artifacts import ArtifactCollection -from gretel_trainer.relational.core import ForeignKey, RelationalData +from gretel_trainer.relational.core import ForeignKey, RelationalData, Scope +from gretel_trainer.relational.json import InventedTableMetadata, ProducerMetadata @dataclass class BackupRelationalDataTable: primary_key: list[str] invented_table_metadata: Optional[dict[str, Any]] = None + producer_metadata: Optional[dict[str, Any]] = None @dataclass @@ -30,50 +32,45 @@ def from_fk(cls, fk: ForeignKey) -> BackupForeignKey: ) -@dataclass -class BackupRelationalJson: - original_table_name: str - original_primary_key: list[str] - original_columns: list[str] - table_name_mappings: dict[str, str] - - @dataclass class BackupRelationalData: tables: dict[str, BackupRelationalDataTable] foreign_keys: list[BackupForeignKey] - relational_jsons: dict[str, BackupRelationalJson] @classmethod def from_relational_data(cls, rel_data: RelationalData) -> BackupRelationalData: tables = {} foreign_keys = [] - relational_jsons = {} - for table in rel_data.list_all_tables(): - backup_table = BackupRelationalDataTable( + for table in rel_data.list_all_tables(Scope.ALL): + tables[table] = BackupRelationalDataTable( primary_key=rel_data.get_primary_key(table), + invented_table_metadata=_optionally_as_dict( + rel_data.get_invented_table_metadata(table) + ), + producer_metadata=_optionally_as_dict( + rel_data.get_producer_metadata(table) + ), ) - if ( - invented_table_metadata := rel_data.get_invented_table_metadata(table) - ) is not None: - backup_table.invented_table_metadata = asdict(invented_table_metadata) - tables[table] = backup_table - foreign_keys.extend( - [ - BackupForeignKey.from_fk(key) - for key in rel_data.get_foreign_keys(table) - ] - ) - for key, rel_json in rel_data.relational_jsons.items(): - relational_jsons[key] = BackupRelationalJson( - original_table_name=rel_json.original_table_name, - original_primary_key=rel_json.original_primary_key, - original_columns=rel_json.original_columns, - table_name_mappings=rel_json.table_name_mappings, - ) - return BackupRelationalData( - tables=tables, foreign_keys=foreign_keys, relational_jsons=relational_jsons - ) + + # Producer tables delegate their foreign keys to root invented tables. + # We exclude producers here to avoid adding duplicate foreign keys. + if not rel_data.is_producer_of_invented_tables(table): + foreign_keys.extend( + [ + BackupForeignKey.from_fk(key) + for key in rel_data.get_foreign_keys(table) + ] + ) + return BackupRelationalData(tables=tables, foreign_keys=foreign_keys) + + +def _optionally_as_dict( + metadata: Optional[Union[InventedTableMetadata, ProducerMetadata]] +) -> Optional[dict[str, Any]]: + if metadata is None: + return None + + return asdict(metadata) @dataclass @@ -137,10 +134,6 @@ def from_dict(cls, b: dict[str, Any]): ) for fk in relational_data.get("foreign_keys", []) ], - relational_jsons={ - k: BackupRelationalJson(**v) - for k, v in relational_data.get("relational_jsons", {}).items() - }, ) backup = Backup( diff --git a/src/gretel_trainer/relational/core.py b/src/gretel_trainer/relational/core.py index a9fb3c43..a0d92386 100644 --- a/src/gretel_trainer/relational/core.py +++ b/src/gretel_trainer/relational/core.py @@ -25,11 +25,11 @@ from networkx.algorithms.dag import dag_longest_path_length, topological_sort from pandas.api.types import is_string_dtype +import gretel_trainer.relational.json as relational_json from gretel_trainer.relational.json import ( IngestResponseT, InventedTableMetadata, - RelationalJson, - get_json_columns, + ProducerMetadata, ) logger = logging.getLogger(__name__) @@ -90,9 +90,18 @@ class TableMetadata: data: pd.DataFrame columns: list[str] invented_table_metadata: Optional[InventedTableMetadata] = None + producer_metadata: Optional[ProducerMetadata] = None safe_ancestral_seed_columns: Optional[set[str]] = None +@dataclass +class _RemovedTableMetadata: + data: pd.DataFrame + primary_key: list[str] + fks_to_parents: list[ForeignKey] + fks_from_children: list[ForeignKey] + + class RelationalData: """ Stores information about multiple tables and their relationships. When @@ -111,7 +120,6 @@ class RelationalData: def __init__(self): self.graph = networkx.DiGraph() - self.relational_jsons: dict[str, RelationalJson] = {} @property def is_empty(self) -> bool: @@ -133,16 +141,28 @@ def restore(self, tableset: dict[str, pd.DataFrame]) -> dict[str, pd.DataFrame]: discarded = set() # Restore any invented tables to nested-JSON format - for table_name, rel_json in self.relational_jsons.items(): + producers = { + table: pmeta + for table in self.list_all_tables(Scope.ALL) + if (pmeta := self.get_producer_metadata(table)) is not None + } + for table_name, producer_metadata in producers.items(): tables = { table: data for table, data in tableset.items() - if table in rel_json.table_names + if table in producer_metadata.table_names } - data = rel_json.restore(tables, self) + data = relational_json.restore( + tables=tables, + rel_data=self, + root_table_name=producer_metadata.invented_root_table_name, + original_columns=self.get_table_columns(table_name), + table_name_mappings=producer_metadata.table_name_mappings, + original_table_name=table_name, + ) if data is not None: restored[table_name] = data - discarded.update(rel_json.table_names) + discarded.update(producer_metadata.table_names) # Add remaining tables for table, data in tableset.items(): @@ -167,7 +187,7 @@ def add_table( """ primary_key = self._format_key_column(primary_key) if (rj_ingest := self._check_for_json(name, primary_key, data)) is not None: - self._add_rel_json_and_tables(name, rj_ingest) + self._add_rel_json_and_tables(name, primary_key, data, rj_ingest) else: self._add_single_table(name=name, primary_key=primary_key, data=data) @@ -177,19 +197,32 @@ def _check_for_json( primary_key: list[str], data: pd.DataFrame, ) -> Optional[IngestResponseT]: - json_cols = get_json_columns(data) + json_cols = relational_json.get_json_columns(data) if len(json_cols) > 0: logger.info( f"Detected JSON data in table `{table}`. Running JSON normalization." ) - return RelationalJson.ingest(table, primary_key, data, json_cols) + return relational_json.ingest(table, primary_key, data, json_cols) - def _add_rel_json_and_tables(self, table: str, rj_ingest: IngestResponseT) -> None: - rel_json, commands = rj_ingest + def _add_rel_json_and_tables( + self, + table: str, + primary_key: list[str], + data: pd.DataFrame, + rj_ingest: IngestResponseT, + ) -> None: + commands, producer_metadata = rj_ingest tables, foreign_keys = commands - self.relational_jsons[table] = rel_json + # Add this table as a standalone node + self._add_single_table( + name=table, + primary_key=primary_key, + data=data, + producer_metadata=producer_metadata, + ) + # Add the invented tables for tbl in tables: self._add_single_table(**tbl) for foreign_key in foreign_keys: @@ -202,6 +235,7 @@ def _add_single_table( primary_key: UserFriendlyPrimaryKeyT, data: pd.DataFrame, invented_table_metadata: Optional[InventedTableMetadata] = None, + producer_metadata: Optional[ProducerMetadata] = None, ) -> None: primary_key = self._format_key_column(primary_key) metadata = TableMetadata( @@ -209,9 +243,16 @@ def _add_single_table( data=data, columns=list(data.columns), invented_table_metadata=invented_table_metadata, + producer_metadata=producer_metadata, ) self.graph.add_node(name, metadata=metadata) + def _get_table_metadata(self, table: str) -> TableMetadata: + try: + return self.graph.nodes[table]["metadata"] + except KeyError: + raise MultiTableException(f"Unrecognized table: `{table}`") + def set_primary_key( self, *, table: str, primary_key: UserFriendlyPrimaryKeyT ) -> None: @@ -229,29 +270,50 @@ def set_primary_key( if col not in known_columns: raise MultiTableException(f"Unrecognized column name: `{col}`") - if self.relational_jsons.get(table) is not None: - original_data, _, original_fks = self._remove_relational_json(table) - if original_data is None: - raise MultiTableException("Original data with JSON is lost.") + # Prevent interfering with manually invented tables + if self._is_invented(table): + raise MultiTableException("Cannot change primary key on invented tables") - new_rj_ingest = RelationalJson.ingest(table, primary_key, original_data) + # If `table` is a producer of invented tables, we redo JSON ingestion + # to ensure primary keys are set properly on invented tables + elif self.is_producer_of_invented_tables(table): + removal_metadata = self._remove_producer(table) + original_data = removal_metadata.data + new_rj_ingest = relational_json.ingest(table, primary_key, original_data) if new_rj_ingest is None: raise MultiTableException( "Failed to change primary key on tables invented from JSON data" ) - self._add_rel_json_and_tables(table, new_rj_ingest) - for fk in original_fks: - self.add_foreign_key_constraint( - table=fk.table_name, - constrained_columns=fk.columns, - referred_table=fk.parent_table_name, - referred_columns=fk.parent_columns, - ) + self._add_rel_json_and_tables( + table, primary_key, original_data, new_rj_ingest + ) + self._restore_fks_in_both_directions(table, removal_metadata) + + # At this point we are working with a "normal" table else: - self.graph.nodes[table]["metadata"].primary_key = primary_key + self._get_table_metadata(table).primary_key = primary_key self._clear_safe_ancestral_seed_columns(table) + def _restore_fks_in_both_directions( + self, table: str, removal_metadata: _RemovedTableMetadata + ) -> None: + for fk in removal_metadata.fks_to_parents: + self.add_foreign_key_constraint( + table=table, + constrained_columns=fk.columns, + referred_table=fk.parent_table_name, + referred_columns=fk.parent_columns, + ) + + for fk in removal_metadata.fks_from_children: + self.add_foreign_key_constraint( + table=fk.table_name, + constrained_columns=fk.columns, + referred_table=table, + referred_columns=fk.parent_columns, + ) + def _get_user_defined_fks_to_table(self, table: str) -> list[ForeignKey]: return [ fk @@ -260,28 +322,36 @@ def _get_user_defined_fks_to_table(self, table: str) -> list[ForeignKey]: if fk.parent_table_name == table and not self._is_invented(fk.table_name) ] - def _remove_relational_json( - self, table: str - ) -> tuple[Optional[pd.DataFrame], list[str], list[ForeignKey]]: + def _remove_producer(self, table: str) -> _RemovedTableMetadata: """ - Removes the RelationalJson instance and removes all invented tables from the graph. + Removes the producer table and all its invented tables from the graph + (which in turn removes all edges (foreign keys) to/from other tables). - Returns the original data, primary key, and foreign keys. + Returns a _RemovedTableMetadata object for restoring metadata in broader "update" contexts. """ - rel_json = self.relational_jsons[table] + table_metadata = self._get_table_metadata(table) + producer_metadata = table_metadata.producer_metadata + + if producer_metadata is None: + raise MultiTableException( + "Cannot remove invented tables from non-producer table" + ) - original_data = rel_json.original_data - original_primary_key = rel_json.original_primary_key - original_foreign_keys = self._get_user_defined_fks_to_table( - rel_json.root_table_name + removal_metadata = _RemovedTableMetadata( + data=table_metadata.data, + primary_key=table_metadata.primary_key, + fks_to_parents=self.get_foreign_keys(table), + fks_from_children=self._get_user_defined_fks_to_table( + self._get_fk_delegate_table(table) + ), ) - for invented_table_name in rel_json.table_names: + for invented_table_name in producer_metadata.table_names: if invented_table_name in self.graph.nodes: self.graph.remove_node(invented_table_name) - del self.relational_jsons[table] + self.graph.remove_node(table) - return original_data, original_primary_key, original_foreign_keys + return removal_metadata def _format_key_column(self, key: Optional[Union[str, list[str]]]) -> list[str]: if key is None: @@ -340,9 +410,6 @@ def add_foreign_key_constraint( if abort: raise MultiTableException("Unrecognized table(s) in foreign key arguments") - table = self._get_table_in_graph(table) - referred_table = self._get_table_in_graph(referred_table) - if len(constrained_columns) != len(referred_columns): logger.warning( "Constrained and referred columns must be of the same length" @@ -369,18 +436,22 @@ def add_foreign_key_constraint( if abort: raise MultiTableException("Unrecognized column(s) in foreign key arguments") - self.graph.add_edge(table, referred_table) - edge = self.graph.edges[table, referred_table] + fk_delegate_table = self._get_fk_delegate_table(table) + fk_delegate_referred_table = self._get_fk_delegate_table(referred_table) + + self.graph.add_edge(fk_delegate_table, fk_delegate_referred_table) + edge = self.graph.edges[fk_delegate_table, fk_delegate_referred_table] via = edge.get("via", []) via.append( ForeignKey( - table_name=table, + table_name=fk_delegate_table, columns=constrained_columns, - parent_table_name=referred_table, + parent_table_name=fk_delegate_referred_table, parent_columns=referred_columns, ) ) edge["via"] = via + self._clear_safe_ancestral_seed_columns(fk_delegate_table) self._clear_safe_ancestral_seed_columns(table) def remove_foreign_key(self, foreign_key: str) -> None: @@ -405,8 +476,6 @@ def remove_foreign_key_constraint( if table not in self.list_all_tables(Scope.ALL): raise MultiTableException(f"Unrecognized table name: `{table}`") - table = self._get_table_in_graph(table) - key_to_remove = None for fk in self.get_foreign_keys(table): if fk.columns == constrained_columns: @@ -417,65 +486,37 @@ def remove_foreign_key_constraint( f"`{table} does not have a foreign key with constrained columns {constrained_columns}`" ) - edge = self.graph.edges[table, key_to_remove.parent_table_name] + fk_delegate_table = self._get_fk_delegate_table(table) + + edge = self.graph.edges[fk_delegate_table, key_to_remove.parent_table_name] via = edge.get("via") via.remove(key_to_remove) if len(via) == 0: - self.graph.remove_edge(table, key_to_remove.parent_table_name) + self.graph.remove_edge(fk_delegate_table, key_to_remove.parent_table_name) else: edge["via"] = via + self._clear_safe_ancestral_seed_columns(fk_delegate_table) self._clear_safe_ancestral_seed_columns(table) def update_table_data(self, table: str, data: pd.DataFrame) -> None: """ Set a DataFrame as the table data for a given table name. """ - if table in self.relational_jsons: - _, original_pk, original_fks = self._remove_relational_json(table) - if ( - new_rj_ingest := self._check_for_json(table, original_pk, data) - ) is not None: - self._add_rel_json_and_tables(table, new_rj_ingest) - parent_table_name = new_rj_ingest[0].root_table_name - else: - self._add_single_table( - name=table, - primary_key=original_pk, - data=data, - ) - parent_table_name = table - for fk in original_fks: - self.add_foreign_key_constraint( - table=fk.table_name, - constrained_columns=fk.columns, - referred_table=parent_table_name, - referred_columns=fk.parent_columns, - ) + if self._is_invented(table): + raise MultiTableException("Cannot modify invented tables' data") + elif self.is_producer_of_invented_tables(table): + removal_metadata = self._remove_producer(table) else: - try: - metadata = self.graph.nodes[table]["metadata"] - except KeyError: - raise MultiTableException( - f"Unrecognized table name: {table}. If this is a new table to add, use `add_table`." - ) + removal_metadata = _RemovedTableMetadata( + data=pd.DataFrame(), # we don't care about the old data + primary_key=self.get_primary_key(table), + fks_to_parents=self.get_foreign_keys(table), + fks_from_children=self._get_user_defined_fks_to_table(table), + ) + self.graph.remove_node(table) - if ( - new_rj_ingest := self._check_for_json(table, metadata.primary_key, data) - ) is not None: - original_foreign_keys = self._get_user_defined_fks_to_table(table) - self.graph.remove_node(table) - self._add_rel_json_and_tables(table, new_rj_ingest) - for fk in original_foreign_keys: - self.add_foreign_key_constraint( - table=fk.table_name, - constrained_columns=fk.columns, - referred_table=new_rj_ingest[0].root_table_name, - referred_columns=fk.parent_columns, - ) - else: - metadata.data = data - metadata.columns = list(data.columns) - self._clear_safe_ancestral_seed_columns(table) + self.add_table(name=table, primary_key=removal_metadata.primary_key, data=data) + self._restore_fks_in_both_directions(table, removal_metadata) def list_all_tables(self, scope: Scope = Scope.MODELABLE) -> list[str]: """ @@ -486,19 +527,16 @@ def list_all_tables(self, scope: Scope = Scope.MODELABLE) -> list[str]: """ graph_nodes = list(self.graph.nodes) - json_source_tables = [ - rel_json.original_table_name - for _, rel_json in self.relational_jsons.items() + producer_tables = [ + t for t in graph_nodes if self.is_producer_of_invented_tables(t) ] - all_tables = graph_nodes + json_source_tables - modelable_tables = [] evaluatable_tables = [] invented_tables: list[str] = [] for n in graph_nodes: - meta = self.graph.nodes[n]["metadata"] + meta = self._get_table_metadata(n) if (invented_meta := meta.invented_table_metadata) is not None: invented_tables.append(n) if invented_meta.invented_root_table_name == n: @@ -506,8 +544,9 @@ def list_all_tables(self, scope: Scope = Scope.MODELABLE) -> list[str]: if not invented_meta.empty: modelable_tables.append(n) else: - modelable_tables.append(n) - evaluatable_tables.append(n) + if n not in producer_tables: + modelable_tables.append(n) + evaluatable_tables.append(n) if scope == Scope.MODELABLE: return modelable_tables @@ -516,15 +555,15 @@ def list_all_tables(self, scope: Scope = Scope.MODELABLE) -> list[str]: elif scope == Scope.INVENTED: return invented_tables elif scope == Scope.ALL: - return all_tables + return graph_nodes elif scope == Scope.PUBLIC: - return [t for t in all_tables if t not in invented_tables] + return [t for t in graph_nodes if t not in invented_tables] def _is_invented(self, table: str) -> bool: - return ( - table in self.graph.nodes - and self.graph.nodes[table]["metadata"].invented_table_metadata is not None - ) + return self.get_invented_table_metadata(table) is not None + + def is_producer_of_invented_tables(self, table: str) -> bool: + return self.get_producer_metadata(table) is not None def get_modelable_table_names(self, table: str) -> list[str]: """ @@ -533,10 +572,15 @@ def get_modelable_table_names(self, table: str) -> list[str]: If the provided table is itself modelable, returns that table name back. Otherwise returns an empty list. """ - if (rel_json := self.relational_jsons.get(table)) is not None: + try: + table_metadata = self._get_table_metadata(table) + except MultiTableException: + return [] + + if (pmeta := table_metadata.producer_metadata) is not None: return [ t - for t in rel_json.table_names + for t in pmeta.table_names if t in self.list_all_tables(Scope.MODELABLE) ] elif table in self.list_all_tables(Scope.MODELABLE): @@ -545,12 +589,7 @@ def get_modelable_table_names(self, table: str) -> list[str]: return [] def get_public_name(self, table: str) -> Optional[str]: - if table in self.relational_jsons: - return table - - if ( - imeta := self.graph.nodes[table]["metadata"].invented_table_metadata - ) is not None: + if (imeta := self.get_invented_table_metadata(table)) is not None: return imeta.original_table_name return table @@ -558,10 +597,10 @@ def get_public_name(self, table: str) -> Optional[str]: def get_invented_table_metadata( self, table: str ) -> Optional[InventedTableMetadata]: - if table in self.relational_jsons: - return None + return self._get_table_metadata(table).invented_table_metadata - return self.graph.nodes[table]["metadata"].invented_table_metadata + def get_producer_metadata(self, table: str) -> Optional[ProducerMetadata]: + return self._get_table_metadata(table).producer_metadata def get_parents(self, table: str) -> list[str]: """ @@ -621,13 +660,7 @@ def get_primary_key(self, table: str) -> list[str]: Return the list of columns defining the primary key for a table. It may be a single column or multiple columns (composite key). """ - try: - return self.graph.nodes[table]["metadata"].primary_key - except KeyError: - if table in self.relational_jsons: - return self.relational_jsons[table].original_primary_key - else: - raise MultiTableException(f"Unrecognized table: `{table}`") + return self._get_table_metadata(table).primary_key def get_table_data( self, table: str, usecols: Optional[list[str]] = None @@ -636,30 +669,16 @@ def get_table_data( Return the table contents for a given table name as a DataFrame. """ usecols = usecols or self.get_table_columns(table) - try: - return self.graph.nodes[table]["metadata"].data[usecols] - except KeyError: - if table in self.relational_jsons: - if (df := self.relational_jsons[table].original_data) is None: - raise MultiTableException("Original data with JSON is lost.") - return df - else: - raise MultiTableException(f"Unrecognized table: `{table}`") + return self._get_table_metadata(table).data[usecols] def get_table_columns(self, table: str) -> list[str]: """ Return the column names for a provided table name. """ - try: - return self.graph.nodes[table]["metadata"].columns - except KeyError: - if table in self.relational_jsons: - return self.relational_jsons[table].original_columns - else: - raise MultiTableException(f"Unrecognized table: `{table}`") + return self._get_table_metadata(table).columns def get_safe_ancestral_seed_columns(self, table: str) -> set[str]: - safe_columns = self.graph.nodes[table]["metadata"].safe_ancestral_seed_columns + safe_columns = self._get_table_metadata(table).safe_ancestral_seed_columns if safe_columns is None: safe_columns = self._set_safe_ancestral_seed_columns(table) return safe_columns @@ -679,36 +698,29 @@ def _set_safe_ancestral_seed_columns(self, table: str) -> set[str]: if _ok_for_train_and_seed(col, data): cols.add(col) - self.graph.nodes[table]["metadata"].safe_ancestral_seed_columns = cols + self._get_table_metadata(table).safe_ancestral_seed_columns = cols return cols def _clear_safe_ancestral_seed_columns(self, table: str) -> None: - self.graph.nodes[table]["metadata"].safe_ancestral_seed_columns = None + self._get_table_metadata(table).safe_ancestral_seed_columns = None + + def _get_fk_delegate_table(self, table: str) -> str: + if (pmeta := self.get_producer_metadata(table)) is not None: + return pmeta.invented_root_table_name - def _get_table_in_graph(self, table: str) -> str: - if table in self.relational_jsons: - table = self.relational_jsons[table].root_table_name return table def get_foreign_keys( self, table: str, rename_invented_tables: bool = False ) -> list[ForeignKey]: def _rename_invented(fk: ForeignKey) -> ForeignKey: - table_name = fk.table_name - parent_table_name = fk.parent_table_name - if self._is_invented(table_name): - table_name = self.graph.nodes[table_name][ - "metadata" - ].invented_table_metadata.original_table_name - if self._is_invented(parent_table_name): - parent_table_name = self.graph.nodes[parent_table_name][ - "metadata" - ].invented_table_metadata.original_table_name + table_name = self.get_public_name(fk.table_name) + parent_table_name = self.get_public_name(fk.parent_table_name) return replace( fk, table_name=table_name, parent_table_name=parent_table_name ) - table = self._get_table_in_graph(table) + table = self._get_fk_delegate_table(table) foreign_keys = [] for parent in self.get_parents(table): fks = self.graph.edges[table, parent]["via"] @@ -737,8 +749,7 @@ def debug_summary(self) -> dict[str, Any]: for table in all_tables: this_table_foreign_key_count = 0 foreign_keys = [] - fk_lookup_table_name = self._get_table_in_graph(table) - for key in self.get_foreign_keys(fk_lookup_table_name): + for key in self.get_foreign_keys(table): total_foreign_key_count = total_foreign_key_count + 1 this_table_foreign_key_count = this_table_foreign_key_count + 1 foreign_keys.append( diff --git a/src/gretel_trainer/relational/json.py b/src/gretel_trainer/relational/json.py index ee658070..36621249 100644 --- a/src/gretel_trainer/relational/json.py +++ b/src/gretel_trainer/relational/json.py @@ -151,167 +151,67 @@ class InventedTableMetadata: empty: bool -class RelationalJson: - def __init__( - self, - original_table_name: str, - original_primary_key: list[str], - original_columns: list[str], - original_data: Optional[pd.DataFrame], - table_name_mappings: dict[str, str], - ): - self.original_table_name = original_table_name - self.original_primary_key = original_primary_key - self.original_columns = original_columns - self.original_data = original_data - self.table_name_mappings = table_name_mappings - - @classmethod - def ingest( - cls, - table_name: str, - primary_key: list[str], - df: pd.DataFrame, - json_columns: Optional[list[str]] = None, - ) -> Optional[IngestResponseT]: - tables = _normalize_json([(table_name, df.copy())], [], json_columns) - # If we created additional tables (from JSON lists) or added columns (from JSON dicts) - if len(tables) > 1 or len(tables[0][1].columns) > len(df.columns): - mappings = {name: sanitize_str(name) for name, _ in tables} - logger.info(f"Transformed JSON into {len(mappings)} tables for modeling.") - logger.debug(f"Invented table names: {list(mappings.values())}") - rel_json = RelationalJson( - original_table_name=table_name, - original_primary_key=primary_key, - original_data=df, - original_columns=list(df.columns), - table_name_mappings=mappings, - ) - commands = _generate_commands(rel_json, tables) - return (rel_json, commands) - - @property - def root_table_name(self) -> str: - return self.table_name_mappings[self.original_table_name] +@dataclass +class ProducerMetadata: + invented_root_table_name: str + table_name_mappings: dict[str, str] @property def table_names(self) -> list[str]: - """Returns sanitized, model-friendly table names, *including* those of empty invented tables.""" return list(self.table_name_mappings.values()) - @property - def inverse_table_name_mappings(self) -> dict[str, str]: - # Keys are sanitized, model-friendly names - # Values are "provenance" names (a^b>c) or the original table name - return {value: key for key, value in self.table_name_mappings.items()} - - def restore( - self, tables: dict[str, pd.DataFrame], rel_data: _RelationalData - ) -> Optional[pd.DataFrame]: - """Reduces a set of tables (assumed to match the shapes created on initialization) - to a single table matching the shape of the original source table - """ - # If the root invented table failed, we are completely out of luck - # (Missing invented child tables can be replaced with empty lists so we at least provide _something_) - if self.root_table_name not in tables: - logger.warning( - f"Cannot restore nested JSON data: root invented table `{self.root_table_name}` is missing from output tables." - ) - return None - - return self._denormalize_json(tables, rel_data)[self.original_columns] - - def _denormalize_json( - self, tables: dict[str, pd.DataFrame], rel_data: _RelationalData - ) -> pd.DataFrame: - table_dict: dict = { - self.inverse_table_name_mappings[k]: v for k, v in tables.items() - } - for table_name in list(reversed(self.table_names)): - table_provenance_name = self.inverse_table_name_mappings[table_name] - empty_fallback = pd.DataFrame( - data={col: [] for col in rel_data.get_table_columns(table_name)}, - ) - table_df = table_dict.get(table_provenance_name, empty_fallback) - if table_df.empty and _is_invented_child_table(table_name, rel_data): - p_name = rel_data.get_foreign_keys(table_name)[0].parent_table_name - parent_name = self.inverse_table_name_mappings[p_name] - col_name = get_parent_column_name_from_child_table_name( - table_provenance_name - ) - kwargs = {col_name: table_dict[parent_name].apply(lambda x: [], axis=1)} - table_dict[parent_name] = table_dict[parent_name].assign(**kwargs) - else: - col_names = [col for col in table_df.columns if FIELD_SEPARATOR in col] - new_col_names = [col.replace(FIELD_SEPARATOR, ".") for col in col_names] - flat_df = table_df[col_names].rename( - columns=dict(zip(col_names, new_col_names)) - ) - flat_dict = { - k: { - k1: v1 - for k1, v1 in v.items() - if v1 is not np.nan and v1 is not None - } - for k, v in flat_df.to_dict("index").items() - } - dict_df = nulls_to_empty_dicts( - pd.DataFrame.from_dict( - {k: unflatten(v) for k, v in flat_dict.items()}, orient="index" - ) - ) - nested_df = table_df.join(dict_df).drop(columns=col_names) - if _is_invented_child_table(table_name, rel_data): - # we know there is exactly one foreign key on invented child tables... - fk = rel_data.get_foreign_keys(table_name)[0] - # ...with exactly one column - fk_col = fk.columns[0] - p_name = fk.parent_table_name - parent_name = self.inverse_table_name_mappings[p_name] - nested_df = ( - nested_df.sort_values(ORDER_COLUMN) - .groupby(fk_col)[CONTENT_COLUMN] - .agg(list) - ) - col_name = get_parent_column_name_from_child_table_name( - table_provenance_name - ) - table_dict[parent_name] = table_dict[parent_name].join( - nested_df.rename(col_name) - ) - table_dict[parent_name][col_name] = nulls_to_empty_lists( - table_dict[parent_name][col_name] - ) - table_dict[table_provenance_name] = nested_df - return table_dict[self.original_table_name] +def ingest( + table_name: str, + primary_key: list[str], + df: pd.DataFrame, + json_columns: Optional[list[str]] = None, +) -> Optional[IngestResponseT]: + tables = _normalize_json([(table_name, df.copy())], [], json_columns) + # If we created additional tables (from JSON lists) or added columns (from JSON dicts) + if len(tables) > 1 or len(tables[0][1].columns) > len(df.columns): + mappings = {name: sanitize_str(name) for name, _ in tables} + logger.info(f"Transformed JSON into {len(mappings)} tables for modeling.") + logger.debug(f"Invented table names: {list(mappings.values())}") + commands = _generate_commands( + tables=tables, + table_name_mappings=mappings, + original_table_name=table_name, + original_primary_key=primary_key, + ) + producer_metadata = ProducerMetadata( + invented_root_table_name=mappings[table_name], + table_name_mappings=mappings, + ) + return (commands, producer_metadata) def _generate_commands( - rel_json: RelationalJson, tables: list[tuple[str, pd.DataFrame]] + tables: list[tuple[str, pd.DataFrame]], + table_name_mappings: dict[str, str], + original_table_name: str, + original_primary_key: list[str], ) -> CommandsT: """ Returns lists of keyword arguments designed to be passed to a RelationalData instance's _add_single_table and add_foreign_key methods """ - tables_to_add = {rel_json.table_name_mappings[name]: df for name, df in tables} + tables_to_add = {table_name_mappings[name]: df for name, df in tables} + root_table_name = table_name_mappings[original_table_name] _add_single_table = [] add_foreign_key = [] for table_name, table_df in tables_to_add.items(): - if table_name == rel_json.root_table_name: - table_pk = rel_json.original_primary_key + [PRIMARY_KEY_COLUMN] + if table_name == root_table_name: + table_pk = original_primary_key + [PRIMARY_KEY_COLUMN] else: table_pk = [PRIMARY_KEY_COLUMN] table_df.index.rename(PRIMARY_KEY_COLUMN, inplace=True) table_df.reset_index(inplace=True) - invented_root_table_name = rel_json.table_name_mappings[ - rel_json.original_table_name - ] metadata = InventedTableMetadata( - invented_root_table_name=invented_root_table_name, - original_table_name=rel_json.original_table_name, + invented_root_table_name=root_table_name, + original_table_name=original_table_name, empty=table_df.empty, ) _add_single_table.append( @@ -325,7 +225,7 @@ def _generate_commands( for table_name, table_df in tables_to_add.items(): for column in get_id_columns(table_df): - referred_table = rel_json.table_name_mappings[ + referred_table = table_name_mappings[ get_parent_table_name_from_child_id_column(column) ] add_foreign_key.append( @@ -339,6 +239,96 @@ def _generate_commands( return (_add_single_table, add_foreign_key) +def restore( + tables: dict[str, pd.DataFrame], + rel_data: _RelationalData, + root_table_name: str, + original_columns: list[str], + table_name_mappings: dict[str, str], + original_table_name: str, +) -> Optional[pd.DataFrame]: + # If the root invented table is not present, we are completely out of luck + # (Missing invented child tables can be replaced with empty lists so we at least provide _something_) + if root_table_name not in tables: + logger.warning( + f"Cannot restore nested JSON data: root invented table `{root_table_name}` is missing from output tables." + ) + return None + + return _denormalize_json( + tables, rel_data, table_name_mappings, original_table_name + )[original_columns] + + +def _denormalize_json( + tables: dict[str, pd.DataFrame], + rel_data: _RelationalData, + table_name_mappings: dict[str, str], + original_table_name: str, +) -> pd.DataFrame: + table_names = list(table_name_mappings.values()) + inverse_table_name_mappings = {v: k for k, v in table_name_mappings.items()} + table_dict: dict = {inverse_table_name_mappings[k]: v for k, v in tables.items()} + for table_name in list(reversed(table_names)): + table_provenance_name = inverse_table_name_mappings[table_name] + empty_fallback = pd.DataFrame( + data={col: [] for col in rel_data.get_table_columns(table_name)}, + ) + table_df = table_dict.get(table_provenance_name, empty_fallback) + + if table_df.empty and _is_invented_child_table(table_name, rel_data): + p_name = rel_data.get_foreign_keys(table_name)[0].parent_table_name + parent_name = inverse_table_name_mappings[p_name] + col_name = get_parent_column_name_from_child_table_name( + table_provenance_name + ) + kwargs = {col_name: table_dict[parent_name].apply(lambda x: [], axis=1)} + table_dict[parent_name] = table_dict[parent_name].assign(**kwargs) + else: + col_names = [col for col in table_df.columns if FIELD_SEPARATOR in col] + new_col_names = [col.replace(FIELD_SEPARATOR, ".") for col in col_names] + flat_df = table_df[col_names].rename( + columns=dict(zip(col_names, new_col_names)) + ) + flat_dict = { + k: { + k1: v1 + for k1, v1 in v.items() + if v1 is not np.nan and v1 is not None + } + for k, v in flat_df.to_dict("index").items() + } + dict_df = nulls_to_empty_dicts( + pd.DataFrame.from_dict( + {k: unflatten(v) for k, v in flat_dict.items()}, orient="index" + ) + ) + nested_df = table_df.join(dict_df).drop(columns=col_names) + if _is_invented_child_table(table_name, rel_data): + # we know there is exactly one foreign key on invented child tables... + fk = rel_data.get_foreign_keys(table_name)[0] + # ...with exactly one column + fk_col = fk.columns[0] + p_name = fk.parent_table_name + parent_name = inverse_table_name_mappings[p_name] + nested_df = ( + nested_df.sort_values(ORDER_COLUMN) + .groupby(fk_col)[CONTENT_COLUMN] + .agg(list) + ) + col_name = get_parent_column_name_from_child_table_name( + table_provenance_name + ) + table_dict[parent_name] = table_dict[parent_name].join( + nested_df.rename(col_name) + ) + table_dict[parent_name][col_name] = nulls_to_empty_lists( + table_dict[parent_name][col_name] + ) + table_dict[table_provenance_name] = nested_df + return table_dict[original_table_name] + + def get_json_columns(df: pd.DataFrame) -> list[str]: """ Samples non-null values from all columns and returns those that contain @@ -371,4 +361,4 @@ def get_json_columns(df: pd.DataFrame) -> list[str]: CommandsT = tuple[list[dict], list[dict]] -IngestResponseT = tuple[RelationalJson, CommandsT] +IngestResponseT = tuple[CommandsT, ProducerMetadata] diff --git a/src/gretel_trainer/relational/multi_table.py b/src/gretel_trainer/relational/multi_table.py index f20814c0..9bed5ee8 100644 --- a/src/gretel_trainer/relational/multi_table.py +++ b/src/gretel_trainer/relational/multi_table.py @@ -45,7 +45,7 @@ Scope, skip_table, ) -from gretel_trainer.relational.json import InventedTableMetadata, RelationalJson +from gretel_trainer.relational.json import InventedTableMetadata, ProducerMetadata from gretel_trainer.relational.log import silent_logs from gretel_trainer.relational.model_config import ( get_model_key, @@ -173,17 +173,17 @@ def _complete_init_from_backup(self, backup: Backup) -> None: for table_name, table_backup in backup.relational_data.tables.items(): source_data = pd.read_csv(self._working_dir / f"source_{table_name}.csv") invented_table_metadata = None + producer_metadata = None if (imeta := table_backup.invented_table_metadata) is not None: - invented_table_metadata = InventedTableMetadata( - invented_root_table_name=imeta["invented_root_table_name"], - original_table_name=imeta["original_table_name"], - empty=imeta["empty"], - ) + invented_table_metadata = InventedTableMetadata(**imeta) + if (pmeta := table_backup.producer_metadata) is not None: + producer_metadata = ProducerMetadata(**pmeta) self.relational_data._add_single_table( name=table_name, primary_key=table_backup.primary_key, data=source_data, invented_table_metadata=invented_table_metadata, + producer_metadata=producer_metadata, ) for fk_backup in backup.relational_data.foreign_keys: self.relational_data.add_foreign_key_constraint( @@ -192,15 +192,6 @@ def _complete_init_from_backup(self, backup: Backup) -> None: referred_table=fk_backup.referred_table, referred_columns=fk_backup.referred_columns, ) - for key, rel_json_backup in backup.relational_data.relational_jsons.items(): - relational_json = RelationalJson( - original_table_name=rel_json_backup.original_table_name, - original_primary_key=rel_json_backup.original_primary_key, - original_columns=rel_json_backup.original_columns, - original_data=None, - table_name_mappings=rel_json_backup.table_name_mappings, - ) - self.relational_data.relational_jsons[key] = relational_json # Debug summary debug_summary_id = backup.artifact_collection.gretel_debug_summary diff --git a/tests/relational/test_backup.py b/tests/relational/test_backup.py index 06933668..c9cb37eb 100644 --- a/tests/relational/test_backup.py +++ b/tests/relational/test_backup.py @@ -8,7 +8,6 @@ BackupGenerate, BackupRelationalData, BackupRelationalDataTable, - BackupRelationalJson, BackupSyntheticsTrain, BackupTransformsTrain, ) @@ -32,7 +31,6 @@ def test_backup_relational_data(trips): referred_columns=["id"], ) ], - relational_jsons={}, ) assert BackupRelationalData.from_relational_data(trips) == expected @@ -42,6 +40,16 @@ def test_backup_relational_data_with_json(documents): expected = BackupRelationalData( tables={ "users": BackupRelationalDataTable(primary_key=["id"]), + "purchases": BackupRelationalDataTable( + primary_key=["id"], + producer_metadata={ + "invented_root_table_name": "purchases-sfx", + "table_name_mappings": { + "purchases": "purchases-sfx", + "purchases^data>years": "purchases-data-years-sfx", + }, + }, + ), "purchases-sfx": BackupRelationalDataTable( primary_key=["id", "~PRIMARY_KEY_ID~"], invented_table_metadata={ @@ -80,17 +88,6 @@ def test_backup_relational_data_with_json(documents): referred_columns=["~PRIMARY_KEY_ID~"], ), ], - relational_jsons={ - "purchases": BackupRelationalJson( - original_table_name="purchases", - original_primary_key=["id"], - original_columns=["id", "user_id", "data"], - table_name_mappings={ - "purchases": "purchases-sfx", - "purchases^data>years": "purchases-data-years-sfx", - }, - ), - }, ) assert BackupRelationalData.from_relational_data(documents) == expected @@ -114,7 +111,6 @@ def test_backup(): referred_columns=["id"], ) ], - relational_jsons={}, ) backup_classify = BackupClassify( model_ids={ diff --git a/tests/relational/test_relational_data_with_json.py b/tests/relational/test_relational_data_with_json.py index 4b7bc481..96ab0019 100644 --- a/tests/relational/test_relational_data_with_json.py +++ b/tests/relational/test_relational_data_with_json.py @@ -134,17 +134,25 @@ def test_get_modelable_table_names(documents): def test_get_modelable_names_ignores_empty_mapped_tables(bball): # The `suspensions` column in the source data contained empty lists for all records. - # We need to hold onto that table name on the RelationalJson instance to support - # denormalizing back to the original source data shape. It is therefore exposed - # in the `table_names` attribute on RelationalJson... - assert set(bball.relational_jsons["bball"].table_names) == { + # The normalization process transforms that into a standalone, empty table. + # We need to hold onto that table name to support denormalizing back to the original + # source data shape. It is therefore present when listing ALL tables... + assert set(bball.list_all_tables(Scope.ALL)) == { + "bball", "bball-sfx", "bball-teams-sfx", "bball-suspensions-sfx", } - # ...BUT clients of RelationalData only care about invented tables that can be modeled - # (i.e. that contain data), so that class does not expose the empty table. + # ...and the producer metadata is aware of it... + assert set(bball.get_producer_metadata("bball").table_names) == { + "bball-sfx", + "bball-teams-sfx", + "bball-suspensions-sfx", + } + + # ...BUT most clients only care about invented tables that can be modeled + # (i.e. that contain data), so the empty table does not appear in these contexts: assert set(bball.get_modelable_table_names("bball")) == { "bball-sfx", "bball-teams-sfx",