diff --git a/python/dask_cudf/dask_cudf/groupby.py b/python/dask_cudf/dask_cudf/groupby.py index d137fac5fe3..a64aabe1a6b 100644 --- a/python/dask_cudf/dask_cudf/groupby.py +++ b/python/dask_cudf/dask_cudf/groupby.py @@ -83,13 +83,21 @@ def __getitem__(self, key): g._meta = g._meta[key] return g + @_dask_cudf_nvtx_annotate + def _make_groupby_method_aggs(self, agg_name): + """Create aggs dictionary for aggregation methods""" + + if isinstance(self.by, list): + return {c: agg_name for c in self.obj.columns if c not in self.by} + return {c: agg_name for c in self.obj.columns if c != self.by} + @_dask_cudf_nvtx_annotate @_check_groupby_supported def count(self, split_every=None, split_out=1): return groupby_agg( self.obj, self.by, - {c: "count" for c in self.obj.columns if c not in self.by}, + self._make_groupby_method_aggs("count"), split_every=split_every, split_out=split_out, sep=self.sep, @@ -104,7 +112,7 @@ def mean(self, split_every=None, split_out=1): return groupby_agg( self.obj, self.by, - {c: "mean" for c in self.obj.columns if c not in self.by}, + self._make_groupby_method_aggs("mean"), split_every=split_every, split_out=split_out, sep=self.sep, @@ -119,7 +127,7 @@ def std(self, split_every=None, split_out=1): return groupby_agg( self.obj, self.by, - {c: "std" for c in self.obj.columns if c not in self.by}, + self._make_groupby_method_aggs("std"), split_every=split_every, split_out=split_out, sep=self.sep, @@ -134,7 +142,7 @@ def var(self, split_every=None, split_out=1): return groupby_agg( self.obj, self.by, - {c: "var" for c in self.obj.columns if c not in self.by}, + self._make_groupby_method_aggs("var"), split_every=split_every, split_out=split_out, sep=self.sep, @@ -149,7 +157,7 @@ def sum(self, split_every=None, split_out=1): return groupby_agg( self.obj, self.by, - {c: "sum" for c in self.obj.columns if c not in self.by}, + self._make_groupby_method_aggs("sum"), split_every=split_every, split_out=split_out, sep=self.sep, @@ -164,7 +172,7 @@ def min(self, split_every=None, split_out=1): return groupby_agg( self.obj, self.by, - {c: "min" for c in self.obj.columns if c not in self.by}, + self._make_groupby_method_aggs("min"), split_every=split_every, split_out=split_out, sep=self.sep, @@ -179,7 +187,7 @@ def max(self, split_every=None, split_out=1): return groupby_agg( self.obj, self.by, - {c: "max" for c in self.obj.columns if c not in self.by}, + self._make_groupby_method_aggs("max"), split_every=split_every, split_out=split_out, sep=self.sep, @@ -194,7 +202,7 @@ def collect(self, split_every=None, split_out=1): return groupby_agg( self.obj, self.by, - {c: "collect" for c in self.obj.columns if c not in self.by}, + self._make_groupby_method_aggs("collect"), split_every=split_every, split_out=split_out, sep=self.sep, @@ -209,7 +217,7 @@ def first(self, split_every=None, split_out=1): return groupby_agg( self.obj, self.by, - {c: "first" for c in self.obj.columns if c not in self.by}, + self._make_groupby_method_aggs("first"), split_every=split_every, split_out=split_out, sep=self.sep, @@ -224,7 +232,7 @@ def last(self, split_every=None, split_out=1): return groupby_agg( self.obj, self.by, - {c: "last" for c in self.obj.columns if c not in self.by}, + self._make_groupby_method_aggs("last"), split_every=split_every, split_out=split_out, sep=self.sep, @@ -660,6 +668,7 @@ def _aggs_supported(arg, supported: set): return False +@_dask_cudf_nvtx_annotate def _groupby_supported(gb): """Check that groupby input is supported by dask-cudf""" return isinstance(gb.obj, DaskDataFrame) and ( diff --git a/python/dask_cudf/dask_cudf/tests/test_groupby.py b/python/dask_cudf/dask_cudf/tests/test_groupby.py index d2c9ecd0293..5aa9cffb789 100644 --- a/python/dask_cudf/dask_cudf/tests/test_groupby.py +++ b/python/dask_cudf/dask_cudf/tests/test_groupby.py @@ -18,20 +18,24 @@ @pytest.mark.parametrize("series", [False, True]) def test_groupby_basic(series, aggregation): np.random.seed(0) + + # note that column name "x" is a substring of the groupby key; + # this gives us coverage for cudf#10829 pdf = pd.DataFrame( { - "x": np.random.randint(0, 5, size=10000), + "xx": np.random.randint(0, 5, size=10000), + "x": np.random.normal(size=10000), "y": np.random.normal(size=10000), } ) gdf = cudf.DataFrame.from_pandas(pdf) - gdf_grouped = gdf.groupby("x") - ddf_grouped = dask_cudf.from_cudf(gdf, npartitions=5).groupby("x") + gdf_grouped = gdf.groupby("xx") + ddf_grouped = dask_cudf.from_cudf(gdf, npartitions=5).groupby("xx") if series: - gdf_grouped = gdf_grouped.x - ddf_grouped = ddf_grouped.x + gdf_grouped = gdf_grouped.xx + ddf_grouped = ddf_grouped.xx a = getattr(gdf_grouped, aggregation)() b = getattr(ddf_grouped, aggregation)().compute() @@ -41,8 +45,8 @@ def test_groupby_basic(series, aggregation): else: dd.assert_eq(a, b) - a = gdf_grouped.agg({"x": aggregation}) - b = ddf_grouped.agg({"x": aggregation}).compute() + a = gdf_grouped.agg({"xx": aggregation}) + b = ddf_grouped.agg({"xx": aggregation}).compute() if aggregation == "count": dd.assert_eq(a, b, check_dtype=False)