From e89359ae3039cd3bc2fdd47a8751a2a2b535c409 Mon Sep 17 00:00:00 2001 From: Mike Knepper Date: Thu, 25 May 2023 14:39:34 -0500 Subject: [PATCH] Reduce unnecessary JSON parsing (#112) --- src/gretel_trainer/relational/core.py | 23 +++++-- src/gretel_trainer/relational/json.py | 68 ++++++++++++------- .../test_relational_data_with_json.py | 35 ++++++++++ 3 files changed, 94 insertions(+), 32 deletions(-) diff --git a/src/gretel_trainer/relational/core.py b/src/gretel_trainer/relational/core.py index 5eea88ef..393c1782 100644 --- a/src/gretel_trainer/relational/core.py +++ b/src/gretel_trainer/relational/core.py @@ -16,6 +16,7 @@ IngestResponseT, InventedTableMetadata, RelationalJson, + get_json_columns, ) logger = logging.getLogger(__name__) @@ -129,12 +130,21 @@ def add_table( the table includes nested JSON data. """ primary_key = self._format_key_column(primary_key) - rj_ingest = RelationalJson.ingest(name, primary_key, data) - if rj_ingest is not None: + if (rj_ingest := self._check_for_json(name, primary_key, data)) is not None: self._add_rel_json_and_tables(name, rj_ingest) else: self._add_single_table(name=name, primary_key=primary_key, data=data) + def _check_for_json( + self, + table: str, + primary_key: list[str], + data: pd.DataFrame, + ) -> Optional[IngestResponseT]: + json_cols = get_json_columns(data) + if len(json_cols) > 0: + return RelationalJson.ingest(table, primary_key, data, json_cols) + def _add_rel_json_and_tables(self, table: str, rj_ingest: IngestResponseT) -> None: rel_json, commands = rj_ingest tables, foreign_keys = commands @@ -374,8 +384,9 @@ def remove_foreign_key_constraint( def update_table_data(self, table: str, data: pd.DataFrame) -> None: if table in self.relational_jsons: _, original_pk, original_fks = self._remove_relational_json(table) - new_rj_ingest = RelationalJson.ingest(table, original_pk, data) - if new_rj_ingest is not None: + if ( + new_rj_ingest := self._check_for_json(table, original_pk, data) + ) is not None: self._add_rel_json_and_tables(table, new_rj_ingest) parent_table_name = new_rj_ingest[0].root_table_name else: @@ -401,9 +412,7 @@ def update_table_data(self, table: str, data: pd.DataFrame) -> None: ) if ( - new_rj_ingest := RelationalJson.ingest( - table, metadata.primary_key, data - ) + new_rj_ingest := self._check_for_json(table, metadata.primary_key, data) ) is not None: original_foreign_keys = self._get_user_defined_fks_to_table(table) self.graph.remove_node(table) diff --git a/src/gretel_trainer/relational/json.py b/src/gretel_trainer/relational/json.py index be74f213..5c88bb12 100644 --- a/src/gretel_trainer/relational/json.py +++ b/src/gretel_trainer/relational/json.py @@ -13,6 +13,8 @@ logger = logging.getLogger(__name__) +PREVIEW_ROW_COUNT = 5 + # JSON dict to multi-column and list to multi-table FIELD_SEPARATOR = ">" @@ -60,21 +62,15 @@ def nulls_to_empty_lists(series: pd.Series) -> pd.Series: def _normalize_json( - nested_dfs: list[tuple[str, pd.DataFrame]], flat_dfs: list[tuple[str, pd.DataFrame]] + nested_dfs: list[tuple[str, pd.DataFrame]], + flat_dfs: list[tuple[str, pd.DataFrame]], + columns: Optional[list[str]] = None, ) -> list[tuple[str, pd.DataFrame]]: if not nested_dfs: return flat_dfs name, df = nested_dfs.pop() - dict_cols = [ - col - for col in df.columns - if df[col].apply(is_dict).any() and df[col].dropna().apply(is_dict).all() - ] - list_cols = [ - col - for col in df.columns - if df[col].apply(is_list).any() and df[col].dropna().apply(is_list).all() - ] + cols_to_scan = columns or [col for col in df.columns if df.dtypes[col] == "object"] + dict_cols = [col for col in cols_to_scan if df[col].dropna().apply(is_dict).all()] if dict_cols: df[dict_cols] = nulls_to_empty_dicts(df[dict_cols]) for col in dict_cols: @@ -82,19 +78,23 @@ def _normalize_json( df = pd.concat([df, new_cols], axis="columns") df = df.drop(columns=new_cols.columns[new_cols.isnull().all()]) nested_dfs.append((name, df.drop(columns=dict_cols))) - elif list_cols: - for col in list_cols: - new_table = df[col].explode().dropna().rename(CONTENT_COLUMN).to_frame() - new_table[ORDER_COLUMN] = new_table.groupby(level=0).cumcount() - nested_dfs.append( - ( - name + TABLE_SEPARATOR + col, - new_table.reset_index(names=name + ID_SUFFIX), - ) - ) - nested_dfs.append((name, df.drop(columns=list_cols))) else: - flat_dfs.append((name, df)) + list_cols = [ + col for col in cols_to_scan if df[col].dropna().apply(is_list).all() + ] + if list_cols: + for col in list_cols: + new_table = df[col].explode().dropna().rename(CONTENT_COLUMN).to_frame() + new_table[ORDER_COLUMN] = new_table.groupby(level=0).cumcount() + nested_dfs.append( + ( + name + TABLE_SEPARATOR + col, + new_table.reset_index(names=name + ID_SUFFIX), + ) + ) + nested_dfs.append((name, df.drop(columns=list_cols))) + else: + flat_dfs.append((name, df)) return _normalize_json(nested_dfs, flat_dfs) @@ -167,10 +167,14 @@ def __init__( @classmethod def ingest( - cls, table_name: str, primary_key: list[str], df: pd.DataFrame + cls, + table_name: str, + primary_key: list[str], + df: pd.DataFrame, + json_columns: Optional[list[str]] = None, ) -> Optional[IngestResponseT]: logger.debug(f"Checking table `{table_name}` for JSON columns") - tables = _normalize_json([(table_name, df.copy())], []) + tables = _normalize_json([(table_name, df.copy())], [], json_columns) # If we created additional tables (from JSON lists) or added columns (from JSON dicts) if len(tables) > 1 or len(tables[0][1].columns) > len(df.columns): mappings = {name: sanitize_str(name) for name, _ in tables} @@ -336,5 +340,19 @@ def _generate_commands( return (_add_single_table, add_foreign_key) +def get_json_columns(df: pd.DataFrame) -> list[str]: + column_previews = { + col: df[col].dropna().head(PREVIEW_ROW_COUNT) + for col in df.columns + if df.dtypes[col] == "object" + } + + return [ + col + for col, series in column_previews.items() + if series.apply(is_dict).all() or series.apply(is_list).all() + ] + + CommandsT = tuple[list[dict], list[dict]] IngestResponseT = tuple[RelationalJson, CommandsT] diff --git a/tests/relational/test_relational_data_with_json.py b/tests/relational/test_relational_data_with_json.py index ce2b91a3..249fc8eb 100644 --- a/tests/relational/test_relational_data_with_json.py +++ b/tests/relational/test_relational_data_with_json.py @@ -8,6 +8,7 @@ RelationalData, Scope, ) +from gretel_trainer.relational.json import get_json_columns @pytest.fixture @@ -24,6 +25,13 @@ def bball(): return rel_data +def test_list_json_cols(documents, bball): + assert get_json_columns(documents.get_table_data("users")) == [] + assert get_json_columns(documents.get_table_data("purchases")) == ["data"] + + assert set(get_json_columns(bball.get_table_data("bball"))) == {"draft", "teams"} + + def test_json_columns_produce_invented_flattened_tables(documents): pdtest.assert_frame_equal( documents.get_table_data("purchases-sfx"), @@ -751,6 +759,33 @@ def test_lists_of_lists(): ) +def test_mix_of_dict_and_list_cols(): + df = pd.DataFrame( + data={ + "id": [1, 2], + "dcol": [{"language": "english"}, {"language": "spanish"}], + "lcol": [["a", "b"], ["c", "d"]], + } + ) + rel_data = RelationalData() + rel_data.add_table(name="mix", primary_key=None, data=df) + assert set(rel_data.list_all_tables()) == { + "mix-sfx", + "mix-lcol-sfx", + } + assert set(rel_data.get_table_data("mix-sfx").columns) == { + "id", + "~PRIMARY_KEY_ID~", + "dcol>language", + } + assert set(rel_data.get_table_data("mix-lcol-sfx").columns) == { + "~PRIMARY_KEY_ID~", + "content", + "array~order", + "mix~id", + } + + def test_all_tables_are_present_in_debug_summary(documents): assert documents.debug_summary() == { "foreign_key_count": 4,