Skip to content

Commit

Permalink
Support groupby operations for decimal dtypes (#7731)
Browse files Browse the repository at this point in the history
This PR resolves #7687. It also does a bit of cleanup of the internals of the code base. There is more that I would like to do, but I'll probably punt everything to a future PR that I don't directly have to touch for this change in the interest of quickly resolving the issue.

I still need help determining why a few aggregations aren't working. The problems fall into two groups:
1. The `var` and `std` aggregations currently don't fail, but they always return columns filled with NULLs. I found the implementation of the dispatch for these methods in `variance.cu`/`compound.cuh`, and at least nominally it seems like these methods _are not_ currently supported because the corresponding `enable_if_t` is based on whether the type satisfies `std::is_arithmetic`, which decimal types will not. However, I'm not sure whether the problem is that this classification is incorrect and these types are actually supported by `libcudf`, or if there really isn't an implementation; I tried to find one, but there are a lot of different files related to aggregation and I'm sure I didn't find all of them. If we simply don't have an implementation, I can remove these from the list of valid aggregations.
2. The `mean`, `quantile`, and `median` aggregations all raise a `RuntimeError` from `binaryop.hpp`: "Input must have fixed_point data_type." I've traced the error to the Cython `GroupBy.aggregate` method, specifically the line where it calls through to the underlying `c_obj`'s `aggregate` method. The call stack in C++ is pretty deep after that, though, and I haven't yet been able to pinpoint whether the failure is a missing cast somewhere (i.e. `libcudf` thinks that the column is a floating point type when it's really not) or if the problem lies elsewhere.

**Update**
Thanks to @codereport, I've now marked all the above as unsupported operations. After some discussion with other devs I've also handled the other extended types. I still need to write tests, but I think this PR is ready for review in its current form to identify if I've missed anything in the implementation.

