diff --git a/python/dask_cudf/dask_cudf/groupby.py b/python/dask_cudf/dask_cudf/groupby.py index 58024e1b71a..3a1557e77f4 100644 --- a/python/dask_cudf/dask_cudf/groupby.py +++ b/python/dask_cudf/dask_cudf/groupby.py @@ -249,7 +249,7 @@ def last(self, split_every=None, split_out=1): ) @_dask_cudf_nvtx_annotate - def aggregate(self, arg, split_every=None, split_out=1): + def aggregate(self, arg, split_every=None, split_out=1, shuffle=None): if arg == "size": return self.size() @@ -274,7 +274,12 @@ def aggregate(self, arg, split_every=None, split_out=1): ) return super().aggregate( - arg, split_every=split_every, split_out=split_out + arg, + split_every=split_every, + split_out=split_out, + # TODO: Change following line to `shuffle=shuffle,` + # when dask_cudf is pinned to dask>2022.8.0 + **({} if shuffle is None else {"shuffle": shuffle}), ) @@ -436,7 +441,7 @@ def last(self, split_every=None, split_out=1): )[self._slice] @_dask_cudf_nvtx_annotate - def aggregate(self, arg, split_every=None, split_out=1): + def aggregate(self, arg, split_every=None, split_out=1, shuffle=None): if arg == "size": return self.size() @@ -459,7 +464,12 @@ def aggregate(self, arg, split_every=None, split_out=1): )[self._slice] return super().aggregate( - arg, split_every=split_every, split_out=split_out + arg, + split_every=split_every, + split_out=split_out, + # TODO: Change following line to `shuffle=shuffle,` + # when dask_cudf is pinned to dask>2022.8.0 + **({} if shuffle is None else {"shuffle": shuffle}), ) diff --git a/python/dask_cudf/dask_cudf/tests/test_groupby.py b/python/dask_cudf/dask_cudf/tests/test_groupby.py index 9d2fc5196e8..e6c23992c4e 100644 --- a/python/dask_cudf/dask_cudf/tests/test_groupby.py +++ b/python/dask_cudf/dask_cudf/tests/test_groupby.py @@ -575,12 +575,12 @@ def test_groupby_categorical_key(): ddf = gddf.to_dask_dataframe() got = ( - gddf.groupby("name") + gddf.groupby("name", sort=True) .agg({"x": ["mean", "max"], "y": ["mean", "count"]}) .compute() ) expect = ( - ddf.groupby("name") + ddf.groupby("name", sort=True) .agg({"x": ["mean", "max"], "y": ["mean", "count"]}) .compute() )