Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor ASGITransport.request() #1021

Merged
merged 2 commits into from
Jun 13, 2020
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 17 additions & 13 deletions httpx/_transports/asgi.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,16 @@
import typing
from typing import Callable, Dict, List, Optional, Tuple
from typing import TYPE_CHECKING, Callable, Dict, List, Optional, Tuple, Union

import httpcore
import sniffio

from .._content_streams import ByteStream
from .._utils import warn_deprecated

if typing.TYPE_CHECKING: # pragma: no cover
if TYPE_CHECKING: # pragma: no cover
import asyncio
import trio

Event = typing.Union[asyncio.Event, trio.Event]
Event = Union[asyncio.Event, trio.Event]


def create_event() -> "Event":
Expand Down Expand Up @@ -78,6 +77,10 @@ async def request(
stream: httpcore.AsyncByteStream = None,
timeout: Dict[str, Optional[float]] = None,
) -> Tuple[bytes, int, bytes, List[Tuple[bytes, bytes]], httpcore.AsyncByteStream]:
headers = [] if headers is None else headers
stream = ByteStream(b"") if stream is None else stream

# ASGI scope.
scheme, host, port, full_path = url
path, _, query = full_path.partition(b"?")
scope = {
Expand All @@ -93,20 +96,22 @@ async def request(
"client": self.client,
"root_path": self.root_path,
}

# Request.
request_body_chunks = stream.__aiter__()
request_complete = False

# Response.
status_code = None
response_headers = None
body_parts = []
request_complete = False
response_started = False
response_complete = create_event()

headers = [] if headers is None else headers
stream = ByteStream(b"") if stream is None else stream

request_body_chunks = stream.__aiter__()
# ASGI callables.

async def receive() -> dict:
nonlocal request_complete, response_complete
nonlocal request_complete

if request_complete:
await response_complete.wait()
Expand All @@ -120,8 +125,7 @@ async def receive() -> dict:
return {"type": "http.request", "body": body, "more_body": True}

async def send(message: dict) -> None:
nonlocal status_code, response_headers, body_parts
nonlocal response_started, response_complete
nonlocal status_code, response_headers, response_started

if message["type"] == "http.response.start":
assert not response_started
Expand All @@ -144,7 +148,7 @@ async def send(message: dict) -> None:
try:
await self.app(scope, receive, send)
except Exception:
if self.raise_app_exceptions or not response_complete:
if self.raise_app_exceptions or not response_complete.is_set():
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this was a typo before.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Think that was my mistake. Missed this when changing response_complete from a bool to an event. Also didn't realise the nonlocal could be removed 😬

raise

assert response_complete.is_set()
Expand Down