From 82bcf8b6c2cd1bbbbd3466c10ec43f40492f379f Mon Sep 17 00:00:00 2001 From: crusaderky Date: Mon, 13 Jun 2022 14:29:34 +0100 Subject: [PATCH] Everything else --- distributed/active_memory_manager.py | 3 +- distributed/diagnostics/plugin.py | 2 +- distributed/node.py | 7 +- distributed/scheduler.py | 5 +- distributed/shuffle/shuffle_extension.py | 2 +- .../tests/test_active_memory_manager.py | 6 +- distributed/tests/test_client.py | 3 +- distributed/tests/test_utils_test.py | 25 +++---- distributed/tests/test_worker.py | 42 ++++++----- .../tests/test_worker_state_machine.py | 69 +++++++++++++++++++ distributed/utils_test.py | 21 +++--- distributed/worker_memory.py | 11 ++- docs/source/worker.rst | 10 +++ 13 files changed, 146 insertions(+), 60 deletions(-) diff --git a/distributed/active_memory_manager.py b/distributed/active_memory_manager.py index 6af1fbe3be5..7fef5c85eeb 100644 --- a/distributed/active_memory_manager.py +++ b/distributed/active_memory_manager.py @@ -416,8 +416,9 @@ def run( ) -> SuggestionGenerator: """This method is invoked by the ActiveMemoryManager every few seconds, or whenever the user invokes ``client.amm.run_once``. + It is an iterator that must emit - :class:`~distributed.active_memory_manager.Suggestion`s: + :class:`~distributed.active_memory_manager.Suggestion` objects: - ``Suggestion("replicate", )`` - ``Suggestion("replicate", , {subset of potential workers to replicate to})`` diff --git a/distributed/diagnostics/plugin.py b/distributed/diagnostics/plugin.py index 2f431da3ae8..bec5f72c7a5 100644 --- a/distributed/diagnostics/plugin.py +++ b/distributed/diagnostics/plugin.py @@ -334,7 +334,7 @@ def __init__(self, filepath): async def setup(self, worker): response = await worker.upload_file( - comm=None, filename=self.filename, data=self.data, load=True + filename=self.filename, data=self.data, load=True ) assert len(self.data) == response["nbytes"] diff --git a/distributed/node.py b/distributed/node.py index 6fedd1b8ace..922125d8875 100644 --- a/distributed/node.py +++ b/distributed/node.py @@ -77,15 +77,16 @@ def stop_services(self): def service_ports(self): return {k: v.port for k, v in self.services.items()} - def _setup_logging(self, logger): + def _setup_logging(self, *loggers): self._deque_handler = DequeHandler( n=dask.config.get("distributed.admin.log-length") ) self._deque_handler.setFormatter( logging.Formatter(dask.config.get("distributed.admin.log-format")) ) - logger.addHandler(self._deque_handler) - weakref.finalize(self, logger.removeHandler, self._deque_handler) + for logger in loggers: + logger.addHandler(self._deque_handler) + weakref.finalize(self, logger.removeHandler, self._deque_handler) def get_logs(self, start=0, n=None, timestamps=False): """ diff --git a/distributed/scheduler.py b/distributed/scheduler.py index a51144d5b28..26c5146d388 100644 --- a/distributed/scheduler.py +++ b/distributed/scheduler.py @@ -347,7 +347,10 @@ def _to_dict(self, *, exclude: Container[str] = ()) -> dict: class WorkerState: - """A simple object holding information about a worker.""" + """A simple object holding information about a worker. + + Not to be confused with :class:`distributed.worker_state_machine.WorkerState`. + """ #: This worker's unique key. This can be its connected address #: (such as ``"tcp://127.0.0.1:8891"``) or an alias (such as ``"alice"``). diff --git a/distributed/shuffle/shuffle_extension.py b/distributed/shuffle/shuffle_extension.py index dbf69460019..3f26c595a80 100644 --- a/distributed/shuffle/shuffle_extension.py +++ b/distributed/shuffle/shuffle_extension.py @@ -230,7 +230,7 @@ def __init__(self, worker: Worker) -> None: # Initialize self.worker: Worker = worker self.shuffles: dict[ShuffleId, Shuffle] = {} - self.executor = ThreadPoolExecutor(worker.nthreads) + self.executor = ThreadPoolExecutor(worker.state.nthreads) # Handlers ########## diff --git a/distributed/tests/test_active_memory_manager.py b/distributed/tests/test_active_memory_manager.py index 8fcc3f31ced..27768854bc2 100644 --- a/distributed/tests/test_active_memory_manager.py +++ b/distributed/tests/test_active_memory_manager.py @@ -866,8 +866,8 @@ async def test_RetireWorker_no_recipients(c, s, w1, w2, w3, w4): assert set(out) in ({w1.address, w3.address}, {w1.address, w4.address}) assert not s.extensions["amm"].policies assert set(s.workers) in ({w2.address, w3.address}, {w2.address, w4.address}) - # After a Scheduler -> Worker -> WorkerState roundtrip, workers that failed to - # retire went back from closing_gracefully to running and can run tasks + # After a Scheduler -> Worker -> Scheduler roundtrip, workers that failed to retire + # went back from closing_gracefully to running and can run tasks while any(ws.status != Status.running for ws in s.workers.values()): await asyncio.sleep(0.01) assert await c.submit(inc, 1) == 2 @@ -896,7 +896,7 @@ async def test_RetireWorker_all_recipients_are_paused(c, s, a, b): assert not s.extensions["amm"].policies assert set(s.workers) == {a.address, b.address} - # After a Scheduler -> Worker -> WorkerState roundtrip, workers that failed to + # After a Scheduler -> Worker -> Scheduler roundtrip, workers that failed to # retire went back from closing_gracefully to running and can run tasks while ws_a.status != Status.running: await asyncio.sleep(0.01) diff --git a/distributed/tests/test_client.py b/distributed/tests/test_client.py index 28d490f15b5..f1d5dba2342 100644 --- a/distributed/tests/test_client.py +++ b/distributed/tests/test_client.py @@ -1612,6 +1612,7 @@ def g(): os.remove("myfile.zip") +@pytest.mark.slow @gen_cluster(client=True) async def test_upload_file_egg(c, s, a, b): pytest.importorskip("setuptools") @@ -6810,7 +6811,7 @@ async def test_workers_collection_restriction(c, s, a, b): assert a.data and not b.data -@gen_cluster(client=True, nthreads=[("127.0.0.1", 0)]) +@gen_cluster(client=True, nthreads=[("127.0.0.1", 1)]) async def test_get_client_functions_spawn_clusters(c, s, a): # see gh4565 diff --git a/distributed/tests/test_utils_test.py b/distributed/tests/test_utils_test.py index 585fbad0811..78523561d99 100755 --- a/distributed/tests/test_utils_test.py +++ b/distributed/tests/test_utils_test.py @@ -8,6 +8,7 @@ import threading from contextlib import contextmanager from time import sleep +from unittest import mock import pytest import yaml @@ -44,7 +45,8 @@ from distributed.worker_state_machine import ( InvalidTaskState, InvalidTransition, - StateMachineEvent, + PauseEvent, + WorkerState, ) @@ -656,22 +658,17 @@ def test_start_failure_scheduler(): def test_invalid_transitions(capsys): - class BrokenEvent(StateMachineEvent): - pass - - class MyWorker(Worker): - @Worker._handle_event.register - def _(self, ev: BrokenEvent): - ts = next(iter(self.tasks.values())) - return {ts: "foo"}, [] - - @gen_cluster(client=True, Worker=MyWorker, nthreads=[("", 1)]) + @gen_cluster(client=True, nthreads=[("", 1)]) async def test_log_invalid_transitions(c, s, a): x = c.submit(inc, 1, key="task-name") await x - - with pytest.raises(InvalidTransition): - a.handle_stimulus(BrokenEvent(stimulus_id="test")) + ts = a.tasks["task-name"] + ev = PauseEvent(stimulus_id="test") + with mock.patch.object( + WorkerState, "_handle_event", return_value=({ts: "foo"}, []) + ): + with pytest.raises(InvalidTransition): + a.handle_stimulus(ev) while not s.events["invalid-worker-transition"]: await asyncio.sleep(0.01) diff --git a/distributed/tests/test_worker.py b/distributed/tests/test_worker.py index 32cee6209b0..0f1c29d95fb 100644 --- a/distributed/tests/test_worker.py +++ b/distributed/tests/test_worker.py @@ -1557,7 +1557,9 @@ async def f(ev): task for task in asyncio.all_tasks() if "execute(f1)" in task.get_name() ) start = time() - with captured_logger("distributed.worker", level=logging.ERROR) as logger: + with captured_logger( + "distributed.worker_state_machine", level=logging.ERROR + ) as logger: await a.close(timeout=1) assert "Failed to cancel asyncio task" in logger.getvalue() assert time() - start < 5 @@ -2030,7 +2032,7 @@ async def test_gather_dep_from_remote_workers_if_all_local_workers_are_busy( assert_story(a.story("receive-dep"), [("receive-dep", rw.address, {"f"})]) -@gen_cluster(client=True, nthreads=[("127.0.0.1", 0)]) +@gen_cluster(client=True, nthreads=[("127.0.0.1", 1)]) async def test_worker_client_uses_default_no_close(c, s, a): """ If a default client is available in the process, the worker will pick this @@ -2057,7 +2059,7 @@ def get_worker_client_id(): assert c is c_def -@gen_cluster(nthreads=[("127.0.0.1", 0)]) +@gen_cluster(nthreads=[("127.0.0.1", 1)]) async def test_worker_client_closes_if_created_on_worker_one_worker(s, a): async with Client(s.address, set_as_default=False, asynchronous=True) as c: with pytest.raises(ValueError): @@ -2542,7 +2544,7 @@ def raise_exc(*args): await asyncio.sleep(0.01) -@gen_cluster(client=True, nthreads=[("127.0.0.1", x) for x in range(4)]) +@gen_cluster(client=True, nthreads=[("", x) for x in (1, 2, 3, 4)]) async def test_hold_on_to_replicas(c, s, *workers): f1 = c.submit(inc, 1, workers=[workers[0].address], key="f1") f2 = c.submit(inc, 2, workers=[workers[1].address], key="f2") @@ -3283,14 +3285,28 @@ async def test_Worker__to_dict(c, s, a): "type", "id", "scheduler", - "nthreads", "address", "status", "thread_id", + "logs", + "config", + "incoming_transfer_log", + "outgoing_transfer_log", + # Attributes of WorkerMemoryManager + "data", + "max_spill", + "memory_limit", + "memory_monitor_interval", + "memory_pause_fraction", + "memory_spill_fraction", + "memory_target_fraction", + # Attributes of WorkerState + "nthreads", + "running", "ready", "constrained", + "executing", "long_running", - "executing_count", "in_flight_tasks", "in_flight_workers", "busy_workers", @@ -3298,23 +3314,11 @@ async def test_Worker__to_dict(c, s, a): "stimulus_log", "transition_counter", "tasks", - "logs", - "config", - "incoming_transfer_log", - "outgoing_transfer_log", "data_needed", "data_needed_per_worker", - # attributes of WorkerMemoryManager - "data", - "max_spill", - "memory_limit", - "memory_monitor_interval", - "memory_pause_fraction", - "memory_spill_fraction", - "memory_target_fraction", } assert d["tasks"]["x"]["key"] == "x" - assert d["data"] == ["x"] + assert d["data"] == {"x": None} @gen_cluster(nthreads=[]) diff --git a/distributed/tests/test_worker_state_machine.py b/distributed/tests/test_worker_state_machine.py index 4fa937581cd..61720834085 100644 --- a/distributed/tests/test_worker_state_machine.py +++ b/distributed/tests/test_worker_state_machine.py @@ -32,6 +32,7 @@ TaskState, TaskStateState, UpdateDataEvent, + WorkerState, merge_recs_instructions, ) @@ -72,6 +73,74 @@ def test_TaskState__to_dict(): ] +def test_WorkerState__to_dict(): + ws = WorkerState(8) + ws.address = "127.0.0.1.1234" + ws.handle_stimulus( + AcquireReplicasEvent(who_has={"x": ["127.0.0.1:1235"]}, stimulus_id="s1") + ) + ws.handle_stimulus( + UpdateDataEvent(data={"y": object()}, report=False, stimulus_id="s2") + ) + + actual = recursive_to_dict(ws) + # Remove timestamps + for ev in actual["log"]: + del ev[-1] + for stim in actual["stimulus_log"]: + del stim["handled"] + + expect = { + "address": "127.0.0.1.1234", + "busy_workers": [], + "constrained": [], + "data": {"y": None}, + "data_needed": ["x"], + "data_needed_per_worker": {"127.0.0.1:1235": ["x"]}, + "executing": [], + "in_flight_tasks": [], + "in_flight_workers": {}, + "log": [ + ["x", "ensure-task-exists", "released", "s1"], + ["x", "released", "fetch", "fetch", {}, "s1"], + ["y", "put-in-memory", "s2"], + ["y", "receive-from-scatter", "s2"], + ], + "long_running": [], + "nthreads": 8, + "ready": [], + "running": True, + "stimulus_log": [ + { + "cls": "AcquireReplicasEvent", + "stimulus_id": "s1", + "who_has": {"x": ["127.0.0.1:1235"]}, + }, + { + "cls": "UpdateDataEvent", + "data": {"y": None}, + "report": False, + "stimulus_id": "s2", + }, + ], + "tasks": { + "x": { + "key": "x", + "priority": [1], + "state": "fetch", + "who_has": ["127.0.0.1:1235"], + }, + "y": { + "key": "y", + "nbytes": 16, + "state": "memory", + }, + }, + "transition_counter": 1, + } + assert actual == expect + + def traverse_subclasses(cls: type) -> Iterator[type]: yield cls for subcls in cls.__subclasses__(): diff --git a/distributed/utils_test.py b/distributed/utils_test.py index 34956571281..42e90de0ebf 100644 --- a/distributed/utils_test.py +++ b/distributed/utils_test.py @@ -70,7 +70,8 @@ reset_logger_locks, sync, ) -from distributed.worker import WORKER_ANY_RUNNING, InvalidTransition, Worker +from distributed.worker import WORKER_ANY_RUNNING, Worker +from distributed.worker_state_machine import InvalidTransition try: import ssl @@ -1271,8 +1272,10 @@ def validate_state(*servers: Scheduler | Worker | Nanny) -> None: Excludes workers wrapped by Nannies and workers manually started by the test. """ for s in servers: - if s.validate and hasattr(s, "validate_state"): - s.validate_state() # type: ignore + if isinstance(s, Scheduler) and s.validate: + s.validate_state() + elif isinstance(s, Worker) and s.state.validate: + s.validate_state() def raises(func, exc=Exception): @@ -2322,13 +2325,13 @@ def freeze_data_fetching(w: Worker, *, jump_start: bool = False): If True, trigger ensure_communicating on exit; this simulates e.g. an unrelated worker moving out of in_flight_workers. """ - old_out_connections = w.total_out_connections - old_comm_threshold = w.comm_threshold_bytes - w.total_out_connections = 0 - w.comm_threshold_bytes = 0 + old_out_connections = w.state.total_out_connections + old_comm_threshold = w.state.comm_threshold_bytes + w.state.total_out_connections = 0 + w.state.comm_threshold_bytes = 0 yield - w.total_out_connections = old_out_connections - w.comm_threshold_bytes = old_comm_threshold + w.state.total_out_connections = old_out_connections + w.state.comm_threshold_bytes = old_comm_threshold if jump_start: w.status = Status.paused w.status = Status.running diff --git a/distributed/worker_memory.py b/distributed/worker_memory.py index 5132afb2a3e..e3aaad21b24 100644 --- a/distributed/worker_memory.py +++ b/distributed/worker_memory.py @@ -68,6 +68,7 @@ def __init__( self, worker: Worker, *, + nthreads: int, memory_limit: str | float = "auto", # This should be None most of the times, short of a power user replacing the # SpillBuffer with their own custom dict-like @@ -84,7 +85,7 @@ def __init__( memory_spill_fraction: float | Literal[False] | None = None, memory_pause_fraction: float | Literal[False] | None = None, ): - self.memory_limit = parse_memory_limit(memory_limit, worker.nthreads) + self.memory_limit = parse_memory_limit(memory_limit, nthreads) self.memory_target_fraction = _parse_threshold( "distributed.worker.memory.target", @@ -293,12 +294,8 @@ async def _maybe_spill(self, worker: Worker, memory: int) -> None: ) def _to_dict(self, *, exclude: Container[str] = ()) -> dict: - info = { - k: v - for k, v in self.__dict__.items() - if not k.startswith("_") and k != "data" and k not in exclude - } - info["data"] = list(self.data) + info = {k: v for k, v in self.__dict__.items() if not k.startswith("_")} + info["data"] = dict.fromkeys(self.data) return info diff --git a/docs/source/worker.rst b/docs/source/worker.rst index 91cac09947a..b3d9f8c0cbd 100644 --- a/docs/source/worker.rst +++ b/docs/source/worker.rst @@ -162,8 +162,18 @@ process. API Documentation ----------------- +.. currentmodule:: distributed.worker_state_machine + .. autoclass:: distributed.worker_state_machine.TaskState :members: +.. autoclass:: distributed.worker_state_machine.WorkerState + :members: + +.. autoclass:: distributed.worker_state_machine.BaseWorker + :members: + +.. currentmodule:: distributed.worker + .. autoclass:: distributed.worker.Worker :members: