diff --git a/modules/launch_utils.py b/modules/launch_utils.py index 894e31a97a9..e830d764a79 100644 --- a/modules/launch_utils.py +++ b/modules/launch_utils.py @@ -462,19 +462,29 @@ def prepare_environment(): "TORCH_COMMAND", f"pip install torch==2.3.0 torchvision --index-url {torch_index_url}", ) + error = None + from modules import zluda_installer try: - from modules import zluda_installer if args.use_zluda_dnn: if zluda_installer.check_dnn_dependency(): zluda_installer.enable_dnn() else: print("Couldn't find the required dependency of ZLUDA DNN.") zluda_installer.install() - zluda_installer.resolve_path() - torch_command = os.environ.get('TORCH_COMMAND', 'pip install torch==2.3.0 torchvision --index-url https://download.pytorch.org/whl/cu118') - print(f'Using ZLUDA in {zluda_installer.ZLUDA_PATH}') + zluda_path = zluda_installer.find() + zluda_installer.make_copy(zluda_path) except Exception as e: + error = e print(f'Failed to install ZLUDA: {e}') + if error is None: + try: + zluda_installer.load(zluda_path) + torch_command = os.environ.get('TORCH_COMMAND', 'pip install torch==2.3.0 torchvision --index-url https://download.pytorch.org/whl/cu118') + print(f'Using ZLUDA in {zluda_path}') + except Exception as e: + error = e + print(f'Failed to load ZLUDA: {e}') + if error is not None: print('Using CPU-only torch') torch_command = os.environ.get('TORCH_COMMAND', 'pip install torch torchvision') elif args.use_ipex: @@ -575,12 +585,6 @@ def prepare_environment(): if args.reinstall_torch or not is_installed("torch") or not is_installed("torchvision"): run(f'"{python}" -m {torch_command}', "Installing torch and torchvision", "Couldn't install torch", live=True) startup_timer.record("install torch") - if args.use_zluda: - try: - from modules.zluda_installer import patch as patch_torch - patch_torch() - except Exception as e: - print(f'ZLUDA: failed to automatically patch torch: {e}') if args.use_ipex or args.use_directml or args.use_cpu_torch: args.skip_torch_cuda_test = True diff --git a/modules/zluda_installer.py b/modules/zluda_installer.py index 5da5588789b..0cbe5442d36 100644 --- a/modules/zluda_installer.py +++ b/modules/zluda_installer.py @@ -1,51 +1,23 @@ import os +import ctypes import shutil import zipfile -import tarfile import platform import urllib.request -RELEASE = 'rel.2804604c29b5fa36deca9ece219d3970b61d4c27' -TARGETS = { +RELEASE = f"rel.{os.environ.get('ZLUDA_HASH', '2804604c29b5fa36deca9ece219d3970b61d4c27')}" +DLL_MAPPING = { 'cublas.dll': 'cublas64_11.dll', 'cusparse.dll': 'cusparse64_11.dll', 'nvrtc.dll': 'nvrtc64_112_0.dll', } -ZLUDA_PATH = None -TORCHLIB_PATH = None +HIP_TARGETS = ('rocblas.dll', 'rocsolver.dll', 'hiprtc0507.dll',) +ZLUDA_TARGETS = ('nvcuda.dll', 'nvml.dll',) -def find_zluda_path(): - zluda_path = os.environ.get('ZLUDA', None) - if zluda_path is None: - paths = os.environ.get('PATH', '').split(';') - for path in paths: - if os.path.exists(os.path.join(path, 'zluda_redirect.dll')): - zluda_path = path - break - return zluda_path - - -def find_venv_dir(): - python_dir = os.path.dirname(shutil.which('python')) - if shutil.which('conda') is None: - python_dir = os.path.dirname(python_dir) - return os.environ.get('VENV_DIR', python_dir) - - -def reset_torch(): - for dll in TARGETS.values(): - path = os.path.join(TORCHLIB_PATH, dll) - if os.path.exists(path): - os.remove(path) - - -def is_patched(): - for dll in TARGETS.values(): - if not os.path.islink(os.path.join(TORCHLIB_PATH, dll)): - return False - return True +def find(): + return os.path.abspath(os.environ.get('ZLUDA', '.zluda')) def check_dnn_dependency(): @@ -59,36 +31,48 @@ def check_dnn_dependency(): def enable_dnn(): global RELEASE # pylint: disable=global-statement - TARGETS['cudnn.dll'] = 'cudnn64_8.dll' + DLL_MAPPING['cudnn.dll'] = 'cudnn64_8.dll' RELEASE = 'v3.8-pre2-dnn' def install(): - global ZLUDA_PATH, TORCHLIB_PATH # pylint: disable=global-statement - ZLUDA_PATH = find_zluda_path() - TORCHLIB_PATH = os.path.join(find_venv_dir(), 'Lib', 'site-packages', 'torch', 'lib') + zluda_path = find() - if ZLUDA_PATH is not None: + if os.path.exists(zluda_path): return - is_windows = platform.system() == 'Windows' - archive_type = zipfile.ZipFile if is_windows else tarfile.TarFile - urllib.request.urlretrieve(f'https://github.com/lshqqytiger/ZLUDA/releases/download/{RELEASE}/ZLUDA-{platform.system().lower()}-amd64.{"zip" if is_windows else "tar.gz"}', '_zluda') - with archive_type('_zluda', 'r') as f: - f.extractall('.zluda') - ZLUDA_PATH = os.path.abspath('./.zluda') - os.remove('_zluda') - + if platform.system() != 'Windows': # TODO + return -def resolve_path(): - paths = os.environ.get('PATH', '.') - if ZLUDA_PATH not in paths: - os.environ['PATH'] = paths + ';' + ZLUDA_PATH + urllib.request.urlretrieve(f'https://github.com/lshqqytiger/ZLUDA/releases/download/{RELEASE}/ZLUDA-windows-amd64.zip', '_zluda') + with zipfile.ZipFile('_zluda', 'r') as archive: + infos = archive.infolist() + for info in infos: + if not info.is_dir(): + info.filename = os.path.basename(info.filename) + archive.extract(info, '.zluda') + os.remove('_zluda') -def patch(): - if is_patched(): - return - reset_torch() - for k, v in TARGETS.items(): - os.symlink(os.path.join(ZLUDA_PATH, k), os.path.join(TORCHLIB_PATH, v)) +def make_copy(zluda_path: os.PathLike): + for k, v in DLL_MAPPING.items(): + if not os.path.exists(os.path.join(zluda_path, v)): + try: + os.link(os.path.join(zluda_path, k), os.path.join(zluda_path, v)) + except Exception: + shutil.copyfile(os.path.join(zluda_path, k), os.path.join(zluda_path, v)) + + +def load(zluda_path: os.PathLike): + hip_path_default = r'C:\Program Files\AMD\ROCm\5.7' + if not os.path.exists(hip_path_default): + hip_path_default = None + hip_path = os.environ.get('HIP_PATH', hip_path_default) + if hip_path is None: + raise RuntimeError('Could not find %HIP_PATH%. Please install AMD HIP SDK.') + for v in HIP_TARGETS: + ctypes.windll.LoadLibrary(os.path.join(hip_path, 'bin', v)) + for v in ZLUDA_TARGETS: + ctypes.windll.LoadLibrary(os.path.join(zluda_path, v)) + for v in DLL_MAPPING.values(): + ctypes.windll.LoadLibrary(os.path.join(zluda_path, v))