Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

DO NOT MERGE: debug test_unsupported_upgrade #7960

Closed
wants to merge 12 commits into from
13 changes: 10 additions & 3 deletions aiohttp/client.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
"""HTTP Client for asyncio."""

import asyncio
import base64
import dataclasses
import hashlib
import json
import os
import pprint
import sys
import traceback
import warnings
Expand Down Expand Up @@ -574,13 +574,20 @@

try:
try:
pprint.pprint(["send"])
resp = await req.send(conn)
try:
pprint.pprint(["start"])

await resp.start(conn)
except BaseException:
except BaseException as ex:

Check notice

Code scanning / CodeQL

Except block handles 'BaseException' Note

Except block directly handles BaseException.
pprint.pprint(["close 1", ex])

resp.close()
raise
except BaseException:
except BaseException as ex:

Check notice

Code scanning / CodeQL

Except block handles 'BaseException' Note

Except block directly handles BaseException.
pprint.pprint(["close 2", ex])

conn.close()
raise
except (ClientOSError, ServerDisconnectedError):
Expand Down
10 changes: 10 additions & 0 deletions aiohttp/client_reqrep.py
Original file line number Diff line number Diff line change
Expand Up @@ -652,6 +652,10 @@ async def send(self, conn: "Connection") -> "ClientResponse":
if connection is not None:
self.headers[hdrs.CONNECTION] = connection

import pprint

pprint.pprint(["send status line"])

# status + headers
status_line = "{0} {1} HTTP/{v.major}.{v.minor}".format(
self.method, path, v=self.version
Expand Down Expand Up @@ -905,7 +909,13 @@ async def start(self, connection: "Connection") -> "ClientResponse":
# read response
try:
protocol = self._protocol
import pprint

pprint.pprint(["read response"])
message, payload = await protocol.read() # type: ignore[union-attr]
import pprint

pprint.pprint(["protocol read", message, payload])
except http.HttpProcessingError as exc:
raise ClientResponseError(
self.request_info,
Expand Down
19 changes: 17 additions & 2 deletions aiohttp/http_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -302,6 +302,9 @@ def feed_data(
data_len = len(data)
start_pos = 0
loop = self.loop
import pprint

pprint.pprint(["feed_data", data])

while start_pos < data_len:
# read HTTP message (request/response line + headers), \r\n\r\n
Expand Down Expand Up @@ -347,17 +350,26 @@ def get_content_length() -> Optional[int]:
if SEC_WEBSOCKET_KEY1 in msg.headers:
raise InvalidHeader(SEC_WEBSOCKET_KEY1)

self._upgraded = msg.upgrade

method = getattr(msg, "method", self.method)
# code is only present on responses
code = getattr(msg, "code", 0)

import pprint

pprint.pprint(["upgraded", msg.upgrade, msg])
# If response is not 101 than we didn't upgrade
# if its 0 its not a response
self._upgraded = msg.upgrade and code in (0, 101)

assert self.protocol is not None
# calculate payload
empty_body = status_code_must_be_empty_body(code) or bool(
method and method_must_be_empty_body(method)
)
import pprint

pprint.pprint(["empty_body", empty_body, code, msg.upgrade])
# self._upgraded=False
if not empty_body and (
(length is not None and length > 0)
or msg.chunked
Expand Down Expand Up @@ -514,6 +526,9 @@ def parse_headers(
close_conn = False
# https://www.rfc-editor.org/rfc/rfc9110.html#name-101-switching-protocols
elif v == "upgrade" and headers.get(hdrs.UPGRADE):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Isn't this the source of the problem for #6446? Just because the Upgrade header has content, doesn't mean there will be an upgrade. The server has every right to ignore it if the scheme is unsupported.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If you think you can provide a fix, that'd be great. I believe there is already an xfail test from that issue, so as long as you can get that test passing..

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

One problem at a time for this newbie... Brotli first 😉

import pprint

pprint.pprint(["upgrade", headers.get(hdrs.UPGRADE)])
upgrade = True

# encoding
Expand Down
7 changes: 6 additions & 1 deletion aiohttp/streams.py
Original file line number Diff line number Diff line change
Expand Up @@ -624,7 +624,12 @@ async def read(self) -> _T:
self._waiter = self._loop.create_future()
try:
await self._waiter
except (asyncio.CancelledError, asyncio.TimeoutError):
except (asyncio.CancelledError, asyncio.TimeoutError) as ex:
import pprint

pprint.pprint(
["read waiter", self._waiter, self._buffer, self._eof, ex]
)
self._waiter = None
raise

Expand Down
3 changes: 3 additions & 0 deletions aiohttp/web_protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -477,6 +477,9 @@ async def _handle_request(
except asyncio.CancelledError:
raise
except asyncio.TimeoutError as exc:
import pprint

pprint.pprint(["request handler timed out"])
self.log_debug("Request handler timed out.", exc_info=exc)
resp = self.handle_error(request, 504)
reset = await self.finish_response(request, resp, start_time)
Expand Down
28 changes: 25 additions & 3 deletions tests/test_web_server.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
# type: ignore
import asyncio
import sys
from contextlib import suppress
from typing import Any
from unittest import mock
Expand All @@ -8,6 +9,11 @@

from aiohttp import client, helpers, web

if sys.version_info >= (3, 11):
import asyncio as async_timeout
else:
import async_timeout


async def test_simple_server(aiohttp_raw_server: Any, aiohttp_client: Any) -> None:
async def handler(request):
Expand All @@ -31,13 +37,29 @@
# don't fail if a client probes for an unsupported protocol upgrade
# https://github.com/aio-libs/aiohttp/issues/6446#issuecomment-999032039
async def handler(request: web.Request):
return web.Response(body=await request.read())
try:
import pprint

pprint.pprint(["handler called"])
async with async_timeout.timeout(10):
result = await request.read()
pprint.pprint(["handler read", result])
return web.Response(body=result)
except Exception as e:
import pprint

Check warning on line 49 in tests/test_web_server.py

View check run for this annotation

Codecov / codecov/patch

tests/test_web_server.py#L48-L49

Added lines #L48 - L49 were not covered by tests

pprint.pprint(["handler except", e])
raise

Check warning on line 52 in tests/test_web_server.py

View check run for this annotation

Codecov / codecov/patch

tests/test_web_server.py#L51-L52

Added lines #L51 - L52 were not covered by tests

upgrade_headers = {"Connection": "Upgrade", "Upgrade": "unsupported_proto"}
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure what is parsing the Upgrade header, but the fact that it's not a valid token at all should also be considered (IANA upgrade token registry. In fact, it even contains an invalid character, "_", for URL schemes (RFC 1738).

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would assume that would be handled in the parsers (i.e. llhttp or http_parser.py, depending on whether extensions are enabled).

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah okay. My point is, especially it it happens externally, the test should know and account for the behavior.

I would hope, no matter which parser is used, that "unsupported_proto" would simply be thrown away and equivalent to "". Then the test should also test a valid, but unsupported, token.

server = await aiohttp_raw_server(handler)
cli = await aiohttp_client(server)
cli: client.ClientSession = await aiohttp_client(server)
test_data = b"Test"
resp = await cli.post("/path/to", data=test_data, headers=upgrade_headers)
async with async_timeout.timeout(10):
resp = await cli.post("/path/to", data=test_data, headers=upgrade_headers)
import pprint

pprint.pprint(resp.headers)
assert resp.status == 200
data = await resp.read()
assert data == test_data
Expand Down
Loading