diff --git a/aiohttp/client.py b/aiohttp/client.py index 4feff51423d..ae90d681f0c 100644 --- a/aiohttp/client.py +++ b/aiohttp/client.py @@ -1,11 +1,11 @@ """HTTP Client for asyncio.""" - import asyncio import base64 import dataclasses import hashlib import json import os +import pprint import sys import traceback import warnings @@ -574,13 +574,20 @@ async def _request( try: try: + pprint.pprint(["send"]) resp = await req.send(conn) try: + pprint.pprint(["start"]) + await resp.start(conn) - except BaseException: + except BaseException as ex: + pprint.pprint(["close 1", ex]) + resp.close() raise - except BaseException: + except BaseException as ex: + pprint.pprint(["close 2", ex]) + conn.close() raise except (ClientOSError, ServerDisconnectedError): diff --git a/aiohttp/client_reqrep.py b/aiohttp/client_reqrep.py index ad77f32959f..47a38f8f7a0 100644 --- a/aiohttp/client_reqrep.py +++ b/aiohttp/client_reqrep.py @@ -652,6 +652,10 @@ async def send(self, conn: "Connection") -> "ClientResponse": if connection is not None: self.headers[hdrs.CONNECTION] = connection + import pprint + + pprint.pprint(["send status line"]) + # status + headers status_line = "{0} {1} HTTP/{v.major}.{v.minor}".format( self.method, path, v=self.version @@ -905,7 +909,13 @@ async def start(self, connection: "Connection") -> "ClientResponse": # read response try: protocol = self._protocol + import pprint + + pprint.pprint(["read response"]) message, payload = await protocol.read() # type: ignore[union-attr] + import pprint + + pprint.pprint(["protocol read", message, payload]) except http.HttpProcessingError as exc: raise ClientResponseError( self.request_info, diff --git a/aiohttp/helpers.py b/aiohttp/helpers.py index abd4c7fc75d..6bf7b6777e5 100644 --- a/aiohttp/helpers.py +++ b/aiohttp/helpers.py @@ -1120,3 +1120,8 @@ def should_remove_content_length(method: str, code: int) -> bool: or 100 <= code < 200 or (200 <= code < 300 and method.upper() == hdrs.METH_CONNECT) ) + + +def is_supported_upgrade(headers: CIMultiDict[str]) -> bool: + """Check if the upgrade header is supported.""" + return headers.get(hdrs.UPGRADE, "").lower() in ("tcp", "websocket") diff --git a/aiohttp/http_parser.py b/aiohttp/http_parser.py index 28f8edcf09d..49940ea9f4e 100644 --- a/aiohttp/http_parser.py +++ b/aiohttp/http_parser.py @@ -32,6 +32,7 @@ DEBUG, NO_EXTENSIONS, BaseTimerContext, + is_supported_upgrade, method_must_be_empty_body, set_exception, status_code_must_be_empty_body, @@ -302,6 +303,9 @@ def feed_data( data_len = len(data) start_pos = 0 loop = self.loop + import pprint + + pprint.pprint(["feed_data", data]) while start_pos < data_len: # read HTTP message (request/response line + headers), \r\n\r\n @@ -347,21 +351,46 @@ def get_content_length() -> Optional[int]: if SEC_WEBSOCKET_KEY1 in msg.headers: raise InvalidHeader(SEC_WEBSOCKET_KEY1) - self._upgraded = msg.upgrade - method = getattr(msg, "method", self.method) # code is only present on responses code = getattr(msg, "code", 0) + import pprint + + pprint.pprint(["upgraded", msg.upgrade, msg]) + # If response is not 101 than we didn't upgrade + # if its 0 its not a response + supported_upgrade = msg.upgrade and is_supported_upgrade( + msg.headers + ) + self._upgraded = supported_upgrade and code in (0, 101) + assert self.protocol is not None # calculate payload empty_body = status_code_must_be_empty_body(code) or bool( method and method_must_be_empty_body(method) ) + import pprint + + pprint.pprint( + [ + "empty_body", + empty_body, + code, + msg.upgrade, + "self._upgraded", + self._upgraded, + msg, + "length", + length, + "msg.chunked", + msg.chunked, + ] + ) + # self._upgraded=False if not empty_body and ( - (length is not None and length > 0) - or msg.chunked - and not msg.upgrade + ((length is not None and length > 0) or msg.chunked) + and not supported_upgrade ): payload = StreamReader( self.protocol, @@ -514,6 +543,9 @@ def parse_headers( close_conn = False # https://www.rfc-editor.org/rfc/rfc9110.html#name-101-switching-protocols elif v == "upgrade" and headers.get(hdrs.UPGRADE): + import pprint + + pprint.pprint(["upgrade", headers.get(hdrs.UPGRADE)]) upgrade = True # encoding diff --git a/aiohttp/streams.py b/aiohttp/streams.py index 84aa7383de9..8e54e9d94e0 100644 --- a/aiohttp/streams.py +++ b/aiohttp/streams.py @@ -624,7 +624,12 @@ async def read(self) -> _T: self._waiter = self._loop.create_future() try: await self._waiter - except (asyncio.CancelledError, asyncio.TimeoutError): + except (asyncio.CancelledError, asyncio.TimeoutError) as ex: + import pprint + + pprint.pprint( + ["read waiter", self._waiter, self._buffer, self._eof, ex] + ) self._waiter = None raise diff --git a/aiohttp/web_protocol.py b/aiohttp/web_protocol.py index cad01512c08..4cf6321525c 100644 --- a/aiohttp/web_protocol.py +++ b/aiohttp/web_protocol.py @@ -477,6 +477,9 @@ async def _handle_request( except asyncio.CancelledError: raise except asyncio.TimeoutError as exc: + import pprint + + pprint.pprint(["request handler timed out"]) self.log_debug("Request handler timed out.", exc_info=exc) resp = self.handle_error(request, 504) reset = await self.finish_response(request, resp, start_time) diff --git a/tests/test_web_server.py b/tests/test_web_server.py index f26f0537ec7..22a54d31d2c 100644 --- a/tests/test_web_server.py +++ b/tests/test_web_server.py @@ -1,5 +1,6 @@ # type: ignore import asyncio +import sys from contextlib import suppress from typing import Any from unittest import mock @@ -8,6 +9,11 @@ from aiohttp import client, helpers, web +if sys.version_info >= (3, 11): + import asyncio as async_timeout +else: + import async_timeout + async def test_simple_server(aiohttp_raw_server: Any, aiohttp_client: Any) -> None: async def handler(request): @@ -31,13 +37,29 @@ async def test_unsupported_upgrade(aiohttp_raw_server, aiohttp_client) -> None: # don't fail if a client probes for an unsupported protocol upgrade # https://github.com/aio-libs/aiohttp/issues/6446#issuecomment-999032039 async def handler(request: web.Request): - return web.Response(body=await request.read()) + try: + import pprint + + pprint.pprint(["handler called"]) + async with async_timeout.timeout(10): + result = await request.read() + pprint.pprint(["handler read", result]) + return web.Response(body=result) + except Exception as e: + import pprint + + pprint.pprint(["handler except", e]) + raise upgrade_headers = {"Connection": "Upgrade", "Upgrade": "unsupported_proto"} server = await aiohttp_raw_server(handler) - cli = await aiohttp_client(server) + cli: client.ClientSession = await aiohttp_client(server) test_data = b"Test" - resp = await cli.post("/path/to", data=test_data, headers=upgrade_headers) + async with async_timeout.timeout(10): + resp = await cli.post("/path/to", data=test_data, headers=upgrade_headers) + import pprint + + pprint.pprint(resp.headers) assert resp.status == 200 data = await resp.read() assert data == test_data