Skip to content

Commit

Permalink
Synthetics config is optional, with strategy-specific defaults
Browse files Browse the repository at this point in the history
GitOrigin-RevId: 5f8d152ce1ce36748e4fccb1ee5d0ad75ba0b912
  • Loading branch information
mikeknep committed Sep 15, 2023
1 parent 8feb616 commit ed4ead7
Show file tree
Hide file tree
Showing 4 changed files with 40 additions and 7 deletions.
7 changes: 2 additions & 5 deletions src/gretel_trainer/relational/multi_table.py
Original file line number Diff line number Diff line change
Expand Up @@ -734,7 +734,7 @@ def _train_synthetics_models(self, configs: dict[str, dict[str, Any]]) -> None:
def train_synthetics(
self,
*,
config: GretelModelConfig,
config: Optional[GretelModelConfig] = None,
table_specific_configs: Optional[dict[str, GretelModelConfig]] = None,
only: Optional[set[str]] = None,
ignore: Optional[set[str]] = None,
Expand All @@ -743,11 +743,8 @@ def train_synthetics(
Train synthetic data models for the tables in the tableset,
optionally scoped by either `only` or `ignore`.
"""
# This argument used to be hinted as Optional with a None default and would
# fall back to a removed attribute on the MultiTable instance. Out of an
# abundance of caution, verify the client isn't relying on that old behavior.
if config is None:
raise MultiTableException(f"Must provide a default model config.")
config = self._strategy.default_config

configs = assemble_configs(
self.relational_data, config, table_specific_configs, only, ignore
Expand Down
10 changes: 9 additions & 1 deletion src/gretel_trainer/relational/strategies/ancestral.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,11 @@
import gretel_trainer.relational.strategies.common as common

from gretel_client.projects.models import Model
from gretel_trainer.relational.core import MultiTableException, RelationalData
from gretel_trainer.relational.core import (
GretelModelConfig,
MultiTableException,
RelationalData,
)
from gretel_trainer.relational.sdk_extras import ExtendedGretelSDK
from gretel_trainer.relational.table_evaluation import TableEvaluation

Expand All @@ -25,6 +29,10 @@ def name(self) -> str:
def supported_model_keys(self) -> list[str]:
return ["amplify"]

@property
def default_config(self) -> GretelModelConfig:
return "synthetics/amplify"

def label_encode_keys(
self, rel_data: RelationalData, tables: dict[str, pd.DataFrame]
) -> dict[str, pd.DataFrame]:
Expand Down
6 changes: 5 additions & 1 deletion src/gretel_trainer/relational/strategies/independent.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
import gretel_trainer.relational.strategies.common as common

from gretel_client.projects.models import Model
from gretel_trainer.relational.core import RelationalData
from gretel_trainer.relational.core import GretelModelConfig, RelationalData
from gretel_trainer.relational.sdk_extras import ExtendedGretelSDK
from gretel_trainer.relational.table_evaluation import TableEvaluation

Expand All @@ -26,6 +26,10 @@ def name(self) -> str:
def supported_model_keys(self) -> list[str]:
return ["amplify", "actgan", "synthetics", "tabular_dp"]

@property
def default_config(self) -> GretelModelConfig:
return "synthetics/tabular-actgan"

def label_encode_keys(
self, rel_data: RelationalData, tables: dict[str, pd.DataFrame]
) -> dict[str, pd.DataFrame]:
Expand Down
24 changes: 24 additions & 0 deletions tests/relational/test_train_synthetics.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,30 @@ def tmpdir(project):
yield tmpdir


class ModelConfigMatcher:
def __init__(self, model_key: str):
self.model_key = model_key

def __eq__(self, other):
return list(other["models"][0])[0] == self.model_key


def test_train_synthetics_strategy_specific_default_configs(pets, tmpdir, project):
mt = MultiTable(pets, strategy="independent", project_display_name=tmpdir)
mt.train_synthetics()
project.create_model_obj.assert_called_with(
model_config=ModelConfigMatcher("actgan"),
data_source=f"{tmpdir}/synthetics_train_pets.csv",
)

mt = MultiTable(pets, strategy="ancestral", project_display_name=tmpdir)
mt.train_synthetics()
project.create_model_obj.assert_called_with(
model_config=ModelConfigMatcher("amplify"),
data_source=f"{tmpdir}/synthetics_train_pets.csv",
)


def test_train_synthetics_defaults_to_training_all_tables(ecom, tmpdir):
mt = MultiTable(ecom, project_display_name=tmpdir)
mt.train_synthetics(config="synthetics/amplify")
Expand Down

0 comments on commit ed4ead7

Please sign in to comment.