From 1382cdbd1f11a75afcedb51790c6248799deccab Mon Sep 17 00:00:00 2001 From: Lawrence Mitchell Date: Wed, 3 Apr 2024 11:58:40 +0000 Subject: [PATCH 1/2] Handle case of scan aggregation in groupby-transform When performing a groupby-transform with a scan aggregation, the intermediate result obtained from calling groupby-agg is already the correct shape and does not need to be broadcast to align with the grouping keys. To handle this, make sure that if the requested transform is a scan that we don't try and broadcast. While here, tighten up the input checking: transform only applies to a single aggregation, rather than the more general interface offered by agg. - Closes #12621 - Closes #15448 --- python/cudf/cudf/core/groupby/groupby.py | 12 +++++- .../cudf/cudf/tests/groupby/test_transform.py | 43 +++++++++++++++++++ 2 files changed, 54 insertions(+), 1 deletion(-) create mode 100644 python/cudf/cudf/tests/groupby/test_transform.py diff --git a/python/cudf/cudf/core/groupby/groupby.py b/python/cudf/cudf/core/groupby/groupby.py index 945e546af1a..dabc1cde8be 100644 --- a/python/cudf/cudf/core/groupby/groupby.py +++ b/python/cudf/cudf/core/groupby/groupby.py @@ -1767,13 +1767,23 @@ def transform(self, function): -------- agg """ + if not (isinstance(function, str) or callable(function)): + raise ValueError( + "Aggregation must be a named aggregation or a callable" + ) try: result = self.agg(function) except TypeError as e: raise NotImplementedError( "Currently, `transform()` supports only aggregations." ) from e - + # If the aggregation is a scan, don't broadcast + if libgroupby._is_all_scan_aggregate([[function]]): + if len(result) != len(self.obj): + raise AssertionError( + "Unexpected result length for scan transform" + ) + return result return self._broadcast(result) def rolling(self, *args, **kwargs): diff --git a/python/cudf/cudf/tests/groupby/test_transform.py b/python/cudf/cudf/tests/groupby/test_transform.py new file mode 100644 index 00000000000..50f0db15d6b --- /dev/null +++ b/python/cudf/cudf/tests/groupby/test_transform.py @@ -0,0 +1,43 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. +import itertools + +import pytest + +import cudf +from cudf.testing._utils import assert_eq + + +@pytest.fixture(params=[False, True], ids=["no-null-keys", "null-keys"]) +def keys_null(request): + return request.param + + +@pytest.fixture(params=[False, True], ids=["no-null-values", "null-values"]) +def values_null(request): + return request.param + + +@pytest.fixture +def df(keys_null, values_null): + keys = ["a", "b", "a", "c", "b", "b", "c", "a"] + r = range(len(keys)) + if keys_null: + keys[::3] = itertools.repeat(None, len(r[::3])) + values = list(range(len(keys))) + if values_null: + values[1::3] = itertools.repeat(None, len(r[1::3])) + return cudf.DataFrame({"key": keys, "values": values}) + + +@pytest.mark.parametrize("agg", ["cumsum", "cumprod", "max", "sum", "prod"]) +def test_transform_broadcast(agg, df): + pf = df.to_pandas() + got = df.groupby("key").transform(agg) + expect = pf.groupby("key").transform(agg) + assert_eq(got, expect, check_dtype=False) + + +def test_transform_invalid(): + df = cudf.DataFrame({"key": [1, 1], "values": [4, 5]}) + with pytest.raises(ValueError): + df.groupby("key").transform({"values": "cumprod"}) From f45b11f8077b968da1cd700bde61202d6df7e868 Mon Sep 17 00:00:00 2001 From: Lawrence Mitchell Date: Mon, 15 Apr 2024 16:28:08 +0000 Subject: [PATCH 2/2] Try and match pandas exception types --- python/cudf/cudf/core/groupby/groupby.py | 2 +- python/cudf/cudf/tests/groupby/test_transform.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/python/cudf/cudf/core/groupby/groupby.py b/python/cudf/cudf/core/groupby/groupby.py index dabc1cde8be..dd4924676f3 100644 --- a/python/cudf/cudf/core/groupby/groupby.py +++ b/python/cudf/cudf/core/groupby/groupby.py @@ -1768,7 +1768,7 @@ def transform(self, function): agg """ if not (isinstance(function, str) or callable(function)): - raise ValueError( + raise TypeError( "Aggregation must be a named aggregation or a callable" ) try: diff --git a/python/cudf/cudf/tests/groupby/test_transform.py b/python/cudf/cudf/tests/groupby/test_transform.py index 50f0db15d6b..78d7fbfd879 100644 --- a/python/cudf/cudf/tests/groupby/test_transform.py +++ b/python/cudf/cudf/tests/groupby/test_transform.py @@ -39,5 +39,5 @@ def test_transform_broadcast(agg, df): def test_transform_invalid(): df = cudf.DataFrame({"key": [1, 1], "values": [4, 5]}) - with pytest.raises(ValueError): + with pytest.raises(TypeError): df.groupby("key").transform({"values": "cumprod"})