Skip to content

Commit

Permalink
Merge pull request #134 from stealthrocket/remote-functions
Browse files Browse the repository at this point in the history
Remote endpoints
  • Loading branch information
chriso authored Mar 21, 2024
2 parents ada9942 + e90a64c commit 99b9ed6
Show file tree
Hide file tree
Showing 4 changed files with 34 additions and 25 deletions.
3 changes: 2 additions & 1 deletion src/dispatch/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

import dispatch.integrations
from dispatch.coroutine import all, any, call, gather, race
from dispatch.function import DEFAULT_API_URL, Client
from dispatch.function import DEFAULT_API_URL, Client, Registry
from dispatch.id import DispatchID
from dispatch.proto import Call, Error, Input, Output
from dispatch.status import Status
Expand All @@ -23,4 +23,5 @@
"all",
"any",
"race",
"Registry",
]
5 changes: 2 additions & 3 deletions src/dispatch/fastapi.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,8 +115,7 @@ def __init__(
"request verification is disabled because DISPATCH_VERIFICATION_KEY is not set"
)

self.client = Client(api_key=api_key, api_url=api_url)
super().__init__(endpoint, self.client)
super().__init__(endpoint, api_key=api_key, api_url=api_url)

function_service = _new_app(self, verification_key)
app.mount("/dispatch.sdk.v1.FunctionService", function_service)
Expand Down Expand Up @@ -225,7 +224,7 @@ async def execute(request: fastapi.Request):
raise _ConnectError(400, "invalid_argument", "function is required")

try:
func = function_registry._functions[req.function]
func = function_registry.functions[req.function]
except KeyError:
logger.debug("function '%s' not found", req.function)
raise _ConnectError(
Expand Down
44 changes: 25 additions & 19 deletions src/dispatch/function.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,6 @@ def __init__(
client: Client,
name: str,
primitive_func: PrimitiveFunctionType,
func: Callable,
):
PrimitiveFunction.__init__(self, endpoint, client, name, primitive_func)

Expand Down Expand Up @@ -158,21 +157,30 @@ def build_call(


class Registry:
"""Registry of local functions."""
"""Registry of functions."""

__slots__ = ("_functions", "_endpoint", "_client")
__slots__ = ("functions", "endpoint", "client")

def __init__(self, endpoint: str, client: Client):
"""Initialize a local function registry.
def __init__(
self, endpoint: str, api_key: str | None = None, api_url: str | None = None
):
"""Initialize a function registry.
Args:
endpoint: URL of the endpoint that the function is accessible from.
client: Client for the Dispatch API. Used to dispatch calls to
local functions.
api_key: Dispatch API key to use for authentication when
dispatching calls to functions. Uses the value of the
DISPATCH_API_KEY environment variable by default.
api_url: The URL of the Dispatch API to use when dispatching calls
to functions. Uses the value of the DISPATCH_API_URL environment
variable if set, otherwise defaults to the public Dispatch API
(DEFAULT_API_URL).
"""
self._functions: Dict[str, PrimitiveFunction] = {}
self._endpoint = endpoint
self._client = client
self.functions: Dict[str, PrimitiveFunction] = {}
self.endpoint = endpoint
self.client = Client(api_key=api_key, api_url=api_url)

@overload
def function(self, func: Callable[P, Coroutine[Any, Any, T]]) -> Function[P, T]: ...
Expand Down Expand Up @@ -215,9 +223,7 @@ def primitive_func(input: Input) -> Output:
primitive_func.__qualname__ = f"{name}_primitive"
primitive_func = durable(primitive_func)

wrapped_func = Function[P, T](
self._endpoint, self._client, name, primitive_func, func
)
wrapped_func = Function[P, T](self.endpoint, self.client, name, primitive_func)
self._register(name, wrapped_func)
return wrapped_func

Expand All @@ -228,20 +234,20 @@ def primitive_function(
name = primitive_func.__qualname__
logger.info("registering primitive function: %s", name)
wrapped_func = PrimitiveFunction(
self._endpoint, self._client, name, primitive_func
self.endpoint, self.client, name, primitive_func
)
self._register(name, wrapped_func)
return wrapped_func

def _register(self, name: str, wrapped_func: PrimitiveFunction):
if name in self._functions:
if name in self.functions:
raise ValueError(f"function already registered with name '{name}'")
self._functions[name] = wrapped_func
self.functions[name] = wrapped_func

def set_client(self, client: Client):
"""Set the Client instance used to dispatch calls to local functions."""
self._client = client
for fn in self._functions.values():
"""Set the Client instance used to dispatch calls to registered functions."""
self.client = client
for fn in self.functions.values():
fn._client = client


Expand Down
7 changes: 5 additions & 2 deletions tests/dispatch/test_function.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,11 @@

class TestFunction(unittest.TestCase):
def setUp(self):
self.client = Client(api_url="http://dispatch.com", api_key="foobar")
self.dispatch = Registry(endpoint="http://example.com", client=self.client)
self.dispatch = Registry(
endpoint="http://example.com",
api_url="http://dispatch.com",
api_key="foobar",
)

def test_serializable(self):
@self.dispatch.function
Expand Down

0 comments on commit 99b9ed6

Please sign in to comment.