Skip to content

Commit

Permalink
ListColumn __setitem__ (#8606)
Browse files Browse the repository at this point in the history
  • Loading branch information
brandon-b-miller authored Jun 30, 2021
1 parent dab8a62 commit 5884b95
Show file tree
Hide file tree
Showing 6 changed files with 97 additions and 21 deletions.
1 change: 1 addition & 0 deletions python/cudf/cudf/_lib/cpp/scalar/scalar.pxd
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,7 @@ cdef extern from "cudf/scalar/scalar.hpp" namespace "cudf" nogil:

cdef cppclass list_scalar(scalar):
list_scalar(column_view col) except +
list_scalar(column_view col, bool is_valid) except +
column_view view() except +

cdef cppclass struct_scalar(scalar):
Expand Down
34 changes: 15 additions & 19 deletions python/cudf/cudf/_lib/scalar.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -333,20 +333,18 @@ cdef _set_list_from_pylist(unique_ptr[scalar]& s,
value = value if valid else [cudf.NA]
cdef Column col
if isinstance(dtype.element_type, ListDtype):
col = cudf.core.column.as_column(
pa.array(
value, from_pandas=True, type=dtype.element_type.to_arrow()
)
)
pa_type = dtype.element_type.to_arrow()
else:
col = cudf.core.column.as_column(
pa.array(value, from_pandas=True)
)
pa_type = dtype.to_arrow().value_type
col = cudf.core.column.as_column(
pa.array(value, from_pandas=True, type=pa_type)
)
cdef column_view col_view = col.view()
s.reset(
new list_scalar(col_view)
new list_scalar(col_view, valid)
)


cdef _get_py_list_from_list(unique_ptr[scalar]& s):

if not s.get()[0].is_valid():
Expand Down Expand Up @@ -497,18 +495,16 @@ cdef _get_np_scalar_from_timedelta64(unique_ptr[scalar]& s):


def as_device_scalar(val, dtype=None):
if dtype:
if isinstance(val, (cudf.Scalar, DeviceScalar)) and dtype != val.dtype:
raise TypeError("Can't update dtype of existing GPU scalar")
if isinstance(val, (cudf.Scalar, DeviceScalar)):
if dtype == val.dtype or dtype is None:
if isinstance(val, DeviceScalar):
return val
else:
return val.device_value
else:
return cudf.Scalar(value=val, dtype=dtype).device_value
raise TypeError("Can't update dtype of existing GPU scalar")
else:
if isinstance(val, DeviceScalar):
return val
if isinstance(val, cudf.Scalar):
return val.device_value
else:
return cudf.Scalar(val).device_value
return cudf.Scalar(val, dtype=dtype).device_value


def _is_null_host_scalar(slr):
Expand Down
12 changes: 12 additions & 0 deletions python/cudf/cudf/core/column/lists.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,18 @@ def __sizeof__(self):

return self._cached_sizeof

def __setitem__(self, key, value):
if isinstance(value, list):
value = cudf.Scalar(value)
if isinstance(value, cudf.Scalar):
if value.dtype != self.dtype:
raise TypeError("list nesting level mismatch")
elif value is cudf.NA:
value = cudf.Scalar(value, dtype=self.dtype)
else:
raise ValueError(f"Can not set {value} into ListColumn")
super().__setitem__(key, value)

@property
def base_size(self):
# in some cases, libcudf will return an empty ListColumn with no
Expand Down
5 changes: 4 additions & 1 deletion python/cudf/cudf/core/indexing.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,10 @@ def __setitem__(self, key, value):
# coerce value into a scalar or column
if is_scalar(value):
value = to_cudf_compatible_scalar(value)
else:
elif not (
isinstance(value, list)
and isinstance(self._sr._column.dtype, cudf.ListDtype)
):
value = column.as_column(value)
if (
not isinstance(
Expand Down
2 changes: 1 addition & 1 deletion python/cudf/cudf/core/scalar.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,7 @@ def _preprocess_host_value(self, value, dtype):
)
return value, dtype
elif isinstance(dtype, ListDtype):
if value is not None:
if value not in {None, NA}:
raise ValueError(f"Can not coerce {value} to ListDtype")
else:
return NA, dtype
Expand Down
64 changes: 64 additions & 0 deletions python/cudf/cudf/tests/test_list.py
Original file line number Diff line number Diff line change
Expand Up @@ -457,3 +457,67 @@ def test_serialize_list_columns(data):
df = cudf.DataFrame(data)
recreated = df.__class__.deserialize(*df.serialize())
assert_eq(recreated, df)


@pytest.mark.parametrize(
"data,item",
[
(
# basic list into a list column
[[1, 2, 3], [4, 5, 6], [7, 8, 9]],
[0, 0, 0],
),
(
# nested list into nested list column
[
[[1, 2, 3], [4, 5, 6]],
[[1, 2, 3], [4, 5, 6]],
[[1, 2, 3], [4, 5, 6]],
],
[[0, 0, 0], [0, 0, 0]],
),
(
# NA into a list column
[[1, 2, 3], [4, 5, 6], [7, 8, 9]],
NA,
),
(
# NA into nested list column
[
[[1, 2, 3], [4, 5, 6]],
[[1, 2, 3], [4, 5, 6]],
[[1, 2, 3], [4, 5, 6]],
],
NA,
),
],
)
def test_listcol_setitem(data, item):
sr = cudf.Series(data)

sr[1] = item
data[1] = item
expect = cudf.Series(data)

assert_eq(expect, sr)


@pytest.mark.parametrize(
"data,item,error",
[
(
[[1, 2, 3], [4, 5, 6], [7, 8, 9]],
[[1, 2, 3], [4, 5, 6]],
"list nesting level mismatch",
),
(
[[1, 2, 3], [4, 5, 6], [7, 8, 9]],
0,
"Can not set 0 into ListColumn",
),
],
)
def test_listcol_setitem_error_cases(data, item, error):
sr = cudf.Series(data)
with pytest.raises(BaseException, match=error):
sr[1] = item

0 comments on commit 5884b95

Please sign in to comment.