Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[REVIEW] Support groupby operations for decimal dtypes #7731

Merged
merged 39 commits into from
Apr 1, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
39 commits
Select commit Hold shift + click to select a range
26bafd0
Don't identify decimals as strings.
vyasr Mar 24, 2021
babcdfc
Reject all extension types as string types.
vyasr Mar 25, 2021
2b00611
Create separate lists for extension type methods.
vyasr Mar 25, 2021
76ab556
Merge branch 'branch-0.19' into fix/issue7687_part2
vyasr Mar 25, 2021
1ebde51
Enable collect for decimals.
vyasr Mar 25, 2021
4c5d876
Enable argmin and argmax.
vyasr Mar 25, 2021
4134e43
Fix variance key name.
vyasr Mar 25, 2021
43cf580
Move groupby aggregation list to groupby.py and clean up the assignme…
vyasr Mar 25, 2021
474a179
Disable aggs that are overrides of actual methods.
vyasr Mar 25, 2021
25e74ef
Move more logic out of the GroupBy class.
vyasr Mar 25, 2021
8a44827
Simplify getattr usage.
vyasr Mar 25, 2021
6b5c67f
Clearly documented unknown failures.
vyasr Mar 25, 2021
8e45ad0
Match other class groupbys to strings.
vyasr Mar 25, 2021
81ffe0a
Fix style and remove unsupported operations.
vyasr Mar 25, 2021
6d3fad3
Apply black reformattings.
vyasr Mar 25, 2021
714742d
Remove variance from obviously unsupported types.
vyasr Mar 25, 2021
ea4ed2e
Defer getattr to getitem if possible.
vyasr Mar 25, 2021
026bb4e
Make getattr safe for copying.
vyasr Mar 25, 2021
1259032
Remove support for aggregating structs.
vyasr Mar 25, 2021
6c61806
Update documented list of groupby operations.
vyasr Mar 25, 2021
a14d30f
Move function out of loop.
vyasr Mar 25, 2021
12caa06
Merge branch 'branch-0.19' into fix/issue7687_part2
vyasr Mar 26, 2021
5c71bfe
Remove redundant test, add test of decimal.
vyasr Mar 27, 2021
25811b0
Fix formatting.
vyasr Mar 27, 2021
1450f2d
Merge branch 'branch-0.19' into fix/issue7687_part2
vyasr Mar 28, 2021
39c45ac
Merge branch 'branch-0.19' into fix/issue7687_part2
vyasr Mar 29, 2021
c036d0c
Add more rigorous test (currently includes debugging statements).
vyasr Mar 29, 2021
d2385ec
Add support for pandas Series composed of decimal.Decimal objects.
vyasr Mar 29, 2021
5395563
Clean up the testing code and use Decimal to make pandas and cudf com…
vyasr Mar 29, 2021
cec2c13
Rewrite test logic to avoid duplicates, but remove those tests for id…
vyasr Mar 29, 2021
eadc028
Minor cleanup.
vyasr Mar 30, 2021
395856a
Apply black.
vyasr Mar 30, 2021
c8a83b3
Don't overwrite dtype variable.
vyasr Mar 30, 2021
040401c
Skip decimal tests on CUDA 10.x.
vyasr Mar 31, 2021
286b686
Rename pyarrow_dtype to pyarrow_type.
vyasr Mar 31, 2021
91642d3
Use rmm to get the CUDA version.
vyasr Mar 31, 2021
2e7dfb8
Make decimal fail loudly on older architectures.
vyasr Mar 31, 2021
fbfcdf8
Fix import order.
vyasr Mar 31, 2021
72bc1b7
Change exception message to indicate that the underlying cause is a c…
vyasr Mar 31, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 18 additions & 18 deletions docs/cudf/source/groupby.md
Original file line number Diff line number Diff line change
Expand Up @@ -120,24 +120,24 @@ a

The following table summarizes the available aggregations and the types that support them:

