Skip to content

Commit

Permalink
fallback to upstream dask code-path if list of Series is given
Browse files Browse the repository at this point in the history
  • Loading branch information
galipremsagar committed Aug 12, 2021
1 parent ad7f2b2 commit 828390d
Show file tree
Hide file tree
Showing 2 changed files with 21 additions and 1 deletion.
8 changes: 7 additions & 1 deletion python/dask_cudf/dask_cudf/groupby.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,13 @@ def aggregate(self, arg, split_every=None, split_out=1):
}
if (
isinstance(self.obj, DaskDataFrame)
and isinstance(self.index, (str, list))
and (
isinstance(self.index, str)
or (
isinstance(self.index, list)
and all(isinstance(x, str) for x in self.index)
)
)
and _is_supported(arg, _supported)
):
if isinstance(self._meta.grouping.keys, cudf.MultiIndex):
Expand Down
14 changes: 14 additions & 0 deletions python/dask_cudf/dask_cudf/tests/test_groupby.py
Original file line number Diff line number Diff line change
Expand Up @@ -631,3 +631,17 @@ def test_groupby_first_last(data, agg):
getattr(gdf.groupby("a"), agg)(),
getattr(gddf.groupby("a"), agg)().compute(),
)


def test_groupby_with_list_of_series():
df = cudf.DataFrame({"a": [1, 2, 3, 4, 5]})
gdf = dask_cudf.from_cudf(df, npartitions=2)
gs = cudf.Series([1, 1, 1, 2, 2], name="id")
ggs = dask_cudf.from_cudf(gs, npartitions=2)

ddf = dd.from_pandas(df.to_pandas(), npartitions=2)
pgs = dd.from_pandas(gs.to_pandas(), npartitions=2)

dd.assert_eq(
gdf.groupby([ggs]).agg(["sum"]), ddf.groupby([pgs]).agg(["sum"])
)

0 comments on commit 828390d

Please sign in to comment.