diff --git a/python/cudf/cudf/core/groupby/groupby.py b/python/cudf/cudf/core/groupby/groupby.py index 945e546af1a..dd4924676f3 100644 --- a/python/cudf/cudf/core/groupby/groupby.py +++ b/python/cudf/cudf/core/groupby/groupby.py @@ -1767,13 +1767,23 @@ def transform(self, function): -------- agg """ + if not (isinstance(function, str) or callable(function)): + raise TypeError( + "Aggregation must be a named aggregation or a callable" + ) try: result = self.agg(function) except TypeError as e: raise NotImplementedError( "Currently, `transform()` supports only aggregations." ) from e - + # If the aggregation is a scan, don't broadcast + if libgroupby._is_all_scan_aggregate([[function]]): + if len(result) != len(self.obj): + raise AssertionError( + "Unexpected result length for scan transform" + ) + return result return self._broadcast(result) def rolling(self, *args, **kwargs): diff --git a/python/cudf/cudf/tests/groupby/test_transform.py b/python/cudf/cudf/tests/groupby/test_transform.py new file mode 100644 index 00000000000..78d7fbfd879 --- /dev/null +++ b/python/cudf/cudf/tests/groupby/test_transform.py @@ -0,0 +1,43 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. +import itertools + +import pytest + +import cudf +from cudf.testing._utils import assert_eq + + +@pytest.fixture(params=[False, True], ids=["no-null-keys", "null-keys"]) +def keys_null(request): + return request.param + + +@pytest.fixture(params=[False, True], ids=["no-null-values", "null-values"]) +def values_null(request): + return request.param + + +@pytest.fixture +def df(keys_null, values_null): + keys = ["a", "b", "a", "c", "b", "b", "c", "a"] + r = range(len(keys)) + if keys_null: + keys[::3] = itertools.repeat(None, len(r[::3])) + values = list(range(len(keys))) + if values_null: + values[1::3] = itertools.repeat(None, len(r[1::3])) + return cudf.DataFrame({"key": keys, "values": values}) + + +@pytest.mark.parametrize("agg", ["cumsum", "cumprod", "max", "sum", "prod"]) +def test_transform_broadcast(agg, df): + pf = df.to_pandas() + got = df.groupby("key").transform(agg) + expect = pf.groupby("key").transform(agg) + assert_eq(got, expect, check_dtype=False) + + +def test_transform_invalid(): + df = cudf.DataFrame({"key": [1, 1], "values": [4, 5]}) + with pytest.raises(TypeError): + df.groupby("key").transform({"values": "cumprod"})