Skip to content

Commit

Permalink
Increase close timeout of Nanny in LocalCUDACluster (#1260)
Browse files Browse the repository at this point in the history
Tests in CI have been failing more often, but those errors can't be reproduced locally. This is possibly related to `Nanny`'s internal mechanism to establish timeouts to kill processes, perhaps due to higher load on the servers, tasks take longer and killing processes takes into account the overall time taken to establish a timeout, which is then drastically reduced leaving little time to actually shutdown processes. It is also not possible to programatically set a different timeout given existing Distributed's API, which currently calls `close()` without arguments in `SpecCluster._correct_state_internal()`.

Given the limitations described above, a new class is added by this change with the sole purpose of rewriting the timeout for `Nanny.close()` method with an increased value, and then use the new class when launching `LocalCUDACluster` via the `worker_class` argument.

Authors:
  - Peter Andreas Entschev (https://github.com/pentschev)

Approvers:
  - Benjamin Zaitlen (https://github.com/quasiben)

URL: #1260
  • Loading branch information
pentschev authored Oct 12, 2023
1 parent 2ffd1d6 commit 48de0c5
Show file tree
Hide file tree
Showing 2 changed files with 21 additions and 8 deletions.
22 changes: 14 additions & 8 deletions dask_cuda/local_cuda_cluster.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@
import logging
import os
import warnings
from functools import partial
from typing import Literal

import dask
from distributed import LocalCluster, Nanny, Worker
Expand All @@ -23,6 +25,13 @@
)


class IncreasedCloseTimeoutNanny(Nanny):
async def close( # type:ignore[override]
self, timeout: float = 10.0, reason: str = "nanny-close"
) -> Literal["OK"]:
return await super().close(timeout=timeout, reason=reason)


class LoggedWorker(Worker):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
Expand All @@ -32,7 +41,7 @@ async def start(self):
self.data.set_address(self.address)


class LoggedNanny(Nanny):
class LoggedNanny(IncreasedCloseTimeoutNanny):
def __init__(self, *args, **kwargs):
super().__init__(*args, worker_class=LoggedWorker, **kwargs)

Expand Down Expand Up @@ -333,13 +342,10 @@ def __init__(
enable_rdmacm=enable_rdmacm,
)

if worker_class is not None:
from functools import partial

worker_class = partial(
LoggedNanny if log_spilling is True else Nanny,
worker_class=worker_class,
)
worker_class = partial(
LoggedNanny if log_spilling is True else IncreasedCloseTimeoutNanny,
worker_class=worker_class,
)

self.pre_import = pre_import

Expand Down
7 changes: 7 additions & 0 deletions dask_cuda/tests/test_explicit_comms.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
import dask_cuda
from dask_cuda.explicit_comms import comms
from dask_cuda.explicit_comms.dataframe.shuffle import shuffle as explicit_comms_shuffle
from dask_cuda.local_cuda_cluster import IncreasedCloseTimeoutNanny

mp = mp.get_context("spawn") # type: ignore
ucp = pytest.importorskip("ucp")
Expand All @@ -35,6 +36,7 @@ def _test_local_cluster(protocol):
dashboard_address=None,
n_workers=4,
threads_per_worker=1,
worker_class=IncreasedCloseTimeoutNanny,
processes=True,
) as cluster:
with Client(cluster) as client:
Expand All @@ -56,6 +58,7 @@ def _test_dataframe_merge_empty_partitions(nrows, npartitions):
dashboard_address=None,
n_workers=npartitions,
threads_per_worker=1,
worker_class=IncreasedCloseTimeoutNanny,
processes=True,
) as cluster:
with Client(cluster):
Expand Down Expand Up @@ -102,6 +105,7 @@ def _test_dataframe_shuffle(backend, protocol, n_workers, _partitions):
dashboard_address=None,
n_workers=n_workers,
threads_per_worker=1,
worker_class=IncreasedCloseTimeoutNanny,
processes=True,
) as cluster:
with Client(cluster) as client:
Expand Down Expand Up @@ -204,6 +208,7 @@ def check_shuffle():
dashboard_address=None,
n_workers=2,
threads_per_worker=1,
worker_class=IncreasedCloseTimeoutNanny,
processes=True,
) as cluster:
with Client(cluster):
Expand All @@ -221,6 +226,7 @@ def _test_dataframe_shuffle_merge(backend, protocol, n_workers):
dashboard_address=None,
n_workers=n_workers,
threads_per_worker=1,
worker_class=IncreasedCloseTimeoutNanny,
processes=True,
) as cluster:
with Client(cluster):
Expand Down Expand Up @@ -327,6 +333,7 @@ def test_lock_workers():
dashboard_address=None,
n_workers=4,
threads_per_worker=5,
worker_class=IncreasedCloseTimeoutNanny,
processes=True,
) as cluster:
ps = []
Expand Down

0 comments on commit 48de0c5

Please sign in to comment.