| Aggregations\dtypes | Numeric | Datetime | String | Categorical | List | Struct |
| ------------------- | -------- | ------- | -------- | ----------- | ---- | ------ |
| count | ✅ | ✅ | ✅ | ✅ | | |
| size | ✅ | ✅ | ✅ | ✅ | | |
| sum | ✅ | ✅ | | | | |
| idxmin | ✅ | ✅ | | | | |
| idxmax | ✅ | ✅ | | | | |
| min | ✅ | ✅ | ✅ | | | |
| max | ✅ | ✅ | ✅ | | | |
| mean | ✅ | ✅ | | | | |
| var | ✅ | ✅ | | | | |
| std | ✅ | ✅ | | | | |
| quantile | ✅ | ✅ | | | | |
| median | ✅ | ✅ | | | | |
| nunique | ✅ | ✅ | ✅ | ✅ | | |
| nth | ✅ | ✅ | ✅ | | | |
| collect | ✅ | ✅ | ✅ | | ✅ | |
| unique | ✅ | ✅ | ✅ | ✅ | | |
| Aggregations\dtypes | Numeric | Datetime | String | Categorical | List | Struct | Interval | Decimal |
| ------------------- | -------- | ------- | -------- | ----------- | ---- | ------ | -------- | ------- |
| count | ✅ | ✅ | ✅ | ✅ | | | | ✅ |
| size | ✅ | ✅ | ✅ | ✅ | | | | ✅ |
| sum | ✅ | ✅ | | | | | | ✅ |
| idxmin | ✅ | ✅ | | | | | | ✅ |
| idxmax | ✅ | ✅ | | | | | | ✅ |
| min | ✅ | ✅ | ✅ | | | | | ✅ |
| max | ✅ | ✅ | ✅ | | | | | ✅ |
| mean | ✅ | ✅ | | | | | | |
| var | ✅ | ✅ | | | | | | |
| std | ✅ | ✅ | | | | | | |
| quantile | ✅ | ✅ | | | | | | |
| median | ✅ | ✅ | | | | | | |
| nunique | ✅ | ✅ | ✅ | ✅ | | | | ✅ |
| nth | ✅ | ✅ | ✅ | | | | | ✅ |
| collect | ✅ | ✅ | ✅ | | ✅ | | | ✅ |
| unique | ✅ | ✅ | ✅ | ✅ | | | | |

## GroupBy apply

Expand Down
69 changes: 49 additions & 20 deletions python/cudf/cudf/_lib/groupby.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from collections import defaultdict

import numpy as np
import rmm

from libcpp.pair cimport pair
from libcpp.memory cimport unique_ptr
Expand All @@ -20,25 +21,9 @@ cimport cudf._lib.cpp.groupby as libcudf_groupby
cimport cudf._lib.cpp.aggregation as libcudf_aggregation


_GROUPBY_AGGS = {
"count",
"size",
"sum",
"idxmin",
"idxmax",
"min",
"max",
"mean",
"var",
"std",
"quantile",
"median",
"nunique",
"nth",
"collect",
"unique",
}

# The sets below define the possible aggregations that can be performed on
# different dtypes. The uppercased versions of these strings correspond to
# elements of the AggregationKind enum.
_CATEGORICAL_AGGS = {
"count",
"size",
Expand All @@ -61,6 +46,24 @@ _LIST_AGGS = {
"collect",
}

_STRUCT_AGGS = {
}

_INTERVAL_AGGS = {
}

_DECIMAL_AGGS = {
"count",
"sum",
"argmin",
"argmax",
"min",
"max",
"nunique",
"nth",
"collect"
}


cdef class GroupBy:
cdef unique_ptr[libcudf_groupby.groupby] c_obj
Expand Down Expand Up @@ -197,7 +200,10 @@ def _drop_unsupported_aggs(Table values, aggs):
from cudf.utils.dtypes import (
is_categorical_dtype,
is_string_dtype,
is_list_dtype
is_list_dtype,
is_interval_dtype,
is_struct_dtype,
is_decimal_dtype,
)
result = aggs.copy()

Expand All @@ -220,6 +226,29 @@ def _drop_unsupported_aggs(Table values, aggs):
for i, agg_name in enumerate(aggs[col_name]):
if Aggregation(agg_name).kind not in _CATEGORICAL_AGGS:
del result[col_name][i]
elif (
is_struct_dtype(values._data[col_name].dtype)
):
for i, agg_name in enumerate(aggs[col_name]):
if Aggregation(agg_name).kind not in _STRUCT_AGGS:
del result[col_name][i]
elif (
is_interval_dtype(values._data[col_name].dtype)
):
for i, agg_name in enumerate(aggs[col_name]):
if Aggregation(agg_name).kind not in _INTERVAL_AGGS:
del result[col_name][i]
elif (
is_decimal_dtype(values._data[col_name].dtype)
):
if rmm._cuda.gpu.runtimeGetVersion() < 11000:
raise RuntimeError(
"Decimal aggregations are only supported on CUDA >= 11 "
"due to an nvcc compiler bug."
)
for i, agg_name in enumerate(aggs[col_name]):
if Aggregation(agg_name).kind not in _DECIMAL_AGGS:
del result[col_name][i]

if all(len(v) == 0 for v in result.values()):
raise DataError("No numeric types to aggregate")
Expand Down
14 changes: 10 additions & 4 deletions python/cudf/cudf/core/column/column.py
Original file line number Diff line number Diff line change
Expand Up @@ -426,6 +426,8 @@ def from_arrow(cls, array: pa.Array) -> ColumnBase:
array.type, pd.core.arrays._arrow_utils.ArrowIntervalType
):
return cudf.core.column.IntervalColumn.from_arrow(array)
elif isinstance(array.type, pa.Decimal128Type):
return cudf.core.column.DecimalColumn.from_arrow(array)

