Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Struct scalar from host dictionary #8629

Merged
merged 8 commits into from
Jul 9, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions python/cudf/cudf/_lib/cpp/scalar/scalar.pxd
Original file line number Diff line number Diff line change
Expand Up @@ -70,4 +70,5 @@ cdef extern from "cudf/scalar/scalar.hpp" namespace "cudf" nogil:
column_view view() except +

cdef cppclass struct_scalar(scalar):
struct_scalar(table_view cols, bool valid) except +
table_view view() except +
35 changes: 33 additions & 2 deletions python/cudf/cudf/_lib/scalar.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ from cudf._lib.column cimport Column
from cudf._lib.cpp.column.column_view cimport column_view
from cudf._lib.cpp.table.table_view cimport table_view
from cudf._lib.table cimport Table
from cudf._lib.interop import to_arrow
from cudf._lib.interop import to_arrow, from_arrow

from cudf._lib.cpp.wrappers.timestamps cimport (
timestamp_s,
Expand Down Expand Up @@ -87,6 +87,8 @@ cdef class DeviceScalar:
elif isinstance(dtype, cudf.ListDtype):
_set_list_from_pylist(
self.c_value, value, dtype, valid)
elif isinstance(dtype, cudf.StructDtype):
_set_struct_from_pydict(self.c_value, value, dtype, valid)
elif pd.api.types.is_string_dtype(dtype):
_set_string_from_np_string(self.c_value, value, valid)
elif pd.api.types.is_numeric_dtype(dtype):
Expand Down Expand Up @@ -172,7 +174,6 @@ cdef class DeviceScalar:

s.c_value = move(ptr)
cdtype = s.get_raw_ptr()[0].type()

if cdtype.id() == libcudf_types.DECIMAL64 and dtype is None:
raise TypeError(
"Must pass a dtype when constructing from a fixed-point scalar"
Expand Down Expand Up @@ -308,6 +309,36 @@ cdef _set_decimal64_from_scalar(unique_ptr[scalar]& s,
)
)

cdef _set_struct_from_pydict(unique_ptr[scalar]& s,
object value,
object dtype,
bool valid=True):
arrow_schema = dtype.to_arrow()
columns = [str(i) for i in range(len(arrow_schema))]
if valid:
pyarrow_table = pa.Table.from_arrays(
[
pa.array([value[f.name]], from_pandas=True, type=f.type)
for f in arrow_schema
],
names=columns
)
else:
pyarrow_table = pa.Table.from_arrays(
[
pa.array([], from_pandas=True, type=f.type)
shaneding marked this conversation as resolved.
Show resolved Hide resolved
for f in arrow_schema
],
names=columns
)

cdef Table table = from_arrow(pyarrow_table, column_names=columns)
cdef table_view struct_view = table.view()

s.reset(
new struct_scalar(struct_view, valid)
)

cdef _get_py_dict_from_struct(unique_ptr[scalar]& s):
if not s.get()[0].is_valid():
return cudf.NA
Expand Down
17 changes: 16 additions & 1 deletion python/cudf/cudf/core/scalar.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

from cudf._lib.scalar import DeviceScalar, _is_null_host_scalar
from cudf.core.column.column import ColumnBase
from cudf.core.dtypes import Decimal64Dtype, ListDtype
from cudf.core.dtypes import Decimal64Dtype, ListDtype, StructDtype
from cudf.core.index import BaseIndex
from cudf.core.series import Series
from cudf.utils.dtypes import (
Expand Down Expand Up @@ -131,6 +131,21 @@ def _preprocess_host_value(self, value, dtype):
raise ValueError(f"Can not coerce {value} to ListDtype")
else:
return NA, dtype

if isinstance(value, dict):
if dtype is not None:
raise TypeError("dict may not be cast to a different dtype")
else:
dtype = StructDtype.from_arrow(
pa.infer_type([value], from_pandas=True)
)
return value, dtype
elif isinstance(dtype, StructDtype):
if value is not None:
raise ValueError(f"Can not coerce {value} to StructDType")
else:
return NA, dtype

if isinstance(dtype, Decimal64Dtype):
value = pa.scalar(
value, type=pa.decimal128(dtype.precision, dtype.scale)
Expand Down
21 changes: 21 additions & 0 deletions python/cudf/cudf/tests/test_struct.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import pytest

import cudf
from cudf.core.dtypes import StructDtype
from cudf.testing._utils import assert_eq


Expand Down Expand Up @@ -98,3 +99,23 @@ def test_serialize_struct_dtype(fields):
def test_struct_getitem(series, expected):
sr = cudf.Series(series)
assert sr[0] == expected


@pytest.mark.parametrize(
"data",
[
{"a": 1, "b": "rapids", "c": [1, 2, 3, 4]},
{"a": 1, "b": "rapids", "c": [1, 2, 3, 4], "d": cudf.NA},
{"a": "Hello"},
{"b": [], "c": [1, 2, 3]},
shaneding marked this conversation as resolved.
Show resolved Hide resolved
],
)
def test_struct_scalar_host_construction(data):
slr = cudf.Scalar(data)
assert slr.value == data
assert list(slr.device_value.value.values()) == list(data.values())


def test_struct_scalar_null():
slr = cudf.Scalar(cudf.NA, dtype=StructDtype)
assert slr.device_value.value is cudf.NA