Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Single transforms config #109

Merged
merged 9 commits into from
May 18, 2023
30 changes: 11 additions & 19 deletions src/gretel_trainer/relational/connectors.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,16 @@
import logging
from typing import Dict, List, Optional, Tuple
from typing import Optional

import pandas as pd
from sqlalchemy import create_engine, inspect
from sqlalchemy.engine.base import Engine
from sqlalchemy.exc import OperationalError

from gretel_trainer.relational.core import MultiTableException, RelationalData
from gretel_trainer.relational.core import (
MultiTableException,
RelationalData,
skip_table,
)

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -38,7 +42,7 @@ def __init__(self, engine: Engine):
logger.info("Successfully connected to db")

def extract(
self, only: Optional[List[str]] = None, ignore: Optional[List[str]] = None
self, only: Optional[list[str]] = None, ignore: Optional[list[str]] = None
) -> RelationalData:
"""
Extracts table data and relationships from the database.
Expand All @@ -50,17 +54,17 @@ def extract(
inspector = inspect(self.engine)

relational_data = RelationalData()
foreign_keys: List[Tuple[str, dict]] = []
foreign_keys: list[tuple[str, dict]] = []

for table_name in inspector.get_table_names():
if _skip_table(table_name, only, ignore):
if skip_table(table_name, only, ignore):
continue

logger.debug(f"Extracting source data from `{table_name}`")
df = pd.read_sql_table(table_name, self.engine)
primary_key = inspector.get_pk_constraint(table_name)["constrained_columns"]
for fk in inspector.get_foreign_keys(table_name):
if _skip_table(fk["referred_table"], only, ignore):
if skip_table(fk["referred_table"], only, ignore):
continue
else:
foreign_keys.append((table_name, fk))
Expand All @@ -78,25 +82,13 @@ def extract(

return relational_data

def save(self, tables: Dict[str, pd.DataFrame], prefix: str = "") -> None:
def save(self, tables: dict[str, pd.DataFrame], prefix: str = "") -> None:
for name, data in tables.items():
data.to_sql(
f"{prefix}{name}", con=self.engine, if_exists="replace", index=False
)


def _skip_table(
table: str, only: Optional[List[str]], ignore: Optional[List[str]]
) -> bool:
skip = False
if only is not None and table not in only:
skip = True
if ignore is not None and table in ignore:
skip = True

return skip


def sqlite_conn(path: str) -> Connector:
engine = create_engine(f"sqlite:///{path}")
return Connector(engine)
Expand Down
22 changes: 22 additions & 0 deletions src/gretel_trainer/relational/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -456,6 +456,16 @@ def _is_invented(self, table: str) -> bool:
and self.graph.nodes[table]["metadata"].invented_table_metadata is not None
)

def get_modelable_table_names(self, table: str) -> list[str]:
"""Returns a list of MODELABLE table names connected to the provided table.
If the provided table is already modelable, returns [table].
If the provided table is not modelable (e.g. source with JSON), returns tables invented from that source.
"""
if (rel_json := self.relational_jsons.get(table)) is not None:
return rel_json.table_names
else:
return [table]

def get_public_name(self, table: str) -> Optional[str]:
if table in self.relational_jsons:
return table
Expand Down Expand Up @@ -651,6 +661,18 @@ def debug_summary(self) -> Dict[str, Any]:
}


def skip_table(
table: str, only: Optional[list[str]], ignore: Optional[list[str]]
) -> bool:
skip = False
if only is not None and table not in only:
skip = True
if ignore is not None and table in ignore:
skip = True

return skip


def _ok_for_train_and_seed(col: str, df: pd.DataFrame) -> bool:
if _is_highly_nan(col, df):
return False
Expand Down
67 changes: 67 additions & 0 deletions src/gretel_trainer/relational/multi_table.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
MultiTableException,
RelationalData,
Scope,
skip_table,
)
from gretel_trainer.relational.json import InventedTableMetadata, RelationalJson
from gretel_trainer.relational.log import silent_logs
Expand Down Expand Up @@ -557,7 +558,73 @@ def classify(self, config: GretelModelConfig, all_rows: bool = False) -> None:
)

def train_transform_models(self, configs: Dict[str, GretelModelConfig]) -> None:
"""
DEPRECATED: Please use `train_transforms` instead.
"""
logger.warning(
"This method is deprecated and will be removed in a future release. "
"Please use `train_transforms` instead."
)
use_configs = {}
for table, config in configs.items():
for m_table in self.relational_data.get_modelable_table_names(table):
use_configs[m_table] = config

for table, config in use_configs.items():
transform_config = make_transform_config(
self.relational_data, table, config
)

# Ensure consistent, friendly data source names in Console
table_data = self.relational_data.get_table_data(table)
transforms_train_path = self._working_dir / f"transforms_train_{table}.csv"
table_data.to_csv(transforms_train_path, index=False)

# Create model
model = self._project.create_model_obj(
model_config=transform_config, data_source=str(transforms_train_path)
)
self._transforms_train.models[table] = model

self._backup()

task = TransformsTrainTask(
transforms_train=self._transforms_train,
multitable=self,
)
run_task(task, self._extended_sdk)
mikeknep marked this conversation as resolved.
Show resolved Hide resolved

def train_transforms(
self,
config: GretelModelConfig,
*,
only: Optional[list[str]] = None,
ignore: Optional[list[str]] = None,
) -> None:
if only is not None and ignore is not None:
raise MultiTableException("Cannot specify both `only` and `ignore`.")

if only is not None:
only = [
modelable_name
for table in only
for modelable_name in self.relational_data.get_modelable_table_names(
table
)
]
if ignore is not None:
ignore = [
modelable_name
for table in ignore
for modelable_name in self.relational_data.get_modelable_table_names(
table
)
]

for table in self.relational_data.list_all_tables():
if skip_table(table, only, ignore):
continue

transform_config = make_transform_config(
self.relational_data, table, config
)
Expand Down
3 changes: 0 additions & 3 deletions tests/relational/test_relational_data.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,3 @@
import os
import tempfile

import pandas as pd
import pytest

Expand Down
14 changes: 14 additions & 0 deletions tests/relational/test_relational_data_with_json.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,20 @@ def test_list_tables_accepts_various_scopes(documents):
)


def test_get_modelable_table_names(documents):
# Given a source-with-JSON name, returns the tables invented from that source
assert set(documents.get_modelable_table_names("purchases")) == {
"purchases-sfx",
"purchases-data-years-sfx",
}

# Invented tables are modelable
assert documents.get_modelable_table_names("purchases-sfx") == ["purchases-sfx"]
assert documents.get_modelable_table_names("purchases-data-years-sfx") == [
"purchases-data-years-sfx"
]


def test_invented_json_column_names(documents, bball):
# The root invented table adds columns for dictionary properties lifted from nested JSON objects
assert set(documents.get_table_columns("purchases-sfx")) == {
Expand Down