diff --git a/docs/cudf/source/api_docs/index_objects.rst b/docs/cudf/source/api_docs/index_objects.rst index 0a6e3c169f0..9163440b23c 100644 --- a/docs/cudf/source/api_docs/index_objects.rst +++ b/docs/cudf/source/api_docs/index_objects.rst @@ -283,6 +283,7 @@ Time-specific operations DatetimeIndex.round DatetimeIndex.ceil DatetimeIndex.floor + DatetimeIndex.tz_localize Conversion ~~~~~~~~~~ diff --git a/docs/cudf/source/api_docs/series.rst b/docs/cudf/source/api_docs/series.rst index 9cd0770431c..b38ef3e382c 100644 --- a/docs/cudf/source/api_docs/series.rst +++ b/docs/cudf/source/api_docs/series.rst @@ -295,6 +295,7 @@ Datetime methods round floor ceil + tz_localize Timedelta properties diff --git a/python/cudf/cudf/_lib/column.pyx b/python/cudf/cudf/_lib/column.pyx index 2a163a795eb..428db210532 100644 --- a/python/cudf/cudf/_lib/column.pyx +++ b/python/cudf/cudf/_lib/column.pyx @@ -8,7 +8,7 @@ import rmm import cudf import cudf._lib as libcudf -from cudf.api.types import is_categorical_dtype +from cudf.api.types import is_categorical_dtype, is_datetime64tz_dtype from cudf.core.buffer import ( Buffer, CopyOnWriteBuffer, @@ -16,7 +16,7 @@ from cudf.core.buffer import ( acquire_spill_lock, as_buffer, ) - +from cudf.utils.dtypes import _get_base_dtype from cpython.buffer cimport PyObject_CheckBuffer from libc.stdint cimport uintptr_t from libcpp.memory cimport make_unique, unique_ptr @@ -313,9 +313,13 @@ cdef class Column: cdef mutable_column_view mutable_view(self) except *: if is_categorical_dtype(self.dtype): col = self.base_children[0] + data_dtype = col.dtype + elif is_datetime64tz_dtype(self.dtype): + col = self + data_dtype = _get_base_dtype(col.dtype) else: col = self - data_dtype = col.dtype + data_dtype = col.dtype cdef libcudf_types.data_type dtype = dtype_to_data_type(data_dtype) cdef libcudf_types.size_type offset = self.offset @@ -373,9 +377,12 @@ cdef class Column: if is_categorical_dtype(self.dtype): col = self.base_children[0] data_dtype = col.dtype + elif is_datetime64tz_dtype(self.dtype): + col = self + data_dtype = _get_base_dtype(col.dtype) else: col = self - data_dtype = self.dtype + data_dtype = col.dtype cdef libcudf_types.data_type dtype = dtype_to_data_type(data_dtype) cdef libcudf_types.size_type offset = self.offset diff --git a/python/cudf/cudf/core/_internals/timezones.py b/python/cudf/cudf/core/_internals/timezones.py index 0cc5db57c9c..693aa1acf9e 100644 --- a/python/cudf/cudf/core/_internals/timezones.py +++ b/python/cudf/cudf/core/_internals/timezones.py @@ -3,9 +3,19 @@ import os import zoneinfo from functools import lru_cache +from typing import Tuple, cast -from cudf._lib.timezone import build_timezone_transition_table +import numpy as np +import pandas as pd + +import cudf +from cudf._lib.labeling import label_bins +from cudf._lib.search import search_sorted +from cudf._lib.timezone import make_timezone_transition_table +from cudf.core.column.column import as_column, build_column +from cudf.core.column.datetime import DatetimeColumn, DatetimeTZColumn from cudf.core.dataframe import DataFrame +from cudf.utils.dtypes import _get_base_dtype @lru_cache(maxsize=20) @@ -21,15 +31,16 @@ def get_tz_data(zone_name): Returns ------- - DataFrame with two columns containing the transition times ("dt") - and corresponding UTC offsets ("offset"). + DataFrame with two columns containing the transition times + ("transition_times") and corresponding UTC offsets ("offsets"). """ try: # like zoneinfo, we first look in TZPATH - return _find_and_read_tzfile_tzpath(zone_name) + tz_table = _find_and_read_tzfile_tzpath(zone_name) except zoneinfo.ZoneInfoNotFoundError: # if that fails, we fall back to using `tzdata` - return _find_and_read_tzfile_tzdata(zone_name) + tz_table = _find_and_read_tzfile_tzdata(zone_name) + return tz_table def _find_and_read_tzfile_tzpath(zone_name): @@ -67,5 +78,143 @@ def _find_and_read_tzfile_tzdata(zone_name): def _read_tzfile_as_frame(tzdir, zone_name): - dt, offsets = build_timezone_transition_table(tzdir, zone_name) - return DataFrame._from_columns([dt, offsets], ["dt", "offsets"]) + transition_times_and_offsets = make_timezone_transition_table( + tzdir, zone_name + ) + + if not transition_times_and_offsets: + # this happens for UTC-like zones + min_date = np.int64(np.iinfo("int64").min + 1).astype("M8[s]") + transition_times_and_offsets = as_column([min_date]), as_column( + [np.timedelta64(0, "s")] + ) + + return DataFrame._from_columns( + transition_times_and_offsets, ["transition_times", "offsets"] + ) + + +def _find_ambiguous_and_nonexistent( + data: DatetimeColumn, zone_name: str +) -> Tuple: + """ + Recognize ambiguous and nonexistent timestamps for the given timezone. + + Returns a tuple of columns, both of "bool" dtype and of the same + size as `data`, that respectively indicate ambiguous and + nonexistent timestamps in `data` with the value `True`. + + Ambiguous and/or nonexistent timestamps are only possible if any + transitions occur in the time zone database for the given timezone. + If no transitions occur, the tuple `(False, False)` is returned. + """ + tz_data_for_zone = get_tz_data(zone_name) + transition_times = tz_data_for_zone["transition_times"] + offsets = tz_data_for_zone["offsets"].astype( + f"timedelta64[{data._time_unit}]" + ) + + if len(offsets) == 1: # no transitions + return False, False + + transition_times, offsets, old_offsets = ( + transition_times[1:]._column, + offsets[1:]._column, + offsets[:-1]._column, + ) + + # Assume we have two clocks at the moment of transition: + # - Clock 1 is turned forward or backwards correctly + # - Clock 2 makes no changes + clock_1 = transition_times + offsets + clock_2 = transition_times + old_offsets + + # At the start of an ambiguous time period, Clock 1 (which has + # been turned back) reads less than Clock 2: + cond = clock_1 < clock_2 + ambiguous_begin = clock_1.apply_boolean_mask(cond) + + # The end of an ambiguous time period is what Clock 2 reads at + # the moment of transition: + ambiguous_end = clock_2.apply_boolean_mask(cond) + ambiguous = label_bins( + data, + left_edges=ambiguous_begin, + left_inclusive=True, + right_edges=ambiguous_end, + right_inclusive=False, + ).notnull() + + # At the start of a non-existent time period, Clock 2 reads less + # than Clock 1 (which has been turned forward): + cond = clock_1 > clock_2 + nonexistent_begin = clock_2.apply_boolean_mask(cond) + + # The end of the non-existent time period is what Clock 1 reads + # at the moment of transition: + nonexistent_end = clock_1.apply_boolean_mask(cond) + nonexistent = label_bins( + data, + left_edges=nonexistent_begin, + left_inclusive=True, + right_edges=nonexistent_end, + right_inclusive=False, + ).notnull() + + return ambiguous, nonexistent + + +def localize( + data: DatetimeColumn, zone_name: str, ambiguous, nonexistent +) -> DatetimeTZColumn: + if ambiguous != "NaT": + raise NotImplementedError( + "Only ambiguous='NaT' is currently supported" + ) + if nonexistent != "NaT": + raise NotImplementedError( + "Only nonexistent='NaT' is currently supported" + ) + if isinstance(data, DatetimeTZColumn): + raise ValueError( + "Already localized. " + "Use `tz_convert` to convert between time zones." + ) + dtype = pd.DatetimeTZDtype(data._time_unit, zone_name) + ambiguous, nonexistent = _find_ambiguous_and_nonexistent(data, zone_name) + localized = cast( + DatetimeColumn, + data._scatter_by_column( + data.isnull() | (ambiguous | nonexistent), + cudf.Scalar(cudf.NA, dtype=data.dtype), + ), + ) + gmt_data = local_to_utc(localized, zone_name) + return cast( + DatetimeTZColumn, + build_column( + data=gmt_data.data, + dtype=dtype, + mask=localized.mask, + size=gmt_data.size, + offset=gmt_data.offset, + ), + ) + + +def utc_to_local(data: DatetimeColumn, zone_name: str) -> DatetimeColumn: + tz_data_for_zone = get_tz_data(zone_name) + transition_times, offsets = tz_data_for_zone._columns + transition_times = transition_times.astype(_get_base_dtype(data.dtype)) + indices = search_sorted([transition_times], [data], "right") - 1 + offsets_from_utc = offsets.take(indices, nullify=True) + return data + offsets_from_utc + + +def local_to_utc(data: DatetimeColumn, zone_name: str) -> DatetimeColumn: + tz_data_for_zone = get_tz_data(zone_name) + transition_times, offsets = tz_data_for_zone._columns + transition_times_local = (transition_times + offsets).astype(data.dtype) + indices = search_sorted([transition_times_local], [data], "right") - 1 + offsets_to_utc = offsets.take(indices, nullify=True) + return data - offsets_to_utc diff --git a/python/cudf/cudf/core/column/__init__.py b/python/cudf/cudf/core/column/__init__.py index 96e2a7554cf..aba4ded4f9d 100644 --- a/python/cudf/cudf/core/column/__init__.py +++ b/python/cudf/cudf/core/column/__init__.py @@ -1,4 +1,4 @@ -# Copyright (c) 2020-2022, NVIDIA CORPORATION. +# Copyright (c) 2020-2023, NVIDIA CORPORATION. """ isort: skip_file @@ -23,6 +23,7 @@ serialize_columns, ) from cudf.core.column.datetime import DatetimeColumn # noqa: F401 +from cudf.core.column.datetime import DatetimeTZColumn # noqa: F401 from cudf.core.column.lists import ListColumn # noqa: F401 from cudf.core.column.numerical import NumericalColumn # noqa: F401 from cudf.core.column.string import StringColumn # noqa: F401 diff --git a/python/cudf/cudf/core/column/column.py b/python/cudf/cudf/core/column/column.py index 4575f57d565..042a1060fae 100644 --- a/python/cudf/cudf/core/column/column.py +++ b/python/cudf/cudf/core/column/column.py @@ -1554,6 +1554,17 @@ def build_column( offset=offset, null_count=null_count, ) + elif is_datetime64tz_dtype(dtype): + if data is None: + raise TypeError("Must specify data buffer") + return cudf.core.column.datetime.DatetimeTZColumn( + data=data, + dtype=dtype, + mask=mask, + size=size, + offset=offset, + null_count=null_count, + ) elif dtype.type is np.timedelta64: if data is None: raise TypeError("Must specify data buffer") @@ -2093,9 +2104,7 @@ def as_column( data = _make_copy_replacing_NaT_with_null(data) mask = data.mask - data = cudf.core.column.datetime.DatetimeColumn( - data=buffer, mask=mask, dtype=arbitrary.dtype - ) + data = build_column(data=buffer, mask=mask, dtype=arbitrary.dtype) elif arb_dtype.kind == "m": time_unit = get_time_unit(arbitrary) @@ -2243,8 +2252,8 @@ def as_column( raise TypeError if is_datetime64tz_dtype(dtype): raise NotImplementedError( - "cuDF does not yet support " - "timezone-aware datetimes" + "Use `tz_localize()` to construct " + "timezone aware data." ) if is_list_dtype(dtype): data = pa.array(arbitrary) diff --git a/python/cudf/cudf/core/column/datetime.py b/python/cudf/cudf/core/column/datetime.py index 14aa7bdd84b..4c65a631adc 100644 --- a/python/cudf/cudf/core/column/datetime.py +++ b/python/cudf/cudf/core/column/datetime.py @@ -10,6 +10,7 @@ import numpy as np import pandas as pd +import pyarrow as pa import cudf from cudf import _lib as libcudf @@ -20,10 +21,16 @@ DtypeObj, ScalarLike, ) -from cudf.api.types import is_datetime64_dtype, is_scalar, is_timedelta64_dtype +from cudf.api.types import ( + is_datetime64_dtype, + is_datetime64tz_dtype, + is_scalar, + is_timedelta64_dtype, +) from cudf.core.buffer import Buffer, cuda_array_interface_wrapper from cudf.core.column import ColumnBase, as_column, column, string from cudf.core.column.timedelta import _unit_to_nanoseconds_conversion +from cudf.utils.dtypes import _get_base_dtype from cudf.utils.utils import _fillna_natwise _guess_datetime_format = pd.core.tools.datetimes.guess_datetime_format @@ -517,6 +524,63 @@ def can_cast_safely(self, to_dtype: Dtype) -> bool: else: return False + def _with_type_metadata(self, dtype): + if is_datetime64tz_dtype(dtype): + return DatetimeTZColumn( + data=self.base_data, + dtype=dtype, + mask=self.base_mask, + size=self.size, + offset=self.offset, + null_count=self.null_count, + ) + return self + + +class DatetimeTZColumn(DatetimeColumn): + def __init__( + self, + data: Buffer, + dtype: pd.DatetimeTZDtype, + mask: Buffer = None, + size: int = None, + offset: int = 0, + null_count: int = None, + ): + super().__init__( + data=data, + dtype=_get_base_dtype(dtype), + mask=mask, + size=size, + offset=offset, + null_count=null_count, + ) + self._dtype = dtype + + def to_pandas( + self, index: pd.Index = None, nullable: bool = False, **kwargs + ) -> "cudf.Series": + return self._local_time.to_pandas().dt.tz_localize( + self.dtype.tz, ambiguous="NaT", nonexistent="NaT" + ) + + def to_arrow(self): + return pa.compute.assume_timezone( + self._local_time.to_arrow(), str(self.dtype.tz) + ) + + @property + def _local_time(self): + """Return the local time as naive timestamps.""" + from cudf.core._internals.timezones import utc_to_local + + return utc_to_local(self, str(self.dtype.tz)) + + def as_string_column( + self, dtype: Dtype, format=None, **kwargs + ) -> "cudf.core.column.StringColumn": + return self._local_time.as_string_column(dtype, format, **kwargs) + def infer_format(element: str, **kwargs) -> str: """ diff --git a/python/cudf/cudf/core/dataframe.py b/python/cudf/cudf/core/dataframe.py index c6aba0d360a..13f6843b1cf 100644 --- a/python/cudf/cudf/core/dataframe.py +++ b/python/cudf/cudf/core/dataframe.py @@ -1697,20 +1697,10 @@ def _concat( out = out.set_index( cudf.core.index.as_index(out.index._values) ) - - # Reassign precision for decimal cols & type schema for struct cols for name, col in out._data.items(): - if isinstance( - col, - ( - cudf.core.column.DecimalBaseColumn, - cudf.core.column.StructColumn, - cudf.core.column.ListColumn, - ), - ): - out._data[name] = col._with_type_metadata( - tables[0]._data[name].dtype - ) + out._data[name] = col._with_type_metadata( + tables[0]._data[name].dtype + ) # Reassign index and column names if objs[0]._data.multiindex: @@ -3805,10 +3795,7 @@ def transpose(self): ) for codes in result_columns ] - elif isinstance( - source_dtype, - (cudf.ListDtype, cudf.StructDtype, cudf.core.dtypes.DecimalDtype), - ): + else: result_columns = [ result_column._with_type_metadata(source_dtype) for result_column in result_columns diff --git a/python/cudf/cudf/core/dtypes.py b/python/cudf/cudf/core/dtypes.py index d80557355ff..406708fd58b 100644 --- a/python/cudf/cudf/core/dtypes.py +++ b/python/cudf/cudf/core/dtypes.py @@ -71,6 +71,8 @@ def dtype(arbitrary): return np.dtype("object") elif isinstance(pd_dtype, pd.IntervalDtype): return cudf.IntervalDtype.from_pandas(pd_dtype) + elif isinstance(pd_dtype, pd.DatetimeTZDtype): + return pd_dtype else: raise TypeError( f"Cannot interpret {arbitrary} as a valid cuDF dtype" diff --git a/python/cudf/cudf/core/frame.py b/python/cudf/cudf/core/frame.py index d8b9ee4d006..230a5054a00 100644 --- a/python/cudf/cudf/core/frame.py +++ b/python/cudf/cudf/core/frame.py @@ -1125,7 +1125,12 @@ def from_arrow(cls, data): result[name] = result[name].as_string_column(cudf.dtype("str")) elif name in data.column_names and isinstance( data[name].type, - (pa.StructType, pa.ListType, pa.Decimal128Type), + ( + pa.StructType, + pa.ListType, + pa.Decimal128Type, + pa.TimestampType, + ), ): # In case of struct column, libcudf is not aware of names of # struct fields, hence renaming the struct fields is @@ -1138,6 +1143,9 @@ def from_arrow(cls, data): # In case of list column, there is a possibility of nested # list columns to have struct or decimal columns inside them. + # Datetimes ("timestamps") may need timezone metadata + # attached to them, as libcudf is timezone-unaware + # All of these cases are handled by calling the # _with_type_metadata method on the column. result[name] = result[name]._with_type_metadata( diff --git a/python/cudf/cudf/core/index.py b/python/cudf/cudf/core/index.py index d1408fec160..281290e1788 100644 --- a/python/cudf/cudf/core/index.py +++ b/python/cudf/cudf/core/index.py @@ -2364,7 +2364,13 @@ def isocalendar(self): @_cudf_nvtx_annotate def to_pandas(self, nullable=False): - nanos = self._values.astype("datetime64[ns]") + # TODO: no need to convert to nanos with Pandas 2.x + if isinstance(self.dtype, pd.DatetimeTZDtype): + nanos = self._values.astype( + pd.DatetimeTZDtype("ns", self.dtype.tz) + ) + else: + nanos = self._values.astype("datetime64[ns]") return pd.DatetimeIndex(nanos.to_pandas(), name=self.name) @_cudf_nvtx_annotate @@ -2490,6 +2496,57 @@ def round(self, freq): return self.__class__._from_data({self.name: out_column}) + def tz_localize(self, tz, ambiguous="NaT", nonexistent="NaT"): + """ + Localize timezone-naive data to timezone-aware data. + + Parameters + ---------- + tz : str + Timezone to convert timestamps to. + + Returns + ------- + DatetimeIndex containing timezone aware timestamps. + + Examples + -------- + >>> import cudf + >>> import pandas as pd + >>> tz_naive = cudf.date_range('2018-03-01 09:00', periods=3, freq='D') + >>> tz_aware = tz_naive.tz_localize("America/New_York") + >>> tz_aware + DatetimeIndex(['2018-03-01 09:00:00-05:00', + '2018-03-02 09:00:00-05:00', + '2018-03-03 09:00:00-05:00'], + dtype='datetime64[ns, America/New_York]') + + Ambiguous or nonexistent datetimes are converted to NaT. + + >>> s = cudf.to_datetime(cudf.Series(['2018-10-28 01:20:00', + ... '2018-10-28 02:36:00', + ... '2018-10-28 03:46:00'])) + >>> s.dt.tz_localize("CET") + 0 2018-10-28 01:20:00.000000000 + 1 + 2 2018-10-28 03:46:00.000000000 + dtype: datetime64[ns, CET] + + Notes + ----- + 'NaT' is currently the only supported option for the + ``ambiguous`` and ``nonexistent`` arguments. Any + ambiguous or nonexistent timestamps are converted + to 'NaT'. + """ + from cudf.core._internals.timezones import localize + + if tz is None: + result_col = self._column._local_time + else: + result_col = localize(self._column, tz, ambiguous, nonexistent) + return DatetimeIndex._from_data({self.name: result_col}) + class TimedeltaIndex(GenericIndex): """ diff --git a/python/cudf/cudf/core/series.py b/python/cudf/cudf/core/series.py index 9e07d135926..57a3653edf1 100644 --- a/python/cudf/cudf/core/series.py +++ b/python/cudf/cudf/core/series.py @@ -60,7 +60,7 @@ from cudf.core.column.struct import StructMethods from cudf.core.column_accessor import ColumnAccessor from cudf.core.groupby.groupby import SeriesGroupBy, groupby_doc_template -from cudf.core.index import BaseIndex, RangeIndex, as_index +from cudf.core.index import BaseIndex, DatetimeIndex, RangeIndex, as_index from cudf.core.indexed_frame import ( IndexedFrame, _FrameIndexer, @@ -1541,15 +1541,7 @@ def _concat(cls, objs, axis=0, index=True): col = concat_columns([o._column for o in objs]) - # Reassign precision for decimal cols & type schema for struct cols - if isinstance( - col, - ( - cudf.core.column.DecimalBaseColumn, - cudf.core.column.StructColumn, - cudf.core.column.ListColumn, - ), - ): + if len(objs): col = col._with_type_metadata(objs[0].dtype) return cls(data=col, index=index, name=name) @@ -4602,6 +4594,21 @@ def strftime(self, date_format, *args, **kwargs): data=str_col, index=self.series._index, name=self.series.name ) + @copy_docstring(DatetimeIndex.tz_localize) + def tz_localize(self, tz, ambiguous="NaT", nonexistent="NaT"): + from cudf.core._internals.timezones import localize + + if tz is None: + result_col = self.series._column._local_time + else: + result_col = localize( + self.series._column, tz, ambiguous, nonexistent + ) + return Series._from_data( + data={self.series.name: result_col}, + index=self.series._index, + ) + class TimedeltaProperties: """ diff --git a/python/cudf/cudf/tests/indexes/datetime/test_time_specific.py b/python/cudf/cudf/tests/indexes/datetime/test_time_specific.py index 03d0d3d4602..7a3fcc25033 100644 --- a/python/cudf/cudf/tests/indexes/datetime/test_time_specific.py +++ b/python/cudf/cudf/tests/indexes/datetime/test_time_specific.py @@ -1 +1,16 @@ # Copyright (c) 2022-2023, NVIDIA CORPORATION. +import pandas as pd + +import cudf +from cudf.testing._utils import assert_eq + + +def test_tz_localize(): + pidx = pd.date_range("2001-01-01", "2001-01-02", freq="1s") + pidx = pidx.astype(" bool: raise NotImplementedError(f"Unsupported dtype: {dtype}") +def _get_base_dtype(dtype: DtypeObj) -> DtypeObj: + # TODO: replace the use of this function with just `dtype.base` + # when Pandas 2.1.0 is the minimum version we support: + # https://github.com/pandas-dev/pandas/pull/52706 + if isinstance(dtype, pd.DatetimeTZDtype): + return np.dtype(f"