diff --git a/python/dask_cudf/dask_cudf/groupby.py b/python/dask_cudf/dask_cudf/groupby.py index 159602f195a..28add2e87f5 100644 --- a/python/dask_cudf/dask_cudf/groupby.py +++ b/python/dask_cudf/dask_cudf/groupby.py @@ -11,6 +11,7 @@ split_out_on_cols, ) from dask.dataframe.groupby import DataFrameGroupBy, SeriesGroupBy +from dask.utils import funcname import cudf from cudf.utils.utils import _dask_cudf_nvtx_annotate @@ -469,6 +470,74 @@ def aggregate(self, arg, split_every=None, split_out=1, shuffle=None): ) +def _shuffle_aggregate( + ddf, + gb_cols, + chunk, + chunk_kwargs, + aggregate, + aggregate_kwargs, + split_every, + split_out, + token=None, + sort=None, + shuffle=None, +): + # Shuffle-based groupby aggregation + # NOTE: This function is the dask_cudf version of + # dask.dataframe.groupby._shuffle_aggregate + + # Step 1 - Chunkwise groupby operation + chunk_name = f"{token or funcname(chunk)}-chunk" + chunked = ddf.map_partitions( + chunk, + meta=chunk(ddf._meta, **chunk_kwargs), + enforce_metadata=False, + token=chunk_name, + **chunk_kwargs, + ) + + # Step 2 - Perform global sort or shuffle + shuffle_npartitions = max( + chunked.npartitions // split_every, + split_out, + ) + if sort and split_out > 1: + # Sort-based code path + result = ( + chunked.repartition(npartitions=shuffle_npartitions) + .sort_values( + gb_cols, + ignore_index=True, + shuffle=shuffle, + ) + .map_partitions( + aggregate, + meta=aggregate(chunked._meta, **aggregate_kwargs), + enforce_metadata=False, + **aggregate_kwargs, + ) + ) + else: + # Hash-based code path + result = chunked.shuffle( + gb_cols, + npartitions=shuffle_npartitions, + ignore_index=True, + shuffle=shuffle, + ).map_partitions( + aggregate, + meta=aggregate(chunked._meta, **aggregate_kwargs), + enforce_metadata=False, + **aggregate_kwargs, + ) + + # Step 3 - Repartition and return + if split_out < result.npartitions: + return result.repartition(npartitions=split_out) + return result + + @_dask_cudf_nvtx_annotate def groupby_agg( ddf, @@ -501,12 +570,6 @@ def groupby_agg( in `dask.dataframe`, because it allows the cudf backend to perform multiple aggregations at once. """ - if shuffle: - # Temporary error until shuffle-based groupby is implemented - raise NotImplementedError( - "The shuffle option is not yet implemented in dask_cudf." - ) - # Assert that aggregations are supported aggs = _redirect_aggs(aggs_in) if not _aggs_supported(aggs, SUPPORTED_AGGS): @@ -523,16 +586,6 @@ def groupby_agg( split_every = split_every or 8 split_out = split_out or 1 - # Deal with sort/shuffle defaults - if split_out > 1 and sort: - # NOTE: This can be changed when `shuffle` is not `None` - # as soon as the shuffle-based groupby is implemented - raise ValueError( - "dask-cudf's groupby algorithm does not yet support " - "`sort=True` when `split_out>1`. Please use `split_out=1`, " - "or try grouping with `sort=False`." - ) - # Standardize `gb_cols`, `columns`, and `aggs` if isinstance(gb_cols, str): gb_cols = [gb_cols] @@ -609,6 +662,36 @@ def groupby_agg( "aggs_renames": aggs_renames, } + # Use shuffle=True for split_out>1 + if sort and split_out > 1 and shuffle is None: + shuffle = "tasks" + + # Check if we are using the shuffle-based algorithm + if shuffle: + # Shuffle-based aggregation + return _shuffle_aggregate( + ddf, + gb_cols, + chunk, + chunk_kwargs, + aggregate, + aggregate_kwargs, + split_every, + split_out, + token="cudf-aggregate", + sort=sort, + shuffle=shuffle if isinstance(shuffle, str) else None, + ) + + # Deal with sort/shuffle defaults + if split_out > 1 and sort: + raise ValueError( + "dask-cudf's groupby algorithm does not yet support " + "`sort=True` when `split_out>1`, unless a shuffle-based " + "algorithm is used. Please use `split_out=1`, group " + "with `sort=False`, or set `shuffle=True`." + ) + return aca( [ddf], chunk=chunk, diff --git a/python/dask_cudf/dask_cudf/tests/test_groupby.py b/python/dask_cudf/dask_cudf/tests/test_groupby.py index cc27c7f2a86..f2047c34684 100644 --- a/python/dask_cudf/dask_cudf/tests/test_groupby.py +++ b/python/dask_cudf/dask_cudf/tests/test_groupby.py @@ -837,3 +837,38 @@ def test_groupby_all_columns(func): actual = func(gddf) dd.assert_eq(expect, actual) + + +def test_groupby_shuffle(): + df = cudf.datasets.randomdata( + nrows=640, dtypes={"a": str, "b": int, "c": int} + ) + gddf = dask_cudf.from_cudf(df, 8) + spec = {"b": "mean", "c": "max"} + expect = df.groupby("a", sort=True).agg(spec) + + # Sorted aggregation, single-partition output + # (sort=True, split_out=1) + got = gddf.groupby("a", sort=True).agg(spec, shuffle=True, split_out=1) + dd.assert_eq(expect, got) + + # Sorted aggregation, multi-partition output + # (sort=True, split_out=2) + got = gddf.groupby("a", sort=True).agg(spec, shuffle=True, split_out=2) + dd.assert_eq(expect, got) + + # Un-sorted aggregation, single-partition output + # (sort=False, split_out=1) + got = gddf.groupby("a", sort=False).agg(spec, shuffle=True, split_out=1) + dd.assert_eq(expect.sort_index(), got.compute().sort_index()) + + # Un-sorted aggregation, multi-partition output + # (sort=False, split_out=2) + # NOTE: `shuffle=True` should be default + got = gddf.groupby("a", sort=False).agg(spec, split_out=2) + dd.assert_eq(expect, got.compute().sort_index()) + + # Sorted aggregation fails with split_out>1 when shuffle is False + # (sort=True, split_out=2, shuffle=False) + with pytest.raises(ValueError): + gddf.groupby("a", sort=True).agg(spec, shuffle=False, split_out=2)