Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Don't stop Adaptive on error #8871

Merged
merged 6 commits into from
Sep 27, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
115 changes: 100 additions & 15 deletions distributed/deploy/adaptive.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,39 @@
from __future__ import annotations

import logging
from collections.abc import Hashable
from datetime import timedelta
from inspect import isawaitable
from typing import TYPE_CHECKING, Any, Callable, Literal, cast

from tornado.ioloop import IOLoop

import dask.config
from dask.utils import parse_timedelta

from distributed.compatibility import PeriodicCallback
from distributed.core import Status
from distributed.deploy.adaptive_core import AdaptiveCore
from distributed.protocol import pickle
from distributed.utils import log_errors

if TYPE_CHECKING:
from typing_extensions import TypeAlias

from distributed.deploy.cluster import Cluster
from distributed.scheduler import WorkerState

logger = logging.getLogger(__name__)


AdaptiveStateState: TypeAlias = Literal[
"starting",
"running",
"stopped",
"inactive",
]


class Adaptive(AdaptiveCore):
'''
Adaptively allocate workers based on scheduler load. A superclass.
Expand Down Expand Up @@ -81,16 +100,21 @@ class Adaptive(AdaptiveCore):
specified in the dask config under the distributed.adaptive key.
'''

interval: float | None
periodic_callback: PeriodicCallback | None
#: Whether this adaptive strategy is periodically adapting
state: AdaptiveStateState

