Skip to content

Commit

Permalink
Rewrite ZLUDA installer.
Browse files Browse the repository at this point in the history
  • Loading branch information
lshqqytiger committed Apr 30, 2024
1 parent 65588fc commit 89e60b8
Show file tree
Hide file tree
Showing 2 changed files with 56 additions and 68 deletions.
24 changes: 14 additions & 10 deletions modules/launch_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -462,19 +462,29 @@ def prepare_environment():
"TORCH_COMMAND",
f"pip install torch==2.3.0 torchvision --index-url {torch_index_url}",
)
error = None
from modules import zluda_installer
try:
from modules import zluda_installer
if args.use_zluda_dnn:
if zluda_installer.check_dnn_dependency():
zluda_installer.enable_dnn()
else:
print("Couldn't find the required dependency of ZLUDA DNN.")
zluda_installer.install()
zluda_installer.resolve_path()
torch_command = os.environ.get('TORCH_COMMAND', 'pip install torch==2.3.0 torchvision --index-url https://download.pytorch.org/whl/cu118')
print(f'Using ZLUDA in {zluda_installer.ZLUDA_PATH}')
zluda_path = zluda_installer.find()
zluda_installer.make_copy(zluda_path)
except Exception as e:
error = e
print(f'Failed to install ZLUDA: {e}')
if error is None:
try:
zluda_installer.load(zluda_path)
torch_command = os.environ.get('TORCH_COMMAND', 'pip install torch==2.3.0 torchvision --index-url https://download.pytorch.org/whl/cu118')
print(f'Using ZLUDA in {zluda_path}')
except Exception as e:
error = e
print(f'Failed to load ZLUDA: {e}')
if error is not None:
print('Using CPU-only torch')
torch_command = os.environ.get('TORCH_COMMAND', 'pip install torch torchvision')
elif args.use_ipex:
Expand Down Expand Up @@ -575,12 +585,6 @@ def prepare_environment():
if args.reinstall_torch or not is_installed("torch") or not is_installed("torchvision"):
run(f'"{python}" -m {torch_command}', "Installing torch and torchvision", "Couldn't install torch", live=True)
startup_timer.record("install torch")
if args.use_zluda:
try:
from modules.zluda_installer import patch as patch_torch
patch_torch()
except Exception as e:
print(f'ZLUDA: failed to automatically patch torch: {e}')

if args.use_ipex or args.use_directml or args.use_cpu_torch:
args.skip_torch_cuda_test = True
Expand Down
100 changes: 42 additions & 58 deletions modules/zluda_installer.py
Original file line number Diff line number Diff line change
@@ -1,51 +1,23 @@
import os
import ctypes
import shutil
import zipfile
import tarfile
import platform
import urllib.request


