diff --git a/dask_cuda/tests/test_explicit_comms.py b/dask_cuda/tests/test_explicit_comms.py index 21b35e48..ed34f21f 100644 --- a/dask_cuda/tests/test_explicit_comms.py +++ b/dask_cuda/tests/test_explicit_comms.py @@ -1,6 +1,9 @@ import asyncio import multiprocessing as mp import os +import signal +import time +from functools import partial from unittest.mock import patch import numpy as np @@ -175,7 +178,7 @@ def test_dataframe_shuffle(backend, protocol, nworkers, _partitions): @pytest.mark.parametrize("in_cluster", [True, False]) -def test_dask_use_explicit_comms(in_cluster): +def _test_dask_use_explicit_comms(in_cluster): def check_shuffle(): """Check if shuffle use explicit-comms by search for keys named 'explicit-comms-shuffle' @@ -217,6 +220,31 @@ def check_shuffle(): check_shuffle() +@pytest.mark.parametrize("in_cluster", [True, False]) +def test_dask_use_explicit_comms(in_cluster): + def _timeout(process, function, timeout): + if process.is_alive(): + function() + timeout = time.time() + timeout + while process.is_alive() and time.time() < timeout: + time.sleep(0.1) + + p = mp.Process(target=_test_dask_use_explicit_comms, args=(in_cluster,)) + p.start() + + # Timeout before killing process + _timeout(p, lambda: None, 60.0) + + # Send SIGINT (i.e., KeyboardInterrupt) hoping we get a stack trace. + _timeout(p, partial(p._popen._send_signal, signal.SIGINT), 3.0) + + # SIGINT didn't work, kill process. + _timeout(p, p.kill, 3.0) + + assert not p.is_alive() + assert p.exitcode == 0 + + def _test_dataframe_shuffle_merge(backend, protocol, n_workers): if backend == "cudf": cudf = pytest.importorskip("cudf")