diff --git a/python/cudf/cudf/_lib/pylibcudf/groupby.pxd b/python/cudf/cudf/_lib/pylibcudf/groupby.pxd index d06959b3c31..f1b7a25d5f9 100644 --- a/python/cudf/cudf/_lib/pylibcudf/groupby.pxd +++ b/python/cudf/cudf/_lib/pylibcudf/groupby.pxd @@ -37,6 +37,7 @@ cdef class GroupByRequest: cdef class GroupBy: cdef unique_ptr[groupby] c_obj + cdef Table _keys cpdef tuple aggregate(self, list requests) cpdef tuple scan(self, list requests) cpdef tuple shift(self, Table values, list offset, list fill_values) diff --git a/python/cudf/cudf/_lib/pylibcudf/groupby.pyx b/python/cudf/cudf/_lib/pylibcudf/groupby.pyx index a3d5997bad5..7dfbe97741c 100644 --- a/python/cudf/cudf/_lib/pylibcudf/groupby.pyx +++ b/python/cudf/cudf/_lib/pylibcudf/groupby.pyx @@ -98,6 +98,9 @@ cdef class GroupBy: sorted keys_are_sorted=sorted.NO ): self.c_obj.reset(new groupby(keys.view(), null_handling, keys_are_sorted)) + # keep a reference to the keys table so it doesn't get + # deallocated from under us: + self._keys = keys @staticmethod cdef tuple _parse_outputs( @@ -254,14 +257,14 @@ cdef class GroupBy: ---------- values : Table, optional The columns to get group labels for. If not specified, - an empty table is returned for the group values. + `None` is returned for the group values. Returns ------- Tuple[Table, Table, List[int]] A tuple of tables containing three items: - A table of group keys - - A table of group values + - A table of group values or None - A list of integer offsets into the tables """ @@ -278,6 +281,6 @@ cdef class GroupBy: c_groups = dereference(self.c_obj).get_groups() return ( Table.from_libcudf(move(c_groups.keys)), - Table([]), + None, c_groups.offsets, ) diff --git a/python/cudf/cudf/tests/test_groupby.py b/python/cudf/cudf/tests/test_groupby.py index ca1e5f74d75..625be44e4dc 100644 --- a/python/cudf/cudf/tests/test_groupby.py +++ b/python/cudf/cudf/tests/test_groupby.py @@ -3762,4 +3762,5 @@ def test_groupby_internal_groups_empty(gdf): # test that we don't segfault when calling the internal # .groups() method with an empty list: gb = gdf.groupby("y")._groupby - gb.groups([]) + _, grouped_vals, _ = gb.groups([]) + assert grouped_vals is None