Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Handle case of scan aggregation in groupby-transform #15450

Merged
merged 3 commits into from
Apr 16, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 11 additions & 1 deletion python/cudf/cudf/core/groupby/groupby.py
Original file line number Diff line number Diff line change
Expand Up @@ -1767,13 +1767,23 @@ def transform(self, function):
--------
agg
"""
if not (isinstance(function, str) or callable(function)):
raise ValueError(
"Aggregation must be a named aggregation or a callable"
)
try:
result = self.agg(function)
except TypeError as e:
raise NotImplementedError(
"Currently, `transform()` supports only aggregations."
) from e

# If the aggregation is a scan, don't broadcast
if libgroupby._is_all_scan_aggregate([[function]]):
if len(result) != len(self.obj):
raise AssertionError(
"Unexpected result length for scan transform"
)
return result
return self._broadcast(result)

def rolling(self, *args, **kwargs):
Expand Down
43 changes: 43 additions & 0 deletions python/cudf/cudf/tests/groupby/test_transform.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
# Copyright (c) 2024, NVIDIA CORPORATION.
import itertools

import pytest

import cudf
from cudf.testing._utils import assert_eq


@pytest.fixture(params=[False, True], ids=["no-null-keys", "null-keys"])
def keys_null(request):
return request.param


@pytest.fixture(params=[False, True], ids=["no-null-values", "null-values"])
def values_null(request):
return request.param


@pytest.fixture
def df(keys_null, values_null):
keys = ["a", "b", "a", "c", "b", "b", "c", "a"]
r = range(len(keys))
if keys_null:
keys[::3] = itertools.repeat(None, len(r[::3]))
values = list(range(len(keys)))
if values_null:
values[1::3] = itertools.repeat(None, len(r[1::3]))
return cudf.DataFrame({"key": keys, "values": values})


@pytest.mark.parametrize("agg", ["cumsum", "cumprod", "max", "sum", "prod"])
def test_transform_broadcast(agg, df):
pf = df.to_pandas()
got = df.groupby("key").transform(agg)
expect = pf.groupby("key").transform(agg)
assert_eq(got, expect, check_dtype=False)


def test_transform_invalid():
df = cudf.DataFrame({"key": [1, 1], "values": [4, 5]})
with pytest.raises(ValueError):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we care to raise the same error as pandas? This ends up hitting a TypeError in pandas for a few of the "bad" inputs that I tried like {"values": "cumprod"} or the tuple ("cumprod",) or the integer 3. Only invalid function names (strings) raise ValueError from what I can tell.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh probably, I didn't check exhaustively what pandas produces.

df.groupby("key").transform({"values": "cumprod"})
Loading