diff --git a/python/raft/dask/common/ucx.py b/python/raft/dask/common/ucx.py index 7e44f3bc43..f61479a0eb 100644 --- a/python/raft/dask/common/ucx.py +++ b/python/raft/dask/common/ucx.py @@ -17,7 +17,7 @@ async def _connection_func(ep): - return 0 + UCX.get().add_server_endpoint(ep) class UCX: @@ -35,6 +35,7 @@ def __init__(self, listener_callback): self._create_listener() self._endpoints = {} + self._server_endpoints = [] assert UCX.__instance is None @@ -60,6 +61,9 @@ async def _create_endpoint(self, ip, port): self._endpoints[(ip, port)] = ep return ep + def add_server_endpoint(self, ep): + self._server_endpoints.append(ep) + async def get_endpoint(self, ip, port): if (ip, port) not in self._endpoints: ep = await self._create_endpoint(ip, port) @@ -72,10 +76,18 @@ async def close_endpoints(self): for k, ep in self._endpoints.items(): await ep.close() + for ep in self._server_endpoints: + ep.close() + def __del__(self): for ip_port, ep in self._endpoints.items(): if not ep.closed(): ep.abort() del ep + for ep in self._server_endpoints: + if not ep.closed(): + ep.abort() + del ep + self._listener.close()