Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use pynvjitlink for MVC #23

Closed
Closed
Show file tree
Hide file tree
Changes from 18 commits
Commits
Show all changes
58 commits
Select commit Hold shift + click to select a range
521be20
off the ground
brandon-b-miller Jul 12, 2024
0f9bc4a
cleanup
brandon-b-miller Jul 15, 2024
b25db1f
Merge remote-tracking branch 'upstream/develop' into develop
brandon-b-miller Jul 16, 2024
1c3517f
enough to launch a kernel
brandon-b-miller Jul 29, 2024
cbcbbab
pass through kwargs
brandon-b-miller Jul 29, 2024
4406809
patch_cuda once
brandon-b-miller Jul 29, 2024
dc887b6
refactor
brandon-b-miller Jul 30, 2024
7d17759
Merge remote-tracking branch 'upstream/develop' into develop
brandon-b-miller Aug 4, 2024
b9898ec
merge latest/resolve conflicts
brandon-b-miller Aug 4, 2024
60d4ca7
style and other fixes
brandon-b-miller Aug 4, 2024
db32cfa
Merge remote-tracking branch 'upstream/develop' into develop
brandon-b-miller Aug 5, 2024
bc424ae
merge latest/resolve conflict
brandon-b-miller Aug 5, 2024
56db9c8
cleanup
brandon-b-miller Aug 8, 2024
c57053c
bifurcate error messages
brandon-b-miller Aug 8, 2024
363b86d
partially address reviews
brandon-b-miller Aug 12, 2024
32164e9
move add_file_guess_ext logic to Linker base class
brandon-b-miller Aug 12, 2024
c3b9084
refactor __new__ logic
brandon-b-miller Aug 19, 2024
2c940ee
address reviews
brandon-b-miller Aug 19, 2024
a8c38b6
refactor config logic
brandon-b-miller Aug 21, 2024
421fdfb
continue addressing reviews
brandon-b-miller Aug 22, 2024
16314a7
rename errors
brandon-b-miller Aug 22, 2024
41d85a9
minor cleanup
brandon-b-miller Aug 22, 2024
f7939b6
Apply suggestions from code review
brandon-b-miller Aug 22, 2024
91f06a8
address reviews
brandon-b-miller Aug 22, 2024
0541dcf
bug fixes and map ltoir to CU_JIT_INPUT_NVVM
brandon-b-miller Aug 22, 2024
710f8cb
CU_JIT_INPUT_LTO_IR -> CU_JIT_INPUT_NVVM
brandon-b-miller Sep 4, 2024
a8cb6c2
only use cuda if CUDA_USE_NVIDIA_BINDING
brandon-b-miller Sep 4, 2024
b8b671f
fixes
brandon-b-miller Sep 4, 2024
7c384a3
tests
brandon-b-miller Sep 26, 2024
6327ec2
fix bug
brandon-b-miller Sep 30, 2024
c97767c
add a new ci job for testing with pynvjitlink
brandon-b-miller Sep 30, 2024
aa3aaf7
fixes
brandon-b-miller Sep 30, 2024
f01c0d6
more small fixes
brandon-b-miller Oct 2, 2024
f512bed
Merge remote-tracking branch 'upstream/develop' into develop
brandon-b-miller Oct 3, 2024
15e16a6
merge/resolve
brandon-b-miller Oct 3, 2024
519f0c1
update matrix filter
brandon-b-miller Oct 3, 2024
1201f1f
simple filter
brandon-b-miller Oct 3, 2024
883d817
clean
brandon-b-miller Oct 3, 2024
f41b931
.1
brandon-b-miller Oct 3, 2024
e5aa41e
revert
brandon-b-miller Oct 3, 2024
a0ea97d
refactor
brandon-b-miller Oct 3, 2024
b979054
try and fix conda workflow
brandon-b-miller Oct 3, 2024
beb3301
readenv boolify string values
brandon-b-miller Oct 3, 2024
3f5a865
fix imports
brandon-b-miller Oct 3, 2024
4770c40
update
brandon-b-miller Oct 4, 2024
2820ee6
use local workflow matrix
brandon-b-miller Oct 4, 2024
dc20cce
cu12 suffix
brandon-b-miller Oct 4, 2024
ff18c5c
Update ci/test_conda.sh
brandon-b-miller Oct 7, 2024
b2f4245
small updates
brandon-b-miller Oct 8, 2024
f40d3ed
fix logic :)
brandon-b-miller Oct 8, 2024
c24ec67
try hardcoding ENABLE_PYNVJITLINK
brandon-b-miller Oct 8, 2024
dccd6db
fix passing of ENABLE_PYNVJITLINK
brandon-b-miller Oct 9, 2024
9aaa21f
ship makefile, find and build tests
brandon-b-miller Oct 9, 2024
d3ca53c
minor fix
brandon-b-miller Oct 9, 2024
4ce95a7
bifurcate pynvjitlink test scripts
brandon-b-miller Oct 9, 2024
d65f80d
pass test bin dir as an env var
brandon-b-miller Oct 9, 2024
5097bcf
more minor fixes
brandon-b-miller Oct 9, 2024
e29744c
Retry installation of pynvjitlink
gmarkall Oct 10, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
280 changes: 230 additions & 50 deletions numba_cuda/numba/cuda/cudadrv/driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
import warnings
import logging
import threading
import traceback
import asyncio
import pathlib
from itertools import product
Expand All @@ -35,6 +36,8 @@
from .error import CudaSupportError, CudaDriverError
from .drvapi import API_PROTOTYPES
from .drvapi import cu_occupancy_b2d_size, cu_stream_callback_pyobj, cu_uuid
from .mappings import FILE_EXTENSION_MAP
from .linkable_code import LinkableCode
from numba.cuda.cudadrv import enums, drvapi, nvrtc

