Skip to content

Commit

Permalink
Warn if CUDA context is created on incorrect device
Browse files Browse the repository at this point in the history
  • Loading branch information
pentschev committed Sep 9, 2021
1 parent d79645f commit bd6a63f
Showing 1 changed file with 36 additions and 9 deletions.
45 changes: 36 additions & 9 deletions dask_cuda/initialize.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,48 @@
import logging
import os
import warnings

import click
import numba.cuda

import dask

from .utils import get_ucx_config
from .utils import get_ucx_config, has_cuda_context

logger = logging.getLogger(__name__)


def _create_cuda_context():
try:
cuda_visible_device = int(
os.environ.get("CUDA_VISIBLE_DEVICES", "0").split(",")[0]
)
ctx = has_cuda_context()
if ctx is not False:
warnings.warn(
f"A CUDA context for device {ctx} already exists on process ID "
f"{os.getpid()}. This is often the result of a CUDA-enabled library "
"calling a CUDA runtime function before Dask-CUDA can spawn worker "
"processes. Please make sure any such function calls don't happen at "
"happen at import time or in the global scope of a program."
)

numba.cuda.current_context()

ctx = has_cuda_context()
if ctx is not False and ctx != cuda_visible_device:
warnings.warn(
f"Worker with process ID {os.getpid()} should have a CUDA context "
f"assigned to device {cuda_visible_device}, but instead the CUDA "
f"context is on device {ctx}. This is often the result of a "
"CUDA-enabled library calling a CUDA runtime function before Dask-CUDA "
"can spawn worker processes. Please make sure any such function calls "
"don't happen at import time or in the global scope of a program."
)
except Exception:
logger.error("Unable to start CUDA Context", exc_info=True)


def initialize(
create_cuda_context=True,
enable_tcp_over_ucx=False,
Expand Down Expand Up @@ -79,10 +112,7 @@ def initialize(
"""

if create_cuda_context:
try:
numba.cuda.current_context()
except Exception:
logger.error("Unable to start CUDA Context", exc_info=True)
_create_cuda_context()

ucx_config = get_ucx_config(
enable_tcp_over_ucx=enable_tcp_over_ucx,
Expand Down Expand Up @@ -138,7 +168,4 @@ def dask_setup(
net_devices,
):
if create_cuda_context:
try:
numba.cuda.current_context()
except Exception:
logger.error("Unable to start CUDA Context", exc_info=True)
_create_cuda_context()

0 comments on commit bd6a63f

Please sign in to comment.