Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Cache root-ish-ness for consistency #7262

Draft
wants to merge 12 commits into
base: main
Choose a base branch
from
59 changes: 50 additions & 9 deletions distributed/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -1311,6 +1311,10 @@ class TaskState:
#: Cached hash of :attr:`~TaskState.client_key`
_hash: int

#: Cached while tasks are in `queued` or `no-worker`; set in
#: `transition_waiting_processing` and `_add_to_processing`
_rootish: bool | None

# Support for weakrefs to a class with __slots__
__weakref__: Any = None
__slots__ = tuple(__annotations__)
Expand Down Expand Up @@ -1352,6 +1356,7 @@ def __init__(self, key: str, run_spec: object, state: TaskStateState):
self.metadata = {}
self.annotations = {}
self.erred_on = set()
self._rootish = None
TaskState._instances.add(self)

def __hash__(self) -> int:
Expand Down Expand Up @@ -1511,10 +1516,14 @@ class SchedulerState:
#: All tasks currently known to the scheduler
tasks: dict[str, TaskState]

#: Tasks in the "queued" state, ordered by priority
#: Tasks in the "queued" state, ordered by priority.
#: They are all root-ish.
#: Always empty if `worker-saturation` is set to `inf`.
queued: HeapSet[TaskState]

#: Tasks in the "no-worker" state
#: Tasks in the "no-worker" state.
#: They may or may not have restrictions.
#: Only contains root-ish tasks if `worker-saturation` is set to `inf`.
unrunnable: set[TaskState]

#: Subset of tasks that exist in memory on more than one worker
Expand Down Expand Up @@ -2014,11 +2023,19 @@ def transition_no_worker_processing(self, key, stimulus_id):
assert not ts.actor, f"Actors can't be in `no-worker`: {ts}"
assert ts in self.unrunnable

if ws := self.decide_worker_non_rootish(ts):
decide_worker = (
self.decide_worker_rootish_queuing_disabled
if self.is_rootish(ts)
else self.decide_worker_non_rootish
)
if ws := decide_worker(ts):
self.unrunnable.discard(ts)
worker_msgs = _add_to_processing(self, ts, ws)
# If no worker, task just stays in `no-worker`

if self.validate and self.is_rootish(ts):
assert ws is not None

