Skip to content

Commit

Permalink
Allow Transform v2 configs in Relational Trainer
Browse files Browse the repository at this point in the history
GitOrigin-RevId: 3d2cd99109da9ce710a025ca65165a8cf8c92650
  • Loading branch information
mikeknep committed Oct 20, 2023
1 parent e0a51b8 commit bc02f80
Show file tree
Hide file tree
Showing 2 changed files with 31 additions and 12 deletions.
26 changes: 14 additions & 12 deletions src/gretel_trainer/relational/model_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@
RelationalData,
)

TRANSFORM_MODEL_KEYS = ["transform", "transforms", "transform_v2"]


def get_model_key(config_dict: dict[str, Any]) -> Optional[str]:
try:
Expand Down Expand Up @@ -56,20 +58,20 @@ def make_transform_config(
tailored_config = ingest(config)
tailored_config["name"] = _model_name("transforms", table)

model_key, model = next(iter(tailored_config["models"][0].items()))

# Ensure we have a transform config
if model_key not in TRANSFORM_MODEL_KEYS:
raise MultiTableException("Invalid transform config")

# Tv2 configs pass through unaltered (except for name, above)
if model_key == "transform_v2":
return tailored_config

# We add a passthrough policy to Tv1 configs to avoid transforming PK/FK columns
key_columns = rel_data.get_all_key_columns(table)
if len(key_columns) > 0:
try:
model = tailored_config["models"][0]
try:
model_key = "transform"
xform = model[model_key]
except KeyError:
model_key = "transforms"
xform = model[model_key]
policies = xform["policies"]
except KeyError:
raise MultiTableException("Invalid transform config")

policies = model["policies"]
passthrough_policy = _passthrough_policy(key_columns)
adjusted_policies = [passthrough_policy] + policies

Expand Down
17 changes: 17 additions & 0 deletions tests/relational/test_model_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,23 @@ def test_synthetics_config_handles_noncompliant_table_names():
assert config["name"] == "synthetics-hello__world"


def test_transform_requires_valid_config(mutagenesis):
with pytest.raises(MultiTableException):
make_transform_config(mutagenesis, "atom", "synthetics/amplify")


def test_transform_v2_config_is_unaltered(mutagenesis):
tv2_config = {
"schema_version": "1.0",
"name": "original-name",
"models": [{"transform_v2": {"some": "Tv2 config"}}],
}
config = make_transform_config(mutagenesis, "atom", tv2_config)
assert config["name"] == "transforms-atom"
assert config["schema_version"] == tv2_config["schema_version"]
assert config["models"] == tv2_config["models"]


def test_transforms_config_prepends_workflow(mutagenesis):
config = make_transform_config(mutagenesis, "atom", "transform/default")
assert config["name"] == "transforms-atom"
Expand Down

0 comments on commit bc02f80

Please sign in to comment.