Skip to content

Commit

Permalink
Relational evaluation
Browse files Browse the repository at this point in the history
GitOrigin-RevId: 92d2ddf6d7bccd26905fcb342aa01bf0151bd349
  • Loading branch information
mikeknep committed Sep 22, 2023
1 parent ed4ead7 commit 5e2c2e5
Show file tree
Hide file tree
Showing 16 changed files with 286 additions and 399 deletions.
47 changes: 40 additions & 7 deletions src/gretel_trainer/relational/ancestry.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,13 @@ def get_multigenerational_primary_key(
]


def get_all_key_columns(rel_data: RelationalData, table: str) -> list[str]:
tableset = _empty_data_tableset(rel_data)
return list(
get_table_data_with_ancestors(rel_data, table, tableset, keys_only=True).columns
)


def get_ancestral_foreign_key_maps(
rel_data: RelationalData, table: str
) -> list[tuple[str, str]]:
Expand Down Expand Up @@ -59,6 +66,13 @@ def _ancestral_fk_map(fk: ForeignKey) -> list[tuple[str, str]]:
]


def _empty_data_tableset(rel_data: RelationalData) -> dict[str, pd.DataFrame]:
return {
table: pd.DataFrame(columns=list(rel_data.get_table_columns(table)))
for table in rel_data.list_all_tables()
}


