diff --git a/CHANGES/7764.bugfix b/CHANGES/7764.bugfix new file mode 100644 index 00000000000..6e4c7aa5ba8 --- /dev/null +++ b/CHANGES/7764.bugfix @@ -0,0 +1 @@ +Fixed an issue when a client request is closed before completing a chunked payload -- by :user:`Dreamsorcerer` diff --git a/aiohttp/client.py b/aiohttp/client.py index de33ce29407..93e9b34c6ab 100644 --- a/aiohttp/client.py +++ b/aiohttp/client.py @@ -1203,6 +1203,7 @@ async def __aexit__( # explicitly. Otherwise connection error handling should kick in # and close/recycle the connection as required. self._resp.release() + await self._resp.wait_for_close() class _WSRequestContextManager(_BaseRequestContextManager[ClientWebSocketResponse]): diff --git a/aiohttp/client_reqrep.py b/aiohttp/client_reqrep.py index e35ddc01a3d..6aa45aacd41 100644 --- a/aiohttp/client_reqrep.py +++ b/aiohttp/client_reqrep.py @@ -584,8 +584,11 @@ async def write_bytes( """Support coroutines that yields bytes objects.""" # 100 response if self._continue is not None: - await writer.drain() - await self._continue + try: + await writer.drain() + await self._continue + except asyncio.CancelledError: + return protocol = conn.protocol assert protocol is not None @@ -598,8 +601,6 @@ async def write_bytes( for chunk in self.body: await writer.write(chunk) # type: ignore[arg-type] - - await writer.write_eof() except OSError as exc: if exc.errno is None and isinstance(exc, asyncio.TimeoutError): protocol.set_exception(exc) @@ -610,12 +611,12 @@ async def write_bytes( new_exc.__context__ = exc new_exc.__cause__ = exc protocol.set_exception(new_exc) - except asyncio.CancelledError as exc: - if not conn.closed: - protocol.set_exception(exc) + except asyncio.CancelledError: + await writer.write_eof() except Exception as exc: protocol.set_exception(exc) else: + await writer.write_eof() protocol.start_timeout() finally: self._writer = None @@ -704,7 +705,8 @@ async def send(self, conn: "Connection") -> "ClientResponse": async def close(self) -> None: if self._writer is not None: try: - await self._writer + with contextlib.suppress(asyncio.CancelledError): + await self._writer finally: self._writer = None @@ -973,8 +975,7 @@ def _response_eof(self) -> None: ): return - self._connection.release() - self._connection = None + self._release_connection() self._closed = True self._cleanup_writer() @@ -986,30 +987,22 @@ def closed(self) -> bool: def close(self) -> None: if not self._released: self._notify_content() - if self._closed: - return self._closed = True if self._loop is None or self._loop.is_closed(): return - if self._connection is not None: - self._connection.close() - self._connection = None self._cleanup_writer() + self._release_connection() def release(self) -> Any: if not self._released: self._notify_content() - if self._closed: - return noop() self._closed = True - if self._connection is not None: - self._connection.release() - self._connection = None self._cleanup_writer() + self._release_connection() return noop() @property @@ -1034,10 +1027,28 @@ def raise_for_status(self) -> None: headers=self.headers, ) + def _release_connection(self) -> None: + if self._connection is not None: + if self._writer is None: + self._connection.release() + self._connection = None + else: + self._writer.add_done_callback(lambda f: self._release_connection()) + + async def _wait_released(self) -> None: + if self._writer is not None: + try: + await self._writer + finally: + self._writer = None + self._release_connection() + def _cleanup_writer(self) -> None: if self._writer is not None: - self._writer.cancel() - self._writer = None + if self._writer.done(): + self._writer = None + else: + self._writer.cancel() self._session = None def _notify_content(self) -> None: @@ -1066,9 +1077,10 @@ async def read(self) -> bytes: except BaseException: self.close() raise - elif self._released: + elif self._released: # Response explicity released raise ClientConnectionError("Connection closed") + await self._wait_released() # Underlying connection released return self._body # type: ignore[no-any-return] def get_encoding(self) -> str: @@ -1151,3 +1163,4 @@ async def __aexit__( # for exceptions, response object can close connection # if state is broken self.release() + await self.wait_for_close() diff --git a/tests/test_client_functional.py b/tests/test_client_functional.py index 7063615a942..06fa3c265ec 100644 --- a/tests/test_client_functional.py +++ b/tests/test_client_functional.py @@ -205,6 +205,7 @@ async def handler(request): client = await aiohttp_client(app) resp = await client.get("/") assert resp.closed + await resp.wait_for_close() assert 1 == len(client._session.connector._conns) @@ -224,6 +225,60 @@ async def handler(request): assert content == b"" +async def test_stream_request_on_server_eof(aiohttp_client) -> None: + async def handler(request): + return web.Response(text="OK", status=200) + + app = web.Application() + app.add_routes([web.get("/", handler)]) + app.add_routes([web.put("/", handler)]) + + client = await aiohttp_client(app) + + async def data_gen(): + for _ in range(2): + yield b"just data" + await asyncio.sleep(0.1) + + async with client.put("/", data=data_gen()) as resp: + assert 200 == resp.status + assert len(client.session.connector._acquired) == 1 + conn = next(iter(client.session.connector._acquired)) + + async with client.get("/") as resp: + assert 200 == resp.status + + # Connection should have been reused + conns = next(iter(client.session.connector._conns.values())) + assert len(conns) == 1 + assert conns[0][0] is conn + + +async def test_stream_request_on_server_eof_nested(aiohttp_client) -> None: + async def handler(request): + return web.Response(text="OK", status=200) + + app = web.Application() + app.add_routes([web.get("/", handler)]) + app.add_routes([web.put("/", handler)]) + + client = await aiohttp_client(app) + + async def data_gen(): + for _ in range(2): + yield b"just data" + await asyncio.sleep(0.1) + + async with client.put("/", data=data_gen()) as resp: + assert 200 == resp.status + async with client.get("/") as resp: + assert 200 == resp.status + + # Should be 2 separate connections + conns = next(iter(client.session.connector._conns.values())) + assert len(conns) == 2 + + async def test_HTTP_304_WITH_BODY(aiohttp_client) -> None: async def handler(request): body = await request.read() @@ -306,8 +361,8 @@ async def handler(request): client = await aiohttp_client(app) with io.BytesIO(data) as file_handle: - resp = await client.post("/", data=file_handle) - assert 200 == resp.status + async with client.post("/", data=file_handle) as resp: + assert 200 == resp.status async def test_post_data_with_bytesio_file(aiohttp_client) -> None: diff --git a/tests/test_client_response.py b/tests/test_client_response.py index 7d07f13e46c..74027fcaf76 100644 --- a/tests/test_client_response.py +++ b/tests/test_client_response.py @@ -15,6 +15,14 @@ from aiohttp.test_utils import make_mocked_coro +class WriterMock(mock.AsyncMock): + def __await__(self) -> None: + return self().__await__() + + def done(self) -> bool: + return True + + @pytest.fixture def session(): return mock.Mock() @@ -27,7 +35,7 @@ async def test_http_processing_error(session) -> None: "get", URL("http://del-cl-resp.org"), request_info=request_info, - writer=mock.Mock(), + writer=WriterMock(), continue100=None, timer=TimerNoop(), traces=[], @@ -55,7 +63,7 @@ def test_del(session) -> None: "get", URL("http://del-cl-resp.org"), request_info=mock.Mock(), - writer=mock.Mock(), + writer=WriterMock(), continue100=None, timer=TimerNoop(), traces=[], @@ -82,7 +90,7 @@ def test_close(loop, session) -> None: "get", URL("http://def-cl-resp.org"), request_info=mock.Mock(), - writer=mock.Mock(), + writer=WriterMock(), continue100=None, timer=TimerNoop(), traces=[], @@ -103,7 +111,7 @@ def test_wait_for_100_1(loop, session) -> None: URL("http://python.org"), continue100=object(), request_info=mock.Mock(), - writer=mock.Mock(), + writer=WriterMock(), timer=TimerNoop(), traces=[], loop=loop, @@ -119,7 +127,7 @@ def test_wait_for_100_2(loop, session) -> None: URL("http://python.org"), request_info=mock.Mock(), continue100=None, - writer=mock.Mock(), + writer=WriterMock(), timer=TimerNoop(), traces=[], loop=loop, @@ -134,7 +142,7 @@ def test_repr(loop, session) -> None: "get", URL("http://def-cl-resp.org"), request_info=mock.Mock(), - writer=mock.Mock(), + writer=WriterMock(), continue100=None, timer=TimerNoop(), traces=[], @@ -151,7 +159,7 @@ def test_repr_non_ascii_url() -> None: "get", URL("http://fake-host.org/\u03bb"), request_info=mock.Mock(), - writer=mock.Mock(), + writer=WriterMock(), continue100=None, timer=TimerNoop(), traces=[], @@ -166,7 +174,7 @@ def test_repr_non_ascii_reason() -> None: "get", URL("http://fake-host.org/path"), request_info=mock.Mock(), - writer=mock.Mock(), + writer=WriterMock(), continue100=None, timer=TimerNoop(), traces=[], @@ -200,7 +208,7 @@ async def test_read_and_release_connection(loop, session) -> None: "get", URL("http://def-cl-resp.org"), request_info=mock.Mock(), - writer=mock.Mock(), + writer=WriterMock(), continue100=None, timer=TimerNoop(), traces=[], @@ -226,7 +234,7 @@ async def test_read_and_release_connection_with_error(loop, session) -> None: "get", URL("http://def-cl-resp.org"), request_info=mock.Mock(), - writer=mock.Mock(), + writer=WriterMock(), continue100=None, timer=TimerNoop(), traces=[], @@ -247,7 +255,7 @@ async def test_release(loop, session) -> None: "get", URL("http://def-cl-resp.org"), request_info=mock.Mock(), - writer=mock.Mock(), + writer=WriterMock(), continue100=None, timer=TimerNoop(), traces=[], @@ -276,7 +284,7 @@ def run(conn): "get", URL("http://def-cl-resp.org"), request_info=mock.Mock(), - writer=mock.Mock(), + writer=WriterMock(), continue100=None, timer=TimerNoop(), traces=[], @@ -296,7 +304,7 @@ async def test_response_eof(loop, session) -> None: "get", URL("http://def-cl-resp.org"), request_info=mock.Mock(), - writer=mock.Mock(), + writer=None, continue100=None, timer=TimerNoop(), traces=[], @@ -317,7 +325,7 @@ async def test_response_eof_upgraded(loop, session) -> None: "get", URL("http://def-cl-resp.org"), request_info=mock.Mock(), - writer=mock.Mock(), + writer=WriterMock(), continue100=None, timer=TimerNoop(), traces=[], @@ -338,7 +346,7 @@ async def test_response_eof_after_connection_detach(loop, session) -> None: "get", URL("http://def-cl-resp.org"), request_info=mock.Mock(), - writer=mock.Mock(), + writer=None, continue100=None, timer=TimerNoop(), traces=[], @@ -359,7 +367,7 @@ async def test_text(loop, session) -> None: "get", URL("http://def-cl-resp.org"), request_info=mock.Mock(), - writer=mock.Mock(), + writer=WriterMock(), continue100=None, timer=TimerNoop(), traces=[], @@ -386,7 +394,7 @@ async def test_text_bad_encoding(loop, session) -> None: "get", URL("http://def-cl-resp.org"), request_info=mock.Mock(), - writer=mock.Mock(), + writer=WriterMock(), continue100=None, timer=TimerNoop(), traces=[], @@ -416,7 +424,7 @@ async def test_text_custom_encoding(loop, session) -> None: "get", URL("http://def-cl-resp.org"), request_info=mock.Mock(), - writer=mock.Mock(), + writer=WriterMock(), continue100=None, timer=TimerNoop(), traces=[], @@ -447,7 +455,7 @@ async def test_text_charset_resolver(content_type: str, loop, session) -> None: "get", URL("http://def-cl-resp.org"), request_info=mock.Mock(), - writer=mock.Mock(), + writer=WriterMock(), continue100=None, timer=TimerNoop(), traces=[], @@ -476,7 +484,7 @@ async def test_get_encoding_body_none(loop, session) -> None: "get", URL("http://def-cl-resp.org"), request_info=mock.Mock(), - writer=mock.Mock(), + writer=WriterMock(), continue100=None, timer=TimerNoop(), traces=[], @@ -506,7 +514,7 @@ async def test_text_after_read(loop, session) -> None: "get", URL("http://def-cl-resp.org"), request_info=mock.Mock(), - writer=mock.Mock(), + writer=WriterMock(), continue100=None, timer=TimerNoop(), traces=[], @@ -533,7 +541,7 @@ async def test_json(loop, session) -> None: "get", URL("http://def-cl-resp.org"), request_info=mock.Mock(), - writer=mock.Mock(), + writer=WriterMock(), continue100=None, timer=TimerNoop(), traces=[], @@ -560,7 +568,7 @@ async def test_json_extended_content_type(loop, session) -> None: "get", URL("http://def-cl-resp.org"), request_info=mock.Mock(), - writer=mock.Mock(), + writer=WriterMock(), continue100=None, timer=TimerNoop(), traces=[], @@ -589,7 +597,7 @@ async def test_json_custom_content_type(loop, session) -> None: "get", URL("http://def-cl-resp.org"), request_info=mock.Mock(), - writer=mock.Mock(), + writer=WriterMock(), continue100=None, timer=TimerNoop(), traces=[], @@ -616,7 +624,7 @@ async def test_json_custom_loader(loop, session) -> None: "get", URL("http://def-cl-resp.org"), request_info=mock.Mock(), - writer=mock.Mock(), + writer=WriterMock(), continue100=None, timer=TimerNoop(), traces=[], @@ -638,7 +646,7 @@ async def test_json_invalid_content_type(loop, session) -> None: "get", URL("http://def-cl-resp.org"), request_info=mock.Mock(), - writer=mock.Mock(), + writer=WriterMock(), continue100=None, timer=TimerNoop(), traces=[], @@ -659,7 +667,7 @@ async def test_json_no_content(loop, session) -> None: "get", URL("http://def-cl-resp.org"), request_info=mock.Mock(), - writer=mock.Mock(), + writer=WriterMock(), continue100=None, timer=TimerNoop(), traces=[], @@ -678,7 +686,7 @@ async def test_json_override_encoding(loop, session) -> None: "get", URL("http://def-cl-resp.org"), request_info=mock.Mock(), - writer=mock.Mock(), + writer=WriterMock(), continue100=None, timer=TimerNoop(), traces=[], @@ -707,7 +715,7 @@ def test_get_encoding_unknown(loop, session) -> None: "get", URL("http://def-cl-resp.org"), request_info=mock.Mock(), - writer=mock.Mock(), + writer=WriterMock(), continue100=None, timer=TimerNoop(), traces=[], @@ -724,7 +732,7 @@ def test_raise_for_status_2xx() -> None: "get", URL("http://def-cl-resp.org"), request_info=mock.Mock(), - writer=mock.Mock(), + writer=WriterMock(), continue100=None, timer=TimerNoop(), traces=[], @@ -741,7 +749,7 @@ def test_raise_for_status_4xx() -> None: "get", URL("http://def-cl-resp.org"), request_info=mock.Mock(), - writer=mock.Mock(), + writer=WriterMock(), continue100=None, timer=TimerNoop(), traces=[], @@ -762,7 +770,7 @@ def test_raise_for_status_4xx_without_reason() -> None: "get", URL("http://def-cl-resp.org"), request_info=mock.Mock(), - writer=mock.Mock(), + writer=WriterMock(), continue100=None, timer=TimerNoop(), traces=[], @@ -783,7 +791,7 @@ def test_resp_host() -> None: "get", URL("http://del-cl-resp.org"), request_info=mock.Mock(), - writer=mock.Mock(), + writer=WriterMock(), continue100=None, timer=TimerNoop(), traces=[], @@ -798,7 +806,7 @@ def test_content_type() -> None: "get", URL("http://def-cl-resp.org"), request_info=mock.Mock(), - writer=mock.Mock(), + writer=WriterMock(), continue100=None, timer=TimerNoop(), traces=[], @@ -815,7 +823,7 @@ def test_content_type_no_header() -> None: "get", URL("http://def-cl-resp.org"), request_info=mock.Mock(), - writer=mock.Mock(), + writer=WriterMock(), continue100=None, timer=TimerNoop(), traces=[], @@ -832,7 +840,7 @@ def test_charset() -> None: "get", URL("http://def-cl-resp.org"), request_info=mock.Mock(), - writer=mock.Mock(), + writer=WriterMock(), continue100=None, timer=TimerNoop(), traces=[], @@ -849,7 +857,7 @@ def test_charset_no_header() -> None: "get", URL("http://def-cl-resp.org"), request_info=mock.Mock(), - writer=mock.Mock(), + writer=WriterMock(), continue100=None, timer=TimerNoop(), traces=[], @@ -866,7 +874,7 @@ def test_charset_no_charset() -> None: "get", URL("http://def-cl-resp.org"), request_info=mock.Mock(), - writer=mock.Mock(), + writer=WriterMock(), continue100=None, timer=TimerNoop(), traces=[], @@ -883,7 +891,7 @@ def test_content_disposition_full() -> None: "get", URL("http://def-cl-resp.org"), request_info=mock.Mock(), - writer=mock.Mock(), + writer=WriterMock(), continue100=None, timer=TimerNoop(), traces=[], @@ -906,7 +914,7 @@ def test_content_disposition_no_parameters() -> None: "get", URL("http://def-cl-resp.org"), request_info=mock.Mock(), - writer=mock.Mock(), + writer=WriterMock(), continue100=None, timer=TimerNoop(), traces=[], @@ -925,7 +933,7 @@ def test_content_disposition_no_header() -> None: "get", URL("http://def-cl-resp.org"), request_info=mock.Mock(), - writer=mock.Mock(), + writer=WriterMock(), continue100=None, timer=TimerNoop(), traces=[], @@ -944,7 +952,7 @@ def test_response_request_info() -> None: "get", URL(url), request_info=RequestInfo(url, "get", headers), - writer=mock.Mock(), + writer=WriterMock(), continue100=None, timer=TimerNoop(), traces=[], @@ -963,7 +971,7 @@ def test_request_info_in_exception() -> None: "get", URL(url), request_info=RequestInfo(url, "get", headers), - writer=mock.Mock(), + writer=WriterMock(), continue100=None, timer=TimerNoop(), traces=[], @@ -984,7 +992,7 @@ def test_no_redirect_history_in_exception() -> None: "get", URL(url), request_info=RequestInfo(url, "get", headers), - writer=mock.Mock(), + writer=WriterMock(), continue100=None, timer=TimerNoop(), traces=[], @@ -1007,7 +1015,7 @@ def test_redirect_history_in_exception() -> None: "get", URL(url), request_info=RequestInfo(url, "get", headers), - writer=mock.Mock(), + writer=WriterMock(), continue100=None, timer=TimerNoop(), traces=[], @@ -1021,7 +1029,7 @@ def test_redirect_history_in_exception() -> None: "get", URL(hist_url), request_info=RequestInfo(url, "get", headers), - writer=mock.Mock(), + writer=WriterMock(), continue100=None, timer=TimerNoop(), traces=[], @@ -1050,7 +1058,7 @@ async def test_response_read_triggers_callback(loop, session) -> None: response_method, response_url, request_info=mock.Mock, - writer=mock.Mock(), + writer=WriterMock(), continue100=None, timer=TimerNoop(), loop=loop, @@ -1083,7 +1091,7 @@ def test_response_real_url(loop, session) -> None: "get", url, request_info=mock.Mock(), - writer=mock.Mock(), + writer=WriterMock(), continue100=None, timer=TimerNoop(), traces=[], @@ -1100,7 +1108,7 @@ def test_response_links_comma_separated(loop, session) -> None: "get", url, request_info=mock.Mock(), - writer=mock.Mock(), + writer=WriterMock(), continue100=None, timer=TimerNoop(), traces=[], @@ -1130,7 +1138,7 @@ def test_response_links_multiple_headers(loop, session) -> None: "get", url, request_info=mock.Mock(), - writer=mock.Mock(), + writer=WriterMock(), continue100=None, timer=TimerNoop(), traces=[], @@ -1155,7 +1163,7 @@ def test_response_links_no_rel(loop, session) -> None: "get", url, request_info=mock.Mock(), - writer=mock.Mock(), + writer=WriterMock(), continue100=None, timer=TimerNoop(), traces=[], @@ -1174,7 +1182,7 @@ def test_response_links_quoted(loop, session) -> None: "get", url, request_info=mock.Mock(), - writer=mock.Mock(), + writer=WriterMock(), continue100=None, timer=TimerNoop(), traces=[], @@ -1197,7 +1205,7 @@ def test_response_links_relative(loop, session) -> None: "get", url, request_info=mock.Mock(), - writer=mock.Mock(), + writer=WriterMock(), continue100=None, timer=TimerNoop(), traces=[], @@ -1220,7 +1228,7 @@ def test_response_links_empty(loop, session) -> None: "get", url, request_info=mock.Mock(), - writer=mock.Mock(), + writer=WriterMock(), continue100=None, timer=TimerNoop(), traces=[], @@ -1236,7 +1244,7 @@ def test_response_not_closed_after_get_ok(mocker) -> None: "get", URL("http://del-cl-resp.org"), request_info=mock.Mock(), - writer=mock.Mock(), + writer=WriterMock(), continue100=None, timer=TimerNoop(), traces=[], diff --git a/tests/test_proxy_functional.py b/tests/test_proxy_functional.py index 475b6a13e26..61e30841cc1 100644 --- a/tests/test_proxy_functional.py +++ b/tests/test_proxy_functional.py @@ -398,7 +398,8 @@ async def test_proxy_http_acquired_cleanup(proxy_test_server, loop) -> None: assert 0 == len(conn._acquired) - resp = await sess.get(url, proxy=proxy.url) + async with sess.get(url, proxy=proxy.url) as resp: + pass assert resp.closed assert 0 == len(conn._acquired) diff --git a/tests/test_web_functional.py b/tests/test_web_functional.py index f93bb5d9e9c..28d97d9694c 100644 --- a/tests/test_web_functional.py +++ b/tests/test_web_functional.py @@ -1889,7 +1889,6 @@ async def handler(request): resp = await session.get(server.make_url("/")) async with resp: assert resp.status == 200 - assert resp.connection is None assert resp.connection is None await session.close()