Skip to content

Commit

Permalink
Preserve tables before training (#111)
Browse files Browse the repository at this point in the history
  • Loading branch information
mikeknep authored May 30, 2023
1 parent e89359a commit c08b3d9
Show file tree
Hide file tree
Showing 19 changed files with 406 additions and 176 deletions.
38 changes: 36 additions & 2 deletions src/gretel_trainer/relational/ancestry.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
def get_multigenerational_primary_key(
rel_data: RelationalData, table: str
) -> list[str]:
"Returns the provided table's primary key with the ancestral lineage prefix appended"
return [
f"{_START_LINEAGE}{_END_LINEAGE}{pk}" for pk in rel_data.get_primary_key(table)
]
Expand All @@ -22,6 +23,17 @@ def get_multigenerational_primary_key(
def get_ancestral_foreign_key_maps(
rel_data: RelationalData, table: str
) -> list[tuple[str, str]]:
"""
Returns a list of two-element tuples where the first element is a foreign key column
with ancestral lineage prefix, and the second element is the ancestral-lineage-prefixed
referred column. This function ultimately provides a list of which columns are duplicates
in a fully-joined ancestral table (i.e. `get_table_data_with_ancestors`) (only between
the provided table and its direct parents, not between parents and grandparents).
For example: given an events table with foreign key `events.user_id` => `users.id`,
this method returns: [("self|user_id", "self.user_id|id")]
"""

def _ancestral_fk_map(fk: ForeignKey) -> list[tuple[str, str]]:
maps = []
fk_lineage = _COL_DELIMITER.join(fk.columns)
Expand All @@ -46,6 +58,29 @@ def _ancestral_fk_map(fk: ForeignKey) -> list[tuple[str, str]]:
]


def get_seed_safe_multigenerational_columns(
rel_data: RelationalData,
) -> dict[str, list[str]]:
"""
Returns a dict with Scope.MODELABLE table names as keys and lists of columns to use
for conditional seeding as values. By using a tableset of empty dataframes, this provides
a significantly faster / less resource-intensive way to get just the column names
from the results of `get_table_data_with_ancestors` for all tables.
"""
tableset = {
table: pd.DataFrame(columns=list(rel_data.get_table_columns(table)))
for table in rel_data.list_all_tables()
}
return {
table: list(
get_table_data_with_ancestors(
rel_data, table, tableset, ancestral_seeding=True
).columns
)
for table in rel_data.list_all_tables()
}


def get_table_data_with_ancestors(
rel_data: RelationalData,
table: str,
Expand Down Expand Up @@ -93,10 +128,9 @@ def _join_parents(
parent_data = tableset[parent_table_name][list(usecols)]
else:
parent_data = rel_data.get_table_data(parent_table_name, usecols=usecols)
parent_data = parent_data.add_prefix(f"{next_lineage}{_END_LINEAGE}")

df = df.merge(
parent_data,
parent_data.add_prefix(f"{next_lineage}{_END_LINEAGE}"),
how="left",
left_on=[f"{lineage}{_END_LINEAGE}{col}" for col in foreign_key.columns],
right_on=[
Expand Down
2 changes: 0 additions & 2 deletions src/gretel_trainer/relational/backup.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,6 @@ class BackupTransformsTrain:
class BackupSyntheticsTrain:
model_ids: dict[str, str]
lost_contact: list[str]
training_columns: dict[str, list[str]]


@dataclass
Expand All @@ -103,7 +102,6 @@ class BackupGenerate:
record_size_ratio: float
record_handler_ids: dict[str, str]
lost_contact: list[str]
missing_model: list[str]


@dataclass
Expand Down
2 changes: 1 addition & 1 deletion src/gretel_trainer/relational/connectors.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ def __init__(self, engine: Engine):
logger.info("Successfully connected to db")

def extract(
self, only: Optional[list[str]] = None, ignore: Optional[list[str]] = None
self, only: Optional[set[str]] = None, ignore: Optional[set[str]] = None
) -> RelationalData:
"""
Extracts table data and relationships from the database.
Expand Down
2 changes: 1 addition & 1 deletion src/gretel_trainer/relational/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -674,7 +674,7 @@ def debug_summary(self) -> dict[str, Any]:


def skip_table(
table: str, only: Optional[list[str]], ignore: Optional[list[str]]
table: str, only: Optional[set[str]], ignore: Optional[set[str]]
) -> bool:
skip = False
if only is not None and table not in only:
Expand Down
145 changes: 83 additions & 62 deletions src/gretel_trainer/relational/multi_table.py
Original file line number Diff line number Diff line change
Expand Up @@ -251,9 +251,6 @@ def _complete_init_from_backup(self, backup: Backup) -> None:
with tarfile.open(synthetics_training_archive_path, "r:gz") as tar:
tar.extractall(path=self._working_dir)

self._synthetics_train.training_columns = (
backup_synthetics_train.training_columns
)
self._synthetics_train.lost_contact = backup_synthetics_train.lost_contact
self._synthetics_train.models = {
table: self._project.get_model(model_id)
Expand Down Expand Up @@ -339,7 +336,6 @@ def _complete_init_from_backup(self, backup: Backup) -> None:
identifier=backup_generate.identifier,
record_size_ratio=backup_generate.record_size_ratio,
preserved=backup_generate.preserved,
missing_model=backup_generate.missing_model,
lost_contact=backup_generate.lost_contact,
record_handlers=record_handlers,
)
Expand Down Expand Up @@ -458,7 +454,6 @@ def _build_backup(self) -> Backup:
for table, model in self._synthetics_train.models.items()
},
lost_contact=self._synthetics_train.lost_contact,
training_columns=self._synthetics_train.training_columns,
)

# Generate
Expand All @@ -467,7 +462,6 @@ def _build_backup(self) -> Backup:
identifier=self._synthetics_run.identifier,
record_size_ratio=self._synthetics_run.record_size_ratio,
preserved=self._synthetics_run.preserved,
missing_model=self._synthetics_run.missing_model,
lost_contact=self._synthetics_run.lost_contact,
record_handler_ids={
table: rh.record_id
Expand Down Expand Up @@ -602,34 +596,15 @@ def train_transforms(
self,
config: GretelModelConfig,
*,
only: Optional[list[str]] = None,
ignore: Optional[list[str]] = None,
only: Optional[set[str]] = None,
ignore: Optional[set[str]] = None,
) -> None:
if only is not None and ignore is not None:
raise MultiTableException("Cannot specify both `only` and `ignore`.")

m_only = None
if only is not None:
m_only = []
for table in only:
m_names = self.relational_data.get_modelable_table_names(table)
if len(m_names) == 0:
raise MultiTableException(f"Unrecognized table name: `{table}`")
m_only.extend(m_names)

m_ignore = None
if ignore is not None:
m_ignore = []
for table in ignore:
m_names = self.relational_data.get_modelable_table_names(table)
if len(m_names) == 0:
raise MultiTableException(f"Unrecognized table name: `{table}`")
m_ignore.extend(m_names)
only, ignore = self._get_only_and_ignore(only, ignore)

configs = {
table: config
for table in self.relational_data.list_all_tables()
if not skip_table(table, m_only, m_ignore)
if not skip_table(table, only, ignore)
}

self._setup_transforms_train_state(configs)
Expand Down Expand Up @@ -722,15 +697,35 @@ def run_transforms(
self._backup()
self.transform_output_tables = reshaped_tables

def _get_only_and_ignore(
self, only: Optional[set[str]], ignore: Optional[set[str]]
) -> tuple[Optional[set[str]], Optional[set[str]]]:
if only is not None and ignore is not None:
raise MultiTableException("Cannot specify both `only` and `ignore`.")

modelable_tables = set()
for table in only or ignore or {}:
m_names = self.relational_data.get_modelable_table_names(table)
if len(m_names) == 0:
raise MultiTableException(f"Unrecognized table name: `{table}`")
modelable_tables.update(m_names)

if only is None:
return (None, modelable_tables)
elif ignore is None:
return (modelable_tables, None)
else:
return (None, None)

def _prepare_training_data(self, tables: list[str]) -> dict[str, Path]:
"""
Exports a copy of each table prepared for training by the configured strategy
to the working directory. Returns a dict with table names as keys and Paths
to the CSVs as values.
"""
training_data = self._strategy.prepare_training_data(self.relational_data)
for table, df in training_data.items():
self._synthetics_train.training_columns[table] = list(df.columns)
training_data = self._strategy.prepare_training_data(
self.relational_data, tables
)
training_paths = {}

for table_name in tables:
Expand All @@ -748,6 +743,13 @@ def _train_synthetics_models(self, training_data: dict[str, Path]) -> None:
)
self._synthetics_train.models[table_name] = model

archive_path = self._working_dir / "synthetics_training.tar.gz"
for table_name, csv_path in training_data.items():
add_to_tar(archive_path, csv_path, csv_path.name)
self._artifact_collection.upload_synthetics_training_archive(
self._project, str(archive_path)
)

self._backup()

task = SyntheticsTrainTask(
Expand All @@ -767,22 +769,51 @@ def _train_synthetics_models(self, training_data: dict[str, Path]) -> None:
self._extended_sdk,
)

# TODO: consider moving this to before running the task
archive_path = self._working_dir / "synthetics_training.tar.gz"
for table_name, csv_path in training_data.items():
add_to_tar(archive_path, csv_path, csv_path.name)
self._artifact_collection.upload_synthetics_training_archive(
self._project, str(archive_path)
)

def train(self) -> None:
"""Train synthetic data models on each table in the relational dataset"""
"""
DEPRECATED: Please use `train_synthetics` instead.
"""
logger.warning(
"This method is deprecated and will be removed in a future release. "
"Please use `train_synthetics` instead."
)
tables = self.relational_data.list_all_tables()
self._synthetics_train = SyntheticsTrain()

training_data = self._prepare_training_data(tables)
self._train_synthetics_models(training_data)

def train_synthetics(
self,
*,
only: Optional[set[str]] = None,
ignore: Optional[set[str]] = None,
) -> None:
"""
Train synthetic data models for the tables in the tableset,
optionally scoped by either `only` or `ignore`.
"""
only, ignore = self._get_only_and_ignore(only, ignore)

all_tables = self.relational_data.list_all_tables()
include_tables = []
omit_tables = []
for table in all_tables:
if skip_table(table, only, ignore):
omit_tables.append(table)
else:
include_tables.append(table)

# TODO: Ancestral strategy requires that for each table omitted from synthetics ("preserved"),
# all its ancestors must also be omitted. In the future, it'd be nice to either find a way to
# eliminate this requirement completely, or (less ideal) allow incremental training of tables,
# e.g. train a few in one "batch", then a few more before generating (perhaps with some logs
# along the way informing the user of which required tables are missing).
self._strategy.validate_preserved_tables(omit_tables, self.relational_data)

training_data = self._prepare_training_data(include_tables)
self._train_synthetics_models(training_data)

def retrain_tables(self, tables: dict[str, pd.DataFrame]) -> None:
"""
Provide updated table data and retrain. This method overwrites the table data in the
Expand All @@ -801,7 +832,6 @@ def retrain_tables(self, tables: dict[str, pd.DataFrame]) -> None:
for table in tables_to_retrain:
with suppress(KeyError):
del self._synthetics_train.models[table]
del self._synthetics_train.training_columns[table]
training_data = self._prepare_training_data(tables_to_retrain)
self._train_synthetics_models(training_data)

Expand Down Expand Up @@ -861,18 +891,23 @@ def generate(
logger.info(f"Resuming synthetics run `{self._synthetics_run.identifier}`")
else:
preserve_tables = preserve_tables or []
preserve_tables.extend(
[
table
for table in self.relational_data.list_all_tables()
if table not in self._synthetics_train.models
]
)
self._strategy.validate_preserved_tables(
preserve_tables, self.relational_data
)

identifier = identifier or f"synthetics_{_timestamp()}"
missing_model = self._list_tables_with_missing_models()

self._synthetics_run = SyntheticsRun(
identifier=identifier,
record_size_ratio=record_size_ratio,
preserved=preserve_tables,
missing_model=missing_model,
record_handlers={},
lost_contact=[],
)
Expand Down Expand Up @@ -909,6 +944,9 @@ def generate(
if table in self._synthetics_run.preserved:
continue

if table not in self._synthetics_train.models:
continue

if table not in self.relational_data.list_all_tables(Scope.EVALUATABLE):
continue

Expand Down Expand Up @@ -982,21 +1020,6 @@ def create_relational_report(self, run_identifier: str, target_dir: Path) -> Non
html_content = ReportRenderer().render(presenter)
report.write(html_content)

def _list_tables_with_missing_models(self) -> list[str]:
missing_model = set()
for table in self.relational_data.list_all_tables():
if not _table_trained_successfully(self._synthetics_train, table):
logger.info(
f"Skipping synthetic data generation for `{table}` because it does not have a trained model"
)
missing_model.add(table)
for descendant in self.relational_data.get_descendants(table):
logger.info(
f"Skipping synthetic data generation for `{descendant}` because it depends on `{table}`"
)
missing_model.add(table)
return list(missing_model)

def _attach_existing_reports(self, run_id: str, table: str) -> None:
individual_path = (
self._working_dir
Expand Down Expand Up @@ -1046,9 +1069,7 @@ def _validate_strategy(strategy: str) -> Union[IndependentStrategy, AncestralStr
raise MultiTableException(msg)


def _table_trained_successfully(
train_state: Union[TransformsTrain, SyntheticsTrain], table: str
) -> bool:
def _table_trained_successfully(train_state: TransformsTrain, table: str) -> bool:
model = train_state.models.get(table)
if model is None:
return False
Expand Down
Loading

0 comments on commit c08b3d9

Please sign in to comment.