Skip to content

Commit

Permalink
Reduce unnecessary JSON parsing (#112)
Browse files Browse the repository at this point in the history
  • Loading branch information
mikeknep authored May 25, 2023
1 parent 7047e94 commit e89359a
Show file tree
Hide file tree
Showing 3 changed files with 94 additions and 32 deletions.
23 changes: 16 additions & 7 deletions src/gretel_trainer/relational/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
IngestResponseT,
InventedTableMetadata,
RelationalJson,
get_json_columns,
)

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -129,12 +130,21 @@ def add_table(
the table includes nested JSON data.
"""
primary_key = self._format_key_column(primary_key)
rj_ingest = RelationalJson.ingest(name, primary_key, data)
if rj_ingest is not None:
if (rj_ingest := self._check_for_json(name, primary_key, data)) is not None:
self._add_rel_json_and_tables(name, rj_ingest)
else:
self._add_single_table(name=name, primary_key=primary_key, data=data)

def _check_for_json(
self,
table: str,
primary_key: list[str],
data: pd.DataFrame,
) -> Optional[IngestResponseT]:
json_cols = get_json_columns(data)
if len(json_cols) > 0:
return RelationalJson.ingest(table, primary_key, data, json_cols)

def _add_rel_json_and_tables(self, table: str, rj_ingest: IngestResponseT) -> None:
rel_json, commands = rj_ingest
tables, foreign_keys = commands
Expand Down Expand Up @@ -374,8 +384,9 @@ def remove_foreign_key_constraint(
def update_table_data(self, table: str, data: pd.DataFrame) -> None:
if table in self.relational_jsons:
_, original_pk, original_fks = self._remove_relational_json(table)
new_rj_ingest = RelationalJson.ingest(table, original_pk, data)
if new_rj_ingest is not None:
if (
new_rj_ingest := self._check_for_json(table, original_pk, data)
) is not None:
self._add_rel_json_and_tables(table, new_rj_ingest)
parent_table_name = new_rj_ingest[0].root_table_name
else:
Expand All @@ -401,9 +412,7 @@ def update_table_data(self, table: str, data: pd.DataFrame) -> None:
)

if (
new_rj_ingest := RelationalJson.ingest(
table, metadata.primary_key, data
)
new_rj_ingest := self._check_for_json(table, metadata.primary_key, data)
) is not None:
original_foreign_keys = self._get_user_defined_fks_to_table(table)
self.graph.remove_node(table)
Expand Down
68 changes: 43 additions & 25 deletions src/gretel_trainer/relational/json.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@

logger = logging.getLogger(__name__)

PREVIEW_ROW_COUNT = 5

# JSON dict to multi-column and list to multi-table

FIELD_SEPARATOR = ">"
Expand Down Expand Up @@ -60,41 +62,39 @@ def nulls_to_empty_lists(series: pd.Series) -> pd.Series:


def _normalize_json(
nested_dfs: list[tuple[str, pd.DataFrame]], flat_dfs: list[tuple[str, pd.DataFrame]]
nested_dfs: list[tuple[str, pd.DataFrame]],
flat_dfs: list[tuple[str, pd.DataFrame]],
columns: Optional[list[str]] = None,
) -> list[tuple[str, pd.DataFrame]]:
if not nested_dfs:
return flat_dfs
name, df = nested_dfs.pop()
dict_cols = [
col
for col in df.columns
if df[col].apply(is_dict).any() and df[col].dropna().apply(is_dict).all()
]
list_cols = [
col
for col in df.columns
if df[col].apply(is_list).any() and df[col].dropna().apply(is_list).all()
]
cols_to_scan = columns or [col for col in df.columns if df.dtypes[col] == "object"]
dict_cols = [col for col in cols_to_scan if df[col].dropna().apply(is_dict).all()]
if dict_cols:
df[dict_cols] = nulls_to_empty_dicts(df[dict_cols])
for col in dict_cols:
new_cols = pandas_json_normalize(df[col]).add_prefix(col + FIELD_SEPARATOR)
df = pd.concat([df, new_cols], axis="columns")
df = df.drop(columns=new_cols.columns[new_cols.isnull().all()])
nested_dfs.append((name, df.drop(columns=dict_cols)))
elif list_cols:
for col in list_cols:
new_table = df[col].explode().dropna().rename(CONTENT_COLUMN).to_frame()
new_table[ORDER_COLUMN] = new_table.groupby(level=0).cumcount()
nested_dfs.append(
(
name + TABLE_SEPARATOR + col,
new_table.reset_index(names=name + ID_SUFFIX),
)
)
nested_dfs.append((name, df.drop(columns=list_cols)))
else:
flat_dfs.append((name, df))
list_cols = [
col for col in cols_to_scan if df[col].dropna().apply(is_list).all()
]
if list_cols:
for col in list_cols:
new_table = df[col].explode().dropna().rename(CONTENT_COLUMN).to_frame()
new_table[ORDER_COLUMN] = new_table.groupby(level=0).cumcount()
nested_dfs.append(
(
name + TABLE_SEPARATOR + col,
new_table.reset_index(names=name + ID_SUFFIX),
)
)
nested_dfs.append((name, df.drop(columns=list_cols)))
else:
flat_dfs.append((name, df))
return _normalize_json(nested_dfs, flat_dfs)


Expand Down Expand Up @@ -167,10 +167,14 @@ def __init__(

@classmethod
def ingest(
cls, table_name: str, primary_key: list[str], df: pd.DataFrame
cls,
table_name: str,
primary_key: list[str],
df: pd.DataFrame,
json_columns: Optional[list[str]] = None,
) -> Optional[IngestResponseT]:
logger.debug(f"Checking table `{table_name}` for JSON columns")
tables = _normalize_json([(table_name, df.copy())], [])
tables = _normalize_json([(table_name, df.copy())], [], json_columns)
# If we created additional tables (from JSON lists) or added columns (from JSON dicts)
if len(tables) > 1 or len(tables[0][1].columns) > len(df.columns):
mappings = {name: sanitize_str(name) for name, _ in tables}
Expand Down Expand Up @@ -336,5 +340,19 @@ def _generate_commands(
return (_add_single_table, add_foreign_key)


def get_json_columns(df: pd.DataFrame) -> list[str]:
column_previews = {
col: df[col].dropna().head(PREVIEW_ROW_COUNT)
for col in df.columns
if df.dtypes[col] == "object"
}

return [
col
for col, series in column_previews.items()
if series.apply(is_dict).all() or series.apply(is_list).all()
]


CommandsT = tuple[list[dict], list[dict]]
IngestResponseT = tuple[RelationalJson, CommandsT]
35 changes: 35 additions & 0 deletions tests/relational/test_relational_data_with_json.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
RelationalData,
Scope,
)
from gretel_trainer.relational.json import get_json_columns


@pytest.fixture
Expand All @@ -24,6 +25,13 @@ def bball():
return rel_data


def test_list_json_cols(documents, bball):
assert get_json_columns(documents.get_table_data("users")) == []
assert get_json_columns(documents.get_table_data("purchases")) == ["data"]

assert set(get_json_columns(bball.get_table_data("bball"))) == {"draft", "teams"}


def test_json_columns_produce_invented_flattened_tables(documents):
pdtest.assert_frame_equal(
documents.get_table_data("purchases-sfx"),
Expand Down Expand Up @@ -751,6 +759,33 @@ def test_lists_of_lists():
)


def test_mix_of_dict_and_list_cols():
df = pd.DataFrame(
data={
"id": [1, 2],
"dcol": [{"language": "english"}, {"language": "spanish"}],
"lcol": [["a", "b"], ["c", "d"]],
}
)
rel_data = RelationalData()
rel_data.add_table(name="mix", primary_key=None, data=df)
assert set(rel_data.list_all_tables()) == {
"mix-sfx",
"mix-lcol-sfx",
}
assert set(rel_data.get_table_data("mix-sfx").columns) == {
"id",
"~PRIMARY_KEY_ID~",
"dcol>language",
}
assert set(rel_data.get_table_data("mix-lcol-sfx").columns) == {
"~PRIMARY_KEY_ID~",
"content",
"array~order",
"mix~id",
}


def test_all_tables_are_present_in_debug_summary(documents):
assert documents.debug_summary() == {
"foreign_key_count": 4,
Expand Down

0 comments on commit e89359a

Please sign in to comment.