Skip to content

Commit

Permalink
Update save_build_and_package_info to allow to be used in non trainin…
Browse files Browse the repository at this point in the history
…g package
  • Loading branch information
jchen351 committed Dec 19, 2024
1 parent d0ffbaf commit ad0cf6b
Show file tree
Hide file tree
Showing 2 changed files with 40 additions and 43 deletions.
28 changes: 12 additions & 16 deletions onnxruntime/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ def check_and_load_cuda_libs(root_directory, cuda_libs_):
# Convert the target library names to lowercase for case-insensitive comparison
# Convert the target library names to lowercase for case-insensitive comparison
if cuda_libs_ is None or len(cuda_libs_) == 0:
logging.debug("No CUDA libraries provided for loading.")
logging.info("No CUDA libraries provided for loading.")
return
cuda_libs_ = {lib.lower() for lib in cuda_libs_}
found_libs = {}
Expand All @@ -100,48 +100,43 @@ def check_and_load_cuda_libs(root_directory, cuda_libs_):
_ = ctypes.CDLL(full_path)
logging.info(f"Successfully loaded: {full_path}")
except OSError as e:
logging.debug(f"Failed to load {full_path}: {e}")
logging.info(f"Failed to load {full_path}: {e}")

# If all required libraries are found, stop the search
if set(found_libs.keys()) == cuda_libs_:
logging.info("All required CUDA libraries found and loaded.")
return
logging.debug(
logging.info(
f"Failed to load CUDA libraries from site-packages/nvidia directory: {cuda_libs_ - found_libs.keys()}. They might be loaded later from standard search paths for shared libraries."
)
return


# Load nvidia libraries from site-packages/nvidia if the package is onnxruntime-gpu
if (
__package__ == "onnxruntime-gpu"
# Just in case we rename the package name in the future
or __package__ == "onnxruntime-cuda"
or __package__ == "onnxruntime_gpu"
or __package__ == "onnxruntime_cuda"
):
if cuda_version is not None and cuda_version != "":
import ctypes
import logging
import os
import platform
import site

cuda_version_ = tuple(map(int, cuda_version.split(".")))
# Get the site-packages path where nvidia packages are installed
site_packages_path = site.getsitepackages()[-1]
nvidia_path = os.path.join(site_packages_path, "nvidia")
# Traverse the directory and subdirectories
cuda_libs = ()
if platform.system() == "Windows": #
# Define the list of DLL patterns, nvrtc, curand and nvJitLink are not included for Windows
if (11, 0) <= cuda_version < (12, 0):
if (11, 0) <= cuda_version_ < (12, 0):
cuda_libs = (
"cublaslt64_11.dll",
"cublas64_11.dll",
"cufft64_10.dll",
"cudart64_11.dll",
"cudnn64_8.dll",
)
elif (12, 0) <= cuda_version < (13, 0):
elif (12, 0) <= cuda_version_ < (13, 0):
cuda_libs = (
"cublaslt64_12.dll",
"cublas64_12.dll",
Expand All @@ -150,7 +145,7 @@ def check_and_load_cuda_libs(root_directory, cuda_libs_):
"cudnn64_9.dll",
)
elif platform.system() == "Linux":
if (11, 0) <= cuda_version < (12, 0):
if (11, 0) <= cuda_version_ < (12, 0):
# Define the patterns with optional version number and case-insensitivity
cuda_libs = (
"libcublaslt.so.11",
Expand All @@ -159,9 +154,10 @@ def check_and_load_cuda_libs(root_directory, cuda_libs_):
"libcufft.so.10",
"libcudart.so.11",
"libcudnn.so.8",
"libnvrtc.so.11.2", # This is not a mistake, it links to more specific version like libnvrtc.so.11.8.89 etc.
"libnvrtc.so.11.2",
# This is not a mistake, it links to more specific version like libnvrtc.so.11.8.89 etc.
)
elif (12, 0) <= cuda_version < (13, 0):
elif (12, 0) <= cuda_version_ < (13, 0):
cuda_libs = (
"libcublaslt.so.12",
"libcublas.so.12",
Expand All @@ -172,5 +168,5 @@ def check_and_load_cuda_libs(root_directory, cuda_libs_):
"libnvrtc.so.12",
)
else:
logging.debug(f"Unsupported platform: {platform.system()}")
logging.info(f"Unsupported platform: {platform.system()}")
check_and_load_cuda_libs(nvidia_path, cuda_libs)
55 changes: 28 additions & 27 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -717,7 +717,7 @@ def reformat_run_count(count_str):
if not (enable_training and wheel_name_suffix == "gpu"):
# for training packages, local version is used to indicate device types
package_name = f"{package_name}-{wheel_name_suffix}"
if wheel_name_suffix == "gpu" and cuda_version_major is not None:
if (wheel_name_suffix == "gpu" or wheel_name_suffix == "cuda") and cuda_version_major is not None:
extras_require = {
# Optional 'cuda_dlls' dependencies
"cuda_dlls": [
Expand Down Expand Up @@ -745,39 +745,40 @@ def reformat_run_count(count_str):
install_requires = f.read().splitlines()


if enable_training:
def save_build_and_package_info(package_name, version_number, cuda_version, rocm_version):
sys.path.append(path.join(path.dirname(__file__), "onnxruntime", "python"))
from onnxruntime_collect_build_info import find_cudart_versions

def save_build_and_package_info(package_name, version_number, cuda_version, rocm_version):
sys.path.append(path.join(path.dirname(__file__), "onnxruntime", "python"))
from onnxruntime_collect_build_info import find_cudart_versions
version_path = path.join("onnxruntime", "capi", "build_and_package_info.py")
with open(version_path, "w") as f:
f.write(f"package_name = '{package_name}'\n")
f.write(f"__version__ = '{version_number}'\n")

version_path = path.join("onnxruntime", "capi", "build_and_package_info.py")
with open(version_path, "w") as f:
f.write(f"package_name = '{package_name}'\n")
f.write(f"__version__ = '{version_number}'\n")
if cuda_version:
f.write(f"cuda_version = '{cuda_version}'\n")

if cuda_version:
f.write(f"cuda_version = '{cuda_version}'\n")
# cudart_versions are integers
cudart_versions = find_cudart_versions(build_env=True)
if cudart_versions and len(cudart_versions) == 1:
f.write(f"cudart_version = {cudart_versions[0]}\n")
else:
print(
"Error getting cudart version. ",
(
"did not find any cudart library"
if not cudart_versions or len(cudart_versions) == 0
else "found multiple cudart libraries"
),
)
elif rocm_version:
f.write(f"rocm_version = '{rocm_version}'\n")

# cudart_versions are integers
cudart_versions = find_cudart_versions(build_env=True)
if cudart_versions and len(cudart_versions) == 1:
f.write(f"cudart_version = {cudart_versions[0]}\n")
else:
print(
"Error getting cudart version. ",
(
"did not find any cudart library"
if not cudart_versions or len(cudart_versions) == 0
else "found multiple cudart libraries"
),
)
elif rocm_version:
f.write(f"rocm_version = '{rocm_version}'\n")

if enable_training:
save_build_and_package_info(package_name, version_number, cuda_version, rocm_version)
else:
save_build_and_package_info(package_name, version_number, cuda_version, None)

# Setup
setup(
name=package_name,
version=version_number,
Expand Down

0 comments on commit ad0cf6b

Please sign in to comment.