Skip to content

Commit

Permalink
Deprecation warnings + bugfix (#122)
Browse files Browse the repository at this point in the history
  • Loading branch information
mikeknep authored Jun 14, 2023
1 parent 0888bc8 commit 76486c3
Show file tree
Hide file tree
Showing 4 changed files with 73 additions and 28 deletions.
54 changes: 35 additions & 19 deletions src/gretel_trainer/relational/multi_table.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,8 +85,7 @@ class MultiTable:
Args:
relational_data (RelationalData): Core data structure representing the source tables and their relationships.
strategy (str, optional): The strategy to use. Supports "independent" (default) and "ancestral".
gretel_model (str, optional): The underlying Gretel model to use. Default and acceptable models vary based on strategy.
strategy (str, optional): The strategy to use for synthetics. Supports "independent" (default) and "ancestral".
project_display_name (str, optional): Display name in the console for a new Gretel project holding models and artifacts. Defaults to "multi-table".
refresh_interval (int, optional): Frequency in seconds to poll Gretel Cloud for job statuses. Must be at least 30. Defaults to 60 (1m).
backup (Backup, optional): Should not be supplied manually; instead use the `restore` classmethod.
Expand All @@ -103,9 +102,20 @@ def __init__(
backup: Optional[Backup] = None,
):
self._strategy = _validate_strategy(strategy)
model_name, model_config = self._validate_gretel_model(gretel_model)
self._gretel_model = model_name
self._model_config = model_config
if gretel_model is not None:
logger.warning(
"The `gretel_model` argument is deprecated and will be removed in a future release. "
"Going forward you should provide a config to `train_synthetics`."
)
model_name, model_config = self._validate_gretel_model(gretel_model)
self._gretel_model = model_name
self._model_config = model_config
else:
# Set these to the original default for backwards compatibility.
# When we completely remove the `gretel_model` init param, these attrs can be removed as well.
# We don't need to validate here because the default model (amplify) works with both strategies.
self._gretel_model = "amplify"
self._model_config = "synthetics/amplify"
self._set_refresh_interval(refresh_interval)
self.relational_data = relational_data
self._artifact_collection = ArtifactCollection(hybrid=self._hybrid)
Expand Down Expand Up @@ -851,17 +861,17 @@ def train_synthetics(
# along the way informing the user of which required tables are missing).
self._strategy.validate_preserved_tables(omit_tables, self.relational_data)

# Translate any JSON-source tables in table_config_overrides to invented tables
overrides = {}
# Translate any JSON-source tables in table_specific_configs to invented tables
all_table_specific_configs = {}
for table, conf in (table_specific_configs or {}).items():
m_names = self.relational_data.get_modelable_table_names(table)
if len(m_names) == 0:
raise MultiTableException(f"Unrecognized table name: `{table}`")
overrides.update({m: conf for m in m_names})
all_table_specific_configs.update({m: conf for m in m_names})

# Ensure compatibility between only/ignore and table_specific_configs
omitted_tables_with_overrides_specified = []
for table in overrides:
for table in all_table_specific_configs:
if table in omit_tables:
omitted_tables_with_overrides_specified.append(table)
if len(omitted_tables_with_overrides_specified) > 0:
Expand All @@ -870,21 +880,28 @@ def train_synthetics(
f"{omitted_tables_with_overrides_specified}"
)

# Validate the provided config, or fall back to the one configured on self
# (gretel_model, which is validated against the strategy on MultiTable initialization)
# Validate the provided config
# Currently an optional argument for backwards compatibility; if None, fall back to the one configured
# on the MultiTable instance via the deprecated `gretel_model` parameter
if config is not None:
default_config_dict = self._validate_synthetics_config(config)
else:
logger.warning(
"Calling `train_synthetics` without specifying a `config` is deprecated; "
"in a future release, this argument will be required. "
"For now, falling back to the model configured on the MultiTable instance "
"(which is also deprecated and scheduled for removal)."
)
default_config_dict = ingest(self._model_config)

# Validate any override configs
override_config_dicts = {
# Validate any table-specific configs
table_specific_config_dicts = {
table: self._validate_synthetics_config(conf)
for table, conf in overrides.items()
for table, conf in all_table_specific_configs.items()
}

configs = {
table: override_config_dicts.get(table, default_config_dict)
table: table_specific_config_dicts.get(table, default_config_dict)
for table in include_tables
}

Expand Down Expand Up @@ -1126,9 +1143,8 @@ def _attach_existing_reports(self, run_id: str, table: str) -> None:
self._evaluations[table].individual_report_json = individual_report_json
self._evaluations[table].cross_table_report_json = cross_table_report_json

def _validate_gretel_model(self, gretel_model: Optional[str]) -> tuple[str, str]:
gretel_model = (gretel_model or self._strategy.default_model).lower()
supported_models = self._strategy.supported_models
def _validate_gretel_model(self, gretel_model: str) -> tuple[str, str]:
supported_models = self._strategy.supported_gretel_models
if gretel_model not in supported_models:
msg = f"Invalid gretel model requested: {gretel_model}. The selected strategy supports: {supported_models}."
logger.warning(msg)
Expand All @@ -1155,7 +1171,7 @@ def _validate_synthetics_config(self, config: GretelModelConfig) -> dict[str, An
if (model_key := get_model_key(config_dict)) is None:
raise MultiTableException("Invalid config")
else:
supported_models = self._strategy.supported_models
supported_models = self._strategy.supported_model_keys
if model_key not in supported_models:
raise MultiTableException(
f"Invalid gretel model requested: {model_key}. "
Expand Down
7 changes: 4 additions & 3 deletions src/gretel_trainer/relational/strategies/ancestral.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,12 +19,13 @@ class AncestralStrategy:
def name(self) -> str:
return "ancestral"

# TODO: remove when `gretel_model` param is removed
@property
def default_model(self) -> str:
return "amplify"
def supported_gretel_models(self) -> list[str]:
return ["amplify"]

@property
def supported_models(self) -> list[str]:
def supported_model_keys(self) -> list[str]:
return ["amplify"]

def label_encode_keys(
Expand Down
9 changes: 5 additions & 4 deletions src/gretel_trainer/relational/strategies/independent.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,13 +20,14 @@ class IndependentStrategy:
def name(self) -> str:
return "independent"

# TODO: remove when `gretel_model` param is removed
@property
def default_model(self) -> str:
return "amplify"
def supported_gretel_models(self) -> list[str]:
return ["amplify", "actgan", "lstm", "tabular-dp"]

@property
def supported_models(self) -> list[str]:
return ["amplify", "actgan", "lstm", "tabular-dp"]
def supported_model_keys(self) -> list[str]:
return ["amplify", "actgan", "synthetics", "tabular_dp"]

def label_encode_keys(
self, rel_data: RelationalData, tables: dict[str, pd.DataFrame]
Expand Down
31 changes: 29 additions & 2 deletions tests/relational/test_train_synthetics.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ def test_train_synthetics_custom_config_for_all_tables(ecom, tmpdir, project):

def test_train_synthetics_custom_configs_per_table(ecom, tmpdir, project):
mock_actgan_config = {"models": [{"actgan": {}}]}
mock_tabdp_config = {"models": [{"tabular-dp": {}}]}
mock_tabdp_config = {"models": [{"tabular_dp": {}}]}

# We set amplify on the MultiTable instance...
mt = MultiTable(ecom, project_display_name=tmpdir, gretel_model="amplify")
Expand All @@ -109,7 +109,7 @@ def test_train_synthetics_custom_configs_per_table(ecom, tmpdir, project):


def test_train_synthetics_table_config_and_mt_init_default(ecom, tmpdir, project):
mock_tabdp_config = {"models": [{"tabular-dp": {}}]}
mock_tabdp_config = {"models": [{"tabular_dp": {}}]}

# We set amplify on the MultiTable instance...
mt = MultiTable(ecom, project_display_name=tmpdir, gretel_model="amplify")
Expand Down Expand Up @@ -140,6 +140,33 @@ def __eq__(self, other):
return list(other["models"][0])[0] == "amplify"


def test_train_synthetics_validates_against_configured_strategy(pets, tmpdir):
# Independent strategy
mt_independent = MultiTable(
pets, project_display_name=tmpdir, strategy="independent"
)

mt_independent.train_synthetics(config="synthetics/tabular-lstm")
mt_independent.train_synthetics(config="synthetics/tabular-actgan")
mt_independent.train_synthetics(config="synthetics/amplify")
mt_independent.train_synthetics(config="synthetics/tabular-differential-privacy")
with pytest.raises(MultiTableException):
mt_independent.train_synthetics(config="synthetics/time-series")

# Ancestral strategy
mt_ancestral = MultiTable(pets, project_display_name=tmpdir, strategy="ancestral")

mt_ancestral.train_synthetics(config="synthetics/amplify")
with pytest.raises(MultiTableException):
mt_ancestral.train_synthetics(config="synthetics/tabular-lstm")
with pytest.raises(MultiTableException):
mt_ancestral.train_synthetics(config="synthetics/tabular-actgan")
with pytest.raises(MultiTableException):
mt_ancestral.train_synthetics(config="synthetics/tabular-differential-privacy")
with pytest.raises(MultiTableException):
mt_ancestral.train_synthetics(config="synthetics/time-series")


def test_train_synthetics_errors(ecom, tmpdir):
actgan_config = {"models": [{"actgan": {}}]}
mt = MultiTable(ecom, project_display_name=tmpdir)
Expand Down

0 comments on commit 76486c3

Please sign in to comment.