Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve detection of async callables #1444

Merged
merged 10 commits into from
May 28, 2022
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 12 additions & 0 deletions starlette/_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
import asyncio
import functools
import typing


def iscoroutinefunction(obj: typing.Any) -> bool:
Copy link
Member

@florimondmanca florimondmanca May 24, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Coming from reviewing #1644

This is a nit, but should we call this is_async_callable?

AFAICT, having "something we can call and await" is all we care about when checking for apps or endpoints.

"Coroutine function" is a well-defined word in the Python glossary: a "function which returns a coroutine object", defined with "async def".

https://docs.python.org/3/glossary.html#term-coroutine-function

Also, the current naming makes it look as a compatibility shim on asyncio.iscoroutinefunction which focuses on checking just the definition above^, whereas we do some more checks, such as looking for __call__.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a nit, but should we call this is_async_callable?

Sure. 👍 I'll adapt later today. Thanks for the review. 🙏

while isinstance(obj, functools.partial):
Kludex marked this conversation as resolved.
Show resolved Hide resolved
obj = obj.func

return asyncio.iscoroutinefunction(obj) or (
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @Kludex, would it make sense to use inspect.iscoroutinefunction here?

I understand why we cannot just use the said function alone, since async def __call__(...) is not captured by it:

>>> import inspect
>>> class A:
...     async def __call__(self):
...             ...
... 
>>> inspect.iscoroutinefunction(A())
False

I have looked at the code of asyncio.iscoroutinefunction and it only additionally check for the deprecated @coroutine decorator.

Cheers,
Libor

This comment was marked as off-topic.

callable(obj) and asyncio.iscoroutinefunction(obj.__call__)
)
4 changes: 2 additions & 2 deletions starlette/authentication.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
import asyncio
import functools
import inspect
import typing

