Skip to content

Commit

Permalink
Implement handler for setsorted
Browse files Browse the repository at this point in the history
  • Loading branch information
wence- committed Jul 11, 2024
1 parent 5a39b6b commit 9bb267f
Show file tree
Hide file tree
Showing 3 changed files with 60 additions and 8 deletions.
29 changes: 28 additions & 1 deletion python/cudf_polars/cudf_polars/dsl/expr.py
Original file line number Diff line number Diff line change
Expand Up @@ -867,7 +867,7 @@ def __init__(
self.name = name
self.options = options
self.children = children
if self.name not in ("round", "unique"):
if self.name not in ("round", "unique", "setsorted"):
raise NotImplementedError(f"Unary function {name=}")

def do_evaluate(
Expand Down Expand Up @@ -923,6 +923,33 @@ def do_evaluate(
if maintain_order:
return Column(column).sorted_like(values)
return Column(column)
elif self.name == "setsorted":
(column,) = (
child.evaluate(df, context=context, mapping=mapping)
for child in self.children
)
(asc,) = self.options
order = (
plc.types.Order.ASCENDING
if asc == "ascending"
else plc.types.Order.DESCENDING
)
null_order = plc.types.NullOrder.BEFORE
if column.obj.null_count() > 0 and (n := column.obj.size()) > 1:
# PERF: This invokes four stream synchronisations!
has_nulls_first = not plc.copying.get_element(column.obj, 0).is_valid()
has_nulls_last = not plc.copying.get_element(
column.obj, n - 1
).is_valid()
if (order == plc.types.Order.DESCENDING and has_nulls_first) or (
order == plc.types.Order.ASCENDING and has_nulls_last
):
null_order = plc.types.NullOrder.AFTER
return column.set_sorted(
is_sorted=plc.types.Sorted.YES,
order=order,
null_order=null_order,
)
raise NotImplementedError(
f"Unimplemented unary function {self.name=}"
) # pragma: no cover; init trips first
Expand Down
8 changes: 1 addition & 7 deletions python/cudf_polars/tests/expressions/test_agg.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,13 +20,7 @@ def dtype(request):
return request.param


@pytest.fixture(
params=[
False,
pytest.param(True, marks=pytest.mark.xfail(reason="No handler for set_sorted")),
],
ids=["unsorted", "sorted"],
)
@pytest.fixture(params=[False, True], ids=["unsorted", "sorted"])
def is_sorted(request):
return request.param

Expand Down
31 changes: 31 additions & 0 deletions python/cudf_polars/tests/expressions/test_sort.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,9 @@

import polars as pl

import cudf._lib.pylibcudf as plc

from cudf_polars import translate_ir
from cudf_polars.testing.asserts import assert_gpu_result_equal


Expand Down Expand Up @@ -51,3 +54,31 @@ def test_sort_by_expression(descending, nulls_last, maintain_order):
)
)
assert_gpu_result_equal(query, check_row_order=maintain_order)


@pytest.mark.parametrize("descending", [False, True])
@pytest.mark.parametrize("nulls_last", [False, True])
def test_setsorted(descending, nulls_last, with_nulls):
values = sorted([1, 2, 3, 4, 5, 6, -2], reverse=descending)
if with_nulls:
values[-1 if nulls_last else 0] = None
df = pl.LazyFrame({"a": values})

q = df.set_sorted("a", descending=descending)

assert_gpu_result_equal(q)

df = translate_ir(q._ldf.visit()).evaluate(cache={})

(a,) = df.columns

assert a.is_sorted == plc.types.Sorted.YES
null_order = (
plc.types.NullOrder.AFTER
if (descending ^ nulls_last) and with_nulls
else plc.types.NullOrder.BEFORE
)
assert a.null_order == null_order
assert a.order == (
plc.types.Order.DESCENDING if descending else plc.types.Order.ASCENDING
)

0 comments on commit 9bb267f

Please sign in to comment.