Skip to content

Commit

Permalink
Struct scalar from host dictionary (#8629)
Browse files Browse the repository at this point in the history
Allows the creation of a `struct_scalar` from python dictionary. Partly addresses #8558.

Authors:
  - https://github.com/shaneding

Approvers:
  - https://github.com/brandon-b-miller
  - Ram (Ramakrishna Prabhu) (https://github.com/rgsl888prabhu)

URL: #8629
  • Loading branch information
shaneding authored Jul 9, 2021
1 parent 1b34652 commit cef51bd
Show file tree
Hide file tree
Showing 4 changed files with 71 additions and 3 deletions.
1 change: 1 addition & 0 deletions python/cudf/cudf/_lib/cpp/scalar/scalar.pxd
Original file line number Diff line number Diff line change
Expand Up @@ -70,4 +70,5 @@ cdef extern from "cudf/scalar/scalar.hpp" namespace "cudf" nogil:
column_view view() except +

cdef cppclass struct_scalar(scalar):
struct_scalar(table_view cols, bool valid) except +
table_view view() except +
35 changes: 33 additions & 2 deletions python/cudf/cudf/_lib/scalar.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ from cudf._lib.column cimport Column
from cudf._lib.cpp.column.column_view cimport column_view
from cudf._lib.cpp.table.table_view cimport table_view
from cudf._lib.table cimport Table
from cudf._lib.interop import to_arrow
from cudf._lib.interop import to_arrow, from_arrow

from cudf._lib.cpp.wrappers.timestamps cimport (
timestamp_s,
Expand Down Expand Up @@ -87,6 +87,8 @@ cdef class DeviceScalar:
elif isinstance(dtype, cudf.ListDtype):
_set_list_from_pylist(
self.c_value, value, dtype, valid)
elif isinstance(dtype, cudf.StructDtype):
_set_struct_from_pydict(self.c_value, value, dtype, valid)
elif pd.api.types.is_string_dtype(dtype):
_set_string_from_np_string(self.c_value, value, valid)
elif pd.api.types.is_numeric_dtype(dtype):
Expand Down Expand Up @@ -172,7 +174,6 @@ cdef class DeviceScalar:

s.c_value = move(ptr)
cdtype = s.get_raw_ptr()[0].type()

if cdtype.id() == libcudf_types.DECIMAL64 and dtype is None:
raise TypeError(
"Must pass a dtype when constructing from a fixed-point scalar"
Expand Down Expand Up @@ -308,6 +309,36 @@ cdef _set_decimal64_from_scalar(unique_ptr[scalar]& s,
)
)

cdef _set_struct_from_pydict(unique_ptr[scalar]& s,
object value,
object dtype,
bool valid=True):
arrow_schema = dtype.to_arrow()
columns = [str(i) for i in range(len(arrow_schema))]
if valid:
pyarrow_table = pa.Table.from_arrays(
[
pa.array([value[f.name]], from_pandas=True, type=f.type)
for f in arrow_schema
],
names=columns
)
else:
pyarrow_table = pa.Table.from_arrays(
[
pa.array([], from_pandas=True, type=f.type)
for f in arrow_schema
],
names=columns
)

cdef Table table = from_arrow(pyarrow_table, column_names=columns)
cdef table_view struct_view = table.view()

s.reset(
new struct_scalar(struct_view, valid)
)

cdef _get_py_dict_from_struct(unique_ptr[scalar]& s):
if not s.get()[0].is_valid():
return cudf.NA
Expand Down
17 changes: 16 additions & 1 deletion python/cudf/cudf/core/scalar.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

from cudf._lib.scalar import DeviceScalar, _is_null_host_scalar
from cudf.core.column.column import ColumnBase
from cudf.core.dtypes import Decimal64Dtype, ListDtype
from cudf.core.dtypes import Decimal64Dtype, ListDtype, StructDtype
from cudf.core.index import BaseIndex
from cudf.core.series import Series
from cudf.utils.dtypes import (
Expand Down Expand Up @@ -131,6 +131,21 @@ def _preprocess_host_value(self, value, dtype):
raise ValueError(f"Can not coerce {value} to ListDtype")
else:
return NA, dtype

if isinstance(value, dict):
if dtype is not None:
raise TypeError("dict may not be cast to a different dtype")
else:
dtype = StructDtype.from_arrow(
pa.infer_type([value], from_pandas=True)
)
return value, dtype
elif isinstance(dtype, StructDtype):
if value is not None:
raise ValueError(f"Can not coerce {value} to StructDType")
else:
return NA, dtype

if isinstance(dtype, Decimal64Dtype):
value = pa.scalar(
value, type=pa.decimal128(dtype.precision, dtype.scale)
Expand Down
21 changes: 21 additions & 0 deletions python/cudf/cudf/tests/test_struct.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import pytest

import cudf
from cudf.core.dtypes import StructDtype
from cudf.testing._utils import assert_eq


Expand Down Expand Up @@ -98,3 +99,23 @@ def test_serialize_struct_dtype(fields):
def test_struct_getitem(series, expected):
sr = cudf.Series(series)
assert sr[0] == expected


@pytest.mark.parametrize(
"data",
[
{"a": 1, "b": "rapids", "c": [1, 2, 3, 4]},
{"a": 1, "b": "rapids", "c": [1, 2, 3, 4], "d": cudf.NA},
{"a": "Hello"},
{"b": [], "c": [1, 2, 3]},
],
)
def test_struct_scalar_host_construction(data):
slr = cudf.Scalar(data)
assert slr.value == data
assert list(slr.device_value.value.values()) == list(data.values())


def test_struct_scalar_null():
slr = cudf.Scalar(cudf.NA, dtype=StructDtype)
assert slr.device_value.value is cudf.NA

0 comments on commit cef51bd

Please sign in to comment.