diff --git a/python/dask_cudf/dask_cudf/backends.py b/python/dask_cudf/dask_cudf/backends.py index 65d9c438fba..b0da82eaeee 100644 --- a/python/dask_cudf/dask_cudf/backends.py +++ b/python/dask_cudf/dask_cudf/backends.py @@ -481,6 +481,31 @@ def sizeof_cudf_series_index(obj): return obj.memory_usage() +# TODO: Remove try/except when cudf is pinned to dask>=2023.10.0 +try: + from dask.dataframe.dispatch import partd_encode_dispatch + + @partd_encode_dispatch.register(cudf.DataFrame) + def _simple_cudf_encode(_): + # Basic pickle-based encoding for a partd k-v store + import pickle + from functools import partial + + import partd + + def join(dfs): + if not dfs: + return cudf.DataFrame() + else: + return cudf.concat(dfs) + + dumps = partial(pickle.dumps, protocol=pickle.HIGHEST_PROTOCOL) + return partial(partd.Encode, dumps, pickle.loads, join) + +except ImportError: + pass + + def _default_backend(func, *args, **kwargs): # Utility to call a dask.dataframe function with # the default ("pandas") backend diff --git a/python/dask_cudf/dask_cudf/tests/test_sort.py b/python/dask_cudf/dask_cudf/tests/test_sort.py index 94609b180d6..e58255cda06 100644 --- a/python/dask_cudf/dask_cudf/tests/test_sort.py +++ b/python/dask_cudf/dask_cudf/tests/test_sort.py @@ -114,3 +114,14 @@ def test_sort_values_empty_string(by): if "a" in by: expect = df.sort_values(by) assert dd.assert_eq(got, expect, check_index=False) + + +def test_disk_shuffle(): + try: + from dask.dataframe.dispatch import partd_encode_dispatch # noqa: F401 + except ImportError: + pytest.skip("need a version of dask that has partd_encode_dispatch") + df = cudf.DataFrame({"a": [1, 2, 3] * 20, "b": [4, 5, 6, 7] * 15}) + ddf = dd.from_pandas(df, npartitions=4) + got = dd.DataFrame.shuffle(ddf, "a", shuffle="disk") + dd.assert_eq(got, df)