From 89787f24b957408d051791ebe725d5eee30c4814 Mon Sep 17 00:00:00 2001 From: GALI PREM SAGAR Date: Fri, 25 Aug 2023 14:44:12 -0500 Subject: [PATCH] Handle `Interval` scalars when passed in list-like inputs to `cudf.Index` (#13956) closes #13952 This PR fixes an issue with `IntervalColumn` construction where we can utilize the existing type inference to create a pandas Series and then construct an `IntervalColumn` out of it since pyarrow is unable to read this kind of input correctly. Authors: - GALI PREM SAGAR (https://github.com/galipremsagar) Approvers: - Matthew Roeschke (https://github.com/mroeschke) URL: https://github.com/rapidsai/cudf/pull/13956 --- python/cudf/cudf/core/column/column.py | 10 +++++++--- python/cudf/cudf/tests/test_interval.py | 13 +++++++++++++ 2 files changed, 20 insertions(+), 3 deletions(-) diff --git a/python/cudf/cudf/core/column/column.py b/python/cudf/cudf/core/column/column.py index 446f01ef419..eafcc18450d 100644 --- a/python/cudf/cudf/core/column/column.py +++ b/python/cudf/cudf/core/column/column.py @@ -2454,25 +2454,29 @@ def as_column( def _construct_array( arbitrary: Any, dtype: Optional[Dtype] -) -> Union[np.ndarray, cupy.ndarray]: +) -> Union[np.ndarray, cupy.ndarray, pd.api.extensions.ExtensionArray]: """ - Construct a CuPy or NumPy array from `arbitrary` + Construct a CuPy/NumPy/Pandas array from `arbitrary` """ try: dtype = dtype if dtype is None else cudf.dtype(dtype) arbitrary = cupy.asarray(arbitrary, dtype=dtype) except (TypeError, ValueError): native_dtype = dtype + inferred_dtype = None if ( dtype is None and not cudf._lib.scalar._is_null_host_scalar(arbitrary) - and infer_dtype(arbitrary, skipna=False) + and (inferred_dtype := infer_dtype(arbitrary, skipna=False)) in ( "mixed", "mixed-integer", ) ): native_dtype = "object" + if inferred_dtype == "interval": + # Only way to construct an Interval column. + return pd.array(arbitrary) arbitrary = np.asarray( arbitrary, dtype=native_dtype diff --git a/python/cudf/cudf/tests/test_interval.py b/python/cudf/cudf/tests/test_interval.py index 18454172289..f2e8f585a69 100644 --- a/python/cudf/cudf/tests/test_interval.py +++ b/python/cudf/cudf/tests/test_interval.py @@ -136,6 +136,19 @@ def test_create_interval_df(data1, data2, data3, data4, closed): assert_eq(expect_three, got_three) +def test_create_interval_index_from_list(): + interval_list = [ + np.nan, + pd.Interval(2.0, 3.0, closed="right"), + pd.Interval(3.0, 4.0, closed="right"), + ] + + expected = pd.Index(interval_list) + actual = cudf.Index(interval_list) + + assert_eq(expected, actual) + + def test_interval_index_unique(): interval_list = [ np.nan,