Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add validation of HTTP status line, header keys and values #5452

1 change: 1 addition & 0 deletions CHANGES/4818.feature
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Add validation of HTTP header keys and values to prevent header injection.
1 change: 1 addition & 0 deletions CONTRIBUTORS.txt
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,7 @@ Felix Yan
Fernanda Guimarães
FichteFoll
Florian Scheffler
Franek Magiera
Frederik Gladhorn
Frederik Peter Aalund
Gabriel Tremblay
Expand Down
12 changes: 12 additions & 0 deletions aiohttp/_http_writer.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,14 @@ cdef str to_str(object s):
return str(s)


cdef void _safe_header(str string) except *:
if "\r" in string or "\n" in string:
raise ValueError(
"Newline or carriage return character detected in HTTP status message or "
"header. This is a potential security issue."
)


def _serialize_headers(str status_line, headers):
cdef Writer writer
cdef object key
Expand All @@ -119,6 +127,10 @@ def _serialize_headers(str status_line, headers):

_init_writer(&writer)

for key, val in headers.items():
_safe_header(to_str(key))
_safe_header(to_str(val))

try:
if _write_str(&writer, status_line) < 0:
raise
Expand Down
18 changes: 12 additions & 6 deletions aiohttp/http_writer.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,13 +171,19 @@ async def drain(self) -> None:
await self._protocol._drain_helper()


def _safe_header(string: str) -> str:
if "\r" in string or "\n" in string:
raise ValueError(
"Newline or carriage return detected in headers. "
"Potential header injection attack."
)
return string


def _py_serialize_headers(status_line: str, headers: "CIMultiDict[str]") -> bytes:
line = (
status_line
+ "\r\n"
+ "".join([k + ": " + v + "\r\n" for k, v in headers.items()])
)
return line.encode("utf-8") + b"\r\n"
headers_gen = (_safe_header(k) + ": " + _safe_header(v) for k, v in headers.items())
line = status_line + "\r\n" + "\r\n".join(headers_gen) + "\r\n\r\n"
return line.encode("utf-8")


_serialize_headers = _py_serialize_headers
Expand Down
14 changes: 14 additions & 0 deletions tests/test_http_writer.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from unittest import mock

import pytest
from multidict import CIMultiDict

from aiohttp import http
from aiohttp.test_utils import make_mocked_coro
Expand Down Expand Up @@ -272,3 +273,16 @@ async def test_drain_no_transport(protocol: Any, transport: Any, loop: Any) -> N
msg._protocol.transport = None
await msg.drain()
assert not protocol._drain_helper.called


async def test_write_headers_prevents_injection(
protocol: Any, transport: Any, loop: Any
) -> None:
msg = http.StreamWriter(protocol, loop)
status_line = "HTTP/1.1 200 OK"
wrong_headers = CIMultiDict({"Set-Cookie: abc=123\r\nContent-Length": "256"})
with pytest.raises(ValueError):
await msg.write_headers(status_line, wrong_headers)
wrong_headers = CIMultiDict({"Content-Length": "256\r\nSet-Cookie: abc=123"})
with pytest.raises(ValueError):
await msg.write_headers(status_line, wrong_headers)
8 changes: 2 additions & 6 deletions tests/test_web_response.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@

from aiohttp import HttpVersion, HttpVersion10, HttpVersion11, hdrs
from aiohttp.helpers import ETag
from aiohttp.http_writer import _serialize_headers
from aiohttp.payload import BytesPayload
from aiohttp.test_utils import make_mocked_coro, make_mocked_request
from aiohttp.web import ContentCoding, Response, StreamResponse, json_response
Expand Down Expand Up @@ -59,12 +60,7 @@ def write(chunk):
buf.extend(chunk)

async def write_headers(status_line, headers):
headers = (
status_line
+ "\r\n"
+ "".join([k + ": " + v + "\r\n" for k, v in headers.items()])
)
headers = headers.encode("utf-8") + b"\r\n"
headers = _serialize_headers(status_line, headers)
buf.extend(headers)

async def write_eof(chunk=b""):
Expand Down