Skip to content

Commit

Permalink
Encapsulate spill buffer and memory_monitor (#5904)
Browse files Browse the repository at this point in the history
  • Loading branch information
crusaderky authored Mar 18, 2022
1 parent 5dd80f3 commit 2d3fddc
Show file tree
Hide file tree
Showing 26 changed files with 1,637 additions and 1,242 deletions.
6 changes: 6 additions & 0 deletions distributed/active_memory_manager.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,9 @@
"""Implementation of the Active Memory Manager. This is a scheduler extension which
sends drop/replicate suggestions to the worker.
See also :mod:`distributed.worker_memory` and :mod:`distributed.spill`, which implement
spill/pause/terminate mechanics on the Worker side.
"""
from __future__ import annotations

import logging
Expand Down
9 changes: 5 additions & 4 deletions distributed/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -263,7 +263,10 @@ def set_thread_ident():

@property
def status(self):
return self._status
try:
return self._status
except AttributeError:
return Status.undefined

@status.setter
def status(self, new_status):
Expand Down Expand Up @@ -398,9 +401,7 @@ def port(self):
def identity(self) -> dict[str, str]:
return {"type": type(self).__name__, "id": self.id}

def _to_dict(
self, comm: Comm | None = None, *, exclude: Container[str] = ()
) -> dict:
def _to_dict(self, *, exclude: Container[str] = ()) -> dict:
"""Dictionary representation for debugging purposes.
Not type stable and not intended for roundtrips.
Expand Down
3 changes: 2 additions & 1 deletion distributed/deploy/local.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,8 @@
from distributed.nanny import Nanny
from distributed.scheduler import Scheduler
from distributed.security import Security
from distributed.worker import Worker, parse_memory_limit
from distributed.worker import Worker
from distributed.worker_memory import parse_memory_limit

logger = logging.getLogger(__name__)

Expand Down
5 changes: 5 additions & 0 deletions distributed/distributed-schema.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -512,6 +512,11 @@ properties:
description: >-
Limit of number of bytes to be spilled on disk.
monitor-interval:
type: string
description: >-
Interval between checks for the spill, pause, and terminate thresholds
http:
type: object
description: Settings for Dask's embedded HTTP Server
Expand Down
4 changes: 4 additions & 0 deletions distributed/distributed.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -149,6 +149,10 @@ distributed:
# Set to false for no maximum.
max-spill: false

# Interval between checks for the spill, pause, and terminate thresholds.
# The target threshold is checked every time new data is inserted.
monitor-interval: 100ms

http:
routes:
- distributed.http.worker.prometheus
Expand Down
57 changes: 18 additions & 39 deletions distributed/nanny.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,11 @@
from inspect import isawaitable
from queue import Empty
from time import sleep as sync_sleep
from typing import TYPE_CHECKING, ClassVar, Literal
from typing import TYPE_CHECKING, ClassVar

import psutil
from tornado import gen
from tornado.ioloop import IOLoop, PeriodicCallback
from tornado.ioloop import IOLoop

import dask
from dask.system import CPU_COUNT
Expand Down Expand Up @@ -49,7 +49,12 @@
parse_ports,
silence_logging,
)
from distributed.worker import Worker, parse_memory_limit, run
from distributed.worker import Worker, run
from distributed.worker_memory import (
DeprecatedMemoryManagerAttribute,
DeprecatedMemoryMonitor,
NannyMemoryManager,
)

if TYPE_CHECKING:
from distributed.diagnostics.plugin import NannyPlugin
Expand Down Expand Up @@ -89,6 +94,7 @@ class Nanny(ServerNode):
_instances: ClassVar[weakref.WeakSet[Nanny]] = weakref.WeakSet()
process = None
status = Status.undefined
memory_manager: NannyMemoryManager

def __init__(
self,
Expand All @@ -103,7 +109,6 @@ def __init__(
services=None,
name=None,
memory_limit="auto",
memory_terminate_fraction: float | Literal[False] | None = None,
reconnect=True,
validate=False,
quiet=False,
Expand Down Expand Up @@ -192,7 +197,8 @@ def __init__(
config_environ = dask.config.get("distributed.nanny.environ", {})
if not isinstance(config_environ, dict):
raise TypeError(
f"distributed.nanny.environ configuration must be of type dict. Instead got {type(config_environ)}"
"distributed.nanny.environ configuration must be of type dict. "
f"Instead got {type(config_environ)}"
)
self.env = config_environ.copy()
for k in self.env:
Expand All @@ -213,19 +219,12 @@ def __init__(
self.worker_kwargs = worker_kwargs

self.contact_address = contact_address
self.memory_terminate_fraction = (
memory_terminate_fraction
if memory_terminate_fraction is not None
else dask.config.get("distributed.worker.memory.terminate")
)

self.services = services
self.name = name
self.quiet = quiet
self.auto_restart = True

self.memory_limit = parse_memory_limit(memory_limit, self.nthreads)

if silence_logs:
silence_logging(level=silence_logs)
self.silence_logs = silence_logs
Expand All @@ -250,10 +249,7 @@ def __init__(
)

self.scheduler = self.rpc(self.scheduler_addr)

if self.memory_limit:
pc = PeriodicCallback(self.memory_monitor, 100)
self.periodic_callbacks["memory"] = pc
self.memory_manager = NannyMemoryManager(self, memory_limit=memory_limit)

if (
not host
Expand All @@ -271,6 +267,11 @@ def __init__(
Nanny._instances.add(self)
self.status = Status.init

# Deprecated attributes; use Nanny.memory_manager.<name> instead
memory_limit = DeprecatedMemoryManagerAttribute()
memory_terminate_fraction = DeprecatedMemoryManagerAttribute()
memory_monitor = DeprecatedMemoryMonitor()

def __repr__(self):
return "<Nanny: %s, threads: %d>" % (self.worker_address, self.nthreads)

Expand Down Expand Up @@ -388,7 +389,7 @@ async def instantiate(self) -> Status:
services=self.services,
nanny=self.address,
name=self.name,
memory_limit=self.memory_limit,
memory_limit=self.memory_manager.memory_limit,
reconnect=self.reconnect,
resources=self.resources,
validate=self.validate,
Expand Down Expand Up @@ -502,28 +503,6 @@ def _psutil_process(self):

return self._psutil_process_obj

def memory_monitor(self):
"""Track worker's memory. Restart if it goes above terminate fraction"""
if self.status != Status.running:
return
if self.process is None or self.process.process is None:
return None
process = self.process.process

try:
proc = self._psutil_process
memory = proc.memory_info().rss
except (ProcessLookupError, psutil.NoSuchProcess, psutil.AccessDenied):
return
frac = memory / self.memory_limit

if self.memory_terminate_fraction and frac > self.memory_terminate_fraction:
logger.warning(
"Worker exceeded %d%% memory budget. Restarting",
100 * self.memory_terminate_fraction,
)
process.terminate()

def is_alive(self):
return self.process is not None and self.process.is_alive()

Expand Down
5 changes: 1 addition & 4 deletions distributed/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,6 @@
from distributed.active_memory_manager import ActiveMemoryManagerExtension, RetireWorker
from distributed.batched import BatchedSend
from distributed.comm import (
Comm,
get_address_host,
normalize_address,
resolve_address,
Expand Down Expand Up @@ -4060,9 +4059,7 @@ def identity(self):
}
return d

def _to_dict(
self, comm: "Comm | None" = None, *, exclude: "Container[str]" = ()
) -> dict:
def _to_dict(self, *, exclude: "Container[str]" = ()) -> dict:
"""Dictionary representation for debugging purposes.
Not type stable and not intended for roundtrips.
Expand Down
46 changes: 40 additions & 6 deletions distributed/spill.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,15 @@

import logging
import time
from collections.abc import Mapping, MutableMapping
from collections.abc import Mapping, MutableMapping, Sized
from contextlib import contextmanager
from functools import partial
from typing import Any, Literal, NamedTuple, cast
from typing import Any, Literal, NamedTuple, Protocol, cast

import zict
from packaging.version import parse as parse_version

import zict

from distributed.protocol import deserialize_bytes, serialize_bytelist
from distributed.sizeof import safe_sizeof

Expand All @@ -34,6 +35,36 @@ def __sub__(self, other: SpilledSize) -> SpilledSize: # type: ignore
return SpilledSize(self.memory - other.memory, self.disk - other.disk)


class ManualEvictProto(Protocol):
"""Duck-type API that a third-party alternative to SpillBuffer must respect (in
addition to MutableMapping) if it wishes to support spilling when the
``distributed.worker.memory.spill`` threshold is surpassed.
This is public API. At the moment of writing, Dask-CUDA implements this protocol in
the ProxifyHostFile class.
"""

@property
def fast(self) -> Sized | bool:
"""Access to fast memory. This is normally a MutableMapping, but for the purpose
of the manual eviction API it is just tested for emptiness to know if there is
anything to evict.
"""
... # pragma: nocover

def evict(self) -> int:
"""Manually evict a key/value pair from fast to slow memory.
Return size of the evicted value in fast memory.
If the eviction failed for whatever reason, return -1. This method must
guarantee that the key/value pair that caused the issue has been retained in
fast memory and that the problem has been logged internally.
This method never raises.
"""
... # pragma: nocover


# zict.Buffer[str, Any] requires zict >= 2.2.0
class SpillBuffer(zict.Buffer):
"""MutableMapping that automatically spills out dask key/value pairs to disk when
Expand Down Expand Up @@ -166,11 +197,14 @@ def __setitem__(self, key: str, value: Any) -> None:
assert key not in self.slow

def evict(self) -> int:
"""Manually evict the oldest key/value pair, even if target has not been reached.
Returns sizeof(value).
"""Implementation of :meth:`ManualEvictProto.evict`.
Manually evict the oldest key/value pair, even if target has not been
reached. Returns sizeof(value).
If the eviction failed (value failed to pickle, disk full, or max_spill
exceeded), return -1; the key/value pair that caused the issue will remain in
fast. This method never raises.
fast. The exception has been logged internally.
This method never raises.
"""
try:
with self.handle_errors(None):
Expand Down
2 changes: 1 addition & 1 deletion distributed/system.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
__all__ = ("memory_limit", "MEMORY_LIMIT")


def memory_limit():
def memory_limit() -> int:
"""Get the memory limit (in bytes) for this system.
Takes the minimum value from the following locations:
Expand Down
Loading

0 comments on commit 2d3fddc

Please sign in to comment.