Skip to content

Commit

Permalink
Merge pull request #129 from stealthrocket/function-name-fix
Browse files Browse the repository at this point in the history
Fix registration of synchronous functions
  • Loading branch information
chriso authored Mar 18, 2024
2 parents 8aa6795 + 988c45e commit f35c6ba
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 11 deletions.
18 changes: 9 additions & 9 deletions src/dispatch/function.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,28 +182,28 @@ def function(self, func: Callable[P, T]) -> Function[P, T]: ...

def function(self, func):
"""Decorator that registers functions."""
name = func.__qualname__
if not inspect.iscoroutinefunction(func):
logger.info("registering function: %s", func.__qualname__)
return self._register_function(func)
logger.info("registering function: %s", name)
return self._register_function(name, func)

logger.info("registering coroutine: %s", func.__qualname__)
return self._register_coroutine(func)
logger.info("registering coroutine: %s", name)
return self._register_coroutine(name, func)

def _register_function(self, func: Callable[P, T]) -> Function[P, T]:
def _register_function(self, name: str, func: Callable[P, T]) -> Function[P, T]:
func = durable(func)

@wraps(func)
async def async_wrapper(*args: P.args, **kwargs: P.kwargs) -> T:
return func(*args, **kwargs)

async_wrapper.__qualname__ = f"{func.__qualname__}_async"
async_wrapper.__qualname__ = f"{name}_async"

return self._register_coroutine(async_wrapper)
return self._register_coroutine(name, async_wrapper)

def _register_coroutine(
self, func: Callable[P, Coroutine[Any, Any, T]]
self, name: str, func: Callable[P, Coroutine[Any, Any, T]]
) -> Function[P, T]:
name = func.__qualname__
logger.info("registering coroutine: %s", name)

func = durable(func)
Expand Down
10 changes: 8 additions & 2 deletions tests/test_full.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,9 @@ def test_simple_end_to_end(self):
def my_function(name: str) -> str:
return f"Hello world: {name}"

call = my_function.build_call(52)
self.assertEqual(call.function.split(".")[-1], "my_function")

# The client.
[dispatch_id] = self.dispatch_client.dispatch([my_function.build_call(52)])

Expand All @@ -73,10 +76,13 @@ def my_function(name: str) -> str:

def test_simple_missing_signature(self):
@self.dispatch.function
def my_function(name: str) -> str:
async def my_function(name: str) -> str:
return f"Hello world: {name}"

[dispatch_id] = self.dispatch_client.dispatch([my_function.build_call(52)])
call = my_function.build_call(52)
self.assertEqual(call.function.split(".")[-1], "my_function")

[dispatch_id] = self.dispatch_client.dispatch([call])

self.dispatch_service.endpoint_client = EndpointClient.from_app(
self.endpoint_app
Expand Down

0 comments on commit f35c6ba

Please sign in to comment.