From 8587e8cf947734a07edca6799de1921543acb427 Mon Sep 17 00:00:00 2001 From: Mike Knepper Date: Thu, 11 Jan 2024 14:16:07 -0600 Subject: [PATCH] PLAT-1492: Bypass join tables in independent strategy GitOrigin-RevId: 32a2a10721dfbc1741b4080280fe36dce9dbdb76 --- src/gretel_trainer/relational/core.py | 8 +++ src/gretel_trainer/relational/multi_table.py | 11 +++- .../relational/strategies/independent.py | 20 ++++++- .../relational/tasks/synthetics_run.py | 14 ++++- .../relational/workflow_state.py | 1 + tests/relational/conftest.py | 5 ++ tests/relational/example_dbs/insurance.sql | 32 ++++++++++ tests/relational/test_independent_strategy.py | 59 +++++++++++++++++++ tests/relational/test_relational_data.py | 5 ++ 9 files changed, 150 insertions(+), 5 deletions(-) create mode 100644 tests/relational/example_dbs/insurance.sql diff --git a/src/gretel_trainer/relational/core.py b/src/gretel_trainer/relational/core.py index 087fae91..e9a0422d 100644 --- a/src/gretel_trainer/relational/core.py +++ b/src/gretel_trainer/relational/core.py @@ -771,6 +771,14 @@ def get_table_columns(self, table: str) -> list[str]: """ return self._get_table_metadata(table).columns + def get_table_row_count(self, table: str) -> int: + """ + Return the number of rows in the table. + """ + source = self.get_table_source(table) + with open_artifact(source, "rb") as src: + return sum(1 for line in src) - 1 + def get_safe_ancestral_seed_columns(self, table: str) -> set[str]: safe_columns = self._get_table_metadata(table).safe_ancestral_seed_columns if safe_columns is None: diff --git a/src/gretel_trainer/relational/multi_table.py b/src/gretel_trainer/relational/multi_table.py index 3ae58864..5ae5b786 100644 --- a/src/gretel_trainer/relational/multi_table.py +++ b/src/gretel_trainer/relational/multi_table.py @@ -19,7 +19,6 @@ from typing import Any, cast, Optional, Union import pandas as pd -import smart_open import gretel_trainer.relational.ancestry as ancestry @@ -719,9 +718,16 @@ def _train_synthetics_models(self, configs: dict[str, dict[str, Any]]) -> None: for table in configs } - self._strategy.prepare_training_data(self.relational_data, training_paths) + training_paths = self._strategy.prepare_training_data( + self.relational_data, training_paths + ) for table_name, config in configs.items(): + if table_name not in training_paths: + logger.info(f"Bypassing model training for table `{table_name}`") + self._synthetics_train.bypass.append(table_name) + continue + synthetics_config = make_synthetics_config(table_name, config) model = self._project.create_model_obj( model_config=synthetics_config, @@ -859,6 +865,7 @@ def generate( table for table in self.relational_data.list_all_tables() if table not in self._synthetics_train.models + and table not in self._synthetics_train.bypass ] ) self._strategy.validate_preserved_tables( diff --git a/src/gretel_trainer/relational/strategies/independent.py b/src/gretel_trainer/relational/strategies/independent.py index 2820c39d..1361fec6 100644 --- a/src/gretel_trainer/relational/strategies/independent.py +++ b/src/gretel_trainer/relational/strategies/independent.py @@ -39,6 +39,8 @@ def prepare_training_data( Writes tables' training data to provided paths. Training data has primary and foreign key columns removed. """ + prepared_tables = {} + for table, path in table_paths.items(): columns_to_drop = set() columns_to_drop.update(rel_data.get_primary_key(table)) @@ -48,6 +50,17 @@ def prepare_training_data( all_columns = rel_data.get_table_columns(table) use_columns = [col for col in all_columns if col not in columns_to_drop] + # It's possible for *all columns* on a table to be part of a PK or FK, + # leaving no columns to send to a model for training. We omit such tables + # from the returned dictionary, indicating to MultiTable that it should + # "bypass" training and running a model for that table and instead leave + # it alone until post-processing (synthesizing key columns). + if len(use_columns) == 0: + logger.info( + f"All columns in table `{table}` are associated with key constraints" + ) + continue + source_path = rel_data.get_table_source(table) with open_artifact(source_path, "rb") as src, open_artifact( path, "wb" @@ -55,8 +68,9 @@ def prepare_training_data( pd.DataFrame(columns=use_columns).to_csv(dest, index=False) for chunk in pd.read_csv(src, usecols=use_columns, chunksize=10_000): chunk.to_csv(dest, index=False, mode="a", header=False) + prepared_tables[table] = path - return table_paths + return prepared_tables def tables_to_retrain( self, tables: list[str], rel_data: RelationalData @@ -268,7 +282,9 @@ def _collect_fk_values( def _unique_not_null_values(values: list) -> list: unique_values = {tuple(v) for v in values} unique_values.discard((None,)) - return list(unique_values) + vals = list(unique_values) + random.shuffle(vals) + return vals # Collect final output values by adding non-null values to `new_values` # (which has the requisite number of nulls already). diff --git a/src/gretel_trainer/relational/tasks/synthetics_run.py b/src/gretel_trainer/relational/tasks/synthetics_run.py index d5cedaf3..20336ea4 100644 --- a/src/gretel_trainer/relational/tasks/synthetics_run.py +++ b/src/gretel_trainer/relational/tasks/synthetics_run.py @@ -38,6 +38,16 @@ def _setup_working_tables(self) -> dict[str, Optional[pd.DataFrame]]: all_tables = self.multitable.relational_data.list_all_tables() for table in all_tables: + if table in self.synthetics_train.bypass: + source_row_count = self.multitable.relational_data.get_table_row_count( + table + ) + out_row_count = int( + source_row_count * self.synthetics_run.record_size_ratio + ) + working_tables[table] = pd.DataFrame(index=range(out_row_count)) + continue + model = self.synthetics_train.models.get(table) # Table was either omitted from training or marked as to-be-preserved during generation @@ -45,10 +55,12 @@ def _setup_working_tables(self) -> dict[str, Optional[pd.DataFrame]]: working_tables[table] = self.multitable._strategy.get_preserved_data( table, self.multitable.relational_data ) + continue # Table was included in training, but failed at that step - elif model.status != Status.COMPLETED: + if model.status != Status.COMPLETED: working_tables[table] = None + continue return working_tables diff --git a/src/gretel_trainer/relational/workflow_state.py b/src/gretel_trainer/relational/workflow_state.py index e93fe813..ff1004f2 100644 --- a/src/gretel_trainer/relational/workflow_state.py +++ b/src/gretel_trainer/relational/workflow_state.py @@ -19,6 +19,7 @@ class TransformsTrain: class SyntheticsTrain: models: dict[str, Model] = field(default_factory=dict) lost_contact: list[str] = field(default_factory=list) + bypass: list[str] = field(default_factory=list) @dataclass diff --git a/tests/relational/conftest.py b/tests/relational/conftest.py index 1980b062..c18b1c2d 100644 --- a/tests/relational/conftest.py +++ b/tests/relational/conftest.py @@ -135,6 +135,11 @@ def art(tmpdir) -> Generator[RelationalData, None, None]: yield _rel_data_connector("art").extract(storage_dir=tmpdir) +@pytest.fixture() +def insurance(tmpdir) -> Generator[RelationalData, None, None]: + yield _rel_data_connector("insurance").extract(storage_dir=tmpdir) + + @pytest.fixture() def documents(tmpdir) -> Generator[RelationalData, None, None]: yield _rel_data_connector("documents").extract(storage_dir=tmpdir) diff --git a/tests/relational/example_dbs/insurance.sql b/tests/relational/example_dbs/insurance.sql new file mode 100644 index 00000000..c00049b7 --- /dev/null +++ b/tests/relational/example_dbs/insurance.sql @@ -0,0 +1,32 @@ +create table if not exists beneficiary ( + id integer primary key, + name text not null +); + +create table if not exists insurance_policies ( + id integer primary key, + primary_beneficiary integer not null, + secondary_beneficiary integer not null, + -- + foreign key (primary_beneficiary) references beneficiary (id), + foreign key (secondary_beneficiary) references beneficiary (id) +); + +insert into beneficiary (name) values + ("John Doe"), + ("Jane Smith"), + ("Michael Johnson"), + ("Emily Brown"), + ("William Wilson"); + +insert into insurance_policies (primary_beneficiary, secondary_beneficiary) values + (1, 2), + (2, 3), + (3, 4), + (4, 5), + (5, 1), + (1, 3), + (2, 4), + (3, 5), + (4, 1), + (5, 2); diff --git a/tests/relational/test_independent_strategy.py b/tests/relational/test_independent_strategy.py index 23ff4c6c..99b47e94 100644 --- a/tests/relational/test_independent_strategy.py +++ b/tests/relational/test_independent_strategy.py @@ -50,6 +50,22 @@ def test_prepare_training_data_subset_of_tables(pets): assert os.stat(pets_dest.name).st_size == 0 +def test_prepare_training_data_join_table(insurance): + strategy = IndependentStrategy() + + with tempfile.NamedTemporaryFile() as beneficiary_dest, tempfile.NamedTemporaryFile() as policies_dest: + training_data = strategy.prepare_training_data( + insurance, + { + "beneficiary": beneficiary_dest.name, + "insurance_policies": policies_dest.name, + }, + ) + assert set(training_data.keys()) == {"beneficiary"} + assert not pd.read_csv(training_data["beneficiary"]).empty + assert os.stat(policies_dest.name).st_size == 0 + + def test_retraining_a_set_of_tables_only_retrains_those_tables(ecom): strategy = IndependentStrategy() assert set(strategy.tables_to_retrain(["users"], ecom)) == {"users"} @@ -366,3 +382,46 @@ def test_post_processing_null_composite_foreign_key(tmpdir): } ), ) + + +def test_post_processing_with_bypass_table(insurance): + strategy = IndependentStrategy() + + raw_synth_tables = { + "beneficiary": pd.DataFrame( + data={ + "name": ["Adam", "Beth", "Chris", "Demi", "Eric"], + } + ), + "insurance_policies": pd.DataFrame(index=range(5)), + } + + # Normally we shuffle synthesized keys for realism, but for deterministic testing we sort instead + with patch("random.shuffle", wraps=sorted): + processed = strategy.post_process_synthetic_results( + raw_synth_tables, [], insurance, 1 + ) + + pdtest.assert_frame_equal( + processed["beneficiary"], + pd.DataFrame( + data={ + "name": ["Adam", "Beth", "Chris", "Demi", "Eric"], + "id": [0, 1, 2, 3, 4], + } + ), + ) + # Given the particular values in this unit test and the patching of random.shuffle to use + # sorted instead, we deterministically get the beneficiary ID values below. In production + # use, we shuffle values to produce more realistic results (though it is still possible to + # get "unusual" results like primary_ and secondary_ pointing to the same beneficiary record). + pdtest.assert_frame_equal( + processed["insurance_policies"], + pd.DataFrame( + data={ + "id": [0, 1, 2, 3, 4], + "primary_beneficiary": [2, 2, 4, 4, 1], + "secondary_beneficiary": [2, 2, 4, 4, 1], + } + ), + ) diff --git a/tests/relational/test_relational_data.py b/tests/relational/test_relational_data.py index 0464a2de..b1a94e69 100644 --- a/tests/relational/test_relational_data.py +++ b/tests/relational/test_relational_data.py @@ -43,6 +43,11 @@ def test_mutagenesis_relational_data(mutagenesis): assert set(mutagenesis.get_all_key_columns("atom")) == {"atom_id", "molecule_id"} +def test_row_count(art): + assert art.get_table_row_count("artists") == 4 + assert art.get_table_row_count("paintings") == 7 + + def test_column_metadata(pets, tmpfile): assert pets.get_table_columns("humans") == ["id", "name", "city"]