Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[3.8] Refactor web error handling (#5270). #5295

Merged
merged 1 commit into from
Nov 27, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion aiohttp/client_proto.py
Original file line number Diff line number Diff line change
Expand Up @@ -231,7 +231,7 @@ def data_received(self, data: bytes) -> None:
self._payload = payload

if self._skip_payload or message.code in (204, 304):
self.feed_data((message, EMPTY_PAYLOAD), 0) # type: ignore
self.feed_data((message, EMPTY_PAYLOAD), 0)
else:
self.feed_data((message, payload), 0)
if payload is not None:
Expand Down
2 changes: 1 addition & 1 deletion aiohttp/http_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -390,7 +390,7 @@ def feed_data(
if not payload_parser.done:
self._payload_parser = payload_parser
else:
payload = EMPTY_PAYLOAD # type: ignore
payload = EMPTY_PAYLOAD

messages.append((msg, payload))
else:
Expand Down
11 changes: 8 additions & 3 deletions aiohttp/streams.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
import warnings
from typing import Awaitable, Callable, Generic, List, Optional, Tuple, TypeVar

from typing_extensions import Final

from .base_protocol import BaseProtocol
from .helpers import BaseTimerContext, set_exception, set_result
from .log import internal_logger
Expand Down Expand Up @@ -504,7 +506,10 @@ def _read_nowait(self, n: int) -> bytes:
return b"".join(chunks) if chunks else b""


class EmptyStreamReader(AsyncStreamReaderMixin):
class EmptyStreamReader(StreamReader): # lgtm [py/missing-call-to-init]
def __init__(self) -> None:
pass

def exception(self) -> Optional[BaseException]:
return None

Expand Down Expand Up @@ -549,11 +554,11 @@ async def readchunk(self) -> Tuple[bytes, bool]:
async def readexactly(self, n: int) -> bytes:
raise asyncio.IncompleteReadError(b"", n)

def read_nowait(self) -> bytes:
def read_nowait(self, n: int = -1) -> bytes:
return b""


EMPTY_PAYLOAD = EmptyStreamReader()
EMPTY_PAYLOAD: Final[StreamReader] = EmptyStreamReader()


class DataQueue(Generic[_T]):
Expand Down
165 changes: 76 additions & 89 deletions aiohttp/web_protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,16 +14,19 @@
Callable,
Deque,
Optional,
Sequence,
Tuple,
Type,
Union,
cast,
)

import attr
import yarl

from .abc import AbstractAccessLogger, AbstractStreamWriter
from .base_protocol import BaseProtocol
from .helpers import CeilTimeout, current_task
from .helpers import CeilTimeout
from .http import (
HttpProcessingError,
HttpRequestParser,
Expand Down Expand Up @@ -58,7 +61,6 @@

_RequestHandler = Callable[[BaseRequest], Awaitable[StreamResponse]]


ERROR = RawRequestMessage(
"UNKNOWN", "/", HttpVersion10, {}, {}, True, False, False, False, yarl.URL("/")
)
Expand All @@ -72,6 +74,16 @@ class PayloadAccessError(Exception):
"""Payload was accessed after response was sent."""


@attr.s(auto_attribs=True, frozen=True, slots=True)
class _ErrInfo:
status: int
exc: BaseException
message: str


_MsgType = Tuple[Union[RawRequestMessage, _ErrInfo], StreamReader]


class RequestHandler(BaseProtocol):
"""HTTP protocol implementation.

Expand All @@ -83,32 +95,28 @@ class RequestHandler(BaseProtocol):
status line, bad headers or incomplete payload. If any error occurs,
connection gets closed.

:param keepalive_timeout: number of seconds before closing
keep-alive connection
:type keepalive_timeout: int or None
keepalive_timeout -- number of seconds before closing
keep-alive connection

:param bool tcp_keepalive: TCP keep-alive is on, default is on
tcp_keepalive -- TCP keep-alive is on, default is on

:param bool debug: enable debug mode
debug -- enable debug mode

:param logger: custom logger object
:type logger: aiohttp.log.server_logger
logger -- custom logger object

:param access_log_class: custom class for access_logger
:type access_log_class: aiohttp.abc.AbstractAccessLogger
access_log_class -- custom class for access_logger

:param access_log: custom logging object
:type access_log: aiohttp.log.server_logger
access_log -- custom logging object

:param str access_log_format: access log format string
access_log_format -- access log format string

:param loop: Optional event loop
loop -- Optional event loop

:param int max_line_size: Optional maximum header line size
max_line_size -- Optional maximum header line size

:param int max_field_size: Optional maximum header field size
max_field_size -- Optional maximum header field size

:param int max_headers: Optional maximum header size
max_headers -- Optional maximum header size

"""

Expand All @@ -128,7 +136,6 @@ class RequestHandler(BaseProtocol):
"_messages",
"_message_tail",
"_waiter",
"_error_handler",
"_task_handler",
"_upgrade",
"_payload_parser",
Expand Down Expand Up @@ -161,19 +168,14 @@ def __init__(
lingering_time: float = 10.0,
read_bufsize: int = 2 ** 16,
):

super().__init__(loop)

self._request_count = 0
self._keepalive = False
self._current_request = None # type: Optional[BaseRequest]
self._manager = manager # type: Optional[Server]
self._request_handler = (
manager.request_handler
) # type: Optional[_RequestHandler]
self._request_factory = (
manager.request_factory
) # type: Optional[_RequestFactory]
self._request_handler: Optional[_RequestHandler] = manager.request_handler
self._request_factory: Optional[_RequestFactory] = manager.request_factory

self._tcp_keepalive = tcp_keepalive
# placeholder to be replaced on keepalive timeout setup
Expand All @@ -182,11 +184,10 @@ def __init__(
self._keepalive_timeout = keepalive_timeout
self._lingering_time = float(lingering_time)

self._messages: Deque[Tuple[RawRequestMessage, StreamReader]] = deque()
self._messages: Deque[_MsgType] = deque()
self._message_tail = b""

self._waiter = None # type: Optional[asyncio.Future[None]]
self._error_handler = None # type: Optional[asyncio.Task[None]]
self._task_handler = None # type: Optional[asyncio.Task[None]]

self._upgrade = False
Expand Down Expand Up @@ -239,9 +240,6 @@ async def shutdown(self, timeout: Optional[float] = 15.0) -> None:
# wait for handlers
with suppress(asyncio.CancelledError, asyncio.TimeoutError):
with CeilTimeout(timeout, loop=self._loop):
if self._error_handler is not None and not self._error_handler.done():
await self._error_handler

if self._current_request is not None:
self._current_request._cancel(asyncio.CancelledError())

Expand Down Expand Up @@ -288,8 +286,6 @@ def connection_lost(self, exc: Optional[BaseException]) -> None:
exc = ConnectionResetError("Connection lost")
self._current_request._cancel(exc)

if self._error_handler is not None:
self._error_handler.cancel()
if self._task_handler is not None:
self._task_handler.cancel()
if self._waiter is not None:
Expand Down Expand Up @@ -318,40 +314,30 @@ def data_received(self, data: bytes) -> None:
if self._force_close or self._close:
return
# parse http messages
messages: Sequence[_MsgType]
if self._payload_parser is None and not self._upgrade:
assert self._request_parser is not None
try:
messages, upgraded, tail = self._request_parser.feed_data(data)
except HttpProcessingError as exc:
# something happened during parsing
self._error_handler = self._loop.create_task(
self.handle_parse_error(
StreamWriter(self, self._loop), 400, exc, exc.message
)
)
self.close()
except Exception as exc:
# 500: internal error
self._error_handler = self._loop.create_task(
self.handle_parse_error(StreamWriter(self, self._loop), 500, exc)
)
self.close()
else:
if messages:
# sometimes the parser returns no messages
for (msg, payload) in messages:
self._request_count += 1
self._messages.append((msg, payload))

waiter = self._waiter
if waiter is not None:
if not waiter.done():
# don't set result twice
waiter.set_result(None)

self._upgrade = upgraded
if upgraded and tail:
self._message_tail = tail
messages = [
(_ErrInfo(status=400, exc=exc, message=exc.message), EMPTY_PAYLOAD)
]
upgraded = False
tail = b""

for msg, payload in messages or ():
self._request_count += 1
self._messages.append((msg, payload))

waiter = self._waiter
if messages and waiter is not None and not waiter.done():
# don't set result twice
waiter.set_result(None)

self._upgrade = upgraded
if upgraded and tail:
self._message_tail = tail

# no parser, just store
elif self._payload_parser is None and self._upgrade and data:
Expand Down Expand Up @@ -424,12 +410,13 @@ async def _handle_request(
self,
request: BaseRequest,
start_time: float,
request_handler: Callable[[BaseRequest], Awaitable[StreamResponse]],
) -> Tuple[StreamResponse, bool]:
assert self._request_handler is not None
try:
try:
self._current_request = request
resp = await self._request_handler(request)
resp = await request_handler(request)
finally:
self._current_request = None
except HTTPException as exc:
Expand Down Expand Up @@ -487,10 +474,19 @@ async def start(self) -> None:

manager.requests_count += 1
writer = StreamWriter(self, loop)
if isinstance(message, _ErrInfo):
# make request_factory work
request_handler = self._make_error_handler(message)
message = ERROR
else:
request_handler = self._request_handler

request = self._request_factory(message, payload, self, writer, handler)
try:
# a new task is used for copy context vars (#3406)
task = self._loop.create_task(self._handle_request(request, start))
task = self._loop.create_task(
self._handle_request(request, start, request_handler)
)
try:
resp, reset = await task
except (asyncio.CancelledError, ConnectionError):
Expand Down Expand Up @@ -568,7 +564,7 @@ async def start(self) -> None:
# remove handler, close transport if no handlers left
if not self._force_close:
self._task_handler = None
if self.transport is not None and self._error_handler is None:
if self.transport is not None:
self.transport.close()

async def finish_response(
Expand Down Expand Up @@ -620,6 +616,13 @@ def handle_error(
information. It always closes current connection."""
self.log_exception("Error handling request", exc_info=exc)

# some data already got sent, connection is broken
if request.writer.output_size > 0:
raise ConnectionError(
"Response is sent already, cannot send another response "
"with the error message"
)

ct = "text/plain"
if status == HTTPStatus.INTERNAL_SERVER_ERROR:
title = "{0.value} {0.phrase}".format(HTTPStatus.INTERNAL_SERVER_ERROR)
Expand Down Expand Up @@ -648,30 +651,14 @@ def handle_error(
resp = Response(status=status, text=message, content_type=ct)
resp.force_close()

# some data already got sent, connection is broken
if request.writer.output_size > 0 or self.transport is None:
self.force_close()

return resp

async def handle_parse_error(
self,
writer: AbstractStreamWriter,
status: int,
exc: Optional[BaseException] = None,
message: Optional[str] = None,
) -> None:
task = current_task()
assert task is not None
request = BaseRequest(
ERROR, EMPTY_PAYLOAD, self, writer, task, self._loop # type: ignore
)

resp = self.handle_error(request, status, exc, message)
await resp.prepare(request)
await resp.write_eof()

if self.transport is not None:
self.transport.close()
def _make_error_handler(
self, err_info: _ErrInfo
) -> Callable[[BaseRequest], Awaitable[StreamResponse]]:
async def handler(request: BaseRequest) -> StreamResponse:
return self.handle_error(
request, err_info.status, err_info.exc, err_info.message
)

self._error_handler = None
return handler
4 changes: 4 additions & 0 deletions tests/test_streams.py
Original file line number Diff line number Diff line change
Expand Up @@ -1490,3 +1490,7 @@ async def test_stream_reader_iter_chunks_chunked_encoding(protocol) -> None:
async for data, end_of_chunk in stream.iter_chunks():
assert (data, end_of_chunk) == (next(it), True)
pytest.raises(StopIteration, next, it)


def test_isinstance_check() -> None:
assert isinstance(streams.EMPTY_PAYLOAD, streams.StreamReader)
Loading