Skip to content

Commit

Permalink
Validate types in pylibcudf Column/Table constructors (#15088)
Browse files Browse the repository at this point in the history
Otherwise, someone can pass any random object to the constructor and will receive an unfriendly segfault when interacting with libcudf.

Authors:
  - Lawrence Mitchell (https://github.com/wence-)

Approvers:
  - GALI PREM SAGAR (https://github.com/galipremsagar)

URL: #15088
  • Loading branch information
wence- authored Feb 20, 2024
1 parent 3150676 commit 66b3a93
Show file tree
Hide file tree
Showing 7 changed files with 118 additions and 18 deletions.
32 changes: 31 additions & 1 deletion python/cudf/cudf/_lib/cpp/join.pxd
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ from rmm._lib.device_uvector cimport device_uvector
from cudf._lib.cpp.column.column cimport column
from cudf._lib.cpp.table.table cimport table
from cudf._lib.cpp.table.table_view cimport table_view
from cudf._lib.cpp.types cimport size_type
from cudf._lib.cpp.types cimport null_equality, size_type

ctypedef unique_ptr[device_uvector[size_type]] gather_map_type
ctypedef pair[gather_map_type, gather_map_type] gather_map_pair_type
Expand Down Expand Up @@ -40,3 +40,33 @@ cdef extern from "cudf/join.hpp" namespace "cudf" nogil:
const table_view left_keys,
const table_view right_keys,
) except +

cdef gather_map_pair_type inner_join(
const table_view left_keys,
const table_view right_keys,
null_equality nulls_equal,
) except +

cdef gather_map_pair_type left_join(
const table_view left_keys,
const table_view right_keys,
null_equality nulls_equal,
) except +

cdef gather_map_pair_type full_join(
const table_view left_keys,
const table_view right_keys,
null_equality nulls_equal,
) except +

cdef gather_map_type left_semi_join(
const table_view left_keys,
const table_view right_keys,
null_equality nulls_equal,
) except +

cdef gather_map_type left_anti_join(
const table_view left_keys,
const table_view right_keys,
null_equality nulls_equal,
) except +
2 changes: 2 additions & 0 deletions python/cudf/cudf/_lib/join.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ def join(list lhs, list rhs, how=None):
left_rows, right_rows = join_func(
pylibcudf.Table([c.to_pylibcudf(mode="read") for c in lhs]),
pylibcudf.Table([c.to_pylibcudf(mode="read") for c in rhs]),
pylibcudf.types.NullEquality.EQUAL
)
return Column.from_pylibcudf(left_rows), Column.from_pylibcudf(right_rows)

Expand All @@ -37,5 +38,6 @@ def semi_join(list lhs, list rhs, how=None):
join_func(
pylibcudf.Table([c.to_pylibcudf(mode="read") for c in lhs]),
pylibcudf.Table([c.to_pylibcudf(mode="read") for c in rhs]),
pylibcudf.types.NullEquality.EQUAL
)
), None
2 changes: 1 addition & 1 deletion python/cudf/cudf/_lib/pylibcudf/column.pxd
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ cdef class Column:
gpumemoryview _mask
size_type _null_count
size_type _offset
# children: List[Column]
# _children: List[Column]
list _children
size_type _num_children

Expand Down
2 changes: 2 additions & 0 deletions python/cudf/cudf/_lib/pylibcudf/column.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,8 @@ cdef class Column:
gpumemoryview mask, size_type null_count, size_type offset,
list children
):
if not all(isinstance(c, Column) for c in children):
raise ValueError("All children must be pylibcudf Column objects")
self._data_type = data_type
self._size = size
self._data = data
Expand Down
32 changes: 27 additions & 5 deletions python/cudf/cudf/_lib/pylibcudf/join.pxd
Original file line number Diff line number Diff line change
@@ -1,15 +1,37 @@
# Copyright (c) 2024, NVIDIA CORPORATION.

from cudf._lib.cpp.types cimport null_equality

from .column cimport Column
from .table cimport Table


cpdef tuple inner_join(Table left_keys, Table right_keys)
cpdef tuple inner_join(
Table left_keys,
Table right_keys,
null_equality nulls_equal
)

cpdef tuple left_join(Table left_keys, Table right_keys)
cpdef tuple left_join(
Table left_keys,
Table right_keys,
null_equality nulls_equal
)

cpdef tuple full_join(Table left_keys, Table right_keys)
cpdef tuple full_join(
Table left_keys,
Table right_keys,
null_equality nulls_equal
)

cpdef Column left_semi_join(Table left_keys, Table right_keys)
cpdef Column left_semi_join(
Table left_keys,
Table right_keys,
null_equality nulls_equal
)

cpdef Column left_anti_join(Table left_keys, Table right_keys)
cpdef Column left_anti_join(
Table left_keys,
Table right_keys,
null_equality nulls_equal
)
64 changes: 53 additions & 11 deletions python/cudf/cudf/_lib/pylibcudf/join.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ from rmm._lib.device_buffer cimport device_buffer

from cudf._lib.cpp cimport join as cpp_join
from cudf._lib.cpp.column.column cimport column
from cudf._lib.cpp.types cimport data_type, size_type, type_id
from cudf._lib.cpp.types cimport data_type, null_equality, size_type, type_id

from .column cimport Column
from .table cimport Table
Expand All @@ -32,7 +32,11 @@ cdef Column _column_from_gather_map(cpp_join.gather_map_type gather_map):
)


