From 2e980b889275c634fe6a54c6a4e03b22220337ea Mon Sep 17 00:00:00 2001 From: GALI PREM SAGAR Date: Fri, 13 Aug 2021 05:36:42 -0500 Subject: [PATCH] Fetch correct grouping keys `agg` of dask groupby (#9022) Fixes: #9020 This PR enables fallback to upstream `dask` when the groupby operation is performed by a list of `Series` objects. Authors: - GALI PREM SAGAR (https://github.com/galipremsagar) Approvers: - Ashwin Srinath (https://github.com/shwina) URL: https://github.com/rapidsai/cudf/pull/9022 --- python/dask_cudf/dask_cudf/groupby.py | 17 +++++++++++++++-- .../dask_cudf/dask_cudf/tests/test_groupby.py | 14 ++++++++++++++ 2 files changed, 29 insertions(+), 2 deletions(-) diff --git a/python/dask_cudf/dask_cudf/groupby.py b/python/dask_cudf/dask_cudf/groupby.py index 11184eb425e..600d6cc7412 100644 --- a/python/dask_cudf/dask_cudf/groupby.py +++ b/python/dask_cudf/dask_cudf/groupby.py @@ -16,6 +16,8 @@ from dask.dataframe.groupby import DataFrameGroupBy, SeriesGroupBy from dask.highlevelgraph import HighLevelGraph +import cudf + class CudfDataFrameGroupBy(DataFrameGroupBy): def __init__(self, *args, **kwargs): @@ -76,12 +78,23 @@ def aggregate(self, arg, split_every=None, split_out=1): } if ( isinstance(self.obj, DaskDataFrame) - and isinstance(self.index, (str, list)) + and ( + isinstance(self.index, str) + or ( + isinstance(self.index, list) + and all(isinstance(x, str) for x in self.index) + ) + ) and _is_supported(arg, _supported) ): + if isinstance(self._meta.grouping.keys, cudf.MultiIndex): + keys = self._meta.grouping.keys.names + else: + keys = self._meta.grouping.keys.name + return groupby_agg( self.obj, - self.index, + keys, arg, split_every=split_every, split_out=split_out, diff --git a/python/dask_cudf/dask_cudf/tests/test_groupby.py b/python/dask_cudf/dask_cudf/tests/test_groupby.py index 0c6f7686275..61fa32b76ed 100644 --- a/python/dask_cudf/dask_cudf/tests/test_groupby.py +++ b/python/dask_cudf/dask_cudf/tests/test_groupby.py @@ -631,3 +631,17 @@ def test_groupby_first_last(data, agg): getattr(gdf.groupby("a"), agg)(), getattr(gddf.groupby("a"), agg)().compute(), ) + + +def test_groupby_with_list_of_series(): + df = cudf.DataFrame({"a": [1, 2, 3, 4, 5]}) + gdf = dask_cudf.from_cudf(df, npartitions=2) + gs = cudf.Series([1, 1, 1, 2, 2], name="id") + ggs = dask_cudf.from_cudf(gs, npartitions=2) + + ddf = dd.from_pandas(df.to_pandas(), npartitions=2) + pgs = dd.from_pandas(gs.to_pandas(), npartitions=2) + + dd.assert_eq( + gdf.groupby([ggs]).agg(["sum"]), ddf.groupby([pgs]).agg(["sum"]) + )