Skip to content

Commit

Permalink
MAINT: Simplify NCCL worker rank identification
Browse files Browse the repository at this point in the history
This is a follow up on gh-1926, since the rank sorting seemed
a bit hard to understand.
It does modify the logic in the sense that the host is now sorted
by IP as a way to group based on it.  But I don't really think that
host sorting was ever a goal?

If the goal is really about being deterministic, then this should
be more (or at least clearer) deterministic about order of worker
IPs.
OTOH, if the NVML device order doesn't matter, we could just sort
the workers directly.

The original gh-1587 mentions:

    NCCL>1.11 expects a process with rank r to be mapped to r % num_gpus_per_node

which is something that neither approach seems to quite assure,
if such a requirement exists, I would want to do one of:
* Ensure we can guarantee this, but this requires initializing
  workers that are not involved in the operation.
* At least raise an error, because if NCCL will end up raising
  the error it will be very confusing.
  • Loading branch information
seberg committed Feb 13, 2024
1 parent 04a9c95 commit 0b69050
Showing 1 changed file with 10 additions and 81 deletions.
91 changes: 10 additions & 81 deletions python/raft-dask/raft_dask/common/comms.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,7 @@
import time
import uuid
import warnings
from collections import Counter, OrderedDict, defaultdict
from typing import Dict
from collections import OrderedDict

from dask.distributed import default_client
from dask_cuda.utils import nvml_device_index
Expand Down Expand Up @@ -691,9 +690,9 @@ def _func_ucp_ports(client, workers):

def _func_worker_ranks(client, workers):
"""
For each worker connected to the client,
compute a global rank which is the sum
of the NVML device index and the worker rank offset.
For each worker connected to the client, compute a global rank which takes
into account the NVML device index and the worker IP
(group workers on same host and order by NVML device).
Parameters
----------
Expand All @@ -703,13 +702,13 @@ def _func_worker_ranks(client, workers):
# TODO: Add Test this function
# Running into build issues preventing testing
nvml_device_index_d = client.run(_get_nvml_device_index, workers=workers)
worker_ips = [
_get_worker_ip(worker_address)
for worker_address in nvml_device_index_d
# Sort workers first by IP and then by the nvml device index:
worker_info_list = [
(_get_worker_ip(worker), nvml_device_index, worker)
for worker, nvml_device_index in nvml_device_index_d.items()
]
ranks = _map_nvml_device_id_to_contiguous_range(nvml_device_index_d)
worker_ip_offset_dict = _get_rank_offset_across_nodes(worker_ips)
return _append_rank_offset(ranks, worker_ip_offset_dict)
worker_info_list.sort()
return {wi[2]: i for i, wi in enumerate(worker_info_list)}


def _get_nvml_device_index():
Expand All @@ -730,73 +729,3 @@ def _get_worker_ip(worker_address):
worker_address (str): Full address string of the worker
"""
return ":".join(worker_address.split(":")[0:2])


def _map_nvml_device_id_to_contiguous_range(nvml_device_index_d: dict) -> dict:
"""
For each worker address in nvml_device_index_d, map the corresponding
worker rank in the range(0, num_workers_per_node) where rank is decided
by the NVML device index. Worker with the lowest NVML device index gets
rank 0, and worker with the highest NVML device index gets rank
num_workers_per_node-1.
Parameters
----------
nvml_device_index_d : dict
Dictionary of worker addresses mapped to their nvml device index.
Returns
-------
dict
Updated dictionary with worker addresses mapped to their rank.
"""

rank_per_ip: Dict[str, int] = defaultdict(int)

# Sort by NVML index to ensure that the worker
# with the lowest NVML index gets rank 0.
for worker, _ in sorted(nvml_device_index_d.items(), key=lambda x: x[1]):
ip = _get_worker_ip(worker)

nvml_device_index_d[worker] = rank_per_ip[ip]
rank_per_ip[ip] += 1

return nvml_device_index_d


def _get_rank_offset_across_nodes(worker_ips):
"""
Get a dictionary of worker IP addresses mapped to the cumulative count of
their occurrences in the worker_ips list. The cumulative count serves as
the rank offset.
Parameters
----------
worker_ips (list): List of worker IP addresses.
"""
worker_count_dict = Counter(worker_ips)
worker_offset_dict = {}
current_offset = 0
for worker_ip, worker_count in worker_count_dict.items():
worker_offset_dict[worker_ip] = current_offset
current_offset += worker_count
return worker_offset_dict


def _append_rank_offset(rank_dict, worker_ip_offset_dict):
"""
For each worker address in the rank dictionary, add the
corresponding worker offset from the worker_ip_offset_dict
to the rank value.
Parameters
----------
rank_dict (dict): Dictionary of worker addresses mapped to their ranks.
worker_ip_offset_dict (dict): Dictionary of worker IP addresses
mapped to their offsets.
"""
for worker_ip, worker_offset in worker_ip_offset_dict.items():
for worker_address in rank_dict:
if worker_ip in worker_address:
rank_dict[worker_address] += worker_offset
return rank_dict

0 comments on commit 0b69050

Please sign in to comment.