Skip to content

Commit

Permalink
Remove more repetition from agg dictionary construction
Browse files Browse the repository at this point in the history
  • Loading branch information
charlesbluca committed May 11, 2022
1 parent 08d8e7f commit d652157
Showing 1 changed file with 14 additions and 19 deletions.
33 changes: 14 additions & 19 deletions python/dask_cudf/dask_cudf/groupby.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,25 +84,20 @@ def __getitem__(self, key):
return g

@_dask_cudf_nvtx_annotate
def _columns_not_in_by(self):
"""Generator of the columns contained in the groupby agg result"""
def _make_groupby_method_aggs(self, agg_name):
"""Create aggs dictionary for aggregation methods"""

if isinstance(self.by, list):
for c in self.obj.columns:
if c not in self.by:
yield c
else:
for c in self.obj.columns:
if c != self.by:
yield c
return {c: agg_name for c in self.obj.columns if c not in self.by}
return {c: agg_name for c in self.obj.columns if c != self.by}

@_dask_cudf_nvtx_annotate
@_check_groupby_supported
def count(self, split_every=None, split_out=1):
return groupby_agg(
self.obj,
self.by,
{c: "count" for c in self._columns_not_in_by()},
self._make_groupby_method_aggs("count"),
split_every=split_every,
split_out=split_out,
sep=self.sep,
Expand All @@ -117,7 +112,7 @@ def mean(self, split_every=None, split_out=1):
return groupby_agg(
self.obj,
self.by,
{c: "mean" for c in self._columns_not_in_by()},
self._make_groupby_method_aggs("mean"),
split_every=split_every,
split_out=split_out,
sep=self.sep,
Expand All @@ -132,7 +127,7 @@ def std(self, split_every=None, split_out=1):
return groupby_agg(
self.obj,
self.by,
{c: "std" for c in self._columns_not_in_by()},
self._make_groupby_method_aggs("std"),
split_every=split_every,
split_out=split_out,
sep=self.sep,
Expand All @@ -147,7 +142,7 @@ def var(self, split_every=None, split_out=1):
return groupby_agg(
self.obj,
self.by,
{c: "var" for c in self._columns_not_in_by()},
self._make_groupby_method_aggs("var"),
split_every=split_every,
split_out=split_out,
sep=self.sep,
Expand All @@ -162,7 +157,7 @@ def sum(self, split_every=None, split_out=1):
return groupby_agg(
self.obj,
self.by,
{c: "sum" for c in self._columns_not_in_by()},
self._make_groupby_method_aggs("sum"),
split_every=split_every,
split_out=split_out,
sep=self.sep,
Expand All @@ -177,7 +172,7 @@ def min(self, split_every=None, split_out=1):
return groupby_agg(
self.obj,
self.by,
{c: "min" for c in self._columns_not_in_by()},
self._make_groupby_method_aggs("min"),
split_every=split_every,
split_out=split_out,
sep=self.sep,
Expand All @@ -192,7 +187,7 @@ def max(self, split_every=None, split_out=1):
return groupby_agg(
self.obj,
self.by,
{c: "max" for c in self._columns_not_in_by()},
self._make_groupby_method_aggs("max"),
split_every=split_every,
split_out=split_out,
sep=self.sep,
Expand All @@ -207,7 +202,7 @@ def collect(self, split_every=None, split_out=1):
return groupby_agg(
self.obj,
self.by,
{c: "collect" for c in self._columns_not_in_by()},
self._make_groupby_method_aggs("collect"),
split_every=split_every,
split_out=split_out,
sep=self.sep,
Expand All @@ -222,7 +217,7 @@ def first(self, split_every=None, split_out=1):
return groupby_agg(
self.obj,
self.by,
{c: "first" for c in self._columns_not_in_by()},
self._make_groupby_method_aggs("first"),
split_every=split_every,
split_out=split_out,
sep=self.sep,
Expand All @@ -237,7 +232,7 @@ def last(self, split_every=None, split_out=1):
return groupby_agg(
self.obj,
self.by,
{c: "last" for c in self._columns_not_in_by()},
self._make_groupby_method_aggs("last"),
split_every=split_every,
split_out=split_out,
sep=self.sep,
Expand Down

0 comments on commit d652157

Please sign in to comment.