Skip to content

Commit

Permalink
Rework use of timezone naive and aware operations to be consistent
Browse files Browse the repository at this point in the history
  • Loading branch information
stefannica committed Jan 23, 2025
1 parent 78ced86 commit c41c43e
Show file tree
Hide file tree
Showing 60 changed files with 247 additions and 223 deletions.
8 changes: 2 additions & 6 deletions src/zenml/analytics/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@
The base functionalities are adapted to work with the ZenML analytics server.
"""

import datetime
import locale
from types import TracebackType
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Type, Union
Expand All @@ -32,6 +31,7 @@
)
from zenml.environment import Environment, get_environment
from zenml.logger import get_logger
from zenml.utils.time_utils import utc_now_tz_aware

if TYPE_CHECKING:
from zenml.analytics.enums import AnalyticsEvent
Expand Down Expand Up @@ -284,11 +284,7 @@ def track(

try:
# Timezone as tzdata
tz = (
datetime.datetime.now(datetime.timezone.utc)
.astimezone()
.tzname()
)
tz = utc_now_tz_aware().astimezone().tzname()
if tz is not None:
properties.update({"timezone": tz})

Expand Down
10 changes: 5 additions & 5 deletions src/zenml/cli/service_connectors.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
# permissions and limitations under the License.
"""Service connector CLI commands."""

from datetime import datetime, timezone
from datetime import datetime
from typing import Any, Dict, List, Optional, Union, cast
from uuid import UUID

Expand All @@ -25,7 +25,6 @@
is_sorted_or_filtered,
list_options,
print_page_info,
seconds_to_human_readable,
)
from zenml.client import Client
from zenml.console import console
Expand All @@ -37,6 +36,7 @@
ServiceConnectorResourcesModel,
ServiceConnectorResponse,
)
from zenml.utils.time_utils import seconds_to_human_readable, utc_now


# Service connectors
Expand Down Expand Up @@ -292,7 +292,7 @@ def prompt_expires_at(
default_str = ""
if default is not None:
seconds = int(
(default - datetime.now(timezone.utc)).total_seconds()
(default - utc_now(tz_aware=default)).total_seconds()
)
default_str = (
f" [{str(default)} i.e. in "
Expand All @@ -309,15 +309,15 @@ def prompt_expires_at(

assert expires_at is not None
assert isinstance(expires_at, datetime)
if expires_at < datetime.now(timezone.utc):
if expires_at < utc_now(tz_aware=expires_at):
cli_utils.warning(
"The expiration time must be in the future. Please enter a "
"later date and time."
)
continue

seconds = int(
(expires_at - datetime.now(timezone.utc)).total_seconds()
(expires_at - utc_now(tz_aware=expires_at)).total_seconds()
)

confirm = click.confirm(
Expand Down
4 changes: 2 additions & 2 deletions src/zenml/cli/stack.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@
import re
import time
import webbrowser
from datetime import datetime, timezone
from typing import (
TYPE_CHECKING,
Any,
Expand Down Expand Up @@ -77,6 +76,7 @@
)
from zenml.utils import requirements_utils
from zenml.utils.dashboard_utils import get_component_url, get_stack_url
from zenml.utils.time_utils import utc_now_tz_aware
from zenml.utils.yaml_utils import read_yaml, write_yaml

if TYPE_CHECKING:
Expand Down Expand Up @@ -1575,7 +1575,7 @@ def deploy(
):
raise click.Abort()

date_start = datetime.now(timezone.utc)
date_start = utc_now_tz_aware()

webbrowser.open(deployment_config.deployment_url)
console.print(
Expand Down
2 changes: 1 addition & 1 deletion src/zenml/cli/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@
"""Utility functions for the CLI."""

import contextlib
import datetime
import json
import os
import platform
Expand Down Expand Up @@ -1581,6 +1580,7 @@ def print_components_table(
configurations.append(component_config)
print_table(configurations)


def print_service_connectors_table(
client: "Client",
connectors: Sequence["ServiceConnectorResponse"],
Expand Down
5 changes: 3 additions & 2 deletions src/zenml/config/pipeline_configurations.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
# permissions and limitations under the License.
"""Pipeline configuration classes."""

from datetime import datetime, timezone
from datetime import datetime
from typing import TYPE_CHECKING, Any, Dict, List, Optional

from pydantic import SerializeAsAny, field_validator
Expand All @@ -23,6 +23,7 @@
from zenml.config.source import SourceWithValidator
from zenml.config.strict_base_model import StrictBaseModel
from zenml.model.model import Model
from zenml.utils.time_utils import utc_now

if TYPE_CHECKING:
from zenml.config import DockerSettings
Expand Down Expand Up @@ -61,7 +62,7 @@ def _get_full_substitutions(
The full substitutions dict including date and time.
"""
if start_time is None:
start_time = datetime.now(timezone.utc)
start_time = utc_now()
ret = self.substitutions.copy()
ret.setdefault("date", start_time.strftime("%Y_%m_%d"))
ret.setdefault("time", start_time.strftime("%H_%M_%S_%f"))
Expand Down
7 changes: 3 additions & 4 deletions src/zenml/event_hub/base_event_hub.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
"""Base class for event hub implementations."""

