diff --git a/python/cudf/cudf/_lib/aggregation.pxd b/python/cudf/cudf/_lib/aggregation.pxd index bb332c44237..972f95d5aab 100644 --- a/python/cudf/cudf/_lib/aggregation.pxd +++ b/python/cudf/cudf/_lib/aggregation.pxd @@ -4,7 +4,7 @@ from libcpp.memory cimport unique_ptr from cudf._lib.cpp.aggregation cimport aggregation -cdef unique_ptr[aggregation] make_aggregation(op, kwargs=*) except * - cdef class Aggregation: cdef unique_ptr[aggregation] c_obj + +cdef Aggregation make_aggregation(op, kwargs=*) diff --git a/python/cudf/cudf/_lib/aggregation.pyx b/python/cudf/cudf/_lib/aggregation.pyx index 7138bb49743..682d8cbf329 100644 --- a/python/cudf/cudf/_lib/aggregation.pyx +++ b/python/cudf/cudf/_lib/aggregation.pyx @@ -56,85 +56,55 @@ class AggregationKind(Enum): cdef class Aggregation: - def __init__(self, op, **kwargs): - self.c_obj = move(make_aggregation(op, kwargs)) - + """A Cython wrapper for aggregations. + + **This class should never be instantiated using a standard constructor, + only using one of its many factories.** These factories handle mapping + different cudf operations to their libcudf analogs, e.g. + `cudf.DataFrame.idxmin` -> `libcudf.argmin`. Additionally, they perform + any additional configuration needed to translate Python arguments into + their corresponding C++ types (for instance, C++ enumerations used for + flag arguments). The factory approach is necessary to support operations + like `df.agg(lambda x: x.sum())`; such functions are called with this + class as an argument to generation the desired aggregation. + """ @property def kind(self): - return AggregationKind(self.c_obj.get()[0].kind).name.lower() - - -cdef unique_ptr[aggregation] make_aggregation(op, kwargs={}) except *: - """ - Parameters - ---------- - op : str or callable - If callable, must meet one of the following requirements: - - * Is of the form lambda x: x.agg(*args, **kwargs), where - `agg` is the name of a supported aggregation. Used to - to specify aggregations that take arguments, e.g., - `lambda x: x.quantile(0.5)`. - * Is a user defined aggregation function that operates on - group values. In this case, the output dtype must be - specified in the `kwargs` dictionary. - - Returns - ------- - unique_ptr[aggregation] - """ - cdef Aggregation agg - if isinstance(op, str): - agg = getattr(_AggregationFactory, op)(**kwargs) - elif callable(op): - if op is list: - agg = _AggregationFactory.collect() - elif "dtype" in kwargs: - agg = _AggregationFactory.from_udf(op, **kwargs) - else: - agg = op(_AggregationFactory) - else: - raise TypeError("Unknown aggregation {}".format(op)) - return move(agg.c_obj) - -# The Cython pattern below enables us to create an Aggregation -# without ever calling its `__init__` method, which would otherwise -# result in a RecursionError. -cdef class _AggregationFactory: + return AggregationKind(self.c_obj.get()[0].kind).name @classmethod def sum(cls): - cdef Aggregation agg = Aggregation.__new__(Aggregation) + cdef Aggregation agg = cls() agg.c_obj = move(libcudf_aggregation.make_sum_aggregation()) return agg @classmethod def min(cls): - cdef Aggregation agg = Aggregation.__new__(Aggregation) + cdef Aggregation agg = cls() agg.c_obj = move(libcudf_aggregation.make_min_aggregation()) return agg @classmethod def max(cls): - cdef Aggregation agg = Aggregation.__new__(Aggregation) + cdef Aggregation agg = cls() agg.c_obj = move(libcudf_aggregation.make_max_aggregation()) return agg @classmethod def idxmin(cls): - cdef Aggregation agg = Aggregation.__new__(Aggregation) + cdef Aggregation agg = cls() agg.c_obj = move(libcudf_aggregation.make_argmin_aggregation()) return agg @classmethod def idxmax(cls): - cdef Aggregation agg = Aggregation.__new__(Aggregation) + cdef Aggregation agg = cls() agg.c_obj = move(libcudf_aggregation.make_argmax_aggregation()) return agg @classmethod def mean(cls): - cdef Aggregation agg = Aggregation.__new__(Aggregation) + cdef Aggregation agg = cls() agg.c_obj = move(libcudf_aggregation.make_mean_aggregation()) return agg @@ -146,7 +116,7 @@ cdef class _AggregationFactory: else: c_null_handling = libcudf_types.null_policy.INCLUDE - cdef Aggregation agg = Aggregation.__new__(Aggregation) + cdef Aggregation agg = cls() agg.c_obj = move(libcudf_aggregation.make_count_aggregation( c_null_handling )) @@ -154,7 +124,7 @@ cdef class _AggregationFactory: @classmethod def size(cls): - cdef Aggregation agg = Aggregation.__new__(Aggregation) + cdef Aggregation agg = cls() agg.c_obj = move(libcudf_aggregation.make_count_aggregation( ( NullHandling.INCLUDE @@ -164,13 +134,13 @@ cdef class _AggregationFactory: @classmethod def nunique(cls): - cdef Aggregation agg = Aggregation.__new__(Aggregation) + cdef Aggregation agg = cls() agg.c_obj = move(libcudf_aggregation.make_nunique_aggregation()) return agg @classmethod def nth(cls, libcudf_types.size_type size): - cdef Aggregation agg = Aggregation.__new__(Aggregation) + cdef Aggregation agg = cls() agg.c_obj = move( libcudf_aggregation.make_nth_element_aggregation(size) ) @@ -178,49 +148,49 @@ cdef class _AggregationFactory: @classmethod def any(cls): - cdef Aggregation agg = Aggregation.__new__(Aggregation) + cdef Aggregation agg = cls() agg.c_obj = move(libcudf_aggregation.make_any_aggregation()) return agg @classmethod def all(cls): - cdef Aggregation agg = Aggregation.__new__(Aggregation) + cdef Aggregation agg = cls() agg.c_obj = move(libcudf_aggregation.make_all_aggregation()) return agg @classmethod def product(cls): - cdef Aggregation agg = Aggregation.__new__(Aggregation) + cdef Aggregation agg = cls() agg.c_obj = move(libcudf_aggregation.make_product_aggregation()) return agg @classmethod def sum_of_squares(cls): - cdef Aggregation agg = Aggregation.__new__(Aggregation) + cdef Aggregation agg = cls() agg.c_obj = move(libcudf_aggregation.make_sum_of_squares_aggregation()) return agg @classmethod def var(cls, ddof=1): - cdef Aggregation agg = Aggregation.__new__(Aggregation) + cdef Aggregation agg = cls() agg.c_obj = move(libcudf_aggregation.make_variance_aggregation(ddof)) return agg @classmethod def std(cls, ddof=1): - cdef Aggregation agg = Aggregation.__new__(Aggregation) + cdef Aggregation agg = cls() agg.c_obj = move(libcudf_aggregation.make_std_aggregation(ddof)) return agg @classmethod def median(cls): - cdef Aggregation agg = Aggregation.__new__(Aggregation) + cdef Aggregation agg = cls() agg.c_obj = move(libcudf_aggregation.make_median_aggregation()) return agg @classmethod def quantile(cls, q=0.5, interpolation="linear"): - cdef Aggregation agg = Aggregation.__new__(Aggregation) + cdef Aggregation agg = cls() if not pd.api.types.is_list_like(q): q = [q] @@ -240,19 +210,19 @@ cdef class _AggregationFactory: @classmethod def collect(cls): - cdef Aggregation agg = Aggregation.__new__(Aggregation) + cdef Aggregation agg = cls() agg.c_obj = move(libcudf_aggregation.make_collect_list_aggregation()) return agg @classmethod def unique(cls): - cdef Aggregation agg = Aggregation.__new__(Aggregation) + cdef Aggregation agg = cls() agg.c_obj = move(libcudf_aggregation.make_collect_set_aggregation()) return agg @classmethod def from_udf(cls, op, *args, **kwargs): - cdef Aggregation agg = Aggregation.__new__(Aggregation) + cdef Aggregation agg = cls() cdef libcudf_types.type_id tid cdef libcudf_types.data_type out_dtype @@ -282,3 +252,42 @@ cdef class _AggregationFactory: libcudf_aggregation.udf_type.PTX, cpp_str, out_dtype )) return agg + + +cdef Aggregation make_aggregation(op, kwargs=None): + r""" + Parameters + ---------- + op : str or callable + If callable, must meet one of the following requirements: + + * Is of the form lambda x: x.agg(*args, **kwargs), where + `agg` is the name of a supported aggregation. Used to + to specify aggregations that take arguments, e.g., + `lambda x: x.quantile(0.5)`. + * Is a user defined aggregation function that operates on + group values. In this case, the output dtype must be + specified in the `kwargs` dictionary. + \*\*kwargs : dict, optional + Any keyword arguments to be passed to the op. + + Returns + ------- + Aggregation + """ + if kwargs is None: + kwargs = {} + + cdef Aggregation agg + if isinstance(op, str): + agg = getattr(Aggregation, op)(**kwargs) + elif callable(op): + if op is list: + agg = Aggregation.collect() + elif "dtype" in kwargs: + agg = Aggregation.from_udf(op, **kwargs) + else: + agg = op(Aggregation) + else: + raise TypeError(f"Unknown aggregation {op}") + return agg diff --git a/python/cudf/cudf/_lib/groupby.pyx b/python/cudf/cudf/_lib/groupby.pyx index 4584841dd33..3c2b541f728 100644 --- a/python/cudf/cudf/_lib/groupby.pyx +++ b/python/cudf/cudf/_lib/groupby.pyx @@ -1,6 +1,15 @@ # Copyright (c) 2020, NVIDIA CORPORATION. from collections import defaultdict +from pandas.core.groupby.groupby import DataError +from cudf.utils.dtypes import ( + is_categorical_dtype, + is_string_dtype, + is_list_dtype, + is_interval_dtype, + is_struct_dtype, + is_decimal_dtype, +) import numpy as np import rmm @@ -13,56 +22,23 @@ from libcpp cimport bool from cudf._lib.column cimport Column from cudf._lib.table cimport Table -from cudf._lib.aggregation cimport make_aggregation, Aggregation +from cudf._lib.aggregation cimport Aggregation, make_aggregation from cudf._lib.cpp.table.table cimport table, table_view cimport cudf._lib.cpp.types as libcudf_types cimport cudf._lib.cpp.groupby as libcudf_groupby -cimport cudf._lib.cpp.aggregation as libcudf_aggregation # The sets below define the possible aggregations that can be performed on -# different dtypes. The uppercased versions of these strings correspond to -# elements of the AggregationKind enum. -_CATEGORICAL_AGGS = { - "count", - "size", - "nunique", - "unique", -} - -_STRING_AGGS = { - "count", - "size", - "max", - "min", - "nunique", - "nth", - "collect", - "unique", -} - -_LIST_AGGS = { - "collect", -} - -_STRUCT_AGGS = { -} - -_INTERVAL_AGGS = { -} - -_DECIMAL_AGGS = { - "count", - "sum", - "argmin", - "argmax", - "min", - "max", - "nunique", - "nth", - "collect" -} +# different dtypes. These strings must be elements of the AggregationKind enum. +_CATEGORICAL_AGGS = {"COUNT", "SIZE", "NUNIQUE", "UNIQUE"} +_STRING_AGGS = {"COUNT", "SIZE", "MAX", "MIN", "NUNIQUE", "NTH", "COLLECT", + "UNIQUE"} +_LIST_AGGS = {"COLLECT"} +_STRUCT_AGGS = set() +_INTERVAL_AGGS = set() +_DECIMAL_AGGS = {"COUNT", "SUM", "ARGMIN", "ARGMAX", "MIN", "MAX", "NUNIQUE", + "NTH", "COLLECT"} cdef class GroupBy: @@ -132,21 +108,51 @@ cdef class GroupBy: """ from cudf.core.column_accessor import ColumnAccessor cdef vector[libcudf_groupby.aggregation_request] c_agg_requests + cdef libcudf_groupby.aggregation_request c_agg_request cdef Column col + cdef Aggregation agg_obj - aggregations = _drop_unsupported_aggs(values, aggregations) + allow_empty = all(len(v) == 0 for v in aggregations.values()) + included_aggregations = defaultdict(list) for i, (col_name, aggs) in enumerate(aggregations.items()): col = values._data[col_name] - c_agg_requests.push_back( - move(libcudf_groupby.aggregation_request()) + dtype = col.dtype + + valid_aggregations = ( + _LIST_AGGS if is_list_dtype(dtype) + else _STRING_AGGS if is_string_dtype(dtype) + else _CATEGORICAL_AGGS if is_categorical_dtype(dtype) + else _STRING_AGGS if is_struct_dtype(dtype) + else _INTERVAL_AGGS if is_interval_dtype(dtype) + else _DECIMAL_AGGS if is_decimal_dtype(dtype) + else "ALL" ) - c_agg_requests[i].values = col.view() + if (valid_aggregations is _DECIMAL_AGGS + and rmm._cuda.gpu.runtimeGetVersion() < 11000): + raise RuntimeError( + "Decimal aggregations are only supported on CUDA >= 11 " + "due to an nvcc compiler bug." + ) + + c_agg_request = move(libcudf_groupby.aggregation_request()) for agg in aggs: - c_agg_requests[i].aggregations.push_back( - move(make_aggregation(agg)) + agg_obj = make_aggregation(agg) + if (valid_aggregations == "ALL" + or agg_obj.kind in valid_aggregations): + included_aggregations[col_name].append(agg) + c_agg_request.aggregations.push_back( + move(agg_obj.c_obj) + ) + if not c_agg_request.aggregations.empty(): + c_agg_request.values = col.view() + c_agg_requests.push_back( + move(c_agg_request) ) + if c_agg_requests.empty() and not allow_empty: + raise DataError("All requested aggregations are unsupported.") + cdef pair[ unique_ptr[table], vector[libcudf_groupby.aggregation_result] @@ -176,81 +182,14 @@ cdef class GroupBy: ) result_data = ColumnAccessor(multiindex=True) - for i, col_name in enumerate(aggregations): - for j, agg_name in enumerate(aggregations[col_name]): + # Note: This loop relies on the included_aggregations dict being + # insertion ordered to map results to requested aggregations by index. + for i, col_name in enumerate(included_aggregations): + for j, agg_name in enumerate(included_aggregations[col_name]): if callable(agg_name): agg_name = agg_name.__name__ result_data[(col_name, agg_name)] = ( Column.from_unique_ptr(move(c_result.second[i].results[j])) ) - result = Table(data=result_data, index=grouped_keys) - return result - - -def _drop_unsupported_aggs(Table values, aggs): - """ - Drop any aggregations that are not supported. - """ - from pandas.core.groupby.groupby import DataError - - if all(len(v) == 0 for v in aggs.values()): - return aggs - - from cudf.utils.dtypes import ( - is_categorical_dtype, - is_string_dtype, - is_list_dtype, - is_interval_dtype, - is_struct_dtype, - is_decimal_dtype, - ) - result = aggs.copy() - - for col_name in aggs: - if ( - is_list_dtype(values._data[col_name].dtype) - ): - for i, agg_name in enumerate(aggs[col_name]): - if Aggregation(agg_name).kind not in _LIST_AGGS: - del result[col_name][i] - elif ( - is_string_dtype(values._data[col_name].dtype) - ): - for i, agg_name in enumerate(aggs[col_name]): - if Aggregation(agg_name).kind not in _STRING_AGGS: - del result[col_name][i] - elif ( - is_categorical_dtype(values._data[col_name].dtype) - ): - for i, agg_name in enumerate(aggs[col_name]): - if Aggregation(agg_name).kind not in _CATEGORICAL_AGGS: - del result[col_name][i] - elif ( - is_struct_dtype(values._data[col_name].dtype) - ): - for i, agg_name in enumerate(aggs[col_name]): - if Aggregation(agg_name).kind not in _STRUCT_AGGS: - del result[col_name][i] - elif ( - is_interval_dtype(values._data[col_name].dtype) - ): - for i, agg_name in enumerate(aggs[col_name]): - if Aggregation(agg_name).kind not in _INTERVAL_AGGS: - del result[col_name][i] - elif ( - is_decimal_dtype(values._data[col_name].dtype) - ): - if rmm._cuda.gpu.runtimeGetVersion() < 11000: - raise RuntimeError( - "Decimal aggregations are only supported on CUDA >= 11 " - "due to an nvcc compiler bug." - ) - for i, agg_name in enumerate(aggs[col_name]): - if Aggregation(agg_name).kind not in _DECIMAL_AGGS: - del result[col_name][i] - - if all(len(v) == 0 for v in result.values()): - raise DataError("No numeric types to aggregate") - - return result + return Table(data=result_data, index=grouped_keys) diff --git a/python/cudf/cudf/_lib/reduce.pyx b/python/cudf/cudf/_lib/reduce.pyx index 62013ea88ae..e5723331f3c 100644 --- a/python/cudf/cudf/_lib/reduce.pyx +++ b/python/cudf/cudf/_lib/reduce.pyx @@ -12,7 +12,7 @@ from cudf._lib.scalar cimport DeviceScalar from cudf._lib.column cimport Column from cudf._lib.types import np_to_cudf_types from cudf._lib.types cimport underlying_type_t_type_id, dtype_to_data_type -from cudf._lib.aggregation cimport make_aggregation, aggregation +from cudf._lib.aggregation cimport make_aggregation, Aggregation from libcpp.memory cimport unique_ptr from libcpp.utility cimport move, pair import numpy as np @@ -45,9 +45,7 @@ def reduce(reduction_op, Column incol, dtype=None, **kwargs): cdef column_view c_incol_view = incol.view() cdef unique_ptr[scalar] c_result - cdef unique_ptr[aggregation] c_agg = move(make_aggregation( - reduction_op, kwargs - )) + cdef Aggregation cython_agg = make_aggregation(reduction_op, kwargs) cdef data_type c_out_dtype = dtype_to_data_type(col_dtype) @@ -65,7 +63,7 @@ def reduce(reduction_op, Column incol, dtype=None, **kwargs): with nogil: c_result = move(cpp_reduce( c_incol_view, - c_agg, + cython_agg.c_obj, c_out_dtype )) @@ -95,9 +93,7 @@ def scan(scan_op, Column incol, inclusive, **kwargs): """ cdef column_view c_incol_view = incol.view() cdef unique_ptr[column] c_result - cdef unique_ptr[aggregation] c_agg = move( - make_aggregation(scan_op, kwargs) - ) + cdef Aggregation cython_agg = make_aggregation(scan_op, kwargs) cdef scan_type c_inclusive if inclusive is True: @@ -108,7 +104,7 @@ def scan(scan_op, Column incol, inclusive, **kwargs): with nogil: c_result = move(cpp_scan( c_incol_view, - c_agg, + cython_agg.c_obj, c_inclusive )) diff --git a/python/cudf/cudf/_lib/rolling.pyx b/python/cudf/cudf/_lib/rolling.pyx index 9c818f39c38..d67fb431ec4 100644 --- a/python/cudf/cudf/_lib/rolling.pyx +++ b/python/cudf/cudf/_lib/rolling.pyx @@ -8,12 +8,11 @@ from libcpp.memory cimport unique_ptr from libcpp.utility cimport move from cudf._lib.column cimport Column -from cudf._lib.aggregation cimport make_aggregation +from cudf._lib.aggregation cimport Aggregation, make_aggregation from cudf._lib.cpp.types cimport size_type from cudf._lib.cpp.column.column cimport column from cudf._lib.cpp.column.column_view cimport column_view -from cudf._lib.cpp.aggregation cimport aggregation from cudf._lib.cpp.rolling cimport ( rolling_window as cpp_rolling_window ) @@ -47,14 +46,12 @@ def rolling(Column source_column, Column pre_column_window, cdef column_view source_column_view = source_column.view() cdef column_view pre_column_window_view cdef column_view fwd_column_window_view - cdef unique_ptr[aggregation] agg + cdef Aggregation cython_agg if callable(op): - agg = move( - make_aggregation(op, {'dtype': source_column.dtype}) - ) + cython_agg = make_aggregation(op, {'dtype': source_column.dtype}) else: - agg = move(make_aggregation(op)) + cython_agg = make_aggregation(op) if window is None: if center: @@ -71,7 +68,7 @@ def rolling(Column source_column, Column pre_column_window, pre_column_window_view, fwd_column_window_view, c_min_periods, - agg) + cython_agg.c_obj) ) else: c_min_periods = min_periods @@ -89,7 +86,7 @@ def rolling(Column source_column, Column pre_column_window, c_window, c_forward_window, c_min_periods, - agg) + cython_agg.c_obj) ) return Column.from_unique_ptr(move(c_result)) diff --git a/python/cudf/cudf/core/groupby/groupby.py b/python/cudf/cudf/core/groupby/groupby.py index cc94548d9a2..a52fae994e7 100644 --- a/python/cudf/cudf/core/groupby/groupby.py +++ b/python/cudf/cudf/core/groupby/groupby.py @@ -1,6 +1,5 @@ # Copyright (c) 2020, NVIDIA CORPORATION. import collections -import functools import pickle import warnings @@ -570,47 +569,106 @@ def rolling(self, *args, **kwargs): """ return cudf.core.window.rolling.RollingGroupby(self, *args, **kwargs) + def count(self, dropna=True): + """Compute the number of values in each column. -# Set of valid groupby aggregations that are monkey-patched into the GroupBy -# namespace. -_VALID_GROUPBY_AGGS = { - "count", - "sum", - "idxmin", - "idxmax", - "min", - "max", - "mean", - "var", - "std", - "quantile", - "median", - "nunique", - "collect", - "unique", -} - - -# Dynamically bind the different aggregation methods. -def _agg_func_name_with_args(self, func_name, *args, **kwargs): - """ - Aggregate given an aggregate function name and arguments to the - function, e.g., `_agg_func_name_with_args("quantile", 0.5)`. The named - aggregations must be members of _AggregationFactory. - """ + Parameters + ---------- + dropna : bool + If ``True``, don't include null values in the count. + """ + + def func(x): + return getattr(x, "count")(dropna=dropna) + + return self.agg(func) + + def sum(self): + """Compute the column-wise sum of the values in each group.""" + return self.agg("sum") + + def idxmin(self): + """Get the column-wise index of the minimum value in each group.""" + return self.agg("idxmin") - def func(x): - """Compute the {} of the group.""".format(func_name) - return getattr(x, func_name)(*args, **kwargs) + def idxmax(self): + """Get the column-wise index of the maximum value in each group.""" + return self.agg("idxmax") + + def min(self): + """Get the column-wise minimum value in each group.""" + return self.agg("min") + + def max(self): + """Get the column-wise maximum value in each group.""" + return self.agg("max") + + def mean(self): + """Compute the column-wise mean of the values in each group.""" + return self.agg("mean") + + def median(self): + """Get the column-wise median of the values in each group.""" + return self.agg("median") + + def var(self, ddof=1): + """Compute the column-wise variance of the values in each group. + + Parameters + ---------- + ddof : int + The delta degrees of freedom. N - ddof is the divisor used to + normalize the variance. + """ - func.__name__ = func_name - return self.agg(func) + def func(x): + return getattr(x, "var")(ddof=ddof) + + return self.agg(func) + + def std(self, ddof=1): + """Compute the column-wise std of the values in each group. + + Parameters + ---------- + ddof : int + The delta degrees of freedom. N - ddof is the divisor used to + normalize the standard deviation. + """ + + def func(x): + return getattr(x, "std")(ddof=ddof) + + return self.agg(func) + + def quantile(self, q=0.5, interpolation="linear"): + """Compute the column-wise quantiles of the values in each group. + + Parameters + ---------- + q : float or array-like + The quantiles to compute. + interpolation : {"linear", "lower", "higher", "midpoint", "nearest"} + The interpolation method to use when the desired quantile lies + between two data points. Defaults to "linear". + """ + + def func(x): + return getattr(x, "quantile")(q=q, interpolation=interpolation) + + return self.agg(func) + + def nunique(self): + """Compute the number of unique values in each column in each group.""" + return self.agg("nunique") + def collect(self): + """Get a list of all the values for each column in each group.""" + return self.agg("collect") -for key in _VALID_GROUPBY_AGGS: - setattr( - GroupBy, key, functools.partialmethod(_agg_func_name_with_args, key) - ) + def unique(self): + """Get a list of the unique values for each column in each group.""" + return self.agg("unique") class DataFrameGroupBy(GroupBy): diff --git a/python/cudf/cudf/tests/test_groupby.py b/python/cudf/cudf/tests/test_groupby.py index 4dbe608af82..868387b100e 100644 --- a/python/cudf/cudf/tests/test_groupby.py +++ b/python/cudf/cudf/tests/test_groupby.py @@ -1236,7 +1236,11 @@ def test_raise_data_error(): pdf = pd.DataFrame({"a": [1, 2, 3, 4], "b": ["a", "b", "c", "d"]}) gdf = cudf.from_pandas(pdf) - assert_exceptions_equal(pdf.groupby("a").mean, gdf.groupby("a").mean) + assert_exceptions_equal( + pdf.groupby("a").mean, + gdf.groupby("a").mean, + compare_error_message=False, + ) def test_drop_unsupported_multi_agg():