Skip to content

Commit

Permalink
Add handling for string by-columns in dask-cudf groupby (#10830)
Browse files Browse the repository at this point in the history
Converts string `by`-columns to lists when calling aggregation methods, which expect `Groupby.by` to be a list or tuple.

We might be able to do this conversion when initializing the groupby object, just started off with this approach as it seems like upstream Dask is pretty careful not to overwrite the original `by` input if it's a string.

Closes #10829

Authors:
  - Charles Blackmon-Luca (https://github.com/charlesbluca)

Approvers:
  - Bradley Dice (https://github.com/bdice)
  - GALI PREM SAGAR (https://github.com/galipremsagar)

URL: #10830
  • Loading branch information
charlesbluca authored May 11, 2022
1 parent 1889133 commit 16d9a92
Show file tree
Hide file tree
Showing 2 changed files with 30 additions and 17 deletions.
29 changes: 19 additions & 10 deletions python/dask_cudf/dask_cudf/groupby.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,13 +83,21 @@ def __getitem__(self, key):
g._meta = g._meta[key]
return g

@_dask_cudf_nvtx_annotate
def _make_groupby_method_aggs(self, agg_name):
"""Create aggs dictionary for aggregation methods"""

if isinstance(self.by, list):
return {c: agg_name for c in self.obj.columns if c not in self.by}
return {c: agg_name for c in self.obj.columns if c != self.by}

@_dask_cudf_nvtx_annotate
@_check_groupby_supported
def count(self, split_every=None, split_out=1):
return groupby_agg(
self.obj,
self.by,
{c: "count" for c in self.obj.columns if c not in self.by},
self._make_groupby_method_aggs("count"),
split_every=split_every,
split_out=split_out,
sep=self.sep,
Expand All @@ -104,7 +112,7 @@ def mean(self, split_every=None, split_out=1):
return groupby_agg(
self.obj,
self.by,
{c: "mean" for c in self.obj.columns if c not in self.by},
self._make_groupby_method_aggs("mean"),
split_every=split_every,
split_out=split_out,
sep=self.sep,
Expand All @@ -119,7 +127,7 @@ def std(self, split_every=None, split_out=1):
return groupby_agg(
self.obj,
self.by,
{c: "std" for c in self.obj.columns if c not in self.by},
self._make_groupby_method_aggs("std"),
split_every=split_every,
split_out=split_out,
sep=self.sep,
Expand All @@ -134,7 +142,7 @@ def var(self, split_every=None, split_out=1):
return groupby_agg(
self.obj,
self.by,
{c: "var" for c in self.obj.columns if c not in self.by},
self._make_groupby_method_aggs("var"),
split_every=split_every,
split_out=split_out,
sep=self.sep,
Expand All @@ -149,7 +157,7 @@ def sum(self, split_every=None, split_out=1):
return groupby_agg(
self.obj,
self.by,
{c: "sum" for c in self.obj.columns if c not in self.by},
self._make_groupby_method_aggs("sum"),
split_every=split_every,
split_out=split_out,
sep=self.sep,
Expand All @@ -164,7 +172,7 @@ def min(self, split_every=None, split_out=1):
return groupby_agg(
self.obj,
self.by,
{c: "min" for c in self.obj.columns if c not in self.by},
self._make_groupby_method_aggs("min"),
split_every=split_every,
split_out=split_out,
sep=self.sep,
Expand All @@ -179,7 +187,7 @@ def max(self, split_every=None, split_out=1):
return groupby_agg(
self.obj,
self.by,
{c: "max" for c in self.obj.columns if c not in self.by},
self._make_groupby_method_aggs("max"),
split_every=split_every,
split_out=split_out,
sep=self.sep,
Expand All @@ -194,7 +202,7 @@ def collect(self, split_every=None, split_out=1):
return groupby_agg(
self.obj,
self.by,
{c: "collect" for c in self.obj.columns if c not in self.by},
self._make_groupby_method_aggs("collect"),
split_every=split_every,
split_out=split_out,
sep=self.sep,
Expand All @@ -209,7 +217,7 @@ def first(self, split_every=None, split_out=1):
return groupby_agg(
self.obj,
self.by,
{c: "first" for c in self.obj.columns if c not in self.by},
self._make_groupby_method_aggs("first"),
split_every=split_every,
split_out=split_out,
sep=self.sep,
Expand All @@ -224,7 +232,7 @@ def last(self, split_every=None, split_out=1):
return groupby_agg(
self.obj,
self.by,
{c: "last" for c in self.obj.columns if c not in self.by},
self._make_groupby_method_aggs("last"),
split_every=split_every,
split_out=split_out,
sep=self.sep,
Expand Down Expand Up @@ -660,6 +668,7 @@ def _aggs_supported(arg, supported: set):
return False


@_dask_cudf_nvtx_annotate
def _groupby_supported(gb):
"""Check that groupby input is supported by dask-cudf"""
return isinstance(gb.obj, DaskDataFrame) and (
Expand Down
18 changes: 11 additions & 7 deletions python/dask_cudf/dask_cudf/tests/test_groupby.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,20 +18,24 @@
@pytest.mark.parametrize("series", [False, True])
def test_groupby_basic(series, aggregation):
np.random.seed(0)

# note that column name "x" is a substring of the groupby key;
# this gives us coverage for cudf#10829
pdf = pd.DataFrame(
{
"x": np.random.randint(0, 5, size=10000),
"xx": np.random.randint(0, 5, size=10000),
"x": np.random.normal(size=10000),
"y": np.random.normal(size=10000),
}
)

gdf = cudf.DataFrame.from_pandas(pdf)
gdf_grouped = gdf.groupby("x")
ddf_grouped = dask_cudf.from_cudf(gdf, npartitions=5).groupby("x")
gdf_grouped = gdf.groupby("xx")
ddf_grouped = dask_cudf.from_cudf(gdf, npartitions=5).groupby("xx")

if series:
gdf_grouped = gdf_grouped.x
ddf_grouped = ddf_grouped.x
gdf_grouped = gdf_grouped.xx
ddf_grouped = ddf_grouped.xx

a = getattr(gdf_grouped, aggregation)()
b = getattr(ddf_grouped, aggregation)().compute()
Expand All @@ -41,8 +45,8 @@ def test_groupby_basic(series, aggregation):
else:
dd.assert_eq(a, b)

a = gdf_grouped.agg({"x": aggregation})
b = ddf_grouped.agg({"x": aggregation}).compute()
a = gdf_grouped.agg({"xx": aggregation})
b = ddf_grouped.agg({"xx": aggregation}).compute()

if aggregation == "count":
dd.assert_eq(a, b, check_dtype=False)
Expand Down

0 comments on commit 16d9a92

Please sign in to comment.