From 88c1aad2faa77bb2b3fa3b36d2686bc6ed72b116 Mon Sep 17 00:00:00 2001 From: mrkaye97 Date: Mon, 13 Jan 2025 12:56:22 -0500 Subject: [PATCH 1/4] feat: overhaul config to use pydantic --- hatchet_sdk/client.py | 9 +- hatchet_sdk/hatchet.py | 6 +- hatchet_sdk/loader.py | 329 ++++++++++------------------------- hatchet_sdk/utils/tracing.py | 16 +- 4 files changed, 115 insertions(+), 245 deletions(-) diff --git a/hatchet_sdk/client.py b/hatchet_sdk/client.py index 45dfd394..b45fc39b 100644 --- a/hatchet_sdk/client.py +++ b/hatchet_sdk/client.py @@ -12,7 +12,7 @@ from .clients.dispatcher.dispatcher import DispatcherClient, new_dispatcher from .clients.events import EventClient, new_event from .clients.rest_client import RestApi -from .loader import ClientConfig, ConfigLoader +from .loader import ClientConfig class Client: @@ -37,11 +37,10 @@ def from_environment( loop = asyncio.new_event_loop() asyncio.set_event_loop(loop) - config: ClientConfig = ConfigLoader(".").load_client_config(defaults) for opt_function in opts_functions: - opt_function(config) + opt_function(defaults) - return cls.from_config(config, debug) + return cls.from_config(defaults, debug) @classmethod def from_config( @@ -116,4 +115,4 @@ def with_host_port_impl(config: ClientConfig): new_client = Client.from_environment -new_client_raw = Client.from_config +new_client_raw = Client.from_config \ No newline at end of file diff --git a/hatchet_sdk/hatchet.py b/hatchet_sdk/hatchet.py index bf0e9089..198aeb1e 100644 --- a/hatchet_sdk/hatchet.py +++ b/hatchet_sdk/hatchet.py @@ -16,7 +16,7 @@ from hatchet_sdk.features.cron import CronClient from hatchet_sdk.features.scheduled import ScheduledClient from hatchet_sdk.labels import DesiredWorkerLabel -from hatchet_sdk.loader import ClientConfig, ConfigLoader +from hatchet_sdk.loader import ClientConfig from hatchet_sdk.rate_limit import RateLimit from hatchet_sdk.v2.callable import HatchetCallable @@ -190,9 +190,7 @@ class HatchetRest: rest: RestApi def __init__(self, config: ClientConfig = ClientConfig()): - _config: ClientConfig = ConfigLoader(".").load_client_config(config) - self.rest = RestApi(_config.server_url, _config.token, _config.tenant_id) - + self.rest = RestApi(config.server_url, config.token, config.tenant_id) class Hatchet: """ diff --git a/hatchet_sdk/loader.py b/hatchet_sdk/loader.py index d754c2ae..7ec30f06 100644 --- a/hatchet_sdk/loader.py +++ b/hatchet_sdk/loader.py @@ -1,246 +1,107 @@ -import json import os from logging import Logger, getLogger -from typing import Dict, Optional - -import yaml - -from .token import get_addresses_from_jwt, get_tenant_id_from_jwt - - -class ClientTLSConfig: - def __init__( - self, - tls_strategy: str, - cert_file: str, - key_file: str, - ca_file: str, - server_name: str, - ): - self.tls_strategy = tls_strategy - self.cert_file = cert_file - self.key_file = key_file - self.ca_file = ca_file - self.server_name = server_name - - -class ClientConfig: - logInterceptor: Logger - - def __init__( - self, - tenant_id: str = None, - tls_config: ClientTLSConfig = None, - token: str = None, - host_port: str = "localhost:7070", - server_url: str = "https://app.dev.hatchet-tools.com", - namespace: str = None, - listener_v2_timeout: int = None, - logger: Logger = None, - grpc_max_recv_message_length: int = 4 * 1024 * 1024, # 4MB - grpc_max_send_message_length: int = 4 * 1024 * 1024, # 4MB - otel_exporter_oltp_endpoint: str | None = None, - otel_service_name: str | None = None, - otel_exporter_oltp_headers: dict[str, str] | None = None, - otel_exporter_oltp_protocol: str | None = None, - worker_healthcheck_port: int | None = None, - worker_healthcheck_enabled: bool | None = None, - ): - self.tenant_id = tenant_id - self.tls_config = tls_config - self.host_port = host_port - self.token = token - self.server_url = server_url - self.namespace = "" - self.logInterceptor = logger - self.grpc_max_recv_message_length = grpc_max_recv_message_length - self.grpc_max_send_message_length = grpc_max_send_message_length - self.otel_exporter_oltp_endpoint = otel_exporter_oltp_endpoint - self.otel_service_name = otel_service_name - self.otel_exporter_oltp_headers = otel_exporter_oltp_headers - self.otel_exporter_oltp_protocol = otel_exporter_oltp_protocol - self.worker_healthcheck_port = worker_healthcheck_port - self.worker_healthcheck_enabled = worker_healthcheck_enabled - - if not self.logInterceptor: - self.logInterceptor = getLogger() - - # case on whether the namespace already has a trailing underscore - if namespace and not namespace.endswith("_"): - self.namespace = f"{namespace}_" - elif namespace: - self.namespace = namespace - - self.namespace = self.namespace.lower() - - self.listener_v2_timeout = listener_v2_timeout - - -class ConfigLoader: - def __init__(self, directory: str): - self.directory = directory - - def load_client_config(self, defaults: ClientConfig) -> ClientConfig: - config_file_path = os.path.join(self.directory, "client.yaml") - config_data: object = {"tls": {}} - - # determine if client.yaml exists - if os.path.exists(config_file_path): - with open(config_file_path, "r") as file: - config_data = yaml.safe_load(file) - - def get_config_value(key, env_var): - if key in config_data: - return config_data[key] - - if self._get_env_var(env_var) is not None: - return self._get_env_var(env_var) - - return getattr(defaults, key, None) - - namespace = get_config_value("namespace", "HATCHET_CLIENT_NAMESPACE") - - tenant_id = get_config_value("tenantId", "HATCHET_CLIENT_TENANT_ID") - token = get_config_value("token", "HATCHET_CLIENT_TOKEN") - listener_v2_timeout = get_config_value( - "listener_v2_timeout", "HATCHET_CLIENT_LISTENER_V2_TIMEOUT" - ) - listener_v2_timeout = int(listener_v2_timeout) if listener_v2_timeout else None - +from typing import cast + +from pydantic import BaseModel, ValidationError, ValidationInfo, field_validator + +from .token import get_tenant_id_from_jwt + + +class ClientTLSConfig(BaseModel): + tls_strategy: str + cert_file: str | None + key_file: str | None + ca_file: str | None + server_name: str + + +def _load_tls_config(host_port: str) -> ClientTLSConfig: + return ClientTLSConfig( + tls_strategy=os.getenv("HATCHET_CLIENT_TLS_STRATEGY", "tls"), + cert_file=os.getenv("HATCHET_CLIENT_TLS_CERT_FILE"), + key_file=os.getenv("HATCHET_CLIENT_TLS_KEY_FILE"), + ca_file=os.getenv("HATCHET_CLIENT_TLS_ROOT_CA_FILE"), + server_name=os.getenv( + "HATCHET_CLIENT_TLS_SERVER_NAME", host_port.split(":")[0] + ), + ) + + +def parse_listener_timeout(timeout: str | None) -> int | None: + if timeout is None: + return None + + return int(timeout) + + +class ClientConfig(BaseModel): + token: str = os.getenv("HATCHET_CLIENT_TOKEN", "") + logger: Logger = getLogger() + tenant_id: str = os.getenv("HATCHET_CLIENT_TENANT_ID", "") + host_port: str = os.getenv("HATCHET_CLIENT_HOST_PORT", "localhost:7070") + tls_config: ClientTLSConfig = _load_tls_config(host_port) + server_url: str = "https://app.dev.hatchet-tools.com" + namespace: str = os.getenv("HATCHET_CLIENT_NAMESPACE", "") + listener_v2_timeout: int | None = parse_listener_timeout( + os.getenv("HATCHET_CLIENT_LISTENER_V2_TIMEOUT") + ) + grpc_max_recv_message_length: int = int( + os.getenv("HATCHET_CLIENT_GRPC_MAX_RECV_MESSAGE_LENGTH", 4 * 1024 * 1024) + ) # 4MB + grpc_max_send_message_length: int = int( + os.getenv("HATCHET_CLIENT_GRPC_MAX_SEND_MESSAGE_LENGTH", 4 * 1024 * 1024) + ) # 4MB + otel_exporter_oltp_endpoint: str | None = os.getenv( + "HATCHET_CLIENT_OTEL_EXPORTER_OTLP_ENDPOINT" + ) + otel_service_name: str | None = os.getenv("HATCHET_CLIENT_OTEL_SERVICE_NAME") + otel_exporter_oltp_headers: str | None = os.getenv( + "HATCHET_CLIENT_OTEL_EXPORTER_OTLP_HEADERS" + ) + otel_exporter_oltp_protocol: str | None = os.getenv( + "HATCHET_CLIENT_OTEL_EXPORTER_OTLP_PROTOCOL" + ) + worker_healthcheck_port: int = int( + os.getenv("HATCHET_CLIENT_WORKER_HEALTHCHECK_PORT", 8001) + ) + worker_healthcheck_enabled: bool = ( + os.getenv("HATCHET_CLIENT_WORKER_HEALTHCHECK_ENABLED", "False") == "True" + ) + + @field_validator("token", mode="after") + @classmethod + def validate_token(cls, token: str) -> str: if not token: - raise ValueError( - "Token must be set via HATCHET_CLIENT_TOKEN environment variable" - ) - - host_port = get_config_value("hostPort", "HATCHET_CLIENT_HOST_PORT") - server_url: str | None = None + raise ValidationError("Token must be set") - grpc_max_recv_message_length = get_config_value( - "grpc_max_recv_message_length", - "HATCHET_CLIENT_GRPC_MAX_RECV_MESSAGE_LENGTH", - ) - grpc_max_send_message_length = get_config_value( - "grpc_max_send_message_length", - "HATCHET_CLIENT_GRPC_MAX_SEND_MESSAGE_LENGTH", - ) + return token - if grpc_max_recv_message_length: - grpc_max_recv_message_length = int(grpc_max_recv_message_length) + @field_validator("namespace", mode="after") + @classmethod + def validate_namespace(cls, namespace: str) -> str: + if not namespace.endswith("_"): + namespace = f"{namespace}_" - if grpc_max_send_message_length: - grpc_max_send_message_length = int(grpc_max_send_message_length) + return namespace.lower() - if not host_port: - # extract host and port from token - server_url, grpc_broadcast_address = get_addresses_from_jwt(token) - host_port = grpc_broadcast_address + @field_validator("tenant_id", mode="after") + @classmethod + def validate_tenant_id(cls, tenant_id: str, info: ValidationInfo) -> str: + token = cast(str | None, info.data.get("token")) if not tenant_id: - tenant_id = get_tenant_id_from_jwt(token) - - tls_config = self._load_tls_config(config_data["tls"], host_port) - - otel_exporter_oltp_endpoint = get_config_value( - "otel_exporter_oltp_endpoint", "HATCHET_CLIENT_OTEL_EXPORTER_OTLP_ENDPOINT" - ) + if not token: + raise ValidationError( + "Token must be set before attempting to infer tenant ID" + ) - otel_service_name = get_config_value( - "otel_service_name", "HATCHET_CLIENT_OTEL_SERVICE_NAME" - ) + return get_tenant_id_from_jwt(token) - _oltp_headers = get_config_value( - "otel_exporter_oltp_headers", "HATCHET_CLIENT_OTEL_EXPORTER_OTLP_HEADERS" - ) + return tenant_id - if _oltp_headers: - try: - otel_header_key, api_key = _oltp_headers.split("=", maxsplit=1) - otel_exporter_oltp_headers = {otel_header_key: api_key} - except ValueError: - raise ValueError( - "HATCHET_CLIENT_OTEL_EXPORTER_OTLP_HEADERS must be in the format `key=value`" - ) - else: - otel_exporter_oltp_headers = None - - otel_exporter_oltp_protocol = get_config_value( - "otel_exporter_oltp_protocol", "HATCHET_CLIENT_OTEL_EXPORTER_OTLP_PROTOCOL" - ) - - worker_healthcheck_port = int( - get_config_value( - "worker_healthcheck_port", "HATCHET_CLIENT_WORKER_HEALTHCHECK_PORT" - ) - or 8001 - ) - - worker_healthcheck_enabled = ( - str( - get_config_value( - "worker_healthcheck_port", - "HATCHET_CLIENT_WORKER_HEALTHCHECK_ENABLED", - ) - ) - == "True" - ) - - return ClientConfig( - tenant_id=tenant_id, - tls_config=tls_config, - token=token, - host_port=host_port, - server_url=server_url, - namespace=namespace, - listener_v2_timeout=listener_v2_timeout, - logger=defaults.logInterceptor, - grpc_max_recv_message_length=grpc_max_recv_message_length, - grpc_max_send_message_length=grpc_max_send_message_length, - otel_exporter_oltp_endpoint=otel_exporter_oltp_endpoint, - otel_service_name=otel_service_name, - otel_exporter_oltp_headers=otel_exporter_oltp_headers, - otel_exporter_oltp_protocol=otel_exporter_oltp_protocol, - worker_healthcheck_port=worker_healthcheck_port, - worker_healthcheck_enabled=worker_healthcheck_enabled, - ) - - def _load_tls_config(self, tls_data: Dict, host_port) -> ClientTLSConfig: - tls_strategy = ( - tls_data["tlsStrategy"] - if "tlsStrategy" in tls_data - else self._get_env_var("HATCHET_CLIENT_TLS_STRATEGY") - ) - - if not tls_strategy: - tls_strategy = "tls" - - cert_file = ( - tls_data["tlsCertFile"] - if "tlsCertFile" in tls_data - else self._get_env_var("HATCHET_CLIENT_TLS_CERT_FILE") - ) - key_file = ( - tls_data["tlsKeyFile"] - if "tlsKeyFile" in tls_data - else self._get_env_var("HATCHET_CLIENT_TLS_KEY_FILE") - ) - ca_file = ( - tls_data["tlsRootCAFile"] - if "tlsRootCAFile" in tls_data - else self._get_env_var("HATCHET_CLIENT_TLS_ROOT_CA_FILE") - ) - - server_name = ( - tls_data["tlsServerName"] - if "tlsServerName" in tls_data - else self._get_env_var("HATCHET_CLIENT_TLS_SERVER_NAME") - ) - - # if server_name is not set, use the host from the host_port - if not server_name: - server_name = host_port.split(":")[0] - - return ClientTLSConfig(tls_strategy, cert_file, key_file, ca_file, server_name) - - @staticmethod - def _get_env_var(env_var: str, default: Optional[str] = None) -> str: - return os.environ.get(env_var, default) + ## TODO: Fix host port overrides here + ## Old code: + ## if not host_port: + ## ## extract host and port from token + ## server_url, grpc_broadcast_address = get_addresses_from_jwt(token) + ## host_port = grpc_broadcast_address \ No newline at end of file diff --git a/hatchet_sdk/utils/tracing.py b/hatchet_sdk/utils/tracing.py index afc398f7..72509f6f 100644 --- a/hatchet_sdk/utils/tracing.py +++ b/hatchet_sdk/utils/tracing.py @@ -16,6 +16,18 @@ OTEL_CARRIER_KEY = "__otel_carrier" +def parse_headers(headers: str | None) -> dict[str, str]: + if headers is None: + return {} + + try: + otel_header_key, api_key = headers.split("=", maxsplit=1) + + return {otel_header_key: api_key} + except ValueError: + raise ValueError("OTLP headers must be in the format `key=value`") + + @cache def create_tracer(config: ClientConfig) -> Tracer: ## TODO: Figure out how to specify protocol here @@ -27,7 +39,7 @@ def create_tracer(config: ClientConfig) -> Tracer: processor = BatchSpanProcessor( OTLPSpanExporter( endpoint=config.otel_exporter_oltp_endpoint, - headers=config.otel_exporter_oltp_headers, + headers=parse_headers(config.otel_exporter_oltp_headers), ), ) @@ -67,4 +79,4 @@ def parse_carrier_from_metadata(metadata: dict[str, Any] | None) -> Context | No TraceContextTextMapPropagator().extract(_ctx) if (_ctx := metadata.get(OTEL_CARRIER_KEY)) else None - ) + ) \ No newline at end of file From 6d581f7a1cf533cd4c2a15d11d746a4a78dee2ac Mon Sep 17 00:00:00 2001 From: mrkaye97 Date: Mon, 13 Jan 2025 12:57:00 -0500 Subject: [PATCH 2/4] fix: ignore python-version --- .gitignore | 1 + 1 file changed, 1 insertion(+) diff --git a/.gitignore b/.gitignore index 6b15a2af..a8fca96c 100644 --- a/.gitignore +++ b/.gitignore @@ -162,3 +162,4 @@ cython_debug/ #.idea/ openapitools.json +.python-version From 42d7539b7f4e7aec48135c269c47cdfae49afc52 Mon Sep 17 00:00:00 2001 From: mrkaye97 Date: Mon, 13 Jan 2025 12:57:28 -0500 Subject: [PATCH 3/4] fix: lint --- hatchet_sdk/client.py | 2 +- hatchet_sdk/hatchet.py | 1 + hatchet_sdk/loader.py | 2 +- hatchet_sdk/utils/tracing.py | 2 +- 4 files changed, 4 insertions(+), 3 deletions(-) diff --git a/hatchet_sdk/client.py b/hatchet_sdk/client.py index b45fc39b..f52a9418 100644 --- a/hatchet_sdk/client.py +++ b/hatchet_sdk/client.py @@ -115,4 +115,4 @@ def with_host_port_impl(config: ClientConfig): new_client = Client.from_environment -new_client_raw = Client.from_config \ No newline at end of file +new_client_raw = Client.from_config diff --git a/hatchet_sdk/hatchet.py b/hatchet_sdk/hatchet.py index 198aeb1e..3977d62f 100644 --- a/hatchet_sdk/hatchet.py +++ b/hatchet_sdk/hatchet.py @@ -192,6 +192,7 @@ class HatchetRest: def __init__(self, config: ClientConfig = ClientConfig()): self.rest = RestApi(config.server_url, config.token, config.tenant_id) + class Hatchet: """ Main client for interacting with the Hatchet SDK. diff --git a/hatchet_sdk/loader.py b/hatchet_sdk/loader.py index 7ec30f06..265aa4f0 100644 --- a/hatchet_sdk/loader.py +++ b/hatchet_sdk/loader.py @@ -104,4 +104,4 @@ def validate_tenant_id(cls, tenant_id: str, info: ValidationInfo) -> str: ## if not host_port: ## ## extract host and port from token ## server_url, grpc_broadcast_address = get_addresses_from_jwt(token) - ## host_port = grpc_broadcast_address \ No newline at end of file + ## host_port = grpc_broadcast_address diff --git a/hatchet_sdk/utils/tracing.py b/hatchet_sdk/utils/tracing.py index 72509f6f..19341780 100644 --- a/hatchet_sdk/utils/tracing.py +++ b/hatchet_sdk/utils/tracing.py @@ -79,4 +79,4 @@ def parse_carrier_from_metadata(metadata: dict[str, Any] | None) -> Context | No TraceContextTextMapPropagator().extract(_ctx) if (_ctx := metadata.get(OTEL_CARRIER_KEY)) else None - ) \ No newline at end of file + ) From 2223426bba6c0ace1f4367392dd0d5d933b7203a Mon Sep 17 00:00:00 2001 From: mrkaye97 Date: Mon, 13 Jan 2025 12:58:46 -0500 Subject: [PATCH 4/4] feat: add loader to mypy --- pyproject.toml | 1 + 1 file changed, 1 insertion(+) diff --git a/pyproject.toml b/pyproject.toml index 69380b67..7207c430 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -93,6 +93,7 @@ files = [ "hatchet_sdk/clients/rest/models/workflow_run.py", "hatchet_sdk/context/worker_context.py", "hatchet_sdk/clients/dispatcher/dispatcher.py", + "hatchet_sdk/loader.py", ] follow_imports = "silent" disable_error_code = ["unused-coroutine"]