Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Assigning Deterministic rank to Dask Workers Based on CUDA_VISIBLE_DEVICES (branch-23.06) #1587

Merged
merged 1 commit into from
Jun 12, 2023
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
79 changes: 74 additions & 5 deletions python/raft-dask/raft_dask/common/comms.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,14 @@
#

import logging
import os
import time
import uuid
import warnings
from collections import OrderedDict
from collections import Counter, OrderedDict

from dask.distributed import default_client
from dask_cuda.utils import nvml_device_index

from pylibraft.common.handle import Handle

Expand Down Expand Up @@ -155,7 +157,7 @@ def worker_info(self, workers):
Builds a dictionary of { (worker_address, worker_port) :
(worker_rank, worker_port ) }
"""
ranks = _func_worker_ranks(workers)
ranks = _func_worker_ranks(self.client)
ports = (
_func_ucp_ports(self.client, workers) if self.comms_p2p else None
)
Expand Down Expand Up @@ -686,8 +688,75 @@ def _func_ucp_ports(client, workers):
return client.run(_func_ucp_listener_port, workers=workers)


def _func_worker_ranks(workers):
def _func_worker_ranks(client):
"""
Builds a dictionary of { (worker_address, worker_port) : worker_rank }
For each worker connected to the client,
compute a global rank which is the sum
of the NVML device index and the worker rank offset.

Parameters
----------
client (object): Dask client object.
"""
ranks = client.run(_get_nvml_device_index)
worker_ips = [_get_worker_ip(worker_address) for worker_address in ranks]
worker_ip_offset_dict = _get_rank_offset_across_nodes(worker_ips)
return _append_rank_offset(ranks, worker_ip_offset_dict)


def _get_nvml_device_index():
"""
Return NVML device index based on environment variable
'CUDA_VISIBLE_DEVICES'.
"""
CUDA_VISIBLE_DEVICES = os.getenv("CUDA_VISIBLE_DEVICES")
return nvml_device_index(0, CUDA_VISIBLE_DEVICES)


def _get_worker_ip(worker_address):
"""
Extract the worker IP address from the worker address string.

Parameters
----------
worker_address (str): Full address string of the worker
"""
return ":".join(worker_address.split(":")[0:2])


def _get_rank_offset_across_nodes(worker_ips):
"""
Get a dictionary of worker IP addresses mapped to the cumulative count of
their occurrences in the worker_ips list. The cumulative count serves as
the rank offset.

Parameters
----------
worker_ips (list): List of worker IP addresses.
"""
worker_count_dict = Counter(worker_ips)
worker_offset_dict = {}
current_offset = 0
for worker_ip, worker_count in worker_count_dict.items():
worker_offset_dict[worker_ip] = current_offset
current_offset += worker_count
return worker_offset_dict


def _append_rank_offset(rank_dict, worker_ip_offset_dict):
"""
For each worker address in the rank dictionary, add the
corresponding worker offset from the worker_ip_offset_dict
to the rank value.

Parameters
----------
rank_dict (dict): Dictionary of worker addresses mapped to their ranks.
worker_ip_offset_dict (dict): Dictionary of worker IP addresses
mapped to their offsets.
"""
return dict(list(zip(workers, range(len(workers)))))
for worker_ip, worker_offset in worker_ip_offset_dict.items():
for worker_address in rank_dict:
if worker_ip in worker_address:
rank_dict[worker_address] += worker_offset
return rank_dict