Skip to content

Commit

Permalink
Fix DatetimeIndex & TimedeltaIndex constructors (rapidsai#11342)
Browse files Browse the repository at this point in the history
Closes rapidsai#11335 

This PR fixes an issue with `DatetimeIndex` & `TimedeltaIndex` where the underlying columns would still be of numeric or string types rather than `DatetimeColumn` and `TimedeltaColumn` respectively. This is the actual root cause that leads to errors in some downstream API calls.

Authors:
  - GALI PREM SAGAR (https://github.com/galipremsagar)

Approvers:
  - Vyas Ramasubramani (https://github.com/vyasr)

URL: rapidsai#11342
  • Loading branch information
galipremsagar authored Jul 25, 2022
1 parent a652ca9 commit 2d214ea
Show file tree
Hide file tree
Showing 3 changed files with 58 additions and 16 deletions.
32 changes: 16 additions & 16 deletions python/cudf/cudf/core/index.py
Original file line number Diff line number Diff line change
Expand Up @@ -1783,15 +1783,12 @@ def __init__(
elif dtype not in valid_dtypes:
raise TypeError("Invalid dtype")

if copy:
data = column.as_column(data).copy()
kwargs = _setdefault_name(data, name=name)
if isinstance(data, np.ndarray) and data.dtype.kind == "M":
data = column.as_column(data)
elif isinstance(data, pd.DatetimeIndex):
data = column.as_column(data.values)
elif isinstance(data, (list, tuple)):
data = column.as_column(np.array(data, dtype=dtype))
data = column.as_column(data, dtype=dtype)

if copy:
data = data.copy()

super().__init__(data, **kwargs)

@property # type: ignore
Expand Down Expand Up @@ -2263,15 +2260,18 @@ def __init__(
"dtype parameter is supported"
)

if copy:
data = column.as_column(data).copy()
valid_dtypes = tuple(
f"timedelta64[{res}]" for res in ("s", "ms", "us", "ns")
)
if dtype not in valid_dtypes:
raise TypeError("Invalid dtype")

kwargs = _setdefault_name(data, name=name)
if isinstance(data, np.ndarray) and data.dtype.kind == "m":
data = column.as_column(data)
elif isinstance(data, pd.TimedeltaIndex):
data = column.as_column(data.values)
elif isinstance(data, (list, tuple)):
data = column.as_column(np.array(data, dtype=dtype))
data = column.as_column(data, dtype=dtype)

if copy:
data = data.copy()

super().__init__(data, **kwargs)

@_cudf_nvtx_annotate
Expand Down
28 changes: 28 additions & 0 deletions python/cudf/cudf/tests/test_datetime.py
Original file line number Diff line number Diff line change
Expand Up @@ -2007,3 +2007,31 @@ def test_last(idx, offset):
got = g.last(offset=offset)

assert_eq(expect, got)


@pytest.mark.parametrize(
"data",
[
[
"2020-01-31",
"2020-02-15",
"2020-02-29",
"2020-03-15",
"2020-03-31",
"2020-04-15",
"2020-04-30",
],
[43534, 43543, 37897, 2000],
],
)
@pytest.mark.parametrize("dtype", [None, "datetime64[ns]"])
def test_datetime_constructor(data, dtype):
expected = pd.DatetimeIndex(data=data, dtype=dtype)
actual = cudf.DatetimeIndex(data=data, dtype=dtype)

assert_eq(expected, actual)

expected = pd.DatetimeIndex(data=pd.Series(data), dtype=dtype)
actual = cudf.DatetimeIndex(data=cudf.Series(data), dtype=dtype)

assert_eq(expected, actual)
14 changes: 14 additions & 0 deletions python/cudf/cudf/tests/test_timedelta.py
Original file line number Diff line number Diff line change
Expand Up @@ -1389,3 +1389,17 @@ def test_create_TimedeltaIndex(dtype, name):
)
pdi = gdi.to_pandas()
assert_eq(pdi, gdi)


@pytest.mark.parametrize("data", [[43534, 43543, 37897, 2000]])
@pytest.mark.parametrize("dtype", ["timedelta64[ns]"])
def test_timedelta_constructor(data, dtype):
expected = pd.TimedeltaIndex(data=data, dtype=dtype)
actual = cudf.TimedeltaIndex(data=data, dtype=dtype)

assert_eq(expected, actual)

expected = pd.TimedeltaIndex(data=pd.Series(data), dtype=dtype)
actual = cudf.TimedeltaIndex(data=cudf.Series(data), dtype=dtype)

assert_eq(expected, actual)

0 comments on commit 2d214ea

Please sign in to comment.