Skip to content

Commit

Permalink
Elevate to TG property and add tests
Browse files Browse the repository at this point in the history
  • Loading branch information
fjetter committed Feb 10, 2023
1 parent 9e8e5d2 commit 25473fe
Show file tree
Hide file tree
Showing 5 changed files with 178 additions and 40 deletions.
55 changes: 22 additions & 33 deletions distributed/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -1034,6 +1034,8 @@ class TaskGroup:
#: subsequent tasks until a new worker is chosen.
last_worker_tasks_left: int

_has_restrictions: bool

prefix: TaskPrefix | None
start: float
stop: float
Expand All @@ -1052,6 +1054,7 @@ def __init__(self, name: str):
self.start = 0.0
self.stop = 0.0
self.all_durations = defaultdict(float)
self._has_restrictions = False
self.last_worker = None
self.last_worker_tasks_left = 0

Expand All @@ -1067,6 +1070,13 @@ def add_duration(self, action: str, start: float, stop: float) -> None:
self.prefix.add_duration(action, start, stop)

def add(self, other: TaskState) -> None:
if (
other.resource_restrictions
or other.worker_restrictions
or other.host_restrictions
or other.actor
):
self._has_restrictions = True
self.states[other.state] += 1
other.group = self

Expand Down Expand Up @@ -1096,6 +1106,15 @@ def _to_dict_no_nest(self, *, exclude: Container[str] = ()) -> dict[str, Any]:
"""
return recursive_to_dict(self, exclude=exclude, members=True)

@property
def rootish(self):
return (
not self._has_restrictions
and len(self) >= 5
and len(self.dependencies) < 5
and sum(map(len, self.dependencies)) < 5
)


class TaskState:
"""A simple object holding information about a task.
Expand Down Expand Up @@ -2075,15 +2094,6 @@ def decide_worker_rootish_queuing_disabled(

return ws

def worker_objective_rootish_queuing(self, ws, ts):
# FIXME: This is basically the ordinary worker_objective but with task
# counts instead of occupancy.
comm_bytes = sum(
dts.get_nbytes() for dts in ts.dependencies if ws not in dts.who_has
)
# See test_nbytes_determines_worker
return (len(ws.processing) / ws.nthreads, comm_bytes, ws.nbytes)

def decide_worker_rootish_queuing_enabled(self, ts) -> WorkerState | None:
"""Pick a worker for a runnable root-ish task, if not all are busy.
Expand All @@ -2109,11 +2119,7 @@ def decide_worker_rootish_queuing_enabled(self, ts) -> WorkerState | None:
"""
if self.validate:
# We don't `assert self.is_rootish(ts)` here, because that check is
# dependent on cluster size. It's possible a task looked root-ish when it
# was queued, but the cluster has since scaled up and it no longer does when
# coming out of the queue. If `is_rootish` changes to a static definition,
# then add that assertion here (and actually pass in the task).
assert ts.group.rootish
assert not math.isinf(self.WORKER_SATURATION)

if not self.idle_task_count:
Expand All @@ -2124,7 +2130,7 @@ def decide_worker_rootish_queuing_enabled(self, ts) -> WorkerState | None:
# NOTE: this will lead to worst-case scheduling with regards to co-assignment.
ws = min(
self.idle_task_count,
key=partial(self.worker_objective_rootish_queuing, ts=ts),
key=lambda ws: len(ws.processing) / ws.nthreads,
)
if self.validate:
assert not _worker_full(ws, self.WORKER_SATURATION), (
Expand Down Expand Up @@ -2216,7 +2222,7 @@ def transition_waiting_processing(self, key: str, stimulus_id: str) -> RecsMsgs:
"""
ts = self.tasks[key]

if self.is_rootish(ts):
if ts.group.rootish:
# NOTE: having two root-ish methods is temporary. When the feature flag is
# removed, there should only be one, which combines co-assignment and
# queuing. Eventually, special-casing root tasks might be removed entirely,
Expand Down Expand Up @@ -2822,23 +2828,6 @@ def story(self, *keys_or_tasks_or_stimuli: str | TaskState) -> list[Transition]:
# Assigning Tasks to Workers #
##############################

def is_rootish(self, ts: TaskState) -> bool:
"""
Whether ``ts`` is a root or root-like task.
Root-ish tasks are part of a group that's much larger than the cluster,
and have few or no dependencies.
"""
if (
ts.resource_restrictions
or ts.worker_restrictions
or ts.host_restrictions
or ts.actor
):
return False
tg = ts.group
return len(tg.dependencies) < 5 and sum(map(len, tg.dependencies)) < 5

def check_idle_saturated(self, ws: WorkerState, occ: float = -1.0) -> None:
"""Update the status of the idle and saturated state
Expand Down
154 changes: 154 additions & 0 deletions distributed/tests/test_rootish.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,154 @@
from __future__ import annotations

from typing import Iterable

import pytest

import dask

from distributed.scheduler import TaskGroup, TaskState


@pytest.fixture()
def abcde():
return "abcde"


def f(*args):
return None


def dummy_dsk_to_taskstate(dsk: dict) -> tuple[list[TaskState], dict[str, TaskGroup]]:
task_groups: dict[str, TaskGroup] = {}
tasks = dict()
priority = dask.order.order(dsk)
for key in dsk:
tasks[key] = ts = TaskState(key, None, "released")
ts.group = task_groups.get(ts.group_key, TaskGroup(ts.group_key))
task_groups[ts.group_key] = ts.group
ts.group.add(ts)
ts.priority = priority[key]
for key, vals in dsk.items():
stack = list(vals[1:])
while stack:
d = stack.pop()
if isinstance(d, list):
stack.extend(d)
continue
assert isinstance(d, (str, tuple, int))
if d not in tasks:
raise ValueError(f"Malformed example. {d} not part of dsk")
tasks[key].add_dependency(tasks[d])
return sorted(tasks.values(), key=lambda ts: ts.priority), task_groups


def _to_keys(prefix: str, suffix: Iterable[str]) -> list[str]:
return list(prefix + "-" + i for i in suffix)


def test_tree_reduce(abcde):
a, b, c, _, _ = abcde
a_ = _to_keys(a, "123456789")
b_ = _to_keys(b, "1234")
dsk = {
a_[0]: (f,),
a_[1]: (f,),
a_[2]: (f,),
b_[0]: (f, a_[0], a_[1], a_[2]),
a_[3]: (f,),
a_[4]: (f,),
a_[5]: (f,),
b_[1]: (
f,
a_[6],
a_[7],
a_[8],
),
a_[6]: (f,),
a_[7]: (f,),
a_[8]: (f,),
b_[2]: (f, a_[6], a_[7], a_[8]),
c: (f, b_[0], b_[1], b_[2]),
}
_, groups = dummy_dsk_to_taskstate(dsk)
assert len(groups) == 3
assert len(groups["a"]) == 9
assert groups["a"].rootish
assert not groups["b"].rootish
assert not groups["c"].rootish


@pytest.mark.parametrize("num_Bs, BRootish", [(4, False), (5, True)])
def test_nearest_neighbor(abcde, num_Bs, BRootish):
r"""
a1 a2 a3 a4 a5 a6 a7 a8 a9
\ | / \ | / \ | / \ | /
b1 b2 b3 b4
"""
a, b, c, _, _ = abcde
a_ = _to_keys(a, "0123456789")
aa_ = _to_keys(a, ["10", "11", "12"])
b_ = _to_keys(b, "012345")

dsk = {
b_[1]: (f,),
b_[2]: (f,),
b_[3]: (f,),
b_[4]: (f,),
a_[1]: (f, b_[1]),
a_[2]: (f, b_[1]),
a_[3]: (f, b_[1], b_[2]),
a_[4]: (f, b_[2]),
a_[5]: (f, b_[2], b_[3]),
a_[6]: (f, b_[3]),
a_[7]: (f, b_[3], b_[4]),
a_[8]: (f, b_[4]),
a_[9]: (f, b_[4]),
}
if num_Bs == 5:
dsk[b_[5]] = ((f,),)
dsk[a_[9]] = ((f, b_[4], b_[5]),)
dsk[aa_[0]] = ((f, b_[5]),)
dsk[aa_[1]] = ((f, b_[5]),)
_, groups = dummy_dsk_to_taskstate(dsk)
assert len(groups) == 2

if BRootish:
assert not groups["a"].rootish
assert groups["b"].rootish
else:
assert groups["a"].rootish
assert not groups["b"].rootish


@pytest.mark.parametrize("num_Bs, rootish", [(4, False), (5, True)])
def test_base_of_reduce_preferred(abcde, num_Bs, rootish):
r"""
a4
/|
a3 |
/| |
a2 | |
/| | |
a1 | | |
/| | | |
a0 | | | |
| | | | |
b0 b1 b2 b3 b4
\ \ / / /
c
"""
a, b, c, d, e = abcde
dsk = {(a, i): (f, (a, i - 1), (b, i)) for i in range(1, num_Bs + 1)}
dsk[(a, 0)] = (f, (b, 0))
dsk.update({(b, i): (f, c) for i in range(num_Bs + 1)})
dsk[c] = (f,)

_, groups = dummy_dsk_to_taskstate(dsk)
assert len(groups) == 3
assert not groups["a"].rootish
if rootish:
assert groups["b"].rootish
else:
assert not groups["b"].rootish
assert not groups["c"].rootish
4 changes: 2 additions & 2 deletions distributed/tests/test_scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -300,7 +300,7 @@ async def test_decide_worker_rootish_while_last_worker_is_retiring(c, s, a):
# - TaskGroup(y) has more than 4 tasks (total_nthreads * 2)
# - TaskGroup(y) has less than 5 dependency groups
# - TaskGroup(y) has less than 5 dependency tasks
assert s.is_rootish(s.tasks["y-2"])
assert s.tasks["y-2"].group.rootish

await evx[0].set()
await wait_for_state("y-0", "processing", s)
Expand Down Expand Up @@ -4276,7 +4276,7 @@ def submit_tasks():
def assert_rootish():
# Just to verify our assumptions in case the definition changes. This is
# currently a bit brittle
assert all(s.is_rootish(s.tasks[k]) for k in keys)
assert all(s.tasks[k].group.rootish for k in keys)

f1 = submit_tasks()
# Make sure that the worker is properly saturated
Expand Down
3 changes: 0 additions & 3 deletions distributed/tests/test_steal.py
Original file line number Diff line number Diff line change
Expand Up @@ -593,9 +593,6 @@ async def test_dont_steal_executing_tasks_2(c, s, a, b):
assert not b.state.executing_count


@pytest.mark.skip(
reason="submitted tasks are root-ish. Stealing works very differently for root-ish tasks. If queued, stealing is disabled entirely"
)
@gen_cluster(
client=True,
nthreads=[("127.0.0.1", 1)] * 10,
Expand Down
2 changes: 0 additions & 2 deletions distributed/tests/test_worker_memory.py
Original file line number Diff line number Diff line change
Expand Up @@ -504,7 +504,6 @@ def __sizeof__(self):
"distributed.worker.memory.target": False,
"distributed.worker.memory.spill": False,
"distributed.worker.memory.pause": False,
"distributed.scheduler.worker-saturation": "inf",
},
)
async def test_pause_executor_manual(c, s, a):
Expand Down Expand Up @@ -567,7 +566,6 @@ def f(ev):
"distributed.worker.memory.spill": False,
"distributed.worker.memory.pause": 0.8,
"distributed.worker.memory.monitor-interval": "10ms",
"distributed.scheduler.worker-saturation": "inf",
},
)
async def test_pause_executor_with_memory_monitor(c, s, a):
Expand Down

0 comments on commit 25473fe

Please sign in to comment.