From 145686dfbbd717d5ca9f5ce104beb325b2e8bf37 Mon Sep 17 00:00:00 2001 From: Mike Knepper Date: Fri, 26 May 2023 08:56:45 -0500 Subject: [PATCH] Use sets for only/ignore --- src/gretel_trainer/relational/connectors.py | 2 +- src/gretel_trainer/relational/core.py | 2 +- src/gretel_trainer/relational/multi_table.py | 16 +++++++------- tests/relational/test_connectors.py | 6 +++--- tests/relational/test_train_synthetics.py | 22 ++++++++++---------- tests/relational/test_train_transforms.py | 18 ++++++++-------- 6 files changed, 33 insertions(+), 33 deletions(-) diff --git a/src/gretel_trainer/relational/connectors.py b/src/gretel_trainer/relational/connectors.py index 9d80a72c..535e40d5 100644 --- a/src/gretel_trainer/relational/connectors.py +++ b/src/gretel_trainer/relational/connectors.py @@ -42,7 +42,7 @@ def __init__(self, engine: Engine): logger.info("Successfully connected to db") def extract( - self, only: Optional[list[str]] = None, ignore: Optional[list[str]] = None + self, only: Optional[set[str]] = None, ignore: Optional[set[str]] = None ) -> RelationalData: """ Extracts table data and relationships from the database. diff --git a/src/gretel_trainer/relational/core.py b/src/gretel_trainer/relational/core.py index 5eea88ef..a9df9c12 100644 --- a/src/gretel_trainer/relational/core.py +++ b/src/gretel_trainer/relational/core.py @@ -665,7 +665,7 @@ def debug_summary(self) -> dict[str, Any]: def skip_table( - table: str, only: Optional[list[str]], ignore: Optional[list[str]] + table: str, only: Optional[set[str]], ignore: Optional[set[str]] ) -> bool: skip = False if only is not None and table not in only: diff --git a/src/gretel_trainer/relational/multi_table.py b/src/gretel_trainer/relational/multi_table.py index 8303bb12..0203b98c 100644 --- a/src/gretel_trainer/relational/multi_table.py +++ b/src/gretel_trainer/relational/multi_table.py @@ -514,8 +514,8 @@ def delete_models( self, workflow: str, *, - only: Optional[list[str]] = None, - ignore: Optional[list[str]] = None, + only: Optional[set[str]] = None, + ignore: Optional[set[str]] = None, ) -> None: only, ignore = self._get_only_and_ignore(only, ignore) tables = [ @@ -625,8 +625,8 @@ def train_transforms( self, config: GretelModelConfig, *, - only: Optional[list[str]] = None, - ignore: Optional[list[str]] = None, + only: Optional[set[str]] = None, + ignore: Optional[set[str]] = None, ) -> None: only, ignore = self._get_only_and_ignore(only, ignore) @@ -727,8 +727,8 @@ def run_transforms( self.transform_output_tables = reshaped_tables def _get_only_and_ignore( - self, only: Optional[list[str]], ignore: Optional[list[str]] - ) -> tuple[Optional[list[str]], Optional[list[str]]]: + self, only: Optional[set[str]], ignore: Optional[set[str]] + ) -> tuple[Optional[set[str]], Optional[set[str]]]: if only is not None and ignore is not None: raise MultiTableException("Cannot specify both `only` and `ignore`.") @@ -819,8 +819,8 @@ def train(self) -> None: def train_synthetics( self, *, - only: Optional[list[str]] = None, - ignore: Optional[list[str]] = None, + only: Optional[set[str]] = None, + ignore: Optional[set[str]] = None, ) -> None: """ Train synthetic data models for the tables in the tableset, diff --git a/tests/relational/test_connectors.py b/tests/relational/test_connectors.py index dbee2014..0567a7d5 100644 --- a/tests/relational/test_connectors.py +++ b/tests/relational/test_connectors.py @@ -17,11 +17,11 @@ def test_extract_subsets_of_relational_data(example_dbs): connector = sqlite_conn(f.name) with pytest.raises(MultiTableException): - connector.extract(only=["users"], ignore=["events"]) + connector.extract(only={"users"}, ignore={"events"}) - only = connector.extract(only=["users", "events", "products"]) + only = connector.extract(only={"users", "events", "products"}) ignore = connector.extract( - ignore=["distribution_center", "order_items", "inventory_items"] + ignore={"distribution_center", "order_items", "inventory_items"} ) expected_tables = {"users", "events", "products"} diff --git a/tests/relational/test_train_synthetics.py b/tests/relational/test_train_synthetics.py index 149d977a..ccbb784b 100644 --- a/tests/relational/test_train_synthetics.py +++ b/tests/relational/test_train_synthetics.py @@ -37,7 +37,7 @@ def test_train_synthetics_defaults_to_training_all_tables(ecom, tmpdir): def test_train_synthetics_only_includes_specified_tables(ecom, tmpdir, project): mt = MultiTable(ecom, project_display_name=tmpdir) - mt.train_synthetics(only=["users"]) + mt.train_synthetics(only={"users"}) assert set(mt._synthetics_train.models.keys()) == {"users"} project.create_model_obj.assert_called_with( @@ -48,7 +48,7 @@ def test_train_synthetics_only_includes_specified_tables(ecom, tmpdir, project): def test_train_synthetics_ignore_excludes_specified_tables(ecom, tmpdir): mt = MultiTable(ecom, project_display_name=tmpdir) - mt.train_synthetics(ignore=["distribution_center", "products"]) + mt.train_synthetics(ignore={"distribution_center", "products"}) assert set(mt._synthetics_train.models.keys()) == { "events", @@ -61,7 +61,7 @@ def test_train_synthetics_ignore_excludes_specified_tables(ecom, tmpdir): def test_train_synthetics_exits_early_if_unrecognized_tables(ecom, tmpdir, project): mt = MultiTable(ecom, project_display_name=tmpdir) with pytest.raises(MultiTableException): - mt.train_synthetics(ignore=["nonsense"]) + mt.train_synthetics(ignore={"nonsense"}) assert len(mt._synthetics_train.models) == 0 project.create_model_obj.assert_not_called() @@ -69,8 +69,8 @@ def test_train_synthetics_exits_early_if_unrecognized_tables(ecom, tmpdir, proje def test_train_synthetics_multiple_calls_additive(ecom, tmpdir): mt = MultiTable(ecom, project_display_name=tmpdir) - mt.train_synthetics(only=["products"]) - mt.train_synthetics(only=["users"]) + mt.train_synthetics(only={"products"}) + mt.train_synthetics(only={"users"}) # We do not lose the first table model assert set(mt._synthetics_train.models.keys()) == {"products", "users"} @@ -80,7 +80,7 @@ def test_train_synthetics_multiple_calls_overwrite(ecom, tmpdir, project): project.create_model_obj.return_value = "m1" mt = MultiTable(ecom, project_display_name=tmpdir) - mt.train_synthetics(only=["products"]) + mt.train_synthetics(only={"products"}) assert mt._synthetics_train.models["products"] == "m1" @@ -88,20 +88,20 @@ def test_train_synthetics_multiple_calls_overwrite(ecom, tmpdir, project): project.create_model_obj.return_value = "m2" # calling a second time will create a new model for the table that overwrites the original - mt.train_synthetics(only=["products"]) + mt.train_synthetics(only={"products"}) assert mt._synthetics_train.models["products"] == "m2" def test_train_synthetics_after_deleting_models(ecom, tmpdir): mt = MultiTable(ecom, project_display_name=tmpdir) - mt.train_synthetics(only=["products"]) + mt.train_synthetics(only={"products"}) mt.delete_models("synthetics") - mt.train_synthetics(only=["users"]) + mt.train_synthetics(only={"users"}) assert set(mt._synthetics_train.models.keys()) == {"users"} # You can scope deletion to a subset of tables - mt.train_synthetics(only=["users", "products"]) + mt.train_synthetics(only={"users", "products"}) assert set(mt._synthetics_train.models.keys()) == {"users", "products"} - mt.delete_models("synthetics", ignore=["products"]) + mt.delete_models("synthetics", ignore={"products"}) assert set(mt._synthetics_train.models.keys()) == {"products"} diff --git a/tests/relational/test_train_transforms.py b/tests/relational/test_train_transforms.py index d355021a..4beadb89 100644 --- a/tests/relational/test_train_transforms.py +++ b/tests/relational/test_train_transforms.py @@ -38,7 +38,7 @@ def test_train_transforms_defaults_to_transforming_all_tables(ecom, tmpdir): def test_train_transforms_only_includes_specified_tables(ecom, tmpdir, project): mt = MultiTable(ecom, project_display_name=tmpdir) - mt.train_transforms("transform/default", only=["users"]) + mt.train_transforms("transform/default", only={"users"}) transforms_train = mt._transforms_train assert set(transforms_train.models.keys()) == {"users"} @@ -50,7 +50,7 @@ def test_train_transforms_only_includes_specified_tables(ecom, tmpdir, project): def test_train_transforms_ignore_excludes_specified_tables(ecom, tmpdir): mt = MultiTable(ecom, project_display_name=tmpdir) - mt.train_transforms("transform/default", ignore=["distribution_center", "products"]) + mt.train_transforms("transform/default", ignore={"distribution_center", "products"}) transforms_train = mt._transforms_train assert set(transforms_train.models.keys()) == { @@ -64,7 +64,7 @@ def test_train_transforms_ignore_excludes_specified_tables(ecom, tmpdir): def test_train_transforms_exits_early_if_unrecognized_tables(ecom, tmpdir, project): mt = MultiTable(ecom, project_display_name=tmpdir) with pytest.raises(MultiTableException): - mt.train_transforms("transform/default", ignore=["nonsense"]) + mt.train_transforms("transform/default", ignore={"nonsense"}) transforms_train = mt._transforms_train assert len(transforms_train.models) == 0 @@ -73,8 +73,8 @@ def test_train_transforms_exits_early_if_unrecognized_tables(ecom, tmpdir, proje def test_train_transforms_multiple_calls_additive(ecom, tmpdir): mt = MultiTable(ecom, project_display_name=tmpdir) - mt.train_transforms("transform/default", only=["products"]) - mt.train_transforms("transform/default", only=["users"]) + mt.train_transforms("transform/default", only={"products"}) + mt.train_transforms("transform/default", only={"users"}) # We do not lose the first table model assert set(mt._transforms_train.models.keys()) == {"products", "users"} @@ -84,7 +84,7 @@ def test_train_transforms_multiple_calls_overwrite(ecom, tmpdir, project): project.create_model_obj.return_value = "m1" mt = MultiTable(ecom, project_display_name=tmpdir) - mt.train_transforms("transform/default", only=["products"]) + mt.train_transforms("transform/default", only={"products"}) assert mt._transforms_train.models["products"] == "m1" @@ -92,15 +92,15 @@ def test_train_transforms_multiple_calls_overwrite(ecom, tmpdir, project): project.create_model_obj.return_value = "m2" # calling a second time will create a new model for the table that overwrites the original - mt.train_transforms("transform/default", only=["products"]) + mt.train_transforms("transform/default", only={"products"}) assert mt._transforms_train.models["products"] == "m2" def test_train_transforms_after_deleting_models(ecom, tmpdir): mt = MultiTable(ecom, project_display_name=tmpdir) - mt.train_transforms("transform/default", only=["products"]) + mt.train_transforms("transform/default", only={"products"}) mt.delete_models("transforms") - mt.train_transforms("transform/default", only=["users"]) + mt.train_transforms("transform/default", only={"users"}) assert set(mt._transforms_train.models.keys()) == {"users"}