Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Check if CUDA context was created in distributed.comm.ucx #722

Merged
Merged
36 changes: 21 additions & 15 deletions dask_cuda/initialize.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import numba.cuda

import dask
import distributed.comm.ucx

from .utils import get_ucx_config, has_cuda_context

Expand All @@ -14,11 +15,15 @@

def _create_cuda_context():
try:
# Added here to ensure the parent `LocalCUDACluster` process creates the CUDA
# context directly from the UCX module, thus avoiding a similar warning there.
distributed.comm.ucx.init_once()

cuda_visible_device = int(
os.environ.get("CUDA_VISIBLE_DEVICES", "0").split(",")[0]
)
ctx = has_cuda_context()
if ctx is not False:
if ctx is not False and distributed.comm.ucx.cuda_context_created is False:
warnings.warn(
f"A CUDA context for device {ctx} already exists on process ID "
f"{os.getpid()}. This is often the result of a CUDA-enabled library "
Expand All @@ -29,16 +34,18 @@ def _create_cuda_context():

numba.cuda.current_context()

ctx = has_cuda_context()
if ctx is not False and ctx != cuda_visible_device:
warnings.warn(
f"Worker with process ID {os.getpid()} should have a CUDA context "
f"assigned to device {cuda_visible_device}, but instead the CUDA "
f"context is on device {ctx}. This is often the result of a "
"CUDA-enabled library calling a CUDA runtime function before Dask-CUDA "
"can spawn worker processes. Please make sure any such function calls "
"don't happen at import time or in the global scope of a program."
)
if distributed.comm.ucx.cuda_context_created is False:
ctx = has_cuda_context()
if ctx is not False and ctx != cuda_visible_device:
warnings.warn(
f"Worker with process ID {os.getpid()} should have a CUDA context "
f"assigned to device {cuda_visible_device}, but instead the CUDA "
f"context is on device {ctx}. This is often the result of a "
"CUDA-enabled library calling a CUDA runtime function before "
"Dask-CUDA can spawn worker processes. Please make sure any such "
"function calls don't happen at import time or in the global scope "
"of a program."
)
except Exception:
logger.error("Unable to start CUDA Context", exc_info=True)

Expand Down Expand Up @@ -110,10 +117,6 @@ def initialize(
it is callable. Can be an integer or ``None`` if ``net_devices`` is not
callable.
"""

if create_cuda_context:
_create_cuda_context()

ucx_config = get_ucx_config(
enable_tcp_over_ucx=enable_tcp_over_ucx,
enable_infiniband=enable_infiniband,
Expand All @@ -124,6 +127,9 @@ def initialize(
)
dask.config.set({"distributed.comm.ucx": ucx_config})

if create_cuda_context:
_create_cuda_context()


@click.command()
@click.option(
Expand Down