Skip to content

Commit

Permalink
Support shuffle-based groupby aggregations in dask_cudf (#11800)
Browse files Browse the repository at this point in the history
This PR corresponds to the `dask_cudf` version of dask/dask#9302 (adding a shuffle-based algorithm for high-cardinality groupby aggregations). The benefits of this algorithm are most significant for cases where `split_out>1` is necessary:

```python
agg = ddf.groupby("id").agg({"x": "mean", "y": "max"}, split_out=4, shuffle=True)
```
**NOTES**:

- ~`shuffle="explicit-comms"` is also supported (when `dask_cuda` is installed)~
- It should be possible to refactor remove some of this code in the future. However, due to some subtle differences between the groupby code in `dask.dataframe` and `dask_cudf`, the specialized `_shuffle_aggregate` is currently necessary.

Authors:
  - Richard (Rick) Zamora (https://github.com/rjzamora)

Approvers:
  - Benjamin Zaitlen (https://github.com/quasiben)
  - Lawrence Mitchell (https://github.com/wence-)

URL: #11800
  • Loading branch information
rjzamora authored Sep 28, 2022
1 parent 2e4acbb commit 5a4afec
Show file tree
Hide file tree
Showing 2 changed files with 134 additions and 16 deletions.
115 changes: 99 additions & 16 deletions python/dask_cudf/dask_cudf/groupby.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
split_out_on_cols,
)
from dask.dataframe.groupby import DataFrameGroupBy, SeriesGroupBy
from dask.utils import funcname

import cudf
from cudf.utils.utils import _dask_cudf_nvtx_annotate
Expand Down Expand Up @@ -469,6 +470,74 @@ def aggregate(self, arg, split_every=None, split_out=1, shuffle=None):
)


def _shuffle_aggregate(
ddf,
gb_cols,
chunk,
chunk_kwargs,
aggregate,
aggregate_kwargs,
split_every,
split_out,
token=None,
sort=None,
shuffle=None,
):
# Shuffle-based groupby aggregation
# NOTE: This function is the dask_cudf version of
# dask.dataframe.groupby._shuffle_aggregate

# Step 1 - Chunkwise groupby operation
chunk_name = f"{token or funcname(chunk)}-chunk"
chunked = ddf.map_partitions(
chunk,
meta=chunk(ddf._meta, **chunk_kwargs),
enforce_metadata=False,
token=chunk_name,
**chunk_kwargs,
)

# Step 2 - Perform global sort or shuffle
shuffle_npartitions = max(
chunked.npartitions // split_every,
split_out,
)
if sort and split_out > 1:
# Sort-based code path
result = (
chunked.repartition(npartitions=shuffle_npartitions)
.sort_values(
gb_cols,
ignore_index=True,
shuffle=shuffle,
)
.map_partitions(
aggregate,
meta=aggregate(chunked._meta, **aggregate_kwargs),
enforce_metadata=False,
**aggregate_kwargs,
)
)
else:
# Hash-based code path
result = chunked.shuffle(
gb_cols,
npartitions=shuffle_npartitions,
ignore_index=True,
shuffle=shuffle,
).map_partitions(
aggregate,
meta=aggregate(chunked._meta, **aggregate_kwargs),
enforce_metadata=False,
**aggregate_kwargs,
)

# Step 3 - Repartition and return
if split_out < result.npartitions:
return result.repartition(npartitions=split_out)
return result


@_dask_cudf_nvtx_annotate
def groupby_agg(
ddf,
Expand Down Expand Up @@ -501,12 +570,6 @@ def groupby_agg(
in `dask.dataframe`, because it allows the cudf backend to
perform multiple aggregations at once.
"""
if shuffle:
# Temporary error until shuffle-based groupby is implemented
raise NotImplementedError(
"The shuffle option is not yet implemented in dask_cudf."
)

# Assert that aggregations are supported
aggs = _redirect_aggs(aggs_in)
if not _aggs_supported(aggs, SUPPORTED_AGGS):
Expand All @@ -523,16 +586,6 @@ def groupby_agg(
split_every = split_every or 8
split_out = split_out or 1

# Deal with sort/shuffle defaults
if split_out > 1 and sort:
# NOTE: This can be changed when `shuffle` is not `None`
# as soon as the shuffle-based groupby is implemented
raise ValueError(
"dask-cudf's groupby algorithm does not yet support "
"`sort=True` when `split_out>1`. Please use `split_out=1`, "
"or try grouping with `sort=False`."
)

# Standardize `gb_cols`, `columns`, and `aggs`
if isinstance(gb_cols, str):
gb_cols = [gb_cols]
Expand Down Expand Up @@ -609,6 +662,36 @@ def groupby_agg(
"aggs_renames": aggs_renames,
}

# Use shuffle=True for split_out>1
if sort and split_out > 1 and shuffle is None:
shuffle = "tasks"

# Check if we are using the shuffle-based algorithm
if shuffle:
# Shuffle-based aggregation
return _shuffle_aggregate(
ddf,
gb_cols,
chunk,
chunk_kwargs,
aggregate,
aggregate_kwargs,
split_every,
split_out,
token="cudf-aggregate",
sort=sort,
shuffle=shuffle if isinstance(shuffle, str) else None,
)

# Deal with sort/shuffle defaults
if split_out > 1 and sort:
raise ValueError(
"dask-cudf's groupby algorithm does not yet support "
"`sort=True` when `split_out>1`, unless a shuffle-based "
"algorithm is used. Please use `split_out=1`, group "
"with `sort=False`, or set `shuffle=True`."
)

return aca(
[ddf],
chunk=chunk,
Expand Down
35 changes: 35 additions & 0 deletions python/dask_cudf/dask_cudf/tests/test_groupby.py
Original file line number Diff line number Diff line change
Expand Up @@ -837,3 +837,38 @@ def test_groupby_all_columns(func):
actual = func(gddf)

dd.assert_eq(expect, actual)


def test_groupby_shuffle():
df = cudf.datasets.randomdata(
nrows=640, dtypes={"a": str, "b": int, "c": int}
)
gddf = dask_cudf.from_cudf(df, 8)
spec = {"b": "mean", "c": "max"}
expect = df.groupby("a", sort=True).agg(spec)

# Sorted aggregation, single-partition output
# (sort=True, split_out=1)
got = gddf.groupby("a", sort=True).agg(spec, shuffle=True, split_out=1)
dd.assert_eq(expect, got)

# Sorted aggregation, multi-partition output
# (sort=True, split_out=2)
got = gddf.groupby("a", sort=True).agg(spec, shuffle=True, split_out=2)
dd.assert_eq(expect, got)

# Un-sorted aggregation, single-partition output
# (sort=False, split_out=1)
got = gddf.groupby("a", sort=False).agg(spec, shuffle=True, split_out=1)
dd.assert_eq(expect.sort_index(), got.compute().sort_index())

# Un-sorted aggregation, multi-partition output
# (sort=False, split_out=2)
# NOTE: `shuffle=True` should be default
got = gddf.groupby("a", sort=False).agg(spec, split_out=2)
dd.assert_eq(expect, got.compute().sort_index())

# Sorted aggregation fails with split_out>1 when shuffle is False
# (sort=True, split_out=2, shuffle=False)
with pytest.raises(ValueError):
gddf.groupby("a", sort=True).agg(spec, shuffle=False, split_out=2)

0 comments on commit 5a4afec

Please sign in to comment.