Skip to content

Commit

Permalink
Support table_specific_configs on train_transforms
Browse files Browse the repository at this point in the history
GitOrigin-RevId: 6aea00256448962abaa5d7c74d8fc86f4f23fd78
  • Loading branch information
mikeknep committed Sep 13, 2023
1 parent b162985 commit 8feb616
Show file tree
Hide file tree
Showing 6 changed files with 275 additions and 113 deletions.
12 changes: 0 additions & 12 deletions src/gretel_trainer/relational/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -807,18 +807,6 @@ def debug_summary(self) -> dict[str, Any]:
}


def skip_table(
table: str, only: Optional[set[str]], ignore: Optional[set[str]]
) -> bool:
skip = False
if only is not None and table not in only:
skip = True
if ignore is not None and table in ignore:
skip = True

return skip


def _ok_for_train_and_seed(col: str, df: pd.DataFrame) -> bool:
if _is_highly_nan(col, df):
return False
Expand Down
87 changes: 87 additions & 0 deletions src/gretel_trainer/relational/model_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,3 +93,90 @@ def _passthrough_policy(columns: list[str]) -> dict[str, Any]:
}
],
}


def assemble_configs(
rel_data: RelationalData,
config: GretelModelConfig,
table_specific_configs: Optional[dict[str, GretelModelConfig]],
only: Optional[set[str]],
ignore: Optional[set[str]],
) -> dict[str, Any]:
only, ignore = _expand_only_and_ignore(rel_data, only, ignore)

tables_in_scope = [
table
for table in rel_data.list_all_tables()
if not _skip_table(table, only, ignore)
]

# Standardize type of all provided models
config_dict = ingest(config)
table_specific_config_dicts = {
table: ingest(conf) for table, conf in (table_specific_configs or {}).items()
}

# Translate any JSON-source tables in table_specific_configs to invented tables
all_table_specific_config_dicts = {}
for table, conf in table_specific_config_dicts.items():
m_names = rel_data.get_modelable_table_names(table)
if len(m_names) == 0:
raise MultiTableException(f"Unrecognized table name: `{table}`")
for m_name in m_names:
all_table_specific_config_dicts[m_name] = table_specific_config_dicts.get(
m_name, conf
)

# Ensure compatibility between only/ignore and table_specific_configs
omitted_tables_with_overrides_specified = []
for table in all_table_specific_config_dicts:
if _skip_table(table, only, ignore):
omitted_tables_with_overrides_specified.append(table)
if len(omitted_tables_with_overrides_specified) > 0:
raise MultiTableException(
f"Cannot provide configs for tables that have been omitted from synthetics training: "
f"{omitted_tables_with_overrides_specified}"
)

return {
table: all_table_specific_config_dicts.get(table, config_dict)
for table in tables_in_scope
}


def _expand_only_and_ignore(
rel_data: RelationalData, only: Optional[set[str]], ignore: Optional[set[str]]
) -> tuple[Optional[set[str]], Optional[set[str]]]:
"""
Accepts the `only` and `ignore` parameter values as provided by the user and:
- ensures both are not set (must provide one or the other, or neither)
- translates any JSON-source tables to the invented tables
"""
if only is not None and ignore is not None:
raise MultiTableException("Cannot specify both `only` and `ignore`.")

modelable_tables = set()
for table in only or ignore or {}:
m_names = rel_data.get_modelable_table_names(table)
if len(m_names) == 0:
raise MultiTableException(f"Unrecognized table name: `{table}`")
modelable_tables.update(m_names)

if only is None:
return (None, modelable_tables)
elif ignore is None:
return (modelable_tables, None)
else:
return (None, None)


def _skip_table(
table: str, only: Optional[set[str]], ignore: Optional[set[str]]
) -> bool:
skip = False
if only is not None and table not in only:
skip = True
if ignore is not None and table in ignore:
skip = True

