Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adding optional CUDA DLLs when installing onnxruntime_gpu #22506

Open
wants to merge 49 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 38 commits
Commits
Show all changes
49 commits
Select commit Hold shift + click to select a range
65b0f6b
Adding optional CUDA DLLs when installing onnxruntime_gpu
jchen351 Oct 18, 2024
a335dc6
Lintrunner -a
jchen351 Oct 18, 2024
87c51fb
Merge branch 'main' into Cjian/cuda_pip
jchen351 Oct 21, 2024
c95dbce
Update python code
jchen351 Oct 21, 2024
990752e
update python lint to 3.12
jchen351 Oct 21, 2024
e15e3e0
update python lint to 3.12
jchen351 Oct 21, 2024
5e70dd0
Revert lint python to 3.10
jchen351 Oct 21, 2024
3bf6817
Merge branch 'main' into Cjian/cuda_pip
jchen351 Nov 6, 2024
faa6e3a
Use the regex to match .so , .so.nn where nn is a digital number
jchen351 Nov 6, 2024
2efad16
Import missing re
jchen351 Nov 6, 2024
f4f9d35
Adding nvidia-curand and nvidia-cuda-runtime
jchen351 Nov 6, 2024
c025719
Fix typo
jchen351 Nov 6, 2024
c7d0951
lintrunner -a
jchen351 Nov 6, 2024
5fea4d4
Adding 2 second output to os.walk()
jchen351 Nov 7, 2024
ca752b6
Merge branch 'main' into Cjian/cuda_pip
jchen351 Nov 11, 2024
f27a566
Try to install onnxruntime-gpu from local wheel with [cuda_dlls]
jchen351 Nov 11, 2024
120ddf9
Try to install onnxruntime-gpu from local wheel with [cuda_dlls]
jchen351 Nov 11, 2024
644b52e
Update onnxruntime/python/onnxruntime_cuda_temp_env.py
jchen351 Nov 11, 2024
f557c1e
Using os.add_dll_directory and ctypes.CDLL to load dynamic library
jchen351 Nov 11, 2024
f170664
Merge remote-tracking branch 'origin/Cjian/cuda_pip' into Cjian/cuda_pip
jchen351 Nov 11, 2024
5a0e3fb
linrunner -a
jchen351 Nov 11, 2024
cddc500
Move nvidia dll loading to __init__.py
jchen351 Nov 11, 2024
81ef596
rolling back accidentally added yml files
jchen351 Nov 11, 2024
05e8441
Merge branch 'main' into Cjian/cuda_pip
jchen351 Nov 11, 2024
513522f
rolling back accidentally added yml files
jchen351 Nov 11, 2024
77ac10a
add __package__ == "onnxruntime-gpu" condition
jchen351 Nov 11, 2024
fa785d7
preload both windows and linux dlibs
jchen351 Nov 11, 2024
cf7ba65
Move load nvidia dll to an other __init__.py
jchen351 Nov 11, 2024
d24c96c
Move load nvidia dll to an other __init__.py
jchen351 Nov 12, 2024
e9b913f
Move load nvidia dll to an other __init__.py
jchen351 Nov 12, 2024
144f066
Move load nvidia dll to an other __init__.py
jchen351 Nov 12, 2024
0560de9
remove print statements
jchen351 Nov 12, 2024
053d0a3
Merge remote-tracking branch 'origin/main' into Cjian/cuda_pip
jchen351 Nov 14, 2024
457f4a2
Refactor CUDA library regex patterns for Windows environments to sear…
jchen351 Nov 15, 2024
014833f
Refactor CUDA library regex patterns to search nvidia libraries.
jchen351 Nov 15, 2024
d2cbf27
Refactor CUDA library regex patterns to search nvidia libraries.
jchen351 Nov 15, 2024
02c73c6
Use site.getsitepackages()[-1] for both Windows and Linux.
jchen351 Dec 11, 2024
cd02bdb
Logging exception error and add "libcublasLt.so" to the list of libar…
jchen351 Dec 12, 2024
b342c18
Change the regex to exact match of lib files, ignore the case
jchen351 Dec 12, 2024
11b2604
Lintrunner -a
jchen351 Dec 12, 2024
74b8a91
change log level to debug
jchen351 Dec 13, 2024
5ddfc0f
change log level to debug
jchen351 Dec 13, 2024
287bd46
Change logging error to debugs and update messages
jchen351 Dec 16, 2024
dff876c
Update "libnvrtc.so.11", toi "libnvrtc.so.11.2"
jchen351 Dec 16, 2024
5923d1b
#This is not a mistake, it links to more specific version like libnvr…
jchen351 Dec 16, 2024
cf612fc
lintrunner -a
jchen351 Dec 17, 2024
d0ffbaf
change cuda_version() to cuda version
jchen351 Dec 19, 2024
ad0cf6b
Update save_build_and_package_info to allow to be used in non trainin…
jchen351 Dec 19, 2024
fad62cb
check linux before calling find_cudart_versions. Also remove if has_o…
jchen351 Dec 22, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
77 changes: 77 additions & 0 deletions onnxruntime/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,3 +76,80 @@
__version__ = version

onnxruntime_validation.check_distro_info()

