Skip to content

Commit

Permalink
Add support for ZoneInfo and generic UTC (#34683)
Browse files Browse the repository at this point in the history
* Add support for ZoneInfo and generic UTC

Certain providers rely on other datetime implementations
and fail to serialize.
  • Loading branch information
bolkedebruin authored Oct 6, 2023
1 parent 1fc2867 commit 7707f4a
Show file tree
Hide file tree
Showing 3 changed files with 68 additions and 9 deletions.
58 changes: 50 additions & 8 deletions airflow/serialization/serializers/timezone.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,17 +17,22 @@
# under the License.
from __future__ import annotations

from typing import TYPE_CHECKING
import datetime
from typing import TYPE_CHECKING, Any, cast

from airflow.utils.module_loading import qualname

if TYPE_CHECKING:
from pendulum.tz.timezone import Timezone

from airflow.serialization.serde import U


serializers = ["pendulum.tz.timezone.FixedTimezone", "pendulum.tz.timezone.Timezone"]
serializers = [
"pendulum.tz.timezone.FixedTimezone",
"pendulum.tz.timezone.Timezone",
"zoneinfo.ZoneInfo",
"backports.zoneinfo.ZoneInfo",
]

deserializers = serializers

__version__ = 1
Expand All @@ -43,21 +48,26 @@ def serialize(o: object) -> tuple[U, str, int, bool]:
0 without the special case), but passing 0 into ``pendulum.timezone`` does
not give us UTC (but ``+00:00``).
"""
from pendulum.tz.timezone import FixedTimezone, Timezone
from pendulum.tz.timezone import FixedTimezone

name = qualname(o)

if isinstance(o, FixedTimezone):
if o.offset == 0:
return "UTC", name, __version__, True
return o.offset, name, __version__, True

if isinstance(o, Timezone):
return o.name, name, __version__, True
tz_name = _get_tzinfo_name(cast(datetime.tzinfo, o))
if tz_name is not None:
return tz_name, name, __version__, True

if cast(datetime.tzinfo, o).utcoffset(None) == datetime.timedelta(0):
return "UTC", qualname(FixedTimezone), __version__, True

return "", "", 0, False


def deserialize(classname: str, version: int, data: object) -> Timezone:
def deserialize(classname: str, version: int, data: object) -> Any:
from pendulum.tz import fixed_timezone, timezone

if not isinstance(data, (str, int)):
Expand All @@ -69,4 +79,36 @@ def deserialize(classname: str, version: int, data: object) -> Timezone:
if isinstance(data, int):
return fixed_timezone(data)

if classname == "zoneinfo.ZoneInfo":
from zoneinfo import ZoneInfo

return ZoneInfo(data)

if classname == "backports.zoneinfo.ZoneInfo":
# python version might have been upgraded, so we need to check
try:
from backports.zoneinfo import ZoneInfo
except ImportError:
from zoneinfo import ZoneInfo

return ZoneInfo(data)

return timezone(data)


# ported from pendulum.tz.timezone._get_tzinfo_name
def _get_tzinfo_name(tzinfo: datetime.tzinfo | None) -> str | None:
if tzinfo is None:
return None

if hasattr(tzinfo, "key"):
# zoneinfo timezone
return tzinfo.key
elif hasattr(tzinfo, "name"):
# Pendulum timezone
return tzinfo.name
elif hasattr(tzinfo, "zone"):
# pytz timezone
return tzinfo.zone # type: ignore[no-any-return]

return None
1 change: 1 addition & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -462,6 +462,7 @@ def write_version(filename: str = str(AIRFLOW_SOURCES_ROOT / "airflow" / "git_ve

_devel_only_tests = [
"aioresponses",
"backports.zoneinfo>=0.2.1;python_version<'3.9'",
"beautifulsoup4>=4.7.1",
"coverage>=7.2",
"pytest",
Expand Down
18 changes: 17 additions & 1 deletion tests/serialization/serializers/test_serializers.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,11 +22,18 @@
import numpy as np
import pendulum.tz
import pytest
from dateutil.tz import tzutc
from pendulum import DateTime

from airflow import PY39
from airflow.models.param import Param, ParamsDict
from airflow.serialization.serde import DATA, deserialize, serialize

if PY39:
from zoneinfo import ZoneInfo
else:
from backports.zoneinfo import ZoneInfo


class TestSerializers:
def test_datetime(self):
Expand Down Expand Up @@ -62,8 +69,17 @@ def test_datetime(self):
d = deserialize(s)
assert i.timestamp() == d.timestamp()

def test_deserialize_datetime_v1(self):
i = DateTime(2022, 7, 10, tzinfo=tzutc())
s = serialize(i)
d = deserialize(s)
assert i.timestamp() == d.timestamp()

i = DateTime(2022, 7, 10, tzinfo=ZoneInfo("Europe/Paris"))
s = serialize(i)
d = deserialize(s)
assert i.timestamp() == d.timestamp()

def test_deserialize_datetime_v1(self):
s = {
"__classname__": "pendulum.datetime.DateTime",
"__version__": 1,
Expand Down

0 comments on commit 7707f4a

Please sign in to comment.