Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Reconstruct dtypes correctly for list aggs of struct columns #12290

Merged
merged 8 commits into from
Jan 23, 2023
36 changes: 27 additions & 9 deletions python/cudf/cudf/_lib/groupby.pyx
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright (c) 2020-2022, NVIDIA CORPORATION.
# Copyright (c) 2020-2023, NVIDIA CORPORATION.

from pandas.core.groupby.groupby import DataError

Expand Down Expand Up @@ -41,15 +41,33 @@ from cudf._lib.cpp.types cimport size_type

# The sets below define the possible aggregations that can be performed on
# different dtypes. These strings must be elements of the AggregationKind enum.
_CATEGORICAL_AGGS = {"COUNT", "SIZE", "NUNIQUE", "UNIQUE"}
_STRING_AGGS = {"COUNT", "SIZE", "MAX", "MIN", "NUNIQUE", "NTH", "COLLECT",
"UNIQUE"}
# The libcudf infrastructure exists for "COLLECT" support on
# categoricals, but the dtype support in python does not.
_CATEGORICAL_AGGS = {"COUNT", "NUNIQUE", "SIZE", "UNIQUE"}
_STRING_AGGS = {
"COLLECT",
"COUNT",
"MAX",
"MIN",
"NTH",
"NUNIQUE",
"SIZE",
"UNIQUE",
}
_LIST_AGGS = {"COLLECT"}
_STRUCT_AGGS = {"CORRELATION", "COVARIANCE"}
_INTERVAL_AGGS = set()
_DECIMAL_AGGS = {"COUNT", "SUM", "ARGMIN", "ARGMAX", "MIN", "MAX", "NUNIQUE",
"NTH", "COLLECT"}

_STRUCT_AGGS = {"COLLECT", "CORRELATION", "COVARIANCE"}
_INTERVAL_AGGS = {"COLLECT"}
_DECIMAL_AGGS = {
"ARGMIN",
"ARGMAX",
"COLLECT",
"COUNT",
"MAX",
"MIN",
"NTH",
"NUNIQUE",
"SUM",
}
# workaround for https://github.com/cython/cython/issues/3885
ctypedef const scalar constscalar

Expand Down
6 changes: 3 additions & 3 deletions python/cudf/cudf/core/dtypes.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright (c) 2020-2022, NVIDIA CORPORATION.
# Copyright (c) 2020-2023, NVIDIA CORPORATION.

import decimal
import operator
Expand Down Expand Up @@ -344,7 +344,7 @@ def __init__(self, element_type: Any) -> None:
)
self._typ = pa.list_(element_type)

@property
@cached_property
def element_type(self) -> Dtype:
"""
Returns the element type of the ``ListDtype``.
Expand Down Expand Up @@ -373,7 +373,7 @@ def element_type(self) -> Dtype:
else:
return cudf.dtype(self._typ.value_type.to_pandas_dtype()).name

@property
@cached_property
def leaf_type(self):
"""
Returns the type of the leaf values.
Expand Down
16 changes: 13 additions & 3 deletions python/cudf/cudf/core/groupby/groupby.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright (c) 2020-2022, NVIDIA CORPORATION.
# Copyright (c) 2020-2023, NVIDIA CORPORATION.

import itertools
import pickle
Expand Down Expand Up @@ -457,6 +457,7 @@ def agg(self, func):
2 3.0 3.00 1.0 1.0
"""
column_names, columns, normalized_aggs = self._normalize_aggs(func)
orig_dtypes = tuple(c.dtype for c in columns)

# Note: When there are no key columns, the below produces
# a Float64Index, while Pandas returns an Int64Index
Expand All @@ -473,15 +474,24 @@ def agg(self, func):

multilevel = _is_multi_agg(func)
data = {}
for col_name, aggs, cols in zip(
column_names, included_aggregations, result_columns
for col_name, aggs, cols, orig_dtype in zip(
column_names,
included_aggregations,
result_columns,
orig_dtypes,
):
for agg, col in zip(aggs, cols):
if multilevel:
agg_name = agg.__name__ if callable(agg) else agg
key = (col_name, agg_name)
else:
key = col_name
if (
agg in {list, "collect"}
and orig_dtype != col.dtype.element_type
):
# Structs lose their labels which we reconstruct here
col = col._with_type_metadata(cudf.ListDtype(orig_dtype))
data[key] = col
data = ColumnAccessor(data, multiindex=multilevel)
if not multilevel:
Expand Down
12 changes: 7 additions & 5 deletions python/cudf/cudf/tests/test_groupby.py
Original file line number Diff line number Diff line change
Expand Up @@ -1595,11 +1595,13 @@ def test_groupby_list_of_structs(list_agg):
}
)
gdf = cudf.from_pandas(pdf)

with pytest.raises(
pd.errors.DataError if PANDAS_GE_150 else pd.core.base.DataError
):
gdf.groupby("a").agg({"b": list_agg})
grouped = gdf.groupby("a").agg({"b": list_agg})
assert_groupby_results_equal(
pdf.groupby("a").agg({"b": list}),
grouped,
check_dtype=True,
)
assert grouped["b"].dtype.element_type == gdf["b"].dtype


@pytest.mark.parametrize("list_agg", [list, "collect"])
Expand Down