diff --git a/python/dask_cudf/dask_cudf/groupby.py b/python/dask_cudf/dask_cudf/groupby.py index 6d819d9d462..a64aabe1a6b 100644 --- a/python/dask_cudf/dask_cudf/groupby.py +++ b/python/dask_cudf/dask_cudf/groupby.py @@ -84,17 +84,12 @@ def __getitem__(self, key): return g @_dask_cudf_nvtx_annotate - def _columns_not_in_by(self): - """Generator of the columns contained in the groupby agg result""" + def _make_groupby_method_aggs(self, agg_name): + """Create aggs dictionary for aggregation methods""" if isinstance(self.by, list): - for c in self.obj.columns: - if c not in self.by: - yield c - else: - for c in self.obj.columns: - if c != self.by: - yield c + return {c: agg_name for c in self.obj.columns if c not in self.by} + return {c: agg_name for c in self.obj.columns if c != self.by} @_dask_cudf_nvtx_annotate @_check_groupby_supported @@ -102,7 +97,7 @@ def count(self, split_every=None, split_out=1): return groupby_agg( self.obj, self.by, - {c: "count" for c in self._columns_not_in_by()}, + self._make_groupby_method_aggs("count"), split_every=split_every, split_out=split_out, sep=self.sep, @@ -117,7 +112,7 @@ def mean(self, split_every=None, split_out=1): return groupby_agg( self.obj, self.by, - {c: "mean" for c in self._columns_not_in_by()}, + self._make_groupby_method_aggs("mean"), split_every=split_every, split_out=split_out, sep=self.sep, @@ -132,7 +127,7 @@ def std(self, split_every=None, split_out=1): return groupby_agg( self.obj, self.by, - {c: "std" for c in self._columns_not_in_by()}, + self._make_groupby_method_aggs("std"), split_every=split_every, split_out=split_out, sep=self.sep, @@ -147,7 +142,7 @@ def var(self, split_every=None, split_out=1): return groupby_agg( self.obj, self.by, - {c: "var" for c in self._columns_not_in_by()}, + self._make_groupby_method_aggs("var"), split_every=split_every, split_out=split_out, sep=self.sep, @@ -162,7 +157,7 @@ def sum(self, split_every=None, split_out=1): return groupby_agg( self.obj, self.by, - {c: "sum" for c in self._columns_not_in_by()}, + self._make_groupby_method_aggs("sum"), split_every=split_every, split_out=split_out, sep=self.sep, @@ -177,7 +172,7 @@ def min(self, split_every=None, split_out=1): return groupby_agg( self.obj, self.by, - {c: "min" for c in self._columns_not_in_by()}, + self._make_groupby_method_aggs("min"), split_every=split_every, split_out=split_out, sep=self.sep, @@ -192,7 +187,7 @@ def max(self, split_every=None, split_out=1): return groupby_agg( self.obj, self.by, - {c: "max" for c in self._columns_not_in_by()}, + self._make_groupby_method_aggs("max"), split_every=split_every, split_out=split_out, sep=self.sep, @@ -207,7 +202,7 @@ def collect(self, split_every=None, split_out=1): return groupby_agg( self.obj, self.by, - {c: "collect" for c in self._columns_not_in_by()}, + self._make_groupby_method_aggs("collect"), split_every=split_every, split_out=split_out, sep=self.sep, @@ -222,7 +217,7 @@ def first(self, split_every=None, split_out=1): return groupby_agg( self.obj, self.by, - {c: "first" for c in self._columns_not_in_by()}, + self._make_groupby_method_aggs("first"), split_every=split_every, split_out=split_out, sep=self.sep, @@ -237,7 +232,7 @@ def last(self, split_every=None, split_out=1): return groupby_agg( self.obj, self.by, - {c: "last" for c in self._columns_not_in_by()}, + self._make_groupby_method_aggs("last"), split_every=split_every, split_out=split_out, sep=self.sep,