Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Validate and debug state machine on handle_compute_task #6327

Merged
merged 8 commits into from
May 13, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 12 additions & 11 deletions distributed/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -1841,7 +1841,7 @@ def transition_waiting_processing(self, key, stimulus_id):
assert not ts.processing_on
assert not ts.has_lost_dependencies
assert ts not in self.unrunnable
assert all([dts.who_has for dts in ts.dependencies])
assert all(dts.who_has for dts in ts.dependencies)

ws = self.decide_worker(ts)
if ws is None:
Expand Down Expand Up @@ -7125,6 +7125,9 @@ def request_acquire_replicas(self, addr: str, keys: list, *, stimulus_id: str):
ts = self.tasks[key]
who_has[key] = {ws.address for ws in ts.who_has}

if self.validate:
assert all(who_has.values())

self.stream_comms[addr].send(
{
"op": "acquire-replicas",
Expand Down Expand Up @@ -7329,21 +7332,19 @@ def _task_to_msg(
"priority": ts.priority,
"duration": duration,
"stimulus_id": f"compute-task-{time()}",
"who_has": {},
"who_has": {
dts.key: [ws.address for ws in dts.who_has] for dts in ts.dependencies
},
"nbytes": {dts.key: dts.nbytes for dts in ts.dependencies},
}
if state.validate:
assert all(msg["who_has"].values())

if ts.resource_restrictions:
msg["resource_restrictions"] = ts.resource_restrictions
if ts.actor:
msg["actor"] = True

deps = ts.dependencies
if deps:
msg["who_has"] = {dts.key: [ws.address for ws in dts.who_has] for dts in deps}
msg["nbytes"] = {dts.key: dts.nbytes for dts in deps}

if state.validate:
assert all(msg["who_has"].values())

if isinstance(ts.run_spec, dict):
msg.update(ts.run_spec)
else:
Expand Down Expand Up @@ -7480,7 +7481,7 @@ def validate_task_state(ts: TaskState) -> None:
assert bool(ts.who_has) == (ts.state == "memory"), (ts, ts.who_has, ts.state)

if ts.state == "processing":
assert all([dts.who_has for dts in ts.dependencies]), (
assert all(dts.who_has for dts in ts.dependencies), (
"task processing without all deps",
str(ts),
str(ts.dependencies),
Expand Down
2 changes: 1 addition & 1 deletion distributed/tests/test_cancelled_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ def f(ev):
assert_story(
a.story("f1"),
[
("f1", "compute-task"),
("f1", "compute-task", "released"),
("f1", "released", "waiting", "waiting", {"f1": "ready"}),
("f1", "waiting", "ready", "ready", {"f1": "executing"}),
("f1", "ready", "executing", "executing", {}),
Expand Down
2 changes: 1 addition & 1 deletion distributed/tests/test_stories.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,7 +142,7 @@ async def test_worker_story_with_deps(c, s, a, b):

# This is a simple transition log
expected = [
("res", "compute-task"),
("res", "compute-task", "released"),
("res", "released", "waiting", "waiting", {"dep": "fetch"}),
("res", "waiting", "ready", "ready", {"res": "executing"}),
("res", "ready", "executing", "executing", {}),
Expand Down
41 changes: 14 additions & 27 deletions distributed/tests/test_stress.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
from distributed import Client, Nanny, wait
from distributed.chaos import KillWorker
from distributed.compatibility import WINDOWS
from distributed.config import config
from distributed.metrics import time
from distributed.utils import CancelledError
from distributed.utils_test import (
Expand Down Expand Up @@ -121,58 +120,46 @@ async def create_and_destroy_worker(delay):
assert await c.compute(z) == 8000884.93


@gen_cluster(nthreads=[("127.0.0.1", 1)] * 10, client=True, timeout=60)
@gen_cluster(nthreads=[("", 1)] * 10, client=True)
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This test is functionally identical to before - all changes are just cosmetic.

async def test_stress_scatter_death(c, s, *workers):
import random

s.allowed_failures = 1000
np = pytest.importorskip("numpy")
L = await c.scatter([np.random.random(10000) for i in range(len(workers))])
L = await c.scatter(
{f"scatter-{i}": np.random.random(10000) for i in range(len(workers))}
)
L = list(L.values())
await c.replicate(L, n=2)

adds = [
delayed(slowadd, pure=True)(
delayed(slowadd)(
random.choice(L),
random.choice(L),
delay=0.05,
dask_key_name="slowadd-1-%d" % i,
dask_key_name=f"slowadd-1-{i}",
)
for i in range(50)
]

adds = [
delayed(slowadd, pure=True)(a, b, delay=0.02, dask_key_name="slowadd-2-%d" % i)
delayed(slowadd)(a, b, delay=0.02, dask_key_name=f"slowadd-2-{i}")
for i, (a, b) in enumerate(sliding_window(2, adds))
]

futures = c.compute(adds)
L = adds = None

alive = list(workers)
del L
del adds

from distributed.scheduler import logger
for w in random.sample(workers, 7):
s.validate_state()
for w2 in workers:
w2.validate_state()

for i in range(7):
await asyncio.sleep(0.1)
try:
s.validate_state()
except Exception as c:
logger.exception(c)
if config.get("log-on-err"):
import pdb

pdb.set_trace()
else:
raise
w = random.choice(alive)
await w.close()
alive.remove(w)

with suppress(CancelledError):
await c.gather(futures)

futures = None


def vsum(*args):
return sum(args)
Expand Down
4 changes: 2 additions & 2 deletions distributed/tests/test_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -3157,7 +3157,7 @@ async def test_task_flight_compute_oserror(c, s, a, b):

sum_story = b.story("f1")
expected_sum_story = [
("f1", "compute-task"),
("f1", "compute-task", "released"),
(
"f1",
"released",
Expand All @@ -3174,7 +3174,7 @@ async def test_task_flight_compute_oserror(c, s, a, b):
("f1", "waiting", "released", "released", lambda msg: msg["f1"] == "forgotten"),
("f1", "released", "forgotten", "forgotten", {}),
# Now, we actually compute the task *once*. This must not cycle back
("f1", "compute-task"),
("f1", "compute-task", "released"),
("f1", "released", "waiting", "waiting", {"f1": "ready"}),
("f1", "waiting", "ready", "ready", {"f1": "executing"}),
("f1", "ready", "executing", "executing", {}),
Expand Down
6 changes: 3 additions & 3 deletions distributed/tests/test_worker_state_machine.py
Original file line number Diff line number Diff line change
Expand Up @@ -268,17 +268,17 @@ async def test_fetch_to_compute(c, s, a, b):
# FIXME: This log should be replaced with an
# StateMachineEvent/Instruction log
[
(f2.key, "compute-task"),
(f2.key, "compute-task", "released"),
# This is a "please fetch" request. We don't have anything like
# this, yet. We don't see the request-dep signal in here because we
# do not wait for the key to be actually scheduled
(f1.key, "ensure-task-exists", "released"),
# After the worker failed, we're instructed to forget f2 before
# something new comes in
("free-keys", (f2.key,)),
(f1.key, "compute-task"),
(f1.key, "compute-task", "released"),
(f1.key, "put-in-memory"),
(f2.key, "compute-task"),
(f2.key, "compute-task", "released"),
],
)

Expand Down
99 changes: 58 additions & 41 deletions distributed/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -1855,6 +1855,10 @@ def handle_acquire_replicas(
who_has: dict[str, Collection[str]],
stimulus_id: str,
) -> None:
if self.validate:
assert set(keys) == who_has.keys()
assert all(who_has.values())

recommendations: Recs = {}
for key in keys:
ts = self.ensure_task_exists(
Expand All @@ -1872,6 +1876,10 @@ def handle_acquire_replicas(
self.update_who_has(who_has)
self.transitions(recommendations, stimulus_id=stimulus_id)

if self.validate:
for key in keys:
assert self.tasks[key].state != "released", self.story(key)

def ensure_task_exists(
self, key: str, *, priority: tuple[int, ...], stimulus_id: str
) -> TaskState:
Expand All @@ -1892,19 +1900,18 @@ def handle_compute_task(
*,
key: str,
who_has: dict[str, Collection[str]],
nbytes: dict[str, int],
priority: tuple[int, ...],
duration: float,
function=None,
args=None,
kwargs=None,
task=no_value, # distributed.scheduler.TaskState.run_spec
nbytes: dict[str, int] | None = None,
resource_restrictions: dict[str, float] | None = None,
actor: bool = False,
annotations: dict | None = None,
stimulus_id: str,
) -> None:
self.log.append((key, "compute-task", stimulus_id, time()))
try:
ts = self.tasks[key]
logger.debug(
Expand All @@ -1913,47 +1920,14 @@ def handle_compute_task(
)
except KeyError:
self.tasks[key] = ts = TaskState(key)

ts.run_spec = SerializedTask(function, args, kwargs, task)

assert isinstance(priority, tuple)
priority = priority + (self.generation,)
self.generation -= 1

if actor:
self.actors[ts.key] = None

ts.exception = None
ts.traceback = None
ts.exception_text = ""
ts.traceback_text = ""
ts.priority = priority
ts.duration = duration
if resource_restrictions:
ts.resource_restrictions = resource_restrictions
ts.annotations = annotations
self.log.append((key, "compute-task", ts.state, stimulus_id, time()))

recommendations: Recs = {}
instructions: Instructions = []
for dependency in who_has:
dep_ts = self.ensure_task_exists(
key=dependency,
priority=priority,
stimulus_id=stimulus_id,
)

# link up to child / parents
ts.dependencies.add(dep_ts)
dep_ts.dependents.add(ts)

if nbytes is not None:
for key, value in nbytes.items():
self.tasks[key].nbytes = value

if ts.state in READY | {"executing", "waiting", "resumed"}:
if ts.state in READY | {"executing", "long-running", "waiting", "resumed"}:
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I omitted a unit test for this - something to write after the state machine refactor for sure

pass
elif ts.state == "memory":
recommendations[ts] = "memory"
instructions.append(
self._get_task_finished_msg(ts, stimulus_id=stimulus_id)
)
Expand All @@ -1966,12 +1940,56 @@ def handle_compute_task(
"error",
}:
recommendations[ts] = "waiting"
else: # pragma: no cover

ts.run_spec = SerializedTask(function, args, kwargs, task)

assert isinstance(priority, tuple)
priority = priority + (self.generation,)
self.generation -= 1

if actor:
self.actors[ts.key] = None

ts.exception = None
ts.traceback = None
ts.exception_text = ""
ts.traceback_text = ""
ts.priority = priority
ts.duration = duration
if resource_restrictions:
ts.resource_restrictions = resource_restrictions
ts.annotations = annotations

if self.validate:
assert who_has.keys() == nbytes.keys()
assert all(who_has.values())

for dep_key, dep_workers in who_has.items():
dep_ts = self.ensure_task_exists(
key=dep_key,
priority=priority,
stimulus_id=stimulus_id,
)
# link up to child / parents
ts.dependencies.add(dep_ts)
dep_ts.dependents.add(ts)

for dep_key, value in nbytes.items():
self.tasks[dep_key].nbytes = value

self.update_who_has(who_has)
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This move prevents deps to be created in resumed state by ensure_task_exists and then remain there because there's nothing actually needing them.

else: # pragma: nocover
raise RuntimeError(f"Unexpected task state encountered {ts} {stimulus_id}")

self._handle_instructions(instructions)
self.update_who_has(who_has)
self.transitions(recommendations, stimulus_id=stimulus_id)
self._handle_instructions(instructions)

if self.validate:
# All previously unknown tasks that were created above by
# ensure_tasks_exists() have been transitioned to fetch or flight
assert all(
ts2.state != "released" for ts2 in (ts, *ts.dependencies)
), self.story(ts, *ts.dependencies)
Copy link
Collaborator Author

@crusaderky crusaderky May 13, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

At the moment of writing, this assertion fails in test_stress_scatter_death 0.4% of the times on a fast desktop.
Explanation in #6305. Resolution out of scope for this PR.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

FYI this assert is not 100% correct. There is a case for valid tasks left in released in the case of cancelled/resumed tasks. I'll open a follow up PR with a case reproducing this condition

Copy link
Member

@fjetter fjetter May 18, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

flowchart TD
  A1[A1 - forgotten / not known] --> B1[B1 - flight]
  A2[A1 - forgotten / not known] --> B1[B1 - flight]
  B1 --> C1[C1 - waiting]
Loading

free-keys / cancel B1

flowchart TD
  A1[A1 - forgotten / not known] --> B1[B1 - cancelled]
  A2[A1 - forgotten / not known] --> B1[B1 - cancelled]
  B1 --> C1[C1 - forgotten]
Loading

compute-task B1

flowchart TD
  A1[A1 - released] --> B1[B1 - resumed]
  A2[A1 - released] --> B1[B1 - resumed]
  B1 --> C1[C1 - forgotten]
Loading

gather-dep finishes w/ Error

flowchart TD
  A1[A1 - fetch] --> B1[B1 - waiting]
  A2[A1 - fetch] --> B1[B1 - waiting]
  B1 --> C1[C1 - forgotten]
Loading

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think I understand from the above diagram how this can end up with a released state by the end of compute-task?


########################
# Worker State Machine #
Expand Down Expand Up @@ -3436,7 +3454,6 @@ async def find_missing(self) -> None:
self.scheduler.who_has,
keys=[ts.key for ts in self._missing_dep_flight],
)
who_has = {k: v for k, v in who_has.items() if v}
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Redundant - update_who_has already throws away empty lists of workers

self.update_who_has(who_has)
recommendations: Recs = {}
for ts in self._missing_dep_flight:
Expand Down