Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix py http parser not treating 204/304/1xx as an empty body #7755

Merged
Show file tree
Hide file tree
Changes from 11 commits
Commits
Show all changes
27 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion aiohttp/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,7 @@
TimeoutHandle,
ceil_timeout,
get_env_proxy_for_url,
method_must_be_empty_body,
sentinel,
strip_auth_from_url,
)
Expand Down Expand Up @@ -526,7 +527,7 @@ async def _request(
assert conn.protocol is not None
conn.protocol.set_response_params(
timer=timer,
skip_payload=method.upper() == "HEAD",
skip_payload=method_must_be_empty_body(method.upper()),
bdraco marked this conversation as resolved.
Show resolved Hide resolved
read_until_eof=read_until_eof,
auto_decompress=auto_decompress,
read_timeout=real_timeout.sock_read,
Expand Down
11 changes: 9 additions & 2 deletions aiohttp/client_proto.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,12 @@
ServerDisconnectedError,
ServerTimeoutError,
)
from .helpers import BaseTimerContext, set_exception, set_result
from .helpers import (
BaseTimerContext,
set_exception,
set_result,
status_code_must_be_empty_body,
)
from .http import HttpResponseParser, RawResponseMessage, WebSocketReader
from .streams import EMPTY_PAYLOAD, DataQueue, StreamReader

Expand Down Expand Up @@ -248,7 +253,9 @@ def data_received(self, data: bytes) -> None:

self._payload = payload

if self._skip_payload or message.code in (204, 304):
if self._skip_payload or status_code_must_be_empty_body(
message.code
):
self.feed_data((message, EMPTY_PAYLOAD), 0)
else:
self.feed_data((message, payload), 0)
Expand Down
19 changes: 19 additions & 0 deletions aiohttp/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -1063,3 +1063,22 @@ def parse_http_date(date_str: Optional[str]) -> Optional[datetime.datetime]:
with suppress(ValueError):
return datetime.datetime(*timetuple[:6], tzinfo=datetime.timezone.utc)
return None


def must_be_empty_body(method: str, code: int) -> bool:
"""Check if a request must return an empty body."""
# https://datatracker.ietf.org/doc/html/rfc9112#section-6.3
return status_code_must_be_empty_body(code) or method_must_be_empty_body(method)


def method_must_be_empty_body(method: str) -> bool:
"""Check if a method must return an empty body."""
# https://datatracker.ietf.org/doc/html/rfc9112#section-6.3
return method in (hdrs.METH_CONNECT, hdrs.METH_HEAD)


def status_code_must_be_empty_body(code: int) -> bool:
"""Check if a status code must return an empty body."""
# 204, 304, 1xx should not have a body per
# https://datatracker.ietf.org/doc/html/rfc9112#section-6.3
return code in (204, 304) or 100 <= code < 200
56 changes: 27 additions & 29 deletions aiohttp/http_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
from . import hdrs
from .base_protocol import BaseProtocol
from .compression_utils import HAS_BROTLI, BrotliDecompressor, ZLibDecompressor
from .helpers import DEBUG, NO_EXTENSIONS, BaseTimerContext
from .helpers import DEBUG, NO_EXTENSIONS, BaseTimerContext, must_be_empty_body
from .http_exceptions import (
BadHttpMessage,
BadStatusLine,
Expand Down Expand Up @@ -338,10 +338,13 @@ def get_content_length() -> Optional[int]:
self._upgraded = msg.upgrade

method = getattr(msg, "method", self.method)
# code is only present on responses
code = getattr(msg, "code", 0)

assert self.protocol is not None
# calculate payload
if (
empty_body = must_be_empty_body(method, code)
if not empty_body and (
(length is not None and length > 0)
or msg.chunked
and not msg.upgrade
Expand Down Expand Up @@ -383,34 +386,29 @@ def get_content_length() -> Optional[int]:
auto_decompress=self._auto_decompress,
lax=self.lax,
)
elif not empty_body and length is None and self.read_until_eof:
payload = StreamReader(
self.protocol,
timer=self.timer,
loop=loop,
limit=self._limit,
)
payload_parser = HttpPayloadParser(
payload,
length=length,
chunked=msg.chunked,
method=method,
compression=msg.compression,
code=self.code,
readall=True,
response_with_body=self.response_with_body,
auto_decompress=self._auto_decompress,
lax=self.lax,
)
if not payload_parser.done:
self._payload_parser = payload_parser
else:
if (
getattr(msg, "code", 100) >= 199
and length is None
and self.read_until_eof
):
payload = StreamReader(
self.protocol,
timer=self.timer,
loop=loop,
limit=self._limit,
)
payload_parser = HttpPayloadParser(
payload,
length=length,
chunked=msg.chunked,
method=method,
compression=msg.compression,
code=self.code,
readall=True,
response_with_body=self.response_with_body,
auto_decompress=self._auto_decompress,
lax=self.lax,
)
if not payload_parser.done:
self._payload_parser = payload_parser
else:
payload = EMPTY_PAYLOAD
payload = EMPTY_PAYLOAD

messages.append((msg, payload))
else:
Expand Down
126 changes: 126 additions & 0 deletions tests/test_http_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
HttpPayloadParser,
HttpRequestParserPy,
HttpResponseParserPy,
HttpVersion,
)

try:
Expand Down Expand Up @@ -1060,6 +1061,131 @@ def test_parse_no_length_payload(parser: Any) -> None:
assert payload.is_eof()


def test_parse_content_length_payload_multiple(response: Any) -> None:
text = b"HTTP/1.1 200 OK\r\ncontent-length: 5\r\n\r\nfirst"
msg, payload = response.feed_data(text)[0][0]
assert msg.version == HttpVersion(major=1, minor=1)
assert msg.code == 200
assert msg.reason == "OK"
assert msg.headers == CIMultiDict(
[
("Content-Length", "5"),
]
)
assert msg.raw_headers == ((b"content-length", b"5"),)
assert not msg.should_close
assert msg.compression is None
assert not msg.upgrade
assert not msg.chunked
assert payload.is_eof()
assert b"first" == b"".join(d for d in payload._buffer)

text = b"HTTP/1.1 200 OK\r\ncontent-length: 6\r\n\r\nsecond"
msg, payload = response.feed_data(text)[0][0]
assert msg.version == HttpVersion(major=1, minor=1)
assert msg.code == 200
assert msg.reason == "OK"
assert msg.headers == CIMultiDict(
[
("Content-Length", "6"),
]
)
assert msg.raw_headers == ((b"content-length", b"6"),)
assert not msg.should_close
assert msg.compression is None
assert not msg.upgrade
assert not msg.chunked
assert payload.is_eof()
assert b"second" == b"".join(d for d in payload._buffer)


def test_parse_content_length_than_chunked_payload(response: Any) -> None:
text = b"HTTP/1.1 200 OK\r\ncontent-length: 5\r\n\r\nfirst"
msg, payload = response.feed_data(text)[0][0]
assert msg.version == HttpVersion(major=1, minor=1)
assert msg.code == 200
assert msg.reason == "OK"
assert msg.headers == CIMultiDict(
[
("Content-Length", "5"),
]
)
assert msg.raw_headers == ((b"content-length", b"5"),)
assert not msg.should_close
assert msg.compression is None
assert not msg.upgrade
assert not msg.chunked
assert payload.is_eof()
assert b"first" == b"".join(d for d in payload._buffer)

text = (
b"HTTP/1.1 200 OK\r\n"
b"transfer-encoding: chunked\r\n\r\n"
b"6\r\nsecond\r\n0\r\n\r\n"
)
msg, payload = response.feed_data(text)[0][0]
assert msg.version == HttpVersion(major=1, minor=1)
assert msg.code == 200
assert msg.reason == "OK"
assert msg.headers == CIMultiDict(
[
("Transfer-Encoding", "chunked"),
]
)
assert msg.raw_headers == ((b"transfer-encoding", b"chunked"),)
assert not msg.should_close
assert msg.compression is None
assert not msg.upgrade
assert msg.chunked
assert payload.is_eof()
assert b"second" == b"".join(d for d in payload._buffer)


@pytest.mark.parametrize("code", [204, 304, 101, 102])
def test_parse_chunked_payload_empty_body_than_another_chunked(
response: Any, code: int
) -> None:
head = f"HTTP/1.1 {code} OK\r\n".encode()
text = head + b"transfer-encoding: chunked\r\n\r\n"
msg, payload = response.feed_data(text)[0][0]
assert msg.version == HttpVersion(major=1, minor=1)
assert msg.code == code
assert msg.reason == "OK"
assert msg.headers == CIMultiDict(
[
("Transfer-Encoding", "chunked"),
]
)
assert msg.raw_headers == ((b"transfer-encoding", b"chunked"),)
assert not msg.should_close
assert msg.compression is None
assert not msg.upgrade
assert msg.chunked
assert payload.is_eof()

text = (
b"HTTP/1.1 200 OK\r\n"
b"transfer-encoding: chunked\r\n\r\n"
b"6\r\nsecond\r\n0\r\n\r\n"
)
msg, payload = response.feed_data(text)[0][0]
assert msg.version == HttpVersion(major=1, minor=1)
assert msg.code == 200
assert msg.reason == "OK"
assert msg.headers == CIMultiDict(
[
("Transfer-Encoding", "chunked"),
]
)
assert msg.raw_headers == ((b"transfer-encoding", b"chunked"),)
assert not msg.should_close
assert msg.compression is None
assert not msg.upgrade
assert msg.chunked
assert payload.is_eof()
assert b"second" == b"".join(d for d in payload._buffer)


def test_partial_url(parser: Any) -> None:
messages, upgrade, tail = parser.feed_data(b"GET /te")
assert len(messages) == 0
Expand Down
Loading