From 6c81a413d671c2b051c5fad99357d716c817995f Mon Sep 17 00:00:00 2001 From: Vibhu Jawa Date: Mon, 12 Jun 2023 08:47:37 -0700 Subject: [PATCH] Update comms.py (#1587) This PR is same as https://github.com/rapidsai/raft/pull/1573 but targetted for branch-23.06 as a hotfix CC: @rlratzel Previously, `dask-raft` non-deterministically maps a process to a GPU. In this PR, we assign a deterministic order to each worker based on the CUDA_VISIBLE_DEVICES environment variable. as NCCL>1.11 expects a process with `rank r` to be mapped to `r % num_gpus_per_node` . This fixes https://github.com/rapidsai/cugraph/issues/3478 and this raft-test in MNMG setting https://github.com/rapidsai/raft/blob/c1a7b7c0e33b11d2e7ff3bc5014e3b410a2edd0d/python/raft-dask/raft_dask/test/test_comms.py#L82-L84 Authors: - Vibhu Jawa (https://github.com/VibhuJawa) Approvers: - Rick Ratzel (https://github.com/rlratzel) - Corey J. Nolet (https://github.com/cjnolet) --- python/raft-dask/raft_dask/common/comms.py | 79 ++++++++++++++++++++-- 1 file changed, 74 insertions(+), 5 deletions(-) diff --git a/python/raft-dask/raft_dask/common/comms.py b/python/raft-dask/raft_dask/common/comms.py index ebe9a8dc4f..7a0b786ec4 100644 --- a/python/raft-dask/raft_dask/common/comms.py +++ b/python/raft-dask/raft_dask/common/comms.py @@ -14,12 +14,14 @@ # import logging +import os import time import uuid import warnings -from collections import OrderedDict +from collections import Counter, OrderedDict from dask.distributed import default_client +from dask_cuda.utils import nvml_device_index from pylibraft.common.handle import Handle @@ -155,7 +157,7 @@ def worker_info(self, workers): Builds a dictionary of { (worker_address, worker_port) : (worker_rank, worker_port ) } """ - ranks = _func_worker_ranks(workers) + ranks = _func_worker_ranks(self.client) ports = ( _func_ucp_ports(self.client, workers) if self.comms_p2p else None ) @@ -686,8 +688,75 @@ def _func_ucp_ports(client, workers): return client.run(_func_ucp_listener_port, workers=workers) -def _func_worker_ranks(workers): +def _func_worker_ranks(client): """ - Builds a dictionary of { (worker_address, worker_port) : worker_rank } + For each worker connected to the client, + compute a global rank which is the sum + of the NVML device index and the worker rank offset. + + Parameters + ---------- + client (object): Dask client object. + """ + ranks = client.run(_get_nvml_device_index) + worker_ips = [_get_worker_ip(worker_address) for worker_address in ranks] + worker_ip_offset_dict = _get_rank_offset_across_nodes(worker_ips) + return _append_rank_offset(ranks, worker_ip_offset_dict) + + +def _get_nvml_device_index(): + """ + Return NVML device index based on environment variable + 'CUDA_VISIBLE_DEVICES'. + """ + CUDA_VISIBLE_DEVICES = os.getenv("CUDA_VISIBLE_DEVICES") + return nvml_device_index(0, CUDA_VISIBLE_DEVICES) + + +def _get_worker_ip(worker_address): + """ + Extract the worker IP address from the worker address string. + + Parameters + ---------- + worker_address (str): Full address string of the worker + """ + return ":".join(worker_address.split(":")[0:2]) + + +def _get_rank_offset_across_nodes(worker_ips): + """ + Get a dictionary of worker IP addresses mapped to the cumulative count of + their occurrences in the worker_ips list. The cumulative count serves as + the rank offset. + + Parameters + ---------- + worker_ips (list): List of worker IP addresses. + """ + worker_count_dict = Counter(worker_ips) + worker_offset_dict = {} + current_offset = 0 + for worker_ip, worker_count in worker_count_dict.items(): + worker_offset_dict[worker_ip] = current_offset + current_offset += worker_count + return worker_offset_dict + + +def _append_rank_offset(rank_dict, worker_ip_offset_dict): + """ + For each worker address in the rank dictionary, add the + corresponding worker offset from the worker_ip_offset_dict + to the rank value. + + Parameters + ---------- + rank_dict (dict): Dictionary of worker addresses mapped to their ranks. + worker_ip_offset_dict (dict): Dictionary of worker IP addresses + mapped to their offsets. """ - return dict(list(zip(workers, range(len(workers))))) + for worker_ip, worker_offset in worker_ip_offset_dict.items(): + for worker_address in rank_dict: + if worker_ip in worker_address: + rank_dict[worker_address] += worker_offset + return rank_dict