From a0073459858908ce1c6c6e2931afd1e3b534ba88 Mon Sep 17 00:00:00 2001 From: Hendrik Makait Date: Tue, 27 Sep 2022 23:34:47 +0200 Subject: [PATCH 01/11] Fix transfer limiting in `_select_keys_for_gather` (#7071) --- .../tests/test_worker_state_machine.py | 52 +++++++++++--- distributed/worker.py | 3 +- distributed/worker_state_machine.py | 72 ++++++++++++------- 3 files changed, 91 insertions(+), 36 deletions(-) diff --git a/distributed/tests/test_worker_state_machine.py b/distributed/tests/test_worker_state_machine.py index eaec94725c..9b46a14b6e 100644 --- a/distributed/tests/test_worker_state_machine.py +++ b/distributed/tests/test_worker_state_machine.py @@ -1010,30 +1010,40 @@ async def test_deprecated_worker_attributes(s, a, b): assert a.data_needed == set() +@pytest.mark.parametrize("n_remote_workers", [1, 2]) @pytest.mark.parametrize( - "nbytes,n_in_flight", + "nbytes,n_in_flight_per_worker", [ (int(10e6), 3), (int(20e6), 2), (int(30e6), 1), + (int(60e6), 1), ], ) -def test_aggregate_gather_deps(ws, nbytes, n_in_flight): +def test_aggregate_gather_deps(ws, nbytes, n_in_flight_per_worker, n_remote_workers): ws.transfer_message_bytes_limit = int(50e6) - ws2 = "127.0.0.1:2" + wss = [f"127.0.0.1:{2 + i}" for i in range(n_remote_workers)] + who_has = {f"x{i}": [wss[i // 3]] for i in range(3 * n_remote_workers)} instructions = ws.handle_stimulus( AcquireReplicasEvent( - who_has={"x1": [ws2], "x2": [ws2], "x3": [ws2]}, - nbytes={"x1": nbytes, "x2": nbytes, "x3": nbytes}, + who_has=who_has, + nbytes={task: nbytes for task in who_has.keys()}, stimulus_id="s1", ) ) - assert instructions == [GatherDep.match(worker=ws2, stimulus_id="s1")] - assert len(instructions[0].to_gather) == n_in_flight - assert len(ws.in_flight_tasks) == n_in_flight - assert ws.transfer_incoming_bytes == nbytes * n_in_flight - assert ws.transfer_incoming_count == 1 - assert ws.transfer_incoming_count_total == 1 + assert instructions == [ + GatherDep.match(worker=remote, stimulus_id="s1") for remote in wss + ] + assert all( + len(instruction.to_gather) == n_in_flight_per_worker + for instruction in instructions + ) + assert len(ws.in_flight_tasks) == n_in_flight_per_worker * n_remote_workers + assert ( + ws.transfer_incoming_bytes == nbytes * n_in_flight_per_worker * n_remote_workers + ) + assert ws.transfer_incoming_count == n_remote_workers + assert ws.transfer_incoming_count_total == n_remote_workers def test_gather_priority(ws): @@ -1358,6 +1368,7 @@ def test_throttling_does_not_affect_first_transfer(ws): ws.transfer_incoming_count_limit = 100 ws.transfer_incoming_bytes_limit = 100 ws.transfer_incoming_bytes_throttle_threshold = 1 + ws.transfer_message_bytes_limit = 100 ws2 = "127.0.0.1:2" ws.handle_stimulus( ComputeTaskEvent.dummy( @@ -1370,6 +1381,25 @@ def test_throttling_does_not_affect_first_transfer(ws): assert ws.tasks["a"].state == "flight" +def test_message_target_does_not_affect_first_transfer_on_different_worker(ws): + ws.transfer_incoming_count_limit = 100 + ws.transfer_incoming_bytes_limit = 600 + ws.transfer_message_bytes_limit = 100 + ws.transfer_incoming_bytes_throttle_threshold = 1 + ws2 = "127.0.0.1:2" + ws3 = "127.0.0.1:3" + ws.handle_stimulus( + ComputeTaskEvent.dummy( + "c", + who_has={"a": [ws2], "b": [ws3]}, + nbytes={"a": 200, "b": 200}, + stimulus_id="s1", + ) + ) + assert ws.tasks["a"].state == "flight" + assert ws.tasks["b"].state == "flight" + + def test_throttle_incoming_transfers_on_count_limit(ws): ws.transfer_incoming_count_limit = 1 ws.transfer_incoming_bytes_limit = 100_000 diff --git a/distributed/worker.py b/distributed/worker.py index 29ef7939f6..3c3a16b5b6 100644 --- a/distributed/worker.py +++ b/distributed/worker.py @@ -6,6 +6,7 @@ import errno import functools import logging +import math import os import pathlib import random @@ -748,7 +749,7 @@ def __init__( memory_pause_fraction=memory_pause_fraction, ) - transfer_incoming_bytes_limit = None + transfer_incoming_bytes_limit = math.inf transfer_incoming_bytes_fraction = dask.config.get( "distributed.worker.memory.transfer" ) diff --git a/distributed/worker_state_machine.py b/distributed/worker_state_machine.py index 30b39f8d05..6add2f72d9 100644 --- a/distributed/worker_state_machine.py +++ b/distributed/worker_state_machine.py @@ -1234,7 +1234,7 @@ class WorkerState: transition_counter_max: int | Literal[False] #: Limit of bytes for incoming data transfers; this is used for throttling. - transfer_incoming_bytes_limit: int | None + transfer_incoming_bytes_limit: float #: Statically-seeded random state, used to guarantee determinism whenever a #: pseudo-random choice is required @@ -1254,7 +1254,7 @@ def __init__( transfer_incoming_count_limit: int = 9999, validate: bool = True, transition_counter_max: int | Literal[False] = False, - transfer_incoming_bytes_limit: int | None = None, + transfer_incoming_bytes_limit: float = math.inf, transfer_message_bytes_limit: float = math.inf, ): self.nthreads = nthreads @@ -1493,8 +1493,7 @@ def _should_throttle_incoming_transfers(self) -> bool: >= self.transfer_incoming_bytes_throttle_threshold ) reached_bytes_limit = ( - self.transfer_incoming_bytes_limit is not None - and self.transfer_incoming_bytes >= self.transfer_incoming_bytes_limit + self.transfer_incoming_bytes >= self.transfer_incoming_bytes_limit ) return reached_count_limit and reached_throttle_threshold or reached_bytes_limit @@ -1512,7 +1511,7 @@ def _ensure_communicating(self, *, stimulus_id: str) -> RecsInstrs: for worker, available_tasks in self._select_workers_for_gather(): assert worker != self.address - to_gather_tasks, total_nbytes = self._select_keys_for_gather( + to_gather_tasks, message_nbytes = self._select_keys_for_gather( available_tasks ) # We always load at least one task @@ -1554,14 +1553,14 @@ def _ensure_communicating(self, *, stimulus_id: str) -> RecsInstrs: GatherDep( worker=worker, to_gather=to_gather_keys, - total_nbytes=total_nbytes, + total_nbytes=message_nbytes, stimulus_id=stimulus_id, ) ) self.in_flight_workers[worker] = to_gather_keys self.transfer_incoming_count_total += 1 - self.transfer_incoming_bytes += total_nbytes + self.transfer_incoming_bytes += message_nbytes if self._should_throttle_incoming_transfers(): break @@ -1641,32 +1640,57 @@ def _select_keys_for_gather( for the size of incoming data transfers. """ to_gather: list[TaskState] = [] - total_nbytes = 0 - - if self.transfer_incoming_bytes_limit is not None: - bytes_left_to_fetch = min( - self.transfer_incoming_bytes_limit - self.transfer_incoming_bytes, - self.transfer_message_bytes_limit, - ) - else: - bytes_left_to_fetch = self.transfer_message_bytes_limit + message_nbytes = 0 while available: ts = available.peek() - if ( - # When there is no other traffic, the top-priority task is fetched - # regardless of its size to ensure progress - self.transfer_incoming_bytes - or to_gather - ) and total_nbytes + ts.get_nbytes() > bytes_left_to_fetch: + if self._task_exceeds_transfer_limits(ts, message_nbytes): break for worker in ts.who_has: # This also effectively pops from available self.data_needed[worker].remove(ts) to_gather.append(ts) - total_nbytes += ts.get_nbytes() + message_nbytes += ts.get_nbytes() + + return to_gather, message_nbytes + + def _task_exceeds_transfer_limits(self, ts: TaskState, message_nbytes: int) -> bool: + """Would asking to gather this task exceed transfer limits? + + Parameters + ---------- + ts + Candidate task for gathering + message_nbytes + Total number of bytes already scheduled for gathering in this message + Returns + ------- + exceeds_limit + True if gathering the task would exceed limits, False otherwise + (in which case the task can be gathered). + """ + if self.transfer_incoming_bytes == 0 and message_nbytes == 0: + # When there is no other traffic, the top-priority task is fetched + # regardless of its size to ensure progress + return False + + incoming_bytes_allowance = ( + self.transfer_incoming_bytes_limit - self.transfer_incoming_bytes + ) + + # If message_nbytes == 0, i.e., this is the first task to gather in this + # message, ignore `self.transfer_message_bytes_limit` for the top-priority + # task to ensure progress. Otherwise: + if message_nbytes != 0: + incoming_bytes_allowance = ( + min( + incoming_bytes_allowance, + self.transfer_message_bytes_limit, + ) + - message_nbytes + ) - return to_gather, total_nbytes + return ts.get_nbytes() > incoming_bytes_allowance def _ensure_computing(self) -> RecsInstrs: if not self.running: From 54b0546be3731d3a78ff1cdad8fdd0b646a0289d Mon Sep 17 00:00:00 2001 From: crusaderky Date: Wed, 28 Sep 2022 12:52:32 +0100 Subject: [PATCH 02/11] AMM support for actors (#7072) --- distributed/active_memory_manager.py | 31 ++++++-- distributed/deploy/adaptive.py | 2 +- distributed/scheduler.py | 2 +- .../tests/test_active_memory_manager.py | 72 +++++++++++++++++++ distributed/worker_state_machine.py | 2 +- docs/source/actors.rst | 4 ++ 6 files changed, 106 insertions(+), 7 deletions(-) diff --git a/distributed/active_memory_manager.py b/distributed/active_memory_manager.py index 37fb22b948..b7b15f9445 100644 --- a/distributed/active_memory_manager.py +++ b/distributed/active_memory_manager.py @@ -230,6 +230,10 @@ def log_reject(msg: str) -> None: log_reject(f"ts.state = {ts.state}") return None + if ts.actor: + log_reject("task is an actor") + return None + if candidates is None: candidates = self.scheduler.running.copy() else: @@ -282,6 +286,10 @@ def log_reject(msg: str) -> None: log_reject("less than 2 replicas exist") return None + if ts.actor: + log_reject("task is an actor") + return None + if candidates is None: candidates = ts.who_has.copy() else: @@ -591,11 +599,26 @@ def run(self) -> SuggestionGenerator: self.manager.policies.remove(self) return + if ws.actors: + logger.warning( + f"Tried retiring worker {self.address}, but it holds actor(s) " + f"{set(ws.actors)}, which can't be moved." + "The worker will not be retired." + ) + self.no_recipients = True + self.manager.policies.remove(self) + return + nrepl = 0 nno_rec = 0 logger.debug("Retiring %s", ws) for ts in ws.has_what: + if ts.actor: + # This is just a proxy Actor object; if there were any originals we + # would have stopped earlier + continue + if len(ts.who_has) > 1: # There are already replicas of this key on other workers. # Suggest dropping the replica from this worker. @@ -663,10 +686,10 @@ def run(self) -> SuggestionGenerator: def done(self) -> bool: """Return True if it is safe to close the worker down; False otherwise""" if self not in self.manager.policies: - # Either the no_recipients flag has been raised, or there were no unique replicas - # as of the latest AMM run. Note that due to tasks transitioning from running to - # memory there may be some now; it's OK to lose them and just recompute them - # somewhere else. + # Either the no_recipients flag has been raised, or there were no unique + # replicas as of the latest AMM run. Note that due to tasks transitioning + # from running to memory there may be some now; it's OK to lose them and + # just recompute them somewhere else. return True ws = self.manager.scheduler.workers.get(self.address) if ws is None: diff --git a/distributed/deploy/adaptive.py b/distributed/deploy/adaptive.py index 760fffc569..3f61964100 100644 --- a/distributed/deploy/adaptive.py +++ b/distributed/deploy/adaptive.py @@ -76,7 +76,7 @@ class Adaptive(AdaptiveCore): :meth:`Adaptive.workers_to_close` to control when the cluster should be resized. The default implementation checks if there are too many tasks per worker or too little memory available (see - :meth:`Scheduler.adaptive_target`). + :meth:`distributed.Scheduler.adaptive_target`). The values for interval, min, max, wait_count and target_duration can be specified in the dask config under the distributed.adaptive key. ''' diff --git a/distributed/scheduler.py b/distributed/scheduler.py index 1ef637de3f..cddada8198 100644 --- a/distributed/scheduler.py +++ b/distributed/scheduler.py @@ -1108,7 +1108,7 @@ class TaskState: #: "processing" state and be sent for execution to another connected worker. loose_restrictions: bool - #: Whether or not this task is an Actor + #: Whether this task is an Actor actor: bool #: The group of tasks to which this one belongs diff --git a/distributed/tests/test_active_memory_manager.py b/distributed/tests/test_active_memory_manager.py index 4c3c513adc..1eef7b9b07 100644 --- a/distributed/tests/test_active_memory_manager.py +++ b/distributed/tests/test_active_memory_manager.py @@ -1056,6 +1056,78 @@ async def test_RetireWorker_faulty_recipient(c, s, *nannies): clog_fut.cancel() +class Counter: + def __init__(self): + self.n = 0 + + def increment(self): + self.n += 1 + + +@gen_cluster(client=True, config=demo_config("drop")) +async def test_dont_drop_actors(c, s, a, b): + x = c.submit(Counter, key="x", actor=True, workers=[a.address]) + y = c.submit(lambda cnt: cnt.increment(), x, key="y", workers=[b.address]) + await wait([x, y]) + assert len(s.tasks["x"].who_has) == 2 + s.extensions["amm"].run_once() + await asyncio.sleep(0.2) + assert len(s.tasks["x"].who_has) == 2 + + +@gen_cluster(client=True, config=demo_config("replicate")) +async def test_dont_replicate_actors(c, s, a, b): + x = c.submit(Counter, key="x", actor=True) + await wait(x) + assert len(s.tasks["x"].who_has) == 1 + s.extensions["amm"].run_once() + await asyncio.sleep(0.2) + assert len(s.tasks["x"].who_has) == 1 + + +@pytest.mark.parametrize("has_proxy", [False, True]) +@gen_cluster(client=True, config=NO_AMM_START) +async def test_RetireWorker_with_actor(c, s, a, b, has_proxy): + """A worker holding one or more original actor objects cannot be retired""" + x = c.submit(Counter, key="x", actor=True, workers=[a.address]) + await wait(x) + assert "x" in a.state.actors + + if has_proxy: + y = c.submit( + lambda cnt: cnt.increment().result(), x, key="y", workers=[b.address] + ) + await wait(y) + assert "x" in b.data + assert "y" in b.data + + with captured_logger("distributed.active_memory_manager", logging.WARNING) as log: + out = await c.retire_workers([a.address]) + assert out == {} + assert "it holds actor(s)" in log.getvalue() + assert "x" in a.state.actors + + if has_proxy: + assert "x" in b.data + assert "y" in b.data + + +@gen_cluster(client=True, config=NO_AMM_START) +async def test_RetireWorker_with_actor_proxy(c, s, a, b): + """A worker holding an Actor proxy object can be retired as normal.""" + x = c.submit(Counter, key="x", actor=True, workers=[a.address]) + y = c.submit(lambda cnt: cnt.increment().result(), x, key="y", workers=[b.address]) + await wait(y) + assert "x" in a.state.actors + assert "x" in b.data + assert "y" in b.data + + out = await c.retire_workers([b.address]) + assert out.keys() == {b.address} + assert "x" in a.state.actors + assert "y" in a.data + + class DropEverything(ActiveMemoryManagerPolicy): """Inanely suggest to drop every single key in the cluster""" diff --git a/distributed/worker_state_machine.py b/distributed/worker_state_machine.py index 6add2f72d9..339e58a63e 100644 --- a/distributed/worker_state_machine.py +++ b/distributed/worker_state_machine.py @@ -2918,7 +2918,7 @@ def _gather_dep_done_common(self, ev: GatherDepDoneEvent) -> Iterator[TaskState] @_handle_event.register def _handle_gather_dep_success(self, ev: GatherDepSuccessEvent) -> RecsInstrs: """gather_dep terminated successfully. - The response may contain less keys than the request. + The response may contain fewer keys than the request. """ recommendations: Recs = {} for ts in self._gather_dep_done_common(ev): diff --git a/docs/source/actors.rst b/docs/source/actors.rst index 5a089c7ded..9f28de11f5 100644 --- a/docs/source/actors.rst +++ b/docs/source/actors.rst @@ -227,3 +227,7 @@ Actors offer advanced capabilities, but with some cost: computations no diagnostics are available about these computations. 3. **No Load balancing:** Actors are allocated onto workers evenly, without serious consideration given to avoiding communication. +4. **No dynamic clusters:** Actors cannot be migrated to other workers. + A worker holding an actor can be retired neither through + :meth:`~distributed.Client.retire_workers` nor through + :class:`~distributed.deploy.Adaptive`. From 162a7c0f028bdbad085e37cb10d9ad39cbee61c1 Mon Sep 17 00:00:00 2001 From: crusaderky Date: Wed, 28 Sep 2022 12:55:57 +0100 Subject: [PATCH 03/11] Make AMM memory measure configurable (#7062) --- distributed/active_memory_manager.py | 36 +++- distributed/distributed-schema.yaml | 10 +- distributed/distributed.yaml | 10 +- .../tests/test_active_memory_manager.py | 178 +++++++++--------- docs/source/active_memory_manager.rst | 2 + 5 files changed, 134 insertions(+), 102 deletions(-) diff --git a/distributed/active_memory_manager.py b/distributed/active_memory_manager.py index b7b15f9445..95bf9548f4 100644 --- a/distributed/active_memory_manager.py +++ b/distributed/active_memory_manager.py @@ -46,14 +46,20 @@ class ActiveMemoryManagerExtension: ``distributed.scheduler.active-memory-manager``. """ + #: Back-reference to the scheduler holding this extension scheduler: Scheduler + #: All active policies policies: set[ActiveMemoryManagerPolicy] + #: Memory measure to use. Must be one of the attributes or properties of + #: :class:`distributed.scheduler.MemoryState`. + measure: str + #: Run automatically every this many seconds interval: float - - # These attributes only exist within the scope of self.run() - # Current memory (in bytes) allocated on each worker, plus/minus pending actions + #: Current memory (in bytes) allocated on each worker, plus/minus pending actions + #: This attribute only exist within the scope of self.run(). workers_memory: dict[WorkerState, int] - # Pending replications and deletions for each task + #: Pending replications and deletions for each task + #: This attribute only exist within the scope of self.run(). pending: dict[TaskState, tuple[set[WorkerState], set[WorkerState]]] def __init__( @@ -63,6 +69,7 @@ def __init__( # away on the fly a specialized manager, separate from the main one. policies: set[ActiveMemoryManagerPolicy] | None = None, *, + measure: str | None = None, register: bool = True, start: bool | None = None, interval: float | None = None, @@ -83,6 +90,23 @@ def __init__( for policy in policies: self.add_policy(policy) + if not measure: + measure = dask.config.get( + "distributed.scheduler.active-memory-manager.measure" + ) + mem = scheduler.memory + measure_domain = { + name + for name in dir(mem) + if not name.startswith("_") and isinstance(getattr(mem, name), int) + } + if not isinstance(measure, str) or measure not in measure_domain: + raise ValueError( + "distributed.scheduler.active-memory-manager.measure " + "must be one of " + ", ".join(sorted(measure_domain)) + ) + self.measure = measure + if register: scheduler.extensions["amm"] = self scheduler.handlers["amm_handler"] = self.amm_handler @@ -92,6 +116,7 @@ def __init__( dask.config.get("distributed.scheduler.active-memory-manager.interval") ) self.interval = interval + if start is None: start = dask.config.get("distributed.scheduler.active-memory-manager.start") if start: @@ -140,8 +165,9 @@ def run_once(self) -> None: assert not hasattr(self, "pending") self.pending = {} + measure = self.measure self.workers_memory = { - w: w.memory.optimistic for w in self.scheduler.workers.values() + ws: getattr(ws.memory, measure) for ws in self.scheduler.workers.values() } try: # populate self.pending diff --git a/distributed/distributed-schema.yaml b/distributed/distributed-schema.yaml index 8e15961680..790c80ea87 100644 --- a/distributed/distributed-schema.yaml +++ b/distributed/distributed-schema.yaml @@ -277,7 +277,7 @@ properties: active-memory-manager: type: object - required: [start, interval, policies] + required: [start, interval, measure, policies] additionalProperties: false properties: start: @@ -287,6 +287,14 @@ properties: type: string description: Time expression, e.g. "2s". Run the AMM cycle every . + measure: + enum: + - process + - optimistic + - managed + - managed_in_memory + description: + One of the attributes of distributed.scheduler.MemoryState policies: type: array items: diff --git a/distributed/distributed.yaml b/distributed/distributed.yaml index 0653af7f88..bc437e5903 100644 --- a/distributed/distributed.yaml +++ b/distributed/distributed.yaml @@ -67,11 +67,17 @@ distributed: # you'll have to either manually start it with client.amm.start() or run it once # with client.amm.run_once(). start: false + # Once started, run the AMM cycle every interval: 2s + + # Memory measure to use. Must be one of the attributes of + # distributed.scheduler.MemoryState. + measure: optimistic + + # Policies that should be executed at every cycle. Any additional keys in each + # object are passed as keyword arguments to the policy constructor. policies: - # Policies that should be executed at every cycle. Any additional keys in each - # object are passed as keyword arguments to the policy constructor. - class: distributed.active_memory_manager.ReduceReplicas worker: diff --git a/distributed/tests/test_active_memory_manager.py b/distributed/tests/test_active_memory_manager.py index 1eef7b9b07..0f13e95c51 100644 --- a/distributed/tests/test_active_memory_manager.py +++ b/distributed/tests/test_active_memory_manager.py @@ -3,14 +3,16 @@ import asyncio import logging import random +import warnings from collections.abc import Iterator from contextlib import contextmanager -from time import sleep from typing import Any, Literal import pytest -from distributed import Event, Lock, Nanny, wait +import dask.config + +from distributed import Event, Lock, Scheduler, wait from distributed.active_memory_manager import ( ActiveMemoryManagerExtension, ActiveMemoryManagerPolicy, @@ -18,14 +20,17 @@ ) from distributed.core import Status from distributed.utils_test import ( + BlockedGatherDep, assert_story, captured_logger, gen_cluster, + gen_test, inc, lock_inc, slowinc, wait_for_state, ) +from distributed.worker_state_machine import AcquireReplicasEvent NO_AMM_START = {"distributed.scheduler.active-memory-manager.start": False} @@ -87,11 +92,13 @@ def demo_config( candidates: list[int] | None = None, start: bool = False, interval: float = 0.1, + measure: str = "managed", ) -> dict[str, Any]: """Create a dask config for AMM with DemoPolicy""" return { "distributed.scheduler.active-memory-manager.start": start, "distributed.scheduler.active-memory-manager.interval": interval, + "distributed.scheduler.active-memory-manager.measure": measure, "distributed.scheduler.active-memory-manager.policies": [ { "class": "distributed.tests.test_active_memory_manager.DemoPolicy", @@ -349,25 +356,15 @@ async def test_double_drop_stress(c, s, a, b): assert len(s.tasks["x"].who_has) == 1 -@pytest.mark.slow -@gen_cluster( - nthreads=[("", 1)] * 4, - Worker=Nanny, - client=True, - worker_kwargs={"memory_limit": "2 GiB"}, - config=demo_config("drop", n=1), -) -async def test_drop_from_worker_with_least_free_memory(c, s, *nannies): - a1, a2, a3, a4 = s.workers.keys() +@gen_cluster(nthreads=[("", 1)] * 4, client=True, config=demo_config("drop", n=1)) +async def test_drop_from_worker_with_least_free_memory(c, s, *workers): ws1, ws2, ws3, ws4 = s.workers.values() futures = await c.scatter({"x": 1}, broadcast=True) assert s.tasks["x"].who_has == {ws1, ws2, ws3, ws4} - # Allocate enough RAM to be safely more than unmanaged memory - clog = c.submit(lambda: "x" * 2**29, workers=[a3]) # 512 MiB - # await wait(clog) is not enough; we need to wait for the heartbeats - while ws3.memory.optimistic < 2**29: - await asyncio.sleep(0.01) + clog = c.submit(lambda: "x" * 100, workers=[ws3.address]) + await wait(clog) + s.extensions["amm"].run_once() while s.tasks["x"].who_has != {ws1, ws2, ws4}: @@ -612,27 +609,14 @@ async def test_double_replicate_stress(c, s, a, b): await asyncio.sleep(0.01) -@pytest.mark.slow -@gen_cluster( - nthreads=[("", 1)] * 4, - Worker=Nanny, - client=True, - worker_kwargs={"memory_limit": "2 GiB"}, - config=demo_config("replicate", n=1), -) -async def test_replicate_to_worker_with_most_free_memory(c, s, *nannies): - a1, a2, a3, a4 = s.workers.keys() +@gen_cluster(nthreads=[("", 1)] * 4, client=True, config=demo_config("replicate", n=1)) +async def test_replicate_to_worker_with_most_free_memory(c, s, *workers): ws1, ws2, ws3, ws4 = s.workers.values() - futures = await c.scatter({"x": 1}, workers=[a1]) + x = await c.scatter({"x": 1}, workers=[ws1.address]) + clogs = await c.scatter([2, 3], workers=[ws2.address, ws4.address]) + assert s.tasks["x"].who_has == {ws1} - # Allocate enough RAM to be safely more than unmanaged memory - clog2 = c.submit(lambda: "x" * 2**29, workers=[a2]) # 512 MiB - clog4 = c.submit(lambda: "x" * 2**29, workers=[a4]) # 512 MiB - # await wait(clog) is not enough; we need to wait for the heartbeats - for ws in (ws2, ws4): - while ws.memory.optimistic < 2**29: - await asyncio.sleep(0.01) s.extensions["amm"].run_once() while s.tasks["x"].who_has != {ws1, ws3}: @@ -701,6 +685,17 @@ async def test_replicate_avoids_paused_workers_2(c, s, a, b): assert "x" not in b.data +@gen_test() +async def test_bad_measure(): + with dask.config.set( + {"distributed.scheduler.active-memory-manager.measure": "notexist"} + ): + with pytest.raises(ValueError) as e: + await Scheduler(dashboard_address=":0") + + assert "measure must be one of " in str(e.value) + + @gen_cluster( nthreads=[("", 1)] * 4, client=True, @@ -789,20 +784,19 @@ async def test_RetireWorker_no_remove(c, s, a, b): assert not s.extensions["amm"].policies -@pytest.mark.slow @pytest.mark.parametrize("use_ReduceReplicas", [False, True]) @gen_cluster( client=True, - Worker=Nanny, config={ "distributed.scheduler.active-memory-manager.start": True, "distributed.scheduler.active-memory-manager.interval": 0.1, + "distributed.scheduler.active-memory-manager.measure": "managed", "distributed.scheduler.active-memory-manager.policies": [ {"class": "distributed.active_memory_manager.ReduceReplicas"}, ], }, ) -async def test_RetireWorker_with_ReduceReplicas(c, s, *nannies, use_ReduceReplicas): +async def test_RetireWorker_with_ReduceReplicas(c, s, *workers, use_ReduceReplicas): """RetireWorker and ReduceReplicas work well with each other. If ReduceReplicas is enabled, @@ -823,12 +817,12 @@ async def test_RetireWorker_with_ReduceReplicas(c, s, *nannies, use_ReduceReplic if not use_ReduceReplicas: s.extensions["amm"].policies.clear() - x = c.submit(lambda: "x" * 2**26, key="x", workers=[ws_a.address]) # 64 MiB - y = c.submit(lambda: "y" * 2**26, key="y", workers=[ws_a.address]) # 64 MiB + x = c.submit(lambda: "x", key="x", workers=[ws_a.address]) + y = c.submit(lambda: "y", key="y", workers=[ws_a.address]) z = c.submit(lambda x: None, x, key="z", workers=[ws_b.address]) # copy x to ws_b # Make sure that the worker NOT being retired has the most RAM usage to test that # it is not being picked first since there's a retiring worker. - w = c.submit(lambda: "w" * 2**28, key="w", workers=[ws_b.address]) # 256 MiB + w = c.submit(lambda: "w" * 100, key="w", workers=[ws_b.address]) await wait([x, y, z, w]) await c.retire_workers([ws_a.address], remove=False) @@ -960,8 +954,9 @@ async def test_RetireWorker_new_keys_arrive_after_all_keys_moved_away(c, s, a, b while not len(ws_b.has_what) == len(xs): await asyncio.sleep(0) - # `_track_retire_worker` _should_ now be sleeping for 0.5s, because there were >=200 keys on A. - # In this test, everything from the beginning of the transfers needs to happen within 0.5s. + # `_track_retire_worker` _should_ now be sleeping for 0.5s, because there were >=200 + # keys on A. In this test, everything from the beginning of the transfers needs to + # happen within 0.5s. # Simulate the policy running again. Because the default 2s AMM interval is longer # than the 0.5s wait, what we're about to trigger is unlikely, but still possible @@ -1008,52 +1003,53 @@ async def test_RetireWorker_new_keys_arrive_after_all_keys_moved_away(c, s, a, b await extra.result() -# FIXME can't drop runtime of this test below 10s; see distributed#5585 -@pytest.mark.slow @gen_cluster( client=True, - Worker=Nanny, - nthreads=[("", 1)] * 3, config={ "distributed.scheduler.worker-ttl": "500ms", "distributed.scheduler.active-memory-manager.start": True, - "distributed.scheduler.active-memory-manager.interval": 0.1, + "distributed.scheduler.active-memory-manager.interval": 0.05, + "distributed.scheduler.active-memory-manager.measure": "managed", "distributed.scheduler.active-memory-manager.policies": [], }, ) -async def test_RetireWorker_faulty_recipient(c, s, *nannies): - """RetireWorker requests to replicate a key onto a unresponsive worker. +async def test_RetireWorker_faulty_recipient(c, s, w1, w2): + """RetireWorker requests to replicate a key onto an unresponsive worker. The AMM will iterate multiple times, repeating the command, until eventually the scheduler declares the worker dead and removes it from the pool; at that point the AMM will choose another valid worker and complete the job. """ - # ws1 is being retired - # ws2 has the lowest RAM usage and is chosen as a recipient, but is unresponsive - ws1, ws2, ws3 = s.workers.values() - f = c.submit(lambda: "x", key="x", workers=[ws1.address]) - await wait(f) - assert s.tasks["x"].who_has == {ws1} + # w1 is being retired + # w3 has the lowest RAM usage and is chosen as a recipient, but is unresponsive - # Fill ws3 with 200 MB of managed memory - # We're using plenty to make sure it's safely more than the unmanaged memory of ws2 - clutter = c.map(lambda i: "x" * 4_000_000, range(50), workers=[ws3.address]) - await wait([f] + clutter) - while ws3.memory.process < 200_000_000: - # Wait for heartbeat - await asyncio.sleep(0.01) - assert ws2.memory.process < ws3.memory.process + x = c.submit(lambda: 123, key="x", workers=[w1.address]) + await wait(x) + # Fill w2 with dummy data so that it's got the highest memory usage + # among the workers that are not being retired (w2 and w3). + clutter = await c.scatter(456, workers=[w2.address]) + + async with BlockedGatherDep(s.address) as w3: + await c.wait_for_workers(3) + + retire_fut = asyncio.create_task(c.retire_workers([w1.address])) + # w3 is chosen as the recipient for x, because it's got the lowest memory usage + await w3.in_gather_dep.wait() + + # AMM unfruitfully sends to w3 a new {op: acquire-replicas} message every 0.05s + while ( + sum(isinstance(ev, AcquireReplicasEvent) for ev in w3.state.stimulus_log) + < 3 + ): + await asyncio.sleep(0.01) - # Make ws2 unresponsive - clog_fut = asyncio.create_task(c.run(sleep, 3600, workers=[ws2.address])) - await asyncio.sleep(0.2) - assert ws2.address in s.workers + assert not retire_fut.done() - await c.retire_workers([ws1.address]) - assert ws1.address not in s.workers - # The AMM tried over and over to send the data to ws2, until it was declared dead - assert ws2.address not in s.workers - assert s.tasks["x"].who_has == {ws3} - clog_fut.cancel() + # w3 has been shut down. At this point, AMM switches to w2. + await retire_fut + + assert w1.address not in s.workers + assert w3.address not in s.workers + assert dict(w2.data) == {"x": 123, clutter.key: 456} class Counter: @@ -1154,20 +1150,21 @@ async def tensordot_stress(c): da = pytest.importorskip("dask.array") rng = da.random.RandomState(0) - a = rng.random((20, 20), chunks=(1, 1)) - b = (a @ a.T).sum().round(3) - assert await c.compute(b) == 2134.398 + a = rng.random((10, 10), chunks=(1, 1)) + # dask.array.core.PerformanceWarning: Increasing number of chunks by factor of 10 + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + b = (a @ a.T).sum().round(3) + assert await c.compute(b) == 245.394 @pytest.mark.slow -@pytest.mark.avoid_ci(reason="distributed#5371") @gen_cluster( client=True, nthreads=[("", 1)] * 4, - Worker=Nanny, config=NO_AMM_START, ) -async def test_noamm_stress(c, s, *nannies): +async def test_noamm_stress(c, s, *workers): """Test the tensordot_stress helper without AMM. This is to figure out if a stability issue is AMM-specific or not. """ @@ -1175,20 +1172,19 @@ async def test_noamm_stress(c, s, *nannies): @pytest.mark.slow -@pytest.mark.avoid_ci(reason="distributed#5371") @gen_cluster( client=True, nthreads=[("", 1)] * 4, - Worker=Nanny, config={ "distributed.scheduler.active-memory-manager.start": True, "distributed.scheduler.active-memory-manager.interval": 0.1, + "distributed.scheduler.active-memory-manager.measure": "managed", "distributed.scheduler.active-memory-manager.policies": [ {"class": "distributed.tests.test_active_memory_manager.DropEverything"}, ], }, ) -async def test_drop_stress(c, s, *nannies): +async def test_drop_stress(c, s, *workers): """A policy which suggests dropping everything won't break a running computation, but only slow it down. @@ -1198,20 +1194,19 @@ async def test_drop_stress(c, s, *nannies): @pytest.mark.slow -@pytest.mark.avoid_ci(reason="distributed#5371") @gen_cluster( client=True, nthreads=[("", 1)] * 4, - Worker=Nanny, config={ "distributed.scheduler.active-memory-manager.start": True, "distributed.scheduler.active-memory-manager.interval": 0.1, + "distributed.scheduler.active-memory-manager.measure": "managed", "distributed.scheduler.active-memory-manager.policies": [ {"class": "distributed.active_memory_manager.ReduceReplicas"}, ], }, ) -async def test_ReduceReplicas_stress(c, s, *nannies): +async def test_ReduceReplicas_stress(c, s, *workers): """Running ReduceReplicas compulsively won't break a running computation. Unlike test_drop_stress above, this test does not stop running after a few seconds - the policy must not disrupt the computation too much. @@ -1220,19 +1215,14 @@ async def test_ReduceReplicas_stress(c, s, *nannies): @pytest.mark.slow -@pytest.mark.avoid_ci(reason="distributed#5371") @pytest.mark.parametrize("use_ReduceReplicas", [False, True]) @gen_cluster( client=True, nthreads=[("", 1)] * 10, - Worker=Nanny, config={ "distributed.scheduler.active-memory-manager.start": True, - # If interval is too low, then the AMM will rerun while tasks have not yet have - # the time to migrate. This is OK if it happens occasionally, but if this - # setting is too aggressive the cluster will get flooded with repeated comm - # requests. - "distributed.scheduler.active-memory-manager.interval": 2.0, + "distributed.scheduler.active-memory-manager.interval": 0.1, + "distributed.scheduler.active-memory-manager.measure": "managed", "distributed.scheduler.active-memory-manager.policies": [ {"class": "distributed.active_memory_manager.ReduceReplicas"}, ], @@ -1240,7 +1230,7 @@ async def test_ReduceReplicas_stress(c, s, *nannies): scheduler_kwargs={"transition_counter_max": 500_000}, worker_kwargs={"transition_counter_max": 500_000}, ) -async def test_RetireWorker_stress(c, s, *nannies, use_ReduceReplicas): +async def test_RetireWorker_stress(c, s, *workers, use_ReduceReplicas): """It is safe to retire the best part of a cluster in the middle of a computation""" if not use_ReduceReplicas: s.extensions["amm"].policies.clear() diff --git a/docs/source/active_memory_manager.rst b/docs/source/active_memory_manager.rst index 10eb9c9290..cf5416b33a 100644 --- a/docs/source/active_memory_manager.rst +++ b/docs/source/active_memory_manager.rst @@ -36,6 +36,7 @@ The AMM can be enabled through the :doc:`Dask configuration file active-memory-manager: start: true interval: 2s + measure: optimistic The above is the recommended setup and will run all enabled *AMM policies* (see below) every two seconds. Alternatively, you can manually start/stop the AMM from the @@ -79,6 +80,7 @@ Individual policies are enabled, disabled, and configured through the Dask confi active-memory-manager: start: true interval: 2s + measure: optimistic policies: - class: distributed.active_memory_manager.ReduceReplicas - class: my_package.MyPolicy From 8f36aa5ce32ed5c2c07870698029a97621ed6c2b Mon Sep 17 00:00:00 2001 From: Hendrik Makait Date: Wed, 28 Sep 2022 16:44:33 +0200 Subject: [PATCH 04/11] Improve documentation of `message-bytes-limit` (#7077) --- distributed/distributed-schema.yaml | 7 ++++++- distributed/worker_state_machine.py | 11 ++++++----- 2 files changed, 12 insertions(+), 6 deletions(-) diff --git a/distributed/distributed-schema.yaml b/distributed/distributed-schema.yaml index 790c80ea87..36887f74e4 100644 --- a/distributed/distributed-schema.yaml +++ b/distributed/distributed-schema.yaml @@ -357,7 +357,12 @@ properties: - string - integer description: | - The maximum size of a message sent between workers + The maximum amount of data for a worker to request from another in a single gather operation + + Tasks are gathered in batches, and if the first task in a batch is larger than this value, + the task will still be gathered to ensure progress. Hence, this limit is not absolute. + Note that this limit applies to a single gather operation and a worker may gather data from + multiple workers in parallel. connections: type: object description: | diff --git a/distributed/worker_state_machine.py b/distributed/worker_state_machine.py index 339e58a63e..c8ad6cdf6f 100644 --- a/distributed/worker_state_machine.py +++ b/distributed/worker_state_machine.py @@ -1126,10 +1126,11 @@ class WorkerState: #: multiple entries in :attr:`~TaskState.who_has` will appear multiple times here. data_needed: defaultdict[str, HeapSet[TaskState]] - #: Number of bytes to fetch from the same worker in a single call to - #: :meth:`BaseWorker.gather_dep`. Multiple small tasks that can be fetched from the - #: same worker will be clustered in a single instruction as long as their combined - #: size doesn't exceed this value. + #: Number of bytes to gather from the same worker in a single call to + #: :meth:`BaseWorker.gather_dep`. Multiple small tasks that can be gathered from the + #: same worker will be batched in a single instruction as long as their combined + #: size doesn't exceed this value. If the first task to be gathered exceeds this + # limit, it will still be gathered to ensure progress. Hence, this limit is not absolute. transfer_message_bytes_limit: float #: All and only tasks with ``TaskState.state == 'missing'``. @@ -1546,7 +1547,7 @@ def _ensure_communicating(self, *, stimulus_id: str) -> RecsInstrs: # A single invocation of _ensure_communicating may generate up to one # GatherDep instruction per worker. Multiple tasks from the same worker may - # be clustered in the same instruction by _select_keys_for_gather. But once + # be batched in the same instruction by _select_keys_for_gather. But once # a worker has been selected for a GatherDep and added to in_flight_workers, # it won't be selected again until the gather completes. instructions.append( From 482941ebe6c0d5fd851efd4b193ea3392b7ce4a9 Mon Sep 17 00:00:00 2001 From: James Bourbeau Date: Wed, 28 Sep 2022 23:42:23 -0500 Subject: [PATCH 05/11] Allow timeout strings in `distributed.wait` (#7081) --- distributed/client.py | 8 +++++--- distributed/tests/test_client.py | 5 +++++ 2 files changed, 10 insertions(+), 3 deletions(-) diff --git a/distributed/client.py b/distributed/client.py index e4ead779fc..eafbe48f73 100644 --- a/distributed/client.py +++ b/distributed/client.py @@ -4741,9 +4741,9 @@ def wait(fs, timeout=None, return_when=ALL_COMPLETED): Parameters ---------- fs : List[Future] - timeout : number, optional - Time in seconds after which to raise a - ``dask.distributed.TimeoutError`` + timeout : number, string, optional + Time after which to raise a ``dask.distributed.TimeoutError``. + Can be a string like ``"10 minutes"`` or a number of seconds to wait. return_when : str, optional One of `ALL_COMPLETED` or `FIRST_COMPLETED` @@ -4751,6 +4751,8 @@ def wait(fs, timeout=None, return_when=ALL_COMPLETED): ------- Named tuple of completed, not completed """ + if timeout is not None and isinstance(timeout, (Number, str)): + timeout = parse_timedelta(timeout, default="s") client = default_client() result = client.sync(_wait, fs, timeout=timeout, return_when=return_when) return result diff --git a/distributed/tests/test_client.py b/distributed/tests/test_client.py index 14ea3c7f88..0c8d83b152 100644 --- a/distributed/tests/test_client.py +++ b/distributed/tests/test_client.py @@ -749,6 +749,11 @@ async def test_wait_timeout(c, s, a, b): with pytest.raises(TimeoutError): await wait(future, timeout=0.01) + # Ensure timeout can be a string + future = c.submit(sleep, 0.3) + with pytest.raises(TimeoutError): + await wait(future, timeout="0.01 s") + def test_wait_sync(c): x = c.submit(inc, 1) From e7057c5dab1508d917c20ea643ae89882fb3264b Mon Sep 17 00:00:00 2001 From: crusaderky Date: Thu, 29 Sep 2022 16:02:33 +0100 Subject: [PATCH 06/11] Enable Active Memory Manager by default (#7042) --- distributed/client.py | 12 +- distributed/distributed.yaml | 2 +- distributed/protocol/tests/test_serialize.py | 4 +- .../tests/test_active_memory_manager.py | 56 ++++----- distributed/tests/test_actor.py | 4 +- distributed/tests/test_client.py | 119 +++++++++++------- distributed/tests/test_failed_workers.py | 52 ++++---- distributed/tests/test_resources.py | 17 ++- distributed/tests/test_scheduler.py | 67 ++++++---- distributed/tests/test_semaphore.py | 5 +- distributed/tests/test_steal.py | 92 +++++++------- distributed/tests/test_stories.py | 10 +- distributed/tests/test_stress.py | 7 +- distributed/tests/test_tls_functional.py | 16 ++- distributed/tests/test_worker.py | 85 +++++++------ distributed/tests/test_worker_memory.py | 22 +++- .../tests/test_worker_state_machine.py | 9 +- distributed/utils_test.py | 25 +++- docs/source/active_memory_manager.rst | 33 ++++- docs/source/resilience.rst | 13 +- 20 files changed, 396 insertions(+), 254 deletions(-) diff --git a/distributed/client.py b/distributed/client.py index eafbe48f73..a09c2ca704 100644 --- a/distributed/client.py +++ b/distributed/client.py @@ -2370,6 +2370,11 @@ def scatter( broadcast : bool (defaults to False) Whether to send each data element to all workers. By default we round-robin based on number of cores. + + .. note:: + Setting this flag to True is incompatible with the Active Memory + Manager's :ref:`ReduceReplicas` policy. If you wish to use it, you must + first disable the policy or disable the AMM entirely. direct : bool (defaults to automatically check) Whether or not to connect directly to the workers, or to ask the scheduler to serve as intermediary. This can also be set when @@ -3513,12 +3518,17 @@ def replicate(self, futures, n=None, workers=None, branching_factor=2, **kwargs) """Set replication of futures within network Copy data onto many workers. This helps to broadcast frequently - accessed data and it helps to improve resilience. + accessed data and can improve resilience. This performs a tree copy of the data throughout the network individually on each piece of data. This operation blocks until complete. It does not guarantee replication of data to future workers. + .. note:: + This method is incompatible with the Active Memory Manager's + :ref:`ReduceReplicas` policy. If you wish to use it, you must first disable + the policy or disable the AMM entirely. + Parameters ---------- futures : list of futures diff --git a/distributed/distributed.yaml b/distributed/distributed.yaml index bc437e5903..e47f7cba8c 100644 --- a/distributed/distributed.yaml +++ b/distributed/distributed.yaml @@ -66,7 +66,7 @@ distributed: # Set to true to auto-start the Active Memory Manager on Scheduler start; if false # you'll have to either manually start it with client.amm.start() or run it once # with client.amm.run_once(). - start: false + start: true # Once started, run the AMM cycle every interval: 2s diff --git a/distributed/protocol/tests/test_serialize.py b/distributed/protocol/tests/test_serialize.py index 900a233885..311334bd2c 100644 --- a/distributed/protocol/tests/test_serialize.py +++ b/distributed/protocol/tests/test_serialize.py @@ -35,7 +35,7 @@ ) from distributed.protocol.serialize import check_dask_serializable from distributed.utils import ensure_memoryview, nbytes -from distributed.utils_test import gen_test, inc +from distributed.utils_test import NO_AMM, gen_test, inc class MyObj: @@ -208,7 +208,7 @@ async def test_object_in_graph(c, s, a, b): assert result.data == 123 -@gen_cluster(client=True) +@gen_cluster(client=True, config=NO_AMM) async def test_scatter(c, s, a, b): o = MyObj(123) [future] = await c._scatter([o]) diff --git a/distributed/tests/test_active_memory_manager.py b/distributed/tests/test_active_memory_manager.py index 0f13e95c51..edc07a546b 100644 --- a/distributed/tests/test_active_memory_manager.py +++ b/distributed/tests/test_active_memory_manager.py @@ -20,6 +20,7 @@ ) from distributed.core import Status from distributed.utils_test import ( + NO_AMM, BlockedGatherDep, assert_story, captured_logger, @@ -32,8 +33,6 @@ ) from distributed.worker_state_machine import AcquireReplicasEvent -NO_AMM_START = {"distributed.scheduler.active-memory-manager.start": False} - @contextmanager def assert_amm_log(expect: list[str]) -> Iterator[None]: @@ -250,7 +249,7 @@ async def test_multi_start(c, s, a, b): assert len(s.tasks["z"].who_has) == 1 -@gen_cluster(client=True, config=NO_AMM_START) +@gen_cluster(client=True, config=NO_AMM) async def test_not_registered(c, s, a, b): futures = await c.scatter({"x": 1}, broadcast=True) assert len(s.tasks["x"].who_has) == 2 @@ -267,16 +266,17 @@ def run(self): await asyncio.sleep(0.01) -def test_client_proxy_sync(client): - assert not client.amm.running() - client.amm.start() - assert client.amm.running() - client.amm.stop() - assert not client.amm.running() - client.amm.run_once() +def test_client_proxy_sync(client_no_amm): + c = client_no_amm + assert not c.amm.running() + c.amm.start() + assert c.amm.running() + c.amm.stop() + assert not c.amm.running() + c.amm.run_once() -@gen_cluster(client=True, config=NO_AMM_START) +@gen_cluster(client=True, config=NO_AMM) async def test_client_proxy_async(c, s, a, b): assert not await c.amm.running() await c.amm.start() @@ -318,7 +318,7 @@ async def test_drop_with_waiter(c, s, a, b): assert not y2.done() -@gen_cluster(client=True, config=NO_AMM_START) +@gen_cluster(client=True, config=NO_AMM) async def test_double_drop(c, s, a, b): """An AMM drop policy runs once to drop one of the two replicas of a key. Then it runs again, before the recommendations from the first iteration had the time @@ -832,7 +832,7 @@ async def test_RetireWorker_with_ReduceReplicas(c, s, *workers, use_ReduceReplic assert {ts.key for ts in ws_b.has_what} == {"x", "y", "z", "w"} -@gen_cluster(client=True, nthreads=[("", 1)] * 3, config=NO_AMM_START) +@gen_cluster(client=True, nthreads=[("", 1)] * 3, config=NO_AMM) async def test_RetireWorker_all_replicas_are_being_retired(c, s, w1, w2, w3): """There are multiple replicas of a key, but they all reside on workers that are being retired @@ -917,7 +917,6 @@ async def test_RetireWorker_all_recipients_are_paused(c, s, a, b): "distributed.scheduler.active-memory-manager.start": True, "distributed.scheduler.active-memory-manager.policies": [], }, - timeout=15, ) async def test_RetireWorker_new_keys_arrive_after_all_keys_moved_away(c, s, a, b): """ @@ -950,22 +949,22 @@ async def test_RetireWorker_new_keys_arrive_after_all_keys_moved_away(c, s, a, b t = asyncio.create_task(c.retire_workers([a.address])) + amm: ActiveMemoryManagerExtension = s.extensions["amm"] + while not amm.policies: + await asyncio.sleep(0) + policy = next(iter(amm.policies)) + assert isinstance(policy, RetireWorker) + # Wait for all `xs` to be replicated. - while not len(ws_b.has_what) == len(xs): + while len(ws_b.has_what) != len(xs): await asyncio.sleep(0) # `_track_retire_worker` _should_ now be sleeping for 0.5s, because there were >=200 # keys on A. In this test, everything from the beginning of the transfers needs to # happen within 0.5s. - # Simulate the policy running again. Because the default 2s AMM interval is longer - # than the 0.5s wait, what we're about to trigger is unlikely, but still possible - # for the times to line up. (Especially with a custom AMM interval.) - amm: ActiveMemoryManagerExtension = s.extensions["amm"] - assert len(amm.policies) == 1 - policy = next(iter(amm.policies)) - assert isinstance(policy, RetireWorker) - + # Simulate waiting for the policy to run again. + # Note that the interval at which the policy runs is inconsequential for this test. amm.run_once() # The policy has removed itself, because all `xs` have been replicated. @@ -979,8 +978,9 @@ async def test_RetireWorker_new_keys_arrive_after_all_keys_moved_away(c, s, a, b if a.address not in s.workers: # It took more than 0.5s to get here, and the scheduler closed our worker. Dang. - pytest.skip( - "Timing didn't work out: `_track_retire_worker` finished before `extra` completed." + pytest.xfail( + "Timing didn't work out: `_track_retire_worker` finished before " + "`extra` completed." ) # `retire_workers` doesn't hang @@ -1082,7 +1082,7 @@ async def test_dont_replicate_actors(c, s, a, b): @pytest.mark.parametrize("has_proxy", [False, True]) -@gen_cluster(client=True, config=NO_AMM_START) +@gen_cluster(client=True, config=NO_AMM) async def test_RetireWorker_with_actor(c, s, a, b, has_proxy): """A worker holding one or more original actor objects cannot be retired""" x = c.submit(Counter, key="x", actor=True, workers=[a.address]) @@ -1108,7 +1108,7 @@ async def test_RetireWorker_with_actor(c, s, a, b, has_proxy): assert "y" in b.data -@gen_cluster(client=True, config=NO_AMM_START) +@gen_cluster(client=True, config=NO_AMM) async def test_RetireWorker_with_actor_proxy(c, s, a, b): """A worker holding an Actor proxy object can be retired as normal.""" x = c.submit(Counter, key="x", actor=True, workers=[a.address]) @@ -1162,7 +1162,7 @@ async def tensordot_stress(c): @gen_cluster( client=True, nthreads=[("", 1)] * 4, - config=NO_AMM_START, + config=NO_AMM, ) async def test_noamm_stress(c, s, *workers): """Test the tensordot_stress helper without AMM. This is to figure out if a diff --git a/distributed/tests/test_actor.py b/distributed/tests/test_actor.py index 9c0f86bc4f..cbf708b0cf 100644 --- a/distributed/tests/test_actor.py +++ b/distributed/tests/test_actor.py @@ -477,9 +477,7 @@ def f(block, ps=None): print(format_time(end - start)) -@pytest.mark.slow -@pytest.mark.flaky(reruns=10, reruns_delay=5) -@gen_cluster(client=True, timeout=120) +@gen_cluster(client=True) async def test_compute(c, s, a, b): @dask.delayed def f(n, counter): diff --git a/distributed/tests/test_client.py b/distributed/tests/test_client.py index 0c8d83b152..aa8a001ac5 100644 --- a/distributed/tests/test_client.py +++ b/distributed/tests/test_client.py @@ -85,6 +85,8 @@ tmp_text, ) from distributed.utils_test import ( + NO_AMM, + BlockedGatherDep, TaskStateMetadataPlugin, _UnhashableCallable, async_wait_for, @@ -112,6 +114,7 @@ tls_only_security, varying, wait_for, + wait_for_state, ) pytestmark = pytest.mark.ci1 @@ -932,7 +935,7 @@ async def test_tokenize_on_futures(c, s, a, b): @pytest.mark.skipif(not LINUX, reason="Need 127.0.0.2 to mean localhost") -@gen_cluster([("127.0.0.1", 1), ("127.0.0.2", 2)], client=True) +@gen_cluster([("127.0.0.1", 1), ("127.0.0.2", 2)], client=True, config=NO_AMM) async def test_restrictions_submit(c, s, a, b): x = c.submit(inc, 1, workers={a.ip}) y = c.submit(inc, x, workers={b.ip}) @@ -945,7 +948,7 @@ async def test_restrictions_submit(c, s, a, b): assert y.key in b.data -@gen_cluster(client=True) +@gen_cluster(client=True, config=NO_AMM) async def test_restrictions_ip_port(c, s, a, b): x = c.submit(inc, 1, workers={a.address}) y = c.submit(inc, x, workers={b.address}) @@ -983,7 +986,7 @@ async def test_restrictions_get(c, s, a, b): assert len(b.data) == 0 -@gen_cluster(client=True) +@gen_cluster(client=True, config=NO_AMM) async def test_restrictions_get_annotate(c, s, a, b): x = 1 with dask.annotate(workers=a.address): @@ -1441,7 +1444,7 @@ async def test_scatter_direct_numpy(c, s, a, b): assert not s.counters["op"].components[0]["scatter"] -@gen_cluster(client=True) +@gen_cluster(client=True, config=NO_AMM) async def test_scatter_direct_broadcast(c, s, a, b): future2 = await c.scatter(456, direct=True, broadcast=True) assert future2.key in a.data @@ -1458,7 +1461,7 @@ async def test_scatter_direct_balanced(c, s, *workers): assert sorted(len(w.data) for w in workers) == [0, 1, 1, 1] -@gen_cluster(client=True, nthreads=[("127.0.0.1", 1)] * 4) +@gen_cluster(client=True, nthreads=[("127.0.0.1", 1)] * 4, config=NO_AMM) async def test_scatter_direct_broadcast_target(c, s, *workers): futures = await c.scatter([123, 456], direct=True, workers=workers[0].address) assert futures[0].key in workers[0].data @@ -1693,7 +1696,8 @@ def g(): assert result == (value, value) -@gen_cluster(client=True) +# _upload_large_file internally calls replicate, which makes it incompatible with AMM +@gen_cluster(client=True, config=NO_AMM) async def test_upload_large_file(c, s, a, b): assert a.local_directory assert b.local_directory @@ -2263,20 +2267,20 @@ async def test_multi_garbage_collection(s, a, b): assert not s.tasks -@gen_cluster(client=True) +@gen_cluster(client=True, config=NO_AMM) async def test__broadcast(c, s, a, b): x, y = await c.scatter([1, 2], broadcast=True) assert a.data == b.data == {x.key: 1, y.key: 2} -@gen_cluster(client=True, nthreads=[("127.0.0.1", 1)] * 4) +@gen_cluster(client=True, nthreads=[("127.0.0.1", 1)] * 4, config=NO_AMM) async def test__broadcast_integer(c, s, *workers): x, y = await c.scatter([1, 2], broadcast=2) assert len(s.tasks[x.key].who_has) == 2 assert len(s.tasks[y.key].who_has) == 2 -@gen_cluster(client=True) +@gen_cluster(client=True, config=NO_AMM) async def test__broadcast_dict(c, s, a, b): d = await c.scatter({"x": 1}, broadcast=True) assert a.data == b.data == {"x": 1} @@ -2897,11 +2901,14 @@ def __reduce__(self): # Set rebalance() to work predictably on small amounts of managed memory. By default, it # uses optimistic memory, which would only be possible to test by allocating very large # amounts of managed memory, so that they would hide variations in unmanaged memory. -REBALANCE_MANAGED_CONFIG = { - "distributed.worker.memory.rebalance.measure": "managed", - "distributed.worker.memory.rebalance.sender-min": 0, - "distributed.worker.memory.rebalance.sender-recipient-gap": 0, -} +REBALANCE_MANAGED_CONFIG = merge( + NO_AMM, + { + "distributed.worker.memory.rebalance.measure": "managed", + "distributed.worker.memory.rebalance.sender-min": 0, + "distributed.worker.memory.rebalance.sender-recipient-gap": 0, + }, +) @gen_cluster(client=True, config=REBALANCE_MANAGED_CONFIG) @@ -2955,7 +2962,7 @@ def test_rebalance_sync(loop): assert len(b.data) == 50 -@gen_cluster(client=True) +@gen_cluster(client=True, config=NO_AMM) async def test_rebalance_unprepared(c, s, a, b): """Client.rebalance() internally waits for unfinished futures""" futures = c.map(slowinc, range(10), delay=0.05, workers=a.address) @@ -2967,7 +2974,7 @@ async def test_rebalance_unprepared(c, s, a, b): s.validate_state() -@gen_cluster(client=True) +@gen_cluster(client=True, config=NO_AMM) async def test_rebalance_raises_on_explicit_missing_data(c, s, a, b): """rebalance() raises KeyError if explicitly listed futures disappear""" f = Future("x", client=c, state="memory") @@ -3015,7 +3022,7 @@ async def test_add_worker_after_tasks(c, s): @pytest.mark.skipif(not LINUX, reason="Need 127.0.0.2 to mean localhost") -@gen_cluster([("127.0.0.1", 1), ("127.0.0.2", 2)], client=True) +@gen_cluster([("127.0.0.1", 1), ("127.0.0.2", 2)], client=True, config=NO_AMM) async def test_workers_register_indirect_data(c, s, a, b): [x] = await c.scatter([1], workers=a.address) y = c.submit(inc, x, workers=b.ip) @@ -3037,7 +3044,11 @@ async def test_submit_on_cancelled_future(c, s, a, b): c.submit(inc, x) -@gen_cluster(client=True, nthreads=[("127.0.0.1", 1)] * 10) +@gen_cluster( + client=True, + nthreads=[("127.0.0.1", 1)] * 10, + config=NO_AMM, +) async def test_replicate(c, s, *workers): [a, b] = await c.scatter([1, 2]) await s.replicate(keys=[a.key, b.key], n=5) @@ -3050,7 +3061,7 @@ async def test_replicate(c, s, *workers): assert sum(b.key in w.data for w in workers) == 5 -@gen_cluster(client=True) +@gen_cluster(client=True, config=NO_AMM) async def test_replicate_tuple_keys(c, s, a, b): x = delayed(inc)(1, dask_key_name=("x", 1)) f = c.persist(x) @@ -3062,7 +3073,11 @@ async def test_replicate_tuple_keys(c, s, a, b): s.validate_state() -@gen_cluster(client=True, nthreads=[("127.0.0.1", 1)] * 10) +@gen_cluster( + client=True, + nthreads=[("127.0.0.1", 1)] * 10, + config=NO_AMM, +) async def test_replicate_workers(c, s, *workers): [a, b] = await c.scatter([1, 2], workers=[workers[0].address]) @@ -3113,7 +3128,11 @@ def __getstate__(self): return self.n -@gen_cluster(client=True, nthreads=[("127.0.0.1", 1)] * 10) +@gen_cluster( + client=True, + nthreads=[("127.0.0.1", 1)] * 10, + config=NO_AMM, +) async def test_replicate_tree_branching(c, s, *workers): obj = CountSerialization() [future] = await c.scatter([obj]) @@ -3123,7 +3142,11 @@ async def test_replicate_tree_branching(c, s, *workers): assert max_count > 1 -@gen_cluster(client=True, nthreads=[("127.0.0.1", 1)] * 10) +@gen_cluster( + client=True, + nthreads=[("127.0.0.1", 1)] * 10, + config=NO_AMM, +) async def test_client_replicate(c, s, *workers): x = c.submit(inc, 1) y = c.submit(inc, 2) @@ -3148,6 +3171,7 @@ async def test_client_replicate(c, s, *workers): @gen_cluster( client=True, nthreads=[("127.0.0.1", 1), ("127.0.0.2", 1), ("127.0.0.2", 1)], + config=NO_AMM, ) async def test_client_replicate_host(client, s, a, b, c): aws = s.workers[a.address] @@ -3165,7 +3189,9 @@ async def test_client_replicate_host(client, s, a, b, c): assert s.tasks[x.key].who_has == {aws, bws, cws} -def test_client_replicate_sync(c): +def test_client_replicate_sync(client_no_amm): + c = client_no_amm + x = c.submit(inc, 1) y = c.submit(inc, 2) c.replicate([x, y], n=2) @@ -3232,7 +3258,7 @@ async def test_balanced_with_submit(c, s, *workers): assert len(w.data) == 1 -@gen_cluster(client=True, nthreads=[("127.0.0.1", 1)] * 4) +@gen_cluster(client=True, nthreads=[("127.0.0.1", 1)] * 4, config=NO_AMM) async def test_balanced_with_submit_and_resident_data(c, s, *workers): [x] = await c.scatter([10], broadcast=True) L = [c.submit(slowinc, x, pure=False) for i in range(4)] @@ -3885,7 +3911,11 @@ async def test_lose_scattered_data(c, s, a, b): assert x.key not in s.tasks -@gen_cluster(client=True, nthreads=[("127.0.0.1", 1)] * 3) +@gen_cluster( + client=True, + nthreads=[("127.0.0.1", 1)] * 3, + config=NO_AMM, +) async def test_partially_lose_scattered_data(e, s, a, b, c): x = await e.scatter(1, workers=a.address) await e.replicate(x, n=2) @@ -3897,22 +3927,22 @@ async def test_partially_lose_scattered_data(e, s, a, b, c): assert s.get_task_status(keys=[x.key]) == {x.key: "memory"} -@gen_cluster(client=True) -async def test_scatter_compute_lose(c, s, a, b): - [x] = await c.scatter([[1, 2, 3, 4]], workers=a.address) - y = c.submit(inc, 1, workers=b.address) +@gen_cluster(client=True, nthreads=[("", 1)]) +async def test_scatter_compute_lose(c, s, a): + x = (await c.scatter({"x": 1}, workers=[a.address]))["x"] - z = c.submit(slowadd, x, y, delay=0.2) - await asyncio.sleep(0.1) + async with BlockedGatherDep(s.address) as b: + y = c.submit(inc, x, key="y", workers=[b.address]) + await wait_for_state("x", "flight", b) - await a.close() + await a.close() + b.block_gather_dep.set() - with pytest.raises(CancelledError): - await wait(z) + with pytest.raises(CancelledError): + await wait(y) - assert x.status == "cancelled" - assert y.status == "finished" - assert z.status == "cancelled" + assert x.status == "cancelled" + assert y.status == "cancelled" @gen_cluster(client=True) @@ -4097,7 +4127,7 @@ async def run2(): @nodebug # test timing is fragile -@gen_cluster(nthreads=[("127.0.0.1", 1)] * 3, client=True) +@gen_cluster(nthreads=[("127.0.0.1", 1)] * 3, client=True, config=NO_AMM) async def test_persist_workers_annotate(e, s, a, b, c): with dask.annotate(workers=a.address, allow_other_workers=False): L1 = [delayed(inc)(i) for i in range(4)] @@ -4304,10 +4334,13 @@ def f(x, y=0, z=0): @gen_cluster( client=True, nthreads=[("127.0.0.1", 1), ("127.0.0.1", 10)], - config={ - "distributed.scheduler.work-stealing": False, - "distributed.scheduler.default-task-durations": {"f": "10ms"}, - }, + config=merge( + NO_AMM, + { + "distributed.scheduler.work-stealing": False, + "distributed.scheduler.default-task-durations": {"f": "10ms"}, + }, + ), ) async def test_distribute_tasks_by_nthreads(c, s, a, b): def f(x, y=0): @@ -5850,7 +5883,7 @@ def bad_fn(x): @gen_cluster( client=True, nthreads=[("", 1)] * 10, - config={"distributed.worker.memory.pause": False}, + config=merge(NO_AMM, {"distributed.worker.memory.pause": False}), ) async def test_scatter_and_replicate_avoid_paused_workers( c, s, *workers, workers_arg, direct, broadcast diff --git a/distributed/tests/test_failed_workers.py b/distributed/tests/test_failed_workers.py index 48bb30af59..29a25fdc44 100644 --- a/distributed/tests/test_failed_workers.py +++ b/distributed/tests/test_failed_workers.py @@ -7,10 +7,13 @@ from time import sleep from unittest import mock +import psutil import pytest -from tlz import first, partition_all +from tlz import first, merge, partition_all +import dask.config from dask import delayed +from dask.utils import parse_bytes from distributed import Client, Nanny, profile, wait from distributed.comm import CommClosedError @@ -18,6 +21,7 @@ from distributed.metrics import time from distributed.utils import CancelledError, sync from distributed.utils_test import ( + NO_AMM, BlockedGatherDep, BlockedGetData, async_wait_for, @@ -235,7 +239,8 @@ async def test_forgotten_futures_dont_clean_up_new_futures(c, s, a, b): y = c.submit(inc, 1) del x - # Ensure that the profiler has stopped and released all references to x so that it can be garbage-collected + # Ensure that the profiler has stopped and released all references to x so that it + # can be garbage-collected with profile.lock: pass await asyncio.sleep(0.1) @@ -305,16 +310,11 @@ def __init__(self, data, delay=0.1): self.data = data def __reduce__(self): - import time - - time.sleep(self.delay) - return (SlowTransmitData, (self.delay,)) + sleep(self.delay) + return SlowTransmitData, (self.data, self.delay) def __sizeof__(self) -> int: # Ensure this is offloaded to avoid blocking loop - import dask - from dask.utils import parse_bytes - return parse_bytes(dask.config.get("distributed.comm.offload")) + 1 @@ -357,6 +357,7 @@ def sink(*args): @gen_cluster( client=True, nthreads=[("127.0.0.1", 1), ("127.0.0.1", 2), ("127.0.0.1", 3)], + config=NO_AMM, ) async def test_worker_same_host_replicas_missing(c, s, a, b, x): # See GH4784 @@ -462,11 +463,11 @@ async def test_forget_data_not_supposed_to_have(c, s, a): @gen_cluster( client=True, - nthreads=[("127.0.0.1", 1) for _ in range(3)], - config={"distributed.comm.timeouts.connect": "1s"}, + nthreads=[("", 1)] * 3, + config=merge(NO_AMM, {"distributed.comm.timeouts.connect": "1s"}), Worker=Nanny, ) -async def test_failing_worker_with_additional_replicas_on_cluster(c, s, *workers): +async def test_failing_worker_with_additional_replicas_on_cluster(c, s, n0, n1, n2): """ If a worker detects a missing dependency, the scheduler is notified. If no other replica is available, the dependency is rescheduled. A reschedule @@ -475,34 +476,33 @@ async def test_failing_worker_with_additional_replicas_on_cluster(c, s, *workers and correct its state. """ - def slow_transfer(x, delay=0.1): - return SlowTransmitData(x, delay=delay) - def dummy(*args, **kwargs): return - import psutil - - proc = psutil.Process(workers[1].pid) + proc1 = psutil.Process(n1.pid) f1 = c.submit( - slow_transfer, + SlowTransmitData, 1, + delay=0.1, key="f1", - workers=[workers[0].worker_address], + workers=[n0.worker_address], ) + await wait(f1) + # We'll schedule tasks on two workers, s.t. f1 is replicated. We will # suspend one of the workers and kill the origin worker of f1 such that a # comm failure causes the worker to handle a missing dependency. It will ask # the schedule such that it knows that a replica is available on f2 and # reschedules the fetch - f2 = c.submit(dummy, f1, pure=False, key="f2", workers=[workers[1].worker_address]) - f3 = c.submit(dummy, f1, pure=False, key="f3", workers=[workers[2].worker_address]) + f2 = c.submit(dummy, f1, key="f2", workers=[n1.worker_address]) + f3 = c.submit(dummy, f1, key="f3", workers=[n2.worker_address]) - await wait(f1) - proc.suspend() + proc1.suspend() await wait(f3) - await workers[0].close() + # Because of this line we need to disable AMM; otherwise it could choose to delete + # the replicas of f1 on n1 and n2 and keep the one on n0. + await n0.close() - proc.resume() + proc1.resume() await c.gather([f1, f2, f3]) diff --git a/distributed/tests/test_resources.py b/distributed/tests/test_resources.py index 405f106cdd..ada72adbec 100644 --- a/distributed/tests/test_resources.py +++ b/distributed/tests/test_resources.py @@ -10,7 +10,7 @@ from distributed import Lock, Worker from distributed.client import wait -from distributed.utils_test import gen_cluster, inc, lock_inc, slowadd, slowinc +from distributed.utils_test import NO_AMM, gen_cluster, inc, lock_inc, slowadd, slowinc from distributed.worker_state_machine import ( ComputeTaskEvent, Execute, @@ -133,23 +133,20 @@ async def test_map(c, s, a, b): @gen_cluster( client=True, - nthreads=[ - ("127.0.0.1", 1, {"resources": {"A": 1}}), - ("127.0.0.1", 1, {"resources": {"B": 1}}), - ], + nthreads=[("", 1, {"resources": {"A": 1}}), ("", 1, {"resources": {"B": 1}})], + config=NO_AMM, ) async def test_persist(c, s, a, b): with dask.annotate(resources={"A": 1}): - x = delayed(inc)(1) + x = delayed(inc)(1, dask_key_name="x") with dask.annotate(resources={"B": 1}): - y = delayed(inc)(x) + y = delayed(inc)(x, dask_key_name="y") xx, yy = c.persist([x, y], optimize_graph=False) - await wait([xx, yy]) - assert x.key in a.data - assert y.key in b.data + assert set(a.data) == {"x"} + assert set(b.data) == {"x", "y"} @gen_cluster( diff --git a/distributed/tests/test_scheduler.py b/distributed/tests/test_scheduler.py index 94ab7439f6..95bebd674e 100644 --- a/distributed/tests/test_scheduler.py +++ b/distributed/tests/test_scheduler.py @@ -43,6 +43,7 @@ from distributed.scheduler import KilledWorker, MemoryState, Scheduler, WorkerState from distributed.utils import TimeoutError from distributed.utils_test import ( + NO_AMM, BrokenComm, async_wait_for, captured_logger, @@ -1354,7 +1355,7 @@ def key(ws): # Assert that *total* byte count in group determines group priority av = await c.scatter("a" * 100, workers=workers[0].address) bv = await c.scatter("b" * 75, workers=workers[2].address) - bv2 = await c.scatter("b" * 75, workers=workers[3].address) + cv = await c.scatter("c" * 75, workers=workers[3].address) assert set(s.workers_to_close(key=key)) == {workers[0].address, workers[1].address} @@ -1897,7 +1898,7 @@ async def test_retries(c, s, a, b): exc_info.match("one") -@gen_cluster(client=True, nthreads=[("127.0.0.1", 1)] * 3) +@gen_cluster(client=True, nthreads=[("127.0.0.1", 1)] * 3, config=NO_AMM) async def test_missing_data_errant_worker(c, s, w1, w2, w3): with dask.config.set({"distributed.comm.timeouts.connect": "1s"}): np = pytest.importorskip("numpy") @@ -2425,7 +2426,7 @@ def scheduler_delay(self, value): pass -@gen_cluster(client=True, Worker=NoSchedulerDelayWorker) +@gen_cluster(client=True, Worker=NoSchedulerDelayWorker, config=NO_AMM) async def test_task_groups(c, s, a, b): start = time() da = pytest.importorskip("dask.array") @@ -3116,7 +3117,7 @@ async def assert_ndata(client, by_addr, total=None): client=True, Worker=Nanny, worker_kwargs={"memory_limit": "1 GiB"}, - config={"distributed.worker.memory.rebalance.sender-min": 0.3}, + config=merge(NO_AMM, {"distributed.worker.memory.rebalance.sender-min": 0.3}), ) async def test_rebalance(c, s, a, b): # We used nannies to have separate processes for each worker @@ -3146,11 +3147,14 @@ async def test_rebalance(c, s, a, b): # Set rebalance() to work predictably on small amounts of managed memory. By default, it # uses optimistic memory, which would only be possible to test by allocating very large # amounts of managed memory, so that they would hide variations in unmanaged memory. -REBALANCE_MANAGED_CONFIG = { - "distributed.worker.memory.rebalance.measure": "managed", - "distributed.worker.memory.rebalance.sender-min": 0, - "distributed.worker.memory.rebalance.sender-recipient-gap": 0, -} +REBALANCE_MANAGED_CONFIG = merge( + NO_AMM, + { + "distributed.worker.memory.rebalance.measure": "managed", + "distributed.worker.memory.rebalance.sender-min": 0, + "distributed.worker.memory.rebalance.sender-recipient-gap": 0, + }, +) @gen_cluster(client=True, config=REBALANCE_MANAGED_CONFIG) @@ -3183,14 +3187,14 @@ async def test_rebalance_workers_and_keys(client, s, a, b, c): await s.rebalance(workers=["notexist"]) -@gen_cluster() +@gen_cluster(config=NO_AMM) async def test_rebalance_missing_data1(s, a, b): """key never existed""" out = await s.rebalance(keys=["notexist"]) assert out == {"status": "partial-fail", "keys": ["notexist"]} -@gen_cluster(client=True) +@gen_cluster(client=True, config=NO_AMM) async def test_rebalance_missing_data2(c, s, a, b): """keys exist but belong to unfinished futures. Unlike Client.rebalance(), Scheduler.rebalance() does not wait for unfinished futures. @@ -3239,7 +3243,7 @@ async def test_rebalance_no_workers(s): @gen_cluster( client=True, worker_kwargs={"memory_limit": 0}, - config={"distributed.worker.memory.rebalance.measure": "managed"}, + config=merge(NO_AMM, {"distributed.worker.memory.rebalance.measure": "managed"}), ) async def test_rebalance_no_limit(c, s, a, b): futures = await c.scatter(range(100), workers=[a.address]) @@ -3255,11 +3259,14 @@ async def test_rebalance_no_limit(c, s, a, b): client=True, Worker=Nanny, worker_kwargs={"memory_limit": "1000 MiB"}, - config={ - "distributed.worker.memory.rebalance.measure": "managed", - "distributed.worker.memory.rebalance.sender-min": 0.2, - "distributed.worker.memory.rebalance.recipient-max": 0.1, - }, + config=merge( + NO_AMM, + { + "distributed.worker.memory.rebalance.measure": "managed", + "distributed.worker.memory.rebalance.sender-min": 0.2, + "distributed.worker.memory.rebalance.recipient-max": 0.1, + }, + ), ) async def test_rebalance_no_recipients(c, s, a, b): """There are sender workers, but no recipient workers""" @@ -3277,7 +3284,7 @@ async def test_rebalance_no_recipients(c, s, a, b): nthreads=[("", 1)] * 3, client=True, worker_kwargs={"memory_limit": 0}, - config={"distributed.worker.memory.rebalance.measure": "managed"}, + config=merge(NO_AMM, {"distributed.worker.memory.rebalance.measure": "managed"}), ) async def test_rebalance_skip_recipient(client, s, a, b, c): """A recipient is skipped because it already holds a copy of the key to be sent""" @@ -3292,7 +3299,7 @@ async def test_rebalance_skip_recipient(client, s, a, b, c): @gen_cluster( client=True, worker_kwargs={"memory_limit": 0}, - config={"distributed.worker.memory.rebalance.measure": "managed"}, + config=merge(NO_AMM, {"distributed.worker.memory.rebalance.measure": "managed"}), ) async def test_rebalance_skip_all_recipients(c, s, a, b): """All recipients are skipped because they already hold copies""" @@ -3308,7 +3315,7 @@ async def test_rebalance_skip_all_recipients(c, s, a, b): client=True, Worker=Nanny, worker_kwargs={"memory_limit": "1000 MiB"}, - config={"distributed.worker.memory.rebalance.measure": "managed"}, + config=merge(NO_AMM, {"distributed.worker.memory.rebalance.measure": "managed"}), ) async def test_rebalance_sender_below_mean(c, s, *_): """A task remains on the sender because moving it would send it below the mean""" @@ -3327,10 +3334,13 @@ async def test_rebalance_sender_below_mean(c, s, *_): client=True, Worker=Nanny, worker_kwargs={"memory_limit": "1000 MiB"}, - config={ - "distributed.worker.memory.rebalance.measure": "managed", - "distributed.worker.memory.rebalance.sender-min": 0.3, - }, + config=merge( + NO_AMM, + { + "distributed.worker.memory.rebalance.measure": "managed", + "distributed.worker.memory.rebalance.sender-min": 0.3, + }, + ), ) async def test_rebalance_least_recently_inserted_sender_min(c, s, *_): """ @@ -3433,7 +3443,7 @@ async def test_gather_on_worker_key_not_on_sender_replicated( assert c.data[x.key] == "x" -@gen_cluster(client=True, nthreads=[("127.0.0.1", 1)] * 3) +@gen_cluster(client=True, nthreads=[("127.0.0.1", 1)] * 3, config=NO_AMM) async def test_gather_on_worker_duplicate_task(client, s, a, b, c): """Race condition where the recipient worker receives the same task twice. Test that the task nbytes are not double-counted on the recipient. @@ -3458,7 +3468,10 @@ async def test_gather_on_worker_duplicate_task(client, s, a, b, c): @gen_cluster( - client=True, nthreads=[("127.0.0.1", 1)] * 3, scheduler_kwargs={"timeout": "100ms"} + client=True, + nthreads=[("127.0.0.1", 1)] * 3, + scheduler_kwargs={"timeout": "100ms"}, + config=NO_AMM, ) async def test_rebalance_dead_recipient(client, s, a, b, c): """A key fails to be rebalanced due to recipient failure. @@ -3483,7 +3496,7 @@ async def test_rebalance_dead_recipient(client, s, a, b, c): assert await client.has_what() == {a.address: (y.key,), b.address: (x.key,)} -@gen_cluster(client=True) +@gen_cluster(client=True, config=NO_AMM) async def test_delete_worker_data(c, s, a, b): # delete only copy of x # delete one of the copies of y diff --git a/distributed/tests/test_semaphore.py b/distributed/tests/test_semaphore.py index 457d51920a..88a24e18a0 100644 --- a/distributed/tests/test_semaphore.py +++ b/distributed/tests/test_semaphore.py @@ -557,8 +557,9 @@ async def test_release_retry(c, s, a, b): }, ) async def test_release_failure(c, s, a, b): - """Don't raise even if release fails: lease will be cleaned up by the lease-validation after - a specified interval anyways (see config parameters used).""" + """Don't raise even if release fails: lease will be cleaned up by the + lease-validation after a specified interval anyway (see config parameters used). + """ with dask.config.set({"distributed.comm.retry.count": 1}): pool = await FlakyConnectionPool(failing_connections=5) diff --git a/distributed/tests/test_steal.py b/distributed/tests/test_steal.py index cc00b50428..e10bd3f021 100644 --- a/distributed/tests/test_steal.py +++ b/distributed/tests/test_steal.py @@ -23,6 +23,8 @@ from distributed.metrics import time from distributed.system import MEMORY_LIMIT from distributed.utils_test import ( + NO_AMM, + BlockedGetData, SizeOf, captured_logger, freeze_batched_send, @@ -88,7 +90,7 @@ async def test_steal_cheap_data_slow_computation(c, s, a, b): @pytest.mark.slow -@gen_cluster(client=True, nthreads=[("", 1)] * 2) +@gen_cluster(client=True, nthreads=[("", 1)] * 2, config=NO_AMM) async def test_steal_expensive_data_slow_computation(c, s, a, b): np = pytest.importorskip("numpy") @@ -105,7 +107,7 @@ async def test_steal_expensive_data_slow_computation(c, s, a, b): assert b.data # not empty -@gen_cluster(client=True, nthreads=[("127.0.0.1", 1)] * 10) +@gen_cluster(client=True, nthreads=[("127.0.0.1", 1)] * 10, config=NO_AMM) async def test_worksteal_many_thieves(c, s, *workers): x = c.submit(slowinc, -1, delay=0.1) await x @@ -300,44 +302,50 @@ def do_nothing(x, y=None): assert len(s.workers[workers[0].address].has_what) == len(xs) + len(futures) -@gen_cluster(client=True) -async def test_dont_steal_fast_tasks_blocklist(c, s, a, b): - # create a dependency - x = c.submit(slowinc, 1, workers=[b.address]) +@gen_cluster(client=True, nthreads=[("", 1)]) +async def test_dont_steal_fast_tasks_blocklist(c, s, a): + async with BlockedGetData(s.address) as b: + # create a dependency + x = c.submit(inc, 1, workers=[b.address], key="x") + await wait(x) - # If the blocklist of fast tasks is tracked somewhere else, this needs to be - # changed. This test requires *any* key which is blocked. - from distributed.stealing import fast_tasks + # If the blocklist of fast tasks is tracked somewhere else, this needs to be + # changed. This test requires *any* key which is blocked. + from distributed.stealing import fast_tasks - blocked_key = next(iter(fast_tasks)) - - def fast_blocked(x, y=None): - # The task should observe a certain computation time such that we can - # ensure that it is not stolen due to the blocking. If it is too - # fast, the standard mechanism shouldn't allow stealing - import time + blocked_key = next(iter(fast_tasks)) - time.sleep(0.01) + def fast_blocked(i, x): + # The task should observe a certain computation time such that we can + # ensure that it is not stolen due to the blocking. If it is too + # fast, the standard mechanism shouldn't allow stealing + sleep(0.01) - futures = c.map( - fast_blocked, - range(100), - y=x, - # Submit the task to one worker but allow it to be distributed else, - # i.e. this is not a task restriction - workers=[a.address], - allow_other_workers=True, - key=blocked_key, - ) + futures = c.map( + fast_blocked, + range(50), + x=x, + # Submit the task to one worker but allow it to be distributed elsewhere, + # i.e. this is not a task restriction + workers=[a.address], + allow_other_workers=True, + key=blocked_key, + ) - await wait(futures) + while len(s.tasks) < 51: + await asyncio.sleep(0.01) + b.block_get_data.set() + await wait(futures) - # The +1 is the dependency we initially submitted to worker B - assert len(s.workers[a.address].has_what) == 101 - assert len(s.workers[b.address].has_what) == 1 + # Note: x may now be on a, b, or both, depending if the Active Memory Manager + # got to run or not + ws_a = s.workers[a.address] + for ts in s.tasks.values(): + if ts.key.startswith(blocked_key): + assert ts.who_has == {ws_a} -@gen_cluster(client=True, nthreads=[("127.0.0.1", 1)]) +@gen_cluster(client=True, nthreads=[("", 1)], config=NO_AMM) async def test_new_worker_steals(c, s, a): await wait(c.submit(slowinc, 1, delay=0.01)) @@ -346,13 +354,16 @@ async def test_new_worker_steals(c, s, a): while len(a.state.tasks) < 10: await asyncio.sleep(0.01) - async with Worker(s.address, nthreads=1, memory_limit=MEMORY_LIMIT) as b: + async with Worker(s.address, nthreads=1) as b: result = await total assert result == sum(map(inc, range(100))) for w in (a, b): assert all(isinstance(v, int) for v in w.data.values()) + # This requires AMM to be off. Otherwise, if b reports higher optimistic memory + # than a and `total` happens to be computed on a, then all keys on b will be + # replicated onto a and then deleted by the AMM. assert b.data @@ -1115,11 +1126,6 @@ async def test_steal_concurrent_simple(c, s, *workers): assert not ws2.has_what -# FIXME shouldn't consistently fail, may be an actual bug? -@pytest.mark.skipif( - math.isfinite(dask.config.get("distributed.scheduler.worker-saturation")), - reason="flaky with queuing active", -) @gen_cluster( client=True, config={ @@ -1142,8 +1148,8 @@ def block(x, event): event = Event() futs1 = [ - c.submit(block, f, event=event, key=f"f1-{ix}") - for f in roots + c.submit(block, r, event=event, key=f"f{ir}-{ix}") + for ir, r in enumerate(roots) for ix in range(4) ] while not w0.state.ready: @@ -1335,9 +1341,9 @@ async def test_correct_bad_time_estimate(c, s, *workers): steal = s.extensions["stealing"] future = c.submit(slowinc, 1, delay=0) await wait(future) - futures = [c.submit(slowinc, future, delay=0.1, pure=False) for i in range(20)] - while not any(f.key in s.tasks for f in futures): - await asyncio.sleep(0.001) + futures = [c.submit(slowinc, future, delay=0.1, pure=False) for _ in range(20)] + while len(s.tasks) < 21: + await asyncio.sleep(0) assert not any(s.tasks[f.key] in steal.key_stealable for f in futures) await asyncio.sleep(0.5) assert any(s.tasks[f.key] in steal.key_stealable for f in futures) diff --git a/distributed/tests/test_stories.py b/distributed/tests/test_stories.py index 36e7739b4e..b9e178c813 100644 --- a/distributed/tests/test_stories.py +++ b/distributed/tests/test_stories.py @@ -6,7 +6,13 @@ from distributed import Worker from distributed.comm import CommClosedError -from distributed.utils_test import assert_story, assert_valid_story, gen_cluster, inc +from distributed.utils_test import ( + NO_AMM, + assert_story, + assert_valid_story, + gen_cluster, + inc, +) @gen_cluster(client=True, nthreads=[("", 1)]) @@ -124,7 +130,7 @@ async def test_client_story_failed_worker(c, s, a, b, on_error): raise ValueError(on_error) -@gen_cluster(client=True) +@gen_cluster(client=True, config=NO_AMM) async def test_worker_story_with_deps(c, s, a, b): """ Assert that the structure of the story does not change unintentionally and diff --git a/distributed/tests/test_stress.py b/distributed/tests/test_stress.py index adce31c100..a0a880737c 100644 --- a/distributed/tests/test_stress.py +++ b/distributed/tests/test_stress.py @@ -17,6 +17,7 @@ from distributed.metrics import time from distributed.utils import CancelledError from distributed.utils_test import ( + NO_AMM, bump_rlimit, cluster, gen_cluster, @@ -122,7 +123,11 @@ async def create_and_destroy_worker(delay): assert await c.compute(z) == 8000884.93 -@gen_cluster(nthreads=[("", 1)] * 10, client=True) +@gen_cluster( + nthreads=[("", 1)] * 10, + client=True, + config=NO_AMM, +) async def test_stress_scatter_death(c, s, *workers): s.allowed_failures = 1000 np = pytest.importorskip("numpy") diff --git a/distributed/tests/test_tls_functional.py b/distributed/tests/test_tls_functional.py index b72ca18666..ac7a0a17ab 100644 --- a/distributed/tests/test_tls_functional.py +++ b/distributed/tests/test_tls_functional.py @@ -6,10 +6,13 @@ import asyncio +from tlz import merge + from distributed import Client, Nanny, Queue, Scheduler, Worker, wait, worker_client from distributed.core import Status from distributed.metrics import time from distributed.utils_test import ( + NO_AMM, double, gen_test, gen_tls_cluster, @@ -101,11 +104,14 @@ async def test_nanny(c, s, a, b): @gen_tls_cluster( client=True, - config={ - "distributed.worker.memory.rebalance.measure": "managed", - "distributed.worker.memory.rebalance.sender-min": 0, - "distributed.worker.memory.rebalance.sender-recipient-gap": 0, - }, + config=merge( + NO_AMM, + { + "distributed.worker.memory.rebalance.measure": "managed", + "distributed.worker.memory.rebalance.sender-min": 0, + "distributed.worker.memory.rebalance.sender-recipient-gap": 0, + }, + ), ) async def test_rebalance(c, s, a, b): """Test Client.rebalance(). This test is just to test the TLS Client wrapper around diff --git a/distributed/tests/test_worker.py b/distributed/tests/test_worker.py index 5e71c49900..3c182061e0 100644 --- a/distributed/tests/test_worker.py +++ b/distributed/tests/test_worker.py @@ -20,7 +20,7 @@ import psutil import pytest -from tlz import first, pluck, sliding_window +from tlz import first, merge, pluck, sliding_window from tornado.ioloop import IOLoop import dask @@ -49,6 +49,7 @@ from distributed.protocol import pickle from distributed.scheduler import Scheduler from distributed.utils_test import ( + NO_AMM, BlockedExecute, BlockedGatherDep, BlockedGetData, @@ -645,29 +646,20 @@ async def test_inter_worker_communication(c, s, a, b): assert result == 3 -@gen_cluster(client=True) -async def test_clean(c, s, a, b): - x = c.submit(inc, 1, workers=a.address) - y = c.submit(inc, x, workers=b.address) - - await y +@gen_cluster(client=True, nthreads=[("", 1)]) +async def test_clean(c, s, a): + x = c.submit(inc, 1) + await x - collections = [ - a.state.tasks, - a.data, - a.threads, - ] - for c in collections: - assert c + collections = [a.state.tasks, a.data, a.threads] + assert all(collections) x.release() - y.release() while x.key in a.state.tasks: await asyncio.sleep(0.01) - for c in collections: - assert not c + assert not any(collections) @gen_cluster(client=True) @@ -806,7 +798,11 @@ async def test_multiple_transfers(c, s, w1, w2, w3): @pytest.mark.xfail(reason="very high flakiness") -@gen_cluster(client=True, nthreads=[("127.0.0.1", 1)] * 3) +@gen_cluster( + client=True, + nthreads=[("127.0.0.1", 1)] * 3, + config=NO_AMM, +) async def test_share_communication(c, s, w1, w2, w3): x = c.submit( mul, b"1", int(w3.transfer_message_bytes_limit + 1), workers=w1.address @@ -815,7 +811,7 @@ async def test_share_communication(c, s, w1, w2, w3): mul, b"2", int(w3.transfer_message_bytes_limit + 1), workers=w2.address ) await wait([x, y]) - await c._replicate([x, y], workers=[w1.address, w2.address]) + await c.replicate([x, y], workers=[w1.address, w2.address]) z = c.submit(add, x, y, workers=w3.address) await wait(z) assert len(w3.transfer_incoming_log) == 2 @@ -867,7 +863,7 @@ async def test_clean_up_dependencies(c, s, a, b): assert set(a.data) | set(b.data) == {zz.key} -@gen_cluster(client=True) +@gen_cluster(client=True, config=NO_AMM) async def test_hold_onto_dependents(c, s, a, b): x = c.submit(inc, 1, workers=a.address) y = c.submit(inc, x, workers=b.address) @@ -1258,7 +1254,9 @@ async def test_wait_for_outgoing(c, s, a, b): @pytest.mark.skipif(not LINUX, reason="Need 127.0.0.2 to mean localhost") @gen_cluster( - nthreads=[("127.0.0.1", 1), ("127.0.0.1", 1), ("127.0.0.2", 1)], client=True + nthreads=[("127.0.0.1", 1), ("127.0.0.1", 1), ("127.0.0.2", 1)], + client=True, + config=NO_AMM, ) async def test_prefer_gather_from_local_address(c, s, w1, w2, w3): x = await c.scatter(123, workers=[w1.address, w3.address], broadcast=True) @@ -1947,8 +1945,8 @@ def f(x): assert not C.instances -@pytest.mark.slow -@gen_cluster(client=True) +# @pytest.mark.slow +@gen_cluster(client=True, config=NO_AMM) async def test_gather_dep_one_worker_always_busy(c, s, a, b): # Ensure that both dependencies for H are on another worker than H itself. # The worker where the dependencies are on is then later blocked such that @@ -1967,10 +1965,7 @@ async def test_gather_dep_one_worker_always_busy(c, s, a, b): h = c.submit(add, f, g, key="h", workers=[b.address]) - while h.key not in b.state.tasks: - await asyncio.sleep(0.01) - - assert b.state.tasks[h.key].state == "waiting" + await wait_for_state(h.key, "waiting", b) assert b.state.tasks[f.key].state in ("flight", "fetch") assert b.state.tasks[g.key].state in ("flight", "fetch") @@ -1999,6 +1994,7 @@ async def test_gather_dep_one_worker_always_busy(c, s, a, b): @gen_cluster( client=True, nthreads=[("127.0.0.1", 1)] * 2 + [("127.0.0.2", 2)] * 10, # type: ignore + config=NO_AMM, ) async def test_gather_dep_local_workers_first(c, s, a, lw, *rws): f = ( @@ -2015,6 +2011,7 @@ async def test_gather_dep_local_workers_first(c, s, a, lw, *rws): @gen_cluster( client=True, nthreads=[("127.0.0.2", 1)] + [("127.0.0.1", 1)] * 10, # type: ignore + config=NO_AMM, ) async def test_gather_dep_from_remote_workers_if_all_local_workers_are_busy( c, s, rw, a, *lws @@ -2244,7 +2241,9 @@ async def test_gpu_executor(c, s, w): assert "gpu" not in w.executors -async def assert_task_states_on_worker(expected, worker): +async def assert_task_states_on_worker( + expected: dict[str, str], worker: Worker +) -> None: active_exc = None for _ in range(10): try: @@ -2270,7 +2269,7 @@ async def assert_task_states_on_worker(expected, worker): raise active_exc -@gen_cluster(client=True) +@gen_cluster(client=True, config=NO_AMM) async def test_worker_state_error_release_error_last(c, s, a, b): """ Create a chain of tasks and err one of them. Then release tasks in a certain @@ -2337,7 +2336,7 @@ def raise_exc(*args): await asyncio.sleep(0.01) -@gen_cluster(client=True) +@gen_cluster(client=True, config=NO_AMM) async def test_worker_state_error_release_error_first(c, s, a, b): """ Create a chain of tasks and err one of them. Then release tasks in a certain @@ -2402,7 +2401,7 @@ def raise_exc(*args): await asyncio.sleep(0.01) -@gen_cluster(client=True) +@gen_cluster(client=True, config=NO_AMM) async def test_worker_state_error_release_error_int(c, s, a, b): """ Create a chain of tasks and err one of them. Then release tasks in a certain @@ -2467,7 +2466,7 @@ def raise_exc(*args): await asyncio.sleep(0.01) -@gen_cluster(client=True) +@gen_cluster(client=True, config=NO_AMM) async def test_worker_state_error_long_chain(c, s, a, b): def raise_exc(*args): raise RuntimeError() @@ -2550,7 +2549,7 @@ def raise_exc(*args): await asyncio.sleep(0.01) -@gen_cluster(client=True, nthreads=[("", x) for x in (1, 2, 3, 4)]) +@gen_cluster(client=True, nthreads=[("", x) for x in (1, 2, 3, 4)], config=NO_AMM) async def test_hold_on_to_replicas(c, s, *workers): f1 = c.submit(inc, 1, workers=[workers[0].address], key="f1") f2 = c.submit(inc, 2, workers=[workers[1].address], key="f2") @@ -2673,7 +2672,7 @@ def __reduce__(self): assert "return lambda: 1 / 0, ()" in logvalue -@gen_cluster(client=True) +@gen_cluster(client=True, config=NO_AMM) async def test_acquire_replicas(c, s, a, b): fut = c.submit(inc, 1, workers=[a.address]) await fut @@ -2693,7 +2692,7 @@ async def test_acquire_replicas(c, s, a, b): await asyncio.sleep(0.005) -@gen_cluster(client=True) +@gen_cluster(client=True, config=NO_AMM) async def test_acquire_replicas_same_channel(c, s, a, b): futA = c.submit(inc, 1, workers=[a.address], key="f-A") futB = c.submit(inc, 2, workers=[a.address], key="f-B") @@ -2725,7 +2724,7 @@ async def test_acquire_replicas_same_channel(c, s, a, b): assert any(fut.key in msg["keys"] for msg in b.transfer_incoming_log) -@gen_cluster(client=True, nthreads=[("127.0.0.1", 1)] * 3) +@gen_cluster(client=True, nthreads=[("127.0.0.1", 1)] * 3, config=NO_AMM) async def test_acquire_replicas_many(c, s, w1, w2, w3): futs = c.map(inc, range(10), workers=[w1.address]) res = c.submit(sum, futs, workers=[w2.address]) @@ -2755,7 +2754,7 @@ async def test_acquire_replicas_many(c, s, w1, w2, w3): await asyncio.sleep(0.001) -@gen_cluster(client=True, nthreads=[("", 1)]) +@gen_cluster(client=True, nthreads=[("", 1)], config=NO_AMM) async def test_acquire_replicas_already_in_flight(c, s, a): """Trying to acquire a replica that is already in flight is a no-op""" async with BlockedGatherDep(s.address) as b: @@ -2783,7 +2782,7 @@ async def test_acquire_replicas_already_in_flight(c, s, a): @pytest.mark.slow -@gen_cluster(client=True) +@gen_cluster(client=True, config=NO_AMM) async def test_forget_acquire_replicas(c, s, a, b): """ 1. The scheduler sends acquire-replicas to the worker @@ -2811,7 +2810,7 @@ async def test_forget_acquire_replicas(c, s, a, b): assert "x" not in s.tasks -@gen_cluster(client=True) +@gen_cluster(client=True, config=NO_AMM) async def test_remove_replicas_simple(c, s, a, b): futs = c.map(inc, range(10), workers=[a.address]) await wait(futs) @@ -2839,7 +2838,7 @@ async def test_remove_replicas_simple(c, s, a, b): @gen_cluster( client=True, nthreads=[("", 1), ("", 6)], # Up to 5 threads of b will get stuck; read below - config={"distributed.comm.recent-messages-log-length": 1_000}, + config=merge(NO_AMM, {"distributed.comm.recent-messages-log-length": 1_000}), ) async def test_remove_replicas_while_computing(c, s, a, b): futs = c.map(inc, range(10), workers=[a.address]) @@ -2939,7 +2938,7 @@ def answer_sent(key): await asyncio.sleep(0.01) -@gen_cluster(client=True, nthreads=[("", 1)] * 3) +@gen_cluster(client=True, nthreads=[("", 1)] * 3, config=NO_AMM) async def test_who_has_consistent_remove_replicas(c, s, *workers): a = workers[0] other_workers = {w for w in workers if w != a} @@ -2975,7 +2974,7 @@ async def test_who_has_consistent_remove_replicas(c, s, *workers): assert s.tasks[f1.key].suspicious == 0 -@gen_cluster(client=True) +@gen_cluster(client=True, config=NO_AMM) async def test_acquire_replicas_with_no_priority(c, s, a, b): """Scattered tasks have no priority. When they transit to another worker through acquire-replicas, they end up in the Worker.data_needed heap together with tasks @@ -2998,7 +2997,7 @@ async def test_acquire_replicas_with_no_priority(c, s, a, b): assert b.state.tasks["x"].priority is not None -@gen_cluster(client=True, nthreads=[("", 1)]) +@gen_cluster(client=True, nthreads=[("", 1)], config=NO_AMM) async def test_acquire_replicas_large_data(c, s, a): """When acquire-replicas is used to acquire multiple sizeable tasks, it respects transfer_message_bytes_limit and acquires them over multiple iterations. diff --git a/distributed/tests/test_worker_memory.py b/distributed/tests/test_worker_memory.py index 8e27f613af..ec13d69ec8 100644 --- a/distributed/tests/test_worker_memory.py +++ b/distributed/tests/test_worker_memory.py @@ -11,6 +11,7 @@ import psutil import pytest +from tlz import merge import dask.config @@ -20,7 +21,13 @@ from distributed.core import Status from distributed.metrics import monotonic from distributed.spill import has_zict_210 -from distributed.utils_test import captured_logger, gen_cluster, inc, wait_for_state +from distributed.utils_test import ( + NO_AMM, + captured_logger, + gen_cluster, + inc, + wait_for_state, +) from distributed.worker_memory import parse_memory_limit from distributed.worker_state_machine import ( ComputeTaskEvent, @@ -589,11 +596,14 @@ def f(ev): @gen_cluster( client=True, nthreads=[("", 1), ("", 1)], - config={ - "distributed.worker.memory.target": False, - "distributed.worker.memory.spill": False, - "distributed.worker.memory.pause": False, - }, + config=merge( + NO_AMM, + { + "distributed.worker.memory.target": False, + "distributed.worker.memory.spill": False, + "distributed.worker.memory.pause": False, + }, + ), ) async def test_pause_prevents_deps_fetch(c, s, a, b): """A worker is paused while there are dependencies ready to fetch, but all other diff --git a/distributed/tests/test_worker_state_machine.py b/distributed/tests/test_worker_state_machine.py index 9b46a14b6e..5d7cd35221 100644 --- a/distributed/tests/test_worker_state_machine.py +++ b/distributed/tests/test_worker_state_machine.py @@ -17,6 +17,7 @@ from distributed.scheduler import TaskState as SchedulerTaskState from distributed.utils import recursive_to_dict from distributed.utils_test import ( + NO_AMM, _LockedCommPool, assert_story, freeze_data_fetching, @@ -598,7 +599,11 @@ async def test_fetch_via_amm_to_compute(c, s, a, b): @pytest.mark.parametrize("as_deps", [False, True]) -@gen_cluster(client=True, nthreads=[("", 1)] * 3) +@gen_cluster( + client=True, + nthreads=[("", 1)] * 3, + config=NO_AMM, +) async def test_lose_replica_during_fetch(c, s, w1, w2, w3, as_deps): """ as_deps=True @@ -779,7 +784,7 @@ async def test_cancelled_while_in_flight(c, s, a, b): await asyncio.sleep(0.01) -@gen_cluster(client=True) +@gen_cluster(client=True, config=NO_AMM) async def test_in_memory_while_in_flight(c, s, a, b): """ 1. A client scatters x to a diff --git a/distributed/utils_test.py b/distributed/utils_test.py index 79b1eeab27..8a92f1db1d 100644 --- a/distributed/utils_test.py +++ b/distributed/utils_test.py @@ -112,6 +112,11 @@ def is_debugging() -> bool: _TEST_TIMEOUT = 30 _offload_executor.submit(lambda: None).result() # create thread during import +# Dask configuration to completely disable the Active Memory Manager. +# This is typically used with @gen_cluster(config=NO_AMM) +# or @gen_cluster(config=merge(NO_AMM, { 5: + if time() - start > 5: # pragma: nocover raise Exception("Timeout on cluster creation") _run_and_close_tornado(wait_for_workers) diff --git a/docs/source/active_memory_manager.rst b/docs/source/active_memory_manager.rst index cf5416b33a..0201e2424d 100644 --- a/docs/source/active_memory_manager.rst +++ b/docs/source/active_memory_manager.rst @@ -27,7 +27,8 @@ causes increased overall memory usage across the cluster. Enabling the Active Memory Manager ---------------------------------- -The AMM can be enabled through the :doc:`Dask configuration file `: +The AMM is enabled by default. It can be disabled or tweaked through the :doc:`Dask +configuration file `: .. code-block:: yaml @@ -96,6 +97,9 @@ config and see if it is fit for purpose for you before you tweak individual poli Built-in policies ----------------- + +.. _ReduceReplicas: + ReduceReplicas ++++++++++++++ class @@ -114,6 +118,30 @@ computation, this policy drops all excess replicas. run this policy, it will delete all replicas but one (but not necessarily the new ones). +RetireWorker +++++++++++++ +class + :class:`distributed.active_memory_manager.RetireWorker` +parameters + address : str + The address of the worker being retired. + +This is a special policy, which should never appear in the Dask configuration file. + +It is injected on the fly by :meth:`distributed.Client.retire_workers` and whenever +an adaptive cluster is being scaled down. +This policy supervises moving all tasks, that are in memory exclusively on the worker +being retired, to different workers. Once the worker does not uniquely hold the data for +any task, this policy uninstalls itself automatically from the Active Memory Manager and +the worker is shut down. + +If multiple workers are being retired at the same time, there will be multiple instances +of this policy installed in the AMM. + +If the Active Memory Manager is disabled, :meth:`distributed.Client.retire_workers` and +adaptive scaling will start a temporary one, install this policy into it, and then shut +it down once it's finished. + Custom policies --------------- @@ -128,7 +156,8 @@ define two methods: ``run`` This method accepts no parameters and is invoked by the AMM every 2 seconds (or whatever the AMM interval is). - It must yield zero or more of the following :class:`~distributed.active_memory_manager.Suggestion` namedtuples: + It must yield zero or more of the following + :class:`~distributed.active_memory_manager.Suggestion` namedtuples: ``yield Suggestion("replicate", )`` Create one replica of the target task on the worker with the lowest memory usage diff --git a/docs/source/resilience.rst b/docs/source/resilience.rst index 1936d7ee99..c7219e51a7 100644 --- a/docs/source/resilience.rst +++ b/docs/source/resilience.rst @@ -48,12 +48,13 @@ This has some fail cases. causes a segmentation fault, then that bad function will repeatedly be called on other workers. This function will be marked as "bad" after it kills a fixed number of workers (defaults to three). -3. Data sent out directly to the workers via a call to ``scatter()`` (instead - of being created from a Dask task graph via other Dask functions) is not - kept in the scheduler, as it is often quite large, and so the loss of this - data is irreparable. You may wish to call ``Client.replicate`` on the data - with a suitable replication factor to ensure that it remains long-lived or - else back the data off of some resilient store, like a file system. +3. Data sent out directly to the workers via a call to + :meth:`~distributed.client.Client.scatter` (instead of being created from a Dask + task graph via other Dask functions) is not kept in the scheduler, as it is often + quite large, and so the loss of this data is irreparable. You may wish to call + :meth:`~distributed.client.Client.replicate` on the data with a suitable replication + factor to ensure that it remains long-lived or else back the data off on some + resilient store, like a file system. Hardware Failures From 4052350a33cafd98a02c99e388dbf70f3401bf60 Mon Sep 17 00:00:00 2001 From: Hendrik Makait Date: Thu, 29 Sep 2022 18:03:18 +0200 Subject: [PATCH 07/11] Smarter stealing with dependencies (#7024) --- distributed/stealing.py | 9 +- distributed/tests/test_steal.py | 494 ++++++++++++++++++++++++++++++-- 2 files changed, 480 insertions(+), 23 deletions(-) diff --git a/distributed/stealing.py b/distributed/stealing.py index cdbcce30c4..2d0917710a 100644 --- a/distributed/stealing.py +++ b/distributed/stealing.py @@ -4,6 +4,7 @@ import logging from collections import defaultdict, deque from collections.abc import Container +from functools import partial from math import log2 from time import time from typing import TYPE_CHECKING, Any, ClassVar, TypedDict, cast @@ -522,12 +523,12 @@ def _get_thief( ) -> WorkerState | None: valid_workers = scheduler.valid_workers(ts) if valid_workers is not None: - subset = potential_thieves & valid_workers - if subset: - return next(iter(subset)) + valid_thieves = potential_thieves & valid_workers + if valid_thieves: + potential_thieves = valid_thieves elif not ts.loose_restrictions: return None - return next(iter(potential_thieves)) + return min(potential_thieves, key=partial(scheduler.worker_objective, ts)) fast_tasks = {"split-shuffle"} diff --git a/distributed/tests/test_steal.py b/distributed/tests/test_steal.py index e10bd3f021..fd42f5bae8 100644 --- a/distributed/tests/test_steal.py +++ b/distributed/tests/test_steal.py @@ -7,17 +7,30 @@ import math import random import weakref +from collections import defaultdict from operator import mul from time import sleep +from typing import Callable, Iterable, Mapping, Sequence import numpy as np import pytest -from tlz import sliding_window +from tlz import merge, sliding_window import dask from dask.utils import key_split -from distributed import Event, Lock, Nanny, Worker, profile, wait, worker_client +from distributed import ( + Client, + Event, + Lock, + Nanny, + Scheduler, + Worker, + profile, + wait, + worker_client, +) +from distributed.client import Future from distributed.compatibility import LINUX from distributed.core import Status from distributed.metrics import time @@ -25,7 +38,6 @@ from distributed.utils_test import ( NO_AMM, BlockedGetData, - SizeOf, captured_logger, freeze_batched_send, gen_cluster, @@ -50,6 +62,11 @@ teardown_module = nodebug_teardown_module +@pytest.fixture(params=[True, False]) +def recompute_saturation(request): + yield request.param + + @gen_cluster(client=True, nthreads=[("", 2), ("", 2)]) async def test_work_stealing(c, s, a, b): [x] = await c._scatter([1], workers=a.address) @@ -664,7 +681,7 @@ def block(*args, event, **kwargs): for t in sorted(ts, reverse=True): if t: [dat] = await c.scatter( - [SizeOf(int(t * s.bandwidth))], workers=w.address + [gen_nbytes(int(t * s.bandwidth))], workers=w.address ) else: dat = 123 @@ -710,24 +727,37 @@ def block(*args, event, **kwargs): raise Exception(f"Expected: {expected2}; got: {result2}") -@pytest.mark.parametrize("recompute_saturation", [True, False]) @pytest.mark.parametrize( "inp,expected", [ - ([[1], []], [[1], []]), # don't move unnecessarily - ([[0, 0], []], [[0], [0]]), # balance - ([[0.1, 0.1], []], [[0], [0]]), # balance even if results in even - ([[0, 0, 0], []], [[0, 0], [0]]), # don't over balance - ([[0, 0], [0, 0, 0], []], [[0, 0], [0, 0], [0]]), # move from larger - ([[0, 0, 0], [0], []], [[0, 0], [0], [0]]), # move to smaller - ([[0, 1], []], [[1], [0]]), # choose easier first - ([[0, 0, 0, 0], [], []], [[0, 0], [0], [0]]), # spread evenly - ([[1, 0, 2, 0], [], []], [[2, 1], [0], [0]]), # move easier - ([[1, 1, 1], []], [[1, 1], [1]]), # be willing to move costly items - ([[1, 1, 1, 1], []], [[1, 1, 1], [1]]), # but don't move too many - ( - [[0, 0], [0, 0], [0, 0], []], # no one clearly saturated + pytest.param([[1], []], [[1], []], id="don't move unnecessarily"), + pytest.param([[0, 0], []], [[0], [0]], id="balance"), + pytest.param( + [[0, 0, 0, 0, 0, 0, 0, 0], []], + [[0, 0, 0, 0, 0, 0], [0, 0]], + id="balance until none idle", + ), + pytest.param( + [[0.1, 0.1], []], [[0], [0]], id="balance even if results in even" + ), + pytest.param([[0, 0, 0], []], [[0, 0], [0]], id="don't over balance"), + pytest.param( + [[0, 0], [0, 0, 0], []], [[0, 0], [0, 0], [0]], id="move from larger" + ), + pytest.param([[0, 0, 0], [0], []], [[0, 0], [0], [0]], id="move to smaller"), + pytest.param([[0, 1], []], [[1], [0]], id="choose easier first"), + pytest.param([[0, 0, 0, 0], [], []], [[0, 0], [0], [0]], id="spread evenly"), + pytest.param([[1, 0, 2, 0], [], []], [[2, 1], [0], [0]], id="move easier"), + pytest.param( + [[1, 1, 1], []], [[1, 1], [1]], id="be willing to move costly items" + ), + pytest.param( + [[1, 1, 1, 1], []], [[1, 1, 1], [1]], id="but don't move too many" + ), + pytest.param( + [[0, 0], [0, 0], [0, 0], []], [[0, 0], [0, 0], [0], [0]], + id="no one clearly saturated", ), # NOTE: There is a timing issue that workers may already start executing # tasks before we call balance, i.e. the workers will reject the @@ -735,9 +765,10 @@ def block(*args, event, **kwargs): # Particularly tests with many input tasks are more likely to fail since # the test setup takes longer and allows the workers more time to # schedule a task on the threadpool - ( + pytest.param( [[4, 2, 2, 2, 2, 1, 1], [4, 2, 1, 1], [], [], []], [[4, 2, 2, 2], [4, 2, 1, 1], [2], [1], [1]], + id="balance multiple saturated workers", ), ], ) @@ -1418,3 +1449,428 @@ def func(*args): ideal = ntasks / len(workers) assert (ntasks_per_worker > ideal * 0.5).all(), (ideal, ntasks_per_worker) assert (ntasks_per_worker < ideal * 1.5).all(), (ideal, ntasks_per_worker) + + +def test_balance_even_with_replica(recompute_saturation): + dependencies = {"a": 1} + dependency_placement = [["a"], ["a"]] + task_placement = [[["a"], ["a"]], []] + + def _correct_placement(actual): + actual_task_counts = [len(placed) for placed in actual] + return actual_task_counts == [ + 1, + 1, + ] + + _run_dependency_balance_test( + dependencies, + dependency_placement, + task_placement, + _correct_placement, + recompute_saturation, + ) + + +def test_balance_to_replica(recompute_saturation): + dependencies = {"a": 2} + dependency_placement = [["a"], ["a"], []] + task_placement = [[["a"], ["a"]], [], []] + + def _correct_placement(actual): + actual_task_counts = [len(placed) for placed in actual] + return actual_task_counts == [ + 1, + 1, + 0, + ] + + _run_dependency_balance_test( + dependencies, + dependency_placement, + task_placement, + _correct_placement, + recompute_saturation, + ) + + +def test_balance_multiple_to_replica(recompute_saturation): + dependencies = {"a": 6} + dependency_placement = [["a"], ["a"], []] + task_placement = [[["a"], ["a"], ["a"], ["a"], ["a"], ["a"], ["a"], ["a"]], [], []] + + def _correct_placement(actual): + actual_task_counts = [len(placed) for placed in actual] + # FIXME: A better task placement would be even but the current balancing + # logic aborts as soon as a worker is no longer classified as idle + # return actual_task_counts == [ + # 4, + # 4, + # 0, + # ] + return actual_task_counts == [ + 6, + 2, + 0, + ] + + _run_dependency_balance_test( + dependencies, + dependency_placement, + task_placement, + _correct_placement, + recompute_saturation, + ) + + +def test_balance_to_larger_dependency(recompute_saturation): + dependencies = {"a": 2, "b": 1} + dependency_placement = [["a", "b"], ["a"], ["b"]] + task_placement = [[["a", "b"], ["a", "b"], ["a", "b"]], [], []] + + def _correct_placement(actual): + actual_task_counts = [len(placed) for placed in actual] + return actual_task_counts == [ + 2, + 1, + 0, + ] + + _run_dependency_balance_test( + dependencies, + dependency_placement, + task_placement, + _correct_placement, + recompute_saturation, + ) + + +def test_balance_prefers_busier_with_dependency(): + recompute_saturation = True + dependencies = {"a": 5, "b": 1} + dependency_placement = [["a"], ["a", "b"], []] + task_placement = [ + [["a"], ["a"], ["a"], ["a"], ["a"], ["a"]], + [["b"]], + [], + ] + + def _correct_placement(actual): + actual_task_placements = [sorted(placed) for placed in actual] + # FIXME: A better task placement would be even but the current balancing + # logic aborts as soon as a worker is no longer classified as idle + # return actual_task_placements == [ + # [["a"], ["a"], ["a"], ["a"]], + # [["a"], ["a"], ["b"]], + # [], + # ] + return actual_task_placements == [ + [["a"], ["a"], ["a"], ["a"], ["a"]], + [["a"], ["b"]], + [], + ] + + _run_dependency_balance_test( + dependencies, + dependency_placement, + task_placement, + _correct_placement, + recompute_saturation, + # This test relies on disabling queueing to flag workers as idle + config={ + "distributed.scheduler.worker-saturation": float("inf"), + }, + ) + + +def _run_dependency_balance_test( + dependencies: Mapping[str, int], + dependency_placement: list[list[str]], + task_placement: list[list[list[str]]], + correct_placement_fn: Callable[[list[list[list[str]]]], bool], + recompute_saturation: bool, + config: dict | None = None, +) -> None: + """Run a test for balancing with task dependencies according to the provided + specifications. + + This method executes the test logic for all permutations of worker placements + and generates a new cluster for each one. + + Parameters + ---------- + dependencies + Mapping of task dependencies to their weight. + dependency_placement + List of list of dependencies to be placed on the worker corresponding + to the index of the outer list. + task_placement + List of list of tasks to be placed on the worker corresponding to the + index of the outer list. Each task is a list of names of dependencies. + correct_placement_fn + Callable used to determine if stealing placed the tasks as expected. + recompute_saturation + Whether to recompute worker saturation before stealing. + config + Optional configuration to apply to the test. + See Also + -------- + _dependency_balance_test_permutation + """ + nworkers = len(task_placement) + for permutation in itertools.permutations(range(nworkers)): + + async def _run( + *args, + permutation=permutation, + **kwargs, + ): + await _dependency_balance_test_permutation( + dependencies, + dependency_placement, + task_placement, + correct_placement_fn, + recompute_saturation, + permutation, + *args, + **kwargs, + ) + + gen_cluster( + client=True, + nthreads=[("", 1)] * len(task_placement), + config=merge( + config or {}, + { + "distributed.scheduler.unknown-task-duration": "1s", + }, + ), + )(_run)() + + +async def _dependency_balance_test_permutation( + dependencies: Mapping[str, int], + dependency_placement: list[list[str]], + task_placement: list[list[list[str]]], + correct_placement_fn: Callable[[list[list[list[str]]]], bool], + recompute_saturation: bool, + permutation: list[int], + c: Client, + s: Scheduler, + *workers: Worker, +) -> None: + """Run a test for balancing with task dependencies according to the provided + specifications and worker permutations. + + Parameters + ---------- + dependencies + Mapping of task dependencies to their weight. + dependency_placement + List of list of dependencies to be placed on the worker corresponding + to the index of the outer list. + task_placement + List of list of tasks to be placed on the worker corresponding to the + index of the outer list. Each task is a list of names of dependencies. + correct_placement_fn + Callable used to determine if stealing placed the tasks as expected. + recompute_saturation + Whether to recompute worker saturation before stealing. + permutation + Permutation of workers to use for this run. + + See Also + -------- + _run_dependency_balance_test + """ + steal = s.extensions["stealing"] + await steal.stop() + + inverse = [permutation.index(i) for i in range(len(permutation))] + permutated_dependency_placement = [dependency_placement[i] for i in permutation] + permutated_task_placement = [task_placement[i] for i in permutation] + + dependency_futures = await _place_dependencies( + dependencies, permutated_dependency_placement, c, s, workers + ) + + ev, futures = await _place_tasks( + permutated_task_placement, + permutated_dependency_placement, + dependency_futures, + c, + s, + workers, + ) + + if recompute_saturation: + for ws in s.workers.values(): + s._reevaluate_occupancy_worker(ws) + try: + for _ in range(20): + steal.balance() + await steal.stop() + + permutated_actual_placement = _get_task_placement(s, workers) + actual_placement = [permutated_actual_placement[i] for i in inverse] + + if correct_placement_fn(actual_placement): + return + finally: + # Release the threadpools + await ev.set() + await c.gather(futures) + + raise AssertionError(actual_placement, permutation) + + +async def _place_dependencies( + dependencies: Mapping[str, int], + placement: list[list[str]], + c: Client, + s: Scheduler, + workers: Sequence[Worker], +) -> dict[str, Future]: + """Places the dependencies on the workers as specified. + + Parameters + ---------- + dependencies + Mapping of task dependencies to their weight. + placement + List of list of dependencies to be placed on the worker corresponding to the + index of the outer list. + + Returns + ------- + Dictionary of futures matching the input dependencies. + + See Also + -------- + _run_dependency_balance_test + """ + dependencies_to_workers = defaultdict(set) + for worker_idx, placed in enumerate(placement): + for dependency in placed: + dependencies_to_workers[dependency].add(workers[worker_idx].address) + + futures = {} + for name, multiplier in dependencies.items(): + worker_addresses = dependencies_to_workers[name] + futs = await c.scatter( + {name: gen_nbytes(int(multiplier * s.bandwidth))}, + workers=worker_addresses, + broadcast=True, + ) + futures[name] = futs[name] + + await c.gather(futures.values()) + + _assert_dependency_placement(placement, workers) + + return futures + + +def _assert_dependency_placement(expected, workers): + """Assert that dependencies are placed on the workers as expected.""" + actual = [] + for worker in workers: + actual.append(list(worker.state.tasks.keys())) + + assert actual == expected + + +async def _place_tasks( + placement: list[list[list[str]]], + dependency_placement: list[list[str]], + dependency_futures: Mapping[str, Future], + c: Client, + s: Scheduler, + workers: Sequence[Worker], +) -> tuple[Event, list[Future]]: + """Places the tasks on the workers as specified. + + Parameters + ---------- + placement + List of list of tasks to be placed on the worker corresponding to the + index of the outer list. Each task is a list of names of dependencies. + dependency_placement + List of list of dependencies to be placed on the worker corresponding to the + index of the outer list. + dependency_futures + Mapping of dependency names to their corresponding futures. + + Returns + ------- + Tuple of the event blocking the placed tasks and list of futures matching + the input task placement. + + See Also + -------- + _run_dependency_balance_test + """ + ev = Event() + + def block(*args, event, **kwargs): + event.wait() + + counter = itertools.count() + futures = [] + for worker_idx, tasks in enumerate(placement): + for dependencies in tasks: + i = next(counter) + dep_key = "".join(sorted(dependencies)) + key = f"{dep_key}-{i}" + f = c.submit( + block, + [dependency_futures[dependency] for dependency in dependencies], + event=ev, + key=key, + workers=workers[worker_idx].address, + allow_other_workers=True, + pure=False, + priority=-i, + ) + futures.append(f) + + while len([ts for ts in s.tasks.values() if ts.processing_on]) < len(futures): + await asyncio.sleep(0.001) + + while any( + len(w.state.tasks) < (len(tasks) + len(dependencies)) + for w, dependencies, tasks in zip(workers, dependency_placement, placement) + ): + await asyncio.sleep(0.001) + + assert_task_placement(placement, s, workers) + + return ev, futures + + +def _get_task_placement( + s: Scheduler, workers: Iterable[Worker] +) -> list[list[list[str]]]: + """Return the placement of tasks on this worker""" + actual = [] + for w in workers: + actual.append( + [list(key_split(ts.key)) for ts in s.workers[w.address].processing] + ) + return _deterministic_placement(actual) + + +def _equal_placement(left, right): + """Return True IFF the two input placements are equal.""" + return _deterministic_placement(left) == _deterministic_placement(right) + + +def _deterministic_placement(placement): + """Return a deterministic ordering of the tasks or dependencies on each worker.""" + return [sorted(placed) for placed in placement] + + +def assert_task_placement(expected, s, workers): + """Assert that tasks are placed on the workers as expected.""" + actual = _get_task_placement(s, workers) + assert _equal_placement(actual, expected) From ee433094f766097a958894a98629177524226fed Mon Sep 17 00:00:00 2001 From: Hendrik Makait Date: Thu, 29 Sep 2022 19:17:21 +0200 Subject: [PATCH 08/11] Remove failing test case (#7087) --- distributed/tests/test_steal.py | 5 ----- 1 file changed, 5 deletions(-) diff --git a/distributed/tests/test_steal.py b/distributed/tests/test_steal.py index fd42f5bae8..efb4892758 100644 --- a/distributed/tests/test_steal.py +++ b/distributed/tests/test_steal.py @@ -732,11 +732,6 @@ def block(*args, event, **kwargs): [ pytest.param([[1], []], [[1], []], id="don't move unnecessarily"), pytest.param([[0, 0], []], [[0], [0]], id="balance"), - pytest.param( - [[0, 0, 0, 0, 0, 0, 0, 0], []], - [[0, 0, 0, 0, 0, 0], [0, 0]], - id="balance until none idle", - ), pytest.param( [[0.1, 0.1], []], [[0], [0]], id="balance even if results in even" ), From 75b966b5bd70f68df793443ccf36b26fedec3412 Mon Sep 17 00:00:00 2001 From: crusaderky Date: Thu, 29 Sep 2022 21:10:18 +0100 Subject: [PATCH 09/11] dask-worker-space (#7054) --- distributed/diskutils.py | 42 +++++++++++++++++++++----- distributed/nanny.py | 7 +++-- distributed/pytest_resourceleaks.py | 3 +- distributed/tests/test_asyncprocess.py | 3 +- distributed/tests/test_diskutils.py | 18 +++++++++++ distributed/tests/test_nanny.py | 15 +++++++++ distributed/worker.py | 4 +-- 7 files changed, 76 insertions(+), 16 deletions(-) diff --git a/distributed/diskutils.py b/distributed/diskutils.py index a85fa21fa0..d54d97b701 100644 --- a/distributed/diskutils.py +++ b/distributed/diskutils.py @@ -5,6 +5,7 @@ import os import shutil import stat +import sys import tempfile import weakref from typing import ClassVar @@ -115,21 +116,46 @@ class WorkSpace: this will be detected and the directories purged. """ + base_dir: str + _global_lock_path: str + _purge_lock_path: str + # Keep track of all locks known to this process, to avoid several # WorkSpaces to step on each other's toes _known_locks: ClassVar[set[str]] = set() - def __init__(self, base_dir): - self.base_dir = os.path.abspath(base_dir) - self._init_workspace() + def __init__(self, base_dir: str): + self.base_dir = self._init_workspace(base_dir) self._global_lock_path = os.path.join(self.base_dir, "global.lock") self._purge_lock_path = os.path.join(self.base_dir, "purge.lock") - def _init_workspace(self): - try: - os.mkdir(self.base_dir) - except FileExistsError: - pass + def _init_workspace(self, base_dir: str) -> str: + """Create base_dir if it doesn't exist. + If base_dir already exists but it's not writeable, change the name. + """ + base_dir = os.path.abspath(base_dir) + try_dirs = [base_dir] + # Note: can't use WINDOWS constant as it upsets mypy + if sys.platform != "win32": + # - os.getlogin() raises OSError on containerized environments + # - os.getuid() does not exist in Windows + try_dirs.append(f"{base_dir}-{os.getuid()}") + + for try_dir in try_dirs: + try: + os.makedirs(try_dir) + except FileExistsError: + try: + with tempfile.TemporaryFile(dir=try_dir): + pass + except PermissionError: + continue + return try_dir + + # If we reached this, we're likely in a containerized environment where /tmp + # has been shared between containers through a mountpoint, every container + # has an external $UID, but the internal one is the same for all. + return tempfile.mkdtemp(prefix=base_dir + "-") def _global_lock(self, **kwargs): return locket.lock_file(self._global_lock_path, **kwargs) diff --git a/distributed/nanny.py b/distributed/nanny.py index e0ad58c343..43523694cc 100644 --- a/distributed/nanny.py +++ b/distributed/nanny.py @@ -37,6 +37,7 @@ error_message, ) from distributed.diagnostics.plugin import _get_plugin_name +from distributed.diskutils import WorkSpace from distributed.metrics import time from distributed.node import ServerNode from distributed.process import AsyncProcess @@ -179,9 +180,9 @@ def __init__( # type: ignore[no-untyped-def] else: self._original_local_dir = local_directory - self.local_directory = local_directory - if not os.path.exists(self.local_directory): - os.makedirs(self.local_directory, exist_ok=True) + # Create directory if it doesn't exist and test for write access. + # In case of PermissionError, change the name. + self.local_directory = WorkSpace(local_directory).base_dir self.preload = preload if self.preload is None: diff --git a/distributed/pytest_resourceleaks.py b/distributed/pytest_resourceleaks.py index aa5443583d..6352eb45fd 100644 --- a/distributed/pytest_resourceleaks.py +++ b/distributed/pytest_resourceleaks.py @@ -158,7 +158,8 @@ def measure(self) -> int: if sys.platform == "win32": # Don't use num_handles(); you'll get tens of thousands of reported leaks return 0 - return psutil.Process().num_fds() + else: + return psutil.Process().num_fds() def has_leak(self, before: int, after: int) -> bool: return after > before diff --git a/distributed/tests/test_asyncprocess.py b/distributed/tests/test_asyncprocess.py index 2b076a432c..e2af214f09 100644 --- a/distributed/tests/test_asyncprocess.py +++ b/distributed/tests/test_asyncprocess.py @@ -163,7 +163,8 @@ async def test_exitcode(): def assert_exit_code(proc: AsyncProcess, expect: signal.Signals) -> None: - if WINDOWS: + # Note: can't use WINDOWS constant as it upsets mypy + if sys.platform == "win32": # multiprocessing.Process.terminate() sets exit code -15 like in Linux, but # os.kill(pid, signal.SIGTERM) sets exit code +15 assert proc.exitcode in (-expect, expect) diff --git a/distributed/tests/test_diskutils.py b/distributed/tests/test_diskutils.py index ccf30507a5..95eb6eb85e 100644 --- a/distributed/tests/test_diskutils.py +++ b/distributed/tests/test_diskutils.py @@ -286,3 +286,21 @@ def test_workspace_concurrency(tmpdir): # We attempted to purge most directories at some point assert n_purged >= 0.5 * n_created > 0 + + +@pytest.mark.skipif(WINDOWS, reason="Need POSIX filesystem permissions and UIDs") +def test_unwritable_base_dir(tmpdir): + os.mkdir(f"{tmpdir}/bad", mode=0o500) + with pytest.raises(PermissionError): + open(f"{tmpdir}/bad/tryme", "w") + + ws = WorkSpace(f"{tmpdir}/bad") + assert ws.base_dir == f"{tmpdir}/bad-{os.getuid()}" + + os.chmod(f"{tmpdir}/bad-{os.getuid()}", 0o500) + with pytest.raises(PermissionError): + open(f"{tmpdir}/bad-{os.getuid()}/tryme", "w") + + ws = WorkSpace(f"{tmpdir}/bad") + assert ws.base_dir.startswith(f"{tmpdir}/bad-") + assert ws.base_dir != f"{tmpdir}/bad-{os.getuid()}" diff --git a/distributed/tests/test_nanny.py b/distributed/tests/test_nanny.py index 0e14a6fc8c..1de2d758a1 100644 --- a/distributed/tests/test_nanny.py +++ b/distributed/tests/test_nanny.py @@ -369,6 +369,21 @@ async def test_local_directory(s): assert n.process.worker_dir.count("dask-worker-space") == 1 +@pytest.mark.skipif(WINDOWS, reason="Need POSIX filesystem permissions and UIDs") +@gen_cluster(nthreads=[]) +async def test_unwriteable_dask_worker_space(s, tmpdir): + os.mkdir(f"{tmpdir}/dask-worker-space", mode=0o500) + with pytest.raises(PermissionError): + open(f"{tmpdir}/dask-worker-space/tryme", "w") + + with dask.config.set(temporary_directory=tmpdir): + async with Nanny(s.address) as n: + assert n.local_directory == os.path.join( + tmpdir, f"dask-worker-space-{os.getuid()}" + ) + assert n.process.worker_dir.count(f"dask-worker-space-{os.getuid()}") == 1 + + def _noop(x): """Define here because closures aren't pickleable.""" pass diff --git a/distributed/worker.py b/distributed/worker.py index 3c3a16b5b6..9f281a00d3 100644 --- a/distributed/worker.py +++ b/distributed/worker.py @@ -569,8 +569,6 @@ def __init__( local_directory = ( dask.config.get("temporary-directory") or tempfile.gettempdir() ) - - os.makedirs(local_directory, exist_ok=True) local_directory = os.path.join(local_directory, "dask-worker-space") with warn_on_duration( @@ -580,7 +578,7 @@ def __init__( "Consider specifying a local-directory to point workers to write " "scratch data to a local disk.", ): - self._workspace = WorkSpace(os.path.abspath(local_directory)) + self._workspace = WorkSpace(local_directory) self._workdir = self._workspace.new_work_dir(prefix="worker-") self.local_directory = self._workdir.dir_path From 5bfa08c71faeaccc1f036bed92b151033511e902 Mon Sep 17 00:00:00 2001 From: jakirkham Date: Thu, 29 Sep 2022 20:11:53 -0700 Subject: [PATCH 10/11] Type platform constants for mypy (#7091) --- distributed/compatibility.py | 6 +++--- distributed/tests/test_asyncprocess.py | 3 +-- 2 files changed, 4 insertions(+), 5 deletions(-) diff --git a/distributed/compatibility.py b/distributed/compatibility.py index 997236106f..6957e8732d 100644 --- a/distributed/compatibility.py +++ b/distributed/compatibility.py @@ -7,9 +7,9 @@ logging_names.update(logging._levelToName) # type: ignore logging_names.update(logging._nameToLevel) # type: ignore -LINUX = sys.platform == "linux" -MACOS = sys.platform == "darwin" -WINDOWS = sys.platform == "win32" +LINUX: bool = sys.platform == "linux" +MACOS: bool = sys.platform == "darwin" +WINDOWS: bool = sys.platform == "win32" if sys.version_info >= (3, 9): diff --git a/distributed/tests/test_asyncprocess.py b/distributed/tests/test_asyncprocess.py index e2af214f09..2b076a432c 100644 --- a/distributed/tests/test_asyncprocess.py +++ b/distributed/tests/test_asyncprocess.py @@ -163,8 +163,7 @@ async def test_exitcode(): def assert_exit_code(proc: AsyncProcess, expect: signal.Signals) -> None: - # Note: can't use WINDOWS constant as it upsets mypy - if sys.platform == "win32": + if WINDOWS: # multiprocessing.Process.terminate() sets exit code -15 like in Linux, but # os.kill(pid, signal.SIGTERM) sets exit code +15 assert proc.exitcode in (-expect, expect) From 68e5a6a33539228c5f57e1882e1794c1fb25e25a Mon Sep 17 00:00:00 2001 From: Graham Markall <535640+gmarkall@users.noreply.github.com> Date: Fri, 30 Sep 2022 13:36:32 +0100 Subject: [PATCH 11/11] test_serialize_numba: Workaround issue with np.empty_like in NP 1.23 (#7089) In NumPy 1.23, the strides of empty arrays are 0 instead of the item size, due to https://github.com/numpy/numpy/pull/21477 - however, `np.empty_like` seems to create non-zero-strided arrays from a zero-strided empty array, and copying to the host from a device array with zero strides fails a compatibility check in Numba. This commit works around the issue by calling `copy_to_host()` with no arguments, allowing Numba to create an array on the host that is compatible with the device array - the resulting implementation is functionally equivalent and slightly simpler, so I believe this change could remain permanant rather than requiring a revert later. --- distributed/protocol/tests/test_numba.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/distributed/protocol/tests/test_numba.py b/distributed/protocol/tests/test_numba.py index f145da3b41..b1b05d2e0b 100644 --- a/distributed/protocol/tests/test_numba.py +++ b/distributed/protocol/tests/test_numba.py @@ -31,10 +31,8 @@ def test_serialize_numba(shape, dtype, order, serializers): elif serializers[0] == "dask": assert all(isinstance(f, memoryview) for f in frames) - hx = np.empty_like(ary) - hy = np.empty_like(ary) - x.copy_to_host(hx) - y.copy_to_host(hy) + hx = x.copy_to_host() + hy = y.copy_to_host() assert (hx == hy).all()