diff --git a/src/gretel_trainer/relational/ancestry.py b/src/gretel_trainer/relational/ancestry.py index 6f8d9567..db034b30 100644 --- a/src/gretel_trainer/relational/ancestry.py +++ b/src/gretel_trainer/relational/ancestry.py @@ -21,6 +21,13 @@ def get_multigenerational_primary_key( ] +def get_all_key_columns(rel_data: RelationalData, table: str) -> list[str]: + tableset = _empty_data_tableset(rel_data) + return list( + get_table_data_with_ancestors(rel_data, table, tableset, keys_only=True).columns + ) + + def get_ancestral_foreign_key_maps( rel_data: RelationalData, table: str ) -> list[tuple[str, str]]: @@ -59,6 +66,13 @@ def _ancestral_fk_map(fk: ForeignKey) -> list[tuple[str, str]]: ] +def _empty_data_tableset(rel_data: RelationalData) -> dict[str, pd.DataFrame]: + return { + table: pd.DataFrame(columns=list(rel_data.get_table_columns(table))) + for table in rel_data.list_all_tables() + } + + def get_seed_safe_multigenerational_columns( rel_data: RelationalData, ) -> dict[str, list[str]]: @@ -68,10 +82,7 @@ def get_seed_safe_multigenerational_columns( a significantly faster / less resource-intensive way to get just the column names from the results of `get_table_data_with_ancestors` for all tables. """ - tableset = { - table: pd.DataFrame(columns=list(rel_data.get_table_columns(table))) - for table in rel_data.list_all_tables() - } + tableset = _empty_data_tableset(rel_data) return { table: list( get_table_data_with_ancestors( @@ -87,6 +98,7 @@ def get_table_data_with_ancestors( table: str, tableset: Optional[dict[str, pd.DataFrame]] = None, ancestral_seeding: bool = False, + keys_only: bool = False, ) -> pd.DataFrame: """ Returns a data frame with all ancestral data joined to each record. @@ -96,14 +108,26 @@ def get_table_data_with_ancestors( separated by periods. If `tableset` is provided, use it in place of the source data in `self.graph`. + + If `ancestral_seeding` is True, the returned dataframe only includes columns + that can be used as conditional seeds. + + If `keys_only` is True, the returned dataframe only includes columns that are primary + or foreign keys. """ - lineage = _START_LINEAGE if tableset is not None: df = tableset[table] else: df = rel_data.get_table_data(table) + + if keys_only: + df = df[rel_data.get_all_key_columns(table)] + + lineage = _START_LINEAGE df = df.add_prefix(f"{_START_LINEAGE}{_END_LINEAGE}") - return _join_parents(rel_data, df, table, lineage, tableset, ancestral_seeding) + return _join_parents( + rel_data, df, table, lineage, tableset, ancestral_seeding, keys_only + ) def _join_parents( @@ -113,6 +137,7 @@ def _join_parents( lineage: str, tableset: Optional[dict[str, pd.DataFrame]], ancestral_seeding: bool, + keys_only: bool, ) -> pd.DataFrame: for foreign_key in rel_data.get_foreign_keys(table): fk_lineage = _COL_DELIMITER.join(foreign_key.columns) @@ -122,6 +147,8 @@ def _join_parents( if ancestral_seeding: usecols = list(rel_data.get_safe_ancestral_seed_columns(parent_table_name)) + elif keys_only: + usecols = rel_data.get_all_key_columns(parent_table_name) else: usecols = rel_data.get_table_columns(parent_table_name) @@ -141,7 +168,13 @@ def _join_parents( ) df = _join_parents( - rel_data, df, parent_table_name, next_lineage, tableset, ancestral_seeding + rel_data, + df, + parent_table_name, + next_lineage, + tableset, + ancestral_seeding, + keys_only, ) return df diff --git a/src/gretel_trainer/relational/core.py b/src/gretel_trainer/relational/core.py index c3aacccf..bb668b65 100644 --- a/src/gretel_trainer/relational/core.py +++ b/src/gretel_trainer/relational/core.py @@ -754,7 +754,7 @@ def get_all_key_columns(self, table: str) -> list[str]: all_key_cols.extend(self.get_primary_key(table)) for fk in self.get_foreign_keys(table): all_key_cols.extend(fk.columns) - return all_key_cols + return sorted(list(set(all_key_cols))) def debug_summary(self) -> dict[str, Any]: max_depth = dag_longest_path_length(self.graph) diff --git a/src/gretel_trainer/relational/model_config.py b/src/gretel_trainer/relational/model_config.py index b76b84bb..b921a3dc 100644 --- a/src/gretel_trainer/relational/model_config.py +++ b/src/gretel_trainer/relational/model_config.py @@ -38,9 +38,9 @@ def make_classify_config(table: str, config: GretelModelConfig) -> dict[str, Any return tailored_config -def make_evaluate_config(table: str) -> dict[str, Any]: +def make_evaluate_config(table: str, sqs_type: str) -> dict[str, Any]: tailored_config = ingest("evaluate/default") - tailored_config["name"] = _model_name("evaluate", table) + tailored_config["name"] = _model_name(sqs_type, table) return tailored_config diff --git a/src/gretel_trainer/relational/multi_table.py b/src/gretel_trainer/relational/multi_table.py index 37eceb8d..e4e4803c 100644 --- a/src/gretel_trainer/relational/multi_table.py +++ b/src/gretel_trainer/relational/multi_table.py @@ -22,6 +22,8 @@ import pandas as pd import smart_open +import gretel_trainer.relational.ancestry as ancestry + from gretel_client.config import get_session_config, RunnerMode from gretel_client.projects import create_project, get_project, Project from gretel_client.projects.jobs import ACTIVE_STATES, END_STATES, Status @@ -270,27 +272,10 @@ def _complete_init_from_backup(self, backup: Backup) -> None: "Cannot restore while model training is actively in progress." ) - training_succeeded = [ - table - for table, model in self._synthetics_train.models.items() - if model.status == Status.COMPLETED - ] - for table in training_succeeded: - if table in self.relational_data.list_all_tables(Scope.EVALUATABLE): - model = self._synthetics_train.models[table] - with silent_logs(): - self._strategy.update_evaluation_from_model( - table, - self._evaluations, - model, - self._working_dir, - self._extended_sdk, - ) - training_failed = [ table for table, model in self._synthetics_train.models.items() - if model.status in END_STATES and table not in training_succeeded + if model.status in END_STATES and model.status != Status.COMPLETED ] if len(training_failed) > 0: logger.info( @@ -720,17 +705,6 @@ def _train_synthetics_models(self, configs: dict[str, dict[str, Any]]) -> None: ) run_task(task, self._extended_sdk) - for table in task.completed: - if table in self.relational_data.list_all_tables(Scope.EVALUATABLE): - model = self._synthetics_train.models[table] - self._strategy.update_evaluation_from_model( - table, - self._evaluations, - model, - self._working_dir, - self._extended_sdk, - ) - def train_synthetics( self, *, @@ -901,7 +875,8 @@ def generate( evaluate_project = create_project( display_name=f"evaluate-{self._project.display_name}" ) - evaluate_models = {} + individual_evaluate_models = {} + cross_table_evaluate_models = {} for table, synth_df in output_tables.items(): if table in self._synthetics_run.preserved: continue @@ -912,49 +887,43 @@ def generate( if table not in self.relational_data.list_all_tables(Scope.EVALUATABLE): continue - evaluate_data = self._strategy.get_evaluate_model_data( - rel_data=self.relational_data, - table_name=table, + # Create an evaluate model for individual SQS + individual_data = self._get_individual_evaluate_data( + table=table, synthetic_tables=output_tables, ) - if evaluate_data is None: - continue + individual_sqs_job = evaluate_project.create_model_obj( + model_config=make_evaluate_config(table, "individual"), + data_source=individual_data["synthetic"], + ref_data=individual_data["source"], + ) + individual_evaluate_models[table] = individual_sqs_job - evaluate_models[table] = evaluate_project.create_model_obj( - model_config=make_evaluate_config(table), - data_source=evaluate_data["synthetic"], - ref_data=evaluate_data["source"], + # Create an evaluate model for cross-table SQS (if we can/should) + cross_table_data = self._get_cross_table_evaluate_data( + table=table, + synthetic_tables=output_tables, ) + if cross_table_data is not None: + cross_table_sqs_job = evaluate_project.create_model_obj( + model_config=make_evaluate_config(table, "cross_table"), + data_source=cross_table_data["synthetic"], + ref_data=cross_table_data["source"], + ) + cross_table_evaluate_models[table] = cross_table_sqs_job synthetics_evaluate_task = SyntheticsEvaluateTask( - evaluate_models=evaluate_models, + individual_evaluate_models=individual_evaluate_models, + cross_table_evaluate_models=cross_table_evaluate_models, project=evaluate_project, + run_dir=run_dir, + evaluations=self._evaluations, multitable=self, ) run_task(synthetics_evaluate_task, self._extended_sdk) - # Tables passed to task were already scoped to evaluatable tables - for table in synthetics_evaluate_task.completed: - self._strategy.update_evaluation_from_evaluate( - table_name=table, - evaluate_model=evaluate_models[table], - evaluations=self._evaluations, - working_dir=self._working_dir, - extended_sdk=self._extended_sdk, - ) - evaluate_project.delete() - for table_name in output_tables: - for eval_type in ["individual", "cross_table"]: - for ext in ["html", "json"]: - filename = f"synthetics_{eval_type}_evaluation_{table_name}.{ext}" - with suppress(FileNotFoundError): - shutil.copyfile( - src=self._working_dir / filename, - dst=run_dir / filename, - ) - logger.info("Creating relational report") self.create_relational_report( run_identifier=self._synthetics_run.identifier, @@ -975,6 +944,66 @@ def generate( self.synthetic_output_tables = reshaped_tables self._backup() + def _get_individual_evaluate_data( + self, table: str, synthetic_tables: dict[str, pd.DataFrame] + ) -> dict[str, pd.DataFrame]: + """ + Returns a dictionary containing source and synthetic versions of a table, + to be used in an Evaluate job. + + Removes all key columns to avoid artificially deflating the score + (key types may not match, and key values carry no semantic meaning). + """ + all_cols = self.relational_data.get_table_columns(table) + key_cols = self.relational_data.get_all_key_columns(table) + use_cols = [c for c in all_cols if c not in key_cols] + + return { + "source": self.relational_data.get_table_data(table, usecols=use_cols), + "synthetic": synthetic_tables[table].drop(columns=key_cols), + } + + def _get_cross_table_evaluate_data( + self, table: str, synthetic_tables: dict[str, pd.DataFrame] + ) -> Optional[dict[str, pd.DataFrame]]: + """ + Returns a dictionary containing source and synthetic versions of a table + with ancestral data attached, to be used in an Evaluate job for cross-table SQS. + + Removes all key columns to avoid artificially deflating the score + (key types may not match, and key values carry no semantic meaning). + + Returns None if a cross-table SQS job cannot or should not be performed. + """ + # Exit early if table does not have parents (no need for cross-table evaluation) + if len(self.relational_data.get_parents(table)) == 0: + return None + + # Exit early if we can't create synthetic cross-table data + # (e.g. parent data missing due to job failure) + missing_ancestors = [ + ancestor + for ancestor in self.relational_data.get_ancestors(table) + if ancestor not in synthetic_tables + ] + if len(missing_ancestors) > 0: + logger.info( + f"Cannot run cross_table evaluations for `{table}` because no synthetic data exists for ancestor tables {missing_ancestors}." + ) + return None + + source_data = ancestry.get_table_data_with_ancestors( + self.relational_data, table + ) + synthetic_data = ancestry.get_table_data_with_ancestors( + self.relational_data, table, synthetic_tables + ) + key_cols = ancestry.get_all_key_columns(self.relational_data, table) + return { + "source": source_data.drop(columns=key_cols), + "synthetic": synthetic_data.drop(columns=key_cols), + } + def create_relational_report(self, run_identifier: str, target_dir: Path) -> None: presenter = ReportPresenter( rel_data=self.relational_data, diff --git a/src/gretel_trainer/relational/sdk_extras.py b/src/gretel_trainer/relational/sdk_extras.py index 6ea05285..844c679d 100644 --- a/src/gretel_trainer/relational/sdk_extras.py +++ b/src/gretel_trainer/relational/sdk_extras.py @@ -32,9 +32,12 @@ def get_job_id(self, job: Job) -> Optional[str]: else: raise MultiTableException("Unexpected job object") - def delete_data_source(self, project: Project, job: Job) -> None: - if not self._hybrid and job.data_source is not None: - project.delete_artifact(job.data_source) + def delete_data_sources(self, project: Project, job: Job) -> None: + if not self._hybrid: + if job.data_source is not None: + project.delete_artifact(job.data_source) + for ref_data in job.ref_data.values: + project.delete_artifact(ref_data) def cautiously_refresh_status( self, job: Job, key: str, refresh_attempts: dict[str, int] diff --git a/src/gretel_trainer/relational/strategies/ancestral.py b/src/gretel_trainer/relational/strategies/ancestral.py index 7f49fa86..99e1d649 100644 --- a/src/gretel_trainer/relational/strategies/ancestral.py +++ b/src/gretel_trainer/relational/strategies/ancestral.py @@ -1,21 +1,18 @@ import logging from pathlib import Path -from typing import Any, Optional, Union +from typing import Any, Union import pandas as pd import gretel_trainer.relational.ancestry as ancestry import gretel_trainer.relational.strategies.common as common -from gretel_client.projects.models import Model from gretel_trainer.relational.core import ( GretelModelConfig, MultiTableException, RelationalData, ) -from gretel_trainer.relational.sdk_extras import ExtendedGretelSDK -from gretel_trainer.relational.table_evaluation import TableEvaluation logger = logging.getLogger(__name__) @@ -298,51 +295,6 @@ def post_process_synthetic_results( for table_name, df in synth_tables.items() } - def update_evaluation_from_model( - self, - table_name: str, - evaluations: dict[str, TableEvaluation], - model: Model, - working_dir: Path, - extended_sdk: ExtendedGretelSDK, - ) -> None: - logger.info(f"Downloading cross_table evaluation reports for `{table_name}`.") - out_filepath = working_dir / f"synthetics_cross_table_evaluation_{table_name}" - common.download_artifacts(model, out_filepath, extended_sdk) - - evaluation = evaluations[table_name] - evaluation.cross_table_report_json = common.read_report_json_data( - model, out_filepath - ) - - def get_evaluate_model_data( - self, - table_name: str, - rel_data: RelationalData, - synthetic_tables: dict[str, pd.DataFrame], - ) -> Optional[dict[str, pd.DataFrame]]: - return { - "source": rel_data.get_table_data(table_name), - "synthetic": synthetic_tables[table_name], - } - - def update_evaluation_from_evaluate( - self, - table_name: str, - evaluations: dict[str, TableEvaluation], - evaluate_model: Model, - working_dir: Path, - extended_sdk: ExtendedGretelSDK, - ) -> None: - logger.info(f"Downloading individual evaluation reports for `{table_name}`.") - out_filepath = working_dir / f"synthetics_individual_evaluation_{table_name}" - common.download_artifacts(evaluate_model, out_filepath, extended_sdk) - - evaluation = evaluations[table_name] - evaluation.individual_report_json = common.read_report_json_data( - evaluate_model, out_filepath - ) - def _add_artifical_rows_for_seeding( rel_data: RelationalData, tables: dict[str, pd.DataFrame], omitted: list[str] diff --git a/src/gretel_trainer/relational/strategies/common.py b/src/gretel_trainer/relational/strategies/common.py index 8ed6bbc1..efdc496b 100644 --- a/src/gretel_trainer/relational/strategies/common.py +++ b/src/gretel_trainer/relational/strategies/common.py @@ -1,52 +1,17 @@ -import json import logging import random -from pathlib import Path from typing import Optional import pandas as pd -import smart_open from sklearn import preprocessing -from gretel_client.projects.models import Model from gretel_trainer.relational.core import MultiTableException, RelationalData -from gretel_trainer.relational.sdk_extras import ExtendedGretelSDK logger = logging.getLogger(__name__) -def download_artifacts( - model: Model, out_filepath: Path, extended_sdk: ExtendedGretelSDK -) -> None: - """ - Downloads all model artifacts to a subdirectory in the working directory. - """ - legend = {"html": "report", "json": "report_json"} - - for filetype, artifact_name in legend.items(): - out_path = f"{out_filepath}.{filetype}" - extended_sdk.download_file_artifact(model, artifact_name, out_path) - - -def read_report_json_data(model: Model, report_path: Path) -> Optional[dict]: - full_path = f"{report_path}.json" - try: - return json.loads(smart_open.open(full_path).read()) - except: - return _get_report_json(model) - - -def _get_report_json(model: Model) -> Optional[dict]: - try: - with model.get_artifact_handle("report_json") as report: - return json.loads(report.read()) - except: - logger.warning("Failed to fetch model evaluation report JSON.") - return None - - def label_encode_keys( rel_data: RelationalData, tables: dict[str, pd.DataFrame], diff --git a/src/gretel_trainer/relational/strategies/independent.py b/src/gretel_trainer/relational/strategies/independent.py index 78c60d6a..dfb4cdd1 100644 --- a/src/gretel_trainer/relational/strategies/independent.py +++ b/src/gretel_trainer/relational/strategies/independent.py @@ -2,17 +2,13 @@ import random from pathlib import Path -from typing import Any, Optional +from typing import Any import pandas as pd -import gretel_trainer.relational.ancestry as ancestry import gretel_trainer.relational.strategies.common as common -from gretel_client.projects.models import Model from gretel_trainer.relational.core import GretelModelConfig, RelationalData -from gretel_trainer.relational.sdk_extras import ExtendedGretelSDK -from gretel_trainer.relational.table_evaluation import TableEvaluation logger = logging.getLogger(__name__) @@ -147,66 +143,6 @@ def post_process_synthetic_results( synth_tables = _synthesize_foreign_keys(synth_tables, rel_data) return synth_tables - def update_evaluation_from_model( - self, - table_name: str, - evaluations: dict[str, TableEvaluation], - model: Model, - working_dir: Path, - extended_sdk: ExtendedGretelSDK, - ) -> None: - logger.info(f"Downloading individual evaluation reports for `{table_name}`.") - out_filepath = working_dir / f"synthetics_individual_evaluation_{table_name}" - common.download_artifacts(model, out_filepath, extended_sdk) - - evaluation = evaluations[table_name] - evaluation.individual_report_json = common.read_report_json_data( - model, out_filepath - ) - - def get_evaluate_model_data( - self, - table_name: str, - rel_data: RelationalData, - synthetic_tables: dict[str, pd.DataFrame], - ) -> Optional[dict[str, pd.DataFrame]]: - missing_ancestors = [ - ancestor - for ancestor in rel_data.get_ancestors(table_name) - if ancestor not in synthetic_tables - ] - if len(missing_ancestors) > 0: - logger.info( - f"Cannot run cross_table evaluations for `{table_name}` because no synthetic data exists for ancestor tables {missing_ancestors}." - ) - return None - - source_data = ancestry.get_table_data_with_ancestors(rel_data, table_name) - synthetic_data = ancestry.get_table_data_with_ancestors( - rel_data, table_name, synthetic_tables - ) - return { - "source": source_data, - "synthetic": synthetic_data, - } - - def update_evaluation_from_evaluate( - self, - table_name: str, - evaluations: dict[str, TableEvaluation], - evaluate_model: Model, - working_dir: Path, - extended_sdk: ExtendedGretelSDK, - ) -> None: - logger.info(f"Downloading cross table evaluation reports for `{table_name}`.") - out_filepath = working_dir / f"synthetics_cross_table_evaluation_{table_name}" - common.download_artifacts(evaluate_model, out_filepath, extended_sdk) - - evaluation = evaluations[table_name] - evaluation.cross_table_report_json = common.read_report_json_data( - evaluate_model, out_filepath - ) - def _synthesize_primary_keys( synth_tables: dict[str, pd.DataFrame], diff --git a/src/gretel_trainer/relational/tasks/common.py b/src/gretel_trainer/relational/tasks/common.py index 45ab051b..f896d12a 100644 --- a/src/gretel_trainer/relational/tasks/common.py +++ b/src/gretel_trainer/relational/tasks/common.py @@ -62,4 +62,4 @@ def log_lost_contact(table_name: str) -> None: def cleanup(sdk: ExtendedGretelSDK, project: Project, job: Job) -> None: - sdk.delete_data_source(project, job) + sdk.delete_data_sources(project, job) diff --git a/src/gretel_trainer/relational/tasks/synthetics_evaluate.py b/src/gretel_trainer/relational/tasks/synthetics_evaluate.py index 8943c417..f8924430 100644 --- a/src/gretel_trainer/relational/tasks/synthetics_evaluate.py +++ b/src/gretel_trainer/relational/tasks/synthetics_evaluate.py @@ -1,8 +1,20 @@ +import json +import logging + +from pathlib import Path +from typing import Optional + +import smart_open + import gretel_trainer.relational.tasks.common as common from gretel_client.projects.jobs import Job from gretel_client.projects.models import Model from gretel_client.projects.projects import Project +from gretel_trainer.relational.sdk_extras import ExtendedGretelSDK +from gretel_trainer.relational.table_evaluation import TableEvaluation + +logger = logging.getLogger(__name__) ACTION = "synthetic data evaluation" @@ -10,12 +22,21 @@ class SyntheticsEvaluateTask: def __init__( self, - evaluate_models: dict[str, Model], + individual_evaluate_models: dict[str, Model], + cross_table_evaluate_models: dict[str, Model], project: Project, + run_dir: Path, + evaluations: dict[str, TableEvaluation], multitable: common._MultiTable, ): - self.evaluate_models = evaluate_models + self.jobs = {} + for table, model in individual_evaluate_models.items(): + self.jobs[f"individual-{table}"] = model + for table, model in cross_table_evaluate_models.items(): + self.jobs[f"cross_table-{table}"] = model self.project = project + self.run_dir = run_dir + self.evaluations = evaluations self.multitable = multitable self.completed = [] self.failed = [] @@ -25,27 +46,44 @@ def action(self, job: Job) -> str: @property def table_collection(self) -> list[str]: - return list(self.evaluate_models.keys()) + return list(self.jobs.keys()) @property def artifacts_per_job(self) -> int: return 2 def more_to_do(self) -> bool: - return len(self.completed + self.failed) < len(self.evaluate_models) + return len(self.completed + self.failed) < len(self.jobs) def wait(self) -> None: - common.wait(20) + common.wait(self.multitable._refresh_interval) def is_finished(self, table: str) -> bool: return table in (self.completed + self.failed) def get_job(self, table: str) -> Job: - return self.evaluate_models[table] + return self.jobs[table] def handle_completed(self, table: str, job: Job) -> None: self.completed.append(table) common.log_success(table, ACTION) + + model = self.get_job(table) + if table.startswith("individual-"): + table_name = table.removeprefix("individual-") + out_filepath = ( + self.run_dir / f"synthetics_individual_evaluation_{table_name}" + ) + data = _get_reports(model, out_filepath, self.multitable._extended_sdk) + self.evaluations[table_name].individual_report_json = data + else: + table_name = table.removeprefix("cross_table-") + out_filepath = ( + self.run_dir / f"synthetics_cross_table_evaluation_{table_name}" + ) + data = _get_reports(model, out_filepath, self.multitable._extended_sdk) + self.evaluations[table_name].cross_table_report_json = data + common.cleanup(sdk=self.multitable._extended_sdk, project=self.project, job=job) def handle_failed(self, table: str, job: Job) -> None: @@ -63,3 +101,42 @@ def handle_in_progress(self, table: str, job: Job) -> None: def each_iteration(self) -> None: pass + + +def _get_reports( + model: Model, out_filepath: Path, extended_sdk: ExtendedGretelSDK +) -> Optional[dict]: + _download_reports(model, out_filepath, extended_sdk) + return _read_json_report(model, out_filepath) + + +def _download_reports( + model: Model, out_filepath: Path, extended_sdk: ExtendedGretelSDK +) -> None: + """ + Downloads model reports to the provided path. + """ + legend = {"html": "report", "json": "report_json"} + + for filetype, artifact_name in legend.items(): + out_path = f"{out_filepath}.{filetype}" + extended_sdk.download_file_artifact(model, artifact_name, out_path) + + +def _read_json_report(model: Model, out_filepath: Path) -> Optional[dict]: + """ + Reads the JSON report data in to a dictionary to be appended to the MultiTable + evaluations property. First try reading the file we just downloaded to the run + directory. If that fails, try reading the data remotely from the model. If that + also fails, log a warning and give up gracefully. + """ + full_path = f"{out_filepath}.json" + try: + return json.loads(smart_open.open(full_path).read()) + except: + try: + with model.get_artifact_handle("report_json") as report: + return json.loads(report.read()) + except: + logger.warning("Failed to fetch model evaluation report JSON.") + return None diff --git a/tests/relational/test_ancestral_strategy.py b/tests/relational/test_ancestral_strategy.py index 92834f93..f18c869f 100644 --- a/tests/relational/test_ancestral_strategy.py +++ b/tests/relational/test_ancestral_strategy.py @@ -13,7 +13,6 @@ from gretel_trainer.relational.core import MultiTableException from gretel_trainer.relational.strategies.ancestral import AncestralStrategy -from gretel_trainer.relational.table_evaluation import TableEvaluation def test_preparing_training_data_does_not_mutate_source_data(pets): @@ -669,67 +668,3 @@ def test_post_process_synthetic_results(ecom): pdtest.assert_frame_equal(expected_events, processed_tables["events"]) pdtest.assert_frame_equal(expected_users, processed_tables["users"]) - - -def test_uses_trained_model_to_update_cross_table_scores( - report_json_dict, extended_sdk -): - strategy = AncestralStrategy() - evaluations = { - "table_1": TableEvaluation(), - "table_2": TableEvaluation(), - } - model = Mock() - - with tempfile.TemporaryDirectory() as working_dir, patch( - "gretel_trainer.relational.strategies.ancestral.common.download_artifacts" - ) as download_artifacts: - working_dir = Path(working_dir) - with open( - working_dir / "synthetics_cross_table_evaluation_table_1.json", "w" - ) as f: - f.write(json.dumps(report_json_dict)) - - strategy.update_evaluation_from_model( - "table_1", evaluations, model, working_dir, extended_sdk - ) - - evaluation = evaluations["table_1"] - - assert evaluation.cross_table_sqs == 95 - assert evaluation.cross_table_report_json == report_json_dict - - assert evaluation.individual_sqs is None - assert evaluation.individual_report_json is None - - -def test_falls_back_to_fetching_report_json_when_download_artifacts_fails( - report_json_dict, extended_sdk -): - strategy = AncestralStrategy() - evaluations = { - "table_1": TableEvaluation(), - "table_2": TableEvaluation(), - } - model = Mock() - - with tempfile.TemporaryDirectory() as working_dir, patch( - "gretel_trainer.relational.strategies.ancestral.common.download_artifacts" - ) as download_artifacts, patch( - "gretel_trainer.relational.strategies.ancestral.common._get_report_json" - ) as get_json: - working_dir = Path(working_dir) - download_artifacts.return_value = None - get_json.return_value = report_json_dict - - strategy.update_evaluation_from_model( - "table_1", evaluations, model, working_dir, extended_sdk - ) - - evaluation = evaluations["table_1"] - - assert evaluation.cross_table_sqs == 95 - assert evaluation.cross_table_report_json == report_json_dict - - assert evaluation.individual_sqs is None - assert evaluation.individual_report_json is None diff --git a/tests/relational/test_ancestry.py b/tests/relational/test_ancestry.py index db477f24..f2116d4b 100644 --- a/tests/relational/test_ancestry.py +++ b/tests/relational/test_ancestry.py @@ -139,6 +139,43 @@ def test_primary_key_in_multigenerational_format(mutagenesis): ] +def test_get_all_key_columns(ecom, mutagenesis): + assert set(ancestry.get_all_key_columns(ecom, "distribution_center")) == {"self|id"} + assert set(ancestry.get_all_key_columns(ecom, "events")) == { + "self|id", + "self|user_id", + "self.user_id|id", + } + assert set(ancestry.get_all_key_columns(ecom, "inventory_items")) == { + "self|id", + "self|product_id", + "self|product_distribution_center_id", + "self.product_id|id", + "self.product_id|distribution_center_id", + "self.product_distribution_center_id|id", + "self.product_id.distribution_center_id|id", + } + + assert set(ancestry.get_all_key_columns(mutagenesis, "molecule")) == { + "self|molecule_id" + } + assert set(ancestry.get_all_key_columns(mutagenesis, "atom")) == { + "self|atom_id", + "self|molecule_id", + "self.molecule_id|molecule_id", + } + assert set(ancestry.get_all_key_columns(mutagenesis, "bond")) == { + "self|atom1_id", + "self|atom2_id", + "self.atom1_id|atom_id", + "self.atom1_id|molecule_id", + "self.atom1_id.molecule_id|molecule_id", + "self.atom2_id|atom_id", + "self.atom2_id|molecule_id", + "self.atom2_id.molecule_id|molecule_id", + } + + def test_ancestral_foreign_key_maps(ecom): events_afk_maps = ancestry.get_ancestral_foreign_key_maps(ecom, "events") assert events_afk_maps == [("self|user_id", "self.user_id|id")] diff --git a/tests/relational/test_independent_strategy.py b/tests/relational/test_independent_strategy.py index 03547b95..52c901ed 100644 --- a/tests/relational/test_independent_strategy.py +++ b/tests/relational/test_independent_strategy.py @@ -10,7 +10,6 @@ import pandas.testing as pdtest from gretel_trainer.relational.strategies.independent import IndependentStrategy -from gretel_trainer.relational.table_evaluation import TableEvaluation def test_preparing_training_data_does_not_mutate_source_data(pets): @@ -171,65 +170,3 @@ def test_post_processing_foreign_keys_with_skewed_frequencies_and_different_size fk_value_counts = sorted(list(fk_value_counts.values())) assert fk_value_counts == [5, 5, 15, 30, 35, 60] - - -def test_uses_trained_model_to_update_individual_scores(report_json_dict, extended_sdk): - strategy = IndependentStrategy() - evaluations = { - "table_1": TableEvaluation(), - "table_2": TableEvaluation(), - } - model = Mock() - - with tempfile.TemporaryDirectory() as working_dir, patch( - "gretel_trainer.relational.strategies.independent.common.download_artifacts" - ) as download_artifacts: - working_dir = Path(working_dir) - with open( - working_dir / "synthetics_individual_evaluation_table_1.json", "w" - ) as f: - f.write(json.dumps(report_json_dict)) - - strategy.update_evaluation_from_model( - "table_1", evaluations, model, working_dir, extended_sdk - ) - - evaluation = evaluations["table_1"] - - assert evaluation.individual_sqs == 95 - assert evaluation.individual_report_json == report_json_dict - - assert evaluation.cross_table_sqs is None - assert evaluation.cross_table_report_json is None - - -def test_falls_back_to_fetching_report_json_when_download_artifacts_fails( - report_json_dict, extended_sdk -): - strategy = IndependentStrategy() - evaluations = { - "table_1": TableEvaluation(), - "table_2": TableEvaluation(), - } - model = Mock() - - with tempfile.TemporaryDirectory() as working_dir, patch( - "gretel_trainer.relational.strategies.independent.common.download_artifacts" - ) as download_artifacts, patch( - "gretel_trainer.relational.strategies.independent.common._get_report_json" - ) as get_json: - working_dir = Path(working_dir) - download_artifacts.return_value = None - get_json.return_value = report_json_dict - - strategy.update_evaluation_from_model( - "table_1", evaluations, model, working_dir, extended_sdk - ) - - evaluation = evaluations["table_1"] - - assert evaluation.individual_sqs == 95 - assert evaluation.individual_report_json == report_json_dict - - assert evaluation.cross_table_sqs is None - assert evaluation.cross_table_report_json is None diff --git a/tests/relational/test_model_config.py b/tests/relational/test_model_config.py index 3fbb3c8e..affc60ce 100644 --- a/tests/relational/test_model_config.py +++ b/tests/relational/test_model_config.py @@ -23,9 +23,9 @@ def test_get_model_key(): assert get_model_key({"models": ["wrong type"]}) is None -def test_evaluate_config_prepends_workflow(): - config = make_evaluate_config("users") - assert config["name"] == "evaluate-users" +def test_evaluate_config_prepends_evaluation_type(): + config = make_evaluate_config("users", "individual") + assert config["name"] == "individual-users" def test_synthetics_config_prepends_workflow(): diff --git a/tests/relational/test_multi_table_restore.py b/tests/relational/test_multi_table_restore.py index 872aaa5a..2890ef8a 100644 --- a/tests/relational/test_multi_table_restore.py +++ b/tests/relational/test_multi_table_restore.py @@ -493,24 +493,15 @@ def test_restore_training_complete( mt = MultiTable.restore(backup_file) # Backup + Debug summary + Source archive + (2) Source CSVs - # + Training archive + (2) Train CSVs + (4) Reports from models - assert len(os.listdir(working_dir)) == 12 + # + Training archive + (2) Train CSVs + assert len(os.listdir(working_dir)) == 8 # Training state is restored assert os.path.exists(local_file(working_dir, "synthetics_training_archive")) assert os.path.exists(local_file(working_dir, "train_humans")) - assert os.path.exists(working_dir / "synthetics_individual_evaluation_humans.json") - assert os.path.exists(working_dir / "synthetics_individual_evaluation_humans.html") assert os.path.exists(local_file(working_dir, "train_pets")) - assert os.path.exists(working_dir / "synthetics_individual_evaluation_pets.json") - assert os.path.exists(working_dir / "synthetics_individual_evaluation_pets.html") assert len(mt._synthetics_train.models) == 2 - assert mt.evaluations["humans"].individual_sqs == 95 - assert mt.evaluations["humans"].cross_table_sqs is None - assert mt.evaluations["pets"].individual_sqs == 95 - assert mt.evaluations["pets"].cross_table_sqs is None - def test_restore_training_one_failed( project, pets, report_json_dict, download_tar_artifact, working_dir, testsetup_dir @@ -547,31 +538,15 @@ def test_restore_training_one_failed( mt = MultiTable.restore(backup_file) # Backup + Debug summary + Source archive + (2) Source CSVs - # Training archive + (2) Train CSVs + (2) Reports from model - assert len(os.listdir(working_dir)) == 10 + # Training archive + (2) Train CSVs + assert len(os.listdir(working_dir)) == 8 # Training state is restored assert os.path.exists(local_file(working_dir, "synthetics_training_archive")) - ## We do expect the training CSV to be present, extracted from the training archive... assert os.path.exists(local_file(working_dir, "train_humans")) - ## ...but we should not see evaluation reports because the table failed to train. - - assert not os.path.exists( - working_dir / "synthetics_individual_evaluation_humans.json" - ) - assert not os.path.exists( - working_dir / "synthetics_individual_evaluation_humans.html" - ) - assert os.path.exists(local_file(working_dir, "train_pets")) - assert os.path.exists(working_dir / "synthetics_individual_evaluation_pets.json") - assert os.path.exists(working_dir / "synthetics_individual_evaluation_pets.html") - assert len(mt._synthetics_train.models) == 2 - assert mt.evaluations["humans"].individual_sqs is None - assert mt.evaluations["humans"].cross_table_sqs is None - assert mt.evaluations["pets"].individual_sqs == 95 - assert mt.evaluations["pets"].cross_table_sqs is None + assert len(mt._synthetics_train.models) == 2 def test_restore_generate_completed( @@ -619,9 +594,9 @@ def test_restore_generate_completed( mt = MultiTable.restore(backup_file) # Backup + Debug summary + Source archive + (2) Source CSVs - # + Training archive + (2) Train CSVs + (4) Reports from models + # + Training archive + (2) Train CSVs # + Outputs archive + Previous run subdirectory - assert len(os.listdir(working_dir)) == 14 + assert len(os.listdir(working_dir)) == 10 # Generate state is restored assert os.path.exists(working_dir / "run-id" / "synth_humans.csv") @@ -631,6 +606,12 @@ def test_restore_generate_completed( assert os.path.exists( working_dir / "run-id" / "synthetics_cross_table_evaluation_humans.html" ) + assert os.path.exists( + working_dir / "run-id" / "synthetics_individual_evaluation_humans.json" + ) + assert os.path.exists( + working_dir / "run-id" / "synthetics_individual_evaluation_humans.html" + ) assert os.path.exists(working_dir / "run-id" / "synth_pets.csv") assert os.path.exists( working_dir / "run-id" / "synthetics_cross_table_evaluation_pets.json" @@ -638,6 +619,12 @@ def test_restore_generate_completed( assert os.path.exists( working_dir / "run-id" / "synthetics_cross_table_evaluation_pets.html" ) + assert os.path.exists( + working_dir / "run-id" / "synthetics_individual_evaluation_pets.json" + ) + assert os.path.exists( + working_dir / "run-id" / "synthetics_individual_evaluation_pets.html" + ) assert mt._synthetics_run is not None assert len(mt.synthetic_output_tables) == 2 assert mt.evaluations["humans"].individual_sqs == 95 @@ -691,8 +678,8 @@ def test_restore_generate_in_progress( mt = MultiTable.restore(backup_file) # Backup + Debug summary + Source archive + (2) Source CSVs - # + Training archive + (2) Train CSVs + (4) Reports from models - assert len(os.listdir(working_dir)) == 12 + # + Training archive + (2) Train CSVs + assert len(os.listdir(working_dir)) == 8 # Generate state is partially restored assert mt._synthetics_run == SyntheticsRun( @@ -703,10 +690,6 @@ def test_restore_generate_in_progress( record_handlers=synthetics_record_handlers, ) assert len(mt.synthetic_output_tables) == 0 - assert mt.evaluations["humans"].individual_sqs == 95 - assert mt.evaluations["humans"].cross_table_sqs is None - assert mt.evaluations["pets"].individual_sqs == 95 - assert mt.evaluations["pets"].cross_table_sqs is None def test_restore_hybrid_run(project, pets, report_json_dict, working_dir): diff --git a/tests/relational/test_synthetics_run_task.py b/tests/relational/test_synthetics_run_task.py index 39298594..a0896491 100644 --- a/tests/relational/test_synthetics_run_task.py +++ b/tests/relational/test_synthetics_run_task.py @@ -124,7 +124,7 @@ def post_process_individual_synthetic_result( "gretel_trainer.relational.sdk_extras.ExtendedGretelSDK.get_record_handler_data" ) as get_rh_data: get_rh_data.return_value = raw_df - task.handle_completed("table", Mock()) + task.handle_completed("table", Mock(ref_data=Mock(values=[]))) post_processed = task.working_tables["table"] assert post_processed is not None