diff --git a/dask_cuda/initialize.py b/dask_cuda/initialize.py index b5fb81ef..8f159cab 100644 --- a/dask_cuda/initialize.py +++ b/dask_cuda/initialize.py @@ -6,6 +6,7 @@ import numba.cuda import dask +import distributed.comm.ucx from .utils import get_ucx_config, has_cuda_context @@ -14,11 +15,15 @@ def _create_cuda_context(): try: + # Added here to ensure the parent `LocalCUDACluster` process creates the CUDA + # context directly from the UCX module, thus avoiding a similar warning there. + distributed.comm.ucx.init_once() + cuda_visible_device = int( os.environ.get("CUDA_VISIBLE_DEVICES", "0").split(",")[0] ) ctx = has_cuda_context() - if ctx is not False: + if ctx is not False and distributed.comm.ucx.cuda_context_created is False: warnings.warn( f"A CUDA context for device {ctx} already exists on process ID " f"{os.getpid()}. This is often the result of a CUDA-enabled library " @@ -29,16 +34,18 @@ def _create_cuda_context(): numba.cuda.current_context() - ctx = has_cuda_context() - if ctx is not False and ctx != cuda_visible_device: - warnings.warn( - f"Worker with process ID {os.getpid()} should have a CUDA context " - f"assigned to device {cuda_visible_device}, but instead the CUDA " - f"context is on device {ctx}. This is often the result of a " - "CUDA-enabled library calling a CUDA runtime function before Dask-CUDA " - "can spawn worker processes. Please make sure any such function calls " - "don't happen at import time or in the global scope of a program." - ) + if distributed.comm.ucx.cuda_context_created is False: + ctx = has_cuda_context() + if ctx is not False and ctx != cuda_visible_device: + warnings.warn( + f"Worker with process ID {os.getpid()} should have a CUDA context " + f"assigned to device {cuda_visible_device}, but instead the CUDA " + f"context is on device {ctx}. This is often the result of a " + "CUDA-enabled library calling a CUDA runtime function before " + "Dask-CUDA can spawn worker processes. Please make sure any such " + "function calls don't happen at import time or in the global scope " + "of a program." + ) except Exception: logger.error("Unable to start CUDA Context", exc_info=True) @@ -110,10 +117,6 @@ def initialize( it is callable. Can be an integer or ``None`` if ``net_devices`` is not callable. """ - - if create_cuda_context: - _create_cuda_context() - ucx_config = get_ucx_config( enable_tcp_over_ucx=enable_tcp_over_ucx, enable_infiniband=enable_infiniband, @@ -124,6 +127,9 @@ def initialize( ) dask.config.set({"distributed.comm.ucx": ucx_config}) + if create_cuda_context: + _create_cuda_context() + @click.command() @click.option(