Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adding optional CUDA DLLs when installing onnxruntime_gpu #22506

Open
wants to merge 49 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
49 commits
Select commit Hold shift + click to select a range
65b0f6b
Adding optional CUDA DLLs when installing onnxruntime_gpu
jchen351 Oct 18, 2024
a335dc6
Lintrunner -a
jchen351 Oct 18, 2024
87c51fb
Merge branch 'main' into Cjian/cuda_pip
jchen351 Oct 21, 2024
c95dbce
Update python code
jchen351 Oct 21, 2024
990752e
update python lint to 3.12
jchen351 Oct 21, 2024
e15e3e0
update python lint to 3.12
jchen351 Oct 21, 2024
5e70dd0
Revert lint python to 3.10
jchen351 Oct 21, 2024
3bf6817
Merge branch 'main' into Cjian/cuda_pip
jchen351 Nov 6, 2024
faa6e3a
Use the regex to match .so , .so.nn where nn is a digital number
jchen351 Nov 6, 2024
2efad16
Import missing re
jchen351 Nov 6, 2024
f4f9d35
Adding nvidia-curand and nvidia-cuda-runtime
jchen351 Nov 6, 2024
c025719
Fix typo
jchen351 Nov 6, 2024
c7d0951
lintrunner -a
jchen351 Nov 6, 2024
5fea4d4
Adding 2 second output to os.walk()
jchen351 Nov 7, 2024
ca752b6
Merge branch 'main' into Cjian/cuda_pip
jchen351 Nov 11, 2024
f27a566
Try to install onnxruntime-gpu from local wheel with [cuda_dlls]
jchen351 Nov 11, 2024
120ddf9
Try to install onnxruntime-gpu from local wheel with [cuda_dlls]
jchen351 Nov 11, 2024
644b52e
Update onnxruntime/python/onnxruntime_cuda_temp_env.py
jchen351 Nov 11, 2024
f557c1e
Using os.add_dll_directory and ctypes.CDLL to load dynamic library
jchen351 Nov 11, 2024
f170664
Merge remote-tracking branch 'origin/Cjian/cuda_pip' into Cjian/cuda_pip
jchen351 Nov 11, 2024
5a0e3fb
linrunner -a
jchen351 Nov 11, 2024
cddc500
Move nvidia dll loading to __init__.py
jchen351 Nov 11, 2024
81ef596
rolling back accidentally added yml files
jchen351 Nov 11, 2024
05e8441
Merge branch 'main' into Cjian/cuda_pip
jchen351 Nov 11, 2024
513522f
rolling back accidentally added yml files
jchen351 Nov 11, 2024
77ac10a
add __package__ == "onnxruntime-gpu" condition
jchen351 Nov 11, 2024
fa785d7
preload both windows and linux dlibs
jchen351 Nov 11, 2024
cf7ba65
Move load nvidia dll to an other __init__.py
jchen351 Nov 11, 2024
d24c96c
Move load nvidia dll to an other __init__.py
jchen351 Nov 12, 2024
e9b913f
Move load nvidia dll to an other __init__.py
jchen351 Nov 12, 2024
144f066
Move load nvidia dll to an other __init__.py
jchen351 Nov 12, 2024
0560de9
remove print statements
jchen351 Nov 12, 2024
053d0a3
Merge remote-tracking branch 'origin/main' into Cjian/cuda_pip
jchen351 Nov 14, 2024
457f4a2
Refactor CUDA library regex patterns for Windows environments to sear…
jchen351 Nov 15, 2024
014833f
Refactor CUDA library regex patterns to search nvidia libraries.
jchen351 Nov 15, 2024
d2cbf27
Refactor CUDA library regex patterns to search nvidia libraries.
jchen351 Nov 15, 2024
02c73c6
Use site.getsitepackages()[-1] for both Windows and Linux.
jchen351 Dec 11, 2024
cd02bdb
Logging exception error and add "libcublasLt.so" to the list of libar…
jchen351 Dec 12, 2024
b342c18
Change the regex to exact match of lib files, ignore the case
jchen351 Dec 12, 2024
11b2604
Lintrunner -a
jchen351 Dec 12, 2024
74b8a91
change log level to debug
jchen351 Dec 13, 2024
5ddfc0f
change log level to debug
jchen351 Dec 13, 2024
287bd46
Change logging error to debugs and update messages
jchen351 Dec 16, 2024
dff876c
Update "libnvrtc.so.11", toi "libnvrtc.so.11.2"
jchen351 Dec 16, 2024
5923d1b
#This is not a mistake, it links to more specific version like libnvr…
jchen351 Dec 16, 2024
cf612fc
lintrunner -a
jchen351 Dec 17, 2024
d0ffbaf
change cuda_version() to cuda version
jchen351 Dec 19, 2024
ad0cf6b
Update save_build_and_package_info to allow to be used in non trainin…
jchen351 Dec 19, 2024
fad62cb
check linux before calling find_cudart_versions. Also remove if has_o…
jchen351 Dec 22, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
94 changes: 94 additions & 0 deletions onnxruntime/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,3 +76,97 @@
__version__ = version

