forked from NVIDIA/apex
-
Notifications
You must be signed in to change notification settings - Fork 0
/
setup.py
144 lines (119 loc) · 5.23 KB
/
setup.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
import torch.cuda
import ctypes.util
import os
import re
import subprocess
from setuptools import setup, find_packages
from distutils.command.clean import clean
from torch.utils.cpp_extension import CUDAExtension, CUDA_HOME
# TODO: multiple modules, so we don't have to route all interfaces through
# the same interface.cpp file?
if not torch.cuda.is_available():
print("Warning: Torch did not find available GPUs on this system.\n",
"If your intention is to cross-compile, this is not an error.")
print("torch.__version__ = ", torch.__version__)
TORCH_MAJOR = int(torch.__version__.split('.')[0])
TORCH_MINOR = int(torch.__version__.split('.')[1])
if TORCH_MAJOR == 0 and TORCH_MINOR < 4:
raise RuntimeError("APEx requires Pytorch 0.4 or newer.\n" +
"The latest stable release can be obtained from https://pytorch.org/")
version_le_04 = []
if TORCH_MAJOR == 0 and TORCH_MINOR == 4:
version_le_04 = ['-DVERSION_LE_04']
def find(path, regex_func, collect=False):
collection = [] if collect else None
for root, dirs, files in os.walk(path):
for file in files:
if regex_func(file):
if collect:
collection.append(os.path.join(root, file))
else:
return os.path.join(root, file)
return list(set(collection))
# Due to https://github.com/pytorch/pytorch/issues/8223, for Pytorch <= 0.4
# torch.utils.cpp_extension's check for CUDA_HOME fails if there are no GPUs
# available on the system, which prevents cross-compiling and building via Dockerfiles.
# Workaround: manually search for CUDA_HOME if Pytorch <= 0.4.
def find_cuda_home():
cuda_path = None
CUDA_HOME = None
CUDA_HOME = os.getenv('CUDA_HOME', '/usr/local/cuda')
if not os.path.exists(CUDA_HOME):
# We use nvcc path on Linux and cudart path on macOS
cudart_path = ctypes.util.find_library('cudart')
if cudart_path is not None:
cuda_path = os.path.dirname(cudart_path)
if cuda_path is not None:
CUDA_HOME = os.path.dirname(cuda_path)
if not cuda_path and not CUDA_HOME:
nvcc_path = find('/usr/local/', re.compile("nvcc").search, False)
if nvcc_path:
CUDA_HOME = os.path.dirname(nvcc_path)
if CUDA_HOME:
os.path.dirname(CUDA_HOME)
if (not os.path.exists(CUDA_HOME+os.sep+"lib64")
or not os.path.exists(CUDA_HOME+os.sep+"include") ):
raise RuntimeError("Error: found NVCC at ",
nvcc_path,
" but could not locate CUDA libraries"+
" or include directories.")
raise RuntimeError("Error: Could not find cuda on this system. " +
"Please set your CUDA_HOME enviornment variable "
"to the CUDA base directory.")
return CUDA_HOME
if TORCH_MAJOR == 0 and TORCH_MINOR == 4:
if CUDA_HOME is None:
CUDA_HOME = find_cuda_home()
# Patch cpp_extension's view of CUDA_HOME:
torch.utils.cpp_extension.CUDA_HOME = CUDA_HOME
def get_cuda_version():
NVCC = find(CUDA_HOME+os.sep+"bin",
re.compile('nvcc$').search)
print("Found NVCC = ", NVCC)
# Parse output of nvcc to get cuda major version
nvcc_output = subprocess.check_output([NVCC, '--version']).decode("utf-8")
CUDA_LIB = re.compile(', V[0-9]+\.[0-9]+\.[0-9]+').search(nvcc_output).group(0).split('V')[1]
print("Found CUDA_LIB = ", CUDA_LIB)
CUDA_MAJOR = int(CUDA_LIB.split('.')[0])
print("Found CUDA_MAJOR = ", CUDA_MAJOR)
if CUDA_MAJOR < 8:
raise RuntimeError("APex requires CUDA 8.0 or newer")
return CUDA_MAJOR
if CUDA_HOME is not None:
print("Found CUDA_HOME = ", CUDA_HOME)
CUDA_MAJOR = get_cuda_version()
gencodes = ['-gencode', 'arch=compute_52,code=sm_52',
'-gencode', 'arch=compute_60,code=sm_60',
'-gencode', 'arch=compute_61,code=sm_61',]
if CUDA_MAJOR > 8:
gencodes += ['-gencode', 'arch=compute_70,code=sm_70',
'-gencode', 'arch=compute_70,code=compute_70',]
ext_modules = []
extension = CUDAExtension(
'apex_C', [
'csrc/interface.cpp',
'csrc/weight_norm_fwd_cuda.cu',
'csrc/weight_norm_bwd_cuda.cu',
'csrc/scale_cuda.cu',
],
extra_compile_args={'cxx': ['-g'] + version_le_04,
'nvcc': ['-O3'] + version_le_04 + gencodes})
ext_modules.append(extension)
else:
raise RuntimeError("Could not find Cuda install directory")
setup(
name='apex',
version='0.1',
packages=find_packages(exclude=('build',
'csrc',
'include',
'tests',
'dist',
'docs',
'tests',
'examples',
'apex.egg-info',)),
ext_modules=ext_modules,
description='PyTorch Extensions written by NVIDIA',
cmdclass={'build_ext': torch.utils.cpp_extension.BuildExtension},
)