Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Preserve tables before training #111

Merged
merged 14 commits into from
May 30, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 36 additions & 2 deletions src/gretel_trainer/relational/ancestry.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
def get_multigenerational_primary_key(
rel_data: RelationalData, table: str
) -> list[str]:
"Returns the provided table's primary key with the ancestral lineage prefix appended"
return [
f"{_START_LINEAGE}{_END_LINEAGE}{pk}" for pk in rel_data.get_primary_key(table)
]
Expand All @@ -22,6 +23,17 @@ def get_multigenerational_primary_key(
def get_ancestral_foreign_key_maps(
rel_data: RelationalData, table: str
) -> list[tuple[str, str]]:
"""
Returns a list of two-element tuples where the first element is a foreign key column
with ancestral lineage prefix, and the second element is the ancestral-lineage-prefixed
referred column. This function ultimately provides a list of which columns are duplicates
in a fully-joined ancestral table (i.e. `get_table_data_with_ancestors`) (only between
the provided table and its direct parents, not between parents and grandparents).

For example: given an events table with foreign key `events.user_id` => `users.id`,
this method returns: [("self|user_id", "self.user_id|id")]
"""

def _ancestral_fk_map(fk: ForeignKey) -> list[tuple[str, str]]:
maps = []
fk_lineage = _COL_DELIMITER.join(fk.columns)
Expand All @@ -46,6 +58,29 @@ def _ancestral_fk_map(fk: ForeignKey) -> list[tuple[str, str]]:
]


def get_seed_safe_multigenerational_columns(
rel_data: RelationalData,
) -> dict[str, list[str]]:
"""
Returns a dict with Scope.MODELABLE table names as keys and lists of columns to use
for conditional seeding as values. By using a tableset of empty dataframes, this provides
a significantly faster / less resource-intensive way to get just the column names
from the results of `get_table_data_with_ancestors` for all tables.
"""
tableset = {
table: pd.DataFrame(columns=list(rel_data.get_table_columns(table)))
for table in rel_data.list_all_tables()
}
return {
table: list(
get_table_data_with_ancestors(
rel_data, table, tableset, ancestral_seeding=True
).columns
)
for table in rel_data.list_all_tables()
}


