Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Python bindings for lists::contains #7547

Merged
merged 17 commits into from
Mar 25, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 15 additions & 0 deletions python/cudf/cudf/_lib/cpp/lists/contains.pxd
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
# Copyright (c) 2021, NVIDIA CORPORATION.

from libcpp.memory cimport unique_ptr
from cudf._lib.cpp.scalar.scalar cimport scalar

from cudf._lib.cpp.column.column cimport column
from cudf._lib.cpp.lists.lists_column_view cimport lists_column_view

from cudf._lib.cpp.column.column_view cimport column_view

cdef extern from "cudf/lists/contains.hpp" namespace "cudf::lists" nogil:
cdef unique_ptr[column] contains(
lists_column_view lists,
scalar search_key,
) except +
24 changes: 24 additions & 0 deletions python/cudf/cudf/_lib/lists.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,9 @@ from cudf._lib.cpp.lists.lists_column_view cimport lists_column_view
from cudf._lib.cpp.column.column_view cimport column_view
from cudf._lib.cpp.column.column cimport column

from cudf._lib.scalar cimport DeviceScalar
from cudf._lib.cpp.scalar.scalar cimport scalar

from cudf._lib.cpp.table.table cimport table
from cudf._lib.cpp.table.table_view cimport table_view
from cudf._lib.cpp.types cimport size_type, order, null_order
Expand All @@ -29,6 +32,8 @@ from cudf._lib.types cimport (
)
from cudf.core.dtypes import ListDtype

from cudf._lib.cpp.lists.contains cimport contains

from cudf._lib.cpp.lists.extract cimport extract_list_element


Expand Down Expand Up @@ -93,10 +98,29 @@ def extract_element(Column col, size_type index):
cdef shared_ptr[lists_column_view] list_view = (
make_shared[lists_column_view](col.view())
)

cdef unique_ptr[column] c_result

with nogil:
c_result = move(extract_list_element(list_view.get()[0], index))

result = Column.from_unique_ptr(move(c_result))
return result


def contains_scalar(Column col, DeviceScalar search_key):
cdef shared_ptr[lists_column_view] list_view = (
make_shared[lists_column_view](col.view())
)
cdef const scalar* search_key_value = search_key.get_raw_ptr()

cdef unique_ptr[column] c_result

with nogil:
c_result = move(contains(
list_view.get()[0],
search_key_value[0],
))

result = Column.from_unique_ptr(move(c_result))
return result
45 changes: 44 additions & 1 deletion python/cudf/cudf/core/column/lists.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,12 @@

import cudf
from cudf._lib.copying import segmented_gather
from cudf._lib.lists import count_elements, extract_element, sort_lists
from cudf._lib.lists import (
contains_scalar,
count_elements,
extract_element,
sort_lists,
)
from cudf.core.buffer import Buffer
from cudf.core.column import ColumnBase, as_column, column
from cudf.core.column.methods import ColumnMethodsMixin
Expand Down Expand Up @@ -210,6 +215,44 @@ def get(self, index):
else:
raise IndexError("list index out of range")

def contains(self, search_key):
"""
Creates a column of bool values indicating whether the specified scalar
is an element of each row of a list column.

Parameters
----------
search_key : scalar
element being searched for in each row of the list column

Returns
-------
Column

Examples
--------
>>> s = cudf.Series([[1, 2, 3], [3, 4, 5], [4, 5, 6]])
>>> s.list.contains(4)
Series([False, True, True])
dtype: bool
"""
try:
res = self._return_or_inplace(
contains_scalar(self._column, search_key.device_value)
)
except RuntimeError as e:
if (
"Type/Scale of search key does not"
"match list column element type" in str(e)
):
raise TypeError(
"Type/Scale of search key does not"
"match list column element type"
) from e
raise
else:
return res

@property
def leaves(self):
"""
Expand Down
35 changes: 35 additions & 0 deletions python/cudf/cudf/tests/test_list.py
Original file line number Diff line number Diff line change
Expand Up @@ -246,3 +246,38 @@ def test_get_nulls():
with pytest.raises(IndexError, match="list index out of range"):
sr = cudf.Series([[], [], []])
sr.list.get(100)


@pytest.mark.parametrize(
"data, scalar, expect",
[
([[1, 2, 3], []], 1, [True, False],),
([[1, 2, 3], [], [3, 4, 5]], 6, [False, False, False],),
([[1.0, 2.0, 3.0], None, []], 2.0, [True, None, False],),
([[None, "b", "c"], [], ["b", "e", "f"]], "b", [True, False, True],),
([[None, 2, 3], None, []], 1, [None, None, False]),
([[None, "b", "c"], [], ["b", "e", "f"]], "d", [None, False, False],),
],
)
def test_contains_scalar(data, scalar, expect):
sr = cudf.Series(data)
expect = cudf.Series(expect)
got = sr.list.contains(cudf.Scalar(scalar, sr.dtype.element_type))
assert_eq(expect, got)


@pytest.mark.parametrize(
"data, expect",
[
([[1, 2, 3], []], [None, None],),
([[1.0, 2.0, 3.0], None, []], [None, None, None],),
([[None, 2, 3], [], None], [None, None, None],),
([[1, 2, 3], [3, 4, 5]], [None, None],),
([[], [], []], [None, None, None],),
],
)
def test_contains_null_search_key(data, expect):
sr = cudf.Series(data)
expect = cudf.Series(expect, dtype="bool")
got = sr.list.contains(cudf.Scalar(cudf.NA, sr.dtype.element_type))
assert_eq(expect, got)