Skip to content

Commit

Permalink
Add checks for HLG layers in dask-cudf groupby tests (#10853)
Browse files Browse the repository at this point in the history
This PR adds helper function `check_groupby_result` to dask-cudf's groupby tests, and is used in the basic tests to ensure that we are using dask-cudf's `groupby_agg` function to compute the result as expected.

I also expanded `test_groupby_agg` to test all supported aggregations, and removed tests that were made superfluous by this change.

Authors:
  - Charles Blackmon-Luca (https://github.com/charlesbluca)

Approvers:
  - Mads R. B. Kristensen (https://github.com/madsbk)
  - Lawrence Mitchell (https://github.com/wence-)

URL: #10853
  • Loading branch information
charlesbluca authored Nov 7, 2022
1 parent 262631b commit 17b6b2e
Show file tree
Hide file tree
Showing 2 changed files with 101 additions and 111 deletions.
76 changes: 36 additions & 40 deletions python/dask_cudf/dask_cudf/groupby.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
# Copyright (c) 2020-2022, NVIDIA CORPORATION.

from functools import wraps
from typing import Set

import numpy as np
Expand All @@ -16,12 +17,8 @@
import cudf
from cudf.utils.utils import _dask_cudf_nvtx_annotate

CUMULATIVE_AGGS = (
"cumsum",
"cumcount",
)

AGGS = (
# aggregations that are dask-cudf optimized
OPTIMIZED_AGGS = (
"count",
"mean",
"std",
Expand All @@ -34,19 +31,18 @@
"last",
)

SUPPORTED_AGGS = (*AGGS, *CUMULATIVE_AGGS)


def _check_groupby_supported(func):
def _check_groupby_optimized(func):
"""
Decorator for dask-cudf's groupby methods that returns the dask-cudf
method if the groupby object is supported, otherwise reverting to the
upstream Dask method
optimized method if the groupby object is supported, otherwise
reverting to the upstream Dask method
"""

@wraps(func)
def wrapper(*args, **kwargs):
gb = args[0]
if _groupby_supported(gb):
if _groupby_optimized(gb):
return func(*args, **kwargs)
# note that we use upstream Dask's default kwargs for this call if
# none are specified; this shouldn't be an issue as those defaults are
Expand Down Expand Up @@ -94,7 +90,7 @@ def _make_groupby_method_aggs(self, agg_name):
return {c: agg_name for c in self.obj.columns if c != self.by}

@_dask_cudf_nvtx_annotate
@_check_groupby_supported
@_check_groupby_optimized
def count(self, split_every=None, split_out=1):
return groupby_agg(
self.obj,
Expand All @@ -109,7 +105,7 @@ def count(self, split_every=None, split_out=1):
)

@_dask_cudf_nvtx_annotate
@_check_groupby_supported
@_check_groupby_optimized
def mean(self, split_every=None, split_out=1):
return groupby_agg(
self.obj,
Expand All @@ -124,7 +120,7 @@ def mean(self, split_every=None, split_out=1):
)

@_dask_cudf_nvtx_annotate
@_check_groupby_supported
@_check_groupby_optimized
def std(self, split_every=None, split_out=1):
return groupby_agg(
self.obj,
Expand All @@ -139,7 +135,7 @@ def std(self, split_every=None, split_out=1):
)

@_dask_cudf_nvtx_annotate
@_check_groupby_supported
@_check_groupby_optimized
def var(self, split_every=None, split_out=1):
return groupby_agg(
self.obj,
Expand All @@ -154,7 +150,7 @@ def var(self, split_every=None, split_out=1):
)

@_dask_cudf_nvtx_annotate
@_check_groupby_supported
@_check_groupby_optimized
def sum(self, split_every=None, split_out=1):
return groupby_agg(
self.obj,
Expand All @@ -169,7 +165,7 @@ def sum(self, split_every=None, split_out=1):
)

@_dask_cudf_nvtx_annotate
@_check_groupby_supported
@_check_groupby_optimized
def min(self, split_every=None, split_out=1):
return groupby_agg(
self.obj,
Expand All @@ -184,7 +180,7 @@ def min(self, split_every=None, split_out=1):
)

@_dask_cudf_nvtx_annotate
@_check_groupby_supported
@_check_groupby_optimized
def max(self, split_every=None, split_out=1):
return groupby_agg(
self.obj,
Expand All @@ -199,7 +195,7 @@ def max(self, split_every=None, split_out=1):
)

@_dask_cudf_nvtx_annotate
@_check_groupby_supported
@_check_groupby_optimized
def collect(self, split_every=None, split_out=1):
return groupby_agg(
self.obj,
Expand All @@ -214,7 +210,7 @@ def collect(self, split_every=None, split_out=1):
)

@_dask_cudf_nvtx_annotate
@_check_groupby_supported
@_check_groupby_optimized
def first(self, split_every=None, split_out=1):
return groupby_agg(
self.obj,
Expand All @@ -229,7 +225,7 @@ def first(self, split_every=None, split_out=1):
)

@_dask_cudf_nvtx_annotate
@_check_groupby_supported
@_check_groupby_optimized
def last(self, split_every=None, split_out=1):
return groupby_agg(
self.obj,
Expand All @@ -250,7 +246,7 @@ def aggregate(self, arg, split_every=None, split_out=1, shuffle=None):

arg = _redirect_aggs(arg)

if _groupby_supported(self) and _aggs_supported(arg, SUPPORTED_AGGS):
if _groupby_optimized(self) and _aggs_optimized(arg, OPTIMIZED_AGGS):
if isinstance(self._meta.grouping.keys, cudf.MultiIndex):
keys = self._meta.grouping.keys.names
else:
Expand Down Expand Up @@ -287,7 +283,7 @@ def __init__(self, *args, sort=None, **kwargs):
super().__init__(*args, sort=sort, **kwargs)

@_dask_cudf_nvtx_annotate
@_check_groupby_supported
@_check_groupby_optimized
def count(self, split_every=None, split_out=1):
return groupby_agg(
self.obj,
Expand All @@ -302,7 +298,7 @@ def count(self, split_every=None, split_out=1):
)[self._slice]

@_dask_cudf_nvtx_annotate
@_check_groupby_supported
@_check_groupby_optimized
def mean(self, split_every=None, split_out=1):
return groupby_agg(
self.obj,
Expand All @@ -317,7 +313,7 @@ def mean(self, split_every=None, split_out=1):
)[self._slice]

@_dask_cudf_nvtx_annotate
@_check_groupby_supported
@_check_groupby_optimized
def std(self, split_every=None, split_out=1):
return groupby_agg(
self.obj,
Expand All @@ -332,7 +328,7 @@ def std(self, split_every=None, split_out=1):
)[self._slice]

@_dask_cudf_nvtx_annotate
@_check_groupby_supported
@_check_groupby_optimized
def var(self, split_every=None, split_out=1):
return groupby_agg(
self.obj,
Expand All @@ -347,7 +343,7 @@ def var(self, split_every=None, split_out=1):
)[self._slice]

@_dask_cudf_nvtx_annotate
@_check_groupby_supported
@_check_groupby_optimized
def sum(self, split_every=None, split_out=1):
return groupby_agg(
self.obj,
Expand All @@ -362,7 +358,7 @@ def sum(self, split_every=None, split_out=1):
)[self._slice]

@_dask_cudf_nvtx_annotate
@_check_groupby_supported
@_check_groupby_optimized
def min(self, split_every=None, split_out=1):
return groupby_agg(
self.obj,
Expand All @@ -377,7 +373,7 @@ def min(self, split_every=None, split_out=1):
)[self._slice]

@_dask_cudf_nvtx_annotate
@_check_groupby_supported
@_check_groupby_optimized
def max(self, split_every=None, split_out=1):
return groupby_agg(
self.obj,
Expand All @@ -392,7 +388,7 @@ def max(self, split_every=None, split_out=1):
)[self._slice]

@_dask_cudf_nvtx_annotate
@_check_groupby_supported
@_check_groupby_optimized
def collect(self, split_every=None, split_out=1):
return groupby_agg(
self.obj,
Expand All @@ -407,7 +403,7 @@ def collect(self, split_every=None, split_out=1):
)[self._slice]

@_dask_cudf_nvtx_annotate
@_check_groupby_supported
@_check_groupby_optimized
def first(self, split_every=None, split_out=1):
return groupby_agg(
self.obj,
Expand All @@ -422,7 +418,7 @@ def first(self, split_every=None, split_out=1):
)[self._slice]

@_dask_cudf_nvtx_annotate
@_check_groupby_supported
@_check_groupby_optimized
def last(self, split_every=None, split_out=1):
return groupby_agg(
self.obj,
Expand All @@ -446,7 +442,7 @@ def aggregate(self, arg, split_every=None, split_out=1, shuffle=None):
if not isinstance(arg, dict):
arg = {self._slice: arg}

if _groupby_supported(self) and _aggs_supported(arg, SUPPORTED_AGGS):
if _groupby_optimized(self) and _aggs_optimized(arg, OPTIMIZED_AGGS):
return groupby_agg(
self.obj,
self.by,
Expand Down Expand Up @@ -569,9 +565,9 @@ def groupby_agg(
"""
# Assert that aggregations are supported
aggs = _redirect_aggs(aggs_in)
if not _aggs_supported(aggs, SUPPORTED_AGGS):
if not _aggs_optimized(aggs, OPTIMIZED_AGGS):
raise ValueError(
f"Supported aggs include {SUPPORTED_AGGS} for groupby_agg API. "
f"Supported aggs include {OPTIMIZED_AGGS} for groupby_agg API. "
f"Aggregations must be specified with dict or list syntax."
)

Expand Down Expand Up @@ -735,7 +731,7 @@ def _redirect_aggs(arg):


@_dask_cudf_nvtx_annotate
def _aggs_supported(arg, supported: set):
def _aggs_optimized(arg, supported: set):
"""Check that aggregations in `arg` are a subset of `supported`"""
if isinstance(arg, (list, dict)):
if isinstance(arg, dict):
Expand All @@ -757,8 +753,8 @@ def _aggs_supported(arg, supported: set):


@_dask_cudf_nvtx_annotate
def _groupby_supported(gb):
"""Check that groupby input is supported by dask-cudf"""
def _groupby_optimized(gb):
"""Check that groupby input can use dask-cudf optimized codepath"""
return isinstance(gb.obj, DaskDataFrame) and (
isinstance(gb.by, str)
or (isinstance(gb.by, list) and all(isinstance(x, str) for x in gb.by))
Expand Down Expand Up @@ -830,7 +826,7 @@ def _tree_node_agg(df, gb_cols, dropna, sort, sep):
agg = col.split(sep)[-1]
if agg in ("count", "sum"):
agg_dict[col] = ["sum"]
elif agg in SUPPORTED_AGGS:
elif agg in OPTIMIZED_AGGS:
agg_dict[col] = [agg]
else:
raise ValueError(f"Unexpected aggregation: {agg}")
Expand Down
Loading

0 comments on commit 17b6b2e

Please sign in to comment.