From 93dd3657a667d6367848ee0bb45bbab2ac25472d Mon Sep 17 00:00:00 2001 From: Jonas Obrist Date: Thu, 29 Aug 2019 16:38:39 +0900 Subject: [PATCH] fixed resolve_host race conditions --- aiohttp/connector.py | 11 ++++++----- tests/test_connector.py | 32 +++++++++++++++++++++++++++++++- 2 files changed, 37 insertions(+), 6 deletions(-) diff --git a/aiohttp/connector.py b/aiohttp/connector.py index 706bc99adb7..c1b0f6a90a6 100644 --- a/aiohttp/connector.py +++ b/aiohttp/connector.py @@ -778,24 +778,25 @@ async def _resolve_host(self, if (key in self._cached_hosts) and \ (not self._cached_hosts.expired(key)): + result = self._cached_hosts.next_addrs(key) if traces: for trace in traces: await trace.send_dns_cache_hit(host) - - return self._cached_hosts.next_addrs(key) + return result if key in self._throttle_dns_events: + event = self._throttle_dns_events[key] if traces: for trace in traces: await trace.send_dns_cache_hit(host) - await self._throttle_dns_events[key].wait() + await event.wait() else: + self._throttle_dns_events[key] = \ + EventResultOrError(self._loop) if traces: for trace in traces: await trace.send_dns_cache_miss(host) - self._throttle_dns_events[key] = \ - EventResultOrError(self._loop) try: if traces: diff --git a/tests/test_connector.py b/tests/test_connector.py index e40559d0c6f..6390c311f12 100644 --- a/tests/test_connector.py +++ b/tests/test_connector.py @@ -18,8 +18,9 @@ from aiohttp import client, web from aiohttp.client import ClientRequest, ClientTimeout from aiohttp.client_reqrep import ConnectionKey -from aiohttp.connector import Connection, _DNSCacheTable +from aiohttp.connector import Connection, _DNSCacheTable, TCPConnector from aiohttp.helpers import PY_37 +from aiohttp.locks import EventResultOrError from aiohttp.test_utils import make_mocked_coro, unused_port from aiohttp.tracing import Trace from conftest import needs_unix @@ -2257,3 +2258,32 @@ def test_next_addrs_single(self, dns_cache_table) -> None: addrs = dns_cache_table.next_addrs('foo') assert addrs == ['127.0.0.1'] + + +async def test_connector_cache_trace_race(): + class DummyTracer: + async def send_dns_cache_hit(self, *args, **kwargs): + connector._cached_hosts.remove(("", 0)) + + token = object() + connector = TCPConnector() + connector._cached_hosts.add(("", 0), [token]) + + traces = [DummyTracer()] + assert await connector._resolve_host("", 0, traces) == [token] + + +async def test_connector_throttle_trace_race(loop): + key = ("", 0) + token = object() + + class DummyTracer: + async def send_dns_cache_hit(self, *args, **kwargs): + event = connector._throttle_dns_events.pop(key) + event.set() + connector._cached_hosts.add(key, [token]) + + connector = TCPConnector() + connector._throttle_dns_events[key] = EventResultOrError(loop) + traces = [DummyTracer()] + assert await connector._resolve_host("", 0, traces) == [token]