Skip to content

Commit

Permalink
Write and drop (#117)
Browse files Browse the repository at this point in the history
* Don't hang on to training DFs longer than necessary

* Only load the columns we need, instead of all followed by drop

* Collapse private methods
  • Loading branch information
mikeknep authored Jun 5, 2023
1 parent 19a08ea commit 7e462e2
Show file tree
Hide file tree
Showing 5 changed files with 146 additions and 95 deletions.
39 changes: 16 additions & 23 deletions src/gretel_trainer/relational/multi_table.py
Original file line number Diff line number Diff line change
Expand Up @@ -718,34 +718,29 @@ def _get_only_and_ignore(
else:
return (None, None)

def _prepare_training_data(self, tables: list[str]) -> dict[str, Path]:
def _train_synthetics_models(self, tables: list[str]) -> None:
"""
Exports a copy of each table prepared for training by the configured strategy
to the working directory. Returns a dict with table names as keys and Paths
to the CSVs as values.
Uses the configured strategy to prepare training data sources for each table,
exported to the working directory. Creates a model for each table and submits
it for training. Upon completion, downloads the evaluation reports for each
table to the working directory.
"""
training_data = self._strategy.prepare_training_data(
self.relational_data, tables
)
training_paths = {}

for table_name in tables:
training_path = self._working_dir / f"synthetics_train_{table_name}.csv"
training_data[table_name].to_csv(training_path, index=False)
training_paths[table_name] = training_path
training_paths = {
table: self._working_dir / f"synthetics_train_{table}.csv"
for table in tables
}

return training_paths
self._strategy.prepare_training_data(self.relational_data, training_paths)

def _train_synthetics_models(self, training_data: dict[str, Path]) -> None:
for table_name, training_csv in training_data.items():
for table_name, training_csv in training_paths.items():
synthetics_config = make_synthetics_config(table_name, self._model_config)
model = self._project.create_model_obj(
model_config=synthetics_config, data_source=str(training_csv)
)
self._synthetics_train.models[table_name] = model

archive_path = self._working_dir / "synthetics_training.tar.gz"
for table_name, csv_path in training_data.items():
for table_name, csv_path in training_paths.items():
add_to_tar(archive_path, csv_path, csv_path.name)
self._artifact_collection.upload_synthetics_training_archive(
self._project, str(archive_path)
Expand Down Expand Up @@ -781,8 +776,7 @@ def train(self) -> None:
tables = self.relational_data.list_all_tables()
self._synthetics_train = SyntheticsTrain()

training_data = self._prepare_training_data(tables)
self._train_synthetics_models(training_data)
self._train_synthetics_models(tables)

def train_synthetics(
self,
Expand Down Expand Up @@ -812,8 +806,7 @@ def train_synthetics(
# along the way informing the user of which required tables are missing).
self._strategy.validate_preserved_tables(omit_tables, self.relational_data)

training_data = self._prepare_training_data(include_tables)
self._train_synthetics_models(training_data)
self._train_synthetics_models(include_tables)

def retrain_tables(self, tables: dict[str, pd.DataFrame]) -> None:
"""
Expand All @@ -833,8 +826,8 @@ def retrain_tables(self, tables: dict[str, pd.DataFrame]) -> None:
for table in tables_to_retrain:
with suppress(KeyError):
del self._synthetics_train.models[table]
training_data = self._prepare_training_data(tables_to_retrain)
self._train_synthetics_models(training_data)

self._train_synthetics_models(tables_to_retrain)

def _upload_sources_to_project(self) -> None:
archive_path = self._working_dir / "source_tables.tar.gz"
Expand Down
18 changes: 9 additions & 9 deletions src/gretel_trainer/relational/strategies/ancestral.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,19 +33,19 @@ def label_encode_keys(
return common.label_encode_keys(rel_data, tables)

def prepare_training_data(
self, rel_data: RelationalData, tables: list[str]
) -> dict[str, pd.DataFrame]:
self, rel_data: RelationalData, table_paths: dict[str, Path]
) -> dict[str, Path]:
"""
Returns tables with:
Writes tables' training data to provided paths.
Training data has:
- all safe-for-seed ancestor fields added
- columns in multigenerational format
- all keys translated to contiguous integers
- artificial min/max seed records added
"""
all_tables = rel_data.list_all_tables()
omitted_tables = [t for t in all_tables if t not in tables]
omitted_tables = [t for t in all_tables if t not in table_paths]
altered_tableset = {}
training_data = {}

# Create a new table set identical to source data
for table_name in all_tables:
Expand All @@ -62,16 +62,16 @@ def prepare_training_data(
)

# Collect all data in multigenerational format
for table_name in tables:
for table, path in table_paths.items():
data = ancestry.get_table_data_with_ancestors(
rel_data=rel_data,
table=table_name,
table=table,
tableset=altered_tableset,
ancestral_seeding=True,
)
training_data[table_name] = data
data.to_csv(path, index=False)

return training_data
return table_paths

def tables_to_retrain(
self, tables: list[str], rel_data: RelationalData
Expand Down
29 changes: 15 additions & 14 deletions src/gretel_trainer/relational/strategies/independent.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,25 +34,26 @@ def label_encode_keys(
return common.label_encode_keys(rel_data, tables)

def prepare_training_data(
self, rel_data: RelationalData, tables: list[str]
) -> dict[str, pd.DataFrame]:
self, rel_data: RelationalData, table_paths: dict[str, Path]
) -> dict[str, Path]:
"""
Returns source tables with primary and foreign keys removed
Writes tables' training data to provided paths.
Training data has primary and foreign key columns removed.
"""
training_data = {}

for table_name in tables:
columns_to_drop = []
columns_to_drop.extend(rel_data.get_primary_key(table_name))
for foreign_key in rel_data.get_foreign_keys(table_name):
columns_to_drop.extend(foreign_key.columns)
for table, path in table_paths.items():
columns_to_drop = set()
columns_to_drop.update(rel_data.get_primary_key(table))
for foreign_key in rel_data.get_foreign_keys(table):
columns_to_drop.update(foreign_key.columns)

data = rel_data.get_table_data(table_name)
data = data.drop(columns=columns_to_drop)
all_columns = rel_data.get_table_columns(table)
use_columns = all_columns - columns_to_drop

training_data[table_name] = data
rel_data.get_table_data(table, usecols=use_columns).to_csv(
path, index=False
)

return training_data
return table_paths

def tables_to_retrain(
self, tables: list[str], rel_data: RelationalData
Expand Down
113 changes: 81 additions & 32 deletions tests/relational/test_ancestral_strategy.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,45 +14,56 @@
from gretel_trainer.relational.table_evaluation import TableEvaluation


def test_preparing_training_data_does_not_mutate_source_data(pets, art):
for rel_data in [pets, art]:
original_tables = {
table: rel_data.get_table_data(table).copy()
for table in rel_data.list_all_tables()
}
def test_preparing_training_data_does_not_mutate_source_data(pets):
original_tables = {
table: pets.get_table_data(table).copy() for table in pets.list_all_tables()
}

strategy = AncestralStrategy()
strategy.prepare_training_data(rel_data, rel_data.list_all_tables())
strategy = AncestralStrategy()

for table in rel_data.list_all_tables():
pdtest.assert_frame_equal(
original_tables[table], rel_data.get_table_data(table)
)
with tempfile.NamedTemporaryFile() as pets_dest, tempfile.NamedTemporaryFile() as humans_dest:
strategy.prepare_training_data(
pets, {"pets": Path(pets_dest.name), "humans": Path(humans_dest.name)}
)

for table in pets.list_all_tables():
pdtest.assert_frame_equal(original_tables[table], pets.get_table_data(table))


def test_prepare_training_data_subset_of_tables(pets):
strategy = AncestralStrategy()

# We aren't synthesizing the "humans" table, so it is not in this list argument...
training_data = strategy.prepare_training_data(pets, ["pets"])
# ...nor do we create training data for it
assert set(training_data.keys()) == {"pets"}
with tempfile.NamedTemporaryFile() as pets_dest, tempfile.NamedTemporaryFile() as humans_dest:
# We aren't synthesizing the "humans" table, so it is not in this list argument...
training_data = strategy.prepare_training_data(
pets, {"pets": Path(pets_dest.name)}
)

train_pets = pd.read_csv(training_data["pets"])

# ...nor do we create training data for it
assert not train_pets.empty
assert os.stat(humans_dest.name).st_size == 0

# Since the humans table is omitted from synthetics, we leave the FK values alone; specifically:
# - they are not label-encoded (which would effectively zero-index them)
# - we do not add artificial min/max values
assert set(training_data["pets"]["self|human_id"].values) == {1, 2, 3, 4, 5}
assert set(train_pets["self|human_id"].values) == {1, 2, 3, 4, 5}
# We do add the artificial max PK row, though, since this table is being synthesized
assert len(training_data["pets"]) == 6
assert len(train_pets) == 6


def test_prepare_training_data_returns_multigenerational_data(pets):
strategy = AncestralStrategy()

training_data = strategy.prepare_training_data(pets, pets.list_all_tables())
with tempfile.NamedTemporaryFile() as pets_dest, tempfile.NamedTemporaryFile() as humans_dest:
training_data = strategy.prepare_training_data(
pets, {"pets": Path(pets_dest.name), "humans": Path(humans_dest.name)}
)
train_pets = pd.read_csv(training_data["pets"])

for expected_column in ["self|id", "self|name", "self.human_id|id"]:
assert expected_column in training_data["pets"]
assert expected_column in train_pets


def test_prepare_training_data_drops_highly_unique_categorical_ancestor_fields(art):
Expand All @@ -77,10 +88,19 @@ def test_prepare_training_data_drops_highly_unique_categorical_ancestor_fields(a
)

strategy = AncestralStrategy()
training_data = strategy.prepare_training_data(art, art.list_all_tables())

with tempfile.NamedTemporaryFile() as artists_dest, tempfile.NamedTemporaryFile() as paintings_dest:
training_data = strategy.prepare_training_data(
art,
{
"artists": Path(artists_dest.name),
"paintings": Path(paintings_dest.name),
},
)
train_paintings = pd.read_csv(training_data["paintings"])

# Does not contain `self.artist_id|name` because it is highly unique categorical
assert set(training_data["paintings"].columns) == {
assert set(train_paintings.columns) == {
"self|id",
"self|name",
"self|artist_id",
Expand Down Expand Up @@ -116,10 +136,19 @@ def test_prepare_training_data_drops_highly_nan_ancestor_fields(art):
)

strategy = AncestralStrategy()
training_data = strategy.prepare_training_data(art, art.list_all_tables())

with tempfile.NamedTemporaryFile() as artists_dest, tempfile.NamedTemporaryFile() as paintings_dest:
training_data = strategy.prepare_training_data(
art,
{
"artists": Path(artists_dest.name),
"paintings": Path(paintings_dest.name),
},
)
train_paintings = pd.read_csv(training_data["paintings"])

# Does not contain `self.artist_id|name` because it is highly NaN
assert set(training_data["paintings"].columns) == {
assert set(train_paintings.columns) == {
"self|id",
"self|name",
"self|artist_id",
Expand All @@ -131,22 +160,32 @@ def test_prepare_training_data_translates_alphanumeric_keys_and_adds_min_max_rec
art,
):
strategy = AncestralStrategy()
training_data = strategy.prepare_training_data(art, art.list_all_tables())

with tempfile.NamedTemporaryFile() as artists_dest, tempfile.NamedTemporaryFile() as paintings_dest:
training_data = strategy.prepare_training_data(
art,
{
"artists": Path(artists_dest.name),
"paintings": Path(paintings_dest.name),
},
)
train_artists = pd.read_csv(training_data["artists"])
train_paintings = pd.read_csv(training_data["paintings"])

# Artists, a parent table, should have 1 additional row
assert len(training_data["artists"]) == len(art.get_table_data("artists")) + 1
assert len(train_artists) == len(art.get_table_data("artists")) + 1
# The last record has the artifical max PK
assert training_data["artists"]["self|id"].to_list() == [0, 1, 2, 3, 200]
assert train_artists["self|id"].to_list() == [0, 1, 2, 3, 200]
# We do not assert on the value of "self|name" because the artificial max PK record is
# randomly sampled from source and so the exact value is not deterministic

# Paintings, as a child table, should have 3 additional rows
# - artificial max PK
# - artificial min FKs
# - artificial max FKs
assert len(training_data["paintings"]) == len(art.get_table_data("paintings")) + 3
assert len(train_paintings) == len(art.get_table_data("paintings")) + 3

last_three = training_data["paintings"].tail(3)
last_three = train_paintings.tail(3)
last_two = last_three.tail(2)

# PKs are max, +1, +2
Expand All @@ -158,15 +197,26 @@ def test_prepare_training_data_translates_alphanumeric_keys_and_adds_min_max_rec

def test_prepare_training_data_with_composite_keys(tpch):
strategy = AncestralStrategy()
training_data = strategy.prepare_training_data(tpch, tpch.list_all_tables())
with tempfile.NamedTemporaryFile() as supplier_dest, tempfile.NamedTemporaryFile() as part_dest, tempfile.NamedTemporaryFile() as partsupp_dest, tempfile.NamedTemporaryFile() as lineitem_dest:
training_data = strategy.prepare_training_data(
tpch,
{
"supplier": Path(supplier_dest.name),
"part": Path(part_dest.name),
"partsupp": Path(partsupp_dest.name),
"lineitem": Path(lineitem_dest.name),
},
)

train_partsupp = pd.read_csv(training_data["partsupp"])
train_lineitem = pd.read_csv(training_data["lineitem"])

l_max = len(tpch.get_table_data("lineitem")) * 50
ps_max = len(tpch.get_table_data("partsupp")) * 50
p_max = len(tpch.get_table_data("part")) * 50
s_max = len(tpch.get_table_data("supplier")) * 50

# partsupp table, composite PK
train_partsupp = training_data["partsupp"]
assert set(train_partsupp.columns) == {
"self|ps_partkey",
"self|ps_suppkey",
Expand All @@ -189,7 +239,6 @@ def test_prepare_training_data_with_composite_keys(tpch):
)

# lineitem table, composite FK to partsupp
train_lineitem = training_data["lineitem"]
assert set(train_lineitem.columns) == {
"self|l_partkey",
"self|l_suppkey",
Expand Down
Loading

0 comments on commit 7e462e2

Please sign in to comment.