From d0ffbaf6096d2df01c440479d7041e576b070255 Mon Sep 17 00:00:00 2001 From: Jian Chen Date: Wed, 18 Dec 2024 18:58:13 -0800 Subject: [PATCH] change cuda_version() to cuda version --- onnxruntime/__init__.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/onnxruntime/__init__.py b/onnxruntime/__init__.py index 3daaf415bdf32..da2f3e14c194d 100644 --- a/onnxruntime/__init__.py +++ b/onnxruntime/__init__.py @@ -133,7 +133,7 @@ def check_and_load_cuda_libs(root_directory, cuda_libs_): cuda_libs = () if platform.system() == "Windows": # # Define the list of DLL patterns, nvrtc, curand and nvJitLink are not included for Windows - if (11, 0) <= cuda_version() < (12, 0): + if (11, 0) <= cuda_version < (12, 0): cuda_libs = ( "cublaslt64_11.dll", "cublas64_11.dll", @@ -141,7 +141,7 @@ def check_and_load_cuda_libs(root_directory, cuda_libs_): "cudart64_11.dll", "cudnn64_8.dll", ) - elif (12, 0) <= cuda_version() < (13, 0): + elif (12, 0) <= cuda_version < (13, 0): cuda_libs = ( "cublaslt64_12.dll", "cublas64_12.dll", @@ -150,7 +150,7 @@ def check_and_load_cuda_libs(root_directory, cuda_libs_): "cudnn64_9.dll", ) elif platform.system() == "Linux": - if (11, 0) <= cuda_version() < (12, 0): + if (11, 0) <= cuda_version < (12, 0): # Define the patterns with optional version number and case-insensitivity cuda_libs = ( "libcublaslt.so.11", @@ -161,7 +161,7 @@ def check_and_load_cuda_libs(root_directory, cuda_libs_): "libcudnn.so.8", "libnvrtc.so.11.2", # This is not a mistake, it links to more specific version like libnvrtc.so.11.8.89 etc. ) - elif (12, 0) <= cuda_version() < (13, 0): + elif (12, 0) <= cuda_version < (13, 0): cuda_libs = ( "libcublaslt.so.12", "libcublas.so.12",