diff --git a/python/dask_cudf/dask_cudf/groupby.py b/python/dask_cudf/dask_cudf/groupby.py index 6fbfd802fb8..600d6cc7412 100644 --- a/python/dask_cudf/dask_cudf/groupby.py +++ b/python/dask_cudf/dask_cudf/groupby.py @@ -78,7 +78,13 @@ def aggregate(self, arg, split_every=None, split_out=1): } if ( isinstance(self.obj, DaskDataFrame) - and isinstance(self.index, (str, list)) + and ( + isinstance(self.index, str) + or ( + isinstance(self.index, list) + and all(isinstance(x, str) for x in self.index) + ) + ) and _is_supported(arg, _supported) ): if isinstance(self._meta.grouping.keys, cudf.MultiIndex): diff --git a/python/dask_cudf/dask_cudf/tests/test_groupby.py b/python/dask_cudf/dask_cudf/tests/test_groupby.py index 0c6f7686275..61fa32b76ed 100644 --- a/python/dask_cudf/dask_cudf/tests/test_groupby.py +++ b/python/dask_cudf/dask_cudf/tests/test_groupby.py @@ -631,3 +631,17 @@ def test_groupby_first_last(data, agg): getattr(gdf.groupby("a"), agg)(), getattr(gddf.groupby("a"), agg)().compute(), ) + + +def test_groupby_with_list_of_series(): + df = cudf.DataFrame({"a": [1, 2, 3, 4, 5]}) + gdf = dask_cudf.from_cudf(df, npartitions=2) + gs = cudf.Series([1, 1, 1, 2, 2], name="id") + ggs = dask_cudf.from_cudf(gs, npartitions=2) + + ddf = dd.from_pandas(df.to_pandas(), npartitions=2) + pgs = dd.from_pandas(gs.to_pandas(), npartitions=2) + + dd.assert_eq( + gdf.groupby([ggs]).agg(["sum"]), ddf.groupby([pgs]).agg(["sum"]) + )