Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Revise shuffle deprecation to align with dask/dask #14762

Merged
merged 3 commits into from
Jan 17, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 6 additions & 1 deletion python/dask_cudf/dask_cudf/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@

from dask_cudf import sorting
from dask_cudf.accessors import ListMethods, StructMethods
from dask_cudf.sorting import _get_shuffle_method
from dask_cudf.sorting import _deprecate_shuffle_kwarg, _get_shuffle_method


class _Frame(dd.core._Frame, OperatorMethodMixin):
Expand Down Expand Up @@ -111,6 +111,7 @@ def do_apply_rows(df, func, incols, outcols, kwargs):
do_apply_rows, func, incols, outcols, kwargs, meta=meta
)

@_deprecate_shuffle_kwarg
@_dask_cudf_nvtx_annotate
def merge(self, other, shuffle_method=None, **kwargs):
on = kwargs.pop("on", None)
Expand All @@ -123,6 +124,7 @@ def merge(self, other, shuffle_method=None, **kwargs):
**kwargs,
)

@_deprecate_shuffle_kwarg
@_dask_cudf_nvtx_annotate
def join(self, other, shuffle_method=None, **kwargs):
# CuDF doesn't support "right" join yet
Expand All @@ -141,6 +143,7 @@ def join(self, other, shuffle_method=None, **kwargs):
**kwargs,
)

