diff --git a/nvtabular/ops/categorify.py b/nvtabular/ops/categorify.py index daf0df101fc..03e2f68dfb3 100644 --- a/nvtabular/ops/categorify.py +++ b/nvtabular/ops/categorify.py @@ -758,7 +758,8 @@ def _top_level_groupby(df, options: FitOptions): df_gb = type(df)() ignore_index = True df_gb[cat_col_selector_str] = _concat( - [df[col] for col in cat_col_selector.names], ignore_index + [_maybe_flatten_list_column(col, df)[col] for col in cat_col_selector.names], + ignore_index, ) cat_col_selector = ColumnSelector([cat_col_selector_str]) else: @@ -795,9 +796,7 @@ def _top_level_groupby(df, options: FitOptions): # Perform groupby and flatten column index # (flattening provides better cudf/pd support) - if is_list_col(cat_col_selector, df_gb): - # handle list columns by encoding the list values - df_gb = dispatch.flatten_list_column(df_gb[cat_col_selector.names[0]]) + df_gb = _maybe_flatten_list_column(cat_col_selector.names[0], df_gb) # NOTE: groupby(..., dropna=False) requires pandas>=1.1.0 gb = df_gb.groupby(cat_col_selector.names, dropna=False).agg(agg_dict) gb.columns = [ @@ -1414,6 +1413,15 @@ def is_list_col(col_selector, df): return has_lists +def _maybe_flatten_list_column(col: str, df): + # Flatten the specified column (col) if it is + # a list dtype. Otherwise, pass back df "as is" + selector = ColumnSelector([col]) + if is_list_col(selector, df): + return dispatch.flatten_list_column(df[selector.names[0]]) + return df + + def _hash_bucket(df, num_buckets, col, encode_type="joint"): if encode_type == "joint": nb = num_buckets[col[0]] diff --git a/tests/unit/ops/test_categorify.py b/tests/unit/ops/test_categorify.py index a2b674fb85f..08d3bdd97b6 100644 --- a/tests/unit/ops/test_categorify.py +++ b/tests/unit/ops/test_categorify.py @@ -654,3 +654,35 @@ def test_categorify_max_size_null_iloc_check(): unique_C2 = pd.read_parquet("./categories/unique.C2.parquet") assert str(unique_C2["C2"].iloc[0]) in ["", "nan"] assert unique_C2["C2_size"].iloc[0] == 0 + + +@pytest.mark.parametrize("cpu", _CPU) +def test_categorify_joint_list(cpu): + df = pd.DataFrame( + { + "Author": ["User_A", "User_E", "User_B", "User_C"], + "Engaging User": [ + ["User_B", "User_C"], + [], + ["User_A", "User_D"], + ["User_A"], + ], + "Post": [1, 2, 3, 4], + } + ) + cat_names = ["Post", ["Author", "Engaging User"]] + cats = cat_names >> nvt.ops.Categorify(encode_type="joint") + workflow = nvt.Workflow(cats) + df_out = ( + workflow.fit_transform(nvt.Dataset(df, cpu=cpu)).to_ddf().compute(scheduler="synchronous") + ) + + compare_a = df_out["Author"].to_list() if cpu else df_out["Author"].to_arrow().to_pylist() + compare_e = ( + df_out["Engaging User"].explode().dropna().to_list() + if cpu + else df_out["Engaging User"].explode().dropna().to_arrow().to_pylist() + ) + + assert compare_a == [1, 5, 2, 3] + assert compare_e == [2, 3, 1, 4, 1]