From 048dd83dcfc7b7559d561fda378d4fa9392bfbd3 Mon Sep 17 00:00:00 2001 From: florimondmanca Date: Sat, 13 Jun 2020 15:35:17 +0200 Subject: [PATCH] Refactor ASGITransport.request() --- httpx/_transports/asgi.py | 30 +++++++++++++++++------------- 1 file changed, 17 insertions(+), 13 deletions(-) diff --git a/httpx/_transports/asgi.py b/httpx/_transports/asgi.py index af03e24fee..2e22820992 100644 --- a/httpx/_transports/asgi.py +++ b/httpx/_transports/asgi.py @@ -1,5 +1,4 @@ -import typing -from typing import Callable, Dict, List, Optional, Tuple +from typing import TYPE_CHECKING, Callable, Dict, List, Optional, Tuple, Union import httpcore import sniffio @@ -7,11 +6,11 @@ from .._content_streams import ByteStream from .._utils import warn_deprecated -if typing.TYPE_CHECKING: # pragma: no cover +if TYPE_CHECKING: # pragma: no cover import asyncio import trio - Event = typing.Union[asyncio.Event, trio.Event] + Event = Union[asyncio.Event, trio.Event] def create_event() -> "Event": @@ -78,6 +77,10 @@ async def request( stream: httpcore.AsyncByteStream = None, timeout: Dict[str, Optional[float]] = None, ) -> Tuple[bytes, int, bytes, List[Tuple[bytes, bytes]], httpcore.AsyncByteStream]: + headers = [] if headers is None else headers + stream = ByteStream(b"") if stream is None else stream + + # ASGI scope. scheme, host, port, full_path = url path, _, query = full_path.partition(b"?") scope = { @@ -93,20 +96,22 @@ async def request( "client": self.client, "root_path": self.root_path, } + + # Request. + request_body_chunks = stream.__aiter__() + request_complete = False + + # Response. status_code = None response_headers = None body_parts = [] - request_complete = False response_started = False response_complete = create_event() - headers = [] if headers is None else headers - stream = ByteStream(b"") if stream is None else stream - - request_body_chunks = stream.__aiter__() + # ASGI callables. async def receive() -> dict: - nonlocal request_complete, response_complete + nonlocal request_complete if request_complete: await response_complete.wait() @@ -120,8 +125,7 @@ async def receive() -> dict: return {"type": "http.request", "body": body, "more_body": True} async def send(message: dict) -> None: - nonlocal status_code, response_headers, body_parts - nonlocal response_started, response_complete + nonlocal status_code, response_headers, response_started if message["type"] == "http.response.start": assert not response_started @@ -144,7 +148,7 @@ async def send(message: dict) -> None: try: await self.app(scope, receive, send) except Exception: - if self.raise_app_exceptions or not response_complete: + if self.raise_app_exceptions or not response_complete.is_set(): raise assert response_complete.is_set()