Skip to content

Commit

Permalink
Single transforms config (#109)
Browse files Browse the repository at this point in the history
* Move skip_table function to core module
* Add method to return modelable tables from a given table name
* Add new transform method that accepts a single config for all tables
  • Loading branch information
mikeknep authored May 18, 2023
1 parent 58e8891 commit 8ac8209
Show file tree
Hide file tree
Showing 6 changed files with 222 additions and 23 deletions.
30 changes: 11 additions & 19 deletions src/gretel_trainer/relational/connectors.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,16 @@
import logging
from typing import Dict, List, Optional, Tuple
from typing import Optional

import pandas as pd
from sqlalchemy import create_engine, inspect
from sqlalchemy.engine.base import Engine
from sqlalchemy.exc import OperationalError

from gretel_trainer.relational.core import MultiTableException, RelationalData
from gretel_trainer.relational.core import (
MultiTableException,
RelationalData,
skip_table,
)

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -38,7 +42,7 @@ def __init__(self, engine: Engine):
logger.info("Successfully connected to db")

def extract(
self, only: Optional[List[str]] = None, ignore: Optional[List[str]] = None
self, only: Optional[list[str]] = None, ignore: Optional[list[str]] = None
) -> RelationalData:
"""
Extracts table data and relationships from the database.
Expand All @@ -50,17 +54,17 @@ def extract(
inspector = inspect(self.engine)

relational_data = RelationalData()
foreign_keys: List[Tuple[str, dict]] = []
foreign_keys: list[tuple[str, dict]] = []

for table_name in inspector.get_table_names():
if _skip_table(table_name, only, ignore):
if skip_table(table_name, only, ignore):
continue

logger.debug(f"Extracting source data from `{table_name}`")
df = pd.read_sql_table(table_name, self.engine)
primary_key = inspector.get_pk_constraint(table_name)["constrained_columns"]
for fk in inspector.get_foreign_keys(table_name):
if _skip_table(fk["referred_table"], only, ignore):
if skip_table(fk["referred_table"], only, ignore):
continue
else:
foreign_keys.append((table_name, fk))
Expand All @@ -78,25 +82,13 @@ def extract(

return relational_data

def save(self, tables: Dict[str, pd.DataFrame], prefix: str = "") -> None:
def save(self, tables: dict[str, pd.DataFrame], prefix: str = "") -> None:
for name, data in tables.items():
data.to_sql(
f"{prefix}{name}", con=self.engine, if_exists="replace", index=False
)


def _skip_table(
table: str, only: Optional[List[str]], ignore: Optional[List[str]]
) -> bool:
skip = False
if only is not None and table not in only:
skip = True
if ignore is not None and table in ignore:
skip = True

return skip


def sqlite_conn(path: str) -> Connector:
engine = create_engine(f"sqlite:///{path}")
return Connector(engine)
Expand Down
25 changes: 25 additions & 0 deletions src/gretel_trainer/relational/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -456,6 +456,19 @@ def _is_invented(self, table: str) -> bool:
and self.graph.nodes[table]["metadata"].invented_table_metadata is not None
)

def get_modelable_table_names(self, table: str) -> list[str]:
"""Returns a list of MODELABLE table names connected to the provided table.
If the provided table is already modelable, returns [table].
If the provided table is not modelable (e.g. source with JSON), returns tables invented from that source.
If the provided table does not exist, returns empty list.
"""
if (rel_json := self.relational_jsons.get(table)) is not None:
return rel_json.table_names
elif table not in self.list_all_tables(Scope.ALL):
return []
else:
return [table]

def get_public_name(self, table: str) -> Optional[str]:
if table in self.relational_jsons:
return table
Expand Down Expand Up @@ -651,6 +664,18 @@ def debug_summary(self) -> Dict[str, Any]:
}


def skip_table(
table: str, only: Optional[list[str]], ignore: Optional[list[str]]
) -> bool:
skip = False
if only is not None and table not in only:
skip = True
if ignore is not None and table in ignore:
skip = True

return skip


def _ok_for_train_and_seed(col: str, df: pd.DataFrame) -> bool:
if _is_highly_nan(col, df):
return False
Expand Down
60 changes: 59 additions & 1 deletion src/gretel_trainer/relational/multi_table.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
MultiTableException,
RelationalData,
Scope,
skip_table,
)
from gretel_trainer.relational.json import InventedTableMetadata, RelationalJson
from gretel_trainer.relational.log import silent_logs
Expand Down Expand Up @@ -556,7 +557,9 @@ def classify(self, config: GretelModelConfig, all_rows: bool = False) -> None:
self._project, str(archive_path)
)

def train_transform_models(self, configs: Dict[str, GretelModelConfig]) -> None:
def _setup_transforms_train_state(
self, configs: dict[str, GretelModelConfig]
) -> None:
for table, config in configs.items():
transform_config = make_transform_config(
self.relational_data, table, config
Expand All @@ -575,6 +578,61 @@ def train_transform_models(self, configs: Dict[str, GretelModelConfig]) -> None:

self._backup()

def train_transform_models(self, configs: dict[str, GretelModelConfig]) -> None:
"""
DEPRECATED: Please use `train_transforms` instead.
"""
logger.warning(
"This method is deprecated and will be removed in a future release. "
"Please use `train_transforms` instead."
)
use_configs = {}
for table, config in configs.items():
for m_table in self.relational_data.get_modelable_table_names(table):
use_configs[m_table] = config

self._setup_transforms_train_state(use_configs)
task = TransformsTrainTask(
transforms_train=self._transforms_train,
multitable=self,
)
run_task(task, self._extended_sdk)

def train_transforms(
self,
config: GretelModelConfig,
*,
only: Optional[list[str]] = None,
ignore: Optional[list[str]] = None,
) -> None:
if only is not None and ignore is not None:
raise MultiTableException("Cannot specify both `only` and `ignore`.")

m_only = None
if only is not None:
m_only = []
for table in only:
m_names = self.relational_data.get_modelable_table_names(table)
if len(m_names) == 0:
raise MultiTableException(f"Unrecognized table name: `{table}`")
m_only.extend(m_names)

m_ignore = None
if ignore is not None:
m_ignore = []
for table in ignore:
m_names = self.relational_data.get_modelable_table_names(table)
if len(m_names) == 0:
raise MultiTableException(f"Unrecognized table name: `{table}`")
m_ignore.extend(m_names)

configs = {
table: config
for table in self.relational_data.list_all_tables()
if not skip_table(table, m_only, m_ignore)
}

self._setup_transforms_train_state(configs)
task = TransformsTrainTask(
transforms_train=self._transforms_train,
multitable=self,
Expand Down
3 changes: 0 additions & 3 deletions tests/relational/test_relational_data.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,3 @@
import os
import tempfile

import pandas as pd
import pytest

Expand Down
17 changes: 17 additions & 0 deletions tests/relational/test_relational_data_with_json.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,23 @@ def test_list_tables_accepts_various_scopes(documents):
)


def test_get_modelable_table_names(documents):
# Given a source-with-JSON name, returns the tables invented from that source
assert set(documents.get_modelable_table_names("purchases")) == {
"purchases-sfx",
"purchases-data-years-sfx",
}

# Invented tables are modelable
assert documents.get_modelable_table_names("purchases-sfx") == ["purchases-sfx"]
assert documents.get_modelable_table_names("purchases-data-years-sfx") == [
"purchases-data-years-sfx"
]

# Unknown tables return empty list
assert documents.get_modelable_table_names("nonsense") == []


def test_invented_json_column_names(documents, bball):
# The root invented table adds columns for dictionary properties lifted from nested JSON objects
assert set(documents.get_table_columns("purchases-sfx")) == {
Expand Down
110 changes: 110 additions & 0 deletions tests/relational/test_train_transforms.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,110 @@
import tempfile
from unittest.mock import ANY, patch

import pytest

from gretel_trainer.relational.core import MultiTableException
from gretel_trainer.relational.multi_table import MultiTable


# The assertions in this file are concerned with setting up the transforms train
# workflow state properly, and stop short of kicking off the task.
@pytest.fixture(autouse=True)
def run_task():
with patch("gretel_trainer.relational.multi_table.run_task"):
yield


@pytest.fixture(autouse=True)
def backup():
with patch.object(MultiTable, "_backup", return_value=None):
yield


@pytest.fixture()
def tmpdir(project):
with tempfile.TemporaryDirectory() as tmpdir:
project.name = tmpdir
yield tmpdir


def test_train_transforms_defaults_to_transforming_all_tables(ecom, tmpdir):
mt = MultiTable(ecom, project_display_name=tmpdir)
mt.train_transforms("transform/default")
transforms_train = mt._transforms_train

assert set(transforms_train.models.keys()) == set(ecom.list_all_tables())


def test_train_transforms_only_includes_specified_tables(ecom, tmpdir, project):
mt = MultiTable(ecom, project_display_name=tmpdir)
mt.train_transforms("transform/default", only=["users"])
transforms_train = mt._transforms_train

assert set(transforms_train.models.keys()) == {"users"}
project.create_model_obj.assert_called_with(
model_config=ANY, # a tailored transforms config, in dict form
data_source=f"{tmpdir}/transforms_train_users.csv",
)


def test_train_transforms_ignore_excludes_specified_tables(ecom, tmpdir):
mt = MultiTable(ecom, project_display_name=tmpdir)
mt.train_transforms("transform/default", ignore=["distribution_center", "products"])
transforms_train = mt._transforms_train

assert set(transforms_train.models.keys()) == {
"events",
"users",
"order_items",
"inventory_items",
}


def test_train_transforms_exits_early_if_unrecognized_tables(ecom, tmpdir, project):
mt = MultiTable(ecom, project_display_name=tmpdir)
with pytest.raises(MultiTableException):
mt.train_transforms("transform/default", ignore=["nonsense"])
transforms_train = mt._transforms_train

assert len(transforms_train.models) == 0
project.create_model_obj.assert_not_called()


def test_train_transforms_multiple_calls_additive(ecom, tmpdir):
mt = MultiTable(ecom, project_display_name=tmpdir)
mt.train_transforms("transform/default", only=["products"])
mt.train_transforms("transform/default", only=["users"])

# We do not lose the first table model
assert set(mt._transforms_train.models.keys()) == {"products", "users"}


def test_train_transforms_multiple_calls_overwrite(ecom, tmpdir, project):
project.create_model_obj.return_value = "m1"

mt = MultiTable(ecom, project_display_name=tmpdir)
mt.train_transforms("transform/default", only=["products"])

assert mt._transforms_train.models["products"] == "m1"

project.reset_mock()
project.create_model_obj.return_value = "m2"

# calling a second time will create a new model for the table that overwrites the original
mt.train_transforms("transform/default", only=["products"])
assert mt._transforms_train.models["products"] == "m2"


# The public method under test here is deprecated
def test_train_transform_models(ecom, tmpdir):
mt = MultiTable(ecom, project_display_name=tmpdir)
mt.train_transform_models(
configs={
"events": "transform/default",
"users": "transform/default",
}
)
transforms_train = mt._transforms_train

assert set(transforms_train.models.keys()) == {"events", "users"}

0 comments on commit 8ac8209

Please sign in to comment.