def __init__(
self,
cluster=None,
interval=None,
minimum=None,
maximum=None,
wait_count=None,
target_duration=None,
worker_key=None,
**kwargs,
cluster: Cluster,
interval: str | float | timedelta | None = None,
minimum: int | None = None,
maximum: int | float | None = None,
wait_count: int | None = None,
target_duration: str | float | timedelta | None = None,
worker_key: Callable[[WorkerState], Hashable] | None = None,
**kwargs: Any,
):
self.cluster = cluster
self.worker_key = worker_key
Expand All @@ -99,20 +123,78 @@ def __init__(
if interval is None:
interval = dask.config.get("distributed.adaptive.interval")
if minimum is None:
minimum = dask.config.get("distributed.adaptive.minimum")
minimum = cast(int, dask.config.get("distributed.adaptive.minimum"))
if maximum is None:
maximum = dask.config.get("distributed.adaptive.maximum")
maximum = cast(float, dask.config.get("distributed.adaptive.maximum"))
if wait_count is None:
wait_count = dask.config.get("distributed.adaptive.wait-count")
wait_count = cast(int, dask.config.get("distributed.adaptive.wait-count"))
if target_duration is None:
target_duration = dask.config.get("distributed.adaptive.target-duration")
target_duration = cast(
str, dask.config.get("distributed.adaptive.target-duration")
)

self.interval = parse_timedelta(interval, "seconds")
self.periodic_callback = None

if self.interval and self.cluster:
import weakref

self_ref = weakref.ref(self)

async def _adapt():
adaptive = self_ref()
if not adaptive or adaptive.state != "running":
return
if adaptive.cluster.status != Status.running:
adaptive.stop(reason="cluster-not-running")
return
try:
await adaptive.adapt()
except Exception:
logger.warning(
"Adaptive encountered an error while adapting", exc_info=True
)

self.periodic_callback = PeriodicCallback(_adapt, self.interval * 1000)
self.state = "starting"
self.loop.add_callback(self._start)
else:
self.state = "inactive"

self.target_duration = parse_timedelta(target_duration)

super().__init__(
minimum=minimum, maximum=maximum, wait_count=wait_count, interval=interval
super().__init__(minimum=minimum, maximum=maximum, wait_count=wait_count)

def _start(self) -> None:
if self.state != "starting":
return

assert self.periodic_callback is not None
self.periodic_callback.start()
self.state = "running"
logger.info(
"Adaptive scaling started: minimum=%s maximum=%s",
self.minimum,
self.maximum,
)

def stop(self, reason: str = "unknown") -> None:
if self.state in ("inactive", "stopped"):
return

if self.state == "running":
assert self.periodic_callback is not None
self.periodic_callback.stop()
logger.info(
"Adaptive scaling stopped: minimum=%s maximum=%s. Reason: %s",
self.minimum,
self.maximum,
reason,
)

self.periodic_callback = None
self.state = "stopped"

@property
def scheduler(self):
return self.cluster.scheduler_comm
Expand Down Expand Up @@ -210,6 +292,9 @@ async def scale_up(self, n):
def loop(self) -> IOLoop:
"""Override Adaptive.loop"""
if self.cluster:
return self.cluster.loop
return self.cluster.loop # type: ignore[return-value]
else:
return IOLoop.current()

def __del__(self):
self.stop(reason="adaptive-deleted")
116 changes: 20 additions & 96 deletions distributed/deploy/adaptive_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,37 +2,24 @@

import logging
import math
from abc import ABC, abstractmethod
from collections import defaultdict, deque
from collections.abc import Iterable
from datetime import timedelta
from typing import TYPE_CHECKING, Literal, cast
from typing import TYPE_CHECKING, cast

import tlz as toolz
from tornado.ioloop import IOLoop

import dask.config
from dask.utils import parse_timedelta

from distributed.compatibility import PeriodicCallback
from distributed.metrics import time

if TYPE_CHECKING:
from typing_extensions import TypeAlias

from distributed.scheduler import WorkerState

logger = logging.getLogger(__name__)


AdaptiveStateState: TypeAlias = Literal[
"starting",
"running",
"stopped",
"inactive",
]


class AdaptiveCore:
class AdaptiveCore(ABC):
"""
The core logic for adaptive deployments, with none of the cluster details

Expand Down Expand Up @@ -91,54 +78,22 @@ class AdaptiveCore:
minimum: int
maximum: int | float
wait_count: int
interval: int | float
periodic_callback: PeriodicCallback | None
plan: set[WorkerState]
requested: set[WorkerState]
observed: set[WorkerState]
close_counts: defaultdict[WorkerState, int]
_adapting: bool
#: Whether this adaptive strategy is periodically adapting
_state: AdaptiveStateState
log: deque[tuple[float, dict]]
_adapting: bool

def __init__(
self,
minimum: int = 0,
maximum: int | float = math.inf,
wait_count: int = 3,
interval: str | int | float | timedelta = "1s",
):
if not isinstance(maximum, int) and not math.isinf(maximum):
raise TypeError(f"maximum must be int or inf; got {maximum}")
raise ValueError(f"maximum must be int or inf; got {maximum}")

self.minimum = minimum
self.maximum = maximum
self.wait_count = wait_count
self.interval = parse_timedelta(interval, "seconds")
self.periodic_callback = None

if self.interval:
import weakref

self_ref = weakref.ref(self)

async def _adapt():
core = self_ref()
if core:
await core.adapt()

self.periodic_callback = PeriodicCallback(_adapt, self.interval * 1000)
self._state = "starting"
self.loop.add_callback(self._start)
else:
self._state = "inactive"
try:
self.plan = set()
self.requested = set()
self.observed = set()
except Exception:
pass

# internal state
self.close_counts = defaultdict(int)
Expand All @@ -147,38 +102,22 @@ async def _adapt():
maxlen=dask.config.get("distributed.admin.low-level-log-length")
)

def _start(self) -> None:
if self._state != "starting":
return

assert self.periodic_callback is not None
self.periodic_callback.start()
self._state = "running"
logger.info(
"Adaptive scaling started: minimum=%s maximum=%s",
self.minimum,
self.maximum,
)

def stop(self) -> None:
if self._state in ("inactive", "stopped"):
return
@property
@abstractmethod
def plan(self) -> set[WorkerState]: ...

if self._state == "running":
assert self.periodic_callback is not None
self.periodic_callback.stop()
logger.info(
"Adaptive scaling stopped: minimum=%s maximum=%s",
self.minimum,
self.maximum,
)
@property
@abstractmethod
def requested(self) -> set[WorkerState]: ...

self.periodic_callback = None
self._state = "stopped"
@property
@abstractmethod
def observed(self) -> set[WorkerState]: ...

@abstractmethod
async def target(self) -> int:
"""The target number of workers that should exist"""
raise NotImplementedError()
...

async def workers_to_close(self, target: int) -> list:
"""
Expand All @@ -198,11 +137,11 @@ async def safe_target(self) -> int:

return n

async def scale_down(self, n: int) -> None:
raise NotImplementedError()
@abstractmethod
async def scale_down(self, n: int) -> None: ...

async def scale_up(self, workers: Iterable) -> None:
raise NotImplementedError()
@abstractmethod
async def scale_up(self, workers: Iterable) -> None: ...

async def recommendations(self, target: int) -> dict:
"""
Expand Down Expand Up @@ -270,20 +209,5 @@ async def adapt(self) -> None:
await self.scale_up(**recommendations)
if status == "down":
await self.scale_down(**recommendations)
except OSError:
if status != "down":
logger.error("Adaptive stopping due to error", exc_info=True)
self.stop()
else:
logger.error(
"Error during adaptive downscaling. Ignoring.", exc_info=True
)
finally:
self._adapting = False

def __del__(self):
self.stop()

@property
def loop(self) -> IOLoop:
return IOLoop.current()
Loading
Loading