diff --git a/python/cudf/cudf/core/column/column.py b/python/cudf/cudf/core/column/column.py index 20f302f7e59..4bf4b2b87f2 100644 --- a/python/cudf/cudf/core/column/column.py +++ b/python/cudf/cudf/core/column/column.py @@ -40,7 +40,12 @@ from cudf._typing import BinaryOperand, ColumnLike, Dtype, ScalarLike from cudf.core.abc import Serializable from cudf.core.buffer import Buffer -from cudf.core.dtypes import CategoricalDtype, IntervalDtype +from cudf.core.dtypes import ( + CategoricalDtype, + IntervalDtype, + ListDtype, + StructDtype, +) from cudf.utils import ioutils, utils from cudf.utils.dtypes import ( check_cast_unsupported_dtype, @@ -291,8 +296,7 @@ def from_arrow(cls, array: pa.Array) -> ColumnBase: "None" ] - if isinstance(result.dtype, cudf.Decimal64Dtype): - result.dtype.precision = array.type.precision + result = _copy_type_metadata_from_arrow(array, result) return result def _get_mask_as_column(self) -> ColumnBase: @@ -2230,6 +2234,60 @@ def full(size: int, fill_value: ScalarLike, dtype: Dtype = None) -> ColumnBase: return ColumnBase.from_scalar(cudf.Scalar(fill_value, dtype), size) +def _copy_type_metadata_from_arrow( + arrow_array: pa.array, cudf_column: ColumnBase +) -> ColumnBase: + """ + Similar to `Column._copy_type_metadata`, except copies type metadata + from arrow array into a cudf column. Recursive for every level. + * When `arrow_array` is struct type and `cudf_column` is StructDtype, copy + field names. + * When `arrow_array` is decimal type and `cudf_column` is + Decimal64Dtype, copy precisions. + """ + if pa.types.is_decimal(arrow_array.type) and isinstance( + cudf_column, cudf.core.column.DecimalColumn + ): + cudf_column.dtype.precision = arrow_array.type.precision + elif pa.types.is_struct(arrow_array.type) and isinstance( + cudf_column, cudf.core.column.StructColumn + ): + base_children = tuple( + _copy_type_metadata_from_arrow(arrow_array.field(i), col_child) + for i, col_child in enumerate(cudf_column.base_children) + ) + cudf_column.set_base_children(base_children) + return cudf.core.column.StructColumn( + data=None, + size=cudf_column.base_size, + dtype=StructDtype.from_arrow(arrow_array.type), + mask=cudf_column.base_mask, + offset=cudf_column.offset, + null_count=cudf_column.null_count, + children=base_children, + ) + elif pa.types.is_list(arrow_array.type) and isinstance( + cudf_column, cudf.core.column.ListColumn + ): + if arrow_array.values and cudf_column.base_children: + base_children = ( + cudf_column.base_children[0], + _copy_type_metadata_from_arrow( + arrow_array.values, cudf_column.base_children[1] + ), + ) + return cudf.core.column.ListColumn( + size=cudf_column.base_size, + dtype=ListDtype.from_arrow(arrow_array.type), + mask=cudf_column.base_mask, + offset=cudf_column.offset, + null_count=cudf_column.null_count, + children=base_children, + ) + + return cudf_column + + def _concat_columns(objs: "MutableSequence[ColumnBase]") -> ColumnBase: """Concatenate a sequence of columns.""" if len(objs) == 0: diff --git a/python/cudf/cudf/core/dtypes.py b/python/cudf/cudf/core/dtypes.py index 7db8ba15caa..f0b0dbba4a5 100644 --- a/python/cudf/cudf/core/dtypes.py +++ b/python/cudf/cudf/core/dtypes.py @@ -143,6 +143,8 @@ def __init__(self, element_type: Any) -> None: def element_type(self) -> Dtype: if isinstance(self._typ.value_type, pa.ListType): return ListDtype.from_arrow(self._typ.value_type) + elif isinstance(self._typ.value_type, pa.StructType): + return StructDtype.from_arrow(self._typ.value_type) else: return np.dtype(self._typ.value_type.to_pandas_dtype()).name @@ -176,10 +178,10 @@ def __eq__(self, other): return self._typ.equals(other._typ) def __repr__(self): - if isinstance(self.element_type, ListDtype): - return f"ListDtype({self.element_type.__repr__()})" + if isinstance(self.element_type, (ListDtype, StructDtype)): + return f"{type(self).__name__}({self.element_type.__repr__()})" else: - return f"ListDtype({self.element_type})" + return f"{type(self).__name__}({self.element_type})" def __hash__(self): return hash(self._typ) diff --git a/python/cudf/cudf/tests/test_dtypes.py b/python/cudf/cudf/tests/test_dtypes.py index b6e2aac0304..a5895caf49f 100644 --- a/python/cudf/cudf/tests/test_dtypes.py +++ b/python/cudf/cudf/tests/test_dtypes.py @@ -6,14 +6,16 @@ import pytest import cudf +from cudf.core.column import ColumnBase from cudf.core.dtypes import ( CategoricalDtype, Decimal64Dtype, + IntervalDtype, ListDtype, StructDtype, - IntervalDtype, ) from cudf.tests.utils import assert_eq +from cudf.utils.dtypes import np_to_pa_dtype def test_cdt_basic(): @@ -155,3 +157,103 @@ def test_interval_dtype_pyarrow_round_trip(fields, closed): expect = pa_array got = IntervalDtype.from_arrow(expect).to_arrow() assert expect.equals(got) + + +def assert_column_array_dtype_equal(column: ColumnBase, array: pa.array): + """ + In cudf, each column holds its dtype. And since column may have child + columns, child columns also holds their datatype. This method tests + that every level of `column` matches the type of the given `array` + recursively. + """ + + if isinstance(column.dtype, ListDtype): + return array.type.equals( + column.dtype.to_arrow() + ) and assert_column_array_dtype_equal( + column.base_children[1], array.values + ) + elif isinstance(column.dtype, StructDtype): + return array.type.equals(column.dtype.to_arrow()) and all( + [ + assert_column_array_dtype_equal(child, array.field(i)) + for i, child in enumerate(column.base_children) + ] + ) + elif isinstance(column.dtype, Decimal64Dtype): + return array.type.equals(column.dtype.to_arrow()) + elif isinstance(column.dtype, CategoricalDtype): + raise NotImplementedError() + else: + return array.type.equals(np_to_pa_dtype(column.dtype)) + + +@pytest.mark.parametrize( + "data", + [ + [[{"name": 123}]], + [ + [ + { + "IsLeapYear": False, + "data": {"Year": 1999, "Month": 7}, + "names": ["Mike", None], + }, + { + "IsLeapYear": True, + "data": {"Year": 2004, "Month": 12}, + "names": None, + }, + { + "IsLeapYear": False, + "data": {"Year": 1996, "Month": 2}, + "names": ["Rose", "Richard"], + }, + ] + ], + [ + [None, {"human?": True, "deets": {"weight": 2.4, "age": 27}}], + [ + {"human?": None, "deets": {"weight": 5.3, "age": 25}}, + {"human?": False, "deets": {"weight": 8.0, "age": 31}}, + {"human?": False, "deets": None}, + ], + [], + None, + [{"human?": None, "deets": {"weight": 6.9, "age": None}}], + ], + [ + { + "name": "var0", + "val": [ + {"name": "var1", "val": None, "type": "optional"} + ], + "type": "list", + }, + {}, + { + "name": "var2", + "val": [ + { + "name": "var3", + "val": {"field": 42}, + "type": "optional", + }, + { + "name": "var4", + "val": {"field": 3.14}, + "type": "optional", + }, + ], + "type": "list", + }, + None, + ], + ], +) +def test_lists_of_structs_dtype(data): + got = cudf.Series(data) + expected = pa.array(data) + + assert_column_array_dtype_equal(got._column, expected) + assert expected.equals(got._column.to_arrow())