diff --git a/dask_cuda/local_cuda_cluster.py b/dask_cuda/local_cuda_cluster.py index 324484331..ef15dcce3 100644 --- a/dask_cuda/local_cuda_cluster.py +++ b/dask_cuda/local_cuda_cluster.py @@ -2,6 +2,8 @@ import logging import os import warnings +from functools import partial +from typing import Literal import dask from distributed import LocalCluster, Nanny, Worker @@ -23,6 +25,13 @@ ) +class IncreasedCloseTimeoutNanny(Nanny): + async def close( # type:ignore[override] + self, timeout: float = 10.0, reason: str = "nanny-close" + ) -> Literal["OK"]: + return await super().close(timeout=timeout, reason=reason) + + class LoggedWorker(Worker): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) @@ -32,7 +41,7 @@ async def start(self): self.data.set_address(self.address) -class LoggedNanny(Nanny): +class LoggedNanny(IncreasedCloseTimeoutNanny): def __init__(self, *args, **kwargs): super().__init__(*args, worker_class=LoggedWorker, **kwargs) @@ -333,13 +342,10 @@ def __init__( enable_rdmacm=enable_rdmacm, ) - if worker_class is not None: - from functools import partial - - worker_class = partial( - LoggedNanny if log_spilling is True else Nanny, - worker_class=worker_class, - ) + worker_class = partial( + LoggedNanny if log_spilling is True else IncreasedCloseTimeoutNanny, + worker_class=worker_class, + ) self.pre_import = pre_import diff --git a/dask_cuda/tests/test_explicit_comms.py b/dask_cuda/tests/test_explicit_comms.py index ae4e3332c..d9cd6dfb2 100644 --- a/dask_cuda/tests/test_explicit_comms.py +++ b/dask_cuda/tests/test_explicit_comms.py @@ -17,6 +17,7 @@ import dask_cuda from dask_cuda.explicit_comms import comms from dask_cuda.explicit_comms.dataframe.shuffle import shuffle as explicit_comms_shuffle +from dask_cuda.local_cuda_cluster import IncreasedCloseTimeoutNanny mp = mp.get_context("spawn") # type: ignore ucp = pytest.importorskip("ucp") @@ -35,6 +36,7 @@ def _test_local_cluster(protocol): dashboard_address=None, n_workers=4, threads_per_worker=1, + worker_class=IncreasedCloseTimeoutNanny, processes=True, ) as cluster: with Client(cluster) as client: @@ -56,6 +58,7 @@ def _test_dataframe_merge_empty_partitions(nrows, npartitions): dashboard_address=None, n_workers=npartitions, threads_per_worker=1, + worker_class=IncreasedCloseTimeoutNanny, processes=True, ) as cluster: with Client(cluster): @@ -102,6 +105,7 @@ def _test_dataframe_shuffle(backend, protocol, n_workers, _partitions): dashboard_address=None, n_workers=n_workers, threads_per_worker=1, + worker_class=IncreasedCloseTimeoutNanny, processes=True, ) as cluster: with Client(cluster) as client: @@ -204,6 +208,7 @@ def check_shuffle(): dashboard_address=None, n_workers=2, threads_per_worker=1, + worker_class=IncreasedCloseTimeoutNanny, processes=True, ) as cluster: with Client(cluster): @@ -221,6 +226,7 @@ def _test_dataframe_shuffle_merge(backend, protocol, n_workers): dashboard_address=None, n_workers=n_workers, threads_per_worker=1, + worker_class=IncreasedCloseTimeoutNanny, processes=True, ) as cluster: with Client(cluster): @@ -327,6 +333,7 @@ def test_lock_workers(): dashboard_address=None, n_workers=4, threads_per_worker=5, + worker_class=IncreasedCloseTimeoutNanny, processes=True, ) as cluster: ps = []