Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Single transforms config #109

Merged
merged 9 commits into from
May 18, 2023
30 changes: 11 additions & 19 deletions src/gretel_trainer/relational/connectors.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,16 @@
import logging
from typing import Dict, List, Optional, Tuple
from typing import Optional

import pandas as pd
from sqlalchemy import create_engine, inspect
from sqlalchemy.engine.base import Engine
from sqlalchemy.exc import OperationalError

from gretel_trainer.relational.core import MultiTableException, RelationalData
from gretel_trainer.relational.core import (
MultiTableException,
RelationalData,
skip_table,
)

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -38,7 +42,7 @@ def __init__(self, engine: Engine):
logger.info("Successfully connected to db")

def extract(
self, only: Optional[List[str]] = None, ignore: Optional[List[str]] = None
self, only: Optional[list[str]] = None, ignore: Optional[list[str]] = None
) -> RelationalData:
"""
Extracts table data and relationships from the database.
Expand All @@ -50,17 +54,17 @@ def extract(
inspector = inspect(self.engine)

relational_data = RelationalData()
foreign_keys: List[Tuple[str, dict]] = []
foreign_keys: list[tuple[str, dict]] = []

for table_name in inspector.get_table_names():
if _skip_table(table_name, only, ignore):
if skip_table(table_name, only, ignore):
continue

logger.debug(f"Extracting source data from `{table_name}`")
df = pd.read_sql_table(table_name, self.engine)
primary_key = inspector.get_pk_constraint(table_name)["constrained_columns"]
for fk in inspector.get_foreign_keys(table_name):
if _skip_table(fk["referred_table"], only, ignore):
if skip_table(fk["referred_table"], only, ignore):
continue
else:
foreign_keys.append((table_name, fk))
Expand All @@ -78,25 +82,13 @@ def extract(

return relational_data

def save(self, tables: Dict[str, pd.DataFrame], prefix: str = "") -> None:
def save(self, tables: dict[str, pd.DataFrame], prefix: str = "") -> None:
for name, data in tables.items():
data.to_sql(
f"{prefix}{name}", con=self.engine, if_exists="replace", index=False
)


def _skip_table(
table: str, only: Optional[List[str]], ignore: Optional[List[str]]
) -> bool:
skip = False
if only is not None and table not in only:
skip = True
if ignore is not None and table in ignore:
skip = True

return skip


def sqlite_conn(path: str) -> Connector:
engine = create_engine(f"sqlite:///{path}")
return Connector(engine)
Expand Down
25 changes: 25 additions & 0 deletions src/gretel_trainer/relational/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -456,6 +456,19 @@ def _is_invented(self, table: str) -> bool:
and self.graph.nodes[table]["metadata"].invented_table_metadata is not None
)

def get_modelable_table_names(self, table: str) -> list[str]:
"""Returns a list of MODELABLE table names connected to the provided table.
If the provided table is already modelable, returns [table].
If the provided table is not modelable (e.g. source with JSON), returns tables invented from that source.
If the provided table does not exist, returns empty list.
"""
if (rel_json := self.relational_jsons.get(table)) is not None:
return rel_json.table_names
elif table not in self.list_all_tables(Scope.ALL):
return []
else:
return [table]

def get_public_name(self, table: str) -> Optional[str]:
if table in self.relational_jsons:
return table
Expand Down Expand Up @@ -651,6 +664,18 @@ def debug_summary(self) -> Dict[str, Any]:
}


def skip_table(
table: str, only: Optional[list[str]], ignore: Optional[list[str]]
) -> bool:
skip = False
if only is not None and table not in only:
skip = True
if ignore is not None and table in ignore:
skip = True

return skip


def _ok_for_train_and_seed(col: str, df: pd.DataFrame) -> bool:
if _is_highly_nan(col, df):
return False
Expand Down
60 changes: 59 additions & 1 deletion src/gretel_trainer/relational/multi_table.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
MultiTableException,
RelationalData,
Scope,
skip_table,
)
from gretel_trainer.relational.json import InventedTableMetadata, RelationalJson
from gretel_trainer.relational.log import silent_logs
Expand Down Expand Up @@ -556,7 +557,9 @@ def classify(self, config: GretelModelConfig, all_rows: bool = False) -> None:
self._project, str(archive_path)
)

def train_transform_models(self, configs: Dict[str, GretelModelConfig]) -> None:
def _setup_transforms_train_state(
self, configs: dict[str, GretelModelConfig]
) -> None:
for table, config in configs.items():
transform_config = make_transform_config(
self.relational_data, table, config
Expand All @@ -575,6 +578,61 @@ def train_transform_models(self, configs: Dict[str, GretelModelConfig]) -> None:

self._backup()

def train_transform_models(self, configs: dict[str, GretelModelConfig]) -> None:
"""
DEPRECATED: Please use `train_transforms` instead.
"""
logger.warning(
"This method is deprecated and will be removed in a future release. "
"Please use `train_transforms` instead."
)
use_configs = {}
for table, config in configs.items():
for m_table in self.relational_data.get_modelable_table_names(table):
use_configs[m_table] = config

self._setup_transforms_train_state(use_configs)
task = TransformsTrainTask(
transforms_train=self._transforms_train,
multitable=self,
)
run_task(task, self._extended_sdk)

def train_transforms(
self,
config: GretelModelConfig,
*,
only: Optional[list[str]] = None,
ignore: Optional[list[str]] = None,
) -> None:
if only is not None and ignore is not None:
raise MultiTableException("Cannot specify both `only` and `ignore`.")

m_only = None
if only is not None:
m_only = []
for table in only:
m_names = self.relational_data.get_modelable_table_names(table)
if len(m_names) == 0:
raise MultiTableException(f"Unrecognized table name: `{table}`")
m_only.extend(m_names)

m_ignore = None
if ignore is not None:
m_ignore = []
for table in ignore:
m_names = self.relational_data.get_modelable_table_names(table)
if len(m_names) == 0:
raise MultiTableException(f"Unrecognized table name: `{table}`")
m_ignore.extend(m_names)

configs = {
table: config
for table in self.relational_data.list_all_tables()
if not skip_table(table, m_only, m_ignore)
}

self._setup_transforms_train_state(configs)
task = TransformsTrainTask(
transforms_train=self._transforms_train,
multitable=self,
Expand Down
3 changes: 0 additions & 3 deletions tests/relational/test_relational_data.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,3 @@
import os
import tempfile

import pandas as pd
import pytest

Expand Down
17 changes: 17 additions & 0 deletions tests/relational/test_relational_data_with_json.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,23 @@ def test_list_tables_accepts_various_scopes(documents):
)


def test_get_modelable_table_names(documents):
# Given a source-with-JSON name, returns the tables invented from that source
assert set(documents.get_modelable_table_names("purchases")) == {
"purchases-sfx",
"purchases-data-years-sfx",
}

# Invented tables are modelable
assert documents.get_modelable_table_names("purchases-sfx") == ["purchases-sfx"]
assert documents.get_modelable_table_names("purchases-data-years-sfx") == [
"purchases-data-years-sfx"
]

# Unknown tables return empty list
assert documents.get_modelable_table_names("nonsense") == []


def test_invented_json_column_names(documents, bball):
# The root invented table adds columns for dictionary properties lifted from nested JSON objects
assert set(documents.get_table_columns("purchases-sfx")) == {
Expand Down
85 changes: 85 additions & 0 deletions tests/relational/test_train_transforms.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
import tempfile
from unittest.mock import ANY, patch

import pytest

from gretel_trainer.relational.core import MultiTableException
from gretel_trainer.relational.multi_table import MultiTable


# The assertions in this file are concerned with setting up the transforms train
# workflow state properly, and stop short of kicking off the task.
@pytest.fixture(autouse=True)
def run_task():
with patch("gretel_trainer.relational.multi_table.run_task"):
yield


@pytest.fixture(autouse=True)
def backup():
with patch.object(MultiTable, "_backup", return_value=None):
yield


@pytest.fixture()
def tmpdir(project):
with tempfile.TemporaryDirectory() as tmpdir:
project.name = tmpdir
yield tmpdir


def test_train_transforms_defaults_to_transforming_all_tables(ecom, tmpdir):
mt = MultiTable(ecom, project_display_name=tmpdir)
mt.train_transforms("transform/default")
transforms_train = mt._transforms_train

assert set(transforms_train.models.keys()) == set(ecom.list_all_tables())


def test_train_transforms_only_includes_specified_tables(ecom, tmpdir, project):
mt = MultiTable(ecom, project_display_name=tmpdir)
mt.train_transforms("transform/default", only=["users"])
pimlock marked this conversation as resolved.
Show resolved Hide resolved
transforms_train = mt._transforms_train

assert set(transforms_train.models.keys()) == {"users"}
project.create_model_obj.assert_called_with(
model_config=ANY, # a tailored transforms config, in dict form
data_source=f"{tmpdir}/transforms_train_users.csv",
)


def test_train_transforms_ignore_excludes_specified_tables(ecom, tmpdir):
mt = MultiTable(ecom, project_display_name=tmpdir)
mt.train_transforms("transform/default", ignore=["distribution_center", "products"])
transforms_train = mt._transforms_train

assert set(transforms_train.models.keys()) == {
"events",
"users",
"order_items",
"inventory_items",
}


def test_train_transforms_exits_early_if_unrecognized_tables(ecom, tmpdir, project):
mt = MultiTable(ecom, project_display_name=tmpdir)
with pytest.raises(MultiTableException):
mt.train_transforms("transform/default", ignore=["nonsense"])
transforms_train = mt._transforms_train

assert len(transforms_train.models) == 0
project.create_model_obj.assert_not_called()


# The public method under test here is deprecated
def test_train_transform_models(ecom, tmpdir):
mt = MultiTable(ecom, project_display_name=tmpdir)
mt.train_transform_models(
configs={
"events": "transform/default",
"users": "transform/default",
}
)
transforms_train = mt._transforms_train

assert set(transforms_train.models.keys()) == {"events", "users"}