return skip
112 changes: 22 additions & 90 deletions src/gretel_trainer/relational/multi_table.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,14 +44,13 @@
MultiTableException,
RelationalData,
Scope,
skip_table,
UserFriendlyDataT,
)
from gretel_trainer.relational.json import InventedTableMetadata, ProducerMetadata
from gretel_trainer.relational.log import silent_logs
from gretel_trainer.relational.model_config import (
assemble_configs,
get_model_key,
ingest,
make_classify_config,
make_evaluate_config,
make_synthetics_config,
Expand Down Expand Up @@ -569,17 +568,13 @@ def train_transforms(
self,
config: GretelModelConfig,
*,
table_specific_configs: Optional[dict[str, GretelModelConfig]] = None,
only: Optional[set[str]] = None,
ignore: Optional[set[str]] = None,
) -> None:
only, ignore = self._get_only_and_ignore(only, ignore)

configs = {
table: config
for table in self.relational_data.list_all_tables()
if not skip_table(table, only, ignore)
}

configs = assemble_configs(
self.relational_data, config, table_specific_configs, only, ignore
)
self._setup_transforms_train_state(configs)
task = TransformsTrainTask(
transforms_train=self._transforms_train,
Expand Down Expand Up @@ -689,31 +684,6 @@ def run_transforms(
self._backup()
self.transform_output_tables = reshaped_tables

def _get_only_and_ignore(
self, only: Optional[set[str]], ignore: Optional[set[str]]
) -> tuple[Optional[set[str]], Optional[set[str]]]:
"""
Accepts the `only` and `ignore` parameter values as provided by the user and:
- ensures both are not set (must provide one or the other, or neither)
- translates any JSON-source tables to the invented tables
"""
if only is not None and ignore is not None:
raise MultiTableException("Cannot specify both `only` and `ignore`.")

modelable_tables = set()
for table in only or ignore or {}:
m_names = self.relational_data.get_modelable_table_names(table)
if len(m_names) == 0:
raise MultiTableException(f"Unrecognized table name: `{table}`")
modelable_tables.update(m_names)

if only is None:
return (None, modelable_tables)
elif ignore is None:
return (modelable_tables, None)
else:
return (None, None)

def _train_synthetics_models(self, configs: dict[str, dict[str, Any]]) -> None:
"""
Uses the configured strategy to prepare training data sources for each table,
Expand Down Expand Up @@ -773,59 +743,27 @@ def train_synthetics(
Train synthetic data models for the tables in the tableset,
optionally scoped by either `only` or `ignore`.
"""
only, ignore = self._get_only_and_ignore(only, ignore)

include_tables: list[str] = []
omit_tables: list[str] = []
for table in self.relational_data.list_all_tables():
if skip_table(table, only, ignore):
omit_tables.append(table)
else:
include_tables.append(table)

# TODO: Ancestral strategy requires that for each table omitted from synthetics ("preserved"),
# all its ancestors must also be omitted. In the future, it'd be nice to either find a way to
# eliminate this requirement completely, or (less ideal) allow incremental training of tables,
# e.g. train a few in one "batch", then a few more before generating (perhaps with some logs
# along the way informing the user of which required tables are missing).
self._strategy.validate_preserved_tables(omit_tables, self.relational_data)

# Translate any JSON-source tables in table_specific_configs to invented tables
all_table_specific_configs = {}
for table, conf in (table_specific_configs or {}).items():
m_names = self.relational_data.get_modelable_table_names(table)
if len(m_names) == 0:
raise MultiTableException(f"Unrecognized table name: `{table}`")
all_table_specific_configs.update({m: conf for m in m_names})

# Ensure compatibility between only/ignore and table_specific_configs
omitted_tables_with_overrides_specified = []
for table in all_table_specific_configs:
if table in omit_tables:
omitted_tables_with_overrides_specified.append(table)
if len(omitted_tables_with_overrides_specified) > 0:
raise MultiTableException(
f"Cannot provide configs for tables that have been omitted from synthetics training: "
f"{omitted_tables_with_overrides_specified}"
)

# This argument used to be hinted as Optional with a None default and would
# fall back to a removed attribute on the MultiTable instance. Out of an
# abundance of caution, verify the client isn't relying on that old behavior.
if config is None:
raise MultiTableException(f"Must provide a default model config.")

# Validate the provided configs
default_config_dict = self._validate_synthetics_config(config)
table_specific_config_dicts = {
table: self._validate_synthetics_config(conf)
for table, conf in all_table_specific_configs.items()
}
configs = assemble_configs(
self.relational_data, config, table_specific_configs, only, ignore
)

configs = {
table: table_specific_config_dicts.get(table, default_config_dict)
for table in include_tables
}
# validate table scope (preserved tables) against the strategy
excluded_tables = [
table
for table in self.relational_data.list_all_tables()
if table not in configs
]
self._strategy.validate_preserved_tables(excluded_tables, self.relational_data)

# validate all provided model configs are supported by the strategy
for conf in configs.values():
self._validate_synthetics_config(conf)

self._train_synthetics_models(configs)

Expand Down Expand Up @@ -1070,15 +1008,11 @@ def _attach_existing_reports(self, run_id: str, table: str) -> None:
self._evaluations[table].individual_report_json = individual_report_json
self._evaluations[table].cross_table_report_json = cross_table_report_json

def _validate_synthetics_config(self, config: GretelModelConfig) -> dict[str, Any]:
def _validate_synthetics_config(self, config_dict: dict[str, Any]) -> None:
"""
Validates that the provided config:
- has the general shape of a Gretel model config (or can be read into one, e.g. blueprints)
- is supported by the configured synthetics strategy
Returns the parsed config as read by read_model_config.
Validates that the provided config (in dict form)
is supported by the configured synthetics strategy
"""
config_dict = ingest(config)
if (model_key := get_model_key(config_dict)) is None:
raise MultiTableException("Invalid config")
else:
Expand All @@ -1089,8 +1023,6 @@ def _validate_synthetics_config(self, config: GretelModelConfig) -> dict[str, An
f"The selected strategy supports: {supported_models}."
)

return config_dict


def _validate_strategy(strategy: str) -> Union[IndependentStrategy, AncestralStrategy]:
strategy = strategy.lower()
Expand Down
11 changes: 11 additions & 0 deletions tests/relational/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,17 @@ def _get_invented_table_suffix(make_suffix_execution_number: int):
return f"invented_{str(make_suffix_execution_number)}"


@pytest.fixture
def invented_tables(get_invented_table_suffix) -> dict[str, str]:
return {
"purchases_root": f"purchases_{get_invented_table_suffix(1)}",
"purchases_data_years": f"purchases_{get_invented_table_suffix(2)}",
"bball_root": f"bball_{get_invented_table_suffix(1)}",
"bball_suspensions": f"bball_{get_invented_table_suffix(2)}",
"bball_teams": f"bball_{get_invented_table_suffix(3)}",
}


@pytest.fixture()
def project():
with patch(
Expand Down
Loading

0 comments on commit 8feb616

Please sign in to comment.