Skip to content

Commit

Permalink
Expose null equality handling in pylibcudf join routines
Browse files Browse the repository at this point in the history
  • Loading branch information
wence- committed Feb 20, 2024
1 parent 94eb621 commit c792e15
Show file tree
Hide file tree
Showing 4 changed files with 113 additions and 17 deletions.
32 changes: 31 additions & 1 deletion python/cudf/cudf/_lib/cpp/join.pxd
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ from rmm._lib.device_uvector cimport device_uvector
from cudf._lib.cpp.column.column cimport column
from cudf._lib.cpp.table.table cimport table
from cudf._lib.cpp.table.table_view cimport table_view
from cudf._lib.cpp.types cimport size_type
from cudf._lib.cpp.types cimport null_equality, size_type

ctypedef unique_ptr[device_uvector[size_type]] gather_map_type
ctypedef pair[gather_map_type, gather_map_type] gather_map_pair_type
Expand Down Expand Up @@ -40,3 +40,33 @@ cdef extern from "cudf/join.hpp" namespace "cudf" nogil:
const table_view left_keys,
const table_view right_keys,
) except +

cdef gather_map_pair_type inner_join(
const table_view left_keys,
const table_view right_keys,
null_equality nulls_equal,
) except +

cdef gather_map_pair_type left_join(
const table_view left_keys,
const table_view right_keys,
null_equality nulls_equal,
) except +

cdef gather_map_pair_type full_join(
const table_view left_keys,
const table_view right_keys,
null_equality nulls_equal,
) except +

cdef gather_map_type left_semi_join(
const table_view left_keys,
const table_view right_keys,
null_equality nulls_equal,
) except +

cdef gather_map_type left_anti_join(
const table_view left_keys,
const table_view right_keys,
null_equality nulls_equal,
) except +
2 changes: 2 additions & 0 deletions python/cudf/cudf/_lib/join.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ def join(list lhs, list rhs, how=None):
left_rows, right_rows = join_func(
pylibcudf.Table([c.to_pylibcudf(mode="read") for c in lhs]),
pylibcudf.Table([c.to_pylibcudf(mode="read") for c in rhs]),
pylibcudf.types.NullEquality.EQUAL
)
return Column.from_pylibcudf(left_rows), Column.from_pylibcudf(right_rows)

Expand All @@ -37,5 +38,6 @@ def semi_join(list lhs, list rhs, how=None):
join_func(
pylibcudf.Table([c.to_pylibcudf(mode="read") for c in lhs]),
pylibcudf.Table([c.to_pylibcudf(mode="read") for c in rhs]),
pylibcudf.types.NullEquality.EQUAL
)
), None
32 changes: 27 additions & 5 deletions python/cudf/cudf/_lib/pylibcudf/join.pxd
Original file line number Diff line number Diff line change
@@ -1,15 +1,37 @@
# Copyright (c) 2024, NVIDIA CORPORATION.

from cudf._lib.cpp.types cimport null_equality

from .column cimport Column
from .table cimport Table


cpdef tuple inner_join(Table left_keys, Table right_keys)
cpdef tuple inner_join(
Table left_keys,
Table right_keys,
null_equality nulls_equal
)

cpdef tuple left_join(Table left_keys, Table right_keys)
cpdef tuple left_join(
Table left_keys,
Table right_keys,
null_equality nulls_equal
)

cpdef tuple full_join(Table left_keys, Table right_keys)
cpdef tuple full_join(
Table left_keys,
Table right_keys,
null_equality nulls_equal
)

cpdef Column left_semi_join(Table left_keys, Table right_keys)
cpdef Column left_semi_join(
Table left_keys,
Table right_keys,
null_equality nulls_equal
)

cpdef Column left_anti_join(Table left_keys, Table right_keys)
cpdef Column left_anti_join(
Table left_keys,
Table right_keys,
null_equality nulls_equal
)
64 changes: 53 additions & 11 deletions python/cudf/cudf/_lib/pylibcudf/join.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ from rmm._lib.device_buffer cimport device_buffer

from cudf._lib.cpp cimport join as cpp_join
from cudf._lib.cpp.column.column cimport column
from cudf._lib.cpp.types cimport data_type, size_type, type_id
from cudf._lib.cpp.types cimport data_type, null_equality, size_type, type_id

from .column cimport Column
from .table cimport Table
Expand All @@ -32,7 +32,11 @@ cdef Column _column_from_gather_map(cpp_join.gather_map_type gather_map):
)


