Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enable sorting on column with nulls using query-planning #15639

Merged
merged 7 commits into from
May 9, 2024
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
53 changes: 53 additions & 0 deletions python/dask_cudf/dask_cudf/expr/_expr.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
# Copyright (c) 2024, NVIDIA CORPORATION.
import functools

import numpy as np
from dask_expr._cumulative import CumulativeBlockwise
from dask_expr._expr import Expr, VarColumns
from dask_expr._quantiles import RepartitionQuantiles
from dask_expr._reductions import Reduction, Var

from dask.dataframe.core import is_dataframe_like, make_meta, meta_nonempty
Expand Down Expand Up @@ -121,3 +123,54 @@ def _patched_var(


Expr.var = _patched_var


# Add custom code path for RepartitionQuantiles, because
# upstream logic fails when null values are present. Note
# that the cudf-specific code path can also be used for
# multi-column divisions in the future.


def _quantile(a, q):
if a.empty:
# Avoid calling `quantile` on empty data
return None, 0
a = a.to_frame() if a.ndim == 1 else a
return (
a.quantile(q=q.tolist(), interpolation="nearest", method="table"),
len(a),
)


def merge_quantiles(finalq, qs, vals):
from dask_cudf.sorting import merge_quantiles as mq

return mq(finalq, qs, vals).iloc[:, 0].to_pandas()


_original_layer = RepartitionQuantiles._layer


def _cudf_layer(self):
if hasattr(self._meta, "to_pandas"):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: this canary for whether or not we have a cudf-backed object is also true for pyarrow objects I think. I know this is a hypothetical right now, but just noting it.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you add a short comment about what exactly this code does differently from the upstream version such that it can handle nulls?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: this canary for whether or not we have a cudf-backed object is also true for pyarrow objects I think. I know this is a hypothetical right now, but just noting it.

Yeah, sorry. This was used as a simple toggle when I was experimenting in the dask-expr code itself. If the code lives in dask-cudf, we can do an instance check.

Could you add a short comment about what exactly this code does differently from the upstream version such that it can handle nulls?

Yes. I agree that a more-thorough comment is definitely needed here.

# pandas/cudf uses quantile in [0, 1]
# numpy / cupy uses [0, 100]
qs = np.linspace(0.0, 1.0, self.input_npartitions + 1)
val_dsk = {
(self._name, 0, i): (_quantile, key, qs)
for i, key in enumerate(self.frame.__dask_keys__())
}
merge_dsk = {
(self._name, 0): (
merge_quantiles,
qs,
[qs] * self.input_npartitions,
sorted(val_dsk),
)
}
return {**val_dsk, **merge_dsk}
else:
return _original_layer(self)


RepartitionQuantiles._layer = _cudf_layer
1 change: 0 additions & 1 deletion python/dask_cudf/dask_cudf/tests/test_sort.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,6 @@ def test_sort_repartition():
dd.assert_eq(len(new_ddf), len(ddf))


@xfail_dask_expr("dask-expr code path fails with nulls")
@pytest.mark.parametrize("na_position", ["first", "last"])
@pytest.mark.parametrize("ascending", [True, False])
@pytest.mark.parametrize("by", ["a", "b", ["a", "b"]])
Expand Down
Loading