Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add pynvjitlink as a dependency #14763

Merged
merged 19 commits into from
Jan 19, 2024
Merged
Show file tree
Hide file tree
Changes from 16 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion ci/build_wheel.sh
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ fi
if [[ $PACKAGE_CUDA_SUFFIX == "-cu12" ]]; then
sed -i "s/cuda-python[<=>\.,0-9a]*/cuda-python>=12.0,<13.0a0/g" ${pyproject_file}
sed -i "s/cupy-cuda11x/cupy-cuda12x/g" ${pyproject_file}
sed -i "/ptxcompiler/d" ${pyproject_file}
sed -i "s/ptxcompiler/pynvjitlink/g" ${pyproject_file}
sed -i "/cubinlinker/d" ${pyproject_file}
fi

Expand Down
1 change: 1 addition & 0 deletions conda/environments/all_cuda-120_arch-x86_64.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,7 @@ dependencies:
- protobuf>=4.21,<5
- pyarrow==14.0.1.*
- pydata-sphinx-theme!=0.14.2
- pynvjitlink
- pytest
- pytest-benchmark
- pytest-cases<3.8.2
Expand Down
3 changes: 2 additions & 1 deletion conda/recipes/cudf/meta.yaml
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright (c) 2018-2023, NVIDIA CORPORATION.
# Copyright (c) 2018-2024, NVIDIA CORPORATION.

{% set version = environ['RAPIDS_PACKAGE_VERSION'].lstrip('v') %}
{% set minor_version = version.split('.')[0] + '.' + version.split('.')[1] %}
Expand Down Expand Up @@ -98,6 +98,7 @@ requirements:
# xref: https://github.com/rapidsai/cudf/issues/12822
- cuda-nvrtc
- cuda-python >=12.0,<13.0a0
- pynvjitlink
{% endif %}
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Once we have a clearer idea on intended compatibility ( rapidsai/pynvjitlink#48 ), we may want to add some version constraints here

This could be done in a separate PR though

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, this is reasonable. John proposed pynvjitlink >=0.1.11,<0.2.0a0 offline, which seems appropriate to me.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah though let's discuss in the issue and we can do this as follow up (after this PR is merged)

- {{ pin_compatible('cuda-version', max_pin='x', min_pin='x') }}
- nvtx >=0.2.1
Expand Down
5 changes: 4 additions & 1 deletion dependencies.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -534,16 +534,19 @@ dependencies:
- {matrix: null, packages: *run_cudf_packages_all_cu11}
- output_types: conda
matrices:
- matrix: {cuda: "12.*"}
packages:
- pynvjitlink
- matrix: {cuda: "11.*"}
packages:
- cubinlinker
- ptxcompiler
- {matrix: null, packages: null}
- output_types: [requirements, pyproject]
matrices:
- matrix: {cuda: "12.*"}
packages:
- rmm-cu12==24.2.*
- pynvjitlink-cu12
- matrix: {cuda: "11.*"}
packages:
- rmm-cu11==24.2.*
Expand Down
29 changes: 10 additions & 19 deletions python/cudf/cudf/utils/_numba.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,23 +3,10 @@
import glob
import os
import sys
import warnings
from functools import lru_cache

from numba import config as numba_config

try:
from pynvjitlink.patch import (
patch_numba_linker as patch_numba_linker_pynvjitlink,
)
except ImportError:

def patch_numba_linker_pynvjitlink():
warnings.warn(
"CUDA Toolkit is newer than CUDA driver. "
"Numba features will not work in this configuration. "
)


# Use an lru_cache with a single value to allow a delayed import of
# strings_udf. This is the easiest way to break an otherwise circular import
Expand Down Expand Up @@ -117,11 +104,13 @@ def _setup_numba():
version of the CUDA Toolkit used to build the PTX files shipped
with the user cuDF package.
"""
# ptxcompiler is a requirement for cuda 11.x packages but not
# cuda 12.x packages. However its version checking machinery
# is still necessary. If a user happens to have ptxcompiler
# in a cuda 12 environment, it's use for the purposes of
# checking the driver and runtime versions is harmless

# Either ptxcompiler, or our vendored version (_ptxcompiler.py)
# is needed to determine the driver and runtime CUDA versions in
# the environment. In a CUDA 11.x environment, ptxcompiler is used
# to provide MVC directly, whereas for CUDA 12.x this is provided
# through pynvjitlink. The presence of either package does not
# perturb cuDF's operation in situations where they aren't used.
try:
from ptxcompiler.patch import NO_DRIVER, safe_get_versions
except ModuleNotFoundError:
Expand All @@ -145,7 +134,9 @@ def _setup_numba():
if driver_version < (12, 0):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we update the comment above to mention pynvjitlink and the corresponding role of that package? This comment:

    # ptxcompiler is a requirement for cuda 11.x packages but not
    # cuda 12.x packages. However its version checking machinery
    # is still necessary. If a user happens to have ptxcompiler
    # in a cuda 12 environment, it's use for the purposes of
    # checking the driver and runtime versions is harmless

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@brandon-b-miller I would generally advocate reviewing this entire file and any other files that relate to ptxcompiler/pynvjitlink to make sure things are named sensibly, etc. in a way that will support both CUDA 11 and CUDA 12+. I want the code comments and docs to reflect the implemented design going forward.

Keep in mind that we don't want to name things "CUDA 12" in the code if we can avoid it if it is likely that later versions will act in the same way.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

how about something like 7dbf9f2 ?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In a CUDA 12.x environment, ptxcompiler provides version checking, but not MVC directly

Is this true? We don't use ptxcompiler in CUDA 12 environments. No environment should have both ptxcompiler and pynvjitlink installed at the same time.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's technically _ptxcompiler.py in this case - our slimmed down, vendored version of the few functions we need.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ooooo. But I don't know how to distinguish ptxcompiler the package (only used when on CUDA 11) from _ptxcompiler.py the internal helper file (always active) from the text of this comment. Documenting that kind of thing clearly is what I want to achieve before merging this.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

some reworking in e8a90b9

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Much clearer! Thanks for iterating on this.

patch_numba_linker_cuda_11()
else:
patch_numba_linker_pynvjitlink()
from pynvjitlink.patch import patch_numba_linker

patch_numba_linker()


def _get_cuda_version_from_ptx_file(path):
Expand Down
Loading