diff --git a/src/dispatch/__init__.py b/src/dispatch/__init__.py index 154257db..45a45e75 100644 --- a/src/dispatch/__init__.py +++ b/src/dispatch/__init__.py @@ -4,7 +4,7 @@ import dispatch.integrations from dispatch.coroutine import all, any, call, gather, race -from dispatch.function import DEFAULT_API_URL, Client +from dispatch.function import DEFAULT_API_URL, Client, Registry from dispatch.id import DispatchID from dispatch.proto import Call, Error, Input, Output from dispatch.status import Status @@ -23,4 +23,5 @@ "all", "any", "race", + "Registry", ] diff --git a/src/dispatch/fastapi.py b/src/dispatch/fastapi.py index 97c4f7ed..0cb98384 100644 --- a/src/dispatch/fastapi.py +++ b/src/dispatch/fastapi.py @@ -115,8 +115,7 @@ def __init__( "request verification is disabled because DISPATCH_VERIFICATION_KEY is not set" ) - self.client = Client(api_key=api_key, api_url=api_url) - super().__init__(endpoint, self.client) + super().__init__(endpoint, api_key=api_key, api_url=api_url) function_service = _new_app(self, verification_key) app.mount("/dispatch.sdk.v1.FunctionService", function_service) @@ -225,7 +224,7 @@ async def execute(request: fastapi.Request): raise _ConnectError(400, "invalid_argument", "function is required") try: - func = function_registry._functions[req.function] + func = function_registry.functions[req.function] except KeyError: logger.debug("function '%s' not found", req.function) raise _ConnectError( diff --git a/src/dispatch/function.py b/src/dispatch/function.py index 55277fa9..db14868a 100644 --- a/src/dispatch/function.py +++ b/src/dispatch/function.py @@ -100,7 +100,6 @@ def __init__( client: Client, name: str, primitive_func: PrimitiveFunctionType, - func: Callable, ): PrimitiveFunction.__init__(self, endpoint, client, name, primitive_func) @@ -158,21 +157,30 @@ def build_call( class Registry: - """Registry of local functions.""" + """Registry of functions.""" - __slots__ = ("_functions", "_endpoint", "_client") + __slots__ = ("functions", "endpoint", "client") - def __init__(self, endpoint: str, client: Client): - """Initialize a local function registry. + def __init__( + self, endpoint: str, api_key: str | None = None, api_url: str | None = None + ): + """Initialize a function registry. Args: endpoint: URL of the endpoint that the function is accessible from. - client: Client for the Dispatch API. Used to dispatch calls to - local functions. + + api_key: Dispatch API key to use for authentication when + dispatching calls to functions. Uses the value of the + DISPATCH_API_KEY environment variable by default. + + api_url: The URL of the Dispatch API to use when dispatching calls + to functions. Uses the value of the DISPATCH_API_URL environment + variable if set, otherwise defaults to the public Dispatch API + (DEFAULT_API_URL). """ - self._functions: Dict[str, PrimitiveFunction] = {} - self._endpoint = endpoint - self._client = client + self.functions: Dict[str, PrimitiveFunction] = {} + self.endpoint = endpoint + self.client = Client(api_key=api_key, api_url=api_url) @overload def function(self, func: Callable[P, Coroutine[Any, Any, T]]) -> Function[P, T]: ... @@ -215,9 +223,7 @@ def primitive_func(input: Input) -> Output: primitive_func.__qualname__ = f"{name}_primitive" primitive_func = durable(primitive_func) - wrapped_func = Function[P, T]( - self._endpoint, self._client, name, primitive_func, func - ) + wrapped_func = Function[P, T](self.endpoint, self.client, name, primitive_func) self._register(name, wrapped_func) return wrapped_func @@ -228,20 +234,20 @@ def primitive_function( name = primitive_func.__qualname__ logger.info("registering primitive function: %s", name) wrapped_func = PrimitiveFunction( - self._endpoint, self._client, name, primitive_func + self.endpoint, self.client, name, primitive_func ) self._register(name, wrapped_func) return wrapped_func def _register(self, name: str, wrapped_func: PrimitiveFunction): - if name in self._functions: + if name in self.functions: raise ValueError(f"function already registered with name '{name}'") - self._functions[name] = wrapped_func + self.functions[name] = wrapped_func def set_client(self, client: Client): - """Set the Client instance used to dispatch calls to local functions.""" - self._client = client - for fn in self._functions.values(): + """Set the Client instance used to dispatch calls to registered functions.""" + self.client = client + for fn in self.functions.values(): fn._client = client diff --git a/tests/dispatch/test_function.py b/tests/dispatch/test_function.py index 6f4a93ab..0befc0f7 100644 --- a/tests/dispatch/test_function.py +++ b/tests/dispatch/test_function.py @@ -6,8 +6,11 @@ class TestFunction(unittest.TestCase): def setUp(self): - self.client = Client(api_url="http://dispatch.com", api_key="foobar") - self.dispatch = Registry(endpoint="http://example.com", client=self.client) + self.dispatch = Registry( + endpoint="http://example.com", + api_url="http://dispatch.com", + api_key="foobar", + ) def test_serializable(self): @self.dispatch.function