From 4f5bedbf5797a3de23ac928a0269769f9e411a61 Mon Sep 17 00:00:00 2001 From: Mike Knepper Date: Tue, 16 May 2023 08:33:42 -0500 Subject: [PATCH 1/9] Move skip_table function to core module --- src/gretel_trainer/relational/connectors.py | 30 ++++++++------------- src/gretel_trainer/relational/core.py | 12 +++++++++ 2 files changed, 23 insertions(+), 19 deletions(-) diff --git a/src/gretel_trainer/relational/connectors.py b/src/gretel_trainer/relational/connectors.py index 4888a3fb..9d80a72c 100644 --- a/src/gretel_trainer/relational/connectors.py +++ b/src/gretel_trainer/relational/connectors.py @@ -1,12 +1,16 @@ import logging -from typing import Dict, List, Optional, Tuple +from typing import Optional import pandas as pd from sqlalchemy import create_engine, inspect from sqlalchemy.engine.base import Engine from sqlalchemy.exc import OperationalError -from gretel_trainer.relational.core import MultiTableException, RelationalData +from gretel_trainer.relational.core import ( + MultiTableException, + RelationalData, + skip_table, +) logger = logging.getLogger(__name__) @@ -38,7 +42,7 @@ def __init__(self, engine: Engine): logger.info("Successfully connected to db") def extract( - self, only: Optional[List[str]] = None, ignore: Optional[List[str]] = None + self, only: Optional[list[str]] = None, ignore: Optional[list[str]] = None ) -> RelationalData: """ Extracts table data and relationships from the database. @@ -50,17 +54,17 @@ def extract( inspector = inspect(self.engine) relational_data = RelationalData() - foreign_keys: List[Tuple[str, dict]] = [] + foreign_keys: list[tuple[str, dict]] = [] for table_name in inspector.get_table_names(): - if _skip_table(table_name, only, ignore): + if skip_table(table_name, only, ignore): continue logger.debug(f"Extracting source data from `{table_name}`") df = pd.read_sql_table(table_name, self.engine) primary_key = inspector.get_pk_constraint(table_name)["constrained_columns"] for fk in inspector.get_foreign_keys(table_name): - if _skip_table(fk["referred_table"], only, ignore): + if skip_table(fk["referred_table"], only, ignore): continue else: foreign_keys.append((table_name, fk)) @@ -78,25 +82,13 @@ def extract( return relational_data - def save(self, tables: Dict[str, pd.DataFrame], prefix: str = "") -> None: + def save(self, tables: dict[str, pd.DataFrame], prefix: str = "") -> None: for name, data in tables.items(): data.to_sql( f"{prefix}{name}", con=self.engine, if_exists="replace", index=False ) -def _skip_table( - table: str, only: Optional[List[str]], ignore: Optional[List[str]] -) -> bool: - skip = False - if only is not None and table not in only: - skip = True - if ignore is not None and table in ignore: - skip = True - - return skip - - def sqlite_conn(path: str) -> Connector: engine = create_engine(f"sqlite:///{path}") return Connector(engine) diff --git a/src/gretel_trainer/relational/core.py b/src/gretel_trainer/relational/core.py index bc3a8e4a..3f307bde 100644 --- a/src/gretel_trainer/relational/core.py +++ b/src/gretel_trainer/relational/core.py @@ -651,6 +651,18 @@ def debug_summary(self) -> Dict[str, Any]: } +def skip_table( + table: str, only: Optional[list[str]], ignore: Optional[list[str]] +) -> bool: + skip = False + if only is not None and table not in only: + skip = True + if ignore is not None and table in ignore: + skip = True + + return skip + + def _ok_for_train_and_seed(col: str, df: pd.DataFrame) -> bool: if _is_highly_nan(col, df): return False From e6a3be3009b932eef4231a09d57110452824bf1b Mon Sep 17 00:00:00 2001 From: Mike Knepper Date: Tue, 16 May 2023 11:30:23 -0500 Subject: [PATCH 2/9] Add method to return modelable tables from a given table name --- src/gretel_trainer/relational/core.py | 10 ++++++++++ tests/relational/test_relational_data.py | 3 --- tests/relational/test_relational_data_with_json.py | 14 ++++++++++++++ 3 files changed, 24 insertions(+), 3 deletions(-) diff --git a/src/gretel_trainer/relational/core.py b/src/gretel_trainer/relational/core.py index 3f307bde..702630f3 100644 --- a/src/gretel_trainer/relational/core.py +++ b/src/gretel_trainer/relational/core.py @@ -456,6 +456,16 @@ def _is_invented(self, table: str) -> bool: and self.graph.nodes[table]["metadata"].invented_table_metadata is not None ) + def get_modelable_table_names(self, table: str) -> list[str]: + """Returns a list of MODELABLE table names connected to the provided table. + If the provided table is already modelable, returns [table]. + If the provided table is not modelable (e.g. source with JSON) returns invented tables from that source. + """ + if (rel_json := self.relational_jsons.get(table)) is not None: + return rel_json.table_names + else: + return [table] + def get_public_name(self, table: str) -> Optional[str]: if table in self.relational_jsons: return table diff --git a/tests/relational/test_relational_data.py b/tests/relational/test_relational_data.py index 755bf8b1..cf5141b0 100644 --- a/tests/relational/test_relational_data.py +++ b/tests/relational/test_relational_data.py @@ -1,6 +1,3 @@ -import os -import tempfile - import pandas as pd import pytest diff --git a/tests/relational/test_relational_data_with_json.py b/tests/relational/test_relational_data_with_json.py index 9ef30c60..b8a488e2 100644 --- a/tests/relational/test_relational_data_with_json.py +++ b/tests/relational/test_relational_data_with_json.py @@ -108,6 +108,20 @@ def test_list_tables_accepts_various_scopes(documents): ) +def test_get_modelable_table_names(documents): + # Given a source-with-JSON name, returns the tables invented from that source + assert set(documents.get_modelable_table_names("purchases")) == { + "purchases-sfx", + "purchases-data-years-sfx", + } + + # Invented tables are modelable + assert documents.get_modelable_table_names("purchases-sfx") == ["purchases-sfx"] + assert documents.get_modelable_table_names("purchases-data-years-sfx") == [ + "purchases-data-years-sfx" + ] + + def test_invented_json_column_names(documents, bball): # The root invented table adds columns for dictionary properties lifted from nested JSON objects assert set(documents.get_table_columns("purchases-sfx")) == { From 496a9461b7fa431344350a5d219077ef5b3355ea Mon Sep 17 00:00:00 2001 From: Mike Knepper Date: Tue, 16 May 2023 11:39:14 -0500 Subject: [PATCH 3/9] Add new transform method that accepts a single config for all tables --- src/gretel_trainer/relational/multi_table.py | 67 ++++++++++++++++++++ 1 file changed, 67 insertions(+) diff --git a/src/gretel_trainer/relational/multi_table.py b/src/gretel_trainer/relational/multi_table.py index d6c397f1..3c8e24b0 100644 --- a/src/gretel_trainer/relational/multi_table.py +++ b/src/gretel_trainer/relational/multi_table.py @@ -33,6 +33,7 @@ MultiTableException, RelationalData, Scope, + skip_table, ) from gretel_trainer.relational.json import InventedTableMetadata, RelationalJson from gretel_trainer.relational.log import silent_logs @@ -557,7 +558,73 @@ def classify(self, config: GretelModelConfig, all_rows: bool = False) -> None: ) def train_transform_models(self, configs: Dict[str, GretelModelConfig]) -> None: + """ + DEPRECATED: Please use `train_transforms` instead. + """ + logger.warning( + "This method is deprecated and will be removed in a future release. " + "Please use `train_transforms` instead." + ) + use_configs = {} for table, config in configs.items(): + for m_table in self.relational_data.get_modelable_table_names(table): + use_configs[m_table] = config + + for table, config in use_configs.items(): + transform_config = make_transform_config( + self.relational_data, table, config + ) + + # Ensure consistent, friendly data source names in Console + table_data = self.relational_data.get_table_data(table) + transforms_train_path = self._working_dir / f"transforms_train_{table}.csv" + table_data.to_csv(transforms_train_path, index=False) + + # Create model + model = self._project.create_model_obj( + model_config=transform_config, data_source=str(transforms_train_path) + ) + self._transforms_train.models[table] = model + + self._backup() + + task = TransformsTrainTask( + transforms_train=self._transforms_train, + multitable=self, + ) + run_task(task, self._extended_sdk) + + def train_transforms( + self, + config: GretelModelConfig, + *, + only: Optional[list[str]] = None, + ignore: Optional[list[str]] = None, + ) -> None: + if only is not None and ignore is not None: + raise MultiTableException("Cannot specify both `only` and `ignore`.") + + if only is not None: + only = [ + modelable_name + for table in only + for modelable_name in self.relational_data.get_modelable_table_names( + table + ) + ] + if ignore is not None: + ignore = [ + modelable_name + for table in ignore + for modelable_name in self.relational_data.get_modelable_table_names( + table + ) + ] + + for table in self.relational_data.list_all_tables(): + if skip_table(table, only, ignore): + continue + transform_config = make_transform_config( self.relational_data, table, config ) From fd854e54b89da33cbf1516ee6f174377b1428445 Mon Sep 17 00:00:00 2001 From: Mike Knepper Date: Wed, 17 May 2023 15:05:49 -0500 Subject: [PATCH 4/9] Docstring tweak --- src/gretel_trainer/relational/core.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/gretel_trainer/relational/core.py b/src/gretel_trainer/relational/core.py index 702630f3..ad2e49d6 100644 --- a/src/gretel_trainer/relational/core.py +++ b/src/gretel_trainer/relational/core.py @@ -459,7 +459,7 @@ def _is_invented(self, table: str) -> bool: def get_modelable_table_names(self, table: str) -> list[str]: """Returns a list of MODELABLE table names connected to the provided table. If the provided table is already modelable, returns [table]. - If the provided table is not modelable (e.g. source with JSON) returns invented tables from that source. + If the provided table is not modelable (e.g. source with JSON), returns tables invented from that source. """ if (rel_json := self.relational_jsons.get(table)) is not None: return rel_json.table_names From e1e4399d6c8a9d5ba48d972ba0586b2af1f2102a Mon Sep 17 00:00:00 2001 From: Mike Knepper Date: Wed, 17 May 2023 16:23:13 -0500 Subject: [PATCH 5/9] Extract method and add unit tests for train transforms --- src/gretel_trainer/relational/multi_table.py | 56 +++++++--------- tests/relational/test_train_transforms.py | 70 ++++++++++++++++++++ 2 files changed, 93 insertions(+), 33 deletions(-) create mode 100644 tests/relational/test_train_transforms.py diff --git a/src/gretel_trainer/relational/multi_table.py b/src/gretel_trainer/relational/multi_table.py index 3c8e24b0..f24a650d 100644 --- a/src/gretel_trainer/relational/multi_table.py +++ b/src/gretel_trainer/relational/multi_table.py @@ -557,20 +557,10 @@ def classify(self, config: GretelModelConfig, all_rows: bool = False) -> None: self._project, str(archive_path) ) - def train_transform_models(self, configs: Dict[str, GretelModelConfig]) -> None: - """ - DEPRECATED: Please use `train_transforms` instead. - """ - logger.warning( - "This method is deprecated and will be removed in a future release. " - "Please use `train_transforms` instead." - ) - use_configs = {} + def _setup_transforms_train_state( + self, configs: dict[str, GretelModelConfig] + ) -> None: for table, config in configs.items(): - for m_table in self.relational_data.get_modelable_table_names(table): - use_configs[m_table] = config - - for table, config in use_configs.items(): transform_config = make_transform_config( self.relational_data, table, config ) @@ -588,6 +578,20 @@ def train_transform_models(self, configs: Dict[str, GretelModelConfig]) -> None: self._backup() + def train_transform_models(self, configs: dict[str, GretelModelConfig]) -> None: + """ + DEPRECATED: Please use `train_transforms` instead. + """ + logger.warning( + "This method is deprecated and will be removed in a future release. " + "Please use `train_transforms` instead." + ) + use_configs = {} + for table, config in configs.items(): + for m_table in self.relational_data.get_modelable_table_names(table): + use_configs[m_table] = config + + self._setup_transforms_train_state(use_configs) task = TransformsTrainTask( transforms_train=self._transforms_train, multitable=self, @@ -621,27 +625,13 @@ def train_transforms( ) ] - for table in self.relational_data.list_all_tables(): - if skip_table(table, only, ignore): - continue - - transform_config = make_transform_config( - self.relational_data, table, config - ) - - # Ensure consistent, friendly data source names in Console - table_data = self.relational_data.get_table_data(table) - transforms_train_path = self._working_dir / f"transforms_train_{table}.csv" - table_data.to_csv(transforms_train_path, index=False) - - # Create model - model = self._project.create_model_obj( - model_config=transform_config, data_source=str(transforms_train_path) - ) - self._transforms_train.models[table] = model - - self._backup() + configs = { + table: config + for table in self.relational_data.list_all_tables() + if not skip_table(table, only, ignore) + } + self._setup_transforms_train_state(configs) task = TransformsTrainTask( transforms_train=self._transforms_train, multitable=self, diff --git a/tests/relational/test_train_transforms.py b/tests/relational/test_train_transforms.py new file mode 100644 index 00000000..bc265bb6 --- /dev/null +++ b/tests/relational/test_train_transforms.py @@ -0,0 +1,70 @@ +import tempfile +from unittest.mock import patch + +import pytest + +from gretel_trainer.relational.multi_table import MultiTable + + +# The assertions in this file are concerned with setting up the transforms train +# workflow state properly, and stop short of kicking off the task. +@pytest.fixture(autouse=True) +def run_task(): + with patch("gretel_trainer.relational.multi_table.run_task"): + yield + + +@pytest.fixture(autouse=True) +def backup(): + with patch.object(MultiTable, "_backup", return_value=None): + yield + + +@pytest.fixture() +def tmpdir(project): + with tempfile.TemporaryDirectory() as tmpdir: + project.name = tmpdir + yield tmpdir + + +def test_train_transforms_defaults_to_transforming_all_tables(ecom, tmpdir): + mt = MultiTable(ecom, project_display_name=tmpdir) + mt.train_transforms("transform/default") + transforms_train = mt._transforms_train + + assert set(transforms_train.models.keys()) == set(ecom.list_all_tables()) + + +def test_train_transforms_only_includes_specified_tables(ecom, tmpdir): + mt = MultiTable(ecom, project_display_name=tmpdir) + mt.train_transforms("transform/default", only=["events", "users"]) + transforms_train = mt._transforms_train + + assert set(transforms_train.models.keys()) == {"events", "users"} + + +def test_train_transforms_ignore_excludes_specified_tables(ecom, tmpdir): + mt = MultiTable(ecom, project_display_name=tmpdir) + mt.train_transforms("transform/default", ignore=["distribution_center", "products"]) + transforms_train = mt._transforms_train + + assert set(transforms_train.models.keys()) == { + "events", + "users", + "order_items", + "inventory_items", + } + + +# The public method under test here is deprecated +def test_train_transform_models(ecom, tmpdir): + mt = MultiTable(ecom, project_display_name=tmpdir) + mt.train_transform_models( + configs={ + "events": "transform/default", + "users": "transform/default", + } + ) + transforms_train = mt._transforms_train + + assert set(transforms_train.models.keys()) == {"events", "users"} From 4eaf38d126d60366143a6d122a8753586417f3cf Mon Sep 17 00:00:00 2001 From: Mike Knepper Date: Wed, 17 May 2023 16:33:23 -0500 Subject: [PATCH 6/9] Assert that a model is created --- tests/relational/test_train_transforms.py | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/tests/relational/test_train_transforms.py b/tests/relational/test_train_transforms.py index bc265bb6..77359639 100644 --- a/tests/relational/test_train_transforms.py +++ b/tests/relational/test_train_transforms.py @@ -1,5 +1,5 @@ import tempfile -from unittest.mock import patch +from unittest.mock import ANY, patch import pytest @@ -35,12 +35,16 @@ def test_train_transforms_defaults_to_transforming_all_tables(ecom, tmpdir): assert set(transforms_train.models.keys()) == set(ecom.list_all_tables()) -def test_train_transforms_only_includes_specified_tables(ecom, tmpdir): +def test_train_transforms_only_includes_specified_tables(ecom, tmpdir, project): mt = MultiTable(ecom, project_display_name=tmpdir) - mt.train_transforms("transform/default", only=["events", "users"]) + mt.train_transforms("transform/default", only=["users"]) transforms_train = mt._transforms_train - assert set(transforms_train.models.keys()) == {"events", "users"} + assert set(transforms_train.models.keys()) == {"users"} + project.create_model_obj.assert_called_with( + model_config=ANY, # a tailored transforms config, in dict form + data_source=f"{tmpdir}/transforms_train_users.csv", + ) def test_train_transforms_ignore_excludes_specified_tables(ecom, tmpdir): From b1c31a83b098aa7cda5fcdfb0ddec500079b20ab Mon Sep 17 00:00:00 2001 From: Mike Knepper Date: Thu, 18 May 2023 11:43:37 -0500 Subject: [PATCH 7/9] Return empty list and log warning if unrecognized table name --- src/gretel_trainer/relational/core.py | 4 ++++ tests/relational/test_relational_data_with_json.py | 3 +++ 2 files changed, 7 insertions(+) diff --git a/src/gretel_trainer/relational/core.py b/src/gretel_trainer/relational/core.py index ad2e49d6..a8f6c17a 100644 --- a/src/gretel_trainer/relational/core.py +++ b/src/gretel_trainer/relational/core.py @@ -460,9 +460,13 @@ def get_modelable_table_names(self, table: str) -> list[str]: """Returns a list of MODELABLE table names connected to the provided table. If the provided table is already modelable, returns [table]. If the provided table is not modelable (e.g. source with JSON), returns tables invented from that source. + If the provided table does not exist, returns empty list. """ if (rel_json := self.relational_jsons.get(table)) is not None: return rel_json.table_names + elif table not in self.list_all_tables(Scope.ALL): + logger.warning(f"Unrecognized table name: `{table}`") + return [] else: return [table] diff --git a/tests/relational/test_relational_data_with_json.py b/tests/relational/test_relational_data_with_json.py index b8a488e2..ce2b91a3 100644 --- a/tests/relational/test_relational_data_with_json.py +++ b/tests/relational/test_relational_data_with_json.py @@ -121,6 +121,9 @@ def test_get_modelable_table_names(documents): "purchases-data-years-sfx" ] + # Unknown tables return empty list + assert documents.get_modelable_table_names("nonsense") == [] + def test_invented_json_column_names(documents, bball): # The root invented table adds columns for dictionary properties lifted from nested JSON objects From 61097336f02ce5a2b8de1c195b84db0a79aa8d99 Mon Sep 17 00:00:00 2001 From: Mike Knepper Date: Thu, 18 May 2023 11:54:01 -0500 Subject: [PATCH 8/9] Exit early if unrecognized table names are provided --- src/gretel_trainer/relational/core.py | 1 - src/gretel_trainer/relational/multi_table.py | 31 ++++++++++---------- tests/relational/test_train_transforms.py | 11 +++++++ 3 files changed, 27 insertions(+), 16 deletions(-) diff --git a/src/gretel_trainer/relational/core.py b/src/gretel_trainer/relational/core.py index a8f6c17a..6053cab4 100644 --- a/src/gretel_trainer/relational/core.py +++ b/src/gretel_trainer/relational/core.py @@ -465,7 +465,6 @@ def get_modelable_table_names(self, table: str) -> list[str]: if (rel_json := self.relational_jsons.get(table)) is not None: return rel_json.table_names elif table not in self.list_all_tables(Scope.ALL): - logger.warning(f"Unrecognized table name: `{table}`") return [] else: return [table] diff --git a/src/gretel_trainer/relational/multi_table.py b/src/gretel_trainer/relational/multi_table.py index f24a650d..de93ab99 100644 --- a/src/gretel_trainer/relational/multi_table.py +++ b/src/gretel_trainer/relational/multi_table.py @@ -608,27 +608,28 @@ def train_transforms( if only is not None and ignore is not None: raise MultiTableException("Cannot specify both `only` and `ignore`.") + m_only = None if only is not None: - only = [ - modelable_name - for table in only - for modelable_name in self.relational_data.get_modelable_table_names( - table - ) - ] + m_only = [] + for table in only: + m_names = self.relational_data.get_modelable_table_names(table) + if len(m_names) == 0: + raise MultiTableException(f"Unrecognized table name: `{table}`") + m_only.extend(m_names) + + m_ignore = None if ignore is not None: - ignore = [ - modelable_name - for table in ignore - for modelable_name in self.relational_data.get_modelable_table_names( - table - ) - ] + m_ignore = [] + for table in ignore: + m_names = self.relational_data.get_modelable_table_names(table) + if len(m_names) == 0: + raise MultiTableException(f"Unrecognized table name: `{table}`") + m_ignore.extend(m_names) configs = { table: config for table in self.relational_data.list_all_tables() - if not skip_table(table, only, ignore) + if not skip_table(table, m_only, m_ignore) } self._setup_transforms_train_state(configs) diff --git a/tests/relational/test_train_transforms.py b/tests/relational/test_train_transforms.py index 77359639..d7c1e4ec 100644 --- a/tests/relational/test_train_transforms.py +++ b/tests/relational/test_train_transforms.py @@ -3,6 +3,7 @@ import pytest +from gretel_trainer.relational.core import MultiTableException from gretel_trainer.relational.multi_table import MultiTable @@ -60,6 +61,16 @@ def test_train_transforms_ignore_excludes_specified_tables(ecom, tmpdir): } +def test_train_transforms_exits_early_if_unrecognized_tables(ecom, tmpdir, project): + mt = MultiTable(ecom, project_display_name=tmpdir) + with pytest.raises(MultiTableException): + mt.train_transforms("transform/default", ignore=["nonsense"]) + transforms_train = mt._transforms_train + + assert len(transforms_train.models) == 0 + project.create_model_obj.assert_not_called() + + # The public method under test here is deprecated def test_train_transform_models(ecom, tmpdir): mt = MultiTable(ecom, project_display_name=tmpdir) From ced662fc77746674873b4ba47bfe43712436e2de Mon Sep 17 00:00:00 2001 From: Mike Knepper Date: Thu, 18 May 2023 14:41:51 -0500 Subject: [PATCH 9/9] Document behavior of multiple calls --- tests/relational/test_train_transforms.py | 25 +++++++++++++++++++++++ 1 file changed, 25 insertions(+) diff --git a/tests/relational/test_train_transforms.py b/tests/relational/test_train_transforms.py index d7c1e4ec..4c93cdbd 100644 --- a/tests/relational/test_train_transforms.py +++ b/tests/relational/test_train_transforms.py @@ -71,6 +71,31 @@ def test_train_transforms_exits_early_if_unrecognized_tables(ecom, tmpdir, proje project.create_model_obj.assert_not_called() +def test_train_transforms_multiple_calls_additive(ecom, tmpdir): + mt = MultiTable(ecom, project_display_name=tmpdir) + mt.train_transforms("transform/default", only=["products"]) + mt.train_transforms("transform/default", only=["users"]) + + # We do not lose the first table model + assert set(mt._transforms_train.models.keys()) == {"products", "users"} + + +def test_train_transforms_multiple_calls_overwrite(ecom, tmpdir, project): + project.create_model_obj.return_value = "m1" + + mt = MultiTable(ecom, project_display_name=tmpdir) + mt.train_transforms("transform/default", only=["products"]) + + assert mt._transforms_train.models["products"] == "m1" + + project.reset_mock() + project.create_model_obj.return_value = "m2" + + # calling a second time will create a new model for the table that overwrites the original + mt.train_transforms("transform/default", only=["products"]) + assert mt._transforms_train.models["products"] == "m2" + + # The public method under test here is deprecated def test_train_transform_models(ecom, tmpdir): mt = MultiTable(ecom, project_display_name=tmpdir)