Skip to content

Commit

Permalink
Add bindings for index_of with column search key (#10696)
Browse files Browse the repository at this point in the history
This adds bindings for `index_of` to enable using `list.index` with a Series of search keys.

Closes #10692 

cc: @randerzander

Authors:
  - https://github.com/ChrisJar

Approvers:
  - Ashwin Srinath (https://github.com/shwina)
  - Bradley Dice (https://github.com/bdice)

URL: #10696
  • Loading branch information
ChrisJar authored Apr 27, 2022
1 parent dc1435b commit 09995a5
Show file tree
Hide file tree
Showing 4 changed files with 129 additions and 12 deletions.
5 changes: 5 additions & 0 deletions python/cudf/cudf/_lib/cpp/lists/contains.pxd
Original file line number Diff line number Diff line change
Expand Up @@ -18,3 +18,8 @@ cdef extern from "cudf/lists/contains.hpp" namespace "cudf::lists" nogil:
lists_column_view lists,
scalar search_key,
) except +

cdef unique_ptr[column] index_of(
lists_column_view lists,
column_view search_keys,
) except +
20 changes: 19 additions & 1 deletion python/cudf/cudf/_lib/lists.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -176,7 +176,7 @@ def contains_scalar(Column col, object py_search_key):
return result


def index_of(Column col, object py_search_key):
def index_of_scalar(Column col, object py_search_key):

cdef DeviceScalar search_key = py_search_key.device_value

Expand All @@ -195,6 +195,24 @@ def index_of(Column col, object py_search_key):
return Column.from_unique_ptr(move(c_result))


def index_of_column(Column col, Column search_keys):

cdef column_view keys_view = search_keys.view()

cdef shared_ptr[lists_column_view] list_view = (
make_shared[lists_column_view](col.view())
)

cdef unique_ptr[column] c_result

with nogil:
c_result = move(cpp_index_of(
list_view.get()[0],
keys_view,
))
return Column.from_unique_ptr(move(c_result))


def concatenate_rows(list source_columns):
cdef unique_ptr[column] c_result

Expand Down
61 changes: 56 additions & 5 deletions python/cudf/cudf/core/column/lists.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,8 @@
drop_list_duplicates,
extract_element_column,
extract_element_scalar,
index_of,
index_of_column,
index_of_scalar,
sort_lists,
)
from cudf._lib.strings.convert.convert_lists import format_list_column
Expand Down Expand Up @@ -463,18 +464,68 @@ def contains(self, search_key: ScalarLike) -> ParentType:
raise
return res

def index(self, search_key: ScalarLike) -> ParentType:
search_key = cudf.Scalar(search_key)
def index(self, search_key: Union[ScalarLike, ColumnLike]) -> ParentType:
"""
Returns integers representing the index of the search key for each row.
If ``search_key`` is a sequence, it must be the same length as the
Series and ``search_key[i]`` represents the search key for the
``i``-th row of the Series.
If the search key is not contained in a row, -1 is returned. If either
the row or the search key are null, <NA> is returned. If the search key
is contained multiple times, the smallest matching index is returned.
Parameters
----------
search_key : scalar or sequence of scalars
Element or elements being searched for in each row of the list
column
Returns
-------
Series or Index
Examples
--------
>>> s = cudf.Series([[1, 2, 3], [3, 4, 5], [4, 5, 6]])
>>> s.list.index(4)
0 -1
1 1
2 0
dtype: int32
>>> s = cudf.Series([["a", "b", "c"], ["x", "y", "z"]])
>>> s.list.index(["b", "z"])
0 1
1 2
dtype: int32
>>> s = cudf.Series([[4, 5, 6], None, [-3, -2, -1]])
>>> s.list.index([None, 3, -2])
0 <NA>
1 <NA>
2 1
dtype: int32
"""

try:
res = self._return_or_inplace(index_of(self._column, search_key))
if is_scalar(search_key):
return self._return_or_inplace(
index_of_scalar(self._column, cudf.Scalar(search_key))
)
else:
return self._return_or_inplace(
index_of_column(self._column, as_column(search_key))
)

except RuntimeError as e:
if (
"Type/Scale of search key does not "
"match list column element type." in str(e)
):
raise TypeError(str(e)) from e
raise
return res

@property
def leaves(self) -> ParentType:
Expand Down
55 changes: 49 additions & 6 deletions python/cudf/cudf/tests/test_list.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
import cudf
from cudf import NA
from cudf._lib.copying import get_element
from cudf.api.types import is_scalar
from cudf.testing._utils import (
DATETIME_TYPES,
NUMERIC_TYPES,
Expand Down Expand Up @@ -425,7 +426,7 @@ def test_contains_invalid(data, scalar):


@pytest.mark.parametrize(
"data, scalar, expect",
"data, search_key, expect",
[
(
[[1, 2, 3], [], [3, 4, 5]],
Expand All @@ -448,6 +449,16 @@ def test_contains_invalid(data, scalar):
"y",
[3, -1],
),
(
[["h", "a", None], ["t", "g"]],
["a", "b"],
[1, -1],
),
(
[None, ["h", "i"], ["p", "k", "z"]],
["x", None, "z"],
[None, None, 2],
),
(
[["d", None, "e"], [None, "f"], []],
cudf.Scalar(cudf.NA, "O"),
Expand All @@ -460,15 +471,21 @@ def test_contains_invalid(data, scalar):
),
],
)
def test_index(data, scalar, expect):
def test_index(data, search_key, expect):
sr = cudf.Series(data)
expect = cudf.Series(expect, dtype="int32")
got = sr.list.index(cudf.Scalar(scalar, sr.dtype.element_type))
if is_scalar(search_key):
got = sr.list.index(cudf.Scalar(search_key, sr.dtype.element_type))
else:
got = sr.list.index(
cudf.Series(search_key, dtype=sr.dtype.element_type)
)

assert_eq(expect, got)


@pytest.mark.parametrize(
"data, scalar",
"data, search_key",
[
(
[[9, None, 8], [], [7, 6, 5]],
Expand All @@ -478,16 +495,42 @@ def test_index(data, scalar, expect):
[["a", "b", "c"], None, [None, "d"]],
2,
),
(
[["e", "s"], ["t", "w"]],
[5, 6],
),
],
)
def test_index_invalid(data, scalar):
def test_index_invalid_type(data, search_key):
sr = cudf.Series(data)
with pytest.raises(
TypeError,
match="Type/Scale of search key does not "
"match list column element type.",
):
sr.list.index(scalar)
sr.list.index(search_key)


@pytest.mark.parametrize(
"data, search_key",
[
(
[[5, 8], [2, 6]],
[8, 2, 4],
),
(
[["h", "j"], ["p", None], ["t", "z"]],
["j", "a"],
),
],
)
def test_index_invalid_length(data, search_key):
sr = cudf.Series(data)
with pytest.raises(
RuntimeError,
match="Number of search keys must match list column size.",
):
sr.list.index(search_key)


@pytest.mark.parametrize(
Expand Down

0 comments on commit 09995a5

Please sign in to comment.