onnxruntime_validation.check_distro_info()


def check_and_load_cuda_libs(root_directory, cuda_libs_):
# Convert the target library names to lowercase for case-insensitive comparison
tianleiwu marked this conversation as resolved.
Show resolved Hide resolved
# Convert the target library names to lowercase for case-insensitive comparison
if cuda_libs_ is None or len(cuda_libs_) == 0:
logging.info("No CUDA libraries provided for loading.")
return
cuda_libs_ = {lib.lower() for lib in cuda_libs_}
found_libs = {}
for dirpath, _, filenames in os.walk(root_directory):
# Convert filenames in the current directory to lowercase for comparison
files_in_dir = {file.lower(): file for file in filenames} # Map lowercase to original
# Find common libraries in the current directory
matched_libs = cuda_libs_.intersection(files_in_dir.keys())
for lib in matched_libs:
# Store the full path of the found DLL
full_path = os.path.join(dirpath, files_in_dir[lib])
found_libs[lib] = full_path
try:
# Load the DLL using ctypes
_ = ctypes.CDLL(full_path)
logging.info(f"Successfully loaded: {full_path}")
except OSError as e:
logging.info(f"Failed to load {full_path}: {e}")

# If all required libraries are found, stop the search
if set(found_libs.keys()) == cuda_libs_:
logging.info("All required CUDA libraries found and loaded.")
return
logging.info(
f"Failed to load CUDA libraries from site-packages/nvidia directory: {cuda_libs_ - found_libs.keys()}. They might be loaded later from standard search paths for shared libraries."
)
return


# Load nvidia libraries from site-packages/nvidia if the package is onnxruntime-gpu
jchen351 marked this conversation as resolved.
Show resolved Hide resolved
if cuda_version is not None and cuda_version != "":
Copy link
Contributor

@tianleiwu tianleiwu Dec 20, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In my test, cuda_version is still empty string. It is imported from onnxruntime.capi.onnxruntime_validation in line 73. That class only outputs cuda_version for training as below:

cuda_version = ""
if has_ortmodule:

We can remove the line of if has_ortmodule there.

import ctypes
import logging
import os
import platform
import site

