Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add groupby.transform (only support for aggregations) #10005

Merged
merged 13 commits into from
Jan 20, 2022
1 change: 1 addition & 0 deletions docs/cudf/source/api_docs/groupby.rst
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ Function application
SeriesGroupBy.aggregate
DataFrameGroupBy.aggregate
GroupBy.pipe
GroupBy.transform

Computations / descriptive stats
--------------------------------
Expand Down
23 changes: 23 additions & 0 deletions docs/cudf/source/basics/groupby.rst
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
.. _basics.groupby:

GroupBy
=======

Expand Down Expand Up @@ -220,6 +222,27 @@ Limitations
.. |describe| replace:: ``describe``
.. _describe: https://pandas.pydata.org/pandas-docs/stable/user_guide/groupby.html#flexible-apply


Transform
---------

The ``.transform()`` method aggregates per group, and broadcasts the
result to the group size, resulting in a Series/DataFrame that is of
the same size as the input Series/DataFrame.

.. code:: python

>>> import cudf
>>> df = cudf.DataFrame({'a': [2, 1, 1, 2, 2], 'b': [1, 2, 3, 4, 5]})
>>> df.groupby('a').transform('max')
b
0 5
1 3
2 3
3 5
4 5


Rolling window calculations
---------------------------

Expand Down
64 changes: 63 additions & 1 deletion python/cudf/cudf/core/groupby/groupby.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,11 +184,25 @@ def agg(self, func):
Parameters
----------
func : str, callable, list or dict
Argument specifying the aggregation(s) to perform on the
groups. `func` can be any of the following:

- string: the name of a supported aggregation
- callable: a function that accepts a Series/DataFrame and
performs a supported operation on it.
- list: a list of strings/callables specifying the
aggregations to perform on every column.
- dict: a mapping of column names to string/callable
specifying the aggregations to perform on those
columns.

See :ref:`the user guide <basics.groupby>` for supported
aggregations.

Returns
-------
A Series or DataFrame containing the combined results of the
aggregation.
aggregation(s).

Examples
--------
Expand Down Expand Up @@ -655,6 +669,54 @@ def rolling_avg(val, avg):
kwargs.update({"chunks": offsets})
return grouped_values.apply_chunks(function, **kwargs)

def transform(self, function):
"""Apply an aggregation, then broadcast the result to the group size.

Parameters
----------
function: str or callable
Aggregation to apply to each group. Note that the set of
operations currently supported by `transform` is identical
to that supported by the `agg` method.

Returns
-------
A Series or DataFrame of the same size as the input, with the
result of the aggregation per group broadcasted to the group
size.

Examples
--------
.. code-block:: python

import cudf
df = cudf.DataFrame({'a': [2, 1, 1, 2, 2], 'b': [1, 2, 3, 4, 5]})
df.groupby('a').transform('max')
b
0 5
1 3
2 3
3 5
4 5

See also
--------
cudf.core.groupby.GroupBy.agg
"""
try:
result = self.agg(function)
except TypeError as e:
raise NotImplementedError(
"Currently, `transform()` supports only aggregations."
) from e

if not result.index.equals(self.grouping.keys):
result = result._align_to_index(
self.grouping.keys, how="right", allow_non_unique=True
)
result = result.reset_index(drop=True)
return result

def rolling(self, *args, **kwargs):
"""
Returns a `RollingGroupby` object that enables rolling window
Expand Down
22 changes: 22 additions & 0 deletions python/cudf/cudf/tests/test_groupby.py
Original file line number Diff line number Diff line change
Expand Up @@ -2362,6 +2362,28 @@ def test_groupby_get_group(pdf, group, name, obj):
assert_groupby_results_equal(expected, actual)


@pytest.mark.parametrize(
"by",
[
"a",
["a", "b"],
pd.Series([2, 1, 1, 2, 2]),
pd.Series(["b", "a", "a", "b", "b"]),
],
)
@pytest.mark.parametrize("agg", ["sum", "mean", lambda df: df.mean()])
def test_groupby_transform_aggregation(by, agg):
gdf = cudf.DataFrame(
{"a": [2, 2, 1, 2, 1], "b": [1, 1, 1, 2, 2], "c": [1, 2, 3, 4, 5]}
)
pdf = gdf.to_pandas()

expected = pdf.groupby(by).transform(agg)
actual = gdf.groupby(by).transform(agg)

assert_groupby_results_equal(expected, actual)


def test_groupby_select_then_ffill():
pdf = pd.DataFrame(
{
Expand Down