diff --git a/CHANGES/8089.bugfix.rst b/CHANGES/8089.bugfix.rst new file mode 100644 index 00000000000..7f47448478d --- /dev/null +++ b/CHANGES/8089.bugfix.rst @@ -0,0 +1,3 @@ +The asynchronous internals now set the underlying causes +when assigning exceptions to the future objects +-- by :user:`webknjaz`. diff --git a/aiohttp/_http_parser.pyx b/aiohttp/_http_parser.pyx index 3f28fbdab43..7ea9b32ca55 100644 --- a/aiohttp/_http_parser.pyx +++ b/aiohttp/_http_parser.pyx @@ -19,7 +19,7 @@ from multidict import CIMultiDict as _CIMultiDict, CIMultiDictProxy as _CIMultiD from yarl import URL as _URL from aiohttp import hdrs -from aiohttp.helpers import DEBUG +from aiohttp.helpers import DEBUG, set_exception from .http_exceptions import ( BadHttpMessage, @@ -763,11 +763,13 @@ cdef int cb_on_body(cparser.llhttp_t* parser, cdef bytes body = at[:length] try: pyparser._payload.feed_data(body, length) - except BaseException as exc: + except BaseException as underlying_exc: + reraised_exc = underlying_exc if pyparser._payload_exception is not None: - pyparser._payload.set_exception(pyparser._payload_exception(str(exc))) - else: - pyparser._payload.set_exception(exc) + reraised_exc = pyparser._payload_exception(str(underlying_exc)) + + set_exception(pyparser._payload, reraised_exc, underlying_exc) + pyparser._payload_error = 1 return -1 else: diff --git a/aiohttp/base_protocol.py b/aiohttp/base_protocol.py index 4c9f0a752e3..dc1f24f99cd 100644 --- a/aiohttp/base_protocol.py +++ b/aiohttp/base_protocol.py @@ -1,6 +1,7 @@ import asyncio from typing import Optional, cast +from .helpers import set_exception from .tcp_helpers import tcp_nodelay @@ -76,7 +77,11 @@ def connection_lost(self, exc: Optional[BaseException]) -> None: if exc is None: waiter.set_result(None) else: - waiter.set_exception(exc) + set_exception( + waiter, + ConnectionError("Connection lost"), + exc, + ) async def _drain_helper(self) -> None: if not self.connected: diff --git a/aiohttp/client_proto.py b/aiohttp/client_proto.py index 604b954903e..324b532476c 100644 --- a/aiohttp/client_proto.py +++ b/aiohttp/client_proto.py @@ -4,18 +4,21 @@ from .base_protocol import BaseProtocol from .client_exceptions import ( + ClientConnectionError, ClientOSError, ClientPayloadError, ServerDisconnectedError, SocketTimeoutError, ) from .helpers import ( + _EXC_SENTINEL, BaseTimerContext, set_exception, set_result, status_code_must_be_empty_body, ) from .http import HttpResponseParser, RawResponseMessage, WebSocketReader +from .http_exceptions import HttpProcessingError from .streams import EMPTY_PAYLOAD, DataQueue, StreamReader @@ -80,33 +83,62 @@ def is_connected(self) -> bool: def connection_lost(self, exc: Optional[BaseException]) -> None: self._drop_timeout() - if exc is not None: - set_exception(self.closed, exc) - else: + original_connection_error = exc + reraised_exc = original_connection_error + + connection_closed_cleanly = original_connection_error is None + + if connection_closed_cleanly: set_result(self.closed, None) + else: + assert original_connection_error is not None + set_exception( + self.closed, + ClientConnectionError( + f"Connection lost: {original_connection_error !s}", + ), + original_connection_error, + ) if self._payload_parser is not None: - with suppress(Exception): + with suppress(Exception): # FIXME: log this somehow? self._payload_parser.feed_eof() uncompleted = None if self._parser is not None: try: uncompleted = self._parser.feed_eof() - except Exception as e: + except Exception as underlying_exc: if self._payload is not None: - exc = ClientPayloadError("Response payload is not completed") - exc.__cause__ = e - self._payload.set_exception(exc) + client_payload_exc_msg = ( + f"Response payload is not completed: {underlying_exc !r}" + ) + if not connection_closed_cleanly: + client_payload_exc_msg = ( + f"{client_payload_exc_msg !s}. " + f"{original_connection_error !r}" + ) + set_exception( + self._payload, + ClientPayloadError(client_payload_exc_msg), + underlying_exc, + ) if not self.is_eof(): - if isinstance(exc, OSError): - exc = ClientOSError(*exc.args) - if exc is None: - exc = ServerDisconnectedError(uncompleted) + if isinstance(original_connection_error, OSError): + reraised_exc = ClientOSError(*original_connection_error.args) + if connection_closed_cleanly: + reraised_exc = ServerDisconnectedError(uncompleted) # assigns self._should_close to True as side effect, # we do it anyway below - self.set_exception(exc) + underlying_non_eof_exc = ( + _EXC_SENTINEL + if connection_closed_cleanly + else original_connection_error + ) + assert underlying_non_eof_exc is not None + assert reraised_exc is not None + self.set_exception(reraised_exc, underlying_non_eof_exc) self._should_close = True self._parser = None @@ -114,7 +146,7 @@ def connection_lost(self, exc: Optional[BaseException]) -> None: self._payload_parser = None self._reading_paused = False - super().connection_lost(exc) + super().connection_lost(reraised_exc) def eof_received(self) -> None: # should call parser.feed_eof() most likely @@ -128,10 +160,14 @@ def resume_reading(self) -> None: super().resume_reading() self._reschedule_timeout() - def set_exception(self, exc: BaseException) -> None: + def set_exception( + self, + exc: BaseException, + exc_cause: BaseException = _EXC_SENTINEL, + ) -> None: self._should_close = True self._drop_timeout() - super().set_exception(exc) + super().set_exception(exc, exc_cause) def set_parser(self, parser: Any, payload: Any) -> None: # TODO: actual types are: @@ -208,7 +244,7 @@ def _on_read_timeout(self) -> None: exc = SocketTimeoutError("Timeout on reading data from socket") self.set_exception(exc) if self._payload is not None: - self._payload.set_exception(exc) + set_exception(self._payload, exc) def data_received(self, data: bytes) -> None: self._reschedule_timeout() @@ -234,14 +270,14 @@ def data_received(self, data: bytes) -> None: # parse http messages try: messages, upgraded, tail = self._parser.feed_data(data) - except BaseException as exc: + except BaseException as underlying_exc: if self.transport is not None: # connection.release() could be called BEFORE # data_received(), the transport is already # closed in this case self.transport.close() # should_close is True after the call - self.set_exception(exc) + self.set_exception(HttpProcessingError(), underlying_exc) return self._upgraded = upgraded diff --git a/aiohttp/client_reqrep.py b/aiohttp/client_reqrep.py index 7645310a6ba..ad77f32959f 100644 --- a/aiohttp/client_reqrep.py +++ b/aiohttp/client_reqrep.py @@ -53,6 +53,7 @@ noop, parse_mimetype, reify, + set_exception, set_result, ) from .http import ( @@ -566,20 +567,29 @@ async def write_bytes( for chunk in self.body: await writer.write(chunk) # type: ignore[arg-type] - except OSError as exc: - if exc.errno is None and isinstance(exc, asyncio.TimeoutError): - protocol.set_exception(exc) - else: - new_exc = ClientOSError( - exc.errno, "Can not write request body for %s" % self.url + except OSError as underlying_exc: + reraised_exc = underlying_exc + + exc_is_not_timeout = underlying_exc.errno is not None or not isinstance( + underlying_exc, asyncio.TimeoutError + ) + if exc_is_not_timeout: + reraised_exc = ClientOSError( + underlying_exc.errno, + f"Can not write request body for {self.url !s}", ) - new_exc.__context__ = exc - new_exc.__cause__ = exc - protocol.set_exception(new_exc) + + set_exception(protocol, reraised_exc, underlying_exc) except asyncio.CancelledError: await writer.write_eof() - except Exception as exc: - protocol.set_exception(exc) + except Exception as underlying_exc: + set_exception( + protocol, + ClientConnectionError( + f"Failed to send bytes into the underlying connection {conn !s}", + ), + underlying_exc, + ) else: await writer.write_eof() protocol.start_timeout() @@ -1019,7 +1029,7 @@ def _notify_content(self) -> None: content = self.content # content can be None here, but the types are cheated elsewhere. if content and content.exception() is None: # type: ignore[truthy-bool] - content.set_exception(ClientConnectionError("Connection closed")) + set_exception(content, ClientConnectionError("Connection closed")) self._released = True async def wait_for_close(self) -> None: diff --git a/aiohttp/helpers.py b/aiohttp/helpers.py index 5435e2f9e07..abd4c7fc75d 100644 --- a/aiohttp/helpers.py +++ b/aiohttp/helpers.py @@ -797,9 +797,39 @@ def set_result(fut: "asyncio.Future[_T]", result: _T) -> None: fut.set_result(result) -def set_exception(fut: "asyncio.Future[_T]", exc: BaseException) -> None: - if not fut.done(): - fut.set_exception(exc) +_EXC_SENTINEL = BaseException() + + +class ErrorableProtocol(Protocol): + def set_exception( + self, + exc: BaseException, + exc_cause: BaseException = ..., + ) -> None: + ... # pragma: no cover + + +def set_exception( + fut: "asyncio.Future[_T] | ErrorableProtocol", + exc: BaseException, + exc_cause: BaseException = _EXC_SENTINEL, +) -> None: + """Set future exception. + + If the future is marked as complete, this function is a no-op. + + :param exc_cause: An exception that is a direct cause of ``exc``. + Only set if provided. + """ + if asyncio.isfuture(fut) and fut.done(): + return + + exc_is_sentinel = exc_cause is _EXC_SENTINEL + exc_causes_itself = exc is exc_cause + if not exc_is_sentinel and not exc_causes_itself: + exc.__cause__ = exc_cause + + fut.set_exception(exc) @functools.total_ordering diff --git a/aiohttp/http_parser.py b/aiohttp/http_parser.py index 70804563b31..28f8edcf09d 100644 --- a/aiohttp/http_parser.py +++ b/aiohttp/http_parser.py @@ -28,10 +28,12 @@ from .base_protocol import BaseProtocol from .compression_utils import HAS_BROTLI, BrotliDecompressor, ZLibDecompressor from .helpers import ( + _EXC_SENTINEL, DEBUG, NO_EXTENSIONS, BaseTimerContext, method_must_be_empty_body, + set_exception, status_code_must_be_empty_body, ) from .http_exceptions import ( @@ -439,13 +441,16 @@ def get_content_length() -> Optional[int]: assert self._payload_parser is not None try: eof, data = self._payload_parser.feed_data(data[start_pos:], SEP) - except BaseException as exc: + except BaseException as underlying_exc: + reraised_exc = underlying_exc if self.payload_exception is not None: - self._payload_parser.payload.set_exception( - self.payload_exception(str(exc)) - ) - else: - self._payload_parser.payload.set_exception(exc) + reraised_exc = self.payload_exception(str(underlying_exc)) + + set_exception( + self._payload_parser.payload, + reraised_exc, + underlying_exc, + ) eof = True data = b"" @@ -826,7 +831,7 @@ def feed_data( exc = TransferEncodingError( chunk[:pos].decode("ascii", "surrogateescape") ) - self.payload.set_exception(exc) + set_exception(self.payload, exc) raise exc size = int(bytes(size_b), 16) @@ -929,8 +934,12 @@ def __init__(self, out: StreamReader, encoding: Optional[str]) -> None: else: self.decompressor = ZLibDecompressor(encoding=encoding) - def set_exception(self, exc: BaseException) -> None: - self.out.set_exception(exc) + def set_exception( + self, + exc: BaseException, + exc_cause: BaseException = _EXC_SENTINEL, + ) -> None: + set_exception(self.out, exc, exc_cause) def feed_data(self, chunk: bytes, size: int) -> None: if not size: diff --git a/aiohttp/http_websocket.py b/aiohttp/http_websocket.py index d81d7269658..1a9ca7e6ea0 100644 --- a/aiohttp/http_websocket.py +++ b/aiohttp/http_websocket.py @@ -25,7 +25,7 @@ from .base_protocol import BaseProtocol from .compression_utils import ZLibCompressor, ZLibDecompressor -from .helpers import NO_EXTENSIONS +from .helpers import NO_EXTENSIONS, set_exception from .streams import DataQueue __all__ = ( @@ -305,7 +305,7 @@ def feed_data(self, data: bytes) -> Tuple[bool, bytes]: return self._feed_data(data) except Exception as exc: self._exc = exc - self.queue.set_exception(exc) + set_exception(self.queue, exc) return True, b"" def _feed_data(self, data: bytes) -> Tuple[bool, bytes]: diff --git a/aiohttp/streams.py b/aiohttp/streams.py index 8ee088d3b5c..84aa7383de9 100644 --- a/aiohttp/streams.py +++ b/aiohttp/streams.py @@ -14,7 +14,13 @@ ) from .base_protocol import BaseProtocol -from .helpers import BaseTimerContext, TimerNoop, set_exception, set_result +from .helpers import ( + _EXC_SENTINEL, + BaseTimerContext, + TimerNoop, + set_exception, + set_result, +) from .log import internal_logger __all__ = ( @@ -146,19 +152,23 @@ def get_read_buffer_limits(self) -> Tuple[int, int]: def exception(self) -> Optional[BaseException]: return self._exception - def set_exception(self, exc: BaseException) -> None: + def set_exception( + self, + exc: BaseException, + exc_cause: BaseException = _EXC_SENTINEL, + ) -> None: self._exception = exc self._eof_callbacks.clear() waiter = self._waiter if waiter is not None: self._waiter = None - set_exception(waiter, exc) + set_exception(waiter, exc, exc_cause) waiter = self._eof_waiter if waiter is not None: self._eof_waiter = None - set_exception(waiter, exc) + set_exception(waiter, exc, exc_cause) def on_eof(self, callback: Callable[[], None]) -> None: if self._eof: @@ -499,7 +509,11 @@ def __repr__(self) -> str: def exception(self) -> Optional[BaseException]: return None - def set_exception(self, exc: BaseException) -> None: + def set_exception( + self, + exc: BaseException, + exc_cause: BaseException = _EXC_SENTINEL, + ) -> None: pass def on_eof(self, callback: Callable[[], None]) -> None: @@ -574,14 +588,18 @@ def at_eof(self) -> bool: def exception(self) -> Optional[BaseException]: return self._exception - def set_exception(self, exc: BaseException) -> None: + def set_exception( + self, + exc: BaseException, + exc_cause: BaseException = _EXC_SENTINEL, + ) -> None: self._eof = True self._exception = exc waiter = self._waiter if waiter is not None: self._waiter = None - set_exception(waiter, exc) + set_exception(waiter, exc, exc_cause) def feed_data(self, data: _T, size: int = 0) -> None: self._size += size diff --git a/aiohttp/web_protocol.py b/aiohttp/web_protocol.py index 36d05e90fd9..cad01512c08 100644 --- a/aiohttp/web_protocol.py +++ b/aiohttp/web_protocol.py @@ -25,7 +25,7 @@ from .abc import AbstractAccessLogger, AbstractAsyncAccessLogger, AbstractStreamWriter from .base_protocol import BaseProtocol -from .helpers import ceil_timeout +from .helpers import ceil_timeout, set_exception from .http import ( HttpProcessingError, HttpRequestParser, @@ -577,7 +577,7 @@ async def start(self) -> None: self.log_debug("Uncompleted request.") self.close() - payload.set_exception(PayloadAccessError()) + set_exception(payload, PayloadAccessError()) except asyncio.CancelledError: self.log_debug("Ignored premature client disconnection ") diff --git a/aiohttp/web_request.py b/aiohttp/web_request.py index 6f14f73ad19..da86fdff18b 100644 --- a/aiohttp/web_request.py +++ b/aiohttp/web_request.py @@ -42,6 +42,7 @@ parse_http_date, reify, sentinel, + set_exception, set_result, ) from .http_parser import RawRequestMessage @@ -813,7 +814,7 @@ async def _prepare_hook(self, response: StreamResponse) -> None: return def _cancel(self, exc: BaseException) -> None: - self._payload.set_exception(exc) + set_exception(self._payload, exc) for fut in self._disconnection_waiters: set_result(fut, None) diff --git a/aiohttp/web_ws.py b/aiohttp/web_ws.py index 3ffecc2c06f..3950e569c4a 100644 --- a/aiohttp/web_ws.py +++ b/aiohttp/web_ws.py @@ -11,7 +11,7 @@ from . import hdrs from .abc import AbstractStreamWriter -from .helpers import call_later, set_result +from .helpers import call_later, set_exception, set_result from .http import ( WS_CLOSED_MESSAGE, WS_CLOSING_MESSAGE, @@ -548,4 +548,4 @@ async def __anext__(self) -> WSMessage: def _cancel(self, exc: BaseException) -> None: if self._reader is not None: - self._reader.set_exception(exc) + set_exception(self._reader, exc) diff --git a/tests/test_base_protocol.py b/tests/test_base_protocol.py index 7254c6ff8cb..2cf7280e4b0 100644 --- a/tests/test_base_protocol.py +++ b/tests/test_base_protocol.py @@ -186,9 +186,9 @@ async def test_lost_drain_waited_exception() -> None: assert pr._drain_waiter is not None exc = RuntimeError() pr.connection_lost(exc) - with pytest.raises(RuntimeError) as cm: + with pytest.raises(ConnectionError, match=r"^Connection lost$") as cm: await t - assert cm.value is exc + assert cm.value.__cause__ is exc assert pr._drain_waiter is None diff --git a/tests/test_client_request.py b/tests/test_client_request.py index 5483d7be984..8ef90b5a1d5 100644 --- a/tests/test_client_request.py +++ b/tests/test_client_request.py @@ -14,6 +14,7 @@ import aiohttp from aiohttp import BaseConnector, hdrs, helpers, payload +from aiohttp.client_exceptions import ClientConnectionError from aiohttp.client_reqrep import ( ClientRequest, ClientResponse, @@ -1050,9 +1051,8 @@ async def throw_exc(): # assert connection.close.called assert conn.protocol.set_exception.called outer_exc = conn.protocol.set_exception.call_args[0][0] - assert isinstance(outer_exc, ValueError) - assert inner_exc is outer_exc - assert inner_exc is outer_exc + assert isinstance(outer_exc, ClientConnectionError) + assert outer_exc.__cause__ is inner_exc await req.close() diff --git a/tests/test_http_parser.py b/tests/test_http_parser.py index e5214035280..a9e58295886 100644 --- a/tests/test_http_parser.py +++ b/tests/test_http_parser.py @@ -279,6 +279,7 @@ def test_parse_headers_longline(parser: Any) -> None: header_name = b"Test" + invalid_unicode_byte + b"Header" + b"A" * 8192 text = b"GET /test HTTP/1.1\r\n" + header_name + b": test\r\n" + b"\r\n" + b"\r\n" with pytest.raises((http_exceptions.LineTooLong, http_exceptions.BadHttpMessage)): + # FIXME: `LineTooLong` doesn't seem to actually be happening parser.feed_data(text)