From cdc7c2767d00a6b60c9982af6395626d957cd840 Mon Sep 17 00:00:00 2001 From: Matthew Roeschke <10647082+mroeschke@users.noreply.github.com> Date: Fri, 29 Dec 2023 12:46:09 -0800 Subject: [PATCH] Fix nan_as_null not being respected when passing arrow object --- python/cudf/cudf/core/column/column.py | 13 +++++++++++-- python/cudf/cudf/tests/test_series.py | 10 ++++++++++ 2 files changed, 21 insertions(+), 2 deletions(-) diff --git a/python/cudf/cudf/core/column/column.py b/python/cudf/cudf/core/column/column.py index 296fd6a41b0..583ad9b8f79 100644 --- a/python/cudf/cudf/core/column/column.py +++ b/python/cudf/cudf/core/column/column.py @@ -24,6 +24,7 @@ import numpy as np import pandas as pd import pyarrow as pa +import pyarrow.compute as pc from numba import cuda from typing_extensions import Self @@ -2006,11 +2007,19 @@ def as_column( return col elif isinstance(arbitrary, (pa.Array, pa.ChunkedArray)): - if isinstance(arbitrary, pa.lib.HalfFloatArray): + if pa.types.is_float16(arbitrary.type): raise NotImplementedError( "Type casting from `float16` to `float32` is not " "yet supported in pyarrow, see: " - "https://issues.apache.org/jira/browse/ARROW-3802" + "https://github.com/apache/arrow/issues/20213" + ) + elif (nan_as_null is None or nan_as_null) and pa.types.is_floating( + arbitrary.type + ): + arbitrary = pc.if_else( + pc.is_nan(arbitrary), + pa.nulls(len(arbitrary), type=arbitrary.type), + arbitrary, ) col = ColumnBase.from_arrow(arbitrary) diff --git a/python/cudf/cudf/tests/test_series.py b/python/cudf/cudf/tests/test_series.py index 39da34fa89c..f05fb7c4849 100644 --- a/python/cudf/cudf/tests/test_series.py +++ b/python/cudf/cudf/tests/test_series.py @@ -2572,6 +2572,16 @@ def test_series_arrow_list_types_roundtrip(): cudf.from_pandas(pdf) +@pytest.mark.parametrize("klass", [cudf.Index, cudf.Series]) +@pytest.mark.parametrize( + "data", [pa.array([float("nan")]), pa.chunked_array([[float("nan")]])] +) +def test_nan_as_null_from_arrow_objects(klass, data): + result = klass(data, nan_as_null=True) + expected = klass(pa.array([None], type=pa.float64())) + assert_eq(result, expected) + + @pytest.mark.parametrize("reso", ["M", "ps"]) @pytest.mark.parametrize("typ", ["M", "m"]) def test_series_invalid_reso_dtype(reso, typ):