Skip to content

Commit

Permalink
Don't invoke log_event from state machine (#6512)
Browse files Browse the repository at this point in the history
  • Loading branch information
crusaderky authored Jun 7, 2022
1 parent febd41d commit bde90af
Show file tree
Hide file tree
Showing 9 changed files with 146 additions and 142 deletions.
15 changes: 7 additions & 8 deletions distributed/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
from functools import partial
from numbers import Number
from queue import Queue as pyQueue
from typing import ClassVar, Literal
from typing import Any, ClassVar, Literal

from tlz import first, groupby, keymap, merge, partition_all, valmap

Expand All @@ -51,6 +51,7 @@
from tornado import gen
from tornado.ioloop import PeriodicCallback

import distributed.utils
from distributed import cluster_dump, preloading
from distributed import versions as version_module
from distributed.batched import BatchedSend
Expand Down Expand Up @@ -80,8 +81,6 @@
from distributed.sizeof import sizeof
from distributed.threadpoolexecutor import rejoin
from distributed.utils import (
All,
Any,
CancelledError,
LoopRunner,
NoOpAwaitable,
Expand Down Expand Up @@ -2028,7 +2027,7 @@ async def wait(k):
logger.debug("Waiting on futures to clear before gather")

with suppress(AllExit):
await All(
await distributed.utils.All(
[wait(key) for key in keys if key in self.futures],
quiet_exceptions=AllExit,
)
Expand Down Expand Up @@ -4053,12 +4052,12 @@ def benchmark_hardware(self) -> dict:
"""
return self.sync(self.scheduler.benchmark_hardware)

def log_event(self, topic, msg):
def log_event(self, topic: str | Collection[str], msg: Any):
"""Log an event under a given topic
Parameters
----------
topic : str, list
topic : str, list[str]
Name of the topic under which to log an event. To log the same
event under multiple topics, pass a list of topic names.
msg
Expand Down Expand Up @@ -4648,9 +4647,9 @@ async def _wait(fs, timeout=None, return_when=ALL_COMPLETED):
)
fs = futures_of(fs)
if return_when == ALL_COMPLETED:
wait_for = All
wait_for = distributed.utils.All
elif return_when == FIRST_COMPLETED:
wait_for = Any
wait_for = distributed.utils.Any
else:
raise NotImplementedError(
"Only return_when='ALL_COMPLETED' and 'FIRST_COMPLETED' are supported"
Expand Down
28 changes: 16 additions & 12 deletions distributed/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -3070,7 +3070,7 @@ def __init__(
"get_logs": self.get_logs,
"logs": self.get_logs,
"worker_logs": self.get_worker_logs,
"log_event": self.log_worker_event,
"log_event": self.log_event,
"events": self.get_events,
"nbytes": self.get_nbytes,
"versions": self.versions,
Expand Down Expand Up @@ -6223,7 +6223,11 @@ async def feed(
if teardown:
teardown(self, state)

def log_worker_event(self, worker=None, topic=None, msg=None):
def log_worker_event(
self, worker: str, topic: str | Collection[str], msg: Any
) -> None:
if isinstance(msg, dict):
msg["worker"] = worker
self.log_event(topic, msg)

def subscribe_worker_status(self, comm=None):
Expand Down Expand Up @@ -6905,21 +6909,21 @@ async def get_worker_logs(self, n=None, workers=None, nanny=False):
)
return results

def log_event(self, name, msg):
def log_event(self, topic: str | Collection[str], msg: Any) -> None:
event = (time(), msg)
if isinstance(name, (list, tuple)):
for n in name:
self.events[n].append(event)
self.event_counts[n] += 1
self._report_event(n, event)
if not isinstance(topic, str):
for t in topic:
self.events[t].append(event)
self.event_counts[t] += 1
self._report_event(t, event)
else:
self.events[name].append(event)
self.event_counts[name] += 1
self._report_event(name, event)
self.events[topic].append(event)
self.event_counts[topic] += 1
self._report_event(topic, event)

for plugin in list(self.plugins.values()):
try:
plugin.log_event(name, msg)
plugin.log_event(topic, msg)
except Exception:
logger.info("Plugin failed with exception", exc_info=True)

Expand Down
10 changes: 5 additions & 5 deletions distributed/tests/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -6603,8 +6603,8 @@ def setup(self, worker=None):
await c.register_worker_plugin(MyPlugin())


@gen_cluster(client=True)
async def test_log_event(c, s, a, b):
@gen_cluster(client=True, nthreads=[("", 1)])
async def test_log_event(c, s, a):

# Log an event from inside a task
def foo():
Expand All @@ -6614,7 +6614,7 @@ def foo():
await c.submit(foo)
events = await c.get_events("topic1")
assert len(events) == 1
assert events[0][1] == {"foo": "bar"}
assert events[0][1] == {"foo": "bar", "worker": a.address}

# Log an event while on the scheduler
def log_scheduler(dask_scheduler):
Expand Down Expand Up @@ -7135,7 +7135,7 @@ def user_event_handler(event):

time_, msg = log[0]
assert isinstance(time_, float)
assert msg == {"important": "event"}
assert msg == {"important": "event", "worker": a.address}

c.unsubscribe_topic("test-topic")

Expand Down Expand Up @@ -7166,7 +7166,7 @@ async def async_user_event_handler(event):
assert len(log) == 2
time_, msg = log[1]
assert isinstance(time_, float)
assert msg == {"async": "event"}
assert msg == {"async": "event", "worker": a.address}

# Even though the middle event was not subscribed to, the scheduler still
# knows about all and we can retrieve them
Expand Down
41 changes: 25 additions & 16 deletions distributed/tests/test_utils_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,12 @@
raises_with_cause,
tls_only_security,
)
from distributed.worker import InvalidTransition, fail_hard
from distributed.worker import fail_hard
from distributed.worker_state_machine import (
InvalidTaskState,
InvalidTransition,
StateMachineEvent,
)


def test_bare_cluster(loop):
Expand Down Expand Up @@ -645,18 +650,22 @@ def test_start_failure_scheduler():


def test_invalid_transitions(capsys):
@gen_cluster(client=True, nthreads=[("127.0.0.1", 1)])
class BrokenEvent(StateMachineEvent):
pass

class MyWorker(Worker):
@Worker._handle_event.register
def _(self, ev: BrokenEvent):
ts = next(iter(self.tasks.values()))
return {ts: "foo"}, []

@gen_cluster(client=True, Worker=MyWorker, nthreads=[("", 1)])
async def test_log_invalid_transitions(c, s, a):
x = c.submit(inc, 1, key="task-name")
y = c.submit(inc, x)
xkey = x.key
del x
await y
while a.tasks[xkey].state != "released":
await asyncio.sleep(0.01)
ts = a.tasks[xkey]
await x

with pytest.raises(InvalidTransition):
a._transition(ts, "foo", stimulus_id="bar")
a.handle_stimulus(BrokenEvent(stimulus_id="test"))

while not s.events["invalid-worker-transition"]:
await asyncio.sleep(0.01)
Expand All @@ -674,20 +683,20 @@ async def test_log_invalid_transitions(c, s, a):
assert "task-name" in out + err


def test_invalid_worker_states(capsys):
def test_invalid_worker_state(capsys):
@gen_cluster(client=True, nthreads=[("127.0.0.1", 1)])
async def test_log_invalid_worker_task_states(c, s, a):
async def test_log_invalid_worker_task_state(c, s, a):
x = c.submit(inc, 1, key="task-name")
await x
a.tasks[x.key].state = "released"
with pytest.raises(Exception):
a.validate_task(a.tasks[x.key])
with pytest.raises(InvalidTaskState):
a.validate_state()

while not s.events["invalid-worker-task-states"]:
while not s.events["invalid-worker-task-state"]:
await asyncio.sleep(0.01)

with pytest.raises(Exception) as info:
test_log_invalid_worker_task_states()
test_log_invalid_worker_task_state()

out, err = capsys.readouterr()

Expand Down
24 changes: 0 additions & 24 deletions distributed/tests/test_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,6 @@
slowsum,
)
from distributed.worker import (
InvalidTransition,
Worker,
benchmark_disk,
benchmark_memory,
Expand Down Expand Up @@ -3402,29 +3401,6 @@ async def test_tick_interval(c, s, a, b):
time.sleep(0.200)


@gen_cluster(client=True, nthreads=[("127.0.0.1", 1)])
async def test_log_invalid_transitions(c, s, a):
x = c.submit(inc, 1)
y = c.submit(inc, x)
xkey = x.key
del x
await y
while a.tasks[xkey].state != "released":
await asyncio.sleep(0.01)
ts = a.tasks[xkey]
with pytest.raises(InvalidTransition):
a._transition(ts, "foo", stimulus_id="bar")

while not s.events["invalid-worker-transition"]:
await asyncio.sleep(0.01)

assert "foo" in str(s.events["invalid-worker-transition"])
assert a.address in str(s.events["invalid-worker-transition"])
assert ts.key in str(s.events["invalid-worker-transition"])

del s.events["invalid-worker-transition"] # for test cleanup


class BreakingWorker(Worker):
broke_once = False

Expand Down
3 changes: 2 additions & 1 deletion distributed/tests/test_worker_state_machine.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
ExecuteFailureEvent,
ExecuteSuccessEvent,
Instruction,
RecommendationsConflict,
ReleaseWorkerDataMsg,
RescheduleEvent,
RescheduleMsg,
Expand Down Expand Up @@ -123,7 +124,7 @@ def test_merge_recs_instructions():
{x: "memory"},
[],
)
with pytest.raises(ValueError):
with pytest.raises(RecommendationsConflict):
merge_recs_instructions(({x: "memory"}, []), ({x: "released"}, []))


Expand Down
4 changes: 2 additions & 2 deletions distributed/utils_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -949,10 +949,10 @@ def check_invalid_worker_transitions(s: Scheduler) -> None:


def check_invalid_task_states(s: Scheduler) -> None:
if not s.events.get("invalid-worker-task-states"):
if not s.events.get("invalid-worker-task-state"):
return

for timestamp, msg in s.events["invalid-worker-task-states"]:
for timestamp, msg in s.events["invalid-worker-task-state"]:
print("Worker:", msg["worker"])
print("State:", msg["state"])
for line in msg["story"]:
Expand Down
Loading

0 comments on commit bde90af

Please sign in to comment.