diff --git a/python/dask_cudf/dask_cudf/groupby.py b/python/dask_cudf/dask_cudf/groupby.py index 73fe1bd2196..336fdaf009c 100644 --- a/python/dask_cudf/dask_cudf/groupby.py +++ b/python/dask_cudf/dask_cudf/groupby.py @@ -62,7 +62,16 @@ def aggregate(self, arg, split_every=None, split_out=1): return self.size() arg = _redirect_aggs(arg) - _supported = {"count", "mean", "std", "var", "sum", "min", "max"} + _supported = { + "count", + "mean", + "std", + "var", + "sum", + "min", + "max", + "collect", + } if ( isinstance(self.obj, DaskDataFrame) and isinstance(self.index, (str, list)) @@ -109,7 +118,16 @@ def aggregate(self, arg, split_every=None, split_out=1): return self.size() arg = _redirect_aggs(arg) - _supported = {"count", "mean", "std", "var", "sum", "min", "max"} + _supported = { + "count", + "mean", + "std", + "var", + "sum", + "min", + "max", + "collect", + } if ( isinstance(self.obj, DaskDataFrame) and isinstance(self.index, (str, list)) @@ -147,7 +165,7 @@ def groupby_agg( This aggregation algorithm only supports the following options: - {"count", "mean", "std", "var", "sum", "min", "max"} + {"count", "mean", "std", "var", "sum", "min", "max", "collect"} This "optimized" approach is more performant than the algorithm in `dask.dataframe`, because it allows the cudf backend to @@ -173,7 +191,7 @@ def groupby_agg( # strings (no lists) str_cols_out = True for col in aggs: - if isinstance(aggs[col], str): + if isinstance(aggs[col], str) or callable(aggs[col]): aggs[col] = [aggs[col]] else: str_cols_out = False @@ -181,7 +199,16 @@ def groupby_agg( columns.append(col) # Assert that aggregations are supported - _supported = {"count", "mean", "std", "var", "sum", "min", "max"} + _supported = { + "count", + "mean", + "std", + "var", + "sum", + "min", + "max", + "collect", + } if not _is_supported(aggs, _supported): raise ValueError( f"Supported aggs include {_supported} for groupby_agg API. " @@ -282,7 +309,13 @@ def groupby_agg( def _redirect_aggs(arg): """ Redirect aggregations to their corresponding name in cuDF """ - redirects = {sum: "sum", max: "max", min: "min"} + redirects = { + sum: "sum", + max: "max", + min: "min", + list: "collect", + "list": "collect", + } if isinstance(arg, dict): new_arg = dict() for col in arg: @@ -400,6 +433,8 @@ def _tree_node_agg(dfs, gb_cols, split_out, dropna, sort, sep): agg_dict[col] = ["sum"] elif agg in ("min", "max"): agg_dict[col] = [agg] + elif agg == "collect": + agg_dict[col] = ["collect"] else: raise ValueError(f"Unexpected aggregation: {agg}") @@ -478,6 +513,9 @@ def _finalize_gb_agg( gb.drop(columns=[sum_name], inplace=True) if "count" not in agg_list: gb.drop(columns=[count_name], inplace=True) + if "collect" in agg_list: + collect_name = _make_name(col, "collect", sep=sep) + gb[collect_name] = gb[collect_name].list.concat() # Ensure sorted keys if `sort=True` if sort: diff --git a/python/dask_cudf/dask_cudf/tests/test_groupby.py b/python/dask_cudf/dask_cudf/tests/test_groupby.py index e3a3045dcc7..356567fdef0 100644 --- a/python/dask_cudf/dask_cudf/tests/test_groupby.py +++ b/python/dask_cudf/dask_cudf/tests/test_groupby.py @@ -125,6 +125,33 @@ def test_groupby_std(func): dd.assert_eq(a, b) +@pytest.mark.parametrize( + "func", + [ + lambda df: df.groupby("x").agg({"y": "collect"}), + pytest.param( + lambda df: df.groupby("x").y.agg("collect"), marks=pytest.mark.skip + ), + ], +) +def test_groupby_collect(func): + pdf = pd.DataFrame( + { + "x": np.random.randint(0, 5, size=10000), + "y": np.random.normal(size=10000), + } + ) + + gdf = cudf.DataFrame.from_pandas(pdf) + + ddf = dask_cudf.from_cudf(gdf, npartitions=5) + + a = func(gdf).to_pandas() + b = func(ddf).compute().to_pandas() + + dd.assert_eq(a, b) + + # reason gotattr in cudf @pytest.mark.parametrize( "func",