Skip to content

Commit

Permalink
Enable passing a sequence for the index argument to .list.get() (#…
Browse files Browse the repository at this point in the history
…10564)

Closes #10552.

Depends on #10547

Authors:
  - Ashwin Srinath (https://github.com/shwina)

Approvers:
  - Bradley Dice (https://github.com/bdice)
  - Vyas Ramasubramani (https://github.com/vyasr)
  - Charles Blackmon-Luca (https://github.com/charlesbluca)

URL: #10564
  • Loading branch information
shwina authored Apr 12, 2022
1 parent 3c13ef1 commit 2348277
Show file tree
Hide file tree
Showing 5 changed files with 75 additions and 24 deletions.
8 changes: 6 additions & 2 deletions python/cudf/cudf/_lib/cpp/lists/extract.pxd
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
# Copyright (c) 2021, NVIDIA CORPORATION.
# Copyright (c) 2021-2022, NVIDIA CORPORATION.

from libcpp.memory cimport unique_ptr

from cudf._lib.cpp.column.column cimport column
from cudf._lib.cpp.column.column cimport column, column_view
from cudf._lib.cpp.lists.lists_column_view cimport lists_column_view
from cudf._lib.cpp.types cimport size_type

Expand All @@ -12,3 +12,7 @@ cdef extern from "cudf/lists/extract.hpp" namespace "cudf::lists" nogil:
const lists_column_view,
size_type
) except +
cdef unique_ptr[column] extract_list_element(
const lists_column_view,
column_view
) except +
18 changes: 17 additions & 1 deletion python/cudf/cudf/_lib/lists.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,7 @@ def sort_lists(Column col, bool ascending, str na_position):
return Column.from_unique_ptr(move(c_result))


def extract_element(Column col, size_type index):
def extract_element_scalar(Column col, size_type index):
# shared_ptr required because lists_column_view has no default
# ctor
cdef shared_ptr[lists_column_view] list_view = (
Expand All @@ -142,6 +142,22 @@ def extract_element(Column col, size_type index):
return result


def extract_element_column(Column col, Column index):
cdef shared_ptr[lists_column_view] list_view = (
make_shared[lists_column_view](col.view())
)

cdef column_view index_view = index.view()

cdef unique_ptr[column] c_result

with nogil:
c_result = move(extract_list_element(list_view.get()[0], index_view))

result = Column.from_unique_ptr(move(c_result))
return result


def contains_scalar(Column col, object py_search_key):

cdef DeviceScalar search_key = py_search_key.device_value
Expand Down
45 changes: 34 additions & 11 deletions python/cudf/cudf/core/column/lists.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import pickle
from functools import cached_property
from typing import List, Optional, Sequence
from typing import List, Optional, Sequence, Union

import numpy as np
import pyarrow as pa
Expand All @@ -15,13 +15,18 @@
contains_scalar,
count_elements,
drop_list_duplicates,
extract_element,
extract_element_column,
extract_element_scalar,
index_of,
sort_lists,
)
from cudf._lib.strings.convert.convert_lists import format_list_column
from cudf._typing import ColumnBinaryOperand, ColumnLike, Dtype, ScalarLike
from cudf.api.types import _is_non_decimal_numeric_dtype, is_list_dtype
from cudf.api.types import (
_is_non_decimal_numeric_dtype,
is_list_dtype,
is_scalar,
)
from cudf.core.buffer import Buffer
from cudf.core.column import ColumnBase, as_column, column
from cudf.core.column.methods import ColumnMethods, ParentType
Expand Down Expand Up @@ -339,18 +344,27 @@ def __init__(self, parent: ParentType):
super().__init__(parent=parent)

def get(
self, index: int, default: Optional[ScalarLike] = None
self,
index: int,
default: Optional[Union[ScalarLike, ColumnLike]] = None,
) -> ParentType:
"""
Extract element at the given index from each list.
Extract element at the given index from each list in a Series of lists.
If the index is out of bounds for any list,
return <NA> or, if provided, ``default``.
Thus, this method never raises an ``IndexError``.
``index`` can be an integer or a sequence of integers. If
``index`` is an integer, the element at position ``index`` is
extracted from each list. If ``index`` is a sequence, it must
be of the same length as the Series, and ``index[i]``
specifies the position of the element to extract from the
``i``-th list in the Series.
If the index is out of bounds for any list, return <NA> or, if
provided, ``default``. Thus, this method never raises an
``IndexError``.
Parameters
----------
index : int
index : int or sequence of ints
default : scalar, optional
Returns
Expand All @@ -373,14 +387,23 @@ def get(
2 6
dtype: int64
>>> s = cudf.Series([[1, 2], [3, 4, 5], [4, 5, 6]])
>>> s.list.get(2, default=0)
0 0
1 5
2 6
dtype: int64
>>> s.list.get([0, 1, 2])
0 1
1 4
2 6
dtype: int64
"""
out = extract_element(self._column, index)
if is_scalar(index):
out = extract_element_scalar(self._column, cudf.Scalar(index))
else:
index = as_column(index)
out = extract_element_column(self._column, as_column(index))

if not (default is None or default is cudf.NA):
# determine rows for which `index` is out-of-bounds
Expand Down
11 changes: 11 additions & 0 deletions python/cudf/cudf/tests/test_list.py
Original file line number Diff line number Diff line change
Expand Up @@ -320,6 +320,17 @@ def test_get_default():
)


def test_get_ind_sequence():
# test .list.get() when `index` is a sequence
sr = cudf.Series([[1, 2], [3, 4, 5], [6, 7, 8, 9]])
assert_eq(cudf.Series([1, 4, 8]), sr.list.get([0, 1, 2]))
assert_eq(cudf.Series([1, 4, 8]), sr.list.get(cudf.Series([0, 1, 2])))
assert_eq(cudf.Series([cudf.NA, 5, cudf.NA]), sr.list.get([2, 2, -5]))
assert_eq(cudf.Series([0, 5, 0]), sr.list.get([2, 2, -5], default=0))
sr_nested = cudf.Series([[[1, 2], [3, 4], [5, 6]], [[5, 6], [7, 8]]])
assert_eq(cudf.Series([[1, 2], [7, 8]]), sr_nested.list.get([0, 1]))


@pytest.mark.parametrize(
"data, scalar, expect",
[
Expand Down
17 changes: 7 additions & 10 deletions python/dask_cudf/dask_cudf/tests/test_accessor.py
Original file line number Diff line number Diff line change
Expand Up @@ -381,19 +381,16 @@ def test_contains(data, search_key):


@pytest.mark.parametrize(
"data, index, expectation",
"data, index",
[
(data_test_1(), 1, does_not_raise()),
(data_test_2(), 2, does_not_raise()),
(data_test_1(), 1),
(data_test_2(), 2),
],
)
def test_get(data, index, expectation):
with expectation:
expect = Series(data).list.get(index)

if expectation == does_not_raise():
ds = dgd.from_cudf(Series(data), 5)
assert_eq(expect, ds.list.get(index).compute())
def test_get(data, index):
expect = Series(data).list.get(index)
ds = dgd.from_cudf(Series(data), 5)
assert_eq(expect, ds.list.get(index).compute())


@pytest.mark.parametrize(
Expand Down

0 comments on commit 2348277

Please sign in to comment.