diff --git a/python/cudf/cudf/_lib/cpp/lists/extract.pxd b/python/cudf/cudf/_lib/cpp/lists/extract.pxd index a023f728989..93a886d7268 100644 --- a/python/cudf/cudf/_lib/cpp/lists/extract.pxd +++ b/python/cudf/cudf/_lib/cpp/lists/extract.pxd @@ -1,8 +1,8 @@ -# Copyright (c) 2021, NVIDIA CORPORATION. +# Copyright (c) 2021-2022, NVIDIA CORPORATION. from libcpp.memory cimport unique_ptr -from cudf._lib.cpp.column.column cimport column +from cudf._lib.cpp.column.column cimport column, column_view from cudf._lib.cpp.lists.lists_column_view cimport lists_column_view from cudf._lib.cpp.types cimport size_type @@ -12,3 +12,7 @@ cdef extern from "cudf/lists/extract.hpp" namespace "cudf::lists" nogil: const lists_column_view, size_type ) except + + cdef unique_ptr[column] extract_list_element( + const lists_column_view, + column_view + ) except + diff --git a/python/cudf/cudf/_lib/lists.pyx b/python/cudf/cudf/_lib/lists.pyx index 702cf86a995..523686fafe6 100644 --- a/python/cudf/cudf/_lib/lists.pyx +++ b/python/cudf/cudf/_lib/lists.pyx @@ -126,7 +126,7 @@ def sort_lists(Column col, bool ascending, str na_position): return Column.from_unique_ptr(move(c_result)) -def extract_element(Column col, size_type index): +def extract_element_scalar(Column col, size_type index): # shared_ptr required because lists_column_view has no default # ctor cdef shared_ptr[lists_column_view] list_view = ( @@ -142,6 +142,22 @@ def extract_element(Column col, size_type index): return result +def extract_element_column(Column col, Column index): + cdef shared_ptr[lists_column_view] list_view = ( + make_shared[lists_column_view](col.view()) + ) + + cdef column_view index_view = index.view() + + cdef unique_ptr[column] c_result + + with nogil: + c_result = move(extract_list_element(list_view.get()[0], index_view)) + + result = Column.from_unique_ptr(move(c_result)) + return result + + def contains_scalar(Column col, object py_search_key): cdef DeviceScalar search_key = py_search_key.device_value diff --git a/python/cudf/cudf/core/column/lists.py b/python/cudf/cudf/core/column/lists.py index 1c9b394d70d..8578bfe8147 100644 --- a/python/cudf/cudf/core/column/lists.py +++ b/python/cudf/cudf/core/column/lists.py @@ -2,7 +2,7 @@ import pickle from functools import cached_property -from typing import List, Optional, Sequence +from typing import List, Optional, Sequence, Union import numpy as np import pyarrow as pa @@ -15,13 +15,18 @@ contains_scalar, count_elements, drop_list_duplicates, - extract_element, + extract_element_column, + extract_element_scalar, index_of, sort_lists, ) from cudf._lib.strings.convert.convert_lists import format_list_column from cudf._typing import ColumnBinaryOperand, ColumnLike, Dtype, ScalarLike -from cudf.api.types import _is_non_decimal_numeric_dtype, is_list_dtype +from cudf.api.types import ( + _is_non_decimal_numeric_dtype, + is_list_dtype, + is_scalar, +) from cudf.core.buffer import Buffer from cudf.core.column import ColumnBase, as_column, column from cudf.core.column.methods import ColumnMethods, ParentType @@ -339,18 +344,27 @@ def __init__(self, parent: ParentType): super().__init__(parent=parent) def get( - self, index: int, default: Optional[ScalarLike] = None + self, + index: int, + default: Optional[Union[ScalarLike, ColumnLike]] = None, ) -> ParentType: """ - Extract element at the given index from each list. + Extract element at the given index from each list in a Series of lists. - If the index is out of bounds for any list, - return or, if provided, ``default``. - Thus, this method never raises an ``IndexError``. + ``index`` can be an integer or a sequence of integers. If + ``index`` is an integer, the element at position ``index`` is + extracted from each list. If ``index`` is a sequence, it must + be of the same length as the Series, and ``index[i]`` + specifies the position of the element to extract from the + ``i``-th list in the Series. + + If the index is out of bounds for any list, return or, if + provided, ``default``. Thus, this method never raises an + ``IndexError``. Parameters ---------- - index : int + index : int or sequence of ints default : scalar, optional Returns @@ -373,14 +387,23 @@ def get( 2 6 dtype: int64 - >>> s = cudf.Series([[1, 2], [3, 4, 5], [4, 5, 6]]) >>> s.list.get(2, default=0) 0 0 1 5 2 6 dtype: int64 + + >>> s.list.get([0, 1, 2]) + 0 1 + 1 4 + 2 6 + dtype: int64 """ - out = extract_element(self._column, index) + if is_scalar(index): + out = extract_element_scalar(self._column, cudf.Scalar(index)) + else: + index = as_column(index) + out = extract_element_column(self._column, as_column(index)) if not (default is None or default is cudf.NA): # determine rows for which `index` is out-of-bounds diff --git a/python/cudf/cudf/tests/test_list.py b/python/cudf/cudf/tests/test_list.py index ade3d1903d8..cf53a3525ef 100644 --- a/python/cudf/cudf/tests/test_list.py +++ b/python/cudf/cudf/tests/test_list.py @@ -320,6 +320,17 @@ def test_get_default(): ) +def test_get_ind_sequence(): + # test .list.get() when `index` is a sequence + sr = cudf.Series([[1, 2], [3, 4, 5], [6, 7, 8, 9]]) + assert_eq(cudf.Series([1, 4, 8]), sr.list.get([0, 1, 2])) + assert_eq(cudf.Series([1, 4, 8]), sr.list.get(cudf.Series([0, 1, 2]))) + assert_eq(cudf.Series([cudf.NA, 5, cudf.NA]), sr.list.get([2, 2, -5])) + assert_eq(cudf.Series([0, 5, 0]), sr.list.get([2, 2, -5], default=0)) + sr_nested = cudf.Series([[[1, 2], [3, 4], [5, 6]], [[5, 6], [7, 8]]]) + assert_eq(cudf.Series([[1, 2], [7, 8]]), sr_nested.list.get([0, 1])) + + @pytest.mark.parametrize( "data, scalar, expect", [ diff --git a/python/dask_cudf/dask_cudf/tests/test_accessor.py b/python/dask_cudf/dask_cudf/tests/test_accessor.py index 95cf0c8d56d..f83800bf6b0 100644 --- a/python/dask_cudf/dask_cudf/tests/test_accessor.py +++ b/python/dask_cudf/dask_cudf/tests/test_accessor.py @@ -381,19 +381,16 @@ def test_contains(data, search_key): @pytest.mark.parametrize( - "data, index, expectation", + "data, index", [ - (data_test_1(), 1, does_not_raise()), - (data_test_2(), 2, does_not_raise()), + (data_test_1(), 1), + (data_test_2(), 2), ], ) -def test_get(data, index, expectation): - with expectation: - expect = Series(data).list.get(index) - - if expectation == does_not_raise(): - ds = dgd.from_cudf(Series(data), 5) - assert_eq(expect, ds.list.get(index).compute()) +def test_get(data, index): + expect = Series(data).list.get(index) + ds = dgd.from_cudf(Series(data), 5) + assert_eq(expect, ds.list.get(index).compute()) @pytest.mark.parametrize(