Skip to content

Commit

Permalink
Reconstruct dtypes correctly for list aggs of struct columns (#12290)
Browse files Browse the repository at this point in the history
As usual when returning from libcudf, we need to reconstruct a struct
dtype with appropriate labels. For groupby.agg(list) this can be done
by matching on the element_type of the result column and
reconstructing with a new list dtype with a leaf from the original
column.

Closes #11765
Closes #11907

Authors:
  - Lawrence Mitchell (https://github.com/wence-)

Approvers:
  - Ashwin Srinath (https://github.com/shwina)

URL: #12290
  • Loading branch information
wence- authored Jan 23, 2023
1 parent 2bfe9e3 commit 24efb9c
Show file tree
Hide file tree
Showing 4 changed files with 50 additions and 20 deletions.
36 changes: 27 additions & 9 deletions python/cudf/cudf/_lib/groupby.pyx
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright (c) 2020-2022, NVIDIA CORPORATION.
# Copyright (c) 2020-2023, NVIDIA CORPORATION.

from pandas.core.groupby.groupby import DataError

Expand Down Expand Up @@ -41,15 +41,33 @@ from cudf._lib.cpp.types cimport size_type

# The sets below define the possible aggregations that can be performed on
# different dtypes. These strings must be elements of the AggregationKind enum.
_CATEGORICAL_AGGS = {"COUNT", "SIZE", "NUNIQUE", "UNIQUE"}
_STRING_AGGS = {"COUNT", "SIZE", "MAX", "MIN", "NUNIQUE", "NTH", "COLLECT",
"UNIQUE"}
# The libcudf infrastructure exists for "COLLECT" support on
# categoricals, but the dtype support in python does not.
_CATEGORICAL_AGGS = {"COUNT", "NUNIQUE", "SIZE", "UNIQUE"}
_STRING_AGGS = {
"COLLECT",
"COUNT",
"MAX",
"MIN",
"NTH",
"NUNIQUE",
"SIZE",
"UNIQUE",
}
_LIST_AGGS = {"COLLECT"}
_STRUCT_AGGS = {"CORRELATION", "COVARIANCE"}
_INTERVAL_AGGS = set()
_DECIMAL_AGGS = {"COUNT", "SUM", "ARGMIN", "ARGMAX", "MIN", "MAX", "NUNIQUE",
"NTH", "COLLECT"}

_STRUCT_AGGS = {"COLLECT", "CORRELATION", "COVARIANCE"}
_INTERVAL_AGGS = {"COLLECT"}
_DECIMAL_AGGS = {
"ARGMIN",
"ARGMAX",
"COLLECT",
"COUNT",
"MAX",
"MIN",
"NTH",
"NUNIQUE",
"SUM",
}
# workaround for https://github.com/cython/cython/issues/3885
ctypedef const scalar constscalar

Expand Down
6 changes: 3 additions & 3 deletions python/cudf/cudf/core/dtypes.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright (c) 2020-2022, NVIDIA CORPORATION.
# Copyright (c) 2020-2023, NVIDIA CORPORATION.

import decimal
import operator
Expand Down Expand Up @@ -344,7 +344,7 @@ def __init__(self, element_type: Any) -> None:
)
self._typ = pa.list_(element_type)

@property
@cached_property
def element_type(self) -> Dtype:
"""
Returns the element type of the ``ListDtype``.
Expand Down Expand Up @@ -373,7 +373,7 @@ def element_type(self) -> Dtype:
else:
return cudf.dtype(self._typ.value_type.to_pandas_dtype()).name

@property
@cached_property
def leaf_type(self):
"""
Returns the type of the leaf values.
Expand Down
16 changes: 13 additions & 3 deletions python/cudf/cudf/core/groupby/groupby.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright (c) 2020-2022, NVIDIA CORPORATION.
# Copyright (c) 2020-2023, NVIDIA CORPORATION.

import itertools
import pickle
Expand Down Expand Up @@ -457,6 +457,7 @@ def agg(self, func):
2 3.0 3.00 1.0 1.0
"""
column_names, columns, normalized_aggs = self._normalize_aggs(func)
orig_dtypes = tuple(c.dtype for c in columns)

# Note: When there are no key columns, the below produces
# a Float64Index, while Pandas returns an Int64Index
Expand All @@ -473,15 +474,24 @@ def agg(self, func):

multilevel = _is_multi_agg(func)
data = {}
for col_name, aggs, cols in zip(
column_names, included_aggregations, result_columns
for col_name, aggs, cols, orig_dtype in zip(
column_names,
included_aggregations,
result_columns,
orig_dtypes,
):
for agg, col in zip(aggs, cols):
if multilevel:
agg_name = agg.__name__ if callable(agg) else agg
key = (col_name, agg_name)
else:
key = col_name
if (
agg in {list, "collect"}
and orig_dtype != col.dtype.element_type
):
# Structs lose their labels which we reconstruct here
col = col._with_type_metadata(cudf.ListDtype(orig_dtype))
data[key] = col
data = ColumnAccessor(data, multiindex=multilevel)
if not multilevel:
Expand Down
12 changes: 7 additions & 5 deletions python/cudf/cudf/tests/test_groupby.py
Original file line number Diff line number Diff line change
Expand Up @@ -1595,11 +1595,13 @@ def test_groupby_list_of_structs(list_agg):
}
)
gdf = cudf.from_pandas(pdf)

with pytest.raises(
pd.errors.DataError if PANDAS_GE_150 else pd.core.base.DataError
):
gdf.groupby("a").agg({"b": list_agg})
grouped = gdf.groupby("a").agg({"b": list_agg})
assert_groupby_results_equal(
pdf.groupby("a").agg({"b": list}),
grouped,
check_dtype=True,
)
assert grouped["b"].dtype.element_type == gdf["b"].dtype


@pytest.mark.parametrize("list_agg", [list, "collect"])
Expand Down

0 comments on commit 24efb9c

Please sign in to comment.