Skip to content

Commit

Permalink
Handle case of scan aggregation in groupby-transform (#15450)
Browse files Browse the repository at this point in the history
When performing a groupby-transform with a scan aggregation, the intermediate result obtained from calling groupby-agg is already the correct shape and does not need to be broadcast to align with the grouping keys.

To handle this, make sure that if the requested transform is a scan that we don't try and broadcast.

While here, tighten up the input checking: transform only applies to a single aggregation, rather than the more general interface offered by agg.

- Closes #12621
- Closes #15448

Authors:
  - Lawrence Mitchell (https://github.com/wence-)

Approvers:
  - Bradley Dice (https://github.com/bdice)

URL: #15450
  • Loading branch information
wence- authored Apr 16, 2024
1 parent 8919690 commit c1dcc31
Show file tree
Hide file tree
Showing 2 changed files with 54 additions and 1 deletion.
12 changes: 11 additions & 1 deletion python/cudf/cudf/core/groupby/groupby.py
Original file line number Diff line number Diff line change
Expand Up @@ -1767,13 +1767,23 @@ def transform(self, function):
--------
agg
"""
if not (isinstance(function, str) or callable(function)):
raise TypeError(
"Aggregation must be a named aggregation or a callable"
)
try:
result = self.agg(function)
except TypeError as e:
raise NotImplementedError(
"Currently, `transform()` supports only aggregations."
) from e

# If the aggregation is a scan, don't broadcast
if libgroupby._is_all_scan_aggregate([[function]]):
if len(result) != len(self.obj):
raise AssertionError(
"Unexpected result length for scan transform"
)
return result
return self._broadcast(result)

def rolling(self, *args, **kwargs):
Expand Down
43 changes: 43 additions & 0 deletions python/cudf/cudf/tests/groupby/test_transform.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
# Copyright (c) 2024, NVIDIA CORPORATION.
import itertools

import pytest

import cudf
from cudf.testing._utils import assert_eq


@pytest.fixture(params=[False, True], ids=["no-null-keys", "null-keys"])
def keys_null(request):
return request.param


@pytest.fixture(params=[False, True], ids=["no-null-values", "null-values"])
def values_null(request):
return request.param


@pytest.fixture
def df(keys_null, values_null):
keys = ["a", "b", "a", "c", "b", "b", "c", "a"]
r = range(len(keys))
if keys_null:
keys[::3] = itertools.repeat(None, len(r[::3]))
values = list(range(len(keys)))
if values_null:
values[1::3] = itertools.repeat(None, len(r[1::3]))
return cudf.DataFrame({"key": keys, "values": values})


@pytest.mark.parametrize("agg", ["cumsum", "cumprod", "max", "sum", "prod"])
def test_transform_broadcast(agg, df):
pf = df.to_pandas()
got = df.groupby("key").transform(agg)
expect = pf.groupby("key").transform(agg)
assert_eq(got, expect, check_dtype=False)


def test_transform_invalid():
df = cudf.DataFrame({"key": [1, 1], "values": [4, 5]})
with pytest.raises(TypeError):
df.groupby("key").transform({"values": "cumprod"})

0 comments on commit c1dcc31

Please sign in to comment.