Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adding optional CUDA DLLs when installing onnxruntime_gpu #22506

Open
wants to merge 50 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 13 commits
Commits
Show all changes
50 commits
Select commit Hold shift + click to select a range
65b0f6b
Adding optional CUDA DLLs when installing onnxruntime_gpu
jchen351 Oct 18, 2024
a335dc6
Lintrunner -a
jchen351 Oct 18, 2024
87c51fb
Merge branch 'main' into Cjian/cuda_pip
jchen351 Oct 21, 2024
c95dbce
Update python code
jchen351 Oct 21, 2024
990752e
update python lint to 3.12
jchen351 Oct 21, 2024
e15e3e0
update python lint to 3.12
jchen351 Oct 21, 2024
5e70dd0
Revert lint python to 3.10
jchen351 Oct 21, 2024
3bf6817
Merge branch 'main' into Cjian/cuda_pip
jchen351 Nov 6, 2024
faa6e3a
Use the regex to match .so , .so.nn where nn is a digital number
jchen351 Nov 6, 2024
2efad16
Import missing re
jchen351 Nov 6, 2024
f4f9d35
Adding nvidia-curand and nvidia-cuda-runtime
jchen351 Nov 6, 2024
c025719
Fix typo
jchen351 Nov 6, 2024
c7d0951
lintrunner -a
jchen351 Nov 6, 2024
5fea4d4
Adding 2 second output to os.walk()
jchen351 Nov 7, 2024
ca752b6
Merge branch 'main' into Cjian/cuda_pip
jchen351 Nov 11, 2024
f27a566
Try to install onnxruntime-gpu from local wheel with [cuda_dlls]
jchen351 Nov 11, 2024
120ddf9
Try to install onnxruntime-gpu from local wheel with [cuda_dlls]
jchen351 Nov 11, 2024
644b52e
Update onnxruntime/python/onnxruntime_cuda_temp_env.py
jchen351 Nov 11, 2024
f557c1e
Using os.add_dll_directory and ctypes.CDLL to load dynamic library
jchen351 Nov 11, 2024
f170664
Merge remote-tracking branch 'origin/Cjian/cuda_pip' into Cjian/cuda_pip
jchen351 Nov 11, 2024
5a0e3fb
linrunner -a
jchen351 Nov 11, 2024
cddc500
Move nvidia dll loading to __init__.py
jchen351 Nov 11, 2024
81ef596
rolling back accidentally added yml files
jchen351 Nov 11, 2024
05e8441
Merge branch 'main' into Cjian/cuda_pip
jchen351 Nov 11, 2024
513522f
rolling back accidentally added yml files
jchen351 Nov 11, 2024
77ac10a
add __package__ == "onnxruntime-gpu" condition
jchen351 Nov 11, 2024
fa785d7
preload both windows and linux dlibs
jchen351 Nov 11, 2024
cf7ba65
Move load nvidia dll to an other __init__.py
jchen351 Nov 11, 2024
d24c96c
Move load nvidia dll to an other __init__.py
jchen351 Nov 12, 2024
e9b913f
Move load nvidia dll to an other __init__.py
jchen351 Nov 12, 2024
144f066
Move load nvidia dll to an other __init__.py
jchen351 Nov 12, 2024
0560de9
remove print statements
jchen351 Nov 12, 2024
053d0a3
Merge remote-tracking branch 'origin/main' into Cjian/cuda_pip
jchen351 Nov 14, 2024
457f4a2
Refactor CUDA library regex patterns for Windows environments to sear…
jchen351 Nov 15, 2024
014833f
Refactor CUDA library regex patterns to search nvidia libraries.
jchen351 Nov 15, 2024
d2cbf27
Refactor CUDA library regex patterns to search nvidia libraries.
jchen351 Nov 15, 2024
02c73c6
Use site.getsitepackages()[-1] for both Windows and Linux.
jchen351 Dec 11, 2024
cd02bdb
Logging exception error and add "libcublasLt.so" to the list of libar…
jchen351 Dec 12, 2024
b342c18
Change the regex to exact match of lib files, ignore the case
jchen351 Dec 12, 2024
11b2604
Lintrunner -a
jchen351 Dec 12, 2024
74b8a91
change log level to debug
jchen351 Dec 13, 2024
5ddfc0f
change log level to debug
jchen351 Dec 13, 2024
287bd46
Change logging error to debugs and update messages
jchen351 Dec 16, 2024
dff876c
Update "libnvrtc.so.11", toi "libnvrtc.so.11.2"
jchen351 Dec 16, 2024
5923d1b
#This is not a mistake, it links to more specific version like libnvr…
jchen351 Dec 16, 2024
cf612fc
lintrunner -a
jchen351 Dec 17, 2024
d0ffbaf
change cuda_version() to cuda version
jchen351 Dec 19, 2024
ad0cf6b
Update save_build_and_package_info to allow to be used in non trainin…
jchen351 Dec 19, 2024
fad62cb
check linux before calling find_cudart_versions. Also remove if has_o…
jchen351 Dec 22, 2024
f2aa262
Merge branch 'main' into Cjian/cuda_pip
jchen351 Jan 8, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
61 changes: 61 additions & 0 deletions onnxruntime/python/onnxruntime_cuda_temp_env.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
import os
Fixed Show fixed Hide fixed
Fixed Show fixed Hide fixed
Fixed Show fixed Hide fixed
Fixed Show fixed Hide fixed
import platform
import re
import site


class TemporaryEnv:
def __init__(self, updates):
self.original_env = os.environ.copy()
os.environ.update(updates)

def __exit__(self, exc_type, exc_value, traceback):
os.environ.clear()
os.environ.update(self.original_env)


