Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enable passing a sequence for the index argument to .list.get() #10564

Merged
merged 22 commits into from
Apr 12, 2022
Merged
Show file tree
Hide file tree
Changes from 18 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 6 additions & 2 deletions python/cudf/cudf/_lib/cpp/lists/extract.pxd
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
# Copyright (c) 2021, NVIDIA CORPORATION.
# Copyright (c) 2021-2022, NVIDIA CORPORATION.

from libcpp.memory cimport unique_ptr

from cudf._lib.cpp.column.column cimport column
from cudf._lib.cpp.column.column cimport column, column_view
from cudf._lib.cpp.lists.lists_column_view cimport lists_column_view
from cudf._lib.cpp.types cimport size_type

Expand All @@ -12,3 +12,7 @@ cdef extern from "cudf/lists/extract.hpp" namespace "cudf::lists" nogil:
const lists_column_view,
size_type
) except +
cdef unique_ptr[column] extract_list_element(
const lists_column_view,
column_view
) except +
20 changes: 18 additions & 2 deletions python/cudf/cudf/_lib/lists.pyx
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright (c) 2021, NVIDIA CORPORATION.
# Copyright (c) 2021-2022, NVIDIA CORPORATION.

from libcpp cimport bool
from libcpp.memory cimport make_shared, shared_ptr, unique_ptr
Expand Down Expand Up @@ -126,7 +126,7 @@ def sort_lists(Column col, bool ascending, str na_position):
return Column.from_unique_ptr(move(c_result))


def extract_element(Column col, size_type index):
def extract_element_scalar(Column col, int index):
vyasr marked this conversation as resolved.
Show resolved Hide resolved
# shared_ptr required because lists_column_view has no default
# ctor
cdef shared_ptr[lists_column_view] list_view = (
Expand All @@ -142,6 +142,22 @@ def extract_element(Column col, size_type index):
return result


def extract_element_column(Column col, Column index):
cdef shared_ptr[lists_column_view] list_view = (
make_shared[lists_column_view](col.view())
)

cdef column_view index_view = index.view()

cdef unique_ptr[column] c_result

with nogil:
c_result = move(extract_list_element(list_view.get()[0], index_view))

result = Column.from_unique_ptr(move(c_result))
return result


def contains_scalar(Column col, object py_search_key):

cdef DeviceScalar search_key = py_search_key.device_value
Expand Down
45 changes: 34 additions & 11 deletions python/cudf/cudf/core/column/lists.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import pickle
from functools import cached_property
from typing import List, Optional, Sequence
from typing import List, Optional, Sequence, Union

import numpy as np
import pyarrow as pa
Expand All @@ -15,12 +15,17 @@
contains_scalar,
count_elements,
drop_list_duplicates,
extract_element,
extract_element_column,
extract_element_scalar,
sort_lists,
)
from cudf._lib.strings.convert.convert_lists import format_list_column
from cudf._typing import ColumnBinaryOperand, ColumnLike, Dtype, ScalarLike
from cudf.api.types import _is_non_decimal_numeric_dtype, is_list_dtype
from cudf.api.types import (
_is_non_decimal_numeric_dtype,
is_list_dtype,
is_scalar,
)
from cudf.core.buffer import Buffer
from cudf.core.column import ColumnBase, as_column, column
from cudf.core.column.methods import ColumnMethods, ParentType
Expand Down Expand Up @@ -338,18 +343,27 @@ def __init__(self, parent: ParentType):
super().__init__(parent=parent)

def get(
self, index: int, default: Optional[ScalarLike] = None
self,
index: int,
default: Optional[Union[ScalarLike, ColumnLike]] = None,
) -> ParentType:
"""
Extract element at the given index from each list.
Extract element at the given index from each list in a Series of lists.

If the index is out of bounds for any list,
return <NA> or, if provided, ``default``.
Thus, this method never raises an ``IndexError``.
``index`` can be an integer or a sequence of integers. If
``index`` is an integer, the element at position ``index`` is
extracted from each list. If ``index`` is a sequence, it must
be of the same length as the Series, and ``index[i]``
specifies the position of the element to extract from the
``i``-th list in the Series.

If the index is out of bounds for any list, return <NA> or, if
provided, ``default``. Thus, this method never raises an
``IndexError``.

Parameters
----------
index : int
index : int or sequence of ints
default : scalar, optional

Returns
Expand All @@ -372,14 +386,23 @@ def get(
2 6
dtype: int64

>>> s = cudf.Series([[1, 2], [3, 4, 5], [4, 5, 6]])
>>> s.list.get(2, default=0)
0 0
1 5
2 6
dtype: int64

>>> s.list.get([0, 1, 2])
0 1
1 4
2 6
dtype: int64
"""
out = extract_element(self._column, index)
if is_scalar(index):
out = extract_element_scalar(self._column, cudf.Scalar(index))
else:
index = as_column(index)
out = extract_element_column(self._column, as_column(index))
bdice marked this conversation as resolved.
Show resolved Hide resolved

if not (default is None or default is cudf.NA):
# determine rows for which `index` is out-of-bounds
Expand Down
12 changes: 12 additions & 0 deletions python/cudf/cudf/tests/test_list.py
Original file line number Diff line number Diff line change
Expand Up @@ -320,6 +320,18 @@ def test_get_default():
)


def test_get_ind_sequence():
# test .list.get() when `index` is a sequence
shwina marked this conversation as resolved.
Show resolved Hide resolved
sr = cudf.Series([[1, 2], [3, 4, 5], [6, 7, 8, 9]])
assert_eq(cudf.Series([1, 4, 8]), sr.list.get([0, 1, 2]))
assert_eq(cudf.Series([1, 4, 8]), sr.list.get(cudf.Series([0, 1, 2])))
assert_eq(cudf.Series([1, 4, 8]), sr.list.get(cudf.Series([0, 1, 2])))
shwina marked this conversation as resolved.
Show resolved Hide resolved
assert_eq(cudf.Series([cudf.NA, 5, cudf.NA]), sr.list.get([2, 2, -5]))
assert_eq(cudf.Series([0, 5, 0]), sr.list.get([2, 2, -5], default=0))
sr_nested = cudf.Series([[[1, 2], [3, 4], [5, 6]], [[5, 6], [7, 8]]])
assert_eq(cudf.Series([[1, 2], [7, 8]]), sr_nested.list.get([0, 1]))


@pytest.mark.parametrize(
"data, scalar, expect",
[
Expand Down
17 changes: 7 additions & 10 deletions python/dask_cudf/dask_cudf/tests/test_accessor.py
Original file line number Diff line number Diff line change
Expand Up @@ -381,19 +381,16 @@ def test_contains(data, search_key):


@pytest.mark.parametrize(
"data, index, expectation",
"data, index",
[
(data_test_1(), 1, does_not_raise()),
(data_test_2(), 2, does_not_raise()),
(data_test_1(), 1),
(data_test_2(), 2),
],
)
def test_get(data, index, expectation):
with expectation:
expect = Series(data).list.get(index)

if expectation == does_not_raise():
ds = dgd.from_cudf(Series(data), 5)
assert_eq(expect, ds.list.get(index).compute())
def test_get(data, index):
expect = Series(data).list.get(index)
ds = dgd.from_cudf(Series(data), 5)
assert_eq(expect, ds.list.get(index).compute())


@pytest.mark.parametrize(
Expand Down