Skip to content

Commit

Permalink
Remove usage of Dask's get_worker
Browse files Browse the repository at this point in the history
In dask/distributed#7580 get_worker was modified to return the worker of
a task, thus it cannot be used by client.run, and we must now use
dask_worker as the first argument to client.run to obtain the worker.
  • Loading branch information
pentschev committed Mar 22, 2023
1 parent 7e328ff commit 608064b
Show file tree
Hide file tree
Showing 2 changed files with 74 additions and 55 deletions.
123 changes: 72 additions & 51 deletions python/raft-dask/raft_dask/common/comms.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
import warnings
from collections import OrderedDict

from dask.distributed import default_client, get_worker
from dask.distributed import default_client

from pylibraft.common.handle import Handle

Expand Down Expand Up @@ -242,7 +242,7 @@ def destroy(self):
self.ucx_initialized = False


def local_handle(sessionId):
def local_handle(sessionId, dask_worker=None):
"""
Simple helper function for retrieving the local handle_t instance
for a comms session on a worker.
Expand All @@ -251,16 +251,19 @@ def local_handle(sessionId):
----------
sessionId : str
session identifier from an initialized comms instance
dask_worker : dask_worker object
(Note: if called by client.run(), this is supplied by Dask
and not the client)
Returns
-------
handle : raft.Handle or None
"""
state = get_raft_comm_state(sessionId, get_worker())
state = get_raft_comm_state(sessionId, dask_worker)
return state["handle"] if "handle" in state else None


def get_raft_comm_state(sessionId, state_object=None):
def get_raft_comm_state(sessionId, state_object=None, dask_worker=None):
"""
Retrieves cuML comms state on the scheduler node, for the given sessionId,
creating a new session if it does not exist. If no session id is given,
Expand All @@ -271,13 +274,16 @@ def get_raft_comm_state(sessionId, state_object=None):
sessionId : SessionId value to retrieve from the dask_scheduler instances
state_object : Object (either Worker, or Scheduler) on which the raft
comm state will retrieved (or created)
dask_worker : dask_worker object
(Note: if called by client.run(), this is supplied by Dask
and not the client)
Returns
-------
session state : str
session state associated with sessionId
"""
state_object = state_object if state_object is not None else get_worker()
state_object = state_object if state_object is not None else dask_worker

if not hasattr(state_object, "_raft_comm_state"):
state_object._raft_comm_state = {}
Expand Down Expand Up @@ -308,13 +314,19 @@ def set_nccl_root(sessionId, state_object):
return raft_comm_state["nccl_uid"]


def get_ucx():
def get_ucx(dask_worker=None):
"""
A simple convenience wrapper to make sure UCP listener and
endpoints are only ever assigned once per worker.
Parameters
----------
dask_worker : dask_worker object
(Note: if called by client.run(), this is supplied by Dask
and not the client)
"""
raft_comm_state = get_raft_comm_state(
sessionId="ucp", state_object=get_worker()
sessionId="ucp", state_object=dask_worker
)
if "ucx" not in raft_comm_state:
raft_comm_state["ucx"] = UCX.get()
Expand Down Expand Up @@ -371,7 +383,7 @@ def _func_set_scheduler_as_nccl_root(sessionId, verbose, dask_scheduler):
return nccl_uid


