From c08b3d970b3b5fea4962b9d2632a99c7c35d489a Mon Sep 17 00:00:00 2001 From: Mike Knepper Date: Tue, 30 May 2023 12:36:52 -0500 Subject: [PATCH] Preserve tables before training (#111) --- src/gretel_trainer/relational/ancestry.py | 38 ++++- src/gretel_trainer/relational/backup.py | 2 - src/gretel_trainer/relational/connectors.py | 2 +- src/gretel_trainer/relational/core.py | 2 +- src/gretel_trainer/relational/multi_table.py | 145 ++++++++++-------- .../relational/strategies/ancestral.py | 34 ++-- .../relational/strategies/common.py | 8 +- .../relational/strategies/independent.py | 5 +- .../relational/tasks/synthetics_run.py | 21 ++- .../relational/workflow_state.py | 2 - tests/relational/test_ancestral_strategy.py | 73 +++------ tests/relational/test_ancestry.py | 54 +++++++ tests/relational/test_backup.py | 5 - tests/relational/test_connectors.py | 6 +- tests/relational/test_independent_strategy.py | 14 +- tests/relational/test_multi_table_restore.py | 3 - tests/relational/test_synthetics_run_task.py | 40 +++-- tests/relational/test_train_synthetics.py | 114 ++++++++++++++ tests/relational/test_train_transforms.py | 14 +- 19 files changed, 406 insertions(+), 176 deletions(-) create mode 100644 tests/relational/test_train_synthetics.py diff --git a/src/gretel_trainer/relational/ancestry.py b/src/gretel_trainer/relational/ancestry.py index 7032822b..757c7ea2 100644 --- a/src/gretel_trainer/relational/ancestry.py +++ b/src/gretel_trainer/relational/ancestry.py @@ -14,6 +14,7 @@ def get_multigenerational_primary_key( rel_data: RelationalData, table: str ) -> list[str]: + "Returns the provided table's primary key with the ancestral lineage prefix appended" return [ f"{_START_LINEAGE}{_END_LINEAGE}{pk}" for pk in rel_data.get_primary_key(table) ] @@ -22,6 +23,17 @@ def get_multigenerational_primary_key( def get_ancestral_foreign_key_maps( rel_data: RelationalData, table: str ) -> list[tuple[str, str]]: + """ + Returns a list of two-element tuples where the first element is a foreign key column + with ancestral lineage prefix, and the second element is the ancestral-lineage-prefixed + referred column. This function ultimately provides a list of which columns are duplicates + in a fully-joined ancestral table (i.e. `get_table_data_with_ancestors`) (only between + the provided table and its direct parents, not between parents and grandparents). + + For example: given an events table with foreign key `events.user_id` => `users.id`, + this method returns: [("self|user_id", "self.user_id|id")] + """ + def _ancestral_fk_map(fk: ForeignKey) -> list[tuple[str, str]]: maps = [] fk_lineage = _COL_DELIMITER.join(fk.columns) @@ -46,6 +58,29 @@ def _ancestral_fk_map(fk: ForeignKey) -> list[tuple[str, str]]: ] +def get_seed_safe_multigenerational_columns( + rel_data: RelationalData, +) -> dict[str, list[str]]: + """ + Returns a dict with Scope.MODELABLE table names as keys and lists of columns to use + for conditional seeding as values. By using a tableset of empty dataframes, this provides + a significantly faster / less resource-intensive way to get just the column names + from the results of `get_table_data_with_ancestors` for all tables. + """ + tableset = { + table: pd.DataFrame(columns=list(rel_data.get_table_columns(table))) + for table in rel_data.list_all_tables() + } + return { + table: list( + get_table_data_with_ancestors( + rel_data, table, tableset, ancestral_seeding=True + ).columns + ) + for table in rel_data.list_all_tables() + } + + def get_table_data_with_ancestors( rel_data: RelationalData, table: str, @@ -93,10 +128,9 @@ def _join_parents( parent_data = tableset[parent_table_name][list(usecols)] else: parent_data = rel_data.get_table_data(parent_table_name, usecols=usecols) - parent_data = parent_data.add_prefix(f"{next_lineage}{_END_LINEAGE}") df = df.merge( - parent_data, + parent_data.add_prefix(f"{next_lineage}{_END_LINEAGE}"), how="left", left_on=[f"{lineage}{_END_LINEAGE}{col}" for col in foreign_key.columns], right_on=[ diff --git a/src/gretel_trainer/relational/backup.py b/src/gretel_trainer/relational/backup.py index 8cc39dce..fd12d2c4 100644 --- a/src/gretel_trainer/relational/backup.py +++ b/src/gretel_trainer/relational/backup.py @@ -93,7 +93,6 @@ class BackupTransformsTrain: class BackupSyntheticsTrain: model_ids: dict[str, str] lost_contact: list[str] - training_columns: dict[str, list[str]] @dataclass @@ -103,7 +102,6 @@ class BackupGenerate: record_size_ratio: float record_handler_ids: dict[str, str] lost_contact: list[str] - missing_model: list[str] @dataclass diff --git a/src/gretel_trainer/relational/connectors.py b/src/gretel_trainer/relational/connectors.py index 9d80a72c..535e40d5 100644 --- a/src/gretel_trainer/relational/connectors.py +++ b/src/gretel_trainer/relational/connectors.py @@ -42,7 +42,7 @@ def __init__(self, engine: Engine): logger.info("Successfully connected to db") def extract( - self, only: Optional[list[str]] = None, ignore: Optional[list[str]] = None + self, only: Optional[set[str]] = None, ignore: Optional[set[str]] = None ) -> RelationalData: """ Extracts table data and relationships from the database. diff --git a/src/gretel_trainer/relational/core.py b/src/gretel_trainer/relational/core.py index 393c1782..93511cf7 100644 --- a/src/gretel_trainer/relational/core.py +++ b/src/gretel_trainer/relational/core.py @@ -674,7 +674,7 @@ def debug_summary(self) -> dict[str, Any]: def skip_table( - table: str, only: Optional[list[str]], ignore: Optional[list[str]] + table: str, only: Optional[set[str]], ignore: Optional[set[str]] ) -> bool: skip = False if only is not None and table not in only: diff --git a/src/gretel_trainer/relational/multi_table.py b/src/gretel_trainer/relational/multi_table.py index d89d5a2d..0aaa2333 100644 --- a/src/gretel_trainer/relational/multi_table.py +++ b/src/gretel_trainer/relational/multi_table.py @@ -251,9 +251,6 @@ def _complete_init_from_backup(self, backup: Backup) -> None: with tarfile.open(synthetics_training_archive_path, "r:gz") as tar: tar.extractall(path=self._working_dir) - self._synthetics_train.training_columns = ( - backup_synthetics_train.training_columns - ) self._synthetics_train.lost_contact = backup_synthetics_train.lost_contact self._synthetics_train.models = { table: self._project.get_model(model_id) @@ -339,7 +336,6 @@ def _complete_init_from_backup(self, backup: Backup) -> None: identifier=backup_generate.identifier, record_size_ratio=backup_generate.record_size_ratio, preserved=backup_generate.preserved, - missing_model=backup_generate.missing_model, lost_contact=backup_generate.lost_contact, record_handlers=record_handlers, ) @@ -458,7 +454,6 @@ def _build_backup(self) -> Backup: for table, model in self._synthetics_train.models.items() }, lost_contact=self._synthetics_train.lost_contact, - training_columns=self._synthetics_train.training_columns, ) # Generate @@ -467,7 +462,6 @@ def _build_backup(self) -> Backup: identifier=self._synthetics_run.identifier, record_size_ratio=self._synthetics_run.record_size_ratio, preserved=self._synthetics_run.preserved, - missing_model=self._synthetics_run.missing_model, lost_contact=self._synthetics_run.lost_contact, record_handler_ids={ table: rh.record_id @@ -602,34 +596,15 @@ def train_transforms( self, config: GretelModelConfig, *, - only: Optional[list[str]] = None, - ignore: Optional[list[str]] = None, + only: Optional[set[str]] = None, + ignore: Optional[set[str]] = None, ) -> None: - if only is not None and ignore is not None: - raise MultiTableException("Cannot specify both `only` and `ignore`.") - - m_only = None - if only is not None: - m_only = [] - for table in only: - m_names = self.relational_data.get_modelable_table_names(table) - if len(m_names) == 0: - raise MultiTableException(f"Unrecognized table name: `{table}`") - m_only.extend(m_names) - - m_ignore = None - if ignore is not None: - m_ignore = [] - for table in ignore: - m_names = self.relational_data.get_modelable_table_names(table) - if len(m_names) == 0: - raise MultiTableException(f"Unrecognized table name: `{table}`") - m_ignore.extend(m_names) + only, ignore = self._get_only_and_ignore(only, ignore) configs = { table: config for table in self.relational_data.list_all_tables() - if not skip_table(table, m_only, m_ignore) + if not skip_table(table, only, ignore) } self._setup_transforms_train_state(configs) @@ -722,15 +697,35 @@ def run_transforms( self._backup() self.transform_output_tables = reshaped_tables + def _get_only_and_ignore( + self, only: Optional[set[str]], ignore: Optional[set[str]] + ) -> tuple[Optional[set[str]], Optional[set[str]]]: + if only is not None and ignore is not None: + raise MultiTableException("Cannot specify both `only` and `ignore`.") + + modelable_tables = set() + for table in only or ignore or {}: + m_names = self.relational_data.get_modelable_table_names(table) + if len(m_names) == 0: + raise MultiTableException(f"Unrecognized table name: `{table}`") + modelable_tables.update(m_names) + + if only is None: + return (None, modelable_tables) + elif ignore is None: + return (modelable_tables, None) + else: + return (None, None) + def _prepare_training_data(self, tables: list[str]) -> dict[str, Path]: """ Exports a copy of each table prepared for training by the configured strategy to the working directory. Returns a dict with table names as keys and Paths to the CSVs as values. """ - training_data = self._strategy.prepare_training_data(self.relational_data) - for table, df in training_data.items(): - self._synthetics_train.training_columns[table] = list(df.columns) + training_data = self._strategy.prepare_training_data( + self.relational_data, tables + ) training_paths = {} for table_name in tables: @@ -748,6 +743,13 @@ def _train_synthetics_models(self, training_data: dict[str, Path]) -> None: ) self._synthetics_train.models[table_name] = model + archive_path = self._working_dir / "synthetics_training.tar.gz" + for table_name, csv_path in training_data.items(): + add_to_tar(archive_path, csv_path, csv_path.name) + self._artifact_collection.upload_synthetics_training_archive( + self._project, str(archive_path) + ) + self._backup() task = SyntheticsTrainTask( @@ -767,22 +769,51 @@ def _train_synthetics_models(self, training_data: dict[str, Path]) -> None: self._extended_sdk, ) - # TODO: consider moving this to before running the task - archive_path = self._working_dir / "synthetics_training.tar.gz" - for table_name, csv_path in training_data.items(): - add_to_tar(archive_path, csv_path, csv_path.name) - self._artifact_collection.upload_synthetics_training_archive( - self._project, str(archive_path) - ) - def train(self) -> None: - """Train synthetic data models on each table in the relational dataset""" + """ + DEPRECATED: Please use `train_synthetics` instead. + """ + logger.warning( + "This method is deprecated and will be removed in a future release. " + "Please use `train_synthetics` instead." + ) tables = self.relational_data.list_all_tables() self._synthetics_train = SyntheticsTrain() training_data = self._prepare_training_data(tables) self._train_synthetics_models(training_data) + def train_synthetics( + self, + *, + only: Optional[set[str]] = None, + ignore: Optional[set[str]] = None, + ) -> None: + """ + Train synthetic data models for the tables in the tableset, + optionally scoped by either `only` or `ignore`. + """ + only, ignore = self._get_only_and_ignore(only, ignore) + + all_tables = self.relational_data.list_all_tables() + include_tables = [] + omit_tables = [] + for table in all_tables: + if skip_table(table, only, ignore): + omit_tables.append(table) + else: + include_tables.append(table) + + # TODO: Ancestral strategy requires that for each table omitted from synthetics ("preserved"), + # all its ancestors must also be omitted. In the future, it'd be nice to either find a way to + # eliminate this requirement completely, or (less ideal) allow incremental training of tables, + # e.g. train a few in one "batch", then a few more before generating (perhaps with some logs + # along the way informing the user of which required tables are missing). + self._strategy.validate_preserved_tables(omit_tables, self.relational_data) + + training_data = self._prepare_training_data(include_tables) + self._train_synthetics_models(training_data) + def retrain_tables(self, tables: dict[str, pd.DataFrame]) -> None: """ Provide updated table data and retrain. This method overwrites the table data in the @@ -801,7 +832,6 @@ def retrain_tables(self, tables: dict[str, pd.DataFrame]) -> None: for table in tables_to_retrain: with suppress(KeyError): del self._synthetics_train.models[table] - del self._synthetics_train.training_columns[table] training_data = self._prepare_training_data(tables_to_retrain) self._train_synthetics_models(training_data) @@ -861,18 +891,23 @@ def generate( logger.info(f"Resuming synthetics run `{self._synthetics_run.identifier}`") else: preserve_tables = preserve_tables or [] + preserve_tables.extend( + [ + table + for table in self.relational_data.list_all_tables() + if table not in self._synthetics_train.models + ] + ) self._strategy.validate_preserved_tables( preserve_tables, self.relational_data ) identifier = identifier or f"synthetics_{_timestamp()}" - missing_model = self._list_tables_with_missing_models() self._synthetics_run = SyntheticsRun( identifier=identifier, record_size_ratio=record_size_ratio, preserved=preserve_tables, - missing_model=missing_model, record_handlers={}, lost_contact=[], ) @@ -909,6 +944,9 @@ def generate( if table in self._synthetics_run.preserved: continue + if table not in self._synthetics_train.models: + continue + if table not in self.relational_data.list_all_tables(Scope.EVALUATABLE): continue @@ -982,21 +1020,6 @@ def create_relational_report(self, run_identifier: str, target_dir: Path) -> Non html_content = ReportRenderer().render(presenter) report.write(html_content) - def _list_tables_with_missing_models(self) -> list[str]: - missing_model = set() - for table in self.relational_data.list_all_tables(): - if not _table_trained_successfully(self._synthetics_train, table): - logger.info( - f"Skipping synthetic data generation for `{table}` because it does not have a trained model" - ) - missing_model.add(table) - for descendant in self.relational_data.get_descendants(table): - logger.info( - f"Skipping synthetic data generation for `{descendant}` because it depends on `{table}`" - ) - missing_model.add(table) - return list(missing_model) - def _attach_existing_reports(self, run_id: str, table: str) -> None: individual_path = ( self._working_dir @@ -1046,9 +1069,7 @@ def _validate_strategy(strategy: str) -> Union[IndependentStrategy, AncestralStr raise MultiTableException(msg) -def _table_trained_successfully( - train_state: Union[TransformsTrain, SyntheticsTrain], table: str -) -> bool: +def _table_trained_successfully(train_state: TransformsTrain, table: str) -> bool: model = train_state.models.get(table) if model is None: return False diff --git a/src/gretel_trainer/relational/strategies/ancestral.py b/src/gretel_trainer/relational/strategies/ancestral.py index 7e73c4db..76e39dbf 100644 --- a/src/gretel_trainer/relational/strategies/ancestral.py +++ b/src/gretel_trainer/relational/strategies/ancestral.py @@ -33,7 +33,7 @@ def label_encode_keys( return common.label_encode_keys(rel_data, tables) def prepare_training_data( - self, rel_data: RelationalData + self, rel_data: RelationalData, tables: list[str] ) -> dict[str, pd.DataFrame]: """ Returns tables with: @@ -43,6 +43,7 @@ def prepare_training_data( - artificial min/max seed records added """ all_tables = rel_data.list_all_tables() + omitted_tables = [t for t in all_tables if t not in tables] altered_tableset = {} training_data = {} @@ -51,13 +52,17 @@ def prepare_training_data( altered_tableset[table_name] = rel_data.get_table_data(table_name).copy() # Translate all keys to a contiguous list of integers - altered_tableset = common.label_encode_keys(rel_data, altered_tableset) + altered_tableset = common.label_encode_keys( + rel_data, altered_tableset, omit=omitted_tables + ) # Add artificial rows to support seeding - altered_tableset = _add_artifical_rows_for_seeding(rel_data, altered_tableset) + altered_tableset = _add_artifical_rows_for_seeding( + rel_data, altered_tableset, omitted_tables + ) # Collect all data in multigenerational format - for table_name in all_tables: + for table_name in tables: data = ancestry.get_table_data_with_ancestors( rel_data=rel_data, table=table_name, @@ -133,7 +138,6 @@ def get_generation_job( record_size_ratio: float, output_tables: dict[str, pd.DataFrame], target_dir: Path, - training_columns: list[str], ) -> dict[str, Any]: """ Returns kwargs for creating a record handler job via the Gretel SDK. @@ -149,7 +153,7 @@ def get_generation_job( return {"params": {"num_records": synth_size}} else: seed_df = self._build_seed_data_for_table( - table, output_tables, rel_data, synth_size, training_columns + table, output_tables, rel_data, synth_size ) seed_path = target_dir / f"synthetics_seed_{table}.csv" seed_df.to_csv(seed_path, index=False) @@ -161,8 +165,8 @@ def _build_seed_data_for_table( output_tables: dict[str, pd.DataFrame], rel_data: RelationalData, synth_size: int, - training_columns: list[str], ) -> pd.DataFrame: + column_legend = ancestry.get_seed_safe_multigenerational_columns(rel_data) seed_df = pd.DataFrame() source_data = rel_data.get_table_data(table) @@ -202,7 +206,9 @@ def _build_seed_data_for_table( # Drop any columns that weren't used in training, as well as the temporary merge column columns_to_drop = [ - col for col in this_fk_seed_df.columns if col not in training_columns + col + for col in this_fk_seed_df.columns + if col not in column_legend[table] ] columns_to_drop.append(tmp_column_name) this_fk_seed_df = this_fk_seed_df.drop(columns=columns_to_drop) @@ -325,11 +331,14 @@ def update_evaluation_from_evaluate( def _add_artifical_rows_for_seeding( - rel_data: RelationalData, tables: dict[str, pd.DataFrame] + rel_data: RelationalData, tables: dict[str, pd.DataFrame], omitted: list[str] ) -> dict[str, pd.DataFrame]: # On each table, add an artifical row with the max possible PK value + # unless the table is omitted from synthetics max_pk_values = {} for table_name, data in tables.items(): + if table_name in omitted: + continue max_pk_values[table_name] = len(data) * 50 random_record = tables[table_name].sample().copy() @@ -343,6 +352,10 @@ def _add_artifical_rows_for_seeding( if len(foreign_keys) == 0: continue + # Skip if the parent table is omitted and is the only parent + if len(foreign_keys) == 1 and foreign_keys[0].parent_table_name in omitted: + continue + two_records = tables[table_name].sample(2) min_fk_record = two_records.head(1).copy() max_fk_record = two_records.tail(1).copy() @@ -354,6 +367,9 @@ def _add_artifical_rows_for_seeding( # This can potentially overwrite the auto-incremented primary keys above in the case of composite keys for foreign_key in foreign_keys: + # Treat FK columns to omitted parents as normal columns + if foreign_key.parent_table_name in omitted: + continue for fk_col in foreign_key.columns: min_fk_record[fk_col] = 0 max_fk_record[fk_col] = max_pk_values[foreign_key.parent_table_name] diff --git a/src/gretel_trainer/relational/strategies/common.py b/src/gretel_trainer/relational/strategies/common.py index 2932fb75..9c4c4595 100644 --- a/src/gretel_trainer/relational/strategies/common.py +++ b/src/gretel_trainer/relational/strategies/common.py @@ -48,13 +48,19 @@ def _get_report_json(model: Model) -> Optional[dict]: def label_encode_keys( - rel_data: RelationalData, tables: dict[str, pd.DataFrame] + rel_data: RelationalData, + tables: dict[str, pd.DataFrame], + omit: Optional[list[str]] = None, ) -> dict[str, pd.DataFrame]: """ Crawls tables for all key columns (primary and foreign). For each PK (and FK columns referencing it), runs all values through a LabelEncoder and updates tables' columns to use LE-transformed values. """ + omit = omit or [] for table_name in rel_data.list_tables_parents_before_children(): + if table_name in omit: + continue + df = tables.get(table_name) if df is None: continue diff --git a/src/gretel_trainer/relational/strategies/independent.py b/src/gretel_trainer/relational/strategies/independent.py index 93b38409..8c771aa4 100644 --- a/src/gretel_trainer/relational/strategies/independent.py +++ b/src/gretel_trainer/relational/strategies/independent.py @@ -34,14 +34,14 @@ def label_encode_keys( return common.label_encode_keys(rel_data, tables) def prepare_training_data( - self, rel_data: RelationalData + self, rel_data: RelationalData, tables: list[str] ) -> dict[str, pd.DataFrame]: """ Returns source tables with primary and foreign keys removed """ training_data = {} - for table_name in rel_data.list_all_tables(): + for table_name in tables: columns_to_drop = [] columns_to_drop.extend(rel_data.get_primary_key(table_name)) for foreign_key in rel_data.get_foreign_keys(table_name): @@ -100,7 +100,6 @@ def get_generation_job( record_size_ratio: float, output_tables: dict[str, pd.DataFrame], target_dir: Path, - training_columns: list[str], ) -> dict[str, Any]: """ Returns kwargs for a record handler job requesting an output record diff --git a/src/gretel_trainer/relational/tasks/synthetics_run.py b/src/gretel_trainer/relational/tasks/synthetics_run.py index 01dad099..82221ad6 100644 --- a/src/gretel_trainer/relational/tasks/synthetics_run.py +++ b/src/gretel_trainer/relational/tasks/synthetics_run.py @@ -3,7 +3,7 @@ from typing import Optional import pandas as pd -from gretel_client.projects.jobs import ACTIVE_STATES, Job +from gretel_client.projects.jobs import ACTIVE_STATES, Job, Status from gretel_client.projects.projects import Project from gretel_client.projects.records import RecordHandler @@ -31,14 +31,20 @@ def __init__( def _setup_working_tables(self) -> dict[str, Optional[pd.DataFrame]]: working_tables = {} + all_tables = self.multitable.relational_data.list_all_tables() - for table in self.synthetics_run.missing_model: - working_tables[table] = None + for table in all_tables: + model = self.synthetics_train.models.get(table) - for table in self.synthetics_run.preserved: - working_tables[table] = self.multitable._strategy.get_preserved_data( - table, self.multitable.relational_data - ) + # Table was either omitted from training or marked as to-be-preserved during generation + if model is None or table in self.synthetics_run.preserved: + working_tables[table] = self.multitable._strategy.get_preserved_data( + table, self.multitable.relational_data + ) + + # Table was included in training, but failed at that step + elif model.status != Status.COMPLETED: + working_tables[table] = None return working_tables @@ -139,7 +145,6 @@ def each_iteration(self) -> None: self.synthetics_run.record_size_ratio, present_working_tables, self.run_dir, - self.synthetics_train.training_columns[table_name], ) model = self.synthetics_train.models[table_name] record_handler = model.create_record_handler_obj(**table_job) diff --git a/src/gretel_trainer/relational/workflow_state.py b/src/gretel_trainer/relational/workflow_state.py index 1f3f1152..e93fe813 100644 --- a/src/gretel_trainer/relational/workflow_state.py +++ b/src/gretel_trainer/relational/workflow_state.py @@ -19,7 +19,6 @@ class TransformsTrain: class SyntheticsTrain: models: dict[str, Model] = field(default_factory=dict) lost_contact: list[str] = field(default_factory=list) - training_columns: dict[str, list[str]] = field(default_factory=dict) @dataclass @@ -29,4 +28,3 @@ class SyntheticsRun: preserved: list[str] record_handlers: dict[str, RecordHandler] lost_contact: list[str] - missing_model: list[str] diff --git a/tests/relational/test_ancestral_strategy.py b/tests/relational/test_ancestral_strategy.py index ebf37f84..6f2e6192 100644 --- a/tests/relational/test_ancestral_strategy.py +++ b/tests/relational/test_ancestral_strategy.py @@ -22,7 +22,7 @@ def test_preparing_training_data_does_not_mutate_source_data(pets, art): } strategy = AncestralStrategy() - strategy.prepare_training_data(rel_data) + strategy.prepare_training_data(rel_data, rel_data.list_all_tables()) for table in rel_data.list_all_tables(): pdtest.assert_frame_equal( @@ -30,10 +30,26 @@ def test_preparing_training_data_does_not_mutate_source_data(pets, art): ) +def test_prepare_training_data_subset_of_tables(pets): + strategy = AncestralStrategy() + + # We aren't synthesizing the "humans" table, so it is not in this list argument... + training_data = strategy.prepare_training_data(pets, ["pets"]) + # ...nor do we create training data for it + assert set(training_data.keys()) == {"pets"} + + # Since the humans table is omitted from synthetics, we leave the FK values alone; specifically: + # - they are not label-encoded (which would effectively zero-index them) + # - we do not add artificial min/max values + assert set(training_data["pets"]["self|human_id"].values) == {1, 2, 3, 4, 5} + # We do add the artificial max PK row, though, since this table is being synthesized + assert len(training_data["pets"]) == 6 + + def test_prepare_training_data_returns_multigenerational_data(pets): strategy = AncestralStrategy() - training_data = strategy.prepare_training_data(pets) + training_data = strategy.prepare_training_data(pets, pets.list_all_tables()) for expected_column in ["self|id", "self|name", "self.human_id|id"]: assert expected_column in training_data["pets"] @@ -61,7 +77,7 @@ def test_prepare_training_data_drops_highly_unique_categorical_ancestor_fields(a ) strategy = AncestralStrategy() - training_data = strategy.prepare_training_data(art) + training_data = strategy.prepare_training_data(art, art.list_all_tables()) # Does not contain `self.artist_id|name` because it is highly unique categorical assert set(training_data["paintings"].columns) == { @@ -100,7 +116,7 @@ def test_prepare_training_data_drops_highly_nan_ancestor_fields(art): ) strategy = AncestralStrategy() - training_data = strategy.prepare_training_data(art) + training_data = strategy.prepare_training_data(art, art.list_all_tables()) # Does not contain `self.artist_id|name` because it is highly NaN assert set(training_data["paintings"].columns) == { @@ -115,7 +131,7 @@ def test_prepare_training_data_translates_alphanumeric_keys_and_adds_min_max_rec art, ): strategy = AncestralStrategy() - training_data = strategy.prepare_training_data(art) + training_data = strategy.prepare_training_data(art, art.list_all_tables()) # Artists, a parent table, should have 1 additional row assert len(training_data["artists"]) == len(art.get_table_data("artists")) + 1 @@ -142,7 +158,7 @@ def test_prepare_training_data_translates_alphanumeric_keys_and_adds_min_max_rec def test_prepare_training_data_with_composite_keys(tpch): strategy = AncestralStrategy() - training_data = strategy.prepare_training_data(tpch) + training_data = strategy.prepare_training_data(tpch, tpch.list_all_tables()) l_max = len(tpch.get_table_data("lineitem")) * 50 ps_max = len(tpch.get_table_data("partsupp")) * 50 @@ -339,28 +355,11 @@ def test_table_generation_readiness(ecom): def test_generation_job(pets): strategy = AncestralStrategy() - training_columns = { - "humans": [ - "self|id", - "self|name", - "self|city", - ], - "pets": [ - "self|id", - "self|name", - "self|age", - "self|human_id", - "self.human_id|id", - # "self.human_id|name", # highly unique categorical - "self.human_id|city", - ], - } - # Table with no ancestors with tempfile.TemporaryDirectory() as tmp: working_dir = Path(tmp) parent_table_job = strategy.get_generation_job( - "humans", pets, 2.0, {}, working_dir, training_columns["humans"] + "humans", pets, 2.0, {}, working_dir ) assert len(os.listdir(working_dir)) == 0 assert parent_table_job == {"params": {"num_records": 10}} @@ -389,7 +388,7 @@ def test_generation_job(pets): with tempfile.TemporaryDirectory() as tmp: working_dir = Path(tmp) child_table_job = strategy.get_generation_job( - "pets", pets, 2.0, output_tables, working_dir, training_columns["pets"] + "pets", pets, 2.0, output_tables, working_dir ) assert len(os.listdir(working_dir)) == 1 @@ -429,29 +428,6 @@ def test_generation_job_seeds_go_back_multiple_generations(source_nba, synthetic "cities": ancestry.get_table_data_with_ancestors(synthetic_nba, "cities"), "states": ancestry.get_table_data_with_ancestors(synthetic_nba, "states"), } - training_columns = { - "teams": [ - "self|name", - "self|id", - "self|city_id", - "self.city_id|id", - "self.city_id|state_id", - # "self.city_id|name", # highly unique categorical - "self.city_id.state_id|id", - # "self.city_id.state_id|name", # highly unique categorical - ], - "cities": [ - "self|id", - "self|state_id", - # "self|name", # highly unique categorical - "self.state_id|id", - # "self.state_id|name", # highly unique categorical - ], - "states": [ - "self|id", - "self|name", - ], - } strategy = AncestralStrategy() @@ -463,7 +439,6 @@ def test_generation_job_seeds_go_back_multiple_generations(source_nba, synthetic 1.0, output_tables, working_dir, - training_columns["teams"], ) seed_df = pd.read_csv(job["data_source"]) diff --git a/tests/relational/test_ancestry.py b/tests/relational/test_ancestry.py index b32a259c..4bc62f50 100644 --- a/tests/relational/test_ancestry.py +++ b/tests/relational/test_ancestry.py @@ -187,3 +187,57 @@ def test_prepend_foreign_key_lineage(ecom): "self.inventory_item_id.product_id.distribution_center_id|id", "self.inventory_item_id.product_id.distribution_center_id|name", } + + +def test_get_seed_safe_multigenerational_columns_1(pets): + table_cols = ancestry.get_seed_safe_multigenerational_columns(pets) + + expected = { + "humans": {"self|id", "self|name", "self|city"}, + "pets": { + "self|id", + "self|name", + "self|age", + "self|human_id", + "self.human_id|id", + # "self.human_id|name", # highly unique categorical + "self.human_id|city", + }, + } + + assert set(table_cols.keys()) == set(expected.keys()) + for table, expected_cols in expected.items(): + assert set(table_cols[table]) == expected_cols + + +def test_get_seed_safe_multigenerational_columns_2(source_nba): + source_nba = source_nba[0] + table_cols = ancestry.get_seed_safe_multigenerational_columns(source_nba) + + expected = { + "teams": { + "self|name", + "self|id", + "self|city_id", + "self.city_id|id", + "self.city_id|state_id", + # "self.city_id|name", # highly unique categorical + "self.city_id.state_id|id", + # "self.city_id.state_id|name", # highly unique categorical + }, + "cities": { + "self|id", + "self|state_id", + "self|name", + "self.state_id|id", + # "self.state_id|name", # highly unique categorical + }, + "states": { + "self|id", + "self|name", + }, + } + + assert set(table_cols.keys()) == set(expected.keys()) + for table, expected_cols in expected.items(): + assert set(table_cols[table]) == expected_cols diff --git a/tests/relational/test_backup.py b/tests/relational/test_backup.py index a6b31ae3..05683bf9 100644 --- a/tests/relational/test_backup.py +++ b/tests/relational/test_backup.py @@ -130,10 +130,6 @@ def test_backup(): "customer": "1234567890", "address": "0987654321", }, - training_columns={ - "customer": ["id", "first", "last"], - "address": ["customer_id", "street", "city"], - }, lost_contact=[], ) backup_generate = BackupGenerate( @@ -141,7 +137,6 @@ def test_backup(): preserved=[], record_size_ratio=1.0, lost_contact=[], - missing_model=[], record_handler_ids={ "customer": "555444666", "address": "333111222", diff --git a/tests/relational/test_connectors.py b/tests/relational/test_connectors.py index dbee2014..0567a7d5 100644 --- a/tests/relational/test_connectors.py +++ b/tests/relational/test_connectors.py @@ -17,11 +17,11 @@ def test_extract_subsets_of_relational_data(example_dbs): connector = sqlite_conn(f.name) with pytest.raises(MultiTableException): - connector.extract(only=["users"], ignore=["events"]) + connector.extract(only={"users"}, ignore={"events"}) - only = connector.extract(only=["users", "events", "products"]) + only = connector.extract(only={"users", "events", "products"}) ignore = connector.extract( - ignore=["distribution_center", "order_items", "inventory_items"] + ignore={"distribution_center", "order_items", "inventory_items"} ) expected_tables = {"users", "events", "products"} diff --git a/tests/relational/test_independent_strategy.py b/tests/relational/test_independent_strategy.py index 6ea31c89..6ff22449 100644 --- a/tests/relational/test_independent_strategy.py +++ b/tests/relational/test_independent_strategy.py @@ -19,7 +19,7 @@ def test_preparing_training_data_does_not_mutate_source_data(pets, art): } strategy = IndependentStrategy() - strategy.prepare_training_data(rel_data) + strategy.prepare_training_data(rel_data, rel_data.list_all_tables()) for table in rel_data.list_all_tables(): pdtest.assert_frame_equal( @@ -30,11 +30,19 @@ def test_preparing_training_data_does_not_mutate_source_data(pets, art): def test_prepare_training_data_removes_primary_and_foreign_keys(pets): strategy = IndependentStrategy() - training_data = strategy.prepare_training_data(pets) + training_data = strategy.prepare_training_data(pets, pets.list_all_tables()) assert set(training_data["pets"].columns) == {"name", "age"} +def test_prepare_training_data_subset_of_tables(pets): + strategy = IndependentStrategy() + + training_data = strategy.prepare_training_data(pets, ["humans"]) + + assert set(training_data.keys()) == {"humans"} + + def test_retraining_a_set_of_tables_only_retrains_those_tables(ecom): strategy = IndependentStrategy() assert set(strategy.tables_to_retrain(["users"], ecom)) == {"users"} @@ -69,7 +77,7 @@ def test_table_generation_readiness(ecom): def test_generation_job_requests_num_records(pets): strategy = IndependentStrategy() - job = strategy.get_generation_job("pets", pets, 2.0, {}, Path("/working"), []) + job = strategy.get_generation_job("pets", pets, 2.0, {}, Path("/working")) assert job == {"params": {"num_records": 10}} diff --git a/tests/relational/test_multi_table_restore.py b/tests/relational/test_multi_table_restore.py index 9e35d593..571f214a 100644 --- a/tests/relational/test_multi_table_restore.py +++ b/tests/relational/test_multi_table_restore.py @@ -89,7 +89,6 @@ def make_backup( model_ids={ table: mock.model_id for table, mock in synthetics_models.items() }, - training_columns={table: ["col1", "col2"] for table in synthetics_models}, lost_contact=[], ) if len(synthetics_record_handlers) > 0: @@ -98,7 +97,6 @@ def make_backup( preserved=[], record_size_ratio=1.0, lost_contact=[], - missing_model=[], record_handler_ids={ table: mock.record_id for table, mock in synthetics_record_handlers.items() @@ -672,7 +670,6 @@ def test_restore_generate_in_progress( preserved=[], record_size_ratio=1.0, lost_contact=[], - missing_model=[], record_handlers=synthetics_record_handlers, ) assert len(mt.synthetic_output_tables) == 0 diff --git a/tests/relational/test_synthetics_run_task.py b/tests/relational/test_synthetics_run_task.py index 60918543..250ae358 100644 --- a/tests/relational/test_synthetics_run_task.py +++ b/tests/relational/test_synthetics_run_task.py @@ -7,6 +7,7 @@ import pandas as pd import pandas.testing as pdtest import pytest +from gretel_client.projects.jobs import Status from gretel_client.projects.projects import Project from gretel_trainer.relational.core import RelationalData @@ -42,8 +43,15 @@ def make_task( rel_data: RelationalData, run_dir: Path, preserved: Optional[list[str]] = None, - missing_model: Optional[list[str]] = None, + failed: Optional[list[str]] = None, + omitted: Optional[list[str]] = None, ) -> SyntheticsRunTask: + def _status_for_table(table: str, failed: list[str]) -> Status: + if table in failed: + return Status.ERROR + else: + return Status.COMPLETED + multitable = MockMultiTable(relational_data=rel_data) return SyntheticsRunTask( synthetics_run=SyntheticsRun( @@ -51,17 +59,16 @@ def make_task( record_handlers={}, lost_contact=[], preserved=preserved or [], - missing_model=missing_model or [], record_size_ratio=1.0, ), synthetics_train=SyntheticsTrain( - training_columns={ - table: list(rel_data.get_table_data(table).columns) - for table in rel_data.list_all_tables() - }, models={ - table: Mock(create_record_handler=Mock()) + table: Mock( + create_record_handler=Mock(), + status=_status_for_table(table, failed or []), + ) for table in rel_data.list_all_tables() + if table not in (omitted or []) }, ), run_dir=run_dir, @@ -72,26 +79,29 @@ def make_task( def test_ignores_preserved_tables(pets, tmpdir): task = make_task(pets, tmpdir, preserved=["pets"]) + # Source data is used assert task.working_tables["pets"] is not None assert "pets" in task.output_tables task.each_iteration() assert "pets" not in task.synthetics_run.record_handlers -def test_ignores_tables_that_failed_to_train(pets, tmpdir): - task = make_task(pets, tmpdir, missing_model=["pets"]) +def test_ignores_tables_that_were_omitted_from_training(pets, tmpdir): + task = make_task(pets, tmpdir, omitted=["pets"]) - assert task.working_tables["pets"] is None - assert "pets" not in task.output_tables + # Source data is used + assert task.working_tables["pets"] is not None + assert "pets" in task.output_tables task.each_iteration() assert "pets" not in task.synthetics_run.record_handlers -def test_preserve_takes_precedence_over_missing_model(pets, tmpdir): - task = make_task(pets, tmpdir, preserved=["pets"], missing_model=["pets"]) +def test_ignores_tables_that_failed_during_training(pets, tmpdir): + task = make_task(pets, tmpdir, failed=["pets"]) - assert task.working_tables["pets"] is not None - assert "pets" in task.output_tables + # We set tables that failed to explicit None + assert task.working_tables["pets"] is None + assert "pets" not in task.output_tables task.each_iteration() assert "pets" not in task.synthetics_run.record_handlers diff --git a/tests/relational/test_train_synthetics.py b/tests/relational/test_train_synthetics.py new file mode 100644 index 00000000..f02d0bb5 --- /dev/null +++ b/tests/relational/test_train_synthetics.py @@ -0,0 +1,114 @@ +import tempfile +from unittest.mock import ANY, patch + +import pytest + +from gretel_trainer.relational.core import MultiTableException +from gretel_trainer.relational.multi_table import MultiTable + + +# The assertions in this file are concerned with setting up the synthetics train +# workflow state properly, and stop short of kicking off the task. +@pytest.fixture(autouse=True) +def run_task(): + with patch("gretel_trainer.relational.multi_table.run_task"): + yield + + +@pytest.fixture(autouse=True) +def backup(): + with patch.object(MultiTable, "_backup", return_value=None): + yield + + +@pytest.fixture() +def tmpdir(project): + with tempfile.TemporaryDirectory() as tmpdir: + project.name = tmpdir + yield tmpdir + + +def test_train_synthetics_defaults_to_training_all_tables(ecom, tmpdir): + mt = MultiTable(ecom, project_display_name=tmpdir) + mt.train_synthetics() + + assert set(mt._synthetics_train.models.keys()) == set(ecom.list_all_tables()) + + +def test_train_synthetics_only_includes_specified_tables(ecom, tmpdir, project): + mt = MultiTable(ecom, project_display_name=tmpdir) + mt.train_synthetics(only={"users"}) + + assert set(mt._synthetics_train.models.keys()) == {"users"} + project.create_model_obj.assert_called_with( + model_config=ANY, # a tailored synthetics config, in dict form + data_source=f"{tmpdir}/synthetics_train_users.csv", + ) + + +def test_train_synthetics_ignore_excludes_specified_tables(ecom, tmpdir): + mt = MultiTable(ecom, project_display_name=tmpdir) + mt.train_synthetics(ignore={"distribution_center", "products"}) + + assert set(mt._synthetics_train.models.keys()) == { + "events", + "users", + "order_items", + "inventory_items", + } + + +def test_train_synthetics_exits_early_if_unrecognized_tables(ecom, tmpdir, project): + mt = MultiTable(ecom, project_display_name=tmpdir) + with pytest.raises(MultiTableException): + mt.train_synthetics(ignore={"nonsense"}) + + assert len(mt._synthetics_train.models) == 0 + project.create_model_obj.assert_not_called() + + +def test_train_synthetics_multiple_calls_additive(ecom, tmpdir): + mt = MultiTable(ecom, project_display_name=tmpdir) + mt.train_synthetics(only={"products"}) + mt.train_synthetics(only={"users"}) + + # We do not lose the first table model + assert set(mt._synthetics_train.models.keys()) == {"products", "users"} + + +def test_train_synthetics_models_for_dbs_with_invented_tables(documents, tmpdir): + mt = MultiTable(documents, project_display_name=tmpdir) + mt.train_synthetics() + + assert set(mt._synthetics_train.models.keys()) == { + "users", + "payments", + "purchases-sfx", + "purchases-data-years-sfx", + } + + +def test_train_synthetics_table_filters_cascade_to_invented_tables(documents, tmpdir): + # When a user provides the ("public") name of a table that contained JSON and led + # to the creation of invented tables, we recognize that as implicitly applying to + # all the tables internally created from that source table. + mt = MultiTable(documents, project_display_name=tmpdir) + mt.train_synthetics(ignore={"purchases"}) + + assert set(mt._synthetics_train.models.keys()) == {"users", "payments"} + + +def test_train_synthetics_multiple_calls_overwrite(ecom, tmpdir, project): + project.create_model_obj.return_value = "m1" + + mt = MultiTable(ecom, project_display_name=tmpdir) + mt.train_synthetics(only={"products"}) + + assert mt._synthetics_train.models["products"] == "m1" + + project.reset_mock() + project.create_model_obj.return_value = "m2" + + # calling a second time will create a new model for the table that overwrites the original + mt.train_synthetics(only={"products"}) + assert mt._synthetics_train.models["products"] == "m2" diff --git a/tests/relational/test_train_transforms.py b/tests/relational/test_train_transforms.py index 4c93cdbd..530c9626 100644 --- a/tests/relational/test_train_transforms.py +++ b/tests/relational/test_train_transforms.py @@ -38,7 +38,7 @@ def test_train_transforms_defaults_to_transforming_all_tables(ecom, tmpdir): def test_train_transforms_only_includes_specified_tables(ecom, tmpdir, project): mt = MultiTable(ecom, project_display_name=tmpdir) - mt.train_transforms("transform/default", only=["users"]) + mt.train_transforms("transform/default", only={"users"}) transforms_train = mt._transforms_train assert set(transforms_train.models.keys()) == {"users"} @@ -50,7 +50,7 @@ def test_train_transforms_only_includes_specified_tables(ecom, tmpdir, project): def test_train_transforms_ignore_excludes_specified_tables(ecom, tmpdir): mt = MultiTable(ecom, project_display_name=tmpdir) - mt.train_transforms("transform/default", ignore=["distribution_center", "products"]) + mt.train_transforms("transform/default", ignore={"distribution_center", "products"}) transforms_train = mt._transforms_train assert set(transforms_train.models.keys()) == { @@ -64,7 +64,7 @@ def test_train_transforms_ignore_excludes_specified_tables(ecom, tmpdir): def test_train_transforms_exits_early_if_unrecognized_tables(ecom, tmpdir, project): mt = MultiTable(ecom, project_display_name=tmpdir) with pytest.raises(MultiTableException): - mt.train_transforms("transform/default", ignore=["nonsense"]) + mt.train_transforms("transform/default", ignore={"nonsense"}) transforms_train = mt._transforms_train assert len(transforms_train.models) == 0 @@ -73,8 +73,8 @@ def test_train_transforms_exits_early_if_unrecognized_tables(ecom, tmpdir, proje def test_train_transforms_multiple_calls_additive(ecom, tmpdir): mt = MultiTable(ecom, project_display_name=tmpdir) - mt.train_transforms("transform/default", only=["products"]) - mt.train_transforms("transform/default", only=["users"]) + mt.train_transforms("transform/default", only={"products"}) + mt.train_transforms("transform/default", only={"users"}) # We do not lose the first table model assert set(mt._transforms_train.models.keys()) == {"products", "users"} @@ -84,7 +84,7 @@ def test_train_transforms_multiple_calls_overwrite(ecom, tmpdir, project): project.create_model_obj.return_value = "m1" mt = MultiTable(ecom, project_display_name=tmpdir) - mt.train_transforms("transform/default", only=["products"]) + mt.train_transforms("transform/default", only={"products"}) assert mt._transforms_train.models["products"] == "m1" @@ -92,7 +92,7 @@ def test_train_transforms_multiple_calls_overwrite(ecom, tmpdir, project): project.create_model_obj.return_value = "m2" # calling a second time will create a new model for the table that overwrites the original - mt.train_transforms("transform/default", only=["products"]) + mt.train_transforms("transform/default", only={"products"}) assert mt._transforms_train.models["products"] == "m2"