# Load nvidia libraries from site-packages/nvidia if the package is onnxruntime-gpu
jchen351 marked this conversation as resolved.
Show resolved Hide resolved
if (
__package__ == "onnxruntime-gpu"
tianleiwu marked this conversation as resolved.
Show resolved Hide resolved
# Just in case we rename the package name in the future
or __package__ == "onnxruntime-cuda"
or __package__ == "onnxruntime_gpu"
or __package__ == "onnxruntime_cuda"
):
import ctypes
import os
import platform
import re
Fixed Show fixed Hide fixed
import site
import logging
# Get the site-packages path where nvidia packages are installed
site_packages_path = site.getsitepackages()[-1]
nvidia_path = os.path.join(site_packages_path, "nvidia")
# Traverse the directory and subdirectories
if platform.system() == "Windows": #
# Define the list of DLL patterns, curand and nvJitLink are not included for Windows
cuda_libs = (
"cublas",
"cublasLt",
"cudnn",
"cudart",
"cufft",
# "curand",
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add some comments why curand and nvJitLink are not included for Windows, but included in Linux.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It needs investigation.

# "nvJitLink",
)
# Construct a regex pattern for each library name with optional parts
# Pattern explanation:
# - `libname`: Match the base library name (e.g., "cudart")
# - `(?:64)?`: Optionally match "64"
# - `(?:_\d+)*`: Match zero or more occurrences of "_n" where "n" is one or more digits
# - `.dll$`: End with ".dll" ignoring case
lib_pattern = {lib: re.compile(rf"{lib}(?:64)?(?:_\d+)*\.dll$", re.IGNORECASE) for lib in cuda_libs}
# Collect all directories under site-packages/nvidia that contain .dll files (for Windows)
for root, _, files in os.walk(nvidia_path):
# Add the current directory to the DLL search path

with os.add_dll_directory(root):
# Find all .dll files in the current directory
for file in files:
for pattern in lib_pattern.items().values():
if pattern.match(file):
dll_path = os.path.join(root, file)
try:
_ = ctypes.CDLL(dll_path)
except Exception as e:
logging.error(f"Failed to load {dll_path}: {e}")
elif platform.system() == "Linux":
# Define the patterns with optional version number and case-insensitivity
cuda_libs = (
jchen351 marked this conversation as resolved.
Show resolved Hide resolved
"libcublas.so",
"libcublasLt.so",
"libcudnn.so",
"libcudart.so",
"libnvrtc.so",
"libcufft.so",
"libcurand.so",
"libnvJitLink.so",
)

# Regular expression to match .so files with optional versioning (e.g., .so, .so.1, .so.2.3)
lib_pattern = {pattern: re.compile(rf"{re.escape(pattern)}(\.\d+)*$", re.IGNORECASE) for pattern in cuda_libs}
tianleiwu marked this conversation as resolved.
Show resolved Hide resolved

# Traverse the directory and subdirectories
for root, _, files in os.walk(nvidia_path):
for file in files:
# Check if the file matches the .so pattern
for regex in lib_pattern.items().values():
if regex.match(file): # Check if the file matches the pattern
so_path = os.path.join(root, file)
_ = ctypes.CDLL(so_path)
else:
pass
21 changes: 20 additions & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import datetime
import logging
import platform
import re
import shlex
import subprocess
import sys
Expand Down Expand Up @@ -54,6 +55,7 @@ def parse_arg_remove_string(argv, arg_name_equal):
wheel_name_suffix = parse_arg_remove_string(sys.argv, "--wheel_name_suffix=")

cuda_version = None
cuda_version_major = None
rocm_version = None
is_migraphx = False
is_rocm = False
Expand All @@ -63,6 +65,11 @@ def parse_arg_remove_string(argv, arg_name_equal):
if wheel_name_suffix == "gpu":
# TODO: how to support multiple CUDA versions?
cuda_version = parse_arg_remove_string(sys.argv, "--cuda_version=")
if cuda_version is not None:
if not bool(re.match(r"^\d+\.\d+(\.\d+)?$", cuda_version)):
logger.error("CUDA version must be in format 'x.y' or 'x.y.z'")
sys.exit(1)
cuda_version_major = cuda_version.split(".")[0]
elif parse_arg_remove_boolean(sys.argv, "--use_rocm"):
is_rocm = True
rocm_version = parse_arg_remove_string(sys.argv, "--rocm_version=")
Expand Down Expand Up @@ -705,11 +712,22 @@ def reformat_run_count(count_str):
version_number = version_number + local_version
if is_rocm and enable_rocm_profiling:
version_number = version_number + ".profiling"

extras_require = {}
if wheel_name_suffix:
if not (enable_training and wheel_name_suffix == "gpu"):
# for training packages, local version is used to indicate device types
package_name = f"{package_name}-{wheel_name_suffix}"
if wheel_name_suffix == "gpu" and cuda_version_major is not None:
extras_require = {
# Optional 'cuda_dlls' dependencies
"cuda_dlls": [
f"nvidia-cuda-nvrtc-cu{cuda_version_major}",
f"nvidia-cuda-runtime-cu{cuda_version_major}",
f"nvidia-cudnn-cu{cuda_version_major}",
tianleiwu marked this conversation as resolved.
Show resolved Hide resolved
f"nvidia-cufft-cu{cuda_version_major}",
f"nvidia-curand-cu{cuda_version_major}",
]
}

cmd_classes = {}
if bdist_wheel is not None:
Expand Down Expand Up @@ -783,4 +801,5 @@ def save_build_and_package_info(package_name, version_number, cuda_version, rocm
]
},
classifiers=classifiers,
extras_require=extras_require,
)
Loading