Skip to content

Commit

Permalink
Use sets for only/ignore
Browse files Browse the repository at this point in the history
  • Loading branch information
mikeknep committed May 26, 2023
1 parent 8360d3a commit 145686d
Show file tree
Hide file tree
Showing 6 changed files with 33 additions and 33 deletions.
2 changes: 1 addition & 1 deletion src/gretel_trainer/relational/connectors.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ def __init__(self, engine: Engine):
logger.info("Successfully connected to db")

def extract(
self, only: Optional[list[str]] = None, ignore: Optional[list[str]] = None
self, only: Optional[set[str]] = None, ignore: Optional[set[str]] = None
) -> RelationalData:
"""
Extracts table data and relationships from the database.
Expand Down
2 changes: 1 addition & 1 deletion src/gretel_trainer/relational/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -665,7 +665,7 @@ def debug_summary(self) -> dict[str, Any]:


def skip_table(
table: str, only: Optional[list[str]], ignore: Optional[list[str]]
table: str, only: Optional[set[str]], ignore: Optional[set[str]]
) -> bool:
skip = False
if only is not None and table not in only:
Expand Down
16 changes: 8 additions & 8 deletions src/gretel_trainer/relational/multi_table.py
Original file line number Diff line number Diff line change
Expand Up @@ -514,8 +514,8 @@ def delete_models(
self,
workflow: str,
*,
only: Optional[list[str]] = None,
ignore: Optional[list[str]] = None,
only: Optional[set[str]] = None,
ignore: Optional[set[str]] = None,
) -> None:
only, ignore = self._get_only_and_ignore(only, ignore)
tables = [
Expand Down Expand Up @@ -625,8 +625,8 @@ def train_transforms(
self,
config: GretelModelConfig,
*,
only: Optional[list[str]] = None,
ignore: Optional[list[str]] = None,
only: Optional[set[str]] = None,
ignore: Optional[set[str]] = None,
) -> None:
only, ignore = self._get_only_and_ignore(only, ignore)

Expand Down Expand Up @@ -727,8 +727,8 @@ def run_transforms(
self.transform_output_tables = reshaped_tables

def _get_only_and_ignore(
self, only: Optional[list[str]], ignore: Optional[list[str]]
) -> tuple[Optional[list[str]], Optional[list[str]]]:
self, only: Optional[set[str]], ignore: Optional[set[str]]
) -> tuple[Optional[set[str]], Optional[set[str]]]:
if only is not None and ignore is not None:
raise MultiTableException("Cannot specify both `only` and `ignore`.")

Expand Down Expand Up @@ -819,8 +819,8 @@ def train(self) -> None:
def train_synthetics(
self,
*,
only: Optional[list[str]] = None,
ignore: Optional[list[str]] = None,
only: Optional[set[str]] = None,
ignore: Optional[set[str]] = None,
) -> None:
"""
Train synthetic data models for the tables in the tableset,
Expand Down
6 changes: 3 additions & 3 deletions tests/relational/test_connectors.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,11 @@ def test_extract_subsets_of_relational_data(example_dbs):
connector = sqlite_conn(f.name)

with pytest.raises(MultiTableException):
connector.extract(only=["users"], ignore=["events"])
connector.extract(only={"users"}, ignore={"events"})

only = connector.extract(only=["users", "events", "products"])
only = connector.extract(only={"users", "events", "products"})
ignore = connector.extract(
ignore=["distribution_center", "order_items", "inventory_items"]
ignore={"distribution_center", "order_items", "inventory_items"}
)

expected_tables = {"users", "events", "products"}
Expand Down
22 changes: 11 additions & 11 deletions tests/relational/test_train_synthetics.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ def test_train_synthetics_defaults_to_training_all_tables(ecom, tmpdir):

def test_train_synthetics_only_includes_specified_tables(ecom, tmpdir, project):
mt = MultiTable(ecom, project_display_name=tmpdir)
mt.train_synthetics(only=["users"])
mt.train_synthetics(only={"users"})

