Skip to content

Commit

Permalink
fix up to_arrow for datatype and add some tests
Browse files Browse the repository at this point in the history
  • Loading branch information
brandon-b-miller committed May 30, 2024
1 parent f09afa1 commit bffe500
Show file tree
Hide file tree
Showing 2 changed files with 88 additions and 9 deletions.
31 changes: 22 additions & 9 deletions python/cudf/cudf/_lib/pylibcudf/interop.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -209,15 +209,28 @@ def to_arrow(cudf_object, metadata=None):


@to_arrow.register(DataType)
def _to_arrow_datatype(cudf_object):
if cudf_object.id() == type_id.DECIMAL128:
return pa.decimal128(-cudf_object.scale)
elif cudf_object.id() in {type_id.LIST, type_id.STRUCT}:
# TODO: need column metadata
raise ValueError(
f"Cannot convert {cudf_object} to PyArrow type"
)
return ARROW_TO_PYLIBCUDF_TYPES.get(cudf_object.id())
def _to_arrow_datatype(cudf_object, **kwargs):
if cudf_object.id() in {type_id.DECIMAL32, type_id.DECIMAL64, type_id.DECIMAL128}:
if not (precision := kwargs.get("precision")):
raise ValueError(
"Precision must be provided for decimal types"
)
# no pa.decimal32 or pa.decimal64
return pa.decimal128(precision, -cudf_object.scale())
elif cudf_object.id() == type_id.STRUCT:
if not (fields := kwargs.get("fields")):
raise ValueError(
"Fields must be provided for struct types"
)
return pa.struct(fields)
elif cudf_object.id() == type_id.LIST:
if not (value_type := kwargs.get("value_type")):
raise ValueError(
"Value type must be provided for list types"
)
return pa.list_(value_type)
else:
return ARROW_TO_PYLIBCUDF_TYPES.get(cudf_object.id())


@to_arrow.register(Table)
Expand Down
66 changes: 66 additions & 0 deletions python/cudf/cudf/pylibcudf_tests/test_interop.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
# Copyright (c) 2024, NVIDIA CORPORATION.

import pyarrow as pa
import pytest

import cudf._lib.pylibcudf as plc


def test_list_dtype_roundtrip():
list_type = pa.list_(pa.int32())
plc_type = plc.interop.from_arrow(list_type)

assert plc_type == plc.types.DataType(plc.types.TypeId.LIST)

with pytest.raises(ValueError):
plc.interop.to_arrow(plc_type)

arrow_type = plc.interop.to_arrow(
plc_type, value_type=list_type.value_type
)
assert arrow_type == list_type


def test_struct_dtype_roundtrip():
struct_type = pa.struct([("a", pa.int32()), ("b", pa.string())])
plc_type = plc.interop.from_arrow(struct_type)

assert plc_type == plc.types.DataType(plc.types.TypeId.STRUCT)

with pytest.raises(ValueError):
plc.interop.to_arrow(plc_type)

arrow_type = plc.interop.to_arrow(plc_type, fields=struct_type)
assert arrow_type == struct_type


def test_decimal128_roundtrip():
decimal_type = pa.decimal128(10, 2)
plc_type = plc.interop.from_arrow(decimal_type)

assert plc_type.id() == plc.types.TypeId.DECIMAL128

with pytest.raises(ValueError):
plc.interop.to_arrow(plc_type)

arrow_type = plc.interop.to_arrow(
plc_type, precision=decimal_type.precision
)
assert arrow_type == decimal_type


@pytest.mark.parametrize(
"data_type",
[
plc.types.DataType(plc.types.TypeId.DECIMAL32),
plc.types.DataType(plc.types.TypeId.DECIMAL64),
],
)
def test_decimal_other(data_type):
precision = 3

with pytest.raises(ValueError):
plc.interop.to_arrow(data_type)

arrow_type = plc.interop.to_arrow(data_type, precision=precision)
assert arrow_type == pa.decimal128(precision, 0)

0 comments on commit bffe500

Please sign in to comment.