diff --git a/dask_cuda/tests/test_dgx.py b/dask_cuda/tests/test_dgx.py index 1fd6d0ebb..a7b79f327 100644 --- a/dask_cuda/tests/test_dgx.py +++ b/dask_cuda/tests/test_dgx.py @@ -128,7 +128,7 @@ def test_tcp_only(): def _test_ucx_infiniband_nvlink( - protocol, enable_infiniband, enable_nvlink, enable_rdmacm + skip_queue, protocol, enable_infiniband, enable_nvlink, enable_rdmacm ): cupy = pytest.importorskip("cupy") if protocol == "ucx": @@ -136,6 +136,14 @@ def _test_ucx_infiniband_nvlink( elif protocol == "ucxx": ucp = pytest.importorskip("ucxx") + if enable_infiniband and not any( + [at.startswith("rc") for at in ucp.get_active_transports()] + ): + skip_queue.put("No support available for 'rc' transport in UCX") + return + else: + skip_queue.put("ok") + if enable_infiniband is None and enable_nvlink is None and enable_rdmacm is None: enable_tcp_over_ucx = None cm_tls = ["all"] @@ -205,17 +213,16 @@ def check_ucx_options(): ) def test_ucx_infiniband_nvlink(protocol, params): if protocol == "ucx": - ucp = pytest.importorskip("ucp") + pytest.importorskip("ucp") elif protocol == "ucxx": - ucp = pytest.importorskip("ucxx") + pytest.importorskip("ucxx") - if params["enable_infiniband"]: - if not any([at.startswith("rc") for at in ucp.get_active_transports()]): - pytest.skip("No support available for 'rc' transport in UCX") + skip_queue = mp.Queue() p = mp.Process( target=_test_ucx_infiniband_nvlink, args=( + skip_queue, protocol, params["enable_infiniband"], params["enable_nvlink"], @@ -225,9 +232,8 @@ def test_ucx_infiniband_nvlink(protocol, params): p.start() p.join() - # Starting a new cluster on the same pytest process after an rdmacm cluster - # has been used may cause UCX-Py to complain about being already initialized. - if params["enable_rdmacm"] is True: - ucp.reset() + skip_msg = skip_queue.get() + if skip_msg != "ok": + pytest.skip(skip_msg) assert not p.exitcode