Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add groupby.transform (only support for aggregations) #10005

Merged
merged 13 commits into from
Jan 20, 2022
1 change: 1 addition & 0 deletions docs/cudf/source/api_docs/groupby.rst
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ Function application
SeriesGroupBy.aggregate
DataFrameGroupBy.aggregate
GroupBy.pipe
GroupBy.transform

Computations / descriptive stats
--------------------------------
Expand Down
21 changes: 21 additions & 0 deletions docs/cudf/source/basics/groupby.rst
Original file line number Diff line number Diff line change
Expand Up @@ -220,6 +220,27 @@ Limitations
.. |describe| replace:: ``describe``
.. _describe: https://pandas.pydata.org/pandas-docs/stable/user_guide/groupby.html#flexible-apply


Transform
---------
shwina marked this conversation as resolved.
Show resolved Hide resolved

The `.transform()` method aggregates per group, and broadcasts the
shwina marked this conversation as resolved.
Show resolved Hide resolved
result to the group size, resulting in a Series/DataFrame that is of
the same size as the input Series/DataFrame.

.. code:: python

>>> import cudf
>>> df = cudf.DataFrame({'a': [2, 1, 1, 2, 2], 'b': [1, 2, 3, 4, 5]})
>>> df.groupby('a').transform('max')
b
0 5
1 3
2 3
3 5
4 5


Rolling window calculations
---------------------------

Expand Down
45 changes: 45 additions & 0 deletions python/cudf/cudf/core/groupby/groupby.py
Original file line number Diff line number Diff line change
Expand Up @@ -655,6 +655,51 @@ def rolling_avg(val, avg):
kwargs.update({"chunks": offsets})
return grouped_values.apply_chunks(function, **kwargs)

def transform(self, function):
"""Apply an aggregation, then broadcast the result to the group size.

Parameters
----------
function: str or callable
Aggregation to apply to each group. Note that currently,
only aggregations are supported by `transform`.
isVoid marked this conversation as resolved.
Show resolved Hide resolved

Returns
-------
A Series or DataFrame of the same size as the input, with the
result of the aggregation per group broadcasted to the group
size.

Examples
--------
.. code-block:: python

import cudf
df = cudf.DataFrame({'a': [2, 1, 1, 2, 2], 'b': [1, 2, 3, 4, 5]})
df.groupby('a').transform('max')
b
0 5
1 3
2 3
3 5
4 5

"""
try:
# first, try aggregating:
shwina marked this conversation as resolved.
Show resolved Hide resolved
result = self.agg(function)
except TypeError as e:
raise NotImplementedError(
"Currently, `transform()` supports only aggregations."
) from e

if not result.index.equals(self.grouping.keys):
result = result._align_to_index(
self.grouping.keys, how="right", allow_non_unique=True
)
result = result.reset_index(drop=True)
return result

def rolling(self, *args, **kwargs):
"""
Returns a `RollingGroupby` object that enables rolling window
Expand Down
22 changes: 22 additions & 0 deletions python/cudf/cudf/tests/test_groupby.py
Original file line number Diff line number Diff line change
Expand Up @@ -2362,4 +2362,26 @@ def test_groupby_get_group(pdf, group, name, obj):
assert_groupby_results_equal(expected, actual)


@pytest.mark.parametrize(
"by",
[
"a",
["a", "b"],
pd.Series([2, 1, 1, 2, 2]),
pd.Series(["b", "a", "a", "b", "b"]),
],
)
@pytest.mark.parametrize("agg", ["sum", "mean", lambda df: df.mean()])
def test_groupby_transform_aggregation(by, agg):
gdf = cudf.DataFrame(
{"a": [2, 2, 1, 2, 1], "b": [1, 1, 1, 2, 2], "c": [1, 2, 3, 4, 5]}
)
pdf = gdf.to_pandas()

expected = pdf.groupby(by).transform(agg)
actual = gdf.groupby(by).transform(agg)

assert_groupby_results_equal(expected, actual)


# TODO: Add a test including datetime64[ms] column in input data