From fa090b6408d3675406156b5b4cd5fb5832fcc3dc Mon Sep 17 00:00:00 2001 From: fjetter Date: Wed, 22 Jun 2022 10:49:05 +0200 Subject: [PATCH 1/2] Ensure work stealing verifies movement confirmation belongs to active worker --- distributed/core.py | 15 ++++++- distributed/scheduler.py | 10 ++++- distributed/stealing.py | 6 +-- distributed/tests/test_steal.py | 71 +++++++++++++++++++++++++++++++++ distributed/worker.py | 1 + 5 files changed, 98 insertions(+), 5 deletions(-) diff --git a/distributed/core.py b/distributed/core.py index f2abb45d05a..85f72f49fb5 100644 --- a/distributed/core.py +++ b/distributed/core.py @@ -189,6 +189,7 @@ def __init__( self._address = None self._listen_address = None self._port = None + self._host = None self._comms = {} self.deserialize = deserialize self.monitor = SystemMonitor() @@ -438,6 +439,18 @@ def listen_address(self): self._listen_address = self.listener.listen_address return self._listen_address + @property + def host(self): + """ + The host this Server is running on. + + This will raise ValueError if the Server is listening on a + non-IP based protocol. + """ + if not self._host: + self._host, self._port = get_address_host_port(self.address) + return self._host + @property def port(self): """ @@ -447,7 +460,7 @@ def port(self): non-IP based protocol. """ if not self._port: - _, self._port = get_address_host_port(self.address) + self._host, self._port = get_address_host_port(self.address) return self._port def identity(self) -> dict[str, str]: diff --git a/distributed/scheduler.py b/distributed/scheduler.py index 34f30938a51..c44b317216e 100644 --- a/distributed/scheduler.py +++ b/distributed/scheduler.py @@ -453,6 +453,9 @@ class WorkerState: #: Arbitrary additional metadata to be added to :meth:`~WorkerState.identity` extra: dict[str, Any] + # The unique server ID this WorkerState is referencing + server_id: str + __slots__ = tuple(__annotations__) def __init__( @@ -466,10 +469,12 @@ def __init__( memory_limit: int, local_directory: str, nanny: str, + server_id: str, services: dict[str, int] | None = None, versions: dict[str, Any] | None = None, extra: dict[str, Any] | None = None, ): + self.server_id = server_id self.address = address self.pid = pid self.name = name @@ -480,7 +485,7 @@ def __init__( self.versions = versions or {} self.nanny = nanny self.status = status - self._hash = hash((address, pid, name)) + self._hash = hash(self.server_id) self.nbytes = 0 self.occupancy = 0 self._memory_unmanaged_old = 0 @@ -548,6 +553,7 @@ def clean(self) -> WorkerState: services=self.services, nanny=self.nanny, extra=self.extra, + server_id=self.server_id, ) ws.processing = { ts.key: cost for ts, cost in self.processing.items() # type: ignore @@ -3576,6 +3582,7 @@ async def add_worker( *, address: str, status: str, + server_id: str, keys=(), nthreads=None, name=None, @@ -3639,6 +3646,7 @@ async def add_worker( versions=versions, nanny=nanny, extra=extra, + server_id=server_id, ) if ws.status == Status.running: self.running.add(ws) diff --git a/distributed/stealing.py b/distributed/stealing.py index 26e2c639c68..e3d5d256786 100644 --- a/distributed/stealing.py +++ b/distributed/stealing.py @@ -148,7 +148,7 @@ def log(self, msg): def add_worker(self, scheduler=None, worker=None): self.stealable[worker] = tuple(set() for _ in range(15)) - def remove_worker(self, scheduler=None, worker=None): + def remove_worker(self, scheduler: Scheduler, worker: str) -> None: del self.stealable[worker] def teardown(self): @@ -310,7 +310,6 @@ async def move_task_confirm(self, *, key, state, stimulus_id, worker=None): thief = d["thief"] victim = d["victim"] - logger.debug("Confirm move %s, %s -> %s. State: %s", key, victim, thief, state) self.in_flight_occupancy[thief] -= d["thief_duration"] @@ -331,8 +330,9 @@ async def move_task_confirm(self, *, key, state, stimulus_id, worker=None): self.scheduler._reevaluate_occupancy_worker(victim) elif ( state in _WORKER_STATE_UNDEFINED + # If our steal information is somehow stale we need to reschedule or state in _WORKER_STATE_CONFIRM - and thief.address not in self.scheduler.workers + and thief != self.scheduler.workers.get(thief.address) ): self.log( ( diff --git a/distributed/tests/test_steal.py b/distributed/tests/test_steal.py index ade4be652ef..62b2f40899a 100644 --- a/distributed/tests/test_steal.py +++ b/distributed/tests/test_steal.py @@ -21,6 +21,7 @@ from distributed.system import MEMORY_LIMIT from distributed.utils_test import ( captured_logger, + freeze_batched_send, gen_cluster, inc, nodebug_setup_module, @@ -1129,6 +1130,76 @@ async def test_get_story(c, s, a, b): assert all(isinstance(m, tuple) for m in msgs) +@gen_cluster( + client=True, + config={ + "distributed.scheduler.work-stealing-interval": 1_000_000, + }, +) +async def test_steal_worker_dies_same_ip(c, s, w0, w1): + # https://github.com/dask/distributed/issues/5370 + steal = s.extensions["stealing"] + ev = Event() + futs1 = c.map( + lambda _, ev: ev.wait(), + range(10), + ev=ev, + key=[f"f1-{ix}" for ix in range(10)], + workers=[w0.address], + allow_other_workers=True, + ) + while not w0.active_keys: + await asyncio.sleep(0.01) + + victim_key = list(w0.state.ready)[-1][1] + + victim_ts = s.tasks[victim_key] + + wsA = victim_ts.processing_on + assert wsA.address == w0.address + wsB = s.workers[w1.address] + + steal.move_task_request(victim_ts, wsA, wsB) + len_before = len(s.events["stealing"]) + with freeze_batched_send(w0.batched_stream): + while not any( + isinstance(event, StealRequestEvent) for event in w0.state.stimulus_log + ): + await asyncio.sleep(0.1) + async with contextlib.AsyncExitStack() as stack: + # Block batched stream of w0 to ensure the steal-confirmation doesn't + # arrive at the scheduler before we want it to + await w1.close() + # Kill worker wsB + # Restart new worker with same IP, name, etc. + while w1.address in s.workers: + await asyncio.sleep(0.1) + + w_new = await stack.enter_async_context( + Worker( + s.address, + host=w1.host, + port=w1.port, + name=w1.name, + ) + ) + wsB2 = s.workers[w_new.address] + assert wsB2.address == wsB.address + assert wsB2 is not wsB + assert wsB2 != wsB + assert hash(wsB2) != hash(wsB) + + # Wait for the steal response to arrive + while len_before == len(s.events["stealing"]): + await asyncio.sleep(0.1) + + assert victim_ts.processing_on != wsB + + await w_new.close(executor_wait=False) + await ev.set() + await c.gather(futs1) + + @gen_cluster( client=True, nthreads=[("", 1)] * 3, diff --git a/distributed/worker.py b/distributed/worker.py index c3754d9fbd0..89a059d976b 100644 --- a/distributed/worker.py +++ b/distributed/worker.py @@ -1094,6 +1094,7 @@ async def _register_with_scheduler(self) -> None: metrics=await self.get_metrics(), extra=await self.get_startup_information(), stimulus_id=f"worker-connect-{time()}", + server_id=self.id, ), serializers=["msgpack"], ) From b20ada66d354428e4489d4258cf3c63c229ef7c4 Mon Sep 17 00:00:00 2001 From: fjetter Date: Wed, 22 Jun 2022 12:55:58 +0200 Subject: [PATCH 2/2] tweak WorkerState.__eq__ --- distributed/scheduler.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/distributed/scheduler.py b/distributed/scheduler.py index c44b317216e..3af733923d5 100644 --- a/distributed/scheduler.py +++ b/distributed/scheduler.py @@ -507,9 +507,7 @@ def __hash__(self) -> int: return self._hash def __eq__(self, other: object) -> bool: - if not isinstance(other, WorkerState): - return False - return hash(self) == hash(other) + return isinstance(other, WorkerState) and other.server_id == self.server_id @property def has_what(self) -> Set[TaskState]: