Skip to content

Commit

Permalink
Add pynvjitlink as a dependency (#14763)
Browse files Browse the repository at this point in the history
This PR adds `pynvjitlink` as a hard dependency for cuDF. This should allow for MVC when launching numba kernels across minor versions of CUDA 12 up to the version of `nvjitlink` statically shipped with `pynvjitlink`. 

cc @bdice

Authors:
  - https://github.com/brandon-b-miller
  - https://github.com/jakirkham
  - Bradley Dice (https://github.com/bdice)

Approvers:
  - Bradley Dice (https://github.com/bdice)
  - Ray Douglass (https://github.com/raydouglass)

URL: #14763
  • Loading branch information
brandon-b-miller authored Jan 19, 2024
1 parent 8fa2945 commit e0905ac
Show file tree
Hide file tree
Showing 5 changed files with 18 additions and 22 deletions.
2 changes: 1 addition & 1 deletion ci/build_wheel.sh
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ fi
if [[ $PACKAGE_CUDA_SUFFIX == "-cu12" ]]; then
sed -i "s/cuda-python[<=>\.,0-9a]*/cuda-python>=12.0,<13.0a0/g" ${pyproject_file}
sed -i "s/cupy-cuda11x/cupy-cuda12x/g" ${pyproject_file}
sed -i "/ptxcompiler/d" ${pyproject_file}
sed -i "s/ptxcompiler/pynvjitlink/g" ${pyproject_file}
sed -i "/cubinlinker/d" ${pyproject_file}
fi

Expand Down
1 change: 1 addition & 0 deletions conda/environments/all_cuda-120_arch-x86_64.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,7 @@ dependencies:
- protobuf>=4.21,<5
- pyarrow==14.0.1.*
- pydata-sphinx-theme!=0.14.2
- pynvjitlink
- pytest
- pytest-benchmark
- pytest-cases>=3.8.2
Expand Down
3 changes: 2 additions & 1 deletion conda/recipes/cudf/meta.yaml
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright (c) 2018-2023, NVIDIA CORPORATION.
# Copyright (c) 2018-2024, NVIDIA CORPORATION.

{% set version = environ['RAPIDS_PACKAGE_VERSION'].lstrip('v') %}
{% set minor_version = version.split('.')[0] + '.' + version.split('.')[1] %}
Expand Down Expand Up @@ -98,6 +98,7 @@ requirements:
# xref: https://github.com/rapidsai/cudf/issues/12822
- cuda-nvrtc
- cuda-python >=12.0,<13.0a0
- pynvjitlink
{% endif %}
- {{ pin_compatible('cuda-version', max_pin='x', min_pin='x') }}
- nvtx >=0.2.1
Expand Down
5 changes: 4 additions & 1 deletion dependencies.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -534,16 +534,19 @@ dependencies:
- {matrix: null, packages: *run_cudf_packages_all_cu11}
- output_types: conda
matrices:
- matrix: {cuda: "12.*"}
packages:
- pynvjitlink
- matrix: {cuda: "11.*"}
packages:
- cubinlinker
- ptxcompiler
- {matrix: null, packages: null}
- output_types: [requirements, pyproject]
matrices:
- matrix: {cuda: "12.*"}
packages:
- rmm-cu12==24.2.*
- pynvjitlink-cu12
- matrix: {cuda: "11.*"}
packages:
- rmm-cu11==24.2.*
Expand Down
29 changes: 10 additions & 19 deletions python/cudf/cudf/utils/_numba.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,23 +3,10 @@
import glob
import os
import sys
import warnings
from functools import lru_cache

from numba import config as numba_config

try:
from pynvjitlink.patch import (
patch_numba_linker as patch_numba_linker_pynvjitlink,
)
except ImportError:

def patch_numba_linker_pynvjitlink():
warnings.warn(
"CUDA Toolkit is newer than CUDA driver. "
"Numba features will not work in this configuration. "
)


# Use an lru_cache with a single value to allow a delayed import of
# strings_udf. This is the easiest way to break an otherwise circular import
Expand Down Expand Up @@ -117,11 +104,13 @@ def _setup_numba():
version of the CUDA Toolkit used to build the PTX files shipped
with the user cuDF package.
"""
# ptxcompiler is a requirement for cuda 11.x packages but not
# cuda 12.x packages. However its version checking machinery
# is still necessary. If a user happens to have ptxcompiler
# in a cuda 12 environment, it's use for the purposes of
# checking the driver and runtime versions is harmless

# Either ptxcompiler, or our vendored version (_ptxcompiler.py)
# is needed to determine the driver and runtime CUDA versions in
# the environment. In a CUDA 11.x environment, ptxcompiler is used
# to provide MVC directly, whereas for CUDA 12.x this is provided
# through pynvjitlink. The presence of either package does not
# perturb cuDF's operation in situations where they aren't used.
try:
from ptxcompiler.patch import NO_DRIVER, safe_get_versions
except ModuleNotFoundError:
Expand All @@ -145,7 +134,9 @@ def _setup_numba():
if driver_version < (12, 0):
patch_numba_linker_cuda_11()
else:
patch_numba_linker_pynvjitlink()
from pynvjitlink.patch import patch_numba_linker

patch_numba_linker()


def _get_cuda_version_from_ptx_file(path):
Expand Down

0 comments on commit e0905ac

Please sign in to comment.