From e78f47ae6cd19501d0875595b82f8618278ca4eb Mon Sep 17 00:00:00 2001 From: Ashwin Srinath <3190405+shwina@users.noreply.github.com> Date: Thu, 20 Jan 2022 08:22:17 -0500 Subject: [PATCH] Add `groupby.transform` (only support for aggregations) (#10005) Closes https://github.com/rapidsai/cudf/issues/4522 This PR adds support for doing groupby aggregations via the `transform()` API, where the result of the aggregation is broadcasted to the size of the group. Note that more general transformations are not supported at this time. Authors: - Ashwin Srinath (https://github.com/shwina) Approvers: - Michael Wang (https://github.com/isVoid) - GALI PREM SAGAR (https://github.com/galipremsagar) URL: https://github.com/rapidsai/cudf/pull/10005 --- docs/cudf/source/api_docs/groupby.rst | 1 + docs/cudf/source/basics/groupby.rst | 23 +++++++++ python/cudf/cudf/core/groupby/groupby.py | 64 +++++++++++++++++++++++- python/cudf/cudf/tests/test_groupby.py | 22 ++++++++ 4 files changed, 109 insertions(+), 1 deletion(-) diff --git a/docs/cudf/source/api_docs/groupby.rst b/docs/cudf/source/api_docs/groupby.rst index 575d7442cdf..190978a7581 100644 --- a/docs/cudf/source/api_docs/groupby.rst +++ b/docs/cudf/source/api_docs/groupby.rst @@ -34,6 +34,7 @@ Function application SeriesGroupBy.aggregate DataFrameGroupBy.aggregate GroupBy.pipe + GroupBy.transform Computations / descriptive stats -------------------------------- diff --git a/docs/cudf/source/basics/groupby.rst b/docs/cudf/source/basics/groupby.rst index f3269768025..cbc8f7e712f 100644 --- a/docs/cudf/source/basics/groupby.rst +++ b/docs/cudf/source/basics/groupby.rst @@ -1,3 +1,5 @@ +.. _basics.groupby: + GroupBy ======= @@ -220,6 +222,27 @@ Limitations .. |describe| replace:: ``describe`` .. _describe: https://pandas.pydata.org/pandas-docs/stable/user_guide/groupby.html#flexible-apply + +Transform +--------- + +The ``.transform()`` method aggregates per group, and broadcasts the +result to the group size, resulting in a Series/DataFrame that is of +the same size as the input Series/DataFrame. + +.. code:: python + + >>> import cudf + >>> df = cudf.DataFrame({'a': [2, 1, 1, 2, 2], 'b': [1, 2, 3, 4, 5]}) + >>> df.groupby('a').transform('max') + b + 0 5 + 1 3 + 2 3 + 3 5 + 4 5 + + Rolling window calculations --------------------------- diff --git a/python/cudf/cudf/core/groupby/groupby.py b/python/cudf/cudf/core/groupby/groupby.py index 6da98bf980d..a393d8e9457 100644 --- a/python/cudf/cudf/core/groupby/groupby.py +++ b/python/cudf/cudf/core/groupby/groupby.py @@ -184,11 +184,25 @@ def agg(self, func): Parameters ---------- func : str, callable, list or dict + Argument specifying the aggregation(s) to perform on the + groups. `func` can be any of the following: + + - string: the name of a supported aggregation + - callable: a function that accepts a Series/DataFrame and + performs a supported operation on it. + - list: a list of strings/callables specifying the + aggregations to perform on every column. + - dict: a mapping of column names to string/callable + specifying the aggregations to perform on those + columns. + + See :ref:`the user guide ` for supported + aggregations. Returns ------- A Series or DataFrame containing the combined results of the - aggregation. + aggregation(s). Examples -------- @@ -655,6 +669,54 @@ def rolling_avg(val, avg): kwargs.update({"chunks": offsets}) return grouped_values.apply_chunks(function, **kwargs) + def transform(self, function): + """Apply an aggregation, then broadcast the result to the group size. + + Parameters + ---------- + function: str or callable + Aggregation to apply to each group. Note that the set of + operations currently supported by `transform` is identical + to that supported by the `agg` method. + + Returns + ------- + A Series or DataFrame of the same size as the input, with the + result of the aggregation per group broadcasted to the group + size. + + Examples + -------- + .. code-block:: python + + import cudf + df = cudf.DataFrame({'a': [2, 1, 1, 2, 2], 'b': [1, 2, 3, 4, 5]}) + df.groupby('a').transform('max') + b + 0 5 + 1 3 + 2 3 + 3 5 + 4 5 + + See also + -------- + cudf.core.groupby.GroupBy.agg + """ + try: + result = self.agg(function) + except TypeError as e: + raise NotImplementedError( + "Currently, `transform()` supports only aggregations." + ) from e + + if not result.index.equals(self.grouping.keys): + result = result._align_to_index( + self.grouping.keys, how="right", allow_non_unique=True + ) + result = result.reset_index(drop=True) + return result + def rolling(self, *args, **kwargs): """ Returns a `RollingGroupby` object that enables rolling window diff --git a/python/cudf/cudf/tests/test_groupby.py b/python/cudf/cudf/tests/test_groupby.py index c73e96de470..f5decd62ea9 100644 --- a/python/cudf/cudf/tests/test_groupby.py +++ b/python/cudf/cudf/tests/test_groupby.py @@ -2362,6 +2362,28 @@ def test_groupby_get_group(pdf, group, name, obj): assert_groupby_results_equal(expected, actual) +@pytest.mark.parametrize( + "by", + [ + "a", + ["a", "b"], + pd.Series([2, 1, 1, 2, 2]), + pd.Series(["b", "a", "a", "b", "b"]), + ], +) +@pytest.mark.parametrize("agg", ["sum", "mean", lambda df: df.mean()]) +def test_groupby_transform_aggregation(by, agg): + gdf = cudf.DataFrame( + {"a": [2, 2, 1, 2, 1], "b": [1, 1, 1, 2, 2], "c": [1, 2, 3, 4, 5]} + ) + pdf = gdf.to_pandas() + + expected = pdf.groupby(by).transform(agg) + actual = gdf.groupby(by).transform(agg) + + assert_groupby_results_equal(expected, actual) + + def test_groupby_select_then_ffill(): pdf = pd.DataFrame( {