Skip to content

Commit

Permalink
Handle Interval scalars when passed in list-like inputs to `cudf.In…
Browse files Browse the repository at this point in the history
…dex` (#13956)

closes #13952 

This PR fixes an issue with `IntervalColumn` construction where we can utilize the existing type inference to create a pandas Series and then construct an `IntervalColumn` out of it since pyarrow is unable to read this kind of input correctly.

Authors:
  - GALI PREM SAGAR (https://github.com/galipremsagar)

Approvers:
  - Matthew Roeschke (https://github.com/mroeschke)

URL: #13956
  • Loading branch information
galipremsagar authored Aug 25, 2023
1 parent 6d10a82 commit 89787f2
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 3 deletions.
10 changes: 7 additions & 3 deletions python/cudf/cudf/core/column/column.py
Original file line number Diff line number Diff line change
Expand Up @@ -2454,25 +2454,29 @@ def as_column(

def _construct_array(
arbitrary: Any, dtype: Optional[Dtype]
) -> Union[np.ndarray, cupy.ndarray]:
) -> Union[np.ndarray, cupy.ndarray, pd.api.extensions.ExtensionArray]:
"""
Construct a CuPy or NumPy array from `arbitrary`
Construct a CuPy/NumPy/Pandas array from `arbitrary`
"""
try:
dtype = dtype if dtype is None else cudf.dtype(dtype)
arbitrary = cupy.asarray(arbitrary, dtype=dtype)
except (TypeError, ValueError):
native_dtype = dtype
inferred_dtype = None
if (
dtype is None
and not cudf._lib.scalar._is_null_host_scalar(arbitrary)
and infer_dtype(arbitrary, skipna=False)
and (inferred_dtype := infer_dtype(arbitrary, skipna=False))
in (
"mixed",
"mixed-integer",
)
):
native_dtype = "object"
if inferred_dtype == "interval":
# Only way to construct an Interval column.
return pd.array(arbitrary)
arbitrary = np.asarray(
arbitrary,
dtype=native_dtype
Expand Down
13 changes: 13 additions & 0 deletions python/cudf/cudf/tests/test_interval.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,19 @@ def test_create_interval_df(data1, data2, data3, data4, closed):
assert_eq(expect_three, got_three)


def test_create_interval_index_from_list():
interval_list = [
np.nan,
pd.Interval(2.0, 3.0, closed="right"),
pd.Interval(3.0, 4.0, closed="right"),
]

expected = pd.Index(interval_list)
actual = cudf.Index(interval_list)

assert_eq(expected, actual)


def test_interval_index_unique():
interval_list = [
np.nan,
Expand Down

0 comments on commit 89787f2

Please sign in to comment.