def get_nvidia_dll_paths() -> str:
# Get the site-packages path where nvidia packages are installed
site_packages_path = site.getsitepackages()[0]
nvidia_path = os.path.join(site_packages_path, "nvidia")

# Collect all directories under site-packages/nvidia that contain .dll files (for Windows)
dll_paths = []
for root, files in os.walk(nvidia_path):
tianleiwu marked this conversation as resolved.
Show resolved Hide resolved
if any(file.endswith(".dll") for file in files):
dll_paths.append(root)
return os.pathsep.join(dll_paths)


def get_nvidia_so_paths() -> str:
# Get the site-packages path where nvidia packages are installed
site_packages_path = site.getsitepackages()[0]
nvidia_path = os.path.join(site_packages_path, "nvidia")

# Collect all directories under site-packages/nvidia that contain .so files (for Linux)
so_paths = []
# Regular expression to match `.so` optionally followed by `.` and digits
pattern = re.compile(r"\.so(\.\d+)?$")
Fixed Show fixed Hide fixed
for root, files in os.walk(nvidia_path):
tianleiwu marked this conversation as resolved.
Show resolved Hide resolved
for file in files:
if pattern.search(file):
so_paths.append(root)
return os.pathsep.join(so_paths)


def setup_temp_env_for_ort_cuda():
Fixed Show fixed Hide fixed
# Determine platform and set up the environment accordingly
if platform.system() == "Windows": # Windows
nvidia_dlls_path = get_nvidia_dll_paths()
if nvidia_dlls_path:
return TemporaryEnv({"PATH": nvidia_dlls_path + os.pathsep + os.environ.get("PATH")})
else:
return TemporaryEnv({"PATH": os.environ.get("PATH")})
tianleiwu marked this conversation as resolved.
Show resolved Hide resolved
elif platform.system() == "Linux":
nvidia_so_paths = get_nvidia_so_paths()
if nvidia_so_paths:
return TemporaryEnv({"LD_LIBRARY_PATH": nvidia_so_paths + os.pathsep + os.environ.get("LD_LIBRARY_PATH")})
else:
return TemporaryEnv({"LD_LIBRARY_PATH": os.environ.get("LD_LIBRARY_PATH")})
tianleiwu marked this conversation as resolved.
Show resolved Hide resolved
else:
return None
8 changes: 8 additions & 0 deletions onnxruntime/python/onnxruntime_inference_collection.py
Original file line number Diff line number Diff line change
Expand Up @@ -441,6 +441,9 @@ def __init__(
means execute a node using `CUDAExecutionProvider`
if capable, otherwise execute using `CPUExecutionProvider`.
"""
from .onnxruntime_cuda_temp_env import setup_temp_env_for_ort_cuda

self.env_manager = setup_temp_env_for_ort_cuda()
super().__init__()

if isinstance(path_or_bytes, (str, os.PathLike)):
Expand Down Expand Up @@ -577,6 +580,11 @@ def _register_ep_custom_ops(self, session_options, providers, provider_options,
):
C.register_tensorrt_plugins_as_custom_ops(session_options, providers[i][1])

def __exit__(self, exc_type, exc_value, traceback):
if self.env_manager is not None:
self.env_manager.__exit__()
return False


class IOBinding:
"""
Expand Down
21 changes: 20 additions & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import datetime
import logging
import platform
import re
import shlex
import subprocess
import sys
Expand Down Expand Up @@ -54,6 +55,7 @@ def parse_arg_remove_string(argv, arg_name_equal):
wheel_name_suffix = parse_arg_remove_string(sys.argv, "--wheel_name_suffix=")

cuda_version = None
cuda_version_major = None
rocm_version = None
is_migraphx = False
is_rocm = False
Expand All @@ -63,6 +65,11 @@ def parse_arg_remove_string(argv, arg_name_equal):
if wheel_name_suffix == "gpu":
# TODO: how to support multiple CUDA versions?
cuda_version = parse_arg_remove_string(sys.argv, "--cuda_version=")
if cuda_version is not None:
if not bool(re.match(r"^\d+\.\d+(\.\d+)?$", cuda_version)):
logger.error("CUDA version must be in format 'x.y' or 'x.y.z'")
sys.exit(1)
cuda_version_major = cuda_version.split(".")[0]
elif parse_arg_remove_boolean(sys.argv, "--use_rocm"):
is_rocm = True
rocm_version = parse_arg_remove_string(sys.argv, "--rocm_version=")
Expand Down Expand Up @@ -705,11 +712,22 @@ def reformat_run_count(count_str):
version_number = version_number + local_version
if is_rocm and enable_rocm_profiling:
version_number = version_number + ".profiling"

extras_require = {}
if wheel_name_suffix:
if not (enable_training and wheel_name_suffix == "gpu"):
# for training packages, local version is used to indicate device types
package_name = f"{package_name}-{wheel_name_suffix}"
if wheel_name_suffix == "gpu" and cuda_version_major is not None:
extras_require = {
# Optional 'cuda_dlls' dependencies
"cuda_dlls": [
f"nvidia-cuda-nvrtc-cu{cuda_version_major}",
f"nvidia-cuda-runtime-cu{cuda_version_major}",
f"nvidia-cudnn-cu{cuda_version_major}",
tianleiwu marked this conversation as resolved.
Show resolved Hide resolved
f"nvidia-cufft-cu{cuda_version_major}",
f"nvidia-curand-cu{cuda_version_major}",
]
}

cmd_classes = {}
if bdist_wheel is not None:
Expand Down Expand Up @@ -783,4 +801,5 @@ def save_build_and_package_info(package_name, version_number, cuda_version, rocm
]
},
classifiers=classifiers,
extras_require=extras_require,
)
Loading