From ce49adee3de481f331831db5a7e0452fe3889800 Mon Sep 17 00:00:00 2001 From: Matthew Roeschke <10647082+mroeschke@users.noreply.github.com> Date: Wed, 30 Oct 2024 16:27:11 -0700 Subject: [PATCH 1/2] Fix groupby.get_group with length-1 tuple with list-like grouper --- python/cudf/cudf/core/groupby/groupby.py | 5 +++++ python/cudf/cudf/tests/test_groupby.py | 13 +++++++++++++ 2 files changed, 18 insertions(+) diff --git a/python/cudf/cudf/core/groupby/groupby.py b/python/cudf/cudf/core/groupby/groupby.py index 6630bd96c01..e59b948aba9 100644 --- a/python/cudf/cudf/core/groupby/groupby.py +++ b/python/cudf/cudf/core/groupby/groupby.py @@ -481,6 +481,11 @@ def get_group(self, name, obj=None): "instead of ``gb.get_group(name, obj=df)``.", FutureWarning, ) + if is_list_like(self._by): + if isinstance(name, tuple) and len(name) == 1: + name = name[0] + else: + raise KeyError(name) return obj.iloc[self.indices[name]] @_performance_tracking diff --git a/python/cudf/cudf/tests/test_groupby.py b/python/cudf/cudf/tests/test_groupby.py index 6b222841622..9767642fe49 100644 --- a/python/cudf/cudf/tests/test_groupby.py +++ b/python/cudf/cudf/tests/test_groupby.py @@ -4059,3 +4059,16 @@ def test_ndim(): pgb = pser.groupby([0, 0, 1]) ggb = gser.groupby(cudf.Series([0, 0, 1])) assert pgb.ndim == ggb.ndim + + +def test_get_group_list_like(): + df = cudf.DataFrame({"a": [1, 2, 3], "b": [4, 5, 6]}) + result = df.groupby(["a"]).get_group((1,)) + expected = df.to_pandas().groupby(["a"]).get_group((1,)) + assert_eq(result, expected) + + with pytest.raises(KeyError): + df.groupby(["a"]).get_group((1, 2)) + + with pytest.raises(KeyError): + df.groupby(["a"]).get_group([1]) From af4d6776f5179028b62e89610ff1877b983ae3e2 Mon Sep 17 00:00:00 2001 From: Matthew Roeschke <10647082+mroeschke@users.noreply.github.com> Date: Thu, 31 Oct 2024 09:47:35 -0700 Subject: [PATCH 2/2] Add skipif for pandas<2.2 --- python/cudf/cudf/tests/test_groupby.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/python/cudf/cudf/tests/test_groupby.py b/python/cudf/cudf/tests/test_groupby.py index 9767642fe49..e4422e204bc 100644 --- a/python/cudf/cudf/tests/test_groupby.py +++ b/python/cudf/cudf/tests/test_groupby.py @@ -4061,6 +4061,9 @@ def test_ndim(): assert pgb.ndim == ggb.ndim +@pytest.mark.skipif( + not PANDAS_GE_220, reason="pandas behavior applicable in >=2.2" +) def test_get_group_list_like(): df = cudf.DataFrame({"a": [1, 2, 3], "b": [4, 5, 6]}) result = df.groupby(["a"]).get_group((1,))