cpdef tuple inner_join(Table left_keys, Table right_keys):
cpdef tuple inner_join(
Table left_keys,
Table right_keys,
null_equality nulls_equal
):
"""Perform an inner join between two tables.
For details, see :cpp:func:`inner_join`.
Expand All @@ -43,6 +47,8 @@ cpdef tuple inner_join(Table left_keys, Table right_keys):
The left table to join.
right_keys : Table
The right table to join.
nulls_equal : NullEquality
Should nulls compare equal?
Returns
-------
Expand All @@ -52,14 +58,18 @@ cpdef tuple inner_join(Table left_keys, Table right_keys):
"""
cdef cpp_join.gather_map_pair_type c_result
with nogil:
c_result = cpp_join.inner_join(left_keys.view(), right_keys.view())
c_result = cpp_join.inner_join(left_keys.view(), right_keys.view(), nulls_equal)
return (
_column_from_gather_map(move(c_result.first)),
_column_from_gather_map(move(c_result.second)),
)


cpdef tuple left_join(Table left_keys, Table right_keys):
cpdef tuple left_join(
Table left_keys,
Table right_keys,
null_equality nulls_equal
):
"""Perform a left join between two tables.
For details, see :cpp:func:`left_join`.
Expand All @@ -70,6 +80,9 @@ cpdef tuple left_join(Table left_keys, Table right_keys):
The left table to join.
right_keys : Table
The right table to join.
nulls_equal : NullEquality
Should nulls compare equal?
Returns
-------
Expand All @@ -79,14 +92,18 @@ cpdef tuple left_join(Table left_keys, Table right_keys):
"""
cdef cpp_join.gather_map_pair_type c_result
with nogil:
c_result = cpp_join.left_join(left_keys.view(), right_keys.view())
c_result = cpp_join.left_join(left_keys.view(), right_keys.view(), nulls_equal)
return (
_column_from_gather_map(move(c_result.first)),
_column_from_gather_map(move(c_result.second)),
)


cpdef tuple full_join(Table left_keys, Table right_keys):
cpdef tuple full_join(
Table left_keys,
Table right_keys,
null_equality nulls_equal
):
"""Perform a full join between two tables.
For details, see :cpp:func:`full_join`.
Expand All @@ -97,6 +114,9 @@ cpdef tuple full_join(Table left_keys, Table right_keys):
The left table to join.
right_keys : Table
The right table to join.
nulls_equal : NullEquality
Should nulls compare equal?
Returns
-------
Expand All @@ -106,14 +126,18 @@ cpdef tuple full_join(Table left_keys, Table right_keys):
"""
cdef cpp_join.gather_map_pair_type c_result
with nogil:
c_result = cpp_join.full_join(left_keys.view(), right_keys.view())
c_result = cpp_join.full_join(left_keys.view(), right_keys.view(), nulls_equal)
return (
_column_from_gather_map(move(c_result.first)),
_column_from_gather_map(move(c_result.second)),
)


cpdef Column left_semi_join(Table left_keys, Table right_keys):
cpdef Column left_semi_join(
Table left_keys,
Table right_keys,
null_equality nulls_equal
):
"""Perform a left semi join between two tables.
For details, see :cpp:func:`left_semi_join`.
Expand All @@ -124,6 +148,9 @@ cpdef Column left_semi_join(Table left_keys, Table right_keys):
The left table to join.
right_keys : Table
The right table to join.
nulls_equal : NullEquality
Should nulls compare equal?
Returns
-------
Expand All @@ -132,11 +159,19 @@ cpdef Column left_semi_join(Table left_keys, Table right_keys):
"""
cdef cpp_join.gather_map_type c_result
with nogil:
c_result = cpp_join.left_semi_join(left_keys.view(), right_keys.view())
c_result = cpp_join.left_semi_join(
left_keys.view(),
right_keys.view(),
nulls_equal
)
return _column_from_gather_map(move(c_result))


cpdef Column left_anti_join(Table left_keys, Table right_keys):
cpdef Column left_anti_join(
Table left_keys,
Table right_keys,
null_equality nulls_equal
):
"""Perform a left anti join between two tables.
For details, see :cpp:func:`left_anti_join`.
Expand All @@ -147,6 +182,9 @@ cpdef Column left_anti_join(Table left_keys, Table right_keys):
The left table to join.
right_keys : Table
The right table to join.
nulls_equal : NullEquality
Should nulls compare equal?
Returns
-------
Expand All @@ -155,5 +193,9 @@ cpdef Column left_anti_join(Table left_keys, Table right_keys):
"""
cdef cpp_join.gather_map_type c_result
with nogil:
c_result = cpp_join.left_anti_join(left_keys.view(), right_keys.view())
c_result = cpp_join.left_anti_join(
left_keys.view(),
right_keys.view(),
nulls_equal
)
return _column_from_gather_map(move(c_result))

0 comments on commit c792e15

Please sign in to comment.