Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Preserve tables before training #111

Merged
merged 14 commits into from
May 30, 2023
Merged
Show file tree
Hide file tree
Changes from 9 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 18 additions & 2 deletions src/gretel_trainer/relational/ancestry.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,23 @@ def _ancestral_fk_map(fk: ForeignKey) -> list[tuple[str, str]]:
]


def get_seed_safe_multigenerational_columns(
rel_data: RelationalData,
) -> dict[str, list[str]]:
tableset = {
table: pd.DataFrame(columns=list(rel_data.get_table_columns(table)))
for table in rel_data.list_all_tables()
}
return {
table: list(
get_table_data_with_ancestors(
rel_data, table, tableset, ancestral_seeding=True
).columns
)
for table in rel_data.list_all_tables()
}


def get_table_data_with_ancestors(
rel_data: RelationalData,
table: str,
Expand Down Expand Up @@ -93,10 +110,9 @@ def _join_parents(
parent_data = tableset[parent_table_name][list(usecols)]
else:
parent_data = rel_data.get_table_data(parent_table_name, usecols=usecols)
parent_data = parent_data.add_prefix(f"{next_lineage}{_END_LINEAGE}")

df = df.merge(
parent_data,
parent_data.add_prefix(f"{next_lineage}{_END_LINEAGE}"),
how="left",
left_on=[f"{lineage}{_END_LINEAGE}{col}" for col in foreign_key.columns],
right_on=[
Expand Down
2 changes: 0 additions & 2 deletions src/gretel_trainer/relational/backup.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,6 @@ class BackupTransformsTrain:
class BackupSyntheticsTrain:
model_ids: dict[str, str]
lost_contact: list[str]
training_columns: dict[str, list[str]]


@dataclass
Expand All @@ -103,7 +102,6 @@ class BackupGenerate:
record_size_ratio: float
record_handler_ids: dict[str, str]
lost_contact: list[str]
missing_model: list[str]


@dataclass
Expand Down
174 changes: 114 additions & 60 deletions src/gretel_trainer/relational/multi_table.py
Original file line number Diff line number Diff line change
Expand Up @@ -251,9 +251,6 @@ def _complete_init_from_backup(self, backup: Backup) -> None:
with tarfile.open(synthetics_training_archive_path, "r:gz") as tar:
tar.extractall(path=self._working_dir)

self._synthetics_train.training_columns = (
backup_synthetics_train.training_columns
)
self._synthetics_train.lost_contact = backup_synthetics_train.lost_contact
self._synthetics_train.models = {
table: self._project.get_model(model_id)
Expand Down Expand Up @@ -339,7 +336,6 @@ def _complete_init_from_backup(self, backup: Backup) -> None:
identifier=backup_generate.identifier,
record_size_ratio=backup_generate.record_size_ratio,
preserved=backup_generate.preserved,
missing_model=backup_generate.missing_model,
lost_contact=backup_generate.lost_contact,
record_handlers=record_handlers,
)
Expand Down Expand Up @@ -458,7 +454,6 @@ def _build_backup(self) -> Backup:
for table, model in self._synthetics_train.models.items()
},
lost_contact=self._synthetics_train.lost_contact,
training_columns=self._synthetics_train.training_columns,
)

# Generate
Expand All @@ -467,7 +462,6 @@ def _build_backup(self) -> Backup:
identifier=self._synthetics_run.identifier,
record_size_ratio=self._synthetics_run.record_size_ratio,
preserved=self._synthetics_run.preserved,
missing_model=self._synthetics_run.missing_model,
lost_contact=self._synthetics_run.lost_contact,
record_handler_ids={
table: rh.record_id
Expand Down Expand Up @@ -516,6 +510,35 @@ def _create_debug_summary(self) -> None:
self._project, str(debug_summary_path)
)

def delete_models(
self,
workflow: str,
*,
only: Optional[list[str]] = None,
ignore: Optional[list[str]] = None,
) -> None:
only, ignore = self._get_only_and_ignore(only, ignore)
tables = [
table
for table in self.relational_data.list_all_tables()
if not skip_table(table, only, ignore)
]
if workflow == "classify":
state = self._classify
elif workflow == "transforms":
state = self._transforms_train
elif workflow == "synthetics":
state = self._synthetics_train
else:
raise MultiTableException(
"Unknown workflow; must specify `classify`, `transforms`, or `synthetics`."
)

for table in tables:
if (model := state.models.get(table)) is not None:
model.delete()
del state.models[table]

def classify(self, config: GretelModelConfig, all_rows: bool = False) -> None:
classify_data_sources = {}
for table in self.relational_data.list_all_tables():
Expand Down Expand Up @@ -605,31 +628,12 @@ def train_transforms(
only: Optional[list[str]] = None,
ignore: Optional[list[str]] = None,
) -> None:
if only is not None and ignore is not None:
raise MultiTableException("Cannot specify both `only` and `ignore`.")

m_only = None
if only is not None:
m_only = []
for table in only:
m_names = self.relational_data.get_modelable_table_names(table)
if len(m_names) == 0:
raise MultiTableException(f"Unrecognized table name: `{table}`")
m_only.extend(m_names)

m_ignore = None
if ignore is not None:
m_ignore = []
for table in ignore:
m_names = self.relational_data.get_modelable_table_names(table)
if len(m_names) == 0:
raise MultiTableException(f"Unrecognized table name: `{table}`")
m_ignore.extend(m_names)
only, ignore = self._get_only_and_ignore(only, ignore)

configs = {
table: config
for table in self.relational_data.list_all_tables()
if not skip_table(table, m_only, m_ignore)
if not skip_table(table, only, ignore)
}

self._setup_transforms_train_state(configs)
Expand Down Expand Up @@ -722,15 +726,39 @@ def run_transforms(
self._backup()
self.transform_output_tables = reshaped_tables

def _get_only_and_ignore(
self, only: Optional[list[str]], ignore: Optional[list[str]]
) -> tuple[Optional[list[str]], Optional[list[str]]]:
if only is not None and ignore is not None:
raise MultiTableException("Cannot specify both `only` and `ignore`.")

only_and_ignore = []

for given_tables in [only, ignore]:
if given_tables is None:
only_and_ignore.append(None)
continue

modelable_tables = []
for table in given_tables:
m_names = self.relational_data.get_modelable_table_names(table)
if len(m_names) == 0:
raise MultiTableException(f"Unrecognized table name: `{table}`")
modelable_tables.extend(m_names)

only_and_ignore.append(modelable_tables)

return tuple(only_and_ignore)

def _prepare_training_data(self, tables: list[str]) -> dict[str, Path]:
"""
Exports a copy of each table prepared for training by the configured strategy
to the working directory. Returns a dict with table names as keys and Paths
to the CSVs as values.
"""
training_data = self._strategy.prepare_training_data(self.relational_data)
for table, df in training_data.items():
self._synthetics_train.training_columns[table] = list(df.columns)
training_data = self._strategy.prepare_training_data(
self.relational_data, tables
)
training_paths = {}

for table_name in tables:
Expand All @@ -748,6 +776,13 @@ def _train_synthetics_models(self, training_data: dict[str, Path]) -> None:
)
self._synthetics_train.models[table_name] = model

archive_path = self._working_dir / "synthetics_training.tar.gz"
for table_name, csv_path in training_data.items():
add_to_tar(archive_path, csv_path, csv_path.name)
self._artifact_collection.upload_synthetics_training_archive(
self._project, str(archive_path)
)

self._backup()

task = SyntheticsTrainTask(
Expand All @@ -767,22 +802,51 @@ def _train_synthetics_models(self, training_data: dict[str, Path]) -> None:
self._extended_sdk,
)

# TODO: consider moving this to before running the task
archive_path = self._working_dir / "synthetics_training.tar.gz"
for table_name, csv_path in training_data.items():
add_to_tar(archive_path, csv_path, csv_path.name)
self._artifact_collection.upload_synthetics_training_archive(
self._project, str(archive_path)
)

def train(self) -> None:
"""Train synthetic data models on each table in the relational dataset"""
"""
DEPRECATED: Please use `train_synthetics` instead.
"""
logger.warning(
"This method is deprecated and will be removed in a future release. "
"Please use `train_synthetics` instead."
)
tables = self.relational_data.list_all_tables()
self._synthetics_train = SyntheticsTrain()

training_data = self._prepare_training_data(tables)
self._train_synthetics_models(training_data)

def train_synthetics(
self,
*,
only: Optional[list[str]] = None,
ignore: Optional[list[str]] = None,
) -> None:
"""
Train synthetic data models for the tables in the tableset,
optionally scoped by either `only` or `ignore`.
"""
only, ignore = self._get_only_and_ignore(only, ignore)

all_tables = self.relational_data.list_all_tables()
include_tables = []
omit_tables = []
for table in all_tables:
if skip_table(table, only, ignore):
omit_tables.append(table)
else:
include_tables.append(table)

# TODO: Ancestral strategy requires that for each table omitted from synthetics ("preserved"),
# all its ancestors must also be omitted. In the future, it'd be nice to either find a way to
# eliminate this requirement completely, or (less ideal) allow incremental training of tables,
# e.g. train a few in one "batch", then a few more before generating (perhaps with some logs
# along the way informing the user of which required tables are missing).
self._strategy.validate_preserved_tables(omit_tables, self.relational_data)

training_data = self._prepare_training_data(include_tables)
self._train_synthetics_models(training_data)

def retrain_tables(self, tables: dict[str, pd.DataFrame]) -> None:
"""
Provide updated table data and retrain. This method overwrites the table data in the
Expand All @@ -801,7 +865,6 @@ def retrain_tables(self, tables: dict[str, pd.DataFrame]) -> None:
for table in tables_to_retrain:
with suppress(KeyError):
del self._synthetics_train.models[table]
del self._synthetics_train.training_columns[table]
training_data = self._prepare_training_data(tables_to_retrain)
self._train_synthetics_models(training_data)

Expand Down Expand Up @@ -861,18 +924,23 @@ def generate(
logger.info(f"Resuming synthetics run `{self._synthetics_run.identifier}`")
else:
preserve_tables = preserve_tables or []
preserve_tables.extend(
[
table
for table in self.relational_data.list_all_tables()
if table not in self._synthetics_train.models
]
)
self._strategy.validate_preserved_tables(
preserve_tables, self.relational_data
)

identifier = identifier or f"synthetics_{_timestamp()}"
missing_model = self._list_tables_with_missing_models()

self._synthetics_run = SyntheticsRun(
identifier=identifier,
record_size_ratio=record_size_ratio,
preserved=preserve_tables,
missing_model=missing_model,
record_handlers={},
lost_contact=[],
)
Expand Down Expand Up @@ -909,6 +977,9 @@ def generate(
if table in self._synthetics_run.preserved:
continue

if table not in self._synthetics_train.models:
continue

if table not in self.relational_data.list_all_tables(Scope.EVALUATABLE):
continue

Expand Down Expand Up @@ -982,21 +1053,6 @@ def create_relational_report(self, run_identifier: str, target_dir: Path) -> Non
html_content = ReportRenderer().render(presenter)
report.write(html_content)

def _list_tables_with_missing_models(self) -> list[str]:
missing_model = set()
for table in self.relational_data.list_all_tables():
if not _table_trained_successfully(self._synthetics_train, table):
logger.info(
f"Skipping synthetic data generation for `{table}` because it does not have a trained model"
)
missing_model.add(table)
for descendant in self.relational_data.get_descendants(table):
logger.info(
f"Skipping synthetic data generation for `{descendant}` because it depends on `{table}`"
)
missing_model.add(table)
return list(missing_model)

def _attach_existing_reports(self, run_id: str, table: str) -> None:
individual_path = (
self._working_dir
Expand Down Expand Up @@ -1046,9 +1102,7 @@ def _validate_strategy(strategy: str) -> Union[IndependentStrategy, AncestralStr
raise MultiTableException(msg)


def _table_trained_successfully(
train_state: Union[TransformsTrain, SyntheticsTrain], table: str
) -> bool:
def _table_trained_successfully(train_state: TransformsTrain, table: str) -> bool:
model = train_state.models.get(table)
if model is None:
return False
Expand Down
Loading