Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Deprecation warnings + bugfix #122

Merged
merged 3 commits into from
Jun 14, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
54 changes: 35 additions & 19 deletions src/gretel_trainer/relational/multi_table.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,8 +85,7 @@ class MultiTable:

Args:
relational_data (RelationalData): Core data structure representing the source tables and their relationships.
strategy (str, optional): The strategy to use. Supports "independent" (default) and "ancestral".
gretel_model (str, optional): The underlying Gretel model to use. Default and acceptable models vary based on strategy.
strategy (str, optional): The strategy to use for synthetics. Supports "independent" (default) and "ancestral".
project_display_name (str, optional): Display name in the console for a new Gretel project holding models and artifacts. Defaults to "multi-table".
refresh_interval (int, optional): Frequency in seconds to poll Gretel Cloud for job statuses. Must be at least 30. Defaults to 60 (1m).
backup (Backup, optional): Should not be supplied manually; instead use the `restore` classmethod.
Expand All @@ -103,9 +102,20 @@ def __init__(
backup: Optional[Backup] = None,
):
self._strategy = _validate_strategy(strategy)
model_name, model_config = self._validate_gretel_model(gretel_model)
self._gretel_model = model_name
self._model_config = model_config
if gretel_model is not None:
logger.warning(
"The `gretel_model` argument is deprecated and will be removed in a future release. "
"Going forward you should provide a config to `train_synthetics`."
)
model_name, model_config = self._validate_gretel_model(gretel_model)
self._gretel_model = model_name
self._model_config = model_config
else:
# Set these to the original default for backwards compatibility.
# When we completely remove the `gretel_model` init param, these attrs can be removed as well.
# We don't need to validate here because the default model (amplify) works with both strategies.
self._gretel_model = "amplify"
self._model_config = "synthetics/amplify"
self._set_refresh_interval(refresh_interval)
self.relational_data = relational_data
self._artifact_collection = ArtifactCollection(hybrid=self._hybrid)
Expand Down Expand Up @@ -851,17 +861,17 @@ def train_synthetics(
# along the way informing the user of which required tables are missing).
self._strategy.validate_preserved_tables(omit_tables, self.relational_data)

# Translate any JSON-source tables in table_config_overrides to invented tables
overrides = {}
# Translate any JSON-source tables in table_specific_configs to invented tables
all_table_specific_configs = {}
for table, conf in (table_specific_configs or {}).items():
m_names = self.relational_data.get_modelable_table_names(table)
if len(m_names) == 0:
raise MultiTableException(f"Unrecognized table name: `{table}`")
overrides.update({m: conf for m in m_names})
all_table_specific_configs.update({m: conf for m in m_names})

# Ensure compatibility between only/ignore and table_specific_configs
omitted_tables_with_overrides_specified = []
for table in overrides:
for table in all_table_specific_configs:
if table in omit_tables:
omitted_tables_with_overrides_specified.append(table)
if len(omitted_tables_with_overrides_specified) > 0:
Expand All @@ -870,21 +880,28 @@ def train_synthetics(
f"{omitted_tables_with_overrides_specified}"
)

# Validate the provided config, or fall back to the one configured on self
# (gretel_model, which is validated against the strategy on MultiTable initialization)
# Validate the provided config
# Currently an optional argument for backwards compatibility; if None, fall back to the one configured
# on the MultiTable instance via the deprecated `gretel_model` parameter
if config is not None:
default_config_dict = self._validate_synthetics_config(config)
else:
logger.warning(
"Calling `train_synthetics` without specifying a `config` is deprecated; "
"in a future release, this argument will be required. "
"For now, falling back to the model configured on the MultiTable instance "
"(which is also deprecated and scheduled for removal)."
)
default_config_dict = ingest(self._model_config)

# Validate any override configs
override_config_dicts = {
# Validate any table-specific configs
table_specific_config_dicts = {
table: self._validate_synthetics_config(conf)
for table, conf in overrides.items()
for table, conf in all_table_specific_configs.items()
}

configs = {
table: override_config_dicts.get(table, default_config_dict)
table: table_specific_config_dicts.get(table, default_config_dict)
for table in include_tables
}

Expand Down Expand Up @@ -1126,9 +1143,8 @@ def _attach_existing_reports(self, run_id: str, table: str) -> None:
self._evaluations[table].individual_report_json = individual_report_json
self._evaluations[table].cross_table_report_json = cross_table_report_json

def _validate_gretel_model(self, gretel_model: Optional[str]) -> tuple[str, str]:
gretel_model = (gretel_model or self._strategy.default_model).lower()
supported_models = self._strategy.supported_models
def _validate_gretel_model(self, gretel_model: str) -> tuple[str, str]:
supported_models = self._strategy.supported_gretel_models
if gretel_model not in supported_models:
msg = f"Invalid gretel model requested: {gretel_model}. The selected strategy supports: {supported_models}."
logger.warning(msg)
Expand All @@ -1155,7 +1171,7 @@ def _validate_synthetics_config(self, config: GretelModelConfig) -> dict[str, An
if (model_key := get_model_key(config_dict)) is None:
raise MultiTableException("Invalid config")
else:
supported_models = self._strategy.supported_models
supported_models = self._strategy.supported_model_keys
if model_key not in supported_models:
raise MultiTableException(
f"Invalid gretel model requested: {model_key}. "
Expand Down
7 changes: 4 additions & 3 deletions src/gretel_trainer/relational/strategies/ancestral.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,12 +19,13 @@ class AncestralStrategy:
def name(self) -> str:
return "ancestral"

# TODO: remove when `gretel_model` param is removed
@property
def default_model(self) -> str:
return "amplify"
def supported_gretel_models(self) -> list[str]:
return ["amplify"]

@property
def supported_models(self) -> list[str]:
def supported_model_keys(self) -> list[str]:
return ["amplify"]

def label_encode_keys(
Expand Down
9 changes: 5 additions & 4 deletions src/gretel_trainer/relational/strategies/independent.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,13 +20,14 @@ class IndependentStrategy:
def name(self) -> str:
return "independent"

# TODO: remove when `gretel_model` param is removed
@property
def default_model(self) -> str:
return "amplify"
def supported_gretel_models(self) -> list[str]:
return ["amplify", "actgan", "lstm", "tabular-dp"]

@property
def supported_models(self) -> list[str]:
return ["amplify", "actgan", "lstm", "tabular-dp"]
def supported_model_keys(self) -> list[str]:
return ["amplify", "actgan", "synthetics", "tabular_dp"]

def label_encode_keys(
self, rel_data: RelationalData, tables: dict[str, pd.DataFrame]
Expand Down
31 changes: 29 additions & 2 deletions tests/relational/test_train_synthetics.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ def test_train_synthetics_custom_config_for_all_tables(ecom, tmpdir, project):

def test_train_synthetics_custom_configs_per_table(ecom, tmpdir, project):
mock_actgan_config = {"models": [{"actgan": {}}]}
mock_tabdp_config = {"models": [{"tabular-dp": {}}]}
mock_tabdp_config = {"models": [{"tabular_dp": {}}]}

# We set amplify on the MultiTable instance...
mt = MultiTable(ecom, project_display_name=tmpdir, gretel_model="amplify")
Expand All @@ -109,7 +109,7 @@ def test_train_synthetics_custom_configs_per_table(ecom, tmpdir, project):


def test_train_synthetics_table_config_and_mt_init_default(ecom, tmpdir, project):
mock_tabdp_config = {"models": [{"tabular-dp": {}}]}
mock_tabdp_config = {"models": [{"tabular_dp": {}}]}

# We set amplify on the MultiTable instance...
mt = MultiTable(ecom, project_display_name=tmpdir, gretel_model="amplify")
Expand Down Expand Up @@ -140,6 +140,33 @@ def __eq__(self, other):
return list(other["models"][0])[0] == "amplify"


def test_train_synthetics_validates_against_configured_strategy(pets, tmpdir):
# Independent strategy
mt_independent = MultiTable(
pets, project_display_name=tmpdir, strategy="independent"
)

mt_independent.train_synthetics(config="synthetics/tabular-lstm")
mt_independent.train_synthetics(config="synthetics/tabular-actgan")
mt_independent.train_synthetics(config="synthetics/amplify")
mt_independent.train_synthetics(config="synthetics/tabular-differential-privacy")
with pytest.raises(MultiTableException):
mt_independent.train_synthetics(config="synthetics/time-series")

# Ancestral strategy
mt_ancestral = MultiTable(pets, project_display_name=tmpdir, strategy="ancestral")

mt_ancestral.train_synthetics(config="synthetics/amplify")
with pytest.raises(MultiTableException):
mt_ancestral.train_synthetics(config="synthetics/tabular-lstm")
with pytest.raises(MultiTableException):
mt_ancestral.train_synthetics(config="synthetics/tabular-actgan")
with pytest.raises(MultiTableException):
mt_ancestral.train_synthetics(config="synthetics/tabular-differential-privacy")
with pytest.raises(MultiTableException):
mt_ancestral.train_synthetics(config="synthetics/time-series")


def test_train_synthetics_errors(ecom, tmpdir):
actgan_config = {"models": [{"actgan": {}}]}
mt = MultiTable(ecom, project_display_name=tmpdir)
Expand Down