Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Avoid creating body writer task when there is no body #9757

Merged
merged 16 commits into from
Nov 10, 2024
1 change: 1 addition & 0 deletions CHANGES/9757.misc.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Improved performance of sending HTTP requests when there is no body -- by :user:`bdraco`.
4 changes: 4 additions & 0 deletions aiohttp/base_protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,10 @@ def connected(self) -> bool:
"""Return True if the connection is open."""
return self.transport is not None

@property
def writing_paused(self) -> bool:
return self._paused

def pause_writing(self) -> None:
assert not self._paused
self._paused = True
Expand Down
32 changes: 19 additions & 13 deletions aiohttp/client_reqrep.py
Original file line number Diff line number Diff line change
Expand Up @@ -677,22 +677,28 @@ async def send(self, conn: "Connection") -> "ClientResponse":
v = self.version
status_line = f"{self.method} {path} HTTP/{v.major}.{v.minor}"
await writer.write_headers(status_line, self.headers)
coro = self.write_bytes(writer, conn)

task: Optional["asyncio.Task[None]"]
if sys.version_info >= (3, 12):
# Optimization for Python 3.12, try to write
# bytes immediately to avoid having to schedule
# the task on the event loop.
task = asyncio.Task(coro, loop=self.loop, eager_start=True)
if self.body or self._continue is not None or protocol.writing_paused:
coro = self.write_bytes(writer, conn)
if sys.version_info >= (3, 12):
# Optimization for Python 3.12, try to write
# bytes immediately to avoid having to schedule
# the task on the event loop.
task = asyncio.Task(coro, loop=self.loop, eager_start=True)
else:
task = self.loop.create_task(coro)
if task.done():
task = None
else:
self._writer = task
else:
task = self.loop.create_task(coro)

if task.done():
# We have nothing to write because
# - there is no body
# - the protocol does not have writing paused
# - we are not waiting for a 100-continue response
protocol.start_timeout()
writer.set_eof()
task = None
else:
self._writer = task

response_class = self.response_class
assert response_class is not None
self.response = response_class(
Expand Down
4 changes: 4 additions & 0 deletions aiohttp/http_writer.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,10 @@ async def write_headers(
buf = _serialize_headers(status_line, headers)
self._write(buf)

def set_eof(self) -> None:
"""Indicate that the message is complete."""
self._eof = True

async def write_eof(self, chunk: bytes = b"") -> None:
if self._eof:
return
Expand Down
2 changes: 2 additions & 0 deletions tests/test_base_protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,10 @@ async def test_pause_writing() -> None:
loop = asyncio.get_event_loop()
pr = BaseProtocol(loop)
assert not pr._paused
assert pr.writing_paused is False
pr.pause_writing()
assert pr._paused
assert pr.writing_paused is True # type: ignore[unreachable]


async def test_pause_reading_no_transport() -> None:
Expand Down
4 changes: 4 additions & 0 deletions tests/test_benchmarks_client_request.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,10 @@ class MockProtocol(asyncio.BaseProtocol):
def __init__(self) -> None:
self.transport = MockTransport()

@property
def writing_paused(self) -> bool:
return False

async def _drain_helper(self) -> None:
"""Swallow drain."""

Expand Down
25 changes: 20 additions & 5 deletions tests/test_client_functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -1540,10 +1540,7 @@ async def handler(request: web.Request) -> web.Response:
assert 200 == resp.status


@pytest.mark.parametrize("data", (None, b""))
async def test_GET_DEFLATE(
aiohttp_client: AiohttpClient, data: Optional[bytes]
) -> None:
async def test_GET_DEFLATE(aiohttp_client: AiohttpClient) -> None:
async def handler(request: web.Request) -> web.Response:
return web.json_response({"ok": True})

Expand All @@ -1566,7 +1563,7 @@ async def write_bytes(
app.router.add_get("/", handler)
client = await aiohttp_client(app)

async with client.get("/", data=data, compress=True) as resp:
async with client.get("/", data=b"", compress=True) as resp:
assert resp.status == 200
content = await resp.json()
assert content == {"ok": True}
Expand All @@ -1576,6 +1573,24 @@ async def write_bytes(
write_mock.assert_not_called()


async def test_GET_DEFLATE_no_body(aiohttp_client: AiohttpClient) -> None:
async def handler(request: web.Request) -> web.Response:
return web.json_response({"ok": True})

with mock.patch.object(ClientRequest, "write_bytes") as mock_write_bytes:
app = web.Application()
app.router.add_get("/", handler)
client = await aiohttp_client(app)

async with client.get("/", data=None, compress=True) as resp:
assert resp.status == 200
content = await resp.json()
assert content == {"ok": True}

# No chunks should have been sent for an empty body.
mock_write_bytes.assert_not_called()


async def test_POST_DATA_DEFLATE(aiohttp_client: AiohttpClient) -> None:
async def handler(request: web.Request) -> web.Response:
data = await request.post()
Expand Down
7 changes: 2 additions & 5 deletions tests/test_client_session.py
Original file line number Diff line number Diff line change
Expand Up @@ -918,7 +918,7 @@ def to_url(path: str) -> URL:
assert to_trace_urls(on_request_redirect) == []
assert to_trace_urls(on_request_end) == [to_url("/?x=0")]
assert to_trace_urls(on_request_exception) == []
assert to_trace_urls(on_request_chunk_sent) == [to_url("/?x=0")]
bdraco marked this conversation as resolved.
Show resolved Hide resolved
assert to_trace_urls(on_request_chunk_sent) == []
assert to_trace_urls(on_response_chunk_received) == [to_url("/?x=0")]
assert to_trace_urls(on_request_headers_sent) == [to_url("/?x=0")]

Expand All @@ -934,10 +934,7 @@ def to_url(path: str) -> URL:
assert to_trace_urls(on_request_redirect) == [to_url("/redirect?x=0")]
assert to_trace_urls(on_request_end) == [to_url("/")]
assert to_trace_urls(on_request_exception) == []
assert to_trace_urls(on_request_chunk_sent) == [
to_url("/redirect?x=0"),
to_url("/"),
]
assert to_trace_urls(on_request_chunk_sent) == []
assert to_trace_urls(on_response_chunk_received) == [to_url("/")]
assert to_trace_urls(on_request_headers_sent) == [
to_url("/redirect?x=0"),
Expand Down
16 changes: 16 additions & 0 deletions tests/test_http_writer.py
Original file line number Diff line number Diff line change
Expand Up @@ -360,3 +360,19 @@ async def test_write_headers_prevents_injection(
wrong_headers = CIMultiDict({"Content-Length": "256\r\nSet-Cookie: abc=123"})
with pytest.raises(ValueError):
await msg.write_headers(status_line, wrong_headers)


async def test_set_eof_after_write_headers(
protocol: BaseProtocol,
transport: mock.Mock,
loop: asyncio.AbstractEventLoop,
) -> None:
msg = http.StreamWriter(protocol, loop)
status_line = "HTTP/1.1 200 OK"
good_headers = CIMultiDict({"Set-Cookie": "abc=123"})
await msg.write_headers(status_line, good_headers)
assert transport.write.called
transport.write.reset_mock()
msg.set_eof()
await msg.write_eof()
assert not transport.write.called
Loading