RELEASE = 'rel.2804604c29b5fa36deca9ece219d3970b61d4c27'
TARGETS = {
RELEASE = f"rel.{os.environ.get('ZLUDA_HASH', '2804604c29b5fa36deca9ece219d3970b61d4c27')}"
DLL_MAPPING = {
'cublas.dll': 'cublas64_11.dll',
'cusparse.dll': 'cusparse64_11.dll',
'nvrtc.dll': 'nvrtc64_112_0.dll',
}
ZLUDA_PATH = None
TORCHLIB_PATH = None
HIP_TARGETS = ('rocblas.dll', 'rocsolver.dll', 'hiprtc0507.dll',)
ZLUDA_TARGETS = ('nvcuda.dll', 'nvml.dll',)


def find_zluda_path():
zluda_path = os.environ.get('ZLUDA', None)
if zluda_path is None:
paths = os.environ.get('PATH', '').split(';')
for path in paths:
if os.path.exists(os.path.join(path, 'zluda_redirect.dll')):
zluda_path = path
break
return zluda_path


def find_venv_dir():
python_dir = os.path.dirname(shutil.which('python'))
if shutil.which('conda') is None:
python_dir = os.path.dirname(python_dir)
return os.environ.get('VENV_DIR', python_dir)


def reset_torch():
for dll in TARGETS.values():
path = os.path.join(TORCHLIB_PATH, dll)
if os.path.exists(path):
os.remove(path)


def is_patched():
for dll in TARGETS.values():
if not os.path.islink(os.path.join(TORCHLIB_PATH, dll)):
return False
return True
def find():
return os.path.abspath(os.environ.get('ZLUDA', '.zluda'))


def check_dnn_dependency():
Expand All @@ -59,36 +31,48 @@ def check_dnn_dependency():

def enable_dnn():
global RELEASE # pylint: disable=global-statement
TARGETS['cudnn.dll'] = 'cudnn64_8.dll'
DLL_MAPPING['cudnn.dll'] = 'cudnn64_8.dll'
RELEASE = 'v3.8-pre2-dnn'


def install():
global ZLUDA_PATH, TORCHLIB_PATH # pylint: disable=global-statement
ZLUDA_PATH = find_zluda_path()
TORCHLIB_PATH = os.path.join(find_venv_dir(), 'Lib', 'site-packages', 'torch', 'lib')
zluda_path = find()

if ZLUDA_PATH is not None:
if os.path.exists(zluda_path):
return

is_windows = platform.system() == 'Windows'
archive_type = zipfile.ZipFile if is_windows else tarfile.TarFile
urllib.request.urlretrieve(f'https://github.com/lshqqytiger/ZLUDA/releases/download/{RELEASE}/ZLUDA-{platform.system().lower()}-amd64.{"zip" if is_windows else "tar.gz"}', '_zluda')
with archive_type('_zluda', 'r') as f:
f.extractall('.zluda')
ZLUDA_PATH = os.path.abspath('./.zluda')
os.remove('_zluda')

if platform.system() != 'Windows': # TODO
return

def resolve_path():
paths = os.environ.get('PATH', '.')
if ZLUDA_PATH not in paths:
os.environ['PATH'] = paths + ';' + ZLUDA_PATH
urllib.request.urlretrieve(f'https://github.com/lshqqytiger/ZLUDA/releases/download/{RELEASE}/ZLUDA-windows-amd64.zip', '_zluda')
with zipfile.ZipFile('_zluda', 'r') as archive:
infos = archive.infolist()
for info in infos:
if not info.is_dir():
info.filename = os.path.basename(info.filename)
archive.extract(info, '.zluda')
os.remove('_zluda')


def patch():
if is_patched():
return
reset_torch()
for k, v in TARGETS.items():
os.symlink(os.path.join(ZLUDA_PATH, k), os.path.join(TORCHLIB_PATH, v))
def make_copy(zluda_path: os.PathLike):
for k, v in DLL_MAPPING.items():
if not os.path.exists(os.path.join(zluda_path, v)):
try:
os.link(os.path.join(zluda_path, k), os.path.join(zluda_path, v))
except Exception:
shutil.copyfile(os.path.join(zluda_path, k), os.path.join(zluda_path, v))


def load(zluda_path: os.PathLike):
hip_path_default = r'C:\Program Files\AMD\ROCm\5.7'
if not os.path.exists(hip_path_default):
hip_path_default = None
hip_path = os.environ.get('HIP_PATH', hip_path_default)
if hip_path is None:
raise RuntimeError('Could not find %HIP_PATH%. Please install AMD HIP SDK.')
for v in HIP_TARGETS:
ctypes.windll.LoadLibrary(os.path.join(hip_path, 'bin', v))
for v in ZLUDA_TARGETS:
ctypes.windll.LoadLibrary(os.path.join(zluda_path, v))
for v in DLL_MAPPING.values():
ctypes.windll.LoadLibrary(os.path.join(zluda_path, v))

0 comments on commit 89e60b8

Please sign in to comment.