From 9e1c95a75fafb773143ab42e59dbec269157a61b Mon Sep 17 00:00:00 2001 From: Will Date: Thu, 15 Oct 2020 14:04:14 +0100 Subject: [PATCH] Fix for race condition on connections in BaseConnector (#4937) (cherry picked from commit ad00c2e44b97e4e69827a8166d299d5ab171f7b9) --- CHANGES/4936.bugfix | 1 + CONTRIBUTORS.txt | 1 + aiohttp/connector.py | 5 +++-- tests/test_connector.py | 3 ++- 4 files changed, 7 insertions(+), 3 deletions(-) create mode 100644 CHANGES/4936.bugfix diff --git a/CHANGES/4936.bugfix b/CHANGES/4936.bugfix new file mode 100644 index 00000000000..b3a0c6d8e80 --- /dev/null +++ b/CHANGES/4936.bugfix @@ -0,0 +1 @@ +Fix for race condition on connections in BaseConnector that leads to exceeding the connection limit. diff --git a/CONTRIBUTORS.txt b/CONTRIBUTORS.txt index 4c85d255b72..79db9e68b3e 100644 --- a/CONTRIBUTORS.txt +++ b/CONTRIBUTORS.txt @@ -269,6 +269,7 @@ Weiwei Wang Will McGugan Willem de Groot William Grzybowski +William S. Wilson Ong Yang Zhou Yannick Koechlin diff --git a/aiohttp/connector.py b/aiohttp/connector.py index 3efea213b2f..1a4d12d6b1e 100644 --- a/aiohttp/connector.py +++ b/aiohttp/connector.py @@ -474,8 +474,9 @@ async def connect(self, req: 'ClientRequest', key = req.connection_key available = self._available_connections(key) - # Wait if there are no available connections. - if available <= 0: + # Wait if there are no available connections or if there are/were + # waiters (i.e. don't steal connection from a waiter about to wake up) + if available <= 0 or key in self._waiters: fut = self._loop.create_future() # This connection will now count towards the limit. diff --git a/tests/test_connector.py b/tests/test_connector.py index d854890dd4d..b3522f17f81 100644 --- a/tests/test_connector.py +++ b/tests/test_connector.py @@ -1713,7 +1713,7 @@ async def create_connection(req, traces, timeout): # with multiple concurrent requests and stops when it hits a # predefined maximum number of requests. - max_requests = 10 + max_requests = 50 num_requests = 0 start_requests = max_connections + 1 @@ -1726,6 +1726,7 @@ async def f(start=True): connection = await conn.connect(req, None, ClientTimeout()) await asyncio.sleep(0) connection.release() + await asyncio.sleep(0) tasks = [ loop.create_task(f(start=False)) for i in range(start_requests)