Skip to content

Commit

Permalink
Rename iscoroutinefunction to is_async_callable
Browse files Browse the repository at this point in the history
  • Loading branch information
Kludex committed May 27, 2022
1 parent dd8b7ed commit 7c74668
Show file tree
Hide file tree
Showing 9 changed files with 29 additions and 29 deletions.
2 changes: 1 addition & 1 deletion starlette/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import typing


def iscoroutinefunction(obj: typing.Any) -> bool:
def is_async_callable(obj: typing.Any) -> bool:
while isinstance(obj, functools.partial):
obj = obj.func

Expand Down
4 changes: 2 additions & 2 deletions starlette/authentication.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import typing
from urllib.parse import urlencode

from starlette._utils import iscoroutinefunction
from starlette._utils import is_async_callable
from starlette.exceptions import HTTPException
from starlette.requests import HTTPConnection, Request
from starlette.responses import RedirectResponse, Response
Expand Down Expand Up @@ -53,7 +53,7 @@ async def websocket_wrapper(

return websocket_wrapper

elif iscoroutinefunction(func):
elif is_async_callable(func):
# Handle async request/response functions.
@functools.wraps(func)
async def async_wrapper(
Expand Down
4 changes: 2 additions & 2 deletions starlette/background.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
else: # pragma: no cover
from typing_extensions import ParamSpec

from starlette._utils import iscoroutinefunction
from starlette._utils import is_async_callable
from starlette.concurrency import run_in_threadpool

P = ParamSpec("P")
Expand All @@ -19,7 +19,7 @@ def __init__(
self.func = func
self.args = args
self.kwargs = kwargs
self.is_async = iscoroutinefunction(func)
self.is_async = is_async_callable(func)

async def __call__(self) -> None:
if self.is_async:
Expand Down
4 changes: 2 additions & 2 deletions starlette/endpoints.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import typing

from starlette import status
from starlette._utils import iscoroutinefunction
from starlette._utils import is_async_callable
from starlette.concurrency import run_in_threadpool
from starlette.exceptions import HTTPException
from starlette.requests import Request
Expand Down Expand Up @@ -37,7 +37,7 @@ async def dispatch(self) -> None:
handler: typing.Callable[[Request], typing.Any] = getattr(
self, handler_name, self.method_not_allowed
)
is_async = iscoroutinefunction(handler)
is_async = is_async_callable(handler)
if is_async:
response = await handler(request)
else:
Expand Down
4 changes: 2 additions & 2 deletions starlette/middleware/errors.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import traceback
import typing

from starlette._utils import iscoroutinefunction
from starlette._utils import is_async_callable
from starlette.concurrency import run_in_threadpool
from starlette.requests import Request
from starlette.responses import HTMLResponse, PlainTextResponse, Response
Expand Down Expand Up @@ -170,7 +170,7 @@ async def _send(message: Message) -> None:
response = self.error_response(request, exc)
else:
# Use an installed 500 error handler.
if iscoroutinefunction(self.handler):
if is_async_callable(self.handler):
response = await self.handler(request, exc)
else:
response = await run_in_threadpool(self.handler, request, exc)
Expand Down
4 changes: 2 additions & 2 deletions starlette/middleware/exceptions.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import asyncio
import typing

from starlette._utils import is_async_callable
from starlette.concurrency import run_in_threadpool
from starlette.exceptions import HTTPException
from starlette.requests import Request
Expand Down Expand Up @@ -79,7 +79,7 @@ async def sender(message: Message) -> None:
raise RuntimeError(msg) from exc

request = Request(scope, receive=receive)
if asyncio.iscoroutinefunction(handler):
if is_async_callable(handler):
response = await handler(request, exc)
else:
response = await run_in_threadpool(handler, request, exc)
Expand Down
8 changes: 4 additions & 4 deletions starlette/routing.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from contextlib import asynccontextmanager
from enum import Enum

from starlette._utils import iscoroutinefunction
from starlette._utils import is_async_callable
from starlette.concurrency import run_in_threadpool
from starlette.convertors import CONVERTOR_TYPES, Convertor
from starlette.datastructures import URL, Headers, URLPath
Expand Down Expand Up @@ -57,7 +57,7 @@ def request_response(func: typing.Callable) -> ASGIApp:
Takes a function or coroutine `func(request) -> response`,
and returns an ASGI application.
"""
is_coroutine = iscoroutinefunction(func)
is_coroutine = is_async_callable(func)

async def app(scope: Scope, receive: Receive, send: Send) -> None:
request = Request(scope, receive=receive, send=send)
Expand Down Expand Up @@ -608,7 +608,7 @@ async def startup(self) -> None:
Run any `.on_startup` event handlers.
"""
for handler in self.on_startup:
if iscoroutinefunction(handler):
if is_async_callable(handler):
await handler()
else:
handler()
Expand All @@ -618,7 +618,7 @@ async def shutdown(self) -> None:
Run any `.on_shutdown` event handlers.
"""
for handler in self.on_shutdown:
if iscoroutinefunction(handler):
if is_async_callable(handler):
await handler()
else:
handler()
Expand Down
4 changes: 2 additions & 2 deletions starlette/testclient.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
import requests
from anyio.streams.stapled import StapledObjectStream

from starlette._utils import iscoroutinefunction
from starlette._utils import is_async_callable
from starlette.types import Message, Receive, Scope, Send
from starlette.websockets import WebSocketDisconnect

Expand Down Expand Up @@ -84,7 +84,7 @@ def _get_reason_phrase(status_code: int) -> str:
def _is_asgi3(app: typing.Union[ASGI2App, ASGI3App]) -> bool:
if inspect.isclass(app):
return hasattr(app, "__await__")
return iscoroutinefunction(app)
return is_async_callable(app)


class _WrapASGI2:
Expand Down
24 changes: 12 additions & 12 deletions tests/test__utils.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import functools

from starlette._utils import iscoroutinefunction
from starlette._utils import is_async_callable


def test_async_func():
Expand All @@ -10,8 +10,8 @@ async def async_func():
def func():
... # pragma: no cover

assert iscoroutinefunction(async_func)
assert not iscoroutinefunction(func)
assert is_async_callable(async_func)
assert not is_async_callable(func)


def test_async_partial():
Expand All @@ -22,10 +22,10 @@ def func(a, b):
... # pragma: no cover

partial = functools.partial(async_func, 1)
assert iscoroutinefunction(partial)
assert is_async_callable(partial)

partial = functools.partial(func, 1)
assert not iscoroutinefunction(partial)
assert not is_async_callable(partial)


def test_async_method():
Expand All @@ -37,8 +37,8 @@ class Sync:
def method(self):
... # pragma: no cover

assert iscoroutinefunction(Async().method)
assert not iscoroutinefunction(Sync().method)
assert is_async_callable(Async().method)
assert not is_async_callable(Sync().method)


def test_async_object_call():
Expand All @@ -50,8 +50,8 @@ class Sync:
def __call__(self):
... # pragma: no cover

assert iscoroutinefunction(Async())
assert not iscoroutinefunction(Sync())
assert is_async_callable(Async())
assert not is_async_callable(Sync())


def test_async_partial_object_call():
Expand All @@ -64,10 +64,10 @@ def __call__(self, a, b):
... # pragma: no cover

partial = functools.partial(Async(), 1)
assert iscoroutinefunction(partial)
assert is_async_callable(partial)

partial = functools.partial(Sync(), 1)
assert not iscoroutinefunction(partial)
assert not is_async_callable(partial)


def test_async_nested_partial():
Expand All @@ -76,4 +76,4 @@ async def async_func(a, b):

partial = functools.partial(async_func, b=2)
nested_partial = functools.partial(partial, a=1)
assert iscoroutinefunction(nested_partial)
assert is_async_callable(nested_partial)

0 comments on commit 7c74668

Please sign in to comment.