USE_NV_BINDING = config.CUDA_USE_NVIDIA_BINDING
Expand All @@ -55,6 +58,43 @@
_py_decref.argtypes = [ctypes.py_object]
_py_incref.argtypes = [ctypes.py_object]

pynvjitlink_import_err = None
try:
from pynvjitlink.api import NvJitLinker, NvJitLinkError
except ImportError as err:
pynvjitlink_import_err = err


def _readenv(name, ctor, default):
value = os.environ.get(name)
if value is None:
return default() if callable(default) else default
try:
return ctor(value)
except Exception:
warnings.warn(
f"Environment variable '{name}' is defined but its associated "
f"value '{value}' could not be parsed.\n"
"The parse failed with exception:\n"
f"{traceback.format_exc()}",
RuntimeWarning
)
return default


if _readenv("ENABLE_PYNVJITLINK", bool, False):
config.ENABLE_PYNVJITLINK = True
gmarkall marked this conversation as resolved.
Show resolved Hide resolved


_MVC_ERROR_MESSAGE_CU11 = (
"Minor version compatibility requires ptxcompiler and cubinlinker packages "
"to be available"
)

_MVC_ERROR_MESSAGE_CU12 = (
"Using pynvjitlink requires the pynvjitlink package to be available"
)


def make_logger():
logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -432,7 +472,7 @@ def get_active_context(self):

def get_version(self):
"""
Returns the CUDA Runtime version as a tuple (major, minor).
Returns the CUDA Driver version as a tuple (major, minor).
"""
if USE_NV_BINDING:
version = driver.cuDriverGetVersion()
Expand Down Expand Up @@ -2546,38 +2586,58 @@ def launch_kernel(cufunc_handle,
extra)


