diff --git a/python/cudf/cudf/core/dataframe.py b/python/cudf/cudf/core/dataframe.py index 721ebf22de7..0d833a7d341 100644 --- a/python/cudf/cudf/core/dataframe.py +++ b/python/cudf/cudf/core/dataframe.py @@ -2585,7 +2585,7 @@ def columns(self, columns): if not len(columns) == len(self._data.names): raise ValueError( - f"Length mismatch: expected {len(self._data.names)} elements ," + f"Length mismatch: expected {len(self._data.names)} elements, " f"got {len(columns)} elements" ) diff --git a/python/dask_cudf/dask_cudf/groupby.py b/python/dask_cudf/dask_cudf/groupby.py index 600d6cc7412..7e2c3a4f36c 100644 --- a/python/dask_cudf/dask_cudf/groupby.py +++ b/python/dask_cudf/dask_cudf/groupby.py @@ -198,34 +198,8 @@ def groupby_agg( in `dask.dataframe`, because it allows the cudf backend to perform multiple aggregations at once. """ - - # Deal with default split_out and split_every params - if split_every is False: - split_every = ddf.npartitions - split_every = split_every or 8 - split_out = split_out or 1 - - # Standardize `gb_cols` and `columns` lists - aggs = _redirect_aggs(aggs_in.copy()) - if isinstance(gb_cols, str): - gb_cols = [gb_cols] - columns = [c for c in ddf.columns if c not in gb_cols] - str_cols_out = False - if isinstance(aggs, dict): - # Use `str_cols_out` to specify if the output columns - # will have str (rather than MultiIndex/tuple) names. - # This happens when all values in the `aggs` dict are - # strings (no lists) - str_cols_out = True - for col in aggs: - if isinstance(aggs[col], str) or callable(aggs[col]): - aggs[col] = [aggs[col]] - else: - str_cols_out = False - if col in gb_cols: - columns.append(col) - # Assert that aggregations are supported + aggs = _redirect_aggs(aggs_in) _supported = { "count", "mean", @@ -244,10 +218,39 @@ def groupby_agg( f"Aggregations must be specified with dict or list syntax." ) - # Always convert aggs to dict for consistency + # Deal with default split_out and split_every params + if split_every is False: + split_every = ddf.npartitions + split_every = split_every or 8 + split_out = split_out or 1 + + # Standardize `gb_cols`, `columns`, and `aggs` + if isinstance(gb_cols, str): + gb_cols = [gb_cols] + columns = [c for c in ddf.columns if c not in gb_cols] if isinstance(aggs, list): aggs = {col: aggs for col in columns} + # Assert if our output will have a MultiIndex; this will be the case if + # any value in the `aggs` dict is not a string (i.e. multiple/named + # aggregations per column) + str_cols_out = True + aggs_renames = {} + for col in aggs: + if isinstance(aggs[col], str) or callable(aggs[col]): + aggs[col] = [aggs[col]] + elif isinstance(aggs[col], dict): + str_cols_out = False + col_aggs = [] + for k, v in aggs[col].items(): + aggs_renames[col, v] = k + col_aggs.append(v) + aggs[col] = col_aggs + else: + str_cols_out = False + if col in gb_cols: + columns.append(col) + # Begin graph construction dsk = {} token = tokenize(ddf, gb_cols, aggs) @@ -314,6 +317,13 @@ def groupby_agg( for col in aggs: _aggs[col] = _aggs[col][0] _meta = ddf._meta.groupby(gb_cols, as_index=as_index).agg(_aggs) + if aggs_renames: + col_array = [] + agg_array = [] + for col, agg in _meta.columns: + col_array.append(col) + agg_array.append(aggs_renames.get((col, agg), agg)) + _meta.columns = pd.MultiIndex.from_arrays([col_array, agg_array]) for s in range(split_out): dsk[(gb_agg_name, s)] = ( _finalize_gb_agg, @@ -326,6 +336,7 @@ def groupby_agg( sort, sep, str_cols_out, + aggs_renames, ) divisions = [None] * (split_out + 1) @@ -350,6 +361,10 @@ def _redirect_aggs(arg): for col in arg: if isinstance(arg[col], list): new_arg[col] = [redirects.get(agg, agg) for agg in arg[col]] + elif isinstance(arg[col], dict): + new_arg[col] = { + k: redirects.get(v, v) for k, v in arg[col].items() + } else: new_arg[col] = redirects.get(arg[col], arg[col]) return new_arg @@ -367,6 +382,8 @@ def _is_supported(arg, supported: set): for col in arg: if isinstance(arg[col], list): _global_set = _global_set.union(set(arg[col])) + elif isinstance(arg[col], dict): + _global_set = _global_set.union(set(arg[col].values())) else: _global_set.add(arg[col]) else: @@ -460,10 +477,8 @@ def _tree_node_agg(dfs, gb_cols, split_out, dropna, sort, sep): agg = col.split(sep)[-1] if agg in ("count", "sum"): agg_dict[col] = ["sum"] - elif agg in ("min", "max"): + elif agg in ("min", "max", "collect"): agg_dict[col] = [agg] - elif agg == "collect": - agg_dict[col] = ["collect"] else: raise ValueError(f"Unexpected aggregation: {agg}") @@ -508,6 +523,7 @@ def _finalize_gb_agg( sort, sep, str_cols_out, + aggs_renames, ): """ Final aggregation task. @@ -564,7 +580,7 @@ def _finalize_gb_agg( else: name, agg = col.split(sep) col_array.append(name) - agg_array.append(agg) + agg_array.append(aggs_renames.get((name, agg), agg)) if str_cols_out: gb.columns = col_array else: diff --git a/python/dask_cudf/dask_cudf/tests/test_groupby.py b/python/dask_cudf/dask_cudf/tests/test_groupby.py index 61fa32b76ed..6569ffa94c5 100644 --- a/python/dask_cudf/dask_cudf/tests/test_groupby.py +++ b/python/dask_cudf/dask_cudf/tests/test_groupby.py @@ -645,3 +645,32 @@ def test_groupby_with_list_of_series(): dd.assert_eq( gdf.groupby([ggs]).agg(["sum"]), ddf.groupby([pgs]).agg(["sum"]) ) + + +@pytest.mark.parametrize( + "func", + [ + lambda df: df.groupby("x").agg({"y": {"foo": "sum"}}), + lambda df: df.groupby("x").agg({"y": {"foo": "sum", "bar": "count"}}), + ], +) +def test_groupby_nested_dict(func): + pdf = pd.DataFrame( + { + "x": np.random.randint(0, 5, size=10000), + "y": np.random.normal(size=10000), + } + ) + + ddf = dd.from_pandas(pdf, npartitions=5) + c_ddf = ddf.map_partitions(cudf.from_pandas) + + a = func(ddf).compute() + b = func(c_ddf).compute().to_pandas() + + a.index.name = None + a.name = None + b.index.name = None + b.name = None + + dd.assert_eq(a, b)