Authors:
  - Vyas Ramasubramani (https://github.com/vyasr)

Approvers:
  - Ashwin Srinath (https://github.com/shwina)
  - Keith Kraus (https://github.com/kkraus14)

URL: #7731
  • Loading branch information
vyasr authored Apr 1, 2021
1 parent e379ab1 commit 24f3016
Show file tree
Hide file tree
Showing 6 changed files with 208 additions and 102 deletions.
36 changes: 18 additions & 18 deletions docs/cudf/source/groupby.md
Original file line number Diff line number Diff line change
Expand Up @@ -120,24 +120,24 @@ a

The following table summarizes the available aggregations and the types that support them:

| Aggregations\dtypes | Numeric | Datetime | String | Categorical | List | Struct |
| ------------------- | -------- | ------- | -------- | ----------- | ---- | ------ |
| count ||||| | |
| size ||||| | |
| sum ||| | | | |
| idxmin ||| | | | |
| idxmax ||| | | | |
| min |||| | | |
| max |||| | | |
| mean ||| | | | |
| var ||| | | | |
| std ||| | | | |
| quantile ||| | | | |
| median ||| | | | |
| nunique ||||| | |
| nth |||| | | |
| collect |||| || |
| unique ||||| | |
| Aggregations\dtypes | Numeric | Datetime | String | Categorical | List | Struct | Interval | Decimal |
| ------------------- | -------- | ------- | -------- | ----------- | ---- | ------ | -------- | ------- |
| count ||||| | | ||
| size ||||| | | ||
| sum ||| | | | | ||
| idxmin ||| | | | | ||
| idxmax ||| | | | | ||
| min |||| | | | ||
| max |||| | | | ||
| mean ||| | | | | | |
| var ||| | | | | | |
| std ||| | | | | | |
| quantile ||| | | | | | |
| median ||| | | | | | |
| nunique ||||| | | ||
| nth |||| | | | ||
| collect |||| || | ||
| unique ||||| | | | |

## GroupBy apply

Expand Down
69 changes: 49 additions & 20 deletions python/cudf/cudf/_lib/groupby.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from collections import defaultdict

import numpy as np
import rmm

from libcpp.pair cimport pair
from libcpp.memory cimport unique_ptr
Expand All @@ -20,25 +21,9 @@ cimport cudf._lib.cpp.groupby as libcudf_groupby
cimport cudf._lib.cpp.aggregation as libcudf_aggregation


_GROUPBY_AGGS = {
"count",
"size",
"sum",
"idxmin",
"idxmax",
"min",
"max",
"mean",
"var",
"std",
"quantile",
"median",
"nunique",
"nth",
"collect",
"unique",
}

# The sets below define the possible aggregations that can be performed on
# different dtypes. The uppercased versions of these strings correspond to
# elements of the AggregationKind enum.
_CATEGORICAL_AGGS = {
"count",
"size",
Expand All @@ -61,6 +46,24 @@ _LIST_AGGS = {
"collect",
}

_STRUCT_AGGS = {
}

_INTERVAL_AGGS = {
}

_DECIMAL_AGGS = {
"count",
"sum",
"argmin",
"argmax",
"min",
"max",
"nunique",
"nth",
"collect"
}


cdef class GroupBy:
cdef unique_ptr[libcudf_groupby.groupby] c_obj
Expand Down Expand Up @@ -197,7 +200,10 @@ def _drop_unsupported_aggs(Table values, aggs):
from cudf.utils.dtypes import (
is_categorical_dtype,
is_string_dtype,
is_list_dtype
is_list_dtype,
is_interval_dtype,
is_struct_dtype,
is_decimal_dtype,
)
result = aggs.copy()

Expand All @@ -220,6 +226,29 @@ def _drop_unsupported_aggs(Table values, aggs):
for i, agg_name in enumerate(aggs[col_name]):
if Aggregation(agg_name).kind not in _CATEGORICAL_AGGS:
del result[col_name][i]
elif (
is_struct_dtype(values._data[col_name].dtype)
):
for i, agg_name in enumerate(aggs[col_name]):
if Aggregation(agg_name).kind not in _STRUCT_AGGS:
del result[col_name][i]
elif (
is_interval_dtype(values._data[col_name].dtype)
):
for i, agg_name in enumerate(aggs[col_name]):
if Aggregation(agg_name).kind not in _INTERVAL_AGGS:
del result[col_name][i]
elif (
is_decimal_dtype(values._data[col_name].dtype)
):
if rmm._cuda.gpu.runtimeGetVersion() < 11000:
raise RuntimeError(
"Decimal aggregations are only supported on CUDA >= 11 "
"due to an nvcc compiler bug."
)
for i, agg_name in enumerate(aggs[col_name]):
if Aggregation(agg_name).kind not in _DECIMAL_AGGS:
del result[col_name][i]

if all(len(v) == 0 for v in result.values()):
raise DataError("No numeric types to aggregate")
Expand Down
14 changes: 10 additions & 4 deletions python/cudf/cudf/core/column/column.py
Original file line number Diff line number Diff line change
Expand Up @@ -426,6 +426,8 @@ def from_arrow(cls, array: pa.Array) -> ColumnBase:
array.type, pd.core.arrays._arrow_utils.ArrowIntervalType
):
return cudf.core.column.IntervalColumn.from_arrow(array)
elif isinstance(array.type, pa.Decimal128Type):
return cudf.core.column.DecimalColumn.from_arrow(array)

return libcudf.interop.from_arrow(data, data.column_names)._data[
"None"
Expand Down Expand Up @@ -1846,10 +1848,14 @@ def as_column(
cupy.asarray(arbitrary), nan_as_null=nan_as_null, dtype=dtype
)
else:
data = as_column(
pa.array(arbitrary, from_pandas=nan_as_null),
dtype=arbitrary.dtype,
)
pyarrow_array = pa.array(arbitrary, from_pandas=nan_as_null)
if isinstance(pyarrow_array.type, pa.Decimal128Type):
pyarrow_type = cudf.Decimal64Dtype.from_arrow(
pyarrow_array.type
)
else:
pyarrow_type = arbitrary.dtype
data = as_column(pyarrow_array, dtype=pyarrow_type)
if dtype is not None:
data = data.astype(dtype)

Expand Down
82 changes: 53 additions & 29 deletions python/cudf/cudf/core/groupby/groupby.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@
from cudf.utils.utils import cached_property


# Note that all valid aggregation methods (e.g. GroupBy.min) are bound to the
# class after its definition (see below).
class GroupBy(Serializable):

_MAX_GROUPS_BEFORE_WARN = 100
Expand Down Expand Up @@ -58,14 +60,6 @@ def __init__(
else:
self.grouping = _Grouping(obj, by, level)

def __getattribute__(self, key):
try:
return super().__getattribute__(key)
except AttributeError:
if key in libgroupby._GROUPBY_AGGS:
return functools.partial(self._agg_func_name_with_args, key)
raise

def __iter__(self):
group_names, offsets, _, grouped_values = self._grouped()
if isinstance(group_names, cudf.Index):
Expand Down Expand Up @@ -267,19 +261,6 @@ def _grouped(self):
group_names = grouped_keys.unique()
return (group_names, offsets, grouped_keys, grouped_values)

def _agg_func_name_with_args(self, func_name, *args, **kwargs):
"""
Aggregate given an aggregate function name
and arguments to the function, e.g.,
`_agg_func_name_with_args("quantile", 0.5)`
"""

def func(x):
return getattr(x, func_name)(*args, **kwargs)

func.__name__ = func_name
return self.agg(func)

def _normalize_aggs(self, aggs):
"""
Normalize aggs to a dict mapping column names
Expand Down Expand Up @@ -590,6 +571,48 @@ def rolling(self, *args, **kwargs):
return cudf.core.window.rolling.RollingGroupby(self, *args, **kwargs)


# Set of valid groupby aggregations that are monkey-patched into the GroupBy
# namespace.
_VALID_GROUPBY_AGGS = {
"count",
"sum",
"idxmin",
"idxmax",
"min",
"max",
"mean",
"var",
"std",
"quantile",
"median",
"nunique",
"collect",
"unique",
}


# Dynamically bind the different aggregation methods.
def _agg_func_name_with_args(self, func_name, *args, **kwargs):
"""
Aggregate given an aggregate function name and arguments to the
function, e.g., `_agg_func_name_with_args("quantile", 0.5)`. The named
aggregations must be members of _AggregationFactory.
"""

def func(x):
"""Compute the {} of the group.""".format(func_name)
return getattr(x, func_name)(*args, **kwargs)

func.__name__ = func_name
return self.agg(func)


for key in _VALID_GROUPBY_AGGS:
setattr(
GroupBy, key, functools.partialmethod(_agg_func_name_with_args, key)
)


class DataFrameGroupBy(GroupBy):
def __init__(
self, obj, by=None, level=None, sort=False, as_index=True, dropna=True
Expand Down Expand Up @@ -685,15 +708,16 @@ def __init__(
dropna=dropna,
)

def __getattribute__(self, key):
def __getattr__(self, key):
# Without this check, copying can trigger a RecursionError. See
# https://nedbatchelder.com/blog/201010/surprising_getattr_recursion.html # noqa: E501
# for an explanation.
if key == "obj":
raise AttributeError
try:
return super().__getattribute__(key)
except AttributeError:
if key in self.obj:
return self.obj[key].groupby(
self.grouping, dropna=self._dropna, sort=self._sort
)
raise
return self[key]
except KeyError:
raise AttributeError

def __getitem__(self, key):
return self.obj[key].groupby(
Expand Down
Loading

0 comments on commit 24f3016

Please sign in to comment.