diff --git a/distributed/profile.py b/distributed/profile.py index f0535ef0da..e085511545 100644 --- a/distributed/profile.py +++ b/distributed/profile.py @@ -24,12 +24,15 @@ 'children': {...}}} } """ +from __future__ import annotations + import bisect import linecache import sys import threading from collections import defaultdict, deque from time import sleep +from typing import Any import tlz as toolz @@ -152,7 +155,7 @@ def merge(*args): } -def create(): +def create() -> dict[str, Any]: return { "count": 0, "children": {}, diff --git a/distributed/tests/test_scheduler.py b/distributed/tests/test_scheduler.py index b222e09fee..e08720ec36 100644 --- a/distributed/tests/test_scheduler.py +++ b/distributed/tests/test_scheduler.py @@ -1947,7 +1947,7 @@ class NoSchedulerDelayWorker(Worker): comparisons using times reported from workers. """ - @property + @property # type: ignore def scheduler_delay(self): return 0 diff --git a/distributed/tests/test_stress.py b/distributed/tests/test_stress.py index 0e6cd66fa1..1219f76727 100644 --- a/distributed/tests/test_stress.py +++ b/distributed/tests/test_stress.py @@ -99,9 +99,6 @@ async def test_stress_creation_and_deletion(c, s): # Assertions are handled by the validate mechanism in the scheduler da = pytest.importorskip("dask.array") - def _disable_suspicious_counter(dask_worker): - dask_worker._suspicious_count_limit = None - rng = da.random.RandomState(0) x = rng.random(size=(2000, 2000), chunks=(100, 100)) y = ((x + 1).T + (x * 2) - x.mean(axis=1)).sum().round(2) @@ -111,14 +108,12 @@ async def create_and_destroy_worker(delay): start = time() while time() < start + 5: async with Nanny(s.address, nthreads=2) as n: - await c.run(_disable_suspicious_counter, workers=[n.worker_address]) await asyncio.sleep(delay) print("Killed nanny") await asyncio.gather(*(create_and_destroy_worker(0.1 * i) for i in range(20))) async with Nanny(s.address, nthreads=2): - await c.run(_disable_suspicious_counter) assert await c.compute(z) == 8000884.93 diff --git a/distributed/tests/test_worker.py b/distributed/tests/test_worker.py index 81e25d8d03..47233bc135 100644 --- a/distributed/tests/test_worker.py +++ b/distributed/tests/test_worker.py @@ -2968,7 +2968,7 @@ async def test_who_has_consistent_remove_replica(c, s, *workers): await f2 - assert ("missing-dep", f1.key) in a.story(f1.key) + assert (f1.key, "missing-dep") in a.story(f1.key) assert a.tasks[f1.key].suspicious_count == 0 assert s.tasks[f1.key].suspicious == 0 diff --git a/distributed/utils_test.py b/distributed/utils_test.py index d6d3821a01..3d4a55cbef 100644 --- a/distributed/utils_test.py +++ b/distributed/utils_test.py @@ -22,6 +22,7 @@ import warnings import weakref from collections import defaultdict +from collections.abc import Callable from contextlib import contextmanager, nullcontext, suppress from glob import glob from itertools import count @@ -54,6 +55,7 @@ from .diagnostics.plugin import WorkerPlugin from .metrics import time from .nanny import Nanny +from .node import ServerNode from .proctitle import enable_proctitle_on_children from .security import Security from .utils import ( @@ -770,7 +772,7 @@ async def disconnect_all(addresses, timeout=3, rpc_kwargs=None): await asyncio.gather(*(disconnect(addr, timeout, rpc_kwargs) for addr in addresses)) -def gen_test(timeout=_TEST_TIMEOUT): +def gen_test(timeout: float = _TEST_TIMEOUT) -> Callable[[Callable], Callable]: """Coroutine test @gen_test(timeout=5) @@ -797,14 +799,14 @@ def test_func(): async def start_cluster( - nthreads, - scheduler_addr, - loop, - security=None, - Worker=Worker, - scheduler_kwargs={}, - worker_kwargs={}, -): + nthreads: list[tuple[str, int] | tuple[str, int, dict]], + scheduler_addr: str, + loop: IOLoop, + security: Security | dict[str, Any] | None = None, + Worker: type[ServerNode] = Worker, + scheduler_kwargs: dict[str, Any] = {}, + worker_kwargs: dict[str, Any] = {}, +) -> tuple[Scheduler, list[ServerNode]]: s = await Scheduler( loop=loop, validate=True, @@ -813,6 +815,7 @@ async def start_cluster( host=scheduler_addr, **scheduler_kwargs, ) + workers = [ Worker( s.address, @@ -822,7 +825,11 @@ async def start_cluster( loop=loop, validate=True, host=ncore[0], - **(merge(worker_kwargs, ncore[2]) if len(ncore) > 2 else worker_kwargs), + **( + merge(worker_kwargs, ncore[2]) # type: ignore + if len(ncore) > 2 + else worker_kwargs + ), ) for i, ncore in enumerate(nthreads) ] @@ -854,21 +861,24 @@ async def end_worker(w): def gen_cluster( - nthreads=[("127.0.0.1", 1), ("127.0.0.1", 2)], - ncores=None, + nthreads: list[tuple[str, int] | tuple[str, int, dict]] = [ + ("127.0.0.1", 1), + ("127.0.0.1", 2), + ], + ncores: None = None, # deprecated scheduler="127.0.0.1", - timeout=_TEST_TIMEOUT, - security=None, - Worker=Worker, - client=False, - scheduler_kwargs={}, - worker_kwargs={}, - client_kwargs={}, - active_rpc_timeout=1, - config={}, - clean_kwargs={}, - allow_unclosed=False, -): + timeout: float = _TEST_TIMEOUT, + security: Security | dict[str, Any] | None = None, + Worker: type[ServerNode] = Worker, + client: bool = False, + scheduler_kwargs: dict[str, Any] = {}, + worker_kwargs: dict[str, Any] = {}, + client_kwargs: dict[str, Any] = {}, + active_rpc_timeout: float = 1, + config: dict[str, Any] = {}, + clean_kwargs: dict[str, Any] = {}, + allow_unclosed: bool = False, +) -> Callable[[Callable], Callable]: from distributed import Client """ Coroutine test with small cluster diff --git a/distributed/worker.py b/distributed/worker.py index ebd62e9cd9..9fca11e9f8 100644 --- a/distributed/worker.py +++ b/distributed/worker.py @@ -3,7 +3,6 @@ import asyncio import bisect import builtins -import concurrent.futures import errno import heapq import logging @@ -14,15 +13,20 @@ import warnings import weakref from collections import defaultdict, deque, namedtuple -from collections.abc import Callable, Hashable, Iterable, MutableMapping +from collections.abc import Callable, Iterable, Mapping, MutableMapping +from concurrent.futures import Executor from contextlib import suppress from datetime import timedelta from inspect import isawaitable from pickle import PicklingError -from typing import TYPE_CHECKING, Any, ClassVar +from typing import TYPE_CHECKING, Any, ClassVar, cast if TYPE_CHECKING: + from typing_extensions import Literal + from .diagnostics.plugin import WorkerPlugin + from .actor import Actor from .client import Client + from .nanny import Nanny from tlz import first, keymap, merge, pluck # noqa: F401 from tornado.ioloop import IOLoop, PeriodicCallback @@ -55,7 +59,7 @@ ) from .diagnostics import nvml from .diagnostics.plugin import _get_plugin_name -from .diskutils import WorkSpace +from .diskutils import WorkDir, WorkSpace from .http import get_handlers from .metrics import time from .node import ServerNode @@ -206,6 +210,7 @@ def __init__(self, key, runspec=None): self.nbytes = None self.annotations = None self.done = False + self._previous = None self._next = None def __repr__(self): @@ -251,7 +256,7 @@ class Worker(ServerNode): * **nthreads:** ``int``: Number of nthreads used by this worker process - * **executors:** ``Dict[str, concurrent.futures.Executor]``: + * **executors:** ``dict[str, concurrent.futures.Executor]``: Executors used to perform computation. Always contains the default executor. * **local_directory:** ``path``: @@ -332,8 +337,9 @@ class Worker(ServerNode): Parameters ---------- - scheduler_ip: str - scheduler_port: int + scheduler_ip: str, optional + scheduler_port: int, optional + scheduler_file: str, optional ip: str, optional data: MutableMapping, type, None The object to use for storage, builds a disk-backed LRU dict by default @@ -347,13 +353,16 @@ class Worker(ServerNode): Set to zero for no limit. Set to 'auto' to calculate as system.MEMORY_LIMIT * min(1, nthreads / total_cores) Use strings or numbers like 5GB or 5e9 - memory_target_fraction: float + memory_target_fraction: float or False Fraction of memory to try to stay beneath - memory_spill_fraction: float + (default: read from config key distributed.worker.memory.target) + memory_spill_fraction: float or false Fraction of memory at which we start spilling to disk - memory_pause_fraction: float + (default: read from config key distributed.worker.memory.spill) + memory_pause_fraction: float or False Fraction of memory at which we stop running new tasks - executor: concurrent.futures.Executor, dict[str, concurrent.futures.Executor], str + (default: read from config key distributed.worker.memory.pause) + executor: concurrent.futures.Executor, dict[str, concurrent.futures.Executor], "offload" The executor(s) to use. Depending on the type, it has the following meanings: - Executor instance: The default executor. - Dict[str, Executor]: mapping names to Executor instances. If the @@ -376,6 +385,8 @@ class Worker(ServerNode): lifetime_restart: bool Whether or not to restart a worker after it has reached its lifetime Default False + kwargs: optional + Additional parameters to ServerNode constructor Examples -------- @@ -398,50 +409,156 @@ class Worker(ServerNode): _instances: ClassVar[weakref.WeakSet[Worker]] = weakref.WeakSet() _initialized_clients: ClassVar[weakref.WeakSet[Client]] = weakref.WeakSet() + tasks: dict[str, TaskState] + waiting_for_data_count: int + has_what: defaultdict[str, set[str]] # {worker address: {ts.key, ...} + pending_data_per_worker: defaultdict[str, deque[str]] + nanny: Nanny | None + _lock: threading.Lock + data_needed: list[tuple[int, str]] # heap[(ts.priority, ts.key)] + in_flight_workers: dict[str, set[str]] # {worker address: {ts.key, ...}} + total_out_connections: int + total_in_connections: int + comm_threshold_bytes: int + comm_nbytes: int + _missing_dep_flight: set[TaskState] + threads: dict[str, int] # {ts.key: thread ID} + active_threads_lock: threading.Lock + active_threads: dict[int, str] # {thread ID: ts.key} + active_keys: set[str] + profile_keys: defaultdict[str, dict[str, Any]] + profile_keys_history: deque[tuple[float, dict[str, dict[str, Any]]]] + profile_recent: dict[str, Any] + profile_history: deque[tuple[float, dict[str, Any]]] + generation: int + ready: list[str] + constrained: deque[str] + _executing: set[TaskState] + _in_flight_tasks: set[TaskState] + executed_count: int + long_running: set[TaskState] + log: deque[tuple] + incoming_transfer_log: deque[dict[str, Any]] + outgoing_transfer_log: deque[dict[str, Any]] + target_message_size: int + validate: bool + _transitions_table: dict[tuple[str, str], Callable] + _transition_counter: int + incoming_count: int + outgoing_count: int + outgoing_current_count: int + repetitively_busy: int + bandwidth: float + latency: float + profile_cycle_interval: float + workspace: WorkSpace + _workdir: WorkDir + local_directory: str + _client: Client | None + bandwidth_workers: defaultdict[str, tuple[float, int]] + bandwidth_types: defaultdict[type, tuple[float, int]] + preloads: list[preloading.Preload] + contact_address: str | None + _start_port: int | None + _start_host: str | None + _interface: str | None + _protocol: str + _dashboard_address: str | None + _dashboard: bool + _http_prefix: str + nthreads: int + total_resources: dict[str, float] + available_resources: dict[str, float] + death_timeout: float | None + lifetime: float | None + lifetime_stagger: float | None + lifetime_restart: bool + extensions: dict + security: Security + connection_args: dict[str, Any] + memory_limit: int | None + memory_target_fraction: float | Literal[False] + memory_spill_fraction: float | Literal[False] + memory_pause_fraction: float | Literal[False] + data: MutableMapping[str, Any] # {task key: task payload} + actors: dict[str, Actor | None] + loop: IOLoop + reconnect: bool + executors: dict[str, Executor] + batched_stream: BatchedSend + name: Any + scheduler_delay: float + stream_comms: dict[str, BatchedSend] + heartbeat_active: bool + _ipython_kernel: Any | None = None + services: dict[str, Any] = {} + service_specs: dict[str, Any] + metrics: dict[str, Callable[[Worker], Any]] + startup_information: dict[str, Callable[[Worker], Any]] + low_level_profiler: bool + scheduler: Any + execution_state: dict[str, Any] + memory_monitor_interval: float | None + _memory_monitoring: bool + _throttled_gc: ThrottledGC + plugins: dict[str, WorkerPlugin] + _pending_plugins: tuple[WorkerPlugin, ...] + def __init__( self, - scheduler_ip=None, - scheduler_port=None, - scheduler_file=None, - ncores=None, - nthreads=None, - loop=None, - local_dir=None, - local_directory=None, - services=None, - service_ports=None, - service_kwargs=None, - name=None, - reconnect=True, - memory_limit="auto", - executor=None, - resources=None, - silence_logs=None, - death_timeout=None, - preload=None, - preload_argv=None, - security=None, - contact_address=None, - memory_monitor_interval="200ms", - extensions=None, - metrics=DEFAULT_METRICS, - startup_information=DEFAULT_STARTUP_INFORMATION, - data=None, - interface=None, - host=None, - port=None, - protocol=None, - dashboard_address=None, - dashboard=False, - http_prefix="/", - nanny=None, - plugins=(), - low_level_profiler=dask.config.get("distributed.worker.profile.low-level"), - validate=None, + scheduler_ip: str | None = None, + scheduler_port: int | None = None, + *, + scheduler_file: str | None = None, + ncores: None = None, # Deprecated, use nthreads instead + nthreads: int | None = None, + loop: IOLoop | None = None, + local_dir: None = None, # Deprecated, use local_directory instead + local_directory: str | None = None, + services: dict | None = None, + name: Any | None = None, + reconnect: bool = True, + memory_limit: str | float = "auto", + executor: Executor | dict[str, Executor] | Literal["offload"] | None = None, + resources: dict[str, float] | None = None, + silence_logs: int | None = None, + death_timeout: Any | None = None, + preload: list[str] | None = None, + preload_argv: list[str] | list[list[str]] | None = None, + security: Security | dict[str, Any] | None = None, + contact_address: str | None = None, + memory_monitor_interval: Any = "200ms", + memory_target_fraction: float | Literal[False] | None = None, + memory_spill_fraction: float | Literal[False] | None = None, + memory_pause_fraction: float | Literal[False] | None = None, + extensions: list[type] | None = None, + metrics: Mapping[str, Callable[[Worker], Any]] = DEFAULT_METRICS, + startup_information: Mapping[ + str, Callable[[Worker], Any] + ] = DEFAULT_STARTUP_INFORMATION, + data: ( + MutableMapping[str, Any] # pre-initialised + | Callable[[], MutableMapping[str, Any]] # constructor + | tuple[ + Callable[..., MutableMapping[str, Any]], dict[str, Any] + ] # (constructor, kwargs to constructor) + | None # create internatlly + ) = None, + interface: str | None = None, + host: str | None = None, + port: int | None = None, + protocol: str | None = None, + dashboard_address: str | None = None, + dashboard: bool = False, + http_prefix: str = "/", + nanny: Nanny | None = None, + plugins: tuple[WorkerPlugin, ...] = (), + low_level_profiler: bool | None = None, + validate: bool | None = None, profile_cycle_interval=None, - lifetime=None, - lifetime_stagger=None, - lifetime_restart=None, + lifetime: Any | None = None, + lifetime_stagger: Any | None = None, + lifetime_restart: bool | None = None, **kwargs, ): self.tasks = {} @@ -460,7 +577,7 @@ def __init__( self.total_in_connections = dask.config.get( "distributed.worker.connections.incoming" ) - self.comm_threshold_bytes = 10e6 + self.comm_threshold_bytes = int(10e6) self.comm_nbytes = 0 self._missing_dep_flight = set() @@ -483,10 +600,7 @@ def __init__( self.executed_count = 0 self.long_running = set() - self.recent_messages_log = deque( - maxlen=dask.config.get("distributed.comm.recent-messages-log-length") - ) - self.target_message_size = 50e6 # 50 MB + self.target_message_size = int(50e6) # 50 MB self.log = deque(maxlen=100000) if validate is None: @@ -559,6 +673,7 @@ def __init__( if profile_cycle_interval is None: profile_cycle_interval = dask.config.get("distributed.worker.profile.cycle") profile_cycle_interval = parse_timedelta(profile_cycle_interval, default="ms") + assert profile_cycle_interval self._setup_logging(logger) @@ -587,6 +702,8 @@ def __init__( preload = dask.config.get("distributed.worker.preload") if not preload_argv: preload_argv = dask.config.get("distributed.worker.preload-argv") + assert preload is not None + assert preload_argv is not None self.preloads = preloading.process_preloads( self, preload, preload_argv, file_dir=self.local_directory ) @@ -606,6 +723,7 @@ def __init__( protocol_address = scheduler_addr.split("://") if len(protocol_address) == 2: protocol = protocol_address[0] + assert protocol self._start_port = port self._start_host = host @@ -627,6 +745,7 @@ def __init__( self.nthreads = nthreads or CPU_COUNT if resources is None: resources = dask.config.get("distributed.worker.resources", None) + assert isinstance(resources, dict) self.total_resources = resources or {} self.available_resources = (resources or {}).copy() @@ -644,24 +763,21 @@ def __init__( self.memory_limit = parse_memory_limit(memory_limit, self.nthreads) - if "memory_target_fraction" in kwargs: - self.memory_target_fraction = kwargs.pop("memory_target_fraction") - else: - self.memory_target_fraction = dask.config.get( - "distributed.worker.memory.target" - ) - if "memory_spill_fraction" in kwargs: - self.memory_spill_fraction = kwargs.pop("memory_spill_fraction") - else: - self.memory_spill_fraction = dask.config.get( - "distributed.worker.memory.spill" - ) - if "memory_pause_fraction" in kwargs: - self.memory_pause_fraction = kwargs.pop("memory_pause_fraction") - else: - self.memory_pause_fraction = dask.config.get( - "distributed.worker.memory.pause" - ) + self.memory_target_fraction = ( + memory_target_fraction + if memory_target_fraction is not None + else dask.config.get("distributed.worker.memory.target") + ) + self.memory_spill_fraction = ( + memory_spill_fraction + if memory_spill_fraction is not None + else dask.config.get("distributed.worker.memory.spill") + ) + self.memory_pause_fraction = ( + memory_pause_fraction + if memory_pause_fraction is not None + else dask.config.get("distributed.worker.memory.pause") + ) if isinstance(data, MutableMapping): self.data = data @@ -690,7 +806,7 @@ def __init__( self.reconnect = reconnect # Common executors always available - self.executors: dict[str, concurrent.futures.Executor] = { + self.executors = { "offload": utils._offload_executor, "actor": ThreadPoolExecutor(1, thread_name_prefix="Dask-Actor-Threads"), } @@ -733,6 +849,8 @@ def __init__( dict(startup_information) if startup_information else {} ) + if low_level_profiler is None: + low_level_profiler = dask.config.get("distributed.worker.profile.low-level") self.low_level_profiler = low_level_profiler handlers = { @@ -794,14 +912,14 @@ def __init__( pc = PeriodicCallback(self.find_missing, 1000) self.periodic_callbacks["find-missing"] = pc - self._suspicious_count_limit = 10 self._address = contact_address self.memory_monitor_interval = parse_timedelta( memory_monitor_interval, default="ms" ) + self._memory_monitoring = False if self.memory_limit: - self._memory_monitoring = False + assert self.memory_monitor_interval is not None pc = PeriodicCallback( self.memory_monitor, self.memory_monitor_interval * 1000 ) @@ -828,19 +946,18 @@ def __init__( self.plugins = {} self._pending_plugins = plugins - self.lifetime = lifetime or dask.config.get( - "distributed.worker.lifetime.duration" - ) - lifetime_stagger = lifetime_stagger or dask.config.get( - "distributed.worker.lifetime.stagger" - ) - self.lifetime_restart = lifetime_restart or dask.config.get( - "distributed.worker.lifetime.restart" - ) - if isinstance(self.lifetime, str): - self.lifetime = parse_timedelta(self.lifetime) - if isinstance(lifetime_stagger, str): - lifetime_stagger = parse_timedelta(lifetime_stagger) + if lifetime is None: + lifetime = dask.config.get("distributed.worker.lifetime.duration") + self.lifetime = parse_timedelta(lifetime) + + if lifetime_stagger is None: + lifetime_stagger = dask.config.get("distributed.worker.lifetime.stagger") + lifetime_stagger = parse_timedelta(lifetime_stagger) + + if lifetime_restart is None: + lifetime_restart = dask.config.get("distributed.worker.lifetime.restart") + self.lifetime_restart = lifetime_restart + if self.lifetime: self.lifetime += (random.random() * 2 - 1) * lifetime_stagger self.io_loop.call_later(self.lifetime, self.close_gracefully) @@ -1644,7 +1761,7 @@ def handle_remove_replicas(self, keys, stimulus_id): if ts is None or ts.state != "memory": continue if not ts.is_protected(): - self.log.append(("remove-replica-confirmed", ts.key, stimulus_id)) + self.log.append((ts.key, "remove-replica-confirmed", stimulus_id)) recommendations[ts] = "released" if ts.dependents else "forgotten" else: rejected.append(key) @@ -2624,7 +2741,7 @@ async def gather_dep( deps_to_iter = set(self.in_flight_workers.pop(worker)) & to_gather_keys for d in deps_to_iter: - ts = self.tasks.get(d) + ts = cast(TaskState, self.tasks.get(d)) assert ts, (d, self.story(d)) ts.done = True if d in data: @@ -2632,7 +2749,7 @@ async def gather_dep( elif not busy: ts.who_has.discard(worker) self.has_what[worker].discard(ts.key) - self.log.append(("missing-dep", d)) + self.log.append((d, "missing-dep")) self.batched_stream.send( {"op": "missing-data", "errant_worker": worker, "key": d} ) @@ -2747,7 +2864,7 @@ def handle_steal_request(self, key, stimulus_id): def release_key( self, - key: Hashable, + key: str, cause: TaskState | None = None, reason: str | None = None, report: bool = True, @@ -3565,7 +3682,7 @@ def client(self) -> Client: else: return self._get_client() - def _get_client(self, timeout=None) -> Client: + def _get_client(self, timeout: float | None = None) -> Client: """Get local client attached to this worker If no such client exists, create one @@ -3625,7 +3742,7 @@ def _get_client(self, timeout=None) -> Client: return self._client - def get_current_task(self): + def get_current_task(self) -> str: """Get the key of the task we are currently running This only makes sense to run within a task