Skip to content

Commit

Permalink
Enable patching linker for LTO, and test LTO-IR linkable code (#73)
Browse files Browse the repository at this point in the history
#72 added support for linking LTO-IR files from memory instead of disk.
However, it could not be tested because:

- We had no way to generate an LTO-IR container for the test input -
NVCC doesn't support emitting an LTO-IR container, only an LTO-IR
wrapped in a fatbin wrapped in a host object.
- The Numba patch needs to patch Numba for LTO for linking LTO-IR to
work.

This PR implements both of the above, with these changes:

- Addition of a new test binary generator, `generate_raw_ltoir.py`. This
uses NVRTC to generate an LTO-IR container. It uses the cuda-python
bindings, so `cuda-python` is added to the test environments.
- The fixtures had ended up with some duplication and inconsistency
between `conftest.py` and other test files, so these are deduplicated,
and now kept in `conftest.py`. Appropriate updates to all tests using
the fixtures are made.
- The `lto` kwarg is added to the `patch_numba_linker()` function. It is
disabled by default, as enabling LTO will not presently work in all use
cases.
- Added a test for linking LTO-IR from memory.
  • Loading branch information
gmarkall authored Apr 25, 2024
1 parent edb96ef commit 687d81b
Show file tree
Hide file tree
Showing 10 changed files with 278 additions and 136 deletions.
1 change: 1 addition & 0 deletions ci/test_conda.sh
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ rapids-mamba-retry create -n test \
cxx-compiler \
cuda-nvcc \
cuda-nvrtc \
cuda-python \
cuda-version=${RAPIDS_CUDA_VERSION%.*} \
"numba>=0.58" \
make \
Expand Down
1 change: 1 addition & 0 deletions ci/test_patch.sh
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ rapids-logger "Install testing dependencies"
rapids-mamba-retry create -n test \
cuda-nvcc \
cuda-nvrtc \
cuda-python \
cuda-version=${RAPIDS_CUDA_VERSION%.*} \
"numba>=0.58" \
psutil \
Expand Down
1 change: 1 addition & 0 deletions ci/test_wheel.sh
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ rapids-logger "Install testing dependencies"
python -m pip install \
"numba>=0.58" \
psutil \
cuda-python \
pytest

rapids-logger "Download Wheel"
Expand Down
6 changes: 4 additions & 2 deletions pynvjitlink/patch.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
# Copyright (c) 2023-2024, NVIDIA CORPORATION.
from functools import partial
from pynvjitlink.api import NvJitLinker, NvJitLinkError

import os
Expand Down Expand Up @@ -132,6 +133,7 @@ def __init__(
options.extend(additional_flags)

self._linker = NvJitLinker(*options)
self.lto = lto
self.options = options

@property
Expand Down Expand Up @@ -250,14 +252,14 @@ def new_patched_linker(
)


def patch_numba_linker():
def patch_numba_linker(*, lto=False):
if not _numba_version_ok:
msg = f"Cannot patch Numba: {_numba_error}"
raise RuntimeError(msg)

# Replace the built-in linker that uses the Driver API with our linker that
# uses nvJitLink
Linker.new = new_patched_linker
Linker.new = partial(new_patched_linker, lto=lto)

# Add linkable code objects to Numba's top-level API
cuda.Archive = Archive
Expand Down
83 changes: 39 additions & 44 deletions pynvjitlink/tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,100 +57,95 @@ def absent_gpu_arch_flag(absent_gpu_compute_capability):
return f"-arch=sm_{major}{minor}"


@pytest.fixture(scope="session")
def device_functions_archive():
def read_test_file(filename):
test_dir = os.path.dirname(os.path.abspath(__file__))
path = os.path.join(test_dir, "test_device_functions.a")
path = os.path.join(test_dir, filename)
with open(path, "rb") as f:
return f.read()
return filename, f.read()


@pytest.fixture(scope="session")
def device_functions_cubin():
test_dir = os.path.dirname(os.path.abspath(__file__))
path = os.path.join(test_dir, "test_device_functions.cubin")
with open(path, "rb") as f:
return f.read()
def device_functions_cusource():
return read_test_file("test_device_functions.cu")


@pytest.fixture(scope="session")
def device_functions_cusource():
test_dir = os.path.dirname(os.path.abspath(__file__))
path = os.path.join(test_dir, "test_device_functions.cu")
with open(path, "r") as f:
return f.read()
def device_functions_cubin():
return read_test_file("test_device_functions.cubin")


@pytest.fixture(scope="session")
def device_functions_fatbin():
test_dir = os.path.dirname(os.path.abspath(__file__))
path = os.path.join(test_dir, "test_device_functions.fatbin")
with open(path, "rb") as f:
return f.read()
return read_test_file("test_device_functions.fatbin")


@pytest.fixture(scope="session")
def device_functions_ltoir():
return read_test_file("test_device_functions.ltoir")


@pytest.fixture(scope="session")
def device_functions_ltoir_object():
return read_test_file("test_device_functions.ltoir.o")


@pytest.fixture(scope="session")
def device_functions_object():
test_dir = os.path.dirname(os.path.abspath(__file__))
path = os.path.join(test_dir, "test_device_functions.o")
with open(path, "rb") as f:
return f.read()
return read_test_file("test_device_functions.o")


@pytest.fixture(scope="session")
def device_functions_ptx():
test_dir = os.path.dirname(os.path.abspath(__file__))
path = os.path.join(test_dir, "test_device_functions.ptx")
with open(path, "rb") as f:
return f.read()
def device_functions_archive():
return read_test_file("test_device_functions.a")


@pytest.fixture(scope="session")
def undefined_extern_cubin():
test_dir = os.path.dirname(os.path.abspath(__file__))
fatbin_path = os.path.join(test_dir, "undefined_extern.cubin")
with open(fatbin_path, "rb") as f:
return f.read()
def device_functions_ptx():
return read_test_file("test_device_functions.ptx")


@pytest.fixture(scope="session")
def device_functions_ltoir():
test_dir = os.path.dirname(os.path.abspath(__file__))
path = os.path.join(test_dir, "test_device_functions.ltoir")
with open(path, "rb") as f:
return f.read()
def undefined_extern_cubin():
return read_test_file("undefined_extern.cubin")


@pytest.fixture(scope="session")
def linkable_code_archive(device_functions_archive):
return Archive(device_functions_archive)
name, data = device_functions_archive
return Archive(data, name=name)


@pytest.fixture(scope="session")
def linkable_code_cubin(device_functions_cubin):
return Cubin(device_functions_cubin)
name, data = device_functions_cubin
return Cubin(data, name=name)


@pytest.fixture(scope="session")
def linkable_code_cusource(device_functions_cusource):
return CUSource(device_functions_cusource)
name, data = device_functions_cusource
return CUSource(data, name=name)


@pytest.fixture(scope="session")
def linkable_code_fatbin(device_functions_fatbin):
return Fatbin(device_functions_fatbin)
name, data = device_functions_fatbin
return Fatbin(data, name=name)


@pytest.fixture(scope="session")
def linkable_code_object(device_functions_object):
return Object(device_functions_object)
name, data = device_functions_object
return Object(data, name=name)


@pytest.fixture(scope="session")
def linkable_code_ptx(device_functions_ptx):
return PTXSource(device_functions_ptx)
name, data = device_functions_ptx
return PTXSource(data, name=name)


@pytest.fixture(scope="session")
def linkable_code_ltoir(device_functions_ltoir):
return LTOIR(device_functions_ltoir)
name, data = device_functions_ltoir
return LTOIR(data, name=name)
35 changes: 28 additions & 7 deletions pynvjitlink/tests/test_patch.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ def test_numba_patching():
from numba.cuda.cudadrv.driver import Linker

patch_numba_linker()
assert Linker.new is new_patched_linker
assert Linker.new.func is new_patched_linker


def test_create():
Expand Down Expand Up @@ -133,12 +133,6 @@ def test_add_file_guess_ext_invalid_input(
"linkable_code_fatbin",
"linkable_code_object",
"linkable_code_ptx",
pytest.param(
"linkable_code_ltoir",
marks=pytest.mark.xfail(
reason=".ltoir file is actually an object and lto=True missing"
),
),
),
)
def test_jit_with_linkable_code(file, request):
Expand All @@ -157,6 +151,33 @@ def kernel(result):
assert result[0] == 3


@pytest.fixture
def numba_linking_with_lto():
"""
Patch the linker for LTO for the duration of the test.
Afterwards, restore the linker to whatever it was before.
"""
from numba.cuda.cudadrv.driver import Linker

old_new = Linker.new
patch_numba_linker(lto=True)
yield
Linker.new = old_new


def test_jit_with_linkable_code_lto(linkable_code_ltoir, numba_linking_with_lto):
sig = "uint32(uint32, uint32)"
add_from_numba = cuda.declare_device("add_from_numba", sig)

@cuda.jit(link=[linkable_code_ltoir])
def kernel(result):
result[0] = add_from_numba(1, 2)

result = cuda.device_array(1)
kernel[1, 1](result)
assert result[0] == 3


@pytest.mark.skipif(
not _numba_version_ok,
reason=f"Requires Numba == {required_numba_ver[0]}.{required_numba_ver[1]}",
Expand Down
73 changes: 16 additions & 57 deletions pynvjitlink/tests/test_pynvjitlink.py
Original file line number Diff line number Diff line change
@@ -1,55 +1,12 @@
# Copyright (c) 2023-2024, NVIDIA CORPORATION. All rights reserved.

import os
import pytest

import pynvjitlink
from pynvjitlink import _nvjitlinklib
from pynvjitlink.api import InputType


def read_test_file(filename):
test_dir = os.path.dirname(os.path.abspath(__file__))
path = os.path.join(test_dir, filename)
with open(path, "rb") as f:
return filename, f.read()


@pytest.fixture
def device_functions_cubin():
return read_test_file("test_device_functions.cubin")


@pytest.fixture
def device_functions_fatbin():
return read_test_file("test_device_functions.fatbin")


@pytest.fixture
def device_functions_ltoir():
return read_test_file("test_device_functions.ltoir")


@pytest.fixture
def device_functions_object():
return read_test_file("test_device_functions.o")


@pytest.fixture
def device_functions_archive():
return read_test_file("test_device_functions.a")


@pytest.fixture
def device_functions_ptx():
return read_test_file("test_device_functions.ptx")


@pytest.fixture
def undefined_extern_cubin():
return read_test_file("undefined_extern.cubin")


def test_create_no_arch_error():
# nvjitlink expects at least the architecture to be specified.
with pytest.raises(RuntimeError, match="NVJITLINK_ERROR_MISSING_ARCH error"):
Expand Down Expand Up @@ -105,8 +62,8 @@ def test_add_file(input_file, input_type, gpu_arch_flag, request):
# We test the LTO input case separately as it requires the `-lto` flag. The
# OBJECT input type is used because the LTO-IR container is packaged in an ELF
# object when produced by NVCC.
def test_add_file_lto(device_functions_ltoir, gpu_arch_flag):
filename, data = device_functions_ltoir
def test_add_file_lto(device_functions_ltoir_object, gpu_arch_flag):
filename, data = device_functions_ltoir_object

handle = _nvjitlinklib.create(gpu_arch_flag, "-lto")
_nvjitlinklib.add_data(handle, InputType.OBJECT.value, data, filename)
Expand Down Expand Up @@ -165,11 +122,11 @@ def test_get_linked_cubin_link_not_complete_error(
_nvjitlinklib.destroy(handle)


def test_get_linked_cubin_from_lto(device_functions_ltoir, gpu_arch_flag):
filename, data = device_functions_ltoir
# device_functions_ltoir is a host object containing a fatbin containing an
# LTOIR container, because that is what NVCC produces when LTO is
# requested. So we need to use the OBJECT input type, and the linker
def test_get_linked_cubin_from_lto(device_functions_ltoir_object, gpu_arch_flag):
filename, data = device_functions_ltoir_object
# device_functions_ltoir_object is a host object containing a fatbin
# containing an LTOIR container, because that is what NVCC produces when
# LTO is requested. So we need to use the OBJECT input type, and the linker
# retrieves the LTO IR from it because we passed the -lto flag.
input_type = InputType.OBJECT.value
handle = _nvjitlinklib.create(gpu_arch_flag, "-lto")
Expand All @@ -182,11 +139,11 @@ def test_get_linked_cubin_from_lto(device_functions_ltoir, gpu_arch_flag):
assert cubin[:4] == b"\x7fELF"


def test_get_linked_ptx_from_lto(device_functions_ltoir, gpu_arch_flag):
filename, data = device_functions_ltoir
# device_functions_ltoir is a host object containing a fatbin containing an
# LTOIR container, because that is what NVCC produces when LTO is
# requested. So we need to use the OBJECT input type, and the linker
def test_get_linked_ptx_from_lto(device_functions_ltoir_object, gpu_arch_flag):
filename, data = device_functions_ltoir_object
# device_functions_ltoir_object is a host object containing a fatbin
# containing an LTOIR container, because that is what NVCC produces when
# LTO is requested. So we need to use the OBJECT input type, and the linker
# retrieves the LTO IR from it because we passed the -lto flag.
input_type = InputType.OBJECT.value
handle = _nvjitlinklib.create(gpu_arch_flag, "-lto", "-ptx")
Expand All @@ -196,9 +153,11 @@ def test_get_linked_ptx_from_lto(device_functions_ltoir, gpu_arch_flag):
_nvjitlinklib.destroy(handle)


def test_get_linked_ptx_link_not_complete_error(device_functions_ltoir, gpu_arch_flag):
def test_get_linked_ptx_link_not_complete_error(
device_functions_ltoir_object, gpu_arch_flag
):
handle = _nvjitlinklib.create(gpu_arch_flag, "-lto", "-ptx")
filename, data = device_functions_ltoir
filename, data = device_functions_ltoir_object
input_type = InputType.OBJECT.value
_nvjitlinklib.add_data(handle, input_type, data, filename)
with pytest.raises(RuntimeError, match="NVJITLINK_ERROR_INTERNAL error"):
Expand Down
Loading

0 comments on commit 687d81b

Please sign in to comment.