Skip to content

Commit

Permalink
Adds list.take, python binding for cudf::lists::segmented_gather (#…
Browse files Browse the repository at this point in the history
…7591)

Closes #7465 

Implements `ListColumn.list.take` based on `cudf::lists:segmented_gather`. Gather elements inside each list based on the provided positions. Example:

```python
>>> s = cudf.Series([[1, 2, 3], [4, 5]])
>>> s
0    [1, 2, 3]
1       [4, 5]
dtype: list
>>> s.list.take([[2, 1], [1, 0]])
0    [3, 2]
1    [5, 4]
dtype: list
```

Authors:
  - Michael Wang (@isVoid)

Approvers:
  - Keith Kraus (@kkraus14)

URL: #7591
  • Loading branch information
isVoid authored Mar 18, 2021
1 parent 9aa33ef commit 951b455
Show file tree
Hide file tree
Showing 4 changed files with 143 additions and 3 deletions.
25 changes: 24 additions & 1 deletion python/cudf/cudf/_lib/copying.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import pandas as pd

from libcpp cimport bool
from libcpp.memory cimport make_unique, unique_ptr
from libcpp.memory cimport make_unique, unique_ptr, shared_ptr, make_shared
from libcpp.vector cimport vector
from libcpp.utility cimport move
from libc.stdint cimport int32_t, int64_t
Expand All @@ -24,6 +24,10 @@ from cudf._lib.cpp.scalar.scalar cimport scalar
from cudf._lib.cpp.table.table cimport table
from cudf._lib.cpp.table.table_view cimport table_view
from cudf._lib.cpp.types cimport size_type
from cudf._lib.cpp.lists.lists_column_view cimport lists_column_view
from cudf._lib.cpp.lists.gather cimport (
segmented_gather as cpp_segmented_gather
)
cimport cudf._lib.cpp.copying as cpp_copying

# workaround for https://github.com/cython/cython/issues/3885
Expand Down Expand Up @@ -704,3 +708,22 @@ def sample(Table input, size_type n,
else input._index_names
)
)


def segmented_gather(Column source_column, Column gather_map):
cdef shared_ptr[lists_column_view] source_LCV = (
make_shared[lists_column_view](source_column.view())
)
cdef shared_ptr[lists_column_view] gather_map_LCV = (
make_shared[lists_column_view](gather_map.view())
)
cdef unique_ptr[column] c_result

with nogil:
c_result = move(
cpp_segmented_gather(
source_LCV.get()[0], gather_map_LCV.get()[0])
)

result = Column.from_unique_ptr(move(c_result))
return result
13 changes: 13 additions & 0 deletions python/cudf/cudf/_lib/cpp/lists/gather.pxd
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
# Copyright (c) 2021, NVIDIA CORPORATION.

from libcpp.memory cimport unique_ptr

from cudf._lib.cpp.column.column cimport column
from cudf._lib.cpp.lists.lists_column_view cimport lists_column_view


cdef extern from "cudf/lists/gather.hpp" namespace "cudf::lists" nogil:
cdef unique_ptr[column] segmented_gather(
const lists_column_view source_column,
const lists_column_view gather_map_list
) except +
61 changes: 59 additions & 2 deletions python/cudf/cudf/core/column/lists.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,16 @@

import pickle

import numpy as np
import pyarrow as pa

import cudf
from cudf._lib.copying import segmented_gather
from cudf._lib.lists import count_elements
from cudf.core.buffer import Buffer
from cudf.core.column import ColumnBase, column
from cudf.core.column import ColumnBase, as_column, column
from cudf.core.column.methods import ColumnMethodsMixin
from cudf.utils.dtypes import is_list_dtype
from cudf.utils.dtypes import is_list_dtype, is_numerical_dtype


class ListColumn(ColumnBase):
Expand Down Expand Up @@ -228,3 +230,58 @@ def len(self):
dtype: int32
"""
return self._return_or_inplace(count_elements(self._column))

def take(self, lists_indices):
"""
Collect list elements based on given indices.
Parameters
----------
lists_indices: List type arrays
Specifies what to collect from each row
Returns
-------
ListColumn
Examples
--------
>>> s = cudf.Series([[1, 2, 3], None, [4, 5]])
>>> s
0 [1, 2, 3]
1 None
2 [4, 5]
dtype: list
>>> s.list.take([[0, 1], [], []])
0 [1, 2]
1 None
2 []
dtype: list
"""

lists_indices_col = as_column(lists_indices)
if not isinstance(lists_indices_col, ListColumn):
raise ValueError("lists_indices should be list type array.")
if not lists_indices_col.size == self._column.size:
raise ValueError(
"lists_indices and list column is of different " "size."
)
if not is_numerical_dtype(
lists_indices_col.children[1].dtype
) or not np.issubdtype(
lists_indices_col.children[1].dtype, np.integer
):
raise TypeError(
"lists_indices should be column of values of index types."
)

try:
res = self._return_or_inplace(
segmented_gather(self._column, lists_indices_col)
)
except RuntimeError as e:
if "contains nulls" in str(e):
raise ValueError("lists_indices contains null.") from e
raise
else:
return res
47 changes: 47 additions & 0 deletions python/cudf/cudf/tests/test_list.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,3 +112,50 @@ def test_len(data):
got = gsr.list.len()

assert_eq(expect, got, check_dtype=False)


@pytest.mark.parametrize(
("data", "idx"),
[
([[1, 2, 3], [3, 4, 5], [4, 5, 6]], [[0, 1], [2], [1, 2]]),
([[1, 2, 3], [3, 4, 5], [4, 5, 6]], [[1, 2, 0], [1, 0, 2], [0, 1, 2]]),
([[1, 2, 3], []], [[0, 1], []]),
([[1, 2, 3], [None]], [[0, 1], []]),
([[1, None, 3], None], [[0, 1], []]),
],
)
def test_take(data, idx):
ps = pd.Series(data)
gs = cudf.from_pandas(ps)

expected = pd.Series(zip(ps, idx)).map(
lambda x: [x[0][i] for i in x[1]] if x[0] is not None else None
)
got = gs.list.take(idx)
assert_eq(expected, got)


@pytest.mark.parametrize(
("invalid", "exception"),
[
([[0]], pytest.raises(ValueError, match="different size")),
([1, 2, 3, 4], pytest.raises(ValueError, match="should be list type")),
(
[["a", "b"], ["c"]],
pytest.raises(
TypeError, match="should be column of values of index types"
),
),
(
[[[1], [0]], [[0]]],
pytest.raises(
TypeError, match="should be column of values of index types"
),
),
([[0, 1], None], pytest.raises(ValueError, match="contains null")),
],
)
def test_take_invalid(invalid, exception):
gs = cudf.Series([[0, 1], [2, 3]])
with exception:
gs.list.take(invalid)

0 comments on commit 951b455

Please sign in to comment.