From 58e8891a7fd9615754f1f00f4adaa33742ff07bd Mon Sep 17 00:00:00 2001 From: Mike Knepper Date: Tue, 16 May 2023 15:23:05 -0500 Subject: [PATCH] Support for JSON (#107) --- requirements.txt | 1 + src/gretel_trainer/relational/__init__.py | 1 + src/gretel_trainer/relational/backup.py | 35 +- src/gretel_trainer/relational/core.py | 409 +++++++-- src/gretel_trainer/relational/json.py | 340 ++++++++ src/gretel_trainer/relational/log.py | 5 + src/gretel_trainer/relational/multi_table.py | 92 +- .../relational/report/report.py | 6 +- tests/relational/conftest.py | 12 + tests/relational/example_dbs/documents.sql | 48 ++ tests/relational/test_backup.py | 61 ++ tests/relational/test_connectors.py | 6 +- tests/relational/test_relational_data.py | 109 +-- .../test_relational_data_with_json.py | 804 ++++++++++++++++++ tests/relational/test_report.py | 35 +- 15 files changed, 1756 insertions(+), 208 deletions(-) create mode 100644 src/gretel_trainer/relational/json.py create mode 100644 tests/relational/example_dbs/documents.sql create mode 100644 tests/relational/test_relational_data_with_json.py diff --git a/requirements.txt b/requirements.txt index 5d84ae7d..3e4d4949 100644 --- a/requirements.txt +++ b/requirements.txt @@ -10,3 +10,4 @@ requests~=2.25 scikit-learn~=1.0 smart-open[s3]~=5.2 sqlalchemy~=1.4 +unflatten==0.1.1 diff --git a/src/gretel_trainer/relational/__init__.py b/src/gretel_trainer/relational/__init__.py index 534fff4d..01532e12 100644 --- a/src/gretel_trainer/relational/__init__.py +++ b/src/gretel_trainer/relational/__init__.py @@ -8,4 +8,5 @@ sqlite_conn, ) from gretel_trainer.relational.core import RelationalData +from gretel_trainer.relational.log import set_log_level from gretel_trainer.relational.multi_table import MultiTable diff --git a/src/gretel_trainer/relational/backup.py b/src/gretel_trainer/relational/backup.py index de996a51..3ff9e216 100644 --- a/src/gretel_trainer/relational/backup.py +++ b/src/gretel_trainer/relational/backup.py @@ -10,6 +10,7 @@ @dataclass class BackupRelationalDataTable: primary_key: List[str] + invented_table_metadata: Optional[dict[str, str]] = None @dataclass @@ -29,26 +30,52 @@ def from_fk(cls, fk: ForeignKey) -> BackupForeignKey: ) +@dataclass +class BackupRelationalJson: + original_table_name: str + original_primary_key: list[str] + original_columns: list[str] + table_name_mappings: dict[str, str] + invented_table_names: list[str] + + @dataclass class BackupRelationalData: tables: Dict[str, BackupRelationalDataTable] foreign_keys: List[BackupForeignKey] + relational_jsons: Dict[str, BackupRelationalJson] @classmethod def from_relational_data(cls, rel_data: RelationalData) -> BackupRelationalData: tables = {} foreign_keys = [] + relational_jsons = {} for table in rel_data.list_all_tables(): - tables[table] = BackupRelationalDataTable( + backup_table = BackupRelationalDataTable( primary_key=rel_data.get_primary_key(table), ) + if ( + invented_table_metadata := rel_data.get_invented_table_metadata(table) + ) is not None: + backup_table.invented_table_metadata = asdict(invented_table_metadata) + tables[table] = backup_table foreign_keys.extend( [ BackupForeignKey.from_fk(key) for key in rel_data.get_foreign_keys(table) ] ) - return BackupRelationalData(tables=tables, foreign_keys=foreign_keys) + for key, rel_json in rel_data.relational_jsons.items(): + relational_jsons[key] = BackupRelationalJson( + original_table_name=rel_json.original_table_name, + original_primary_key=rel_json.original_primary_key, + original_columns=rel_json.original_columns, + table_name_mappings=rel_json.table_name_mappings, + invented_table_names=rel_json.table_names, + ) + return BackupRelationalData( + tables=tables, foreign_keys=foreign_keys, relational_jsons=relational_jsons + ) @dataclass @@ -114,6 +141,10 @@ def from_dict(cls, b: Dict[str, Any]): ) for fk in relational_data.get("foreign_keys", []) ], + relational_jsons={ + k: BackupRelationalJson(**v) + for k, v in relational_data.get("relational_jsons", {}).items() + }, ) backup = Backup( diff --git a/src/gretel_trainer/relational/core.py b/src/gretel_trainer/relational/core.py index 47a623f5..bc3a8e4a 100644 --- a/src/gretel_trainer/relational/core.py +++ b/src/gretel_trainer/relational/core.py @@ -1,8 +1,9 @@ from __future__ import annotations -import json import logging -from dataclasses import dataclass +from contextlib import suppress +from dataclasses import dataclass, replace +from enum import Enum from pathlib import Path from typing import Any, Dict, List, Optional, Union @@ -11,6 +12,12 @@ from networkx.algorithms.dag import dag_longest_path_length, topological_sort from pandas.api.types import is_string_dtype +from gretel_trainer.relational.json import ( + IngestResponseT, + InventedTableMetadata, + RelationalJson, +) + logger = logging.getLogger(__name__) @@ -32,17 +39,80 @@ class ForeignKey: UserFriendlyPrimaryKeyT = Optional[Union[str, List[str]]] +class Scope(str, Enum): + """ + Various non-mutually-exclusive sets of tables known to the system + """ + + ALL = "all" + """ + Every known table (all user-supplied tables, all invented tables) + """ + + PUBLIC = "public" + """ + Includes all user-supplied tables, omits invented tables + """ + + MODELABLE = "modelable" + """ + Includes flat source tables and all invented tables, omits source tables that led to invented tables + """ + + EVALUATABLE = "evaluatable" + """ + A subset of MODELABLE that additionally omits invented child tables (but includes invented root tables) + """ + + INVENTED = "invented" + """ + Includes all tables invented from un-modelable user source tables + """ + + @dataclass class TableMetadata: primary_key: list[str] data: pd.DataFrame columns: set[str] + invented_table_metadata: Optional[InventedTableMetadata] = None safe_ancestral_seed_columns: Optional[set[str]] = None class RelationalData: def __init__(self): self.graph = networkx.DiGraph() + self.relational_jsons: dict[str, RelationalJson] = {} + + def restore(self, tableset: dict[str, pd.DataFrame]) -> dict[str, pd.DataFrame]: + """Restores a given tableset (presumably output from some MultiTable workflow, + i.e. transforms or synthetics) to its original shape (specifically, "re-nests" + any JSON that had been expanded out. + + Users should rely on MultiTable calling this internally when appropriate and not + need to do so themselves. + """ + restored = {} + discarded = set() + + # Restore any invented tables to nested-JSON format + for table_name, rel_json in self.relational_jsons.items(): + tables = { + table: data + for table, data in tableset.items() + if table in rel_json.table_names + } + data = rel_json.restore(tables, self) + if data is not None: + restored[table_name] = data + discarded.update(rel_json.table_names) + + # Add remaining tables + for table, data in tableset.items(): + if table not in discarded: + restored[table] = data + + return restored def add_table( self, @@ -54,10 +124,42 @@ def add_table( """ Add a table. The primary key can be None (if one is not defined on the table), a string column name (most common), or a list of multiple string column names (composite key). + + This call MAY result in multiple tables getting "registered," specifically if + the table includes nested JSON data. """ + primary_key = self._format_key_column(primary_key) + rj_ingest = RelationalJson.ingest(name, primary_key, data) + if rj_ingest is not None: + self._add_rel_json_and_tables(name, rj_ingest) + else: + self._add_single_table(name=name, primary_key=primary_key, data=data) + + def _add_rel_json_and_tables(self, table: str, rj_ingest: IngestResponseT) -> None: + rel_json, commands = rj_ingest + tables, foreign_keys = commands + + self.relational_jsons[table] = rel_json + + for tbl in tables: + self._add_single_table(**tbl) + for foreign_key in foreign_keys: + self.add_foreign_key_constraint(**foreign_key) + + def _add_single_table( + self, + *, + name: str, + primary_key: UserFriendlyPrimaryKeyT, + data: pd.DataFrame, + invented_table_metadata: Optional[InventedTableMetadata] = None, + ) -> None: primary_key = self._format_key_column(primary_key) metadata = TableMetadata( - primary_key=primary_key, data=data, columns=set(data.columns) + primary_key=primary_key, + data=data, + columns=set(data.columns), + invented_table_metadata=invented_table_metadata, ) self.graph.add_node(name, metadata=metadata) @@ -68,7 +170,7 @@ def set_primary_key( (Re)set the primary key on an existing table. If the table does not yet exist in the instance's collection, add it via `add_table`. """ - if table not in self.list_all_tables(): + if table not in self.list_all_tables(Scope.ALL): raise MultiTableException(f"Unrecognized table name: `{table}`") primary_key = self._format_key_column(primary_key) @@ -76,10 +178,61 @@ def set_primary_key( known_columns = self.get_table_columns(table) for col in primary_key: if col not in known_columns: - raise MultiTableException(f"Unrecognized column name: `{primary_key}`") + raise MultiTableException(f"Unrecognized column name: `{col}`") - self.graph.nodes[table]["metadata"].primary_key = primary_key - self._clear_safe_ancestral_seed_columns(table) + if self.relational_jsons.get(table) is not None: + original_data, _, original_fks = self._remove_relational_json(table) + if original_data is None: + raise MultiTableException("Original data with JSON is lost.") + + new_rj_ingest = RelationalJson.ingest(table, primary_key, original_data) + if new_rj_ingest is None: + raise MultiTableException( + "Failed to change primary key on tables invented from JSON data" + ) + + self._add_rel_json_and_tables(table, new_rj_ingest) + for fk in original_fks: + self.add_foreign_key_constraint( + table=fk.table_name, + constrained_columns=fk.columns, + referred_table=fk.parent_table_name, + referred_columns=fk.parent_columns, + ) + else: + self.graph.nodes[table]["metadata"].primary_key = primary_key + self._clear_safe_ancestral_seed_columns(table) + + def _get_user_defined_fks_to_table(self, table: str) -> list[ForeignKey]: + return [ + fk + for child in self.graph.predecessors(table) + for fk in self.get_foreign_keys(child) + if fk.parent_table_name == table and not self._is_invented(fk.table_name) + ] + + def _remove_relational_json( + self, table: str + ) -> tuple[Optional[pd.DataFrame], list[str], list[ForeignKey]]: + """ + Removes the RelationalJson instance and removes all invented tables from the graph. + + Returns the original data, primary key, and foreign keys. + """ + rel_json = self.relational_jsons[table] + + original_data = rel_json.original_data + original_primary_key = rel_json.original_primary_key + original_foreign_keys = self._get_user_defined_fks_to_table( + rel_json.root_table_name + ) + + for invented_table_name in rel_json.table_names: + with suppress(KeyError): + self.graph.remove_node(invented_table_name) + del self.relational_jsons[table] + + return original_data, original_primary_key, original_foreign_keys def _format_key_column(self, key: Optional[Union[str, List[str]]]) -> List[str]: if key is None: @@ -119,7 +272,7 @@ def add_foreign_key_constraint( """ Add a foreign key relationship between two tables. """ - known_tables = self.list_all_tables() + known_tables = self.list_all_tables(Scope.ALL) abort = False if table not in known_tables: @@ -132,6 +285,9 @@ def add_foreign_key_constraint( if abort: raise MultiTableException("Unrecognized table(s) in foreign key arguments") + table = self._get_table_in_graph(table) + referred_table = self._get_table_in_graph(referred_table) + if len(constrained_columns) != len(referred_columns): logger.warning( "Constrained and referred columns must be of the same length" @@ -191,9 +347,11 @@ def remove_foreign_key_constraint( """ Remove an existing foreign key. """ - if table not in self.list_all_tables(): + if table not in self.list_all_tables(Scope.ALL): raise MultiTableException(f"Unrecognized table name: `{table}`") + table = self._get_table_in_graph(table) + key_to_remove = None for fk in self.get_foreign_keys(table): if fk.columns == constrained_columns: @@ -214,17 +372,108 @@ def remove_foreign_key_constraint( self._clear_safe_ancestral_seed_columns(table) def update_table_data(self, table: str, data: pd.DataFrame) -> None: - try: - self.graph.nodes[table]["metadata"].data = data - self.graph.nodes[table]["metadata"].columns = set(data.columns) - self._clear_safe_ancestral_seed_columns(table) - except KeyError: - raise MultiTableException( - f"Unrecognized table name: {table}. If this is a new table to add, use `add_table`." - ) + if table in self.relational_jsons: + _, original_pk, original_fks = self._remove_relational_json(table) + new_rj_ingest = RelationalJson.ingest(table, original_pk, data) + if new_rj_ingest is not None: + self._add_rel_json_and_tables(table, new_rj_ingest) + parent_table_name = new_rj_ingest[0].root_table_name + else: + self._add_single_table( + name=table, + primary_key=original_pk, + data=data, + ) + parent_table_name = table + for fk in original_fks: + self.add_foreign_key_constraint( + table=fk.table_name, + constrained_columns=fk.columns, + referred_table=parent_table_name, + referred_columns=fk.parent_columns, + ) + else: + try: + metadata = self.graph.nodes[table]["metadata"] + except KeyError: + raise MultiTableException( + f"Unrecognized table name: {table}. If this is a new table to add, use `add_table`." + ) + + if ( + new_rj_ingest := RelationalJson.ingest( + table, metadata.primary_key, data + ) + ) is not None: + original_foreign_keys = self._get_user_defined_fks_to_table(table) + self.graph.remove_node(table) + self._add_rel_json_and_tables(table, new_rj_ingest) + for fk in original_foreign_keys: + self.add_foreign_key_constraint( + table=fk.table_name, + constrained_columns=fk.columns, + referred_table=new_rj_ingest[0].root_table_name, + referred_columns=fk.parent_columns, + ) + else: + metadata.data = data + metadata.columns = set(data.columns) + self._clear_safe_ancestral_seed_columns(table) + + def list_all_tables(self, scope: Scope = Scope.MODELABLE) -> List[str]: + modelable_nodes = self.graph.nodes + + json_source_tables = [ + rel_json.original_table_name + for _, rel_json in self.relational_jsons.items() + ] + + if scope == Scope.MODELABLE: + return list(modelable_nodes) + elif scope == Scope.EVALUATABLE: + e = [] + for n in modelable_nodes: + meta = self.graph.nodes[n]["metadata"] + if ( + meta.invented_table_metadata is None + or meta.invented_table_metadata.invented_root_table_name == n + ): + e.append(n) + return e + elif scope == Scope.INVENTED: + return [n for n in modelable_nodes if self._is_invented(n)] + elif scope == Scope.ALL: + return list(modelable_nodes) + json_source_tables + elif scope == Scope.PUBLIC: + non_invented_nodes = [ + n for n in modelable_nodes if not self._is_invented(n) + ] + return json_source_tables + non_invented_nodes + + def _is_invented(self, table: str) -> bool: + return ( + table in self.graph.nodes + and self.graph.nodes[table]["metadata"].invented_table_metadata is not None + ) + + def get_public_name(self, table: str) -> Optional[str]: + if table in self.relational_jsons: + return table + + if ( + imeta := self.graph.nodes[table]["metadata"].invented_table_metadata + ) is not None: + return imeta.original_table_name - def list_all_tables(self) -> List[str]: - return list(self.graph.nodes) + return table + + def get_invented_table_metadata( + self, table: str + ) -> Optional[InventedTableMetadata]: + if table in self.relational_jsons: + return None + + return self.graph.nodes[table]["metadata"].invented_table_metadata def get_parents(self, table: str) -> List[str]: return list(self.graph.successors(table)) @@ -265,16 +514,36 @@ def list_tables_parents_before_children(self) -> List[str]: return list(reversed(list(topological_sort(self.graph)))) def get_primary_key(self, table: str) -> List[str]: - return self.graph.nodes[table]["metadata"].primary_key + try: + return self.graph.nodes[table]["metadata"].primary_key + except KeyError: + if table in self.relational_jsons: + return self.relational_jsons[table].original_primary_key + else: + raise MultiTableException(f"Unrecognized table: `{table}`") def get_table_data( self, table: str, usecols: Optional[set[str]] = None ) -> pd.DataFrame: usecols = usecols or self.get_table_columns(table) - return self.graph.nodes[table]["metadata"].data[list(usecols)] + try: + return self.graph.nodes[table]["metadata"].data[list(usecols)] + except KeyError: + if table in self.relational_jsons: + if (df := self.relational_jsons[table].original_data) is None: + raise MultiTableException("Original data with JSON is lost.") + return df + else: + raise MultiTableException(f"Unrecognized table: `{table}`") def get_table_columns(self, table: str) -> set[str]: - return self.graph.nodes[table]["metadata"].columns + try: + return self.graph.nodes[table]["metadata"].columns + except KeyError: + if table in self.relational_jsons: + return set(self.relational_jsons[table].original_columns) + else: + raise MultiTableException(f"Unrecognized table: `{table}`") def get_safe_ancestral_seed_columns(self, table: str) -> set[str]: safe_columns = self.graph.nodes[table]["metadata"].safe_ancestral_seed_columns @@ -303,12 +572,39 @@ def _set_safe_ancestral_seed_columns(self, table: str) -> set[str]: def _clear_safe_ancestral_seed_columns(self, table: str) -> None: self.graph.nodes[table]["metadata"].safe_ancestral_seed_columns = None - def get_foreign_keys(self, table: str) -> List[ForeignKey]: + def _get_table_in_graph(self, table: str) -> str: + if table in self.relational_jsons: + table = self.relational_jsons[table].root_table_name + return table + + def get_foreign_keys( + self, table: str, rename_invented_tables: bool = False + ) -> List[ForeignKey]: + def _rename_invented(fk: ForeignKey) -> ForeignKey: + table_name = fk.table_name + parent_table_name = fk.parent_table_name + if self._is_invented(table_name): + table_name = self.graph.nodes[table_name][ + "metadata" + ].invented_table_metadata.original_table_name + if self._is_invented(parent_table_name): + parent_table_name = self.graph.nodes[parent_table_name][ + "metadata" + ].invented_table_metadata.original_table_name + return replace( + fk, table_name=table_name, parent_table_name=parent_table_name + ) + + table = self._get_table_in_graph(table) foreign_keys = [] for parent in self.get_parents(table): fks = self.graph.edges[table, parent]["via"] foreign_keys.extend(fks) - return foreign_keys + + if rename_invented_tables: + return [_rename_invented(fk) for fk in foreign_keys] + else: + return foreign_keys def get_all_key_columns(self, table: str) -> List[str]: all_key_cols = [] @@ -319,14 +615,17 @@ def get_all_key_columns(self, table: str) -> List[str]: def debug_summary(self) -> Dict[str, Any]: max_depth = dag_longest_path_length(self.graph) - all_tables = self.list_all_tables() - table_count = len(all_tables) + public_table_count = len(self.list_all_tables(Scope.PUBLIC)) + invented_table_count = len(self.list_all_tables(Scope.INVENTED)) + + all_tables = self.list_all_tables(Scope.ALL) total_foreign_key_count = 0 tables = {} for table in all_tables: this_table_foreign_key_count = 0 foreign_keys = [] - for key in self.get_foreign_keys(table): + fk_lookup_table_name = self._get_table_in_graph(table) + for key in self.get_foreign_keys(fk_lookup_table_name): total_foreign_key_count = total_foreign_key_count + 1 this_table_foreign_key_count = this_table_foreign_key_count + 1 foreign_keys.append( @@ -337,68 +636,20 @@ def debug_summary(self) -> Dict[str, Any]: } ) tables[table] = { - "column_count": len(self.get_table_data(table).columns), + "column_count": len(self.get_table_columns(table)), "primary_key": self.get_primary_key(table), "foreign_key_count": this_table_foreign_key_count, "foreign_keys": foreign_keys, + "is_invented_table": self._is_invented(table), } return { "foreign_key_count": total_foreign_key_count, "max_depth": max_depth, - "table_count": table_count, "tables": tables, + "public_table_count": public_table_count, + "invented_table_count": invented_table_count, } - def as_dict(self, out_dir: str) -> Dict[str, Any]: - d = {"tables": {}, "foreign_keys": []} - for table in self.list_all_tables(): - d["tables"][table] = { - "primary_key": self.get_primary_key(table), - "csv_path": f"{out_dir}/{table}.csv", - } - keys = [ - { - "table": table, - "constrained_columns": key.columns, - "referred_table": key.parent_table_name, - "referred_columns": key.parent_columns, - } - for key in self.get_foreign_keys(table) - ] - d["foreign_keys"].extend(keys) - return d - - def to_filesystem(self, out_dir: str) -> str: - d = self.as_dict(out_dir) - for table_name, details in d["tables"].items(): - self.get_table_data(table_name).to_csv(details["csv_path"], index=False) - metadata_path = f"{out_dir}/metadata.json" - with open(metadata_path, "w") as metadata_file: - json.dump(d, metadata_file) - return metadata_path - - @classmethod - def from_filesystem(cls, metadata_filepath: str) -> RelationalData: - with open(metadata_filepath, "r") as metadata_file: - d = json.load(metadata_file) - relational_data = RelationalData() - - for table_name, details in d["tables"].items(): - primary_key = details["primary_key"] - data = pd.read_csv(details["csv_path"]) - relational_data.add_table( - name=table_name, primary_key=primary_key, data=data - ) - for foreign_key in d["foreign_keys"]: - relational_data.add_foreign_key_constraint( - table=foreign_key["table"], - constrained_columns=foreign_key["constrained_columns"], - referred_table=foreign_key["referred_table"], - referred_columns=foreign_key["referred_columns"], - ) - - return relational_data - def _ok_for_train_and_seed(col: str, df: pd.DataFrame) -> bool: if _is_highly_nan(col, df): diff --git a/src/gretel_trainer/relational/json.py b/src/gretel_trainer/relational/json.py new file mode 100644 index 00000000..be74f213 --- /dev/null +++ b/src/gretel_trainer/relational/json.py @@ -0,0 +1,340 @@ +from __future__ import annotations + +import hashlib +import logging +import re +from dataclasses import dataclass +from json import JSONDecodeError, loads +from typing import Any, Optional, Protocol, Union + +import numpy as np +import pandas as pd +from unflatten import unflatten + +logger = logging.getLogger(__name__) + +# JSON dict to multi-column and list to multi-table + +FIELD_SEPARATOR = ">" +TABLE_SEPARATOR = "^" +ID_SUFFIX = "~id" +ORDER_COLUMN = "array~order" +CONTENT_COLUMN = "content" +PRIMARY_KEY_COLUMN = "~PRIMARY_KEY_ID~" + + +def load_json(obj: Any) -> Union[dict, list]: + if isinstance(obj, (dict, list)): + return obj + else: + return loads(obj) + + +def is_json(obj: Any, json_type=(dict, list)) -> bool: + try: + obj = load_json(obj) + except (ValueError, TypeError, JSONDecodeError): + return False + else: + return isinstance(obj, json_type) + + +def is_dict(obj: Any) -> bool: + return is_json(obj, dict) + + +def is_list(obj: Any) -> bool: + return isinstance(obj, np.ndarray) or is_json(obj, list) + + +def pandas_json_normalize(series: pd.Series) -> pd.DataFrame: + return pd.json_normalize(series.apply(load_json).to_list(), sep=FIELD_SEPARATOR) + + +def nulls_to_empty_dicts(df: pd.DataFrame) -> pd.DataFrame: + return df.applymap(lambda x: {} if pd.isnull(x) else x) + + +def nulls_to_empty_lists(series: pd.Series) -> pd.Series: + return series.apply(lambda x: x if isinstance(x, list) or not pd.isnull(x) else []) + + +def _normalize_json( + nested_dfs: list[tuple[str, pd.DataFrame]], flat_dfs: list[tuple[str, pd.DataFrame]] +) -> list[tuple[str, pd.DataFrame]]: + if not nested_dfs: + return flat_dfs + name, df = nested_dfs.pop() + dict_cols = [ + col + for col in df.columns + if df[col].apply(is_dict).any() and df[col].dropna().apply(is_dict).all() + ] + list_cols = [ + col + for col in df.columns + if df[col].apply(is_list).any() and df[col].dropna().apply(is_list).all() + ] + if dict_cols: + df[dict_cols] = nulls_to_empty_dicts(df[dict_cols]) + for col in dict_cols: + new_cols = pandas_json_normalize(df[col]).add_prefix(col + FIELD_SEPARATOR) + df = pd.concat([df, new_cols], axis="columns") + df = df.drop(columns=new_cols.columns[new_cols.isnull().all()]) + nested_dfs.append((name, df.drop(columns=dict_cols))) + elif list_cols: + for col in list_cols: + new_table = df[col].explode().dropna().rename(CONTENT_COLUMN).to_frame() + new_table[ORDER_COLUMN] = new_table.groupby(level=0).cumcount() + nested_dfs.append( + ( + name + TABLE_SEPARATOR + col, + new_table.reset_index(names=name + ID_SUFFIX), + ) + ) + nested_dfs.append((name, df.drop(columns=list_cols))) + else: + flat_dfs.append((name, df)) + return _normalize_json(nested_dfs, flat_dfs) + + +# Multi-table and multi-column back to single-table with JSON + + +def get_id_columns(df: pd.DataFrame) -> list[str]: + return [col for col in df.columns if col.endswith(ID_SUFFIX)] + + +def get_parent_table_name_from_child_id_column(id_column_name: str) -> str: + return id_column_name[: -len(ID_SUFFIX)] + + +def get_parent_column_name_from_child_table_name(table_name: str) -> str: + return table_name.split(TABLE_SEPARATOR)[-1] + + +def _is_invented_child_table(table: str, rel_data: _RelationalData) -> bool: + imeta = rel_data.get_invented_table_metadata(table) + return imeta is not None and imeta.invented_root_table_name != table + + +def sanitize_str(s): + sanitized_str = "-".join(re.findall(r"[a-zA-Z_0-9]+", s)) + # Generate suffix from original string, in case of sanitized_str collision + unique_suffix = make_suffix(s) + return f"{sanitized_str}-{unique_suffix}" + + +def make_suffix(s): + return hashlib.sha256(s.encode("utf-8")).hexdigest()[:10] + + +class _RelationalData(Protocol): + def get_foreign_keys( + self, table: str + ) -> list: # can't specify element type (ForeignKey) without cyclic dependency + ... + + def get_table_columns(self, table: str) -> set[str]: + ... + + def get_invented_table_metadata( + self, table: str + ) -> Optional[InventedTableMetadata]: + ... + + +@dataclass +class InventedTableMetadata: + invented_root_table_name: str + original_table_name: str + + +class RelationalJson: + def __init__( + self, + original_table_name: str, + original_primary_key: list[str], + original_columns: list[str], + original_data: Optional[pd.DataFrame], + table_name_mappings: dict[str, str], + ): + self.original_table_name = original_table_name + self.original_primary_key = original_primary_key + self.original_columns = original_columns + self.original_data = original_data + self.table_name_mappings = table_name_mappings + + @classmethod + def ingest( + cls, table_name: str, primary_key: list[str], df: pd.DataFrame + ) -> Optional[IngestResponseT]: + logger.debug(f"Checking table `{table_name}` for JSON columns") + tables = _normalize_json([(table_name, df.copy())], []) + # If we created additional tables (from JSON lists) or added columns (from JSON dicts) + if len(tables) > 1 or len(tables[0][1].columns) > len(df.columns): + mappings = {name: sanitize_str(name) for name, _ in tables} + logger.info( + f"Found JSON data in table `{table_name}`, transformed into {len(mappings)} tables for modeling." + ) + logger.debug(f"Invented table names: {list(mappings.values())}") + rel_json = RelationalJson( + original_table_name=table_name, + original_primary_key=primary_key, + original_data=df, + original_columns=list(df.columns), + table_name_mappings=mappings, + ) + commands = _generate_commands(rel_json, tables) + return (rel_json, commands) + + @property + def root_table_name(self) -> str: + return self.table_name_mappings[self.original_table_name] + + @property + def table_names(self) -> list[str]: + return list(self.table_name_mappings.values()) + + @property + def inverse_table_name_mappings(self) -> dict[str, str]: + # Keys are sanitized, model-friendly names + # Values are "provenance" names (a^b>c) or the original table name + return {value: key for key, value in self.table_name_mappings.items()} + + def restore( + self, tables: dict[str, pd.DataFrame], rel_data: _RelationalData + ) -> Optional[pd.DataFrame]: + """Reduces a set of tables (assumed to match the shapes created on initialization) + to a single table matching the shape of the original source table + """ + # If the root invented table failed, we are completely out of luck + # (Missing invented child tables can be replaced with empty lists so we at least provide _something_) + if self.root_table_name not in tables: + logger.warning( + f"Cannot restore nested JSON data: root invented table `{self.root_table_name}` is missing from output tables." + ) + return None + + return self._denormalize_json(tables, rel_data)[self.original_columns] + + def _denormalize_json( + self, tables: dict[str, pd.DataFrame], rel_data: _RelationalData + ) -> pd.DataFrame: + table_dict: dict = { + self.inverse_table_name_mappings[k]: v for k, v in tables.items() + } + for table_name in list(reversed(self.table_names)): + table_provenance_name = self.inverse_table_name_mappings[table_name] + empty_fallback = pd.DataFrame( + data={col: [] for col in rel_data.get_table_columns(table_name)}, + ) + table_df = table_dict.get(table_provenance_name, empty_fallback) + + if table_df.empty and _is_invented_child_table(table_name, rel_data): + p_name = rel_data.get_foreign_keys(table_name)[0].parent_table_name + parent_name = self.inverse_table_name_mappings[p_name] + col_name = get_parent_column_name_from_child_table_name( + table_provenance_name + ) + kwargs = {col_name: table_dict[parent_name].apply(lambda x: [], axis=1)} + table_dict[parent_name] = table_dict[parent_name].assign(**kwargs) + else: + col_names = [col for col in table_df.columns if FIELD_SEPARATOR in col] + new_col_names = [col.replace(FIELD_SEPARATOR, ".") for col in col_names] + flat_df = table_df[col_names].rename( + columns=dict(zip(col_names, new_col_names)) + ) + flat_dict = { + k: { + k1: v1 + for k1, v1 in v.items() + if v1 is not np.nan and v1 is not None + } + for k, v in flat_df.to_dict("index").items() + } + dict_df = nulls_to_empty_dicts( + pd.DataFrame.from_dict( + {k: unflatten(v) for k, v in flat_dict.items()}, orient="index" + ) + ) + nested_df = table_df.join(dict_df).drop(columns=col_names) + if _is_invented_child_table(table_name, rel_data): + # we know there is exactly one foreign key on invented child tables... + fk = rel_data.get_foreign_keys(table_name)[0] + # ...with exactly one column + fk_col = fk.columns[0] + p_name = fk.parent_table_name + parent_name = self.inverse_table_name_mappings[p_name] + nested_df = ( + nested_df.sort_values(ORDER_COLUMN) + .groupby(fk_col)[CONTENT_COLUMN] + .agg(list) + ) + col_name = get_parent_column_name_from_child_table_name( + table_provenance_name + ) + table_dict[parent_name] = table_dict[parent_name].join( + nested_df.rename(col_name) + ) + table_dict[parent_name][col_name] = nulls_to_empty_lists( + table_dict[parent_name][col_name] + ) + table_dict[table_provenance_name] = nested_df + return table_dict[self.original_table_name] + + +def _generate_commands( + rel_json: RelationalJson, tables: list[tuple[str, pd.DataFrame]] +) -> CommandsT: + """ + Returns lists of keyword arguments designed to be passed to a + RelationalData instance's _add_single_table and add_foreign_key methods + """ + tables = [(rel_json.table_name_mappings[name], df) for name, df in tables] + non_empty_tables = [t for t in tables if not t[1].empty] + + _add_single_table = [] + add_foreign_key = [] + + for table_name, table_df in non_empty_tables: + if table_name == rel_json.root_table_name: + table_pk = rel_json.original_primary_key + [PRIMARY_KEY_COLUMN] + else: + table_pk = [PRIMARY_KEY_COLUMN] + table_df.index.rename(PRIMARY_KEY_COLUMN, inplace=True) + table_df.reset_index(inplace=True) + invented_root_table_name = rel_json.table_name_mappings[ + rel_json.original_table_name + ] + metadata = InventedTableMetadata( + invented_root_table_name=invented_root_table_name, + original_table_name=rel_json.original_table_name, + ) + _add_single_table.append( + { + "name": table_name, + "primary_key": table_pk, + "data": table_df, + "invented_table_metadata": metadata, + } + ) + + for table_name, table_df in non_empty_tables: + for column in get_id_columns(table_df): + referred_table = rel_json.table_name_mappings[ + get_parent_table_name_from_child_id_column(column) + ] + add_foreign_key.append( + { + "table": table_name, + "constrained_columns": [column], + "referred_table": referred_table, + "referred_columns": [PRIMARY_KEY_COLUMN], + } + ) + return (_add_single_table, add_foreign_key) + + +CommandsT = tuple[list[dict], list[dict]] +IngestResponseT = tuple[RelationalJson, CommandsT] diff --git a/src/gretel_trainer/relational/log.py b/src/gretel_trainer/relational/log.py index ee777cd7..ea50840f 100644 --- a/src/gretel_trainer/relational/log.py +++ b/src/gretel_trainer/relational/log.py @@ -20,6 +20,11 @@ logging.root.removeHandler(root_handler) +def set_log_level(level: str): + logger = logging.getLogger(RELATIONAL) + logger.setLevel(level) + + @contextmanager def silent_logs(): logger = logging.getLogger(RELATIONAL) diff --git a/src/gretel_trainer/relational/multi_table.py b/src/gretel_trainer/relational/multi_table.py index 3dfe261c..d6c397f1 100644 --- a/src/gretel_trainer/relational/multi_table.py +++ b/src/gretel_trainer/relational/multi_table.py @@ -32,7 +32,9 @@ GretelModelConfig, MultiTableException, RelationalData, + Scope, ) +from gretel_trainer.relational.json import InventedTableMetadata, RelationalJson from gretel_trainer.relational.log import silent_logs from gretel_trainer.relational.model_config import ( make_classify_config, @@ -102,7 +104,7 @@ def __init__( self._synthetics_train = SyntheticsTrain() self._synthetics_run: Optional[SyntheticsRun] = None self.synthetic_output_tables: Dict[str, pd.DataFrame] = {} - self.evaluations = defaultdict(lambda: TableEvaluation()) + self._evaluations = defaultdict(lambda: TableEvaluation()) if backup is None: self._complete_fresh_init(project_display_name) @@ -146,8 +148,17 @@ def _complete_init_from_backup(self, backup: Backup) -> None: tar.extractall(path=self._working_dir) for table_name, table_backup in backup.relational_data.tables.items(): source_data = pd.read_csv(self._working_dir / f"source_{table_name}.csv") - self.relational_data.add_table( - name=table_name, primary_key=table_backup.primary_key, data=source_data + invented_table_metadata = None + if (imeta := table_backup.invented_table_metadata) is not None: + invented_table_metadata = InventedTableMetadata( + invented_root_table_name=imeta["invented_root_table_name"], + original_table_name=imeta["original_table_name"], + ) + self.relational_data._add_single_table( + name=table_name, + primary_key=table_backup.primary_key, + data=source_data, + invented_table_metadata=invented_table_metadata, ) for fk_backup in backup.relational_data.foreign_keys: self.relational_data.add_foreign_key_constraint( @@ -156,6 +167,15 @@ def _complete_init_from_backup(self, backup: Backup) -> None: referred_table=fk_backup.referred_table, referred_columns=fk_backup.referred_columns, ) + for key, rel_json_backup in backup.relational_data.relational_jsons.items(): + relational_json = RelationalJson( + original_table_name=rel_json_backup.original_table_name, + original_primary_key=rel_json_backup.original_primary_key, + original_columns=rel_json_backup.original_columns, + original_data=None, + table_name_mappings=rel_json_backup.table_name_mappings, + ) + self.relational_data.relational_jsons[key] = relational_json # Debug summary debug_summary_id = backup.artifact_collection.gretel_debug_summary @@ -258,15 +278,16 @@ def _complete_init_from_backup(self, backup: Backup) -> None: if model.status == Status.COMPLETED ] for table in training_succeeded: - model = self._synthetics_train.models[table] - with silent_logs(): - self._strategy.update_evaluation_from_model( - table, - self.evaluations, - model, - self._working_dir, - self._extended_sdk, - ) + if table in self.relational_data.list_all_tables(Scope.EVALUATABLE): + model = self._synthetics_train.models[table] + with silent_logs(): + self._strategy.update_evaluation_from_model( + table, + self._evaluations, + model, + self._working_dir, + self._extended_sdk, + ) training_failed = [ table @@ -459,6 +480,16 @@ def _build_backup(self) -> Backup: def _hybrid(self) -> bool: return get_session_config().default_runner == RunnerMode.HYBRID + @property + def evaluations(self) -> dict[str, TableEvaluation]: + evaluations = defaultdict(lambda: TableEvaluation()) + + for table, evaluation in self._evaluations.items(): + if (public_name := self.relational_data.get_public_name(table)) is not None: + evaluations[public_name] = evaluation + + return evaluations + def _set_refresh_interval(self, interval: Optional[int]) -> None: if interval is None: self._refresh_interval = 60 @@ -617,7 +648,9 @@ def run_transforms( self.relational_data.update_table_data(table_name, transformed_table) self._upload_sources_to_project() - for table, df in output_tables.items(): + reshaped_tables = self.relational_data.restore(output_tables) + + for table, df in reshaped_tables.items(): filename = f"transformed_{table}.csv" out_path = run_dir / filename df.to_csv(out_path, index=False) @@ -629,7 +662,7 @@ def run_transforms( self._project, str(archive_path) ) self._backup() - self.transform_output_tables = output_tables + self.transform_output_tables = reshaped_tables def _prepare_training_data(self, tables: List[str]) -> Dict[str, Path]: """ @@ -666,10 +699,15 @@ def _train_synthetics_models(self, training_data: Dict[str, Path]) -> None: run_task(task, self._extended_sdk) for table in task.completed: - model = self._synthetics_train.models[table] - self._strategy.update_evaluation_from_model( - table, self.evaluations, model, self._working_dir, self._extended_sdk - ) + if table in self.relational_data.list_all_tables(Scope.EVALUATABLE): + model = self._synthetics_train.models[table] + self._strategy.update_evaluation_from_model( + table, + self._evaluations, + model, + self._working_dir, + self._extended_sdk, + ) # TODO: consider moving this to before running the task archive_path = self._working_dir / "synthetics_training.tar.gz" @@ -712,7 +750,7 @@ def retrain_tables(self, tables: Dict[str, pd.DataFrame]) -> None: def _upload_sources_to_project(self) -> None: archive_path = self._working_dir / "source_tables.tar.gz" with tarfile.open(archive_path, "w:gz") as tar: - for table in self.relational_data.list_all_tables(): + for table in self.relational_data.list_all_tables(Scope.ALL): filename = f"source_{table}.csv" out_path = self._working_dir / filename df = self.relational_data.get_table_data(table) @@ -799,7 +837,9 @@ def generate( record_size_ratio=self._synthetics_run.record_size_ratio, ) - for table, synth_df in output_tables.items(): + reshaped_tables = self.relational_data.restore(output_tables) + + for table, synth_df in reshaped_tables.items(): synth_csv_path = run_dir / f"synth_{table}.csv" synth_df.to_csv(synth_csv_path, index=False) @@ -811,6 +851,9 @@ def generate( if table in self._synthetics_run.preserved: continue + if table not in self.relational_data.list_all_tables(Scope.EVALUATABLE): + continue + evaluate_data = self._strategy.get_evaluate_model_data( rel_data=self.relational_data, table_name=table, @@ -832,11 +875,12 @@ def generate( ) run_task(synthetics_evaluate_task, self._extended_sdk) + # Tables passed to task were already scoped to evaluatable tables for table in synthetics_evaluate_task.completed: self._strategy.update_evaluation_from_evaluate( table_name=table, evaluate_model=evaluate_models[table], - evaluations=self.evaluations, + evaluations=self._evaluations, working_dir=self._working_dir, extended_sdk=self._extended_sdk, ) @@ -865,7 +909,7 @@ def generate( self._artifact_collection.upload_synthetics_outputs_archive( self._project, str(archive_path) ) - self.synthetic_output_tables = output_tables + self.synthetic_output_tables = reshaped_tables self._backup() def create_relational_report(self, run_identifier: str, target_dir: Path) -> None: @@ -910,8 +954,8 @@ def _attach_existing_reports(self, run_id: str, table: str) -> None: individual_report_json = json.loads(smart_open.open(individual_path).read()) cross_table_report_json = json.loads(smart_open.open(cross_table_path).read()) - self.evaluations[table].individual_report_json = individual_report_json - self.evaluations[table].cross_table_report_json = cross_table_report_json + self._evaluations[table].individual_report_json = individual_report_json + self._evaluations[table].cross_table_report_json = cross_table_report_json def _validate_gretel_model(self, gretel_model: Optional[str]) -> Tuple[str, str]: gretel_model = (gretel_model or self._strategy.default_model).lower() diff --git a/src/gretel_trainer/relational/report/report.py b/src/gretel_trainer/relational/report/report.py index 8f11b2f4..1036f2c3 100644 --- a/src/gretel_trainer/relational/report/report.py +++ b/src/gretel_trainer/relational/report/report.py @@ -10,7 +10,7 @@ import plotly.graph_objects as go from jinja2 import Environment, FileSystemLoader -from gretel_trainer.relational.core import ForeignKey, RelationalData +from gretel_trainer.relational.core import ForeignKey, RelationalData, Scope from gretel_trainer.relational.report.figures import ( PRIVACY_LEVEL_VALUES, gauge_and_needle_chart, @@ -135,9 +135,9 @@ def composite_ppl_figure(self) -> go.Figure: @property def report_table_data(self) -> List[ReportTableData]: table_data = [] - for table in self.rel_data.list_all_tables(): + for table in self.rel_data.list_all_tables(Scope.PUBLIC): pk = self.rel_data.get_primary_key(table) - fks = self.rel_data.get_foreign_keys(table) + fks = self.rel_data.get_foreign_keys(table, rename_invented_tables=True) table_data.append(ReportTableData(table=table, pk=pk, fks=fks)) # Sort tables alphabetically because that's nice. diff --git a/tests/relational/conftest.py b/tests/relational/conftest.py index 83b65e0c..873758d4 100644 --- a/tests/relational/conftest.py +++ b/tests/relational/conftest.py @@ -19,6 +19,13 @@ def extended_sdk(): return ExtendedGretelSDK(hybrid=False) +@pytest.fixture(autouse=True) +def static_suffix(): + with patch("gretel_trainer.relational.json.make_suffix") as make_suffix: + make_suffix.return_value = "sfx" + yield + + @pytest.fixture() def project(): with patch( @@ -77,6 +84,11 @@ def art() -> RelationalData: return rel_data_from_example_db("art") +@pytest.fixture() +def documents() -> RelationalData: + return rel_data_from_example_db("documents") + + @pytest.fixture() def trips() -> RelationalData: rel_data = rel_data_from_example_db("trips") diff --git a/tests/relational/example_dbs/documents.sql b/tests/relational/example_dbs/documents.sql new file mode 100644 index 00000000..ebc861bf --- /dev/null +++ b/tests/relational/example_dbs/documents.sql @@ -0,0 +1,48 @@ +create table if not exists users ( + id integer primary key, + name text not null +); +delete from users; + +create table if not exists purchases ( + id integer primary key, + user_id integer not null, + data text not null, + -- + foreign key (user_id) references users (id) +); +delete from purchases; + +create table if not exists payments ( + id integer primary key, + purchase_id integer not null, + amount integer not null, + -- + foreign key (purchase_id) references purchases (id) +); +delete from payments; + +insert into users (id, name) values + (1, "Andy"), + (2, "Bob"), + (3, "Charlie"), + (4, "David"); + +insert into purchases (id, user_id, data) values + (1, 1, '{"item": "pen", "cost": 100, "details": {"color": "red"}, "years": [2023]}'), + (2, 2, '{"item": "paint", "cost": 100, "details": {"color": "red"}, "years": [2023, 2022]}'), + (3, 2, '{"item": "ink", "cost": 100, "details": {"color": "red"}, "years": [2020, 2019]}'), + (4, 3, '{"item": "pen", "cost": 100, "details": {"color": "blue"}, "years": []}'), + (5, 3, '{"item": "paint", "cost": 100, "details": {"color": "blue"}, "years": [2021]}'), + (6, 3, '{"item": "ink", "cost": 100, "details": {"color": "blue"}, "years": []}'); + +insert into payments (id, purchase_id, amount) values + (1, 1, 42), + (2, 1, 42), + (3, 2, 42), + (4, 2, 42), + (5, 2, 42), + (6, 3, 42), + (7, 4, 42), + (8, 4, 42), + (9, 5, 42); diff --git a/tests/relational/test_backup.py b/tests/relational/test_backup.py index a4b12931..a6b31ae3 100644 --- a/tests/relational/test_backup.py +++ b/tests/relational/test_backup.py @@ -8,9 +8,11 @@ BackupGenerate, BackupRelationalData, BackupRelationalDataTable, + BackupRelationalJson, BackupSyntheticsTrain, BackupTransformsTrain, ) +from gretel_trainer.relational.json import InventedTableMetadata def test_backup_relational_data(trips): @@ -27,11 +29,69 @@ def test_backup_relational_data(trips): referred_columns=["id"], ) ], + relational_jsons={}, ) assert BackupRelationalData.from_relational_data(trips) == expected +def test_backup_relational_data_with_json(documents): + expected = BackupRelationalData( + tables={ + "users": BackupRelationalDataTable(primary_key=["id"]), + "purchases-sfx": BackupRelationalDataTable( + primary_key=["id", "~PRIMARY_KEY_ID~"], + invented_table_metadata={ + "invented_root_table_name": "purchases-sfx", + "original_table_name": "purchases", + }, + ), + "purchases-data-years-sfx": BackupRelationalDataTable( + primary_key=["~PRIMARY_KEY_ID~"], + invented_table_metadata={ + "invented_root_table_name": "purchases-sfx", + "original_table_name": "purchases", + }, + ), + "payments": BackupRelationalDataTable(primary_key=["id"]), + }, + foreign_keys=[ + BackupForeignKey( + table="payments", + constrained_columns=["purchase_id"], + referred_table="purchases-sfx", + referred_columns=["id"], + ), + BackupForeignKey( + table="purchases-sfx", + constrained_columns=["user_id"], + referred_table="users", + referred_columns=["id"], + ), + BackupForeignKey( + table="purchases-data-years-sfx", + constrained_columns=["purchases~id"], + referred_table="purchases-sfx", + referred_columns=["~PRIMARY_KEY_ID~"], + ), + ], + relational_jsons={ + "purchases": BackupRelationalJson( + original_table_name="purchases", + original_primary_key=["id"], + original_columns=["id", "user_id", "data"], + table_name_mappings={ + "purchases": "purchases-sfx", + "purchases^data>years": "purchases-data-years-sfx", + }, + invented_table_names=["purchases-sfx", "purchases-data-years-sfx"], + ), + }, + ) + + assert BackupRelationalData.from_relational_data(documents) == expected + + def test_backup(): backup_relational = BackupRelationalData( tables={ @@ -50,6 +110,7 @@ def test_backup(): referred_columns=["id"], ) ], + relational_jsons={}, ) backup_classify = BackupClassify( model_ids={ diff --git a/tests/relational/test_connectors.py b/tests/relational/test_connectors.py index cca417ae..dbee2014 100644 --- a/tests/relational/test_connectors.py +++ b/tests/relational/test_connectors.py @@ -4,7 +4,7 @@ import pytest from gretel_trainer.relational.connectors import sqlite_conn -from gretel_trainer.relational.core import MultiTableException +from gretel_trainer.relational.core import MultiTableException, Scope def test_extract_subsets_of_relational_data(example_dbs): @@ -25,8 +25,8 @@ def test_extract_subsets_of_relational_data(example_dbs): ) expected_tables = {"users", "events", "products"} - assert set(only.list_all_tables()) == expected_tables - assert set(ignore.list_all_tables()) == expected_tables + assert set(only.list_all_tables(Scope.ALL)) == expected_tables + assert set(ignore.list_all_tables(Scope.ALL)) == expected_tables # `products` has a foreign key to `distribution_center` in the source, but because the # latter table was not extracted, the relationship is not recognized diff --git a/tests/relational/test_relational_data.py b/tests/relational/test_relational_data.py index 38d24f8c..755bf8b1 100644 --- a/tests/relational/test_relational_data.py +++ b/tests/relational/test_relational_data.py @@ -222,111 +222,19 @@ def in_order(col, t1, t2): assert in_order(tables, "users", "order_items") -def test_relational_data_as_dict(ecom): - as_dict = ecom.as_dict("test_out") - - assert as_dict["tables"] == { - "users": {"primary_key": ["id"], "csv_path": "test_out/users.csv"}, - "events": {"primary_key": ["id"], "csv_path": "test_out/events.csv"}, - "distribution_center": { - "primary_key": ["id"], - "csv_path": "test_out/distribution_center.csv", - }, - "products": {"primary_key": ["id"], "csv_path": "test_out/products.csv"}, - "inventory_items": { - "primary_key": ["id"], - "csv_path": "test_out/inventory_items.csv", - }, - "order_items": {"primary_key": ["id"], "csv_path": "test_out/order_items.csv"}, - } - expected_foreign_keys = [ - { - "table": "events", - "constrained_columns": ["user_id"], - "referred_table": "users", - "referred_columns": ["id"], - }, - { - "table": "order_items", - "constrained_columns": ["user_id"], - "referred_table": "users", - "referred_columns": ["id"], - }, - { - "table": "order_items", - "constrained_columns": ["inventory_item_id"], - "referred_table": "inventory_items", - "referred_columns": ["id"], - }, - { - "table": "inventory_items", - "constrained_columns": ["product_id"], - "referred_table": "products", - "referred_columns": ["id"], - }, - { - "table": "inventory_items", - "constrained_columns": ["product_distribution_center_id"], - "referred_table": "distribution_center", - "referred_columns": ["id"], - }, - { - "table": "products", - "constrained_columns": ["distribution_center_id"], - "referred_table": "distribution_center", - "referred_columns": ["id"], - }, - ] - for expected_fk in expected_foreign_keys: - assert expected_fk in as_dict["foreign_keys"] - - -def test_ecommerce_filesystem_serde(ecom): - with tempfile.TemporaryDirectory() as tmp: - ecom.to_filesystem(tmp) - - expected_files = [ - f"{tmp}/metadata.json", - f"{tmp}/events.csv", - f"{tmp}/users.csv", - f"{tmp}/distribution_center.csv", - f"{tmp}/products.csv", - f"{tmp}/inventory_items.csv", - f"{tmp}/order_items.csv", - ] - for expected_file in expected_files: - assert os.path.exists(expected_file) - - from_json = RelationalData.from_filesystem(f"{tmp}/metadata.json") - - for table in ecom.list_all_tables(): - assert set(ecom.get_table_data(table).columns) == set( - from_json.get_table_data(table).columns - ) - assert ecom.get_parents(table) == from_json.get_parents(table) - assert ecom.get_foreign_keys(table) == from_json.get_foreign_keys(table) - - -def test_filesystem_serde_accepts_composite_primary_keys(mutagenesis): - with tempfile.TemporaryDirectory() as tmp: - mutagenesis.to_filesystem(tmp) - from_json = RelationalData.from_filesystem(f"{tmp}/metadata.json") - - assert from_json.get_primary_key("bond") == ["atom1_id", "atom2_id"] - assert from_json.get_primary_key("atom") == ["atom_id"] - - def test_debug_summary(ecom, mutagenesis): assert ecom.debug_summary() == { "foreign_key_count": 6, "max_depth": 3, - "table_count": 6, + "public_table_count": 6, + "invented_table_count": 0, "tables": { "users": { "column_count": 3, "primary_key": ["id"], "foreign_key_count": 0, "foreign_keys": [], + "is_invented_table": False, }, "events": { "column_count": 4, @@ -339,12 +247,14 @@ def test_debug_summary(ecom, mutagenesis): "parent_columns": ["id"], } ], + "is_invented_table": False, }, "distribution_center": { "column_count": 2, "primary_key": ["id"], "foreign_key_count": 0, "foreign_keys": [], + "is_invented_table": False, }, "products": { "column_count": 4, @@ -357,6 +267,7 @@ def test_debug_summary(ecom, mutagenesis): "parent_columns": ["id"], } ], + "is_invented_table": False, }, "inventory_items": { "column_count": 5, @@ -374,6 +285,7 @@ def test_debug_summary(ecom, mutagenesis): "parent_columns": ["id"], }, ], + "is_invented_table": False, }, "order_items": { "column_count": 5, @@ -391,6 +303,7 @@ def test_debug_summary(ecom, mutagenesis): "parent_columns": ["id"], }, ], + "is_invented_table": False, }, }, } @@ -398,7 +311,8 @@ def test_debug_summary(ecom, mutagenesis): assert mutagenesis.debug_summary() == { "foreign_key_count": 3, "max_depth": 2, - "table_count": 3, + "public_table_count": 3, + "invented_table_count": 0, "tables": { "bond": { "column_count": 3, @@ -416,6 +330,7 @@ def test_debug_summary(ecom, mutagenesis): "parent_columns": ["atom_id"], }, ], + "is_invented_table": False, }, "atom": { "column_count": 4, @@ -428,12 +343,14 @@ def test_debug_summary(ecom, mutagenesis): "parent_columns": ["molecule_id"], } ], + "is_invented_table": False, }, "molecule": { "column_count": 2, "primary_key": ["molecule_id"], "foreign_key_count": 0, "foreign_keys": [], + "is_invented_table": False, }, }, } diff --git a/tests/relational/test_relational_data_with_json.py b/tests/relational/test_relational_data_with_json.py new file mode 100644 index 00000000..9ef30c60 --- /dev/null +++ b/tests/relational/test_relational_data_with_json.py @@ -0,0 +1,804 @@ +import pandas as pd +import pandas.testing as pdtest +import pytest + +from gretel_trainer.relational.core import ( + ForeignKey, + MultiTableException, + RelationalData, + Scope, +) + + +@pytest.fixture +def bball(): + bball_jsonl = """ + {"name": "LeBron James", "age": 38, "draft": {"year": 2003}, "teams": ["Cavaliers", "Heat", "Lakers"]} + {"name": "Steph Curry", "age": 35, "draft": {"year": 2009, "college": "Davidson"}, "teams": ["Warriors"]} + """ + bball_df = pd.read_json(bball_jsonl, lines=True) + + rel_data = RelationalData() + rel_data.add_table(name="bball", primary_key=None, data=bball_df) + + return rel_data + + +def test_json_columns_produce_invented_flattened_tables(documents): + pdtest.assert_frame_equal( + documents.get_table_data("purchases-sfx"), + pd.DataFrame( + data={ + "~PRIMARY_KEY_ID~": [0, 1, 2, 3, 4, 5], + "id": [1, 2, 3, 4, 5, 6], + "user_id": [1, 2, 2, 3, 3, 3], + "data>item": ["pen", "paint", "ink", "pen", "paint", "ink"], + "data>cost": [100, 100, 100, 100, 100, 100], + "data>details>color": ["red", "red", "red", "blue", "blue", "blue"], + } + ), + check_like=True, + ) + + pdtest.assert_frame_equal( + documents.get_table_data("purchases-data-years-sfx"), + pd.DataFrame( + data={ + "content": [2023, 2023, 2022, 2020, 2019, 2021], + "array~order": [0, 0, 1, 0, 1, 0], + "~PRIMARY_KEY_ID~": [0, 1, 2, 3, 4, 5], + "purchases~id": [0, 1, 1, 2, 2, 4], + } + ), + check_like=True, + check_dtype=False, # Without this, test fails asserting dtype mismatch in `content` field (object vs. int) + ) + + assert documents.get_foreign_keys("purchases-data-years-sfx") == [ + ForeignKey( + table_name="purchases-data-years-sfx", + columns=["purchases~id"], + parent_table_name="purchases-sfx", + parent_columns=["~PRIMARY_KEY_ID~"], + ) + ] + + +def test_list_tables_accepts_various_scopes(documents): + # PUBLIC reflects the user's source + assert set(documents.list_all_tables(Scope.PUBLIC)) == { + "users", + "purchases", + "payments", + } + + # MODELABLE replaces any source tables containing JSON with the invented tables + assert set(documents.list_all_tables(Scope.MODELABLE)) == { + "users", + "payments", + "purchases-sfx", + "purchases-data-years-sfx", + } + + # EVALUATABLE is similar to MODELABLE, but omits invented child tables—we only evaluate the root invented table + assert set(documents.list_all_tables(Scope.EVALUATABLE)) == { + "users", + "payments", + "purchases-sfx", + } + + # INVENTED returns only tables invented from source tables with JSON + assert set(documents.list_all_tables(Scope.INVENTED)) == { + "purchases-sfx", + "purchases-data-years-sfx", + } + + # ALL returns every table name, including both source-with-JSON tables and those invented from such tables + assert set(documents.list_all_tables(Scope.ALL)) == { + "users", + "purchases", + "payments", + "purchases-sfx", + "purchases-data-years-sfx", + } + + # Default scope is MODELABLE + assert set(documents.list_all_tables()) == set( + documents.list_all_tables(Scope.MODELABLE) + ) + + +def test_invented_json_column_names(documents, bball): + # The root invented table adds columns for dictionary properties lifted from nested JSON objects + assert set(documents.get_table_columns("purchases-sfx")) == { + "~PRIMARY_KEY_ID~", + "id", + "user_id", + "data>item", + "data>cost", + "data>details>color", + } + + # JSON lists lead to invented child tables. These tables store the original content, + # a new primary key, a foreign key back to the parent, and the original array index + assert set(documents.get_table_columns("purchases-data-years-sfx")) == { + "content", + "~PRIMARY_KEY_ID~", + "purchases~id", + "array~order", + } + + # If the source table does not have a primary key defined, one is created on the root invented table + assert set(bball.get_table_columns("bball-sfx")) == { + "name", + "age", + "draft>year", + "draft>college", + "~PRIMARY_KEY_ID~", + } + + +def test_primary_key(documents, bball): + # The root invented table's primary key is a composite key that includes the source PK and an invented column + assert documents.get_primary_key("purchases") == ["id"] + assert documents.get_primary_key("purchases-sfx") == ["id", "~PRIMARY_KEY_ID~"] + + assert bball.get_primary_key("bball") == [] + assert bball.get_primary_key("bball-sfx") == ["~PRIMARY_KEY_ID~"] + + # Setting an existing primary key to None puts us in the correct state + assert len(documents.list_all_tables(Scope.ALL)) == 5 + original_payments_fks = documents.get_foreign_keys("payments") + documents.set_primary_key(table="purchases", primary_key=None) + assert len(documents.list_all_tables(Scope.ALL)) == 5 + assert documents.get_primary_key("purchases") == [] + assert documents.get_primary_key("purchases-sfx") == ["~PRIMARY_KEY_ID~"] + assert documents.get_foreign_keys("purchases-data-years-sfx") == [ + ForeignKey( + table_name="purchases-data-years-sfx", + columns=["purchases~id"], + parent_table_name="purchases-sfx", + parent_columns=["~PRIMARY_KEY_ID~"], + ) + ] + assert documents.get_foreign_keys("payments") == original_payments_fks + + # Setting a None primary key to some column puts us in the correct state + assert len(bball.list_all_tables(Scope.ALL)) == 3 + bball.set_primary_key(table="bball", primary_key="name") + assert len(bball.list_all_tables(Scope.ALL)) == 3 + assert bball.get_primary_key("bball") == ["name"] + assert bball.get_primary_key("bball-sfx") == ["name", "~PRIMARY_KEY_ID~"] + assert bball.get_foreign_keys("bball-teams-sfx") == [ + ForeignKey( + table_name="bball-teams-sfx", + columns=["bball~id"], + parent_table_name="bball-sfx", + parent_columns=["~PRIMARY_KEY_ID~"], + ) + ] + + +def test_foreign_keys(documents): + # Foreign keys from the source-with-JSON table are present on the root invented table + assert documents.get_foreign_keys("purchases") == documents.get_foreign_keys( + "purchases-sfx" + ) + + # The root invented table name is used in the ForeignKey + assert documents.get_foreign_keys("purchases") == [ + ForeignKey( + table_name="purchases-sfx", + columns=["user_id"], + parent_table_name="users", + parent_columns=["id"], + ) + ] + + # Invented children point to invented parents + assert documents.get_foreign_keys("purchases-data-years-sfx") == [ + ForeignKey( + table_name="purchases-data-years-sfx", + columns=["purchases~id"], + parent_table_name="purchases-sfx", + parent_columns=["~PRIMARY_KEY_ID~"], + ) + ] + + # Source children of the source-with-JSON table point to the root invented table + assert documents.get_foreign_keys("payments") == [ + ForeignKey( + table_name="payments", + columns=["purchase_id"], + parent_table_name="purchases-sfx", + parent_columns=["id"], + ) + ] + + # You can request public/user-supplied names instead of the default invented table names + assert documents.get_foreign_keys("payments", rename_invented_tables=True) == [ + ForeignKey( + table_name="payments", + columns=["purchase_id"], + parent_table_name="purchases", + parent_columns=["id"], + ) + ] + assert documents.get_foreign_keys("purchases", rename_invented_tables=True) == [ + ForeignKey( + table_name="purchases", + columns=["user_id"], + parent_table_name="users", + parent_columns=["id"], + ) + ] + + # Removing a foreign key from the source-with-JSON table updates the root invented table + documents.remove_foreign_key_constraint( + table="purchases", constrained_columns=["user_id"] + ) + assert documents.get_foreign_keys("purchases") == [] + assert documents.get_foreign_keys("purchases-sfx") == [] + + +def test_update_data_with_existing_json_to_new_json(documents): + new_purchases_jsonl = """ + {"id": 1, "user_id": 1, "data": {"item": "watercolor", "cost": 200, "details": {"color": "aquamarine"}, "years": [1999]}} + {"id": 2, "user_id": 2, "data": {"item": "watercolor", "cost": 200, "details": {"color": "aquamarine"}, "years": [1999]}} + {"id": 3, "user_id": 2, "data": {"item": "watercolor", "cost": 200, "details": {"color": "aquamarine"}, "years": [1999]}} + {"id": 4, "user_id": 3, "data": {"item": "charcoal", "cost": 200, "details": {"color": "aquamarine"}, "years": [1998]}} + {"id": 5, "user_id": 3, "data": {"item": "charcoal", "cost": 200, "details": {"color": "aquamarine"}, "years": [1998]}} + {"id": 6, "user_id": 3, "data": {"item": "charcoal", "cost": 200, "details": {"color": "aquamarine"}, "years": [1998]}} + """ + new_purchases_df = pd.read_json(new_purchases_jsonl, lines=True) + + documents.update_table_data("purchases", data=new_purchases_df) + + assert len(documents.list_all_tables(Scope.ALL)) == 5 + assert len(documents.list_all_tables(Scope.MODELABLE)) == 4 + + expected = { + "purchases-sfx": pd.DataFrame( + data={ + "~PRIMARY_KEY_ID~": [0, 1, 2, 3, 4, 5], + "id": [1, 2, 3, 4, 5, 6], + "user_id": [1, 2, 2, 3, 3, 3], + "data>item": [ + "watercolor", + "watercolor", + "watercolor", + "charcoal", + "charcoal", + "charcoal", + ], + "data>cost": [200, 200, 200, 200, 200, 200], + "data>details>color": [ + "aquamarine", + "aquamarine", + "aquamarine", + "aquamarine", + "aquamarine", + "aquamarine", + ], + } + ), + "purchases-data-years-sfx": pd.DataFrame( + data={ + "content": [1999, 1999, 1999, 1998, 1998, 1998], + "array~order": [0, 0, 0, 0, 0, 0], + "~PRIMARY_KEY_ID~": [0, 1, 2, 3, 4, 5], + "purchases~id": [0, 1, 2, 3, 4, 5], + } + ), + } + + pdtest.assert_frame_equal( + documents.get_table_data("purchases-sfx"), + expected["purchases-sfx"], + check_like=True, + ) + + pdtest.assert_frame_equal( + documents.get_table_data("purchases-data-years-sfx"), + expected["purchases-data-years-sfx"], + check_like=True, + check_dtype=False, # Without this, test fails asserting dtype mismatch in `content` field (object vs. int) + ) + + # User-supplied child table FK still exists + assert documents.get_foreign_keys("payments") == [ + ForeignKey( + table_name="payments", + columns=["purchase_id"], + parent_table_name="purchases-sfx", + parent_columns=["id"], + ) + ] + + +def test_update_data_existing_json_to_no_json(documents): + new_purchases_df = pd.DataFrame( + data={ + "id": [1, 2, 3, 4, 5, 6], + "user_id": [1, 2, 2, 3, 3, 3], + "data": ["pen", "paint", "ink", "pen", "paint", "ink"], + } + ) + + documents.update_table_data("purchases", data=new_purchases_df) + + assert len(documents.list_all_tables(Scope.ALL)) == 3 + + pdtest.assert_frame_equal( + documents.get_table_data("purchases"), + new_purchases_df, + check_like=True, + ) + + assert documents.get_foreign_keys("payments") == [ + ForeignKey( + table_name="payments", + columns=["purchase_id"], + parent_table_name="purchases", + parent_columns=["id"], + ) + ] + + +def test_update_data_existing_flat_to_json(documents): + # Build up a RelationalData instance that basically mirrors documents, + # but purchases is flat to start and thus there are no RelationalJson instances + flat_purchases_df = pd.DataFrame( + data={ + "id": [1, 2, 3, 4, 5, 6], + "user_id": [1, 2, 2, 3, 3, 3], + "data": ["pen", "paint", "ink", "pen", "paint", "ink"], + } + ) + rel_data = RelationalData() + rel_data.add_table( + name="users", primary_key="id", data=documents.get_table_data("users") + ) + rel_data.add_table(name="purchases", primary_key="id", data=flat_purchases_df) + rel_data.add_table( + name="payments", primary_key="id", data=documents.get_table_data("payments") + ) + rel_data.add_foreign_key_constraint( + table="purchases", + constrained_columns=["user_id"], + referred_table="users", + referred_columns=["id"], + ) + rel_data.add_foreign_key_constraint( + table="payments", + constrained_columns=["purchase_id"], + referred_table="purchases", + referred_columns=["id"], + ) + assert len(rel_data.list_all_tables(Scope.ALL)) == 3 + assert len(rel_data.list_all_tables(Scope.MODELABLE)) == 3 + + rel_data.update_table_data("purchases", documents.get_table_data("purchases")) + + assert set(rel_data.list_all_tables(Scope.ALL)) == { + "users", + "purchases", + "purchases-sfx", + "purchases-data-years-sfx", + "payments", + } + # the original purchases table is no longer flat, nor (therefore) MODELABLE + assert set(rel_data.list_all_tables(Scope.MODELABLE)) == { + "users", + "purchases-sfx", + "purchases-data-years-sfx", + "payments", + } + assert rel_data.get_foreign_keys("payments") == [ + ForeignKey( + table_name="payments", + columns=["purchase_id"], + parent_table_name="purchases-sfx", # The foreign key now points to the root invented table + parent_columns=["id"], + ) + ] + + +# Simulates output tables from MultiTable transforms or synthetics, which will only include the MODELABLE tables +@pytest.fixture() +def mt_output_tables(): + return { + "users": pd.DataFrame( + data={ + "id": [1, 2, 3], + "name": ["Rob", "Sam", "Tim"], + } + ), + "payments": pd.DataFrame( + data={ + "id": [1, 2, 3, 4], + "amount": [10, 10, 10, 10], + "purchase_id": [1, 2, 3, 4], + } + ), + "purchases-sfx": pd.DataFrame( + data={ + "~PRIMARY_KEY_ID~": [0, 1, 2, 3], + "id": [1, 2, 3, 4], + "user_id": [1, 1, 2, 3], + "data>item": ["pen", "paint", "ink", "ink"], + "data>cost": [18, 19, 20, 21], + "data>details>color": ["blue", "yellow", "pink", "orange"], + } + ), + "purchases-data-years-sfx": pd.DataFrame( + data={ + "content": [2000, 2001, 2002, 2003, 2004, 2005, 2006, 2007], + "~PRIMARY_KEY_ID~": [0, 1, 2, 3, 4, 5, 6, 7], + "purchases~id": [0, 0, 0, 1, 2, 2, 3, 3], + "array~order": [0, 1, 2, 0, 0, 1, 0, 1], + } + ), + } + + +def test_restoring_output_tables_to_original_shape(documents, mt_output_tables): + restored_tables = documents.restore(mt_output_tables) + + # We expect our restored tables to match the PUBLIC tables + assert len(restored_tables) == 3 + expected = { + "users": mt_output_tables["users"], + "payments": mt_output_tables["payments"], + "purchases": pd.DataFrame( + data={ + "id": [1, 2, 3, 4], + "user_id": [1, 1, 2, 3], + "data": [ + { + "item": "pen", + "cost": 18, + "details": {"color": "blue"}, + "years": [2000, 2001, 2002], + }, + { + "item": "paint", + "cost": 19, + "details": {"color": "yellow"}, + "years": [2003], + }, + { + "item": "ink", + "cost": 20, + "details": {"color": "pink"}, + "years": [2004, 2005], + }, + { + "item": "ink", + "cost": 21, + "details": {"color": "orange"}, + "years": [2006, 2007], + }, + ], + } + ), + } + + for t, df in restored_tables.items(): + pdtest.assert_frame_equal(df, expected[t]) + + +def test_restore_with_incomplete_tableset(documents, mt_output_tables): + without_invented_root = { + k: v for k, v in mt_output_tables.items() if k != "purchases-sfx" + } + + without_invented_child = { + k: v for k, v in mt_output_tables.items() if k != "purchases-data-years-sfx" + } + + restored_without_invented_root = documents.restore(without_invented_root) + restored_without_invented_child = documents.restore(without_invented_child) + + # non-JSON-related tables are fine/unaffected + pdtest.assert_frame_equal( + restored_without_invented_child["users"], mt_output_tables["users"] + ) + pdtest.assert_frame_equal( + restored_without_invented_child["payments"], mt_output_tables["payments"] + ) + pdtest.assert_frame_equal( + restored_without_invented_root["users"], mt_output_tables["users"] + ) + pdtest.assert_frame_equal( + restored_without_invented_root["payments"], mt_output_tables["payments"] + ) + + # If the invented root is missing, the table is omitted from the result dict entirely + assert "purchases" not in restored_without_invented_root + + # If an invented child is missing, we restore the shape but populate the list column with empty lists + pdtest.assert_frame_equal( + restored_without_invented_child["purchases"], + pd.DataFrame( + data={ + "id": [1, 2, 3, 4], + "user_id": [1, 1, 2, 3], + "data": [ + { + "item": "pen", + "cost": 18, + "details": {"color": "blue"}, + "years": [], + }, + { + "item": "paint", + "cost": 19, + "details": {"color": "yellow"}, + "years": [], + }, + { + "item": "ink", + "cost": 20, + "details": {"color": "pink"}, + "years": [], + }, + { + "item": "ink", + "cost": 21, + "details": {"color": "orange"}, + "years": [], + }, + ], + } + ), + ) + + +def test_flatten_and_restore_all_sorts_of_json(): + json = """ +[ + { + "a": 1, + "b": {"bb": 1}, + "c": {"cc": {"ccc": 1}}, + "d": [1, 2, 3], + "e": [ + {"ee": 1}, + {"ee": 2} + ], + "f": [ + { + "ff": [ + {"fff": 1}, + {"fff": 2} + ] + } + ], + } +] +""" + json_df = pd.read_json(json, orient="records") + rel_data = RelationalData() + rel_data.add_table(name="demo", primary_key=None, data=json_df) + + assert set(rel_data.list_all_tables(Scope.ALL)) == { + "demo", + "demo-sfx", + "demo-d-sfx", + "demo-e-sfx", + "demo-f-sfx", + "demo-f-content-ff-sfx", + } + + assert rel_data.get_table_columns("demo-sfx") == { + "a", + "b>bb", + "c>cc>ccc", + "~PRIMARY_KEY_ID~", + } + assert rel_data.get_table_columns("demo-d-sfx") == { + "content", + "~PRIMARY_KEY_ID~", + "demo~id", + "array~order", + } + assert rel_data.get_table_columns("demo-e-sfx") == { + "content>ee", + "~PRIMARY_KEY_ID~", + "demo~id", + "array~order", + } + assert rel_data.get_table_columns("demo-f-sfx") == { + "~PRIMARY_KEY_ID~", + "demo~id", + "array~order", + } + assert rel_data.get_table_columns("demo-f-content-ff-sfx") == { + "content>fff", + "~PRIMARY_KEY_ID~", + "demo^f~id", + "array~order", + } + + output_tables = { + "demo-sfx": pd.DataFrame( + data={ + "a": [1, 2], + "b>bb": [3, 4], + "c>cc>ccc": [5, 6], + "~PRIMARY_KEY_ID~": [0, 1], + } + ), + "demo-d-sfx": pd.DataFrame( + data={ + "content": [10, 11, 12, 13], + "~PRIMARY_KEY_ID~": [0, 1, 2, 3], + "demo~id": [0, 0, 0, 1], + "array~order": [0, 1, 2, 0], + } + ), + "demo-e-sfx": pd.DataFrame( + data={ + "content>ee": [100, 200, 300], + "~PRIMARY_KEY_ID~": [0, 1, 2], + "demo~id": [0, 1, 1], + "array~order": [0, 0, 1], + } + ), + "demo-f-sfx": pd.DataFrame( + data={"~PRIMARY_KEY_ID~": [0, 1], "demo~id": [0, 1], "array~order": [0, 0]} + ), + "demo-f-content-ff-sfx": pd.DataFrame( + data={ + "content>fff": [10, 11, 12], + "~PRIMARY_KEY_ID~": [0, 1, 2], + "demo^f~id": [0, 0, 0], + "array~order": [0, 1, 2], + } + ), + } + + restored = rel_data.restore(output_tables) + + expected = pd.DataFrame( + data={ + "a": [1, 2], + "b": [{"bb": 3}, {"bb": 4}], + "c": [{"cc": {"ccc": 5}}, {"cc": {"ccc": 6}}], + "d": [[10, 11, 12], [13]], + "e": [[{"ee": 100}], [{"ee": 200}, {"ee": 300}]], + "f": [[{"ff": [{"fff": 10}, {"fff": 11}, {"fff": 12}]}], [{"ff": []}]], + } + ) + + assert len(restored) == 1 + pdtest.assert_frame_equal(restored["demo"], expected) + + +def test_only_lists_edge_case(): + # Smallest reproduction: a dataframe with just one row and one column, and the value is a list + list_df = pd.DataFrame(data={"l": [[1, 2, 3, 4]]}) + rel_data = RelationalData() + + # Since there are no flat fields on the source, we cannot create an invented root table. + # Internally, the call below fails when trying to create a foreign key from the invented child table up to a nonexistent root. + with pytest.raises(MultiTableException): + rel_data.add_table(name="list", primary_key=None, data=list_df) + + # TODO: Ideally we should "roll back the transaction" and the call below should return an empty set, + # but instead we currently get left in a nonsensical state. + assert set(rel_data.list_all_tables(Scope.ALL)) == {"list", "list-l-sfx"} + + +def test_lists_of_lists(): + # Enough flat data in the source to create a root invented table. + # Upping the complexity by making the special value a list of lists, + # but not to fear: we can handle this correctly. + lol_df = pd.DataFrame(data={"a": [1], "l": [[[1, 2], [3, 4]]]}) + rel_data = RelationalData() + rel_data.add_table(name="lol", primary_key=None, data=lol_df) + + assert set(rel_data.list_all_tables(Scope.ALL)) == { + "lol", + "lol-sfx", + "lol-l-sfx", + "lol-l-content-sfx", + } + + output = { + "lol-sfx": pd.DataFrame(data={"a": [1, 2], "~PRIMARY_KEY_ID~": [0, 1]}), + "lol-l-sfx": pd.DataFrame( + data={"~PRIMARY_KEY_ID~": [0, 1], "lol~id": [0, 0], "array~order": [0, 1]} + ), + "lol-l-content-sfx": pd.DataFrame( + data={ + "content": [10, 20, 30, 40], + "~PRIMARY_KEY_ID~": [0, 1, 2, 3], + "lol^l~id": [0, 0, 1, 1], + "array~order": [0, 1, 0, 1], + } + ), + } + restored = rel_data.restore(output) + + assert len(restored) == 1 + pdtest.assert_frame_equal( + restored["lol"], + pd.DataFrame( + data={ + "a": [1, 2], + "l": [[[10, 20], [30, 40]], []], + } + ), + ) + + +def test_all_tables_are_present_in_debug_summary(documents): + assert documents.debug_summary() == { + "foreign_key_count": 4, + "max_depth": 2, + "public_table_count": 3, + "invented_table_count": 2, + "tables": { + "users": { + "column_count": 2, + "primary_key": ["id"], + "foreign_key_count": 0, + "foreign_keys": [], + "is_invented_table": False, + }, + "payments": { + "column_count": 3, + "primary_key": ["id"], + "foreign_key_count": 1, + "foreign_keys": [ + { + "columns": ["purchase_id"], + "parent_table_name": "purchases-sfx", + "parent_columns": ["id"], + } + ], + "is_invented_table": False, + }, + "purchases": { + "column_count": 3, + "primary_key": ["id"], + "foreign_key_count": 1, + "foreign_keys": [ + { + "columns": ["user_id"], + "parent_table_name": "users", + "parent_columns": ["id"], + } + ], + "is_invented_table": False, + }, + "purchases-sfx": { + "column_count": 6, + "primary_key": ["id", "~PRIMARY_KEY_ID~"], + "foreign_key_count": 1, + "foreign_keys": [ + { + "columns": ["user_id"], + "parent_table_name": "users", + "parent_columns": ["id"], + } + ], + "is_invented_table": True, + }, + "purchases-data-years-sfx": { + "column_count": 4, + "primary_key": ["~PRIMARY_KEY_ID~"], + "foreign_key_count": 1, + "foreign_keys": [ + { + "columns": ["purchases~id"], + "parent_table_name": "purchases-sfx", + "parent_columns": ["~PRIMARY_KEY_ID~"], + } + ], + "is_invented_table": True, + }, + }, + } diff --git a/tests/relational/test_report.py b/tests/relational/test_report.py index 96ff4ca6..7d6c3c28 100644 --- a/tests/relational/test_report.py +++ b/tests/relational/test_report.py @@ -2,6 +2,7 @@ from lxml import html +from gretel_trainer.relational.core import Scope from gretel_trainer.relational.report.report import ReportPresenter, ReportRenderer from gretel_trainer.relational.table_evaluation import TableEvaluation @@ -12,7 +13,7 @@ def _evals_from_rel_data(rel_data): "privacy_protection_level": {"score": 2, "grade": "Good"}, } evals = {} - for table in rel_data.list_all_tables(): + for table in rel_data.list_all_tables(Scope.PUBLIC): eval = TableEvaluation(cross_table_report_json=d, individual_report_json=d) evals[table] = eval return evals @@ -240,3 +241,35 @@ def test_mutagenesis_relational_data_report(mutagenesis): )[0] == "synthetics_individual_evaluation_atom.html" ) + + +def test_source_data_including_json(documents): + # Fake these + evaluations = _evals_from_rel_data(documents) + + presenter = ReportPresenter( + rel_data=documents, + evaluations=evaluations, + now=datetime.utcnow(), + run_identifier="run_identifier", + ) + + html_content = ReportRenderer().render(presenter) + + # DEV ONLY if you want to save a local copy to look at + # with open("report.html", 'w') as f: + # f.write(html_content) + + tree = html.fromstring(html_content) + + relations_data_rows = tree.xpath( + '//section[contains(@class, "test-table-relationships")]' + "//tr" + )[1:] + + # Ensure public names, not invented table names, are displayed + table_names = [ + # Row, Table name td, bold tag wrapping table name + row.getchildren()[0].getchildren()[0].text + for row in relations_data_rows + ] + assert table_names == ["payments", "purchases", "users"]