diff --git a/python/cudf/cudf/_lib/cpp/lists/contains.pxd b/python/cudf/cudf/_lib/cpp/lists/contains.pxd index 46aea37643f..e3cb01721a0 100644 --- a/python/cudf/cudf/_lib/cpp/lists/contains.pxd +++ b/python/cudf/cudf/_lib/cpp/lists/contains.pxd @@ -18,3 +18,8 @@ cdef extern from "cudf/lists/contains.hpp" namespace "cudf::lists" nogil: lists_column_view lists, scalar search_key, ) except + + + cdef unique_ptr[column] index_of( + lists_column_view lists, + column_view search_keys, + ) except + diff --git a/python/cudf/cudf/_lib/lists.pyx b/python/cudf/cudf/_lib/lists.pyx index e5a705ab603..025fb0665d3 100644 --- a/python/cudf/cudf/_lib/lists.pyx +++ b/python/cudf/cudf/_lib/lists.pyx @@ -176,7 +176,7 @@ def contains_scalar(Column col, object py_search_key): return result -def index_of(Column col, object py_search_key): +def index_of_scalar(Column col, object py_search_key): cdef DeviceScalar search_key = py_search_key.device_value @@ -195,6 +195,24 @@ def index_of(Column col, object py_search_key): return Column.from_unique_ptr(move(c_result)) +def index_of_column(Column col, Column search_keys): + + cdef column_view keys_view = search_keys.view() + + cdef shared_ptr[lists_column_view] list_view = ( + make_shared[lists_column_view](col.view()) + ) + + cdef unique_ptr[column] c_result + + with nogil: + c_result = move(cpp_index_of( + list_view.get()[0], + keys_view, + )) + return Column.from_unique_ptr(move(c_result)) + + def concatenate_rows(list source_columns): cdef unique_ptr[column] c_result diff --git a/python/cudf/cudf/core/column/lists.py b/python/cudf/cudf/core/column/lists.py index df6aaa91a2b..2964378d114 100644 --- a/python/cudf/cudf/core/column/lists.py +++ b/python/cudf/cudf/core/column/lists.py @@ -17,7 +17,8 @@ drop_list_duplicates, extract_element_column, extract_element_scalar, - index_of, + index_of_column, + index_of_scalar, sort_lists, ) from cudf._lib.strings.convert.convert_lists import format_list_column @@ -463,10 +464,61 @@ def contains(self, search_key: ScalarLike) -> ParentType: raise return res - def index(self, search_key: ScalarLike) -> ParentType: - search_key = cudf.Scalar(search_key) + def index(self, search_key: Union[ScalarLike, ColumnLike]) -> ParentType: + """ + Returns integers representing the index of the search key for each row. + + If ``search_key`` is a sequence, it must be the same length as the + Series and ``search_key[i]`` represents the search key for the + ``i``-th row of the Series. + + If the search key is not contained in a row, -1 is returned. If either + the row or the search key are null, is returned. If the search key + is contained multiple times, the smallest matching index is returned. + + Parameters + ---------- + search_key : scalar or sequence of scalars + Element or elements being searched for in each row of the list + column + + Returns + ------- + Series or Index + + Examples + -------- + >>> s = cudf.Series([[1, 2, 3], [3, 4, 5], [4, 5, 6]]) + >>> s.list.index(4) + 0 -1 + 1 1 + 2 0 + dtype: int32 + + >>> s = cudf.Series([["a", "b", "c"], ["x", "y", "z"]]) + >>> s.list.index(["b", "z"]) + 0 1 + 1 2 + dtype: int32 + + >>> s = cudf.Series([[4, 5, 6], None, [-3, -2, -1]]) + >>> s.list.index([None, 3, -2]) + 0 + 1 + 2 1 + dtype: int32 + """ + try: - res = self._return_or_inplace(index_of(self._column, search_key)) + if is_scalar(search_key): + return self._return_or_inplace( + index_of_scalar(self._column, cudf.Scalar(search_key)) + ) + else: + return self._return_or_inplace( + index_of_column(self._column, as_column(search_key)) + ) + except RuntimeError as e: if ( "Type/Scale of search key does not " @@ -474,7 +526,6 @@ def index(self, search_key: ScalarLike) -> ParentType: ): raise TypeError(str(e)) from e raise - return res @property def leaves(self) -> ParentType: diff --git a/python/cudf/cudf/tests/test_list.py b/python/cudf/cudf/tests/test_list.py index c21e1a0f61f..09eee3520e5 100644 --- a/python/cudf/cudf/tests/test_list.py +++ b/python/cudf/cudf/tests/test_list.py @@ -11,6 +11,7 @@ import cudf from cudf import NA from cudf._lib.copying import get_element +from cudf.api.types import is_scalar from cudf.testing._utils import ( DATETIME_TYPES, NUMERIC_TYPES, @@ -425,7 +426,7 @@ def test_contains_invalid(data, scalar): @pytest.mark.parametrize( - "data, scalar, expect", + "data, search_key, expect", [ ( [[1, 2, 3], [], [3, 4, 5]], @@ -448,6 +449,16 @@ def test_contains_invalid(data, scalar): "y", [3, -1], ), + ( + [["h", "a", None], ["t", "g"]], + ["a", "b"], + [1, -1], + ), + ( + [None, ["h", "i"], ["p", "k", "z"]], + ["x", None, "z"], + [None, None, 2], + ), ( [["d", None, "e"], [None, "f"], []], cudf.Scalar(cudf.NA, "O"), @@ -460,15 +471,21 @@ def test_contains_invalid(data, scalar): ), ], ) -def test_index(data, scalar, expect): +def test_index(data, search_key, expect): sr = cudf.Series(data) expect = cudf.Series(expect, dtype="int32") - got = sr.list.index(cudf.Scalar(scalar, sr.dtype.element_type)) + if is_scalar(search_key): + got = sr.list.index(cudf.Scalar(search_key, sr.dtype.element_type)) + else: + got = sr.list.index( + cudf.Series(search_key, dtype=sr.dtype.element_type) + ) + assert_eq(expect, got) @pytest.mark.parametrize( - "data, scalar", + "data, search_key", [ ( [[9, None, 8], [], [7, 6, 5]], @@ -478,16 +495,42 @@ def test_index(data, scalar, expect): [["a", "b", "c"], None, [None, "d"]], 2, ), + ( + [["e", "s"], ["t", "w"]], + [5, 6], + ), ], ) -def test_index_invalid(data, scalar): +def test_index_invalid_type(data, search_key): sr = cudf.Series(data) with pytest.raises( TypeError, match="Type/Scale of search key does not " "match list column element type.", ): - sr.list.index(scalar) + sr.list.index(search_key) + + +@pytest.mark.parametrize( + "data, search_key", + [ + ( + [[5, 8], [2, 6]], + [8, 2, 4], + ), + ( + [["h", "j"], ["p", None], ["t", "z"]], + ["j", "a"], + ), + ], +) +def test_index_invalid_length(data, search_key): + sr = cudf.Series(data) + with pytest.raises( + RuntimeError, + match="Number of search keys must match list column size.", + ): + sr.list.index(search_key) @pytest.mark.parametrize(