Skip to content

Commit

Permalink
Allow string aggs for dask_cudf.CudfDataFrameGroupBy.aggregate (#10222
Browse files Browse the repository at this point in the history
)

I noticed that `CudfDataFrameGroupBy.aggregate` doesn't actually support passing aggregations as strings, for example something like

```python
import cudf
import dask_cudf

gdf = cudf.DataFrame({'id4': 4*list(range(6)), 'id5': 4*list(reversed(range(6))), 'v3': 6*list(range(4))})
gddf = dask_cudf.from_cudf(gdf, npartitions=5)

gddf.groupby("id4").agg("mean")
```

Would actually end up using the upstream `aggregate` implementation. This is because:

- `CudfDataFrameGroupBy.aggregate` does not convert string aggs to a dict before calling `_is_supported` on them
- `_is_supported` only handles list / dict aggs, returning false otherwise

I've resolved this by adding string support to `_is_supported`, and moving the conversion of aggs to the internal `groupby_agg`.

It looks like this is exposing some failures for `first` and `last` groupby aggs, as tests that were originally using upstream Dask to compute these aggregations (I assume accidentally since these aggregations are listed as supported) are now using dask-cuDF and getting the wrong result.

Authors:
  - Charles Blackmon-Luca (https://github.com/charlesbluca)

Approvers:
  - Bradley Dice (https://github.com/bdice)
  - GALI PREM SAGAR (https://github.com/galipremsagar)

URL: #10222
  • Loading branch information
charlesbluca authored May 13, 2022
1 parent b64452a commit 4ad1e51
Show file tree
Hide file tree
Showing 3 changed files with 22 additions and 6 deletions.
2 changes: 1 addition & 1 deletion python/cudf/cudf/core/resample.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# SPDX-FileCopyrightText: Copyright (c) 2021 NVIDIA CORPORATION &
# SPDX-FileCopyrightText: Copyright (c) 2021-2022, NVIDIA CORPORATION &
# AFFILIATES. All rights reserved. SPDX-License-Identifier:
# Apache-2.0
#
Expand Down
6 changes: 5 additions & 1 deletion python/dask_cudf/dask_cudf/groupby.py
Original file line number Diff line number Diff line change
Expand Up @@ -245,6 +245,7 @@ def last(self, split_every=None, split_out=1):
def aggregate(self, arg, split_every=None, split_out=1):
if arg == "size":
return self.size()

arg = _redirect_aggs(arg)

if _groupby_supported(self) and _aggs_supported(arg, SUPPORTED_AGGS):
Expand Down Expand Up @@ -431,6 +432,7 @@ def last(self, split_every=None, split_out=1):
def aggregate(self, arg, split_every=None, split_out=1):
if arg == "size":
return self.size()

arg = _redirect_aggs(arg)

if not isinstance(arg, dict):
Expand Down Expand Up @@ -503,7 +505,7 @@ def groupby_agg(
if isinstance(gb_cols, str):
gb_cols = [gb_cols]
columns = [c for c in ddf.columns if c not in gb_cols]
if isinstance(aggs, list):
if not isinstance(aggs, dict):
aggs = {col: aggs for col in columns}

# Assert if our output will have a MultiIndex; this will be the case if
Expand Down Expand Up @@ -665,6 +667,8 @@ def _aggs_supported(arg, supported: set):
_global_set = set(arg)

return bool(_global_set.issubset(supported))
elif isinstance(arg, str):
return arg in supported
return False


Expand Down
20 changes: 16 additions & 4 deletions python/dask_cudf/dask_cudf/tests/test_groupby.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,10 @@ def test_groupby_basic(series, aggregation):
"func",
[
lambda df: df.groupby("x").agg({"y": "max"}),
lambda df: df.groupby("x").agg(["sum", "max"]),
lambda df: df.groupby("x").y.agg(["sum", "max"]),
lambda df: df.groupby("x").agg("sum"),
lambda df: df.groupby("x").y.agg("sum"),
],
)
def test_groupby_agg(func):
Expand Down Expand Up @@ -663,11 +666,20 @@ def test_groupby_agg_redirect(aggregations):


@pytest.mark.parametrize(
"arg",
[["not_supported"], {"a": "not_supported"}, {"a": ["not_supported"]}],
"arg,supported",
[
("sum", True),
(["sum"], True),
({"a": "sum"}, True),
({"a": ["sum"]}, True),
("not_supported", False),
(["not_supported"], False),
({"a": "not_supported"}, False),
({"a": ["not_supported"]}, False),
],
)
def test_is_supported(arg):
assert _aggs_supported(arg, {"supported"}) is False
def test_is_supported(arg, supported):
assert _aggs_supported(arg, SUPPORTED_AGGS) is supported


def test_groupby_unique_lists():
Expand Down

0 comments on commit 4ad1e51

Please sign in to comment.