Skip to content

Commit

Permalink
Update comms.py (#1587)
Browse files Browse the repository at this point in the history
This PR is same as #1573 but targetted for branch-23.06 as a hotfix CC: @rlratzel 

Previously, `dask-raft` non-deterministically maps a process to a GPU. 

In this PR, we assign a deterministic order to each worker based on the CUDA_VISIBLE_DEVICES environment variable.
as NCCL>1.11 expects a process with `rank r` to be mapped to `r % num_gpus_per_node` . 


This  fixes rapidsai/cugraph#3478 
and this raft-test in MNMG setting  https://github.com/rapidsai/raft/blob/c1a7b7c0e33b11d2e7ff3bc5014e3b410a2edd0d/python/raft-dask/raft_dask/test/test_comms.py#L82-L84

Authors:
   - Vibhu Jawa (https://github.com/VibhuJawa)

Approvers:
   - Rick Ratzel (https://github.com/rlratzel)
   - Corey J. Nolet (https://github.com/cjnolet)
  • Loading branch information
VibhuJawa authored Jun 12, 2023
1 parent 4de0748 commit 6c81a41
Showing 1 changed file with 74 additions and 5 deletions.
79 changes: 74 additions & 5 deletions python/raft-dask/raft_dask/common/comms.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,14 @@
#

import logging
import os
import time
import uuid
import warnings
from collections import OrderedDict
from collections import Counter, OrderedDict

from dask.distributed import default_client
from dask_cuda.utils import nvml_device_index

from pylibraft.common.handle import Handle

Expand Down Expand Up @@ -155,7 +157,7 @@ def worker_info(self, workers):
Builds a dictionary of { (worker_address, worker_port) :
(worker_rank, worker_port ) }
"""
ranks = _func_worker_ranks(workers)
ranks = _func_worker_ranks(self.client)
ports = (
_func_ucp_ports(self.client, workers) if self.comms_p2p else None
)
Expand Down Expand Up @@ -686,8 +688,75 @@ def _func_ucp_ports(client, workers):
return client.run(_func_ucp_listener_port, workers=workers)


def _func_worker_ranks(workers):
def _func_worker_ranks(client):
"""
Builds a dictionary of { (worker_address, worker_port) : worker_rank }
For each worker connected to the client,
compute a global rank which is the sum
of the NVML device index and the worker rank offset.
Parameters
----------
client (object): Dask client object.
"""
ranks = client.run(_get_nvml_device_index)
worker_ips = [_get_worker_ip(worker_address) for worker_address in ranks]
worker_ip_offset_dict = _get_rank_offset_across_nodes(worker_ips)
return _append_rank_offset(ranks, worker_ip_offset_dict)


def _get_nvml_device_index():
"""
Return NVML device index based on environment variable
'CUDA_VISIBLE_DEVICES'.
"""
CUDA_VISIBLE_DEVICES = os.getenv("CUDA_VISIBLE_DEVICES")
return nvml_device_index(0, CUDA_VISIBLE_DEVICES)


def _get_worker_ip(worker_address):
"""
Extract the worker IP address from the worker address string.
Parameters
----------
worker_address (str): Full address string of the worker
"""
return ":".join(worker_address.split(":")[0:2])


def _get_rank_offset_across_nodes(worker_ips):
"""
Get a dictionary of worker IP addresses mapped to the cumulative count of
their occurrences in the worker_ips list. The cumulative count serves as
the rank offset.
Parameters
----------
worker_ips (list): List of worker IP addresses.
"""
worker_count_dict = Counter(worker_ips)
worker_offset_dict = {}
current_offset = 0
for worker_ip, worker_count in worker_count_dict.items():
worker_offset_dict[worker_ip] = current_offset
current_offset += worker_count
return worker_offset_dict


def _append_rank_offset(rank_dict, worker_ip_offset_dict):
"""
For each worker address in the rank dictionary, add the
corresponding worker offset from the worker_ip_offset_dict
to the rank value.
Parameters
----------
rank_dict (dict): Dictionary of worker addresses mapped to their ranks.
worker_ip_offset_dict (dict): Dictionary of worker IP addresses
mapped to their offsets.
"""
return dict(list(zip(workers, range(len(workers)))))
for worker_ip, worker_offset in worker_ip_offset_dict.items():
for worker_address in rank_dict:
if worker_ip in worker_address:
rank_dict[worker_address] += worker_offset
return rank_dict

0 comments on commit 6c81a41

Please sign in to comment.