Skip to content

Commit

Permalink
Add groupby.transform (only support for aggregations) (#10005)
Browse files Browse the repository at this point in the history
Closes #4522

This PR adds support for doing groupby aggregations via the `transform()` API, where the result of the aggregation is broadcasted to the size of the group. Note that more general transformations are not supported at this time.

Authors:
  - Ashwin Srinath (https://github.com/shwina)

Approvers:
  - Michael Wang (https://github.com/isVoid)
  - GALI PREM SAGAR (https://github.com/galipremsagar)

URL: #10005
  • Loading branch information
shwina authored Jan 20, 2022
1 parent 2bd7320 commit e78f47a
Show file tree
Hide file tree
Showing 4 changed files with 109 additions and 1 deletion.
1 change: 1 addition & 0 deletions docs/cudf/source/api_docs/groupby.rst
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ Function application
SeriesGroupBy.aggregate
DataFrameGroupBy.aggregate
GroupBy.pipe
GroupBy.transform

Computations / descriptive stats
--------------------------------
Expand Down
23 changes: 23 additions & 0 deletions docs/cudf/source/basics/groupby.rst
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
.. _basics.groupby:

GroupBy
=======

Expand Down Expand Up @@ -220,6 +222,27 @@ Limitations
.. |describe| replace:: ``describe``
.. _describe: https://pandas.pydata.org/pandas-docs/stable/user_guide/groupby.html#flexible-apply


Transform
---------

The ``.transform()`` method aggregates per group, and broadcasts the
result to the group size, resulting in a Series/DataFrame that is of
the same size as the input Series/DataFrame.

.. code:: python
>>> import cudf
>>> df = cudf.DataFrame({'a': [2, 1, 1, 2, 2], 'b': [1, 2, 3, 4, 5]})
>>> df.groupby('a').transform('max')
b
0 5
1 3
2 3
3 5
4 5
Rolling window calculations
---------------------------

Expand Down
64 changes: 63 additions & 1 deletion python/cudf/cudf/core/groupby/groupby.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,11 +184,25 @@ def agg(self, func):
Parameters
----------
func : str, callable, list or dict
Argument specifying the aggregation(s) to perform on the
groups. `func` can be any of the following:
- string: the name of a supported aggregation
- callable: a function that accepts a Series/DataFrame and
performs a supported operation on it.
- list: a list of strings/callables specifying the
aggregations to perform on every column.
- dict: a mapping of column names to string/callable
specifying the aggregations to perform on those
columns.
See :ref:`the user guide <basics.groupby>` for supported
aggregations.
Returns
-------
A Series or DataFrame containing the combined results of the
aggregation.
aggregation(s).
Examples
--------
Expand Down Expand Up @@ -655,6 +669,54 @@ def rolling_avg(val, avg):
kwargs.update({"chunks": offsets})
return grouped_values.apply_chunks(function, **kwargs)

def transform(self, function):
"""Apply an aggregation, then broadcast the result to the group size.
Parameters
----------
function: str or callable
Aggregation to apply to each group. Note that the set of
operations currently supported by `transform` is identical
to that supported by the `agg` method.
Returns
-------
A Series or DataFrame of the same size as the input, with the
result of the aggregation per group broadcasted to the group
size.
Examples
--------
.. code-block:: python
import cudf
df = cudf.DataFrame({'a': [2, 1, 1, 2, 2], 'b': [1, 2, 3, 4, 5]})
df.groupby('a').transform('max')
b
0 5
1 3
2 3
3 5
4 5
See also
--------
cudf.core.groupby.GroupBy.agg
"""
try:
result = self.agg(function)
except TypeError as e:
raise NotImplementedError(
"Currently, `transform()` supports only aggregations."
) from e

if not result.index.equals(self.grouping.keys):
result = result._align_to_index(
self.grouping.keys, how="right", allow_non_unique=True
)
result = result.reset_index(drop=True)
return result

def rolling(self, *args, **kwargs):
"""
Returns a `RollingGroupby` object that enables rolling window
Expand Down
22 changes: 22 additions & 0 deletions python/cudf/cudf/tests/test_groupby.py
Original file line number Diff line number Diff line change
Expand Up @@ -2362,6 +2362,28 @@ def test_groupby_get_group(pdf, group, name, obj):
assert_groupby_results_equal(expected, actual)


@pytest.mark.parametrize(
"by",
[
"a",
["a", "b"],
pd.Series([2, 1, 1, 2, 2]),
pd.Series(["b", "a", "a", "b", "b"]),
],
)
@pytest.mark.parametrize("agg", ["sum", "mean", lambda df: df.mean()])
def test_groupby_transform_aggregation(by, agg):
gdf = cudf.DataFrame(
{"a": [2, 2, 1, 2, 1], "b": [1, 1, 1, 2, 2], "c": [1, 2, 3, 4, 5]}
)
pdf = gdf.to_pandas()

expected = pdf.groupby(by).transform(agg)
actual = gdf.groupby(by).transform(agg)

assert_groupby_results_equal(expected, actual)


def test_groupby_select_then_ffill():
pdf = pd.DataFrame(
{
Expand Down

0 comments on commit e78f47a

Please sign in to comment.