return libcudf.interop.from_arrow(data, data.column_names)._data[
"None"
Expand Down Expand Up @@ -1836,10 +1838,14 @@ def as_column(
cupy.asarray(arbitrary), nan_as_null=nan_as_null, dtype=dtype
)
else:
data = as_column(
pa.array(arbitrary, from_pandas=nan_as_null),
dtype=arbitrary.dtype,
)
pyarrow_array = pa.array(arbitrary, from_pandas=nan_as_null)
if isinstance(pyarrow_array.type, pa.Decimal128Type):
pyarrow_type = cudf.Decimal64Dtype.from_arrow(
pyarrow_array.type
)
else:
pyarrow_type = arbitrary.dtype
data = as_column(pyarrow_array, dtype=pyarrow_type)
if dtype is not None:
data = data.astype(dtype)

Expand Down
82 changes: 53 additions & 29 deletions python/cudf/cudf/core/groupby/groupby.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@
from cudf.utils.utils import cached_property


# Note that all valid aggregation methods (e.g. GroupBy.min) are bound to the
# class after its definition (see below).
class GroupBy(Serializable):

_MAX_GROUPS_BEFORE_WARN = 100
Expand Down Expand Up @@ -58,14 +60,6 @@ def __init__(
else:
self.grouping = _Grouping(obj, by, level)

def __getattribute__(self, key):
try:
return super().__getattribute__(key)
except AttributeError:
if key in libgroupby._GROUPBY_AGGS:
return functools.partial(self._agg_func_name_with_args, key)
raise

def __iter__(self):
group_names, offsets, _, grouped_values = self._grouped()
if isinstance(group_names, cudf.Index):
Expand Down Expand Up @@ -267,19 +261,6 @@ def _grouped(self):
group_names = grouped_keys.unique()
return (group_names, offsets, grouped_keys, grouped_values)

def _agg_func_name_with_args(self, func_name, *args, **kwargs):
"""
Aggregate given an aggregate function name
and arguments to the function, e.g.,
`_agg_func_name_with_args("quantile", 0.5)`
"""

def func(x):
return getattr(x, func_name)(*args, **kwargs)

func.__name__ = func_name
return self.agg(func)

def _normalize_aggs(self, aggs):
"""
Normalize aggs to a dict mapping column names
Expand Down Expand Up @@ -590,6 +571,48 @@ def rolling(self, *args, **kwargs):
return cudf.core.window.rolling.RollingGroupby(self, *args, **kwargs)


# Set of valid groupby aggregations that are monkey-patched into the GroupBy
# namespace.
_VALID_GROUPBY_AGGS = {
"count",
"sum",
"idxmin",
"idxmax",
"min",
"max",
"mean",
"var",
"std",
"quantile",
"median",
"nunique",
"collect",
"unique",
}


# Dynamically bind the different aggregation methods.
def _agg_func_name_with_args(self, func_name, *args, **kwargs):
"""
Aggregate given an aggregate function name and arguments to the
function, e.g., `_agg_func_name_with_args("quantile", 0.5)`. The named
aggregations must be members of _AggregationFactory.
"""

def func(x):
"""Compute the {} of the group.""".format(func_name)
return getattr(x, func_name)(*args, **kwargs)

func.__name__ = func_name
return self.agg(func)


for key in _VALID_GROUPBY_AGGS:
setattr(
GroupBy, key, functools.partialmethod(_agg_func_name_with_args, key)
)


class DataFrameGroupBy(GroupBy):
def __init__(
self, obj, by=None, level=None, sort=False, as_index=True, dropna=True
Expand Down Expand Up @@ -685,15 +708,16 @@ def __init__(
dropna=dropna,
)

def __getattribute__(self, key):
def __getattr__(self, key):
# Without this check, copying can trigger a RecursionError. See
# https://nedbatchelder.com/blog/201010/surprising_getattr_recursion.html # noqa: E501
# for an explanation.
if key == "obj":
raise AttributeError
try:
return super().__getattribute__(key)
except AttributeError:
if key in self.obj:
return self.obj[key].groupby(
self.grouping, dropna=self._dropna, sort=self._sort
)
raise
return self[key]
except KeyError:
raise AttributeError

def __getitem__(self, key):
return self.obj[key].groupby(
Expand Down
Loading