diff --git a/distributed/dashboard/components/scheduler.py b/distributed/dashboard/components/scheduler.py index 01ca786cef3..e677fe99a41 100644 --- a/distributed/dashboard/components/scheduler.py +++ b/distributed/dashboard/components/scheduler.py @@ -1968,7 +1968,7 @@ def update(self): current = len(self.scheduler.events["stealing"]) n = current - self.last - log = [log[-i][1] for i in range(1, n + 1) if isinstance(log[-i][1], list)] + log = [log[-i][1][1] for i in range(1, n + 1) if log[-i][1][0] == "request"] self.last = current if log: diff --git a/distributed/scheduler.py b/distributed/scheduler.py index cddada81986..af77964c948 100644 --- a/distributed/scheduler.py +++ b/distributed/scheduler.py @@ -425,11 +425,6 @@ class WorkerState: #: (i.e. the tasks in this worker's :attr:`~WorkerState.has_what`). nbytes: int - #: The total expected runtime, in seconds, of all tasks currently processing on this - #: worker. This is the sum of all the costs in this worker's - # :attr:`~WorkerState.processing` dictionary. - occupancy: float - #: Worker memory unknown to the worker, in bytes, which has been there for more than #: 30 seconds. See :class:`MemoryState`. _memory_unmanaged_old: int @@ -455,23 +450,13 @@ class WorkerState: #: Underlying data of :meth:`WorkerState.has_what` _has_what: dict[TaskState, None] - #: A dictionary of tasks that have been submitted to this worker. Each task state is - #: associated with the expected cost in seconds of running that task, summing both - #: the task's expected computation time and the expected communication time of its - #: result. - #: - #: If a task is already executing on the worker and the excecution time is twice the - #: learned average TaskGroup duration, this will be set to twice the current - #: executing time. If the task is unknown, the default task duration is used instead - #: of the TaskGroup average. - #: - #: Multiple tasks may be submitted to a worker in advance and the worker will run - #: them eventually, depending on its execution resources (but see - #: :doc:`work-stealing`). + #: A set of tasks that have been submitted to this worker. Multiple tasks may be + # submitted to a worker in advance and the worker will run them eventually, + # depending on its execution resources (but see :doc:`work-stealing`). #: #: All the tasks here are in the "processing" state. #: This attribute is kept in sync with :attr:`TaskState.processing_on`. - processing: dict[TaskState, float] + processing: set[TaskState] #: Running tasks that invoked :func:`distributed.secede` long_running: set[TaskState] @@ -497,6 +482,19 @@ class WorkerState: # The unique server ID this WorkerState is referencing server_id: str + # Reference to scheduler task_groups + scheduler_ref: weakref.ref[SchedulerState] | None + task_groups_count: defaultdict[str, int] + _network_occ: float + _occupancy_cache: float | None + + #: Keys that may need to be fetched to this worker, and the number of tasks that need them. + #: All tasks are currently in `memory` on a worker other than this one. + #: Much like `processing`, this does not exactly reflect worker state: + #: keys here may be queued to fetch, in flight, or already in memory + #: on the worker. + needs_what: dict[TaskState, int] + __slots__ = tuple(__annotations__) def __init__( @@ -514,6 +512,7 @@ def __init__( services: dict[str, int] | None = None, versions: dict[str, Any] | None = None, extra: dict[str, Any] | None = None, + scheduler: SchedulerState | None = None, ): self.server_id = server_id self.address = address @@ -528,7 +527,6 @@ def __init__( self.status = status self._hash = hash(self.server_id) self.nbytes = 0 - self.occupancy = 0 self._memory_unmanaged_old = 0 self._memory_unmanaged_history = deque() self.metrics = {} @@ -537,12 +535,17 @@ def __init__( self.bandwidth = parse_bytes(dask.config.get("distributed.scheduler.bandwidth")) self.actors = set() self._has_what = {} - self.processing = {} + self.processing = set() self.long_running = set() self.executing = {} self.resources = {} self.used_resources = {} self.extra = extra or {} + self.scheduler_ref = weakref.ref(scheduler) if scheduler else None + self.task_groups_count = defaultdict(int) + self.needs_what = {} + self._network_occ = 0 + self._occupancy_cache = None def __hash__(self) -> int: return self._hash @@ -594,9 +597,8 @@ def clean(self) -> WorkerState: extra=self.extra, server_id=self.server_id, ) - ws.processing = { - ts.key: cost for ts, cost in self.processing.items() # type: ignore - } + ws._occupancy_cache = self.occupancy + ws.executing = { ts.key: duration for ts, duration in self.executing.items() # type: ignore } @@ -654,6 +656,121 @@ def _to_dict_no_nest(self, *, exclude: Container[str] = ()) -> dict[str, Any]: members=True, ) + @property + def scheduler(self): + assert self.scheduler_ref + s = self.scheduler_ref() + assert s + return s + + def add_to_processing(self, ts: TaskState) -> None: + """Assign a task to this worker for compute.""" + if self.scheduler.validate: + assert ts not in self.processing + + tg = ts.group + self.task_groups_count[tg.name] += 1 + self.scheduler.task_groups_count_global[tg.name] += 1 + self.processing.add(ts) + for dts in ts.dependencies: + if self not in dts.who_has: + self._inc_needs_replica(dts) + + def add_to_long_running(self, ts: TaskState) -> None: + if self.scheduler.validate: + assert ts in self.processing + assert ts not in self.long_running + + self._remove_from_task_groups_count(ts) + # Cannot remove from processing since we're using this for things like + # idleness detection. Idle workers are typically targeted for + # downscaling but we should not downscale workers with long running + # tasks + self.long_running.add(ts) + + def remove_from_processing(self, ts: TaskState) -> None: + """Remove a task from a workers processing""" + if self.scheduler.validate: + assert ts in self.processing + + if ts in self.long_running: + self.long_running.discard(ts) + else: + self._remove_from_task_groups_count(ts) + self.processing.remove(ts) + for dts in ts.dependencies: + if dts in self.needs_what: + self._dec_needs_replica(dts) + + def _remove_from_task_groups_count(self, ts: TaskState) -> None: + count = self.task_groups_count[ts.group.name] - 1 + if count: + self.task_groups_count[ts.group.name] = count + else: + del self.task_groups_count[ts.group.name] + + count = self.scheduler.task_groups_count_global[ts.group.name] - 1 + if count: + self.scheduler.task_groups_count_global[ts.group.name] = count + else: + del self.scheduler.task_groups_count_global[ts.group.name] + + def remove_replica(self, ts: TaskState) -> None: + """The worker no longer has a task in memory""" + if self.scheduler.validate: + assert self in ts.who_has + assert ts in self.has_what + assert ts not in self.needs_what + + self.nbytes -= ts.get_nbytes() + del self._has_what[ts] + ts.who_has.remove(self) + + def _inc_needs_replica(self, ts: TaskState) -> None: + """Assign a task fetch to this worker and update network occupancies""" + if self.scheduler.validate: + assert self not in ts.who_has + assert ts not in self.has_what + if ts not in self.needs_what: + self.needs_what[ts] = 1 + nbytes = ts.get_nbytes() + self._network_occ += nbytes + self.scheduler._network_occ_global += nbytes + else: + self.needs_what[ts] += 1 + + def _dec_needs_replica(self, ts: TaskState) -> None: + if self.scheduler.validate: + assert ts in self.needs_what + + self.needs_what[ts] -= 1 + if self.needs_what[ts] == 0: + del self.needs_what[ts] + nbytes = ts.get_nbytes() + self._network_occ -= nbytes + self.scheduler._network_occ_global -= nbytes + + def add_replica(self, ts: TaskState) -> None: + """The worker acquired a replica of task""" + if self.scheduler.validate: + assert self not in ts.who_has + assert ts not in self.has_what + + nbytes = ts.get_nbytes() + if ts in self.needs_what: + del self.needs_what[ts] + self._network_occ -= nbytes + self.scheduler._network_occ_global -= nbytes + ts.who_has.add(self) + self.nbytes += nbytes + self._has_what[ts] = None + + @property + def occupancy(self) -> float: + return self._occupancy_cache or self.scheduler._calc_occupancy( + self.task_groups_count, self._network_occ + ) + @dataclasses.dataclass class ErredTask: @@ -751,6 +868,10 @@ class TaskPrefix: #: Store timings for each prefix-action all_durations: defaultdict[str, float] + #: This measures the maximum recorded live execution time and can be used to + #: detect outliers + max_exec_time: float + #: Task groups associated to this prefix groups: list[TaskGroup] @@ -765,8 +886,14 @@ def __init__(self, name: str): self.duration_average = parse_timedelta(task_durations[self.name]) else: self.duration_average = -1 + self.max_exec_time = -1 self.suspicious = 0 + def add_exec_time(self, duration: float): + self.max_exec_time = max(duration, self.max_exec_time) + if duration > 2 * self.duration_average: + self.duration_average = -1 + def add_duration(self, action: str, start: float, stop: float) -> None: duration = stop - start self.all_durations[action] += duration @@ -1312,7 +1439,6 @@ class SchedulerState: #: Workers that are fully utilized. May include non-running workers. saturated: set[WorkerState] total_nthreads: int - total_occupancy: float #: Cluster-wide resources. {resource name: {worker address: amount}} resources: dict[str, dict[str, float]] @@ -1365,6 +1491,8 @@ class SchedulerState: #: In production, it should always be set to False. transition_counter_max: int | Literal[False] + task_groups_count_global: defaultdict[str, int] + _network_occ_global: float ###################### # Cached configuration ###################### @@ -1426,12 +1554,13 @@ def __init__( self.task_prefixes = {} self.task_metadata = {} self.total_nthreads = 0 - self.total_occupancy = 0.0 self.unknown_durations = {} self.queued = queued self.unrunnable = unrunnable self.validate = validate self.workers = workers + self.task_groups_count_global = defaultdict(int) + self._network_occ_global = 0.0 self.running = { ws for ws in self.workers.values() if ws.status == Status.running } @@ -1542,6 +1671,34 @@ def _clear_task_state(self): ]: collection.clear() + @property + def total_occupancy(self) -> float: + return self._calc_occupancy( + self.task_groups_count_global, + self._network_occ_global, + ) + + def _calc_occupancy( + self, + task_groups_count: dict[str, int], + network_occ: float, + ) -> float: + res = 0.0 + for group_name, count in task_groups_count.items(): + # TODO: Deal with unknown tasks better + prefix = self.task_groups[group_name].prefix + assert prefix is not None + duration = prefix.duration_average + if duration < 0: + if prefix.max_exec_time > 0: + duration = 2 * prefix.max_exec_time + else: + duration = self.UNKNOWN_TASK_DURATION + res += duration * count + occ = res + network_occ / self.bandwidth + assert occ >= 0, occ + return occ + ##################### # State Transitions # ##################### @@ -2132,7 +2289,10 @@ def transition_processing_memory( if self.validate: assert ts.processing_on - assert ts in ts.processing_on.processing + wss = ts.processing_on + assert wss + assert ts in wss.processing + del wss assert not ts.waiting_on assert not ts.who_has, (ts, ts.who_has) assert not ts.exception_blame @@ -2172,10 +2332,9 @@ def transition_processing_memory( s: set = self.unknown_durations.pop(ts.prefix.name, set()) tts: TaskState steal = self.extensions.get("stealing") - for tts in s: - if tts.processing_on: - self._set_duration_estimate(tts, tts.processing_on) - if steal: + if steal: + for tts in s: + if tts.processing_on: steal.recalculate_cost(tts) ############################ @@ -2834,34 +2993,6 @@ def is_rootish(self, ts: TaskState) -> bool: and sum(map(len, tg.dependencies)) < 5 ) - def _set_duration_estimate(self, ts: TaskState, ws: WorkerState) -> None: - """Estimate task duration using worker state and task state. - - If a task takes longer than twice the current average duration we - estimate the task duration to be 2x current-runtime, otherwise we set it - to be the average duration. - - See also ``_remove_from_processing`` - """ - # Long running tasks do not contribute to occupancy calculations and we - # do not set any task duration estimates - if ts in ws.long_running: - return - - exec_time: float = ws.executing.get(ts, 0) - duration: float = self.get_task_duration(ts) - total_duration: float - if exec_time > 2 * duration: - total_duration = 2 * exec_time - else: - comm: float = self.get_comm_cost(ts, ws) - total_duration = duration + comm - - old = ws.processing.get(ts, 0) - ws.processing[ts] = total_duration - self.total_occupancy += total_duration - old - ws.occupancy += total_duration - old - def check_idle_saturated(self, ws: WorkerState, occ: float = -1.0): """Update the status of the idle and saturated state @@ -3068,21 +3199,13 @@ def worker_objective(self, ts: TaskState, ws: WorkerState) -> tuple: def add_replica(self, ts: TaskState, ws: WorkerState): """Note that a worker holds a replica of a task with state='memory'""" - if self.validate: - assert ws not in ts.who_has - assert ts not in ws.has_what - - ws.nbytes += ts.get_nbytes() - ws._has_what[ts] = None - ts.who_has.add(ws) + ws.add_replica(ts) if len(ts.who_has) == 2: self.replicated_tasks.add(ts) def remove_replica(self, ts: TaskState, ws: WorkerState): """Note that a worker no longer holds a replica of a task""" - ws.nbytes -= ts.get_nbytes() - del ws._has_what[ts] - ts.who_has.remove(ws) + ws.remove_replica(ts) if len(ts.who_has) == 1: self.replicated_tasks.remove(ts) @@ -3097,21 +3220,6 @@ def remove_all_replicas(self, ts: TaskState): self.replicated_tasks.remove(ts) ts.who_has.clear() - def _reevaluate_occupancy_worker(self, ws: WorkerState): - """See reevaluate_occupancy""" - ts: TaskState - old = ws.occupancy - for ts in ws.processing: - self._set_duration_estimate(ts, ws) - - self.check_idle_saturated(ws) - steal = self.extensions.get("stealing") - if steal is None: - return - if ws.occupancy > old * 1.3 or old > ws.occupancy * 1.3: - for ts in ws.processing: - steal.recalculate_cost(ts) - def bulk_schedule_after_adding_worker(self, ws: WorkerState) -> Recs: """Send ``queued`` or ``no-worker`` tasks to ``processing`` that this worker can handle. @@ -3724,8 +3832,6 @@ async def start_unsafe(self): for k, v in self.services.items(): logger.info("%11s at: %25s", k, "%s:%d" % (listen_ip, v.port)) - self._ongoing_background_tasks.call_soon(self.reevaluate_occupancy) - if self.scheduler_file: with open(self.scheduler_file, "w") as f: json.dump(self.identity(), f, indent=2) @@ -3894,11 +4000,13 @@ def heartbeat_worker( ws.last_seen = local_now if executing is not None: - ws.executing = { - self.tasks[key]: duration - for key, duration in executing.items() - if key in self.tasks - } + # NOTE: the executing dict is unused + ws.executing = {} + for key, duration in executing.items(): + if key in self.tasks: + ts = self.tasks[key] + ws.executing[ts] = duration + ts.prefix.add_exec_time(duration) ws.metrics = metrics @@ -4023,6 +4131,7 @@ async def add_worker( nanny=nanny, extra=extra, server_id=server_id, + scheduler=self, ) if ws.status == Status.running: self.running.add(ws) @@ -4595,7 +4704,7 @@ async def remove_worker( event_msg = { "action": "remove-worker", - "processing-tasks": {ts.key: cost for ts, cost in ws.processing.items()}, + "processing-tasks": {ts.key for ts in ws.processing}, } self.log_event(address, event_msg.copy()) event_msg["worker"] = address @@ -4624,7 +4733,6 @@ async def remove_worker( del self.workers[address] ws.status = Status.closed self.running.discard(ws) - self.total_occupancy -= ws.occupancy recommendations: dict = {} @@ -4899,6 +5007,7 @@ def validate_state(self, allow_overlap: bool = False) -> None: self.running, list(self.idle.values()), ) + task_group_counts: defaultdict[str, int] = defaultdict(int) for w, ws in self.workers.items(): assert isinstance(w, str), (type(w), w) assert isinstance(ws, WorkerState), (type(ws), ws) @@ -4910,7 +5019,22 @@ def validate_state(self, allow_overlap: bool = False) -> None: assert not ws.occupancy if ws.status == Status.running: assert ws.address in self.idle + assert not ws.needs_what.keys() & ws.has_what + actual_needs_what: defaultdict[TaskState, int] = defaultdict(int) + for ts in ws.processing: + for tss in ts.dependencies: + if tss not in ws.has_what: + actual_needs_what[tss] += 1 + assert actual_needs_what == ws.needs_what assert (ws.status == Status.running) == (ws in self.running) + for name, count in ws.task_groups_count.items(): + task_group_counts[name] += count + + assert task_group_counts.keys() == self.task_groups_count_global.keys() + for name, global_count in self.task_groups_count_global.items(): + assert ( + task_group_counts[name] == global_count + ), f"{name}: {task_group_counts[name]} (wss), {global_count} (global)" for ws in self.running: assert ws.status == Status.running @@ -4939,22 +5063,6 @@ def validate_state(self, allow_overlap: bool = False) -> None: } assert a == b, (a, b) - actual_total_occupancy = 0.0 - for worker, ws in self.workers.items(): - ws_processing_total = sum( - cost for ts, cost in ws.processing.items() if ts not in ws.long_running - ) - assert abs(ws_processing_total - ws.occupancy) < 1e-8, ( - worker, - ws_processing_total, - ws.occupancy, - ) - actual_total_occupancy += ws.occupancy - - assert abs(actual_total_occupancy - self.total_occupancy) < 1e-8, ( - actual_total_occupancy, - self.total_occupancy, - ) if self.transition_counter_max: assert self.transition_counter < self.transition_counter_max @@ -5165,15 +5273,7 @@ def handle_long_running( else: ts.prefix.duration_average = (old_duration + compute_duration) / 2 - occ = ws.processing[ts] - ws.occupancy -= occ - self.total_occupancy -= occ - # Cannot remove from processing since we're using this for things like - # idleness detection. Idle workers are typically targeted for - # downscaling but we should not downscale workers with long running - # tasks - ws.processing[ts] = 0 - ws.long_running.add(ts) + ws.add_to_long_running(ts) self.check_idle_saturated(ws) def handle_worker_status_change( @@ -7501,50 +7601,6 @@ async def get_worker_monitor_info(self, recent=False, starts=None): # Cleanup # ########### - async def reevaluate_occupancy(self, worker_index: int = 0): - """Periodically reassess task duration time - - The expected duration of a task can change over time. Unfortunately we - don't have a good constant-time way to propagate the effects of these - changes out to the summaries that they affect, like the total expected - runtime of each of the workers, or what tasks are stealable. - - In this coroutine we walk through all of the workers and re-align their - estimates with the current state of tasks. We do this periodically - rather than at every transition, and we only do it if the scheduler - process isn't under load (using psutil.Process.cpu_percent()). This - lets us avoid this fringe optimization when we have better things to - think about. - """ - try: - while self.status != Status.closed: - last = time() - delay = 0.1 - - if self.proc.cpu_percent() < 50: - workers: list = list(self.workers.values()) - nworkers: int = len(workers) - i: int - for _ in range(nworkers): - ws: WorkerState = workers[worker_index % nworkers] - worker_index += 1 - try: - if ws is None or not ws.processing: - continue - self._reevaluate_occupancy_worker(ws) - finally: - del ws # lose ref - - duration = time() - last - if duration > 0.005: # 5ms since last release - delay = duration * 5 # 25ms gap - break - await asyncio.sleep(delay) - - except Exception: - logger.error("Error in reevaluate occupancy", exc_info=True) - raise - async def check_worker_ttl(self): now = time() stimulus_id = f"check-worker-ttl-{now}" @@ -7719,11 +7775,10 @@ def _add_to_processing( """Set a task as processing on a worker and return the worker messages to send.""" if state.validate: _validate_ready(state, ts) - assert ts not in ws.processing assert ws in state.running, state.running assert (o := state.workers.get(ws.address)) is ws, (ws, o) - state._set_duration_estimate(ts, ws) + ws.add_to_processing(ts) ts.processing_on = ws ts.state = "processing" state.acquire_resources(ts, ws) @@ -7754,18 +7809,10 @@ def _exit_processing_common( assert ws ts.processing_on = None + ws.remove_from_processing(ts) if state.workers.get(ws.address) is not ws: # may have been removed return None - duration = ws.processing.pop(ts) - ws.long_running.discard(ts) - if not ws.processing: - state.total_occupancy -= ws.occupancy - ws.occupancy = 0 - else: - state.total_occupancy -= duration - ws.occupancy -= duration - state.check_idle_saturated(ws) state.release_resources(ts, ws) diff --git a/distributed/stealing.py b/distributed/stealing.py index 2d0917710ad..8abb07b3b08 100644 --- a/distributed/stealing.py +++ b/distributed/stealing.py @@ -257,6 +257,7 @@ def steal_time_ratio(self, ts: TaskState) -> tuple[float, int] | tuple[None, Non cost_multiplier = transfer_time / compute_time level = int(round(log2(cost_multiplier) + 6)) + if level < 1: level = 1 elif level >= len(self.cost_multipliers): @@ -286,8 +287,11 @@ def move_task_request( thief.occupancy, ) - victim_duration = victim.processing[ts] - + # TODO: occupancy no longer concats linearily so we can't easily + # assume that the network cost would go down by that much + victim_duration = self.scheduler.get_task_duration( + ts + ) + self.scheduler.get_comm_cost(ts, victim) thief_duration = self.scheduler.get_task_duration( ts ) + self.scheduler.get_comm_cost(ts, thief) @@ -343,10 +347,7 @@ async def move_task_confirm( try: _log_msg = [key, state, victim.address, thief.address, stimulus_id] - if ts.state != "processing": - self.scheduler._reevaluate_occupancy_worker(thief) - self.scheduler._reevaluate_occupancy_worker(victim) - elif ( + if ( state in _WORKER_STATE_UNDEFINED # If our steal information is somehow stale we need to reschedule or state in _WORKER_STATE_CONFIRM @@ -367,15 +368,8 @@ async def move_task_confirm( elif state in _WORKER_STATE_CONFIRM: self.remove_key_from_stealable(ts) ts.processing_on = thief - duration = victim.processing.pop(ts) - victim.occupancy -= duration - self.scheduler.total_occupancy -= duration - if not victim.processing: - self.scheduler.total_occupancy -= victim.occupancy - victim.occupancy = 0 - thief.processing[ts] = info["thief_duration"] - thief.occupancy += info["thief_duration"] - self.scheduler.total_occupancy += info["thief_duration"] + victim.remove_from_processing(ts) + thief.add_to_processing(ts) self.put_key_in_stealable(ts) self.scheduler.send_task_to_worker(thief.address, ts) @@ -446,19 +440,19 @@ def balance(self) -> None: i += 1 if not (thief := _get_thief(s, ts, potential_thieves)): continue - task_occ_on_victim = victim.processing.get(ts) - if task_occ_on_victim is None: + if ts not in victim.processing: stealable.discard(ts) continue occ_thief = self._combined_occupancy(thief) occ_victim = self._combined_occupancy(victim) - comm_cost = self.scheduler.get_comm_cost(ts, thief) + comm_cost_thief = self.scheduler.get_comm_cost(ts, thief) + comm_cost_victim = self.scheduler.get_comm_cost(ts, victim) compute = self.scheduler.get_task_duration(ts) if ( - occ_thief + comm_cost + compute - <= occ_victim - task_occ_on_victim / 2 + occ_thief + comm_cost_thief + compute + <= occ_victim - (comm_cost_victim + compute) / 2 ): self.move_task_request(ts, victim, thief) log.append( @@ -466,7 +460,7 @@ def balance(self) -> None: start, level, ts.key, - task_occ_on_victim, + comm_cost_victim + compute, victim.address, occ_victim, thief.address, @@ -487,7 +481,7 @@ def balance(self) -> None: ) if log: - self.log(log) + self.log(("request", log)) self.count += 1 stop = time() if s.digests: @@ -510,7 +504,9 @@ def story(self, *keys_or_ts: str | TaskState) -> list: keys = {key.key if not isinstance(key, str) else key for key in keys_or_ts} out = [] for _, L in self.scheduler.get_events(topic="stealing"): - if not isinstance(L, list): + if L[0] == "request": + L = L[1] + else: L = [L] for t in L: if any(x in keys for x in t): diff --git a/distributed/tests/test_cancelled_state.py b/distributed/tests/test_cancelled_state.py index 3336fc8481d..7687b1b16f5 100644 --- a/distributed/tests/test_cancelled_state.py +++ b/distributed/tests/test_cancelled_state.py @@ -1034,7 +1034,8 @@ def f(ev1, ev2, ev3, ev4): await ev1.wait() ts = a.state.tasks["x"] assert ts.state == "executing" - assert sum(ws.processing.values()) > 0 + assert ws.processing + assert not ws.long_running x.release() await wait_for_state("x", "cancelled", a) @@ -1050,8 +1051,8 @@ def f(ev1, ev2, ev3, ev4): # Test that the scheduler receives a delayed {op: long-running} assert ws.processing - while sum(ws.processing.values()): - await asyncio.sleep(0.1) + while not ws.long_running: + await asyncio.sleep(0) assert ws.processing await ev4.set() diff --git a/distributed/tests/test_client.py b/distributed/tests/test_client.py index aa8a001ac5e..21372451dfe 100644 --- a/distributed/tests/test_client.py +++ b/distributed/tests/test_client.py @@ -5158,13 +5158,10 @@ def long_running(lock, entered): await asyncio.sleep(0.01) await a.heartbeat() - s._set_duration_estimate(ts, ws) assert s.workers[a.address].occupancy == 0 assert s.total_occupancy == 0 assert ws.occupancy == 0 - s._ongoing_background_tasks.call_soon(s.reevaluate_occupancy, 0) - assert s.workers[a.address].occupancy == 0 await l.release() with ( diff --git a/distributed/tests/test_scheduler.py b/distributed/tests/test_scheduler.py index 95bebd674ea..ad18405e5f2 100644 --- a/distributed/tests/test_scheduler.py +++ b/distributed/tests/test_scheduler.py @@ -50,6 +50,7 @@ cluster, dec, div, + freeze_data_fetching, gen_cluster, gen_test, inc, @@ -657,14 +658,14 @@ async def test_scheduler_init_pulls_blocked_handlers_from_config(s): @gen_cluster() async def test_feed(s, a, b): def func(scheduler): - return dumps(dict(scheduler.workers)) + return dumps({addr: ws.clean() for addr, ws in scheduler.workers.items()}) comm = await connect(s.address) await comm.write({"op": "feed", "function": dumps(func), "interval": 0.01}) for _ in range(5): response = await comm.read() - expected = dict(s.workers) + expected = {addr: ws.clean() for addr, ws in s.workers.items()} assert cloudpickle.loads(response) == expected await comm.close() @@ -1454,21 +1455,6 @@ async def test_learn_occupancy_2(c, s, a, b): assert nproc * 0.1 < s.total_occupancy < nproc * 0.4 -@gen_cluster(client=True) -async def test_occupancy_cleardown(c, s, a, b): - s.validate = False - - # Inject excess values in s.occupancy - s.workers[a.address].occupancy = 2 - s.total_occupancy += 2 - futures = c.map(slowinc, range(100), delay=0.01) - await wait(futures) - - # Verify that occupancy values have been zeroed out - assert abs(s.total_occupancy) < 0.01 - assert all(ws.occupancy == 0 for ws in s.workers.values()) - - @nodebug @gen_cluster(client=True, nthreads=[("127.0.0.1", 1)] * 30) async def test_balance_many_workers(c, s, *workers): @@ -1496,31 +1482,40 @@ async def test_balance_many_workers_2(c, s, *workers): @gen_cluster(client=True) -async def test_learn_occupancy_multiple_workers(c, s, a, b): - x = c.submit(slowinc, 1, delay=0.2, workers=a.address) - await asyncio.sleep(0.05) - futures = c.map(slowinc, range(100), delay=0.2) - - await wait(x) - - assert not any(v == 0.5 for w in s.workers.values() for v in w.processing.values()) +async def test_include_communication_in_occupancy(c, s, a, b): + x = c.submit(operator.mul, b"0", int(s.bandwidth) * 2, workers=a.address) + y = c.submit(operator.mul, b"1", int(s.bandwidth * 3), workers=b.address) + event = Event() + def add_blocked(x, y, event): + event.wait() + return x + y -@gen_cluster(client=True) -async def test_include_communication_in_occupancy(c, s, a, b): - await c.submit(slowadd, 1, 2, delay=0) - x = c.submit(operator.mul, b"0", int(s.bandwidth), workers=a.address) - y = c.submit(operator.mul, b"1", int(s.bandwidth * 1.5), workers=b.address) + with freeze_data_fetching(b): + z = c.submit(add_blocked, x, y, event=event, pure=False) + while z.key not in s.tasks or not s.tasks[z.key].processing_on: + await asyncio.sleep(0.01) - z = c.submit(slowadd, x, y, delay=1) - while z.key not in s.tasks or not s.tasks[z.key].processing_on: + ts = s.tasks[z.key] + ws = s.workers[b.address] + assert ts.processing_on == ws + # Occ should be 0.5s (CPU, unknown) + 2s (network) + occ = ws.occupancy + assert occ == 2.5 + z2 = c.submit(add_blocked, x, y, event=event, pure=False, workers=b.address) + while z2.key not in s.tasks or not s.tasks[z2.key].processing_on: + await asyncio.sleep(0.01) + # Occ should be 2 * 0.5 (CPU, unknown) + 2s (network) + # Network cost for the same key should only cost once + occ2 = ws.occupancy + assert occ2 == 3 + while s.tasks[x.key] not in ws.has_what: await asyncio.sleep(0.01) - - ts = s.tasks[z.key] - assert ts.processing_on == s.workers[b.address] - assert s.workers[b.address].processing[ts] > 1 + occ3 = ws.occupancy + # Occ should be 2 * 0.5 (CPU, unknown) + assert occ3 == 1 + await event.set() await wait(z) - del z @gen_cluster(nthreads=[]) diff --git a/distributed/tests/test_steal.py b/distributed/tests/test_steal.py index efb4892758f..e975a6a0e4d 100644 --- a/distributed/tests/test_steal.py +++ b/distributed/tests/test_steal.py @@ -17,7 +17,7 @@ from tlz import merge, sliding_window import dask -from dask.utils import key_split +from dask.utils import key_split, parse_bytes from distributed import ( Client, @@ -62,11 +62,6 @@ teardown_module = nodebug_teardown_module -@pytest.fixture(params=[True, False]) -def recompute_saturation(request): - yield request.param - - @gen_cluster(client=True, nthreads=[("", 2), ("", 2)]) async def test_work_stealing(c, s, a, b): [x] = await c._scatter([1], workers=a.address) @@ -666,7 +661,7 @@ def slow2(x): assert any(future.key in w.state.tasks for w in rest) -async def assert_balanced(inp, expected, recompute_saturation, c, s, *workers): +async def assert_balanced(inp, expected, c, s, *workers): steal = s.extensions["stealing"] await steal.stop() ev = Event() @@ -700,9 +695,7 @@ def block(*args, event, **kwargs): while len([ts for ts in s.tasks.values() if ts.processing_on]) < len(futures): await asyncio.sleep(0.001) - if recompute_saturation: - for ws in s.workers.values(): - s._reevaluate_occupancy_worker(ws) + try: for _ in range(10): steal.balance() @@ -767,9 +760,9 @@ def block(*args, event, **kwargs): ), ], ) -def test_balance(inp, expected, recompute_saturation): +def test_balance(inp, expected): async def test_balance_(*args, **kwargs): - await assert_balanced(inp, expected, recompute_saturation, *args, **kwargs) + await assert_balanced(inp, expected, *args, **kwargs) config = { "distributed.scheduler.default-task-durations": {str(i): 1 for i in range(10)} @@ -834,20 +827,25 @@ def block_reduce(x, y, event): @gen_cluster( client=True, - config={"distributed.scheduler.default-task-durations": {"slowadd": 0.001}}, + config={"distributed.scheduler.default-task-durations": {"blocked_add": 0.001}}, ) async def test_steal_communication_heavy_tasks(c, s, a, b): steal = s.extensions["stealing"] await steal.stop() x = c.submit(mul, b"0", int(s.bandwidth), workers=a.address) y = c.submit(mul, b"1", int(s.bandwidth), workers=b.address) + event = Event() + + def blocked_add(x, y, event): + event.wait() + return x + y futures = [ c.submit( - slowadd, + blocked_add, x, y, - delay=1, + event=event, pure=False, workers=a.address, allow_other_workers=True, @@ -858,9 +856,11 @@ async def test_steal_communication_heavy_tasks(c, s, a, b): while not any(f.key in s.tasks and s.tasks[f.key].processing_on for f in futures): await asyncio.sleep(0.01) + await steal.start() steal.balance() - await steal.stop() + await event.set() + await c.gather(futures) @gen_cluster(client=True) @@ -962,7 +962,7 @@ def long(delay): await c.submit(inc, 1) # learn duration long_tasks = c.map(long, [0.5, 0.6], workers=a.address, allow_other_workers=True) - while sum(len(ws.processing) for ws in s.workers.values()) < 2: # let them start + while sum(len(ws.long_running) for ws in s.workers.values()) < 2: # let them start await asyncio.sleep(0.01) start = time() @@ -976,7 +976,6 @@ def long(delay): incs = c.map(inc, range(100), workers=a.address, allow_other_workers=True) await asyncio.sleep(0.2) - await wait(long_tasks) for t in long_tasks: @@ -1152,15 +1151,11 @@ async def test_steal_concurrent_simple(c, s, *workers): assert not ws2.has_what -@gen_cluster( - client=True, - config={ - "distributed.scheduler.work-stealing-interval": 1_000_000, - }, -) +@gen_cluster(client=True) async def test_steal_reschedule_reset_in_flight_occupancy(c, s, *workers): # https://github.com/dask/distributed/issues/5370 steal = s.extensions["stealing"] + await steal.stop() w0 = workers[0] roots = c.map( inc, @@ -1356,6 +1351,7 @@ async def test_reschedule_concurrent_requests_deadlock(c, s, *workers): assert msgs in (expect1, expect2, expect3) +@pytest.mark.skip("executing heartbeats not considered yet") @gen_cluster(client=True, nthreads=[("127.0.0.1", 1)] * 3) async def test_correct_bad_time_estimate(c, s, *workers): """Initial time estimation causes the task to not be considered for @@ -1446,7 +1442,7 @@ def func(*args): assert (ntasks_per_worker < ideal * 1.5).all(), (ideal, ntasks_per_worker) -def test_balance_even_with_replica(recompute_saturation): +def test_balance_even_with_replica(): dependencies = {"a": 1} dependency_placement = [["a"], ["a"]] task_placement = [[["a"], ["a"]], []] @@ -1463,11 +1459,10 @@ def _correct_placement(actual): dependency_placement, task_placement, _correct_placement, - recompute_saturation, ) -def test_balance_to_replica(recompute_saturation): +def test_balance_to_replica(): dependencies = {"a": 2} dependency_placement = [["a"], ["a"], []] task_placement = [[["a"], ["a"]], [], []] @@ -1485,11 +1480,10 @@ def _correct_placement(actual): dependency_placement, task_placement, _correct_placement, - recompute_saturation, ) -def test_balance_multiple_to_replica(recompute_saturation): +def test_balance_multiple_to_replica(): dependencies = {"a": 6} dependency_placement = [["a"], ["a"], []] task_placement = [[["a"], ["a"], ["a"], ["a"], ["a"], ["a"], ["a"], ["a"]], [], []] @@ -1514,11 +1508,10 @@ def _correct_placement(actual): dependency_placement, task_placement, _correct_placement, - recompute_saturation, ) -def test_balance_to_larger_dependency(recompute_saturation): +def test_balance_to_larger_dependency(): dependencies = {"a": 2, "b": 1} dependency_placement = [["a", "b"], ["a"], ["b"]] task_placement = [[["a", "b"], ["a", "b"], ["a", "b"]], [], []] @@ -1536,12 +1529,10 @@ def _correct_placement(actual): dependency_placement, task_placement, _correct_placement, - recompute_saturation, ) def test_balance_prefers_busier_with_dependency(): - recompute_saturation = True dependencies = {"a": 5, "b": 1} dependency_placement = [["a"], ["a", "b"], []] task_placement = [ @@ -1570,7 +1561,6 @@ def _correct_placement(actual): dependency_placement, task_placement, _correct_placement, - recompute_saturation, # This test relies on disabling queueing to flag workers as idle config={ "distributed.scheduler.worker-saturation": float("inf"), @@ -1583,7 +1573,6 @@ def _run_dependency_balance_test( dependency_placement: list[list[str]], task_placement: list[list[list[str]]], correct_placement_fn: Callable[[list[list[list[str]]]], bool], - recompute_saturation: bool, config: dict | None = None, ) -> None: """Run a test for balancing with task dependencies according to the provided @@ -1604,8 +1593,6 @@ def _run_dependency_balance_test( index of the outer list. Each task is a list of names of dependencies. correct_placement_fn Callable used to determine if stealing placed the tasks as expected. - recompute_saturation - Whether to recompute worker saturation before stealing. config Optional configuration to apply to the test. See Also @@ -1625,7 +1612,6 @@ async def _run( dependency_placement, task_placement, correct_placement_fn, - recompute_saturation, permutation, *args, **kwargs, @@ -1635,6 +1621,7 @@ async def _run( client=True, nthreads=[("", 1)] * len(task_placement), config=merge( + NO_AMM, config or {}, { "distributed.scheduler.unknown-task-duration": "1s", @@ -1648,7 +1635,6 @@ async def _dependency_balance_test_permutation( dependency_placement: list[list[str]], task_placement: list[list[list[str]]], correct_placement_fn: Callable[[list[list[list[str]]]], bool], - recompute_saturation: bool, permutation: list[int], c: Client, s: Scheduler, @@ -1669,8 +1655,6 @@ async def _dependency_balance_test_permutation( index of the outer list. Each task is a list of names of dependencies. correct_placement_fn Callable used to determine if stealing placed the tasks as expected. - recompute_saturation - Whether to recompute worker saturation before stealing. permutation Permutation of workers to use for this run. @@ -1698,9 +1682,12 @@ async def _dependency_balance_test_permutation( workers, ) - if recompute_saturation: - for ws in s.workers.values(): - s._reevaluate_occupancy_worker(ws) + # Re-evaluate idle/saturated classification to avoid outdated classifications due to + # the initialization order of workers. On a real cluster, this would get constantly + # updated by tasks being added or completing. + for ws in s.workers.values(): + s.check_idle_saturated(ws) + try: for _ in range(20): steal.balance() @@ -1869,3 +1856,17 @@ def assert_task_placement(expected, s, workers): """Assert that tasks are placed on the workers as expected.""" actual = _get_task_placement(s, workers) assert _equal_placement(actual, expected) + + +# Reproducer from https://github.com/dask/distributed/issues/6573 +@gen_cluster( + client=True, + nthreads=[("", 1)] * 4, +) +async def test_trivial_workload_should_not_cause_work_stealing(c, s, *workers): + root = dask.delayed(lambda n: "x" * n)(parse_bytes("1MiB"), dask_key_name="root") + results = [dask.delayed(lambda *args: None)(root, i) for i in range(1000)] + futs = c.compute(results) + await c.gather(futs) + events = s.events["stealing"] + assert len(events) == 0 diff --git a/distributed/tests/test_worker.py b/distributed/tests/test_worker.py index 3c182061e04..25fce6d67bc 100644 --- a/distributed/tests/test_worker.py +++ b/distributed/tests/test_worker.py @@ -3120,7 +3120,7 @@ async def test_worker_status_sync(s, a): "prev-status": "running", "status": "closing_gracefully", }, - {"action": "remove-worker", "processing-tasks": {}}, + {"action": "remove-worker", "processing-tasks": set()}, {"action": "retired"}, ]