diff --git a/python/cudf/cudf/_lib/cpp/scalar/scalar.pxd b/python/cudf/cudf/_lib/cpp/scalar/scalar.pxd index f0b6ea0b606..771ec9100d1 100644 --- a/python/cudf/cudf/_lib/cpp/scalar/scalar.pxd +++ b/python/cudf/cudf/_lib/cpp/scalar/scalar.pxd @@ -70,4 +70,5 @@ cdef extern from "cudf/scalar/scalar.hpp" namespace "cudf" nogil: column_view view() except + cdef cppclass struct_scalar(scalar): + struct_scalar(table_view cols, bool valid) except + table_view view() except + diff --git a/python/cudf/cudf/_lib/scalar.pyx b/python/cudf/cudf/_lib/scalar.pyx index 9429ab0ee57..9e50f42d625 100644 --- a/python/cudf/cudf/_lib/scalar.pyx +++ b/python/cudf/cudf/_lib/scalar.pyx @@ -30,7 +30,7 @@ from cudf._lib.column cimport Column from cudf._lib.cpp.column.column_view cimport column_view from cudf._lib.cpp.table.table_view cimport table_view from cudf._lib.table cimport Table -from cudf._lib.interop import to_arrow +from cudf._lib.interop import to_arrow, from_arrow from cudf._lib.cpp.wrappers.timestamps cimport ( timestamp_s, @@ -87,6 +87,8 @@ cdef class DeviceScalar: elif isinstance(dtype, cudf.ListDtype): _set_list_from_pylist( self.c_value, value, dtype, valid) + elif isinstance(dtype, cudf.StructDtype): + _set_struct_from_pydict(self.c_value, value, dtype, valid) elif pd.api.types.is_string_dtype(dtype): _set_string_from_np_string(self.c_value, value, valid) elif pd.api.types.is_numeric_dtype(dtype): @@ -172,7 +174,6 @@ cdef class DeviceScalar: s.c_value = move(ptr) cdtype = s.get_raw_ptr()[0].type() - if cdtype.id() == libcudf_types.DECIMAL64 and dtype is None: raise TypeError( "Must pass a dtype when constructing from a fixed-point scalar" @@ -308,6 +309,36 @@ cdef _set_decimal64_from_scalar(unique_ptr[scalar]& s, ) ) +cdef _set_struct_from_pydict(unique_ptr[scalar]& s, + object value, + object dtype, + bool valid=True): + arrow_schema = dtype.to_arrow() + columns = [str(i) for i in range(len(arrow_schema))] + if valid: + pyarrow_table = pa.Table.from_arrays( + [ + pa.array([value[f.name]], from_pandas=True, type=f.type) + for f in arrow_schema + ], + names=columns + ) + else: + pyarrow_table = pa.Table.from_arrays( + [ + pa.array([], from_pandas=True, type=f.type) + for f in arrow_schema + ], + names=columns + ) + + cdef Table table = from_arrow(pyarrow_table, column_names=columns) + cdef table_view struct_view = table.view() + + s.reset( + new struct_scalar(struct_view, valid) + ) + cdef _get_py_dict_from_struct(unique_ptr[scalar]& s): if not s.get()[0].is_valid(): return cudf.NA diff --git a/python/cudf/cudf/core/scalar.py b/python/cudf/cudf/core/scalar.py index ad39642cf60..db9bc6d6c85 100644 --- a/python/cudf/cudf/core/scalar.py +++ b/python/cudf/cudf/core/scalar.py @@ -7,7 +7,7 @@ from cudf._lib.scalar import DeviceScalar, _is_null_host_scalar from cudf.core.column.column import ColumnBase -from cudf.core.dtypes import Decimal64Dtype, ListDtype +from cudf.core.dtypes import Decimal64Dtype, ListDtype, StructDtype from cudf.core.index import BaseIndex from cudf.core.series import Series from cudf.utils.dtypes import ( @@ -131,6 +131,21 @@ def _preprocess_host_value(self, value, dtype): raise ValueError(f"Can not coerce {value} to ListDtype") else: return NA, dtype + + if isinstance(value, dict): + if dtype is not None: + raise TypeError("dict may not be cast to a different dtype") + else: + dtype = StructDtype.from_arrow( + pa.infer_type([value], from_pandas=True) + ) + return value, dtype + elif isinstance(dtype, StructDtype): + if value is not None: + raise ValueError(f"Can not coerce {value} to StructDType") + else: + return NA, dtype + if isinstance(dtype, Decimal64Dtype): value = pa.scalar( value, type=pa.decimal128(dtype.precision, dtype.scale) diff --git a/python/cudf/cudf/tests/test_struct.py b/python/cudf/cudf/tests/test_struct.py index da2af1469c0..4f3bb9bda92 100644 --- a/python/cudf/cudf/tests/test_struct.py +++ b/python/cudf/cudf/tests/test_struct.py @@ -6,6 +6,7 @@ import pytest import cudf +from cudf.core.dtypes import StructDtype from cudf.testing._utils import assert_eq @@ -98,3 +99,23 @@ def test_serialize_struct_dtype(fields): def test_struct_getitem(series, expected): sr = cudf.Series(series) assert sr[0] == expected + + +@pytest.mark.parametrize( + "data", + [ + {"a": 1, "b": "rapids", "c": [1, 2, 3, 4]}, + {"a": 1, "b": "rapids", "c": [1, 2, 3, 4], "d": cudf.NA}, + {"a": "Hello"}, + {"b": [], "c": [1, 2, 3]}, + ], +) +def test_struct_scalar_host_construction(data): + slr = cudf.Scalar(data) + assert slr.value == data + assert list(slr.device_value.value.values()) == list(data.values()) + + +def test_struct_scalar_null(): + slr = cudf.Scalar(cudf.NA, dtype=StructDtype) + assert slr.device_value.value is cudf.NA