Skip to content

Commit

Permalink
Handle interval scalars in the form of a list
Browse files Browse the repository at this point in the history
  • Loading branch information
galipremsagar committed Aug 25, 2023
1 parent 384b33f commit f5e8039
Show file tree
Hide file tree
Showing 2 changed files with 21 additions and 4 deletions.
10 changes: 7 additions & 3 deletions python/cudf/cudf/core/column/column.py
Original file line number Diff line number Diff line change
Expand Up @@ -2452,25 +2452,29 @@ def as_column(

def _construct_array(
arbitrary: Any, dtype: Optional[Dtype]
) -> Union[np.ndarray, cupy.ndarray]:
) -> Union[np.ndarray, cupy.ndarray, pd.api.extensions.ExtensionArray]:
"""
Construct a CuPy or NumPy array from `arbitrary`
Construct a CuPy/NumPy/Pandas array from `arbitrary`
"""
try:
dtype = dtype if dtype is None else cudf.dtype(dtype)
arbitrary = cupy.asarray(arbitrary, dtype=dtype)
except (TypeError, ValueError):
native_dtype = dtype
inferred_dtype = None
if (
dtype is None
and not cudf._lib.scalar._is_null_host_scalar(arbitrary)
and infer_dtype(arbitrary, skipna=False)
and (inferred_dtype := infer_dtype(arbitrary, skipna=False))
in (
"mixed",
"mixed-integer",
)
):
native_dtype = "object"
if inferred_dtype == "interval":
# Only way to construct an Interval column.
return pd.array(arbitrary)
arbitrary = np.asarray(
arbitrary,
dtype=native_dtype
Expand Down
15 changes: 14 additions & 1 deletion python/cudf/cudf/tests/test_interval.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# Copyright (c) 2020-2022, NVIDIA CORPORATION.
# Copyright (c) 2020-2023, NVIDIA CORPORATION.

import pandas as pd
import numpy as np
import pytest

import cudf
Expand Down Expand Up @@ -132,3 +133,15 @@ def test_create_interval_df(data1, data2, data3, data4, closed):
dtype="interval",
)
assert_eq(expect_three, got_three)


def test_create_interval_index_from_list():
interval_list = [
np.nan,
pd.Interval(2.0, 3.0, closed="right"),
pd.Interval(3.0, 4.0, closed="right"),
]
expected = pd.Index(interval_list)
actual = cudf.Index(interval_list)

assert_eq(expected, actual)

0 comments on commit f5e8039

Please sign in to comment.