Skip to content

Commit

Permalink
Retry connection (#7363) (#8038)
Browse files Browse the repository at this point in the history
Fixes #7297

(cherry picked from commit be9a3cc)
  • Loading branch information
Dreamsorcerer authored Jan 20, 2024
1 parent c465e85 commit 6e3e53c
Show file tree
Hide file tree
Showing 3 changed files with 85 additions and 12 deletions.
1 change: 1 addition & 0 deletions CHANGES/7297.feature
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Added a feature to retry closed connections automatically for idempotent methods. -- by :user:`Dreamsorcerer`
10 changes: 10 additions & 0 deletions aiohttp/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,6 +167,9 @@ class ClientTimeout:
# 5 Minute default read timeout
DEFAULT_TIMEOUT: Final[ClientTimeout] = ClientTimeout(total=5 * 60)

# https://www.rfc-editor.org/rfc/rfc9110#section-9.2.2
IDEMPOTENT_METHODS = frozenset({"GET", "HEAD", "OPTIONS", "TRACE", "PUT", "DELETE"})

_RetType = TypeVar("_RetType")
_CharsetResolver = Callable[[ClientResponse, bytes], str]

Expand Down Expand Up @@ -507,6 +510,8 @@ async def _request(
timer = tm.timer()
try:
with timer:
# https://www.rfc-editor.org/rfc/rfc9112.html#name-retrying-requests
retry_persistent_connection = method in IDEMPOTENT_METHODS
while True:
url, auth_from_url = strip_auth_from_url(url)
if auth and auth_from_url:
Expand Down Expand Up @@ -614,6 +619,11 @@ async def _request(
except BaseException:
conn.close()
raise
except (ClientOSError, ServerDisconnectedError):
if retry_persistent_connection:
retry_persistent_connection = False
continue
raise
except ClientError:
raise
except OSError as exc:
Expand Down
86 changes: 74 additions & 12 deletions tests/test_client_functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@
import pathlib
import socket
import ssl
import sys
import time
from typing import Any, AsyncIterator
from unittest import mock

Expand Down Expand Up @@ -214,6 +216,67 @@ async def handler(request):
assert 0 == len(client._session.connector._conns)


async def test_keepalive_timeout_async_sleep() -> None:
async def handler(request):
body = await request.read()
assert b"" == body
return web.Response(body=b"OK")

app = web.Application()
app.router.add_route("GET", "/", handler)

runner = web.AppRunner(app, tcp_keepalive=True, keepalive_timeout=0.001)
await runner.setup()

port = unused_port()
site = web.TCPSite(runner, host="localhost", port=port)
await site.start()

try:
async with aiohttp.client.ClientSession() as sess:
resp1 = await sess.get(f"http://localhost:{port}/")
await resp1.read()
# wait for server keepalive_timeout
await asyncio.sleep(0.01)
resp2 = await sess.get(f"http://localhost:{port}/")
await resp2.read()
finally:
await asyncio.gather(runner.shutdown(), site.stop())


@pytest.mark.skipif(
sys.version_info[:2] == (3, 11),
reason="https://github.com/pytest-dev/pytest/issues/10763",
)
async def test_keepalive_timeout_sync_sleep() -> None:
async def handler(request):
body = await request.read()
assert b"" == body
return web.Response(body=b"OK")

app = web.Application()
app.router.add_route("GET", "/", handler)

runner = web.AppRunner(app, tcp_keepalive=True, keepalive_timeout=0.001)
await runner.setup()

port = unused_port()
site = web.TCPSite(runner, host="localhost", port=port)
await site.start()

try:
async with aiohttp.client.ClientSession() as sess:
resp1 = await sess.get(f"http://localhost:{port}/")
await resp1.read()
# wait for server keepalive_timeout
# time.sleep is a more challenging scenario than asyncio.sleep
time.sleep(0.01)
resp2 = await sess.get(f"http://localhost:{port}/")
await resp2.read()
finally:
await asyncio.gather(runner.shutdown(), site.stop())


async def test_release_early(aiohttp_client) -> None:
async def handler(request):
await request.read()
Expand Down Expand Up @@ -3043,21 +3106,20 @@ def connection_lost(self, exc):

addr = server.sockets[0].getsockname()

connector = aiohttp.TCPConnector(limit=1)
session = aiohttp.ClientSession(connector=connector)
async with aiohttp.TCPConnector(limit=1) as connector:
async with aiohttp.ClientSession(connector=connector) as session:
url = "http://{}:{}/".format(*addr)

url = "http://{}:{}/".format(*addr)
r = await session.request("GET", url)
await r.read()
assert 1 == len(connector._conns)
closed_conn = next(iter(connector._conns.values()))

r = await session.request("GET", url)
await r.read()
assert 1 == len(connector._conns)
await session.request("GET", url)
assert 1 == len(connector._conns)
new_conn = next(iter(connector._conns.values()))
assert closed_conn is not new_conn

with pytest.raises(aiohttp.ClientConnectionError):
await session.request("GET", url)
assert 0 == len(connector._conns)

await session.close()
await connector.close()
server.close()
await server.wait_closed()

Expand Down

0 comments on commit 6e3e53c

Please sign in to comment.