Skip to content

Commit

Permalink
Remove hard coded connect handshake timeouts (#4176)
Browse files Browse the repository at this point in the history
  • Loading branch information
fjetter authored Nov 3, 2020
1 parent debe1b8 commit 4205280
Show file tree
Hide file tree
Showing 3 changed files with 180 additions and 100 deletions.
136 changes: 67 additions & 69 deletions distributed/comm/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -213,9 +213,15 @@ async def _():

async def on_connection(self, comm: Comm, handshake_overrides=None):
local_info = {**comm.handshake_info(), **(handshake_overrides or {})}

timeout = dask.config.get("distributed.comm.timeouts.connect")
timeout = parse_timedelta(timeout, default="seconds")
try:
write = await asyncio.wait_for(comm.write(local_info), 1)
handshake = await asyncio.wait_for(comm.read(), 1)
# Timeout is to ensure that we'll terminate connections eventually.
# Connector side will employ smaller timeouts and we should only
# reach this if the comm is dead anyhow.
write = await asyncio.wait_for(comm.write(local_info), timeout=timeout)
handshake = await asyncio.wait_for(comm.read(), timeout=timeout)
# This would be better, but connections leak if worker is closed quickly
# write, handshake = await asyncio.gather(comm.write(local_info), comm.read())
except Exception as e:
Expand Down Expand Up @@ -262,79 +268,71 @@ async def connect(
comm = None

start = time()
deadline = start + timeout
error = None

def _raise(error):
error = error or "connect() didn't finish in time"
msg = "Timed out trying to connect to %r after %s s: %s" % (
addr,
timeout,
error,
)
raise IOError(msg)

backoff = 0.01
if timeout and timeout / 20 < backoff:
backoff = timeout / 20
def time_left():
deadline = start + timeout
return max(0, deadline - time())

retry_timeout_backoff = random.randrange(140, 160) / 100
backoff_base = 0.01
attempt = 0

# This starts a thread
while True:
# Prefer multiple small attempts than one long attempt. This should protect
# primarily from DNS race conditions
# gh3104, gh4176, gh4167
intermediate_cap = timeout / 5
active_exception = None
while time_left() > 0:
try:
while deadline - time() > 0:

async def _():
comm = await connector.connect(
loc, deserialize=deserialize, **connection_args
)
local_info = {
**comm.handshake_info(),
**(handshake_overrides or {}),
}
try:
handshake = await asyncio.wait_for(comm.read(), 1)
write = await asyncio.wait_for(comm.write(local_info), 1)
# This would be better, but connections leak if worker is closed quickly
# write, handshake = await asyncio.gather(comm.write(local_info), comm.read())
except Exception as e:
with suppress(Exception):
await comm.close()
raise CommClosedError() from e

comm.remote_info = handshake
comm.remote_info["address"] = comm._peer_addr
comm.local_info = local_info
comm.local_info["address"] = comm._local_addr

comm.handshake_options = comm.handshake_configuration(
comm.local_info, comm.remote_info
)
return comm

with suppress(TimeoutError):
comm = await asyncio.wait_for(
_(), timeout=min(deadline - time(), retry_timeout_backoff)
)
break
if not comm:
_raise(error)
comm = await asyncio.wait_for(
connector.connect(loc, deserialize=deserialize, **connection_args),
timeout=min(intermediate_cap, time_left()),
)
break
except FatalCommClosedError:
raise
except EnvironmentError as e:
error = str(e)
if time() < deadline:
logger.debug("Could not connect, waiting before retrying")
await asyncio.sleep(backoff)
backoff *= random.randrange(140, 160) / 100
retry_timeout_backoff *= random.randrange(140, 160) / 100
backoff = min(backoff, 1) # wait at most one second
else:
_raise(error)
else:
break

# CommClosed, EnvironmentError inherit from OSError
except (TimeoutError, OSError) as exc:
active_exception = exc

# The intermediate capping is mostly relevant for the initial
# connect. Afterwards we should be more forgiving
intermediate_cap = intermediate_cap * 1.5
# FullJitter see https://aws.amazon.com/blogs/architecture/exponential-backoff-and-jitter/

upper_cap = min(time_left(), backoff_base * (2 ** attempt))
backoff = random.uniform(0, upper_cap)
attempt += 1
logger.debug("Could not connect, waiting for %s before retrying", backoff)
await asyncio.sleep(backoff)
else:
raise IOError(
f"Timed out trying to connect to {addr} after {timeout} s"
) from active_exception

local_info = {
**comm.handshake_info(),
**(handshake_overrides or {}),
}
try:
# This would be better, but connections leak if worker is closed quickly
# write, handshake = await asyncio.gather(comm.write(local_info), comm.read())
handshake = await asyncio.wait_for(comm.read(), time_left())
await asyncio.wait_for(comm.write(local_info), time_left())
except Exception as exc:
with suppress(Exception):
await comm.close()
raise IOError(
f"Timed out during handshake while connecting to {addr} after {timeout} s"
) from exc

comm.remote_info = handshake
comm.remote_info["address"] = comm._peer_addr
comm.local_info = local_info
comm.local_info["address"] = comm._local_addr

comm.handshake_options = comm.handshake_configuration(
comm.local_info, comm.remote_info
)
return comm


Expand Down
4 changes: 3 additions & 1 deletion distributed/comm/tcp.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import errno
import logging
import socket
from ssl import SSLError
import struct
import sys
from tornado import gen
Expand Down Expand Up @@ -349,7 +350,6 @@ async def connect(self, address, deserialize=True, **connection_args):
stream = await self.client.connect(
ip, port, max_buffer_size=MAX_BUFFER_SIZE, **kwargs
)

# Under certain circumstances tornado will have a closed connnection with an error and not raise
# a StreamClosedError.
#
Expand All @@ -360,6 +360,8 @@ async def connect(self, address, deserialize=True, **connection_args):
except StreamClosedError as e:
# The socket connect() call failed
convert_stream_closed_error(self, e)
except SSLError as err:
raise FatalCommClosedError() from err

local_address = self.prefix + get_stream_address(stream)
comm = self.comm_class(
Expand Down
140 changes: 110 additions & 30 deletions distributed/comm/tests/test_comms.py
Original file line number Diff line number Diff line change
@@ -1,45 +1,42 @@
import asyncio
import types
from functools import partial
import os
import sys
import threading
import types
import warnings
from functools import partial

import distributed
import pkg_resources
import pytest

from tornado import ioloop
from tornado.concurrent import Future

import distributed
from distributed.metrics import time
from distributed.utils import get_ip, get_ipv6
from distributed.utils_test import (
requires_ipv6,
has_ipv6,
get_cert,
get_server_ssl_context,
get_client_ssl_context,
)
from distributed.utils_test import loop # noqa: F401

from distributed.protocol import to_serialize, Serialized, serialize, deserialize

from distributed.comm.registry import backends, get_backend
from distributed.comm import (
tcp,
inproc,
CommClosedError,
connect,
get_address_host,
get_local_address_for,
inproc,
listen,
CommClosedError,
parse_address,
parse_host_port,
unparse_host_port,
resolve_address,
get_address_host,
get_local_address_for,
tcp,
unparse_host_port,
)
from distributed.comm.registry import backends, get_backend
from distributed.comm.tcp import TCP, TCPBackend, TCPConnector
from distributed.metrics import time
from distributed.protocol import Serialized, deserialize, serialize, to_serialize
from distributed.utils import get_ip, get_ipv6
from distributed.utils_test import loop # noqa: F401
from distributed.utils_test import (
get_cert,
get_client_ssl_context,
get_server_ssl_context,
has_ipv6,
requires_ipv6,
)
from tornado import ioloop
from tornado.concurrent import Future

EXTERNAL_IP4 = get_ip()
if has_ipv6():
Expand Down Expand Up @@ -218,7 +215,7 @@ async def handle_comm(comm):
await comm.write(msg)
await comm.close()

listener = await tcp.TCPListener("localhost", handle_comm)
listener = await tcp.TCPListener("127.0.0.1", handle_comm)
host, port = listener.get_host_port()
assert host in ("localhost", "127.0.0.1", "::1")
assert port > 0
Expand Down Expand Up @@ -264,7 +261,7 @@ async def handle_comm(comm):
server_ctx = get_server_ssl_context()
client_ctx = get_client_ssl_context()

listener = await tcp.TLSListener("localhost", handle_comm, ssl_context=server_ctx)
listener = await tcp.TLSListener("127.0.0.1", handle_comm, ssl_context=server_ctx)
host, port = listener.get_host_port()
assert host in ("localhost", "127.0.0.1", "::1")
assert port > 0
Expand Down Expand Up @@ -665,7 +662,8 @@ async def handle_comm(comm):

with pytest.raises(EnvironmentError) as excinfo:
await connect(listener.contact_address, timeout=2, ssl_context=cli_ctx)
assert "certificate verify failed" in str(excinfo.value)

assert "certificate verify failed" in str(excinfo.value.__cause__)


#
Expand Down Expand Up @@ -797,6 +795,88 @@ async def handle_comm(comm):
#


async def echo(comm):
message = await comm.read()
await comm.write(message)


@pytest.mark.asyncio
async def test_retry_connect(monkeypatch):
async def echo(comm):
message = await comm.read()
await comm.write(message)

class UnreliableConnector(TCPConnector):
def __init__(self):

self.num_failures = 2
self.failures = 0
super().__init__()

async def connect(self, address, deserialize=True, **connection_args):
if self.failures > self.num_failures:
return await super().connect(address, deserialize, **connection_args)
else:
self.failures += 1
raise IOError()

class UnreliableBackend(TCPBackend):
_connector_class = UnreliableConnector

monkeypatch.setitem(backends, "tcp", UnreliableBackend())

listener = await listen("tcp://127.0.0.1:1234", echo)
try:
comm = await connect(listener.contact_address)
await comm.write(b"test")
msg = await comm.read()
assert msg == b"test"
finally:
listener.stop()


@pytest.mark.asyncio
async def test_handshake_slow_comm(monkeypatch):
class SlowComm(TCP):
def __init__(self, *args, delay_in_comm=0.5, **kwargs):
super().__init__(*args, **kwargs)
self.delay_in_comm = delay_in_comm

async def read(self, *args, **kwargs):
await asyncio.sleep(self.delay_in_comm)
return await super().read(*args, **kwargs)

async def write(self, *args, **kwargs):
await asyncio.sleep(self.delay_in_comm)
res = await super(type(self), self).write(*args, **kwargs)
return res

class SlowConnector(TCPConnector):
comm_class = SlowComm

class SlowBackend(TCPBackend):
_connector_class = SlowConnector

monkeypatch.setitem(backends, "tcp", SlowBackend())

listener = await listen("tcp://127.0.0.1:1234", echo)
try:
comm = await connect(listener.contact_address)
await comm.write(b"test")
msg = await comm.read()
assert msg == b"test"

import dask

with dask.config.set({"distributed.comm.timeouts.connect": "100ms"}):
with pytest.raises(
IOError, match="Timed out during handshake while connecting to"
):
await connect(listener.contact_address)
finally:
listener.stop()


async def check_connect_timeout(addr):
t1 = time()
with pytest.raises(IOError):
Expand Down

0 comments on commit 4205280

Please sign in to comment.