Skip to content

Commit

Permalink
change cuda_version() to cuda version
Browse files Browse the repository at this point in the history
  • Loading branch information
jchen351 committed Dec 19, 2024
1 parent cf612fc commit d0ffbaf
Showing 1 changed file with 4 additions and 4 deletions.
8 changes: 4 additions & 4 deletions onnxruntime/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,15 +133,15 @@ def check_and_load_cuda_libs(root_directory, cuda_libs_):
cuda_libs = ()
if platform.system() == "Windows": #
# Define the list of DLL patterns, nvrtc, curand and nvJitLink are not included for Windows
if (11, 0) <= cuda_version() < (12, 0):
if (11, 0) <= cuda_version < (12, 0):
cuda_libs = (
"cublaslt64_11.dll",
"cublas64_11.dll",
"cufft64_10.dll",
"cudart64_11.dll",
"cudnn64_8.dll",
)
elif (12, 0) <= cuda_version() < (13, 0):
elif (12, 0) <= cuda_version < (13, 0):
cuda_libs = (
"cublaslt64_12.dll",
"cublas64_12.dll",
Expand All @@ -150,7 +150,7 @@ def check_and_load_cuda_libs(root_directory, cuda_libs_):
"cudnn64_9.dll",
)
elif platform.system() == "Linux":
if (11, 0) <= cuda_version() < (12, 0):
if (11, 0) <= cuda_version < (12, 0):
# Define the patterns with optional version number and case-insensitivity
cuda_libs = (
"libcublaslt.so.11",
Expand All @@ -161,7 +161,7 @@ def check_and_load_cuda_libs(root_directory, cuda_libs_):
"libcudnn.so.8",
"libnvrtc.so.11.2", # This is not a mistake, it links to more specific version like libnvrtc.so.11.8.89 etc.
)
elif (12, 0) <= cuda_version() < (13, 0):
elif (12, 0) <= cuda_version < (13, 0):
cuda_libs = (
"libcublaslt.so.12",
"libcublas.so.12",
Expand Down

0 comments on commit d0ffbaf

Please sign in to comment.