diff --git a/tests/relational/test_train_synthetics.py b/tests/relational/test_train_synthetics.py index ccbb784b..f805aa26 100644 --- a/tests/relational/test_train_synthetics.py +++ b/tests/relational/test_train_synthetics.py @@ -76,6 +76,28 @@ def test_train_synthetics_multiple_calls_additive(ecom, tmpdir): assert set(mt._synthetics_train.models.keys()) == {"products", "users"} +def test_train_synthetics_models_for_dbs_with_invented_tables(documents, tmpdir): + mt = MultiTable(documents, project_display_name=tmpdir) + mt.train_synthetics() + + assert set(mt._synthetics_train.models.keys()) == { + "users", + "payments", + "purchases-sfx", + "purchases-data-years-sfx", + } + + +def test_train_synthetics_table_filters_cascade_to_invented_tables(documents, tmpdir): + # When a user provides the ("public") name of a table that contained JSON and led + # to the creation of invented tables, we recognize that as implicitly applying to + # all the tables internally created from that source table. + mt = MultiTable(documents, project_display_name=tmpdir) + mt.train_synthetics(ignore={"purchases"}) + + assert set(mt._synthetics_train.models.keys()) == {"users", "payments"} + + def test_train_synthetics_multiple_calls_overwrite(ecom, tmpdir, project): project.create_model_obj.return_value = "m1"