from abc import ABC, abstractmethod
from datetime import datetime, timedelta, timezone
from datetime import datetime, timedelta
from typing import TYPE_CHECKING, Any, Callable, Dict, Optional, Tuple

from zenml import EventSourceResponse
Expand All @@ -28,6 +28,7 @@
TriggerExecutionResponse,
TriggerResponse,
)
from zenml.utils.time_utils import utc_now
from zenml.zen_server.auth import AuthContext
from zenml.zen_server.jwt import JWTToken

Expand Down Expand Up @@ -134,9 +135,7 @@ def trigger_action(
)
expires: Optional[datetime] = None
if trigger.action.auth_window:
expires = datetime.now(timezone.utc) + timedelta(
minutes=trigger.action.auth_window
)
expires = utc_now() + timedelta(minutes=trigger.action.auth_window)
encoded_token = token.encode(expires=expires)
auth_context = AuthContext(
user=trigger.action.service_account,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@
from zenml.orchestrators.utils import get_orchestrator_run_name
from zenml.stack import StackValidator
from zenml.utils import io_utils
from zenml.utils.time_utils import utc_now

if TYPE_CHECKING:
from zenml.config import ResourceSettings
Expand Down Expand Up @@ -408,8 +409,7 @@ def _translate_schedule(
if schedule:
if schedule.cron_expression:
start_time = schedule.start_time or (
datetime.datetime.now(datetime.timezone.utc)
- datetime.timedelta(7)
utc_now() - datetime.timedelta(7)
)
return {
"schedule": schedule.cron_expression,
Expand All @@ -429,7 +429,6 @@ def _translate_schedule(
"schedule": "@once",
# set a start time in the past and disable catchup so airflow
# runs the dag immediately
"start_date": datetime.datetime.now(datetime.timezone.utc)
- datetime.timedelta(7),
"start_date": utc_now() - datetime.timedelta(7),
"catchup": False,
}
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@

import os
import re
from datetime import datetime, timezone
from datetime import timezone
from typing import (
TYPE_CHECKING,
Any,
Expand Down Expand Up @@ -64,6 +64,7 @@
from zenml.orchestrators.utils import get_orchestrator_run_name
from zenml.stack import StackValidator
from zenml.utils.env_utils import split_environment_variables
from zenml.utils.time_utils import utc_now

if TYPE_CHECKING:
from zenml.models import PipelineDeploymentResponse, PipelineRunResponse
Expand Down Expand Up @@ -553,8 +554,7 @@ def prepare_or_run_pipeline(
enabled=True,
)
next_execution = (
deployment.schedule.start_time
or datetime.now(timezone.utc)
deployment.schedule.start_time or utc_now()
) + deployment.schedule.interval_second
else:
# One-time schedule
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,7 @@
)
from zenml.utils.enum_utils import StrEnum
from zenml.utils.secret_utils import PlainSerializedSecretStr
from zenml.utils.time_utils import utc_now_tz_aware

logger = get_logger(__name__)

Expand Down Expand Up @@ -711,7 +712,7 @@ def get_boto3_session(
return session, None

# Refresh expired sessions
now = datetime.datetime.now(datetime.timezone.utc)
now = utc_now_tz_aware()
expires_at = expires_at.replace(tzinfo=datetime.timezone.utc)
# check if the token expires in the near future
if expires_at > now + datetime.timedelta(
Expand Down Expand Up @@ -959,9 +960,7 @@ def _authenticate(
# determine the expiration time of the temporary credentials
# from the boto3 session, so we assume the default IAM role
# expiration date is used
expiration_time = datetime.datetime.now(
tz=datetime.timezone.utc
) + datetime.timedelta(
expiration_time = utc_now_tz_aware() + datetime.timedelta(
seconds=DEFAULT_IAM_ROLE_TOKEN_EXPIRATION
)
return session, expiration_time
Expand Down Expand Up @@ -1673,9 +1672,7 @@ def _auto_configure(
# expiration time of the temporary credentials from the
# boto3 session, so we assume the default IAM role
# expiration period is used
expires_at = datetime.datetime.now(
tz=datetime.timezone.utc
) + datetime.timedelta(
expires_at = utc_now_tz_aware() + datetime.timedelta(
seconds=DEFAULT_IAM_ROLE_TOKEN_EXPIRATION
)

Expand Down Expand Up @@ -1720,9 +1717,7 @@ def _auto_configure(
aws_secret_access_key=credentials["SecretAccessKey"],
aws_session_token=credentials["SessionToken"],
)
expires_at = datetime.datetime.now(
tz=datetime.timezone.utc
) + datetime.timedelta(
expires_at = utc_now_tz_aware() + datetime.timedelta(
seconds=DEFAULT_STS_TOKEN_EXPIRATION
)

Expand Down Expand Up @@ -2130,9 +2125,9 @@ def _get_connector_client(
# Kubernetes authentication tokens issued by AWS EKS have a fixed
# expiration time of 15 minutes
# source: https://aws.github.io/aws-eks-best-practices/security/docs/iam/#controlling-access-to-eks-clusters
expires_at = datetime.datetime.now(
tz=datetime.timezone.utc
) + datetime.timedelta(minutes=EKS_KUBE_API_TOKEN_EXPIRATION)
expires_at = utc_now_tz_aware() + datetime.timedelta(
minutes=EKS_KUBE_API_TOKEN_EXPIRATION
)

# get cluster details
cluster_arn = cluster["cluster"]["arn"]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@
)
from zenml.utils.enum_utils import StrEnum
from zenml.utils.secret_utils import PlainSerializedSecretStr
from zenml.utils.time_utils import to_local_tz, utc_now

# Configure the logging level for azure.identity
logging.getLogger("azure.identity").setLevel(logging.WARNING)
Expand Down Expand Up @@ -171,12 +172,7 @@ def __init__(self, token: str, expires_at: datetime.datetime):
self.token = token

# Convert the expiration time from UTC to local time
expires_at.replace(tzinfo=datetime.timezone.utc)
expires_at = expires_at.astimezone(
datetime.datetime.now().astimezone().tzinfo
)

self.expires_on = int(expires_at.timestamp())
self.expires_on = int(to_local_tz(expires_at).timestamp())

def get_token(self, *scopes: str, **kwargs: Any) -> Any:
"""Get token.
Expand Down Expand Up @@ -604,11 +600,9 @@ def get_azure_credential(
return session, None

# Refresh expired sessions
now = datetime.datetime.now(datetime.timezone.utc)
expires_at = expires_at.replace(tzinfo=datetime.timezone.utc)

# check if the token expires in the near future
if expires_at > now + datetime.timedelta(
if expires_at > utc_now(tz_aware=expires_at) + datetime.timedelta(
minutes=AZURE_SESSION_EXPIRATION_BUFFER
):
return session, expires_at
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,7 @@
from zenml.utils.enum_utils import StrEnum
from zenml.utils.pydantic_utils import before_validator_handler
from zenml.utils.secret_utils import PlainSerializedSecretStr
from zenml.utils.time_utils import utc_now

logger = get_logger(__name__)

Expand Down Expand Up @@ -1124,10 +1125,9 @@ def get_session(
return session, None

# Refresh expired sessions
now = datetime.datetime.now(datetime.timezone.utc)
expires_at = expires_at.replace(tzinfo=datetime.timezone.utc)

# check if the token expires in the near future
if expires_at > now + datetime.timedelta(
if expires_at > utc_now(tz_aware=expires_at) + datetime.timedelta(
minutes=GCP_SESSION_EXPIRATION_BUFFER
):
return session, expires_at
Expand Down
6 changes: 3 additions & 3 deletions src/zenml/integrations/kubernetes/orchestrators/kube_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,6 @@
Adjusted from https://github.com/tensorflow/tfx/blob/master/tfx/utils/kube_utils.py.
"""

import datetime
import enum
import re
import time
Expand All @@ -47,6 +46,7 @@
build_service_account_manifest,
)
from zenml.logger import get_logger
from zenml.utils.time_utils import utc_now

logger = get_logger(__name__)

Expand Down Expand Up @@ -248,7 +248,7 @@ def wait_pod(
Returns:
The pod object which meets the exit condition.
"""
start_time = datetime.datetime.now(datetime.timezone.utc)
start_time = utc_now()

# Link to exponential back-off algorithm used here:
# https://cloud.google.com/storage/docs/exponential-backoff
Expand Down Expand Up @@ -288,7 +288,7 @@ def wait_pod(
return resp

# Check if wait timed out.
elapse_time = datetime.datetime.now(datetime.timezone.utc) - start_time
elapse_time = utc_now() - start_time
if elapse_time.seconds >= timeout_sec and timeout_sec != 0:
raise RuntimeError(
f"Waiting for pod `{namespace}:{pod_name}` timed out after "
Expand Down
Loading

0 comments on commit c41c43e

Please sign in to comment.