def _func_set_worker_as_nccl_root(sessionId, verbose):
def _func_set_worker_as_nccl_root(sessionId, verbose, dask_worker=None):
"""
Creates a persistent nccl uniqueId on the scheduler node.
Expand All @@ -380,23 +392,25 @@ def _func_set_worker_as_nccl_root(sessionId, verbose):
----------
sessionId : Associated session to attach the unique ID to.
verbose : Indicates whether or not to emit additional information
dask_worker : dask_worker object
(Note: if called by client.run(), this is supplied by Dask
and not the client)
Return
------
uniqueId : byte str
NCCL uniqueId, associating this DASK worker as its root node.
"""
worker = get_worker()
if verbose:
worker.log_event(
dask_worker.log_event(
topic="info",
msg=f"Setting worker as NCCL root for session, '{sessionId}'",
)

nccl_uid = set_nccl_root(sessionId=sessionId, state_object=worker)
nccl_uid = set_nccl_root(sessionId=sessionId, state_object=dask_worker)

if verbose:
worker.log_event(
dask_worker.log_event(
topic="info", msg="Done setting scheduler as NCCL root."
)

Expand All @@ -408,54 +422,53 @@ def _func_ucp_listener_port():


async def _func_init_all(
sessionId, uniqueId, comms_p2p, worker_info, verbose, streams_per_handle
sessionId, uniqueId, comms_p2p, worker_info, verbose, streams_per_handle, dask_worker=None
):
worker = get_worker()
raft_comm_state = get_raft_comm_state(
sessionId=sessionId, state_object=worker
sessionId=sessionId, state_object=dask_worker
)
raft_comm_state["nccl_uid"] = uniqueId
raft_comm_state["wid"] = worker_info[get_worker().address]["rank"]
raft_comm_state["wid"] = worker_info[dask_worker.address]["rank"]
raft_comm_state["nworkers"] = len(worker_info)

if verbose:
worker.log_event(topic="info", msg="Initializing NCCL.")
dask_worker.log_event(topic="info", msg="Initializing NCCL.")
start = time.time()

_func_init_nccl(sessionId, uniqueId)
_func_init_nccl(sessionId, uniqueId, dask_worker=dask_worker)

if verbose:
elapsed = time.time() - start
worker.log_event(
dask_worker.log_event(
topic="info", msg=f"NCCL Initialization took: {elapsed} seconds."
)

if comms_p2p:
if verbose:
worker.log_event(topic="info", msg="Initializing UCX Endpoints")
dask_worker.log_event(topic="info", msg="Initializing UCX Endpoints")

if verbose:
start = time.time()
await _func_ucp_create_endpoints(sessionId, worker_info)
await _func_ucp_create_endpoints(sessionId, worker_info, dask_worker=dask_worker)

if verbose:
elapsed = time.time() - start
msg = (
f"Done initializing UCX endpoints."
f"Took: {elapsed} seconds.\nBuilding handle."
)
worker.log_event(topic="info", msg=msg)
dask_worker.log_event(topic="info", msg=msg)

_func_build_handle_p2p(sessionId, streams_per_handle, verbose)

if verbose:
worker.log_event(topic="info", msg="Done building handle.")
dask_worker.log_event(topic="info", msg="Done building handle.")

else:
_func_build_handle(sessionId, streams_per_handle, verbose)
_func_build_handle(sessionId, streams_per_handle, verbose, dask_worker=dask_worker)


def _func_init_nccl(sessionId, uniqueId):
def _func_init_nccl(sessionId, uniqueId, dask_worker=None):
"""
Initialize ncclComm_t on worker
Expand All @@ -466,11 +479,13 @@ def _func_init_nccl(sessionId, uniqueId):
uniqueId : array[byte]
The NCCL unique Id generated from the
client.
dask_worker : dask_worker object
(Note: if called by client.run(), this is supplied by Dask
and not the client)
"""

worker = get_worker()
raft_comm_state = get_raft_comm_state(
sessionId=sessionId, state_object=get_worker()
sessionId=sessionId, state_object=dask_worker, dask_worker=dask_worker
)
wid = raft_comm_state["wid"]
nWorkers = raft_comm_state["nworkers"]
Expand All @@ -480,13 +495,13 @@ def _func_init_nccl(sessionId, uniqueId):
n.init(nWorkers, uniqueId, wid)
raft_comm_state["nccl"] = n
except Exception as e:
worker.log_event(
dask_worker.log_event(
topic="error", msg=f"An error occurred initializing NCCL: {e}."
)
raise


def _func_build_handle_p2p(sessionId, streams_per_handle, verbose):
def _func_build_handle_p2p(sessionId, streams_per_handle, verbose, dask_worker=None):
"""
Builds a handle_t on the current worker given the initialized comms
Expand All @@ -495,14 +510,16 @@ def _func_build_handle_p2p(sessionId, streams_per_handle, verbose):
sessionId : str id to reference state for current comms instance.
streams_per_handle : int number of internal streams to create
verbose : bool print verbose logging output
dask_worker : dask_worker object
(Note: if called by client.run(), this is supplied by Dask
and not the client)
"""
worker = get_worker()
if verbose:
worker.log_event(topic="info", msg="Building p2p handle.")
dask_worker.log_event(topic="info", msg="Building p2p handle.")

ucp_worker = get_ucx().get_worker()
ucp_worker = get_ucx(dask_worker).get_worker()
raft_comm_state = get_raft_comm_state(
sessionId=sessionId, state_object=worker
sessionId=sessionId, state_object=dask_worker
)

handle = Handle(n_streams=streams_per_handle)
Expand All @@ -512,21 +529,21 @@ def _func_build_handle_p2p(sessionId, streams_per_handle, verbose):
workerId = raft_comm_state["wid"]

if verbose:
worker.log_event(topic="info", msg="Injecting comms on handle.")
dask_worker.log_event(topic="info", msg="Injecting comms on handle.")

inject_comms_on_handle(
handle, nccl_comm, ucp_worker, eps, nWorkers, workerId, verbose
)

if verbose:
worker.log_event(
dask_worker.log_event(
topic="info", msg="Finished injecting comms on handle."
)

raft_comm_state["handle"] = handle


def _func_build_handle(sessionId, streams_per_handle, verbose):
def _func_build_handle(sessionId, streams_per_handle, verbose, dask_worker=None):
"""
Builds a handle_t on the current worker given the initialized comms
Expand All @@ -535,17 +552,19 @@ def _func_build_handle(sessionId, streams_per_handle, verbose):
sessionId : str id to reference state for current comms instance.
streams_per_handle : int number of internal streams to create
verbose : bool print verbose logging output
dask_worker : dask_worker object
(Note: if called by client.run(), this is supplied by Dask
and not the client)
"""
worker = get_worker()
if verbose:
worker.log_event(
dask_worker.log_event(
topic="info", msg="Finished injecting comms on handle."
)

handle = Handle(n_streams=streams_per_handle)

raft_comm_state = get_raft_comm_state(
sessionId=sessionId, state_object=worker
sessionId=sessionId, state_object=dask_worker
)

workerId = raft_comm_state["wid"]
Expand All @@ -558,16 +577,16 @@ def _func_build_handle(sessionId, streams_per_handle, verbose):
raft_comm_state["handle"] = handle


def _func_store_initial_state(nworkers, sessionId, uniqueId, wid):
def _func_store_initial_state(nworkers, sessionId, uniqueId, wid, dask_worker=None):
raft_comm_state = get_raft_comm_state(
sessionId=sessionId, state_object=get_worker()
sessionId=sessionId, state_object=dask_worker
)
raft_comm_state["nccl_uid"] = uniqueId
raft_comm_state["wid"] = wid
raft_comm_state["nworkers"] = nworkers


async def _func_ucp_create_endpoints(sessionId, worker_info):
async def _func_ucp_create_endpoints(sessionId, worker_info, dask_worker):
"""
Runs on each worker to create ucp endpoints to all other workers
Expand All @@ -577,6 +596,9 @@ async def _func_ucp_create_endpoints(sessionId, worker_info):
uuid unique id for this instance
worker_info : dict
Maps worker addresses to NCCL ranks & UCX ports
dask_worker : dask_worker object
(Note: if called by client.run(), this is supplied by Dask
and not the client)
"""
eps = [None] * len(worker_info)
count = 1
Expand All @@ -590,34 +612,33 @@ async def _func_ucp_create_endpoints(sessionId, worker_info):
count += 1

raft_comm_state = get_raft_comm_state(
sessionId=sessionId, state_object=get_worker()
sessionId=sessionId, state_object=dask_worker
)
raft_comm_state["ucp_eps"] = eps


async def _func_destroy_all(sessionId, comms_p2p, verbose=False):
worker = get_worker()
async def _func_destroy_all(sessionId, comms_p2p, verbose=False, dask_worker=None):
if verbose:
worker.log_event(topic="info", msg="Destroying NCCL session state.")
dask_worker.log_event(topic="info", msg="Destroying NCCL session state.")

raft_comm_state = get_raft_comm_state(
sessionId=sessionId, state_object=worker
sessionId=sessionId, state_object=dask_worker
)
if "nccl" in raft_comm_state:
raft_comm_state["nccl"].destroy()
del raft_comm_state["nccl"]
if verbose:
worker.log_event(topic="info", msg="NCCL session state destroyed.")
dask_worker.log_event(topic="info", msg="NCCL session state destroyed.")
else:
if verbose:
worker.log_event(
dask_worker.log_event(
topic="warning",
msg=f"Session state for, '{sessionId}', "
f"does not contain expected 'nccl' element",
)

if verbose:
worker.log_event(
dask_worker.log_event(
topic="info",
msg=f"Destroying CUDA handle for sessionId, '{sessionId}.'",
)
Expand All @@ -626,7 +647,7 @@ async def _func_destroy_all(sessionId, comms_p2p, verbose=False):
del raft_comm_state["handle"]
else:
if verbose:
worker.log_event(
dask_worker.log_event(
topic="warning",
msg=f"Session state for, '{sessionId}', "
f"does not contain expected 'handle' element",
Expand Down
6 changes: 2 additions & 4 deletions python/raft-dask/raft_dask/test/test_comms.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,11 +114,9 @@ def func_check_uid_on_scheduler(sessionId, uniqueId, dask_scheduler):
)


def func_check_uid_on_worker(sessionId, uniqueId):
from dask.distributed import get_worker

def func_check_uid_on_worker(sessionId, uniqueId, dask_worker=None):
return func_check_uid(
sessionId=sessionId, uniqueId=uniqueId, state_object=get_worker()
sessionId=sessionId, uniqueId=uniqueId, state_object=dask_worker
)


Expand Down

0 comments on commit 608064b

Please sign in to comment.