Skip to content

Commit

Permalink
Use correct index when returning Series from GroupBy.apply() (#9016)
Browse files Browse the repository at this point in the history
Closes #8898 

Originally, when returning a Series from a `GroupBy.apply()` operation, we would pass in `self.grouping.keys[offsets[:-1]]` as the index, which was meant to grab each unique group key, assuming that `self.grouping.keys` is sorted. However, because it is not sorted, this just ends up grabbing 5 group keys at random.

Since we are already calling `GroupBy._grouped()` in this operation, we can use the `group_names` returned by that as the index instead, which is what the result of `self.grouping.keys[offsets[:-1]]` would be if `self.grouping.keys` was sorted.

Authors:
  - Charles Blackmon-Luca (https://github.com/charlesbluca)

Approvers:
  - Michael Wang (https://github.com/isVoid)

URL: #9016
  • Loading branch information
charlesbluca authored Aug 13, 2021
1 parent fb29071 commit 233943d
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 8 deletions.
6 changes: 2 additions & 4 deletions python/cudf/cudf/core/groupby/groupby.py
Original file line number Diff line number Diff line change
Expand Up @@ -449,7 +449,7 @@ def mult(df):
"""
if not callable(function):
raise TypeError(f"type {type(function)} is not callable")
_, offsets, _, grouped_values = self._grouped()
group_names, offsets, _, grouped_values = self._grouped()

ngroups = len(offsets) - 1
if ngroups > self._MAX_GROUPS_BEFORE_WARN:
Expand All @@ -467,9 +467,7 @@ def mult(df):
return self.obj.__class__()

if cudf.utils.dtypes.is_scalar(chunk_results[0]):
result = cudf.Series(
chunk_results, index=self.grouping.keys[offsets[:-1]]
)
result = cudf.Series(chunk_results, index=group_names)
result.index.names = self.grouping.names
elif isinstance(chunk_results[0], cudf.Series):
result = cudf.concat(chunk_results, axis=1).T
Expand Down
8 changes: 4 additions & 4 deletions python/cudf/cudf/tests/test_groupby.py
Original file line number Diff line number Diff line change
Expand Up @@ -2106,11 +2106,11 @@ def test_groupby_first(data, agg):
assert_groupby_results_equal(expect, got, check_dtype=False)


def test_groupby_apply_series_name():
def test_groupby_apply_series():
def foo(x):
return x.sum()

got = make_frame(DataFrame, 3).groupby("x").y.apply(foo)
expect = make_frame(pd.DataFrame, 3).groupby("x").y.apply(foo)
got = make_frame(DataFrame, 100).groupby("x").y.apply(foo)
expect = make_frame(pd.DataFrame, 100).groupby("x").y.apply(foo)

assert expect.name == got.name
assert_groupby_results_equal(expect, got)

0 comments on commit 233943d

Please sign in to comment.