Skip to content

Commit

Permalink
PLAT-1492: Bypass join tables in independent strategy
Browse files Browse the repository at this point in the history
GitOrigin-RevId: 32a2a10721dfbc1741b4080280fe36dce9dbdb76
  • Loading branch information
mikeknep committed Jan 11, 2024
1 parent 3cddc7c commit 8587e8c
Show file tree
Hide file tree
Showing 9 changed files with 150 additions and 5 deletions.
8 changes: 8 additions & 0 deletions src/gretel_trainer/relational/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -771,6 +771,14 @@ def get_table_columns(self, table: str) -> list[str]:
"""
return self._get_table_metadata(table).columns

def get_table_row_count(self, table: str) -> int:
"""
Return the number of rows in the table.
"""
source = self.get_table_source(table)
with open_artifact(source, "rb") as src:
return sum(1 for line in src) - 1

def get_safe_ancestral_seed_columns(self, table: str) -> set[str]:
safe_columns = self._get_table_metadata(table).safe_ancestral_seed_columns
if safe_columns is None:
Expand Down
11 changes: 9 additions & 2 deletions src/gretel_trainer/relational/multi_table.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@
from typing import Any, cast, Optional, Union

import pandas as pd
import smart_open

import gretel_trainer.relational.ancestry as ancestry

Expand Down Expand Up @@ -719,9 +718,16 @@ def _train_synthetics_models(self, configs: dict[str, dict[str, Any]]) -> None:
for table in configs
}

self._strategy.prepare_training_data(self.relational_data, training_paths)
training_paths = self._strategy.prepare_training_data(
self.relational_data, training_paths
)

for table_name, config in configs.items():
if table_name not in training_paths:
logger.info(f"Bypassing model training for table `{table_name}`")
self._synthetics_train.bypass.append(table_name)
continue

synthetics_config = make_synthetics_config(table_name, config)
model = self._project.create_model_obj(
model_config=synthetics_config,
Expand Down Expand Up @@ -859,6 +865,7 @@ def generate(
table
for table in self.relational_data.list_all_tables()
if table not in self._synthetics_train.models
and table not in self._synthetics_train.bypass
]
)
self._strategy.validate_preserved_tables(
Expand Down
20 changes: 18 additions & 2 deletions src/gretel_trainer/relational/strategies/independent.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,8 @@ def prepare_training_data(
Writes tables' training data to provided paths.
Training data has primary and foreign key columns removed.
"""
prepared_tables = {}

