diff --git a/python/raft-dask/raft_dask/common/comms.py b/python/raft-dask/raft_dask/common/comms.py index ebe9a8dc4f..7a0b786ec4 100644 --- a/python/raft-dask/raft_dask/common/comms.py +++ b/python/raft-dask/raft_dask/common/comms.py @@ -14,12 +14,14 @@ # import logging +import os import time import uuid import warnings -from collections import OrderedDict +from collections import Counter, OrderedDict from dask.distributed import default_client +from dask_cuda.utils import nvml_device_index from pylibraft.common.handle import Handle @@ -155,7 +157,7 @@ def worker_info(self, workers): Builds a dictionary of { (worker_address, worker_port) : (worker_rank, worker_port ) } """ - ranks = _func_worker_ranks(workers) + ranks = _func_worker_ranks(self.client) ports = ( _func_ucp_ports(self.client, workers) if self.comms_p2p else None ) @@ -686,8 +688,75 @@ def _func_ucp_ports(client, workers): return client.run(_func_ucp_listener_port, workers=workers) -def _func_worker_ranks(workers): +def _func_worker_ranks(client): """ - Builds a dictionary of { (worker_address, worker_port) : worker_rank } + For each worker connected to the client, + compute a global rank which is the sum + of the NVML device index and the worker rank offset. + + Parameters + ---------- + client (object): Dask client object. + """ + ranks = client.run(_get_nvml_device_index) + worker_ips = [_get_worker_ip(worker_address) for worker_address in ranks] + worker_ip_offset_dict = _get_rank_offset_across_nodes(worker_ips) + return _append_rank_offset(ranks, worker_ip_offset_dict) + + +def _get_nvml_device_index(): + """ + Return NVML device index based on environment variable + 'CUDA_VISIBLE_DEVICES'. + """ + CUDA_VISIBLE_DEVICES = os.getenv("CUDA_VISIBLE_DEVICES") + return nvml_device_index(0, CUDA_VISIBLE_DEVICES) + + +def _get_worker_ip(worker_address): + """ + Extract the worker IP address from the worker address string. + + Parameters + ---------- + worker_address (str): Full address string of the worker + """ + return ":".join(worker_address.split(":")[0:2]) + + +def _get_rank_offset_across_nodes(worker_ips): + """ + Get a dictionary of worker IP addresses mapped to the cumulative count of + their occurrences in the worker_ips list. The cumulative count serves as + the rank offset. + + Parameters + ---------- + worker_ips (list): List of worker IP addresses. + """ + worker_count_dict = Counter(worker_ips) + worker_offset_dict = {} + current_offset = 0 + for worker_ip, worker_count in worker_count_dict.items(): + worker_offset_dict[worker_ip] = current_offset + current_offset += worker_count + return worker_offset_dict + + +def _append_rank_offset(rank_dict, worker_ip_offset_dict): + """ + For each worker address in the rank dictionary, add the + corresponding worker offset from the worker_ip_offset_dict + to the rank value. + + Parameters + ---------- + rank_dict (dict): Dictionary of worker addresses mapped to their ranks. + worker_ip_offset_dict (dict): Dictionary of worker IP addresses + mapped to their offsets. """ - return dict(list(zip(workers, range(len(workers))))) + for worker_ip, worker_offset in worker_ip_offset_dict.items(): + for worker_address in rank_dict: + if worker_ip in worker_address: + rank_dict[worker_address] += worker_offset + return rank_dict