Skip to content

Commit

Permalink
Allow empty, dtyped IntervalIndex input to keep dtype (rapidsai#69)
Browse files Browse the repository at this point in the history
Before, this would always default to int64.
  • Loading branch information
mroeschke authored Oct 12, 2023
1 parent 506e1a6 commit f1564a8
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 3 deletions.
7 changes: 4 additions & 3 deletions python/cudf/cudf/core/index.py
Original file line number Diff line number Diff line change
Expand Up @@ -3170,13 +3170,14 @@ def __init__(
data = data
elif isinstance(data, pd.Series) and (is_interval_dtype(data.dtype)):
data = column.as_column(data, data.dtype)
elif isinstance(data, (pd._libs.interval.Interval, pd.IntervalIndex)):
elif isinstance(data, (pd.Interval, pd.IntervalIndex)):
data = column.as_column(
data,
dtype=dtype,
)
elif not data:
dtype = IntervalDtype("int64", closed)
elif len(data) == 0:
subtype = getattr(data, "dtype", "int64")
dtype = IntervalDtype(subtype, closed)
data = column.column_empty_like_same_mask(
column.as_column(data), dtype
)
Expand Down
7 changes: 7 additions & 0 deletions python/cudf/cudf/tests/indexes/test_interval.py
Original file line number Diff line number Diff line change
Expand Up @@ -308,3 +308,10 @@ def test_interval_range_floating(start, stop, freq, periods):
)
got = interval_range(start=start, end=stop, freq=freq, periods=periods)
assert_eq(expected, got)


def test_intervalindex_empty_typed_non_int():
data = np.array([], dtype="datetime64[ns]")
result = cudf.IntervalIndex(data)
expected = pd.IntervalIndex(data)
assert_eq(result, expected)

0 comments on commit f1564a8

Please sign in to comment.