From d6d5a4dd7c611c9dcc014326fe4eae0ae3abe417 Mon Sep 17 00:00:00 2001 From: Mike Knepper Date: Fri, 2 Jun 2023 15:26:16 -0500 Subject: [PATCH 1/3] Don't hang on to training DFs longer than necessary --- src/gretel_trainer/relational/multi_table.py | 13 +- .../relational/strategies/ancestral.py | 18 +-- .../relational/strategies/independent.py | 21 ++-- tests/relational/test_ancestral_strategy.py | 113 +++++++++++++----- tests/relational/test_independent_strategy.py | 42 ++++--- 5 files changed, 130 insertions(+), 77 deletions(-) diff --git a/src/gretel_trainer/relational/multi_table.py b/src/gretel_trainer/relational/multi_table.py index 5053c97b..8bb14956 100644 --- a/src/gretel_trainer/relational/multi_table.py +++ b/src/gretel_trainer/relational/multi_table.py @@ -724,15 +724,12 @@ def _prepare_training_data(self, tables: list[str]) -> dict[str, Path]: to the working directory. Returns a dict with table names as keys and Paths to the CSVs as values. """ - training_data = self._strategy.prepare_training_data( - self.relational_data, tables - ) - training_paths = {} + training_paths = { + table: self._working_dir / f"synthetics_train_{table}.csv" + for table in tables + } - for table_name in tables: - training_path = self._working_dir / f"synthetics_train_{table_name}.csv" - training_data[table_name].to_csv(training_path, index=False) - training_paths[table_name] = training_path + self._strategy.prepare_training_data(self.relational_data, training_paths) return training_paths diff --git a/src/gretel_trainer/relational/strategies/ancestral.py b/src/gretel_trainer/relational/strategies/ancestral.py index 76e39dbf..2c1c5ab9 100644 --- a/src/gretel_trainer/relational/strategies/ancestral.py +++ b/src/gretel_trainer/relational/strategies/ancestral.py @@ -33,19 +33,19 @@ def label_encode_keys( return common.label_encode_keys(rel_data, tables) def prepare_training_data( - self, rel_data: RelationalData, tables: list[str] - ) -> dict[str, pd.DataFrame]: + self, rel_data: RelationalData, table_paths: dict[str, Path] + ) -> dict[str, Path]: """ - Returns tables with: + Writes tables' training data to provided paths. + Training data has: - all safe-for-seed ancestor fields added - columns in multigenerational format - all keys translated to contiguous integers - artificial min/max seed records added """ all_tables = rel_data.list_all_tables() - omitted_tables = [t for t in all_tables if t not in tables] + omitted_tables = [t for t in all_tables if t not in table_paths] altered_tableset = {} - training_data = {} # Create a new table set identical to source data for table_name in all_tables: @@ -62,16 +62,16 @@ def prepare_training_data( ) # Collect all data in multigenerational format - for table_name in tables: + for table, path in table_paths.items(): data = ancestry.get_table_data_with_ancestors( rel_data=rel_data, - table=table_name, + table=table, tableset=altered_tableset, ancestral_seeding=True, ) - training_data[table_name] = data + data.to_csv(path, index=False) - return training_data + return table_paths def tables_to_retrain( self, tables: list[str], rel_data: RelationalData diff --git a/src/gretel_trainer/relational/strategies/independent.py b/src/gretel_trainer/relational/strategies/independent.py index 8c771aa4..44bf1a28 100644 --- a/src/gretel_trainer/relational/strategies/independent.py +++ b/src/gretel_trainer/relational/strategies/independent.py @@ -34,25 +34,24 @@ def label_encode_keys( return common.label_encode_keys(rel_data, tables) def prepare_training_data( - self, rel_data: RelationalData, tables: list[str] - ) -> dict[str, pd.DataFrame]: + self, rel_data: RelationalData, table_paths: dict[str, Path] + ) -> dict[str, Path]: """ - Returns source tables with primary and foreign keys removed + Writes tables' training data to provided paths. + Training data has primary and foreign key columns removed. """ - training_data = {} - - for table_name in tables: + for table, path in table_paths.items(): columns_to_drop = [] - columns_to_drop.extend(rel_data.get_primary_key(table_name)) - for foreign_key in rel_data.get_foreign_keys(table_name): + columns_to_drop.extend(rel_data.get_primary_key(table)) + for foreign_key in rel_data.get_foreign_keys(table): columns_to_drop.extend(foreign_key.columns) - data = rel_data.get_table_data(table_name) + data = rel_data.get_table_data(table) data = data.drop(columns=columns_to_drop) - training_data[table_name] = data + data.to_csv(path, index=False) - return training_data + return table_paths def tables_to_retrain( self, tables: list[str], rel_data: RelationalData diff --git a/tests/relational/test_ancestral_strategy.py b/tests/relational/test_ancestral_strategy.py index 6f2e6192..c3cc9508 100644 --- a/tests/relational/test_ancestral_strategy.py +++ b/tests/relational/test_ancestral_strategy.py @@ -14,45 +14,56 @@ from gretel_trainer.relational.table_evaluation import TableEvaluation -def test_preparing_training_data_does_not_mutate_source_data(pets, art): - for rel_data in [pets, art]: - original_tables = { - table: rel_data.get_table_data(table).copy() - for table in rel_data.list_all_tables() - } +def test_preparing_training_data_does_not_mutate_source_data(pets): + original_tables = { + table: pets.get_table_data(table).copy() for table in pets.list_all_tables() + } - strategy = AncestralStrategy() - strategy.prepare_training_data(rel_data, rel_data.list_all_tables()) + strategy = AncestralStrategy() - for table in rel_data.list_all_tables(): - pdtest.assert_frame_equal( - original_tables[table], rel_data.get_table_data(table) - ) + with tempfile.NamedTemporaryFile() as pets_dest, tempfile.NamedTemporaryFile() as humans_dest: + strategy.prepare_training_data( + pets, {"pets": Path(pets_dest.name), "humans": Path(humans_dest.name)} + ) + + for table in pets.list_all_tables(): + pdtest.assert_frame_equal(original_tables[table], pets.get_table_data(table)) def test_prepare_training_data_subset_of_tables(pets): strategy = AncestralStrategy() - # We aren't synthesizing the "humans" table, so it is not in this list argument... - training_data = strategy.prepare_training_data(pets, ["pets"]) - # ...nor do we create training data for it - assert set(training_data.keys()) == {"pets"} + with tempfile.NamedTemporaryFile() as pets_dest, tempfile.NamedTemporaryFile() as humans_dest: + # We aren't synthesizing the "humans" table, so it is not in this list argument... + training_data = strategy.prepare_training_data( + pets, {"pets": Path(pets_dest.name)} + ) + + train_pets = pd.read_csv(training_data["pets"]) + + # ...nor do we create training data for it + assert not train_pets.empty + assert os.stat(humans_dest.name).st_size == 0 # Since the humans table is omitted from synthetics, we leave the FK values alone; specifically: # - they are not label-encoded (which would effectively zero-index them) # - we do not add artificial min/max values - assert set(training_data["pets"]["self|human_id"].values) == {1, 2, 3, 4, 5} + assert set(train_pets["self|human_id"].values) == {1, 2, 3, 4, 5} # We do add the artificial max PK row, though, since this table is being synthesized - assert len(training_data["pets"]) == 6 + assert len(train_pets) == 6 def test_prepare_training_data_returns_multigenerational_data(pets): strategy = AncestralStrategy() - training_data = strategy.prepare_training_data(pets, pets.list_all_tables()) + with tempfile.NamedTemporaryFile() as pets_dest, tempfile.NamedTemporaryFile() as humans_dest: + training_data = strategy.prepare_training_data( + pets, {"pets": Path(pets_dest.name), "humans": Path(humans_dest.name)} + ) + train_pets = pd.read_csv(training_data["pets"]) for expected_column in ["self|id", "self|name", "self.human_id|id"]: - assert expected_column in training_data["pets"] + assert expected_column in train_pets def test_prepare_training_data_drops_highly_unique_categorical_ancestor_fields(art): @@ -77,10 +88,19 @@ def test_prepare_training_data_drops_highly_unique_categorical_ancestor_fields(a ) strategy = AncestralStrategy() - training_data = strategy.prepare_training_data(art, art.list_all_tables()) + + with tempfile.NamedTemporaryFile() as artists_dest, tempfile.NamedTemporaryFile() as paintings_dest: + training_data = strategy.prepare_training_data( + art, + { + "artists": Path(artists_dest.name), + "paintings": Path(paintings_dest.name), + }, + ) + train_paintings = pd.read_csv(training_data["paintings"]) # Does not contain `self.artist_id|name` because it is highly unique categorical - assert set(training_data["paintings"].columns) == { + assert set(train_paintings.columns) == { "self|id", "self|name", "self|artist_id", @@ -116,10 +136,19 @@ def test_prepare_training_data_drops_highly_nan_ancestor_fields(art): ) strategy = AncestralStrategy() - training_data = strategy.prepare_training_data(art, art.list_all_tables()) + + with tempfile.NamedTemporaryFile() as artists_dest, tempfile.NamedTemporaryFile() as paintings_dest: + training_data = strategy.prepare_training_data( + art, + { + "artists": Path(artists_dest.name), + "paintings": Path(paintings_dest.name), + }, + ) + train_paintings = pd.read_csv(training_data["paintings"]) # Does not contain `self.artist_id|name` because it is highly NaN - assert set(training_data["paintings"].columns) == { + assert set(train_paintings.columns) == { "self|id", "self|name", "self|artist_id", @@ -131,12 +160,22 @@ def test_prepare_training_data_translates_alphanumeric_keys_and_adds_min_max_rec art, ): strategy = AncestralStrategy() - training_data = strategy.prepare_training_data(art, art.list_all_tables()) + + with tempfile.NamedTemporaryFile() as artists_dest, tempfile.NamedTemporaryFile() as paintings_dest: + training_data = strategy.prepare_training_data( + art, + { + "artists": Path(artists_dest.name), + "paintings": Path(paintings_dest.name), + }, + ) + train_artists = pd.read_csv(training_data["artists"]) + train_paintings = pd.read_csv(training_data["paintings"]) # Artists, a parent table, should have 1 additional row - assert len(training_data["artists"]) == len(art.get_table_data("artists")) + 1 + assert len(train_artists) == len(art.get_table_data("artists")) + 1 # The last record has the artifical max PK - assert training_data["artists"]["self|id"].to_list() == [0, 1, 2, 3, 200] + assert train_artists["self|id"].to_list() == [0, 1, 2, 3, 200] # We do not assert on the value of "self|name" because the artificial max PK record is # randomly sampled from source and so the exact value is not deterministic @@ -144,9 +183,9 @@ def test_prepare_training_data_translates_alphanumeric_keys_and_adds_min_max_rec # - artificial max PK # - artificial min FKs # - artificial max FKs - assert len(training_data["paintings"]) == len(art.get_table_data("paintings")) + 3 + assert len(train_paintings) == len(art.get_table_data("paintings")) + 3 - last_three = training_data["paintings"].tail(3) + last_three = train_paintings.tail(3) last_two = last_three.tail(2) # PKs are max, +1, +2 @@ -158,7 +197,19 @@ def test_prepare_training_data_translates_alphanumeric_keys_and_adds_min_max_rec def test_prepare_training_data_with_composite_keys(tpch): strategy = AncestralStrategy() - training_data = strategy.prepare_training_data(tpch, tpch.list_all_tables()) + with tempfile.NamedTemporaryFile() as supplier_dest, tempfile.NamedTemporaryFile() as part_dest, tempfile.NamedTemporaryFile() as partsupp_dest, tempfile.NamedTemporaryFile() as lineitem_dest: + training_data = strategy.prepare_training_data( + tpch, + { + "supplier": Path(supplier_dest.name), + "part": Path(part_dest.name), + "partsupp": Path(partsupp_dest.name), + "lineitem": Path(lineitem_dest.name), + }, + ) + + train_partsupp = pd.read_csv(training_data["partsupp"]) + train_lineitem = pd.read_csv(training_data["lineitem"]) l_max = len(tpch.get_table_data("lineitem")) * 50 ps_max = len(tpch.get_table_data("partsupp")) * 50 @@ -166,7 +217,6 @@ def test_prepare_training_data_with_composite_keys(tpch): s_max = len(tpch.get_table_data("supplier")) * 50 # partsupp table, composite PK - train_partsupp = training_data["partsupp"] assert set(train_partsupp.columns) == { "self|ps_partkey", "self|ps_suppkey", @@ -189,7 +239,6 @@ def test_prepare_training_data_with_composite_keys(tpch): ) # lineitem table, composite FK to partsupp - train_lineitem = training_data["lineitem"] assert set(train_lineitem.columns) == { "self|l_partkey", "self|l_suppkey", diff --git a/tests/relational/test_independent_strategy.py b/tests/relational/test_independent_strategy.py index 6ff22449..662e1f32 100644 --- a/tests/relational/test_independent_strategy.py +++ b/tests/relational/test_independent_strategy.py @@ -1,4 +1,5 @@ import json +import os import tempfile from collections import defaultdict from pathlib import Path @@ -11,36 +12,43 @@ from gretel_trainer.relational.table_evaluation import TableEvaluation -def test_preparing_training_data_does_not_mutate_source_data(pets, art): - for rel_data in [pets, art]: - original_tables = { - table: rel_data.get_table_data(table).copy() - for table in rel_data.list_all_tables() - } +def test_preparing_training_data_does_not_mutate_source_data(pets): + original_tables = { + table: pets.get_table_data(table).copy() for table in pets.list_all_tables() + } + + strategy = IndependentStrategy() - strategy = IndependentStrategy() - strategy.prepare_training_data(rel_data, rel_data.list_all_tables()) + with tempfile.NamedTemporaryFile() as pets_dest, tempfile.NamedTemporaryFile() as humans_dest: + strategy.prepare_training_data( + pets, {"pets": Path(pets_dest.name), "humans": Path(humans_dest.name)} + ) - for table in rel_data.list_all_tables(): - pdtest.assert_frame_equal( - original_tables[table], rel_data.get_table_data(table) - ) + for table in pets.list_all_tables(): + pdtest.assert_frame_equal(original_tables[table], pets.get_table_data(table)) def test_prepare_training_data_removes_primary_and_foreign_keys(pets): strategy = IndependentStrategy() - training_data = strategy.prepare_training_data(pets, pets.list_all_tables()) + with tempfile.NamedTemporaryFile() as pets_dest, tempfile.NamedTemporaryFile() as humans_dest: + training_data = strategy.prepare_training_data( + pets, {"pets": Path(pets_dest.name), "humans": Path(humans_dest.name)} + ) + train_pets = pd.read_csv(training_data["pets"]) - assert set(training_data["pets"].columns) == {"name", "age"} + assert set(train_pets.columns) == {"name", "age"} def test_prepare_training_data_subset_of_tables(pets): strategy = IndependentStrategy() - training_data = strategy.prepare_training_data(pets, ["humans"]) - - assert set(training_data.keys()) == {"humans"} + with tempfile.NamedTemporaryFile() as pets_dest, tempfile.NamedTemporaryFile() as humans_dest: + training_data = strategy.prepare_training_data( + pets, {"humans": Path(humans_dest.name)} + ) + assert not pd.read_csv(training_data["humans"]).empty + assert os.stat(pets_dest.name).st_size == 0 def test_retraining_a_set_of_tables_only_retrains_those_tables(ecom): From 6867476ccabe06a68e70e2cc5e69c5ee4c460e61 Mon Sep 17 00:00:00 2001 From: Mike Knepper Date: Fri, 2 Jun 2023 15:40:45 -0500 Subject: [PATCH 2/3] Only load the columns we need, instead of all followed by drop --- .../relational/strategies/independent.py | 14 ++++++++------ 1 file changed, 8 insertions(+), 6 deletions(-) diff --git a/src/gretel_trainer/relational/strategies/independent.py b/src/gretel_trainer/relational/strategies/independent.py index 44bf1a28..5c39178c 100644 --- a/src/gretel_trainer/relational/strategies/independent.py +++ b/src/gretel_trainer/relational/strategies/independent.py @@ -41,15 +41,17 @@ def prepare_training_data( Training data has primary and foreign key columns removed. """ for table, path in table_paths.items(): - columns_to_drop = [] - columns_to_drop.extend(rel_data.get_primary_key(table)) + columns_to_drop = set() + columns_to_drop.update(rel_data.get_primary_key(table)) for foreign_key in rel_data.get_foreign_keys(table): - columns_to_drop.extend(foreign_key.columns) + columns_to_drop.update(foreign_key.columns) - data = rel_data.get_table_data(table) - data = data.drop(columns=columns_to_drop) + all_columns = rel_data.get_table_columns(table) + use_columns = all_columns - columns_to_drop - data.to_csv(path, index=False) + rel_data.get_table_data(table, usecols=use_columns).to_csv( + path, index=False + ) return table_paths From 56c14011d71b8a60a1fa64013891495afc4a09ee Mon Sep 17 00:00:00 2001 From: Mike Knepper Date: Fri, 2 Jun 2023 16:12:51 -0500 Subject: [PATCH 3/3] Collapse private methods --- src/gretel_trainer/relational/multi_table.py | 26 +++++++++----------- 1 file changed, 11 insertions(+), 15 deletions(-) diff --git a/src/gretel_trainer/relational/multi_table.py b/src/gretel_trainer/relational/multi_table.py index 8bb14956..de9b703a 100644 --- a/src/gretel_trainer/relational/multi_table.py +++ b/src/gretel_trainer/relational/multi_table.py @@ -718,11 +718,12 @@ def _get_only_and_ignore( else: return (None, None) - def _prepare_training_data(self, tables: list[str]) -> dict[str, Path]: + def _train_synthetics_models(self, tables: list[str]) -> None: """ - Exports a copy of each table prepared for training by the configured strategy - to the working directory. Returns a dict with table names as keys and Paths - to the CSVs as values. + Uses the configured strategy to prepare training data sources for each table, + exported to the working directory. Creates a model for each table and submits + it for training. Upon completion, downloads the evaluation reports for each + table to the working directory. """ training_paths = { table: self._working_dir / f"synthetics_train_{table}.csv" @@ -731,10 +732,7 @@ def _prepare_training_data(self, tables: list[str]) -> dict[str, Path]: self._strategy.prepare_training_data(self.relational_data, training_paths) - return training_paths - - def _train_synthetics_models(self, training_data: dict[str, Path]) -> None: - for table_name, training_csv in training_data.items(): + for table_name, training_csv in training_paths.items(): synthetics_config = make_synthetics_config(table_name, self._model_config) model = self._project.create_model_obj( model_config=synthetics_config, data_source=str(training_csv) @@ -742,7 +740,7 @@ def _train_synthetics_models(self, training_data: dict[str, Path]) -> None: self._synthetics_train.models[table_name] = model archive_path = self._working_dir / "synthetics_training.tar.gz" - for table_name, csv_path in training_data.items(): + for table_name, csv_path in training_paths.items(): add_to_tar(archive_path, csv_path, csv_path.name) self._artifact_collection.upload_synthetics_training_archive( self._project, str(archive_path) @@ -778,8 +776,7 @@ def train(self) -> None: tables = self.relational_data.list_all_tables() self._synthetics_train = SyntheticsTrain() - training_data = self._prepare_training_data(tables) - self._train_synthetics_models(training_data) + self._train_synthetics_models(tables) def train_synthetics( self, @@ -809,8 +806,7 @@ def train_synthetics( # along the way informing the user of which required tables are missing). self._strategy.validate_preserved_tables(omit_tables, self.relational_data) - training_data = self._prepare_training_data(include_tables) - self._train_synthetics_models(training_data) + self._train_synthetics_models(include_tables) def retrain_tables(self, tables: dict[str, pd.DataFrame]) -> None: """ @@ -830,8 +826,8 @@ def retrain_tables(self, tables: dict[str, pd.DataFrame]) -> None: for table in tables_to_retrain: with suppress(KeyError): del self._synthetics_train.models[table] - training_data = self._prepare_training_data(tables_to_retrain) - self._train_synthetics_models(training_data) + + self._train_synthetics_models(tables_to_retrain) def _upload_sources_to_project(self) -> None: archive_path = self._working_dir / "source_tables.tar.gz"