Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve detection of async callables #1444

Merged
merged 10 commits into from
May 28, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 12 additions & 0 deletions starlette/_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
import asyncio
import functools
import typing


def is_async_callable(obj: typing.Any) -> bool:
while isinstance(obj, functools.partial):
Kludex marked this conversation as resolved.
Show resolved Hide resolved
obj = obj.func

return asyncio.iscoroutinefunction(obj) or (
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @Kludex, would it make sense to use inspect.iscoroutinefunction here?

I understand why we cannot just use the said function alone, since async def __call__(...) is not captured by it:

>>> import inspect
>>> class A:
...     async def __call__(self):
...             ...
... 
>>> inspect.iscoroutinefunction(A())
False

I have looked at the code of asyncio.iscoroutinefunction and it only additionally check for the deprecated @coroutine decorator.

Cheers,
Libor

This comment was marked as off-topic.

callable(obj) and asyncio.iscoroutinefunction(obj.__call__)
)
4 changes: 2 additions & 2 deletions starlette/authentication.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
import asyncio
import functools
import inspect
import typing
from urllib.parse import urlencode

from starlette._utils import is_async_callable
from starlette.exceptions import HTTPException
from starlette.requests import HTTPConnection, Request
from starlette.responses import RedirectResponse, Response
Expand Down Expand Up @@ -53,7 +53,7 @@ async def websocket_wrapper(

return websocket_wrapper

elif asyncio.iscoroutinefunction(func):
elif is_async_callable(func):
# Handle async request/response functions.
@functools.wraps(func)
async def async_wrapper(
Expand Down
4 changes: 2 additions & 2 deletions starlette/background.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import asyncio
import sys
import typing

Expand All @@ -7,6 +6,7 @@
else: # pragma: no cover
from typing_extensions import ParamSpec

from starlette._utils import is_async_callable
from starlette.concurrency import run_in_threadpool

P = ParamSpec("P")
Expand All @@ -19,7 +19,7 @@ def __init__(
self.func = func
self.args = args
self.kwargs = kwargs
self.is_async = asyncio.iscoroutinefunction(func)
self.is_async = is_async_callable(func)

async def __call__(self) -> None:
if self.is_async:
Expand Down
4 changes: 2 additions & 2 deletions starlette/endpoints.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
import asyncio
import json
import typing

from starlette import status
from starlette._utils import is_async_callable
from starlette.concurrency import run_in_threadpool
from starlette.exceptions import HTTPException
from starlette.requests import Request
Expand Down Expand Up @@ -37,7 +37,7 @@ async def dispatch(self) -> None:
handler: typing.Callable[[Request], typing.Any] = getattr(
self, handler_name, self.method_not_allowed
)
is_async = asyncio.iscoroutinefunction(handler)
is_async = is_async_callable(handler)
if is_async:
response = await handler(request)
else:
Expand Down
4 changes: 2 additions & 2 deletions starlette/middleware/errors.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
import asyncio
import html
import inspect
import traceback
import typing

from starlette._utils import is_async_callable
from starlette.concurrency import run_in_threadpool
from starlette.requests import Request
from starlette.responses import HTMLResponse, PlainTextResponse, Response
Expand Down Expand Up @@ -170,7 +170,7 @@ async def _send(message: Message) -> None:
response = self.error_response(request, exc)
else:
# Use an installed 500 error handler.
if asyncio.iscoroutinefunction(self.handler):
if is_async_callable(self.handler):
response = await self.handler(request, exc)
else:
response = await run_in_threadpool(self.handler, request, exc)
Expand Down
4 changes: 2 additions & 2 deletions starlette/middleware/exceptions.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import asyncio
import typing

from starlette._utils import is_async_callable
from starlette.concurrency import run_in_threadpool
from starlette.exceptions import HTTPException
from starlette.requests import Request
Expand Down Expand Up @@ -79,7 +79,7 @@ async def sender(message: Message) -> None:
raise RuntimeError(msg) from exc

request = Request(scope, receive=receive)
if asyncio.iscoroutinefunction(handler):
if is_async_callable(handler):
response = await handler(request, exc)
else:
response = await run_in_threadpool(handler, request, exc)
Expand Down
15 changes: 10 additions & 5 deletions starlette/routing.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import asyncio
import contextlib
import functools
import inspect
Expand All @@ -10,6 +9,7 @@
from contextlib import asynccontextmanager
from enum import Enum

from starlette._utils import is_async_callable
from starlette.concurrency import run_in_threadpool
from starlette.convertors import CONVERTOR_TYPES, Convertor
from starlette.datastructures import URL, Headers, URLPath
Expand Down Expand Up @@ -37,11 +37,16 @@ class Match(Enum):
FULL = 2


def iscoroutinefunction_or_partial(obj: typing.Any) -> bool:
def iscoroutinefunction_or_partial(obj: typing.Any) -> bool: # pragma: no cover
"""
Correctly determines if an object is a coroutine function,
including those wrapped in functools.partial objects.
"""
warnings.warn(
"iscoroutinefunction_or_partial is deprecated, "
"and will be removed in a future release.",
DeprecationWarning,
)
while isinstance(obj, functools.partial):
obj = obj.func
return inspect.iscoroutinefunction(obj)
Expand All @@ -52,7 +57,7 @@ def request_response(func: typing.Callable) -> ASGIApp:
Takes a function or coroutine `func(request) -> response`,
and returns an ASGI application.
"""
is_coroutine = iscoroutinefunction_or_partial(func)
is_coroutine = is_async_callable(func)

async def app(scope: Scope, receive: Receive, send: Send) -> None:
request = Request(scope, receive=receive, send=send)
Expand Down Expand Up @@ -603,7 +608,7 @@ async def startup(self) -> None:
Run any `.on_startup` event handlers.
"""
for handler in self.on_startup:
if asyncio.iscoroutinefunction(handler):
if is_async_callable(handler):
await handler()
else:
handler()
Expand All @@ -613,7 +618,7 @@ async def shutdown(self) -> None:
Run any `.on_shutdown` event handlers.
"""
for handler in self.on_shutdown:
if asyncio.iscoroutinefunction(handler):
if is_async_callable(handler):
await handler()
else:
handler()
Expand Down
7 changes: 2 additions & 5 deletions starlette/testclient.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import asyncio
import contextlib
import http
import inspect
Expand All @@ -16,6 +15,7 @@
import requests
from anyio.streams.stapled import StapledObjectStream

from starlette._utils import is_async_callable
from starlette.types import Message, Receive, Scope, Send
from starlette.websockets import WebSocketDisconnect

Expand Down Expand Up @@ -84,10 +84,7 @@ def _get_reason_phrase(status_code: int) -> str:
def _is_asgi3(app: typing.Union[ASGI2App, ASGI3App]) -> bool:
if inspect.isclass(app):
return hasattr(app, "__await__")
elif inspect.isfunction(app):
return asyncio.iscoroutinefunction(app)
call = getattr(app, "__call__", None)
return asyncio.iscoroutinefunction(call)
return is_async_callable(app)


class _WrapASGI2:
Expand Down
79 changes: 79 additions & 0 deletions tests/test__utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
import functools

from starlette._utils import is_async_callable


def test_async_func():
async def async_func():
... # pragma: no cover

def func():
... # pragma: no cover

assert is_async_callable(async_func)
assert not is_async_callable(func)


def test_async_partial():
Kludex marked this conversation as resolved.
Show resolved Hide resolved
async def async_func(a, b):
... # pragma: no cover

def func(a, b):
... # pragma: no cover

partial = functools.partial(async_func, 1)
assert is_async_callable(partial)

partial = functools.partial(func, 1)
assert not is_async_callable(partial)


def test_async_method():
class Async:
async def method(self):
... # pragma: no cover

class Sync:
def method(self):
... # pragma: no cover

assert is_async_callable(Async().method)
assert not is_async_callable(Sync().method)


def test_async_object_call():
class Async:
async def __call__(self):
... # pragma: no cover

class Sync:
def __call__(self):
... # pragma: no cover

assert is_async_callable(Async())
assert not is_async_callable(Sync())


def test_async_partial_object_call():
class Async:
async def __call__(self, a, b):
... # pragma: no cover

class Sync:
def __call__(self, a, b):
... # pragma: no cover

partial = functools.partial(Async(), 1)
assert is_async_callable(partial)

partial = functools.partial(Sync(), 1)
assert not is_async_callable(partial)


def test_async_nested_partial():
async def async_func(a, b):
... # pragma: no cover

partial = functools.partial(async_func, b=2)
nested_partial = functools.partial(partial, a=1)
assert is_async_callable(nested_partial)