-
Notifications
You must be signed in to change notification settings - Fork 915
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add pynvjitlink
as a dependency
#14763
Changes from 16 commits
532da9f
02ef286
c6ae80c
6c110b4
482be0f
d301cb9
8b8a12c
aefd0c0
c6b3982
c424796
687225a
92c6bb1
7dbf9f2
e8a90b9
698f1a2
011038d
b0972a9
016c237
d5c1efa
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -3,23 +3,10 @@ | |
import glob | ||
import os | ||
import sys | ||
import warnings | ||
from functools import lru_cache | ||
|
||
from numba import config as numba_config | ||
|
||
try: | ||
from pynvjitlink.patch import ( | ||
patch_numba_linker as patch_numba_linker_pynvjitlink, | ||
) | ||
except ImportError: | ||
|
||
def patch_numba_linker_pynvjitlink(): | ||
warnings.warn( | ||
"CUDA Toolkit is newer than CUDA driver. " | ||
"Numba features will not work in this configuration. " | ||
) | ||
|
||
|
||
# Use an lru_cache with a single value to allow a delayed import of | ||
# strings_udf. This is the easiest way to break an otherwise circular import | ||
|
@@ -117,11 +104,13 @@ def _setup_numba(): | |
version of the CUDA Toolkit used to build the PTX files shipped | ||
with the user cuDF package. | ||
""" | ||
# ptxcompiler is a requirement for cuda 11.x packages but not | ||
# cuda 12.x packages. However its version checking machinery | ||
# is still necessary. If a user happens to have ptxcompiler | ||
# in a cuda 12 environment, it's use for the purposes of | ||
# checking the driver and runtime versions is harmless | ||
|
||
# Either ptxcompiler, or our vendored version (_ptxcompiler.py) | ||
# is needed to determine the driver and runtime CUDA versions in | ||
# the environment. In a CUDA 11.x environment, ptxcompiler is used | ||
# to provide MVC directly, whereas for CUDA 12.x this is provided | ||
# through pynvjitlink. The presence of either package does not | ||
# perturb cuDF's operation in situations where they aren't used. | ||
try: | ||
from ptxcompiler.patch import NO_DRIVER, safe_get_versions | ||
except ModuleNotFoundError: | ||
|
@@ -145,7 +134,9 @@ def _setup_numba(): | |
if driver_version < (12, 0): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can we update the comment above to mention pynvjitlink and the corresponding role of that package? This comment:
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @brandon-b-miller I would generally advocate reviewing this entire file and any other files that relate to ptxcompiler/pynvjitlink to make sure things are named sensibly, etc. in a way that will support both CUDA 11 and CUDA 12+. I want the code comments and docs to reflect the implemented design going forward. Keep in mind that we don't want to name things "CUDA 12" in the code if we can avoid it if it is likely that later versions will act in the same way. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. how about something like 7dbf9f2 ? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Is this true? We don't use ptxcompiler in CUDA 12 environments. No environment should have both ptxcompiler and pynvjitlink installed at the same time. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It's technically There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Ooooo. But I don't know how to distinguish There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. some reworking in e8a90b9 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Much clearer! Thanks for iterating on this. |
||
patch_numba_linker_cuda_11() | ||
else: | ||
patch_numba_linker_pynvjitlink() | ||
from pynvjitlink.patch import patch_numba_linker | ||
|
||
patch_numba_linker() | ||
|
||
|
||
def _get_cuda_version_from_ptx_file(path): | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Once we have a clearer idea on intended compatibility ( rapidsai/pynvjitlink#48 ), we may want to add some version constraints here
This could be done in a separate PR though
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, this is reasonable. John proposed
pynvjitlink >=0.1.11,<0.2.0a0
offline, which seems appropriate to me.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah though let's discuss in the issue and we can do this as follow up (after this PR is merged)