def get_seed_safe_multigenerational_columns(
rel_data: RelationalData,
) -> dict[str, list[str]]:
Expand All @@ -68,10 +82,7 @@ def get_seed_safe_multigenerational_columns(
a significantly faster / less resource-intensive way to get just the column names
from the results of `get_table_data_with_ancestors` for all tables.
"""
tableset = {
table: pd.DataFrame(columns=list(rel_data.get_table_columns(table)))
for table in rel_data.list_all_tables()
}
tableset = _empty_data_tableset(rel_data)
return {
table: list(
get_table_data_with_ancestors(
Expand All @@ -87,6 +98,7 @@ def get_table_data_with_ancestors(
table: str,
tableset: Optional[dict[str, pd.DataFrame]] = None,
ancestral_seeding: bool = False,
keys_only: bool = False,
) -> pd.DataFrame:
"""
Returns a data frame with all ancestral data joined to each record.
Expand All @@ -96,14 +108,26 @@ def get_table_data_with_ancestors(
separated by periods.
If `tableset` is provided, use it in place of the source data in `self.graph`.
If `ancestral_seeding` is True, the returned dataframe only includes columns
that can be used as conditional seeds.
If `keys_only` is True, the returned dataframe only includes columns that are primary
or foreign keys.
"""
lineage = _START_LINEAGE
if tableset is not None:
df = tableset[table]
else:
df = rel_data.get_table_data(table)

if keys_only:
df = df[rel_data.get_all_key_columns(table)]

lineage = _START_LINEAGE
df = df.add_prefix(f"{_START_LINEAGE}{_END_LINEAGE}")
return _join_parents(rel_data, df, table, lineage, tableset, ancestral_seeding)
return _join_parents(
rel_data, df, table, lineage, tableset, ancestral_seeding, keys_only
)


def _join_parents(
Expand All @@ -113,6 +137,7 @@ def _join_parents(
lineage: str,
tableset: Optional[dict[str, pd.DataFrame]],
ancestral_seeding: bool,
keys_only: bool,
) -> pd.DataFrame:
for foreign_key in rel_data.get_foreign_keys(table):
fk_lineage = _COL_DELIMITER.join(foreign_key.columns)
Expand All @@ -122,6 +147,8 @@ def _join_parents(

if ancestral_seeding:
usecols = list(rel_data.get_safe_ancestral_seed_columns(parent_table_name))
elif keys_only:
usecols = rel_data.get_all_key_columns(parent_table_name)
else:
usecols = rel_data.get_table_columns(parent_table_name)

Expand All @@ -141,7 +168,13 @@ def _join_parents(
)

df = _join_parents(
rel_data, df, parent_table_name, next_lineage, tableset, ancestral_seeding
rel_data,
df,
parent_table_name,
next_lineage,
tableset,
ancestral_seeding,
keys_only,
)
return df

Expand Down
2 changes: 1 addition & 1 deletion src/gretel_trainer/relational/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -754,7 +754,7 @@ def get_all_key_columns(self, table: str) -> list[str]:
all_key_cols.extend(self.get_primary_key(table))
for fk in self.get_foreign_keys(table):
all_key_cols.extend(fk.columns)
return all_key_cols
return sorted(list(set(all_key_cols)))

def debug_summary(self) -> dict[str, Any]:
max_depth = dag_longest_path_length(self.graph)
Expand Down
4 changes: 2 additions & 2 deletions src/gretel_trainer/relational/model_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,9 +38,9 @@ def make_classify_config(table: str, config: GretelModelConfig) -> dict[str, Any
return tailored_config


def make_evaluate_config(table: str) -> dict[str, Any]:
def make_evaluate_config(table: str, sqs_type: str) -> dict[str, Any]:
tailored_config = ingest("evaluate/default")
tailored_config["name"] = _model_name("evaluate", table)
tailored_config["name"] = _model_name(sqs_type, table)
return tailored_config


Expand Down
149 changes: 89 additions & 60 deletions src/gretel_trainer/relational/multi_table.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@
import pandas as pd
import smart_open

import gretel_trainer.relational.ancestry as ancestry

from gretel_client.config import get_session_config, RunnerMode
from gretel_client.projects import create_project, get_project, Project
from gretel_client.projects.jobs import ACTIVE_STATES, END_STATES, Status
Expand Down Expand Up @@ -270,27 +272,10 @@ def _complete_init_from_backup(self, backup: Backup) -> None:
"Cannot restore while model training is actively in progress."
)

training_succeeded = [
table
for table, model in self._synthetics_train.models.items()
if model.status == Status.COMPLETED
]
for table in training_succeeded:
if table in self.relational_data.list_all_tables(Scope.EVALUATABLE):
model = self._synthetics_train.models[table]
with silent_logs():
self._strategy.update_evaluation_from_model(
table,
self._evaluations,
model,
self._working_dir,
self._extended_sdk,
)

training_failed = [
table
for table, model in self._synthetics_train.models.items()
if model.status in END_STATES and table not in training_succeeded
if model.status in END_STATES and model.status != Status.COMPLETED
]
if len(training_failed) > 0:
logger.info(
Expand Down Expand Up @@ -720,17 +705,6 @@ def _train_synthetics_models(self, configs: dict[str, dict[str, Any]]) -> None:
)
run_task(task, self._extended_sdk)

for table in task.completed:
if table in self.relational_data.list_all_tables(Scope.EVALUATABLE):
model = self._synthetics_train.models[table]
self._strategy.update_evaluation_from_model(
table,
self._evaluations,
model,
self._working_dir,
self._extended_sdk,
)

def train_synthetics(
self,
*,
Expand Down Expand Up @@ -901,7 +875,8 @@ def generate(
evaluate_project = create_project(
display_name=f"evaluate-{self._project.display_name}"
)
evaluate_models = {}
individual_evaluate_models = {}
cross_table_evaluate_models = {}
for table, synth_df in output_tables.items():
if table in self._synthetics_run.preserved:
continue
Expand All @@ -912,49 +887,43 @@ def generate(
if table not in self.relational_data.list_all_tables(Scope.EVALUATABLE):
continue

evaluate_data = self._strategy.get_evaluate_model_data(
rel_data=self.relational_data,
table_name=table,
# Create an evaluate model for individual SQS
individual_data = self._get_individual_evaluate_data(
table=table,
synthetic_tables=output_tables,
)
if evaluate_data is None:
continue
individual_sqs_job = evaluate_project.create_model_obj(
model_config=make_evaluate_config(table, "individual"),
data_source=individual_data["synthetic"],
ref_data=individual_data["source"],
)
individual_evaluate_models[table] = individual_sqs_job

evaluate_models[table] = evaluate_project.create_model_obj(
model_config=make_evaluate_config(table),
data_source=evaluate_data["synthetic"],
ref_data=evaluate_data["source"],
# Create an evaluate model for cross-table SQS (if we can/should)
cross_table_data = self._get_cross_table_evaluate_data(
table=table,
synthetic_tables=output_tables,
)
if cross_table_data is not None:
cross_table_sqs_job = evaluate_project.create_model_obj(
model_config=make_evaluate_config(table, "cross_table"),
data_source=cross_table_data["synthetic"],
ref_data=cross_table_data["source"],
)
cross_table_evaluate_models[table] = cross_table_sqs_job

synthetics_evaluate_task = SyntheticsEvaluateTask(
evaluate_models=evaluate_models,
individual_evaluate_models=individual_evaluate_models,
cross_table_evaluate_models=cross_table_evaluate_models,
project=evaluate_project,
run_dir=run_dir,
evaluations=self._evaluations,
multitable=self,
)
run_task(synthetics_evaluate_task, self._extended_sdk)

# Tables passed to task were already scoped to evaluatable tables
for table in synthetics_evaluate_task.completed:
self._strategy.update_evaluation_from_evaluate(
table_name=table,
evaluate_model=evaluate_models[table],
evaluations=self._evaluations,
working_dir=self._working_dir,
extended_sdk=self._extended_sdk,
)

evaluate_project.delete()

for table_name in output_tables:
for eval_type in ["individual", "cross_table"]:
for ext in ["html", "json"]:
filename = f"synthetics_{eval_type}_evaluation_{table_name}.{ext}"
with suppress(FileNotFoundError):
shutil.copyfile(
src=self._working_dir / filename,
dst=run_dir / filename,
)

logger.info("Creating relational report")
self.create_relational_report(
run_identifier=self._synthetics_run.identifier,
Expand All @@ -975,6 +944,66 @@ def generate(
self.synthetic_output_tables = reshaped_tables
self._backup()

def _get_individual_evaluate_data(
self, table: str, synthetic_tables: dict[str, pd.DataFrame]
) -> dict[str, pd.DataFrame]:
"""
Returns a dictionary containing source and synthetic versions of a table,
to be used in an Evaluate job.
Removes all key columns to avoid artificially deflating the score
(key types may not match, and key values carry no semantic meaning).
"""
all_cols = self.relational_data.get_table_columns(table)
key_cols = self.relational_data.get_all_key_columns(table)
use_cols = [c for c in all_cols if c not in key_cols]

return {
"source": self.relational_data.get_table_data(table, usecols=use_cols),
"synthetic": synthetic_tables[table].drop(columns=key_cols),
}

def _get_cross_table_evaluate_data(
self, table: str, synthetic_tables: dict[str, pd.DataFrame]
) -> Optional[dict[str, pd.DataFrame]]:
"""
Returns a dictionary containing source and synthetic versions of a table
with ancestral data attached, to be used in an Evaluate job for cross-table SQS.
Removes all key columns to avoid artificially deflating the score
(key types may not match, and key values carry no semantic meaning).
Returns None if a cross-table SQS job cannot or should not be performed.
"""
# Exit early if table does not have parents (no need for cross-table evaluation)
if len(self.relational_data.get_parents(table)) == 0:
return None

# Exit early if we can't create synthetic cross-table data
# (e.g. parent data missing due to job failure)
missing_ancestors = [
ancestor
for ancestor in self.relational_data.get_ancestors(table)
if ancestor not in synthetic_tables
]
if len(missing_ancestors) > 0:
logger.info(
f"Cannot run cross_table evaluations for `{table}` because no synthetic data exists for ancestor tables {missing_ancestors}."
)
return None

source_data = ancestry.get_table_data_with_ancestors(
self.relational_data, table
)
synthetic_data = ancestry.get_table_data_with_ancestors(
self.relational_data, table, synthetic_tables
)
key_cols = ancestry.get_all_key_columns(self.relational_data, table)
return {
"source": source_data.drop(columns=key_cols),
"synthetic": synthetic_data.drop(columns=key_cols),
}

def create_relational_report(self, run_identifier: str, target_dir: Path) -> None:
presenter = ReportPresenter(
rel_data=self.relational_data,
Expand Down
9 changes: 6 additions & 3 deletions src/gretel_trainer/relational/sdk_extras.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,9 +32,12 @@ def get_job_id(self, job: Job) -> Optional[str]:
else:
raise MultiTableException("Unexpected job object")

def delete_data_source(self, project: Project, job: Job) -> None:
if not self._hybrid and job.data_source is not None:
project.delete_artifact(job.data_source)
def delete_data_sources(self, project: Project, job: Job) -> None:
if not self._hybrid:
if job.data_source is not None:
project.delete_artifact(job.data_source)
for ref_data in job.ref_data.values:
project.delete_artifact(ref_data)

def cautiously_refresh_status(
self, job: Job, key: str, refresh_attempts: dict[str, int]
Expand Down
Loading

0 comments on commit 5e2c2e5

Please sign in to comment.