Skip to content

Commit

Permalink
Fix missing eof when writer cancelled (aio-libs#7764) (aio-libs#7781)
Browse files Browse the repository at this point in the history
Fixes aio-libs#5220.

I believe this is a better fix than aio-libs#5238. That PR detects that we
didn't finish sending a chunked response and then closes the connection.
This PR ensures that we simply complete the chunked response by sending
the EOF bytes, allowing the connection to remain open and be reused
normally.

(cherry picked from commit 9c07121)
  • Loading branch information
Dreamsorcerer authored Nov 3, 2023
1 parent cdfed8b commit 79f5266
Show file tree
Hide file tree
Showing 7 changed files with 160 additions and 82 deletions.
1 change: 1 addition & 0 deletions CHANGES/7764.bugfix
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Fixed an issue when a client request is closed before completing a chunked payload -- by :user:`Dreamsorcerer`
1 change: 1 addition & 0 deletions aiohttp/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -1203,6 +1203,7 @@ async def __aexit__(
# explicitly. Otherwise connection error handling should kick in
# and close/recycle the connection as required.
self._resp.release()
await self._resp.wait_for_close()


class _WSRequestContextManager(_BaseRequestContextManager[ClientWebSocketResponse]):
Expand Down
59 changes: 36 additions & 23 deletions aiohttp/client_reqrep.py
Original file line number Diff line number Diff line change
Expand Up @@ -584,8 +584,11 @@ async def write_bytes(
"""Support coroutines that yields bytes objects."""
# 100 response
if self._continue is not None:
await writer.drain()
await self._continue
try:
await writer.drain()
await self._continue
except asyncio.CancelledError:
return

protocol = conn.protocol
assert protocol is not None
Expand All @@ -598,8 +601,6 @@ async def write_bytes(

for chunk in self.body:
await writer.write(chunk) # type: ignore[arg-type]

await writer.write_eof()
except OSError as exc:
if exc.errno is None and isinstance(exc, asyncio.TimeoutError):
protocol.set_exception(exc)
Expand All @@ -610,12 +611,12 @@ async def write_bytes(
new_exc.__context__ = exc
new_exc.__cause__ = exc
protocol.set_exception(new_exc)
except asyncio.CancelledError as exc:
if not conn.closed:
protocol.set_exception(exc)
except asyncio.CancelledError:
await writer.write_eof()
except Exception as exc:
protocol.set_exception(exc)
else:
await writer.write_eof()
protocol.start_timeout()
finally:
self._writer = None
Expand Down Expand Up @@ -704,7 +705,8 @@ async def send(self, conn: "Connection") -> "ClientResponse":
async def close(self) -> None:
if self._writer is not None:
try:
await self._writer
with contextlib.suppress(asyncio.CancelledError):
await self._writer
finally:
self._writer = None

Expand Down Expand Up @@ -973,8 +975,7 @@ def _response_eof(self) -> None:
):
return

self._connection.release()
self._connection = None
self._release_connection()

self._closed = True
self._cleanup_writer()
Expand All @@ -986,30 +987,22 @@ def closed(self) -> bool:
def close(self) -> None:
if not self._released:
self._notify_content()
if self._closed:
return

self._closed = True
if self._loop is None or self._loop.is_closed():
return

if self._connection is not None:
self._connection.close()
self._connection = None
self._cleanup_writer()
self._release_connection()

def release(self) -> Any:
if not self._released:
self._notify_content()
if self._closed:
return noop()

self._closed = True
if self._connection is not None:
self._connection.release()
self._connection = None

self._cleanup_writer()
self._release_connection()
return noop()

@property
Expand All @@ -1034,10 +1027,28 @@ def raise_for_status(self) -> None:
headers=self.headers,
)

def _release_connection(self) -> None:
if self._connection is not None:
if self._writer is None:
self._connection.release()
self._connection = None
else:
self._writer.add_done_callback(lambda f: self._release_connection())

async def _wait_released(self) -> None:
if self._writer is not None:
try:
await self._writer
finally:
self._writer = None
self._release_connection()

def _cleanup_writer(self) -> None:
if self._writer is not None:
self._writer.cancel()
self._writer = None
if self._writer.done():
self._writer = None
else:
self._writer.cancel()
self._session = None

def _notify_content(self) -> None:
Expand Down Expand Up @@ -1066,9 +1077,10 @@ async def read(self) -> bytes:
except BaseException:
self.close()
raise
elif self._released:
elif self._released: # Response explicity released
raise ClientConnectionError("Connection closed")

await self._wait_released() # Underlying connection released
return self._body # type: ignore[no-any-return]

def get_encoding(self) -> str:
Expand Down Expand Up @@ -1151,3 +1163,4 @@ async def __aexit__(
# for exceptions, response object can close connection
# if state is broken
self.release()
await self.wait_for_close()
59 changes: 57 additions & 2 deletions tests/test_client_functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -205,6 +205,7 @@ async def handler(request):
client = await aiohttp_client(app)
resp = await client.get("/")
assert resp.closed
await resp.wait_for_close()
assert 1 == len(client._session.connector._conns)


Expand All @@ -224,6 +225,60 @@ async def handler(request):
assert content == b""


async def test_stream_request_on_server_eof(aiohttp_client) -> None:
async def handler(request):
return web.Response(text="OK", status=200)

app = web.Application()
app.add_routes([web.get("/", handler)])
app.add_routes([web.put("/", handler)])

client = await aiohttp_client(app)

async def data_gen():
for _ in range(2):
yield b"just data"
await asyncio.sleep(0.1)

async with client.put("/", data=data_gen()) as resp:
assert 200 == resp.status
assert len(client.session.connector._acquired) == 1
conn = next(iter(client.session.connector._acquired))

async with client.get("/") as resp:
assert 200 == resp.status

# Connection should have been reused
conns = next(iter(client.session.connector._conns.values()))
assert len(conns) == 1
assert conns[0][0] is conn


async def test_stream_request_on_server_eof_nested(aiohttp_client) -> None:
async def handler(request):
return web.Response(text="OK", status=200)

app = web.Application()
app.add_routes([web.get("/", handler)])
app.add_routes([web.put("/", handler)])

client = await aiohttp_client(app)

async def data_gen():
for _ in range(2):
yield b"just data"
await asyncio.sleep(0.1)

async with client.put("/", data=data_gen()) as resp:
assert 200 == resp.status
async with client.get("/") as resp:
assert 200 == resp.status

# Should be 2 separate connections
conns = next(iter(client.session.connector._conns.values()))
assert len(conns) == 2


async def test_HTTP_304_WITH_BODY(aiohttp_client) -> None:
async def handler(request):
body = await request.read()
Expand Down Expand Up @@ -306,8 +361,8 @@ async def handler(request):
client = await aiohttp_client(app)

with io.BytesIO(data) as file_handle:
resp = await client.post("/", data=file_handle)
assert 200 == resp.status
async with client.post("/", data=file_handle) as resp:
assert 200 == resp.status


async def test_post_data_with_bytesio_file(aiohttp_client) -> None:
Expand Down
Loading

0 comments on commit 79f5266

Please sign in to comment.