Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add bindings for index_of with column search key #10696

Merged
merged 5 commits into from
Apr 27, 2022
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions python/cudf/cudf/_lib/cpp/lists/contains.pxd
Original file line number Diff line number Diff line change
Expand Up @@ -18,3 +18,8 @@ cdef extern from "cudf/lists/contains.hpp" namespace "cudf::lists" nogil:
lists_column_view lists,
scalar search_key,
) except +

cdef unique_ptr[column] index_of(
lists_column_view lists,
column_view search_keys,
) except +
20 changes: 19 additions & 1 deletion python/cudf/cudf/_lib/lists.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -176,7 +176,7 @@ def contains_scalar(Column col, object py_search_key):
return result


def index_of(Column col, object py_search_key):
def index_of_scalar(Column col, object py_search_key):

cdef DeviceScalar search_key = py_search_key.device_value

Expand All @@ -195,6 +195,24 @@ def index_of(Column col, object py_search_key):
return Column.from_unique_ptr(move(c_result))


def index_of_column(Column col, Column search_keys):

cdef column_view keys_view = search_keys.view()

cdef shared_ptr[lists_column_view] list_view = (
make_shared[lists_column_view](col.view())
)

cdef unique_ptr[column] c_result

with nogil:
c_result = move(cpp_index_of(
list_view.get()[0],
keys_view,
))
return Column.from_unique_ptr(move(c_result))


def concatenate_rows(list source_columns):
cdef unique_ptr[column] c_result

Expand Down
58 changes: 54 additions & 4 deletions python/cudf/cudf/core/column/lists.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,8 @@
drop_list_duplicates,
extract_element_column,
extract_element_scalar,
index_of,
index_of_column,
index_of_scalar,
sort_lists,
)
from cudf._lib.strings.convert.convert_lists import format_list_column
Expand Down Expand Up @@ -453,10 +454,59 @@ def contains(self, search_key: ScalarLike) -> ParentType:
raise
return res

def index(self, search_key: ScalarLike) -> ParentType:
search_key = cudf.Scalar(search_key)
def index(self, search_key: Union[ScalarLike, ColumnLike]) -> ParentType:
bdice marked this conversation as resolved.
Show resolved Hide resolved
"""
Return integers representing the index of the search key for each row.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The first line "brief" should be followed by a blank line before the longer summary.

Suggested change
Return integers representing the index of the search key for each row.
Return integers representing the index of the search key for each row.

If ``search_key`` is a sequence, it must be the same length as the
Series and ``search_key[i]`` represents the search key for the
``i``-th row of the Series.

If the search key is not contained in a row, return -1.
If either the row or the search key are null, return <NA>.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If the search key is contained multiple times, does this return the smallest matching index?

Suggested change
If the search key is not contained in a row, return -1.
If either the row or the search key are null, return <NA>.
If the search key is not contained in a row, return -1.
If either the row or the search key are null, return <NA>.
If the search key is contained multiple times, the smallest matching
index is returned.


Parameters
----------
search_key : scalar or sequence of scalars
Element or elements being searched for in each row of the list
column

Returns
-------
Series or Index

Examples
--------
>>> s = cudf.Series([[1, 2, 3], [3, 4, 5], [4, 5, 6]])
>>> s.list.index(4)
0 -1
1 1
2 0
dtype: int32

>>> s = cudf.Series([["a", "b", "c"], ["x", "y", "z"]])
>>> s.list.index(["b", "z"])
0 1
1 2
dtype: int32

>>> s = cudf.Series([[4, 5, 6], None, [-3, -2, -1]])
>>> s.list.index([None, 3, -2])
0 <NA>
1 <NA>
2 1
dtype: int32
"""

try:
res = self._return_or_inplace(index_of(self._column, search_key))
if is_scalar(search_key):
res = self._return_or_inplace(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There's no post-processing needed so I would return directly instead of saving a variable res and returning later.

Suggested change
res = self._return_or_inplace(
return self._return_or_inplace(

index_of_scalar(self._column, cudf.Scalar(search_key))
)
else:
res = self._return_or_inplace(
index_of_column(self._column, as_column(search_key))
)

except RuntimeError as e:
if (
"Type/Scale of search key does not "
Expand Down
33 changes: 27 additions & 6 deletions python/cudf/cudf/tests/test_list.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
import cudf
from cudf import NA
from cudf._lib.copying import get_element
from cudf.api.types import is_scalar
from cudf.testing._utils import (
DATETIME_TYPES,
NUMERIC_TYPES,
Expand Down Expand Up @@ -425,7 +426,7 @@ def test_contains_invalid(data, scalar):


@pytest.mark.parametrize(
"data, scalar, expect",
"data, search_key, expect",
[
(
[[1, 2, 3], [], [3, 4, 5]],
Expand All @@ -448,6 +449,16 @@ def test_contains_invalid(data, scalar):
"y",
[3, -1],
),
(
[["h", "a", None], ["t", "g"]],
["a", "b"],
[1, -1],
),
(
[None, ["h", "i"], ["p", "k", "z"]],
["x", None, "z"],
[None, None, 2],
),
(
[["d", None, "e"], [None, "f"], []],
cudf.Scalar(cudf.NA, "O"),
Expand All @@ -460,15 +471,21 @@ def test_contains_invalid(data, scalar):
),
],
)
def test_index(data, scalar, expect):
def test_index(data, search_key, expect):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we add a test case for multi-level nested data? (if that is supported)

sr = cudf.Series({"a": [[[1, 2], [3, 4]], [[5, 6], [7, 8]]]})
sr.list.index([[1, 2], [7, 8]])  # returns [0, 1]

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Still curious about multi-level nesting here. If multi-level nesting is supported, we'll need to revise a few other items as well. e.g. is_scalar might not be the appropriate check if "list scalars" are provided to check against a list of lists -- scalar-like input would have one fewer dimension / nested level that the input column, while column-like input would have an equal number of nested levels.

sr = cudf.Series(data)
expect = cudf.Series(expect, dtype="int32")
got = sr.list.index(cudf.Scalar(scalar, sr.dtype.element_type))
if is_scalar(search_key):
got = sr.list.index(cudf.Scalar(search_key, sr.dtype.element_type))
else:
got = sr.list.index(
cudf.Series(search_key, dtype=sr.dtype.element_type)
)

assert_eq(expect, got)


@pytest.mark.parametrize(
"data, scalar",
"data, search_key",
[
(
[[9, None, 8], [], [7, 6, 5]],
Expand All @@ -478,16 +495,20 @@ def test_index(data, scalar, expect):
[["a", "b", "c"], None, [None, "d"]],
2,
),
(
[["e", "s"], ["t", "w"]],
[5, 6],
),
],
)
def test_index_invalid(data, scalar):
def test_index_invalid(data, search_key):
Copy link
Contributor

@bdice bdice Apr 21, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we add a test for the invalid case where the search key is not the right length? e.g. len(sr) != len(search_key)

sr = cudf.Series(data)
with pytest.raises(
TypeError,
match="Type/Scale of search key does not "
"match list column element type.",
):
sr.list.index(scalar)
sr.list.index(search_key)


@pytest.mark.parametrize(
Expand Down