Skip to content

Commit

Permalink
Document behavior of multiple calls
Browse files Browse the repository at this point in the history
  • Loading branch information
mikeknep committed May 18, 2023
1 parent 6109733 commit ced662f
Showing 1 changed file with 25 additions and 0 deletions.
25 changes: 25 additions & 0 deletions tests/relational/test_train_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,31 @@ def test_train_transforms_exits_early_if_unrecognized_tables(ecom, tmpdir, proje
project.create_model_obj.assert_not_called()


def test_train_transforms_multiple_calls_additive(ecom, tmpdir):
mt = MultiTable(ecom, project_display_name=tmpdir)
mt.train_transforms("transform/default", only=["products"])
mt.train_transforms("transform/default", only=["users"])

# We do not lose the first table model
assert set(mt._transforms_train.models.keys()) == {"products", "users"}


def test_train_transforms_multiple_calls_overwrite(ecom, tmpdir, project):
project.create_model_obj.return_value = "m1"

mt = MultiTable(ecom, project_display_name=tmpdir)
mt.train_transforms("transform/default", only=["products"])

assert mt._transforms_train.models["products"] == "m1"

project.reset_mock()
project.create_model_obj.return_value = "m2"

# calling a second time will create a new model for the table that overwrites the original
mt.train_transforms("transform/default", only=["products"])
assert mt._transforms_train.models["products"] == "m2"


# The public method under test here is deprecated
def test_train_transform_models(ecom, tmpdir):
mt = MultiTable(ecom, project_display_name=tmpdir)
Expand Down

0 comments on commit ced662f

Please sign in to comment.