cpdef tuple inner_join(Table left_keys, Table right_keys):
cpdef tuple inner_join(
Table left_keys,
Table right_keys,
null_equality nulls_equal
):
"""Perform an inner join between two tables.
For details, see :cpp:func:`inner_join`.
Expand All @@ -43,6 +47,8 @@ cpdef tuple inner_join(Table left_keys, Table right_keys):
The left table to join.
right_keys : Table
The right table to join.
nulls_equal : NullEquality
Should nulls compare equal?
Returns
-------
Expand All @@ -52,14 +58,18 @@ cpdef tuple inner_join(Table left_keys, Table right_keys):
"""
cdef cpp_join.gather_map_pair_type c_result
with nogil:
c_result = cpp_join.inner_join(left_keys.view(), right_keys.view())
c_result = cpp_join.inner_join(left_keys.view(), right_keys.view(), nulls_equal)
return (
_column_from_gather_map(move(c_result.first)),
_column_from_gather_map(move(c_result.second)),
)


cpdef tuple left_join(Table left_keys, Table right_keys):
cpdef tuple left_join(
Table left_keys,
Table right_keys,
null_equality nulls_equal
):
"""Perform a left join between two tables.
For details, see :cpp:func:`left_join`.
Expand All @@ -70,6 +80,9 @@ cpdef tuple left_join(Table left_keys, Table right_keys):
The left table to join.
right_keys : Table
The right table to join.
nulls_equal : NullEquality
Should nulls compare equal?
Returns
-------
Expand All @@ -79,14 +92,18 @@ cpdef tuple left_join(Table left_keys, Table right_keys):
"""
cdef cpp_join.gather_map_pair_type c_result
with nogil:
c_result = cpp_join.left_join(left_keys.view(), right_keys.view())
c_result = cpp_join.left_join(left_keys.view(), right_keys.view(), nulls_equal)
return (
_column_from_gather_map(move(c_result.first)),
_column_from_gather_map(move(c_result.second)),
)


cpdef tuple full_join(Table left_keys, Table right_keys):
cpdef tuple full_join(
Table left_keys,
Table right_keys,
null_equality nulls_equal
):
"""Perform a full join between two tables.
For details, see :cpp:func:`full_join`.
Expand All @@ -97,6 +114,9 @@ cpdef tuple full_join(Table left_keys, Table right_keys):
The left table to join.
right_keys : Table
The right table to join.
nulls_equal : NullEquality
Should nulls compare equal?
Returns
-------
Expand All @@ -106,14 +126,18 @@ cpdef tuple full_join(Table left_keys, Table right_keys):
"""
cdef cpp_join.gather_map_pair_type c_result
with nogil:
c_result = cpp_join.full_join(left_keys.view(), right_keys.view())
c_result = cpp_join.full_join(left_keys.view(), right_keys.view(), nulls_equal)
return (
_column_from_gather_map(move(c_result.first)),
_column_from_gather_map(move(c_result.second)),
)


cpdef Column left_semi_join(Table left_keys, Table right_keys):
cpdef Column left_semi_join(
Table left_keys,
Table right_keys,
null_equality nulls_equal
):
"""Perform a left semi join between two tables.
For details, see :cpp:func:`left_semi_join`.
Expand All @@ -124,6 +148,9 @@ cpdef Column left_semi_join(Table left_keys, Table right_keys):
The left table to join.
right_keys : Table
The right table to join.
nulls_equal : NullEquality
Should nulls compare equal?
Returns
-------
Expand All @@ -132,11 +159,19 @@ cpdef Column left_semi_join(Table left_keys, Table right_keys):
"""
cdef cpp_join.gather_map_type c_result
with nogil:
c_result = cpp_join.left_semi_join(left_keys.view(), right_keys.view())
c_result = cpp_join.left_semi_join(
left_keys.view(),
right_keys.view(),
nulls_equal
)
return _column_from_gather_map(move(c_result))


cpdef Column left_anti_join(Table left_keys, Table right_keys):
cpdef Column left_anti_join(
Table left_keys,
Table right_keys,
null_equality nulls_equal
):
"""Perform a left anti join between two tables.
For details, see :cpp:func:`left_anti_join`.
Expand All @@ -147,6 +182,9 @@ cpdef Column left_anti_join(Table left_keys, Table right_keys):
The left table to join.
right_keys : Table
The right table to join.
nulls_equal : NullEquality
Should nulls compare equal?
Returns
-------
Expand All @@ -155,5 +193,9 @@ cpdef Column left_anti_join(Table left_keys, Table right_keys):
"""
cdef cpp_join.gather_map_type c_result
with nogil:
c_result = cpp_join.left_anti_join(left_keys.view(), right_keys.view())
c_result = cpp_join.left_anti_join(
left_keys.view(),
right_keys.view(),
nulls_equal
)
return _column_from_gather_map(move(c_result))
2 changes: 2 additions & 0 deletions python/cudf/cudf/_lib/pylibcudf/table.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,8 @@ cdef class Table:
The columns in this table.
"""
def __init__(self, list columns):
if not all(isinstance(c, Column) for c in columns):
raise ValueError("All columns must be pylibcudf Column objects")
self._columns = columns

cdef table_view view(self) nogil:
Expand Down

0 comments on commit 66b3a93

Please sign in to comment.