Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow string aggs for dask_cudf.CudfDataFrameGroupBy.aggregate #10222

Merged
2 changes: 1 addition & 1 deletion python/cudf/cudf/core/resample.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# SPDX-FileCopyrightText: Copyright (c) 2021 NVIDIA CORPORATION &
# SPDX-FileCopyrightText: Copyright (c) 2021-2022, NVIDIA CORPORATION &
# AFFILIATES. All rights reserved. SPDX-License-Identifier:
# Apache-2.0
#
Expand Down
6 changes: 5 additions & 1 deletion python/dask_cudf/dask_cudf/groupby.py
Original file line number Diff line number Diff line change
Expand Up @@ -245,6 +245,7 @@ def last(self, split_every=None, split_out=1):
def aggregate(self, arg, split_every=None, split_out=1):
if arg == "size":
return self.size()

arg = _redirect_aggs(arg)

if _groupby_supported(self) and _aggs_supported(arg, SUPPORTED_AGGS):
Expand Down Expand Up @@ -431,6 +432,7 @@ def last(self, split_every=None, split_out=1):
def aggregate(self, arg, split_every=None, split_out=1):
if arg == "size":
return self.size()

arg = _redirect_aggs(arg)
charlesbluca marked this conversation as resolved.
Show resolved Hide resolved

if not isinstance(arg, dict):
Expand Down Expand Up @@ -503,7 +505,7 @@ def groupby_agg(
if isinstance(gb_cols, str):
gb_cols = [gb_cols]
columns = [c for c in ddf.columns if c not in gb_cols]
if isinstance(aggs, list):
if not isinstance(aggs, dict):
aggs = {col: aggs for col in columns}

# Assert if our output will have a MultiIndex; this will be the case if
Expand Down Expand Up @@ -665,6 +667,8 @@ def _aggs_supported(arg, supported: set):
_global_set = set(arg)

return bool(_global_set.issubset(supported))
elif isinstance(arg, str):
return arg in supported
return False


Expand Down
20 changes: 16 additions & 4 deletions python/dask_cudf/dask_cudf/tests/test_groupby.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,10 @@ def test_groupby_basic(series, aggregation):
"func",
[
lambda df: df.groupby("x").agg({"y": "max"}),
lambda df: df.groupby("x").agg(["sum", "max"]),
lambda df: df.groupby("x").y.agg(["sum", "max"]),
lambda df: df.groupby("x").agg("sum"),
lambda df: df.groupby("x").y.agg("sum"),
],
)
def test_groupby_agg(func):
Expand Down Expand Up @@ -663,11 +666,20 @@ def test_groupby_agg_redirect(aggregations):


@pytest.mark.parametrize(
"arg",
[["not_supported"], {"a": "not_supported"}, {"a": ["not_supported"]}],
"arg,supported",
[
("sum", True),
(["sum"], True),
({"a": "sum"}, True),
({"a": ["sum"]}, True),
("not_supported", False),
(["not_supported"], False),
({"a": "not_supported"}, False),
({"a": ["not_supported"]}, False),
],
)
def test_is_supported(arg):
assert _aggs_supported(arg, {"supported"}) is False
def test_is_supported(arg, supported):
assert _aggs_supported(arg, SUPPORTED_AGGS) is supported


def test_groupby_unique_lists():
Expand Down