def get_table_data_with_ancestors(
rel_data: RelationalData,
table: str,
Expand Down Expand Up @@ -93,10 +128,9 @@ def _join_parents(
parent_data = tableset[parent_table_name][list(usecols)]
else:
parent_data = rel_data.get_table_data(parent_table_name, usecols=usecols)
parent_data = parent_data.add_prefix(f"{next_lineage}{_END_LINEAGE}")

df = df.merge(
parent_data,
parent_data.add_prefix(f"{next_lineage}{_END_LINEAGE}"),
how="left",
left_on=[f"{lineage}{_END_LINEAGE}{col}" for col in foreign_key.columns],
right_on=[
Expand Down
2 changes: 0 additions & 2 deletions src/gretel_trainer/relational/backup.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,6 @@ class BackupTransformsTrain:
class BackupSyntheticsTrain:
model_ids: dict[str, str]
lost_contact: list[str]
training_columns: dict[str, list[str]]


@dataclass
Expand All @@ -103,7 +102,6 @@ class BackupGenerate:
record_size_ratio: float
record_handler_ids: dict[str, str]
lost_contact: list[str]
missing_model: list[str]


@dataclass
Expand Down
2 changes: 1 addition & 1 deletion src/gretel_trainer/relational/connectors.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ def __init__(self, engine: Engine):
logger.info("Successfully connected to db")

def extract(
self, only: Optional[list[str]] = None, ignore: Optional[list[str]] = None
self, only: Optional[set[str]] = None, ignore: Optional[set[str]] = None
) -> RelationalData:
"""
Extracts table data and relationships from the database.
Expand Down
2 changes: 1 addition & 1 deletion src/gretel_trainer/relational/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -665,7 +665,7 @@ def debug_summary(self) -> dict[str, Any]:


def skip_table(
table: str, only: Optional[list[str]], ignore: Optional[list[str]]
table: str, only: Optional[set[str]], ignore: Optional[set[str]]
) -> bool:
skip = False
if only is not None and table not in only:
Expand Down
145 changes: 83 additions & 62 deletions src/gretel_trainer/relational/multi_table.py
Original file line number Diff line number Diff line change
Expand Up @@ -251,9 +251,6 @@ def _complete_init_from_backup(self, backup: Backup) -> None:
with tarfile.open(synthetics_training_archive_path, "r:gz") as tar:
tar.extractall(path=self._working_dir)

self._synthetics_train.training_columns = (
backup_synthetics_train.training_columns
)
self._synthetics_train.lost_contact = backup_synthetics_train.lost_contact
self._synthetics_train.models = {
table: self._project.get_model(model_id)
Expand Down Expand Up @@ -339,7 +336,6 @@ def _complete_init_from_backup(self, backup: Backup) -> None:
identifier=backup_generate.identifier,
record_size_ratio=backup_generate.record_size_ratio,
preserved=backup_generate.preserved,
missing_model=backup_generate.missing_model,
lost_contact=backup_generate.lost_contact,
record_handlers=record_handlers,
)
Expand Down Expand Up @@ -458,7 +454,6 @@ def _build_backup(self) -> Backup:
for table, model in self._synthetics_train.models.items()
},
lost_contact=self._synthetics_train.lost_contact,
training_columns=self._synthetics_train.training_columns,
)

# Generate
Expand All @@ -467,7 +462,6 @@ def _build_backup(self) -> Backup:
identifier=self._synthetics_run.identifier,
record_size_ratio=self._synthetics_run.record_size_ratio,
preserved=self._synthetics_run.preserved,
missing_model=self._synthetics_run.missing_model,
lost_contact=self._synthetics_run.lost_contact,
record_handler_ids={
table: rh.record_id
Expand Down Expand Up @@ -602,34 +596,15 @@ def train_transforms(
self,
config: GretelModelConfig,
*,
only: Optional[list[str]] = None,
ignore: Optional[list[str]] = None,
only: Optional[set[str]] = None,
ignore: Optional[set[str]] = None,
) -> None:
if only is not None and ignore is not None:
raise MultiTableException("Cannot specify both `only` and `ignore`.")

m_only = None
if only is not None:
m_only = []
for table in only:
m_names = self.relational_data.get_modelable_table_names(table)
if len(m_names) == 0:
raise MultiTableException(f"Unrecognized table name: `{table}`")
m_only.extend(m_names)

m_ignore = None
if ignore is not None:
m_ignore = []
for table in ignore:
m_names = self.relational_data.get_modelable_table_names(table)
if len(m_names) == 0:
raise MultiTableException(f"Unrecognized table name: `{table}`")
m_ignore.extend(m_names)
only, ignore = self._get_only_and_ignore(only, ignore)

configs = {
table: config
for table in self.relational_data.list_all_tables()
if not skip_table(table, m_only, m_ignore)
if not skip_table(table, only, ignore)
}

self._setup_transforms_train_state(configs)
Expand Down Expand Up @@ -722,15 +697,35 @@ def run_transforms(
self._backup()
self.transform_output_tables = reshaped_tables

def _get_only_and_ignore(
self, only: Optional[set[str]], ignore: Optional[set[str]]
) -> tuple[Optional[set[str]], Optional[set[str]]]:
if only is not None and ignore is not None:
raise MultiTableException("Cannot specify both `only` and `ignore`.")

modelable_tables = set()
for table in only or ignore or {}:
m_names = self.relational_data.get_modelable_table_names(table)
if len(m_names) == 0:
raise MultiTableException(f"Unrecognized table name: `{table}`")
modelable_tables.update(m_names)

if only is None:
return (None, modelable_tables)
elif ignore is None:
return (modelable_tables, None)
else:
return (None, None)

def _prepare_training_data(self, tables: list[str]) -> dict[str, Path]:
"""
Exports a copy of each table prepared for training by the configured strategy
to the working directory. Returns a dict with table names as keys and Paths
to the CSVs as values.
"""
training_data = self._strategy.prepare_training_data(self.relational_data)
for table, df in training_data.items():
self._synthetics_train.training_columns[table] = list(df.columns)
training_data = self._strategy.prepare_training_data(
self.relational_data, tables
)
training_paths = {}

for table_name in tables:
Expand All @@ -748,6 +743,13 @@ def _train_synthetics_models(self, training_data: dict[str, Path]) -> None:
)
self._synthetics_train.models[table_name] = model

archive_path = self._working_dir / "synthetics_training.tar.gz"
for table_name, csv_path in training_data.items():
add_to_tar(archive_path, csv_path, csv_path.name)
self._artifact_collection.upload_synthetics_training_archive(
self._project, str(archive_path)
)

self._backup()

task = SyntheticsTrainTask(
Expand All @@ -767,22 +769,51 @@ def _train_synthetics_models(self, training_data: dict[str, Path]) -> None:
self._extended_sdk,
)

# TODO: consider moving this to before running the task
mikeknep marked this conversation as resolved.
Show resolved Hide resolved
archive_path = self._working_dir / "synthetics_training.tar.gz"
for table_name, csv_path in training_data.items():
add_to_tar(archive_path, csv_path, csv_path.name)
self._artifact_collection.upload_synthetics_training_archive(
self._project, str(archive_path)
)

def train(self) -> None:
"""Train synthetic data models on each table in the relational dataset"""
"""
DEPRECATED: Please use `train_synthetics` instead.
"""
logger.warning(
mikeknep marked this conversation as resolved.
Show resolved Hide resolved
"This method is deprecated and will be removed in a future release. "
"Please use `train_synthetics` instead."
)
tables = self.relational_data.list_all_tables()
self._synthetics_train = SyntheticsTrain()

training_data = self._prepare_training_data(tables)
self._train_synthetics_models(training_data)

def train_synthetics(
self,
*,
only: Optional[set[str]] = None,
ignore: Optional[set[str]] = None,
) -> None:
"""
Train synthetic data models for the tables in the tableset,
optionally scoped by either `only` or `ignore`.
"""
only, ignore = self._get_only_and_ignore(only, ignore)

all_tables = self.relational_data.list_all_tables()
include_tables = []
omit_tables = []
for table in all_tables:
if skip_table(table, only, ignore):
omit_tables.append(table)
else:
include_tables.append(table)
tylersbray marked this conversation as resolved.
Show resolved Hide resolved

# TODO: Ancestral strategy requires that for each table omitted from synthetics ("preserved"),
tylersbray marked this conversation as resolved.
Show resolved Hide resolved
# all its ancestors must also be omitted. In the future, it'd be nice to either find a way to
# eliminate this requirement completely, or (less ideal) allow incremental training of tables,
# e.g. train a few in one "batch", then a few more before generating (perhaps with some logs
# along the way informing the user of which required tables are missing).
self._strategy.validate_preserved_tables(omit_tables, self.relational_data)

training_data = self._prepare_training_data(include_tables)
self._train_synthetics_models(training_data)

def retrain_tables(self, tables: dict[str, pd.DataFrame]) -> None:
"""
Provide updated table data and retrain. This method overwrites the table data in the
Expand All @@ -801,7 +832,6 @@ def retrain_tables(self, tables: dict[str, pd.DataFrame]) -> None:
for table in tables_to_retrain:
with suppress(KeyError):
del self._synthetics_train.models[table]
del self._synthetics_train.training_columns[table]
training_data = self._prepare_training_data(tables_to_retrain)
self._train_synthetics_models(training_data)

Expand Down Expand Up @@ -861,18 +891,23 @@ def generate(
logger.info(f"Resuming synthetics run `{self._synthetics_run.identifier}`")
else:
preserve_tables = preserve_tables or []
preserve_tables.extend(
[
table
for table in self.relational_data.list_all_tables()
if table not in self._synthetics_train.models
]
)
self._strategy.validate_preserved_tables(
preserve_tables, self.relational_data
)

identifier = identifier or f"synthetics_{_timestamp()}"
missing_model = self._list_tables_with_missing_models()

self._synthetics_run = SyntheticsRun(
identifier=identifier,
record_size_ratio=record_size_ratio,
preserved=preserve_tables,
missing_model=missing_model,
record_handlers={},
lost_contact=[],
)
Expand Down Expand Up @@ -909,6 +944,9 @@ def generate(
if table in self._synthetics_run.preserved:
continue

if table not in self._synthetics_train.models:
continue

if table not in self.relational_data.list_all_tables(Scope.EVALUATABLE):
continue

Expand Down Expand Up @@ -982,21 +1020,6 @@ def create_relational_report(self, run_identifier: str, target_dir: Path) -> Non
html_content = ReportRenderer().render(presenter)
report.write(html_content)

def _list_tables_with_missing_models(self) -> list[str]:
missing_model = set()
for table in self.relational_data.list_all_tables():
if not _table_trained_successfully(self._synthetics_train, table):
logger.info(
f"Skipping synthetic data generation for `{table}` because it does not have a trained model"
)
missing_model.add(table)
for descendant in self.relational_data.get_descendants(table):
logger.info(
f"Skipping synthetic data generation for `{descendant}` because it depends on `{table}`"
)
missing_model.add(table)
return list(missing_model)

def _attach_existing_reports(self, run_id: str, table: str) -> None:
individual_path = (
self._working_dir
Expand Down Expand Up @@ -1046,9 +1069,7 @@ def _validate_strategy(strategy: str) -> Union[IndependentStrategy, AncestralStr
raise MultiTableException(msg)


def _table_trained_successfully(
train_state: Union[TransformsTrain, SyntheticsTrain], table: str
) -> bool:
def _table_trained_successfully(train_state: TransformsTrain, table: str) -> bool:
model = train_state.models.get(table)
if model is None:
return False
Expand Down
Loading