diff --git a/distributed/scheduler.py b/distributed/scheduler.py index 630bd64d722..f857e2e8711 100644 --- a/distributed/scheduler.py +++ b/distributed/scheduler.py @@ -1034,6 +1034,8 @@ class TaskGroup: #: subsequent tasks until a new worker is chosen. last_worker_tasks_left: int + _has_restrictions: bool + prefix: TaskPrefix | None start: float stop: float @@ -1052,6 +1054,7 @@ def __init__(self, name: str): self.start = 0.0 self.stop = 0.0 self.all_durations = defaultdict(float) + self._has_restrictions = False self.last_worker = None self.last_worker_tasks_left = 0 @@ -1067,6 +1070,13 @@ def add_duration(self, action: str, start: float, stop: float) -> None: self.prefix.add_duration(action, start, stop) def add(self, other: TaskState) -> None: + if ( + other.resource_restrictions + or other.worker_restrictions + or other.host_restrictions + or other.actor + ): + self._has_restrictions = True self.states[other.state] += 1 other.group = self @@ -1096,6 +1106,15 @@ def _to_dict_no_nest(self, *, exclude: Container[str] = ()) -> dict[str, Any]: """ return recursive_to_dict(self, exclude=exclude, members=True) + @property + def rootish(self): + return ( + not self._has_restrictions + and len(self) >= 5 + and len(self.dependencies) < 5 + and sum(map(len, self.dependencies)) < 5 + ) + class TaskState: """A simple object holding information about a task. @@ -2075,15 +2094,6 @@ def decide_worker_rootish_queuing_disabled( return ws - def worker_objective_rootish_queuing(self, ws, ts): - # FIXME: This is basically the ordinary worker_objective but with task - # counts instead of occupancy. - comm_bytes = sum( - dts.get_nbytes() for dts in ts.dependencies if ws not in dts.who_has - ) - # See test_nbytes_determines_worker - return (len(ws.processing) / ws.nthreads, comm_bytes, ws.nbytes) - def decide_worker_rootish_queuing_enabled(self, ts) -> WorkerState | None: """Pick a worker for a runnable root-ish task, if not all are busy. @@ -2109,11 +2119,7 @@ def decide_worker_rootish_queuing_enabled(self, ts) -> WorkerState | None: """ if self.validate: - # We don't `assert self.is_rootish(ts)` here, because that check is - # dependent on cluster size. It's possible a task looked root-ish when it - # was queued, but the cluster has since scaled up and it no longer does when - # coming out of the queue. If `is_rootish` changes to a static definition, - # then add that assertion here (and actually pass in the task). + assert ts.group.rootish assert not math.isinf(self.WORKER_SATURATION) if not self.idle_task_count: @@ -2124,7 +2130,7 @@ def decide_worker_rootish_queuing_enabled(self, ts) -> WorkerState | None: # NOTE: this will lead to worst-case scheduling with regards to co-assignment. ws = min( self.idle_task_count, - key=partial(self.worker_objective_rootish_queuing, ts=ts), + key=lambda ws: len(ws.processing) / ws.nthreads, ) if self.validate: assert not _worker_full(ws, self.WORKER_SATURATION), ( @@ -2216,7 +2222,7 @@ def transition_waiting_processing(self, key: str, stimulus_id: str) -> RecsMsgs: """ ts = self.tasks[key] - if self.is_rootish(ts): + if ts.group.rootish: # NOTE: having two root-ish methods is temporary. When the feature flag is # removed, there should only be one, which combines co-assignment and # queuing. Eventually, special-casing root tasks might be removed entirely, @@ -2822,23 +2828,6 @@ def story(self, *keys_or_tasks_or_stimuli: str | TaskState) -> list[Transition]: # Assigning Tasks to Workers # ############################## - def is_rootish(self, ts: TaskState) -> bool: - """ - Whether ``ts`` is a root or root-like task. - - Root-ish tasks are part of a group that's much larger than the cluster, - and have few or no dependencies. - """ - if ( - ts.resource_restrictions - or ts.worker_restrictions - or ts.host_restrictions - or ts.actor - ): - return False - tg = ts.group - return len(tg.dependencies) < 5 and sum(map(len, tg.dependencies)) < 5 - def check_idle_saturated(self, ws: WorkerState, occ: float = -1.0) -> None: """Update the status of the idle and saturated state diff --git a/distributed/tests/test_rootish.py b/distributed/tests/test_rootish.py new file mode 100644 index 00000000000..3f188618a91 --- /dev/null +++ b/distributed/tests/test_rootish.py @@ -0,0 +1,154 @@ +from __future__ import annotations + +from typing import Iterable + +import pytest + +import dask + +from distributed.scheduler import TaskGroup, TaskState + + +@pytest.fixture() +def abcde(): + return "abcde" + + +def f(*args): + return None + + +def dummy_dsk_to_taskstate(dsk: dict) -> tuple[list[TaskState], dict[str, TaskGroup]]: + task_groups: dict[str, TaskGroup] = {} + tasks = dict() + priority = dask.order.order(dsk) + for key in dsk: + tasks[key] = ts = TaskState(key, None, "released") + ts.group = task_groups.get(ts.group_key, TaskGroup(ts.group_key)) + task_groups[ts.group_key] = ts.group + ts.group.add(ts) + ts.priority = priority[key] + for key, vals in dsk.items(): + stack = list(vals[1:]) + while stack: + d = stack.pop() + if isinstance(d, list): + stack.extend(d) + continue + assert isinstance(d, (str, tuple, int)) + if d not in tasks: + raise ValueError(f"Malformed example. {d} not part of dsk") + tasks[key].add_dependency(tasks[d]) + return sorted(tasks.values(), key=lambda ts: ts.priority), task_groups + + +def _to_keys(prefix: str, suffix: Iterable[str]) -> list[str]: + return list(prefix + "-" + i for i in suffix) + + +def test_tree_reduce(abcde): + a, b, c, _, _ = abcde + a_ = _to_keys(a, "123456789") + b_ = _to_keys(b, "1234") + dsk = { + a_[0]: (f,), + a_[1]: (f,), + a_[2]: (f,), + b_[0]: (f, a_[0], a_[1], a_[2]), + a_[3]: (f,), + a_[4]: (f,), + a_[5]: (f,), + b_[1]: ( + f, + a_[6], + a_[7], + a_[8], + ), + a_[6]: (f,), + a_[7]: (f,), + a_[8]: (f,), + b_[2]: (f, a_[6], a_[7], a_[8]), + c: (f, b_[0], b_[1], b_[2]), + } + _, groups = dummy_dsk_to_taskstate(dsk) + assert len(groups) == 3 + assert len(groups["a"]) == 9 + assert groups["a"].rootish + assert not groups["b"].rootish + assert not groups["c"].rootish + + +@pytest.mark.parametrize("num_Bs, BRootish", [(4, False), (5, True)]) +def test_nearest_neighbor(abcde, num_Bs, BRootish): + r""" + a1 a2 a3 a4 a5 a6 a7 a8 a9 + \ | / \ | / \ | / \ | / + b1 b2 b3 b4 + """ + a, b, c, _, _ = abcde + a_ = _to_keys(a, "0123456789") + aa_ = _to_keys(a, ["10", "11", "12"]) + b_ = _to_keys(b, "012345") + + dsk = { + b_[1]: (f,), + b_[2]: (f,), + b_[3]: (f,), + b_[4]: (f,), + a_[1]: (f, b_[1]), + a_[2]: (f, b_[1]), + a_[3]: (f, b_[1], b_[2]), + a_[4]: (f, b_[2]), + a_[5]: (f, b_[2], b_[3]), + a_[6]: (f, b_[3]), + a_[7]: (f, b_[3], b_[4]), + a_[8]: (f, b_[4]), + a_[9]: (f, b_[4]), + } + if num_Bs == 5: + dsk[b_[5]] = ((f,),) + dsk[a_[9]] = ((f, b_[4], b_[5]),) + dsk[aa_[0]] = ((f, b_[5]),) + dsk[aa_[1]] = ((f, b_[5]),) + _, groups = dummy_dsk_to_taskstate(dsk) + assert len(groups) == 2 + + if BRootish: + assert not groups["a"].rootish + assert groups["b"].rootish + else: + assert groups["a"].rootish + assert not groups["b"].rootish + + +@pytest.mark.parametrize("num_Bs, rootish", [(4, False), (5, True)]) +def test_base_of_reduce_preferred(abcde, num_Bs, rootish): + r""" + a4 + /| + a3 | + /| | + a2 | | + /| | | + a1 | | | + /| | | | + a0 | | | | + | | | | | + b0 b1 b2 b3 b4 + \ \ / / / + c + """ + a, b, c, d, e = abcde + dsk = {(a, i): (f, (a, i - 1), (b, i)) for i in range(1, num_Bs + 1)} + dsk[(a, 0)] = (f, (b, 0)) + dsk.update({(b, i): (f, c) for i in range(num_Bs + 1)}) + dsk[c] = (f,) + + _, groups = dummy_dsk_to_taskstate(dsk) + assert len(groups) == 3 + assert not groups["a"].rootish + if rootish: + assert groups["b"].rootish + else: + assert not groups["b"].rootish + assert not groups["c"].rootish diff --git a/distributed/tests/test_scheduler.py b/distributed/tests/test_scheduler.py index 2b1bcfca95c..5148ca795bc 100644 --- a/distributed/tests/test_scheduler.py +++ b/distributed/tests/test_scheduler.py @@ -300,7 +300,7 @@ async def test_decide_worker_rootish_while_last_worker_is_retiring(c, s, a): # - TaskGroup(y) has more than 4 tasks (total_nthreads * 2) # - TaskGroup(y) has less than 5 dependency groups # - TaskGroup(y) has less than 5 dependency tasks - assert s.is_rootish(s.tasks["y-2"]) + assert s.tasks["y-2"].group.rootish await evx[0].set() await wait_for_state("y-0", "processing", s) @@ -4276,7 +4276,7 @@ def submit_tasks(): def assert_rootish(): # Just to verify our assumptions in case the definition changes. This is # currently a bit brittle - assert all(s.is_rootish(s.tasks[k]) for k in keys) + assert all(s.tasks[k].group.rootish for k in keys) f1 = submit_tasks() # Make sure that the worker is properly saturated diff --git a/distributed/tests/test_steal.py b/distributed/tests/test_steal.py index d28c21e2b90..024c5a0e96a 100644 --- a/distributed/tests/test_steal.py +++ b/distributed/tests/test_steal.py @@ -593,9 +593,6 @@ async def test_dont_steal_executing_tasks_2(c, s, a, b): assert not b.state.executing_count -@pytest.mark.skip( - reason="submitted tasks are root-ish. Stealing works very differently for root-ish tasks. If queued, stealing is disabled entirely" -) @gen_cluster( client=True, nthreads=[("127.0.0.1", 1)] * 10, diff --git a/distributed/tests/test_worker_memory.py b/distributed/tests/test_worker_memory.py index e92bace547f..60cd82364f1 100644 --- a/distributed/tests/test_worker_memory.py +++ b/distributed/tests/test_worker_memory.py @@ -504,7 +504,6 @@ def __sizeof__(self): "distributed.worker.memory.target": False, "distributed.worker.memory.spill": False, "distributed.worker.memory.pause": False, - "distributed.scheduler.worker-saturation": "inf", }, ) async def test_pause_executor_manual(c, s, a): @@ -567,7 +566,6 @@ def f(ev): "distributed.worker.memory.spill": False, "distributed.worker.memory.pause": 0.8, "distributed.worker.memory.monitor-interval": "10ms", - "distributed.scheduler.worker-saturation": "inf", }, ) async def test_pause_executor_with_memory_monitor(c, s, a):