From 4948aa25557f07a65ce8d8d4afd8ded66576f3a8 Mon Sep 17 00:00:00 2001 From: Matthew Roeschke <10647082+mroeschke@users.noreply.github.com> Date: Wed, 21 Feb 2024 11:54:50 -1000 Subject: [PATCH] Fix Series.groupby.shift with a MultiIndex (#15098) closes #15087 closes #11259 (The typing annotation is incorrect, but I guess there needs to be a check somewhere to make `_copy_type_metadata` stricter) Authors: - Matthew Roeschke (https://github.com/mroeschke) Approvers: - Bradley Dice (https://github.com/bdice) URL: https://github.com/rapidsai/cudf/pull/15098 --- python/cudf/cudf/core/multiindex.py | 3 ++- python/cudf/cudf/tests/test_groupby.py | 11 ++++++++++- 2 files changed, 12 insertions(+), 2 deletions(-) diff --git a/python/cudf/cudf/core/multiindex.py b/python/cudf/cudf/core/multiindex.py index 9466d172eb1..df1b1ea10cd 100644 --- a/python/cudf/cudf/core/multiindex.py +++ b/python/cudf/cudf/core/multiindex.py @@ -2037,7 +2037,8 @@ def _copy_type_metadata( self: MultiIndex, other: MultiIndex, *, override_dtypes=None ) -> MultiIndex: res = super()._copy_type_metadata(other) - res._names = other._names + if isinstance(other, MultiIndex): + res._names = other._names return res @_cudf_nvtx_annotate diff --git a/python/cudf/cudf/tests/test_groupby.py b/python/cudf/cudf/tests/test_groupby.py index e8dbdd35352..c22e47bdf06 100644 --- a/python/cudf/cudf/tests/test_groupby.py +++ b/python/cudf/cudf/tests/test_groupby.py @@ -3308,7 +3308,6 @@ def test_groupby_pct_change(data, gkey, periods, fill_method): assert_eq(expected, actual) -@pytest.mark.xfail(reason="https://github.com/rapidsai/cudf/issues/11259") @pytest.mark.parametrize("periods", [-5, 5]) def test_groupby_pct_change_multiindex_dataframe(periods): gdf = cudf.DataFrame( @@ -3812,3 +3811,13 @@ def test_groupby_internal_groups_empty(gdf): gb = gdf.groupby("y")._groupby _, _, grouped_vals = gb.groups([]) assert grouped_vals == [] + + +def test_groupby_shift_series_multiindex(): + idx = cudf.MultiIndex.from_tuples( + [("a", 1), ("a", 2), ("b", 1), ("b", 2)], names=["f", "s"] + ) + ser = Series(range(4), index=idx) + result = ser.groupby(level=0).shift(1) + expected = ser.to_pandas().groupby(level=0).shift(1) + assert_eq(expected, result)