Skip to content

Commit

Permalink
Rename util groupby functions from supported -> optimized
Browse files Browse the repository at this point in the history
  • Loading branch information
charlesbluca committed Nov 3, 2022
1 parent 5b2fe76 commit 2001ffe
Show file tree
Hide file tree
Showing 2 changed files with 38 additions and 37 deletions.
67 changes: 34 additions & 33 deletions python/dask_cudf/dask_cudf/groupby.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,8 @@
import cudf
from cudf.utils.utils import _dask_cudf_nvtx_annotate

SUPPORTED_AGGS = (
# aggregations that are dask-cudf optimized
OPTIMIZED_AGGS = (
"count",
"mean",
"std",
Expand All @@ -30,16 +31,16 @@
)


def _check_groupby_supported(func):
def _check_groupby_optimized(func):
"""
Decorator for dask-cudf's groupby methods that returns the dask-cudf
method if the groupby object is supported, otherwise reverting to the
upstream Dask method
optimized method if the groupby object is supported, otherwise
reverting to the upstream Dask method
"""

def wrapper(*args, **kwargs):
gb = args[0]
if _groupby_supported(gb):
if _groupby_optimized(gb):
return func(*args, **kwargs)
# note that we use upstream Dask's default kwargs for this call if
# none are specified; this shouldn't be an issue as those defaults are
Expand Down Expand Up @@ -87,7 +88,7 @@ def _make_groupby_method_aggs(self, agg_name):
return {c: agg_name for c in self.obj.columns if c != self.by}

@_dask_cudf_nvtx_annotate
@_check_groupby_supported
@_check_groupby_optimized
def count(self, split_every=None, split_out=1):
return groupby_agg(
self.obj,
Expand All @@ -102,7 +103,7 @@ def count(self, split_every=None, split_out=1):
)

@_dask_cudf_nvtx_annotate
@_check_groupby_supported
@_check_groupby_optimized
def mean(self, split_every=None, split_out=1):
return groupby_agg(
self.obj,
Expand All @@ -117,7 +118,7 @@ def mean(self, split_every=None, split_out=1):
)

@_dask_cudf_nvtx_annotate
@_check_groupby_supported
@_check_groupby_optimized
def std(self, split_every=None, split_out=1):
return groupby_agg(
self.obj,
Expand All @@ -132,7 +133,7 @@ def std(self, split_every=None, split_out=1):
)

@_dask_cudf_nvtx_annotate
@_check_groupby_supported
@_check_groupby_optimized
def var(self, split_every=None, split_out=1):
return groupby_agg(
self.obj,
Expand All @@ -147,7 +148,7 @@ def var(self, split_every=None, split_out=1):
)

@_dask_cudf_nvtx_annotate
@_check_groupby_supported
@_check_groupby_optimized
def sum(self, split_every=None, split_out=1):
return groupby_agg(
self.obj,
Expand All @@ -162,7 +163,7 @@ def sum(self, split_every=None, split_out=1):
)

@_dask_cudf_nvtx_annotate
@_check_groupby_supported
@_check_groupby_optimized
def min(self, split_every=None, split_out=1):
return groupby_agg(
self.obj,
Expand All @@ -177,7 +178,7 @@ def min(self, split_every=None, split_out=1):
)

@_dask_cudf_nvtx_annotate
@_check_groupby_supported
@_check_groupby_optimized
def max(self, split_every=None, split_out=1):
return groupby_agg(
self.obj,
Expand All @@ -192,7 +193,7 @@ def max(self, split_every=None, split_out=1):
)

@_dask_cudf_nvtx_annotate
@_check_groupby_supported
@_check_groupby_optimized
def collect(self, split_every=None, split_out=1):
return groupby_agg(
self.obj,
Expand All @@ -207,7 +208,7 @@ def collect(self, split_every=None, split_out=1):
)

@_dask_cudf_nvtx_annotate
@_check_groupby_supported
@_check_groupby_optimized
def first(self, split_every=None, split_out=1):
return groupby_agg(
self.obj,
Expand All @@ -222,7 +223,7 @@ def first(self, split_every=None, split_out=1):
)

@_dask_cudf_nvtx_annotate
@_check_groupby_supported
@_check_groupby_optimized
def last(self, split_every=None, split_out=1):
return groupby_agg(
self.obj,
Expand All @@ -243,7 +244,7 @@ def aggregate(self, arg, split_every=None, split_out=1, shuffle=None):

arg = _redirect_aggs(arg)

if _groupby_supported(self) and _aggs_supported(arg, SUPPORTED_AGGS):
if _groupby_optimized(self) and _aggs_optimized(arg, OPTIMIZED_AGGS):
if isinstance(self._meta.grouping.keys, cudf.MultiIndex):
keys = self._meta.grouping.keys.names
else:
Expand Down Expand Up @@ -280,7 +281,7 @@ def __init__(self, *args, sort=None, **kwargs):
super().__init__(*args, sort=sort, **kwargs)

@_dask_cudf_nvtx_annotate
@_check_groupby_supported
@_check_groupby_optimized
def count(self, split_every=None, split_out=1):
return groupby_agg(
self.obj,
Expand All @@ -295,7 +296,7 @@ def count(self, split_every=None, split_out=1):
)[self._slice]

@_dask_cudf_nvtx_annotate
@_check_groupby_supported
@_check_groupby_optimized
def mean(self, split_every=None, split_out=1):
return groupby_agg(
self.obj,
Expand All @@ -310,7 +311,7 @@ def mean(self, split_every=None, split_out=1):
)[self._slice]

@_dask_cudf_nvtx_annotate
@_check_groupby_supported
@_check_groupby_optimized
def std(self, split_every=None, split_out=1):
return groupby_agg(
self.obj,
Expand All @@ -325,7 +326,7 @@ def std(self, split_every=None, split_out=1):
)[self._slice]

@_dask_cudf_nvtx_annotate
@_check_groupby_supported
@_check_groupby_optimized
def var(self, split_every=None, split_out=1):
return groupby_agg(
self.obj,
Expand All @@ -340,7 +341,7 @@ def var(self, split_every=None, split_out=1):
)[self._slice]

@_dask_cudf_nvtx_annotate
@_check_groupby_supported
@_check_groupby_optimized
def sum(self, split_every=None, split_out=1):
return groupby_agg(
self.obj,
Expand All @@ -355,7 +356,7 @@ def sum(self, split_every=None, split_out=1):
)[self._slice]

@_dask_cudf_nvtx_annotate
@_check_groupby_supported
@_check_groupby_optimized
def min(self, split_every=None, split_out=1):
return groupby_agg(
self.obj,
Expand All @@ -370,7 +371,7 @@ def min(self, split_every=None, split_out=1):
)[self._slice]

@_dask_cudf_nvtx_annotate
@_check_groupby_supported
@_check_groupby_optimized
def max(self, split_every=None, split_out=1):
return groupby_agg(
self.obj,
Expand All @@ -385,7 +386,7 @@ def max(self, split_every=None, split_out=1):
)[self._slice]

@_dask_cudf_nvtx_annotate
@_check_groupby_supported
@_check_groupby_optimized
def collect(self, split_every=None, split_out=1):
return groupby_agg(
self.obj,
Expand All @@ -400,7 +401,7 @@ def collect(self, split_every=None, split_out=1):
)[self._slice]

@_dask_cudf_nvtx_annotate
@_check_groupby_supported
@_check_groupby_optimized
def first(self, split_every=None, split_out=1):
return groupby_agg(
self.obj,
Expand All @@ -415,7 +416,7 @@ def first(self, split_every=None, split_out=1):
)[self._slice]

@_dask_cudf_nvtx_annotate
@_check_groupby_supported
@_check_groupby_optimized
def last(self, split_every=None, split_out=1):
return groupby_agg(
self.obj,
Expand All @@ -439,7 +440,7 @@ def aggregate(self, arg, split_every=None, split_out=1, shuffle=None):
if not isinstance(arg, dict):
arg = {self._slice: arg}

if _groupby_supported(self) and _aggs_supported(arg, SUPPORTED_AGGS):
if _groupby_optimized(self) and _aggs_optimized(arg, OPTIMIZED_AGGS):
return groupby_agg(
self.obj,
self.by,
Expand Down Expand Up @@ -562,9 +563,9 @@ def groupby_agg(
"""
# Assert that aggregations are supported
aggs = _redirect_aggs(aggs_in)
if not _aggs_supported(aggs, SUPPORTED_AGGS):
if not _aggs_optimized(aggs, OPTIMIZED_AGGS):
raise ValueError(
f"Supported aggs include {SUPPORTED_AGGS} for groupby_agg API. "
f"Supported aggs include {OPTIMIZED_AGGS} for groupby_agg API. "
f"Aggregations must be specified with dict or list syntax."
)

Expand Down Expand Up @@ -728,7 +729,7 @@ def _redirect_aggs(arg):


@_dask_cudf_nvtx_annotate
def _aggs_supported(arg, supported: set):
def _aggs_optimized(arg, supported: set):
"""Check that aggregations in `arg` are a subset of `supported`"""
if isinstance(arg, (list, dict)):
if isinstance(arg, dict):
Expand All @@ -750,8 +751,8 @@ def _aggs_supported(arg, supported: set):


@_dask_cudf_nvtx_annotate
def _groupby_supported(gb):
"""Check that groupby input is supported by dask-cudf"""
def _groupby_optimized(gb):
"""Check that groupby input can use dask-cudf optimized codepath"""
return isinstance(gb.obj, DaskDataFrame) and (
isinstance(gb.by, str)
or (isinstance(gb.by, list) and all(isinstance(x, str) for x in gb.by))
Expand Down Expand Up @@ -823,7 +824,7 @@ def _tree_node_agg(df, gb_cols, dropna, sort, sep):
agg = col.split(sep)[-1]
if agg in ("count", "sum"):
agg_dict[col] = ["sum"]
elif agg in SUPPORTED_AGGS:
elif agg in OPTIMIZED_AGGS:
agg_dict[col] = [agg]
else:
raise ValueError(f"Unexpected aggregation: {agg}")
Expand Down
8 changes: 4 additions & 4 deletions python/dask_cudf/dask_cudf/tests/test_groupby.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from cudf.core._compat import PANDAS_GE_120

import dask_cudf
from dask_cudf.groupby import SUPPORTED_AGGS, _aggs_supported
from dask_cudf.groupby import OPTIMIZED_AGGS, _aggs_optimized


def assert_cudf_groupby_layers(ddf):
Expand Down Expand Up @@ -47,7 +47,7 @@ def pdf(request):
return pdf


@pytest.mark.parametrize("aggregation", SUPPORTED_AGGS)
@pytest.mark.parametrize("aggregation", OPTIMIZED_AGGS)
@pytest.mark.parametrize("series", [False, True])
def test_groupby_basic(series, aggregation, pdf):
gdf = cudf.DataFrame.from_pandas(pdf)
Expand Down Expand Up @@ -106,7 +106,7 @@ def test_groupby_cumulative(aggregation, pdf, series):
dd.assert_eq(a, b)


@pytest.mark.parametrize("aggregation", SUPPORTED_AGGS)
@pytest.mark.parametrize("aggregation", OPTIMIZED_AGGS)
@pytest.mark.parametrize(
"func",
[
Expand Down Expand Up @@ -706,7 +706,7 @@ def test_groupby_agg_redirect(aggregations):
],
)
def test_is_supported(arg, supported):
assert _aggs_supported(arg, SUPPORTED_AGGS) is supported
assert _aggs_optimized(arg, OPTIMIZED_AGGS) is supported


def test_groupby_unique_lists():
Expand Down

0 comments on commit 2001ffe

Please sign in to comment.