From 2c1b94970959a98780a603f18c560e79f558094d Mon Sep 17 00:00:00 2001 From: Matthew Roeschke <10647082+mroeschke@users.noreply.github.com> Date: Thu, 18 Jan 2024 18:25:28 -1000 Subject: [PATCH] Clean up `DatetimeIndex.__init__` constructor (#14774) Additionally adds some typing and remove validation done by `cudf.dtype` and add a unit test to ensure numpy dtype objects are accepted in the constructor Authors: - Matthew Roeschke (https://github.com/mroeschke) - Vyas Ramasubramani (https://github.com/vyasr) Approvers: - Vyas Ramasubramani (https://github.com/vyasr) URL: https://github.com/rapidsai/cudf/pull/14774 --- python/cudf/cudf/core/index.py | 22 ++++++++++------------ python/cudf/cudf/tests/test_datetime.py | 8 ++++++++ 2 files changed, 18 insertions(+), 12 deletions(-) diff --git a/python/cudf/cudf/core/index.py b/python/cudf/cudf/core/index.py index 3e8f6bc2ccb..96643ef08d3 100644 --- a/python/cudf/cudf/core/index.py +++ b/python/cudf/cudf/core/index.py @@ -2118,13 +2118,13 @@ def __init__( data=None, freq=None, tz=None, - normalize=False, + normalize: bool = False, closed=None, - ambiguous="raise", - dayfirst=False, - yearfirst=False, + ambiguous: Literal["raise"] = "raise", + dayfirst: bool = False, + yearfirst: bool = False, dtype=None, - copy=False, + copy: bool = False, name=None, ): # we should be more strict on what we accept here but @@ -2147,22 +2147,20 @@ def __init__( self._freq = _validate_freq(freq) - valid_dtypes = tuple( - f"datetime64[{res}]" for res in ("s", "ms", "us", "ns") - ) if dtype is None: # nanosecond default matches pandas dtype = "datetime64[ns]" - elif dtype not in valid_dtypes: - raise TypeError("Invalid dtype") + dtype = cudf.dtype(dtype) + if dtype.kind != "M": + raise TypeError("dtype must be a datetime type") - kwargs = _setdefault_name(data, name=name) + name = _setdefault_name(data, name=name)["name"] data = column.as_column(data, dtype=dtype) if copy: data = data.copy() - super().__init__(data, **kwargs) + super().__init__(data, name=name) if self._freq is not None: unique_vals = self.to_series().diff().unique() diff --git a/python/cudf/cudf/tests/test_datetime.py b/python/cudf/cudf/tests/test_datetime.py index 22d452fdda5..2ea2885bc7b 100644 --- a/python/cudf/cudf/tests/test_datetime.py +++ b/python/cudf/cudf/tests/test_datetime.py @@ -2429,3 +2429,11 @@ def test_dateimeindex_from_noniso_string(): def test_to_datetime_errors_non_scalar_not_implemented(errors): with pytest.raises(NotImplementedError): cudf.to_datetime([1, ""], unit="s", errors=errors) + + +def test_datetimeindex_dtype_np_dtype(): + dtype = np.dtype("datetime64[ns]") + data = [1] + gdti = cudf.DatetimeIndex(data, dtype=dtype) + pdti = pd.DatetimeIndex(data, dtype=dtype) + assert_eq(gdti, pdti)