for table, path in table_paths.items():
columns_to_drop = set()
columns_to_drop.update(rel_data.get_primary_key(table))
Expand All @@ -48,15 +50,27 @@ def prepare_training_data(
all_columns = rel_data.get_table_columns(table)
use_columns = [col for col in all_columns if col not in columns_to_drop]

# It's possible for *all columns* on a table to be part of a PK or FK,
# leaving no columns to send to a model for training. We omit such tables
# from the returned dictionary, indicating to MultiTable that it should
# "bypass" training and running a model for that table and instead leave
# it alone until post-processing (synthesizing key columns).
if len(use_columns) == 0:
logger.info(
f"All columns in table `{table}` are associated with key constraints"
)
continue

source_path = rel_data.get_table_source(table)
with open_artifact(source_path, "rb") as src, open_artifact(
path, "wb"
) as dest:
pd.DataFrame(columns=use_columns).to_csv(dest, index=False)
for chunk in pd.read_csv(src, usecols=use_columns, chunksize=10_000):
chunk.to_csv(dest, index=False, mode="a", header=False)
prepared_tables[table] = path

return table_paths
return prepared_tables

def tables_to_retrain(
self, tables: list[str], rel_data: RelationalData
Expand Down Expand Up @@ -268,7 +282,9 @@ def _collect_fk_values(
def _unique_not_null_values(values: list) -> list:
unique_values = {tuple(v) for v in values}
unique_values.discard((None,))
return list(unique_values)
vals = list(unique_values)
random.shuffle(vals)
return vals

# Collect final output values by adding non-null values to `new_values`
# (which has the requisite number of nulls already).
Expand Down
14 changes: 13 additions & 1 deletion src/gretel_trainer/relational/tasks/synthetics_run.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,17 +38,29 @@ def _setup_working_tables(self) -> dict[str, Optional[pd.DataFrame]]:
all_tables = self.multitable.relational_data.list_all_tables()

for table in all_tables:
if table in self.synthetics_train.bypass:
source_row_count = self.multitable.relational_data.get_table_row_count(
table
)
out_row_count = int(
source_row_count * self.synthetics_run.record_size_ratio
)
working_tables[table] = pd.DataFrame(index=range(out_row_count))
continue

model = self.synthetics_train.models.get(table)

# Table was either omitted from training or marked as to-be-preserved during generation
if model is None or table in self.synthetics_run.preserved:
working_tables[table] = self.multitable._strategy.get_preserved_data(
table, self.multitable.relational_data
)
continue

# Table was included in training, but failed at that step
elif model.status != Status.COMPLETED:
if model.status != Status.COMPLETED:
working_tables[table] = None
continue

return working_tables

Expand Down
1 change: 1 addition & 0 deletions src/gretel_trainer/relational/workflow_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ class TransformsTrain:
class SyntheticsTrain:
models: dict[str, Model] = field(default_factory=dict)
lost_contact: list[str] = field(default_factory=list)
bypass: list[str] = field(default_factory=list)


@dataclass
Expand Down
5 changes: 5 additions & 0 deletions tests/relational/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,6 +135,11 @@ def art(tmpdir) -> Generator[RelationalData, None, None]:
yield _rel_data_connector("art").extract(storage_dir=tmpdir)


@pytest.fixture()
def insurance(tmpdir) -> Generator[RelationalData, None, None]:
yield _rel_data_connector("insurance").extract(storage_dir=tmpdir)


@pytest.fixture()
def documents(tmpdir) -> Generator[RelationalData, None, None]:
yield _rel_data_connector("documents").extract(storage_dir=tmpdir)
Expand Down
32 changes: 32 additions & 0 deletions tests/relational/example_dbs/insurance.sql
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
create table if not exists beneficiary (
id integer primary key,
name text not null
);

create table if not exists insurance_policies (
id integer primary key,
primary_beneficiary integer not null,
secondary_beneficiary integer not null,
--
foreign key (primary_beneficiary) references beneficiary (id),
foreign key (secondary_beneficiary) references beneficiary (id)
);

insert into beneficiary (name) values
("John Doe"),
("Jane Smith"),
("Michael Johnson"),
("Emily Brown"),
("William Wilson");

insert into insurance_policies (primary_beneficiary, secondary_beneficiary) values
(1, 2),
(2, 3),
(3, 4),
(4, 5),
(5, 1),
(1, 3),
(2, 4),
(3, 5),
(4, 1),
(5, 2);
59 changes: 59 additions & 0 deletions tests/relational/test_independent_strategy.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,22 @@ def test_prepare_training_data_subset_of_tables(pets):
assert os.stat(pets_dest.name).st_size == 0


def test_prepare_training_data_join_table(insurance):
strategy = IndependentStrategy()

with tempfile.NamedTemporaryFile() as beneficiary_dest, tempfile.NamedTemporaryFile() as policies_dest:
training_data = strategy.prepare_training_data(
insurance,
{
"beneficiary": beneficiary_dest.name,
"insurance_policies": policies_dest.name,
},
)
assert set(training_data.keys()) == {"beneficiary"}
assert not pd.read_csv(training_data["beneficiary"]).empty
assert os.stat(policies_dest.name).st_size == 0


def test_retraining_a_set_of_tables_only_retrains_those_tables(ecom):
strategy = IndependentStrategy()
assert set(strategy.tables_to_retrain(["users"], ecom)) == {"users"}
Expand Down Expand Up @@ -366,3 +382,46 @@ def test_post_processing_null_composite_foreign_key(tmpdir):
}
),
)


def test_post_processing_with_bypass_table(insurance):
strategy = IndependentStrategy()

raw_synth_tables = {
"beneficiary": pd.DataFrame(
data={
"name": ["Adam", "Beth", "Chris", "Demi", "Eric"],
}
),
"insurance_policies": pd.DataFrame(index=range(5)),
}

# Normally we shuffle synthesized keys for realism, but for deterministic testing we sort instead
with patch("random.shuffle", wraps=sorted):
processed = strategy.post_process_synthetic_results(
raw_synth_tables, [], insurance, 1
)

pdtest.assert_frame_equal(
processed["beneficiary"],
pd.DataFrame(
data={
"name": ["Adam", "Beth", "Chris", "Demi", "Eric"],
"id": [0, 1, 2, 3, 4],
}
),
)
# Given the particular values in this unit test and the patching of random.shuffle to use
# sorted instead, we deterministically get the beneficiary ID values below. In production
# use, we shuffle values to produce more realistic results (though it is still possible to
# get "unusual" results like primary_ and secondary_ pointing to the same beneficiary record).
pdtest.assert_frame_equal(
processed["insurance_policies"],
pd.DataFrame(
data={
"id": [0, 1, 2, 3, 4],
"primary_beneficiary": [2, 2, 4, 4, 1],
"secondary_beneficiary": [2, 2, 4, 4, 1],
}
),
)
5 changes: 5 additions & 0 deletions tests/relational/test_relational_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,11 @@ def test_mutagenesis_relational_data(mutagenesis):
assert set(mutagenesis.get_all_key_columns("atom")) == {"atom_id", "molecule_id"}


def test_row_count(art):
assert art.get_table_row_count("artists") == 4
assert art.get_table_row_count("paintings") == 7


def test_column_metadata(pets, tmpfile):
assert pets.get_table_columns("humans") == ["id", "name", "city"]

Expand Down

0 comments on commit 8587e8c

Please sign in to comment.