Skip to content

Commit

Permalink
Abort closed ssl client transports, broken servers can keep socket op…
Browse files Browse the repository at this point in the history
…en unlimit time #1568
  • Loading branch information
Nikolay Kim committed Feb 7, 2017
1 parent e651668 commit 5a4e7ab
Show file tree
Hide file tree
Showing 3 changed files with 148 additions and 23 deletions.
4 changes: 3 additions & 1 deletion CHANGES.rst
Original file line number Diff line number Diff line change
Expand Up @@ -51,11 +51,13 @@ CHANGES

- Return 504 if request handle raises TimeoutError.

- Refactor how we use keep-alive and clone lingering timeouts.
- Refactor how we use keep-alive and close lingering timeouts.

- Close response connection if we can not consume whole http
message during client response release

- Abort closed ssl client transports, broken servers can keep socket open unlimit time #1568

- Log warning instead of `RuntimeError` is websocket connection is closed.

- Deprecated: `aiohttp.protocol.HttpPrefixParser`
Expand Down
51 changes: 49 additions & 2 deletions aiohttp/connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,14 +104,19 @@ class BaseConnector(object):
force_close - Set to True to force close and do reconnect
after each request (and between redirects).
limit - The limit of simultaneous connections to the same endpoint.
disable_cleanup_closed - Disable clean-up closed ssl transports.
loop - Optional event loop.
"""

_closed = True # prevent AttributeError in __del__ if ctor was failed
_source_traceback = None

# abort transport after 2 seconds (cleanup broken connections)
_cleanup_closed_period = 2.0

def __init__(self, *, conn_timeout=None, keepalive_timeout=sentinel,
force_close=False, limit=20, time_service=None, loop=None):
force_close=False, limit=20, time_service=None,
disable_cleanup_closed=False, loop=None):

if force_close:
if keepalive_timeout is not None and \
Expand Down Expand Up @@ -151,11 +156,18 @@ def __init__(self, *, conn_timeout=None, keepalive_timeout=sentinel,

self.cookies = SimpleCookie()

# start keep-alive connection cleanup task
self._cleanup_handle = None
if (keepalive_timeout is not sentinel and
keepalive_timeout is not None):
self._cleanup()

# start cleanup closed transports task
self._cleanup_closed_handle = None
self._cleanup_closed_disabled = disable_cleanup_closed
self._cleanup_closed_transports = []
self._cleanup_closed()

def __del__(self, _warnings=warnings):
if self._closed:
return
Expand Down Expand Up @@ -219,6 +231,10 @@ def _cleanup(self):
if proto.is_connected():
if use_time - deadline < 0:
transport.close()
if (key[-1] and
not self._cleanup_closed_disabled):
self._cleanup_closed_transports.append(
transport)
else:
alive.append((transport, proto, use_time))

Expand All @@ -230,6 +246,22 @@ def _cleanup(self):
self._cleanup_handle = self._time_service.call_later(
self._keepalive_timeout / 2.0, self._cleanup)

def _cleanup_closed(self):
"""Double confirmation for transport close.
Some broken ssl servers may leave socket open without proper close.
"""
if self._cleanup_closed_handle:
self._cleanup_closed_handle.close()

for transport in self._cleanup_closed_transports:
transport.abort()

self._cleanup_closed_transports = []

if not self._cleanup_closed_disabled:
self._cleanup_closed_handle = self._time_service.call_later(
self._cleanup_closed_period, self._cleanup_closed)

def close(self):
"""Close all opened transports."""
ret = helpers.create_future(self._loop)
Expand All @@ -252,13 +284,24 @@ def close(self):
for transport in chain(*self._acquired.values()):
transport.close()

# cacnel cleanup task
if self._cleanup_handle:
self._cleanup_handle.cancel()

# cacnel cleanup close task
if self._cleanup_closed_handle:
self._cleanup_closed_handle.cancel()

for transport in self._cleanup_closed_transports:
transport.abort()

finally:
self._conns.clear()
self._acquired.clear()
self._cleanup_handle = None
self._cleanup_closed_transports.clear()
self._cleanup_closed_handle = None

return ret

@property
Expand Down Expand Up @@ -335,7 +378,8 @@ def _get(self, key):
if transport is not None and proto.is_connected():
if t1 - t0 > self._keepalive_timeout:
transport.close()
transport = None
if key[-1] and not self._cleanup_closed_disabled:
self._cleanup_closed_transports.append(transport)
else:
if not conns:
# The very last connection was reclaimed: drop the key
Expand Down Expand Up @@ -391,6 +435,9 @@ def _release(self, key, req, transport, protocol, *, should_close=False):
reader = protocol.reader
if should_close or (reader.output and not reader.output.at_eof()):
transport.close()

if key[-1] and not self._cleanup_closed_disabled:
self._cleanup_closed_transports.append(transport)
else:
conns = self._conns.get(key)
if conns is None:
Expand Down
116 changes: 96 additions & 20 deletions tests/test_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,12 +48,12 @@ def test_del(loop):
@pytest.mark.xfail
@asyncio.coroutine
def test_del_with_scheduled_cleanup(loop):
loop.set_debug(True)
conn = aiohttp.BaseConnector(loop=loop, keepalive_timeout=0.01)
transp = unittest.mock.Mock()
conn._conns['a'] = [(transp, 'proto', 123)]

conns_impl = conn._conns
conn._start_cleanup_task()
exc_handler = unittest.mock.Mock()
loop.set_exception_handler(exc_handler)

Expand Down Expand Up @@ -112,7 +112,8 @@ def test_create_conn(loop):

def test_ctor_loop():
with unittest.mock.patch('aiohttp.connector.asyncio') as m_asyncio:
session = aiohttp.BaseConnector()
session = aiohttp.BaseConnector(time_service=unittest.mock.Mock())

assert session._loop is m_asyncio.get_event_loop.return_value


Expand All @@ -121,7 +122,7 @@ def test_close(loop):

conn = aiohttp.BaseConnector(loop=loop)
assert not conn.closed
conn._conns[1] = [(tr, object(), object())]
conn._conns[('host', 8080, False)] = [(tr, object(), object())]
conn.close()

assert not conn._conns
Expand Down Expand Up @@ -171,12 +172,24 @@ def test_get(loop):

def test_get_expired(loop):
conn = aiohttp.BaseConnector(loop=loop)
assert conn._get(1) == (None, None)
assert conn._get(('localhost', 80, False)) == (None, None)

tr, proto = unittest.mock.Mock(), unittest.mock.Mock()
conn._conns[1] = [(tr, proto, loop.time() - 1000)]
assert conn._get(1) == (None, None)
conn._conns[('localhost', 80, False)] = [(tr, proto, loop.time() - 1000)]
assert conn._get(('localhost', 80, False)) == (None, None)
assert not conn._conns
conn.close()


def test_get_expired_ssl(loop):
conn = aiohttp.BaseConnector(loop=loop)
assert conn._get(('localhost', 80, True)) == (None, None)

tr, proto = unittest.mock.Mock(), unittest.mock.Mock()
conn._conns[('localhost', 80, True)] = [(tr, proto, loop.time() - 1000)]
assert conn._get(('localhost', 80, True)) == (None, None)
assert not conn._conns
assert conn._cleanup_closed_transports == [tr]
conn.close()


Expand Down Expand Up @@ -215,7 +228,6 @@ def test_release(loop):
loop.time = mock.Mock(return_value=10)

conn = aiohttp.BaseConnector(loop=loop)
conn._start_cleanup_task = unittest.mock.Mock()
conn._release_waiter = unittest.mock.Mock()
req = unittest.mock.Mock()
resp = req.response = unittest.mock.Mock()
Expand All @@ -227,6 +239,24 @@ def test_release(loop):
conn._release(key, req, tr, proto)
assert conn._release_waiter.called
assert conn._conns[1][0] == (tr, proto, 10)
assert not conn._cleanup_closed_transports
conn.close()


def test_release_ssl_transport(loop):
loop.time = mock.Mock(return_value=10)

conn = aiohttp.BaseConnector(loop=loop)
conn._release_waiter = unittest.mock.Mock()
req = unittest.mock.Mock()
resp = req.response = unittest.mock.Mock()
resp._should_close = True

tr, proto = unittest.mock.Mock(), unittest.mock.Mock()
key = ('localhost', 80, True)
conn._acquired[key].add(tr)
conn._release(key, req, tr, proto, should_close=True)
assert conn._cleanup_closed_transports == [tr]
conn.close()


Expand All @@ -238,14 +268,12 @@ def test_release_already_closed(loop):
conn._acquired[key].add(tr)
conn.close()

conn._start_cleanup_task = unittest.mock.Mock()
conn._release_waiter = unittest.mock.Mock()
conn._release_acquired = unittest.mock.Mock()
req = unittest.mock.Mock()

conn._release(key, req, tr, proto)
assert not conn._release_waiter.called
assert not conn._start_cleanup_task.called
assert not conn._release_acquired.called


Expand Down Expand Up @@ -289,7 +317,7 @@ def test_release_close(loop):
req.response = resp

tr, proto = unittest.mock.Mock(), unittest.mock.Mock()
key = 1
key = ('localhost', 80, False)
conn._acquired[key].add(tr)
conn._release(key, req, tr, proto)
assert not conn._conns
Expand Down Expand Up @@ -394,7 +422,7 @@ def test_release_not_opened(loop):
req.response.message = None

tr, proto = unittest.mock.Mock(), unittest.mock.Mock()
key = 1
key = ('localhost', 80, False)
conn._acquired[key].add(tr)
conn._release(key, req, tr, proto)
assert tr.close.called
Expand Down Expand Up @@ -461,13 +489,14 @@ def test_ctor_cleanup():


def test_cleanup():
key = ('localhost', 80, False)
testset = {
1: [(unittest.mock.Mock(), unittest.mock.Mock(), 10),
(unittest.mock.Mock(), unittest.mock.Mock(), 300),
(None, unittest.mock.Mock(), 300)],
key: [(unittest.mock.Mock(), unittest.mock.Mock(), 10),
(unittest.mock.Mock(), unittest.mock.Mock(), 300),
(None, unittest.mock.Mock(), 300)],
}
testset[1][0][1].is_connected.return_value = True
testset[1][1][1].is_connected.return_value = False
testset[key][0][1].is_connected.return_value = True
testset[key][1][1].is_connected.return_value = False

loop = unittest.mock.Mock()
time_service = unittest.mock.Mock()
Expand Down Expand Up @@ -502,9 +531,10 @@ def test_cleanup2():


def test_cleanup3():
testset = {1: [(unittest.mock.Mock(), unittest.mock.Mock(), 290.1),
(unittest.mock.Mock(), unittest.mock.Mock(), 305.1)]}
testset[1][0][1].is_connected.return_value = True
key = ('localhost', 80, False)
testset = {key: [(unittest.mock.Mock(), unittest.mock.Mock(), 290.1),
(unittest.mock.Mock(), unittest.mock.Mock(), 305.1)]}
testset[key][0][1].is_connected.return_value = True

loop = unittest.mock.Mock()
time_service = unittest.mock.Mock()
Expand All @@ -515,13 +545,40 @@ def test_cleanup3():
conn._conns = testset

conn._cleanup()
assert conn._conns == {1: [testset[1][1]]}
assert conn._conns == {key: [testset[key][1]]}

assert conn._cleanup_handle is not None
time_service.call_later.assert_called_with(5, conn._cleanup)
conn.close()


def test_cleanup_closed(loop):
ts = unittest.mock.Mock()
conn = aiohttp.BaseConnector(loop=loop, time_service=ts)

ts = conn._time_service = unittest.mock.Mock()
tr = unittest.mock.Mock()
conn._cleanup_closed_transports = [tr]
conn._cleanup_closed()
assert tr.abort.called
assert not conn._cleanup_closed_transports
assert ts.call_later.called


def test_cleanup_closed_disabled(loop):
ts = unittest.mock.Mock()
conn = aiohttp.BaseConnector(
loop=loop, time_service=ts, disable_cleanup_closed=True)

ts = conn._time_service = unittest.mock.Mock()
tr = unittest.mock.Mock()
conn._cleanup_closed_transports = [tr]
conn._cleanup_closed()
assert tr.abort.called
assert not conn._cleanup_closed_transports
assert not ts.call_later.called


def test_tcp_connector_ctor(loop):
conn = aiohttp.TCPConnector(loop=loop)
assert conn.verify_ssl
Expand Down Expand Up @@ -626,6 +683,25 @@ def test_close_cancels_cleanup_handle(loop):
assert conn._cleanup_handle is None


def test_close_abort_closed_transports(loop):
tr = unittest.mock.Mock()

conn = aiohttp.BaseConnector(loop=loop)
conn._cleanup_closed_transports.append(tr)
conn.close()

assert not conn._cleanup_closed_transports
assert tr.abort.called
assert conn.closed


def test_close_cancels_cleanup_closed_handle(loop):
conn = aiohttp.BaseConnector(loop=loop)
assert conn._cleanup_closed_handle is not None
conn.close()
assert conn._cleanup_closed_handle is None


def test_ctor_with_default_loop():
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
Expand Down

0 comments on commit 5a4e7ab

Please sign in to comment.