Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use pynvjitlink for CUDA 12+ MVC #13650

Merged
merged 18 commits into from
Nov 20, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
99 changes: 99 additions & 0 deletions python/cudf/cudf/tests/test_mvc.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,99 @@
# Copyright (c) 2023, NVIDIA CORPORATION.
import subprocess
import sys

import pytest

IS_CUDA_11 = False
IS_CUDA_12 = False
try:
from ptxcompiler.patch import safe_get_versions
except ModuleNotFoundError:
from cudf.utils._ptxcompiler import safe_get_versions

# do not test cuda 12 if pynvjitlink isn't present
HAVE_PYNVJITLINK = False
try:
import pynvjitlink # noqa: F401

HAVE_PYNVJITLINK = True
except ModuleNotFoundError:
pass


versions = safe_get_versions()
driver_version, runtime_version = versions

if (11, 0) <= driver_version < (12, 0):
IS_CUDA_11 = True
if (12, 0) <= driver_version < (13, 0):
IS_CUDA_12 = True


TEST_BODY = """
@numba.cuda.jit
def test_kernel(x):
id = numba.cuda.grid(1)
if id < len(x):
x[id] += 1

s = cudf.Series([1, 2, 3])
with _CUDFNumbaConfig():
test_kernel.forall(len(s))(s)
"""

CUDA_11_TEST = (
"""
import numba.cuda
import cudf
from cudf.utils._numba import _CUDFNumbaConfig, patch_numba_linker_cuda_11


patch_numba_linker_cuda_11()
"""
+ TEST_BODY
)


CUDA_12_TEST = (
"""
import numba.cuda
import cudf
from cudf.utils._numba import _CUDFNumbaConfig
from pynvjitlink.patch import (
patch_numba_linker as patch_numba_linker_pynvjitlink,
)

patch_numba_linker_pynvjitlink()
"""
+ TEST_BODY
)


@pytest.mark.parametrize(
"test",
[
pytest.param(
CUDA_11_TEST,
marks=pytest.mark.skipif(
not IS_CUDA_11,
reason="Minor Version Compatibility test for CUDA 11",
),
),
pytest.param(
CUDA_12_TEST,
marks=pytest.mark.skipif(
not IS_CUDA_12 or not HAVE_PYNVJITLINK,
reason="Minor Version Compatibility test for CUDA 12",
),
),
],
)
def test_numba_mvc(test):
cp = subprocess.run(
[sys.executable, "-c", test],
capture_output=True,
cwd="/",
)

assert cp.returncode == 0
48 changes: 0 additions & 48 deletions python/cudf/cudf/tests/test_numba_import.py

This file was deleted.

53 changes: 29 additions & 24 deletions python/cudf/cudf/utils/_numba.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,19 @@

from numba import config as numba_config

try:
from pynvjitlink.patch import (
patch_numba_linker as patch_numba_linker_pynvjitlink,
)
except ImportError:

def patch_numba_linker_pynvjitlink():
warnings.warn(
"CUDA Toolkit is newer than CUDA driver. "
"Numba features will not work in this configuration. "
)


CC_60_PTX_FILE = os.path.join(
os.path.dirname(__file__), "../core/udf/shim_60.ptx"
)
Expand Down Expand Up @@ -65,7 +78,7 @@ def _get_ptx_file(path, prefix):
return regular_result[1]


def _patch_numba_mvc():
def patch_numba_linker_cuda_11():
brandon-b-miller marked this conversation as resolved.
Show resolved Hide resolved
# Enable the config option for minor version compatibility
numba_config.CUDA_ENABLE_MINOR_VERSION_COMPATIBILITY = 1

Expand Down Expand Up @@ -106,29 +119,19 @@ def _setup_numba():
versions = safe_get_versions()
if versions != NO_DRIVER:
driver_version, runtime_version = versions
if driver_version >= (12, 0) and runtime_version > driver_version:
warnings.warn(
f"Using CUDA toolkit version {runtime_version} with CUDA "
f"driver version {driver_version} requires minor version "
"compatibility, which is not yet supported for CUDA "
"driver versions 12.0 and above. It is likely that many "
"cuDF operations will not work in this state. Please "
f"install CUDA toolkit version {driver_version} to "
"continue using cuDF."
)
else:
# Support MVC for all CUDA versions in the 11.x range
ptx_toolkit_version = _get_cuda_version_from_ptx_file(
CC_60_PTX_FILE
)
# Numba thinks cubinlinker is only needed if the driver is older
# than the CUDA runtime, but when PTX files are present, it might
# also need to patch because those PTX files may be compiled by
# a CUDA version that is newer than the driver as well
if (driver_version < ptx_toolkit_version) or (
driver_version < runtime_version
):
_patch_numba_mvc()
ptx_toolkit_version = _get_cuda_version_from_ptx_file(CC_60_PTX_FILE)

# MVC is required whenever any PTX is newer than the driver
# This could be the shipped PTX file or the PTX emitted by
# the version of NVVM on the user system, the latter aligning
# with the runtime version
if (driver_version < ptx_toolkit_version) or (
driver_version < runtime_version
):
if driver_version < (12, 0):
patch_numba_linker_cuda_11()
else:
patch_numba_linker_pynvjitlink()


def _get_cuda_version_from_ptx_file(path):
Expand Down Expand Up @@ -171,6 +174,8 @@ def _get_cuda_version_from_ptx_file(path):
"7.8": (11, 8),
"8.0": (12, 0),
"8.1": (12, 1),
"8.2": (12, 2),
brandon-b-miller marked this conversation as resolved.
Show resolved Hide resolved
"8.3": (12, 3),
}

cuda_ver = ver_map.get(version)
Expand Down