Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Return python lists for __getitem__ calls to list type series #8265

Merged
merged 12 commits into from
May 20, 2021
6 changes: 6 additions & 0 deletions python/cudf/cudf/_lib/cpp/scalar/scalar.pxd
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,9 @@ from libcpp.string cimport string
from cudf._lib.cpp.types cimport data_type
from cudf._lib.cpp.wrappers.decimals cimport scale_type

from cudf._lib.cpp.column.column_view cimport column_view


cdef extern from "cudf/scalar/scalar.hpp" namespace "cudf" nogil:
cdef cppclass scalar:
scalar() except +
Expand Down Expand Up @@ -60,3 +63,6 @@ cdef extern from "cudf/scalar/scalar.hpp" namespace "cudf" nogil:
bool is_valid) except +
int64_t value() except +
# TODO: Figure out how to add an int32 overload of value()

cdef cppclass list_scalar(scalar):
column_view view() except +
63 changes: 58 additions & 5 deletions python/cudf/cudf/_lib/scalar.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,18 @@ from libcpp.utility cimport move
from libcpp cimport bool

import cudf
from cudf._lib.types import cudf_to_np_types, duration_unit_map
from cudf.core.dtypes import ListDtype
from cudf._lib.types import (
cudf_to_np_types,
duration_unit_map
)
from cudf._lib.types import datetime_unit_map
from cudf._lib.types cimport underlying_type_t_type_id
from cudf._lib.types cimport underlying_type_t_type_id, dtype_from_column_view

from cudf._lib.column cimport Column
from cudf._lib.cpp.column.column_view cimport column_view
from cudf._lib.table cimport Table
from cudf._lib.interop import to_arrow

from cudf._lib.cpp.wrappers.timestamps cimport (
timestamp_s,
Expand All @@ -41,12 +50,12 @@ from cudf._lib.cpp.scalar.scalar cimport (
timestamp_scalar,
duration_scalar,
string_scalar,
fixed_point_scalar
fixed_point_scalar,
list_scalar,
)
from cudf.utils.dtypes import _decimal_to_int64
from cudf.utils.dtypes import _decimal_to_int64, is_list_dtype
cimport cudf._lib.cpp.types as libcudf_types


cdef class DeviceScalar:

def __init__(self, value, dtype):
Expand Down Expand Up @@ -97,6 +106,8 @@ cdef class DeviceScalar:
def _to_host_scalar(self):
if isinstance(self.dtype, cudf.Decimal64Dtype):
result = _get_py_decimal_from_fixed_point(self.c_value)
elif is_list_dtype(self.dtype):
result = _get_py_list_from_list(self.c_value)
elif pd.api.types.is_string_dtype(self.dtype):
result = _get_py_string_from_string(self.c_value)
elif pd.api.types.is_numeric_dtype(self.dtype):
Expand Down Expand Up @@ -159,6 +170,22 @@ cdef class DeviceScalar:
raise TypeError(
"Must pass a dtype when constructing from a fixed-point scalar"
)
elif cdtype.id() == libcudf_types.LIST:
if (
<list_scalar*>s.get_raw_ptr()
)[0].view().type().id() == libcudf_types.LIST:
s._dtype = dtype_from_column_view(
(<list_scalar*>s.get_raw_ptr())[0].view()
)
else:
s._dtype = ListDtype(
cudf_to_np_types[
<underlying_type_t_type_id>(
(<list_scalar*>s.get_raw_ptr())[0]
.view().type().id()
)
]
)
else:
if dtype is not None:
s._dtype = dtype
Expand Down Expand Up @@ -268,6 +295,19 @@ cdef _set_decimal64_from_scalar(unique_ptr[scalar]& s,
)
)

cdef _get_py_list_from_list(unique_ptr[scalar]& s):

if not s.get()[0].is_valid():
return cudf.NA

cdef column_view list_col_view = (<list_scalar*>s.get()).view()
cdef Column list_col = Column.from_column_view(list_col_view, None)
cdef Table to_arrow_table = Table({"col": list_col})

arrow_table = to_arrow(to_arrow_table, [["col", []]])
result = arrow_table['col'].to_pylist()
return _nested_na_replace(result)

cdef _get_py_string_from_string(unique_ptr[scalar]& s):
if not s.get()[0].is_valid():
return cudf.NA
Expand Down Expand Up @@ -440,3 +480,16 @@ def _create_proxy_nat_scalar(dtype):
return result
else:
raise TypeError('NAT only valid for datetime and timedelta')


def _nested_na_replace(input_list):
'''
Replace `None` with `cudf.NA` in the result of
`__getitem__` calls to list type columns
'''
for idx, value in enumerate(input_list):
if isinstance(value, list):
_nested_na_replace(value)
elif value is None:
input_list[idx] = cudf.NA
return input_list
6 changes: 5 additions & 1 deletion python/cudf/cudf/core/indexing.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,11 @@ def __getitem__(self, arg):
arg = list(arg)
data = self._sr._column[arg]

if is_scalar(data) or _is_null_host_scalar(data):
if (
isinstance(data, list)
or is_scalar(data)
or _is_null_host_scalar(data)
):
return data
index = self._sr.index.take(arg)
return self._sr._copy_construct(data=data, index=index)
Expand Down
18 changes: 18 additions & 0 deletions python/cudf/cudf/tests/test_list.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import pytest

import cudf
from cudf import NA
from cudf.tests.utils import assert_eq


Expand Down Expand Up @@ -332,3 +333,20 @@ def test_concatenate_list_with_nonlist():
gdf1 = cudf.DataFrame({"A": [["a", "c"], ["b", "d"], ["c", "d"]]})
gdf2 = cudf.DataFrame({"A": ["a", "b", "c"]})
gdf1["A"] + gdf2["A"]


@pytest.mark.parametrize(
"indata,expect",
[
([1], [1]),
([1, 2, 3], [1, 2, 3]),
([[1, 2, 3], [4, 5, 6]], [[1, 2, 3], [4, 5, 6]]),
([None], [NA]),
([1, None, 3], [1, NA, 3]),
([[1, None, 3], [None, 5, 6]], [[1, NA, 3], [NA, 5, 6]]),
],
)
def test_list_getitem(indata, expect):
list_sr = cudf.Series([indata])
# __getitem__ shall fill None with cudf.NA
assert list_sr[0] == expect