Skip to content

Commit

Permalink
Fix #39 by matching current Numba implementation of add_cu() (#40)
Browse files Browse the repository at this point in the history
Numba 0.58 implements `add_cu()` differently to Numba 0.57 and below.
These changes bring the implementation in the patch in line with Numba
0.58 onwards, and makes the minimum supported Numba version 0.58, so
that we don't have to support multiple different implementations. There
is presently no widespread production use (or perhaps no production use
at all) of pynvjitlink with Numba 0.57, so keeping Numba 0.57
compatibility is not considered.

Testing of the patch against the current Numba version is added to the
CI configuration.

The version number is bumped to 0.1.12 to reflect the nature of this
change.

Fixes #39
  • Loading branch information
gmarkall authored Jan 18, 2024
1 parent d62cc1d commit 90bc3f2
Show file tree
Hide file tree
Showing 9 changed files with 84 additions and 15 deletions.
11 changes: 11 additions & 0 deletions .github/workflows/pr.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ jobs:
- test-conda
- build-wheels
- test-wheels
- test-patch
secrets: inherit
uses: rapidsai/shared-workflows/.github/workflows/[email protected]
checks:
Expand Down Expand Up @@ -53,6 +54,16 @@ jobs:
build_type: pull-request
test_script: "ci/test_conda.sh"
matrix_filter: ${{ needs.compute-matrix.outputs.TEST_MATRIX }}
test-patch:
needs:
- build-conda
- compute-matrix
uses: rapidsai/shared-workflows/.github/workflows/[email protected]
with:
build_type: pull-request
test_script: "ci/test_patch.sh"
run_codecov: false
matrix_filter: ${{ needs.compute-matrix.outputs.TEST_MATRIX }}
build-wheels:
needs:
- compute-matrix
Expand Down
2 changes: 1 addition & 1 deletion CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ cmake_minimum_required(VERSION 3.26.4 FATAL_ERROR)

project(
pynvjitlink
VERSION 0.1.11
VERSION 0.1.12
LANGUAGES CXX CUDA
)

Expand Down
12 changes: 12 additions & 0 deletions ci/run_patched_numba_tests.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
#!/usr/bin/env python
# Copyright (c) 2024, NVIDIA CORPORATION

from pynvjitlink import patch

patch.patch_numba_linker()

if __name__ == "__main__":
from numba.testing._runtests import _main
import sys

sys.exit(0 if _main(sys.argv) else 1)
2 changes: 1 addition & 1 deletion ci/test_conda.sh
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ rapids-mamba-retry create -n test \
cxx-compiler \
cuda-nvcc \
cuda-version=${RAPIDS_CUDA_VERSION%.*} \
"numba>=0.57" \
"numba>=0.58" \
make \
pytest \
python=${RAPIDS_PY_VERSION}
Expand Down
41 changes: 41 additions & 0 deletions ci/test_patch.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
#!/bin/bash
# Copyright (c) 2024, NVIDIA CORPORATION

set -euo pipefail

. /opt/conda/etc/profile.d/conda.sh

rapids-logger "Generate testing dependencies"
# TODO: Replace with rapids-dependency-file-generator
rapids-mamba-retry create -n test \
cuda-nvcc-impl \
cuda-nvrtc \
cuda-version=${RAPIDS_CUDA_VERSION%.*} \
"numba>=0.58" \
python=${RAPIDS_PY_VERSION}

# Temporarily allow unbound variables for conda activation.
set +u
conda activate test
set -u

PYTHON_CHANNEL=$(rapids-download-conda-from-s3 python)

rapids-print-env

rapids-mamba-retry install \
--channel "${PYTHON_CHANNEL}" \
pynvjitlink

rapids-logger "Check GPU usage"
nvidia-smi

EXITCODE=0
trap "EXITCODE=1" ERR
set +e

rapids-logger "Test Numba with patch"
python ci/run_patched_numba_tests.py numba.cuda.tests -v -m

rapids-logger "Test script exiting with value: $EXITCODE"
exit ${EXITCODE}
2 changes: 1 addition & 1 deletion conda/recipes/pynvjitlink/meta.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ requirements:
- scikit-build-core
run:
- python
- numba >=0.57
- numba >=0.58
- {{ pin_compatible('cuda-version', min_pin='x', max_pin='x.x') }}

about:
Expand Down
2 changes: 1 addition & 1 deletion pynvjitlink/VERSION
Original file line number Diff line number Diff line change
@@ -1 +1 @@
0.1.11
0.1.12
25 changes: 15 additions & 10 deletions pynvjitlink/patch.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
_numba_version_ok = False
_numba_error = None

required_numba_ver = (0, 57)
required_numba_ver = (0, 58)

mvc_docs_url = (
"https://numba.readthedocs.io/en/stable/cuda/" "minor_version_compatibility.html"
Expand All @@ -29,12 +29,13 @@

if _numba_version_ok:
from numba.core import config
from numba.cuda.cudadrv.driver import FILE_EXTENSION_MAP, Linker, LinkerError

if ver < (0, 58):
from numba.cuda.cudadrv.driver import NvrtcProgram
else:
from numba.cuda.cudadrv.nvrtc import NvrtcProgram
from numba.cuda.cudadrv import nvrtc
from numba.cuda.cudadrv.driver import (
driver,
FILE_EXTENSION_MAP,
Linker,
LinkerError,
)
else:
# Prevent the definition of PatchedLinker failing if we have no Numba
# Linker - it won't be used anyway.
Expand Down Expand Up @@ -117,16 +118,20 @@ def add_file(self, path, kind):
raise LinkerError from e

def add_cu(self, cu, name):
program = NvrtcProgram(cu, name)
with driver.get_active_context() as ac:
dev = driver.get_device(ac.devnum)
cc = dev.compute_capability

ptx, log = nvrtc.compile(cu, name, cc)

if config.DUMP_ASSEMBLY:
print(("ASSEMBLY %s" % name).center(80, "-"))
print(program.ptx.decode())
print(ptx)
print("=" * 80)

# Link the program's PTX using the normal linker mechanism
ptx_name = os.path.splitext(name)[0] + ".ptx"
self.add_ptx(program.ptx.rstrip(b"\x00"), ptx_name)
self.add_ptx(ptx.encode(), ptx_name)

def complete(self):
try:
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ build-backend = "scikit_build_core.build"

[project]
name = "pynvjitlink"
version = "0.1.11"
version = "0.1.12"
description = "nvJitLink Python binding"
readme = { file = "README.md", content-type = "text/markdown" }
authors = [
Expand Down

0 comments on commit 90bc3f2

Please sign in to comment.