Skip to content

Commit

Permalink
Return python lists for __getitem__ calls to list type series (#8265)
Browse files Browse the repository at this point in the history
Make it so that this works:

```
x = cudf.Series([[1,2,None]])
x[0]
# [1, 2, <NA>]
```

Authors:
  - https://github.com/brandon-b-miller

Approvers:
  - GALI PREM SAGAR (https://github.com/galipremsagar)
  - Michael Wang (https://github.com/isVoid)

URL: #8265
  • Loading branch information
brandon-b-miller authored May 20, 2021
1 parent 2a1075e commit b553144
Show file tree
Hide file tree
Showing 4 changed files with 87 additions and 6 deletions.
6 changes: 6 additions & 0 deletions python/cudf/cudf/_lib/cpp/scalar/scalar.pxd
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,9 @@ from libcpp.string cimport string
from cudf._lib.cpp.types cimport data_type
from cudf._lib.cpp.wrappers.decimals cimport scale_type

from cudf._lib.cpp.column.column_view cimport column_view


cdef extern from "cudf/scalar/scalar.hpp" namespace "cudf" nogil:
cdef cppclass scalar:
scalar() except +
Expand Down Expand Up @@ -60,3 +63,6 @@ cdef extern from "cudf/scalar/scalar.hpp" namespace "cudf" nogil:
bool is_valid) except +
int64_t value() except +
# TODO: Figure out how to add an int32 overload of value()

cdef cppclass list_scalar(scalar):
column_view view() except +
63 changes: 58 additions & 5 deletions python/cudf/cudf/_lib/scalar.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,18 @@ from libcpp.utility cimport move
from libcpp cimport bool

import cudf
from cudf._lib.types import cudf_to_np_types, duration_unit_map
from cudf.core.dtypes import ListDtype
from cudf._lib.types import (
cudf_to_np_types,
duration_unit_map
)
from cudf._lib.types import datetime_unit_map
from cudf._lib.types cimport underlying_type_t_type_id
from cudf._lib.types cimport underlying_type_t_type_id, dtype_from_column_view

from cudf._lib.column cimport Column
from cudf._lib.cpp.column.column_view cimport column_view
from cudf._lib.table cimport Table
from cudf._lib.interop import to_arrow

from cudf._lib.cpp.wrappers.timestamps cimport (
timestamp_s,
Expand All @@ -41,12 +50,12 @@ from cudf._lib.cpp.scalar.scalar cimport (
timestamp_scalar,
duration_scalar,
string_scalar,
fixed_point_scalar
fixed_point_scalar,
list_scalar,
)
from cudf.utils.dtypes import _decimal_to_int64
from cudf.utils.dtypes import _decimal_to_int64, is_list_dtype
cimport cudf._lib.cpp.types as libcudf_types


cdef class DeviceScalar:

def __init__(self, value, dtype):
Expand Down Expand Up @@ -97,6 +106,8 @@ cdef class DeviceScalar:
def _to_host_scalar(self):
if isinstance(self.dtype, cudf.Decimal64Dtype):
result = _get_py_decimal_from_fixed_point(self.c_value)
elif is_list_dtype(self.dtype):
result = _get_py_list_from_list(self.c_value)
elif pd.api.types.is_string_dtype(self.dtype):
result = _get_py_string_from_string(self.c_value)
elif pd.api.types.is_numeric_dtype(self.dtype):
Expand Down Expand Up @@ -159,6 +170,22 @@ cdef class DeviceScalar:
raise TypeError(
"Must pass a dtype when constructing from a fixed-point scalar"
)
elif cdtype.id() == libcudf_types.LIST:
if (
<list_scalar*>s.get_raw_ptr()
)[0].view().type().id() == libcudf_types.LIST:
s._dtype = dtype_from_column_view(
(<list_scalar*>s.get_raw_ptr())[0].view()
)
else:
s._dtype = ListDtype(
cudf_to_np_types[
<underlying_type_t_type_id>(
(<list_scalar*>s.get_raw_ptr())[0]
.view().type().id()
)
]
)
else:
if dtype is not None:
s._dtype = dtype
Expand Down Expand Up @@ -268,6 +295,19 @@ cdef _set_decimal64_from_scalar(unique_ptr[scalar]& s,
)
)

cdef _get_py_list_from_list(unique_ptr[scalar]& s):

if not s.get()[0].is_valid():
return cudf.NA

cdef column_view list_col_view = (<list_scalar*>s.get()).view()
cdef Column list_col = Column.from_column_view(list_col_view, None)
cdef Table to_arrow_table = Table({"col": list_col})

arrow_table = to_arrow(to_arrow_table, [["col", []]])
result = arrow_table['col'].to_pylist()
return _nested_na_replace(result)

cdef _get_py_string_from_string(unique_ptr[scalar]& s):
if not s.get()[0].is_valid():
return cudf.NA
Expand Down Expand Up @@ -440,3 +480,16 @@ def _create_proxy_nat_scalar(dtype):
return result
else:
raise TypeError('NAT only valid for datetime and timedelta')


def _nested_na_replace(input_list):
'''
Replace `None` with `cudf.NA` in the result of
`__getitem__` calls to list type columns
'''
for idx, value in enumerate(input_list):
if isinstance(value, list):
_nested_na_replace(value)
elif value is None:
input_list[idx] = cudf.NA
return input_list
6 changes: 5 additions & 1 deletion python/cudf/cudf/core/indexing.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,11 @@ def __getitem__(self, arg):
arg = list(arg)
data = self._sr._column[arg]

if is_scalar(data) or _is_null_host_scalar(data):
if (
isinstance(data, list)
or is_scalar(data)
or _is_null_host_scalar(data)
):
return data
index = self._sr.index.take(arg)
return self._sr._copy_construct(data=data, index=index)
Expand Down
18 changes: 18 additions & 0 deletions python/cudf/cudf/tests/test_list.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import pytest

import cudf
from cudf import NA
from cudf.tests.utils import assert_eq


Expand Down Expand Up @@ -332,3 +333,20 @@ def test_concatenate_list_with_nonlist():
gdf1 = cudf.DataFrame({"A": [["a", "c"], ["b", "d"], ["c", "d"]]})
gdf2 = cudf.DataFrame({"A": ["a", "b", "c"]})
gdf1["A"] + gdf2["A"]


@pytest.mark.parametrize(
"indata,expect",
[
([1], [1]),
([1, 2, 3], [1, 2, 3]),
([[1, 2, 3], [4, 5, 6]], [[1, 2, 3], [4, 5, 6]]),
([None], [NA]),
([1, None, 3], [1, NA, 3]),
([[1, None, 3], [None, 5, 6]], [[1, NA, 3], [NA, 5, 6]]),
],
)
def test_list_getitem(indata, expect):
list_sr = cudf.Series([indata])
# __getitem__ shall fill None with cudf.NA
assert list_sr[0] == expect

0 comments on commit b553144

Please sign in to comment.