Skip to content

Commit

Permalink
Merge pull request #101 from stealthrocket/paramspec2
Browse files Browse the repository at this point in the history
Improve type checking of call results
  • Loading branch information
chriso authored Mar 4, 2024
2 parents 78c0ced + a9148a1 commit b268e72
Showing 1 changed file with 107 additions and 103 deletions.
210 changes: 107 additions & 103 deletions src/dispatch/function.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,17 @@
import logging
from functools import wraps
from types import CoroutineType
from typing import Any, Callable, Dict, Generic, ParamSpec, TypeAlias, TypeVar
from typing import (
Any,
Callable,
Coroutine,
Dict,
Generic,
ParamSpec,
TypeAlias,
TypeVar,
overload,
)

import dispatch.coroutine
from dispatch.client import Client
Expand All @@ -23,44 +33,20 @@
"""


P = ParamSpec("P")
T = TypeVar("T")


class Function(Generic[P, T]):
"""Callable wrapper around a function meant to be used throughout the
Dispatch Python SDK.
"""

__slots__ = ("_endpoint", "_client", "_name", "_primitive_func", "_func")
class PrimitiveFunction:
__slots__ = ("_endpoint", "_client", "_name", "_primitive_func")

def __init__(
self,
endpoint: str,
client: Client,
name: str,
primitive_func: PrimitiveFunctionType,
func: Callable[P, T] | None,
coroutine: bool = False,
):
self._endpoint = endpoint
self._client = client
self._name = name
self._primitive_func = primitive_func
if func:
self._func: Callable[P, T] | None = (
durable(self._call_async) if coroutine else func
)
else:
self._func = None

def __call__(self, *args: P.args, **kwargs: P.kwargs):
if self._func is None:
raise ValueError("cannot call a primitive function directly")
return self._func(*args, **kwargs)

def _primitive_call(self, input: Input) -> Output:
return self._primitive_func(input)

@property
def endpoint(self) -> str:
Expand All @@ -70,8 +56,62 @@ def endpoint(self) -> str:
def name(self) -> str:
return self._name

def _primitive_call(self, input: Input) -> Output:
return self._primitive_func(input)

def _primitive_dispatch(self, input: Any = None) -> DispatchID:
[dispatch_id] = self._client.dispatch([self._build_primitive_call(input)])
return dispatch_id

def _build_primitive_call(
self, input: Any, correlation_id: int | None = None
) -> Call:
return Call(
correlation_id=correlation_id,
endpoint=self.endpoint,
function=self.name,
input=input,
)


P = ParamSpec("P")
T = TypeVar("T")


class Function(PrimitiveFunction, Generic[P, T]):
"""Callable wrapper around a function meant to be used throughout the
Dispatch Python SDK.
"""

__slots__ = ("_func_indirect",)

def __init__(
self,
endpoint: str,
client: Client,
name: str,
primitive_func: PrimitiveFunctionType,
func: Callable,
):
PrimitiveFunction.__init__(self, endpoint, client, name, primitive_func)

self._func_indirect: Callable[P, Coroutine[Any, Any, T]] = durable(
self._call_async
)

async def _call_async(self, *args: P.args, **kwargs: P.kwargs) -> T:
return await dispatch.coroutine.call(
self.build_call(*args, **kwargs, correlation_id=None)
)

def __call__(self, *args: P.args, **kwargs: P.kwargs) -> Coroutine[Any, Any, T]:
"""Call the function asynchronously (through Dispatch), and return a
coroutine that can be awaited to retrieve the call result."""
return self._func_indirect(*args, **kwargs)

def dispatch(self, *args: P.args, **kwargs: P.kwargs) -> DispatchID:
"""Dispatch a call to the function.
"""Dispatch an asynchronous call to the function without
waiting for a result.
The Registry this function was registered with must be initialized
with a Client / api_key for this call facility to be available.
Expand All @@ -88,16 +128,6 @@ def dispatch(self, *args: P.args, **kwargs: P.kwargs) -> DispatchID:
"""
return self._primitive_dispatch(Arguments(args, kwargs))

def _primitive_dispatch(self, input: Any = None) -> DispatchID:
[dispatch_id] = self._client.dispatch([self._build_primitive_call(input)])
return dispatch_id

async def _call_async(self, *args: P.args, **kwargs: P.kwargs) -> T:
"""Asynchronously call the function from a @dispatch.function."""
return await dispatch.coroutine.call(
self.build_call(*args, **kwargs, correlation_id=None)
)

def build_call(
self, *args: P.args, correlation_id: int | None = None, **kwargs: P.kwargs
) -> Call:
Expand All @@ -117,16 +147,6 @@ def build_call(
Arguments(args, kwargs), correlation_id=correlation_id
)

