From 608064b1c5024a6232ee0420bc7366fdc9e42810 Mon Sep 17 00:00:00 2001 From: Peter Andreas Entschev Date: Wed, 22 Mar 2023 04:23:41 -0700 Subject: [PATCH] Remove usage of Dask's `get_worker` In dask/distributed#7580 get_worker was modified to return the worker of a task, thus it cannot be used by client.run, and we must now use dask_worker as the first argument to client.run to obtain the worker. --- python/raft-dask/raft_dask/common/comms.py | 123 ++++++++++-------- python/raft-dask/raft_dask/test/test_comms.py | 6 +- 2 files changed, 74 insertions(+), 55 deletions(-) diff --git a/python/raft-dask/raft_dask/common/comms.py b/python/raft-dask/raft_dask/common/comms.py index 56e40b98da..5cb9f65629 100644 --- a/python/raft-dask/raft_dask/common/comms.py +++ b/python/raft-dask/raft_dask/common/comms.py @@ -19,7 +19,7 @@ import warnings from collections import OrderedDict -from dask.distributed import default_client, get_worker +from dask.distributed import default_client from pylibraft.common.handle import Handle @@ -242,7 +242,7 @@ def destroy(self): self.ucx_initialized = False -def local_handle(sessionId): +def local_handle(sessionId, dask_worker=None): """ Simple helper function for retrieving the local handle_t instance for a comms session on a worker. @@ -251,16 +251,19 @@ def local_handle(sessionId): ---------- sessionId : str session identifier from an initialized comms instance + dask_worker : dask_worker object + (Note: if called by client.run(), this is supplied by Dask + and not the client) Returns ------- handle : raft.Handle or None """ - state = get_raft_comm_state(sessionId, get_worker()) + state = get_raft_comm_state(sessionId, dask_worker) return state["handle"] if "handle" in state else None -def get_raft_comm_state(sessionId, state_object=None): +def get_raft_comm_state(sessionId, state_object=None, dask_worker=None): """ Retrieves cuML comms state on the scheduler node, for the given sessionId, creating a new session if it does not exist. If no session id is given, @@ -271,13 +274,16 @@ def get_raft_comm_state(sessionId, state_object=None): sessionId : SessionId value to retrieve from the dask_scheduler instances state_object : Object (either Worker, or Scheduler) on which the raft comm state will retrieved (or created) + dask_worker : dask_worker object + (Note: if called by client.run(), this is supplied by Dask + and not the client) Returns ------- session state : str session state associated with sessionId """ - state_object = state_object if state_object is not None else get_worker() + state_object = state_object if state_object is not None else dask_worker if not hasattr(state_object, "_raft_comm_state"): state_object._raft_comm_state = {} @@ -308,13 +314,19 @@ def set_nccl_root(sessionId, state_object): return raft_comm_state["nccl_uid"] -def get_ucx(): +def get_ucx(dask_worker=None): """ A simple convenience wrapper to make sure UCP listener and endpoints are only ever assigned once per worker. + + Parameters + ---------- + dask_worker : dask_worker object + (Note: if called by client.run(), this is supplied by Dask + and not the client) """ raft_comm_state = get_raft_comm_state( - sessionId="ucp", state_object=get_worker() + sessionId="ucp", state_object=dask_worker ) if "ucx" not in raft_comm_state: raft_comm_state["ucx"] = UCX.get() @@ -371,7 +383,7 @@ def _func_set_scheduler_as_nccl_root(sessionId, verbose, dask_scheduler): return nccl_uid -def _func_set_worker_as_nccl_root(sessionId, verbose): +def _func_set_worker_as_nccl_root(sessionId, verbose, dask_worker=None): """ Creates a persistent nccl uniqueId on the scheduler node. @@ -380,23 +392,25 @@ def _func_set_worker_as_nccl_root(sessionId, verbose): ---------- sessionId : Associated session to attach the unique ID to. verbose : Indicates whether or not to emit additional information + dask_worker : dask_worker object + (Note: if called by client.run(), this is supplied by Dask + and not the client) Return ------ uniqueId : byte str NCCL uniqueId, associating this DASK worker as its root node. """ - worker = get_worker() if verbose: - worker.log_event( + dask_worker.log_event( topic="info", msg=f"Setting worker as NCCL root for session, '{sessionId}'", ) - nccl_uid = set_nccl_root(sessionId=sessionId, state_object=worker) + nccl_uid = set_nccl_root(sessionId=sessionId, state_object=dask_worker) if verbose: - worker.log_event( + dask_worker.log_event( topic="info", msg="Done setting scheduler as NCCL root." ) @@ -408,35 +422,34 @@ def _func_ucp_listener_port(): async def _func_init_all( - sessionId, uniqueId, comms_p2p, worker_info, verbose, streams_per_handle + sessionId, uniqueId, comms_p2p, worker_info, verbose, streams_per_handle, dask_worker=None ): - worker = get_worker() raft_comm_state = get_raft_comm_state( - sessionId=sessionId, state_object=worker + sessionId=sessionId, state_object=dask_worker ) raft_comm_state["nccl_uid"] = uniqueId - raft_comm_state["wid"] = worker_info[get_worker().address]["rank"] + raft_comm_state["wid"] = worker_info[dask_worker.address]["rank"] raft_comm_state["nworkers"] = len(worker_info) if verbose: - worker.log_event(topic="info", msg="Initializing NCCL.") + dask_worker.log_event(topic="info", msg="Initializing NCCL.") start = time.time() - _func_init_nccl(sessionId, uniqueId) + _func_init_nccl(sessionId, uniqueId, dask_worker=dask_worker) if verbose: elapsed = time.time() - start - worker.log_event( + dask_worker.log_event( topic="info", msg=f"NCCL Initialization took: {elapsed} seconds." ) if comms_p2p: if verbose: - worker.log_event(topic="info", msg="Initializing UCX Endpoints") + dask_worker.log_event(topic="info", msg="Initializing UCX Endpoints") if verbose: start = time.time() - await _func_ucp_create_endpoints(sessionId, worker_info) + await _func_ucp_create_endpoints(sessionId, worker_info, dask_worker=dask_worker) if verbose: elapsed = time.time() - start @@ -444,18 +457,18 @@ async def _func_init_all( f"Done initializing UCX endpoints." f"Took: {elapsed} seconds.\nBuilding handle." ) - worker.log_event(topic="info", msg=msg) + dask_worker.log_event(topic="info", msg=msg) _func_build_handle_p2p(sessionId, streams_per_handle, verbose) if verbose: - worker.log_event(topic="info", msg="Done building handle.") + dask_worker.log_event(topic="info", msg="Done building handle.") else: - _func_build_handle(sessionId, streams_per_handle, verbose) + _func_build_handle(sessionId, streams_per_handle, verbose, dask_worker=dask_worker) -def _func_init_nccl(sessionId, uniqueId): +def _func_init_nccl(sessionId, uniqueId, dask_worker=None): """ Initialize ncclComm_t on worker @@ -466,11 +479,13 @@ def _func_init_nccl(sessionId, uniqueId): uniqueId : array[byte] The NCCL unique Id generated from the client. + dask_worker : dask_worker object + (Note: if called by client.run(), this is supplied by Dask + and not the client) """ - worker = get_worker() raft_comm_state = get_raft_comm_state( - sessionId=sessionId, state_object=get_worker() + sessionId=sessionId, state_object=dask_worker, dask_worker=dask_worker ) wid = raft_comm_state["wid"] nWorkers = raft_comm_state["nworkers"] @@ -480,13 +495,13 @@ def _func_init_nccl(sessionId, uniqueId): n.init(nWorkers, uniqueId, wid) raft_comm_state["nccl"] = n except Exception as e: - worker.log_event( + dask_worker.log_event( topic="error", msg=f"An error occurred initializing NCCL: {e}." ) raise -def _func_build_handle_p2p(sessionId, streams_per_handle, verbose): +def _func_build_handle_p2p(sessionId, streams_per_handle, verbose, dask_worker=None): """ Builds a handle_t on the current worker given the initialized comms @@ -495,14 +510,16 @@ def _func_build_handle_p2p(sessionId, streams_per_handle, verbose): sessionId : str id to reference state for current comms instance. streams_per_handle : int number of internal streams to create verbose : bool print verbose logging output + dask_worker : dask_worker object + (Note: if called by client.run(), this is supplied by Dask + and not the client) """ - worker = get_worker() if verbose: - worker.log_event(topic="info", msg="Building p2p handle.") + dask_worker.log_event(topic="info", msg="Building p2p handle.") - ucp_worker = get_ucx().get_worker() + ucp_worker = get_ucx(dask_worker).get_worker() raft_comm_state = get_raft_comm_state( - sessionId=sessionId, state_object=worker + sessionId=sessionId, state_object=dask_worker ) handle = Handle(n_streams=streams_per_handle) @@ -512,21 +529,21 @@ def _func_build_handle_p2p(sessionId, streams_per_handle, verbose): workerId = raft_comm_state["wid"] if verbose: - worker.log_event(topic="info", msg="Injecting comms on handle.") + dask_worker.log_event(topic="info", msg="Injecting comms on handle.") inject_comms_on_handle( handle, nccl_comm, ucp_worker, eps, nWorkers, workerId, verbose ) if verbose: - worker.log_event( + dask_worker.log_event( topic="info", msg="Finished injecting comms on handle." ) raft_comm_state["handle"] = handle -def _func_build_handle(sessionId, streams_per_handle, verbose): +def _func_build_handle(sessionId, streams_per_handle, verbose, dask_worker=None): """ Builds a handle_t on the current worker given the initialized comms @@ -535,17 +552,19 @@ def _func_build_handle(sessionId, streams_per_handle, verbose): sessionId : str id to reference state for current comms instance. streams_per_handle : int number of internal streams to create verbose : bool print verbose logging output + dask_worker : dask_worker object + (Note: if called by client.run(), this is supplied by Dask + and not the client) """ - worker = get_worker() if verbose: - worker.log_event( + dask_worker.log_event( topic="info", msg="Finished injecting comms on handle." ) handle = Handle(n_streams=streams_per_handle) raft_comm_state = get_raft_comm_state( - sessionId=sessionId, state_object=worker + sessionId=sessionId, state_object=dask_worker ) workerId = raft_comm_state["wid"] @@ -558,16 +577,16 @@ def _func_build_handle(sessionId, streams_per_handle, verbose): raft_comm_state["handle"] = handle -def _func_store_initial_state(nworkers, sessionId, uniqueId, wid): +def _func_store_initial_state(nworkers, sessionId, uniqueId, wid, dask_worker=None): raft_comm_state = get_raft_comm_state( - sessionId=sessionId, state_object=get_worker() + sessionId=sessionId, state_object=dask_worker ) raft_comm_state["nccl_uid"] = uniqueId raft_comm_state["wid"] = wid raft_comm_state["nworkers"] = nworkers -async def _func_ucp_create_endpoints(sessionId, worker_info): +async def _func_ucp_create_endpoints(sessionId, worker_info, dask_worker): """ Runs on each worker to create ucp endpoints to all other workers @@ -577,6 +596,9 @@ async def _func_ucp_create_endpoints(sessionId, worker_info): uuid unique id for this instance worker_info : dict Maps worker addresses to NCCL ranks & UCX ports + dask_worker : dask_worker object + (Note: if called by client.run(), this is supplied by Dask + and not the client) """ eps = [None] * len(worker_info) count = 1 @@ -590,34 +612,33 @@ async def _func_ucp_create_endpoints(sessionId, worker_info): count += 1 raft_comm_state = get_raft_comm_state( - sessionId=sessionId, state_object=get_worker() + sessionId=sessionId, state_object=dask_worker ) raft_comm_state["ucp_eps"] = eps -async def _func_destroy_all(sessionId, comms_p2p, verbose=False): - worker = get_worker() +async def _func_destroy_all(sessionId, comms_p2p, verbose=False, dask_worker=None): if verbose: - worker.log_event(topic="info", msg="Destroying NCCL session state.") + dask_worker.log_event(topic="info", msg="Destroying NCCL session state.") raft_comm_state = get_raft_comm_state( - sessionId=sessionId, state_object=worker + sessionId=sessionId, state_object=dask_worker ) if "nccl" in raft_comm_state: raft_comm_state["nccl"].destroy() del raft_comm_state["nccl"] if verbose: - worker.log_event(topic="info", msg="NCCL session state destroyed.") + dask_worker.log_event(topic="info", msg="NCCL session state destroyed.") else: if verbose: - worker.log_event( + dask_worker.log_event( topic="warning", msg=f"Session state for, '{sessionId}', " f"does not contain expected 'nccl' element", ) if verbose: - worker.log_event( + dask_worker.log_event( topic="info", msg=f"Destroying CUDA handle for sessionId, '{sessionId}.'", ) @@ -626,7 +647,7 @@ async def _func_destroy_all(sessionId, comms_p2p, verbose=False): del raft_comm_state["handle"] else: if verbose: - worker.log_event( + dask_worker.log_event( topic="warning", msg=f"Session state for, '{sessionId}', " f"does not contain expected 'handle' element", diff --git a/python/raft-dask/raft_dask/test/test_comms.py b/python/raft-dask/raft_dask/test/test_comms.py index 74ec446e94..d84abad5e8 100644 --- a/python/raft-dask/raft_dask/test/test_comms.py +++ b/python/raft-dask/raft_dask/test/test_comms.py @@ -114,11 +114,9 @@ def func_check_uid_on_scheduler(sessionId, uniqueId, dask_scheduler): ) -def func_check_uid_on_worker(sessionId, uniqueId): - from dask.distributed import get_worker - +def func_check_uid_on_worker(sessionId, uniqueId, dask_worker=None): return func_check_uid( - sessionId=sessionId, uniqueId=uniqueId, state_object=get_worker() + sessionId=sessionId, uniqueId=uniqueId, state_object=dask_worker )