From c6fbe6935458e3f0001694ad84b95d664e24bd0c Mon Sep 17 00:00:00 2001 From: Devin Robison Date: Tue, 12 Jan 2021 14:16:46 -0700 Subject: [PATCH 01/16] Initial fixes --- python/raft/dask/common/comms.py | 85 ++++++++++++++++++++++++++++++-- python/raft/dask/common/nccl.pyx | 36 +++++++------- 2 files changed, 101 insertions(+), 20 deletions(-) diff --git a/python/raft/dask/common/comms.py b/python/raft/dask/common/comms.py index 4278783968..34c667fc6f 100644 --- a/python/raft/dask/common/comms.py +++ b/python/raft/dask/common/comms.py @@ -77,8 +77,10 @@ def _use_comms(sessionId): cluster.close() """ + valid_nccl_placements = ('client', 'worker', 'scheduler') + def __init__(self, comms_p2p=False, client=None, verbose=False, - streams_per_handle=0): + streams_per_handle=0, nccl_root_location="scheduler"): """ Construct a new CommsContext instance @@ -90,10 +92,18 @@ def __init__(self, comms_p2p=False, client=None, verbose=False, Dask client to use verbose : bool Print verbose logging + nccl_root_location : string + Indicates where the NCCL's root node should be located. ['client', 'worker', 'scheduler' (default)] + """ self.client = client if client is not None else default_client() + self.comms_p2p = comms_p2p + if (nccl_root_location.lower() not in Comms.valid_nccl_placements): + raise ValueError(f"nccl_root_location must be one of: {Comms.valid_nccl_placements}") + self.nccl_root_location = nccl_root_location.lower() + self.streams_per_handle = streams_per_handle self.sessionId = uuid.uuid4().bytes @@ -149,7 +159,12 @@ def init(self, workers=None): worker_info = self.worker_info(self.worker_addresses) worker_info = {w: worker_info[w] for w in self.worker_addresses} - self.uniqueId = nccl.get_unique_id() + if (self.nccl_root_location == 'client'): + self.uniqueId = nccl.get_unique_id() + elif (self.nccl_root_location == 'worker'): + self.uniqueId = self.client.run(get_unique_id_on_worker, sessionId=self.sessionId) + else: + self.uniqueId = self.client.run_on_scheduler(_func_set_scheduler_as_nccl_root, sessionId=self.sessionId) self.client.run(_func_init_all, self.sessionId, @@ -182,6 +197,10 @@ def destroy(self): wait=True, workers=self.worker_addresses) + if (self.nccl_root_location == 'scheduler'): + self.client.run_on_scheduler(_func_destroy_scheduler_session, + self.sessionId) + if self.verbose: print("Destroying comms.") @@ -207,6 +226,27 @@ def local_handle(sessionId): return state["handle"] if "handle" in state else None +def scheduler_state(sessionId, dask_scheduler): + """ + Retrieves cuML comms state on the scheduler node, for the given sessionId, creating + a new session if it does not exist. If no session id is given, returns the state dict for + all sessions. + :param sessionId: SessionId value to retrieve from the dask_scheduler instances + :param dask_scheduler: Dask Scheduler object + :return: session state associated with sessionId + """ + + if (not hasattr(dask_scheduler, "_raft_comm_state")): + dask_scheduler._raft_comm_state = {} + + if (sessionId is not None and sessionId not in dask_scheduler._raft_comm_state): + dask_scheduler._raft_comm_state[sessionId] = { "ts": time.time() } + + return dask_scheduler._raft_comm_state[sessionId] + + return dask_scheduler._raft_comm_state + + def worker_state(sessionId=None): """ Retrieves cuML comms state on local worker for the given @@ -240,6 +280,38 @@ def get_ucx(): worker_state("ucp")["ucx"] = UCX.get() return worker_state("ucp")["ucx"] +def _func_destroy_scheduler_session(sessionId, dask_scheduler): + if (sessionId is not None and sessionId in dask_scheduler._raft_comm_state): + del dask_scheduler._raft_comm_state[sessionId] + + return 0 + +def _func_set_scheduler_as_nccl_root(sessionId, dask_scheduler): + """ + Creates a persistent nccl uniqueId on the scheduler node. + + Note: dask_scheduler should be passed by the scheduler, it does not need to be supplied to the run_on_scheduler + call. + + :param sessionId: Associated session to attach the unique ID to. + :param dask_scheduler: dask scheduler object, populated by the client/scheduler call + :return: + """ + if (sessionId is None): + raise ValueError("sessionId cannot be None.") + + session_state = scheduler_state(sessionId=sessionId, dask_scheduler=dask_scheduler) + if ('nccl_uid' not in session_state): + session_state['nccl_uid'] = nccl.get_unique_id() + + return session_state['nccl_uid'] + +# TODO +def _func_set_worker_as_nccl_root(sessionId, workerId): + pass + +def _func_destroy_worker_session(sessionId, workerId): + pass def _func_ucp_listener_port(): return get_ucx().listener_port() @@ -254,6 +326,7 @@ async def _func_init_all(sessionId, uniqueId, comms_p2p, session_state["nworkers"] = len(worker_info) if verbose: + # TODO: prints should be replaced with logging calls. print("Initializing NCCL") start = time.time() @@ -261,10 +334,12 @@ async def _func_init_all(sessionId, uniqueId, comms_p2p, if verbose: elapsed = time.time() - start + # TODO: prints should be replaced with logging calls. print("NCCL Initialization took: %f seconds." % elapsed) if comms_p2p: if verbose: + # TODO: prints should be replaced with logging calls. print("Initializing UCX Endpoints") if verbose: @@ -273,6 +348,7 @@ async def _func_init_all(sessionId, uniqueId, comms_p2p, if verbose: elapsed = time.time() - start + # TODO: prints should be replaced with logging calls. print("Done initializing UCX endpoints. Took: %f seconds." % elapsed) print("Building handle") @@ -280,6 +356,7 @@ async def _func_init_all(sessionId, uniqueId, comms_p2p, _func_build_handle_p2p(sessionId, streams_per_handle, verbose) if verbose: + # TODO: prints should be replaced with logging calls. print("Done building handle.") else: @@ -306,8 +383,10 @@ def _func_init_nccl(sessionId, uniqueId): n = nccl() n.init(nWorkers, uniqueId, wid) worker_state(sessionId)["nccl"] = n - except Exception: + except Exception as e: + # TODO: prints should be replaced with logging calls. print("An error occurred initializing NCCL!") + raise def _func_build_handle_p2p(sessionId, streams_per_handle, verbose): diff --git a/python/raft/dask/common/nccl.pyx b/python/raft/dask/common/nccl.pyx index d55a0e4c42..7fc813b515 100644 --- a/python/raft/dask/common/nccl.pyx +++ b/python/raft/dask/common/nccl.pyx @@ -140,18 +140,15 @@ cdef class nccl: cdef int r = rank cdef ncclResult_t result - import time - - start = time.time() with nogil: result = ncclCommInitRank(comm_, nr, deref(ident), r) - end = time.time() if result != ncclSuccess: with nogil: err_str = ncclGetErrorString(result) - print("NCCL_ERROR: %s" % err_str) + + raise RuntimeError("NCCL_ERROR: %s" % err_str) def destroy(self): """ @@ -164,13 +161,14 @@ cdef class nccl: with nogil: result = ncclCommDestroy(deref(comm_)) + free(self.comm) + self.comm = NULL + if result != ncclSuccess: with nogil: err_str = ncclGetErrorString(result) - print("NCCL_ERROR: %s" % err_str) - free(self.comm) - self.comm = NULL + raise RuntimeError("NCCL_ERROR: %s" % err_str) def abort(self): """ @@ -182,12 +180,13 @@ cdef class nccl: with nogil: result = ncclCommAbort(deref(comm_)) + free(comm_) + self.comm = NULL + if result != ncclSuccess: with nogil: err_str = ncclGetErrorString(result) - print("NCCL_ERROR: %s" % err_str) - free(comm_) - self.comm = NULL + raise RuntimeError("NCCL_ERROR: %s" % err_str) def cu_device(self): """ @@ -204,13 +203,15 @@ cdef class nccl: with nogil: result = ncclCommCuDevice(deref(comm_), dev) + ret = dev[0] + free(dev) + if result != ncclSuccess: with nogil: err_str = ncclGetErrorString(result) - print("NCCL_ERROR: %s" % err_str) - ret = dev[0] - free(dev) + raise RuntimeError("NCCL_ERROR: %s" % err_str) + return ret def user_rank(self): @@ -230,13 +231,14 @@ cdef class nccl: with nogil: result = ncclCommUserRank(deref(comm_), rank) + ret = rank[0] + free(rank) + if result != ncclSuccess: with nogil: err_str = ncclGetErrorString(result) - print("NCCL_ERROR: %s" % err_str) + raise RuntimeError("NCCL_ERROR: %s" % err_str) - ret = rank[0] - free(rank) return ret def get_comm(self): From 79172059082e3d7dd0f5d22aa9dd3cba8a3ed004 Mon Sep 17 00:00:00 2001 From: Devin Robison Date: Tue, 12 Jan 2021 17:07:57 -0700 Subject: [PATCH 02/16] Add worker root node ability, and unit tests --- python/raft/dask/common/comms.py | 20 +++--- python/raft/test/test_comms.py | 101 ++++++++++++++++++++++++++----- 2 files changed, 100 insertions(+), 21 deletions(-) diff --git a/python/raft/dask/common/comms.py b/python/raft/dask/common/comms.py index 34c667fc6f..04f2211aad 100644 --- a/python/raft/dask/common/comms.py +++ b/python/raft/dask/common/comms.py @@ -162,7 +162,10 @@ def init(self, workers=None): if (self.nccl_root_location == 'client'): self.uniqueId = nccl.get_unique_id() elif (self.nccl_root_location == 'worker'): - self.uniqueId = self.client.run(get_unique_id_on_worker, sessionId=self.sessionId) + self.uniqueId = self.client.run(_func_set_worker_as_nccl_root, + self.sessionId, + workers=[self.worker_addresses[0]], + wait=True)[self.worker_addresses[0]] else: self.uniqueId = self.client.run_on_scheduler(_func_set_scheduler_as_nccl_root, sessionId=self.sessionId) @@ -240,7 +243,7 @@ def scheduler_state(sessionId, dask_scheduler): dask_scheduler._raft_comm_state = {} if (sessionId is not None and sessionId not in dask_scheduler._raft_comm_state): - dask_scheduler._raft_comm_state[sessionId] = { "ts": time.time() } + dask_scheduler._raft_comm_state[sessionId] = {"ts": time.time()} return dask_scheduler._raft_comm_state[sessionId] @@ -306,12 +309,15 @@ def _func_set_scheduler_as_nccl_root(sessionId, dask_scheduler): return session_state['nccl_uid'] -# TODO -def _func_set_worker_as_nccl_root(sessionId, workerId): - pass +def _func_set_worker_as_nccl_root(sessionId, workerId=0): + if (sessionId is None): + raise ValueError("sessionId cannot be None.") -def _func_destroy_worker_session(sessionId, workerId): - pass + session_state = worker_state(sessionId) + if ('nccl_uid' not in session_state): + session_state['nccl_uid'] = nccl.get_unique_id() + + return session_state['nccl_uid'] def _func_ucp_listener_port(): return get_ucx().listener_port() diff --git a/python/raft/test/test_comms.py b/python/raft/test/test_comms.py index 5dfe2243c0..feab00fee2 100644 --- a/python/raft/test/test_comms.py +++ b/python/raft/test/test_comms.py @@ -15,11 +15,14 @@ import pytest +from collections import OrderedDict + from dask.distributed import Client from dask.distributed import wait try: from raft.dask import Comms + from raft.dask.common import nccl from raft.dask.common import local_handle from raft.dask.common import perform_test_comms_send_recv from raft.dask.common import perform_test_comms_allreduce @@ -64,6 +67,45 @@ def func_test_comm_split(sessionId, n_trials): handle = local_handle(sessionId) return perform_test_comm_split(handle, n_trials) +def func_chk_uid_on_scheduler(sessionId, uniqueId, dask_scheduler): + if (not hasattr(dask_scheduler, '_raft_comm_state')): + return 1 + + state_object = dask_scheduler._raft_comm_state + if (sessionId not in state_object): + return 2 + + session_state = state_object[sessionId] + if ('nccl_uid' not in dask_scheduler._raft_comm_state[sessionId]): + return 3 + + nccl_uid = session_state['nccl_uid'] + if (nccl_uid != uniqueId): + return 4 + + return 0 + +def func_chk_uid_on_worker(sessionId, uniqueId): + from dask.distributed import get_worker + + worker_state = get_worker() + if (not hasattr(worker_state, '_raft_comm_state')): + return 1 + + state_object = worker_state._raft_comm_state + if (sessionId not in state_object): + return 2 + + session_state = state_object[sessionId] + if ('nccl_uid' not in session_state): + return 3 + + nccl_uid = session_state['nccl_uid'] + if (nccl_uid != uniqueId): + return 4 + + return 0 + def test_handles(cluster): @@ -100,26 +142,57 @@ def _has_handle(sessionId): functions = [None] +@pytest.mark.parametrize("root_location", ['client', 'worker', 'scheduler']) +def test_nccl_root_placement(client, root_location): + + cb = None + try: + cb = Comms(verbose=True, client=client, nccl_root_location=root_location) + cb.init() + + worker_addresses = list(OrderedDict.fromkeys( + client.scheduler_info()["workers"].keys())) + + if (root_location in ('worker',)): + result = client.run(func_chk_uid_on_worker, + cb.sessionId, + cb.uniqueId, + workers=[worker_addresses[0]])[worker_addresses[0]] + elif (root_location in ('scheduler',)): + result = client.run_on_scheduler(func_chk_uid_on_scheduler, cb.sessionId, cb.uniqueId) + else: + result = int(cb.uniqueId == None) + + assert (result == 0) + + finally: + if (cb): + cb.destroy() + @pytest.mark.parametrize("func", functions) +@pytest.mark.parametrize("root_location", ['client', 'worker', 'scheduler']) @pytest.mark.nccl -def test_collectives(client, func): - - cb = Comms(verbose=True) - cb.init() +def test_collectives(client, func, root_location): - for k, v in cb.worker_info(cb.worker_addresses).items(): + try: + cb = Comms(verbose=True, client=client, nccl_root_location=root_location) + cb.init() - dfs = [client.submit(func_test_collective, - func, - cb.sessionId, - v["rank"], - pure=False, - workers=[w]) - for w in cb.worker_addresses] - wait(dfs, timeout=5) + for k, v in cb.worker_info(cb.worker_addresses).items(): - assert all([x.result() for x in dfs]) + dfs = [client.submit(func_test_collective, + func, + cb.sessionId, + v["rank"], + pure=False, + workers=[w]) + for w in cb.worker_addresses] + wait(dfs, timeout=5) + assert all([x.result() for x in dfs]) + finally: + if (cb): + cb.destroy() @pytest.mark.nccl def test_comm_split(client): From 7e9458e9c1957b1bdce6c413e231dba259886cc0 Mon Sep 17 00:00:00 2001 From: Devin Robison Date: Wed, 13 Jan 2021 13:24:38 -0700 Subject: [PATCH 03/16] Add logging and improve verbose usage --- python/raft/dask/common/comms.py | 91 +++++++++++++++++++++++--------- python/raft/test/test_comms.py | 10 ++-- 2 files changed, 73 insertions(+), 28 deletions(-) diff --git a/python/raft/dask/common/comms.py b/python/raft/dask/common/comms.py index 04f2211aad..59c802da50 100644 --- a/python/raft/dask/common/comms.py +++ b/python/raft/dask/common/comms.py @@ -26,10 +26,13 @@ import warnings +import logging import time import uuid from collections import OrderedDict +logger = logging.getLogger(__name__) + class Comms: @@ -153,7 +156,8 @@ def init(self, workers=None): if workers is None else workers)) if self.nccl_initialized or self.ucx_initialized: - warnings.warn("Comms have already been initialized.") + msg = "Comms have already been initialized." + warnings.warn(msg) return worker_info = self.worker_info(self.worker_addresses) @@ -163,11 +167,14 @@ def init(self, workers=None): self.uniqueId = nccl.get_unique_id() elif (self.nccl_root_location == 'worker'): self.uniqueId = self.client.run(_func_set_worker_as_nccl_root, - self.sessionId, + sessionId=self.sessionId, + verbose=self.verbose, workers=[self.worker_addresses[0]], wait=True)[self.worker_addresses[0]] else: - self.uniqueId = self.client.run_on_scheduler(_func_set_scheduler_as_nccl_root, sessionId=self.sessionId) + self.uniqueId = self.client.run_on_scheduler(_func_set_scheduler_as_nccl_root, + sessionId=self.sessionId, + verbose=self.verbose) self.client.run(_func_init_all, self.sessionId, @@ -185,7 +192,7 @@ def init(self, workers=None): self.ucx_initialized = True if self.verbose: - print("Initialization complete.") + print("Initialization Complete") def destroy(self): """ @@ -205,7 +212,7 @@ def destroy(self): self.sessionId) if self.verbose: - print("Destroying comms.") + print("Destroying Comms.") self.nccl_initialized = False self.ucx_initialized = False @@ -289,7 +296,7 @@ def _func_destroy_scheduler_session(sessionId, dask_scheduler): return 0 -def _func_set_scheduler_as_nccl_root(sessionId, dask_scheduler): +def _func_set_scheduler_as_nccl_root(sessionId, verbose, dask_scheduler): """ Creates a persistent nccl uniqueId on the scheduler node. @@ -300,6 +307,9 @@ def _func_set_scheduler_as_nccl_root(sessionId, dask_scheduler): :param dask_scheduler: dask scheduler object, populated by the client/scheduler call :return: """ + if(verbose): + logger.info(msg=f"Setting scheduler as NCCL root for sessionId, '{sessionId}'") + if (sessionId is None): raise ValueError("sessionId cannot be None.") @@ -307,9 +317,16 @@ def _func_set_scheduler_as_nccl_root(sessionId, dask_scheduler): if ('nccl_uid' not in session_state): session_state['nccl_uid'] = nccl.get_unique_id() + if(verbose): + logger.info(f"Done setting scheduler as NCCL root.") + return session_state['nccl_uid'] -def _func_set_worker_as_nccl_root(sessionId, workerId=0): +def _func_set_worker_as_nccl_root(sessionId, verbose, workerId=0): + if(verbose): + get_worker().log_event(topic="info", + msg=f"Setting worker, '{workerId}', as NCCL root for session, '{sessionId}'") + if (sessionId is None): raise ValueError("sessionId cannot be None.") @@ -317,6 +334,10 @@ def _func_set_worker_as_nccl_root(sessionId, workerId=0): if ('nccl_uid' not in session_state): session_state['nccl_uid'] = nccl.get_unique_id() + if(verbose): + get_worker().log_event(topic="info", + msg=f"Done setting scheduler as NCCL root.") + return session_state['nccl_uid'] def _func_ucp_listener_port(): @@ -332,21 +353,18 @@ async def _func_init_all(sessionId, uniqueId, comms_p2p, session_state["nworkers"] = len(worker_info) if verbose: - # TODO: prints should be replaced with logging calls. - print("Initializing NCCL") + get_worker().log_event(topic="info", msg="Initializing NCCL.") start = time.time() _func_init_nccl(sessionId, uniqueId) if verbose: elapsed = time.time() - start - # TODO: prints should be replaced with logging calls. - print("NCCL Initialization took: %f seconds." % elapsed) + get_worker().log_event(topic="info", msg=f"NCCL Initialization took: {elapsed} seconds.") if comms_p2p: if verbose: - # TODO: prints should be replaced with logging calls. - print("Initializing UCX Endpoints") + get_worker().log_event(topic="info", msg="Initializing UCX Endpoints") if verbose: start = time.time() @@ -354,16 +372,13 @@ async def _func_init_all(sessionId, uniqueId, comms_p2p, if verbose: elapsed = time.time() - start - # TODO: prints should be replaced with logging calls. - print("Done initializing UCX endpoints. Took: %f seconds." % - elapsed) - print("Building handle") + msg = f"Done initializing UCX endpoints. Took: {elapsed} seconds.\nBuilding handle." + get_worker().log_event(topic="info", msg=msg) _func_build_handle_p2p(sessionId, streams_per_handle, verbose) if verbose: - # TODO: prints should be replaced with logging calls. - print("Done building handle.") + get_worker().log_event(topic="info", msg="Done building handle.") else: _func_build_handle(sessionId, streams_per_handle, verbose) @@ -390,8 +405,7 @@ def _func_init_nccl(sessionId, uniqueId): n.init(nWorkers, uniqueId, wid) worker_state(sessionId)["nccl"] = n except Exception as e: - # TODO: prints should be replaced with logging calls. - print("An error occurred initializing NCCL!") + get_worker().log_event(topic="error", msg="An error occurred initializing NCCL!.") raise @@ -405,6 +419,9 @@ def _func_build_handle_p2p(sessionId, streams_per_handle, verbose): streams_per_handle : int number of internal streams to create verbose : bool print verbose logging output """ + if (verbose): + get_worker().log_event(topic="info", msg="Building p2p handle.") + ucp_worker = get_ucx().get_worker() session_state = worker_state(sessionId) @@ -414,9 +431,14 @@ def _func_build_handle_p2p(sessionId, streams_per_handle, verbose): nWorkers = session_state["nworkers"] workerId = session_state["wid"] + if (verbose): + get_worker().log_event(topic="info", msg="Injecting comms on handle.") inject_comms_on_handle(handle, nccl_comm, ucp_worker, eps, nWorkers, workerId, verbose) + if (verbose): + get_worker().log_event(topic="info", msg="Finished injecting comms on handle.") + worker_state(sessionId)["handle"] = handle @@ -430,6 +452,9 @@ def _func_build_handle(sessionId, streams_per_handle, verbose): streams_per_handle : int number of internal streams to create verbose : bool print verbose logging output """ + if (verbose): + get_worker().log_event(topic="info", msg="Finished injecting comms on handle.") + handle = Handle(streams_per_handle) session_state = worker_state(sessionId) @@ -444,6 +469,7 @@ def _func_build_handle(sessionId, streams_per_handle, verbose): def _func_store_initial_state(nworkers, sessionId, uniqueId, wid): + # TODO: We don't ever remove wid or nworkers... could cause problems? Maybe we should just blow away whole session session_state = worker_state(sessionId) session_state["nccl_uid"] = uniqueId session_state["wid"] = wid @@ -476,9 +502,26 @@ async def _func_ucp_create_endpoints(sessionId, worker_info): async def _func_destroy_all(sessionId, comms_p2p, verbose=False): - worker_state(sessionId)["nccl"].destroy() - del worker_state(sessionId)["nccl"] - del worker_state(sessionId)["handle"] + if(verbose): + get_worker().log_event(topic="info", msg="Destroying NCCL session state.") + session_state = worker_state(sessionId) + if ('nccl' in session_state): + session_state["nccl"].destroy() + del session_state["nccl"] + if (verbose): + get_worker().log_event(topic="info", msg="NCCL session state destroyed.") + else: + if (verbose): + get_worker().log_event(topic="info", msg=f"{sessionId} does not contain") + + if (verbose): + get_worker().log_event(topic="info", msg=f"Destroy CUDA handle for sessionId, '{sessionId}.'") + if ('handle' in session_state): + del session_state["handle"] + else: + if (verbose): + #TODO add logging for unexpected worker state + pass def _func_ucp_ports(client, workers): diff --git a/python/raft/test/test_comms.py b/python/raft/test/test_comms.py index feab00fee2..1eba64098a 100644 --- a/python/raft/test/test_comms.py +++ b/python/raft/test/test_comms.py @@ -67,7 +67,7 @@ def func_test_comm_split(sessionId, n_trials): handle = local_handle(sessionId) return perform_test_comm_split(handle, n_trials) -def func_chk_uid_on_scheduler(sessionId, uniqueId, dask_scheduler): +def func_check_uid_on_scheduler(sessionId, uniqueId, dask_scheduler): if (not hasattr(dask_scheduler, '_raft_comm_state')): return 1 @@ -85,7 +85,7 @@ def func_chk_uid_on_scheduler(sessionId, uniqueId, dask_scheduler): return 0 -def func_chk_uid_on_worker(sessionId, uniqueId): +def func_check_uid_on_worker(sessionId, uniqueId): from dask.distributed import get_worker worker_state = get_worker() @@ -154,12 +154,12 @@ def test_nccl_root_placement(client, root_location): client.scheduler_info()["workers"].keys())) if (root_location in ('worker',)): - result = client.run(func_chk_uid_on_worker, + result = client.run(func_check_uid_on_worker, cb.sessionId, cb.uniqueId, workers=[worker_addresses[0]])[worker_addresses[0]] elif (root_location in ('scheduler',)): - result = client.run_on_scheduler(func_chk_uid_on_scheduler, cb.sessionId, cb.uniqueId) + result = client.run_on_scheduler(func_check_uid_on_scheduler, cb.sessionId, cb.uniqueId) else: result = int(cb.uniqueId == None) @@ -169,6 +169,8 @@ def test_nccl_root_placement(client, root_location): if (cb): cb.destroy() +# TODO: Add negative case test for bad location type, and Comm init failure so we check cleanup routines. + @pytest.mark.parametrize("func", functions) @pytest.mark.parametrize("root_location", ['client', 'worker', 'scheduler']) @pytest.mark.nccl From 94c9d3b29f6ddcc4db26ab3121d2c4408fb264ad Mon Sep 17 00:00:00 2001 From: Devin Robison Date: Wed, 13 Jan 2021 13:47:18 -0700 Subject: [PATCH 04/16] Fix typos, add better docs info --- python/raft/dask/common/comms.py | 67 +++++++++++++++++++++++++------- 1 file changed, 52 insertions(+), 15 deletions(-) diff --git a/python/raft/dask/common/comms.py b/python/raft/dask/common/comms.py index 59c802da50..8ff9795f1e 100644 --- a/python/raft/dask/common/comms.py +++ b/python/raft/dask/common/comms.py @@ -156,8 +156,7 @@ def init(self, workers=None): if workers is None else workers)) if self.nccl_initialized or self.ucx_initialized: - msg = "Comms have already been initialized." - warnings.warn(msg) + warnings.warn("Comms have already been initialized.") return worker_info = self.worker_info(self.worker_addresses) @@ -192,7 +191,7 @@ def init(self, workers=None): self.ucx_initialized = True if self.verbose: - print("Initialization Complete") + print("Initialization complete.") def destroy(self): """ @@ -212,14 +211,15 @@ def destroy(self): self.sessionId) if self.verbose: - print("Destroying Comms.") + print("Destroying comms.") self.nccl_initialized = False self.ucx_initialized = False def local_handle(sessionId): - """Simple helper function for retrieving the local handle_t instance + """ + Simple helper function for retrieving the local handle_t instance for a comms session on a worker. Parameters @@ -241,9 +241,17 @@ def scheduler_state(sessionId, dask_scheduler): Retrieves cuML comms state on the scheduler node, for the given sessionId, creating a new session if it does not exist. If no session id is given, returns the state dict for all sessions. - :param sessionId: SessionId value to retrieve from the dask_scheduler instances - :param dask_scheduler: Dask Scheduler object - :return: session state associated with sessionId + + Parameters + ---------- + sessionId : SessionId value to retrieve from the dask_scheduler instances + dask_scheduler : Dask Scheduler object + + Returns + ------- + + session state : str + session state associated with sessionId """ if (not hasattr(dask_scheduler, "_raft_comm_state")): @@ -291,8 +299,18 @@ def get_ucx(): return worker_state("ucp")["ucx"] def _func_destroy_scheduler_session(sessionId, dask_scheduler): + """ + Remove session date from _raft_comm_state, associated with sessionId + + Parameters + ---------- + sessionId : session Id to be destroyed. + dask_scheduler : dask_scheduler object (Note: this is supplied by DASK, not the client) + """ if (sessionId is not None and sessionId in dask_scheduler._raft_comm_state): del dask_scheduler._raft_comm_state[sessionId] + else: + return 1 return 0 @@ -300,12 +318,17 @@ def _func_set_scheduler_as_nccl_root(sessionId, verbose, dask_scheduler): """ Creates a persistent nccl uniqueId on the scheduler node. - Note: dask_scheduler should be passed by the scheduler, it does not need to be supplied to the run_on_scheduler - call. - :param sessionId: Associated session to attach the unique ID to. - :param dask_scheduler: dask scheduler object, populated by the client/scheduler call - :return: + Parameters + ---------- + sessionId : Associated session to attach the unique ID to. + verbose : Indicates whether or not to emit additional information + dask_scheduler : dask scheduler object, (Note: this is supplied by DASK, not the client) + + Return + ------ + uniqueId : byte str + NCCL uniqueId, associating the DASK scheduler as its root node. """ if(verbose): logger.info(msg=f"Setting scheduler as NCCL root for sessionId, '{sessionId}'") @@ -322,10 +345,24 @@ def _func_set_scheduler_as_nccl_root(sessionId, verbose, dask_scheduler): return session_state['nccl_uid'] -def _func_set_worker_as_nccl_root(sessionId, verbose, workerId=0): +def _func_set_worker_as_nccl_root(sessionId, verbose): + """ + Creates a persistent nccl uniqueId on the scheduler node. + + + Parameters + ---------- + sessionId : Associated session to attach the unique ID to. + verbose : Indicates whether or not to emit additional information + + Return + ------ + uniqueId : byte str + NCCL uniqueId, associating this DASK worker as its root node. + """ if(verbose): get_worker().log_event(topic="info", - msg=f"Setting worker, '{workerId}', as NCCL root for session, '{sessionId}'") + msg=f"Setting worker as NCCL root for session, '{sessionId}'") if (sessionId is None): raise ValueError("sessionId cannot be None.") From 1cf6cc5cfe20437e4315274027d0cf370b3e20c4 Mon Sep 17 00:00:00 2001 From: Devin Robison Date: Wed, 13 Jan 2021 13:52:30 -0700 Subject: [PATCH 05/16] More logging updates --- python/raft/dask/common/comms.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/python/raft/dask/common/comms.py b/python/raft/dask/common/comms.py index 8ff9795f1e..0ba6fccdc5 100644 --- a/python/raft/dask/common/comms.py +++ b/python/raft/dask/common/comms.py @@ -549,7 +549,8 @@ async def _func_destroy_all(sessionId, comms_p2p, verbose=False): get_worker().log_event(topic="info", msg="NCCL session state destroyed.") else: if (verbose): - get_worker().log_event(topic="info", msg=f"{sessionId} does not contain") + get_worker().log_event(topic="warning", + msg=f"Session state for, '{sessionId}', does not contain expected 'nccl' element") if (verbose): get_worker().log_event(topic="info", msg=f"Destroy CUDA handle for sessionId, '{sessionId}.'") @@ -557,8 +558,8 @@ async def _func_destroy_all(sessionId, comms_p2p, verbose=False): del session_state["handle"] else: if (verbose): - #TODO add logging for unexpected worker state - pass + get_worker().log_event(topic="warning", + msg=f"Session state for, '{sessionId}', does not contain expected 'handle' element") def _func_ucp_ports(client, workers): From 28e4de1caef19f346a21f705ea990fa118ec31b2 Mon Sep 17 00:00:00 2001 From: Devin Robison Date: Wed, 13 Jan 2021 14:21:29 -0700 Subject: [PATCH 06/16] Update changelog --- CHANGELOG.md | 1 + 1 file changed, 1 insertion(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index 7d5cf682b9..4747bca4e2 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -23,6 +23,7 @@ - PR #103: Epsilon parameter for Cholesky rank one update - PR #100: Add divyegala as codeowner - PR #111: Cleanup gpuCI scripts +- PR #120: Update NCCL init process to support root node placement. ## Bug Fixes - PR #106: Specify dependency branches to avoid pip resolver failure From 3e7122cadfbc47ee6967ac814935b86415c3ff95 Mon Sep 17 00:00:00 2001 From: Devin Robison Date: Wed, 13 Jan 2021 14:53:36 -0700 Subject: [PATCH 07/16] Fix test_comms.py style issues --- python/raft/test/test_comms.py | 141 +++++++++++++++++++-------------- 1 file changed, 82 insertions(+), 59 deletions(-) diff --git a/python/raft/test/test_comms.py b/python/raft/test/test_comms.py index 1eba64098a..8c0daf359c 100644 --- a/python/raft/test/test_comms.py +++ b/python/raft/test/test_comms.py @@ -22,7 +22,6 @@ try: from raft.dask import Comms - from raft.dask.common import nccl from raft.dask.common import local_handle from raft.dask.common import perform_test_comms_send_recv from raft.dask.common import perform_test_comms_allreduce @@ -31,6 +30,7 @@ from raft.dask.common import perform_test_comms_allgather from raft.dask.common import perform_test_comms_reducescatter from raft.dask.common import perform_test_comm_split + pytestmark = pytest.mark.mg except ImportError: pytestmark = pytest.mark.skip @@ -67,41 +67,43 @@ def func_test_comm_split(sessionId, n_trials): handle = local_handle(sessionId) return perform_test_comm_split(handle, n_trials) + def func_check_uid_on_scheduler(sessionId, uniqueId, dask_scheduler): - if (not hasattr(dask_scheduler, '_raft_comm_state')): + if not hasattr(dask_scheduler, "_raft_comm_state"): return 1 state_object = dask_scheduler._raft_comm_state - if (sessionId not in state_object): + if sessionId not in state_object: return 2 session_state = state_object[sessionId] - if ('nccl_uid' not in dask_scheduler._raft_comm_state[sessionId]): + if "nccl_uid" not in dask_scheduler._raft_comm_state[sessionId]: return 3 - nccl_uid = session_state['nccl_uid'] - if (nccl_uid != uniqueId): + nccl_uid = session_state["nccl_uid"] + if nccl_uid != uniqueId: return 4 return 0 + def func_check_uid_on_worker(sessionId, uniqueId): from dask.distributed import get_worker worker_state = get_worker() - if (not hasattr(worker_state, '_raft_comm_state')): + if not hasattr(worker_state, "_raft_comm_state"): return 1 state_object = worker_state._raft_comm_state - if (sessionId not in state_object): + if sessionId not in state_object: return 2 session_state = state_object[sessionId] - if ('nccl_uid' not in session_state): + if "nccl_uid" not in session_state: return 3 - nccl_uid = session_state['nccl_uid'] - if (nccl_uid != uniqueId): + nccl_uid = session_state["nccl_uid"] + if nccl_uid != uniqueId: return 4 return 0 @@ -118,11 +120,10 @@ def _has_handle(sessionId): cb = Comms(verbose=True) cb.init() - dfs = [client.submit(_has_handle, - cb.sessionId, - pure=False, - workers=[w]) - for w in cb.worker_addresses] + dfs = [ + client.submit(_has_handle, cb.sessionId, pure=False, workers=[w]) + for w in cb.worker_addresses + ] wait(dfs, timeout=5) assert all(client.compute(dfs, sync=True)) @@ -132,82 +133,100 @@ def _has_handle(sessionId): client.close() -if pytestmark.markname != 'skip': - functions = [perform_test_comms_allgather, - perform_test_comms_allreduce, - perform_test_comms_bcast, - perform_test_comms_reduce, - perform_test_comms_reducescatter] +if pytestmark.markname != "skip": + functions = [ + perform_test_comms_allgather, + perform_test_comms_allreduce, + perform_test_comms_bcast, + perform_test_comms_reduce, + perform_test_comms_reducescatter, + ] else: functions = [None] -@pytest.mark.parametrize("root_location", ['client', 'worker', 'scheduler']) +@pytest.mark.parametrize("root_location", ["client", "worker", "scheduler"]) def test_nccl_root_placement(client, root_location): cb = None try: - cb = Comms(verbose=True, client=client, nccl_root_location=root_location) + cb = Comms( + verbose=True, client=client, nccl_root_location=root_location + ) cb.init() - worker_addresses = list(OrderedDict.fromkeys( - client.scheduler_info()["workers"].keys())) - - if (root_location in ('worker',)): - result = client.run(func_check_uid_on_worker, - cb.sessionId, - cb.uniqueId, - workers=[worker_addresses[0]])[worker_addresses[0]] - elif (root_location in ('scheduler',)): - result = client.run_on_scheduler(func_check_uid_on_scheduler, cb.sessionId, cb.uniqueId) + worker_addresses = list( + OrderedDict.fromkeys(client.scheduler_info()["workers"].keys()) + ) + + if root_location in ("worker",): + result = client.run( + func_check_uid_on_worker, + cb.sessionId, + cb.uniqueId, + workers=[worker_addresses[0]], + )[worker_addresses[0]] + elif root_location in ("scheduler",): + result = client.run_on_scheduler( + func_check_uid_on_scheduler, cb.sessionId, cb.uniqueId + ) else: result = int(cb.uniqueId == None) - assert (result == 0) + assert result == 0 finally: - if (cb): + if cb: cb.destroy() + # TODO: Add negative case test for bad location type, and Comm init failure so we check cleanup routines. + @pytest.mark.parametrize("func", functions) -@pytest.mark.parametrize("root_location", ['client', 'worker', 'scheduler']) +@pytest.mark.parametrize("root_location", ["client", "worker", "scheduler"]) @pytest.mark.nccl def test_collectives(client, func, root_location): try: - cb = Comms(verbose=True, client=client, nccl_root_location=root_location) + cb = Comms( + verbose=True, client=client, nccl_root_location=root_location + ) cb.init() for k, v in cb.worker_info(cb.worker_addresses).items(): - dfs = [client.submit(func_test_collective, - func, - cb.sessionId, - v["rank"], - pure=False, - workers=[w]) - for w in cb.worker_addresses] + dfs = [ + client.submit( + func_test_collective, + func, + cb.sessionId, + v["rank"], + pure=False, + workers=[w], + ) + for w in cb.worker_addresses + ] wait(dfs, timeout=5) assert all([x.result() for x in dfs]) finally: - if (cb): + if cb: cb.destroy() + @pytest.mark.nccl def test_comm_split(client): cb = Comms(comms_p2p=True, verbose=True) cb.init() - dfs = [client.submit(func_test_comm_split, - cb.sessionId, - 3, - pure=False, - workers=[w]) - for w in cb.worker_addresses] + dfs = [ + client.submit( + func_test_comm_split, cb.sessionId, 3, pure=False, workers=[w] + ) + for w in cb.worker_addresses + ] wait(dfs, timeout=5) @@ -221,13 +240,17 @@ def test_send_recv(n_trials, client): cb = Comms(comms_p2p=True, verbose=True) cb.init() - dfs = [client.submit(func_test_send_recv, - cb.sessionId, - n_trials, - pure=False, - workers=[w]) - for w in cb.worker_addresses] + dfs = [ + client.submit( + func_test_send_recv, + cb.sessionId, + n_trials, + pure=False, + workers=[w], + ) + for w in cb.worker_addresses + ] wait(dfs, timeout=5) - assert(list(map(lambda x: x.result(), dfs))) + assert list(map(lambda x: x.result(), dfs)) From bfc1b715124452c5bac0da5a2a542b324b309d50 Mon Sep 17 00:00:00 2001 From: Devin Robison Date: Wed, 13 Jan 2021 15:03:59 -0700 Subject: [PATCH 08/16] Style updates --- python/raft/dask/common/comms.py | 81 ++++++++++++++++++++------------ python/raft/test/test_comms.py | 4 +- 2 files changed, 51 insertions(+), 34 deletions(-) diff --git a/python/raft/dask/common/comms.py b/python/raft/dask/common/comms.py index 0ba6fccdc5..54e5926b83 100644 --- a/python/raft/dask/common/comms.py +++ b/python/raft/dask/common/comms.py @@ -96,7 +96,8 @@ def __init__(self, comms_p2p=False, client=None, verbose=False, verbose : bool Print verbose logging nccl_root_location : string - Indicates where the NCCL's root node should be located. ['client', 'worker', 'scheduler' (default)] + Indicates where the NCCL's root node should be located. + ['client', 'worker', 'scheduler' (default)] """ self.client = client if client is not None else default_client() @@ -104,7 +105,8 @@ def __init__(self, comms_p2p=False, client=None, verbose=False, self.comms_p2p = comms_p2p if (nccl_root_location.lower() not in Comms.valid_nccl_placements): - raise ValueError(f"nccl_root_location must be one of: {Comms.valid_nccl_placements}") + raise ValueError(f"nccl_root_location must be one of: " + f"{Comms.valid_nccl_placements}") self.nccl_root_location = nccl_root_location.lower() self.streams_per_handle = streams_per_handle @@ -166,14 +168,15 @@ def init(self, workers=None): self.uniqueId = nccl.get_unique_id() elif (self.nccl_root_location == 'worker'): self.uniqueId = self.client.run(_func_set_worker_as_nccl_root, - sessionId=self.sessionId, - verbose=self.verbose, - workers=[self.worker_addresses[0]], - wait=True)[self.worker_addresses[0]] + sessionId=self.sessionId, + verbose=self.verbose, + workers=[self.worker_addresses[0]], + wait=True)[self.worker_addresses[0]] else: - self.uniqueId = self.client.run_on_scheduler(_func_set_scheduler_as_nccl_root, - sessionId=self.sessionId, - verbose=self.verbose) + self.uniqueId = self.client.run_on_scheduler( + _func_set_scheduler_as_nccl_root, + sessionId=self.sessionId, + verbose=self.verbose) self.client.run(_func_init_all, self.sessionId, @@ -238,9 +241,9 @@ def local_handle(sessionId): def scheduler_state(sessionId, dask_scheduler): """ - Retrieves cuML comms state on the scheduler node, for the given sessionId, creating - a new session if it does not exist. If no session id is given, returns the state dict for - all sessions. + Retrieves cuML comms state on the scheduler node, for the given sessionId, + creating a new session if it does not exist. If no session id is given, + returns the state dict for all sessions. Parameters ---------- @@ -257,7 +260,8 @@ def scheduler_state(sessionId, dask_scheduler): if (not hasattr(dask_scheduler, "_raft_comm_state")): dask_scheduler._raft_comm_state = {} - if (sessionId is not None and sessionId not in dask_scheduler._raft_comm_state): + if (sessionId is not None + and sessionId not in dask_scheduler._raft_comm_state): dask_scheduler._raft_comm_state[sessionId] = {"ts": time.time()} return dask_scheduler._raft_comm_state[sessionId] @@ -305,9 +309,11 @@ def _func_destroy_scheduler_session(sessionId, dask_scheduler): Parameters ---------- sessionId : session Id to be destroyed. - dask_scheduler : dask_scheduler object (Note: this is supplied by DASK, not the client) + dask_scheduler : dask_scheduler object + (Note: this is supplied by DASK, not the client) """ - if (sessionId is not None and sessionId in dask_scheduler._raft_comm_state): + if (sessionId is not None + and sessionId in dask_scheduler._raft_comm_state): del dask_scheduler._raft_comm_state[sessionId] else: return 1 @@ -323,7 +329,8 @@ def _func_set_scheduler_as_nccl_root(sessionId, verbose, dask_scheduler): ---------- sessionId : Associated session to attach the unique ID to. verbose : Indicates whether or not to emit additional information - dask_scheduler : dask scheduler object, (Note: this is supplied by DASK, not the client) + dask_scheduler : dask scheduler object, + (Note: this is supplied by DASK, not the client) Return ------ @@ -331,12 +338,14 @@ def _func_set_scheduler_as_nccl_root(sessionId, verbose, dask_scheduler): NCCL uniqueId, associating the DASK scheduler as its root node. """ if(verbose): - logger.info(msg=f"Setting scheduler as NCCL root for sessionId, '{sessionId}'") + logger.info(msg=f"Setting scheduler as NCCL " + f"root for sessionId, '{sessionId}'") if (sessionId is None): raise ValueError("sessionId cannot be None.") - session_state = scheduler_state(sessionId=sessionId, dask_scheduler=dask_scheduler) + session_state = scheduler_state(sessionId=sessionId, + dask_scheduler=dask_scheduler) if ('nccl_uid' not in session_state): session_state['nccl_uid'] = nccl.get_unique_id() @@ -362,7 +371,7 @@ def _func_set_worker_as_nccl_root(sessionId, verbose): """ if(verbose): get_worker().log_event(topic="info", - msg=f"Setting worker as NCCL root for session, '{sessionId}'") + msg=f"Setting worker as NCCL root for session, '{sessionId}'") if (sessionId is None): raise ValueError("sessionId cannot be None.") @@ -397,11 +406,13 @@ async def _func_init_all(sessionId, uniqueId, comms_p2p, if verbose: elapsed = time.time() - start - get_worker().log_event(topic="info", msg=f"NCCL Initialization took: {elapsed} seconds.") + get_worker().log_event(topic="info", + msg=f"NCCL Initialization took: {elapsed} seconds.") if comms_p2p: if verbose: - get_worker().log_event(topic="info", msg="Initializing UCX Endpoints") + get_worker().log_event(topic="info", + msg="Initializing UCX Endpoints") if verbose: start = time.time() @@ -409,7 +420,8 @@ async def _func_init_all(sessionId, uniqueId, comms_p2p, if verbose: elapsed = time.time() - start - msg = f"Done initializing UCX endpoints. Took: {elapsed} seconds.\nBuilding handle." + msg = f"Done initializing UCX endpoints." \ + f"Took: {elapsed} seconds.\nBuilding handle." get_worker().log_event(topic="info", msg=msg) _func_build_handle_p2p(sessionId, streams_per_handle, verbose) @@ -442,7 +454,8 @@ def _func_init_nccl(sessionId, uniqueId): n.init(nWorkers, uniqueId, wid) worker_state(sessionId)["nccl"] = n except Exception as e: - get_worker().log_event(topic="error", msg="An error occurred initializing NCCL!.") + get_worker().log_event(topic="error", + msg="An error occurred initializing NCCL!.") raise @@ -474,7 +487,8 @@ def _func_build_handle_p2p(sessionId, streams_per_handle, verbose): nWorkers, workerId, verbose) if (verbose): - get_worker().log_event(topic="info", msg="Finished injecting comms on handle.") + get_worker().log_event(topic="info", + msg="Finished injecting comms on handle.") worker_state(sessionId)["handle"] = handle @@ -490,7 +504,8 @@ def _func_build_handle(sessionId, streams_per_handle, verbose): verbose : bool print verbose logging output """ if (verbose): - get_worker().log_event(topic="info", msg="Finished injecting comms on handle.") + get_worker().log_event(topic="info", + msg="Finished injecting comms on handle.") handle = Handle(streams_per_handle) @@ -506,7 +521,6 @@ def _func_build_handle(sessionId, streams_per_handle, verbose): def _func_store_initial_state(nworkers, sessionId, uniqueId, wid): - # TODO: We don't ever remove wid or nworkers... could cause problems? Maybe we should just blow away whole session session_state = worker_state(sessionId) session_state["nccl_uid"] = uniqueId session_state["wid"] = wid @@ -540,26 +554,31 @@ async def _func_ucp_create_endpoints(sessionId, worker_info): async def _func_destroy_all(sessionId, comms_p2p, verbose=False): if(verbose): - get_worker().log_event(topic="info", msg="Destroying NCCL session state.") + get_worker().log_event(topic="info", + msg="Destroying NCCL session state.") session_state = worker_state(sessionId) if ('nccl' in session_state): session_state["nccl"].destroy() del session_state["nccl"] if (verbose): - get_worker().log_event(topic="info", msg="NCCL session state destroyed.") + get_worker().log_event(topic="info", + msg="NCCL session state destroyed.") else: if (verbose): get_worker().log_event(topic="warning", - msg=f"Session state for, '{sessionId}', does not contain expected 'nccl' element") + msg=f"Session state for, '{sessionId}', " + f"does not contain expected 'nccl' element") if (verbose): - get_worker().log_event(topic="info", msg=f"Destroy CUDA handle for sessionId, '{sessionId}.'") + get_worker().log_event(topic="info", + msg=f"Destroy CUDA handle for sessionId, '{sessionId}.'") if ('handle' in session_state): del session_state["handle"] else: if (verbose): get_worker().log_event(topic="warning", - msg=f"Session state for, '{sessionId}', does not contain expected 'handle' element") + msg=f"Session state for, '{sessionId}', " + f"does not contain expected 'handle' element") def _func_ucp_ports(client, workers): diff --git a/python/raft/test/test_comms.py b/python/raft/test/test_comms.py index 8c0daf359c..e65976c49d 100644 --- a/python/raft/test/test_comms.py +++ b/python/raft/test/test_comms.py @@ -171,7 +171,7 @@ def test_nccl_root_placement(client, root_location): func_check_uid_on_scheduler, cb.sessionId, cb.uniqueId ) else: - result = int(cb.uniqueId == None) + result = int(cb.uniqueId is None) assert result == 0 @@ -180,8 +180,6 @@ def test_nccl_root_placement(client, root_location): cb.destroy() -# TODO: Add negative case test for bad location type, and Comm init failure so we check cleanup routines. - @pytest.mark.parametrize("func", functions) @pytest.mark.parametrize("root_location", ["client", "worker", "scheduler"]) From f71827d5bb60bd72a0476d5fc9e63be23c47f931 Mon Sep 17 00:00:00 2001 From: Devin Robison Date: Wed, 13 Jan 2021 15:08:27 -0700 Subject: [PATCH 09/16] Style updates --- python/raft/dask/common/comms.py | 267 ++++++++++++++++++------------- 1 file changed, 158 insertions(+), 109 deletions(-) diff --git a/python/raft/dask/common/comms.py b/python/raft/dask/common/comms.py index 54e5926b83..795bd5f6df 100644 --- a/python/raft/dask/common/comms.py +++ b/python/raft/dask/common/comms.py @@ -80,10 +80,16 @@ def _use_comms(sessionId): cluster.close() """ - valid_nccl_placements = ('client', 'worker', 'scheduler') - - def __init__(self, comms_p2p=False, client=None, verbose=False, - streams_per_handle=0, nccl_root_location="scheduler"): + valid_nccl_placements = ("client", "worker", "scheduler") + + def __init__( + self, + comms_p2p=False, + client=None, + verbose=False, + streams_per_handle=0, + nccl_root_location="scheduler", + ): """ Construct a new CommsContext instance @@ -104,9 +110,11 @@ def __init__(self, comms_p2p=False, client=None, verbose=False, self.comms_p2p = comms_p2p - if (nccl_root_location.lower() not in Comms.valid_nccl_placements): - raise ValueError(f"nccl_root_location must be one of: " - f"{Comms.valid_nccl_placements}") + if nccl_root_location.lower() not in Comms.valid_nccl_placements: + raise ValueError( + f"nccl_root_location must be one of: " + f"{Comms.valid_nccl_placements}" + ) self.nccl_root_location = nccl_root_location.lower() self.streams_per_handle = streams_per_handle @@ -131,8 +139,9 @@ def worker_info(self, workers): (worker_rank, worker_port ) } """ ranks = _func_worker_ranks(workers) - ports = _func_ucp_ports(self.client, workers) \ - if self.comms_p2p else None + ports = ( + _func_ucp_ports(self.client, workers) if self.comms_p2p else None + ) output = {} for k in ranks.keys(): @@ -153,9 +162,13 @@ def init(self, workers=None): Unique collection of workers for initializing comms. """ - self.worker_addresses = list(OrderedDict.fromkeys( - self.client.scheduler_info()["workers"].keys() - if workers is None else workers)) + self.worker_addresses = list( + OrderedDict.fromkeys( + self.client.scheduler_info()["workers"].keys() + if workers is None + else workers + ) + ) if self.nccl_initialized or self.ucx_initialized: warnings.warn("Comms have already been initialized.") @@ -164,29 +177,34 @@ def init(self, workers=None): worker_info = self.worker_info(self.worker_addresses) worker_info = {w: worker_info[w] for w in self.worker_addresses} - if (self.nccl_root_location == 'client'): + if self.nccl_root_location == "client": self.uniqueId = nccl.get_unique_id() - elif (self.nccl_root_location == 'worker'): - self.uniqueId = self.client.run(_func_set_worker_as_nccl_root, - sessionId=self.sessionId, - verbose=self.verbose, - workers=[self.worker_addresses[0]], - wait=True)[self.worker_addresses[0]] + elif self.nccl_root_location == "worker": + self.uniqueId = self.client.run( + _func_set_worker_as_nccl_root, + sessionId=self.sessionId, + verbose=self.verbose, + workers=[self.worker_addresses[0]], + wait=True, + )[self.worker_addresses[0]] else: self.uniqueId = self.client.run_on_scheduler( _func_set_scheduler_as_nccl_root, sessionId=self.sessionId, - verbose=self.verbose) - - self.client.run(_func_init_all, - self.sessionId, - self.uniqueId, - self.comms_p2p, - worker_info, - self.verbose, - self.streams_per_handle, - workers=self.worker_addresses, - wait=True) + verbose=self.verbose, + ) + + self.client.run( + _func_init_all, + self.sessionId, + self.uniqueId, + self.comms_p2p, + worker_info, + self.verbose, + self.streams_per_handle, + workers=self.worker_addresses, + wait=True, + ) self.nccl_initialized = True @@ -202,16 +220,19 @@ def destroy(self): be called automatically by the Comms destructor, but may be called earlier to save resources. """ - self.client.run(_func_destroy_all, - self.sessionId, - self.comms_p2p, - self.verbose, - wait=True, - workers=self.worker_addresses) - - if (self.nccl_root_location == 'scheduler'): - self.client.run_on_scheduler(_func_destroy_scheduler_session, - self.sessionId) + self.client.run( + _func_destroy_all, + self.sessionId, + self.comms_p2p, + self.verbose, + wait=True, + workers=self.worker_addresses, + ) + + if self.nccl_root_location == "scheduler": + self.client.run_on_scheduler( + _func_destroy_scheduler_session, self.sessionId + ) if self.verbose: print("Destroying comms.") @@ -257,11 +278,13 @@ def scheduler_state(sessionId, dask_scheduler): session state associated with sessionId """ - if (not hasattr(dask_scheduler, "_raft_comm_state")): + if not hasattr(dask_scheduler, "_raft_comm_state"): dask_scheduler._raft_comm_state = {} - if (sessionId is not None - and sessionId not in dask_scheduler._raft_comm_state): + if ( + sessionId is not None + and sessionId not in dask_scheduler._raft_comm_state + ): dask_scheduler._raft_comm_state[sessionId] = {"ts": time.time()} return dask_scheduler._raft_comm_state[sessionId] @@ -302,6 +325,7 @@ def get_ucx(): worker_state("ucp")["ucx"] = UCX.get() return worker_state("ucp")["ucx"] + def _func_destroy_scheduler_session(sessionId, dask_scheduler): """ Remove session date from _raft_comm_state, associated with sessionId @@ -312,14 +336,14 @@ def _func_destroy_scheduler_session(sessionId, dask_scheduler): dask_scheduler : dask_scheduler object (Note: this is supplied by DASK, not the client) """ - if (sessionId is not None - and sessionId in dask_scheduler._raft_comm_state): + if sessionId is not None and sessionId in dask_scheduler._raft_comm_state: del dask_scheduler._raft_comm_state[sessionId] else: return 1 return 0 + def _func_set_scheduler_as_nccl_root(sessionId, verbose, dask_scheduler): """ Creates a persistent nccl uniqueId on the scheduler node. @@ -337,22 +361,26 @@ def _func_set_scheduler_as_nccl_root(sessionId, verbose, dask_scheduler): uniqueId : byte str NCCL uniqueId, associating the DASK scheduler as its root node. """ - if(verbose): - logger.info(msg=f"Setting scheduler as NCCL " - f"root for sessionId, '{sessionId}'") + if verbose: + logger.info( + msg=f"Setting scheduler as NCCL " + f"root for sessionId, '{sessionId}'" + ) - if (sessionId is None): + if sessionId is None: raise ValueError("sessionId cannot be None.") - session_state = scheduler_state(sessionId=sessionId, - dask_scheduler=dask_scheduler) - if ('nccl_uid' not in session_state): - session_state['nccl_uid'] = nccl.get_unique_id() + session_state = scheduler_state( + sessionId=sessionId, dask_scheduler=dask_scheduler + ) + if "nccl_uid" not in session_state: + session_state["nccl_uid"] = nccl.get_unique_id() - if(verbose): + if verbose: logger.info(f"Done setting scheduler as NCCL root.") - return session_state['nccl_uid'] + return session_state["nccl_uid"] + def _func_set_worker_as_nccl_root(sessionId, verbose): """ @@ -369,29 +397,34 @@ def _func_set_worker_as_nccl_root(sessionId, verbose): uniqueId : byte str NCCL uniqueId, associating this DASK worker as its root node. """ - if(verbose): - get_worker().log_event(topic="info", - msg=f"Setting worker as NCCL root for session, '{sessionId}'") + if verbose: + get_worker().log_event( + topic="info", + msg=f"Setting worker as NCCL root for session, '{sessionId}'", + ) - if (sessionId is None): + if sessionId is None: raise ValueError("sessionId cannot be None.") session_state = worker_state(sessionId) - if ('nccl_uid' not in session_state): - session_state['nccl_uid'] = nccl.get_unique_id() + if "nccl_uid" not in session_state: + session_state["nccl_uid"] = nccl.get_unique_id() + + if verbose: + get_worker().log_event( + topic="info", msg=f"Done setting scheduler as NCCL root." + ) - if(verbose): - get_worker().log_event(topic="info", - msg=f"Done setting scheduler as NCCL root.") + return session_state["nccl_uid"] - return session_state['nccl_uid'] def _func_ucp_listener_port(): return get_ucx().listener_port() -async def _func_init_all(sessionId, uniqueId, comms_p2p, - worker_info, verbose, streams_per_handle): +async def _func_init_all( + sessionId, uniqueId, comms_p2p, worker_info, verbose, streams_per_handle +): session_state = worker_state(sessionId) session_state["nccl_uid"] = uniqueId @@ -406,13 +439,15 @@ async def _func_init_all(sessionId, uniqueId, comms_p2p, if verbose: elapsed = time.time() - start - get_worker().log_event(topic="info", - msg=f"NCCL Initialization took: {elapsed} seconds.") + get_worker().log_event( + topic="info", msg=f"NCCL Initialization took: {elapsed} seconds." + ) if comms_p2p: if verbose: - get_worker().log_event(topic="info", - msg="Initializing UCX Endpoints") + get_worker().log_event( + topic="info", msg="Initializing UCX Endpoints" + ) if verbose: start = time.time() @@ -420,8 +455,10 @@ async def _func_init_all(sessionId, uniqueId, comms_p2p, if verbose: elapsed = time.time() - start - msg = f"Done initializing UCX endpoints." \ - f"Took: {elapsed} seconds.\nBuilding handle." + msg = ( + f"Done initializing UCX endpoints." + f"Took: {elapsed} seconds.\nBuilding handle." + ) get_worker().log_event(topic="info", msg=msg) _func_build_handle_p2p(sessionId, streams_per_handle, verbose) @@ -454,8 +491,9 @@ def _func_init_nccl(sessionId, uniqueId): n.init(nWorkers, uniqueId, wid) worker_state(sessionId)["nccl"] = n except Exception as e: - get_worker().log_event(topic="error", - msg="An error occurred initializing NCCL!.") + get_worker().log_event( + topic="error", msg="An error occurred initializing NCCL!." + ) raise @@ -469,7 +507,7 @@ def _func_build_handle_p2p(sessionId, streams_per_handle, verbose): streams_per_handle : int number of internal streams to create verbose : bool print verbose logging output """ - if (verbose): + if verbose: get_worker().log_event(topic="info", msg="Building p2p handle.") ucp_worker = get_ucx().get_worker() @@ -481,14 +519,16 @@ def _func_build_handle_p2p(sessionId, streams_per_handle, verbose): nWorkers = session_state["nworkers"] workerId = session_state["wid"] - if (verbose): + if verbose: get_worker().log_event(topic="info", msg="Injecting comms on handle.") - inject_comms_on_handle(handle, nccl_comm, ucp_worker, eps, - nWorkers, workerId, verbose) + inject_comms_on_handle( + handle, nccl_comm, ucp_worker, eps, nWorkers, workerId, verbose + ) - if (verbose): - get_worker().log_event(topic="info", - msg="Finished injecting comms on handle.") + if verbose: + get_worker().log_event( + topic="info", msg="Finished injecting comms on handle." + ) worker_state(sessionId)["handle"] = handle @@ -503,9 +543,10 @@ def _func_build_handle(sessionId, streams_per_handle, verbose): streams_per_handle : int number of internal streams to create verbose : bool print verbose logging output """ - if (verbose): - get_worker().log_event(topic="info", - msg="Finished injecting comms on handle.") + if verbose: + get_worker().log_event( + topic="info", msg="Finished injecting comms on handle." + ) handle = Handle(streams_per_handle) @@ -515,8 +556,9 @@ def _func_build_handle(sessionId, streams_per_handle, verbose): nWorkers = session_state["nworkers"] nccl_comm = session_state["nccl"] - inject_comms_on_handle_coll_only(handle, nccl_comm, nWorkers, - workerId, verbose) + inject_comms_on_handle_coll_only( + handle, nccl_comm, nWorkers, workerId, verbose + ) session_state["handle"] = handle @@ -553,37 +595,44 @@ async def _func_ucp_create_endpoints(sessionId, worker_info): async def _func_destroy_all(sessionId, comms_p2p, verbose=False): - if(verbose): - get_worker().log_event(topic="info", - msg="Destroying NCCL session state.") + if verbose: + get_worker().log_event( + topic="info", msg="Destroying NCCL session state." + ) session_state = worker_state(sessionId) - if ('nccl' in session_state): + if "nccl" in session_state: session_state["nccl"].destroy() del session_state["nccl"] - if (verbose): - get_worker().log_event(topic="info", - msg="NCCL session state destroyed.") + if verbose: + get_worker().log_event( + topic="info", msg="NCCL session state destroyed." + ) else: - if (verbose): - get_worker().log_event(topic="warning", - msg=f"Session state for, '{sessionId}', " - f"does not contain expected 'nccl' element") - - if (verbose): - get_worker().log_event(topic="info", - msg=f"Destroy CUDA handle for sessionId, '{sessionId}.'") - if ('handle' in session_state): + if verbose: + get_worker().log_event( + topic="warning", + msg=f"Session state for, '{sessionId}', " + f"does not contain expected 'nccl' element", + ) + + if verbose: + get_worker().log_event( + topic="info", + msg=f"Destroy CUDA handle for sessionId, '{sessionId}.'", + ) + if "handle" in session_state: del session_state["handle"] else: - if (verbose): - get_worker().log_event(topic="warning", - msg=f"Session state for, '{sessionId}', " - f"does not contain expected 'handle' element") + if verbose: + get_worker().log_event( + topic="warning", + msg=f"Session state for, '{sessionId}', " + f"does not contain expected 'handle' element", + ) def _func_ucp_ports(client, workers): - return client.run(_func_ucp_listener_port, - workers=workers) + return client.run(_func_ucp_listener_port, workers=workers) def _func_worker_ranks(workers): From 712a169528422177b9723f90640ffe0ec04986ed Mon Sep 17 00:00:00 2001 From: Devin Robison Date: Wed, 13 Jan 2021 15:10:19 -0700 Subject: [PATCH 10/16] Style updates --- python/raft/dask/common/comms.py | 6 +++--- python/raft/test/test_comms.py | 1 - 2 files changed, 3 insertions(+), 4 deletions(-) diff --git a/python/raft/dask/common/comms.py b/python/raft/dask/common/comms.py index 795bd5f6df..34d2872902 100644 --- a/python/raft/dask/common/comms.py +++ b/python/raft/dask/common/comms.py @@ -377,7 +377,7 @@ def _func_set_scheduler_as_nccl_root(sessionId, verbose, dask_scheduler): session_state["nccl_uid"] = nccl.get_unique_id() if verbose: - logger.info(f"Done setting scheduler as NCCL root.") + logger.info("Done setting scheduler as NCCL root.") return session_state["nccl_uid"] @@ -412,7 +412,7 @@ def _func_set_worker_as_nccl_root(sessionId, verbose): if verbose: get_worker().log_event( - topic="info", msg=f"Done setting scheduler as NCCL root." + topic="info", msg="Done setting scheduler as NCCL root." ) return session_state["nccl_uid"] @@ -492,7 +492,7 @@ def _func_init_nccl(sessionId, uniqueId): worker_state(sessionId)["nccl"] = n except Exception as e: get_worker().log_event( - topic="error", msg="An error occurred initializing NCCL!." + topic="error", msg=f"An error occurred initializing NCCL: {e}." ) raise diff --git a/python/raft/test/test_comms.py b/python/raft/test/test_comms.py index e65976c49d..0fb988f1dd 100644 --- a/python/raft/test/test_comms.py +++ b/python/raft/test/test_comms.py @@ -180,7 +180,6 @@ def test_nccl_root_placement(client, root_location): cb.destroy() - @pytest.mark.parametrize("func", functions) @pytest.mark.parametrize("root_location", ["client", "worker", "scheduler"]) @pytest.mark.nccl From dd8a9f4550277a004e5c039b2ab0c75fb7c52576 Mon Sep 17 00:00:00 2001 From: Devin Robison Date: Thu, 14 Jan 2021 15:34:58 -0700 Subject: [PATCH 11/16] Feedback updates --- python/raft/dask/common/comms.py | 39 +++++++++++++++++--------------- 1 file changed, 21 insertions(+), 18 deletions(-) diff --git a/python/raft/dask/common/comms.py b/python/raft/dask/common/comms.py index 34d2872902..fb22964acf 100644 --- a/python/raft/dask/common/comms.py +++ b/python/raft/dask/common/comms.py @@ -110,12 +110,12 @@ def __init__( self.comms_p2p = comms_p2p - if nccl_root_location.lower() not in Comms.valid_nccl_placements: + self.nccl_root_location = nccl_root_location.lower() + if self.nccl_root_location not in Comms.valid_nccl_placements: raise ValueError( f"nccl_root_location must be one of: " f"{Comms.valid_nccl_placements}" ) - self.nccl_root_location = nccl_root_location.lower() self.streams_per_handle = streams_per_handle @@ -150,6 +150,24 @@ def worker_info(self, workers): output[k]["port"] = ports[k] return output + def create_nccl_uniqueid(self): + if self.nccl_root_location == "client": + self.uniqueId = nccl.get_unique_id() + elif self.nccl_root_location == "worker": + self.uniqueId = self.client.run( + _func_set_worker_as_nccl_root, + sessionId=self.sessionId, + verbose=self.verbose, + workers=[self.worker_addresses[0]], + wait=True, + )[self.worker_addresses[0]] + else: + self.uniqueId = self.client.run_on_scheduler( + _func_set_scheduler_as_nccl_root, + sessionId=self.sessionId, + verbose=self.verbose, + ) + def init(self, workers=None): """ Initializes the underlying comms. NCCL is required but @@ -177,22 +195,7 @@ def init(self, workers=None): worker_info = self.worker_info(self.worker_addresses) worker_info = {w: worker_info[w] for w in self.worker_addresses} - if self.nccl_root_location == "client": - self.uniqueId = nccl.get_unique_id() - elif self.nccl_root_location == "worker": - self.uniqueId = self.client.run( - _func_set_worker_as_nccl_root, - sessionId=self.sessionId, - verbose=self.verbose, - workers=[self.worker_addresses[0]], - wait=True, - )[self.worker_addresses[0]] - else: - self.uniqueId = self.client.run_on_scheduler( - _func_set_scheduler_as_nccl_root, - sessionId=self.sessionId, - verbose=self.verbose, - ) + self.create_nccl_uniqueid() self.client.run( _func_init_all, From 0a97a497fc07f6903ff95736f455f94e10ef7003 Mon Sep 17 00:00:00 2001 From: Devin Robison Date: Thu, 14 Jan 2021 16:07:34 -0700 Subject: [PATCH 12/16] Feedback updates - consolidate code, merge worker_state and scheduler_state functions --- python/raft/dask/common/comms.py | 228 +++++++++++++++---------------- 1 file changed, 114 insertions(+), 114 deletions(-) diff --git a/python/raft/dask/common/comms.py b/python/raft/dask/common/comms.py index fb22964acf..e806788e11 100644 --- a/python/raft/dask/common/comms.py +++ b/python/raft/dask/common/comms.py @@ -133,23 +133,6 @@ def __del__(self): if self.nccl_initialized or self.ucx_initialized: self.destroy() - def worker_info(self, workers): - """ - Builds a dictionary of { (worker_address, worker_port) : - (worker_rank, worker_port ) } - """ - ranks = _func_worker_ranks(workers) - ports = ( - _func_ucp_ports(self.client, workers) if self.comms_p2p else None - ) - - output = {} - for k in ranks.keys(): - output[k] = {"rank": ranks[k]} - if self.comms_p2p: - output[k]["port"] = ports[k] - return output - def create_nccl_uniqueid(self): if self.nccl_root_location == "client": self.uniqueId = nccl.get_unique_id() @@ -168,6 +151,23 @@ def create_nccl_uniqueid(self): verbose=self.verbose, ) + def worker_info(self, workers): + """ + Builds a dictionary of { (worker_address, worker_port) : + (worker_rank, worker_port ) } + """ + ranks = _func_worker_ranks(workers) + ports = ( + _func_ucp_ports(self.client, workers) if self.comms_p2p else None + ) + + output = {} + for k in ranks.keys(): + output[k] = {"rank": ranks[k]} + if self.comms_p2p: + output[k]["port"] = ports[k] + return output + def init(self, workers=None): """ Initializes the underlying comms. NCCL is required but @@ -259,11 +259,11 @@ def local_handle(sessionId): handle : raft.Handle or None """ - state = worker_state(sessionId) + state = get_raft_comm_state(sessionId, get_worker()) return state["handle"] if "handle" in state else None -def scheduler_state(sessionId, dask_scheduler): +def get_raft_comm_state(sessionId, state_object): """ Retrieves cuML comms state on the scheduler node, for the given sessionId, creating a new session if it does not exist. If no session id is given, @@ -272,7 +272,8 @@ def scheduler_state(sessionId, dask_scheduler): Parameters ---------- sessionId : SessionId value to retrieve from the dask_scheduler instances - dask_scheduler : Dask Scheduler object + state_object : Object (either Worker, or Scheduler) on which the raft + comm state will retrieved (or created) Returns ------- @@ -281,42 +282,33 @@ def scheduler_state(sessionId, dask_scheduler): session state associated with sessionId """ - if not hasattr(dask_scheduler, "_raft_comm_state"): - dask_scheduler._raft_comm_state = {} + if not hasattr(state_object, "_raft_comm_state"): + state_object._raft_comm_state = {} if ( - sessionId is not None - and sessionId not in dask_scheduler._raft_comm_state + sessionId is not None + and sessionId not in state_object._raft_comm_state ): - dask_scheduler._raft_comm_state[sessionId] = {"ts": time.time()} + state_object._raft_comm_state[sessionId] = {"ts": time.time()} - return dask_scheduler._raft_comm_state[sessionId] + if (sessionId is not None): + return state_object._raft_comm_state[sessionId] - return dask_scheduler._raft_comm_state + return state_object._raft_comm_state -def worker_state(sessionId=None): - """ - Retrieves cuML comms state on local worker for the given - sessionId, creating a new session if it does not exist. - If no session id is given, returns the state dict for all - sessions. +def set_nccl_root(sessionId, state_object): + if sessionId is None: + raise ValueError("sessionId cannot be None.") - Parameters - ---------- - sessionId : str - session identifier from initialized comms instance - """ - worker = get_worker() - if not hasattr(worker, "_raft_comm_state"): - worker._raft_comm_state = {} - if sessionId is not None and sessionId not in worker._raft_comm_state: - # Build state for new session and mark session creation time - worker._raft_comm_state[sessionId] = {"ts": time.time()} + raft_comm_state = get_raft_comm_state( + sessionId=sessionId, state_object=state_object + ) - if sessionId is not None: - return worker._raft_comm_state[sessionId] - return worker._raft_comm_state + if "nccl_uid" not in raft_comm_state: + raft_comm_state["nccl_uid"] = nccl.get_unique_id() + + return raft_comm_state["nccl_uid"] def get_ucx(): @@ -324,9 +316,12 @@ def get_ucx(): A simple convenience wrapper to make sure UCP listener and endpoints are only ever assigned once per worker. """ - if "ucx" not in worker_state("ucp"): - worker_state("ucp")["ucx"] = UCX.get() - return worker_state("ucp")["ucx"] + raft_comm_state = get_raft_comm_state(sessionId="ucp", + state_object=get_worker()) + if "ucx" not in raft_comm_state: + raft_comm_state["ucx"] = UCX.get() + + return raft_comm_state["ucx"] def _func_destroy_scheduler_session(sessionId, dask_scheduler): @@ -370,19 +365,12 @@ def _func_set_scheduler_as_nccl_root(sessionId, verbose, dask_scheduler): f"root for sessionId, '{sessionId}'" ) - if sessionId is None: - raise ValueError("sessionId cannot be None.") - - session_state = scheduler_state( - sessionId=sessionId, dask_scheduler=dask_scheduler - ) - if "nccl_uid" not in session_state: - session_state["nccl_uid"] = nccl.get_unique_id() + nccl_uid = set_nccl_root(sessionId=sessionId, state_object=dask_scheduler) if verbose: logger.info("Done setting scheduler as NCCL root.") - return session_state["nccl_uid"] + return nccl_uid def _func_set_worker_as_nccl_root(sessionId, verbose): @@ -400,25 +388,21 @@ def _func_set_worker_as_nccl_root(sessionId, verbose): uniqueId : byte str NCCL uniqueId, associating this DASK worker as its root node. """ + worker = get_worker() if verbose: - get_worker().log_event( + worker.log_event( topic="info", msg=f"Setting worker as NCCL root for session, '{sessionId}'", ) - if sessionId is None: - raise ValueError("sessionId cannot be None.") - - session_state = worker_state(sessionId) - if "nccl_uid" not in session_state: - session_state["nccl_uid"] = nccl.get_unique_id() + nccl_uid = set_nccl_root(sessionId=sessionId, state_object=worker) if verbose: - get_worker().log_event( + worker.log_event( topic="info", msg="Done setting scheduler as NCCL root." ) - return session_state["nccl_uid"] + return nccl_uid def _func_ucp_listener_port(): @@ -428,27 +412,28 @@ def _func_ucp_listener_port(): async def _func_init_all( sessionId, uniqueId, comms_p2p, worker_info, verbose, streams_per_handle ): - - session_state = worker_state(sessionId) - session_state["nccl_uid"] = uniqueId - session_state["wid"] = worker_info[get_worker().address]["rank"] - session_state["nworkers"] = len(worker_info) + worker = get_worker() + raft_comm_state = get_raft_comm_state(sessionId=sessionId, + state_object=worker) + raft_comm_state["nccl_uid"] = uniqueId + raft_comm_state["wid"] = worker_info[get_worker().address]["rank"] + raft_comm_state["nworkers"] = len(worker_info) if verbose: - get_worker().log_event(topic="info", msg="Initializing NCCL.") + worker.log_event(topic="info", msg="Initializing NCCL.") start = time.time() _func_init_nccl(sessionId, uniqueId) if verbose: elapsed = time.time() - start - get_worker().log_event( + worker.log_event( topic="info", msg=f"NCCL Initialization took: {elapsed} seconds." ) if comms_p2p: if verbose: - get_worker().log_event( + worker.log_event( topic="info", msg="Initializing UCX Endpoints" ) @@ -462,12 +447,12 @@ async def _func_init_all( f"Done initializing UCX endpoints." f"Took: {elapsed} seconds.\nBuilding handle." ) - get_worker().log_event(topic="info", msg=msg) + worker.log_event(topic="info", msg=msg) _func_build_handle_p2p(sessionId, streams_per_handle, verbose) if verbose: - get_worker().log_event(topic="info", msg="Done building handle.") + worker.log_event(topic="info", msg="Done building handle.") else: _func_build_handle(sessionId, streams_per_handle, verbose) @@ -486,15 +471,18 @@ def _func_init_nccl(sessionId, uniqueId): client. """ - wid = worker_state(sessionId)["wid"] - nWorkers = worker_state(sessionId)["nworkers"] + worker = get_worker() + raft_comm_state = get_raft_comm_state(sessionId=sessionId, + state_object=get_worker()) + wid = raft_comm_state["wid"] + nWorkers = raft_comm_state["nworkers"] try: n = nccl() n.init(nWorkers, uniqueId, wid) - worker_state(sessionId)["nccl"] = n + raft_comm_state["nccl"] = n except Exception as e: - get_worker().log_event( + worker.log_event( topic="error", msg=f"An error occurred initializing NCCL: {e}." ) raise @@ -510,30 +498,33 @@ def _func_build_handle_p2p(sessionId, streams_per_handle, verbose): streams_per_handle : int number of internal streams to create verbose : bool print verbose logging output """ + worker = get_worker() if verbose: - get_worker().log_event(topic="info", msg="Building p2p handle.") + worker.log_event(topic="info", msg="Building p2p handle.") ucp_worker = get_ucx().get_worker() - session_state = worker_state(sessionId) + raft_comm_state = get_raft_comm_state(sessionId=sessionId, + state_object=worker) handle = Handle(streams_per_handle) - nccl_comm = session_state["nccl"] - eps = session_state["ucp_eps"] - nWorkers = session_state["nworkers"] - workerId = session_state["wid"] + nccl_comm = raft_comm_state["nccl"] + eps = raft_comm_state["ucp_eps"] + nWorkers = raft_comm_state["nworkers"] + workerId = raft_comm_state["wid"] if verbose: - get_worker().log_event(topic="info", msg="Injecting comms on handle.") + worker.log_event(topic="info", msg="Injecting comms on handle.") + inject_comms_on_handle( handle, nccl_comm, ucp_worker, eps, nWorkers, workerId, verbose ) if verbose: - get_worker().log_event( + worker.log_event( topic="info", msg="Finished injecting comms on handle." ) - worker_state(sessionId)["handle"] = handle + raft_comm_state["handle"] = handle def _func_build_handle(sessionId, streams_per_handle, verbose): @@ -546,30 +537,33 @@ def _func_build_handle(sessionId, streams_per_handle, verbose): streams_per_handle : int number of internal streams to create verbose : bool print verbose logging output """ + worker = get_worker() if verbose: - get_worker().log_event( + worker.log_event( topic="info", msg="Finished injecting comms on handle." ) handle = Handle(streams_per_handle) - session_state = worker_state(sessionId) + raft_comm_state = get_raft_comm_state(sessionId=sessionId, + state_object=worker) - workerId = session_state["wid"] - nWorkers = session_state["nworkers"] + workerId = raft_comm_state["wid"] + nWorkers = raft_comm_state["nworkers"] - nccl_comm = session_state["nccl"] + nccl_comm = raft_comm_state["nccl"] inject_comms_on_handle_coll_only( handle, nccl_comm, nWorkers, workerId, verbose ) - session_state["handle"] = handle + raft_comm_state["handle"] = handle def _func_store_initial_state(nworkers, sessionId, uniqueId, wid): - session_state = worker_state(sessionId) - session_state["nccl_uid"] = uniqueId - session_state["wid"] = wid - session_state["nworkers"] = nworkers + raft_comm_state = get_raft_comm_state(sessionId=sessionId, + state_object=get_worker()) + raft_comm_state["nccl_uid"] = uniqueId + raft_comm_state["wid"] = wid + raft_comm_state["nworkers"] = nworkers async def _func_ucp_create_endpoints(sessionId, worker_info): @@ -594,40 +588,46 @@ async def _func_ucp_create_endpoints(sessionId, worker_info): eps[worker_info[k]["rank"]] = ep count += 1 - worker_state(sessionId)["ucp_eps"] = eps + raft_comm_state = get_raft_comm_state(sessionId=sessionId, + state_object=get_worker()) + raft_comm_state["ucp_eps"] = eps async def _func_destroy_all(sessionId, comms_p2p, verbose=False): + worker = get_worker() if verbose: - get_worker().log_event( + worker.log_event( topic="info", msg="Destroying NCCL session state." ) - session_state = worker_state(sessionId) - if "nccl" in session_state: - session_state["nccl"].destroy() - del session_state["nccl"] + + raft_comm_state = get_raft_comm_state(sessionId=sessionId, + state_object=worker) + if "nccl" in raft_comm_state: + raft_comm_state["nccl"].destroy() + del raft_comm_state["nccl"] if verbose: - get_worker().log_event( + worker.log_event( topic="info", msg="NCCL session state destroyed." ) else: if verbose: - get_worker().log_event( + worker.log_event( topic="warning", msg=f"Session state for, '{sessionId}', " f"does not contain expected 'nccl' element", ) if verbose: - get_worker().log_event( + worker.log_event( topic="info", - msg=f"Destroy CUDA handle for sessionId, '{sessionId}.'", + msg=f"Destroying CUDA handle for sessionId, '{sessionId}.'", ) - if "handle" in session_state: - del session_state["handle"] + + if "handle" in raft_comm_state: + del raft_comm_state["handle"] else: if verbose: - get_worker().log_event( + worker.log_event( topic="warning", msg=f"Session state for, '{sessionId}', " f"does not contain expected 'handle' element", From 223a4a7fbc3810a551db634707ec5f300a04f605 Mon Sep 17 00:00:00 2001 From: Devin Robison Date: Thu, 14 Jan 2021 16:24:47 -0700 Subject: [PATCH 13/16] Consolidate test code paths --- python/raft/test/test_comms.py | 38 ++++++++++++++-------------------- 1 file changed, 15 insertions(+), 23 deletions(-) diff --git a/python/raft/test/test_comms.py b/python/raft/test/test_comms.py index 0fb988f1dd..381cf9e4a7 100644 --- a/python/raft/test/test_comms.py +++ b/python/raft/test/test_comms.py @@ -68,16 +68,16 @@ def func_test_comm_split(sessionId, n_trials): return perform_test_comm_split(handle, n_trials) -def func_check_uid_on_scheduler(sessionId, uniqueId, dask_scheduler): - if not hasattr(dask_scheduler, "_raft_comm_state"): +def func_check_uid(sessionId, uniqueId, state_object): + if not hasattr(state_object, "_raft_comm_state"): return 1 - state_object = dask_scheduler._raft_comm_state - if sessionId not in state_object: + state = state_object._raft_comm_state + if sessionId not in state: return 2 - session_state = state_object[sessionId] - if "nccl_uid" not in dask_scheduler._raft_comm_state[sessionId]: + session_state = state[sessionId] + if "nccl_uid" not in session_state: return 3 nccl_uid = session_state["nccl_uid"] @@ -87,26 +87,18 @@ def func_check_uid_on_scheduler(sessionId, uniqueId, dask_scheduler): return 0 -def func_check_uid_on_worker(sessionId, uniqueId): - from dask.distributed import get_worker - - worker_state = get_worker() - if not hasattr(worker_state, "_raft_comm_state"): - return 1 +def func_check_uid_on_scheduler(sessionId, uniqueId, dask_scheduler): + return func_check_uid(sessionId=sessionId, + uniqueId=uniqueId, + state_object=dask_scheduler) - state_object = worker_state._raft_comm_state - if sessionId not in state_object: - return 2 - session_state = state_object[sessionId] - if "nccl_uid" not in session_state: - return 3 - - nccl_uid = session_state["nccl_uid"] - if nccl_uid != uniqueId: - return 4 +def func_check_uid_on_worker(sessionId, uniqueId): + from dask.distributed import get_worker - return 0 + return func_check_uid(sessionId=sessionId, + uniqueId=uniqueId, + state_object=get_worker()) def test_handles(cluster): From f688d917c71ef75cf75404fe92ed756710484b30 Mon Sep 17 00:00:00 2001 From: Devin Robison Date: Tue, 19 Jan 2021 10:06:11 -0700 Subject: [PATCH 14/16] Style updates. --- python/raft/dask/common/comms.py | 58 +++++++++++++++++--------------- python/raft/test/test_comms.py | 12 +++---- 2 files changed, 36 insertions(+), 34 deletions(-) diff --git a/python/raft/dask/common/comms.py b/python/raft/dask/common/comms.py index e806788e11..339ea85777 100644 --- a/python/raft/dask/common/comms.py +++ b/python/raft/dask/common/comms.py @@ -286,12 +286,12 @@ def get_raft_comm_state(sessionId, state_object): state_object._raft_comm_state = {} if ( - sessionId is not None - and sessionId not in state_object._raft_comm_state + sessionId is not None + and sessionId not in state_object._raft_comm_state ): state_object._raft_comm_state[sessionId] = {"ts": time.time()} - if (sessionId is not None): + if sessionId is not None: return state_object._raft_comm_state[sessionId] return state_object._raft_comm_state @@ -316,8 +316,9 @@ def get_ucx(): A simple convenience wrapper to make sure UCP listener and endpoints are only ever assigned once per worker. """ - raft_comm_state = get_raft_comm_state(sessionId="ucp", - state_object=get_worker()) + raft_comm_state = get_raft_comm_state( + sessionId="ucp", state_object=get_worker() + ) if "ucx" not in raft_comm_state: raft_comm_state["ucx"] = UCX.get() @@ -413,8 +414,9 @@ async def _func_init_all( sessionId, uniqueId, comms_p2p, worker_info, verbose, streams_per_handle ): worker = get_worker() - raft_comm_state = get_raft_comm_state(sessionId=sessionId, - state_object=worker) + raft_comm_state = get_raft_comm_state( + sessionId=sessionId, state_object=worker + ) raft_comm_state["nccl_uid"] = uniqueId raft_comm_state["wid"] = worker_info[get_worker().address]["rank"] raft_comm_state["nworkers"] = len(worker_info) @@ -433,9 +435,7 @@ async def _func_init_all( if comms_p2p: if verbose: - worker.log_event( - topic="info", msg="Initializing UCX Endpoints" - ) + worker.log_event(topic="info", msg="Initializing UCX Endpoints") if verbose: start = time.time() @@ -472,8 +472,9 @@ def _func_init_nccl(sessionId, uniqueId): """ worker = get_worker() - raft_comm_state = get_raft_comm_state(sessionId=sessionId, - state_object=get_worker()) + raft_comm_state = get_raft_comm_state( + sessionId=sessionId, state_object=get_worker() + ) wid = raft_comm_state["wid"] nWorkers = raft_comm_state["nworkers"] @@ -503,8 +504,9 @@ def _func_build_handle_p2p(sessionId, streams_per_handle, verbose): worker.log_event(topic="info", msg="Building p2p handle.") ucp_worker = get_ucx().get_worker() - raft_comm_state = get_raft_comm_state(sessionId=sessionId, - state_object=worker) + raft_comm_state = get_raft_comm_state( + sessionId=sessionId, state_object=worker + ) handle = Handle(streams_per_handle) nccl_comm = raft_comm_state["nccl"] @@ -545,8 +547,9 @@ def _func_build_handle(sessionId, streams_per_handle, verbose): handle = Handle(streams_per_handle) - raft_comm_state = get_raft_comm_state(sessionId=sessionId, - state_object=worker) + raft_comm_state = get_raft_comm_state( + sessionId=sessionId, state_object=worker + ) workerId = raft_comm_state["wid"] nWorkers = raft_comm_state["nworkers"] @@ -559,8 +562,9 @@ def _func_build_handle(sessionId, streams_per_handle, verbose): def _func_store_initial_state(nworkers, sessionId, uniqueId, wid): - raft_comm_state = get_raft_comm_state(sessionId=sessionId, - state_object=get_worker()) + raft_comm_state = get_raft_comm_state( + sessionId=sessionId, state_object=get_worker() + ) raft_comm_state["nccl_uid"] = uniqueId raft_comm_state["wid"] = wid raft_comm_state["nworkers"] = nworkers @@ -588,27 +592,25 @@ async def _func_ucp_create_endpoints(sessionId, worker_info): eps[worker_info[k]["rank"]] = ep count += 1 - raft_comm_state = get_raft_comm_state(sessionId=sessionId, - state_object=get_worker()) + raft_comm_state = get_raft_comm_state( + sessionId=sessionId, state_object=get_worker() + ) raft_comm_state["ucp_eps"] = eps async def _func_destroy_all(sessionId, comms_p2p, verbose=False): worker = get_worker() if verbose: - worker.log_event( - topic="info", msg="Destroying NCCL session state." - ) + worker.log_event(topic="info", msg="Destroying NCCL session state.") - raft_comm_state = get_raft_comm_state(sessionId=sessionId, - state_object=worker) + raft_comm_state = get_raft_comm_state( + sessionId=sessionId, state_object=worker + ) if "nccl" in raft_comm_state: raft_comm_state["nccl"].destroy() del raft_comm_state["nccl"] if verbose: - worker.log_event( - topic="info", msg="NCCL session state destroyed." - ) + worker.log_event(topic="info", msg="NCCL session state destroyed.") else: if verbose: worker.log_event( diff --git a/python/raft/test/test_comms.py b/python/raft/test/test_comms.py index 381cf9e4a7..7dccb7bbae 100644 --- a/python/raft/test/test_comms.py +++ b/python/raft/test/test_comms.py @@ -88,17 +88,17 @@ def func_check_uid(sessionId, uniqueId, state_object): def func_check_uid_on_scheduler(sessionId, uniqueId, dask_scheduler): - return func_check_uid(sessionId=sessionId, - uniqueId=uniqueId, - state_object=dask_scheduler) + return func_check_uid( + sessionId=sessionId, uniqueId=uniqueId, state_object=dask_scheduler + ) def func_check_uid_on_worker(sessionId, uniqueId): from dask.distributed import get_worker - return func_check_uid(sessionId=sessionId, - uniqueId=uniqueId, - state_object=get_worker()) + return func_check_uid( + sessionId=sessionId, uniqueId=uniqueId, state_object=get_worker() + ) def test_handles(cluster): From e20a16c012e1fa24f2804ae5ac599aa2a9b0493c Mon Sep 17 00:00:00 2001 From: Devin Robison Date: Tue, 19 Jan 2021 14:58:05 -0700 Subject: [PATCH 15/16] Add default state object to get_raft_comm_state to improve cuml compatibility --- python/raft/dask/common/comms.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/python/raft/dask/common/comms.py b/python/raft/dask/common/comms.py index 339ea85777..d3b3d3de53 100644 --- a/python/raft/dask/common/comms.py +++ b/python/raft/dask/common/comms.py @@ -262,8 +262,7 @@ def local_handle(sessionId): state = get_raft_comm_state(sessionId, get_worker()) return state["handle"] if "handle" in state else None - -def get_raft_comm_state(sessionId, state_object): +def get_raft_comm_state(sessionId, state_object=None): """ Retrieves cuML comms state on the scheduler node, for the given sessionId, creating a new session if it does not exist. If no session id is given, @@ -281,6 +280,7 @@ def get_raft_comm_state(sessionId, state_object): session state : str session state associated with sessionId """ + state_object = state_object if state_object is not None else get_worker() if not hasattr(state_object, "_raft_comm_state"): state_object._raft_comm_state = {} From eb7a89736f9d0a7617c2c25b7ee57c26d2a428e4 Mon Sep 17 00:00:00 2001 From: Devin Robison Date: Tue, 19 Jan 2021 15:00:33 -0700 Subject: [PATCH 16/16] Style fix --- python/raft/dask/common/comms.py | 1 + 1 file changed, 1 insertion(+) diff --git a/python/raft/dask/common/comms.py b/python/raft/dask/common/comms.py index d3b3d3de53..27533dfb9a 100644 --- a/python/raft/dask/common/comms.py +++ b/python/raft/dask/common/comms.py @@ -262,6 +262,7 @@ def local_handle(sessionId): state = get_raft_comm_state(sessionId, get_worker()) return state["handle"] if "handle" in state else None + def get_raft_comm_state(sessionId, state_object=None): """ Retrieves cuML comms state on the scheduler node, for the given sessionId,