From 48de0c5cb28d4a691ebebcdd0539226e74f4f69c Mon Sep 17 00:00:00 2001 From: Peter Andreas Entschev Date: Thu, 12 Oct 2023 21:01:22 +0200 Subject: [PATCH] Increase close timeout of `Nanny` in `LocalCUDACluster` (#1260) Tests in CI have been failing more often, but those errors can't be reproduced locally. This is possibly related to `Nanny`'s internal mechanism to establish timeouts to kill processes, perhaps due to higher load on the servers, tasks take longer and killing processes takes into account the overall time taken to establish a timeout, which is then drastically reduced leaving little time to actually shutdown processes. It is also not possible to programatically set a different timeout given existing Distributed's API, which currently calls `close()` without arguments in `SpecCluster._correct_state_internal()`. Given the limitations described above, a new class is added by this change with the sole purpose of rewriting the timeout for `Nanny.close()` method with an increased value, and then use the new class when launching `LocalCUDACluster` via the `worker_class` argument. Authors: - Peter Andreas Entschev (https://github.com/pentschev) Approvers: - Benjamin Zaitlen (https://github.com/quasiben) URL: https://github.com/rapidsai/dask-cuda/pull/1260 --- dask_cuda/local_cuda_cluster.py | 22 ++++++++++++++-------- dask_cuda/tests/test_explicit_comms.py | 7 +++++++ 2 files changed, 21 insertions(+), 8 deletions(-) diff --git a/dask_cuda/local_cuda_cluster.py b/dask_cuda/local_cuda_cluster.py index 324484331..ef15dcce3 100644 --- a/dask_cuda/local_cuda_cluster.py +++ b/dask_cuda/local_cuda_cluster.py @@ -2,6 +2,8 @@ import logging import os import warnings +from functools import partial +from typing import Literal import dask from distributed import LocalCluster, Nanny, Worker @@ -23,6 +25,13 @@ ) +class IncreasedCloseTimeoutNanny(Nanny): + async def close( # type:ignore[override] + self, timeout: float = 10.0, reason: str = "nanny-close" + ) -> Literal["OK"]: + return await super().close(timeout=timeout, reason=reason) + + class LoggedWorker(Worker): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) @@ -32,7 +41,7 @@ async def start(self): self.data.set_address(self.address) -class LoggedNanny(Nanny): +class LoggedNanny(IncreasedCloseTimeoutNanny): def __init__(self, *args, **kwargs): super().__init__(*args, worker_class=LoggedWorker, **kwargs) @@ -333,13 +342,10 @@ def __init__( enable_rdmacm=enable_rdmacm, ) - if worker_class is not None: - from functools import partial - - worker_class = partial( - LoggedNanny if log_spilling is True else Nanny, - worker_class=worker_class, - ) + worker_class = partial( + LoggedNanny if log_spilling is True else IncreasedCloseTimeoutNanny, + worker_class=worker_class, + ) self.pre_import = pre_import diff --git a/dask_cuda/tests/test_explicit_comms.py b/dask_cuda/tests/test_explicit_comms.py index ae4e3332c..d9cd6dfb2 100644 --- a/dask_cuda/tests/test_explicit_comms.py +++ b/dask_cuda/tests/test_explicit_comms.py @@ -17,6 +17,7 @@ import dask_cuda from dask_cuda.explicit_comms import comms from dask_cuda.explicit_comms.dataframe.shuffle import shuffle as explicit_comms_shuffle +from dask_cuda.local_cuda_cluster import IncreasedCloseTimeoutNanny mp = mp.get_context("spawn") # type: ignore ucp = pytest.importorskip("ucp") @@ -35,6 +36,7 @@ def _test_local_cluster(protocol): dashboard_address=None, n_workers=4, threads_per_worker=1, + worker_class=IncreasedCloseTimeoutNanny, processes=True, ) as cluster: with Client(cluster) as client: @@ -56,6 +58,7 @@ def _test_dataframe_merge_empty_partitions(nrows, npartitions): dashboard_address=None, n_workers=npartitions, threads_per_worker=1, + worker_class=IncreasedCloseTimeoutNanny, processes=True, ) as cluster: with Client(cluster): @@ -102,6 +105,7 @@ def _test_dataframe_shuffle(backend, protocol, n_workers, _partitions): dashboard_address=None, n_workers=n_workers, threads_per_worker=1, + worker_class=IncreasedCloseTimeoutNanny, processes=True, ) as cluster: with Client(cluster) as client: @@ -204,6 +208,7 @@ def check_shuffle(): dashboard_address=None, n_workers=2, threads_per_worker=1, + worker_class=IncreasedCloseTimeoutNanny, processes=True, ) as cluster: with Client(cluster): @@ -221,6 +226,7 @@ def _test_dataframe_shuffle_merge(backend, protocol, n_workers): dashboard_address=None, n_workers=n_workers, threads_per_worker=1, + worker_class=IncreasedCloseTimeoutNanny, processes=True, ) as cluster: with Client(cluster): @@ -327,6 +333,7 @@ def test_lock_workers(): dashboard_address=None, n_workers=4, threads_per_worker=5, + worker_class=IncreasedCloseTimeoutNanny, processes=True, ) as cluster: ps = []