Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Restrict the allowed pandas timezone objects in cudf #16013

Merged
merged 13 commits into from
Jun 24, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 32 additions & 1 deletion python/cudf/cudf/core/_internals/timezones.py
Original file line number Diff line number Diff line change
@@ -1,21 +1,50 @@
# Copyright (c) 2023-2024, NVIDIA CORPORATION.
from __future__ import annotations

import datetime
import os
import zoneinfo
from functools import lru_cache
from typing import TYPE_CHECKING, Literal

import numpy as np
import pandas as pd

import cudf
from cudf._lib.timezone import make_timezone_transition_table
from cudf.core.column.column import as_column

if TYPE_CHECKING:
from cudf.core.column.datetime import DatetimeColumn
from cudf.core.column.timedelta import TimeDeltaColumn


def get_compatible_timezone(dtype: pd.DatetimeTZDtype) -> pd.DatetimeTZDtype:
"""Convert dtype.tz object to zoneinfo object if possible."""
tz = dtype.tz
if isinstance(tz, zoneinfo.ZoneInfo):
return dtype
wence- marked this conversation as resolved.
Show resolved Hide resolved
if cudf.get_option("mode.pandas_compatible"):
raise NotImplementedError(
f"{tz} must be a zoneinfo.ZoneInfo object in pandas_compatible mode."
)
elif (tzname := getattr(tz, "zone", None)) is not None:
# pytz-like
key = tzname
elif (tz_file := getattr(tz, "_filename", None)) is not None:
# dateutil-like
key = tz_file.split("zoneinfo/")[-1]
elif isinstance(tz, datetime.tzinfo):
# Try to get UTC-like tzinfos
reference = datetime.datetime.now()
key = tz.tzname(reference)
wence- marked this conversation as resolved.
Show resolved Hide resolved
if not (isinstance(key, str) and key.lower() == "utc"):
raise NotImplementedError(f"cudf does not support {tz}")
else:
raise NotImplementedError(f"cudf does not support {tz}")
new_tz = zoneinfo.ZoneInfo(key)
return pd.DatetimeTZDtype(dtype.unit, new_tz)


@lru_cache(maxsize=20)
def get_tz_data(zone_name: str) -> tuple[DatetimeColumn, TimeDeltaColumn]:
"""
Expand Down Expand Up @@ -87,6 +116,8 @@ def _read_tzfile_as_columns(
)

if not transition_times_and_offsets:
from cudf.core.column.column import as_column

# this happens for UTC-like zones
min_date = np.int64(np.iinfo("int64").min + 1).astype("M8[s]")
return (as_column([min_date]), as_column([np.timedelta64(0, "s")]))
Expand Down
16 changes: 16 additions & 0 deletions python/cudf/cudf/core/column/column.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@
is_string_dtype,
)
from cudf.core._compat import PANDAS_GE_210
from cudf.core._internals.timezones import get_compatible_timezone
from cudf.core.abc import Serializable
from cudf.core.buffer import (
Buffer,
Expand Down Expand Up @@ -1854,6 +1855,21 @@ def as_column(
arbitrary.dtype,
(pd.CategoricalDtype, pd.IntervalDtype, pd.DatetimeTZDtype),
):
if isinstance(arbitrary.dtype, pd.DatetimeTZDtype):
new_tz = get_compatible_timezone(arbitrary.dtype)
arbitrary = arbitrary.astype(new_tz)
if isinstance(arbitrary.dtype, pd.CategoricalDtype) and isinstance(
arbitrary.dtype.categories.dtype, pd.DatetimeTZDtype
):
new_tz = get_compatible_timezone(
arbitrary.dtype.categories.dtype
)
new_cats = arbitrary.dtype.categories.astype(new_tz)
new_dtype = pd.CategoricalDtype(
categories=new_cats, ordered=arbitrary.dtype.ordered
)
arbitrary = arbitrary.astype(new_dtype)

return as_column(
pa.array(arbitrary, from_pandas=True),
nan_as_null=nan_as_null,
Expand Down
33 changes: 14 additions & 19 deletions python/cudf/cudf/core/column/datetime.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,11 @@
from cudf._lib.search import search_sorted
from cudf.api.types import is_datetime64_dtype, is_scalar, is_timedelta64_dtype
from cudf.core._compat import PANDAS_GE_220
from cudf.core._internals.timezones import (
check_ambiguous_and_nonexistent,
get_compatible_timezone,
get_tz_data,
)
from cudf.core.column import ColumnBase, as_column, column, string
from cudf.core.column.timedelta import _unit_to_nanoseconds_conversion
from cudf.utils.dtypes import _get_base_dtype
Expand Down Expand Up @@ -282,8 +287,6 @@ def __contains__(self, item: ScalarLike) -> bool:

@functools.cached_property
def time_unit(self) -> str:
if isinstance(self.dtype, pd.DatetimeTZDtype):
return self.dtype.unit
return np.datetime_data(self.dtype)[0]

@property
Expand Down Expand Up @@ -715,8 +718,6 @@ def _find_ambiguous_and_nonexistent(
transitions occur in the time zone database for the given timezone.
If no transitions occur, the tuple `(False, False)` is returned.
"""
from cudf.core._internals.timezones import get_tz_data

