Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Ensure steal requests from same-IP but distinct workers are rejected #6585

Merged
merged 2 commits into from
Jun 22, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 14 additions & 1 deletion distributed/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -189,6 +189,7 @@ def __init__(
self._address = None
self._listen_address = None
self._port = None
self._host = None
self._comms = {}
self.deserialize = deserialize
self.monitor = SystemMonitor()
Expand Down Expand Up @@ -438,6 +439,18 @@ def listen_address(self):
self._listen_address = self.listener.listen_address
return self._listen_address

@property
def host(self):
"""
The host this Server is running on.

This will raise ValueError if the Server is listening on a
non-IP based protocol.
"""
if not self._host:
self._host, self._port = get_address_host_port(self.address)
return self._host

@property
def port(self):
"""
Expand All @@ -447,7 +460,7 @@ def port(self):
non-IP based protocol.
"""
if not self._port:
_, self._port = get_address_host_port(self.address)
self._host, self._port = get_address_host_port(self.address)
return self._port

def identity(self) -> dict[str, str]:
Expand Down
14 changes: 10 additions & 4 deletions distributed/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -453,6 +453,9 @@ class WorkerState:
#: Arbitrary additional metadata to be added to :meth:`~WorkerState.identity`
extra: dict[str, Any]

# The unique server ID this WorkerState is referencing
server_id: str

__slots__ = tuple(__annotations__)

def __init__(
Expand All @@ -466,10 +469,12 @@ def __init__(
memory_limit: int,
local_directory: str,
nanny: str,
server_id: str,
services: dict[str, int] | None = None,
versions: dict[str, Any] | None = None,
extra: dict[str, Any] | None = None,
):
self.server_id = server_id
self.address = address
self.pid = pid
self.name = name
Expand All @@ -480,7 +485,7 @@ def __init__(
self.versions = versions or {}
self.nanny = nanny
self.status = status
self._hash = hash((address, pid, name))
self._hash = hash(self.server_id)
self.nbytes = 0
self.occupancy = 0
self._memory_unmanaged_old = 0
Expand All @@ -502,9 +507,7 @@ def __hash__(self) -> int:
return self._hash

def __eq__(self, other: object) -> bool:
if not isinstance(other, WorkerState):
return False
return hash(self) == hash(other)
return isinstance(other, WorkerState) and other.server_id == self.server_id

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

#6593 instead makes it explicit, and explains why the decision. I think we should agree on one or the other and be consistent.

def __hash__(self) -> int:
    # TODO eplain
    return id(self)

def __eq__(self, other: object) -> bool:
    # TODO explain
    return other is self

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

#6593 adds to worker_state_machine.WorkerState.validate_state a check that you can never have two instances of a TaskState with the same key in any of its sets. I think scheduler.SchedulerState.validate_state should have the same logic.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not entirely convinced that this is necessary. There are many collections and this particular problem was not even a collection on the scheduler itself but rather an extension.
Specifically, without https://github.com/dask/distributed/pull/6585/files#r899011537 this condition would not even be true and I don't think we should implement the full remove_worker cleanup on the stealing extension

@property
def has_what(self) -> Set[TaskState]:
Expand Down Expand Up @@ -548,6 +551,7 @@ def clean(self) -> WorkerState:
services=self.services,
nanny=self.nanny,
extra=self.extra,
server_id=self.server_id,
)
ws.processing = {
ts.key: cost for ts, cost in self.processing.items() # type: ignore
Expand Down Expand Up @@ -3576,6 +3580,7 @@ async def add_worker(
*,
address: str,
status: str,
server_id: str,
keys=(),
nthreads=None,
name=None,
Expand Down Expand Up @@ -3639,6 +3644,7 @@ async def add_worker(
versions=versions,
nanny=nanny,
extra=extra,
server_id=server_id,
)
if ws.status == Status.running:
self.running.add(ws)
Expand Down
6 changes: 3 additions & 3 deletions distributed/stealing.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,7 +148,7 @@ def log(self, msg):
def add_worker(self, scheduler=None, worker=None):
self.stealable[worker] = tuple(set() for _ in range(15))

def remove_worker(self, scheduler=None, worker=None):
def remove_worker(self, scheduler: Scheduler, worker: str) -> None:
del self.stealable[worker]

def teardown(self):
Expand Down Expand Up @@ -310,7 +310,6 @@ async def move_task_confirm(self, *, key, state, stimulus_id, worker=None):

thief = d["thief"]
victim = d["victim"]

logger.debug("Confirm move %s, %s -> %s. State: %s", key, victim, thief, state)

self.in_flight_occupancy[thief] -= d["thief_duration"]
Expand All @@ -331,8 +330,9 @@ async def move_task_confirm(self, *, key, state, stimulus_id, worker=None):
self.scheduler._reevaluate_occupancy_worker(victim)
elif (
state in _WORKER_STATE_UNDEFINED
# If our steal information is somehow stale we need to reschedule
or state in _WORKER_STATE_CONFIRM
and thief.address not in self.scheduler.workers
and thief != self.scheduler.workers.get(thief.address)
fjetter marked this conversation as resolved.
Show resolved Hide resolved
):
self.log(
(
Expand Down
71 changes: 71 additions & 0 deletions distributed/tests/test_steal.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from distributed.system import MEMORY_LIMIT
from distributed.utils_test import (
captured_logger,
freeze_batched_send,
gen_cluster,
inc,
nodebug_setup_module,
Expand Down Expand Up @@ -1129,6 +1130,76 @@ async def test_get_story(c, s, a, b):
assert all(isinstance(m, tuple) for m in msgs)


@gen_cluster(
client=True,
config={
"distributed.scheduler.work-stealing-interval": 1_000_000,
},
)
async def test_steal_worker_dies_same_ip(c, s, w0, w1):
# https://github.com/dask/distributed/issues/5370
steal = s.extensions["stealing"]
ev = Event()
futs1 = c.map(
lambda _, ev: ev.wait(),
range(10),
ev=ev,
key=[f"f1-{ix}" for ix in range(10)],
workers=[w0.address],
allow_other_workers=True,
)
while not w0.active_keys:
await asyncio.sleep(0.01)

victim_key = list(w0.state.ready)[-1][1]

victim_ts = s.tasks[victim_key]

wsA = victim_ts.processing_on
assert wsA.address == w0.address
wsB = s.workers[w1.address]
fjetter marked this conversation as resolved.
Show resolved Hide resolved

steal.move_task_request(victim_ts, wsA, wsB)
len_before = len(s.events["stealing"])
with freeze_batched_send(w0.batched_stream):
while not any(
isinstance(event, StealRequestEvent) for event in w0.state.stimulus_log
):
await asyncio.sleep(0.1)
async with contextlib.AsyncExitStack() as stack:
# Block batched stream of w0 to ensure the steal-confirmation doesn't
# arrive at the scheduler before we want it to
await w1.close()
# Kill worker wsB
# Restart new worker with same IP, name, etc.
while w1.address in s.workers:
await asyncio.sleep(0.1)

w_new = await stack.enter_async_context(
Worker(
s.address,
host=w1.host,
port=w1.port,
name=w1.name,
)
)
wsB2 = s.workers[w_new.address]
assert wsB2.address == wsB.address
assert wsB2 is not wsB
assert wsB2 != wsB
assert hash(wsB2) != hash(wsB)

# Wait for the steal response to arrive
while len_before == len(s.events["stealing"]):
await asyncio.sleep(0.1)

assert victim_ts.processing_on != wsB

await w_new.close(executor_wait=False)
await ev.set()
await c.gather(futs1)


@gen_cluster(
client=True,
nthreads=[("", 1)] * 3,
Expand Down
1 change: 1 addition & 0 deletions distributed/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -1094,6 +1094,7 @@ async def _register_with_scheduler(self) -> None:
metrics=await self.get_metrics(),
extra=await self.get_startup_information(),
stimulus_id=f"worker-connect-{time()}",
server_id=self.id,
),
serializers=["msgpack"],
)
Expand Down