From 608064b1c5024a6232ee0420bc7366fdc9e42810 Mon Sep 17 00:00:00 2001 From: Peter Andreas Entschev Date: Wed, 22 Mar 2023 04:23:41 -0700 Subject: [PATCH 1/5] Remove usage of Dask's `get_worker` In dask/distributed#7580 get_worker was modified to return the worker of a task, thus it cannot be used by client.run, and we must now use dask_worker as the first argument to client.run to obtain the worker. --- python/raft-dask/raft_dask/common/comms.py | 123 ++++++++++-------- python/raft-dask/raft_dask/test/test_comms.py | 6 +- 2 files changed, 74 insertions(+), 55 deletions(-) diff --git a/python/raft-dask/raft_dask/common/comms.py b/python/raft-dask/raft_dask/common/comms.py index 56e40b98da..5cb9f65629 100644 --- a/python/raft-dask/raft_dask/common/comms.py +++ b/python/raft-dask/raft_dask/common/comms.py @@ -19,7 +19,7 @@ import warnings from collections import OrderedDict -from dask.distributed import default_client, get_worker +from dask.distributed import default_client from pylibraft.common.handle import Handle @@ -242,7 +242,7 @@ def destroy(self): self.ucx_initialized = False -def local_handle(sessionId): +def local_handle(sessionId, dask_worker=None): """ Simple helper function for retrieving the local handle_t instance for a comms session on a worker. @@ -251,16 +251,19 @@ def local_handle(sessionId): ---------- sessionId : str session identifier from an initialized comms instance + dask_worker : dask_worker object + (Note: if called by client.run(), this is supplied by Dask + and not the client) Returns ------- handle : raft.Handle or None """ - state = get_raft_comm_state(sessionId, get_worker()) + state = get_raft_comm_state(sessionId, dask_worker) return state["handle"] if "handle" in state else None -def get_raft_comm_state(sessionId, state_object=None): +def get_raft_comm_state(sessionId, state_object=None, dask_worker=None): """ Retrieves cuML comms state on the scheduler node, for the given sessionId, creating a new session if it does not exist. If no session id is given, @@ -271,13 +274,16 @@ def get_raft_comm_state(sessionId, state_object=None): sessionId : SessionId value to retrieve from the dask_scheduler instances state_object : Object (either Worker, or Scheduler) on which the raft comm state will retrieved (or created) + dask_worker : dask_worker object + (Note: if called by client.run(), this is supplied by Dask + and not the client) Returns ------- session state : str session state associated with sessionId """ - state_object = state_object if state_object is not None else get_worker() + state_object = state_object if state_object is not None else dask_worker if not hasattr(state_object, "_raft_comm_state"): state_object._raft_comm_state = {} @@ -308,13 +314,19 @@ def set_nccl_root(sessionId, state_object): return raft_comm_state["nccl_uid"] -def get_ucx(): +def get_ucx(dask_worker=None): """ A simple convenience wrapper to make sure UCP listener and endpoints are only ever assigned once per worker. + + Parameters + ---------- + dask_worker : dask_worker object + (Note: if called by client.run(), this is supplied by Dask + and not the client) """ raft_comm_state = get_raft_comm_state( - sessionId="ucp", state_object=get_worker() + sessionId="ucp", state_object=dask_worker ) if "ucx" not in raft_comm_state: raft_comm_state["ucx"] = UCX.get() @@ -371,7 +383,7 @@ def _func_set_scheduler_as_nccl_root(sessionId, verbose, dask_scheduler): return nccl_uid -def _func_set_worker_as_nccl_root(sessionId, verbose): +def _func_set_worker_as_nccl_root(sessionId, verbose, dask_worker=None): """ Creates a persistent nccl uniqueId on the scheduler node. @@ -380,23 +392,25 @@ def _func_set_worker_as_nccl_root(sessionId, verbose): ---------- sessionId : Associated session to attach the unique ID to. verbose : Indicates whether or not to emit additional information + dask_worker : dask_worker object + (Note: if called by client.run(), this is supplied by Dask + and not the client) Return ------ uniqueId : byte str NCCL uniqueId, associating this DASK worker as its root node. """ - worker = get_worker() if verbose: - worker.log_event( + dask_worker.log_event( topic="info", msg=f"Setting worker as NCCL root for session, '{sessionId}'", ) - nccl_uid = set_nccl_root(sessionId=sessionId, state_object=worker) + nccl_uid = set_nccl_root(sessionId=sessionId, state_object=dask_worker) if verbose: - worker.log_event( + dask_worker.log_event( topic="info", msg="Done setting scheduler as NCCL root." ) @@ -408,35 +422,34 @@ def _func_ucp_listener_port(): async def _func_init_all( - sessionId, uniqueId, comms_p2p, worker_info, verbose, streams_per_handle + sessionId, uniqueId, comms_p2p, worker_info, verbose, streams_per_handle, dask_worker=None ): - worker = get_worker() raft_comm_state = get_raft_comm_state( - sessionId=sessionId, state_object=worker + sessionId=sessionId, state_object=dask_worker ) raft_comm_state["nccl_uid"] = uniqueId - raft_comm_state["wid"] = worker_info[get_worker().address]["rank"] + raft_comm_state["wid"] = worker_info[dask_worker.address]["rank"] raft_comm_state["nworkers"] = len(worker_info) if verbose: - worker.log_event(topic="info", msg="Initializing NCCL.") + dask_worker.log_event(topic="info", msg="Initializing NCCL.") start = time.time() - _func_init_nccl(sessionId, uniqueId) + _func_init_nccl(sessionId, uniqueId, dask_worker=dask_worker) if verbose: elapsed = time.time() - start - worker.log_event( + dask_worker.log_event( topic="info", msg=f"NCCL Initialization took: {elapsed} seconds." ) if comms_p2p: if verbose: - worker.log_event(topic="info", msg="Initializing UCX Endpoints") + dask_worker.log_event(topic="info", msg="Initializing UCX Endpoints") if verbose: start = time.time() - await _func_ucp_create_endpoints(sessionId, worker_info) + await _func_ucp_create_endpoints(sessionId, worker_info, dask_worker=dask_worker) if verbose: elapsed = time.time() - start @@ -444,18 +457,18 @@ async def _func_init_all( f"Done initializing UCX endpoints." f"Took: {elapsed} seconds.\nBuilding handle." ) - worker.log_event(topic="info", msg=msg) + dask_worker.log_event(topic="info", msg=msg) _func_build_handle_p2p(sessionId, streams_per_handle, verbose) if verbose: - worker.log_event(topic="info", msg="Done building handle.") + dask_worker.log_event(topic="info", msg="Done building handle.") else: - _func_build_handle(sessionId, streams_per_handle, verbose) + _func_build_handle(sessionId, streams_per_handle, verbose, dask_worker=dask_worker) -def _func_init_nccl(sessionId, uniqueId): +def _func_init_nccl(sessionId, uniqueId, dask_worker=None): """ Initialize ncclComm_t on worker @@ -466,11 +479,13 @@ def _func_init_nccl(sessionId, uniqueId): uniqueId : array[byte] The NCCL unique Id generated from the client. + dask_worker : dask_worker object + (Note: if called by client.run(), this is supplied by Dask + and not the client) """ - worker = get_worker() raft_comm_state = get_raft_comm_state( - sessionId=sessionId, state_object=get_worker() + sessionId=sessionId, state_object=dask_worker, dask_worker=dask_worker ) wid = raft_comm_state["wid"] nWorkers = raft_comm_state["nworkers"] @@ -480,13 +495,13 @@ def _func_init_nccl(sessionId, uniqueId): n.init(nWorkers, uniqueId, wid) raft_comm_state["nccl"] = n except Exception as e: - worker.log_event( + dask_worker.log_event( topic="error", msg=f"An error occurred initializing NCCL: {e}." ) raise -def _func_build_handle_p2p(sessionId, streams_per_handle, verbose): +def _func_build_handle_p2p(sessionId, streams_per_handle, verbose, dask_worker=None): """ Builds a handle_t on the current worker given the initialized comms @@ -495,14 +510,16 @@ def _func_build_handle_p2p(sessionId, streams_per_handle, verbose): sessionId : str id to reference state for current comms instance. streams_per_handle : int number of internal streams to create verbose : bool print verbose logging output + dask_worker : dask_worker object + (Note: if called by client.run(), this is supplied by Dask + and not the client) """ - worker = get_worker() if verbose: - worker.log_event(topic="info", msg="Building p2p handle.") + dask_worker.log_event(topic="info", msg="Building p2p handle.") - ucp_worker = get_ucx().get_worker() + ucp_worker = get_ucx(dask_worker).get_worker() raft_comm_state = get_raft_comm_state( - sessionId=sessionId, state_object=worker + sessionId=sessionId, state_object=dask_worker ) handle = Handle(n_streams=streams_per_handle) @@ -512,21 +529,21 @@ def _func_build_handle_p2p(sessionId, streams_per_handle, verbose): workerId = raft_comm_state["wid"] if verbose: - worker.log_event(topic="info", msg="Injecting comms on handle.") + dask_worker.log_event(topic="info", msg="Injecting comms on handle.") inject_comms_on_handle( handle, nccl_comm, ucp_worker, eps, nWorkers, workerId, verbose ) if verbose: - worker.log_event( + dask_worker.log_event( topic="info", msg="Finished injecting comms on handle." ) raft_comm_state["handle"] = handle -def _func_build_handle(sessionId, streams_per_handle, verbose): +def _func_build_handle(sessionId, streams_per_handle, verbose, dask_worker=None): """ Builds a handle_t on the current worker given the initialized comms @@ -535,17 +552,19 @@ def _func_build_handle(sessionId, streams_per_handle, verbose): sessionId : str id to reference state for current comms instance. streams_per_handle : int number of internal streams to create verbose : bool print verbose logging output + dask_worker : dask_worker object + (Note: if called by client.run(), this is supplied by Dask + and not the client) """ - worker = get_worker() if verbose: - worker.log_event( + dask_worker.log_event( topic="info", msg="Finished injecting comms on handle." ) handle = Handle(n_streams=streams_per_handle) raft_comm_state = get_raft_comm_state( - sessionId=sessionId, state_object=worker + sessionId=sessionId, state_object=dask_worker ) workerId = raft_comm_state["wid"] @@ -558,16 +577,16 @@ def _func_build_handle(sessionId, streams_per_handle, verbose): raft_comm_state["handle"] = handle -def _func_store_initial_state(nworkers, sessionId, uniqueId, wid): +def _func_store_initial_state(nworkers, sessionId, uniqueId, wid, dask_worker=None): raft_comm_state = get_raft_comm_state( - sessionId=sessionId, state_object=get_worker() + sessionId=sessionId, state_object=dask_worker ) raft_comm_state["nccl_uid"] = uniqueId raft_comm_state["wid"] = wid raft_comm_state["nworkers"] = nworkers -async def _func_ucp_create_endpoints(sessionId, worker_info): +async def _func_ucp_create_endpoints(sessionId, worker_info, dask_worker): """ Runs on each worker to create ucp endpoints to all other workers @@ -577,6 +596,9 @@ async def _func_ucp_create_endpoints(sessionId, worker_info): uuid unique id for this instance worker_info : dict Maps worker addresses to NCCL ranks & UCX ports + dask_worker : dask_worker object + (Note: if called by client.run(), this is supplied by Dask + and not the client) """ eps = [None] * len(worker_info) count = 1 @@ -590,34 +612,33 @@ async def _func_ucp_create_endpoints(sessionId, worker_info): count += 1 raft_comm_state = get_raft_comm_state( - sessionId=sessionId, state_object=get_worker() + sessionId=sessionId, state_object=dask_worker ) raft_comm_state["ucp_eps"] = eps -async def _func_destroy_all(sessionId, comms_p2p, verbose=False): - worker = get_worker() +async def _func_destroy_all(sessionId, comms_p2p, verbose=False, dask_worker=None): if verbose: - worker.log_event(topic="info", msg="Destroying NCCL session state.") + dask_worker.log_event(topic="info", msg="Destroying NCCL session state.") raft_comm_state = get_raft_comm_state( - sessionId=sessionId, state_object=worker + sessionId=sessionId, state_object=dask_worker ) if "nccl" in raft_comm_state: raft_comm_state["nccl"].destroy() del raft_comm_state["nccl"] if verbose: - worker.log_event(topic="info", msg="NCCL session state destroyed.") + dask_worker.log_event(topic="info", msg="NCCL session state destroyed.") else: if verbose: - worker.log_event( + dask_worker.log_event( topic="warning", msg=f"Session state for, '{sessionId}', " f"does not contain expected 'nccl' element", ) if verbose: - worker.log_event( + dask_worker.log_event( topic="info", msg=f"Destroying CUDA handle for sessionId, '{sessionId}.'", ) @@ -626,7 +647,7 @@ async def _func_destroy_all(sessionId, comms_p2p, verbose=False): del raft_comm_state["handle"] else: if verbose: - worker.log_event( + dask_worker.log_event( topic="warning", msg=f"Session state for, '{sessionId}', " f"does not contain expected 'handle' element", diff --git a/python/raft-dask/raft_dask/test/test_comms.py b/python/raft-dask/raft_dask/test/test_comms.py index 74ec446e94..d84abad5e8 100644 --- a/python/raft-dask/raft_dask/test/test_comms.py +++ b/python/raft-dask/raft_dask/test/test_comms.py @@ -114,11 +114,9 @@ def func_check_uid_on_scheduler(sessionId, uniqueId, dask_scheduler): ) -def func_check_uid_on_worker(sessionId, uniqueId): - from dask.distributed import get_worker - +def func_check_uid_on_worker(sessionId, uniqueId, dask_worker=None): return func_check_uid( - sessionId=sessionId, uniqueId=uniqueId, state_object=get_worker() + sessionId=sessionId, uniqueId=uniqueId, state_object=dask_worker ) From 6a38498a38c938f40821325962429d381254bca5 Mon Sep 17 00:00:00 2001 From: Peter Andreas Entschev Date: Wed, 22 Mar 2023 05:58:15 -0700 Subject: [PATCH 2/5] Fix linting --- python/raft-dask/raft_dask/common/comms.py | 44 ++++++++++++++----- python/raft-dask/raft_dask/test/test_comms.py | 2 +- 2 files changed, 35 insertions(+), 11 deletions(-) diff --git a/python/raft-dask/raft_dask/common/comms.py b/python/raft-dask/raft_dask/common/comms.py index 5cb9f65629..9292213f44 100644 --- a/python/raft-dask/raft_dask/common/comms.py +++ b/python/raft-dask/raft_dask/common/comms.py @@ -422,7 +422,13 @@ def _func_ucp_listener_port(): async def _func_init_all( - sessionId, uniqueId, comms_p2p, worker_info, verbose, streams_per_handle, dask_worker=None + sessionId, + uniqueId, + comms_p2p, + worker_info, + verbose, + streams_per_handle, + dask_worker=None, ): raft_comm_state = get_raft_comm_state( sessionId=sessionId, state_object=dask_worker @@ -445,11 +451,15 @@ async def _func_init_all( if comms_p2p: if verbose: - dask_worker.log_event(topic="info", msg="Initializing UCX Endpoints") + dask_worker.log_event( + topic="info", msg="Initializing UCX Endpoints" + ) if verbose: start = time.time() - await _func_ucp_create_endpoints(sessionId, worker_info, dask_worker=dask_worker) + await _func_ucp_create_endpoints( + sessionId, worker_info, dask_worker=dask_worker + ) if verbose: elapsed = time.time() - start @@ -465,7 +475,9 @@ async def _func_init_all( dask_worker.log_event(topic="info", msg="Done building handle.") else: - _func_build_handle(sessionId, streams_per_handle, verbose, dask_worker=dask_worker) + _func_build_handle( + sessionId, streams_per_handle, verbose, dask_worker=dask_worker + ) def _func_init_nccl(sessionId, uniqueId, dask_worker=None): @@ -501,7 +513,9 @@ def _func_init_nccl(sessionId, uniqueId, dask_worker=None): raise -def _func_build_handle_p2p(sessionId, streams_per_handle, verbose, dask_worker=None): +def _func_build_handle_p2p( + sessionId, streams_per_handle, verbose, dask_worker=None +): """ Builds a handle_t on the current worker given the initialized comms @@ -543,7 +557,9 @@ def _func_build_handle_p2p(sessionId, streams_per_handle, verbose, dask_worker=N raft_comm_state["handle"] = handle -def _func_build_handle(sessionId, streams_per_handle, verbose, dask_worker=None): +def _func_build_handle( + sessionId, streams_per_handle, verbose, dask_worker=None +): """ Builds a handle_t on the current worker given the initialized comms @@ -577,7 +593,9 @@ def _func_build_handle(sessionId, streams_per_handle, verbose, dask_worker=None) raft_comm_state["handle"] = handle -def _func_store_initial_state(nworkers, sessionId, uniqueId, wid, dask_worker=None): +def _func_store_initial_state( + nworkers, sessionId, uniqueId, wid, dask_worker=None +): raft_comm_state = get_raft_comm_state( sessionId=sessionId, state_object=dask_worker ) @@ -617,9 +635,13 @@ async def _func_ucp_create_endpoints(sessionId, worker_info, dask_worker): raft_comm_state["ucp_eps"] = eps -async def _func_destroy_all(sessionId, comms_p2p, verbose=False, dask_worker=None): +async def _func_destroy_all( + sessionId, comms_p2p, verbose=False, dask_worker=None +): if verbose: - dask_worker.log_event(topic="info", msg="Destroying NCCL session state.") + dask_worker.log_event( + topic="info", msg="Destroying NCCL session state." + ) raft_comm_state = get_raft_comm_state( sessionId=sessionId, state_object=dask_worker @@ -628,7 +650,9 @@ async def _func_destroy_all(sessionId, comms_p2p, verbose=False, dask_worker=Non raft_comm_state["nccl"].destroy() del raft_comm_state["nccl"] if verbose: - dask_worker.log_event(topic="info", msg="NCCL session state destroyed.") + dask_worker.log_event( + topic="info", msg="NCCL session state destroyed." + ) else: if verbose: dask_worker.log_event( diff --git a/python/raft-dask/raft_dask/test/test_comms.py b/python/raft-dask/raft_dask/test/test_comms.py index d84abad5e8..6820a7c43b 100644 --- a/python/raft-dask/raft_dask/test/test_comms.py +++ b/python/raft-dask/raft_dask/test/test_comms.py @@ -1,4 +1,4 @@ -# Copyright (c) 2019-2022, NVIDIA CORPORATION. +# Copyright (c) 2019-2023, NVIDIA CORPORATION. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. From ec1a8b5168e8eba93fa8479666f159ebd10fdbd0 Mon Sep 17 00:00:00 2001 From: Peter Andreas Entschev Date: Wed, 22 Mar 2023 10:46:16 -0700 Subject: [PATCH 3/5] Pass missing `dask_worker` to `get_ucx` calls --- python/raft-dask/raft_dask/common/comms.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/python/raft-dask/raft_dask/common/comms.py b/python/raft-dask/raft_dask/common/comms.py index 9292213f44..934a0ea89e 100644 --- a/python/raft-dask/raft_dask/common/comms.py +++ b/python/raft-dask/raft_dask/common/comms.py @@ -417,8 +417,8 @@ def _func_set_worker_as_nccl_root(sessionId, verbose, dask_worker=None): return nccl_uid -def _func_ucp_listener_port(): - return get_ucx().listener_port() +def _func_ucp_listener_port(dask_worker=None): + return get_ucx(dask_worker=dask_worker).listener_port() async def _func_init_all( @@ -469,7 +469,7 @@ async def _func_init_all( ) dask_worker.log_event(topic="info", msg=msg) - _func_build_handle_p2p(sessionId, streams_per_handle, verbose) + _func_build_handle_p2p(sessionId, streams_per_handle, verbose, dask_worker=dask_worker) if verbose: dask_worker.log_event(topic="info", msg="Done building handle.") @@ -624,7 +624,7 @@ async def _func_ucp_create_endpoints(sessionId, worker_info, dask_worker): for k in worker_info: ip, port = parse_host_port(k) - ep = await get_ucx().get_endpoint(ip, worker_info[k]["port"]) + ep = await get_ucx(dask_worker=dask_worker).get_endpoint(ip, worker_info[k]["port"]) eps[worker_info[k]["rank"]] = ep count += 1 From e6974ca92ffeff26e50debed54398ea0edbdc19c Mon Sep 17 00:00:00 2001 From: Peter Andreas Entschev Date: Wed, 22 Mar 2023 10:46:33 -0700 Subject: [PATCH 4/5] Use `get_worker()` in tests relying on `client.submit` --- ci/wheel_smoke_test_raft_dask.py | 6 +++--- python/raft-dask/raft_dask/test/test_comms.py | 16 ++++++++-------- 2 files changed, 11 insertions(+), 11 deletions(-) diff --git a/ci/wheel_smoke_test_raft_dask.py b/ci/wheel_smoke_test_raft_dask.py index 32c13e61ca..d022a9d644 100644 --- a/ci/wheel_smoke_test_raft_dask.py +++ b/ci/wheel_smoke_test_raft_dask.py @@ -1,4 +1,4 @@ -from dask.distributed import Client, wait +from dask.distributed import Client, get_worker, wait from dask_cuda import LocalCUDACluster, initialize from raft_dask.common import ( @@ -23,12 +23,12 @@ def func_test_send_recv(sessionId, n_trials): - handle = local_handle(sessionId) + handle = local_handle(sessionId, dask_worker=get_worker()) return perform_test_comms_send_recv(handle, n_trials) def func_test_collective(func, sessionId, root): - handle = local_handle(sessionId) + handle = local_handle(sessionId, dask_worker=get_worker()) return func(handle, root) diff --git a/python/raft-dask/raft_dask/test/test_comms.py b/python/raft-dask/raft_dask/test/test_comms.py index 6820a7c43b..3a430f9270 100644 --- a/python/raft-dask/raft_dask/test/test_comms.py +++ b/python/raft-dask/raft_dask/test/test_comms.py @@ -17,7 +17,7 @@ import pytest -from dask.distributed import Client, wait +from dask.distributed import Client, get_worker, wait try: from raft_dask.common import ( @@ -60,32 +60,32 @@ def test_comms_init_no_p2p(cluster): def func_test_collective(func, sessionId, root): - handle = local_handle(sessionId) + handle = local_handle(sessionId, dask_worker=get_worker()) return func(handle, root) def func_test_send_recv(sessionId, n_trials): - handle = local_handle(sessionId) + handle = local_handle(sessionId, dask_worker=get_worker()) return perform_test_comms_send_recv(handle, n_trials) def func_test_device_send_or_recv(sessionId, n_trials): - handle = local_handle(sessionId) + handle = local_handle(sessionId, dask_worker=get_worker()) return perform_test_comms_device_send_or_recv(handle, n_trials) def func_test_device_sendrecv(sessionId, n_trials): - handle = local_handle(sessionId) + handle = local_handle(sessionId, dask_worker=get_worker()) return perform_test_comms_device_sendrecv(handle, n_trials) def func_test_device_multicast_sendrecv(sessionId, n_trials): - handle = local_handle(sessionId) + handle = local_handle(sessionId, dask_worker=get_worker()) return perform_test_comms_device_multicast_sendrecv(handle, n_trials) def func_test_comm_split(sessionId, n_trials): - handle = local_handle(sessionId) + handle = local_handle(sessionId, dask_worker=get_worker()) return perform_test_comm_split(handle, n_trials) @@ -125,7 +125,7 @@ def test_handles(cluster): client = Client(cluster) def _has_handle(sessionId): - return local_handle(sessionId) is not None + return local_handle(sessionId, dask_worker=get_worker()) is not None try: cb = Comms(verbose=True) From a8a9bcd98a1ec3e645ec3745903ab5a5cd40e619 Mon Sep 17 00:00:00 2001 From: Peter Andreas Entschev Date: Wed, 22 Mar 2023 10:57:56 -0700 Subject: [PATCH 5/5] Fix linting and missing copyright --- ci/wheel_smoke_test_raft_dask.py | 15 +++++++++++++++ python/raft-dask/raft_dask/common/comms.py | 8 ++++++-- 2 files changed, 21 insertions(+), 2 deletions(-) diff --git a/ci/wheel_smoke_test_raft_dask.py b/ci/wheel_smoke_test_raft_dask.py index d022a9d644..5709ac901c 100644 --- a/ci/wheel_smoke_test_raft_dask.py +++ b/ci/wheel_smoke_test_raft_dask.py @@ -1,3 +1,18 @@ +# Copyright (c) 2019-2023, NVIDIA CORPORATION. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + from dask.distributed import Client, get_worker, wait from dask_cuda import LocalCUDACluster, initialize diff --git a/python/raft-dask/raft_dask/common/comms.py b/python/raft-dask/raft_dask/common/comms.py index 934a0ea89e..ebe9a8dc4f 100644 --- a/python/raft-dask/raft_dask/common/comms.py +++ b/python/raft-dask/raft_dask/common/comms.py @@ -469,7 +469,9 @@ async def _func_init_all( ) dask_worker.log_event(topic="info", msg=msg) - _func_build_handle_p2p(sessionId, streams_per_handle, verbose, dask_worker=dask_worker) + _func_build_handle_p2p( + sessionId, streams_per_handle, verbose, dask_worker=dask_worker + ) if verbose: dask_worker.log_event(topic="info", msg="Done building handle.") @@ -624,7 +626,9 @@ async def _func_ucp_create_endpoints(sessionId, worker_info, dask_worker): for k in worker_info: ip, port = parse_host_port(k) - ep = await get_ucx(dask_worker=dask_worker).get_endpoint(ip, worker_info[k]["port"]) + ep = await get_ucx(dask_worker=dask_worker).get_endpoint( + ip, worker_info[k]["port"] + ) eps[worker_info[k]["rank"]] = ep count += 1