Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adds list.take, python binding for cudf::lists::segmented_gather #7591

Merged
merged 9 commits into from
Mar 18, 2021
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 24 additions & 1 deletion python/cudf/cudf/_lib/copying.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import pandas as pd

from libcpp cimport bool
from libcpp.memory cimport make_unique, unique_ptr
from libcpp.memory cimport make_unique, unique_ptr, shared_ptr, make_shared
from libcpp.vector cimport vector
from libcpp.utility cimport move
from libc.stdint cimport int32_t, int64_t
Expand All @@ -24,6 +24,10 @@ from cudf._lib.cpp.scalar.scalar cimport scalar
from cudf._lib.cpp.table.table cimport table
from cudf._lib.cpp.table.table_view cimport table_view
from cudf._lib.cpp.types cimport size_type
from cudf._lib.cpp.lists.lists_column_view cimport lists_column_view
from cudf._lib.cpp.lists.gather cimport (
segmented_gather as cpp_segmented_gather
)
cimport cudf._lib.cpp.copying as cpp_copying

# workaround for https://github.com/cython/cython/issues/3885
Expand Down Expand Up @@ -704,3 +708,22 @@ def sample(Table input, size_type n,
else input._index_names
)
)


def segmented_gather(Column source_column, Column gather_map):
cdef shared_ptr[lists_column_view] source_LCV = (
make_shared[lists_column_view](source_column.view())
)
cdef shared_ptr[lists_column_view] gather_map_LCV = (
make_shared[lists_column_view](gather_map.view())
)
kkraus14 marked this conversation as resolved.
Show resolved Hide resolved
cdef unique_ptr[column] c_result

with nogil:
c_result = move(
cpp_segmented_gather(
source_LCV.get()[0], gather_map_LCV.get()[0])
)

result = Column.from_unique_ptr(move(c_result))
return result
13 changes: 13 additions & 0 deletions python/cudf/cudf/_lib/cpp/lists/gather.pxd
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
# Copyright (c) 2021, NVIDIA CORPORATION.

from libcpp.memory cimport unique_ptr

from cudf._lib.cpp.column.column cimport column
from cudf._lib.cpp.lists.lists_column_view cimport lists_column_view


cdef extern from "cudf/lists/gather.hpp" namespace "cudf::lists" nogil:
cdef unique_ptr[column] segmented_gather(
const lists_column_view source_column,
const lists_column_view gather_map_list
) except +
56 changes: 54 additions & 2 deletions python/cudf/cudf/core/column/lists.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,16 @@

import pickle

import numpy as np
import pyarrow as pa

import cudf
from cudf._lib.copying import segmented_gather
from cudf._lib.lists import count_elements
from cudf.core.buffer import Buffer
from cudf.core.column import ColumnBase, column
from cudf.core.column import ColumnBase, as_column, column
from cudf.core.column.methods import ColumnMethodsMixin
from cudf.utils.dtypes import is_list_dtype
from cudf.utils.dtypes import is_list_dtype, is_numerical_dtype


class ListColumn(ColumnBase):
Expand Down Expand Up @@ -228,3 +230,53 @@ def len(self):
dtype: int32
"""
return self._return_or_inplace(count_elements(self._column))

def take(self, lists_indices):
"""
Collect list elements based on given indices.

Parameters
----------
lists_indices: List type arrays
Specifies what to collect from each row
isVoid marked this conversation as resolved.
Show resolved Hide resolved

Returns
-------
ListColumn

Examples
--------
>>> s = cudf.Series([[1, 2, 3], None, [4, 5]])
>>> s
0 [1, 2, 3]
1 None
2 [4, 5]
dtype: list
>>> s.list.take([[0, 1], [], []])
0 [1, 2]
1 None
2 []
dtype: list
"""

lists_indices_col = as_column(lists_indices)
if not isinstance(lists_indices_col, ListColumn):
raise ValueError("lists_indices should be list type array.")
kkraus14 marked this conversation as resolved.
Show resolved Hide resolved
if not lists_indices_col.size == self._column.size:
raise ValueError(
"lists_indices and list column is of different " "size."
)
if lists_indices_col.null_count > 0:
raise ValueError("lists_indices contains null elements.")
isVoid marked this conversation as resolved.
Show resolved Hide resolved
if not is_numerical_dtype(
lists_indices_col.children[1].dtype
) or not np.issubdtype(
lists_indices_col.children[1].dtype, np.integer
):
raise TypeError(
"lists_indices should be column of values of index types."
)

return self._return_or_inplace(
segmented_gather(self._column, lists_indices_col)
)
47 changes: 47 additions & 0 deletions python/cudf/cudf/tests/test_list.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,3 +112,50 @@ def test_len(data):
got = gsr.list.len()

assert_eq(expect, got, check_dtype=False)


@pytest.mark.parametrize(
("data", "idx"),
[
([[1, 2, 3], [3, 4, 5], [4, 5, 6]], [[0, 1], [2], [1, 2]]),
([[1, 2, 3], [3, 4, 5], [4, 5, 6]], [[1, 2, 0], [1, 0, 2], [0, 1, 2]]),
([[1, 2, 3], []], [[0, 1], []]),
([[1, 2, 3], [None]], [[0, 1], []]),
([[1, None, 3], None], [[0, 1], []]),
],
)
def test_take(data, idx):
ps = pd.Series(data)
gs = cudf.from_pandas(ps)

expected = pd.Series(zip(ps, idx)).map(
lambda x: [x[0][i] for i in x[1]] if x[0] is not None else None
)
got = gs.list.take(idx)
assert_eq(expected, got)


@pytest.mark.parametrize(
("invalid", "exception"),
[
([[0]], pytest.raises(ValueError, match="different size")),
([1, 2, 3, 4], pytest.raises(ValueError, match="should be list type")),
(
[["a", "b"], ["c"]],
pytest.raises(
TypeError, match="should be column of values of index types"
),
),
(
[[[1], [0]], [[0]]],
pytest.raises(
TypeError, match="should be column of values of index types"
),
),
([[0, 1], None], pytest.raises(ValueError, match="contains null")),
],
)
def test_take_invalid(invalid, exception):
gs = cudf.Series([[0, 1], [2, 3]])
with exception:
gs.list.take(invalid)