Skip to content

Commit

Permalink
[3.8] Refactor web error handling (#5270). (#5295)
Browse files Browse the repository at this point in the history
(cherry picked from commit e9fdf0a)

Co-authored-by: Andrew Svetlov <[email protected]>
  • Loading branch information
asvetlov authored Nov 27, 2020
1 parent 135cb16 commit 88f8f3b
Show file tree
Hide file tree
Showing 6 changed files with 90 additions and 852 deletions.
2 changes: 1 addition & 1 deletion aiohttp/client_proto.py
Original file line number Diff line number Diff line change
Expand Up @@ -231,7 +231,7 @@ def data_received(self, data: bytes) -> None:
self._payload = payload

if self._skip_payload or message.code in (204, 304):
self.feed_data((message, EMPTY_PAYLOAD), 0) # type: ignore
self.feed_data((message, EMPTY_PAYLOAD), 0)
else:
self.feed_data((message, payload), 0)
if payload is not None:
Expand Down
2 changes: 1 addition & 1 deletion aiohttp/http_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -390,7 +390,7 @@ def feed_data(
if not payload_parser.done:
self._payload_parser = payload_parser
else:
payload = EMPTY_PAYLOAD # type: ignore
payload = EMPTY_PAYLOAD

messages.append((msg, payload))
else:
Expand Down
11 changes: 8 additions & 3 deletions aiohttp/streams.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
import warnings
from typing import Awaitable, Callable, Generic, List, Optional, Tuple, TypeVar

from typing_extensions import Final

from .base_protocol import BaseProtocol
from .helpers import BaseTimerContext, set_exception, set_result
from .log import internal_logger
Expand Down Expand Up @@ -504,7 +506,10 @@ def _read_nowait(self, n: int) -> bytes:
return b"".join(chunks) if chunks else b""


class EmptyStreamReader(AsyncStreamReaderMixin):
class EmptyStreamReader(StreamReader): # lgtm [py/missing-call-to-init]
def __init__(self) -> None:
pass

def exception(self) -> Optional[BaseException]:
return None

Expand Down Expand Up @@ -549,11 +554,11 @@ async def readchunk(self) -> Tuple[bytes, bool]:
async def readexactly(self, n: int) -> bytes:
raise asyncio.IncompleteReadError(b"", n)

def read_nowait(self) -> bytes:
def read_nowait(self, n: int = -1) -> bytes:
return b""


EMPTY_PAYLOAD = EmptyStreamReader()
EMPTY_PAYLOAD: Final[StreamReader] = EmptyStreamReader()


class DataQueue(Generic[_T]):
Expand Down
165 changes: 76 additions & 89 deletions aiohttp/web_protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,16 +14,19 @@
Callable,
Deque,
Optional,
Sequence,
Tuple,
Type,
Union,
cast,
)

import attr
import yarl

from .abc import AbstractAccessLogger, AbstractStreamWriter
from .base_protocol import BaseProtocol
from .helpers import CeilTimeout, current_task
from .helpers import CeilTimeout
from .http import (
HttpProcessingError,
HttpRequestParser,
Expand Down Expand Up @@ -58,7 +61,6 @@

_RequestHandler = Callable[[BaseRequest], Awaitable[StreamResponse]]


ERROR = RawRequestMessage(
"UNKNOWN", "/", HttpVersion10, {}, {}, True, False, False, False, yarl.URL("/")
)
Expand All @@ -72,6 +74,16 @@ class PayloadAccessError(Exception):
"""Payload was accessed after response was sent."""


@attr.s(auto_attribs=True, frozen=True, slots=True)
class _ErrInfo:
status: int
exc: BaseException
message: str


_MsgType = Tuple[Union[RawRequestMessage, _ErrInfo], StreamReader]


class RequestHandler(BaseProtocol):
"""HTTP protocol implementation.
Expand All @@ -83,32 +95,28 @@ class RequestHandler(BaseProtocol):
status line, bad headers or incomplete payload. If any error occurs,
connection gets closed.
:param keepalive_timeout: number of seconds before closing
keep-alive connection
:type keepalive_timeout: int or None
keepalive_timeout -- number of seconds before closing
keep-alive connection
:param bool tcp_keepalive: TCP keep-alive is on, default is on
tcp_keepalive -- TCP keep-alive is on, default is on
:param bool debug: enable debug mode
debug -- enable debug mode
:param logger: custom logger object
:type logger: aiohttp.log.server_logger
logger -- custom logger object
:param access_log_class: custom class for access_logger
:type access_log_class: aiohttp.abc.AbstractAccessLogger
access_log_class -- custom class for access_logger
:param access_log: custom logging object
:type access_log: aiohttp.log.server_logger
access_log -- custom logging object
:param str access_log_format: access log format string
access_log_format -- access log format string
:param loop: Optional event loop
loop -- Optional event loop
:param int max_line_size: Optional maximum header line size
max_line_size -- Optional maximum header line size
:param int max_field_size: Optional maximum header field size
max_field_size -- Optional maximum header field size
:param int max_headers: Optional maximum header size
max_headers -- Optional maximum header size
"""

Expand All @@ -128,7 +136,6 @@ class RequestHandler(BaseProtocol):
"_messages",
"_message_tail",
"_waiter",
"_error_handler",
"_task_handler",
"_upgrade",
"_payload_parser",
Expand Down Expand Up @@ -161,19 +168,14 @@ def __init__(
lingering_time: float = 10.0,
read_bufsize: int = 2 ** 16,
):

super().__init__(loop)

self._request_count = 0
self._keepalive = False
self._current_request = None # type: Optional[BaseRequest]
self._manager = manager # type: Optional[Server]
self._request_handler = (
manager.request_handler
) # type: Optional[_RequestHandler]
self._request_factory = (
manager.request_factory
) # type: Optional[_RequestFactory]
self._request_handler: Optional[_RequestHandler] = manager.request_handler
self._request_factory: Optional[_RequestFactory] = manager.request_factory

self._tcp_keepalive = tcp_keepalive
# placeholder to be replaced on keepalive timeout setup
Expand All @@ -182,11 +184,10 @@ def __init__(
self._keepalive_timeout = keepalive_timeout
self._lingering_time = float(lingering_time)

self._messages: Deque[Tuple[RawRequestMessage, StreamReader]] = deque()
self._messages: Deque[_MsgType] = deque()
self._message_tail = b""

self._waiter = None # type: Optional[asyncio.Future[None]]
self._error_handler = None # type: Optional[asyncio.Task[None]]
self._task_handler = None # type: Optional[asyncio.Task[None]]

self._upgrade = False
Expand Down Expand Up @@ -239,9 +240,6 @@ async def shutdown(self, timeout: Optional[float] = 15.0) -> None:
# wait for handlers
with suppress(asyncio.CancelledError, asyncio.TimeoutError):
with CeilTimeout(timeout, loop=self._loop):
if self._error_handler is not None and not self._error_handler.done():
await self._error_handler

if self._current_request is not None:
self._current_request._cancel(asyncio.CancelledError())

Expand Down Expand Up @@ -288,8 +286,6 @@ def connection_lost(self, exc: Optional[BaseException]) -> None:
exc = ConnectionResetError("Connection lost")
self._current_request._cancel(exc)

if self._error_handler is not None:
self._error_handler.cancel()
if self._task_handler is not None:
self._task_handler.cancel()
if self._waiter is not None:
Expand Down Expand Up @@ -318,40 +314,30 @@ def data_received(self, data: bytes) -> None:
if self._force_close or self._close:
return
# parse http messages
messages: Sequence[_MsgType]
if self._payload_parser is None and not self._upgrade:
assert self._request_parser is not None
try:
messages, upgraded, tail = self._request_parser.feed_data(data)
except HttpProcessingError as exc:
# something happened during parsing
self._error_handler = self._loop.create_task(
self.handle_parse_error(
StreamWriter(self, self._loop), 400, exc, exc.message
)
)
self.close()
except Exception as exc:
# 500: internal error
self._error_handler = self._loop.create_task(
self.handle_parse_error(StreamWriter(self, self._loop), 500, exc)
)
self.close()
else:
if messages:
# sometimes the parser returns no messages
for (msg, payload) in messages:
self._request_count += 1
self._messages.append((msg, payload))

waiter = self._waiter
if waiter is not None:
if not waiter.done():
# don't set result twice
waiter.set_result(None)

self._upgrade = upgraded
if upgraded and tail:
self._message_tail = tail
messages = [
(_ErrInfo(status=400, exc=exc, message=exc.message), EMPTY_PAYLOAD)
]
upgraded = False
tail = b""

for msg, payload in messages or ():
self._request_count += 1
self._messages.append((msg, payload))

waiter = self._waiter
if messages and waiter is not None and not waiter.done():
# don't set result twice
waiter.set_result(None)

self._upgrade = upgraded
if upgraded and tail:
self._message_tail = tail

# no parser, just store
elif self._payload_parser is None and self._upgrade and data:
Expand Down Expand Up @@ -424,12 +410,13 @@ async def _handle_request(
self,
request: BaseRequest,
start_time: float,
request_handler: Callable[[BaseRequest], Awaitable[StreamResponse]],
) -> Tuple[StreamResponse, bool]:
assert self._request_handler is not None
try:
try:
self._current_request = request
resp = await self._request_handler(request)
resp = await request_handler(request)
finally:
self._current_request = None
except HTTPException as exc:
Expand Down Expand Up @@ -487,10 +474,19 @@ async def start(self) -> None:

manager.requests_count += 1
writer = StreamWriter(self, loop)
if isinstance(message, _ErrInfo):
# make request_factory work
request_handler = self._make_error_handler(message)
message = ERROR
else:
request_handler = self._request_handler

request = self._request_factory(message, payload, self, writer, handler)
try:
# a new task is used for copy context vars (#3406)
task = self._loop.create_task(self._handle_request(request, start))
task = self._loop.create_task(
self._handle_request(request, start, request_handler)
)
try:
resp, reset = await task
except (asyncio.CancelledError, ConnectionError):
Expand Down Expand Up @@ -568,7 +564,7 @@ async def start(self) -> None:
# remove handler, close transport if no handlers left
if not self._force_close:
self._task_handler = None
if self.transport is not None and self._error_handler is None:
if self.transport is not None:
self.transport.close()

async def finish_response(
Expand Down Expand Up @@ -620,6 +616,13 @@ def handle_error(
information. It always closes current connection."""
self.log_exception("Error handling request", exc_info=exc)

# some data already got sent, connection is broken
if request.writer.output_size > 0:
raise ConnectionError(
"Response is sent already, cannot send another response "
"with the error message"
)

ct = "text/plain"
if status == HTTPStatus.INTERNAL_SERVER_ERROR:
title = "{0.value} {0.phrase}".format(HTTPStatus.INTERNAL_SERVER_ERROR)
Expand Down Expand Up @@ -648,30 +651,14 @@ def handle_error(
resp = Response(status=status, text=message, content_type=ct)
resp.force_close()

# some data already got sent, connection is broken
if request.writer.output_size > 0 or self.transport is None:
self.force_close()

return resp

async def handle_parse_error(
self,
writer: AbstractStreamWriter,
status: int,
exc: Optional[BaseException] = None,
message: Optional[str] = None,
) -> None:
task = current_task()
assert task is not None
request = BaseRequest(
ERROR, EMPTY_PAYLOAD, self, writer, task, self._loop # type: ignore
)

resp = self.handle_error(request, status, exc, message)
await resp.prepare(request)
await resp.write_eof()

if self.transport is not None:
self.transport.close()
def _make_error_handler(
self, err_info: _ErrInfo
) -> Callable[[BaseRequest], Awaitable[StreamResponse]]:
async def handler(request: BaseRequest) -> StreamResponse:
return self.handle_error(
request, err_info.status, err_info.exc, err_info.message
)

self._error_handler = None
return handler
4 changes: 4 additions & 0 deletions tests/test_streams.py
Original file line number Diff line number Diff line change
Expand Up @@ -1490,3 +1490,7 @@ async def test_stream_reader_iter_chunks_chunked_encoding(protocol) -> None:
async for data, end_of_chunk in stream.iter_chunks():
assert (data, end_of_chunk) == (next(it), True)
pytest.raises(StopIteration, next, it)


def test_isinstance_check() -> None:
assert isinstance(streams.EMPTY_PAYLOAD, streams.StreamReader)
Loading

0 comments on commit 88f8f3b

Please sign in to comment.