transition_times, offsets = get_tz_data(zone_name)
offsets = offsets.astype(f"timedelta64[{self.time_unit}]") # type: ignore[assignment]

Expand Down Expand Up @@ -775,26 +776,22 @@ def tz_localize(
ambiguous: Literal["NaT"] = "NaT",
nonexistent: Literal["NaT"] = "NaT",
):
from cudf.core._internals.timezones import (
check_ambiguous_and_nonexistent,
get_tz_data,
)

if tz is None:
return self.copy()
ambiguous, nonexistent = check_ambiguous_and_nonexistent(
ambiguous, nonexistent
)
dtype = pd.DatetimeTZDtype(self.time_unit, tz)
dtype = get_compatible_timezone(pd.DatetimeTZDtype(self.time_unit, tz))
tzname = dtype.tz.key
ambiguous_col, nonexistent_col = self._find_ambiguous_and_nonexistent(
tz
tzname
)
localized = self._scatter_by_column(
self.isnull() | (ambiguous_col | nonexistent_col),
cudf.Scalar(cudf.NaT, dtype=self.dtype),
)

transition_times, offsets = get_tz_data(tz)
transition_times, offsets = get_tz_data(tzname)
transition_times_local = (transition_times + offsets).astype(
localized.dtype
)
Expand Down Expand Up @@ -835,7 +832,7 @@ def __init__(
offset=offset,
null_count=null_count,
)
self._dtype = dtype
self._dtype = get_compatible_timezone(dtype)

def to_pandas(
self,
Expand All @@ -855,6 +852,10 @@ def to_arrow(self):
self._local_time.to_arrow(), str(self.dtype.tz)
)

@functools.cached_property
def time_unit(self) -> str:
return self.dtype.unit

@property
def _utc_time(self):
"""Return UTC time as naive timestamps."""
Expand All @@ -870,8 +871,6 @@ def _utc_time(self):
@property
def _local_time(self):
"""Return the local time as naive timestamps."""
from cudf.core._internals.timezones import get_tz_data

transition_times, offsets = get_tz_data(str(self.dtype.tz))
transition_times = transition_times.astype(_get_base_dtype(self.dtype))
indices = search_sorted([transition_times], [self], "right") - 1
Expand Down Expand Up @@ -901,10 +900,6 @@ def __repr__(self):
)

def tz_localize(self, tz: str | None, ambiguous="NaT", nonexistent="NaT"):
from cudf.core._internals.timezones import (
check_ambiguous_and_nonexistent,
)