def _build_primitive_call(
self, input: Any, correlation_id: int | None = None
) -> Call:
return Call(
correlation_id=correlation_id,
endpoint=self.endpoint,
function=self.name,
input=input,
)


class Registry:
"""Registry of local functions."""
Expand All @@ -141,89 +161,73 @@ def __init__(self, endpoint: str, client: Client):
client: Client for the Dispatch API. Used to dispatch calls to
local functions.
"""
self._functions: Dict[str, Function] = {}
self._functions: Dict[str, PrimitiveFunction] = {}
self._endpoint = endpoint
self._client = client

def function(self, func: Callable[P, T]) -> Function[P, T]:
@overload
def function(self, func: Callable[P, Coroutine[Any, Any, T]]) -> Function[P, T]: ...

@overload
def function(self, func: Callable[P, T]) -> Function[P, T]: ...

def function(self, func):
"""Decorator that registers functions."""
if inspect.iscoroutinefunction(func):
return self._register_coroutine(func)
return self._register_function(func)
if not inspect.iscoroutinefunction(func):
logger.info("registering function: %s", func.__qualname__)
return self._register_function(func)

def primitive_function(self, func: PrimitiveFunctionType) -> Function:
"""Decorator that registers primitive functions."""
return self._register_primitive_function(func)
logger.info("registering coroutine: %s", func.__qualname__)
return self._register_coroutine(func)

def _register_function(self, func: Callable[P, T]) -> Function[P, T]:
logger.info("registering function: %s", func.__qualname__)

# Register the function with the experimental.durable package, in case
# it's referenced from a coroutine.
func = durable(func)

@wraps(func)
def primitive_func(input: Input) -> Output:
try:
try:
args, kwargs = input.input_arguments()
except ValueError:
raise ValueError("incorrect input for function")
raw_output = func(*args, **kwargs)
except Exception as e:
logger.exception(
f"@dispatch.function: '{func.__name__}' raised an exception"
)
return Output.error(Error.from_exception(e))
else:
return Output.value(raw_output)

primitive_func.__qualname__ = f"{func.__qualname__}_primitive"
primitive_func = durable(primitive_func)
async def async_wrapper(*args: P.args, **kwargs: P.kwargs) -> T:
return func(*args, **kwargs)

async_wrapper.__qualname__ = f"{func.__qualname__}_async"

return self._register(primitive_func, func, coroutine=False)
return self._register_coroutine(async_wrapper)

def _register_coroutine(
self, func: Callable[P, CoroutineType[Any, Any, T]]
self, func: Callable[P, Coroutine[Any, Any, T]]
) -> Function[P, T]:
logger.info("registering coroutine: %s", func.__qualname__)
name = func.__qualname__
logger.info("registering coroutine: %s", name)

func = durable(func)

@wraps(func)
def primitive_func(input: Input) -> Output:
return OneShotScheduler(func).run(input)

primitive_func.__qualname__ = f"{func.__qualname__}_primitive"
primitive_func.__qualname__ = f"{name}_primitive"
primitive_func = durable(primitive_func)

return self._register(primitive_func, func, coroutine=True)
wrapped_func = Function[P, T](
self._endpoint, self._client, name, primitive_func, func
)
self._register(name, wrapped_func)
return wrapped_func

def _register_primitive_function(
def primitive_function(
self, primitive_func: PrimitiveFunctionType
) -> Function[P, T]:
logger.info("registering primitive function: %s", primitive_func.__qualname__)
return self._register(primitive_func, func=None, coroutine=False)
) -> PrimitiveFunction:
"""Decorator that registers primitive functions."""
name = primitive_func.__qualname__
logger.info("registering primitive function: %s", name)
wrapped_func = PrimitiveFunction(
self._endpoint, self._client, name, primitive_func
)
self._register(name, wrapped_func)
return wrapped_func

def _register(
self,
primitive_func: PrimitiveFunctionType,
func: Callable[P, T] | None,
coroutine: bool,
) -> Function[P, T]:
if func:
name = func.__qualname__
else:
name = primitive_func.__qualname__
def _register(self, name: str, wrapped_func: PrimitiveFunction):
if name in self._functions:
raise ValueError(
f"function or coroutine already registered with name '{name}'"
)
wrapped_func = Function[P, T](
self._endpoint, self._client, name, primitive_func, func, coroutine
)
raise ValueError(f"function already registered with name '{name}'")
self._functions[name] = wrapped_func
return wrapped_func

def set_client(self, client: Client):
"""Set the Client instance used to dispatch calls to local functions."""
Expand Down

0 comments on commit b268e72

Please sign in to comment.