Skip to content

Commit

Permalink
Try using contiguous rank to fix cuda_visible_devices (#1926)
Browse files Browse the repository at this point in the history
This PR attempts to solve rapidsai/cugraph#3889

Authors:
  - Vibhu Jawa (https://github.com/VibhuJawa)

Approvers:
  - Corey J. Nolet (https://github.com/cjnolet)

URL: #1926
  • Loading branch information
VibhuJawa authored Oct 25, 2023
1 parent 53c2539 commit 3b87796
Showing 1 changed file with 46 additions and 6 deletions.
52 changes: 46 additions & 6 deletions python/raft-dask/raft_dask/common/comms.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,8 @@
import time
import uuid
import warnings
from collections import Counter, OrderedDict
from collections import Counter, OrderedDict, defaultdict
from typing import Dict

from dask.distributed import default_client
from dask_cuda.utils import nvml_device_index
Expand Down Expand Up @@ -157,7 +158,7 @@ def worker_info(self, workers):
Builds a dictionary of { (worker_address, worker_port) :
(worker_rank, worker_port ) }
"""
ranks = _func_worker_ranks(self.client)
ranks = _func_worker_ranks(self.client, workers)
ports = (
_func_ucp_ports(self.client, workers) if self.comms_p2p else None
)
Expand Down Expand Up @@ -688,7 +689,7 @@ def _func_ucp_ports(client, workers):
return client.run(_func_ucp_listener_port, workers=workers)


def _func_worker_ranks(client):
def _func_worker_ranks(client, workers):
"""
For each worker connected to the client,
compute a global rank which is the sum
Expand All @@ -697,9 +698,16 @@ def _func_worker_ranks(client):
Parameters
----------
client (object): Dask client object.
"""
ranks = client.run(_get_nvml_device_index)
worker_ips = [_get_worker_ip(worker_address) for worker_address in ranks]
workers (list): List of worker addresses.
"""
# TODO: Add Test this function
# Running into build issues preventing testing
nvml_device_index_d = client.run(_get_nvml_device_index, workers=workers)
worker_ips = [
_get_worker_ip(worker_address)
for worker_address in nvml_device_index_d
]
ranks = _map_nvml_device_id_to_contiguous_range(nvml_device_index_d)
worker_ip_offset_dict = _get_rank_offset_across_nodes(worker_ips)
return _append_rank_offset(ranks, worker_ip_offset_dict)

Expand All @@ -724,6 +732,38 @@ def _get_worker_ip(worker_address):
return ":".join(worker_address.split(":")[0:2])


def _map_nvml_device_id_to_contiguous_range(nvml_device_index_d: dict) -> dict:
"""
For each worker address in nvml_device_index_d, map the corresponding
worker rank in the range(0, num_workers_per_node) where rank is decided
by the NVML device index. Worker with the lowest NVML device index gets
rank 0, and worker with the highest NVML device index gets rank
num_workers_per_node-1.
Parameters
----------
nvml_device_index_d : dict
Dictionary of worker addresses mapped to their nvml device index.
Returns
-------
dict
Updated dictionary with worker addresses mapped to their rank.
"""

rank_per_ip: Dict[str, int] = defaultdict(int)

# Sort by NVML index to ensure that the worker
# with the lowest NVML index gets rank 0.
for worker, _ in sorted(nvml_device_index_d.items(), key=lambda x: x[1]):
ip = _get_worker_ip(worker)

nvml_device_index_d[worker] = rank_per_ip[ip]
rank_per_ip[ip] += 1

return nvml_device_index_d


def _get_rank_offset_across_nodes(worker_ips):
"""
Get a dictionary of worker IP addresses mapped to the cumulative count of
Expand Down

0 comments on commit 3b87796

Please sign in to comment.