if tz is None:
return self._local_time
ambiguous, nonexistent = check_ambiguous_and_nonexistent(
Expand Down
12 changes: 5 additions & 7 deletions python/cudf/cudf/tests/indexes/datetime/test_indexing.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
# Copyright (c) 2023, NVIDIA CORPORATION.
# Copyright (c) 2023-2024, NVIDIA CORPORATION.
import zoneinfo

import pandas as pd

Expand All @@ -7,13 +8,10 @@


def test_slice_datetimetz_index():
tz = zoneinfo.ZoneInfo("US/Eastern")
data = ["2001-01-01", "2001-01-02", None, None, "2001-01-03"]
pidx = pd.DatetimeIndex(data, dtype="datetime64[ns]").tz_localize(
"US/Eastern"
)
idx = cudf.DatetimeIndex(data, dtype="datetime64[ns]").tz_localize(
"US/Eastern"
)
pidx = pd.DatetimeIndex(data, dtype="datetime64[ns]").tz_localize(tz)
idx = cudf.DatetimeIndex(data, dtype="datetime64[ns]").tz_localize(tz)
expected = pidx[1:4]
got = idx[1:4]
assert_eq(expected, got)
13 changes: 6 additions & 7 deletions python/cudf/cudf/tests/indexes/datetime/test_time_specific.py
Original file line number Diff line number Diff line change
@@ -1,29 +1,28 @@
# Copyright (c) 2022-2024, NVIDIA CORPORATION.
import zoneinfo

import pandas as pd

import cudf
from cudf.testing._utils import assert_eq


def test_tz_localize():
tz = zoneinfo.ZoneInfo("America/New_York")
pidx = pd.date_range("2001-01-01", "2001-01-02", freq="1s")
pidx = pidx.astype("<M8[ns]")
idx = cudf.from_pandas(pidx)
assert pidx.dtype == idx.dtype
assert_eq(
pidx.tz_localize("America/New_York"),
idx.tz_localize("America/New_York"),
)
assert_eq(pidx.tz_localize(tz), idx.tz_localize(tz))
wence- marked this conversation as resolved.
Show resolved Hide resolved


def test_tz_convert():
tz = zoneinfo.ZoneInfo("America/New_York")
pidx = pd.date_range("2023-01-01", periods=3, freq="h")
idx = cudf.from_pandas(pidx)
pidx = pidx.tz_localize("UTC")
idx = idx.tz_localize("UTC")
assert_eq(
pidx.tz_convert("America/New_York"), idx.tz_convert("America/New_York")
)
assert_eq(pidx.tz_convert(tz), idx.tz_convert(tz))


def test_delocalize_naive():
Expand Down
40 changes: 35 additions & 5 deletions python/cudf/cudf/tests/series/test_datetimelike.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
# Copyright (c) 2023-2024, NVIDIA CORPORATION.

import datetime
import os
import zoneinfo

import pandas as pd
import pytest
Expand Down Expand Up @@ -70,7 +72,7 @@ def test_localize_ambiguous(request, unit, zone_name):
dtype=f"datetime64[{unit}]",
)
expect = s.to_pandas().dt.tz_localize(
zone_name, ambiguous="NaT", nonexistent="NaT"
zoneinfo.ZoneInfo(zone_name), ambiguous="NaT", nonexistent="NaT"
)
got = s.dt.tz_localize(zone_name)
assert_eq(expect, got)
Expand All @@ -96,7 +98,7 @@ def test_localize_nonexistent(request, unit, zone_name):
dtype=f"datetime64[{unit}]",
)
expect = s.to_pandas().dt.tz_localize(
zone_name, ambiguous="NaT", nonexistent="NaT"
zoneinfo.ZoneInfo(zone_name), ambiguous="NaT", nonexistent="NaT"
)
got = s.dt.tz_localize(zone_name)
assert_eq(expect, got)
Expand Down Expand Up @@ -130,6 +132,9 @@ def test_delocalize_naive():
"to_tz", ["Europe/London", "America/Chicago", "UTC", None]
)
def test_convert(from_tz, to_tz):
from_tz = zoneinfo.ZoneInfo(from_tz)
if to_tz is not None:
to_tz = zoneinfo.ZoneInfo(to_tz)
ps = pd.Series(pd.date_range("2023-01-01", periods=3, freq="h"))
gs = cudf.from_pandas(ps)
ps = ps.dt.tz_localize(from_tz)
Expand Down Expand Up @@ -169,6 +174,8 @@ def test_convert_from_naive():
],
)
def test_convert_edge_cases(data, original_timezone, target_timezone):
original_timezone = zoneinfo.ZoneInfo(original_timezone)
target_timezone = zoneinfo.ZoneInfo(target_timezone)
ps = pd.Series(data, dtype="datetime64[s]").dt.tz_localize(
original_timezone
)
Expand Down Expand Up @@ -229,10 +236,33 @@ def test_tz_convert_naive_typeerror():
"klass", ["Series", "DatetimeIndex", "Index", "CategoricalIndex"]
)
def test_from_pandas_obj_tz_aware(klass):
tz_aware_data = [
pd.Timestamp("2020-01-01", tz="UTC").tz_convert("US/Pacific")
]
tz = zoneinfo.ZoneInfo("US/Pacific")
tz_aware_data = [pd.Timestamp("2020-01-01", tz="UTC").tz_convert(tz)]
pandas_obj = getattr(pd, klass)(tz_aware_data)
result = cudf.from_pandas(pandas_obj)
expected = getattr(cudf, klass)(tz_aware_data)
assert_eq(result, expected)


@pytest.mark.parametrize(
"klass", ["Series", "DatetimeIndex", "Index", "CategoricalIndex"]
)
def test_from_pandas_obj_tz_aware_unsupported(klass):
tz = datetime.timezone(datetime.timedelta(hours=1))
tz_aware_data = [pd.Timestamp("2020-01-01", tz="UTC").tz_convert(tz)]
pandas_obj = getattr(pd, klass)(tz_aware_data)
with pytest.raises(NotImplementedError):
cudf.from_pandas(pandas_obj)


@pytest.mark.parametrize(
"klass", ["Series", "DatetimeIndex", "Index", "CategoricalIndex"]
)
def test_pandas_compatible_non_zoneinfo_raises(klass):
pytz = pytest.importorskip("pytz")
tz = pytz.timezone("US/Pacific")
tz_aware_data = [pd.Timestamp("2020-01-01", tz="UTC").tz_convert(tz)]
pandas_obj = getattr(pd, klass)(tz_aware_data)
with cudf.option_context("mode.pandas_compatible", True):
with pytest.raises(NotImplementedError):
cudf.from_pandas(pandas_obj)
Loading