assert set(mt._synthetics_train.models.keys()) == {"users"}
project.create_model_obj.assert_called_with(
Expand All @@ -48,7 +48,7 @@ def test_train_synthetics_only_includes_specified_tables(ecom, tmpdir, project):

def test_train_synthetics_ignore_excludes_specified_tables(ecom, tmpdir):
mt = MultiTable(ecom, project_display_name=tmpdir)
mt.train_synthetics(ignore=["distribution_center", "products"])
mt.train_synthetics(ignore={"distribution_center", "products"})

assert set(mt._synthetics_train.models.keys()) == {
"events",
Expand All @@ -61,16 +61,16 @@ def test_train_synthetics_ignore_excludes_specified_tables(ecom, tmpdir):
def test_train_synthetics_exits_early_if_unrecognized_tables(ecom, tmpdir, project):
mt = MultiTable(ecom, project_display_name=tmpdir)
with pytest.raises(MultiTableException):
mt.train_synthetics(ignore=["nonsense"])
mt.train_synthetics(ignore={"nonsense"})

assert len(mt._synthetics_train.models) == 0
project.create_model_obj.assert_not_called()


def test_train_synthetics_multiple_calls_additive(ecom, tmpdir):
mt = MultiTable(ecom, project_display_name=tmpdir)
mt.train_synthetics(only=["products"])
mt.train_synthetics(only=["users"])
mt.train_synthetics(only={"products"})
mt.train_synthetics(only={"users"})

# We do not lose the first table model
assert set(mt._synthetics_train.models.keys()) == {"products", "users"}
Expand All @@ -80,28 +80,28 @@ def test_train_synthetics_multiple_calls_overwrite(ecom, tmpdir, project):
project.create_model_obj.return_value = "m1"

mt = MultiTable(ecom, project_display_name=tmpdir)
mt.train_synthetics(only=["products"])
mt.train_synthetics(only={"products"})

assert mt._synthetics_train.models["products"] == "m1"

project.reset_mock()
project.create_model_obj.return_value = "m2"

# calling a second time will create a new model for the table that overwrites the original
mt.train_synthetics(only=["products"])
mt.train_synthetics(only={"products"})
assert mt._synthetics_train.models["products"] == "m2"


def test_train_synthetics_after_deleting_models(ecom, tmpdir):
mt = MultiTable(ecom, project_display_name=tmpdir)
mt.train_synthetics(only=["products"])
mt.train_synthetics(only={"products"})
mt.delete_models("synthetics")
mt.train_synthetics(only=["users"])
mt.train_synthetics(only={"users"})

assert set(mt._synthetics_train.models.keys()) == {"users"}

# You can scope deletion to a subset of tables
mt.train_synthetics(only=["users", "products"])
mt.train_synthetics(only={"users", "products"})
assert set(mt._synthetics_train.models.keys()) == {"users", "products"}
mt.delete_models("synthetics", ignore=["products"])
mt.delete_models("synthetics", ignore={"products"})
assert set(mt._synthetics_train.models.keys()) == {"products"}
18 changes: 9 additions & 9 deletions tests/relational/test_train_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ def test_train_transforms_defaults_to_transforming_all_tables(ecom, tmpdir):

def test_train_transforms_only_includes_specified_tables(ecom, tmpdir, project):
mt = MultiTable(ecom, project_display_name=tmpdir)
mt.train_transforms("transform/default", only=["users"])
mt.train_transforms("transform/default", only={"users"})
transforms_train = mt._transforms_train

assert set(transforms_train.models.keys()) == {"users"}
Expand All @@ -50,7 +50,7 @@ def test_train_transforms_only_includes_specified_tables(ecom, tmpdir, project):

def test_train_transforms_ignore_excludes_specified_tables(ecom, tmpdir):
mt = MultiTable(ecom, project_display_name=tmpdir)
mt.train_transforms("transform/default", ignore=["distribution_center", "products"])
mt.train_transforms("transform/default", ignore={"distribution_center", "products"})
transforms_train = mt._transforms_train

assert set(transforms_train.models.keys()) == {
Expand All @@ -64,7 +64,7 @@ def test_train_transforms_ignore_excludes_specified_tables(ecom, tmpdir):
def test_train_transforms_exits_early_if_unrecognized_tables(ecom, tmpdir, project):
mt = MultiTable(ecom, project_display_name=tmpdir)
with pytest.raises(MultiTableException):
mt.train_transforms("transform/default", ignore=["nonsense"])
mt.train_transforms("transform/default", ignore={"nonsense"})
transforms_train = mt._transforms_train

assert len(transforms_train.models) == 0
Expand All @@ -73,8 +73,8 @@ def test_train_transforms_exits_early_if_unrecognized_tables(ecom, tmpdir, proje

def test_train_transforms_multiple_calls_additive(ecom, tmpdir):
mt = MultiTable(ecom, project_display_name=tmpdir)
mt.train_transforms("transform/default", only=["products"])
mt.train_transforms("transform/default", only=["users"])
mt.train_transforms("transform/default", only={"products"})
mt.train_transforms("transform/default", only={"users"})

# We do not lose the first table model
assert set(mt._transforms_train.models.keys()) == {"products", "users"}
Expand All @@ -84,23 +84,23 @@ def test_train_transforms_multiple_calls_overwrite(ecom, tmpdir, project):
project.create_model_obj.return_value = "m1"

mt = MultiTable(ecom, project_display_name=tmpdir)
mt.train_transforms("transform/default", only=["products"])
mt.train_transforms("transform/default", only={"products"})

assert mt._transforms_train.models["products"] == "m1"

project.reset_mock()
project.create_model_obj.return_value = "m2"

# calling a second time will create a new model for the table that overwrites the original
mt.train_transforms("transform/default", only=["products"])
mt.train_transforms("transform/default", only={"products"})
assert mt._transforms_train.models["products"] == "m2"


def test_train_transforms_after_deleting_models(ecom, tmpdir):
mt = MultiTable(ecom, project_display_name=tmpdir)
mt.train_transforms("transform/default", only=["products"])
mt.train_transforms("transform/default", only={"products"})
mt.delete_models("transforms")
mt.train_transforms("transform/default", only=["users"])
mt.train_transforms("transform/default", only={"users"})

assert set(mt._transforms_train.models.keys()) == {"users"}

Expand Down

0 comments on commit 145686d

Please sign in to comment.