From 5da5ea42bb3d6e6cd01171f74a08941b75ba93df Mon Sep 17 00:00:00 2001 From: Peter Andreas Entschev Date: Thu, 17 Oct 2024 13:31:42 -0700 Subject: [PATCH] Add UCXX proxy backend --- distributed/comm/ucx.py | 81 +++++++++++++++++++++++++++++++++++++++-- distributed/worker.py | 2 +- 2 files changed, 79 insertions(+), 4 deletions(-) diff --git a/distributed/comm/ucx.py b/distributed/comm/ucx.py index 54b14fec44..969636ebd6 100644 --- a/distributed/comm/ucx.py +++ b/distributed/comm/ucx.py @@ -12,6 +12,7 @@ import logging import os import struct +import warnings import weakref from collections.abc import Awaitable, Callable, Collection from typing import TYPE_CHECKING, Any @@ -491,7 +492,7 @@ def __init__( ): super().__init__() if not address.startswith("ucx"): - address = "ucx://" + address + address = self.prefix + address self.ip, self._input_port = parse_host_port(address, default_port=0) self.comm_handler = comm_handler self.deserialize = deserialize @@ -506,7 +507,7 @@ def port(self): @property def address(self): - return "ucx://" + self.ip + ":" + str(self.port) + return self.prefix + self.ip + ":" + str(self.port) async def start(self): async def serve_forever(client_ep): @@ -559,6 +560,18 @@ def get_connector(self): return UCXConnector() def get_listener(self, loc, handle_comm, deserialize, **connection_args): + print("UCXBackend.get_listener", flush=True) + warnings.warn( + "you have requested protocol='ucx', which now defaults to UCXX but " + "the package distributed-ucxx is not installed. In the current version " + "of Distributed this will fallback to UCX-Py which is now deprecated " + "and will be removed in a future release. For now protocol='ucx' will " + "fallback to the old UCX-Py library, but for continued use of UCX as " + "a Distributed communication backend, please ensure you switch to the " + "new distributed-ucxx package. To keep on using UCX-Py for now and " + "disable this warning, specify protocol='ucx-old'.", + # DeprecationWarning, + ) return UCXListener(loc, handle_comm, deserialize, **connection_args) # Address handling @@ -584,7 +597,69 @@ def get_local_address_for(self, loc): return unparse_host_port(local_host, None) -backends["ucx"] = UCXBackend() +class UCXConnectorOld(UCXConnector): + prefix = "ucx-old://" + + +class UCXListenerOld(UCXListener): + prefix = UCXConnectorOld.prefix + + +class UCXBackendOld(UCXBackend): + def get_connector(self): + return UCXConnectorOld() + + def get_listener(self, loc, handle_comm, deserialize, **connection_args): + return UCXListenerOld(loc, handle_comm, deserialize, **connection_args) + + +def _rewrite_ucxx_backend(): + try: + from distributed_ucxx.ucxx import UCXX, UCXXBackend, UCXXConnector, UCXXListener + + + class UCXXPrefixRewrite(UCXX): + prefix = "ucx://" + + + class UCXXConnectorPrefixRewrite(UCXXConnector): + prefix = "ucx://" + comm_class = UCXXPrefixRewrite + + + class UCXXListenerPrefixRewrite(UCXXListener): + prefix = UCXXConnectorPrefixRewrite.prefix + comm_class = UCXXConnectorPrefixRewrite.comm_class + encrypted = UCXXConnectorPrefixRewrite.encrypted + + + class UCXXBackendPrefixRewrite(UCXXBackend): + def get_connector(self): + return UCXXConnectorPrefixRewrite() + + def get_listener(self, loc, handle_comm, deserialize, **connection_args): + return UCXXListenerPrefixRewrite(loc, handle_comm, deserialize, **connection_args) + + + return UCXXBackendPrefixRewrite + except ImportError: + return UCXBackend() + +try: + # It's necessary to `try`/`except` `import distributed_ucxx` first, then in + # the `finally` block `from distributed_ucxx.ucxx import UCXXBackend` should + # succeed if distributed-ucxx is installed. This requirement is probably due + # to distributed-ucxx registering `backends["ucxx"]` as an entry point. + # This entire block is temporary (along with this entire file) until UCXX + # becomes the default and only backend and UCX-Py is ultimately archived. + import distributed_ucxx +except ImportError: + pass +finally: + backends["ucx"] = _rewrite_ucxx_backend()() +backends["ucx-old"] = UCXBackendOld() +print(f"{backends['ucx']=}", flush=True) +print(f"{backends['ucx-old']=}", flush=True) def _prepare_ucx_config(): diff --git a/distributed/worker.py b/distributed/worker.py index 7e3fecb9b2..380d115fe6 100644 --- a/distributed/worker.py +++ b/distributed/worker.py @@ -1604,7 +1604,7 @@ async def close( # type: ignore # before closing self.batched_stream, otherwise the local endpoint # may be closed too early and errors be raised on the scheduler when # trying to send closing message. - if self._protocol == "ucx": # pragma: no cover + if self._protocol.startswith("ucx"): # pragma: no cover await asyncio.sleep(0.2) self.batched_send({"op": "close-stream"})