@_deprecate_shuffle_kwarg
@_dask_cudf_nvtx_annotate
def set_index(
self,
Expand Down Expand Up @@ -216,6 +219,7 @@ def set_index(
**kwargs,
)

@_deprecate_shuffle_kwarg
@_dask_cudf_nvtx_annotate
def sort_values(
self,
Expand Down Expand Up @@ -298,6 +302,7 @@ def var(
else:
return _parallel_var(self, meta, skipna, split_every, out)

@_deprecate_shuffle_kwarg
@_dask_cudf_nvtx_annotate
def shuffle(self, *args, shuffle_method=None, **kwargs):
"""Wraps dask.dataframe DataFrame.shuffle method"""
Expand Down
20 changes: 14 additions & 6 deletions python/dask_cudf/dask_cudf/groupby.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@
import cudf
from cudf.utils.nvtx_annotation import _dask_cudf_nvtx_annotate

from dask_cudf.sorting import _deprecate_shuffle_kwarg

# aggregations that are dask-cudf optimized
OPTIMIZED_AGGS = (
"count",
Expand Down Expand Up @@ -189,8 +191,11 @@ def last(self, split_every=None, split_out=1):
split_out,
)

@_deprecate_shuffle_kwarg
@_dask_cudf_nvtx_annotate
def aggregate(self, arg, split_every=None, split_out=1, shuffle=None):
def aggregate(
self, arg, split_every=None, split_out=1, shuffle_method=None
):
if arg == "size":
return self.size()

Expand All @@ -211,15 +216,15 @@ def aggregate(self, arg, split_every=None, split_out=1, shuffle=None):
sep=self.sep,
sort=self.sort,
as_index=self.as_index,
shuffle_method=shuffle,
shuffle_method=shuffle_method,
**self.dropna,
)

return super().aggregate(
arg,
split_every=split_every,
split_out=split_out,
shuffle=shuffle,
shuffle_method=shuffle_method,
)


Expand Down Expand Up @@ -330,8 +335,11 @@ def last(self, split_every=None, split_out=1):
split_out,
)[self._slice]

@_deprecate_shuffle_kwarg
@_dask_cudf_nvtx_annotate
def aggregate(self, arg, split_every=None, split_out=1, shuffle=None):
def aggregate(
self, arg, split_every=None, split_out=1, shuffle_method=None
):
if arg == "size":
return self.size()

Expand All @@ -342,14 +350,14 @@ def aggregate(self, arg, split_every=None, split_out=1, shuffle=None):

if _groupby_optimized(self) and _aggs_optimized(arg, OPTIMIZED_AGGS):
return _make_groupby_agg_call(
self, arg, split_every, split_out, shuffle
self, arg, split_every, split_out, shuffle_method
)[self._slice]

return super().aggregate(
arg,
split_every=split_every,
split_out=split_out,
shuffle=shuffle,
shuffle_method=shuffle_method,
)


Expand Down
28 changes: 28 additions & 0 deletions python/dask_cudf/dask_cudf/sorting.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
# Copyright (c) 2020-2024, NVIDIA CORPORATION.

import warnings
from collections.abc import Iterator
from functools import wraps

import cupy
import numpy as np
Expand All @@ -21,6 +23,31 @@
_SHUFFLE_SUPPORT = ("tasks", "p2p") # "disk" not supported


def _deprecate_shuffle_kwarg(func):
@wraps(func)
def wrapper(*args, **kwargs):
old_arg_value = kwargs.pop("shuffle", None)

if old_arg_value is not None:
new_arg_value = old_arg_value
msg = (
"the 'shuffle' keyword is deprecated, "
"use 'shuffle_method' instead."
)

warnings.warn(msg, FutureWarning)
if kwargs.get("shuffle_method") is not None:
msg = (
"Can only specify 'shuffle' "
"or 'shuffle_method', not both."
)
raise TypeError(msg)
kwargs["shuffle_method"] = new_arg_value
return func(*args, **kwargs)

return wrapper


@_dask_cudf_nvtx_annotate
def set_index_post(df, index_name, drop, column_dtype):
df2 = df.set_index(index_name, drop=drop)
Expand Down Expand Up @@ -229,6 +256,7 @@ def quantile_divisions(df, by, npartitions):
return divisions


@_deprecate_shuffle_kwarg
@_dask_cudf_nvtx_annotate
def sort_values(
df,
Expand Down
24 changes: 18 additions & 6 deletions python/dask_cudf/dask_cudf/tests/test_groupby.py
Original file line number Diff line number Diff line change
Expand Up @@ -834,26 +834,38 @@ def test_groupby_shuffle():

# Sorted aggregation, single-partition output
# (sort=True, split_out=1)
got = gddf.groupby("a", sort=True).agg(spec, shuffle=True, split_out=1)
got = gddf.groupby("a", sort=True).agg(
spec, shuffle_method=True, split_out=1
)
dd.assert_eq(expect, got)

# Sorted aggregation, multi-partition output
# (sort=True, split_out=2)
got = gddf.groupby("a", sort=True).agg(spec, shuffle=True, split_out=2)
got = gddf.groupby("a", sort=True).agg(
spec, shuffle_method=True, split_out=2
)
dd.assert_eq(expect, got)

# Un-sorted aggregation, single-partition output
# (sort=False, split_out=1)
got = gddf.groupby("a", sort=False).agg(spec, shuffle=True, split_out=1)
got = gddf.groupby("a", sort=False).agg(
spec, shuffle_method=True, split_out=1
)
dd.assert_eq(expect.sort_index(), got.compute().sort_index())

# Un-sorted aggregation, multi-partition output
# (sort=False, split_out=2)
# NOTE: `shuffle=True` should be default
# NOTE: `shuffle_method=True` should be default
got = gddf.groupby("a", sort=False).agg(spec, split_out=2)
dd.assert_eq(expect, got.compute().sort_index())

# Sorted aggregation fails with split_out>1 when shuffle is False
# (sort=True, split_out=2, shuffle=False)
# (sort=True, split_out=2, shuffle_method=False)
with pytest.raises(ValueError):
gddf.groupby("a", sort=True).agg(spec, shuffle=False, split_out=2)
gddf.groupby("a", sort=True).agg(
spec, shuffle_method=False, split_out=2
)

# Check shuffle kwarg deprecation
with pytest.warns(match="'shuffle' keyword is deprecated"):
gddf.groupby("a", sort=True).agg(spec, shuffle=False)