From 8ac8209d5b62da71834e6ee255aa136179890b27 Mon Sep 17 00:00:00 2001 From: Mike Knepper Date: Thu, 18 May 2023 14:46:59 -0500 Subject: [PATCH] Single transforms config (#109) * Move skip_table function to core module * Add method to return modelable tables from a given table name * Add new transform method that accepts a single config for all tables --- src/gretel_trainer/relational/connectors.py | 30 ++--- src/gretel_trainer/relational/core.py | 25 ++++ src/gretel_trainer/relational/multi_table.py | 60 +++++++++- tests/relational/test_relational_data.py | 3 - .../test_relational_data_with_json.py | 17 +++ tests/relational/test_train_transforms.py | 110 ++++++++++++++++++ 6 files changed, 222 insertions(+), 23 deletions(-) create mode 100644 tests/relational/test_train_transforms.py diff --git a/src/gretel_trainer/relational/connectors.py b/src/gretel_trainer/relational/connectors.py index 4888a3fb..9d80a72c 100644 --- a/src/gretel_trainer/relational/connectors.py +++ b/src/gretel_trainer/relational/connectors.py @@ -1,12 +1,16 @@ import logging -from typing import Dict, List, Optional, Tuple +from typing import Optional import pandas as pd from sqlalchemy import create_engine, inspect from sqlalchemy.engine.base import Engine from sqlalchemy.exc import OperationalError -from gretel_trainer.relational.core import MultiTableException, RelationalData +from gretel_trainer.relational.core import ( + MultiTableException, + RelationalData, + skip_table, +) logger = logging.getLogger(__name__) @@ -38,7 +42,7 @@ def __init__(self, engine: Engine): logger.info("Successfully connected to db") def extract( - self, only: Optional[List[str]] = None, ignore: Optional[List[str]] = None + self, only: Optional[list[str]] = None, ignore: Optional[list[str]] = None ) -> RelationalData: """ Extracts table data and relationships from the database. @@ -50,17 +54,17 @@ def extract( inspector = inspect(self.engine) relational_data = RelationalData() - foreign_keys: List[Tuple[str, dict]] = [] + foreign_keys: list[tuple[str, dict]] = [] for table_name in inspector.get_table_names(): - if _skip_table(table_name, only, ignore): + if skip_table(table_name, only, ignore): continue logger.debug(f"Extracting source data from `{table_name}`") df = pd.read_sql_table(table_name, self.engine) primary_key = inspector.get_pk_constraint(table_name)["constrained_columns"] for fk in inspector.get_foreign_keys(table_name): - if _skip_table(fk["referred_table"], only, ignore): + if skip_table(fk["referred_table"], only, ignore): continue else: foreign_keys.append((table_name, fk)) @@ -78,25 +82,13 @@ def extract( return relational_data - def save(self, tables: Dict[str, pd.DataFrame], prefix: str = "") -> None: + def save(self, tables: dict[str, pd.DataFrame], prefix: str = "") -> None: for name, data in tables.items(): data.to_sql( f"{prefix}{name}", con=self.engine, if_exists="replace", index=False ) -def _skip_table( - table: str, only: Optional[List[str]], ignore: Optional[List[str]] -) -> bool: - skip = False - if only is not None and table not in only: - skip = True - if ignore is not None and table in ignore: - skip = True - - return skip - - def sqlite_conn(path: str) -> Connector: engine = create_engine(f"sqlite:///{path}") return Connector(engine) diff --git a/src/gretel_trainer/relational/core.py b/src/gretel_trainer/relational/core.py index bc3a8e4a..6053cab4 100644 --- a/src/gretel_trainer/relational/core.py +++ b/src/gretel_trainer/relational/core.py @@ -456,6 +456,19 @@ def _is_invented(self, table: str) -> bool: and self.graph.nodes[table]["metadata"].invented_table_metadata is not None ) + def get_modelable_table_names(self, table: str) -> list[str]: + """Returns a list of MODELABLE table names connected to the provided table. + If the provided table is already modelable, returns [table]. + If the provided table is not modelable (e.g. source with JSON), returns tables invented from that source. + If the provided table does not exist, returns empty list. + """ + if (rel_json := self.relational_jsons.get(table)) is not None: + return rel_json.table_names + elif table not in self.list_all_tables(Scope.ALL): + return [] + else: + return [table] + def get_public_name(self, table: str) -> Optional[str]: if table in self.relational_jsons: return table @@ -651,6 +664,18 @@ def debug_summary(self) -> Dict[str, Any]: } +def skip_table( + table: str, only: Optional[list[str]], ignore: Optional[list[str]] +) -> bool: + skip = False + if only is not None and table not in only: + skip = True + if ignore is not None and table in ignore: + skip = True + + return skip + + def _ok_for_train_and_seed(col: str, df: pd.DataFrame) -> bool: if _is_highly_nan(col, df): return False diff --git a/src/gretel_trainer/relational/multi_table.py b/src/gretel_trainer/relational/multi_table.py index d6c397f1..de93ab99 100644 --- a/src/gretel_trainer/relational/multi_table.py +++ b/src/gretel_trainer/relational/multi_table.py @@ -33,6 +33,7 @@ MultiTableException, RelationalData, Scope, + skip_table, ) from gretel_trainer.relational.json import InventedTableMetadata, RelationalJson from gretel_trainer.relational.log import silent_logs @@ -556,7 +557,9 @@ def classify(self, config: GretelModelConfig, all_rows: bool = False) -> None: self._project, str(archive_path) ) - def train_transform_models(self, configs: Dict[str, GretelModelConfig]) -> None: + def _setup_transforms_train_state( + self, configs: dict[str, GretelModelConfig] + ) -> None: for table, config in configs.items(): transform_config = make_transform_config( self.relational_data, table, config @@ -575,6 +578,61 @@ def train_transform_models(self, configs: Dict[str, GretelModelConfig]) -> None: self._backup() + def train_transform_models(self, configs: dict[str, GretelModelConfig]) -> None: + """ + DEPRECATED: Please use `train_transforms` instead. + """ + logger.warning( + "This method is deprecated and will be removed in a future release. " + "Please use `train_transforms` instead." + ) + use_configs = {} + for table, config in configs.items(): + for m_table in self.relational_data.get_modelable_table_names(table): + use_configs[m_table] = config + + self._setup_transforms_train_state(use_configs) + task = TransformsTrainTask( + transforms_train=self._transforms_train, + multitable=self, + ) + run_task(task, self._extended_sdk) + + def train_transforms( + self, + config: GretelModelConfig, + *, + only: Optional[list[str]] = None, + ignore: Optional[list[str]] = None, + ) -> None: + if only is not None and ignore is not None: + raise MultiTableException("Cannot specify both `only` and `ignore`.") + + m_only = None + if only is not None: + m_only = [] + for table in only: + m_names = self.relational_data.get_modelable_table_names(table) + if len(m_names) == 0: + raise MultiTableException(f"Unrecognized table name: `{table}`") + m_only.extend(m_names) + + m_ignore = None + if ignore is not None: + m_ignore = [] + for table in ignore: + m_names = self.relational_data.get_modelable_table_names(table) + if len(m_names) == 0: + raise MultiTableException(f"Unrecognized table name: `{table}`") + m_ignore.extend(m_names) + + configs = { + table: config + for table in self.relational_data.list_all_tables() + if not skip_table(table, m_only, m_ignore) + } + + self._setup_transforms_train_state(configs) task = TransformsTrainTask( transforms_train=self._transforms_train, multitable=self, diff --git a/tests/relational/test_relational_data.py b/tests/relational/test_relational_data.py index 755bf8b1..cf5141b0 100644 --- a/tests/relational/test_relational_data.py +++ b/tests/relational/test_relational_data.py @@ -1,6 +1,3 @@ -import os -import tempfile - import pandas as pd import pytest diff --git a/tests/relational/test_relational_data_with_json.py b/tests/relational/test_relational_data_with_json.py index 9ef30c60..ce2b91a3 100644 --- a/tests/relational/test_relational_data_with_json.py +++ b/tests/relational/test_relational_data_with_json.py @@ -108,6 +108,23 @@ def test_list_tables_accepts_various_scopes(documents): ) +def test_get_modelable_table_names(documents): + # Given a source-with-JSON name, returns the tables invented from that source + assert set(documents.get_modelable_table_names("purchases")) == { + "purchases-sfx", + "purchases-data-years-sfx", + } + + # Invented tables are modelable + assert documents.get_modelable_table_names("purchases-sfx") == ["purchases-sfx"] + assert documents.get_modelable_table_names("purchases-data-years-sfx") == [ + "purchases-data-years-sfx" + ] + + # Unknown tables return empty list + assert documents.get_modelable_table_names("nonsense") == [] + + def test_invented_json_column_names(documents, bball): # The root invented table adds columns for dictionary properties lifted from nested JSON objects assert set(documents.get_table_columns("purchases-sfx")) == { diff --git a/tests/relational/test_train_transforms.py b/tests/relational/test_train_transforms.py new file mode 100644 index 00000000..4c93cdbd --- /dev/null +++ b/tests/relational/test_train_transforms.py @@ -0,0 +1,110 @@ +import tempfile +from unittest.mock import ANY, patch + +import pytest + +from gretel_trainer.relational.core import MultiTableException +from gretel_trainer.relational.multi_table import MultiTable + + +# The assertions in this file are concerned with setting up the transforms train +# workflow state properly, and stop short of kicking off the task. +@pytest.fixture(autouse=True) +def run_task(): + with patch("gretel_trainer.relational.multi_table.run_task"): + yield + + +@pytest.fixture(autouse=True) +def backup(): + with patch.object(MultiTable, "_backup", return_value=None): + yield + + +@pytest.fixture() +def tmpdir(project): + with tempfile.TemporaryDirectory() as tmpdir: + project.name = tmpdir + yield tmpdir + + +def test_train_transforms_defaults_to_transforming_all_tables(ecom, tmpdir): + mt = MultiTable(ecom, project_display_name=tmpdir) + mt.train_transforms("transform/default") + transforms_train = mt._transforms_train + + assert set(transforms_train.models.keys()) == set(ecom.list_all_tables()) + + +def test_train_transforms_only_includes_specified_tables(ecom, tmpdir, project): + mt = MultiTable(ecom, project_display_name=tmpdir) + mt.train_transforms("transform/default", only=["users"]) + transforms_train = mt._transforms_train + + assert set(transforms_train.models.keys()) == {"users"} + project.create_model_obj.assert_called_with( + model_config=ANY, # a tailored transforms config, in dict form + data_source=f"{tmpdir}/transforms_train_users.csv", + ) + + +def test_train_transforms_ignore_excludes_specified_tables(ecom, tmpdir): + mt = MultiTable(ecom, project_display_name=tmpdir) + mt.train_transforms("transform/default", ignore=["distribution_center", "products"]) + transforms_train = mt._transforms_train + + assert set(transforms_train.models.keys()) == { + "events", + "users", + "order_items", + "inventory_items", + } + + +def test_train_transforms_exits_early_if_unrecognized_tables(ecom, tmpdir, project): + mt = MultiTable(ecom, project_display_name=tmpdir) + with pytest.raises(MultiTableException): + mt.train_transforms("transform/default", ignore=["nonsense"]) + transforms_train = mt._transforms_train + + assert len(transforms_train.models) == 0 + project.create_model_obj.assert_not_called() + + +def test_train_transforms_multiple_calls_additive(ecom, tmpdir): + mt = MultiTable(ecom, project_display_name=tmpdir) + mt.train_transforms("transform/default", only=["products"]) + mt.train_transforms("transform/default", only=["users"]) + + # We do not lose the first table model + assert set(mt._transforms_train.models.keys()) == {"products", "users"} + + +def test_train_transforms_multiple_calls_overwrite(ecom, tmpdir, project): + project.create_model_obj.return_value = "m1" + + mt = MultiTable(ecom, project_display_name=tmpdir) + mt.train_transforms("transform/default", only=["products"]) + + assert mt._transforms_train.models["products"] == "m1" + + project.reset_mock() + project.create_model_obj.return_value = "m2" + + # calling a second time will create a new model for the table that overwrites the original + mt.train_transforms("transform/default", only=["products"]) + assert mt._transforms_train.models["products"] == "m2" + + +# The public method under test here is deprecated +def test_train_transform_models(ecom, tmpdir): + mt = MultiTable(ecom, project_display_name=tmpdir) + mt.train_transform_models( + configs={ + "events": "transform/default", + "users": "transform/default", + } + ) + transforms_train = mt._transforms_train + + assert set(transforms_train.models.keys()) == {"events", "users"}