diff --git a/python/raft-dask/raft_dask/common/comms.py b/python/raft-dask/raft_dask/common/comms.py index bbae24a4f7..ed600410fd 100644 --- a/python/raft-dask/raft_dask/common/comms.py +++ b/python/raft-dask/raft_dask/common/comms.py @@ -693,6 +693,8 @@ def _func_worker_ranks(client, workers): For each worker connected to the client, compute a global rank which takes into account the NVML device index and the worker IP (group workers on same host and order by NVML device). + Note that the reason for sorting was nvbug 4149999 and is presumably + fixed afterNCCL 2.19.3. Parameters ---------- diff --git a/python/raft-dask/raft_dask/test/test_comms.py b/python/raft-dask/raft_dask/test/test_comms.py index 68c9fee556..467cc8fe0e 100644 --- a/python/raft-dask/raft_dask/test/test_comms.py +++ b/python/raft-dask/raft_dask/test/test_comms.py @@ -246,6 +246,17 @@ def test_collectives(client, func, root_location): cb.destroy() +@pytest.mark.nccl +@pytest.mark.parametrize("subset", [-1, 1, slice(None, None, -2)]) +def test_comm_init_worker_subset(client, subset): + # Basic test that initializing a subset of workers is fine + cb = Comms(comms_p2p=True, verbose=True) + + workers = list(client.scheduler_info()["workers"].keys()) + workers = workers[subset] + cb.init(workers=workers) + + @pytest.mark.nccl def test_comm_split(client):