cuda_version_ = tuple(map(int, cuda_version.split(".")))
# Get the site-packages path where nvidia packages are installed
site_packages_path = site.getsitepackages()[-1]
nvidia_path = os.path.join(site_packages_path, "nvidia")
# Traverse the directory and subdirectories
cuda_libs = ()
if platform.system() == "Windows": #
# Define the list of DLL patterns, nvrtc, curand and nvJitLink are not included for Windows
if (11, 0) <= cuda_version_ < (12, 0):
cuda_libs = (
"cublaslt64_11.dll",
"cublas64_11.dll",
"cufft64_10.dll",
"cudart64_11.dll",
"cudnn64_8.dll",
)
elif (12, 0) <= cuda_version_ < (13, 0):
cuda_libs = (
"cublaslt64_12.dll",
"cublas64_12.dll",
"cufft64_11.dll",
"cudart64_12.dll",
"cudnn64_9.dll",
)
elif platform.system() == "Linux":
if (11, 0) <= cuda_version_ < (12, 0):
# Define the patterns with optional version number and case-insensitivity
cuda_libs = (
"libcublaslt.so.11",
"libcublas.so.11",
"libcurand.so.10",
"libcufft.so.10",
"libcudart.so.11",
"libcudnn.so.8",
"libnvrtc.so.11.2",
# This is not a mistake, it links to more specific version like libnvrtc.so.11.8.89 etc.
)
elif (12, 0) <= cuda_version_ < (13, 0):
cuda_libs = (
"libcublaslt.so.12",
"libcublas.so.12",
"libcurand.so.10",
"libcufft.so.11",
"libcudart.so.12",
"libcudnn.so.9",
"libnvrtc.so.12",
)
else:
logging.info(f"Unsupported platform: {platform.system()}")
check_and_load_cuda_libs(nvidia_path, cuda_libs)
59 changes: 30 additions & 29 deletions onnxruntime/python/onnxruntime_validation.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,32 +99,33 @@
version = ""
cuda_version = ""

if has_ortmodule:
try:
# collect onnxruntime package name, version, and cuda version
from .build_and_package_info import __version__ as version
from .build_and_package_info import package_name
try:
# collect onnxruntime package name, version, and cuda version
from .build_and_package_info import __version__ as version
from .build_and_package_info import package_name

try: # noqa: SIM105
from .build_and_package_info import cuda_version
except Exception:

Check notice

Code scanning / CodeQL

Empty except Note

'except' clause does nothing but pass and there is no explanatory comment.
pass

try: # noqa: SIM105
from .build_and_package_info import cuda_version
if cuda_version:
# collect cuda library build info. the library info may not be available
# when the build environment has none or multiple libraries installed
try:
from .build_and_package_info import cudart_version
except Exception:
pass

if cuda_version:
# collect cuda library build info. the library info may not be available
# when the build environment has none or multiple libraries installed
try:
from .build_and_package_info import cudart_version
except Exception:
warnings.warn("WARNING: failed to get cudart_version from onnxruntime build info.")
cudart_version = None

def print_build_package_info():
warnings.warn(f"onnxruntime training package info: package_name: {package_name}")
warnings.warn(f"onnxruntime training package info: __version__: {version}")
warnings.warn(f"onnxruntime training package info: cuda_version: {cuda_version}")
warnings.warn(f"onnxruntime build info: cudart_version: {cudart_version}")
warnings.warn("WARNING: failed to get cudart_version from onnxruntime build info.")
cudart_version = None

def print_build_package_info():
warnings.warn(f"onnxruntime training package info: package_name: {package_name}")
warnings.warn(f"onnxruntime training package info: __version__: {version}")
warnings.warn(f"onnxruntime training package info: cuda_version: {cuda_version}")
warnings.warn(f"onnxruntime build info: cudart_version: {cudart_version}")

# Cudart only available on Linux
if platform.system().lower() == "linux":
# collection cuda library info from current environment.
from onnxruntime.capi.onnxruntime_collect_build_info import find_cudart_versions

Expand All @@ -133,13 +134,13 @@
print_build_package_info()
warnings.warn("WARNING: failed to find cudart version that matches onnxruntime build info")
warnings.warn(f"WARNING: found cudart versions: {local_cudart_versions}")
else:
# TODO: rcom
pass
else:
# TODO: rcom
pass

except Exception as e:
warnings.warn("WARNING: failed to collect onnxruntime version and build info")
print(e)
except Exception as e:
warnings.warn("WARNING: failed to collect onnxruntime version and build info")
print(e)

if import_ortmodule_exception:
raise import_ortmodule_exception
Expand Down
53 changes: 37 additions & 16 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import datetime
import logging
import platform
import re
import shlex
import subprocess
import sys
Expand Down Expand Up @@ -54,6 +55,7 @@ def parse_arg_remove_string(argv, arg_name_equal):
wheel_name_suffix = parse_arg_remove_string(sys.argv, "--wheel_name_suffix=")

