diff --git a/src/gretel_trainer/relational/multi_table.py b/src/gretel_trainer/relational/multi_table.py index af54c5ca..2cf5436f 100644 --- a/src/gretel_trainer/relational/multi_table.py +++ b/src/gretel_trainer/relational/multi_table.py @@ -85,8 +85,7 @@ class MultiTable: Args: relational_data (RelationalData): Core data structure representing the source tables and their relationships. - strategy (str, optional): The strategy to use. Supports "independent" (default) and "ancestral". - gretel_model (str, optional): The underlying Gretel model to use. Default and acceptable models vary based on strategy. + strategy (str, optional): The strategy to use for synthetics. Supports "independent" (default) and "ancestral". project_display_name (str, optional): Display name in the console for a new Gretel project holding models and artifacts. Defaults to "multi-table". refresh_interval (int, optional): Frequency in seconds to poll Gretel Cloud for job statuses. Must be at least 30. Defaults to 60 (1m). backup (Backup, optional): Should not be supplied manually; instead use the `restore` classmethod. @@ -103,9 +102,20 @@ def __init__( backup: Optional[Backup] = None, ): self._strategy = _validate_strategy(strategy) - model_name, model_config = self._validate_gretel_model(gretel_model) - self._gretel_model = model_name - self._model_config = model_config + if gretel_model is not None: + logger.warning( + "The `gretel_model` argument is deprecated and will be removed in a future release. " + "Going forward you should provide a config to `train_synthetics`." + ) + model_name, model_config = self._validate_gretel_model(gretel_model) + self._gretel_model = model_name + self._model_config = model_config + else: + # Set these to the original default for backwards compatibility. + # When we completely remove the `gretel_model` init param, these attrs can be removed as well. + # We don't need to validate here because the default model (amplify) works with both strategies. + self._gretel_model = "amplify" + self._model_config = "synthetics/amplify" self._set_refresh_interval(refresh_interval) self.relational_data = relational_data self._artifact_collection = ArtifactCollection(hybrid=self._hybrid) @@ -851,17 +861,17 @@ def train_synthetics( # along the way informing the user of which required tables are missing). self._strategy.validate_preserved_tables(omit_tables, self.relational_data) - # Translate any JSON-source tables in table_config_overrides to invented tables - overrides = {} + # Translate any JSON-source tables in table_specific_configs to invented tables + all_table_specific_configs = {} for table, conf in (table_specific_configs or {}).items(): m_names = self.relational_data.get_modelable_table_names(table) if len(m_names) == 0: raise MultiTableException(f"Unrecognized table name: `{table}`") - overrides.update({m: conf for m in m_names}) + all_table_specific_configs.update({m: conf for m in m_names}) # Ensure compatibility between only/ignore and table_specific_configs omitted_tables_with_overrides_specified = [] - for table in overrides: + for table in all_table_specific_configs: if table in omit_tables: omitted_tables_with_overrides_specified.append(table) if len(omitted_tables_with_overrides_specified) > 0: @@ -870,21 +880,28 @@ def train_synthetics( f"{omitted_tables_with_overrides_specified}" ) - # Validate the provided config, or fall back to the one configured on self - # (gretel_model, which is validated against the strategy on MultiTable initialization) + # Validate the provided config + # Currently an optional argument for backwards compatibility; if None, fall back to the one configured + # on the MultiTable instance via the deprecated `gretel_model` parameter if config is not None: default_config_dict = self._validate_synthetics_config(config) else: + logger.warning( + "Calling `train_synthetics` without specifying a `config` is deprecated; " + "in a future release, this argument will be required. " + "For now, falling back to the model configured on the MultiTable instance " + "(which is also deprecated and scheduled for removal)." + ) default_config_dict = ingest(self._model_config) - # Validate any override configs - override_config_dicts = { + # Validate any table-specific configs + table_specific_config_dicts = { table: self._validate_synthetics_config(conf) - for table, conf in overrides.items() + for table, conf in all_table_specific_configs.items() } configs = { - table: override_config_dicts.get(table, default_config_dict) + table: table_specific_config_dicts.get(table, default_config_dict) for table in include_tables } @@ -1126,9 +1143,8 @@ def _attach_existing_reports(self, run_id: str, table: str) -> None: self._evaluations[table].individual_report_json = individual_report_json self._evaluations[table].cross_table_report_json = cross_table_report_json - def _validate_gretel_model(self, gretel_model: Optional[str]) -> tuple[str, str]: - gretel_model = (gretel_model or self._strategy.default_model).lower() - supported_models = self._strategy.supported_models + def _validate_gretel_model(self, gretel_model: str) -> tuple[str, str]: + supported_models = self._strategy.supported_gretel_models if gretel_model not in supported_models: msg = f"Invalid gretel model requested: {gretel_model}. The selected strategy supports: {supported_models}." logger.warning(msg) @@ -1155,7 +1171,7 @@ def _validate_synthetics_config(self, config: GretelModelConfig) -> dict[str, An if (model_key := get_model_key(config_dict)) is None: raise MultiTableException("Invalid config") else: - supported_models = self._strategy.supported_models + supported_models = self._strategy.supported_model_keys if model_key not in supported_models: raise MultiTableException( f"Invalid gretel model requested: {model_key}. " diff --git a/src/gretel_trainer/relational/strategies/ancestral.py b/src/gretel_trainer/relational/strategies/ancestral.py index 2c1c5ab9..7b143d06 100644 --- a/src/gretel_trainer/relational/strategies/ancestral.py +++ b/src/gretel_trainer/relational/strategies/ancestral.py @@ -19,12 +19,13 @@ class AncestralStrategy: def name(self) -> str: return "ancestral" + # TODO: remove when `gretel_model` param is removed @property - def default_model(self) -> str: - return "amplify" + def supported_gretel_models(self) -> list[str]: + return ["amplify"] @property - def supported_models(self) -> list[str]: + def supported_model_keys(self) -> list[str]: return ["amplify"] def label_encode_keys( diff --git a/src/gretel_trainer/relational/strategies/independent.py b/src/gretel_trainer/relational/strategies/independent.py index 0057acdb..4e8eb2f9 100644 --- a/src/gretel_trainer/relational/strategies/independent.py +++ b/src/gretel_trainer/relational/strategies/independent.py @@ -20,13 +20,14 @@ class IndependentStrategy: def name(self) -> str: return "independent" + # TODO: remove when `gretel_model` param is removed @property - def default_model(self) -> str: - return "amplify" + def supported_gretel_models(self) -> list[str]: + return ["amplify", "actgan", "lstm", "tabular-dp"] @property - def supported_models(self) -> list[str]: - return ["amplify", "actgan", "lstm", "tabular-dp"] + def supported_model_keys(self) -> list[str]: + return ["amplify", "actgan", "synthetics", "tabular_dp"] def label_encode_keys( self, rel_data: RelationalData, tables: dict[str, pd.DataFrame] diff --git a/tests/relational/test_train_synthetics.py b/tests/relational/test_train_synthetics.py index 4170d689..21174a32 100644 --- a/tests/relational/test_train_synthetics.py +++ b/tests/relational/test_train_synthetics.py @@ -85,7 +85,7 @@ def test_train_synthetics_custom_config_for_all_tables(ecom, tmpdir, project): def test_train_synthetics_custom_configs_per_table(ecom, tmpdir, project): mock_actgan_config = {"models": [{"actgan": {}}]} - mock_tabdp_config = {"models": [{"tabular-dp": {}}]} + mock_tabdp_config = {"models": [{"tabular_dp": {}}]} # We set amplify on the MultiTable instance... mt = MultiTable(ecom, project_display_name=tmpdir, gretel_model="amplify") @@ -109,7 +109,7 @@ def test_train_synthetics_custom_configs_per_table(ecom, tmpdir, project): def test_train_synthetics_table_config_and_mt_init_default(ecom, tmpdir, project): - mock_tabdp_config = {"models": [{"tabular-dp": {}}]} + mock_tabdp_config = {"models": [{"tabular_dp": {}}]} # We set amplify on the MultiTable instance... mt = MultiTable(ecom, project_display_name=tmpdir, gretel_model="amplify") @@ -140,6 +140,33 @@ def __eq__(self, other): return list(other["models"][0])[0] == "amplify" +def test_train_synthetics_validates_against_configured_strategy(pets, tmpdir): + # Independent strategy + mt_independent = MultiTable( + pets, project_display_name=tmpdir, strategy="independent" + ) + + mt_independent.train_synthetics(config="synthetics/tabular-lstm") + mt_independent.train_synthetics(config="synthetics/tabular-actgan") + mt_independent.train_synthetics(config="synthetics/amplify") + mt_independent.train_synthetics(config="synthetics/tabular-differential-privacy") + with pytest.raises(MultiTableException): + mt_independent.train_synthetics(config="synthetics/time-series") + + # Ancestral strategy + mt_ancestral = MultiTable(pets, project_display_name=tmpdir, strategy="ancestral") + + mt_ancestral.train_synthetics(config="synthetics/amplify") + with pytest.raises(MultiTableException): + mt_ancestral.train_synthetics(config="synthetics/tabular-lstm") + with pytest.raises(MultiTableException): + mt_ancestral.train_synthetics(config="synthetics/tabular-actgan") + with pytest.raises(MultiTableException): + mt_ancestral.train_synthetics(config="synthetics/tabular-differential-privacy") + with pytest.raises(MultiTableException): + mt_ancestral.train_synthetics(config="synthetics/time-series") + + def test_train_synthetics_errors(ecom, tmpdir): actgan_config = {"models": [{"actgan": {}}]} mt = MultiTable(ecom, project_display_name=tmpdir)