Skip to content

Commit

Permalink
initial
Browse files Browse the repository at this point in the history
  • Loading branch information
isVoid committed Dec 18, 2024
1 parent 9479123 commit b8c03c4
Show file tree
Hide file tree
Showing 2 changed files with 77 additions and 17 deletions.
30 changes: 16 additions & 14 deletions numba_cuda/numba/cuda/cudadrv/driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,11 @@
from .linkable_code import LinkableCode, LTOIR, Fatbin, Object
from numba.cuda.cudadrv import enums, drvapi, nvrtc

try:
from pynvjitlink.api import NvJitLinker, NvJitLinkError
except ImportError:
NvJitLinker, NvJitLinkError = None, None

USE_NV_BINDING = config.CUDA_USE_NVIDIA_BINDING

if USE_NV_BINDING:
Expand Down Expand Up @@ -92,20 +97,6 @@ def _readenv(name, ctor, default):
if not hasattr(config, "CUDA_ENABLE_PYNVJITLINK"):
config.CUDA_ENABLE_PYNVJITLINK = ENABLE_PYNVJITLINK

if ENABLE_PYNVJITLINK:
try:
from pynvjitlink.api import NvJitLinker, NvJitLinkError
except ImportError:
raise ImportError(
"Using pynvjitlink requires the pynvjitlink package to be available"
)

if config.CUDA_ENABLE_MINOR_VERSION_COMPATIBILITY:
raise ValueError(
"Can't set CUDA_ENABLE_MINOR_VERSION_COMPATIBILITY and "
"CUDA_ENABLE_PYNVJITLINK at the same time"
)


def make_logger():
logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -3061,6 +3052,17 @@ def __init__(
lto=False,
additional_flags=None,
):
if NvJitLinker is None:
raise ImportError(
"Using pynvjitlink requires the pynvjitlink package to be "
"available"
)

if config.CUDA_ENABLE_MINOR_VERSION_COMPATIBILITY:
raise ValueError(
"Can't set CUDA_ENABLE_MINOR_VERSION_COMPATIBILITY and "
"CUDA_ENABLE_PYNVJITLINK at the same time"
)

if cc is None:
raise RuntimeError("PyNvJitLinker requires CC to be specified")
Expand Down
64 changes: 61 additions & 3 deletions numba_cuda/numba/cuda/tests/cudadrv/test_nvjitlink.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,16 +2,18 @@
from numba.cuda.testing import skip_on_cudasim
from numba.cuda.testing import CUDATestCase
from numba.cuda.cudadrv.driver import PyNvJitLinker
from numba.cuda import get_current_device

from numba import cuda
from numba import config
from numba.tests.support import run_in_subprocess, override_config

import itertools
import os
import io
import contextlib
import warnings

from numba.cuda import get_current_device
from numba import cuda
from numba import config

TEST_BIN_DIR = os.getenv("NUMBA_CUDA_TEST_BIN_DIR")
if TEST_BIN_DIR:
Expand Down Expand Up @@ -251,5 +253,61 @@ def kernel():
pass


class TestLinkerUsage(CUDATestCase):
"""Test that whether pynvjitlink can be enabled by both environment variable
and modification of config at runtime.
"""
def test_linker_enabled_envvar(self):
# Linkable code is only supported via pynvjitlink
src = """if 1:
import os
from numba import cuda
TEST_BIN_DIR = os.getenv("NUMBA_CUDA_TEST_BIN_DIR")
if TEST_BIN_DIR:
test_device_functions_cubin = os.path.join(
TEST_BIN_DIR, "test_device_functions.cubin"
)
print(TEST_BIN_DIR)
files = (
test_device_functions_cubin,
)
for lto in [True, False]:
for file in files:
sig = "uint32(uint32, uint32)"
add_from_numba = cuda.declare_device("add_from_numba", sig)
@cuda.jit(link=[file], lto=lto)
def kernel(result):
result[0] = add_from_numba(1, 2)
result = cuda.device_array(1)
kernel[1, 1](result)
assert result[0] == 3
"""
env = os.environ.copy()
env['NUMBA_CUDA_ENABLE_PYNVJITLINK'] = "1"
print(env['NUMBA_CUDA_TEST_BIN_DIR'])
run_in_subprocess(src, env=env)

def test_linker_enabled_config(self):
with override_config("CUDA_ENABLE_PYNVJITLINK", True):
files = (
test_device_functions_cubin,
)
for lto in [True, False]:
for file in files:
sig = "uint32(uint32, uint32)"
add_from_numba = cuda.declare_device("add_from_numba", sig)

@cuda.jit(link=[file], lto=lto)
def kernel(result):
result[0] = add_from_numba(1, 2)

result = cuda.device_array(1)
kernel[1, 1](result)
assert result[0] == 3


if __name__ == "__main__":
unittest.main()

0 comments on commit b8c03c4

Please sign in to comment.