diff --git a/python/cugraph/comms/comms.py b/python/cugraph/comms/comms.py index 28ce2a3fc1e..642d99440e0 100644 --- a/python/cugraph/comms/comms.py +++ b/python/cugraph/comms/comms.py @@ -9,16 +9,33 @@ # Intialize Comms. If explicit Comms not provided as arg, # default Comms are initialized as per client information. -def initialize(arg=None): +def initialize(comms=None, p2p=False): + """ + Intitializes a communicator for multi-node multi-gpu communications. + It is expected to be called right after client initialization for running + mnmg algorithms. It wraps raft comms that manages underlying NCCL and UCX + comms handles across the workers of a Dask cluster. + It is recommended to also call `destroy()` when the comms are no longer + needed so the underlying resources can be cleaned up. + + Parameters + ---------- + comms : raft Comms + A pre-initialized raft communicator. If provided, this is used for mnmg + communications. + p2p : bool + Initialize UCX endpoints + """ + global __instance if __instance is None: global __default_handle __default_handle = None - if arg is None: - __instance = raftComms() + if comms is None: + __instance = raftComms(comms_p2p=p2p) __instance.init() else: - __instance = arg + __instance = comms else: raise Exception("Communicator is already initialized") @@ -47,6 +64,9 @@ def get_session_id(): # Destroy Comms def destroy(): + """ + Shuts down initialized comms and cleans up resources. + """ global __instance if is_initialized(): __instance.destroy()