Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix #39 by matching current Numba implementation of add_cu() #40

Merged
merged 4 commits into from
Jan 18, 2024
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 11 additions & 0 deletions .github/workflows/pr.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ jobs:
- test-conda
- build-wheels
- test-wheels
- test-patch
secrets: inherit
uses: rapidsai/shared-workflows/.github/workflows/[email protected]
checks:
Expand Down Expand Up @@ -53,6 +54,16 @@ jobs:
build_type: pull-request
test_script: "ci/test_conda.sh"
matrix_filter: ${{ needs.compute-matrix.outputs.TEST_MATRIX }}
test-patch:
needs:
- build-conda
- compute-matrix
uses: rapidsai/shared-workflows/.github/workflows/[email protected]
with:
build_type: pull-request
test_script: "ci/test_patch.sh"
run_codecov: false
matrix_filter: ${{ needs.compute-matrix.outputs.TEST_MATRIX }}
build-wheels:
needs:
- compute-matrix
Expand Down
12 changes: 12 additions & 0 deletions ci/run_patched_numba_tests.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
#!/usr/bin/env python
# Copyright (c) 2024, NVIDIA CORPORATION

from pynvjitlink import patch

patch.patch_numba_linker()

if __name__ == "__main__":
from numba.testing._runtests import _main
import sys

sys.exit(0 if _main(sys.argv) else 1)
2 changes: 1 addition & 1 deletion ci/test_conda.sh
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ rapids-mamba-retry create -n test \
cxx-compiler \
cuda-nvcc \
cuda-version=${RAPIDS_CUDA_VERSION%.*} \
"numba>=0.57" \
"numba>=0.58" \
make \
pytest \
python=${RAPIDS_PY_VERSION}
Expand Down
41 changes: 41 additions & 0 deletions ci/test_patch.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
#!/bin/bash
# Copyright (c) 2024, NVIDIA CORPORATION

set -euo pipefail

. /opt/conda/etc/profile.d/conda.sh

rapids-logger "Generate testing dependencies"
# TODO: Replace with rapids-dependency-file-generator
rapids-mamba-retry create -n test \
cuda-nvcc-impl \
cuda-nvrtc \
cuda-version=${RAPIDS_CUDA_VERSION%.*} \
"numba>=0.58" \
python=${RAPIDS_PY_VERSION}

# Temporarily allow unbound variables for conda activation.
set +u
conda activate test
set -u

PYTHON_CHANNEL=$(rapids-download-conda-from-s3 python)

rapids-print-env

rapids-mamba-retry install \
--channel "${PYTHON_CHANNEL}" \
pynvjitlink

rapids-logger "Check GPU usage"
nvidia-smi

EXITCODE=0
trap "EXITCODE=1" ERR
set +e

rapids-logger "Test Numba with patch"
python ci/run_patched_numba_tests.py numba.cuda.tests -v -m

rapids-logger "Test script exiting with value: $EXITCODE"
exit ${EXITCODE}
2 changes: 1 addition & 1 deletion conda/recipes/pynvjitlink/meta.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ requirements:
- scikit-build-core
run:
- python
- numba >=0.57
- numba >=0.58
- {{ pin_compatible('cuda-version', min_pin='x', max_pin='x.x') }}

about:
Expand Down
25 changes: 15 additions & 10 deletions pynvjitlink/patch.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
_numba_version_ok = False
_numba_error = None

required_numba_ver = (0, 57)
required_numba_ver = (0, 58)

mvc_docs_url = (
"https://numba.readthedocs.io/en/stable/cuda/" "minor_version_compatibility.html"
Expand All @@ -29,12 +29,13 @@

if _numba_version_ok:
from numba.core import config
from numba.cuda.cudadrv.driver import FILE_EXTENSION_MAP, Linker, LinkerError

if ver < (0, 58):
from numba.cuda.cudadrv.driver import NvrtcProgram
else:
from numba.cuda.cudadrv.nvrtc import NvrtcProgram
from numba.cuda.cudadrv import nvrtc
from numba.cuda.cudadrv.driver import (
driver,
FILE_EXTENSION_MAP,
Linker,
LinkerError,
)
else:
# Prevent the definition of PatchedLinker failing if we have no Numba
# Linker - it won't be used anyway.
Expand Down Expand Up @@ -117,16 +118,20 @@ def add_file(self, path, kind):
raise LinkerError from e

def add_cu(self, cu, name):
program = NvrtcProgram(cu, name)
with driver.get_active_context() as ac:
dev = driver.get_device(ac.devnum)
cc = dev.compute_capability

ptx, log = nvrtc.compile(cu, name, cc)

if config.DUMP_ASSEMBLY:
print(("ASSEMBLY %s" % name).center(80, "-"))
print(program.ptx.decode())
print(ptx)
print("=" * 80)

# Link the program's PTX using the normal linker mechanism
ptx_name = os.path.splitext(name)[0] + ".ptx"
self.add_ptx(program.ptx.rstrip(b"\x00"), ptx_name)
self.add_ptx(ptx.encode(), ptx_name)

def complete(self):
try:
Expand Down