Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Migrate lists/sorting to pylibcudf #16179

Merged
merged 3 commits into from
Jul 22, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 7 additions & 21 deletions python/cudf/cudf/_lib/lists.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -11,17 +11,13 @@ from cudf._lib.pylibcudf.libcudf.column.column cimport column
from cudf._lib.pylibcudf.libcudf.lists.lists_column_view cimport (
lists_column_view,
)
from cudf._lib.pylibcudf.libcudf.lists.sorting cimport (
sort_lists as cpp_sort_lists,
)
from cudf._lib.pylibcudf.libcudf.lists.stream_compaction cimport (
distinct as cpp_distinct,
)
from cudf._lib.pylibcudf.libcudf.types cimport (
nan_equality,
null_equality,
null_order,
order,
size_type,
)
from cudf._lib.utils cimport columns_from_pylibcudf_table
Expand Down Expand Up @@ -80,24 +76,14 @@ def distinct(Column col, bool nulls_equal, bool nans_all_equal):

@acquire_spill_lock()
def sort_lists(Column col, bool ascending, str na_position):
cdef shared_ptr[lists_column_view] list_view = (
make_shared[lists_column_view](col.view())
)
cdef order c_sort_order = (
order.ASCENDING if ascending else order.DESCENDING
)
cdef null_order c_null_prec = (
null_order.BEFORE if na_position == "first" else null_order.AFTER
)

cdef unique_ptr[column] c_result

with nogil:
c_result = move(
cpp_sort_lists(list_view.get()[0], c_sort_order, c_null_prec)
return Column.from_pylibcudf(
pylibcudf.lists.sort_lists(
col.to_pylibcudf(mode="read"),
ascending,
null_order.BEFORE if na_position == "first" else null_order.AFTER,
False,
)

return Column.from_unique_ptr(move(c_result))
)


@acquire_spill_lock()
Expand Down
6 changes: 6 additions & 0 deletions python/cudf/cudf/_lib/pylibcudf/libcudf/lists/sorting.pxd
Original file line number Diff line number Diff line change
Expand Up @@ -15,3 +15,9 @@ cdef extern from "cudf/lists/sorting.hpp" namespace "cudf::lists" nogil:
order column_order,
null_order null_precedence
) except +
Matt711 marked this conversation as resolved.
Show resolved Hide resolved

cdef unique_ptr[column] stable_sort_lists(
const lists_column_view source_column,
order column_order,
null_order null_precedence
) except +
4 changes: 3 additions & 1 deletion python/cudf/cudf/_lib/pylibcudf/lists.pxd
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

from libcpp cimport bool

from cudf._lib.pylibcudf.libcudf.types cimport size_type
from cudf._lib.pylibcudf.libcudf.types cimport null_order, size_type

from .column cimport Column
from .scalar cimport Scalar
Expand Down Expand Up @@ -35,3 +35,5 @@ cpdef Column segmented_gather(Column, Column)
cpdef Column extract_list_element(Column, ColumnOrSizeType)

cpdef Column count_elements(Column)

cpdef Column sort_lists(Column, bool, null_order, bool stable = *)
57 changes: 56 additions & 1 deletion python/cudf/cudf/_lib/pylibcudf/lists.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,12 @@ from cudf._lib.pylibcudf.libcudf.lists.count_elements cimport (
from cudf._lib.pylibcudf.libcudf.lists.extract cimport (
extract_list_element as cpp_extract_list_element,
)
from cudf._lib.pylibcudf.libcudf.lists.sorting cimport (
sort_lists as cpp_sort_lists,
stable_sort_lists as cpp_stable_sort_lists,
)
from cudf._lib.pylibcudf.libcudf.table.table cimport table
from cudf._lib.pylibcudf.libcudf.types cimport size_type
from cudf._lib.pylibcudf.libcudf.types cimport null_order, order, size_type
from cudf._lib.pylibcudf.lists cimport ColumnOrScalar, ColumnOrSizeType

from .column cimport Column, ListColumnView
Expand Down Expand Up @@ -320,3 +324,54 @@ cpdef Column count_elements(Column input):
c_result = move(cpp_count_elements(list_view.view()))

return Column.from_libcudf(move(c_result))


cpdef Column sort_lists(
Column input,
bool ascending,
null_order na_position,
bool stable = False
):
"""Sort the elements within a list in each row of a list column.

For details, see :cpp:func:`sort_lists`.

Parameters
----------
input : Column
The input column.
ascending : bool
If true, the sort order is ascending. Otherwise, the sort order is descending.
na_position : NullOrder
If na_position equals NullOrder.FIRST, then the null values in the output
column are placed first. Otherwise, they are be placed after.
stable: bool
If true :cpp:func:`stable_sort_lists` is used, Otherwise,
:cpp:func:`sort_lists` is used.

Returns
-------
Column
A new Column with elements in each list sorted.
"""
cdef unique_ptr[column] c_result
cdef ListColumnView list_view = input.list_view()

cdef order c_sort_order = (
order.ASCENDING if ascending else order.DESCENDING
)

with nogil:
if stable:
c_result = move(cpp_stable_sort_lists(
list_view.view(),
c_sort_order,
na_position,
))
else:
c_result = move(cpp_sort_lists(
list_view.view(),
c_sort_order,
na_position,
))
return Column.from_libcudf(move(c_result))
46 changes: 46 additions & 0 deletions python/cudf/cudf/pylibcudf_tests/test_lists.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,11 @@ def column():
return pa.array([3, 2, 5, 6]), pa.array([-1, 0, 0, 0], type=pa.int32())


@pytest.fixture
def lists_column():
return [[4, 2, 3, 1], [1, 2, None, 4], [-10, 10, 10, 0]]


def test_concatenate_rows(test_data):
arrow_tbl = pa.Table.from_arrays(test_data[0], names=["a", "b"])
plc_tbl = plc.interop.from_arrow(arrow_tbl)
Expand Down Expand Up @@ -191,3 +196,44 @@ def test_count_elements(test_data):
expect = pa.array([1, 1, 0, 3], type=pa.int32())

assert_column_eq(expect, res)


@pytest.mark.parametrize(
"ascending,na_position,expected",
[
(
True,
plc.types.NullOrder.BEFORE,
[[1, 2, 3, 4], [None, 1, 2, 4], [-10, 0, 10, 10]],
),
(
True,
plc.types.NullOrder.AFTER,
[[1, 2, 3, 4], [1, 2, 4, None], [-10, 0, 10, 10]],
),
(
False,
plc.types.NullOrder.BEFORE,
[[4, 3, 2, 1], [4, 2, 1, None], [10, 10, 0, -10]],
),
(
False,
plc.types.NullOrder.AFTER,
[[4, 3, 2, 1], [None, 4, 2, 1], [10, 10, 0, -10]],
),
(
False,
plc.types.NullOrder.AFTER,
[[4, 3, 2, 1], [None, 4, 2, 1], [10, 10, 0, -10]],
),
],
)
def test_sort_lists(lists_column, ascending, na_position, expected):
plc_column = plc.interop.from_arrow(pa.array(lists_column))
res = plc.lists.sort_lists(plc_column, ascending, na_position, False)
res_stable = plc.lists.sort_lists(plc_column, ascending, na_position, True)

expect = pa.array(expected)

assert_column_eq(expect, res)
assert_column_eq(expect, res_stable)
Matt711 marked this conversation as resolved.
Show resolved Hide resolved
Loading