From 60a5dc252d15d9d9ff971979ddf27b330d2c4653 Mon Sep 17 00:00:00 2001 From: crusaderky Date: Sat, 5 Mar 2022 01:01:00 +0000 Subject: [PATCH] Encapsulate spill buffer and memory_monitor --- distributed/active_memory_manager.py | 6 + distributed/core.py | 9 +- distributed/deploy/local.py | 3 +- distributed/distributed-schema.yaml | 6 + distributed/distributed.yaml | 6 + distributed/nanny.py | 52 +- distributed/scheduler.py | 5 +- distributed/spill.py | 7 +- distributed/system.py | 2 +- .../tests/test_active_memory_manager.py | 100 ++- distributed/tests/test_cancelled_state.py | 59 +- distributed/tests/test_client.py | 4 +- distributed/tests/test_nanny.py | 92 +-- distributed/tests/test_scheduler.py | 47 +- distributed/tests/test_spill.py | 16 +- distributed/tests/test_steal.py | 4 +- distributed/tests/test_worker.py | 471 +------------ distributed/tests/test_worker_memory.py | 620 ++++++++++++++++++ distributed/utils_test.py | 1 - distributed/worker.py | 312 ++------- distributed/worker_memory.py | 384 +++++++++++ setup.cfg | 2 +- 22 files changed, 1215 insertions(+), 993 deletions(-) create mode 100644 distributed/tests/test_worker_memory.py create mode 100644 distributed/worker_memory.py diff --git a/distributed/active_memory_manager.py b/distributed/active_memory_manager.py index 4a616095908..7d10fbcdb15 100644 --- a/distributed/active_memory_manager.py +++ b/distributed/active_memory_manager.py @@ -1,3 +1,9 @@ +"""Implementation of the Active Memory Manager. This is a scheduler extension which +sends drop/replicate suggestions to the worker. + +See also :mod:`distributed.worker_memory` and :mod:`distributed.spill`, which implement +spill/pause/terminate mechanics on the Worker side. +""" from __future__ import annotations import logging diff --git a/distributed/core.py b/distributed/core.py index cf209fd5cae..1dd164ccdd0 100644 --- a/distributed/core.py +++ b/distributed/core.py @@ -264,7 +264,10 @@ def set_thread_ident(): @property def status(self): - return self._status + try: + return self._status + except AttributeError: + return Status.undefined @status.setter def status(self, new_status): @@ -399,9 +402,7 @@ def port(self): def identity(self) -> dict[str, str]: return {"type": type(self).__name__, "id": self.id} - def _to_dict( - self, comm: Comm | None = None, *, exclude: Container[str] = () - ) -> dict: + def _to_dict(self, *, exclude: Container[str] = ()) -> dict: """Dictionary representation for debugging purposes. Not type stable and not intended for roundtrips. diff --git a/distributed/deploy/local.py b/distributed/deploy/local.py index 1180c931085..26b0c07880f 100644 --- a/distributed/deploy/local.py +++ b/distributed/deploy/local.py @@ -10,7 +10,8 @@ from ..nanny import Nanny from ..scheduler import Scheduler from ..security import Security -from ..worker import Worker, parse_memory_limit +from ..worker import Worker +from ..worker_memory import parse_memory_limit from .spec import SpecCluster from .utils import nprocesses_nthreads diff --git a/distributed/distributed-schema.yaml b/distributed/distributed-schema.yaml index d0fd06474e4..df0605ac496 100644 --- a/distributed/distributed-schema.yaml +++ b/distributed/distributed-schema.yaml @@ -512,6 +512,12 @@ properties: description: >- Limit of number of bytes to be spilled on disk. + monitor-interval: + type: object + properties: + spill-pause: {type: string} + terminate: {type: string} + http: type: object description: Settings for Dask's embedded HTTP Server diff --git a/distributed/distributed.yaml b/distributed/distributed.yaml index 7076d5c3364..0449f7ba97d 100644 --- a/distributed/distributed.yaml +++ b/distributed/distributed.yaml @@ -149,6 +149,12 @@ distributed: # Set to false for no maximum. max-spill: false + monitor-interval: + # Interval between checks for the spill, pause, and terminate thresholds. + # The target threshold is checked every time new data is inserted. + spill-pause: 200ms # memory monitor on the Worker + terminate: 100ms # memory monitor on the Nanny + http: routes: - distributed.http.worker.prometheus diff --git a/distributed/nanny.py b/distributed/nanny.py index 65a2d303785..201ef0c380d 100644 --- a/distributed/nanny.py +++ b/distributed/nanny.py @@ -13,11 +13,11 @@ from inspect import isawaitable from queue import Empty from time import sleep as sync_sleep -from typing import TYPE_CHECKING, ClassVar, Literal +from typing import TYPE_CHECKING, ClassVar import psutil from tornado import gen -from tornado.ioloop import IOLoop, PeriodicCallback +from tornado.ioloop import IOLoop import dask from dask.system import CPU_COUNT @@ -43,7 +43,8 @@ parse_ports, silence_logging, ) -from .worker import Worker, parse_memory_limit, run +from .worker import Worker, run +from .worker_memory import DeprecatedMMAccessor, NannyMemoryManager if TYPE_CHECKING: from .diagnostics.plugin import NannyPlugin @@ -83,6 +84,7 @@ class Nanny(ServerNode): _instances: ClassVar[weakref.WeakSet[Nanny]] = weakref.WeakSet() process = None status = Status.undefined + memory_manager: NannyMemoryManager def __init__( self, @@ -97,7 +99,6 @@ def __init__( services=None, name=None, memory_limit="auto", - memory_terminate_fraction: float | Literal[False] | None = None, reconnect=True, validate=False, quiet=False, @@ -186,7 +187,8 @@ def __init__( config_environ = dask.config.get("distributed.nanny.environ", {}) if not isinstance(config_environ, dict): raise TypeError( - f"distributed.nanny.environ configuration must be of type dict. Instead got {type(config_environ)}" + "distributed.nanny.environ configuration must be of type dict. " + f"Instead got {type(config_environ)}" ) self.env = config_environ.copy() for k in self.env: @@ -207,19 +209,12 @@ def __init__( self.worker_kwargs = worker_kwargs self.contact_address = contact_address - self.memory_terminate_fraction = ( - memory_terminate_fraction - if memory_terminate_fraction is not None - else dask.config.get("distributed.worker.memory.terminate") - ) self.services = services self.name = name self.quiet = quiet self.auto_restart = True - self.memory_limit = parse_memory_limit(memory_limit, self.nthreads) - if silence_logs: silence_logging(level=silence_logs) self.silence_logs = silence_logs @@ -244,10 +239,7 @@ def __init__( ) self.scheduler = self.rpc(self.scheduler_addr) - - if self.memory_limit: - pc = PeriodicCallback(self.memory_monitor, 100) - self.periodic_callbacks["memory"] = pc + self.memory_manager = NannyMemoryManager(self, memory_limit=memory_limit) if ( not host @@ -265,6 +257,10 @@ def __init__( Nanny._instances.add(self) self.status = Status.init + # Deprecated attribute; use Nanny.memory_manager.memory_terminate_fraction instead + memory_limit = DeprecatedMMAccessor() + memory_terminate_fraction = DeprecatedMMAccessor() + def __repr__(self): return "" % (self.worker_address, self.nthreads) @@ -382,7 +378,7 @@ async def instantiate(self) -> Status: services=self.services, nanny=self.address, name=self.name, - memory_limit=self.memory_limit, + memory_limit=self.memory_manager.memory_limit, reconnect=self.reconnect, resources=self.resources, validate=self.validate, @@ -496,28 +492,6 @@ def _psutil_process(self): return self._psutil_process_obj - def memory_monitor(self): - """Track worker's memory. Restart if it goes above terminate fraction""" - if self.status != Status.running: - return - if self.process is None or self.process.process is None: - return None - process = self.process.process - - try: - proc = self._psutil_process - memory = proc.memory_info().rss - except (ProcessLookupError, psutil.NoSuchProcess, psutil.AccessDenied): - return - frac = memory / self.memory_limit - - if self.memory_terminate_fraction and frac > self.memory_terminate_fraction: - logger.warning( - "Worker exceeded %d%% memory budget. Restarting", - 100 * self.memory_terminate_fraction, - ) - process.terminate() - def is_alive(self): return self.process is not None and self.process.is_alive() diff --git a/distributed/scheduler.py b/distributed/scheduler.py index 59cd77419b7..e1b9ca88889 100644 --- a/distributed/scheduler.py +++ b/distributed/scheduler.py @@ -58,7 +58,6 @@ from .active_memory_manager import ActiveMemoryManagerExtension, RetireWorker from .batched import BatchedSend from .comm import ( - Comm, get_address_host, normalize_address, resolve_address, @@ -4053,9 +4052,7 @@ def identity(self): } return d - def _to_dict( - self, comm: "Comm | None" = None, *, exclude: "Container[str]" = () - ) -> dict: + def _to_dict(self, *, exclude: "Container[str]" = ()) -> dict: """Dictionary representation for debugging purposes. Not type stable and not intended for roundtrips. diff --git a/distributed/spill.py b/distributed/spill.py index b734d4bfe00..296e1231bca 100644 --- a/distributed/spill.py +++ b/distributed/spill.py @@ -7,14 +7,15 @@ from functools import partial from typing import Any, Literal, NamedTuple -import zict from packaging.version import parse as parse_version +import zict + from .protocol import deserialize_bytes, serialize_bytelist from .sizeof import safe_sizeof logger = logging.getLogger(__name__) -has_zict_210 = parse_version(zict.__version__) > parse_version("2.0.0") +has_zict_210 = parse_version(zict.__version__) >= parse_version("2.1.0") class SpilledSize(NamedTuple): @@ -62,7 +63,7 @@ def __init__( ): if max_spill is not False and not has_zict_210: - raise ValueError("zict > 2.0.0 required to set max_weight") + raise ValueError("zict >= 2.1.0 required to set max-spill") super().__init__( fast={}, diff --git a/distributed/system.py b/distributed/system.py index 2b032a34024..ad981e8b1cf 100644 --- a/distributed/system.py +++ b/distributed/system.py @@ -5,7 +5,7 @@ __all__ = ("memory_limit", "MEMORY_LIMIT") -def memory_limit(): +def memory_limit() -> int: """Get the memory limit (in bytes) for this system. Takes the minimum value from the following locations: diff --git a/distributed/tests/test_active_memory_manager.py b/distributed/tests/test_active_memory_manager.py index 8a909701e2c..8fcc3f31ced 100644 --- a/distributed/tests/test_active_memory_manager.py +++ b/distributed/tests/test_active_memory_manager.py @@ -5,6 +5,7 @@ import random from contextlib import contextmanager from time import sleep +from typing import Literal import pytest @@ -43,7 +44,13 @@ def assert_amm_log(expect: list[str]): class DemoPolicy(ActiveMemoryManagerPolicy): """Drop or replicate a key n times""" - def __init__(self, action, key, n, candidates): + def __init__( + self, + action: Literal["drop", "replicate"], + key: str, + n: int, + candidates: list[int] | None, + ): self.action = action self.key = key self.n = n @@ -63,7 +70,14 @@ def run(self): yield self.action, ts, candidates -def demo_config(action, key="x", n=10, candidates=None, start=False, interval=0.1): +def demo_config( + action: Literal["drop", "replicate"], + key: str = "x", + n: int = 10, + candidates: list[int] | None = None, + start: bool = False, + interval: float = 0.1, +): """Create a dask config for AMM with DemoPolicy""" return { "distributed.scheduler.active-memory-manager.start": start, @@ -77,6 +91,8 @@ def demo_config(action, key="x", n=10, candidates=None, start=False, interval=0. "candidates": candidates, }, ], + # If pause is required, do it manually by setting Worker.status = Status.paused + "distributed.worker.memory.pause": False, } @@ -351,7 +367,7 @@ async def test_drop_from_worker_with_least_free_memory(c, s, *nannies): @gen_cluster( nthreads=[("", 1)] * 8, client=True, - config=demo_config("drop", n=1, candidates={5, 6}), + config=demo_config("drop", n=1, candidates=[5, 6]), ) async def test_drop_with_candidates(c, s, *workers): futures = await c.scatter({"x": 1}, broadcast=True) @@ -363,7 +379,7 @@ async def test_drop_with_candidates(c, s, *workers): await asyncio.sleep(0.01) -@gen_cluster(client=True, config=demo_config("drop", candidates=set())) +@gen_cluster(client=True, config=demo_config("drop", candidates=[])) async def test_drop_with_empty_candidates(c, s, a, b): """Key is not dropped as the plugin proposes an empty set of candidates, not to be confused with None @@ -375,7 +391,9 @@ async def test_drop_with_empty_candidates(c, s, a, b): @gen_cluster( - client=True, nthreads=[("", 1)] * 3, config=demo_config("drop", candidates={2}) + client=True, + nthreads=[("", 1)] * 3, + config=demo_config("drop", candidates=[2]), ) async def test_drop_from_candidates_without_key(c, s, *workers): """Key is not dropped as none of the candidates hold a replica""" @@ -390,7 +408,7 @@ async def test_drop_from_candidates_without_key(c, s, *workers): assert s.tasks["x"].who_has == {ws0, ws1} -@gen_cluster(client=True, config=demo_config("drop", candidates={0})) +@gen_cluster(client=True, config=demo_config("drop", candidates=[0])) async def test_drop_with_bad_candidates(c, s, a, b): """Key is not dropped as all candidates hold waiter tasks""" ws0, ws1 = s.workers.values() # Not necessarily a, b; it could be b, a! @@ -404,18 +422,13 @@ async def test_drop_with_bad_candidates(c, s, a, b): assert s.tasks["x"].who_has == {ws0, ws1} -@gen_cluster( - client=True, - nthreads=[("", 1)] * 10, - config=demo_config("drop", n=1), - worker_kwargs={"memory_monitor_interval": "20ms"}, -) +@gen_cluster(client=True, nthreads=[("", 1)] * 10, config=demo_config("drop", n=1)) async def test_drop_prefers_paused_workers(c, s, *workers): x = await c.scatter({"x": 1}, broadcast=True) ts = s.tasks["x"] assert len(ts.who_has) == 10 ws = s.workers[workers[3].address] - workers[3].memory_pause_fraction = 1e-15 + workers[3].status = Status.paused while ws.status != Status.paused: await asyncio.sleep(0.01) @@ -426,11 +439,7 @@ async def test_drop_prefers_paused_workers(c, s, *workers): @pytest.mark.slow -@gen_cluster( - client=True, - config=demo_config("drop"), - worker_kwargs={"memory_monitor_interval": "20ms"}, -) +@gen_cluster(client=True, config=demo_config("drop")) async def test_drop_with_paused_workers_with_running_tasks_1(c, s, a, b): """If there is exactly 1 worker that holds a replica of a task that isn't paused or retiring, and there are 1+ paused/retiring workers with the same task, don't drop @@ -445,7 +454,7 @@ async def test_drop_with_paused_workers_with_running_tasks_1(c, s, a, b): while "y" not in a.tasks or a.tasks["y"].state != "executing": await asyncio.sleep(0.01) - a.memory_pause_fraction = 1e-15 + a.status = Status.paused while s.workers[a.address].status != Status.paused: await asyncio.sleep(0.01) assert a.tasks["y"].state == "executing" @@ -455,11 +464,7 @@ async def test_drop_with_paused_workers_with_running_tasks_1(c, s, a, b): assert len(s.tasks["x"].who_has) == 2 -@gen_cluster( - client=True, - config=demo_config("drop"), - worker_kwargs={"memory_monitor_interval": "20ms"}, -) +@gen_cluster(client=True, config=demo_config("drop")) async def test_drop_with_paused_workers_with_running_tasks_2(c, s, a, b): """If there is exactly 1 worker that holds a replica of a task that isn't paused or retiring, and there are 1+ paused/retiring workers with the same task, don't drop @@ -470,7 +475,7 @@ async def test_drop_with_paused_workers_with_running_tasks_2(c, s, a, b): b is running and has no dependent tasks """ x = (await c.scatter({"x": 1}, broadcast=True))["x"] - a.memory_pause_fraction = 1e-15 + a.status = Status.paused while s.workers[a.address].status != Status.paused: await asyncio.sleep(0.01) @@ -481,11 +486,7 @@ async def test_drop_with_paused_workers_with_running_tasks_2(c, s, a, b): @pytest.mark.slow @pytest.mark.parametrize("pause", [True, False]) -@gen_cluster( - client=True, - config=demo_config("drop"), - worker_kwargs={"memory_monitor_interval": "20ms"}, -) +@gen_cluster(client=True, config=demo_config("drop")) async def test_drop_with_paused_workers_with_running_tasks_3_4(c, s, a, b, pause): """If there is exactly 1 worker that holds a replica of a task that isn't paused or retiring, and there are 1+ paused/retiring workers with the same task, don't drop @@ -505,8 +506,8 @@ async def test_drop_with_paused_workers_with_running_tasks_3_4(c, s, a, b, pause await asyncio.sleep(0.01) if pause: - a.memory_pause_fraction = 1e-15 - b.memory_pause_fraction = 1e-15 + a.status = Status.paused + b.status = Status.paused while any(ws.status != Status.paused for ws in s.workers.values()): await asyncio.sleep(0.01) @@ -519,12 +520,7 @@ async def test_drop_with_paused_workers_with_running_tasks_3_4(c, s, a, b, pause @pytest.mark.slow -@gen_cluster( - client=True, - nthreads=[("", 1)] * 3, - config=demo_config("drop"), - worker_kwargs={"memory_monitor_interval": "20ms"}, -) +@gen_cluster(client=True, nthreads=[("", 1)] * 3, config=demo_config("drop")) async def test_drop_with_paused_workers_with_running_tasks_5(c, s, w1, w2, w3): """If there is exactly 1 worker that holds a replica of a task that isn't paused or retiring, and there are 1+ paused/retiring workers with the same task, don't drop @@ -549,7 +545,7 @@ def executing() -> bool: while not executing(): await asyncio.sleep(0.01) - w1.memory_pause_fraction = 1e-15 + w1.status = Status.paused while s.workers[w1.address].status != Status.paused: await asyncio.sleep(0.01) assert executing() @@ -635,7 +631,7 @@ async def test_replicate_to_worker_with_most_free_memory(c, s, *nannies): @gen_cluster( nthreads=[("", 1)] * 8, client=True, - config=demo_config("replicate", n=1, candidates={5, 6}), + config=demo_config("replicate", n=1, candidates=[5, 6]), ) async def test_replicate_with_candidates(c, s, *workers): wss = list(s.workers.values()) @@ -647,7 +643,7 @@ async def test_replicate_with_candidates(c, s, *workers): await asyncio.sleep(0.01) -@gen_cluster(client=True, config=demo_config("replicate", candidates=set())) +@gen_cluster(client=True, config=demo_config("replicate", candidates=[])) async def test_replicate_with_empty_candidates(c, s, a, b): """Key is not replicated as the plugin proposes an empty set of candidates, not to be confused with None @@ -658,7 +654,7 @@ async def test_replicate_with_empty_candidates(c, s, a, b): assert len(s.tasks["x"].who_has) == 1 -@gen_cluster(client=True, config=demo_config("replicate", candidates={0})) +@gen_cluster(client=True, config=demo_config("replicate", candidates=[0])) async def test_replicate_to_candidates_with_key(c, s, a, b): """Key is not replicated as all candidates already hold replicas""" ws0, ws1 = s.workers.values() # Not necessarily a, b; it could be b, a! @@ -668,14 +664,9 @@ async def test_replicate_to_candidates_with_key(c, s, a, b): assert s.tasks["x"].who_has == {ws0} -@gen_cluster( - client=True, - nthreads=[("", 1)] * 3, - config=demo_config("replicate"), - worker_kwargs={"memory_monitor_interval": "20ms"}, -) +@gen_cluster(client=True, nthreads=[("", 1)] * 3, config=demo_config("replicate")) async def test_replicate_avoids_paused_workers_1(c, s, w0, w1, w2): - w1.memory_pause_fraction = 1e-15 + w1.status = Status.paused while s.workers[w1.address].status != Status.paused: await asyncio.sleep(0.01) @@ -687,13 +678,9 @@ async def test_replicate_avoids_paused_workers_1(c, s, w0, w1, w2): assert "x" not in w1.data -@gen_cluster( - client=True, - config=demo_config("replicate"), - worker_kwargs={"memory_monitor_interval": "20ms"}, -) +@gen_cluster(client=True, config=demo_config("replicate")) async def test_replicate_avoids_paused_workers_2(c, s, a, b): - b.memory_pause_fraction = 1e-15 + b.status = Status.paused while s.workers[b.address].status != Status.paused: await asyncio.sleep(0.01) @@ -892,13 +879,14 @@ async def test_RetireWorker_no_recipients(c, s, w1, w2, w3, w4): "distributed.scheduler.active-memory-manager.start": True, "distributed.scheduler.active-memory-manager.interval": 999, "distributed.scheduler.active-memory-manager.policies": [], + "distributed.worker.memory.pause": False, }, ) async def test_RetireWorker_all_recipients_are_paused(c, s, a, b): ws_a = s.workers[a.address] ws_b = s.workers[b.address] - b.memory_pause_fraction = 1e-15 + b.status = Status.paused while ws_b.status != Status.paused: await asyncio.sleep(0.01) diff --git a/distributed/tests/test_cancelled_state.py b/distributed/tests/test_cancelled_state.py index 31062a40039..ef53e9e1ccb 100644 --- a/distributed/tests/test_cancelled_state.py +++ b/distributed/tests/test_cancelled_state.py @@ -1,10 +1,7 @@ import asyncio from unittest import mock -import pytest - import distributed -from distributed import Worker from distributed.core import CommClosedError from distributed.utils_test import _LockedCommPool, gen_cluster, inc, slowinc @@ -208,58 +205,4 @@ async def wait_and_raise(*args, **kwargs): await asyncio.sleep(0.01) # Everything should still be executing as usual after this - await c.submit(sum, c.map(inc, range(10))) == sum(map(inc, range(10))) - - -class LargeButForbiddenSerialization: - def __reduce__(self): - raise RuntimeError("I will never serialize!") - - def __sizeof__(self) -> int: - """Ensure this is immediately tried to spill""" - return 1_000_000_000_000 - - -def test_ensure_spilled_immediately(tmpdir): - """See also test_value_raises_during_spilling""" - import sys - - from distributed.spill import SpillBuffer - - mem_target = 1000 - buf = SpillBuffer(tmpdir, target=mem_target) - buf["key"] = 1 - - obj = LargeButForbiddenSerialization() - assert sys.getsizeof(obj) > mem_target - with pytest.raises( - TypeError, - match=f"Could not serialize object of type {LargeButForbiddenSerialization.__name__}", - ): - buf["error"] = obj - - -@gen_cluster(client=True, nthreads=[]) -async def test_value_raises_during_spilling(c, s): - """See also test_ensure_spilled_immediately""" - - # Use a worker with a default memory limit - async with Worker( - s.address, - ) as w: - - def produce_evil_data(): - return LargeButForbiddenSerialization() - - fut = c.submit(produce_evil_data) - - await wait_for_state(fut.key, "error", w) - - with pytest.raises( - TypeError, - match=f"Could not serialize object of type {LargeButForbiddenSerialization.__name__}", - ): - await fut - - # Everything should still be executing as usual after this - await c.submit(sum, c.map(inc, range(10))) == sum(map(inc, range(10))) + assert await c.submit(sum, c.map(inc, range(10))) == sum(map(inc, range(10))) diff --git a/distributed/tests/test_client.py b/distributed/tests/test_client.py index b2f32d48ced..f2b28d42235 100644 --- a/distributed/tests/test_client.py +++ b/distributed/tests/test_client.py @@ -5877,14 +5877,14 @@ def bad_fn(x): @gen_cluster( client=True, nthreads=[("", 1)] * 10, - worker_kwargs={"memory_monitor_interval": "20ms"}, + config={"distributed.worker.memory.pause": False}, ) async def test_scatter_and_replicate_avoid_paused_workers( c, s, *workers, workers_arg, direct, broadcast ): paused_workers = [w for i, w in enumerate(workers) if i not in (3, 7)] for w in paused_workers: - w.memory_pause_fraction = 1e-15 + w.status = Status.paused while any(s.workers[w.address].status != Status.paused for w in paused_workers): await asyncio.sleep(0.01) diff --git a/distributed/tests/test_nanny.py b/distributed/tests/test_nanny.py index afc2dba7cd3..934e52c4fe7 100644 --- a/distributed/tests/test_nanny.py +++ b/distributed/tests/test_nanny.py @@ -6,7 +6,6 @@ import random import sys from contextlib import suppress -from time import sleep from unittest import mock import psutil @@ -27,7 +26,7 @@ from distributed.metrics import time from distributed.protocol.pickle import dumps from distributed.utils import TimeoutError, parse_ports -from distributed.utils_test import captured_logger, gen_cluster, gen_test, inc +from distributed.utils_test import captured_logger, gen_cluster, gen_test pytestmark = pytest.mark.ci1 @@ -265,55 +264,25 @@ async def test_nanny_timeout(c, s, a): @gen_cluster( - nthreads=[("127.0.0.1", 1)], - client=True, - Worker=Nanny, - worker_kwargs={"memory_limit": "400 MiB"}, -) -async def test_nanny_terminate(c, s, a): - def leak(): - L = [] - while True: - L.append(b"0" * 5_000_000) - sleep(0.01) - - before = a.process.pid - with captured_logger(logging.getLogger("distributed.nanny")) as logger: - future = c.submit(leak) - while a.process.pid == before: - await asyncio.sleep(0.01) - - out = logger.getvalue() - assert "restart" in out.lower() - assert "memory" in out.lower() - - -@gen_cluster( - nthreads=[("127.0.0.1", 1)] * 8, + nthreads=[("", 1)] * 8, client=True, - Worker=Worker, clean_kwargs={"threads": False}, + config={"distributed.worker.memory.pause": False}, ) -async def test_throttle_outgoing_connections(c, s, a, *workers): - # But a bunch of small data on worker a - await c.run(lambda: logging.getLogger("distributed.worker").setLevel(logging.DEBUG)) +async def test_throttle_outgoing_connections(c, s, a, *other_workers): + # Put a bunch of small data on worker a + logging.getLogger("distributed.worker").setLevel(logging.DEBUG) remote_data = c.map( lambda x: b"0" * 10000, range(10), pure=False, workers=[a.address] ) await wait(remote_data) - def pause(dask_worker): - # Patch paused and memory_monitor on the one worker - # This is is very fragile, since a refactor of memory_monitor to - # remove _memory_monitoring will break this test. - dask_worker._memory_monitoring = True - dask_worker.status = Status.paused - dask_worker.outgoing_current_count = 2 + a.status = Status.paused + a.outgoing_current_count = 2 - await c.run(pause, workers=[a.address]) requests = [ await a.get_data(await w.rpc.connect(w.address), keys=[f.key], who=w.address) - for w in workers + for w in other_workers for f in remote_data ] await wait(requests) @@ -322,36 +291,13 @@ def pause(dask_worker): assert "throttling" in wlogs.lower() -@gen_cluster(nthreads=[], client=True) -async def test_avoid_memory_monitor_if_zero_limit(c, s): - nanny = await Nanny(s.address, loop=s.loop, memory_limit=0) - typ = await c.run(lambda dask_worker: type(dask_worker.data)) - assert typ == {nanny.worker_address: dict} - pcs = await c.run(lambda dask_worker: list(dask_worker.periodic_callbacks)) - assert "memory" not in pcs - assert "memory" not in nanny.periodic_callbacks - - future = c.submit(inc, 1) - assert await future == 2 - await asyncio.sleep(0.02) - - await c.submit(inc, 2) # worker doesn't pause - - await nanny.close() - - -@gen_cluster(nthreads=[], client=True) -async def test_scheduler_address_config(c, s): +@gen_cluster(nthreads=[]) +async def test_scheduler_address_config(s): with dask.config.set({"scheduler-address": s.address}): - nanny = await Nanny(loop=s.loop) - assert nanny.scheduler.address == s.address - - start = time() - while not s.workers: - await asyncio.sleep(0.1) - assert time() < start + 10 - - await nanny.close() + async with Nanny() as nanny: + assert nanny.scheduler.address == s.address + while not s.workers: + await asyncio.sleep(0.01) @pytest.mark.slow @@ -421,14 +367,6 @@ async def test_environment_variable_config(c, s, monkeypatch): assert results[n.worker_address]["D"] == "123" -@gen_cluster(nthreads=[], client=True) -async def test_data_types(c, s): - w = await Nanny(s.address, data=dict) - r = await c.run(lambda dask_worker: type(dask_worker.data)) - assert r[w.worker_address] == dict - await w.close() - - @gen_cluster(nthreads=[]) async def test_local_directory(s): with tmpfile() as fn: diff --git a/distributed/tests/test_scheduler.py b/distributed/tests/test_scheduler.py index 6d5081ef0bd..c0f9b84aad5 100644 --- a/distributed/tests/test_scheduler.py +++ b/distributed/tests/test_scheduler.py @@ -19,7 +19,15 @@ from dask import delayed from dask.utils import apply, parse_timedelta, stringify, tmpfile, typename -from distributed import Client, Nanny, Worker, fire_and_forget, wait +from distributed import ( + Client, + Lock, + Nanny, + SchedulerPlugin, + Worker, + fire_and_forget, + wait, +) from distributed.compatibility import LINUX, WINDOWS from distributed.core import ConnectionPool, Status, clean_exception, connect, rpc from distributed.metrics import time @@ -2178,7 +2186,6 @@ async def test_gather_allow_worker_reconnect( """ # GH3246 if reschedule_different_worker: - from distributed.diagnostics.plugin import SchedulerPlugin class SwitchRestrictions(SchedulerPlugin): def __init__(self, scheduler): @@ -2191,8 +2198,6 @@ def transition(self, key, start, finish, **kwargs): plugin = SwitchRestrictions(s) s.add_plugin(plugin) - from distributed import Lock - b_address = b.address def inc_slow(x, lock): @@ -2214,8 +2219,9 @@ def reducer(*args): def finalizer(addr): if swap_data_insert_order: w = get_worker() - new_data = {k: w.data[k] for k in list(w.data.keys())[::-1]} - w.data = new_data + new_data = dict(reversed(list(w.data.items()))) + w.data.clear() + w.data.update(new_data) return addr z = c.submit(reducer, x, key="reducer", workers=[a.address]) @@ -3288,10 +3294,10 @@ async def test_set_restrictions(c, s, a, b): @gen_cluster( client=True, nthreads=[("", 1)] * 3, - worker_kwargs={"memory_monitor_interval": "20ms"}, + config={"distributed.worker.memory.pause": False}, ) async def test_avoid_paused_workers(c, s, w1, w2, w3): - w2.memory_pause_fraction = 1e-15 + w2.status = Status.paused while s.workers[w2.address].status != Status.paused: await asyncio.sleep(0.01) futures = c.map(slowinc, range(8), delay=0.1) @@ -3302,25 +3308,6 @@ async def test_avoid_paused_workers(c, s, w1, w2, w3): assert len(w1.data) + len(w3.data) == 8 -@gen_cluster( - client=True, - nthreads=[("", 1)], - worker_kwargs={"memory_monitor_interval": "20ms"}, -) -async def test_unpause_schedules_unrannable_tasks(c, s, a): - a.memory_pause_fraction = 1e-15 - while s.workers[a.address].status != Status.paused: - await asyncio.sleep(0.01) - - fut = c.submit(inc, 1, key="x") - while not s.unrunnable: - await asyncio.sleep(0.001) - assert next(iter(s.unrunnable)).key == "x" - - a.memory_pause_fraction = 0.8 - assert await fut == 2 - - @gen_cluster(client=True, nthreads=[("", 1)]) async def test_Scheduler__to_dict(c, s, a): futs = c.map(inc, range(2)) @@ -3403,9 +3390,6 @@ async def test_TaskState__to_dict(c, s): @gen_cluster(nthreads=[]) async def test_idempotent_plugins(s): - - from distributed.diagnostics.plugin import SchedulerPlugin - class IdempotentPlugin(SchedulerPlugin): def __init__(self, instance=None): self.name = "idempotentplugin" @@ -3429,9 +3413,6 @@ def start(self, scheduler): @gen_cluster(nthreads=[]) async def test_non_idempotent_plugins(s): - - from distributed.diagnostics.plugin import SchedulerPlugin - class NonIdempotentPlugin(SchedulerPlugin): def __init__(self, instance=None): self.name = "nonidempotentplugin" diff --git a/distributed/tests/test_spill.py b/distributed/tests/test_spill.py index c30aa6cefc6..55bbb6ad8ad 100644 --- a/distributed/tests/test_spill.py +++ b/distributed/tests/test_spill.py @@ -5,16 +5,18 @@ import pytest -zict = pytest.importorskip("zict") -from packaging.version import parse as parse_version - from dask.sizeof import sizeof from distributed.compatibility import WINDOWS from distributed.protocol import serialize_bytelist -from distributed.spill import SpillBuffer +from distributed.spill import SpillBuffer, has_zict_210 from distributed.utils_test import captured_logger +requires_zict_210 = pytest.mark.skipif( + not has_zict_210, + reason="requires zict version >= 2.1.0", +) + def psize(*objs) -> tuple[int, int]: return ( @@ -105,12 +107,6 @@ def test_spillbuffer(tmpdir): assert buf.slow.total_weight == psize(d, e) -requires_zict_210 = pytest.mark.skipif( - parse_version(zict.__version__) <= parse_version("2.0.0"), - reason="requires zict version > 2.0.0", -) - - @requires_zict_210 def test_spillbuffer_maxlim(tmpdir): buf = SpillBuffer(str(tmpdir), target=200, max_spill=600, min_log_interval=0) diff --git a/distributed/tests/test_steal.py b/distributed/tests/test_steal.py index 2a6a220bb97..e6469eae8b4 100644 --- a/distributed/tests/test_steal.py +++ b/distributed/tests/test_steal.py @@ -834,10 +834,10 @@ async def test_steal_twice(c, s, a, b): @gen_cluster( client=True, nthreads=[("", 1)] * 3, - worker_kwargs={"memory_monitor_interval": "20ms"}, + config={"distributed.worker.memory.pause": False}, ) async def test_paused_workers_must_not_steal(c, s, w1, w2, w3): - w2.memory_pause_fraction = 1e-15 + w2.status = Status.paused while s.workers[w2.address].status != Status.paused: await asyncio.sleep(0.01) diff --git a/distributed/tests/test_worker.py b/distributed/tests/test_worker.py index 2cb36819713..7f92abe67c2 100644 --- a/distributed/tests/test_worker.py +++ b/distributed/tests/test_worker.py @@ -16,7 +16,6 @@ import psutil import pytest -from packaging.version import parse as parse_version from tlz import first, pluck, sliding_window import dask @@ -27,7 +26,6 @@ import distributed from distributed import ( Client, - Event, Nanny, Reschedule, default_client, @@ -59,28 +57,10 @@ slowinc, slowsum, ) -from distributed.worker import ( - TaskState, - UniqueTaskHeap, - Worker, - error_message, - logger, - parse_memory_limit, -) +from distributed.worker import TaskState, UniqueTaskHeap, Worker, error_message, logger pytestmark = pytest.mark.ci1 -try: - import zict -except ImportError: - zict = None - -requires_zict = pytest.mark.skipif(not zict, reason="requires zict") -requires_zict_210 = pytest.mark.skipif( - not zict or parse_version(zict.__version__) <= parse_version("2.0.0"), - reason="requires zict version > 2.0.0", -) - @gen_cluster(nthreads=[]) async def test_worker_nthreads(s): @@ -912,109 +892,6 @@ def __sizeof__(self): assert result.data == 123 -class FailToPickle: - def __init__(self, *, reported_size=0, actual_size=0): - self.reported_size = int(reported_size) - self.data = "x" * int(actual_size) - - def __getstate__(self): - raise TypeError() - - def __sizeof__(self): - return self.reported_size - - -async def assert_basic_futures(c: Client) -> None: - futures = c.map(inc, range(10)) - results = await c.gather(futures) - assert results == list(map(inc, range(10))) - - -@requires_zict -@gen_cluster(client=True) -async def test_fail_write_to_disk_target_1(c, s, a, b): - """Test failure to spill triggered by key which is individually larger - than target. The data is lost and the task is marked as failed; - the worker remains in usable condition. - """ - future = c.submit(FailToPickle, reported_size=100e9) - await wait(future) - - assert future.status == "error" - - with pytest.raises(TypeError, match="Could not serialize"): - await future - - await assert_basic_futures(c) - - -@requires_zict -@gen_cluster( - client=True, - nthreads=[("", 1)], - worker_kwargs=dict( - memory_limit="1 kiB", - memory_target_fraction=0.5, - memory_spill_fraction=False, - memory_pause_fraction=False, - ), -) -async def test_fail_write_to_disk_target_2(c, s, a): - """Test failure to spill triggered by key which is individually smaller - than target, so it is not spilled immediately. The data is retained and - the task is NOT marked as failed; the worker remains in usable condition. - """ - x = c.submit(FailToPickle, reported_size=256, key="x") - await wait(x) - assert x.status == "finished" - assert set(a.data.memory) == {"x"} - - y = c.submit(lambda: "y" * 256, key="y") - await wait(y) - if parse_version(zict.__version__) <= parse_version("2.0.0"): - assert set(a.data.memory) == {"y"} - else: - assert set(a.data.memory) == {"x", "y"} - assert not a.data.disk - - await assert_basic_futures(c) - - -@requires_zict_210 -@gen_cluster( - client=True, - nthreads=[("", 1)], - worker_kwargs=dict( - memory_monitor_interval="10ms", - memory_limit="1 kiB", # Spill everything - memory_target_fraction=False, - memory_spill_fraction=0.7, - memory_pause_fraction=False, - ), -) -async def test_fail_write_to_disk_spill(c, s, a): - """Test failure to evict a key, triggered by the spill threshold""" - with captured_logger(logging.getLogger("distributed.spill")) as logs: - bad = c.submit(FailToPickle, actual_size=1_000_000, key="bad") - await wait(bad) - - # Must wait for memory monitor to kick in - while True: - logs_value = logs.getvalue() - if logs_value: - break - await asyncio.sleep(0.01) - - assert "Failed to pickle" in logs_value - assert "Traceback" in logs_value - - # key is in fast - assert bad.status == "finished" - assert bad.key in a.data.fast - - await assert_basic_futures(c) - - @gen_cluster() async def test_pid(s, a, b): assert s.workers[a.address].pid == os.getpid() @@ -1193,245 +1070,6 @@ async def test_statistical_profiling_2(c, s, a, b): break -@requires_zict -@gen_cluster( - client=True, - nthreads=[("", 1)], - worker_kwargs=dict( - memory_limit=1200 / 0.6, - memory_target_fraction=0.6, - memory_spill_fraction=False, - memory_pause_fraction=False, - ), -) -async def test_spill_target_threshold(c, s, a): - """Test distributed.worker.memory.target threshold. Note that in this test we - disabled spill and pause thresholds, which work on the process memory, and just left - the target threshold, which works on managed memory so it is unperturbed by the - several hundreds of MB of unmanaged memory that are typical of the test suite. - """ - x = c.submit(lambda: "x" * 500, key="x") - await wait(x) - y = c.submit(lambda: "y" * 500, key="y") - await wait(y) - - assert set(a.data) == {"x", "y"} - assert set(a.data.memory) == {"x", "y"} - - z = c.submit(lambda: "z" * 500, key="z") - await wait(z) - assert set(a.data) == {"x", "y", "z"} - assert set(a.data.memory) == {"y", "z"} - assert set(a.data.disk) == {"x"} - - await x - assert set(a.data.memory) == {"x", "z"} - assert set(a.data.disk) == {"y"} - - -@requires_zict_210 -@gen_cluster( - client=True, - nthreads=[("", 1)], - worker_kwargs=dict( - memory_limit=1600, - max_spill=600, - memory_target_fraction=0.6, - memory_spill_fraction=False, - memory_pause_fraction=False, - ), -) -async def test_spill_constrained(c, s, w): - """Test distributed.worker.memory.max-spill parameter""" - # spills starts at 1600*0.6=960 bytes of managed memory - - # size in memory ~200; size on disk ~400 - x = c.submit(lambda: "x" * 200, key="x") - await wait(x) - # size in memory ~500; size on disk ~700 - y = c.submit(lambda: "y" * 500, key="y") - await wait(y) - - assert set(w.data) == {x.key, y.key} - assert set(w.data.memory) == {x.key, y.key} - - z = c.submit(lambda: "z" * 500, key="z") - await wait(z) - - assert set(w.data) == {x.key, y.key, z.key} - - # max_spill has not been reached - assert set(w.data.memory) == {y.key, z.key} - assert set(w.data.disk) == {x.key} - - # zb is individually larger than max_spill - zb = c.submit(lambda: "z" * 1700, key="zb") - await wait(zb) - - assert set(w.data.memory) == {y.key, z.key, zb.key} - assert set(w.data.disk) == {x.key} - - del zb - while "zb" in w.data: - await asyncio.sleep(0.01) - - # zc is individually smaller than max_spill, but the evicted key together with - # x it exceeds max_spill - zc = c.submit(lambda: "z" * 500, key="zc") - await wait(zc) - assert set(w.data.memory) == {y.key, z.key, zc.key} - assert set(w.data.disk) == {x.key} - - -@requires_zict -@gen_cluster( - nthreads=[("", 1)], - client=True, - worker_kwargs=dict( - memory_limit="1000 MB", - memory_monitor_interval="10ms", - memory_target_fraction=False, - memory_spill_fraction=0.7, - memory_pause_fraction=False, - ), -) -async def test_spill_spill_threshold(c, s, a): - """Test distributed.worker.memory.spill threshold. - Test that the spill threshold uses the process memory and not the managed memory - reported by sizeof(), which may be inaccurate. - """ - a.monitor.get_process_memory = lambda: 800_000_000 if a.data.fast else 0 - x = c.submit(inc, 0, key="x") - while not a.data.disk: - await asyncio.sleep(0.01) - assert await x == 1 - - -@requires_zict -@pytest.mark.parametrize( - "memory_target_fraction,managed,expect_spilled", - [ - # no target -> no hysteresis - # Over-report managed memory to test that the automated LRU eviction based on - # target is never triggered - (False, int(10e9), 1), - # Under-report managed memory, so that we reach the spill threshold for process - # memory without first reaching the target threshold for managed memory - # target == spill -> no hysteresis - (0.7, 0, 1), - # target < spill -> hysteresis from spill to target - (0.4, 0, 7), - ], -) -@gen_cluster(nthreads=[], client=True) -async def test_spill_hysteresis(c, s, memory_target_fraction, managed, expect_spilled): - """ - 1. Test that you can enable the spill threshold while leaving the target threshold - to False - 2. Test the hysteresis system where, once you reach the spill threshold, the worker - won't stop spilling until the target threshold is reached - """ - - class C: - def __sizeof__(self): - return managed - - async with Worker( - s.address, - memory_limit="1000 MB", - memory_monitor_interval="10ms", - memory_target_fraction=memory_target_fraction, - memory_spill_fraction=0.7, - memory_pause_fraction=False, - ) as a: - a.monitor.get_process_memory = lambda: 50_000_000 * len(a.data.fast) - - # Add 500MB (reported) process memory. Spilling must not happen. - futures = [c.submit(C, pure=False) for _ in range(10)] - await wait(futures) - await asyncio.sleep(0.1) - assert not a.data.disk - - # Add another 250MB unmanaged memory. This must trigger the spilling. - futures += [c.submit(C, pure=False) for _ in range(5)] - await wait(futures) - - # Wait until spilling starts. Then, wait until it stops. - prev_n = 0 - while not a.data.disk or len(a.data.disk) > prev_n: - prev_n = len(a.data.disk) - await asyncio.sleep(0) - - assert len(a.data.disk) == expect_spilled - - -@gen_cluster( - nthreads=[("", 1)], - client=True, - worker_kwargs=dict( - memory_limit="1000 MB", - memory_monitor_interval="10ms", - memory_target_fraction=False, - memory_spill_fraction=False, - memory_pause_fraction=0.8, - ), -) -async def test_pause_executor(c, s, a): - mocked_rss = 0 - a.monitor.get_process_memory = lambda: mocked_rss - - # Task that is running when the worker pauses - ev_x = Event() - - def f(ev): - ev.wait() - return 1 - - x = c.submit(f, ev_x, key="x") - while a.executing_count != 1: - await asyncio.sleep(0.01) - - with captured_logger(logging.getLogger("distributed.worker")) as logger: - # Task that is queued on the worker when the worker pauses - y = c.submit(inc, 1, key="y") - while "y" not in a.tasks: - await asyncio.sleep(0.01) - - # Hog the worker with 900MB unmanaged memory - mocked_rss = 900_000_000 - while s.workers[a.address].status != Status.paused: - await asyncio.sleep(0.01) - - assert "Pausing worker" in logger.getvalue() - - # Task that is queued on the scheduler when the worker pauses. - # It is not sent to the worker. - z = c.submit(inc, 2, key="z") - while "z" not in s.tasks or s.tasks["z"].state != "no-worker": - await asyncio.sleep(0.01) - - # Test that a task that already started when the worker paused can complete - # and its output can be retrieved. Also test that the now free slot won't be - # used by other tasks. - await ev_x.set() - assert await x == 1 - await asyncio.sleep(0.05) - - assert a.executing_count == 0 - assert len(a.ready) == 1 - assert a.tasks["y"].state == "ready" - assert "z" not in a.tasks - - # Release the memory. Tasks that were queued on the worker are executed. - # Tasks that were stuck on the scheduler are sent to the worker and executed. - mocked_rss = 0 - assert await y == 2 - assert await z == 3 - - assert a.status == Status.running - assert "Resuming worker" in logger.getvalue() - - @gen_cluster(client=True, worker_kwargs={"profile_cycle_interval": "50 ms"}) async def test_statistical_profiling_cycle(c, s, a, b): futures = c.map(slowinc, range(20), delay=0.05) @@ -1492,31 +1130,6 @@ async def test_deque_handler(s): assert any(msg.msg == "foo456" for msg in deque_handler.deque) -@gen_cluster( - client=True, - nthreads=[("", 1)], - worker_kwargs={"memory_limit": 0, "memory_monitor_interval": "10ms"}, -) -async def test_avoid_memory_monitor_if_zero_limit(c, s, a): - assert type(a.data) is dict - assert "memory" not in a.periodic_callbacks - future = c.submit(inc, 1) - assert (await future) == 2 - await asyncio.sleep(0.05) - await c.submit(inc, 2) # worker doesn't pause - - -@gen_cluster( - nthreads=[("127.0.0.1", 1)], - config={ - "distributed.worker.memory.spill": False, - "distributed.worker.memory.target": False, - }, -) -async def test_dict_data_if_no_spill_to_disk(s, w): - assert type(w.data) is dict - - def test_get_worker_name(client): def f(): get_client().submit(inc, 1).result() @@ -1532,11 +1145,6 @@ def func(dask_scheduler): assert time() < start + 10 -@gen_cluster(nthreads=[("127.0.0.1", 1)], worker_kwargs={"memory_limit": "2e3 MB"}) -async def test_parse_memory_limit(s, w): - assert w.memory_limit == 2e9 - - @gen_cluster(nthreads=[], client=True) async def test_scheduler_address_config(c, s): with dask.config.set({"scheduler-address": s.address}): @@ -1676,28 +1284,6 @@ async def test_register_worker_callbacks_err(c, s, a, b): await c.register_worker_callbacks(setup=lambda: 1 / 0) -@gen_cluster(nthreads=[]) -async def test_data_types(s): - w = await Worker(s.address, data=dict) - assert isinstance(w.data, dict) - await w.close() - - data = dict() - w = await Worker(s.address, data=data) - assert w.data is data - await w.close() - - class Data(dict): - def __init__(self, x, y): - self.x = x - self.y = y - - w = await Worker(s.address, data=(Data, {"x": 123, "y": 456})) - assert w.data.x == 123 - assert w.data.y == 456 - await w.close() - - @gen_cluster(nthreads=[]) async def test_local_directory(s): with tmpfile() as fn: @@ -1729,16 +1315,6 @@ async def test_host_address(c, s): await n.close() -def test_resource_limit(monkeypatch): - assert parse_memory_limit("250MiB", 1, total_cores=1) == 1024 * 1024 * 250 - - new_limit = 1024 * 1024 * 200 - import distributed.worker - - monkeypatch.setattr(distributed.system, "MEMORY_LIMIT", new_limit) - assert parse_memory_limit("250MiB", 1, total_cores=1) == new_limit - - @pytest.mark.asyncio @pytest.mark.parametrize("Worker", [Worker, Nanny]) async def test_interface_async(cleanup, loop, Worker): @@ -3396,38 +2972,18 @@ async def test_missing_released_zombie_tasks_2(c, s, a, b): ) -@pytest.mark.slow -@gen_cluster( - client=True, - Worker=Nanny, - nthreads=[("", 1)], - config={"distributed.worker.memory.pause": 0.5}, - worker_kwargs={"memory_limit": 2**29}, # 500 MiB -) -async def test_worker_status_sync(c, s, a): - (ws,) = s.workers.values() - - while ws.status != Status.running: - await asyncio.sleep(0.01) - - def leak(): - distributed._test_leak = "x" * 2**28 # 250 MiB - - def clear_leak(): - del distributed._test_leak - - await c.run(leak) - +@gen_cluster(nthreads=[("", 1)], config={"distributed.worker.memory.pause": False}) +async def test_worker_status_sync(s, a): + ws = s.workers[a.address] + a.status = Status.paused while ws.status != Status.paused: await asyncio.sleep(0.01) - await c.run(clear_leak) - + a.status = Status.running while ws.status != Status.running: await asyncio.sleep(0.01) await s.retire_workers() - while ws.status != Status.closed: await asyncio.sleep(0.01) @@ -3716,12 +3272,11 @@ async def test_Worker__to_dict(c, s, a): x = c.submit(inc, 1, key="x") await wait(x) d = a._to_dict() - assert d.keys() == { + assert set(d) == { "type", "id", "scheduler", "nthreads", - "memory_limit", "address", "status", "thread_id", @@ -3733,17 +3288,23 @@ async def test_Worker__to_dict(c, s, a): "in_flight_workers", "log", "tasks", - "memory_target_fraction", - "memory_spill_fraction", - "memory_pause_fraction", "logs", "config", "incoming_transfer_log", "outgoing_transfer_log", "data_needed", "pending_data_per_worker", + # attributes of WorkerMemoryManager + "data", + "max_spill", + "memory_limit", + "memory_monitor_interval", + "memory_pause_fraction", + "memory_spill_fraction", + "memory_target_fraction", } assert d["tasks"]["x"]["key"] == "x" + assert d["data"] == ["x"] @gen_cluster(client=True, nthreads=[("", 1)]) diff --git a/distributed/tests/test_worker_memory.py b/distributed/tests/test_worker_memory.py new file mode 100644 index 00000000000..094cbe15a75 --- /dev/null +++ b/distributed/tests/test_worker_memory.py @@ -0,0 +1,620 @@ +from __future__ import annotations + +import asyncio +import logging +from time import sleep + +import pytest + +import dask.config + +import distributed.system +from distributed import Client, Event, Nanny, Worker, wait +from distributed.core import Status +from distributed.spill import has_zict_210 +from distributed.utils_test import captured_logger, gen_cluster, inc +from distributed.worker_memory import parse_memory_limit + +requires_zict_210 = pytest.mark.skipif( + not has_zict_210, + reason="requires zict version >= 2.1.0", +) + + +def memory_monitor_running(dask_worker: Worker | Nanny) -> bool: + return "memory_monitor" in dask_worker.periodic_callbacks + + +def test_parse_memory_limit_zero(): + assert parse_memory_limit(0, 1) is None + assert parse_memory_limit("0", 1) is None + assert parse_memory_limit(None, 1) is None + + +def test_resource_limit(monkeypatch): + assert parse_memory_limit("250MiB", 1, total_cores=1) == 1024 * 1024 * 250 + + new_limit = 1024 * 1024 * 200 + monkeypatch.setattr(distributed.system, "MEMORY_LIMIT", new_limit) + assert parse_memory_limit("250MiB", 1, total_cores=1) == new_limit + + +@gen_cluster(nthreads=[("", 1)], worker_kwargs={"memory_limit": "2e3 MB"}) +async def test_parse_memory_limit_worker(s, w): + assert w.memory_manager.memory_limit == 2e9 + + +@gen_cluster( + client=True, + nthreads=[("", 1)], + Worker=Nanny, + worker_kwargs={"memory_limit": "2e3 MB"}, +) +async def test_parse_memory_limit_nanny(c, s, n): + assert n.memory_manager.memory_limit == 2e9 + out = await c.run(lambda dask_worker: dask_worker.memory_manager.memory_limit) + assert out[n.worker_address] == 2e9 + + +@gen_cluster( + nthreads=[("127.0.0.1", 1)], + config={ + "distributed.worker.memory.spill": False, + "distributed.worker.memory.target": False, + }, +) +async def test_dict_data_if_no_spill_to_disk(s, w): + assert type(w.data) is dict + + +class FailToPickle: + def __init__(self, *, reported_size=0, actual_size=0): + self.reported_size = int(reported_size) + self.data = "x" * int(actual_size) + + def __getstate__(self): + raise TypeError() + + def __sizeof__(self): + return self.reported_size + + +async def assert_basic_futures(c: Client) -> None: + futures = c.map(inc, range(10)) + results = await c.gather(futures) + assert results == list(map(inc, range(10))) + + +@gen_cluster(client=True) +async def test_fail_write_to_disk_target_1(c, s, a, b): + """Test failure to spill triggered by key which is individually larger + than target. The data is lost and the task is marked as failed; + the worker remains in usable condition. + """ + future = c.submit(FailToPickle, reported_size=100e9) + await wait(future) + + assert future.status == "error" + + with pytest.raises(TypeError, match="Could not serialize"): + await future + + await assert_basic_futures(c) + + +@gen_cluster( + client=True, + nthreads=[("", 1)], + worker_kwargs={"memory_limit": "1 kiB"}, + config={ + "distributed.worker.memory.target": 0.5, + "distributed.worker.memory.spill": False, + "distributed.worker.memory.pause": False, + }, +) +async def test_fail_write_to_disk_target_2(c, s, a): + """Test failure to spill triggered by key which is individually smaller + than target, so it is not spilled immediately. The data is retained and + the task is NOT marked as failed; the worker remains in usable condition. + """ + x = c.submit(FailToPickle, reported_size=256, key="x") + await wait(x) + assert x.status == "finished" + assert set(a.data.memory) == {"x"} + + y = c.submit(lambda: "y" * 256, key="y") + await wait(y) + if has_zict_210: + assert set(a.data.memory) == {"x", "y"} + else: + assert set(a.data.memory) == {"y"} + + assert not a.data.disk + + await assert_basic_futures(c) + + +@requires_zict_210 +@gen_cluster( + client=True, + nthreads=[("", 1)], + worker_kwargs={"memory_limit": "1 kiB"}, # Spill everything + config={ + "distributed.worker.memory.target": False, + "distributed.worker.memory.spill": 0.7, + "distributed.worker.memory.pause": False, + "distributed.worker.memory.monitor-interval.spill-pause": "10ms", + }, +) +async def test_fail_write_to_disk_spill(c, s, a): + """Test failure to evict a key, triggered by the spill threshold""" + with captured_logger(logging.getLogger("distributed.spill")) as logs: + bad = c.submit(FailToPickle, actual_size=1_000_000, key="bad") + await wait(bad) + + # Must wait for memory monitor to kick in + while True: + logs_value = logs.getvalue() + if logs_value: + break + await asyncio.sleep(0.01) + + assert "Failed to pickle" in logs_value + assert "Traceback" in logs_value + + # key is in fast + assert bad.status == "finished" + assert bad.key in a.data.fast + + await assert_basic_futures(c) + + +@gen_cluster( + client=True, + nthreads=[("", 1)], + worker_kwargs={"memory_limit": 1200 / 0.6}, + config={ + "distributed.worker.memory.target": 0.6, + "distributed.worker.memory.spill": False, + "distributed.worker.memory.pause": False, + }, +) +async def test_spill_target_threshold(c, s, a): + """Test distributed.worker.memory.target threshold. Note that in this test we + disabled spill and pause thresholds, which work on the process memory, and just left + the target threshold, which works on managed memory so it is unperturbed by the + several hundreds of MB of unmanaged memory that are typical of the test suite. + """ + assert not memory_monitor_running(a) + + x = c.submit(lambda: "x" * 500, key="x") + await wait(x) + y = c.submit(lambda: "y" * 500, key="y") + await wait(y) + + assert set(a.data) == {"x", "y"} + assert set(a.data.memory) == {"x", "y"} + + z = c.submit(lambda: "z" * 500, key="z") + await wait(z) + assert set(a.data) == {"x", "y", "z"} + assert set(a.data.memory) == {"y", "z"} + assert set(a.data.disk) == {"x"} + + await x + assert set(a.data.memory) == {"x", "z"} + assert set(a.data.disk) == {"y"} + + +@requires_zict_210 +@gen_cluster( + client=True, + nthreads=[("", 1)], + worker_kwargs={"memory_limit": 1600}, + config={ + "distributed.worker.memory.target": 0.6, + "distributed.worker.memory.spill": False, + "distributed.worker.memory.pause": False, + "distributed.worker.memory.max-spill": 600, + }, +) +async def test_spill_constrained(c, s, w): + """Test distributed.worker.memory.max-spill parameter""" + # spills starts at 1600*0.6=960 bytes of managed memory + + # size in memory ~200; size on disk ~400 + x = c.submit(lambda: "x" * 200, key="x") + await wait(x) + # size in memory ~500; size on disk ~700 + y = c.submit(lambda: "y" * 500, key="y") + await wait(y) + + assert set(w.data) == {x.key, y.key} + assert set(w.data.memory) == {x.key, y.key} + + z = c.submit(lambda: "z" * 500, key="z") + await wait(z) + + assert set(w.data) == {x.key, y.key, z.key} + + # max_spill has not been reached + assert set(w.data.memory) == {y.key, z.key} + assert set(w.data.disk) == {x.key} + + # zb is individually larger than max_spill + zb = c.submit(lambda: "z" * 1700, key="zb") + await wait(zb) + + assert set(w.data.memory) == {y.key, z.key, zb.key} + assert set(w.data.disk) == {x.key} + + del zb + while "zb" in w.data: + await asyncio.sleep(0.01) + + # zc is individually smaller than max_spill, but the evicted key together with + # x it exceeds max_spill + zc = c.submit(lambda: "z" * 500, key="zc") + await wait(zc) + assert set(w.data.memory) == {y.key, z.key, zc.key} + assert set(w.data.disk) == {x.key} + + +@gen_cluster( + nthreads=[("", 1)], + client=True, + worker_kwargs={"memory_limit": "1000 MB"}, + config={ + "distributed.worker.memory.target": False, + "distributed.worker.memory.spill": 0.7, + "distributed.worker.memory.pause": False, + "distributed.worker.memory.monitor-interval.spill-pause": "10ms", + }, +) +async def test_spill_spill_threshold(c, s, a): + """Test distributed.worker.memory.spill threshold. + Test that the spill threshold uses the process memory and not the managed memory + reported by sizeof(), which may be inaccurate. + """ + assert memory_monitor_running(a) + a.monitor.get_process_memory = lambda: 800_000_000 if a.data.fast else 0 + x = c.submit(inc, 0, key="x") + while not a.data.disk: + await asyncio.sleep(0.01) + assert await x == 1 + + +@pytest.mark.parametrize( + "target,managed,expect_spilled", + [ + # no target -> no hysteresis + # Over-report managed memory to test that the automated LRU eviction based on + # target is never triggered + (False, int(10e9), 1), + # Under-report managed memory, so that we reach the spill threshold for process + # memory without first reaching the target threshold for managed memory + # target == spill -> no hysteresis + (0.7, 0, 1), + # target < spill -> hysteresis from spill to target + (0.4, 0, 7), + ], +) +@gen_cluster( + nthreads=[], + client=True, + config={ + "distributed.worker.memory.spill": 0.7, + "distributed.worker.memory.pause": False, + "distributed.worker.memory.monitor-interval.spill-pause": "10ms", + }, +) +async def test_spill_hysteresis(c, s, target, managed, expect_spilled): + """ + 1. Test that you can enable the spill threshold while leaving the target threshold + to False + 2. Test the hysteresis system where, once you reach the spill threshold, the worker + won't stop spilling until the target threshold is reached + """ + + class C: + def __sizeof__(self): + return managed + + with dask.config.set({"distributed.worker.memory.target": target}): + async with Worker(s.address, memory_limit="1000 MB") as a: + a.monitor.get_process_memory = lambda: 50_000_000 * len(a.data.fast) + + # Add 500MB (reported) process memory. Spilling must not happen. + futures = [c.submit(C, pure=False) for _ in range(10)] + await wait(futures) + await asyncio.sleep(0.1) + assert not a.data.disk + + # Add another 250MB unmanaged memory. This must trigger the spilling. + futures += [c.submit(C, pure=False) for _ in range(5)] + await wait(futures) + + # Wait until spilling starts. Then, wait until it stops. + prev_n = 0 + while not a.data.disk or len(a.data.disk) > prev_n: + prev_n = len(a.data.disk) + await asyncio.sleep(0) + + assert len(a.data.disk) == expect_spilled + + +@gen_cluster( + nthreads=[("", 1)], + client=True, + config={ + "distributed.worker.memory.target": False, + "distributed.worker.memory.spill": False, + "distributed.worker.memory.pause": False, + }, +) +async def test_pause_executor_manual(c, s, a): + assert not memory_monitor_running(a) + + # Task that is running when the worker pauses + ev_x = Event() + + def f(ev): + ev.wait() + return 1 + + # Task that is running on the worker when the worker pauses + x = c.submit(f, ev_x, key="x") + while a.executing_count != 1: + await asyncio.sleep(0.01) + + # Task that is queued on the worker when the worker pauses + y = c.submit(inc, 1, key="y") + while "y" not in a.tasks: + await asyncio.sleep(0.01) + + a.status = Status.paused + # Wait for sync to scheduler + while s.workers[a.address].status != Status.paused: + await asyncio.sleep(0.01) + + # Task that is queued on the scheduler when the worker pauses. + # It is not sent to the worker. + z = c.submit(inc, 2, key="z") + while "z" not in s.tasks or s.tasks["z"].state != "no-worker": + await asyncio.sleep(0.01) + assert s.unrunnable == {s.tasks["z"]} + + # Test that a task that already started when the worker paused can complete + # and its output can be retrieved. Also test that the now free slot won't be + # used by other tasks. + await ev_x.set() + assert await x == 1 + await asyncio.sleep(0.05) + + assert a.executing_count == 0 + assert len(a.ready) == 1 + assert a.tasks["y"].state == "ready" + assert "z" not in a.tasks + + # Unpause. Tasks that were queued on the worker are executed. + # Tasks that were stuck on the scheduler are sent to the worker and executed. + a.status = Status.running + assert await y == 2 + assert await z == 3 + + +@gen_cluster( + nthreads=[("", 1)], + client=True, + worker_kwargs={"memory_limit": "1000 MB"}, + config={ + "distributed.worker.memory.target": False, + "distributed.worker.memory.spill": False, + "distributed.worker.memory.pause": 0.8, + "distributed.worker.memory.monitor-interval.spill-pause": "10ms", + }, +) +async def test_pause_executor_with_memory_monitor(c, s, a): + assert memory_monitor_running(a) + mocked_rss = 0 + a.monitor.get_process_memory = lambda: mocked_rss + + # Task that is running when the worker pauses + ev_x = Event() + + def f(ev): + ev.wait() + return 1 + + # Task that is running on the worker when the worker pauses + x = c.submit(f, ev_x, key="x") + while a.executing_count != 1: + await asyncio.sleep(0.01) + + with captured_logger(logging.getLogger("distributed.worker_memory")) as logger: + # Task that is queued on the worker when the worker pauses + y = c.submit(inc, 1, key="y") + while "y" not in a.tasks: + await asyncio.sleep(0.01) + + # Hog the worker with 900MB unmanaged memory + mocked_rss = 900_000_000 + while s.workers[a.address].status != Status.paused: + await asyncio.sleep(0.01) + + assert "Pausing worker" in logger.getvalue() + + # Task that is queued on the scheduler when the worker pauses. + # It is not sent to the worker. + z = c.submit(inc, 2, key="z") + while "z" not in s.tasks or s.tasks["z"].state != "no-worker": + await asyncio.sleep(0.01) + assert s.unrunnable == {s.tasks["z"]} + + # Test that a task that already started when the worker paused can complete + # and its output can be retrieved. Also test that the now free slot won't be + # used by other tasks. + await ev_x.set() + assert await x == 1 + await asyncio.sleep(0.05) + + assert a.executing_count == 0 + assert len(a.ready) == 1 + assert a.tasks["y"].state == "ready" + assert "z" not in a.tasks + + # Release the memory. Tasks that were queued on the worker are executed. + # Tasks that were stuck on the scheduler are sent to the worker and executed. + mocked_rss = 0 + assert await y == 2 + assert await z == 3 + + assert a.status == Status.running + assert "Resuming worker" in logger.getvalue() + + +@gen_cluster( + client=True, + nthreads=[("", 1)], + worker_kwargs={"memory_limit": 0}, + config={"distributed.worker.memory.monitor-interval.spill-pause": "10ms"}, +) +async def test_avoid_memory_monitor_if_zero_limit_worker(c, s, a): + assert type(a.data) is dict + assert not memory_monitor_running(a) + + future = c.submit(inc, 1) + assert await future == 2 + await asyncio.sleep(0.05) + assert await c.submit(inc, 2) == 3 # worker doesn't pause + + +@gen_cluster( + client=True, + nthreads=[("", 1)], + Worker=Nanny, + worker_kwargs={"memory_limit": 0}, + config={ + "distributed.worker.memory.monitor-interval.spill-pause": "10ms", + "distributed.worker.memory.monitor-interval.terminate": "10ms", + }, +) +async def test_avoid_memory_monitor_if_zero_limit_nanny(c, s, nanny): + typ = await c.run(lambda dask_worker: type(dask_worker.data)) + assert typ == {nanny.worker_address: dict} + assert not memory_monitor_running(nanny) + assert not (await c.run(memory_monitor_running))[nanny.worker_address] + + future = c.submit(inc, 1) + assert await future == 2 + await asyncio.sleep(0.02) + assert await c.submit(inc, 2) == 3 # worker doesn't pause + + +@gen_cluster(nthreads=[]) +async def test_override_data_worker(s): + async with Worker(s.address, data=dict) as w: + assert type(w.data) is dict + + data = {"x": 1} + async with Worker(s.address, data=data) as w: + assert w.data is data + assert w.data == {"x": 1} + + class Data(dict): + def __init__(self, x, y): + self.x = x + self.y = y + + async with Worker(s.address, data=(Data, {"x": 123, "y": 456})) as w: + assert w.data.x == 123 + assert w.data.y == 456 + + +@gen_cluster( + client=True, + nthreads=[("", 1)], + Worker=Nanny, + worker_kwargs={"data": dict}, +) +async def test_override_data_nanny(c, s, n): + r = await c.run(lambda dask_worker: type(dask_worker.data)) + assert r[n.worker_address] is dict + + +@gen_cluster( + client=True, + nthreads=[("", 1)], + worker_kwargs={"memory_limit": 1000, "data": dict}, + config={ + "distributed.worker.memory.pause": False, + "distributed.worker.memory.monitor-interval.spill-pause": "10ms", + }, +) +async def test_override_data_does_not_spill(c, s, a): + assert memory_monitor_running(a) + a.monitor.get_process_memory = lambda: 10000 + # Push a key that would normally trip both the target and the spill thresholds + x = c.submit(lambda: "x" * 2000) + await wait(x) + await asyncio.sleep(0.05) + assert type(a.data) is dict + assert a.data == {x.key: "x" * 2000} + + +@pytest.mark.slow +@gen_cluster( + nthreads=[("", 1)], + client=True, + Worker=Nanny, + worker_kwargs={"memory_limit": "400 MiB"}, + config={"distributed.worker.memory.monitor-interval.terminate": "10ms"}, +) +async def test_nanny_terminate(c, s, a): + def leak(): + L = [] + while True: + L.append(b"0" * 5_000_000) + sleep(0.01) + + before = a.process.pid + with captured_logger(logging.getLogger("distributed.worker_memory")) as logger: + future = c.submit(leak) + while a.process.pid == before: + await asyncio.sleep(0.01) + + out = logger.getvalue() + assert "restart" in out.lower() + assert "memory" in out.lower() + + +@pytest.mark.parametrize( + "cls,name,value", + [ + (Worker, "memory_limit", 123e9), + (Worker, "memory_target_fraction", 0.789), + (Worker, "memory_spill_fraction", 0.789), + (Worker, "memory_pause_fraction", 0.789), + (Nanny, "memory_limit", 123e9), + (Nanny, "memory_terminate_fraction", 0.789), + ], +) +@gen_cluster(nthreads=[]) +async def test_deprecated_attributes(s, cls, name, value): + async with cls(s.address) as a: + with pytest.warns(FutureWarning, match=name): + setattr(a, name, value) + with pytest.warns(FutureWarning, match=name): + assert getattr(a, name) == value + assert getattr(a.memory_manager, name) == value + + +@pytest.mark.parametrize( + "name", + ["memory_target_fraction", "memory_spill_fraction", "memory_pause_fraction"], +) +@gen_cluster(nthreads=[]) +async def test_deprecated_params(s, name): + with pytest.warns(FutureWarning, match=name): + async with Worker(s.address, **{name: 0.789}) as a: + assert getattr(a.memory_manager, name) == 0.789 diff --git a/distributed/utils_test.py b/distributed/utils_test.py index f6625e74180..ab28ad534cf 100644 --- a/distributed/utils_test.py +++ b/distributed/utils_test.py @@ -1093,7 +1093,6 @@ def get_unclosed(): # zict backends can fail if their storage directory # was already removed pass - del w.data return result diff --git a/distributed/worker.py b/distributed/worker.py index eb0d07c6505..ade48a760d3 100644 --- a/distributed/worker.py +++ b/distributed/worker.py @@ -45,9 +45,9 @@ typename, ) -from . import comm, preloading, profile, system, utils +from . import comm, preloading, profile, utils from .batched import BatchedSend -from .comm import Comm, connect, get_address_host +from .comm import connect, get_address_host from .comm.addressing import address_from_user_args, parse_address from .comm.utils import OFFLOAD_THRESHOLD from .core import ( @@ -92,8 +92,9 @@ warn_on_duration, ) from .utils_comm import gather_from_workers, pack_data, retry_operation -from .utils_perf import ThrottledGC, disable_gc_diagnosis, enable_gc_diagnosis +from .utils_perf import disable_gc_diagnosis, enable_gc_diagnosis from .versions import get_versions +from .worker_memory import DeprecatedMMAccessor, WorkerMemoryManager if TYPE_CHECKING: from typing_extensions import TypeAlias @@ -414,17 +415,6 @@ class Worker(ServerNode): * **tasks**: ``{key: TaskState}`` The tasks currently executing on this worker (and any dependencies of those tasks) - * **data:** ``{key: object}``: - Prefer using the **host** attribute instead of this, unless - memory_limit and at least one of memory_target_fraction or - memory_spill_fraction values are defined, in that case, this attribute - is a zict.Buffer, from which information on LRU cache can be queried. - * **data.memory:** ``{key: object}``: - Dictionary mapping keys to actual values stored in memory. Only - available if condition for **data** being a zict.Buffer is met. - * **data.disk:** ``{key: object}``: - Dictionary mapping keys to actual values stored on disk. Only - available if condition for **data** being a zict.Buffer is met. * **data_needed**: UniqueTaskHeap The tasks which still require data in order to execute, prioritized as a heap * **ready**: [keys] @@ -607,12 +597,6 @@ class Worker(ServerNode): extensions: dict security: Security connection_args: dict[str, Any] - memory_limit: int | None - memory_target_fraction: float | Literal[False] - memory_spill_fraction: float | Literal[False] - memory_pause_fraction: float | Literal[False] - max_spill: int | Literal[False] - data: MutableMapping[str, Any] # {task key: task payload} actors: dict[str, Actor | None] loop: IOLoop reconnect: bool @@ -631,9 +615,6 @@ class Worker(ServerNode): low_level_profiler: bool scheduler: Any execution_state: dict[str, Any] - memory_monitor_interval: float | None - _memory_monitoring: bool - _throttled_gc: ThrottledGC plugins: dict[str, WorkerPlugin] _pending_plugins: tuple[WorkerPlugin, ...] @@ -650,7 +631,6 @@ def __init__( services: dict | None = None, name: Any | None = None, reconnect: bool = True, - memory_limit: str | float = "auto", executor: Executor | dict[str, Executor] | Literal["offload"] | None = None, resources: dict[str, float] | None = None, silence_logs: int | None = None, @@ -660,24 +640,11 @@ def __init__( security: Security | dict[str, Any] | None = None, contact_address: str | None = None, heartbeat_interval: Any = "1s", - memory_monitor_interval: Any = "200ms", - memory_target_fraction: float | Literal[False] | None = None, - memory_spill_fraction: float | Literal[False] | None = None, - memory_pause_fraction: float | Literal[False] | None = None, - max_spill: float | str | Literal[False] | None = None, extensions: list[type] | None = None, metrics: Mapping[str, Callable[[Worker], Any]] = DEFAULT_METRICS, startup_information: Mapping[ str, Callable[[Worker], Any] ] = DEFAULT_STARTUP_INFORMATION, - data: ( - MutableMapping[str, Any] # pre-initialised - | Callable[[], MutableMapping[str, Any]] # constructor - | tuple[ - Callable[..., MutableMapping[str, Any]], dict[str, Any] - ] # (constructor, kwargs to constructor) - | None # create internatlly - ) = None, interface: str | None = None, host: str | None = None, port: int | None = None, @@ -693,6 +660,18 @@ def __init__( lifetime: Any | None = None, lifetime_stagger: Any | None = None, lifetime_restart: bool | None = None, + ################################### + # Parameters to WorkerMemoryManager + memory_limit: str | float = "auto", + # Allow overriding the dict-like that stores the task outputs. + # This is meant for power users only. See WorkerMemoryManager for details. + data=None, + # Deprecated parameters; please use dask config instead. + memory_target_fraction: float | Literal[False] | None = None, + memory_spill_fraction: float | Literal[False] | None = None, + memory_pause_fraction: float | Literal[False] | None = None, + ################################### + # Parameters to Server **kwargs, ): self.tasks = {} @@ -892,54 +871,6 @@ def __init__( assert isinstance(self.security, Security) self.connection_args = self.security.get_connection_args("worker") - self.memory_limit = parse_memory_limit(memory_limit, self.nthreads) - - self.memory_target_fraction = ( - memory_target_fraction - if memory_target_fraction is not None - else dask.config.get("distributed.worker.memory.target") - ) - self.memory_spill_fraction = ( - memory_spill_fraction - if memory_spill_fraction is not None - else dask.config.get("distributed.worker.memory.spill") - ) - self.memory_pause_fraction = ( - memory_pause_fraction - if memory_pause_fraction is not None - else dask.config.get("distributed.worker.memory.pause") - ) - - if max_spill is None: - max_spill = dask.config.get("distributed.worker.memory.max-spill") - self.max_spill = False if max_spill is False else parse_bytes(max_spill) - - if isinstance(data, MutableMapping): - self.data = data - elif callable(data): - self.data = data() - elif isinstance(data, tuple): - self.data = data[0](**data[1]) - elif self.memory_limit and ( - self.memory_target_fraction or self.memory_spill_fraction - ): - from .spill import SpillBuffer - - if self.memory_target_fraction: - target = int( - self.memory_limit - * (self.memory_target_fraction or self.memory_spill_fraction) - ) - else: - target = sys.maxsize - self.data = SpillBuffer( - os.path.join(self.local_directory, "storage"), - target=target, - max_spill=self.max_spill, - ) - else: - self.data = {} - self.actors = {} self.loop = loop or IOLoop.current() self.reconnect = reconnect @@ -1057,24 +988,19 @@ def __init__( self._address = contact_address - self.memory_monitor_interval = parse_timedelta( - memory_monitor_interval, default="ms" - ) - self._memory_monitoring = False - if self.memory_limit: - assert self.memory_monitor_interval is not None - pc = PeriodicCallback( - self.memory_monitor, # type: ignore - self.memory_monitor_interval * 1000, - ) - self.periodic_callbacks["memory"] = pc - if extensions is None: extensions = DEFAULT_EXTENSIONS for ext in extensions: ext(self) - self._throttled_gc = ThrottledGC(logger=logger) + self.memory_manager = WorkerMemoryManager( + self, + data=data, + memory_limit=memory_limit, + memory_target_fraction=memory_target_fraction, + memory_spill_fraction=memory_spill_fraction, + memory_pause_fraction=memory_pause_fraction, + ) setproctitle("dask-worker [not started]") @@ -1108,6 +1034,32 @@ def __init__( Worker._instances.add(self) + ################ + # Memory manager + ################ + memory_manager: WorkerMemoryManager + + @property + def data(self) -> MutableMapping[str, Any]: + """{task key: task payload} of all completed tasks, whether they were computed on + this Worker or computed somewhere else and then transferred here over the + network. + + When using the default configuration, this is a zict buffer that automatically + spills to disk whenever the target threshold is exceeded. + If spilling is disabled, it is a plain dict instead. + It could also be a user-defined arbitrary dict-like passed when initialising + the Worker or the Nanny. + Worker logic should treat this opaquely and stick to the MutableMapping API. + """ + return self.memory_manager.data + + # Deprecated attributes moved to self.memory_manager. + memory_limit = DeprecatedMMAccessor() + memory_target_fraction = DeprecatedMMAccessor() + memory_spill_fraction = DeprecatedMMAccessor() + memory_pause_fraction = DeprecatedMMAccessor() + ################## # Administrative # ################## @@ -1150,23 +1102,21 @@ def worker_address(self): """For API compatibility with Nanny""" return self.address - @property - def local_dir(self): - """For API compatibility with Nanny""" - warnings.warn( - "The local_dir attribute has moved to local_directory", stacklevel=2 - ) - return self.local_directory - @property def executor(self): return self.executors["default"] @ServerNode.status.setter # type: ignore def status(self, value): - """Override Server.status to notify the Scheduler of status changes""" + """Override Server.status to notify the Scheduler of status changes. + Also handles unpausing. + """ + prev_status = self.status ServerNode.status.__set__(self, value) self._send_worker_status_change() + if prev_status == Status.paused and value == Status.running: + self.ensure_computing() + self.ensure_communicating() def _send_worker_status_change(self) -> None: if ( @@ -1235,12 +1185,10 @@ def identity(self): "id": self.id, "scheduler": self.scheduler.address, "nthreads": self.nthreads, - "memory_limit": self.memory_limit, + "memory_limit": self.memory_manager.memory_limit, } - def _to_dict( - self, comm: Comm | None = None, *, exclude: Container[str] = () - ) -> dict: + def _to_dict(self, *, exclude: Container[str] = ()) -> dict: """Dictionary representation for debugging purposes. Not type stable and not intended for roundtrips. @@ -1265,16 +1213,13 @@ def _to_dict( "in_flight_workers": self.in_flight_workers, "log": self.log, "tasks": self.tasks, - "memory_limit": self.memory_limit, - "memory_target_fraction": self.memory_target_fraction, - "memory_spill_fraction": self.memory_spill_fraction, - "memory_pause_fraction": self.memory_pause_fraction, "logs": self.get_logs(), "config": dask.config.config, "incoming_transfer_log": self.incoming_transfer_log, "outgoing_transfer_log": self.outgoing_transfer_log, } info.update(extra) + info.update(self.memory_manager._to_dict(exclude=exclude)) info = {k: v for k, v in info.items() if k not in exclude} return recursive_to_dict(info, exclude=exclude) @@ -1315,7 +1260,7 @@ async def _register_with_scheduler(self): types={k: typename(v) for k, v in self.data.items()}, now=time(), resources=self.total_resources, - memory_limit=self.memory_limit, + memory_limit=self.memory_manager.memory_limit, local_directory=self.local_directory, services=self.service_ports, nanny=self.nanny, @@ -1606,8 +1551,11 @@ async def start(self): logger.info("Waiting to connect to: %26s", self.scheduler.address) logger.info("-" * 49) logger.info(" Threads: %26d", self.nthreads) - if self.memory_limit: - logger.info(" Memory: %26s", format_bytes(self.memory_limit)) + if self.memory_manager.memory_limit: + logger.info( + " Memory: %26s", + format_bytes(self.memory_manager.memory_limit), + ) logger.info(" Local Directory: %26s", self.local_directory) setproctitle("dask-worker [%s]" % self.address) @@ -3828,115 +3776,6 @@ def _prepare_args_for_execution( ################## # Administrative # ################## - - async def memory_monitor(self) -> None: - """Track this process's memory usage and act accordingly - - If we rise above 70% memory use, start dumping data to disk. - - If we rise above 80% memory use, stop execution of new tasks - """ - if self._memory_monitoring: - return - self._memory_monitoring = True - assert self.memory_limit - total = 0 - - memory = self.monitor.get_process_memory() - frac = memory / self.memory_limit - - def check_pause(memory): - frac = memory / self.memory_limit - # Pause worker threads if above 80% memory use - if self.memory_pause_fraction and frac > self.memory_pause_fraction: - # Try to free some memory while in paused state - self._throttled_gc.collect() - if self.status == Status.running: - logger.warning( - "Worker is at %d%% memory usage. Pausing worker. " - "Process memory: %s -- Worker memory limit: %s", - int(frac * 100), - format_bytes(memory), - format_bytes(self.memory_limit) - if self.memory_limit is not None - else "None", - ) - self.status = Status.paused - elif self.status == Status.paused: - logger.warning( - "Worker is at %d%% memory usage. Resuming worker. " - "Process memory: %s -- Worker memory limit: %s", - int(frac * 100), - format_bytes(memory), - format_bytes(self.memory_limit) - if self.memory_limit is not None - else "None", - ) - self.status = Status.running - self.ensure_computing() - self.ensure_communicating() - - check_pause(memory) - # Dump data to disk if above 70% - if self.memory_spill_fraction and frac > self.memory_spill_fraction: - from .spill import SpillBuffer - - assert isinstance(self.data, SpillBuffer) - - logger.debug( - "Worker is at %.0f%% memory usage. Start spilling data to disk.", - frac * 100, - ) - # Implement hysteresis cycle where spilling starts at the spill threshold - # and stops at the target threshold. Normally that here the target threshold - # defines process memory, whereas normally it defines reported managed - # memory (e.g. output of sizeof() ). - # If target=False, disable hysteresis. - target = self.memory_limit * ( - self.memory_target_fraction or self.memory_spill_fraction - ) - count = 0 - need = memory - target - while memory > target: - if not self.data.fast: - logger.warning( - "Unmanaged memory use is high. This may indicate a memory leak " - "or the memory may not be released to the OS; see " - "https://distributed.dask.org/en/latest/worker.html#memtrim " - "for more information. " - "-- Unmanaged memory: %s -- Worker memory limit: %s", - format_bytes(memory), - format_bytes(self.memory_limit), - ) - break - weight = self.data.evict() - if weight == -1: - # Failed to evict: - # disk full, spill size limit exceeded, or pickle error - break - - total += weight - count += 1 - await asyncio.sleep(0) - - memory = self.monitor.get_process_memory() - if total > need and memory > target: - # Issue a GC to ensure that the evicted data is actually - # freed from memory and taken into account by the monitor - # before trying to evict even more data. - self._throttled_gc.collect() - memory = self.monitor.get_process_memory() - - check_pause(memory) - if count: - logger.debug( - "Moved %d tasks worth %s to disk", - count, - format_bytes(total), - ) - - self._memory_monitoring = False - def cycle_profile(self) -> None: now = time() + self.scheduler_delay prof, self.profile_recent = self.profile_recent, profile.create() @@ -4487,25 +4326,6 @@ class Reschedule(Exception): """ -def parse_memory_limit(memory_limit, nthreads, total_cores=CPU_COUNT) -> int | None: - if memory_limit is None: - return None - - if memory_limit == "auto": - memory_limit = int(system.MEMORY_LIMIT * min(1, nthreads / total_cores)) - with suppress(ValueError, TypeError): - memory_limit = float(memory_limit) - if isinstance(memory_limit, float) and memory_limit <= 1: - memory_limit = int(memory_limit * system.MEMORY_LIMIT) - - if isinstance(memory_limit, str): - memory_limit = parse_bytes(memory_limit) - else: - memory_limit = int(memory_limit) - - return min(memory_limit, system.MEMORY_LIMIT) - - async def get_data_from_worker( rpc, keys, diff --git a/distributed/worker_memory.py b/distributed/worker_memory.py new file mode 100644 index 00000000000..10cfc5a5293 --- /dev/null +++ b/distributed/worker_memory.py @@ -0,0 +1,384 @@ +"""Encapsulated manager for in-memory tasks on a worker. + +This module covers: +- spill/unspill data depending on the 'distributed.worker.memory.target' threshold +- spill/unspill data depending on the 'distributed.worker.memory.spill' threshold +- pause/unpause the worker depending on the 'distributed.worker.memory.pause' threshold +- kill the worker depending on the 'distributed.worker.memory.terminate' threshold + +This module does *not* cover: +- Changes in behaviour in Worker, Scheduler, task stealing, Active Memory Manager, etc. + caused by the Worker being in paused status +- Worker restart after it's been killed +- Scheduler-side heuristics regarding memory usage, e.g. the Active Memory Manager + +See also: +- :mod:`distributed.spill`, which implements the spill-to-disk mechanism and is wrapped + around by this module. Unlike this module, :mod:`distributed.spill` is agnostic to the + Worker. +- :mod:`distributed.active_memory_manager`, which runs on the scheduler side +""" +from __future__ import annotations + +import asyncio +import logging +import os +import sys +import warnings +from collections.abc import Callable, MutableMapping +from contextlib import suppress +from functools import partial +from typing import TYPE_CHECKING, Any, Container, Literal + +import psutil +from tornado.ioloop import PeriodicCallback + +import dask.config +from dask.system import CPU_COUNT +from dask.utils import format_bytes, parse_bytes, parse_timedelta + +from . import system +from .core import Status +from .spill import SpillBuffer +from .utils_perf import ThrottledGC + +if TYPE_CHECKING: + # Circular imports + from .nanny import Nanny + from .worker import Worker + +logger = logging.getLogger(__name__) + + +class WorkerMemoryManager: + data: MutableMapping[str, Any] # {task key: task payload} + memory_limit: int | None + memory_target_fraction: float | Literal[False] + memory_spill_fraction: float | Literal[False] + memory_pause_fraction: float | Literal[False] + max_spill: int | Literal[False] + memory_monitor_interval: float + _memory_monitoring: bool + _throttled_gc: ThrottledGC + + def __init__( + self, + worker: Worker, + *, + memory_limit: str | float = "auto", + # This should be None most of the times, short of a power user replacing the + # SpillBuffer with their own custom dict-like + data: ( + MutableMapping[str, Any] # pre-initialised + | Callable[[], MutableMapping[str, Any]] # constructor + | tuple[ + Callable[..., MutableMapping[str, Any]], dict[str, Any] + ] # (constructor, kwargs to constructor) + | None # create internally + ) = None, + # Deprecated parameters; use dask.config instead + memory_target_fraction: float | Literal[False] | None = None, + memory_spill_fraction: float | Literal[False] | None = None, + memory_pause_fraction: float | Literal[False] | None = None, + ): + self.memory_limit = parse_memory_limit(memory_limit, worker.nthreads) + + self.memory_target_fraction = _parse_threshold( + "distributed.worker.memory.target", + "memory_target_fraction", + memory_target_fraction, + ) + self.memory_spill_fraction = _parse_threshold( + "distributed.worker.memory.spill", + "memory_spill_fraction", + memory_spill_fraction, + ) + self.memory_pause_fraction = _parse_threshold( + "distributed.worker.memory.pause", + "memory_pause_fraction", + memory_pause_fraction, + ) + + max_spill = dask.config.get("distributed.worker.memory.max-spill") + self.max_spill = False if max_spill is False else parse_bytes(max_spill) + + if isinstance(data, MutableMapping): + self.data = data + elif callable(data): + self.data = data() + elif isinstance(data, tuple): + self.data = data[0](**data[1]) + elif self.memory_limit and ( + self.memory_target_fraction or self.memory_spill_fraction + ): + if self.memory_target_fraction: + target = int( + self.memory_limit + * (self.memory_target_fraction or self.memory_spill_fraction) + ) + else: + target = sys.maxsize + self.data = SpillBuffer( + os.path.join(worker.local_directory, "storage"), + target=target, + max_spill=self.max_spill, + ) + else: + self.data = {} + + self._memory_monitoring = False + + self.memory_monitor_interval = parse_timedelta( + dask.config.get("distributed.worker.memory.monitor-interval.spill-pause"), + default=None, + ) + assert isinstance(self.memory_monitor_interval, (int, float)) + + if self.memory_limit and ( + self.memory_spill_fraction is not False + or self.memory_pause_fraction is not False + ): + assert self.memory_monitor_interval is not None + pc = PeriodicCallback( + # Don't store worker as self.worker to avoid creating a circular + # dependency. We could have alternatively used a weakref. + # FIXME annotations: https://github.com/tornadoweb/tornado/issues/3117 + partial(self.memory_monitor, worker), # type: ignore + self.memory_monitor_interval * 1000, + ) + worker.periodic_callbacks["memory_monitor"] = pc + + self._throttled_gc = ThrottledGC(logger=logger) + + async def memory_monitor(self, worker: Worker) -> None: + """Track this process's memory usage and act accordingly. + If process memory rises above the spill threshold (70%), start dumping data to + disk until it goes below the target threshold (60%). + If process memory rises above the pause threshold (80%), stop execution of new + tasks. + """ + if self._memory_monitoring: + return + self._memory_monitoring = True + # Don't use psutil directly; instead read from the same API that is used to send + # info to the Scheduler (e.g. for the benefit of Active Memory Manager) and + # which can be easily mocked in unit tests. + memory = worker.monitor.get_process_memory() + self._maybe_pause_or_unpause(worker, memory) + await self._maybe_spill(worker, memory) + self._memory_monitoring = False + + def _maybe_pause_or_unpause(self, worker: Worker, memory: int) -> None: + if self.memory_pause_fraction is False: + return + + assert self.memory_limit + frac = memory / self.memory_limit + # Pause worker threads if above 80% memory use + if frac > self.memory_pause_fraction: + # Try to free some memory while in paused state + self._throttled_gc.collect() + if worker.status == Status.running: + logger.warning( + "Worker is at %d%% memory usage. Pausing worker. " + "Process memory: %s -- Worker memory limit: %s", + int(frac * 100), + format_bytes(memory), + format_bytes(self.memory_limit) + if self.memory_limit is not None + else "None", + ) + worker.status = Status.paused + elif worker.status == Status.paused: + logger.warning( + "Worker is at %d%% memory usage. Resuming worker. " + "Process memory: %s -- Worker memory limit: %s", + int(frac * 100), + format_bytes(memory), + format_bytes(self.memory_limit) + if self.memory_limit is not None + else "None", + ) + worker.status = Status.running + + async def _maybe_spill(self, worker: Worker, memory: int) -> None: + if self.memory_spill_fraction is False: + return + if not isinstance(self.data, SpillBuffer): + return + + assert self.memory_limit + frac = memory / self.memory_limit + if frac <= self.memory_spill_fraction: + return + + total_spilled = 0 + logger.debug( + "Worker is at %.0f%% memory usage. Start spilling data to disk.", + frac * 100, + ) + # Implement hysteresis cycle where spilling starts at the spill threshold and + # stops at the target threshold. Normally that here the target threshold defines + # process memory, whereas normally it defines reported managed memory (e.g. + # output of sizeof() ). If target=False, disable hysteresis. + target = self.memory_limit * ( + self.memory_target_fraction or self.memory_spill_fraction + ) + count = 0 + need = memory - target + while memory > target: + if not self.data.fast: + logger.warning( + "Unmanaged memory use is high. This may indicate a memory leak " + "or the memory may not be released to the OS; see " + "https://distributed.dask.org/en/latest/worker.html#memtrim " + "for more information. " + "-- Unmanaged memory: %s -- Worker memory limit: %s", + format_bytes(memory), + format_bytes(self.memory_limit), + ) + break + weight = self.data.evict() + if weight == -1: + # Failed to evict: + # disk full, spill size limit exceeded, or pickle error + break + + total_spilled += weight + count += 1 + await asyncio.sleep(0) + + memory = worker.monitor.get_process_memory() + if total_spilled > need and memory > target: + # Issue a GC to ensure that the evicted data is actually + # freed from memory and taken into account by the monitor + # before trying to evict even more data. + self._throttled_gc.collect() + memory = worker.monitor.get_process_memory() + + self._maybe_pause_or_unpause(worker, memory) + if count: + logger.debug( + "Moved %d tasks worth %s to disk", + count, + format_bytes(total_spilled), + ) + + def _to_dict(self, *, exclude: Container[str] = ()) -> dict: + info = { + k: v + for k, v in self.__dict__.items() + if not k.startswith("_") and k != "data" and k not in exclude + } + info["data"] = list(self.data) + return info + + +class NannyMemoryManager: + memory_limit: int | None + memory_terminate_fraction: float | Literal[False] + memory_monitor_interval: float | None + + def __init__( + self, + nanny: Nanny, + *, + memory_limit: str | float = "auto", + ): + self.memory_limit = parse_memory_limit(memory_limit, nanny.nthreads) + self.memory_terminate_fraction = dask.config.get( + "distributed.worker.memory.terminate" + ) + self.memory_monitor_interval = parse_timedelta( + dask.config.get("distributed.worker.memory.monitor-interval.terminate"), + default=None, + ) + assert isinstance(self.memory_monitor_interval, (int, float)) + if self.memory_limit and self.memory_terminate_fraction is not False: + pc = PeriodicCallback( + partial(self.memory_monitor, nanny), + self.memory_monitor_interval * 1000, + ) + nanny.periodic_callbacks["memory_monitor"] = pc + + def memory_monitor(self, nanny: Nanny) -> None: + """Track worker's memory. Restart if it goes above terminate fraction.""" + if nanny.status != Status.running: + return # pragma: nocover + if nanny.process is None or nanny.process.process is None: + return # pragma: nocover + process = nanny.process.process + try: + proc = nanny._psutil_process + memory = proc.memory_info().rss + except (ProcessLookupError, psutil.NoSuchProcess, psutil.AccessDenied): + return # pragma: nocover + + if memory / self.memory_limit > self.memory_terminate_fraction: + logger.warning( + "Worker exceeded %d%% memory budget. Restarting", + 100 * self.memory_terminate_fraction, + ) + process.terminate() + + +def parse_memory_limit( + memory_limit: str | float, nthreads: int, total_cores: int = CPU_COUNT +) -> int | None: + if memory_limit is None: + return None + + if memory_limit == "auto": + memory_limit = int(system.MEMORY_LIMIT * min(1, nthreads / total_cores)) + with suppress(ValueError, TypeError): + memory_limit = float(memory_limit) + if isinstance(memory_limit, float) and memory_limit <= 1: + memory_limit = int(memory_limit * system.MEMORY_LIMIT) + + if isinstance(memory_limit, str): + memory_limit = parse_bytes(memory_limit) + else: + memory_limit = int(memory_limit) + + assert isinstance(memory_limit, int) + if memory_limit == 0: + return None + return min(memory_limit, system.MEMORY_LIMIT) + + +def _parse_threshold( + config_key: str, + deprecated_param_name: str, + deprecated_param_value: float | Literal[False] | None, +) -> float | Literal[False]: + if deprecated_param_value is not None: + warnings.warn( + f"Parameter {deprecated_param_name} has been deprecated and will be " + f"removed in a future version; please use dask config key {config_key} " + "instead", + FutureWarning, + ) + return deprecated_param_value + return dask.config.get(config_key) + + +class DeprecatedMMAccessor: + name: str + + def __set_name__(self, owner: type, name: str) -> None: + self.name = name + + def __get__(self, instance: Nanny | Worker, _): + self._warn(instance) + return getattr(instance.memory_manager, self.name) + + def __set__(self, instance: Nanny | Worker, value) -> None: + self._warn(instance) + setattr(instance.memory_manager, self.name, value) + + def _warn(self, instance: Nanny | Worker) -> None: + warnings.warn( + f"The `{type(instance).__name__}.{self.name}` attribute has been moved to " + f"`{type(instance).__name__}.memory_manager.{self.name}", + FutureWarning, + ) diff --git a/setup.cfg b/setup.cfg index aebd0a81dee..dd99eccfc7a 100644 --- a/setup.cfg +++ b/setup.cfg @@ -28,7 +28,7 @@ skip_gitignore = true force_to_top = true default_section = THIRDPARTY known_first_party = distributed -known_distributed = dask +known_distributed = dask,zict [versioneer] VCS = git