Skip to content

Commit

Permalink
Merge branch 'main' into ensure_communicating
Browse files Browse the repository at this point in the history
  • Loading branch information
crusaderky committed Apr 27, 2022
2 parents 672aaf0 + 9bad573 commit df6f189
Show file tree
Hide file tree
Showing 6 changed files with 215 additions and 7 deletions.
3 changes: 2 additions & 1 deletion .github/PULL_REQUEST_TEMPLATE.md
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
- [ ] Closes #xxxx
Closes #xxxx

- [ ] Tests added / passed
- [ ] Passes `pre-commit run --all-files`
5 changes: 3 additions & 2 deletions distributed/nanny.py
Original file line number Diff line number Diff line change
Expand Up @@ -853,12 +853,13 @@ def watch_stop_q():
"""
try:
msg = child_stop_q.get()
except (TypeError, OSError):
except (TypeError, OSError, EOFError):
logger.error("Worker process died unexpectedly")
msg = {"op": "stop"}
finally:
child_stop_q.close()
assert msg.pop("op") == "stop"
assert msg["op"] == "stop", msg
del msg["op"]
loop.add_callback(do_stop, **msg)

thread = threading.Thread(
Expand Down
182 changes: 181 additions & 1 deletion distributed/tests/test_worker_memory.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,8 @@

import asyncio
import logging
from collections import UserDict
import threading
from collections import Counter, UserDict
from time import sleep

import pytest
Expand All @@ -12,6 +13,7 @@
import distributed.system
from distributed import Client, Event, Nanny, Worker, wait
from distributed.core import Status
from distributed.metrics import monotonic
from distributed.spill import has_zict_210
from distributed.utils_test import captured_logger, gen_cluster, inc
from distributed.worker_memory import parse_memory_limit
Expand Down Expand Up @@ -478,6 +480,73 @@ def f(ev):
assert "Resuming worker" in logger.getvalue()


@gen_cluster(
client=True,
nthreads=[("", 1), ("", 1)],
config={
"distributed.worker.memory.target": False,
"distributed.worker.memory.spill": False,
"distributed.worker.memory.pause": False,
},
)
async def test_pause_prevents_deps_fetch(c, s, a, b):
"""A worker is paused while there are dependencies ready to fetch, but all other
workers are in flight
"""
a_addr = a.address

class X:
def __sizeof__(self):
return 2**40 # Disable clustering in select_keys_for_gather

def __reduce__(self):
return X.pause_on_unpickle, ()

@staticmethod
def pause_on_unpickle():
# Note: outside of task execution, distributed.get_worker()
# returns a random worker running in the process
for w in Worker._instances:
if w.address == a_addr:
w.status = Status.paused
return X()
assert False

x = c.submit(X, key="x", workers=[b.address])
y = c.submit(inc, 1, key="y", workers=[b.address])
await wait([x, y])
w = c.submit(lambda _: None, x, key="w", priority=1, workers=[a.address])
z = c.submit(inc, y, key="z", priority=0, workers=[a.address])

# - w and z reach worker a within the same message
# - w and z respectively make x and y go into fetch state.
# w has a higher priority than z, therefore w's dependency x has a higher priority
# than z's dependency y.
# a.data_needed = ["x", "y"]
# - ensure_communicating decides to fetch x but not to fetch y together with it, as
# it thinks x is 1TB in size
# - x fetch->flight; a is added to in_flight_workers
# - y is skipped by ensure_communicating since all workers that hold a replica are
# in flight
# - x reaches a and sends a into paused state
# - x flight->memory; a is removed from in_flight_workers
# - ensure_communicating is triggered again
# - ensure_communicating refuses to fetch y because the worker is paused

while "y" not in a.tasks or a.tasks["y"].state != "fetch":
await asyncio.sleep(0.01)
await asyncio.sleep(0.1)
assert a.tasks["y"].state == "fetch"
assert "y" not in a.data
assert [ts.key for ts in a.data_needed] == ["y"]

# Unpausing kicks off ensure_communicating again
a.status = Status.running
assert await z == 3
assert a.tasks["y"].state == "memory"
assert "y" in a.data


@gen_cluster(
client=True,
nthreads=[("", 1)],
Expand Down Expand Up @@ -636,6 +705,117 @@ def leak():
assert "memory" in out.lower()


@gen_cluster(
nthreads=[("", 1)],
client=True,
worker_kwargs={"memory_limit": "10 GiB"},
config={
"distributed.worker.memory.target": False,
"distributed.worker.memory.spill": 0.7,
"distributed.worker.memory.pause": 0.9,
"distributed.worker.memory.monitor-interval": "10ms",
},
)
async def test_pause_while_spilling(c, s, a):
N_PAUSE = 3
N_TOTAL = 5

def get_process_memory():
if len(a.data) < N_PAUSE:
# Don't trigger spilling until after all tasks have completed
return 0
elif a.data.fast and not a.data.slow:
# Trigger spilling
return 8 * 2**30
else:
# Trigger pause, but only after we started spilling
return 10 * 2**30

a.monitor.get_process_memory = get_process_memory

class SlowSpill:
def __init__(self, _):
# Can't pickle a Semaphore, so instead of a default value, we create it
# here. Don't worry about race conditions; the worker is single-threaded.
if not hasattr(type(self), "sem"):
type(self).sem = threading.Semaphore(N_PAUSE)
# Block if there are N_PAUSE tasks in a.data.fast
self.sem.acquire()

def __reduce__(self):
paused = distributed.get_worker().status == Status.paused
if not paused:
sleep(0.1)
self.sem.release()
return bool, (paused,)

futs = c.map(SlowSpill, range(N_TOTAL))
while len(a.data.slow) < N_PAUSE + 1:
await asyncio.sleep(0.01)

assert a.status == Status.paused
# Worker should have become paused after the first `SlowSpill` was evicted, because
# the spill to disk took longer than the memory monitor interval.
assert len(a.data.fast) == 0
assert len(a.data.slow) == N_PAUSE + 1
n_spilled_while_paused = sum(paused is True for paused in a.data.slow.values())
assert N_PAUSE <= n_spilled_while_paused <= N_PAUSE + 1


@pytest.mark.slow
@gen_cluster(
nthreads=[("", 1)],
client=True,
worker_kwargs={"memory_limit": "10 GiB"},
config={
"distributed.worker.memory.target": False,
"distributed.worker.memory.spill": 0.6,
"distributed.worker.memory.pause": False,
"distributed.worker.memory.monitor-interval": "10ms",
},
)
async def test_release_evloop_while_spilling(c, s, a):
N = 100

def get_process_memory():
if len(a.data) < N:
# Don't trigger spilling until after all tasks have completed
return 0
return 10 * 2**30

a.monitor.get_process_memory = get_process_memory

class SlowSpill:
def __reduce__(self):
sleep(0.01)
return SlowSpill, ()

futs = [c.submit(SlowSpill, pure=False) for _ in range(N)]
while len(a.data) < N:
await asyncio.sleep(0)

ts = [monotonic()]
while a.data.fast:
await asyncio.sleep(0)
ts.append(monotonic())

# 100 tasks taking 0.01s to pickle each = 2s to spill everything
# (this is because everything is pickled twice:
# https://github.com/dask/distributed/issues/1371).
# We should regain control of the event loop every 0.5s.
c = Counter(round(t1 - t0, 1) for t0, t1 in zip(ts, ts[1:]))
# Depending on the implementation of WorkerMemoryMonitor._maybe_spill:
# if it calls sleep(0) every 0.5s:
# {0.0: 315, 0.5: 4}
# if it calls sleep(0) after spilling each key:
# {0.0: 233}
# if it never yields:
# {0.0: 359, 2.0: 1}
# Make sure we remain in the first use case.
assert 1 < sum(v for k, v in c.items() if 0.5 <= k <= 1.9), dict(c)
assert not any(v for k, v in c.items() if k >= 2.0), dict(c)


@pytest.mark.parametrize(
"cls,name,value",
[
Expand Down
5 changes: 5 additions & 0 deletions distributed/utils_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,7 @@
except ImportError:
pass

from pytest_timeout import is_debugging

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -798,6 +799,8 @@ async def test_foo():
"timeout should always be set and it should be smaller than the global one from"
"pytest-timeout"
)
if is_debugging():
timeout = 3600

def _(func):
def test_func(*args, **kwargs):
Expand Down Expand Up @@ -956,6 +959,8 @@ async def test_foo(scheduler, worker1, worker2, pytest_fixture_a, pytest_fixture
"timeout should always be set and it should be smaller than the global one from"
"pytest-timeout"
)
if is_debugging():
timeout = 3600

scheduler_kwargs = merge(
{"dashboard": False, "dashboard_address": ":0"}, scheduler_kwargs
Expand Down
5 changes: 4 additions & 1 deletion distributed/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -2763,6 +2763,9 @@ def stimulus_story(
return [ev for ev in self.stimulus_log if getattr(ev, "key", None) in keys]

def _ensure_communicating(self) -> RecsInstrs:
if self.status != Status.running:
return {}, []

stimulus_id = f"ensure-communicating-{time()}"
skipped_worker_in_flight_or_busy = []

Expand Down Expand Up @@ -3528,7 +3531,7 @@ async def _maybe_deserialize_task(
raise

def _ensure_computing(self) -> RecsInstrs:
if self.status in (Status.paused, Status.closing_gracefully):
if self.status != Status.running:
return {}, []

recs: Recs = {}
Expand Down
22 changes: 20 additions & 2 deletions distributed/worker_memory.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@

from distributed import system
from distributed.core import Status
from distributed.metrics import monotonic
from distributed.spill import ManualEvictProto, SpillBuffer
from distributed.utils import log_errors
from distributed.utils_perf import ThrottledGC
Expand Down Expand Up @@ -234,6 +235,8 @@ async def _maybe_spill(self, worker: Worker, memory: int) -> None:
)
count = 0
need = memory - target
last_checked_for_pause = last_yielded = monotonic()

while memory > target:
if not data.fast:
logger.warning(
Expand All @@ -255,7 +258,6 @@ async def _maybe_spill(self, worker: Worker, memory: int) -> None:

total_spilled += weight
count += 1
await asyncio.sleep(0)

memory = worker.monitor.get_process_memory()
if total_spilled > need and memory > target:
Expand All @@ -265,7 +267,23 @@ async def _maybe_spill(self, worker: Worker, memory: int) -> None:
self._throttled_gc.collect()
memory = worker.monitor.get_process_memory()

self._maybe_pause_or_unpause(worker, memory)
now = monotonic()

# Spilling may potentially take multiple seconds; we may pass the pause
# threshold in the meantime.
if now - last_checked_for_pause > self.memory_monitor_interval:
self._maybe_pause_or_unpause(worker, memory)
last_checked_for_pause = now

# Increase spilling aggressiveness when the fast buffer is filled with a lot
# of small values. This artificially chokes the rest of the event loop -
# namely, the reception of new data from other workers. While this is
# somewhat of an ugly hack, DO NOT tweak this without a thorough cycle of
# stress testing. See: https://github.com/dask/distributed/issues/6110.
if now - last_yielded > 0.5:
await asyncio.sleep(0)
last_yielded = monotonic()

if count:
logger.debug(
"Moved %d tasks worth %s to disk",
Expand Down

0 comments on commit df6f189

Please sign in to comment.