from starlette._utils import iscoroutinefunction
from starlette.exceptions import HTTPException
from starlette.requests import HTTPConnection, Request
from starlette.responses import RedirectResponse, Response
Expand Down Expand Up @@ -52,7 +52,7 @@ async def websocket_wrapper(

return websocket_wrapper

elif asyncio.iscoroutinefunction(func):
elif iscoroutinefunction(func):
# Handle async request/response functions.
@functools.wraps(func)
async def async_wrapper(
Expand Down
4 changes: 2 additions & 2 deletions starlette/background.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import asyncio
import sys
import typing

Expand All @@ -7,6 +6,7 @@
else: # pragma: no cover
from typing_extensions import ParamSpec

from starlette._utils import iscoroutinefunction
from starlette.concurrency import run_in_threadpool

P = ParamSpec("P")
Expand All @@ -19,7 +19,7 @@ def __init__(
self.func = func
self.args = args
self.kwargs = kwargs
self.is_async = asyncio.iscoroutinefunction(func)
self.is_async = iscoroutinefunction(func)

async def __call__(self) -> None:
if self.is_async:
Expand Down
4 changes: 2 additions & 2 deletions starlette/endpoints.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
import asyncio
import json
import typing

from starlette import status
from starlette._utils import iscoroutinefunction
from starlette.concurrency import run_in_threadpool
from starlette.exceptions import HTTPException
from starlette.requests import Request
Expand Down Expand Up @@ -37,7 +37,7 @@ async def dispatch(self) -> None:
handler: typing.Callable[[Request], typing.Any] = getattr(
self, handler_name, self.method_not_allowed
)
is_async = asyncio.iscoroutinefunction(handler)
is_async = iscoroutinefunction(handler)
if is_async:
response = await handler(request)
else:
Expand Down
4 changes: 2 additions & 2 deletions starlette/exceptions.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import asyncio
import http
import typing

from starlette._utils import iscoroutinefunction
from starlette.concurrency import run_in_threadpool
from starlette.requests import Request
from starlette.responses import PlainTextResponse, Response
Expand Down Expand Up @@ -94,7 +94,7 @@ async def sender(message: Message) -> None:
raise RuntimeError(msg) from exc

request = Request(scope, receive=receive)
if asyncio.iscoroutinefunction(handler):
if iscoroutinefunction(handler):
response = await handler(request, exc)
else:
response = await run_in_threadpool(handler, request, exc)
Expand Down
4 changes: 2 additions & 2 deletions starlette/middleware/errors.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
import asyncio
import html
import inspect
import traceback
import typing

from starlette._utils import iscoroutinefunction
from starlette.concurrency import run_in_threadpool
from starlette.requests import Request
from starlette.responses import HTMLResponse, PlainTextResponse, Response
Expand Down Expand Up @@ -168,7 +168,7 @@ async def _send(message: Message) -> None:
response = self.error_response(request, exc)
else:
# Use an installed 500 error handler.
if asyncio.iscoroutinefunction(self.handler):
if iscoroutinefunction(self.handler):
response = await self.handler(request, exc)
else:
response = await run_in_threadpool(self.handler, request, exc)
Expand Down
6 changes: 3 additions & 3 deletions starlette/routing.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import asyncio
import contextlib
import functools
import inspect
Expand All @@ -10,6 +9,7 @@
import warnings
from enum import Enum

from starlette._utils import iscoroutinefunction
from starlette.concurrency import run_in_threadpool
from starlette.convertors import CONVERTOR_TYPES, Convertor
from starlette.datastructures import URL, Headers, URLPath
Expand Down Expand Up @@ -600,7 +600,7 @@ async def startup(self) -> None:
Run any `.on_startup` event handlers.
"""
for handler in self.on_startup:
if asyncio.iscoroutinefunction(handler):
if iscoroutinefunction(handler):
await handler()
else:
handler()
Expand All @@ -610,7 +610,7 @@ async def shutdown(self) -> None:
Run any `.on_shutdown` event handlers.
"""
for handler in self.on_shutdown:
if asyncio.iscoroutinefunction(handler):
if iscoroutinefunction(handler):
await handler()
else:
handler()
Expand Down
7 changes: 2 additions & 5 deletions starlette/testclient.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import asyncio
import contextlib
import http
import inspect
Expand All @@ -16,6 +15,7 @@
import requests
from anyio.streams.stapled import StapledObjectStream

from starlette._utils import iscoroutinefunction
from starlette.types import Message, Receive, Scope, Send
from starlette.websockets import WebSocketDisconnect

Expand Down Expand Up @@ -84,10 +84,7 @@ def _get_reason_phrase(status_code: int) -> str:
def _is_asgi3(app: typing.Union[ASGI2App, ASGI3App]) -> bool:
if inspect.isclass(app):
return hasattr(app, "__await__")
elif inspect.isfunction(app):
return asyncio.iscoroutinefunction(app)
call = getattr(app, "__call__", None)
return asyncio.iscoroutinefunction(call)
return iscoroutinefunction(app)


class _WrapASGI2:
Expand Down
79 changes: 79 additions & 0 deletions tests/test__utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
import functools

from starlette._utils import iscoroutinefunction


def test_async_func():
async def async_func():
... # pragma: no cover

def func():
... # pragma: no cover

assert iscoroutinefunction(async_func)
assert not iscoroutinefunction(func)


def test_async_partial():
Kludex marked this conversation as resolved.
Show resolved Hide resolved
async def async_func(a, b):
... # pragma: no cover

def func(a, b):
... # pragma: no cover

partial = functools.partial(async_func, 1)
assert iscoroutinefunction(partial)

partial = functools.partial(func, 1)
assert not iscoroutinefunction(partial)


def test_async_method():
class Async:
async def method(self):
... # pragma: no cover

class Sync:
def method(self):
... # pragma: no cover

assert iscoroutinefunction(Async().method)
assert not iscoroutinefunction(Sync().method)


def test_async_object_call():
class Async:
async def __call__(self):
... # pragma: no cover

class Sync:
def __call__(self):
... # pragma: no cover

assert iscoroutinefunction(Async())
assert not iscoroutinefunction(Sync())


def test_async_partial_object_call():
class Async:
async def __call__(self, a, b):
... # pragma: no cover

class Sync:
def __call__(self, a, b):
... # pragma: no cover

partial = functools.partial(Async(), 1)
assert iscoroutinefunction(partial)

partial = functools.partial(Sync(), 1)
assert not iscoroutinefunction(partial)


def test_async_nested_partial():
async def async_func(a, b):
... # pragma: no cover

partial = functools.partial(async_func, b=2)
nested_partial = functools.partial(partial, a=1)
assert iscoroutinefunction(nested_partial)