From 9d4f0bf2fc804f955a869febd3b51423c4382908 Mon Sep 17 00:00:00 2001 From: crusaderky Date: Tue, 1 Jun 2021 15:18:04 +0100 Subject: [PATCH] O(1) rebalance (#4774) * partial prototype * incomplete poc * poc (incomplete) * complete POC * polish * polish * bugfix * fixes * fix * Use arbitrary measure in rebalance * Code review * renames * suggest tweaking malloc_trim * self-review * test_tls_functional * test_memory to use gen_cluster * test_memory to use gen_cluster * half memory * tests * tests * tests * tests * make Cython happy * test_rebalance_managed_memory * tests * robustness * improve test stability * tests stability * trivial * reload dask.config on Scheduler.__init__ * code review --- distributed/client.py | 9 +- distributed/distributed-schema.yaml | 42 +- distributed/distributed.yaml | 38 +- distributed/scheduler.py | 546 ++++++++++++++++------- distributed/tests/test_client.py | 162 ++++--- distributed/tests/test_scheduler.py | 287 +++++++++++- distributed/tests/test_tls_functional.py | 40 +- docs/source/memory.rst | 3 +- 8 files changed, 893 insertions(+), 234 deletions(-) diff --git a/distributed/client.py b/distributed/client.py index b1aa94032c..b577126693 100644 --- a/distributed/client.py +++ b/distributed/client.py @@ -3057,11 +3057,14 @@ def upload_file(self, filename, **kwargs): ) async def _rebalance(self, futures=None, workers=None): - await _wait(futures) - keys = list({stringify(f.key) for f in self.futures_of(futures)}) + if futures is not None: + await _wait(futures) + keys = list({stringify(f.key) for f in self.futures_of(futures)}) + else: + keys = None result = await self.scheduler.rebalance(keys=keys, workers=workers) if result["status"] == "missing-data": - raise ValueError( + raise KeyError( f"During rebalance {len(result['keys'])} keys were found to be missing" ) assert result["status"] == "OK" diff --git a/distributed/distributed-schema.yaml b/distributed/distributed-schema.yaml index eac459d7aa..86c00a8f31 100644 --- a/distributed/distributed-schema.yaml +++ b/distributed/distributed-schema.yaml @@ -395,7 +395,7 @@ properties: description: >- Settings for memory management properties: - recent_to_old_time: + recent-to-old-time: type: string description: >- When there is an increase in process memory (as observed by the @@ -403,6 +403,46 @@ properties: the worker, ignore it for this long before considering it in non-time-sensitive heuristics. This should be set to be longer than the duration of most dask tasks. + rebalance: + type: object + description: >- + Settings for memory rebalance operations + properties: + measure: + enum: + - process + - optimistic + - managed + - managed_in_memory + description: >- + Which of the properties of distributed.scheduler.MemoryState + should be used for measuring worker memory usage + sender-min: + type: number + minimum: 0 + maximum: 1 + description: >- + Fraction of worker process memory at which we start potentially + transferring data to other workers. + recipient-max: + type: number + minimum: 0 + maximum: 1 + description: >- + Fraction of worker process memory at which we stop potentially + receiving data from other workers. Ignored when max_memory is not + set. + sender-recipient-gap: + type: number + minimum: 0 + maximum: 1 + description: >- + Fraction of worker process memory, around the cluster mean, where + a worker is neither a sender nor a recipient of data during a + rebalance operation. E.g. if the mean cluster occupation is 50%, + sender-recipient-gap=0.1 means that only nodes above 55% will + donate data and only nodes below 45% will receive them. This helps + avoid data from bouncing around the cluster repeatedly. target: oneOf: diff --git a/distributed/distributed.yaml b/distributed/distributed.yaml index ff34e48bc3..a55700cd08 100644 --- a/distributed/distributed.yaml +++ b/distributed/distributed.yaml @@ -88,7 +88,43 @@ distributed: # system) that is not accounted for by the dask keys stored on the worker, ignore # it for this long before considering it in non-critical memory measures. # This should be set to be longer than the duration of most dask tasks. - recent_to_old_time: 30s + recent-to-old-time: 30s + + rebalance: + # Memory measure to rebalance upon. Possible choices are: + # process + # Total process memory, as measured by the OS. + # optimistic + # Managed by dask (instantaneous) + unmanaged (without any increases + # happened in the last ). + # Recommended for use on CPython with large (2MiB+) numpy-based data chunks. + # managed_in_memory + # Only consider the data allocated by dask in RAM. Recommended if RAM is not + # released in a timely fashion back to the OS after the Python objects are + # dereferenced, but remains available for reuse by PyMalloc. + # + # If this is your problem on Linux, you should alternatively consider + # setting the MALLOC_TRIM_THRESHOLD_ environment variable (note the final + # underscore) to a low value; refer to the mallopt man page and to the + # comments about M_TRIM_THRESHOLD on + # https://sourceware.org/git/?p=glibc.git;a=blob;f=malloc/malloc.c + # managed + # Only consider data allocated by dask, including that spilled to disk. + # Recommended if disk occupation of the spill file is an issue. + measure: optimistic + # Fraction of worker process memory at which we start potentially sending + # data to other workers. Ignored when max_memory is not set. + sender-min: 0.30 + # Fraction of worker process memory at which we stop potentially accepting + # data from other workers. Ignored when max_memory is not set. + recipient-max: 0.60 + # Fraction of worker process memory, around the cluster mean, where a worker is + # neither a sender nor a recipient of data during a rebalance operation. E.g. if + # the mean cluster occupation is 50%, sender-recipient-gap=0.10 means that only + # nodes above 55% will donate data and only nodes below 45% will receive them. + # This helps avoid data from bouncing around the cluster repeatedly. + # Ignored when max_memory is not set. + sender-recipient-gap: 0.10 # Fractions of worker process memory at which we take action to avoid memory # blowup. Set any of the values to False to turn off the behavior entirely. diff --git a/distributed/scheduler.py b/distributed/scheduler.py index 0603ac0b1c..de3843c17f 100644 --- a/distributed/scheduler.py +++ b/distributed/scheduler.py @@ -1,4 +1,5 @@ import asyncio +import heapq import html import inspect import itertools @@ -12,11 +13,12 @@ import warnings import weakref from collections import defaultdict, deque -from collections.abc import Mapping, Set +from collections.abc import Hashable, Iterable, Iterator, Mapping, Set from contextlib import suppress from datetime import timedelta from functools import partial from numbers import Number +from typing import Optional import psutil import sortedcontainers @@ -161,14 +163,6 @@ def nogil(func): DEFAULT_DATA_SIZE = declare( Py_ssize_t, parse_bytes(dask.config.get("distributed.scheduler.default-data-size")) ) -UNKNOWN_TASK_DURATION = declare( - double, - parse_timedelta(dask.config.get("distributed.scheduler.unknown-task-duration")), -) -MEMORY_RECENT_TO_OLD_TIME = declare( - double, - parse_timedelta(dask.config.get("distributed.worker.memory.recent_to_old_time")), -) DEFAULT_EXTENSIONS = [ LockExtension, @@ -292,7 +286,7 @@ class MemoryState: unmanaged_old Minimum of the 'unmanaged' measures over the last - ``distributed.memory.recent_to_old_time`` seconds + ``distributed.memory.recent-to-old-time`` seconds unmanaged_recent unmanaged - unmanaged_old; in other words process memory that has been recently allocated but is not accounted for by dask; hopefully it's mostly a temporary @@ -419,7 +413,7 @@ class WorkerState: .. attribute:: has_what: {TaskState} - The set of tasks which currently reside on this worker. + An insertion-sorted set-like of tasks which currently reside on this worker. All the tasks here are in the "memory" state. This is the reverse mapping of :class:`TaskState.who_has`. @@ -479,7 +473,9 @@ class WorkerState: _bandwidth: double _executing: dict _extra: dict - _has_what: set + # _has_what is a dict with all values set to None as rebalance() relies on the + # property of Python >=3.7 dicts to be insertion-sorted. + _has_what: dict _hash: Py_hash_t _last_seen: double _local_directory: str @@ -567,7 +563,7 @@ def __init__( ) self._actors = set() - self._has_what = set() + self._has_what = {} self._processing = {} self._executing = {} self._resources = {} @@ -608,8 +604,8 @@ def extra(self): return self._extra @property - def has_what(self): - return self._has_what + def has_what(self) -> "Set[TaskState]": + return self._has_what.keys() @property def host(self): @@ -1780,6 +1776,14 @@ class SchedulerState: _workers: object _workers_dv: dict + # Variables from dask.config, cached by __init__ for performance + UNKNOWN_TASK_DURATION: double + MEMORY_RECENT_TO_OLD_TIME: double + MEMORY_REBALANCE_MEASURE: str + MEMORY_REBALANCE_SENDER_MIN: double + MEMORY_REBALANCE_RECIPIENT_MAX: double + MEMORY_REBALANCE_HALF_GAP: double + def __init__( self, aliases: dict = None, @@ -1854,6 +1858,28 @@ def __init__( else: self._workers = sortedcontainers.SortedDict() self._workers_dv: dict = cast(dict, self._workers) + + # Variables from dask.config, cached by __init__ for performance + self.UNKNOWN_TASK_DURATION = parse_timedelta( + dask.config.get("distributed.scheduler.unknown-task-duration") + ) + self.MEMORY_RECENT_TO_OLD_TIME = parse_timedelta( + dask.config.get("distributed.worker.memory.recent-to-old-time") + ) + self.MEMORY_REBALANCE_MEASURE = dask.config.get( + "distributed.worker.memory.rebalance.measure" + ) + self.MEMORY_REBALANCE_SENDER_MIN = dask.config.get( + "distributed.worker.memory.rebalance.sender-min" + ) + self.MEMORY_REBALANCE_RECIPIENT_MAX = dask.config.get( + "distributed.worker.memory.rebalance.recipient-max" + ) + self.MEMORY_REBALANCE_HALF_GAP = ( + dask.config.get("distributed.worker.memory.rebalance.sender-recipient-gap") + / 2.0 + ) + super().__init__(**kwargs) @property @@ -2608,7 +2634,7 @@ def transition_memory_released(self, key, safe: bint = False): "report": False, } for ws in ts._who_has: - ws._has_what.remove(ts) + del ws._has_what[ts] ws._nbytes -= ts_nbytes ts._group._nbytes_in_memory -= ts_nbytes worker_msgs[ws._address] = [worker_msg] @@ -3061,7 +3087,7 @@ def get_comm_cost(self, ts: TaskState, ws: WorkerState) -> double: on the given worker. """ dts: TaskState - deps: set = ts._dependencies - ws._has_what + deps: set = ts._dependencies.difference(ws._has_what) nbytes: Py_ssize_t = 0 for dts in deps: nbytes += dts._nbytes @@ -3074,18 +3100,14 @@ def get_task_duration(self, ts: TaskState, default: double = -1) -> double: (not including any communication cost). """ duration: double = ts._prefix._duration_average - if duration < 0: - s: set = self._unknown_durations.get(ts._prefix._name) - if s is None: - self._unknown_durations[ts._prefix._name] = s = set() - s.add(ts) + if duration >= 0: + return duration - if default < 0: - duration = UNKNOWN_TASK_DURATION - else: - duration = default - - return duration + s: set = self._unknown_durations.get(ts._prefix._name) + if s is None: + self._unknown_durations[ts._prefix._name] = s = set() + s.add(ts) + return default if default >= 0 else self.UNKNOWN_TASK_DURATION @ccall @exceptval(check=False) @@ -3867,7 +3889,7 @@ def heartbeat_worker( # Calculate RSS - dask keys, separating "old" and "new" usage # See MemoryState for details - max_memory_unmanaged_old_hist_age = local_now - MEMORY_RECENT_TO_OLD_TIME + max_memory_unmanaged_old_hist_age = local_now - parent.MEMORY_RECENT_TO_OLD_TIME memory_unmanaged_old = ws._memory_unmanaged_old while ws._memory_other_history: timestamp, size = ws._memory_other_history[0] @@ -4498,7 +4520,7 @@ def stimulus_missing_data( ws: WorkerState cts_nbytes: Py_ssize_t = cts.get_nbytes() for ws in cts._who_has: # TODO: this behavior is extreme - ws._has_what.remove(cts) + del ws._has_what[ts] ws._nbytes -= cts_nbytes cts._who_has.clear() recommendations[cause] = "released" @@ -5084,7 +5106,7 @@ def handle_missing_data(self, key=None, errant_worker=None, **kwargs): ws: WorkerState = parent._workers_dv.get(errant_worker) if ws is not None and ws in ts._who_has: ts._who_has.remove(ws) - ws._has_what.remove(ts) + del ws._has_what[ts] ws._nbytes -= ts.get_nbytes() if not ts._who_has: if ts._run_spec: @@ -5096,12 +5118,12 @@ def release_worker_data(self, comm=None, keys=None, worker=None): parent: SchedulerState = cast(SchedulerState, self) ws: WorkerState = parent._workers_dv[worker] tasks: set = {parent._tasks[k] for k in keys} - removed_tasks: set = tasks & ws._has_what - ws._has_what -= removed_tasks + removed_tasks: set = tasks.intersection(ws._has_what) ts: TaskState recommendations: dict = {} for ts in removed_tasks: + del ws._has_what[ts] ws._nbytes -= ts.get_nbytes() wh: set = ts._who_has wh.remove(ws) @@ -5342,7 +5364,7 @@ async def gather(self, comm=None, keys=None, serializers=None): for worker in workers: ws = parent._workers_dv.get(worker) if ws is not None and ts in ws._has_what: - ws._has_what.remove(ts) + del ws._has_what[ts] ts._who_has.remove(ws) ws._nbytes -= ts_nbytes parent._transitions( @@ -5508,144 +5530,364 @@ async def _delete_worker_data(self, worker_address, keys): ws: WorkerState = parent._workers_dv[worker_address] ts: TaskState tasks: set = {parent._tasks[key] for key in keys} - ws._has_what -= tasks for ts in tasks: + del ws._has_what[ts] ts._who_has.remove(ws) ws._nbytes -= ts.get_nbytes() self.log_event(ws._address, {"action": "remove-worker-data", "keys": keys}) - async def rebalance(self, comm=None, keys=None, workers=None): - """Rebalance keys so that each worker stores roughly equal bytes - - **Policy** + async def rebalance( + self, + comm=None, + keys: "Iterable[Hashable]" = None, + workers: "Iterable[str]" = None, + ) -> dict: + """Rebalance keys so that each worker ends up with roughly the same process + memory (managed+unmanaged). + + FIXME this method is not robust when the cluster is not idle. + + **Algorithm** + + #. Find the mean occupancy of the cluster, defined as data managed by dask + + unmanaged process memory that has been there for at least 30 seconds + (``distributed.worker.memory.recent-to-old-time``). + This lets us ignore temporary spikes caused by task heap usage. + #. Discard workers whose occupancy is within 5% of the mean cluster occupancy + (``distributed.worker.memory.rebalance.sender-recipient-gap`` / 2). + This helps avoid data from bouncing around the cluster repeatedly. + #. Workers above the mean are senders; those below are recipients. + #. Discard senders whose absolute occupancy is below 40% + (``distributed.worker.memory.rebalance.sender-min``). In other words, no data + is moved regardless of imbalancing as long as all workers are below 40%. + #. Discard recipients whose absolute occupancy is above 60% + (``distributed.worker.memory.rebalance.recipient-max``). + Note that this threshold by default is the same as + ``distributed.worker.memory.target`` to prevent workers from accepting data + and immediately spilling it out to disk. + #. Iteratively pick the sender and recipient that are farthest from the mean and + move the *least recently inserted* key between the two, until either all + senders or all recipients fall within 5% of the mean. + + A recipient will be skipped if it already has a copy of the data. In other + words, this method does not degrade replication. + A key will be skipped if there are no recipients that have both enough memory + to accept and don't already hold a copy. + + The least recently insertd (LRI) policy is a greedy choice with the advantage of + being O(1), trivial to implement (it relies on python dict insertion-sorting) + and hopefully good enough in most cases. Discarded alternative policies were: + + - Largest first. O(n*log(n)) save for non-trivial additional data structures and + risks causing the largest chunks of data to repeatedly move around the + cluster like pinballs. + - Least recently utilized. This information is currently available on the + workers only and not trivial to replicate on the scheduler; transmitting it + over the network would be very expensive. Also, note that dask will go out of + its way to minimise the amount of time intermediate keys are held in memory, + so in such a case LRI is a close approximation of LRU. - This orders the workers by what fraction of bytes of the existing keys - they have. It walks down this list from most-to-least. At each worker - it sends the largest results it can find and sends them to the least - occupied worker until either the sender or the recipient are at the - average expected load. + Parameters + ---------- + keys: optional + whitelist of dask keys that should be considered for moving. All other keys + will be ignored. Note that this offers no guarantee that a key will actually + be moved (e.g. because it is unnecessary or because there are no viable + recipient workers for it). + workers: optional + whitelist of workers addresses to be considered as senders or recipients. + All other workers will be ignored. The mean cluster occupancy will be + calculated only using the whitelisted workers. """ - parent: SchedulerState = cast(SchedulerState, self) - ts: TaskState - with log_errors(): - async with self._lock: - if keys: - tasks = {parent._tasks[k] for k in keys} - missing_data = [ts._key for ts in tasks if not ts._who_has] - if missing_data: - return {"status": "missing-data", "keys": missing_data} - else: - tasks = set(parent._tasks.values()) + parent: SchedulerState = self - if workers: - workers = {parent._workers_dv[w] for w in workers} - workers_by_task = {ts: ts._who_has & workers for ts in tasks} - else: - workers = set(parent._workers_dv.values()) - workers_by_task = {ts: ts._who_has for ts in tasks} + with log_errors(): + if workers is not None: + workers = [parent._workers_dv[w] for w in workers] + else: + workers = parent._workers_dv.values() + if not workers: + return {"status": "OK"} - ws: WorkerState - tasks_by_worker = {ws: set() for ws in workers} + if keys is not None: + if not isinstance(keys, Set): + keys = set(keys) # unless already a set-like + if not keys: + return {"status": "OK"} + missing_data = [ + k + for k in keys + if k not in parent._tasks or not parent._tasks[k].who_has + ] + if missing_data: + return {"status": "missing-data", "keys": missing_data} - for k, v in workers_by_task.items(): - for vv in v: - tasks_by_worker[vv].add(k) + msgs = self._rebalance_find_msgs(keys, workers) + if not msgs: + return {"status": "OK"} - worker_bytes = { - ws: sum(ts.get_nbytes() for ts in v) - for ws, v in tasks_by_worker.items() - } + async with self._lock: + return await self._rebalance_move_data(msgs) + + def _rebalance_find_msgs( + self: SchedulerState, + keys: "Optional[Set[Hashable]]", + workers: "Iterable[WorkerState]", + ) -> "list[tuple[WorkerState, WorkerState, TaskState]]": + """Identify workers that need to lose keys and those that can receive them, + together with how many bytes each needs to lose/receive. Then, pair a sender + worker with a recipient worker for each key, until the cluster is rebalanced. + + This method only defines the work to be performed; it does not start any network + transfers itself. + + The big-O complexity is O(wt + ke*log(we)), where + + - wt is the total number of workers on the cluster (or the number of whitelisted + workers, if explicitly stated by the user) + - we is the number of workers that are eligible to be senders or recipients + - kt is the total number of keys on the cluster (or on the whitelisted workers) + - ke is the number of keys that need to be moved in order to achieve a balanced + cluster + + There is a degenerate edge case O(wt + kt*log(we)) when kt is much greater than + the number of whitelisted keys, or when most keys are replicated or cannot be + moved for some other reason. + + Returns list of tuples to feed into _rebalance_move_data: + + - sender worker + - recipient worker + - task to be transferred + """ + parent: SchedulerState = self + ts: TaskState + ws: WorkerState - avg = sum(worker_bytes.values()) / len(worker_bytes) + # Heaps of workers, managed by the heapq module, that need to send/receive data, + # with how many bytes each needs to send/receive. + # + # Each element of the heap is a tuple constructed as follows: + # - snd_bytes_max/rec_bytes_max: maximum number of bytes to send or receive. + # This number is negative, so that the workers farthest from the cluster mean + # are at the top of the smallest-first heaps. + # - snd_bytes_min/rec_bytes_min: minimum number of bytes after sending/receiving + # which the worker should not be considered anymore. This is also negative. + # - arbitrary unique number, there just to to make sure that WorkerState objects + # are never used for sorting in the unlikely event that two processes have + # exactly the same number of bytes allocated. + # - WorkerState + # - iterator of all tasks in memory on the worker (senders only), insertion + # sorted (least recently inserted first). + # Note that this iterator will typically *not* be exhausted. It will only be + # exhausted if, after moving away from the worker all keys that can be moved, + # is insufficient to drop snd_bytes_min above 0. + senders: "list[tuple[int, int, int, WorkerState, Iterator[TaskState]]]" = [] + recipients: "list[tuple[int, int, int, WorkerState]]" = [] + + # Output: [(sender, recipient, task), ...] + msgs: "list[tuple[WorkerState, WorkerState, TaskState]]" = [] + + # By default, this is the optimistic memory, meaning total process memory minus + # unmanaged memory that appeared over the last 30 seconds + # (distributed.worker.memory.recent-to-old-time). + # This lets us ignore temporary spikes caused by task heap usage. + memory_by_worker = [ + (ws, getattr(ws.memory, parent.MEMORY_REBALANCE_MEASURE)) for ws in workers + ] + mean_memory = sum(m for _, m in memory_by_worker) // len(memory_by_worker) - sorted_workers = list( - map(first, sorted(worker_bytes.items(), key=second, reverse=True)) + for ws, ws_memory in memory_by_worker: + if ws.memory_limit: + half_gap = int(parent.MEMORY_REBALANCE_HALF_GAP * ws.memory_limit) + sender_min = parent.MEMORY_REBALANCE_SENDER_MIN * ws.memory_limit + recipient_max = parent.MEMORY_REBALANCE_RECIPIENT_MAX * ws.memory_limit + else: + half_gap = 0 + sender_min = 0.0 + recipient_max = math.inf + + if ( + ws._has_what + and ws_memory >= mean_memory + half_gap + and ws_memory >= sender_min + ): + # This may send the worker below sender_min (by design) + snd_bytes_max = mean_memory - ws_memory # negative + snd_bytes_min = snd_bytes_max + half_gap # negative + # See definition of senders above + senders.append( + (snd_bytes_max, snd_bytes_min, id(ws), ws, iter(ws._has_what)) ) + elif ws_memory < mean_memory - half_gap and ws_memory < recipient_max: + # This may send the worker above recipient_max (by design) + rec_bytes_max = ws_memory - mean_memory # negative + rec_bytes_min = rec_bytes_max + half_gap # negative + # See definition of recipients above + recipients.append((rec_bytes_max, rec_bytes_min, id(ws), ws)) + + # Fast exit in case no transfers are necessary or possible + if not senders or not recipients: + self.log_event( + "all", + { + "action": "rebalance", + "senders": len(senders), + "recipients": len(recipients), + "moved_keys": 0, + }, + ) + return [] - recipients = reversed(sorted_workers) - recipient = next(recipients) - msgs = [] # (sender, recipient, key) - for sender in sorted_workers[: len(workers) // 2]: - sender_keys = { - ts: ts.get_nbytes() for ts in tasks_by_worker[sender] - } - sender_keys = iter( - sorted(sender_keys.items(), key=second, reverse=True) - ) + heapq.heapify(senders) + heapq.heapify(recipients) - try: - while avg < worker_bytes[sender]: - while worker_bytes[recipient] < avg < worker_bytes[sender]: - ts, nb = next(sender_keys) - if ts not in tasks_by_worker[recipient]: - tasks_by_worker[recipient].add(ts) - # tasks_by_worker[sender].remove(ts) - msgs.append((sender, recipient, ts)) - worker_bytes[sender] -= nb - worker_bytes[recipient] += nb - if avg < worker_bytes[sender]: - recipient = next(recipients) - except StopIteration: + snd_ws: WorkerState + rec_ws: WorkerState + + while senders and recipients: + snd_bytes_max, snd_bytes_min, _, snd_ws, ts_iter = senders[0] + + # Iterate through tasks in memory, least recently inserted first + for ts in ts_iter: + if keys is not None and ts.key not in keys: + continue + nbytes = ts.nbytes + if nbytes + snd_bytes_max > 0: + # Moving this task would cause the sender to go below mean and + # potentially risk becoming a recipient, which would cause tasks to + # bounce around. Move on to the next task of the same sender. + continue + + # Find the recipient, farthest from the mean, which + # 1. has enough available RAM for this task, and + # 2. doesn't hold a copy of this task already + # There may not be any that satisfies these conditions; in this case + # this task won't be moved. + skipped_recipients = [] + use_recipient = False + while recipients and not use_recipient: + rec_bytes_max, rec_bytes_min, _, rec_ws = recipients[0] + if nbytes + rec_bytes_max > 0: + # recipients are sorted by rec_bytes_max. + # The next ones will be worse; no reason to continue iterating break + use_recipient = ts not in rec_ws._has_what + if not use_recipient: + skipped_recipients.append(heapq.heappop(recipients)) - to_recipients = defaultdict(lambda: defaultdict(list)) - to_senders = defaultdict(list) - for sender, recipient, ts in msgs: - to_recipients[recipient.address][ts._key].append(sender.address) - to_senders[sender.address].append(ts._key) + for recipient in skipped_recipients: + heapq.heappush(recipients, recipient) - result = await asyncio.gather( - *( - retry_operation(self.rpc(addr=r).gather, who_has=v) - for r, v in to_recipients.items() + if not use_recipient: + # This task has no recipients available. Leave it on the sender and + # move on to the next task of the same sender. + continue + + # Schedule task for transfer from sender to receiver + msgs.append((snd_ws, rec_ws, ts)) + + # *_bytes_max/min are all negative for heap sorting + snd_bytes_max += nbytes + snd_bytes_min += nbytes + rec_bytes_max += nbytes + rec_bytes_min += nbytes + + # Stop iterating on the tasks of this sender for now and, if it still + # has bytes to lose, push it back into the senders heap; it may or may + # not come back on top again. + if snd_bytes_min < 0: + # See definition of senders above + heapq.heapreplace( + senders, + (snd_bytes_max, snd_bytes_min, id(snd_ws), snd_ws, ts_iter), ) - ) - for r, v in to_recipients.items(): - self.log_event(r, {"action": "rebalance", "who_has": v}) + else: + heapq.heappop(senders) + + # If receiver still has bytes to gain, push it back into the receivers + # heap; it may or may not come back on top again. + if rec_bytes_min < 0: + # See definition of recipients above + heapq.heapreplace( + recipients, + (rec_bytes_max, rec_bytes_min, id(rec_ws), rec_ws), + ) + else: + heapq.heappop(recipients) - self.log_event( - "all", - { - "action": "rebalance", - "total-keys": len(tasks), - "senders": valmap(len, to_senders), - "recipients": valmap(len, to_recipients), - "moved_keys": len(msgs), - }, - ) + # Move to next sender with the most data to lose. + # It may or may not be the same sender again. + break - if any(r["status"] != "OK" for r in result): - return { - "status": "missing-data", - "keys": tuple( - concat( - r["keys"].keys() - for r in result - if r["status"] == "missing-data" - ) - ), - } + else: # for ts in ts_iter + # Exhausted tasks on this sender + heapq.heappop(senders) - for sender, recipient, ts in msgs: - assert ts._state == "memory" - ts._who_has.add(recipient) - recipient.has_what.add(ts) - recipient.nbytes += ts.get_nbytes() - self.log.append( - ( - "rebalance", - ts._key, - time(), - sender.address, - recipient.address, - ) + return msgs + + async def _rebalance_move_data( + self, msgs: "list[tuple[WorkerState, WorkerState, TaskState]]" + ) -> dict: + """Perform the actual transfer of data across the network in rebalance(). + Takes in input the output of _rebalance_find_msgs(). + + FIXME this method is not robust when the cluster is not idle. + """ + ts: TaskState + snd_ws: WorkerState + rec_ws: WorkerState + + to_recipients = defaultdict(lambda: defaultdict(list)) + to_senders = defaultdict(list) + for sender, recipient, ts in msgs: + to_recipients[recipient.address][ts._key].append(sender.address) + to_senders[sender.address].append(ts._key) + + result = await asyncio.gather( + *( + retry_operation(self.rpc(addr=r).gather, who_has=v) + for r, v in to_recipients.items() + ) + ) + for r, v in to_recipients.items(): + self.log_event(r, {"action": "rebalance", "who_has": v}) + + self.log_event( + "all", + { + "action": "rebalance", + "senders": valmap(len, to_senders), + "recipients": valmap(len, to_recipients), + "moved_keys": len(msgs), + }, + ) + + if any(r["status"] != "OK" for r in result): + return { + "status": "missing-data", + "keys": list( + concat( + r["keys"].keys() + for r in result + if r["status"] == "missing-data" ) + ), + } - await asyncio.gather( - *(self._delete_worker_data(r, v) for r, v in to_senders.items()) - ) + for snd_ws, rec_ws, ts in msgs: + assert ts._state == "memory" + ts._who_has.add(rec_ws) + rec_ws._has_what[ts] = None + rec_ws.nbytes += ts.get_nbytes() + self.log.append( + ("rebalance", ts._key, time(), snd_ws.address, rec_ws.address) + ) - return {"status": "OK"} + await asyncio.gather( + *(self._delete_worker_data(r, v) for r, v in to_senders.items()) + ) + return {"status": "OK"} async def replicate( self, @@ -5976,7 +6218,7 @@ async def retire_workers( logger.info("Retire workers %s", workers) # Keys orphaned by retiring those workers - keys = set.union(*[w.has_what for w in workers]) + keys = {k for w in workers for k in w.has_what} keys = {ts._key for ts in keys if ts._who_has.issubset(workers)} if keys: @@ -6030,7 +6272,7 @@ def add_keys(self, comm=None, worker=None, keys=()): if ts is not None and ts._state == "memory": if ts not in ws._has_what: ws._nbytes += ts.get_nbytes() - ws._has_what.add(ts) + ws._has_what[ts] = None ts._who_has.add(ws) else: self.worker_send( @@ -6075,7 +6317,7 @@ def update_data( ws: WorkerState = parent._workers_dv[w] if ts not in ws._has_what: ws._nbytes += ts_nbytes - ws._has_what.add(ts) + ws._has_what[ts] = None ts._who_has.add(ws) self.report( {"op": "key-in-memory", "key": key, "workers": list(workers)} @@ -6202,7 +6444,7 @@ def get_has_what(self, comm=None, workers=None): } else: return { - w: [ts._key for ts in ws._has_what] + w: [ts._key for ts in ws.has_what] for w, ws in parent._workers_dv.items() } @@ -6975,7 +7217,7 @@ def _add_to_memory( assert ts not in ws._has_what ts._who_has.add(ws) - ws._has_what.add(ts) + ws._has_what[ts] = None ws._nbytes += ts.get_nbytes() deps: list = list(ts._dependents) @@ -7058,7 +7300,7 @@ def _propagate_forgotten( ws: WorkerState for ws in ts._who_has: - ws._has_what.remove(ts) + del ws._has_what[ts] ws._nbytes -= ts_nbytes w: str = ws._address if w in state._workers_dv: # in case worker has died diff --git a/distributed/tests/test_client.py b/distributed/tests/test_client.py index f17de5949b..2e3d1ef2aa 100644 --- a/distributed/tests/test_client.py +++ b/distributed/tests/test_client.py @@ -2892,92 +2892,128 @@ def __reduce__(self): raise BadlySerializedException("hello world") x = c.submit(f) + with pytest.raises(Exception, match="hello world"): + await x - try: - result = await x - except Exception as e: - assert "hello world" in str(e) - else: - assert False +@gen_cluster( + client=True, + Worker=Nanny, + worker_kwargs={"memory_limit": "1 GiB"}, + config={"distributed.worker.memory.rebalance.sender-min": 0.3}, +) +async def test_rebalance(c, s, *_): + """Test Client.rebalance(). These are just to test the Client wrapper around + Scheduler.rebalance(); for more thorough tests on the latter see test_scheduler.py. + """ + # We used nannies to have separate processes for each worker + a, b = s.workers -@gen_cluster(client=True) -async def test_rebalance(c, s, a, b): - aws = s.workers[a.address] - bws = s.workers[b.address] + # Generate 10 buffers worth 512 MiB total on worker a. This sends its memory + # utilisation slightly above 50% (after counting unmanaged) which is above the + # distributed.worker.memory.rebalance.sender-min threshold. + futures = c.map(lambda _: "x" * (2 ** 29 // 10), range(10), workers=[a]) + await wait(futures) + # Wait for heartbeats + while s.memory.process < 2 ** 29: + await asyncio.sleep(0.1) - x, y = await c.scatter([1, 2], workers=[a.address]) - assert len(a.data) == 2 - assert len(b.data) == 0 + assert await c.run(lambda dask_worker: len(dask_worker.data)) == {a: 10, b: 0} - s.validate_state() await c.rebalance() - s.validate_state() - assert len(b.data) == 1 - assert {ts.key for ts in bws.has_what} == set(b.data) - assert bws in s.tasks[x.key].who_has or bws in s.tasks[y.key].who_has + ndata = await c.run(lambda dask_worker: len(dask_worker.data)) + # Allow for some uncertainty as the unmanaged memory is not stable + assert sum(ndata.values()) == 10 + assert 3 <= ndata[a] <= 7 + assert 3 <= ndata[b] <= 7 - assert len(a.data) == 1 - assert {ts.key for ts in aws.has_what} == set(a.data) - assert aws not in s.tasks[x.key].who_has or aws not in s.tasks[y.key].who_has - - -@gen_cluster(nthreads=[("127.0.0.1", 1)] * 4, client=True) -async def test_rebalance_workers(e, s, a, b, c, d): - w, x, y, z = await e.scatter([1, 2, 3, 4], workers=[a.address]) - assert len(a.data) == 4 - assert len(b.data) == 0 - assert len(c.data) == 0 - assert len(d.data) == 0 - - await e.rebalance([x, y], workers=[a.address, c.address]) - assert len(a.data) == 3 - assert len(b.data) == 0 - assert len(c.data) == 1 - assert len(d.data) == 0 - assert c.data == {x.key: 2} or c.data == {y.key: 3} - - await e.rebalance() - assert len(a.data) == 1 - assert len(b.data) == 1 - assert len(c.data) == 1 - assert len(d.data) == 1 - s.validate_state() +@gen_cluster( + nthreads=[("127.0.0.1", 1)] * 3, + client=True, + Worker=Nanny, + worker_kwargs={"memory_limit": "1 GiB"}, +) +async def test_rebalance_workers_and_keys(client, s, *_): + """Test Client.rebalance(). These are just to test the Client wrapper around + Scheduler.rebalance(); for more thorough tests on the latter see test_scheduler.py. + """ + a, b, c = s.workers + futures = client.map(lambda _: "x" * (2 ** 29 // 10), range(10), workers=[a]) + await wait(futures) + # Wait for heartbeats + while s.memory.process < 2 ** 29: + await asyncio.sleep(0.1) -@gen_cluster(client=True) -async def test_rebalance_execution(c, s, a, b): - futures = c.map(inc, range(10), workers=a.address) - await c.rebalance(futures) - assert len(a.data) == len(b.data) == 5 - s.validate_state() + # Passing empty iterables is not the same as omitting the arguments + await client.rebalance([]) + await client.rebalance(workers=[]) + assert await client.run(lambda dask_worker: len(dask_worker.data)) == { + a: 10, + b: 0, + c: 0, + } + # Limit rebalancing to two arbitrary keys and two arbitrary workers. + await client.rebalance([futures[3], futures[7]], [a, b]) + assert await client.run(lambda dask_worker: len(dask_worker.data)) == { + a: 8, + b: 2, + c: 0, + } -def test_rebalance_sync(c, s, a, b): - futures = c.map(inc, range(10), workers=[a["address"]]) - c.rebalance(futures) + with pytest.raises(KeyError): + await client.rebalance(workers=["notexist"]) + + +def test_rebalance_sync(): + # can't use the 'c' fixture because we need workers to run in a separate process + with Client(n_workers=2, memory_limit="1 GiB") as c: + s = c.cluster.scheduler + a, b = [ws.address for ws in s.workers.values()] + futures = c.map(lambda _: "x" * (2 ** 29 // 10), range(10), workers=[a]) + wait(futures) + # Wait for heartbeat + while s.memory.process < 2 ** 29: + sleep(0.1) - has_what = c.has_what() - assert len(has_what) == 2 - assert list(valmap(len, has_what).values()) == [5, 5] + assert c.run(lambda dask_worker: len(dask_worker.data)) == {a: 10, b: 0} + c.rebalance() + ndata = c.run(lambda dask_worker: len(dask_worker.data)) + # Allow for some uncertainty as the unmanaged memory is not stable + assert sum(ndata.values()) == 10 + assert 3 <= ndata[a] <= 7 + assert 3 <= ndata[b] <= 7 @gen_cluster(client=True) async def test_rebalance_unprepared(c, s, a, b): + """Client.rebalance() internally waits for unfinished futures""" futures = c.map(slowinc, range(10), delay=0.05, workers=a.address) + # Let the futures reach the scheduler await asyncio.sleep(0.1) + # We didn't wait enough for futures to complete. However, Client.rebalance() will + # block until all futures are completed before invoking Scheduler.rebalance(). await c.rebalance(futures) s.validate_state() -@gen_cluster(client=True) -async def test_rebalance_raises_missing_data(c, s, a, b): - with pytest.raises(ValueError, match="keys were found to be missing"): - futures = await c.scatter(range(100)) - keys = [f.key for f in futures] - del futures - await c.rebalance(keys) +@gen_cluster(client=True, Worker=Nanny, worker_kwargs={"memory_limit": "1 GiB"}) +async def test_rebalance_raises_missing_data(c, s, *_): + a, b = s.workers + futures = c.map(lambda _: "x" * (2 ** 29 // 10), range(10), workers=[a]) + await wait(futures) + # Wait for heartbeats + while s.memory.process < 2 ** 29: + await asyncio.sleep(0.1) + + # Descoping the futures enqueues a coroutine to release the data on the server + del futures + with pytest.raises(KeyError, match="keys were found to be missing"): + # During the synchronous part of rebalance, the futures still exist, but they + # will be (partially) gone by the time the actual transferring happens. + await c.rebalance() @gen_cluster(client=True) diff --git a/distributed/tests/test_scheduler.py b/distributed/tests/test_scheduler.py index ccd7ce31dd..f1ddae3615 100644 --- a/distributed/tests/test_scheduler.py +++ b/distributed/tests/test_scheduler.py @@ -2354,7 +2354,7 @@ async def assert_memory(scheduler_or_workerstate, attr: str, min_, max_, timeout # This test is heavily influenced by hard-to-control factors such as memory management # by the Python interpreter and the OS, so it occasionally glitches @pytest.mark.flaky(reruns=3, reruns_delay=5) -# ~33s runtime, or distributed.memory.recent_to_old_time + 3s +# ~33s runtime, or distributed.memory.recent-to-old-time + 3s @pytest.mark.slow @gen_cluster( client=True, Worker=Nanny, worker_kwargs={"memory_limit": "500 MiB"}, timeout=60 @@ -2407,7 +2407,7 @@ async def test_memory(c, s, *_): await assert_memory(s, "managed_spilled", 1, 999) # Wait for the spilling to finish. Note that this does not make the test take - # longer as we're waiting for recent_to_old_time anyway. + # longer as we're waiting for recent-to-old-time anyway. await asyncio.sleep(10) # Delete spilled keys @@ -2505,3 +2505,286 @@ async def test_close_scheduler__close_workers_Nanny(s, a, b): await asyncio.sleep(0.05) log = log.getvalue() assert "retry" not in log + + +async def assert_ndata(client, by_addr, total=None): + """Test that the number of elements in Worker.data is as expected. + To be used when the worker is wrapped by a nanny. + + by_addr: dict of either exact numbers or (min, max) tuples + total: optional exact match on the total number of keys (with duplicates) across all + workers + """ + out = await client.run(lambda dask_worker: len(dask_worker.data)) + try: + for k, v in by_addr.items(): + if isinstance(v, tuple): + assert v[0] <= out[k] <= v[1] + else: + assert out[k] == v + if total is not None: + assert sum(out.values()) == total + except AssertionError: + raise AssertionError(f"Expected {by_addr}, total={total}; got {out}") + + +@gen_cluster( + client=True, + Worker=Nanny, + worker_kwargs={"memory_limit": "1 GiB"}, + config={"distributed.worker.memory.rebalance.sender-min": 0.3}, +) +async def test_rebalance(c, s, *_): + # We used nannies to have separate processes for each worker + a, b = s.workers + + # Generate 10 buffers worth 512 MiB total on worker a. This sends its memory + # utilisation slightly above 50% (after counting unmanaged) which is above the + # distributed.worker.memory.rebalance.sender-min threshold. + futures = c.map(lambda _: "x" * (2 ** 29 // 10), range(10), workers=[a]) + await wait(futures) + # Wait for heartbeats + await assert_memory(s, "process", 512, 1024) + await assert_ndata(c, {a: 10, b: 0}) + await s.rebalance() + # Allow for some uncertainty as the unmanaged memory is not stable + await assert_ndata(c, {a: (3, 7), b: (3, 7)}, total=10) + + # rebalance() when there is nothing to do + await s.rebalance() + await assert_ndata(c, {a: (3, 7), b: (3, 7)}, total=10) + s.validate_state() + + +@gen_cluster( + nthreads=[("127.0.0.1", 1)] * 3, + client=True, + Worker=Nanny, + worker_kwargs={"memory_limit": "1 GiB"}, +) +async def test_rebalance_workers_and_keys(client, s, *_): + a, b, c = s.workers + futures = client.map(lambda _: "x" * (2 ** 29 // 10), range(10), workers=[a]) + await wait(futures) + # Wait for heartbeats + await assert_memory(s, "process", 512, 1024) + + # Passing empty iterables is not the same as omitting the arguments + await s.rebalance(keys=[]) + await assert_ndata(client, {a: 10, b: 0, c: 0}) + await s.rebalance(workers=[]) + await assert_ndata(client, {a: 10, b: 0, c: 0}) + # Limit operation to workers that have nothing to do + await s.rebalance(workers=[b, c]) + await assert_ndata(client, {a: 10, b: 0, c: 0}) + + # Limit rebalancing to two arbitrary keys and two arbitrary workers + await s.rebalance(keys=[futures[3].key, futures[7].key], workers=[a, b]) + await assert_ndata(client, {a: 8, b: 2, c: 0}, total=10) + + with pytest.raises(KeyError): + await s.rebalance(workers=["notexist"]) + + s.validate_state() + + +@gen_cluster() +async def test_rebalance_missing_data1(s, a, b): + """key never existed""" + out = await s.rebalance(keys=["notexist"]) + assert out == {"status": "missing-data", "keys": ["notexist"]} + s.validate_state() + + +@gen_cluster(client=True) +async def test_rebalance_missing_data2(c, s, a, b): + """keys exist but belong to unfinished futures. Unlike Client.rebalance(), + Scheduler.rebalance() does not wait for unfinished futures. + """ + futures = c.map(slowinc, range(10), delay=0.05, workers=a.address) + await asyncio.sleep(0.1) + out = await s.rebalance(keys=[f.key for f in futures]) + assert out["status"] == "missing-data" + assert 8 <= len(out["keys"]) <= 10 + s.validate_state() + + +@gen_cluster(client=True, Worker=Nanny, worker_kwargs={"memory_limit": "1 GiB"}) +async def test_rebalance_raises_missing_data3(c, s, *_): + """keys exist when the sync part of rebalance runs, but are gone by the time the + actual data movement runs + """ + a, _ = s.workers + futures = c.map(lambda _: "x" * (2 ** 29 // 10), range(10), workers=[a]) + await wait(futures) + # Wait for heartbeats + await assert_memory(s, "process", 512, 1024) + del futures + out = await s.rebalance() + assert out["status"] == "missing-data" + assert 1 <= len(out["keys"]) <= 10 + s.validate_state() + + +@gen_cluster(nthreads=[]) +async def test_rebalance_no_workers(s): + await s.rebalance() + s.validate_state() + + +@gen_cluster( + client=True, + Worker=Nanny, + worker_kwargs={"memory_limit": "1000 MiB"}, + config={ + "distributed.worker.memory.rebalance.measure": "managed", + "distributed.worker.memory.rebalance.sender-min": 0.3, + }, +) +async def test_rebalance_managed_memory(c, s, *_): + a, b = s.workers + # Generate 100 buffers worth 400 MiB total on worker a. This sends its memory + # utilisation to exactly 40%, ignoring unmanaged, which is above the + # distributed.worker.memory.rebalance.sender-min threshold. + futures = c.map(lambda _: "x" * (2 ** 22), range(100), workers=[a]) + await wait(futures) + # Even if we're just using managed memory, which is instantaneously accounted for as + # soon as the tasks finish, MemoryState.managed is still capped by the process + # memory, so we need to wait for the heartbeat. + await assert_memory(s, "managed", 400, 401) + await assert_ndata(c, {a: 100, b: 0}) + await s.rebalance() + # We can expect an exact, stable result because we are completely bypassing the + # unpredictability of unmanaged memory. + await assert_ndata(c, {a: 62, b: 38}) + s.validate_state() + + +@gen_cluster( + client=True, + worker_kwargs={"memory_limit": 0}, + config={"distributed.worker.memory.rebalance.measure": "managed"}, +) +async def test_rebalance_no_limit(c, s, a, b): + # See notes in test_rebalance_managed_memory + futures = c.map(lambda _: "x", range(100), workers=[a.address]) + await wait(futures) + # No reason to wait for memory here as we're allocating hundreds of bytes, so + # there's plenty of unmanaged process memory to pad it out + await assert_ndata(c, {a.address: 100, b.address: 0}) + await s.rebalance() + # Disabling memory_limit made us ignore all % thresholds set in the config + await assert_ndata(c, {a.address: 50, b.address: 50}) + s.validate_state() + + +@gen_cluster( + client=True, + Worker=Nanny, + worker_kwargs={"memory_limit": "1000 MiB"}, + config={ + "distributed.worker.memory.rebalance.measure": "managed", + "distributed.worker.memory.rebalance.recipient-max": 0.4, + }, +) +async def test_rebalance_no_recipients(c, s, *_): + """There are sender workers, but no recipient workers""" + a, b = s.workers + futures = [ + c.submit(lambda: "x" * (400 * 2 ** 20), pure=False, workers=[a]), # 40% + c.submit(lambda: "x" * (400 * 2 ** 20), pure=False, workers=[b]), # 40% + ] + c.map( + lambda _: "x" * (2 ** 21), range(100), workers=[a] + ) # 20% + await wait(futures) + await assert_memory(s, "managed", 1000, 1001) + await assert_ndata(c, {a: 101, b: 1}) + await s.rebalance() + await assert_ndata(c, {a: 101, b: 1}) + s.validate_state() + + +@gen_cluster( + nthreads=[("127.0.0.1", 1)] * 3, + client=True, + worker_kwargs={"memory_limit": 0}, + config={"distributed.worker.memory.rebalance.measure": "managed"}, +) +async def test_rebalance_skip_recipient(client, s, a, b, c): + """A recipient is skipped because it already holds a copy of the key to be sent""" + futures = client.map(lambda _: "x", range(10), workers=[a.address]) + await wait(futures) + await client.replicate(futures[0:2], workers=[a.address, b.address]) + await client.replicate(futures[2:4], workers=[a.address, c.address]) + await assert_ndata(client, {a.address: 10, b.address: 2, c.address: 2}) + await client.rebalance(futures[:2]) + await assert_ndata(client, {a.address: 8, b.address: 2, c.address: 4}) + s.validate_state() + + +@gen_cluster( + client=True, + worker_kwargs={"memory_limit": 0}, + config={"distributed.worker.memory.rebalance.measure": "managed"}, +) +async def test_rebalance_skip_all_recipients(c, s, a, b): + """All recipients are skipped because they already hold copies""" + futures = c.map(lambda _: "x", range(10), workers=[a.address]) + await wait(futures) + await c.replicate([futures[0]]) + await assert_ndata(c, {a.address: 10, b.address: 1}) + await c.rebalance(futures[:2]) + await assert_ndata(c, {a.address: 9, b.address: 2}) + s.validate_state() + + +@gen_cluster( + client=True, + Worker=Nanny, + worker_kwargs={"memory_limit": "1000 MiB"}, + config={"distributed.worker.memory.rebalance.measure": "managed"}, +) +async def test_rebalance_sender_below_mean(c, s, *_): + """A task remains on the sender because moving it would send it below the mean""" + a, b = s.workers + f1 = c.submit(lambda: "x" * (400 * 2 ** 20), workers=[a]) + await wait([f1]) + f2 = c.submit(lambda: "x" * (10 * 2 ** 20), workers=[a]) + await wait([f2]) + await assert_memory(s, "managed", 410, 411) + await assert_ndata(c, {a: 2, b: 0}) + await s.rebalance() + assert await c.has_what() == {a: (f1.key,), b: (f2.key,)} + + +@gen_cluster( + client=True, + Worker=Nanny, + worker_kwargs={"memory_limit": "1000 MiB"}, + config={ + "distributed.worker.memory.rebalance.measure": "managed", + "distributed.worker.memory.rebalance.sender-min": 0.3, + }, +) +async def test_rebalance_least_recently_inserted_sender_min(c, s, *_): + """ + 1. keys are picked using a least recently inserted policy + 2. workers below sender-min are never senders + """ + a, b = s.workers + small_futures = c.map(lambda _: "x", range(10), workers=[a]) + await wait(small_futures) + await assert_ndata(c, {a: 10, b: 0}) + await s.rebalance() + await assert_ndata(c, {a: 10, b: 0}) + + large_future = c.submit(lambda: "x" * (300 * 2 ** 20), workers=[a]) + await wait([large_future]) + await assert_memory(s, "managed", 300, 301) + await assert_ndata(c, {a: 11, b: 0}) + await s.rebalance() + await assert_ndata(c, {a: 1, b: 10}) + assert await c.has_what() == { + a: (large_future.key,), + b: tuple(f.key for f in small_futures), + } diff --git a/distributed/tests/test_tls_functional.py b/distributed/tests/test_tls_functional.py index 59e0bbb429..84a0f5922d 100644 --- a/distributed/tests/test_tls_functional.py +++ b/distributed/tests/test_tls_functional.py @@ -6,11 +6,9 @@ import pytest -from distributed import Client, Nanny, Queue, Scheduler, Worker, worker_client -from distributed.client import wait +from distributed import Client, Nanny, Queue, Scheduler, Worker, wait, worker_client from distributed.core import Status from distributed.metrics import time -from distributed.nanny import Nanny from distributed.utils_test import ( # noqa: F401 cleanup, double, @@ -101,16 +99,36 @@ async def test_nanny(c, s, a, b): assert result == 11 -@gen_tls_cluster(client=True) -async def test_rebalance(c, s, a, b): - x, y = await c._scatter([1, 2], workers=[a.address]) - assert len(a.data) == 2 - assert len(b.data) == 0 +@gen_tls_cluster( + client=True, + Worker=Nanny, + worker_kwargs={"memory_limit": "1 GiB"}, + config={"distributed.worker.memory.rebalance.sender-min": 0.3}, +) +async def test_rebalance(c, s, *_): + # We used nannies to have separate processes for each worker + a, b = s.workers + assert a.startswith("tls://") + + # Generate 10 buffers worth 512 MiB total on worker a. This sends its memory + # utilisation slightly above 50% (after counting unmanaged) which is above the + # distributed.worker.memory.rebalance.sender-min threshold. + futures = c.map(lambda _: "x" * (2 ** 29 // 10), range(10), workers=[a]) + await wait(futures) + + # Wait for heartbeats + while s.memory.process < 2 ** 29: + await asyncio.sleep(0.1) + + assert await c.run(lambda dask_worker: len(dask_worker.data)) == {a: 10, b: 0} - await c._rebalance() + await c.rebalance() - assert len(a.data) == 1 - assert len(b.data) == 1 + ndata = await c.run(lambda dask_worker: len(dask_worker.data)) + # Allow for some uncertainty as the unmanaged memory is not stable + assert sum(ndata.values()) == 10 + assert 3 <= ndata[a] <= 7 + assert 3 <= ndata[b] <= 7 @gen_tls_cluster(client=True, nthreads=[("tls://127.0.0.1", 2)] * 2) diff --git a/docs/source/memory.rst b/docs/source/memory.rst index 2ae16698ea..849b71ec61 100644 --- a/docs/source/memory.rst +++ b/docs/source/memory.rst @@ -166,7 +166,8 @@ copied to another worker node in the course of normal computation if that result is required by another task that is intended to by run by a different worker. This occurs if a task requires two pieces of data on different machines (at least one must move) or through work stealing. In these cases it -is the policy for the second machine to maintain its redundant copy of the data. This helps to organically spread around data that is in high demand. +is the policy for the second machine to maintain its redundant copy of the data. +This helps to organically spread around data that is in high demand. However, advanced users may want to control the location, replication, and balancing of data more directly throughout the cluster. They may know ahead of