From dcdfa9609a33e5fbdfb8b8d04c3af03b2fa982bf Mon Sep 17 00:00:00 2001 From: Marco Gorelli <33491632+MarcoGorelli@users.noreply.github.com> Date: Mon, 23 Dec 2024 08:59:26 +0000 Subject: [PATCH 1/2] fix: DataFrameGroupBy.get_group was raising with length>1 tuples --- python/cudf/cudf/core/groupby/groupby.py | 2 +- python/cudf/cudf/tests/test_groupby.py | 7 +++++++ 2 files changed, 8 insertions(+), 1 deletion(-) diff --git a/python/cudf/cudf/core/groupby/groupby.py b/python/cudf/cudf/core/groupby/groupby.py index be3cc410174..a6af8e5dff4 100644 --- a/python/cudf/cudf/core/groupby/groupby.py +++ b/python/cudf/cudf/core/groupby/groupby.py @@ -641,7 +641,7 @@ def get_group(self, name, obj=None): "instead of ``gb.get_group(name, obj=df)``.", FutureWarning, ) - if is_list_like(self._by): + if is_list_like(self._by) and len(self._by) == 1: if isinstance(name, tuple) and len(name) == 1: name = name[0] else: diff --git a/python/cudf/cudf/tests/test_groupby.py b/python/cudf/cudf/tests/test_groupby.py index db4f3cd3c9f..5abb3ff085d 100644 --- a/python/cudf/cudf/tests/test_groupby.py +++ b/python/cudf/cudf/tests/test_groupby.py @@ -4076,6 +4076,13 @@ def test_get_group_list_like(): df.groupby(["a"]).get_group([1]) +def test_get_group_list_like_len_2(): + df = cudf.DataFrame({"a": [1, 2, 3], "b": [4, 5, 6], 'c': [3, 2, 1]}) + result = df.groupby(["a", "b"]).get_group((1, 4)) + expected = df.to_pandas().groupby(["a", "b"]).get_group((1, 4)) + assert_eq(result, expected) + + def test_size_as_index_false(): df = pd.DataFrame({"a": [1, 2, 1], "b": [1, 2, 3]}, columns=["a", "b"]) expected = df.groupby("a", as_index=False).size() From fdd6508e6d1eb0a59713ab08027da24bcf05e497 Mon Sep 17 00:00:00 2001 From: Marco Edward Gorelli Date: Mon, 23 Dec 2024 16:29:17 +0000 Subject: [PATCH 2/2] Update python/cudf/cudf/tests/test_groupby.py --- python/cudf/cudf/tests/test_groupby.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/cudf/cudf/tests/test_groupby.py b/python/cudf/cudf/tests/test_groupby.py index 5abb3ff085d..23950f044f8 100644 --- a/python/cudf/cudf/tests/test_groupby.py +++ b/python/cudf/cudf/tests/test_groupby.py @@ -4077,7 +4077,7 @@ def test_get_group_list_like(): def test_get_group_list_like_len_2(): - df = cudf.DataFrame({"a": [1, 2, 3], "b": [4, 5, 6], 'c': [3, 2, 1]}) + df = cudf.DataFrame({"a": [1, 2, 3], "b": [4, 5, 6], "c": [3, 2, 1]}) result = df.groupby(["a", "b"]).get_group((1, 4)) expected = df.to_pandas().groupby(["a", "b"]).get_group((1, 4)) assert_eq(result, expected)