if USE_NV_BINDING:
jitty = binding.CUjitInputType
FILE_EXTENSION_MAP = {
'o': jitty.CU_JIT_INPUT_OBJECT,
'ptx': jitty.CU_JIT_INPUT_PTX,
'a': jitty.CU_JIT_INPUT_LIBRARY,
'lib': jitty.CU_JIT_INPUT_LIBRARY,
'cubin': jitty.CU_JIT_INPUT_CUBIN,
'fatbin': jitty.CU_JIT_INPUT_FATBINARY,
}
else:
FILE_EXTENSION_MAP = {
'o': enums.CU_JIT_INPUT_OBJECT,
'ptx': enums.CU_JIT_INPUT_PTX,
'a': enums.CU_JIT_INPUT_LIBRARY,
'lib': enums.CU_JIT_INPUT_LIBRARY,
'cubin': enums.CU_JIT_INPUT_CUBIN,
'fatbin': enums.CU_JIT_INPUT_FATBINARY,
}


class Linker(metaclass=ABCMeta):
"""Abstract base class for linkers"""

@classmethod
def new(cls, max_registers=0, lineinfo=False, cc=None):
if config.CUDA_ENABLE_MINOR_VERSION_COMPATIBILITY:
return MVCLinker(max_registers, lineinfo, cc)
elif USE_NV_BINDING:
return CudaPythonLinker(max_registers, lineinfo, cc)
def new(cls,
max_registers=0,
lineinfo=False,
cc=None,
lto=None,
brandon-b-miller marked this conversation as resolved.
Show resolved Hide resolved
additional_flags=None
):

driver_ver = driver.get_version()
if (
config.CUDA_ENABLE_MINOR_VERSION_COMPATIBILITY
and driver_ver >= (12, 0)
):
raise ValueError(
"Use ENABLE_PYNVJITLINK for CUDA >= 12.0 MVC"
)
if config.ENABLE_PYNVJITLINK and driver_ver < (12, 0):
raise ValueError(
"Use CUDA_ENABLE_MINOR_VERSION_COMPATIBILITY "
"for CUDA < 12.0 MVC"
gmarkall marked this conversation as resolved.
Show resolved Hide resolved
)
if (
config.CUDA_ENABLE_MINOR_VERSION_COMPATIBILITY
and config.ENABLE_PYNVJITLINK
):
raise ValueError(
"can't set both config.ENABLE_PYNVJITLINK "
"and config.CUDA_ENABLE_MINOR_VERSION_COMPATIBILITY "
"at the same time"
)
brandon-b-miller marked this conversation as resolved.
Show resolved Hide resolved

if config.ENABLE_PYNVJITLINK:
linker = PyNvJitLinker

elif config.CUDA_ENABLE_MINOR_VERSION_COMPATIBILITY:
linker = MVCLinker
else:
if USE_NV_BINDING:
linker = CudaPythonLinker
else:
linker = CtypesLinker

if linker is PyNvJitLinker:
return linker(max_registers, lineinfo, cc, lto, additional_flags)
elif additional_flags or lto:
brandon-b-miller marked this conversation as resolved.
Show resolved Hide resolved
raise ValueError("LTO and additional flags require PyNvJitLinker")
else:
return CtypesLinker(max_registers, lineinfo, cc)
return linker(max_registers, lineinfo, cc)

@abstractmethod
def __init__(self, max_registers, lineinfo, cc):
Expand Down Expand Up @@ -2626,19 +2686,38 @@ def add_cu_file(self, path):
cu = f.read()
self.add_cu(cu, os.path.basename(path))

