Skip to content

Commit

Permalink
Copy nested types upon construction (#8244)
Browse files Browse the repository at this point in the history
Closes #7561 

This PR makes sure upon constructing cudf object, nested types from the pyarrow array is copied to cudf object. This should handle arbitrary nesting of `Lists`, `Structs`. For decimal types, precision is copied from the array.

Authors:
  - Michael Wang (https://github.com/isVoid)
  - Keith Kraus (https://github.com/kkraus14)

Approvers:
  - Keith Kraus (https://github.com/kkraus14)

URL: #8244
  • Loading branch information
isVoid authored May 20, 2021
1 parent b553144 commit c7d0524
Show file tree
Hide file tree
Showing 3 changed files with 169 additions and 7 deletions.
64 changes: 61 additions & 3 deletions python/cudf/cudf/core/column/column.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,12 @@
from cudf._typing import BinaryOperand, ColumnLike, Dtype, ScalarLike
from cudf.core.abc import Serializable
from cudf.core.buffer import Buffer
from cudf.core.dtypes import CategoricalDtype, IntervalDtype
from cudf.core.dtypes import (
CategoricalDtype,
IntervalDtype,
ListDtype,
StructDtype,
)
from cudf.utils import ioutils, utils
from cudf.utils.dtypes import (
check_cast_unsupported_dtype,
Expand Down Expand Up @@ -291,8 +296,7 @@ def from_arrow(cls, array: pa.Array) -> ColumnBase:
"None"
]

if isinstance(result.dtype, cudf.Decimal64Dtype):
result.dtype.precision = array.type.precision
result = _copy_type_metadata_from_arrow(array, result)
return result

def _get_mask_as_column(self) -> ColumnBase:
Expand Down Expand Up @@ -2230,6 +2234,60 @@ def full(size: int, fill_value: ScalarLike, dtype: Dtype = None) -> ColumnBase:
return ColumnBase.from_scalar(cudf.Scalar(fill_value, dtype), size)


def _copy_type_metadata_from_arrow(
arrow_array: pa.array, cudf_column: ColumnBase
) -> ColumnBase:
"""
Similar to `Column._copy_type_metadata`, except copies type metadata
from arrow array into a cudf column. Recursive for every level.
* When `arrow_array` is struct type and `cudf_column` is StructDtype, copy
field names.
* When `arrow_array` is decimal type and `cudf_column` is
Decimal64Dtype, copy precisions.
"""
if pa.types.is_decimal(arrow_array.type) and isinstance(
cudf_column, cudf.core.column.DecimalColumn
):
cudf_column.dtype.precision = arrow_array.type.precision
elif pa.types.is_struct(arrow_array.type) and isinstance(
cudf_column, cudf.core.column.StructColumn
):
base_children = tuple(
_copy_type_metadata_from_arrow(arrow_array.field(i), col_child)
for i, col_child in enumerate(cudf_column.base_children)
)
cudf_column.set_base_children(base_children)
return cudf.core.column.StructColumn(
data=None,
size=cudf_column.base_size,
dtype=StructDtype.from_arrow(arrow_array.type),
mask=cudf_column.base_mask,
offset=cudf_column.offset,
null_count=cudf_column.null_count,
children=base_children,
)
elif pa.types.is_list(arrow_array.type) and isinstance(
cudf_column, cudf.core.column.ListColumn
):
if arrow_array.values and cudf_column.base_children:
base_children = (
cudf_column.base_children[0],
_copy_type_metadata_from_arrow(
arrow_array.values, cudf_column.base_children[1]
),
)
return cudf.core.column.ListColumn(
size=cudf_column.base_size,
dtype=ListDtype.from_arrow(arrow_array.type),
mask=cudf_column.base_mask,
offset=cudf_column.offset,
null_count=cudf_column.null_count,
children=base_children,
)

return cudf_column


def _concat_columns(objs: "MutableSequence[ColumnBase]") -> ColumnBase:
"""Concatenate a sequence of columns."""
if len(objs) == 0:
Expand Down
8 changes: 5 additions & 3 deletions python/cudf/cudf/core/dtypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,6 +143,8 @@ def __init__(self, element_type: Any) -> None:
def element_type(self) -> Dtype:
if isinstance(self._typ.value_type, pa.ListType):
return ListDtype.from_arrow(self._typ.value_type)
elif isinstance(self._typ.value_type, pa.StructType):
return StructDtype.from_arrow(self._typ.value_type)
else:
return np.dtype(self._typ.value_type.to_pandas_dtype()).name

Expand Down Expand Up @@ -176,10 +178,10 @@ def __eq__(self, other):
return self._typ.equals(other._typ)

def __repr__(self):
if isinstance(self.element_type, ListDtype):
return f"ListDtype({self.element_type.__repr__()})"
if isinstance(self.element_type, (ListDtype, StructDtype)):
return f"{type(self).__name__}({self.element_type.__repr__()})"
else:
return f"ListDtype({self.element_type})"
return f"{type(self).__name__}({self.element_type})"

def __hash__(self):
return hash(self._typ)
Expand Down
104 changes: 103 additions & 1 deletion python/cudf/cudf/tests/test_dtypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,14 +6,16 @@
import pytest

import cudf
from cudf.core.column import ColumnBase
from cudf.core.dtypes import (
CategoricalDtype,
Decimal64Dtype,
IntervalDtype,
ListDtype,
StructDtype,
IntervalDtype,
)
from cudf.tests.utils import assert_eq
from cudf.utils.dtypes import np_to_pa_dtype


def test_cdt_basic():
Expand Down Expand Up @@ -155,3 +157,103 @@ def test_interval_dtype_pyarrow_round_trip(fields, closed):
expect = pa_array
got = IntervalDtype.from_arrow(expect).to_arrow()
assert expect.equals(got)


def assert_column_array_dtype_equal(column: ColumnBase, array: pa.array):
"""
In cudf, each column holds its dtype. And since column may have child
columns, child columns also holds their datatype. This method tests
that every level of `column` matches the type of the given `array`
recursively.
"""

if isinstance(column.dtype, ListDtype):
return array.type.equals(
column.dtype.to_arrow()
) and assert_column_array_dtype_equal(
column.base_children[1], array.values
)
elif isinstance(column.dtype, StructDtype):
return array.type.equals(column.dtype.to_arrow()) and all(
[
assert_column_array_dtype_equal(child, array.field(i))
for i, child in enumerate(column.base_children)
]
)
elif isinstance(column.dtype, Decimal64Dtype):
return array.type.equals(column.dtype.to_arrow())
elif isinstance(column.dtype, CategoricalDtype):
raise NotImplementedError()
else:
return array.type.equals(np_to_pa_dtype(column.dtype))


@pytest.mark.parametrize(
"data",
[
[[{"name": 123}]],
[
[
{
"IsLeapYear": False,
"data": {"Year": 1999, "Month": 7},
"names": ["Mike", None],
},
{
"IsLeapYear": True,
"data": {"Year": 2004, "Month": 12},
"names": None,
},
{
"IsLeapYear": False,
"data": {"Year": 1996, "Month": 2},
"names": ["Rose", "Richard"],
},
]
],
[
[None, {"human?": True, "deets": {"weight": 2.4, "age": 27}}],
[
{"human?": None, "deets": {"weight": 5.3, "age": 25}},
{"human?": False, "deets": {"weight": 8.0, "age": 31}},
{"human?": False, "deets": None},
],
[],
None,
[{"human?": None, "deets": {"weight": 6.9, "age": None}}],
],
[
{
"name": "var0",
"val": [
{"name": "var1", "val": None, "type": "optional<struct>"}
],
"type": "list",
},
{},
{
"name": "var2",
"val": [
{
"name": "var3",
"val": {"field": 42},
"type": "optional<struct>",
},
{
"name": "var4",
"val": {"field": 3.14},
"type": "optional<struct>",
},
],
"type": "list",
},
None,
],
],
)
def test_lists_of_structs_dtype(data):
got = cudf.Series(data)
expected = pa.array(data)

assert_column_array_dtype_equal(got._column, expected)
assert expected.equals(got._column.to_arrow())

0 comments on commit c7d0524

Please sign in to comment.