Skip to content

Commit

Permalink
Test Numba iwth pynvjitlink patch in CI
Browse files Browse the repository at this point in the history
  • Loading branch information
gmarkall committed Jan 18, 2024
1 parent 59179fb commit f07467d
Show file tree
Hide file tree
Showing 4 changed files with 70 additions and 2 deletions.
11 changes: 11 additions & 0 deletions .github/workflows/pr.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ jobs:
- test-conda
- build-wheels
- test-wheels
- test-patch
secrets: inherit
uses: rapidsai/shared-workflows/.github/workflows/[email protected]
checks:
Expand Down Expand Up @@ -53,6 +54,16 @@ jobs:
build_type: pull-request
test_script: "ci/test_conda.sh"
matrix_filter: ${{ needs.compute-matrix.outputs.TEST_MATRIX }}
test-patch:
needs:
- build-conda
- compute-matrix
uses: rapidsai/shared-workflows/.github/workflows/[email protected]
with:
build_type: pull-request
test_script: "ci/test_patch.sh"
run_codecov: false
matrix_filter: ${{ needs.compute-matrix.outputs.TEST_MATRIX }}
build-wheels:
needs:
- compute-matrix
Expand Down
12 changes: 12 additions & 0 deletions ci/run_patched_numba_tests.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
#!/usr/bin/env python
# Copyright (c) 2024, NVIDIA CORPORATION

from pynvjitlink import patch

patch.patch_numba_linker()

if __name__ == "__main__":
from numba.testing._runtests import _main
import sys

sys.exit(0 if _main(sys.argv) else 1)
41 changes: 41 additions & 0 deletions ci/test_patch.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
#!/bin/bash
# Copyright (c) 2024, NVIDIA CORPORATION

set -euo pipefail

. /opt/conda/etc/profile.d/conda.sh

rapids-logger "Generate testing dependencies"
# TODO: Replace with rapids-dependency-file-generator
rapids-mamba-retry create -n test \
cuda-nvcc-impl \
cuda-nvrtc \
cuda-version=${RAPIDS_CUDA_VERSION%.*} \
"numba>=0.58" \
python=${RAPIDS_PY_VERSION}

# Temporarily allow unbound variables for conda activation.
set +u
conda activate test
set -u

PYTHON_CHANNEL=$(rapids-download-conda-from-s3 python)

rapids-print-env

rapids-mamba-retry install \
--channel "${PYTHON_CHANNEL}" \
pynvjitlink

rapids-logger "Check GPU usage"
nvidia-smi

EXITCODE=0
trap "EXITCODE=1" ERR
set +e

rapids-logger "Test Numba with patch"
python ci/run_patched_numba_tests.py numba.cuda.tests -v -m

rapids-logger "Test script exiting with value: $EXITCODE"
exit ${EXITCODE}
8 changes: 6 additions & 2 deletions pynvjitlink/patch.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,12 @@
if _numba_version_ok:
from numba.core import config
from numba.cuda.cudadrv import nvrtc
from numba.cuda.cudadrv.driver import (driver, FILE_EXTENSION_MAP, Linker,
LinkerError)
from numba.cuda.cudadrv.driver import (
driver,
FILE_EXTENSION_MAP,
Linker,
LinkerError,
)
else:
# Prevent the definition of PatchedLinker failing if we have no Numba
# Linker - it won't be used anyway.
Expand Down

0 comments on commit f07467d

Please sign in to comment.