def add_file_guess_ext(self, path):
def add_file_guess_ext(self, path_or_code):
"""Add a file to the link, guessing its type from its extension."""
gmarkall marked this conversation as resolved.
Show resolved Hide resolved
ext = os.path.splitext(path)[1][1:]
if ext == '':
raise RuntimeError("Don't know how to link file with no extension")
elif ext == 'cu':
self.add_cu_file(path)
else:
kind = FILE_EXTENSION_MAP.get(ext, None)
if kind is None:
raise RuntimeError("Don't know how to link file with extension "
f".{ext}")
self.add_file(path, kind)
if isinstance(path_or_code, str):
ext = pathlib.Path(path_or_code).suffix
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are we missing a bit of logic for handling .ltoir here? Or has it gone somewhere else? The original I'm looking at is https://github.com/rapidsai/pynvjitlink/blob/a2f23b7c3c237f2cdde3093c845e0453572503eb/pynvjitlink/patch.py#L170-L171

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I realise now the handling for this is taken care of by having LTOIR in the file extension map. See comments below (I'll have to link after posting the review because the links don't exist before I post the review.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if ext == '':
raise RuntimeError(
"Don't know how to link file with no extension"
)
elif ext == '.cu':
self.add_cu_file(path_or_code)
else:
kind = FILE_EXTENSION_MAP.get(ext, None)
if kind is None:
raise RuntimeError(
"Don't know how to link file with extension "
f".{ext}"
)
self.add_file(path_or_code, kind)
return
else:
# Otherwise, we should have been given a LinkableCode object
if not isinstance(path_or_code, LinkableCode):
raise TypeError(
"Expected path to file or a LinkableCode object"
)

if path_or_code.kind == "cu":
self.add_cu(path_or_code.data, path_or_code.name)
else:
self.add_data(
path_or_code.data, path_or_code.kind, path_or_code.name
)

@abstractmethod
def complete(self):
Expand All @@ -2649,12 +2728,6 @@ def complete(self):
"""


_MVC_ERROR_MESSAGE = (
"Minor version compatibility requires ptxcompiler and cubinlinker packages "
"to be available"
)


class MVCLinker(Linker):
"""
Linker supporting Minor Version Compatibility, backed by the cubinlinker
Expand All @@ -2664,7 +2737,7 @@ def __init__(self, max_registers=None, lineinfo=False, cc=None):
try:
from cubinlinker import CubinLinker
except ImportError as err:
raise ImportError(_MVC_ERROR_MESSAGE) from err
raise ImportError(_MVC_ERROR_MESSAGE_CU11) from err

if cc is None:
raise RuntimeError("MVCLinker requires Compute Capability to be "
Expand Down Expand Up @@ -2696,7 +2769,7 @@ def add_ptx(self, ptx, name='<cudapy-ptx>'):
from ptxcompiler import compile_ptx
from cubinlinker import CubinLinkerError
except ImportError as err:
raise ImportError(_MVC_ERROR_MESSAGE) from err
raise ImportError(_MVC_ERROR_MESSAGE_CU11) from err
compile_result = compile_ptx(ptx.decode(), self.ptx_compile_options)
try:
self._linker.add_cubin(compile_result.compiled_program, name)
Expand All @@ -2707,7 +2780,7 @@ def add_file(self, path, kind):
try:
from cubinlinker import CubinLinkerError
except ImportError as err:
raise ImportError(_MVC_ERROR_MESSAGE) from err
raise ImportError(_MVC_ERROR_MESSAGE_CU11) from err

try:
with open(path, 'rb') as f:
Expand Down Expand Up @@ -2736,7 +2809,7 @@ def complete(self):
try:
from cubinlinker import CubinLinkerError
except ImportError as err:
raise ImportError(_MVC_ERROR_MESSAGE) from err
raise ImportError(_MVC_ERROR_MESSAGE_CU11) from err

try:
return self._linker.complete()
Expand Down Expand Up @@ -2930,6 +3003,113 @@ def complete(self):
return bytes(np.ctypeslib.as_array(cubin_ptr, shape=(size,)))


class PyNvJitLinker(Linker):
def __init__(
self,
max_registers=None,
lineinfo=False,
cc=None,
lto=False,
additional_flags=None,
):
if pynvjitlink_import_err is not None:
raise ImportError(_MVC_ERROR_MESSAGE_CU12)
gmarkall marked this conversation as resolved.
Show resolved Hide resolved
if cc is None:
raise RuntimeError("PyNvJitLinker requires CC to be specified")
if not any(isinstance(cc, t) for t in [list, tuple]):
raise TypeError("`cc` must be a list or tuple of length 2")

sm_ver = f"{cc[0] * 10 + cc[1]}"
arch = f"-arch=sm_{sm_ver}"
options = [arch]
if max_registers:
options.append(f"-maxrregcount={max_registers}")
if lineinfo:
options.append("-lineinfo")
if lto:
options.append("-lto")
if additional_flags is not None:
options.extend(additional_flags)

self._linker = NvJitLinker(*options)
self.lto = lto
self.options = options

@property
def info_log(self):
return self._linker.info_log

@property
def error_log(self):
return self._linker.error_log

def add_ptx(self, ptx, name="<cudapy-ptx>"):
self._linker.add_ptx(ptx, name)

def add_fatbin(self, fatbin, name="<external-fatbin>"):
self._linker.add_fatbin(fatbin, name)

def add_ltoir(self, ltoir, name="<external-ltoir>"):
self._linker.add_ltoir(ltoir, name)

def add_object(self, obj, name="<external-object>"):
self._linker.add_object(obj, name)

def add_file(self, path, kind):
try:
with open(path, "rb") as f:
data = f.read()
except FileNotFoundError:
raise LinkerError(f"{path} not found")

name = pathlib.Path(path).name
self.add_data(data, kind, name)

def add_data(self, data, kind, name):
if kind == FILE_EXTENSION_MAP["cubin"]:
fn = self._linker.add_cubin
elif kind == FILE_EXTENSION_MAP["fatbin"]:
fn = self._linker.add_fatbin
elif kind == FILE_EXTENSION_MAP["a"]:
fn = self._linker.add_library
elif kind == FILE_EXTENSION_MAP["ptx"]:
return self.add_ptx(data, name)
elif kind == FILE_EXTENSION_MAP["o"]:
fn = self._linker.add_object
elif kind == "ltoir":
brandon-b-miller marked this conversation as resolved.
Show resolved Hide resolved
fn = self._linker.add_ltoir
else:
raise LinkerError(f"Don't know how to link {kind}")

try:
fn(data, name)
except NvJitLinkError as e:
raise LinkerError from e

def add_cu(self, cu, name):
with driver.get_active_context() as ac:
dev = driver.get_device(ac.devnum)
cc = dev.compute_capability

ptx, log = nvrtc.compile(cu, name, cc)

if config.DUMP_ASSEMBLY:
print(("ASSEMBLY %s" % name).center(80, "-"))
print(ptx)
print("=" * 80)

# Link the program's PTX using the normal linker mechanism
ptx_name = os.path.splitext(name)[0] + ".ptx"
self.add_ptx(ptx.encode(), ptx_name)
gmarkall marked this conversation as resolved.
Show resolved Hide resolved

def complete(self):
try:
cubin = self._linker.get_linked_cubin()
self._linker._complete = True
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I just noticed whilst looking at #48 I think this is superfluous because pynvjitlink sets it itself when you get the linked cubin: https://github.com/rapidsai/pynvjitlink/blob/main/pynvjitlink/api.py#L85

return cubin
except NvJitLinkError as e:
raise LinkerError from e

# -----------------------------------------------------------------------------


Expand Down
5 changes: 4 additions & 1 deletion numba_cuda/numba/cuda/cudadrv/enums.py
Original file line number Diff line number Diff line change
Expand Up @@ -309,7 +309,10 @@
# Applicable options: PTX compiler options, ::CU_JIT_FALLBACK_STRATEGY
CU_JIT_INPUT_LIBRARY = 4

CU_JIT_NUM_INPUT_TYPES = 6
# LTO IR
CU_JIT_INPUT_LTO_IR = 5
gmarkall marked this conversation as resolved.
Show resolved Hide resolved

CU_JIT_NUM_INPUT_TYPES = 7
gmarkall marked this conversation as resolved.
Show resolved Hide resolved


# Online compiler and linker options
Expand Down
Loading
Loading