diff --git a/python/cudf_polars/cudf_polars/dsl/expr.py b/python/cudf_polars/cudf_polars/dsl/expr.py index f83d9e82d30..86d0a65a2d8 100644 --- a/python/cudf_polars/cudf_polars/dsl/expr.py +++ b/python/cudf_polars/cudf_polars/dsl/expr.py @@ -867,7 +867,7 @@ def __init__( self.name = name self.options = options self.children = children - if self.name not in ("round", "unique"): + if self.name not in ("round", "unique", "setsorted"): raise NotImplementedError(f"Unary function {name=}") def do_evaluate( @@ -923,6 +923,33 @@ def do_evaluate( if maintain_order: return Column(column).sorted_like(values) return Column(column) + elif self.name == "setsorted": + (column,) = ( + child.evaluate(df, context=context, mapping=mapping) + for child in self.children + ) + (asc,) = self.options + order = ( + plc.types.Order.ASCENDING + if asc == "ascending" + else plc.types.Order.DESCENDING + ) + null_order = plc.types.NullOrder.BEFORE + if column.obj.null_count() > 0 and (n := column.obj.size()) > 1: + # PERF: This invokes four stream synchronisations! + has_nulls_first = not plc.copying.get_element(column.obj, 0).is_valid() + has_nulls_last = not plc.copying.get_element( + column.obj, n - 1 + ).is_valid() + if (order == plc.types.Order.DESCENDING and has_nulls_first) or ( + order == plc.types.Order.ASCENDING and has_nulls_last + ): + null_order = plc.types.NullOrder.AFTER + return column.set_sorted( + is_sorted=plc.types.Sorted.YES, + order=order, + null_order=null_order, + ) raise NotImplementedError( f"Unimplemented unary function {self.name=}" ) # pragma: no cover; init trips first diff --git a/python/cudf_polars/tests/expressions/test_agg.py b/python/cudf_polars/tests/expressions/test_agg.py index 267d0a99692..37c52d35083 100644 --- a/python/cudf_polars/tests/expressions/test_agg.py +++ b/python/cudf_polars/tests/expressions/test_agg.py @@ -20,13 +20,7 @@ def dtype(request): return request.param -@pytest.fixture( - params=[ - False, - pytest.param(True, marks=pytest.mark.xfail(reason="No handler for set_sorted")), - ], - ids=["unsorted", "sorted"], -) +@pytest.fixture(params=[False, True], ids=["unsorted", "sorted"]) def is_sorted(request): return request.param diff --git a/python/cudf_polars/tests/expressions/test_sort.py b/python/cudf_polars/tests/expressions/test_sort.py index 0195266f5c6..d46df92db94 100644 --- a/python/cudf_polars/tests/expressions/test_sort.py +++ b/python/cudf_polars/tests/expressions/test_sort.py @@ -8,6 +8,9 @@ import polars as pl +import cudf._lib.pylibcudf as plc + +from cudf_polars import translate_ir from cudf_polars.testing.asserts import assert_gpu_result_equal @@ -51,3 +54,31 @@ def test_sort_by_expression(descending, nulls_last, maintain_order): ) ) assert_gpu_result_equal(query, check_row_order=maintain_order) + + +@pytest.mark.parametrize("descending", [False, True]) +@pytest.mark.parametrize("nulls_last", [False, True]) +def test_setsorted(descending, nulls_last, with_nulls): + values = sorted([1, 2, 3, 4, 5, 6, -2], reverse=descending) + if with_nulls: + values[-1 if nulls_last else 0] = None + df = pl.LazyFrame({"a": values}) + + q = df.set_sorted("a", descending=descending) + + assert_gpu_result_equal(q) + + df = translate_ir(q._ldf.visit()).evaluate(cache={}) + + (a,) = df.columns + + assert a.is_sorted == plc.types.Sorted.YES + null_order = ( + plc.types.NullOrder.AFTER + if (descending ^ nulls_last) and with_nulls + else plc.types.NullOrder.BEFORE + ) + assert a.null_order == null_order + assert a.order == ( + plc.types.Order.DESCENDING if descending else plc.types.Order.ASCENDING + )