Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[REVIEW] Replace dask groupby .index usages with .by #10193

Merged
merged 1 commit into from
Feb 2, 2022
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 17 additions & 24 deletions python/dask_cudf/dask_cudf/groupby.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
# Copyright (c) 2020, NVIDIA CORPORATION.
# Copyright (c) 2020-2022, NVIDIA CORPORATION.

import math
from operator import getitem
from typing import Set
Expand Down Expand Up @@ -42,19 +43,11 @@ def __init__(self, *args, **kwargs):
def __getitem__(self, key):
if isinstance(key, list):
g = CudfDataFrameGroupBy(
self.obj,
by=self.index,
slice=key,
sort=self.sort,
**self.dropna,
self.obj, by=self.by, slice=key, sort=self.sort, **self.dropna,
)
else:
g = CudfSeriesGroupBy(
self.obj,
by=self.index,
slice=key,
sort=self.sort,
**self.dropna,
self.obj, by=self.by, slice=key, sort=self.sort, **self.dropna,
)

g._meta = g._meta[key]
Expand All @@ -63,8 +56,8 @@ def __getitem__(self, key):
def mean(self, split_every=None, split_out=1):
return groupby_agg(
self.obj,
self.index,
{c: "mean" for c in self.obj.columns if c not in self.index},
self.by,
{c: "mean" for c in self.obj.columns if c not in self.by},
split_every=split_every,
split_out=split_out,
dropna=self.dropna,
Expand All @@ -76,8 +69,8 @@ def mean(self, split_every=None, split_out=1):
def collect(self, split_every=None, split_out=1):
return groupby_agg(
self.obj,
self.index,
{c: "collect" for c in self.obj.columns if c not in self.index},
self.by,
{c: "collect" for c in self.obj.columns if c not in self.by},
split_every=split_every,
split_out=split_out,
dropna=self.dropna,
Expand All @@ -94,10 +87,10 @@ def aggregate(self, arg, split_every=None, split_out=1):
if (
isinstance(self.obj, DaskDataFrame)
and (
isinstance(self.index, str)
isinstance(self.by, str)
or (
isinstance(self.index, list)
and all(isinstance(x, str) for x in self.index)
isinstance(self.by, list)
and all(isinstance(x, str) for x in self.by)
)
)
and _is_supported(arg, SUPPORTED_AGGS)
Expand Down Expand Up @@ -133,7 +126,7 @@ def __init__(self, *args, **kwargs):
def mean(self, split_every=None, split_out=1):
return groupby_agg(
self.obj,
self.index,
self.by,
{self._slice: "mean"},
split_every=split_every,
split_out=split_out,
Expand All @@ -146,7 +139,7 @@ def mean(self, split_every=None, split_out=1):
def std(self, split_every=None, split_out=1):
return groupby_agg(
self.obj,
self.index,
self.by,
{self._slice: "std"},
split_every=split_every,
split_out=split_out,
Expand All @@ -159,7 +152,7 @@ def std(self, split_every=None, split_out=1):
def var(self, split_every=None, split_out=1):
return groupby_agg(
self.obj,
self.index,
self.by,
{self._slice: "var"},
split_every=split_every,
split_out=split_out,
Expand All @@ -172,7 +165,7 @@ def var(self, split_every=None, split_out=1):
def collect(self, split_every=None, split_out=1):
return groupby_agg(
self.obj,
self.index,
self.by,
{self._slice: "collect"},
split_every=split_every,
split_out=split_out,
Expand All @@ -192,12 +185,12 @@ def aggregate(self, arg, split_every=None, split_out=1):

if (
isinstance(self.obj, DaskDataFrame)
and isinstance(self.index, (str, list))
and isinstance(self.by, (str, list))
and _is_supported(arg, SUPPORTED_AGGS)
):
return groupby_agg(
self.obj,
self.index,
self.by,
arg,
split_every=split_every,
split_out=split_out,
Expand Down