Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feat: Overhaul config using Pydantic #297

Closed
wants to merge 4 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -162,3 +162,4 @@ cython_debug/
#.idea/

openapitools.json
.python-version
7 changes: 3 additions & 4 deletions hatchet_sdk/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from .clients.dispatcher.dispatcher import DispatcherClient, new_dispatcher
from .clients.events import EventClient, new_event
from .clients.rest_client import RestApi
from .loader import ClientConfig, ConfigLoader
from .loader import ClientConfig


class Client:
Expand All @@ -37,11 +37,10 @@ def from_environment(
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)

config: ClientConfig = ConfigLoader(".").load_client_config(defaults)
for opt_function in opts_functions:
opt_function(config)
opt_function(defaults)

return cls.from_config(config, debug)
return cls.from_config(defaults, debug)

@classmethod
def from_config(
Expand Down
5 changes: 2 additions & 3 deletions hatchet_sdk/hatchet.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
from hatchet_sdk.features.cron import CronClient
from hatchet_sdk.features.scheduled import ScheduledClient
from hatchet_sdk.labels import DesiredWorkerLabel
from hatchet_sdk.loader import ClientConfig, ConfigLoader
from hatchet_sdk.loader import ClientConfig
from hatchet_sdk.rate_limit import RateLimit
from hatchet_sdk.v2.callable import HatchetCallable

Expand Down Expand Up @@ -190,8 +190,7 @@ class HatchetRest:
rest: RestApi

def __init__(self, config: ClientConfig = ClientConfig()):
_config: ClientConfig = ConfigLoader(".").load_client_config(config)
self.rest = RestApi(_config.server_url, _config.token, _config.tenant_id)
self.rest = RestApi(config.server_url, config.token, config.tenant_id)


class Hatchet:
Expand Down
329 changes: 95 additions & 234 deletions hatchet_sdk/loader.py
Original file line number Diff line number Diff line change
@@ -1,246 +1,107 @@
import json
import os
from logging import Logger, getLogger
from typing import Dict, Optional

import yaml

from .token import get_addresses_from_jwt, get_tenant_id_from_jwt


class ClientTLSConfig:
def __init__(
self,
tls_strategy: str,
cert_file: str,
key_file: str,
ca_file: str,
server_name: str,
):
self.tls_strategy = tls_strategy
self.cert_file = cert_file
self.key_file = key_file
self.ca_file = ca_file
self.server_name = server_name


class ClientConfig:
logInterceptor: Logger

def __init__(
self,
tenant_id: str = None,
tls_config: ClientTLSConfig = None,
token: str = None,
host_port: str = "localhost:7070",
server_url: str = "https://app.dev.hatchet-tools.com",
namespace: str = None,
listener_v2_timeout: int = None,
logger: Logger = None,
grpc_max_recv_message_length: int = 4 * 1024 * 1024, # 4MB
grpc_max_send_message_length: int = 4 * 1024 * 1024, # 4MB
otel_exporter_oltp_endpoint: str | None = None,
otel_service_name: str | None = None,
otel_exporter_oltp_headers: dict[str, str] | None = None,
otel_exporter_oltp_protocol: str | None = None,
worker_healthcheck_port: int | None = None,
worker_healthcheck_enabled: bool | None = None,
):
self.tenant_id = tenant_id
self.tls_config = tls_config
self.host_port = host_port
self.token = token
self.server_url = server_url
self.namespace = ""
self.logInterceptor = logger
self.grpc_max_recv_message_length = grpc_max_recv_message_length
self.grpc_max_send_message_length = grpc_max_send_message_length
self.otel_exporter_oltp_endpoint = otel_exporter_oltp_endpoint
self.otel_service_name = otel_service_name
self.otel_exporter_oltp_headers = otel_exporter_oltp_headers
self.otel_exporter_oltp_protocol = otel_exporter_oltp_protocol
self.worker_healthcheck_port = worker_healthcheck_port
self.worker_healthcheck_enabled = worker_healthcheck_enabled

if not self.logInterceptor:
self.logInterceptor = getLogger()

# case on whether the namespace already has a trailing underscore
if namespace and not namespace.endswith("_"):
self.namespace = f"{namespace}_"
elif namespace:
self.namespace = namespace

self.namespace = self.namespace.lower()

self.listener_v2_timeout = listener_v2_timeout


class ConfigLoader:
def __init__(self, directory: str):
self.directory = directory

def load_client_config(self, defaults: ClientConfig) -> ClientConfig:
config_file_path = os.path.join(self.directory, "client.yaml")
config_data: object = {"tls": {}}

# determine if client.yaml exists
if os.path.exists(config_file_path):
with open(config_file_path, "r") as file:
config_data = yaml.safe_load(file)

def get_config_value(key, env_var):
if key in config_data:
return config_data[key]

if self._get_env_var(env_var) is not None:
return self._get_env_var(env_var)

return getattr(defaults, key, None)

namespace = get_config_value("namespace", "HATCHET_CLIENT_NAMESPACE")

tenant_id = get_config_value("tenantId", "HATCHET_CLIENT_TENANT_ID")
token = get_config_value("token", "HATCHET_CLIENT_TOKEN")
listener_v2_timeout = get_config_value(
"listener_v2_timeout", "HATCHET_CLIENT_LISTENER_V2_TIMEOUT"
)
listener_v2_timeout = int(listener_v2_timeout) if listener_v2_timeout else None

from typing import cast

from pydantic import BaseModel, ValidationError, ValidationInfo, field_validator

from .token import get_tenant_id_from_jwt


class ClientTLSConfig(BaseModel):
tls_strategy: str
cert_file: str | None
key_file: str | None
ca_file: str | None
server_name: str


def _load_tls_config(host_port: str) -> ClientTLSConfig:
return ClientTLSConfig(
tls_strategy=os.getenv("HATCHET_CLIENT_TLS_STRATEGY", "tls"),
cert_file=os.getenv("HATCHET_CLIENT_TLS_CERT_FILE"),
key_file=os.getenv("HATCHET_CLIENT_TLS_KEY_FILE"),
ca_file=os.getenv("HATCHET_CLIENT_TLS_ROOT_CA_FILE"),
server_name=os.getenv(
"HATCHET_CLIENT_TLS_SERVER_NAME", host_port.split(":")[0]
),
)


def parse_listener_timeout(timeout: str | None) -> int | None:
if timeout is None:
return None

return int(timeout)


class ClientConfig(BaseModel):
token: str = os.getenv("HATCHET_CLIENT_TOKEN", "")
logger: Logger = getLogger()
tenant_id: str = os.getenv("HATCHET_CLIENT_TENANT_ID", "")
host_port: str = os.getenv("HATCHET_CLIENT_HOST_PORT", "localhost:7070")
tls_config: ClientTLSConfig = _load_tls_config(host_port)
server_url: str = "https://app.dev.hatchet-tools.com"
namespace: str = os.getenv("HATCHET_CLIENT_NAMESPACE", "")
listener_v2_timeout: int | None = parse_listener_timeout(
os.getenv("HATCHET_CLIENT_LISTENER_V2_TIMEOUT")
)
grpc_max_recv_message_length: int = int(
os.getenv("HATCHET_CLIENT_GRPC_MAX_RECV_MESSAGE_LENGTH", 4 * 1024 * 1024)
) # 4MB
grpc_max_send_message_length: int = int(
os.getenv("HATCHET_CLIENT_GRPC_MAX_SEND_MESSAGE_LENGTH", 4 * 1024 * 1024)
) # 4MB
otel_exporter_oltp_endpoint: str | None = os.getenv(
"HATCHET_CLIENT_OTEL_EXPORTER_OTLP_ENDPOINT"
)
otel_service_name: str | None = os.getenv("HATCHET_CLIENT_OTEL_SERVICE_NAME")
otel_exporter_oltp_headers: str | None = os.getenv(
"HATCHET_CLIENT_OTEL_EXPORTER_OTLP_HEADERS"
)
otel_exporter_oltp_protocol: str | None = os.getenv(
"HATCHET_CLIENT_OTEL_EXPORTER_OTLP_PROTOCOL"
)
worker_healthcheck_port: int = int(
os.getenv("HATCHET_CLIENT_WORKER_HEALTHCHECK_PORT", 8001)
)
worker_healthcheck_enabled: bool = (
os.getenv("HATCHET_CLIENT_WORKER_HEALTHCHECK_ENABLED", "False") == "True"
)

@field_validator("token", mode="after")
@classmethod
def validate_token(cls, token: str) -> str:
if not token:
raise ValueError(
"Token must be set via HATCHET_CLIENT_TOKEN environment variable"
)

host_port = get_config_value("hostPort", "HATCHET_CLIENT_HOST_PORT")
server_url: str | None = None
raise ValidationError("Token must be set")

grpc_max_recv_message_length = get_config_value(
"grpc_max_recv_message_length",
"HATCHET_CLIENT_GRPC_MAX_RECV_MESSAGE_LENGTH",
)
grpc_max_send_message_length = get_config_value(
"grpc_max_send_message_length",
"HATCHET_CLIENT_GRPC_MAX_SEND_MESSAGE_LENGTH",
)
return token

if grpc_max_recv_message_length:
grpc_max_recv_message_length = int(grpc_max_recv_message_length)
@field_validator("namespace", mode="after")
@classmethod
def validate_namespace(cls, namespace: str) -> str:
if not namespace.endswith("_"):
namespace = f"{namespace}_"

if grpc_max_send_message_length:
grpc_max_send_message_length = int(grpc_max_send_message_length)
return namespace.lower()

if not host_port:
# extract host and port from token
server_url, grpc_broadcast_address = get_addresses_from_jwt(token)
host_port = grpc_broadcast_address
@field_validator("tenant_id", mode="after")
@classmethod
def validate_tenant_id(cls, tenant_id: str, info: ValidationInfo) -> str:
token = cast(str | None, info.data.get("token"))

if not tenant_id:
tenant_id = get_tenant_id_from_jwt(token)

tls_config = self._load_tls_config(config_data["tls"], host_port)

otel_exporter_oltp_endpoint = get_config_value(
"otel_exporter_oltp_endpoint", "HATCHET_CLIENT_OTEL_EXPORTER_OTLP_ENDPOINT"
)
if not token:
raise ValidationError(
"Token must be set before attempting to infer tenant ID"
)

otel_service_name = get_config_value(
"otel_service_name", "HATCHET_CLIENT_OTEL_SERVICE_NAME"
)
return get_tenant_id_from_jwt(token)

_oltp_headers = get_config_value(
"otel_exporter_oltp_headers", "HATCHET_CLIENT_OTEL_EXPORTER_OTLP_HEADERS"
)
return tenant_id

if _oltp_headers:
try:
otel_header_key, api_key = _oltp_headers.split("=", maxsplit=1)
otel_exporter_oltp_headers = {otel_header_key: api_key}
except ValueError:
raise ValueError(
"HATCHET_CLIENT_OTEL_EXPORTER_OTLP_HEADERS must be in the format `key=value`"
)
else:
otel_exporter_oltp_headers = None

otel_exporter_oltp_protocol = get_config_value(
"otel_exporter_oltp_protocol", "HATCHET_CLIENT_OTEL_EXPORTER_OTLP_PROTOCOL"
)

worker_healthcheck_port = int(
get_config_value(
"worker_healthcheck_port", "HATCHET_CLIENT_WORKER_HEALTHCHECK_PORT"
)
or 8001
)

worker_healthcheck_enabled = (
str(
get_config_value(
"worker_healthcheck_port",
"HATCHET_CLIENT_WORKER_HEALTHCHECK_ENABLED",
)
)
== "True"
)

return ClientConfig(
tenant_id=tenant_id,
tls_config=tls_config,
token=token,
host_port=host_port,
server_url=server_url,
namespace=namespace,
listener_v2_timeout=listener_v2_timeout,
logger=defaults.logInterceptor,
grpc_max_recv_message_length=grpc_max_recv_message_length,
grpc_max_send_message_length=grpc_max_send_message_length,
otel_exporter_oltp_endpoint=otel_exporter_oltp_endpoint,
otel_service_name=otel_service_name,
otel_exporter_oltp_headers=otel_exporter_oltp_headers,
otel_exporter_oltp_protocol=otel_exporter_oltp_protocol,
worker_healthcheck_port=worker_healthcheck_port,
worker_healthcheck_enabled=worker_healthcheck_enabled,
)

def _load_tls_config(self, tls_data: Dict, host_port) -> ClientTLSConfig:
tls_strategy = (
tls_data["tlsStrategy"]
if "tlsStrategy" in tls_data
else self._get_env_var("HATCHET_CLIENT_TLS_STRATEGY")
)

if not tls_strategy:
tls_strategy = "tls"

cert_file = (
tls_data["tlsCertFile"]
if "tlsCertFile" in tls_data
else self._get_env_var("HATCHET_CLIENT_TLS_CERT_FILE")
)
key_file = (
tls_data["tlsKeyFile"]
if "tlsKeyFile" in tls_data
else self._get_env_var("HATCHET_CLIENT_TLS_KEY_FILE")
)
ca_file = (
tls_data["tlsRootCAFile"]
if "tlsRootCAFile" in tls_data
else self._get_env_var("HATCHET_CLIENT_TLS_ROOT_CA_FILE")
)

server_name = (
tls_data["tlsServerName"]
if "tlsServerName" in tls_data
else self._get_env_var("HATCHET_CLIENT_TLS_SERVER_NAME")
)

# if server_name is not set, use the host from the host_port
if not server_name:
server_name = host_port.split(":")[0]

return ClientTLSConfig(tls_strategy, cert_file, key_file, ca_file, server_name)

@staticmethod
def _get_env_var(env_var: str, default: Optional[str] = None) -> str:
return os.environ.get(env_var, default)
## TODO: Fix host port overrides here
## Old code:
## if not host_port:
## ## extract host and port from token
## server_url, grpc_broadcast_address = get_addresses_from_jwt(token)
## host_port = grpc_broadcast_address
Loading
Loading