cuda_version = None
cuda_version_major = None
rocm_version = None
is_migraphx = False
is_rocm = False
Expand All @@ -63,6 +65,11 @@ def parse_arg_remove_string(argv, arg_name_equal):
if wheel_name_suffix == "gpu":
# TODO: how to support multiple CUDA versions?
cuda_version = parse_arg_remove_string(sys.argv, "--cuda_version=")
if cuda_version is not None:
if not bool(re.match(r"^\d+\.\d+(\.\d+)?$", cuda_version)):
logger.error("CUDA version must be in format 'x.y' or 'x.y.z'")
sys.exit(1)
cuda_version_major = cuda_version.split(".")[0]
elif parse_arg_remove_boolean(sys.argv, "--use_rocm"):
is_rocm = True
rocm_version = parse_arg_remove_string(sys.argv, "--rocm_version=")
Expand Down Expand Up @@ -705,11 +712,22 @@ def reformat_run_count(count_str):
version_number = version_number + local_version
if is_rocm and enable_rocm_profiling:
version_number = version_number + ".profiling"

extras_require = {}
if wheel_name_suffix:
if not (enable_training and wheel_name_suffix == "gpu"):
# for training packages, local version is used to indicate device types
package_name = f"{package_name}-{wheel_name_suffix}"
if (wheel_name_suffix == "gpu" or wheel_name_suffix == "cuda") and cuda_version_major is not None:
extras_require = {
# Optional 'cuda_dlls' dependencies
"cuda_dlls": [
f"nvidia-cuda-nvrtc-cu{cuda_version_major}",
f"nvidia-cuda-runtime-cu{cuda_version_major}",
f"nvidia-cudnn-cu{cuda_version_major}",
tianleiwu marked this conversation as resolved.
Show resolved Hide resolved
f"nvidia-cufft-cu{cuda_version_major}",
f"nvidia-curand-cu{cuda_version_major}",
]
}

cmd_classes = {}
if bdist_wheel is not None:
Expand All @@ -727,21 +745,20 @@ def reformat_run_count(count_str):
install_requires = f.read().splitlines()


if enable_training:

def save_build_and_package_info(package_name, version_number, cuda_version, rocm_version):
sys.path.append(path.join(path.dirname(__file__), "onnxruntime", "python"))
from onnxruntime_collect_build_info import find_cudart_versions
def save_build_and_package_info(package_name, version_number, cuda_version, rocm_version):
sys.path.append(path.join(path.dirname(__file__), "onnxruntime", "python"))
from onnxruntime_collect_build_info import find_cudart_versions

version_path = path.join("onnxruntime", "capi", "build_and_package_info.py")
with open(version_path, "w") as f:
f.write(f"package_name = '{package_name}'\n")
f.write(f"__version__ = '{version_number}'\n")
version_path = path.join("onnxruntime", "capi", "build_and_package_info.py")
with open(version_path, "w") as f:
f.write(f"package_name = '{package_name}'\n")
f.write(f"__version__ = '{version_number}'\n")

if cuda_version:
f.write(f"cuda_version = '{cuda_version}'\n")
if cuda_version:
f.write(f"cuda_version = '{cuda_version}'\n")

# cudart_versions are integers
# cudart_versions are integers
if platform.system().lower() == "linux":
cudart_versions = find_cudart_versions(build_env=True)
if cudart_versions and len(cudart_versions) == 1:
f.write(f"cudart_version = {cudart_versions[0]}\n")
Expand All @@ -754,12 +771,15 @@ def save_build_and_package_info(package_name, version_number, cuda_version, rocm
else "found multiple cudart libraries"
),
)
elif rocm_version:
f.write(f"rocm_version = '{rocm_version}'\n")
elif rocm_version:
f.write(f"rocm_version = '{rocm_version}'\n")


if enable_training:
save_build_and_package_info(package_name, version_number, cuda_version, rocm_version)
else:
save_build_and_package_info(package_name, version_number, cuda_version, None)

# Setup
setup(
name=package_name,
version=version_number,
Expand All @@ -783,4 +803,5 @@ def save_build_and_package_info(package_name, version_number, cuda_version, rocm
]
},
classifiers=classifiers,
extras_require=extras_require,
)
Loading