return recommendations, client_msgs, worker_msgs
except Exception as e:
logger.exception(e)
Expand Down Expand Up @@ -2052,8 +2069,8 @@ def decide_worker_rootish_queuing_disabled(
``no-worker``.
"""
if self.validate:
# See root-ish-ness note below in `decide_worker_rootish_queuing_enabled`
assert math.isinf(self.WORKER_SATURATION)
assert self.is_rootish(ts)

pool = self.idle.values() if self.idle else self.running
if not pool:
Expand Down Expand Up @@ -2113,11 +2130,6 @@ def decide_worker_rootish_queuing_enabled(self) -> WorkerState | None:

"""
if self.validate:
# We don't `assert self.is_rootish(ts)` here, because that check is dependent on
# cluster size. It's possible a task looked root-ish when it was queued, but the
# cluster has since scaled up and it no longer does when coming out of the queue.
# If `is_rootish` changes to a static definition, then add that assertion here
# (and actually pass in the task).
assert not math.isinf(self.WORKER_SATURATION)

if not self.idle:
Expand Down Expand Up @@ -2154,6 +2166,9 @@ def decide_worker_non_rootish(self, ts: TaskState) -> WorkerState | None:
``ts`` or there are no running workers, returns None, in which case the task
should be transitioned to ``no-worker``.
"""
if self.validate:
assert not self.is_rootish(ts)

if not self.running:
return None

Expand Down Expand Up @@ -2222,13 +2237,15 @@ def transition_waiting_processing(self, key, stimulus_id):
# NOTE: having two root-ish methods is temporary. When the feature flag is removed,
# there should only be one, which combines co-assignment and queuing.
# Eventually, special-casing root tasks might be removed entirely, with better heuristics.
ts._rootish = True # cached until `processing`
if math.isinf(self.WORKER_SATURATION):
if not (ws := self.decide_worker_rootish_queuing_disabled(ts)):
return {ts.key: "no-worker"}, {}, {}
else:
if not (ws := self.decide_worker_rootish_queuing_enabled()):
return {ts.key: "queued"}, {}, {}
else:
ts._rootish = False # cached until `processing`
if not (ws := self.decide_worker_non_rootish(ts)):
return {ts.key: "no-worker"}, {}, {}

Expand Down Expand Up @@ -2988,6 +3005,15 @@ def is_rootish(self, ts: TaskState) -> bool:
Root-ish tasks are part of a group that's much larger than the cluster,
and have few or no dependencies.
"""
# NOTE: the result of `is_rootish` is cached in `waiting->processing`, and
# invalidated when entering `processing`. This is for the benefit of the
# `queued` and and `no-worker` states. We cache `is_rootish` not for
# performance, but so it can't change if `TaskGroup` and cluster size does. That
# avoids annoying edge cases where a task does/doesn't look root-ish when it
# goes into `queued` or `unrunnable`, but that's flipped when it comes out.
if (cached := ts._rootish) is not None:
return cached

if ts.resource_restrictions or ts.worker_restrictions or ts.host_restrictions:
return False
tg = ts.group
Expand Down Expand Up @@ -4892,6 +4918,7 @@ def validate_released(self, key):
assert not any([ts in dts.waiters for dts in ts.dependencies])
assert ts not in self.unrunnable
assert ts not in self.queued
assert ts._rootish is None, ts._rootish

def validate_waiting(self, key):
ts: TaskState = self.tasks[key]
Expand All @@ -4900,6 +4927,7 @@ def validate_waiting(self, key):
assert not ts.processing_on
assert ts not in self.unrunnable
assert ts not in self.queued
assert ts._rootish is None, ts._rootish
for dts in ts.dependencies:
# We are waiting on a dependency iff it's not stored
assert bool(dts.who_has) != (dts in ts.waiting_on)
Expand All @@ -4912,6 +4940,7 @@ def validate_queued(self, key):
assert not ts.waiting_on
assert not ts.who_has
assert not ts.processing_on
assert ts._rootish is True, ts._rootish
assert not (
ts.worker_restrictions or ts.host_restrictions or ts.resource_restrictions
)
Expand All @@ -4928,6 +4957,7 @@ def validate_processing(self, key):
assert ts in ws.processing
assert not ts.who_has
assert ts not in self.queued
assert ts._rootish is None, ts._rootish
for dts in ts.dependencies:
assert dts.who_has
assert ts in dts.waiters
Expand All @@ -4941,6 +4971,7 @@ def validate_memory(self, key):
assert not ts.waiting_on
assert ts not in self.unrunnable
assert ts not in self.queued
assert ts._rootish is None, ts._rootish
for dts in ts.dependents:
assert (dts in ts.waiters) == (
dts.state in ("waiting", "queued", "processing", "no-worker")
Expand All @@ -4955,6 +4986,7 @@ def validate_no_worker(self, key):
assert not ts.processing_on
assert not ts.who_has
assert ts not in self.queued
assert ts._rootish is not None, ts._rootish
for dts in ts.dependencies:
assert dts.who_has

Expand All @@ -4963,6 +4995,7 @@ def validate_erred(self, key):
assert ts.exception_blame
assert not ts.who_has
assert ts not in self.queued
assert ts._rootish is None, ts._rootish

def validate_key(self, key, ts: TaskState | None = None):
try:
Expand Down Expand Up @@ -7052,6 +7085,8 @@ def get_metadata(self, keys: list[str], default=no_default):
def set_restrictions(self, worker: dict[str, Collection[str] | str]):
for key, restrictions in worker.items():
ts = self.tasks[key]
if ts._rootish is not None:
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't like that set_restrictions is a public API at all. Doesn't seem like something you should be able to do post-hoc.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This fails test_reschedule_concurrent_requests_deadlock, which sets restrictions on a processing task.

@gen_cluster(
client=True,
nthreads=[("", 1)] * 3,
config={
"distributed.scheduler.work-stealing-interval": 1_000_000,
},
)
async def test_reschedule_concurrent_requests_deadlock(c, s, *workers):
# https://github.com/dask/distributed/issues/5370
steal = s.extensions["stealing"]
w0 = workers[0]
ev = Event()
futs1 = c.map(
lambda _, ev: ev.wait(),
range(10),
ev=ev,
key=[f"f1-{ix}" for ix in range(10)],
workers=[w0.address],
allow_other_workers=True,
)
while not w0.active_keys:
await asyncio.sleep(0.01)
# ready is a heap but we don't need last, just not the next
victim_key = list(w0.active_keys)[0]
victim_ts = s.tasks[victim_key]
wsA = victim_ts.processing_on
other_workers = [ws for ws in s.workers.values() if ws != wsA]
wsB = other_workers[0]
wsC = other_workers[1]
steal.move_task_request(victim_ts, wsA, wsB)
s.set_restrictions(worker={victim_key: [wsB.address]})
s._reschedule(victim_key, stimulus_id="test")
assert wsB == victim_ts.processing_on
# move_task_request is not responsible for respecting worker restrictions
steal.move_task_request(victim_ts, wsB, wsC)
# Let tasks finish
await ev.set()
await c.gather(futs1)
assert victim_ts.who_has != {wsC}
msgs = steal.story(victim_ts)
msgs = [msg[:-1] for msg in msgs] # Remove random IDs
# There are three possible outcomes
expect1 = [
("stale-response", victim_key, "executing", wsA.address),
("already-computing", victim_key, "executing", wsB.address, wsC.address),
]
expect2 = [
("already-computing", victim_key, "executing", wsB.address, wsC.address),
("already-aborted", victim_key, "executing", wsA.address),
]
# This outcome appears only in ~2% of the runs
expect3 = [
("already-computing", victim_key, "executing", wsB.address, wsC.address),
("already-aborted", victim_key, "memory", wsA.address),
]
assert msgs in (expect1, expect2, expect3)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't like that set_restrictions is a public API at all. Doesn't seem like something you should be able to do post-hoc.

We can change it. First step is a deprecation warning

raise ValueError(f"cannot set restrictions on ready {ts}")
if isinstance(restrictions, str):
restrictions = {restrictions}
ts.worker_restrictions = set(restrictions)
Expand Down Expand Up @@ -7774,6 +7809,7 @@ def _validate_ready(state: SchedulerState, ts: TaskState) -> None:
assert ts not in state.unrunnable
assert ts not in state.queued
assert all(dts.who_has for dts in ts.dependencies)
assert ts._rootish is not None, ts._rootish


def _add_to_processing(
Expand All @@ -7785,6 +7821,7 @@ def _add_to_processing(
assert ws in state.running, state.running
assert (o := state.workers.get(ws.address)) is ws, (ws, o)

ts._rootish = None
ws.add_to_processing(ts)
ts.processing_on = ws
ts.state = "processing"
Expand Down Expand Up @@ -7812,6 +7849,9 @@ def _exit_processing_common(
--------
Scheduler._set_duration_estimate
"""
if state.validate:
assert ts._rootish is None, ts._rootish

ws = ts.processing_on
assert ws
ts.processing_on = None
Expand Down Expand Up @@ -7913,6 +7953,7 @@ def _propagate_released(
recommendations: Recs,
) -> None:
ts.state = "released"
ts._rootish = None
key = ts.key

if ts.has_lost_dependencies:
Expand Down
79 changes: 79 additions & 0 deletions distributed/tests/test_scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import pickle
import re
import sys
from contextlib import AsyncExitStack
from itertools import product
from textwrap import dedent
from time import sleep
Expand Down Expand Up @@ -481,6 +482,84 @@ async def test_queued_remove_add_worker(c, s, a, b):
await wait(fs)


@gen_cluster(
client=True,
nthreads=[("", 2)] * 2,
config={
"distributed.worker.memory.pause": False,
"distributed.worker.memory.target": False,
"distributed.worker.memory.spill": False,
"distributed.scheduler.work-stealing": False,
},
)
async def test_queued_rootish_changes_while_paused(c, s, a, b):
"Some tasks are root-ish, some aren't. So both `unrunnable` and `queued` contain non-restricted tasks."

root = c.submit(inc, 1, key="root")
await root

# manually pause the workers
a.status = Status.paused
b.status = Status.paused

await async_wait_for(lambda: not s.running, 5)

fs = [c.submit(inc, root, key=f"inc-{i}") for i in range(s.total_nthreads * 2 + 1)]
# ^ `c.submit` in a for-loop so the first tasks don't look root-ish (`TaskGroup` too
# small), then the last one does. So N-1 tasks will go to `no-worker`, and the last
# to `queued`. `is_rootish` is just messed up like that.

await async_wait_for(lambda: len(s.tasks) > len(fs), 5)

# un-pause
a.status = Status.running
b.status = Status.running
await async_wait_for(lambda: len(s.running) == len(s.workers), 5)

await c.gather(fs)


@gen_cluster(
client=True,
nthreads=[("", 1)],
config={"distributed.scheduler.work-stealing": False},
)
async def test_queued_rootish_changes_scale_up(c, s, a):
"Tasks are initially root-ish. After cluster scales, they don't meet the definition, but still are."

root = c.submit(inc, 1, key="root")

event = Event()
clog = c.submit(event.wait, key="clog")
await wait_for_state(clog.key, "processing", s)

fs = c.map(inc, [root] * 5, key=[f"inc-{i}" for i in range(5)])

await async_wait_for(lambda: len(s.tasks) > len(fs), 5)

if not s.is_rootish(s.tasks[fs[0].key]):
pytest.fail(
"Test assumptions have changed; task is not root-ish. Test may no longer be relevant."
)
if math.isfinite(s.WORKER_SATURATION):
assert s.queued

async with AsyncExitStack() as stack:
for _ in range(3):
await stack.enter_async_context(Worker(s.address, nthreads=2))

if not s.is_rootish(s.tasks[fs[0].key]):
pytest.fail(
"Test assumptions have changed; root-ish-ness has flipped. Test may no longer be relevant."
)

await event.set()
await clog

# Just verify it doesn't deadlock
await c.gather(fs)


@gen_cluster(client=True, nthreads=[("", 1)])
async def test_secede_opens_slot(c, s, a):
first = Event()
Expand Down