Skip to content
/ cudf Public
forked from rapidsai/cudf

Commit

Permalink
groupby: Transfer struct dtype into collected aggregate
Browse files Browse the repository at this point in the history
As usual when returning from libcudf, we need to reconstruct a struct
dtype with appropriate labels. For groupby.agg(list) this can be done
by matching on the element_type of the result column and
reconstructing with a new list dtype with a leaf from the original
column.

Closes rapidsai#11765.
  • Loading branch information
wence- committed Jan 17, 2023
1 parent edc46d3 commit 44cc3be
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 7 deletions.
18 changes: 15 additions & 3 deletions python/cudf/cudf/core/groupby/groupby.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright (c) 2020-2022, NVIDIA CORPORATION.
# Copyright (c) 2020-2023, NVIDIA CORPORATION.

import itertools
import pickle
Expand All @@ -21,6 +21,7 @@
from cudf.core.abc import Serializable
from cudf.core.column.column import ColumnBase, arange, as_column
from cudf.core.column_accessor import ColumnAccessor
from cudf.core.dtypes import is_categorical_dtype
from cudf.core.mixins import Reducible, Scannable
from cudf.core.multiindex import MultiIndex
from cudf.utils.utils import GetAttrGetItemMixin, _cudf_nvtx_annotate
Expand Down Expand Up @@ -457,6 +458,7 @@ def agg(self, func):
2 3.0 3.00 1.0 1.0
"""
column_names, columns, normalized_aggs = self._normalize_aggs(func)
orig_dtypes = tuple(c.dtype for c in columns)

# Note: When there are no key columns, the below produces
# a Float64Index, while Pandas returns an Int64Index
Expand All @@ -473,15 +475,25 @@ def agg(self, func):

multilevel = _is_multi_agg(func)
data = {}
for col_name, aggs, cols in zip(
column_names, included_aggregations, result_columns
for col_name, aggs, cols, orig_dtype in zip(
column_names,
included_aggregations,
result_columns,
orig_dtypes,
):
for agg, col in zip(aggs, cols):
if multilevel:
agg_name = agg.__name__ if callable(agg) else agg
key = (col_name, agg_name)
else:
key = col_name
if (
agg in {list, "collect"}
and not is_categorical_dtype(orig_dtype)
and orig_dtype != col.dtype.element_type
):
# Structs lose their labels which we reconstruct here
col = col._with_type_metadata(cudf.ListDtype(orig_dtype))
data[key] = col
data = ColumnAccessor(data, multiindex=multilevel)
if not multilevel:
Expand Down
9 changes: 5 additions & 4 deletions python/cudf/cudf/tests/test_groupby.py
Original file line number Diff line number Diff line change
Expand Up @@ -1596,10 +1596,11 @@ def test_groupby_list_of_structs(list_agg):
)
gdf = cudf.from_pandas(pdf)

with pytest.raises(
pd.errors.DataError if PANDAS_GE_150 else pd.core.base.DataError
):
gdf.groupby("a").agg({"b": list_agg})
assert_groupby_results_equal(
pdf.groupby("a").agg({"b": list_agg}),
gdf.groupby("a").agg({"b": list_agg}),
check_dtype=True,
)


@pytest.mark.parametrize("list_agg", [list, "collect"])
Expand Down

0 comments on commit 44cc3be

Please sign in to comment.