Skip to content

Commit

Permalink
Add python bindings for cudf::list::index_of (#10549)
Browse files Browse the repository at this point in the history
This PR adds python bindings for `cudf::list::index_of` in the form of `ListMethods.index`

Authors:
  - https://github.com/ChrisJar

Approvers:
  - Vyas Ramasubramani (https://github.com/vyasr)
  - Ashwin Srinath (https://github.com/shwina)

URL: #10549
  • Loading branch information
ChrisJar authored Apr 10, 2022
1 parent bc43e6a commit bf4ffc9
Show file tree
Hide file tree
Showing 4 changed files with 134 additions and 11 deletions.
7 changes: 6 additions & 1 deletion python/cudf/cudf/_lib/cpp/lists/contains.pxd
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright (c) 2021, NVIDIA CORPORATION.
# Copyright (c) 2021-2022, NVIDIA CORPORATION.

from libcpp.memory cimport unique_ptr

Expand All @@ -13,3 +13,8 @@ cdef extern from "cudf/lists/contains.hpp" namespace "cudf::lists" nogil:
lists_column_view lists,
scalar search_key,
) except +

cdef unique_ptr[column] index_of(
lists_column_view lists,
scalar search_key,
) except +
23 changes: 21 additions & 2 deletions python/cudf/cudf/_lib/lists.pyx
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright (c) 2021, NVIDIA CORPORATION.
# Copyright (c) 2021-2022, NVIDIA CORPORATION.

from libcpp cimport bool
from libcpp.memory cimport make_shared, shared_ptr, unique_ptr
Expand Down Expand Up @@ -40,7 +40,7 @@ from cudf._lib.types cimport (

from cudf.core.dtypes import ListDtype

from cudf._lib.cpp.lists.contains cimport contains
from cudf._lib.cpp.lists.contains cimport contains, index_of as cpp_index_of
from cudf._lib.cpp.lists.extract cimport extract_list_element
from cudf._lib.utils cimport data_from_unique_ptr, table_view_from_table

Expand Down Expand Up @@ -162,6 +162,25 @@ def contains_scalar(Column col, object py_search_key):
return result


def index_of(Column col, object py_search_key):

cdef DeviceScalar search_key = py_search_key.device_value

cdef shared_ptr[lists_column_view] list_view = (
make_shared[lists_column_view](col.view())
)
cdef const scalar* search_key_value = search_key.get_raw_ptr()

cdef unique_ptr[column] c_result

with nogil:
c_result = move(cpp_index_of(
list_view.get()[0],
search_key_value[0],
))
return Column.from_unique_ptr(move(c_result))


def concatenate_rows(tbl):
cdef unique_ptr[column] c_result

Expand Down
26 changes: 18 additions & 8 deletions python/cudf/cudf/core/column/lists.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
count_elements,
drop_list_duplicates,
extract_element,
index_of,
sort_lists,
)
from cudf._lib.strings.convert.convert_lists import format_list_column
Expand Down Expand Up @@ -424,16 +425,25 @@ def contains(self, search_key: ScalarLike) -> ParentType:
)
except RuntimeError as e:
if (
"Type/Scale of search key does not"
"match list column element type" in str(e)
"Type/Scale of search key does not "
"match list column element type." in str(e)
):
raise TypeError(
"Type/Scale of search key does not"
"match list column element type"
) from e
raise TypeError(str(e)) from e
raise
else:
return res
return res

def index(self, search_key: ScalarLike) -> ParentType:
search_key = cudf.Scalar(search_key)
try:
res = self._return_or_inplace(index_of(self._column, search_key))
except RuntimeError as e:
if (
"Type/Scale of search key does not "
"match list column element type." in str(e)
):
raise TypeError(str(e)) from e
raise
return res

@property
def leaves(self) -> ParentType:
Expand Down
89 changes: 89 additions & 0 deletions python/cudf/cudf/tests/test_list.py
Original file line number Diff line number Diff line change
Expand Up @@ -390,6 +390,95 @@ def test_contains_null_search_key(data, expect):
assert_eq(expect, got)


@pytest.mark.parametrize(
"data, scalar",
[
(
[[9, 0, 2], [], [1, None, 0]],
"x",
),
(
[["z", "y", None], None, [None, "x"]],
5,
),
],
)
def test_contains_invalid(data, scalar):
sr = cudf.Series(data)
with pytest.raises(
TypeError,
match="Type/Scale of search key does not "
"match list column element type.",
):
sr.list.contains(scalar)


@pytest.mark.parametrize(
"data, scalar, expect",
[
(
[[1, 2, 3], [], [3, 4, 5]],
3,
[2, -1, 0],
),
(
[[1.0, 2.0, 3.0], None, [2.0, 5.0]],
2.0,
[1, None, 0],
),
(
[[None, "b", "c"], [], ["b", "e", "f"]],
"f",
[-1, -1, 2],
),
([[-5, None, 8], None, []], -5, [0, None, -1]),
(
[[None, "x", None, "y"], ["z", "i", "j"]],
"y",
[3, -1],
),
(
[["d", None, "e"], [None, "f"], []],
cudf.Scalar(cudf.NA, "O"),
[None, None, None],
),
(
[None, [10, 9, 8], [5, 8, None]],
cudf.Scalar(cudf.NA, "int64"),
[None, None, None],
),
],
)
def test_index(data, scalar, expect):
sr = cudf.Series(data)
expect = cudf.Series(expect, dtype="int32")
got = sr.list.index(cudf.Scalar(scalar, sr.dtype.element_type))
assert_eq(expect, got)


@pytest.mark.parametrize(
"data, scalar",
[
(
[[9, None, 8], [], [7, 6, 5]],
"c",
),
(
[["a", "b", "c"], None, [None, "d"]],
2,
),
],
)
def test_index_invalid(data, scalar):
sr = cudf.Series(data)
with pytest.raises(
TypeError,
match="Type/Scale of search key does not "
"match list column element type.",
):
sr.list.index(scalar)


@pytest.mark.parametrize(
"row",
[
Expand Down

0 comments on commit bf4ffc9

Please sign in to comment.