Skip to content

Commit

Permalink
Refactor ensure_communicating
Browse files Browse the repository at this point in the history
  • Loading branch information
crusaderky committed Apr 27, 2022
1 parent 9bad573 commit 738f7c6
Show file tree
Hide file tree
Showing 3 changed files with 109 additions and 69 deletions.
10 changes: 1 addition & 9 deletions distributed/tests/test_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -1412,21 +1412,13 @@ def assert_amm_transfer_story(key: str, w_from: Worker, w_to: Worker) -> None:
assert_story(
w_to.story(key),
[
(key, "ensure-task-exists", "released"),
(key, "released", "fetch", "fetch", {}),
("gather-dependencies", w_from.address, lambda set_: key in set_),
(key, "fetch", "flight", "flight", {}),
("request-dep", w_from.address, lambda set_: key in set_),
("receive-dep", w_from.address, lambda set_: key in set_),
(key, "put-in-memory"),
(key, "flight", "memory", "memory", {}),
],
# There may be additional ('missing', 'fetch', 'fetch') events if transfers
# are slow enough that the Active Memory Manager ends up requesting them a
# second time. Here we're asserting that no matter how slow CI is, all
# transfers will be completed within 2 seconds (hardcoded interval in
# Scheduler.retire_worker when AMM is not enabled).
strict=True,
strict=False,
)
assert key in w_to.data
# The key may or may not still be in w_from.data, depending if the AMM had the
Expand Down
142 changes: 94 additions & 48 deletions distributed/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
from pickle import PicklingError
from typing import TYPE_CHECKING, Any, ClassVar, Literal, cast

from tlz import first, keymap, merge, pluck # noqa: F401
from tlz import first, keymap, merge, peekn, pluck # noqa: F401
from tornado.ioloop import IOLoop, PeriodicCallback

import dask
Expand Down Expand Up @@ -114,6 +114,8 @@
Execute,
ExecuteFailureEvent,
ExecuteSuccessEvent,
GatherDep,
GatherDepDoneEvent,
Instructions,
InvalidTransition,
LongRunningMsg,
Expand Down Expand Up @@ -1201,7 +1203,7 @@ async def heartbeat(self):

