Skip to content

Commit

Permalink
Remove worker reconnect (#6361)
Browse files Browse the repository at this point in the history
When a worker disconnects from the scheduler, close it immediately instead of trying to reconnect.

Also prohibit workers from joining if they have data in memory, as an alternative to #6341.

Closes #6350
  • Loading branch information
gjoseph92 authored May 20, 2022
1 parent fb3589c commit f669f06
Show file tree
Hide file tree
Showing 8 changed files with 259 additions and 537 deletions.
19 changes: 17 additions & 2 deletions distributed/cli/dask_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,8 +170,8 @@
)
@click.option(
"--reconnect/--no-reconnect",
default=True,
help="Reconnect to scheduler if disconnected [default: --reconnect]",
default=None,
help="Deprecated, has no effect. Passing --reconnect is an error. [default: --no-reconnect]",
)
@click.option(
"--nanny/--no-nanny",
Expand Down Expand Up @@ -280,6 +280,7 @@ def main(
dashboard_address,
worker_class,
preload_nanny,
reconnect,
**kwargs,
):
g0, g1, g2 = gc.get_threshold() # https://github.com/dask/distributed/issues/1653
Expand All @@ -299,6 +300,20 @@ def main(
"The --bokeh/--no-bokeh flag has been renamed to --dashboard/--no-dashboard. "
)
dashboard = bokeh
if reconnect is not None:
if reconnect:
logger.error(
"The `--reconnect` option has been removed. "
"To improve cluster stability, workers now always shut down in the face of network disconnects. "
"For details, or if this is an issue for you, see https://github.com/dask/distributed/issues/6350."
)
sys.exit(1)
else:
logger.warning(
"The `--no-reconnect/--reconnect` flag is deprecated, and will be removed in a future release. "
"Worker reconnection is now always disabled, so `--no-reconnect` is unnecessary. "
"See https://github.com/dask/distributed/issues/6350 for details.",
)

sec = {
k: v
Expand Down
68 changes: 23 additions & 45 deletions distributed/cli/tests/test_dask_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@

from distributed import Client
from distributed.cli.dask_worker import _apportion_ports, main
from distributed.compatibility import LINUX, WINDOWS, to_thread
from distributed.compatibility import LINUX, WINDOWS
from distributed.deploy.utils import nprocesses_nthreads
from distributed.metrics import time
from distributed.utils_test import gen_cluster, popen, requires_ipv6
Expand Down Expand Up @@ -275,56 +275,34 @@ async def test_no_nanny(c, s):
await c.wait_for_workers(1)


@pytest.mark.slow
@pytest.mark.parametrize("nanny", ["--nanny", "--no-nanny"])
@gen_cluster(client=True, nthreads=[])
async def test_no_reconnect(c, s, nanny):
async def test_reconnect_deprecated(c, s):
with popen(
[
"dask-worker",
s.address,
"--no-reconnect",
nanny,
"--no-dashboard",
]
["dask-worker", s.address, "--reconnect"],
flush_output=False,
) as worker:
# roundtrip works
assert await c.submit(lambda x: x + 1, 10) == 11

(comm,) = s.stream_comms.values()
comm.abort()

# worker terminates as soon as the connection is aborted
await to_thread(worker.wait, timeout=5)
assert worker.returncode == 0

for _ in range(10):
line = worker.stdout.readline()
print(line)
if b"`--reconnect` option has been removed" in line:
break
else:
raise AssertionError("Message not printed, see stdout")
assert worker.wait() == 1

@pytest.mark.slow
@pytest.mark.parametrize("nanny", ["--nanny", "--no-nanny"])
@gen_cluster(client=True, nthreads=[])
async def test_reconnect(c, s, nanny):
with popen(
[
"dask-worker",
s.address,
"--reconnect",
nanny,
"--no-dashboard",
]
["dask-worker", s.address, "--no-reconnect"],
flush_output=False,
) as worker:
# roundtrip works
assert await c.submit(lambda x: x + 1, 10) == 11

(comm,) = s.stream_comms.values()
comm.abort()

# roundtrip still works, which means the worker reconnected
assert await c.submit(lambda x: x + 1, 11) == 12

# closing the scheduler cleanly does terminate the worker
await s.close()
await to_thread(worker.wait, timeout=5)
assert worker.returncode == 0
for _ in range(10):
line = worker.stdout.readline()
print(line)
if b"flag is deprecated, and will be removed" in line:
break
else:
raise AssertionError("Message not printed, see stdout")
await c.wait_for_workers(1)
await c.shutdown()


@pytest.mark.slow
Expand Down
1 change: 0 additions & 1 deletion distributed/nanny.py
Original file line number Diff line number Diff line change
Expand Up @@ -387,7 +387,6 @@ async def instantiate(self) -> Status:
nanny=self.address,
name=self.name,
memory_limit=self.memory_manager.memory_limit,
reconnect=self.reconnect,
resources=self.resources,
validate=self.validate,
silence_logs=self.silence_logs,
Expand Down
84 changes: 15 additions & 69 deletions distributed/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -3456,6 +3456,7 @@ def heartbeat_worker(
address = normalize_address(address)
ws = self.workers.get(address)
if ws is None:
logger.warning(f"Received heartbeat from unregistered worker {address!r}.")
return {"status": "missing"}

host = get_address_host(address)
Expand Down Expand Up @@ -3581,6 +3582,16 @@ async def add_worker(
if address in self.workers:
raise ValueError("Worker already exists %s" % address)

if nbytes:
err = (
f"Worker {address!r} connected with {len(nbytes)} key(s) in memory! Worker reconnection is not supported. "
f"Keys: {list(nbytes)}"
)
logger.error(err)
if comm:
await comm.write({"status": "error", "message": err, "time": time()})
return

if name in self.aliases:
logger.warning("Worker tried to connect with a duplicate name: %s", name)
msg = {
Expand Down Expand Up @@ -3655,51 +3666,8 @@ async def add_worker(
except Exception as e:
logger.exception(e)

recommendations: dict = {}
client_msgs: dict = {}
worker_msgs: dict = {}
if nbytes:
assert isinstance(nbytes, dict)
already_released_keys = []
for key in nbytes:
ts: TaskState = self.tasks.get(key) # type: ignore
if ts is not None and ts.state != "released":
if ts.state == "memory":
self.add_keys(worker=address, keys=[key])
else:
t: tuple = self._transition(
key,
"memory",
stimulus_id,
worker=address,
nbytes=nbytes[key],
typename=types[key],
)
recommendations, client_msgs, worker_msgs = t
self._transitions(
recommendations, client_msgs, worker_msgs, stimulus_id
)
recommendations = {}
else:
already_released_keys.append(key)
if already_released_keys:
if address not in worker_msgs:
worker_msgs[address] = []
worker_msgs[address].append(
{
"op": "remove-replicas",
"keys": already_released_keys,
"stimulus_id": stimulus_id,
}
)

if ws.status == Status.running:
recommendations.update(self.bulk_schedule_after_adding_worker(ws))

if recommendations:
self._transitions(recommendations, client_msgs, worker_msgs, stimulus_id)

self.send_all(client_msgs, worker_msgs)
self.transitions(self.bulk_schedule_after_adding_worker(ws), stimulus_id)

logger.info("Register worker %s", ws)

Expand Down Expand Up @@ -5060,43 +5028,21 @@ async def gather(self, keys, serializers=None):
)
result = {"status": "error", "keys": missing_keys}
with log_errors():
# Remove suspicious workers from the scheduler but allow them to
# reconnect.
# Remove suspicious workers from the scheduler and shut them down.
await asyncio.gather(
*(
self.remove_worker(
address=worker, close=False, stimulus_id=stimulus_id
address=worker, close=True, stimulus_id=stimulus_id
)
for worker in missing_workers
)
)
recommendations: dict
client_msgs: dict = {}
worker_msgs: dict = {}
for key, workers in missing_keys.items():
# Task may already be gone if it was held by a
# `missing_worker`
ts: TaskState = self.tasks.get(key)
logger.exception(
"Workers don't have promised key: %s, %s",
"Shut down workers that don't have promised key: %s, %s",
str(workers),
str(key),
)
if not workers or ts is None:
continue
recommendations: dict = {key: "released"}
for worker in workers:
ws = self.workers.get(worker)
if ws is not None and ws in ts.who_has:
# FIXME: This code path is not tested
self.remove_replica(ts, ws)
self._transitions(
recommendations,
client_msgs,
worker_msgs,
stimulus_id=stimulus_id,
)
self.send_all(client_msgs, worker_msgs)

self.log_event("all", {"action": "gather", "count": len(keys)})
return result
Expand Down
11 changes: 6 additions & 5 deletions distributed/tests/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -3590,7 +3590,7 @@ async def test_scatter_raises_if_no_workers(c, s):
await c.scatter(1, timeout=0.5)


@pytest.mark.flaky(reruns=2)
@pytest.mark.flaky(reruns=2) # due to random port
@gen_test()
async def test_reconnect():
port = random.randint(10000, 50000)
Expand All @@ -3609,11 +3609,8 @@ async def hard_stop(s):
s.stop()
await Server.close(s)

futures = []
w = Worker(f"127.0.0.1:{port}")
futures.append(asyncio.ensure_future(w.start()))

s = await Scheduler(port=port)
w = await Worker(f"127.0.0.1:{port}")
c = await Client(f"127.0.0.1:{port}", asynchronous=True)
await c.wait_for_workers(1, timeout=10)
x = c.submit(inc, 1)
Expand All @@ -3634,6 +3631,10 @@ async def hard_stop(s):
while c.status != "running":
await asyncio.sleep(0.1)
assert time() < start + 10

await w.finished()
w = await Worker(f"127.0.0.1:{port}")

start = time()
while len(await c.nthreads()) != 1:
await asyncio.sleep(0.05)
Expand Down
Loading

0 comments on commit f669f06

Please sign in to comment.