diff --git a/numba_cuda/numba/cuda/cudadrv/driver.py b/numba_cuda/numba/cuda/cudadrv/driver.py index 465d33d..cd5bd85 100644 --- a/numba_cuda/numba/cuda/cudadrv/driver.py +++ b/numba_cuda/numba/cuda/cudadrv/driver.py @@ -42,6 +42,11 @@ from .linkable_code import LinkableCode, LTOIR, Fatbin, Object from numba.cuda.cudadrv import enums, drvapi, nvrtc +try: + from pynvjitlink.api import NvJitLinker, NvJitLinkError +except ImportError: + NvJitLinker, NvJitLinkError = None, None + USE_NV_BINDING = config.CUDA_USE_NVIDIA_BINDING if USE_NV_BINDING: @@ -92,20 +97,6 @@ def _readenv(name, ctor, default): if not hasattr(config, "CUDA_ENABLE_PYNVJITLINK"): config.CUDA_ENABLE_PYNVJITLINK = ENABLE_PYNVJITLINK -if ENABLE_PYNVJITLINK: - try: - from pynvjitlink.api import NvJitLinker, NvJitLinkError - except ImportError: - raise ImportError( - "Using pynvjitlink requires the pynvjitlink package to be available" - ) - - if config.CUDA_ENABLE_MINOR_VERSION_COMPATIBILITY: - raise ValueError( - "Can't set CUDA_ENABLE_MINOR_VERSION_COMPATIBILITY and " - "CUDA_ENABLE_PYNVJITLINK at the same time" - ) - def make_logger(): logger = logging.getLogger(__name__) @@ -3061,6 +3052,17 @@ def __init__( lto=False, additional_flags=None, ): + if NvJitLinker is None: + raise ImportError( + "Using pynvjitlink requires the pynvjitlink package to be " + "available" + ) + + if config.CUDA_ENABLE_MINOR_VERSION_COMPATIBILITY: + raise ValueError( + "Can't set CUDA_ENABLE_MINOR_VERSION_COMPATIBILITY and " + "CUDA_ENABLE_PYNVJITLINK at the same time" + ) if cc is None: raise RuntimeError("PyNvJitLinker requires CC to be specified") diff --git a/numba_cuda/numba/cuda/tests/cudadrv/test_nvjitlink.py b/numba_cuda/numba/cuda/tests/cudadrv/test_nvjitlink.py index 044895c..dd6ba61 100644 --- a/numba_cuda/numba/cuda/tests/cudadrv/test_nvjitlink.py +++ b/numba_cuda/numba/cuda/tests/cudadrv/test_nvjitlink.py @@ -2,6 +2,11 @@ from numba.cuda.testing import skip_on_cudasim from numba.cuda.testing import CUDATestCase from numba.cuda.cudadrv.driver import PyNvJitLinker +from numba.cuda import get_current_device + +from numba import cuda +from numba import config +from numba.tests.support import run_in_subprocess, override_config import itertools import os @@ -9,9 +14,6 @@ import contextlib import warnings -from numba.cuda import get_current_device -from numba import cuda -from numba import config TEST_BIN_DIR = os.getenv("NUMBA_CUDA_TEST_BIN_DIR") if TEST_BIN_DIR: @@ -251,5 +253,61 @@ def kernel(): pass +class TestLinkerUsage(CUDATestCase): + """Test that whether pynvjitlink can be enabled by both environment variable + and modification of config at runtime. + """ + def test_linker_enabled_envvar(self): + # Linkable code is only supported via pynvjitlink + src = """if 1: + import os + from numba import cuda + + TEST_BIN_DIR = os.getenv("NUMBA_CUDA_TEST_BIN_DIR") + if TEST_BIN_DIR: + test_device_functions_cubin = os.path.join( + TEST_BIN_DIR, "test_device_functions.cubin" + ) + print(TEST_BIN_DIR) + files = ( + test_device_functions_cubin, + ) + for lto in [True, False]: + for file in files: + sig = "uint32(uint32, uint32)" + add_from_numba = cuda.declare_device("add_from_numba", sig) + + @cuda.jit(link=[file], lto=lto) + def kernel(result): + result[0] = add_from_numba(1, 2) + + result = cuda.device_array(1) + kernel[1, 1](result) + assert result[0] == 3 + """ + env = os.environ.copy() + env['NUMBA_CUDA_ENABLE_PYNVJITLINK'] = "1" + print(env['NUMBA_CUDA_TEST_BIN_DIR']) + run_in_subprocess(src, env=env) + + def test_linker_enabled_config(self): + with override_config("CUDA_ENABLE_PYNVJITLINK", True): + files = ( + test_device_functions_cubin, + ) + for lto in [True, False]: + for file in files: + sig = "uint32(uint32, uint32)" + add_from_numba = cuda.declare_device("add_from_numba", sig) + + @cuda.jit(link=[file], lto=lto) + def kernel(result): + result[0] = add_from_numba(1, 2) + + result = cuda.device_array(1) + kernel[1, 1](result) + assert result[0] == 3 + + if __name__ == "__main__": unittest.main()