async def handle_scheduler(self, comm):
try:
await self.handle_stream(comm, every_cycle=[self.ensure_communicating])
await self.handle_stream(comm)
except Exception as e:
logger.exception(e)
raise
Expand Down Expand Up @@ -1937,6 +1939,12 @@ def handle_compute_task(
for key, value in nbytes.items():
self.tasks[key].nbytes = value

def _add_to_data_needed(self, ts: TaskState) -> RecsInstrs:
self.data_needed.push(ts)
for w in ts.who_has:
self.data_needed_per_worker[w].push(ts)
return self._ensure_communicating()

def transition_missing_fetch(
self, ts: TaskState, *, stimulus_id: str
) -> RecsInstrs:
Expand All @@ -1947,10 +1955,7 @@ def transition_missing_fetch(
self._missing_dep_flight.discard(ts)
ts.state = "fetch"
ts.done = False
self.data_needed.push(ts)
for w in ts.who_has:
self.data_needed_per_worker[w].push(ts)
return {}, []
return self._add_to_data_needed(ts)

def transition_missing_released(
self, ts: TaskState, *, stimulus_id: str
Expand Down Expand Up @@ -1987,10 +1992,7 @@ def transition_released_fetch(
assert ts.priority is not None
ts.state = "fetch"
ts.done = False
self.data_needed.push(ts)
for w in ts.who_has:
self.data_needed_per_worker[w].push(ts)
return {}, []
return self._add_to_data_needed(ts)

def transition_generic_released(
self, ts: TaskState, *, stimulus_id: str
Expand Down Expand Up @@ -2426,17 +2428,13 @@ def transition_flight_fetch(self, ts: TaskState, *, stimulus_id: str) -> RecsIns
if not ts.done:
return {}, []

recommendations: Recs = {}
ts.state = "fetch"
ts.coming_from = None
ts.done = False
if not ts.who_has:
recommendations[ts] = "missing"
if ts.who_has:
return self._add_to_data_needed(ts)
else:
self.data_needed.push(ts)
for w in ts.who_has:
self.data_needed_per_worker[w].push(ts)
return recommendations, []
return {ts: "missing"}, []

def transition_flight_error(
self,
Expand Down Expand Up @@ -2699,18 +2697,35 @@ def _handle_instructions(self, instructions: Instructions) -> None:
# TODO this method is temporary.
# See final design: https://github.com/dask/distributed/issues/5894
for inst in instructions:
task = None
if isinstance(inst, SendMessageToScheduler):
self.batched_stream.send(inst.to_dict())
elif isinstance(inst, Execute):
task = asyncio.create_task(
self.execute(inst.key, stimulus_id=inst.stimulus_id),
name=f"execute({inst.key})",
)
self._async_instructions.add(task)
task.add_done_callback(self._handle_stimulus_from_task)
elif isinstance(inst, GatherDep):
assert inst.to_gather
keys_str = ", ".join(peekn(27, inst.to_gather)[0])
if len(keys_str) > 80:
keys_str = keys_str[:77] + "..."
task = asyncio.create_task(
self.gather_dep(
inst.worker,
inst.to_gather,
total_nbytes=inst.total_nbytes,
stimulus_id=inst.stimulus_id,
),
name=f"gather_dep({inst.worker}, {{{keys_str}}})",
)
else:
raise TypeError(inst) # pragma: nocover

if task is not None:
self._async_instructions.add(task)
task.add_done_callback(self._handle_stimulus_from_task)

def maybe_transition_long_running(
self, ts: TaskState, *, compute_duration: float, stimulus_id: str
):
Expand Down Expand Up @@ -2747,13 +2762,17 @@ def stimulus_story(
keys = {e.key if isinstance(e, TaskState) else e for e in keys_or_tasks}
return [ev for ev in self.stimulus_log if getattr(ev, "key", None) in keys]

def ensure_communicating(self) -> None:
def _ensure_communicating(self) -> RecsInstrs:
if self.status != Status.running:
return
return {}, []

stimulus_id = f"ensure-communicating-{time()}"
skipped_worker_in_flight_or_busy = []

recommendations: Recs = {}
instructions: Instructions = []
all_keys_to_gather: set[str] = set()

while self.data_needed and (
len(self.in_flight_workers) < self.total_out_connections
or self.comm_nbytes < self.comm_threshold_bytes
Expand All @@ -2768,7 +2787,7 @@ def ensure_communicating(self) -> None:

ts = self.data_needed.pop()

if ts.state != "fetch":
if ts.state != "fetch" or ts.key in all_keys_to_gather:
continue

if self.validate:
Expand All @@ -2788,30 +2807,41 @@ def ensure_communicating(self) -> None:
local = [w for w in workers if get_address_host(w) == host]
worker = random.choice(local or workers)

to_gather, total_nbytes = self.select_keys_for_gather(worker, ts.key)
to_gather, total_nbytes = self._select_keys_for_gather(
worker, ts.key, all_keys_to_gather
)
all_keys_to_gather |= to_gather

self.log.append(
("gather-dependencies", worker, to_gather, stimulus_id, time())
)

self.comm_nbytes += total_nbytes
self.in_flight_workers[worker] = to_gather
recommendations: Recs = {
self.tasks[d]: ("flight", worker) for d in to_gather
}
self.transitions(recommendations, stimulus_id=stimulus_id)
for d_key in to_gather:
d_ts = self.tasks[d_key]
if self.validate:
assert d_ts.state == "fetch"
assert d_ts not in recommendations
recommendations[d_ts] = ("flight", worker)

self.loop.add_callback(
self.gather_dep,
worker=worker,
to_gather=to_gather,
total_nbytes=total_nbytes,
stimulus_id=stimulus_id,
# Note: given n tasks that must be fetched from the same worker, this method
# may generate anywhere between 1 and n GatherDep instructions, as multiple
# tasks may be clustered in the same instruction by _select_keys_for_gather
instructions.append(
GatherDep(
worker=worker,
to_gather=to_gather,
total_nbytes=total_nbytes,
stimulus_id=stimulus_id,
)
)

for ts in skipped_worker_in_flight_or_busy:
self.data_needed.push(ts)

return recommendations, instructions

def _get_task_finished_msg(
self, ts: TaskState, stimulus_id: str
) -> TaskFinishedMsg:
Expand Down Expand Up @@ -2896,16 +2926,22 @@ def _put_key_in_memory(self, ts: TaskState, value, *, stimulus_id: str) -> Recs:
self.log.append((ts.key, "put-in-memory", stimulus_id, time()))
return recommendations

def select_keys_for_gather(self, worker, dep):
assert isinstance(dep, str)
def _select_keys_for_gather(
self, worker: str, dep: str, all_keys_to_gather: Container[str]
) -> tuple[set[str], int]:
"""``_ensure_communicating`` decided to fetch a single task from a worker,
following priority. In order to minimise overhead, request fetching other tasks
from the same worker within the message, following priority for the single
worker but ignoring higher priority tasks from other workers, up to
``target_message_size``.
"""
deps = {dep}

total_bytes = self.tasks[dep].get_nbytes()
tasks = self.data_needed_per_worker[worker]

while tasks:
ts = tasks.peek()
if ts.state != "fetch":
if ts.state != "fetch" or ts.key in all_keys_to_gather:
tasks.pop()
continue
if total_bytes + ts.get_nbytes() > self.target_message_size:
Expand Down Expand Up @@ -3031,7 +3067,7 @@ async def gather_dep(
total_nbytes: int,
*,
stimulus_id: str,
) -> None:
) -> StateMachineEvent | None:
"""Gather dependencies for a task from a worker who has them
Parameters
Expand All @@ -3046,7 +3082,7 @@ async def gather_dep(
Total number of bytes for all the dependencies in to_gather combined
"""
if self.status not in Status.ANY_RUNNING: # type: ignore
return
return None

recommendations: Recs = {}
response = {}
Expand All @@ -3061,7 +3097,7 @@ async def gather_dep(
self.log.append(
("nothing-to-gather", worker, to_gather, stimulus_id, time())
)
return
return GatherDepDoneEvent(stimulus_id=stimulus_id)

assert cause
# Keep namespace clean since this func is long and has many
Expand All @@ -3084,7 +3120,7 @@ async def gather_dep(
)
stop = time()
if response["status"] == "busy":
return
return GatherDepDoneEvent(stimulus_id=stimulus_id)

self._update_metrics_received_data(
start=start,
Expand All @@ -3096,6 +3132,7 @@ async def gather_dep(
self.log.append(
("receive-dep", worker, set(response["data"]), stimulus_id, time())
)
return GatherDepDoneEvent(stimulus_id=stimulus_id)

except OSError:
logger.exception("Worker stream died during communication: %s", worker)
Expand All @@ -3112,6 +3149,8 @@ async def gather_dep(
self.log.append(
("missing-who-has", worker, ts.key, stimulus_id, time())
)
return GatherDepDoneEvent(stimulus_id=stimulus_id)

except Exception as e:
logger.exception(e)
if self.batched_stream and LOG_PDB:
Expand All @@ -3122,7 +3161,8 @@ async def gather_dep(
for k in self.in_flight_workers[worker]:
ts = self.tasks[k]
recommendations[ts] = tuple(msg.values())
raise
return GatherDepDoneEvent(stimulus_id=stimulus_id)

finally:
self.comm_nbytes -= total_nbytes
busy = response.get("status", "") == "busy"
Expand Down Expand Up @@ -3180,12 +3220,12 @@ async def gather_dep(
)
self.update_who_has(who_has)

self.ensure_communicating()

@log_errors
def _readd_busy_worker(self, worker: str) -> None:
self.busy_workers.remove(worker)
self.ensure_communicating()
self.handle_stimulus(
GatherDepDoneEvent(stimulus_id=f"readd-busy-worker-{time()}")
)

@log_errors
async def find_missing(self) -> None:
Expand Down Expand Up @@ -3214,7 +3254,6 @@ async def find_missing(self) -> None:
self.periodic_callbacks[
"find-missing"
].callback_time = self.periodic_callbacks["heartbeat"].callback_time
self.ensure_communicating()

def update_who_has(self, who_has: dict[str, Collection[str]]) -> None:
try:
Expand Down Expand Up @@ -3686,8 +3725,15 @@ def _(self, ev: UnpauseEvent) -> RecsInstrs:
Worker.status back to running.
"""
assert self.status == Status.running
self.ensure_communicating()
return self._ensure_computing()
return merge_recs_instructions(
self._ensure_computing(),
self._ensure_communicating(),
)

@handle_event.register
def _(self, ev: GatherDepDoneEvent) -> RecsInstrs:
"""Temporary hack - to be removed"""
return self._ensure_communicating()

@handle_event.register
def _(self, ev: CancelComputeEvent) -> RecsInstrs:
Expand Down
26 changes: 14 additions & 12 deletions distributed/worker_state_machine.py
Original file line number Diff line number Diff line change
Expand Up @@ -259,18 +259,13 @@ class Instruction:
__slots__ = ()


# TODO https://github.com/dask/distributed/issues/5736

# @dataclass
# class GatherDep(Instruction):
# __slots__ = ("worker", "to_gather")
# worker: str
# to_gather: set[str]


# @dataclass
# class FindMissing(Instruction):
# __slots__ = ()
@dataclass
class GatherDep(Instruction):
worker: str
to_gather: set[str]
total_nbytes: int
stimulus_id: str
__slots__ = tuple(__annotations__) # type: ignore


@dataclass
Expand Down Expand Up @@ -434,6 +429,13 @@ class UnpauseEvent(StateMachineEvent):
__slots__ = ()


@dataclass
class GatherDepDoneEvent(StateMachineEvent):
"""Temporary hack - to be removed"""

__slots__ = ()


@dataclass
class ExecuteSuccessEvent(StateMachineEvent):
key: str
Expand Down

0 comments on commit 738f7c6

Please sign in to comment.