Skip to content

Commit

Permalink
Clean up DatetimeIndex.__init__ constructor (#14774)
Browse files Browse the repository at this point in the history
Additionally adds some typing and remove validation done by `cudf.dtype` and add a unit test to ensure numpy dtype objects are accepted in the constructor

Authors:
  - Matthew Roeschke (https://github.com/mroeschke)
  - Vyas Ramasubramani (https://github.com/vyasr)

Approvers:
  - Vyas Ramasubramani (https://github.com/vyasr)

URL: #14774
  • Loading branch information
mroeschke authored Jan 19, 2024
1 parent eeee795 commit 2c1b949
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 12 deletions.
22 changes: 10 additions & 12 deletions python/cudf/cudf/core/index.py
Original file line number Diff line number Diff line change
Expand Up @@ -2118,13 +2118,13 @@ def __init__(
data=None,
freq=None,
tz=None,
normalize=False,
normalize: bool = False,
closed=None,
ambiguous="raise",
dayfirst=False,
yearfirst=False,
ambiguous: Literal["raise"] = "raise",
dayfirst: bool = False,
yearfirst: bool = False,
dtype=None,
copy=False,
copy: bool = False,
name=None,
):
# we should be more strict on what we accept here but
Expand All @@ -2147,22 +2147,20 @@ def __init__(

self._freq = _validate_freq(freq)

valid_dtypes = tuple(
f"datetime64[{res}]" for res in ("s", "ms", "us", "ns")
)
if dtype is None:
# nanosecond default matches pandas
dtype = "datetime64[ns]"
elif dtype not in valid_dtypes:
raise TypeError("Invalid dtype")
dtype = cudf.dtype(dtype)
if dtype.kind != "M":
raise TypeError("dtype must be a datetime type")

kwargs = _setdefault_name(data, name=name)
name = _setdefault_name(data, name=name)["name"]
data = column.as_column(data, dtype=dtype)

if copy:
data = data.copy()

super().__init__(data, **kwargs)
super().__init__(data, name=name)

if self._freq is not None:
unique_vals = self.to_series().diff().unique()
Expand Down
8 changes: 8 additions & 0 deletions python/cudf/cudf/tests/test_datetime.py
Original file line number Diff line number Diff line change
Expand Up @@ -2429,3 +2429,11 @@ def test_dateimeindex_from_noniso_string():
def test_to_datetime_errors_non_scalar_not_implemented(errors):
with pytest.raises(NotImplementedError):
cudf.to_datetime([1, ""], unit="s", errors=errors)


def test_datetimeindex_dtype_np_dtype():
dtype = np.dtype("datetime64[ns]")
data = [1]
gdti = cudf.DatetimeIndex(data, dtype=dtype)
pdti = pd.DatetimeIndex(data, dtype=dtype)
assert_eq(gdti, pdti)

0 comments on commit 2c1b949

Please sign in to comment.