diff --git a/python/dask_cudf/dask_cudf/groupby.py b/python/dask_cudf/dask_cudf/groupby.py index 149d98ebfb9..1bc270a5b9f 100644 --- a/python/dask_cudf/dask_cudf/groupby.py +++ b/python/dask_cudf/dask_cudf/groupby.py @@ -1,4 +1,5 @@ -# Copyright (c) 2020, NVIDIA CORPORATION. +# Copyright (c) 2020-2022, NVIDIA CORPORATION. + import math from operator import getitem from typing import Set @@ -42,19 +43,11 @@ def __init__(self, *args, **kwargs): def __getitem__(self, key): if isinstance(key, list): g = CudfDataFrameGroupBy( - self.obj, - by=self.index, - slice=key, - sort=self.sort, - **self.dropna, + self.obj, by=self.by, slice=key, sort=self.sort, **self.dropna, ) else: g = CudfSeriesGroupBy( - self.obj, - by=self.index, - slice=key, - sort=self.sort, - **self.dropna, + self.obj, by=self.by, slice=key, sort=self.sort, **self.dropna, ) g._meta = g._meta[key] @@ -63,8 +56,8 @@ def __getitem__(self, key): def mean(self, split_every=None, split_out=1): return groupby_agg( self.obj, - self.index, - {c: "mean" for c in self.obj.columns if c not in self.index}, + self.by, + {c: "mean" for c in self.obj.columns if c not in self.by}, split_every=split_every, split_out=split_out, dropna=self.dropna, @@ -76,8 +69,8 @@ def mean(self, split_every=None, split_out=1): def collect(self, split_every=None, split_out=1): return groupby_agg( self.obj, - self.index, - {c: "collect" for c in self.obj.columns if c not in self.index}, + self.by, + {c: "collect" for c in self.obj.columns if c not in self.by}, split_every=split_every, split_out=split_out, dropna=self.dropna, @@ -94,10 +87,10 @@ def aggregate(self, arg, split_every=None, split_out=1): if ( isinstance(self.obj, DaskDataFrame) and ( - isinstance(self.index, str) + isinstance(self.by, str) or ( - isinstance(self.index, list) - and all(isinstance(x, str) for x in self.index) + isinstance(self.by, list) + and all(isinstance(x, str) for x in self.by) ) ) and _is_supported(arg, SUPPORTED_AGGS) @@ -133,7 +126,7 @@ def __init__(self, *args, **kwargs): def mean(self, split_every=None, split_out=1): return groupby_agg( self.obj, - self.index, + self.by, {self._slice: "mean"}, split_every=split_every, split_out=split_out, @@ -146,7 +139,7 @@ def mean(self, split_every=None, split_out=1): def std(self, split_every=None, split_out=1): return groupby_agg( self.obj, - self.index, + self.by, {self._slice: "std"}, split_every=split_every, split_out=split_out, @@ -159,7 +152,7 @@ def std(self, split_every=None, split_out=1): def var(self, split_every=None, split_out=1): return groupby_agg( self.obj, - self.index, + self.by, {self._slice: "var"}, split_every=split_every, split_out=split_out, @@ -172,7 +165,7 @@ def var(self, split_every=None, split_out=1): def collect(self, split_every=None, split_out=1): return groupby_agg( self.obj, - self.index, + self.by, {self._slice: "collect"}, split_every=split_every, split_out=split_out, @@ -192,12 +185,12 @@ def aggregate(self, arg, split_every=None, split_out=1): if ( isinstance(self.obj, DaskDataFrame) - and isinstance(self.index, (str, list)) + and isinstance(self.by, (str, list)) and _is_supported(arg, SUPPORTED_AGGS) ): return groupby_agg( self.obj, - self.index, + self.by, arg, split_every=split_every, split_out=split_out,