diff --git a/src/gretel_trainer/relational/core.py b/src/gretel_trainer/relational/core.py index 63c3a2dd..c3aacccf 100644 --- a/src/gretel_trainer/relational/core.py +++ b/src/gretel_trainer/relational/core.py @@ -807,18 +807,6 @@ def debug_summary(self) -> dict[str, Any]: } -def skip_table( - table: str, only: Optional[set[str]], ignore: Optional[set[str]] -) -> bool: - skip = False - if only is not None and table not in only: - skip = True - if ignore is not None and table in ignore: - skip = True - - return skip - - def _ok_for_train_and_seed(col: str, df: pd.DataFrame) -> bool: if _is_highly_nan(col, df): return False diff --git a/src/gretel_trainer/relational/model_config.py b/src/gretel_trainer/relational/model_config.py index b43d3aee..b76b84bb 100644 --- a/src/gretel_trainer/relational/model_config.py +++ b/src/gretel_trainer/relational/model_config.py @@ -93,3 +93,90 @@ def _passthrough_policy(columns: list[str]) -> dict[str, Any]: } ], } + + +def assemble_configs( + rel_data: RelationalData, + config: GretelModelConfig, + table_specific_configs: Optional[dict[str, GretelModelConfig]], + only: Optional[set[str]], + ignore: Optional[set[str]], +) -> dict[str, Any]: + only, ignore = _expand_only_and_ignore(rel_data, only, ignore) + + tables_in_scope = [ + table + for table in rel_data.list_all_tables() + if not _skip_table(table, only, ignore) + ] + + # Standardize type of all provided models + config_dict = ingest(config) + table_specific_config_dicts = { + table: ingest(conf) for table, conf in (table_specific_configs or {}).items() + } + + # Translate any JSON-source tables in table_specific_configs to invented tables + all_table_specific_config_dicts = {} + for table, conf in table_specific_config_dicts.items(): + m_names = rel_data.get_modelable_table_names(table) + if len(m_names) == 0: + raise MultiTableException(f"Unrecognized table name: `{table}`") + for m_name in m_names: + all_table_specific_config_dicts[m_name] = table_specific_config_dicts.get( + m_name, conf + ) + + # Ensure compatibility between only/ignore and table_specific_configs + omitted_tables_with_overrides_specified = [] + for table in all_table_specific_config_dicts: + if _skip_table(table, only, ignore): + omitted_tables_with_overrides_specified.append(table) + if len(omitted_tables_with_overrides_specified) > 0: + raise MultiTableException( + f"Cannot provide configs for tables that have been omitted from synthetics training: " + f"{omitted_tables_with_overrides_specified}" + ) + + return { + table: all_table_specific_config_dicts.get(table, config_dict) + for table in tables_in_scope + } + + +def _expand_only_and_ignore( + rel_data: RelationalData, only: Optional[set[str]], ignore: Optional[set[str]] +) -> tuple[Optional[set[str]], Optional[set[str]]]: + """ + Accepts the `only` and `ignore` parameter values as provided by the user and: + - ensures both are not set (must provide one or the other, or neither) + - translates any JSON-source tables to the invented tables + """ + if only is not None and ignore is not None: + raise MultiTableException("Cannot specify both `only` and `ignore`.") + + modelable_tables = set() + for table in only or ignore or {}: + m_names = rel_data.get_modelable_table_names(table) + if len(m_names) == 0: + raise MultiTableException(f"Unrecognized table name: `{table}`") + modelable_tables.update(m_names) + + if only is None: + return (None, modelable_tables) + elif ignore is None: + return (modelable_tables, None) + else: + return (None, None) + + +def _skip_table( + table: str, only: Optional[set[str]], ignore: Optional[set[str]] +) -> bool: + skip = False + if only is not None and table not in only: + skip = True + if ignore is not None and table in ignore: + skip = True + + return skip diff --git a/src/gretel_trainer/relational/multi_table.py b/src/gretel_trainer/relational/multi_table.py index 4b47a4a9..326d049a 100644 --- a/src/gretel_trainer/relational/multi_table.py +++ b/src/gretel_trainer/relational/multi_table.py @@ -44,14 +44,13 @@ MultiTableException, RelationalData, Scope, - skip_table, UserFriendlyDataT, ) from gretel_trainer.relational.json import InventedTableMetadata, ProducerMetadata from gretel_trainer.relational.log import silent_logs from gretel_trainer.relational.model_config import ( + assemble_configs, get_model_key, - ingest, make_classify_config, make_evaluate_config, make_synthetics_config, @@ -569,17 +568,13 @@ def train_transforms( self, config: GretelModelConfig, *, + table_specific_configs: Optional[dict[str, GretelModelConfig]] = None, only: Optional[set[str]] = None, ignore: Optional[set[str]] = None, ) -> None: - only, ignore = self._get_only_and_ignore(only, ignore) - - configs = { - table: config - for table in self.relational_data.list_all_tables() - if not skip_table(table, only, ignore) - } - + configs = assemble_configs( + self.relational_data, config, table_specific_configs, only, ignore + ) self._setup_transforms_train_state(configs) task = TransformsTrainTask( transforms_train=self._transforms_train, @@ -689,31 +684,6 @@ def run_transforms( self._backup() self.transform_output_tables = reshaped_tables - def _get_only_and_ignore( - self, only: Optional[set[str]], ignore: Optional[set[str]] - ) -> tuple[Optional[set[str]], Optional[set[str]]]: - """ - Accepts the `only` and `ignore` parameter values as provided by the user and: - - ensures both are not set (must provide one or the other, or neither) - - translates any JSON-source tables to the invented tables - """ - if only is not None and ignore is not None: - raise MultiTableException("Cannot specify both `only` and `ignore`.") - - modelable_tables = set() - for table in only or ignore or {}: - m_names = self.relational_data.get_modelable_table_names(table) - if len(m_names) == 0: - raise MultiTableException(f"Unrecognized table name: `{table}`") - modelable_tables.update(m_names) - - if only is None: - return (None, modelable_tables) - elif ignore is None: - return (modelable_tables, None) - else: - return (None, None) - def _train_synthetics_models(self, configs: dict[str, dict[str, Any]]) -> None: """ Uses the configured strategy to prepare training data sources for each table, @@ -773,59 +743,27 @@ def train_synthetics( Train synthetic data models for the tables in the tableset, optionally scoped by either `only` or `ignore`. """ - only, ignore = self._get_only_and_ignore(only, ignore) - - include_tables: list[str] = [] - omit_tables: list[str] = [] - for table in self.relational_data.list_all_tables(): - if skip_table(table, only, ignore): - omit_tables.append(table) - else: - include_tables.append(table) - - # TODO: Ancestral strategy requires that for each table omitted from synthetics ("preserved"), - # all its ancestors must also be omitted. In the future, it'd be nice to either find a way to - # eliminate this requirement completely, or (less ideal) allow incremental training of tables, - # e.g. train a few in one "batch", then a few more before generating (perhaps with some logs - # along the way informing the user of which required tables are missing). - self._strategy.validate_preserved_tables(omit_tables, self.relational_data) - - # Translate any JSON-source tables in table_specific_configs to invented tables - all_table_specific_configs = {} - for table, conf in (table_specific_configs or {}).items(): - m_names = self.relational_data.get_modelable_table_names(table) - if len(m_names) == 0: - raise MultiTableException(f"Unrecognized table name: `{table}`") - all_table_specific_configs.update({m: conf for m in m_names}) - - # Ensure compatibility between only/ignore and table_specific_configs - omitted_tables_with_overrides_specified = [] - for table in all_table_specific_configs: - if table in omit_tables: - omitted_tables_with_overrides_specified.append(table) - if len(omitted_tables_with_overrides_specified) > 0: - raise MultiTableException( - f"Cannot provide configs for tables that have been omitted from synthetics training: " - f"{omitted_tables_with_overrides_specified}" - ) - # This argument used to be hinted as Optional with a None default and would # fall back to a removed attribute on the MultiTable instance. Out of an # abundance of caution, verify the client isn't relying on that old behavior. if config is None: raise MultiTableException(f"Must provide a default model config.") - # Validate the provided configs - default_config_dict = self._validate_synthetics_config(config) - table_specific_config_dicts = { - table: self._validate_synthetics_config(conf) - for table, conf in all_table_specific_configs.items() - } + configs = assemble_configs( + self.relational_data, config, table_specific_configs, only, ignore + ) - configs = { - table: table_specific_config_dicts.get(table, default_config_dict) - for table in include_tables - } + # validate table scope (preserved tables) against the strategy + excluded_tables = [ + table + for table in self.relational_data.list_all_tables() + if table not in configs + ] + self._strategy.validate_preserved_tables(excluded_tables, self.relational_data) + + # validate all provided model configs are supported by the strategy + for conf in configs.values(): + self._validate_synthetics_config(conf) self._train_synthetics_models(configs) @@ -1070,15 +1008,11 @@ def _attach_existing_reports(self, run_id: str, table: str) -> None: self._evaluations[table].individual_report_json = individual_report_json self._evaluations[table].cross_table_report_json = cross_table_report_json - def _validate_synthetics_config(self, config: GretelModelConfig) -> dict[str, Any]: + def _validate_synthetics_config(self, config_dict: dict[str, Any]) -> None: """ - Validates that the provided config: - - has the general shape of a Gretel model config (or can be read into one, e.g. blueprints) - - is supported by the configured synthetics strategy - - Returns the parsed config as read by read_model_config. + Validates that the provided config (in dict form) + is supported by the configured synthetics strategy """ - config_dict = ingest(config) if (model_key := get_model_key(config_dict)) is None: raise MultiTableException("Invalid config") else: @@ -1089,8 +1023,6 @@ def _validate_synthetics_config(self, config: GretelModelConfig) -> dict[str, An f"The selected strategy supports: {supported_models}." ) - return config_dict - def _validate_strategy(strategy: str) -> Union[IndependentStrategy, AncestralStrategy]: strategy = strategy.lower() diff --git a/tests/relational/conftest.py b/tests/relational/conftest.py index 4a06cd0c..aeb1b4dc 100644 --- a/tests/relational/conftest.py +++ b/tests/relational/conftest.py @@ -43,6 +43,17 @@ def _get_invented_table_suffix(make_suffix_execution_number: int): return f"invented_{str(make_suffix_execution_number)}" +@pytest.fixture +def invented_tables(get_invented_table_suffix) -> dict[str, str]: + return { + "purchases_root": f"purchases_{get_invented_table_suffix(1)}", + "purchases_data_years": f"purchases_{get_invented_table_suffix(2)}", + "bball_root": f"bball_{get_invented_table_suffix(1)}", + "bball_suspensions": f"bball_{get_invented_table_suffix(2)}", + "bball_teams": f"bball_{get_invented_table_suffix(3)}", + } + + @pytest.fixture() def project(): with patch( diff --git a/tests/relational/test_model_config.py b/tests/relational/test_model_config.py index 80018162..3fbb3c8e 100644 --- a/tests/relational/test_model_config.py +++ b/tests/relational/test_model_config.py @@ -1,5 +1,9 @@ +import pytest + from gretel_client.projects.models import read_model_config +from gretel_trainer.relational.core import MultiTableException from gretel_trainer.relational.model_config import ( + assemble_configs, get_model_key, make_evaluate_config, make_synthetics_config, @@ -66,3 +70,154 @@ def get_policies(config): } ], } + + +_ACTGAN_CONFIG = {"models": [{"actgan": {}}]} +_LSTM_CONFIG = {"models": [{"synthetics": {}}]} +_TABULAR_DP_CONFIG = {"models": [{"tabular_dp": {}}]} + + +def test_assemble_configs(ecom): + # Apply a config to all tables + configs = assemble_configs( + rel_data=ecom, + config=_ACTGAN_CONFIG, + table_specific_configs=None, + only=None, + ignore=None, + ) + assert len(configs) == len(ecom.list_all_tables()) + assert all([config == _ACTGAN_CONFIG for config in configs.values()]) + + # Limit scope to a subset of tables + configs = assemble_configs( + rel_data=ecom, + config=_ACTGAN_CONFIG, + only={"events", "users"}, + table_specific_configs=None, + ignore=None, + ) + assert len(configs) == 2 + + # Exclude a table + configs = assemble_configs( + rel_data=ecom, + config=_ACTGAN_CONFIG, + ignore={"events"}, + table_specific_configs=None, + only=None, + ) + assert len(configs) == len(ecom.list_all_tables()) - 1 + + # Cannot specify both only and ignore + with pytest.raises(MultiTableException): + assemble_configs( + rel_data=ecom, + config=_ACTGAN_CONFIG, + only={"users"}, + ignore={"events"}, + table_specific_configs=None, + ) + + # Provide table-specific configs + configs = assemble_configs( + rel_data=ecom, + config=_ACTGAN_CONFIG, + table_specific_configs={"events": _LSTM_CONFIG}, + only=None, + ignore=None, + ) + assert configs["events"] == _LSTM_CONFIG + assert all( + [ + config == _ACTGAN_CONFIG + for table, config in configs.items() + if table != "events" + ] + ) + + # Ensure no conflicts between table-specific configs and scope + with pytest.raises(MultiTableException): + assemble_configs( + rel_data=ecom, + config=_ACTGAN_CONFIG, + table_specific_configs={"events": _LSTM_CONFIG}, + ignore={"events"}, + only=None, + ) + with pytest.raises(MultiTableException): + assemble_configs( + rel_data=ecom, + config=_ACTGAN_CONFIG, + table_specific_configs={"events": _LSTM_CONFIG}, + only={"users"}, + ignore=None, + ) + + +def test_assemble_configs_json(documents, invented_tables): + # If table_specific_configs includes a producer table, we apply it to all invented tables + configs = assemble_configs( + rel_data=documents, + config=_ACTGAN_CONFIG, + table_specific_configs={"purchases": _LSTM_CONFIG}, + only=None, + ignore=None, + ) + assert configs == { + "users": _ACTGAN_CONFIG, + "payments": _ACTGAN_CONFIG, + invented_tables["purchases_root"]: _LSTM_CONFIG, + invented_tables["purchases_data_years"]: _LSTM_CONFIG, + } + + # If table_specific_configs includes a producer table AND an invented table, + # the more specific config takes precedence. + configs = assemble_configs( + rel_data=documents, + config=_ACTGAN_CONFIG, + table_specific_configs={ + "purchases": _LSTM_CONFIG, + invented_tables["purchases_data_years"]: _TABULAR_DP_CONFIG, + }, + only=None, + ignore=None, + ) + assert configs == { + "users": _ACTGAN_CONFIG, + "payments": _ACTGAN_CONFIG, + invented_tables["purchases_root"]: _LSTM_CONFIG, + invented_tables["purchases_data_years"]: _TABULAR_DP_CONFIG, + } + + # Ensure no conflicts between (invented) table-specific configs and scope + with pytest.raises(MultiTableException): + assemble_configs( + rel_data=documents, + config=_ACTGAN_CONFIG, + table_specific_configs={ + "purchases": _LSTM_CONFIG, + }, + ignore={"purchases"}, + only=None, + ) + with pytest.raises(MultiTableException): + assemble_configs( + rel_data=documents, + config=_ACTGAN_CONFIG, + table_specific_configs={ + "purchases": _LSTM_CONFIG, + }, + ignore={invented_tables["purchases_root"]}, + only=None, + ) + with pytest.raises(MultiTableException): + assemble_configs( + rel_data=documents, + config=_ACTGAN_CONFIG, + table_specific_configs={ + invented_tables["purchases_root"]: _LSTM_CONFIG, + }, + ignore={"purchases"}, + only=None, + ) diff --git a/tests/relational/test_relational_data_with_json.py b/tests/relational/test_relational_data_with_json.py index 83cfb1ff..491ff86d 100644 --- a/tests/relational/test_relational_data_with_json.py +++ b/tests/relational/test_relational_data_with_json.py @@ -10,17 +10,6 @@ from gretel_trainer.relational.json import generate_unique_table_name, get_json_columns -@pytest.fixture -def invented_tables(get_invented_table_suffix) -> dict[str, str]: - return { - "purchases_root": f"purchases_{get_invented_table_suffix(1)}", - "purchases_data_years": f"purchases_{get_invented_table_suffix(2)}", - "bball_root": f"bball_{get_invented_table_suffix(1)}", - "bball_suspensions": f"bball_{get_invented_table_suffix(2)}", - "bball_teams": f"bball_{get_invented_table_suffix(3)}", - } - - @pytest.fixture def bball(tmpdir): bball_jsonl = """