diff --git a/python/distributed-ucxx/distributed_ucxx/ucxx.py b/python/distributed-ucxx/distributed_ucxx/ucxx.py index 9e09baef..aa908f56 100644 --- a/python/distributed-ucxx/distributed_ucxx/ucxx.py +++ b/python/distributed-ucxx/distributed_ucxx/ucxx.py @@ -11,6 +11,7 @@ import logging import os import struct +import threading import weakref from collections.abc import Awaitable, Callable, Collection from typing import TYPE_CHECKING, Any @@ -53,6 +54,9 @@ cuda_context_created = False multi_buffer = None +instances = 0 +instances_lock = threading.Lock() + _warning_suffix = ( "This is often the result of a CUDA-enabled library calling a CUDA runtime " @@ -264,6 +268,8 @@ def __init__( # type: ignore[no-untyped-def] deserialize: bool = True, enable_close_callback: bool = True, ): + global instances, instances_lock + super().__init__(deserialize=deserialize) self._ep = ep self._ep_handle = int(self._ep._ep.handle) @@ -286,6 +292,23 @@ def __init__( # type: ignore[no-untyped-def] logger.debug("UCX.__init__ %s", self) + with instances_lock: + instances += 1 + ucxx.core._get_ctx().continuous_ucx_progress() + + def __del__(self): + global instances, instances_lock + + # print(f"[{os.getpid()}] UCXX.__del__") + + with instances_lock: + instances -= 1 + if instances == 0: + # print(f"[{os.getpid()}] Stopping notifier thread", flush=True) + ucxx.stop_notifier_thread() + # print(f"[{os.getpid()}] Stopping progress tasks", flush=True) + ucxx.core._get_ctx().progress_tasks.clear() + @property def local_address(self) -> str: return self._local_addr