diff --git a/op_builder/__init__.py b/op_builder/__init__.py index 40ebebc5685f..aceced8cedef 100644 --- a/op_builder/__init__.py +++ b/op_builder/__init__.py @@ -5,6 +5,7 @@ from .transformer import TransformerBuilder from .stochastic_transformer import StochasticTransformerBuilder from .utils import UtilsBuilder +from .builder import get_default_compute_capatabilities # TODO: infer this list instead of hard coded # List of all available ops diff --git a/op_builder/builder.py b/op_builder/builder.py index c1116fad007a..13b5a4b046e7 100644 --- a/op_builder/builder.py +++ b/op_builder/builder.py @@ -11,9 +11,10 @@ WARNING = f"{YELLOW} [WARNING] {END}" DEFAULT_TORCH_EXTENSION_PATH = "/tmp/torch_extensions" +DEFAULT_COMPUTE_CAPABILITIES = "6.0;6.1;7.0" -def assert_no_cuda_mismatch(): +def installed_cuda_version(): import torch.utils.cpp_extension cuda_home = torch.utils.cpp_extension.CUDA_HOME assert cuda_home is not None, "CUDA_HOME does not exist, unable to compile CUDA op(s)" @@ -25,12 +26,26 @@ def assert_no_cuda_mismatch(): release_idx = output_split.index("release") release = output_split[release_idx + 1].replace(',', '').split(".") # Ignore patch versions, only look at major + minor + cuda_major, cuda_minor = release[:2] installed_cuda_version = ".".join(release[:2]) + return int(cuda_major), int(cuda_minor) + + +def get_default_compute_capatabilities(): + compute_caps = DEFAULT_COMPUTE_CAPABILITIES + if installed_cuda_version()[0] >= 11: + compute_caps += ";8.0" + return compute_caps + + +def assert_no_cuda_mismatch(): + cuda_major, cuda_minor = installed_cuda_version() + sys_cuda_version = f'{cuda_major}.{cuda_minor}' torch_cuda_version = ".".join(torch.version.cuda.split('.')[:2]) # This is a show-stopping error, should probably not proceed past this - if installed_cuda_version != torch_cuda_version: + if sys_cuda_version != torch_cuda_version: raise Exception( - f"Installed CUDA version {installed_cuda_version} does not match the " + f"Installed CUDA version {sys_cuda_version} does not match the " f"version torch was compiled with {torch.version.cuda}, unable to compile " "cuda/cpp extensions without a matching cuda version.") @@ -197,7 +212,10 @@ def jit_load(self, verbose=True): class CUDAOpBuilder(OpBuilder): - def compute_capability_args(self, cross_compile_archs=['60', '61', '70']): + def compute_capability_args(self, cross_compile_archs=None): + if cross_compile_archs is None: + cross_compile_archs = get_default_compute_capatabilities() + args = [] if self.jit_mode: # Compile for underlying architecture since we know it at runtime @@ -208,7 +226,8 @@ def compute_capability_args(self, cross_compile_archs=['60', '61', '70']): f'arch=compute_{compute_capability},code=compute_{compute_capability}') else: # Cross-compile mode, compile for various architectures - for compute_capability in cross_compile_archs: + for compute_capability in cross_compile_archs.split(';'): + compute_capability = compute_capability.replace('.', '') args.append('-gencode') args.append( f'arch=compute_{compute_capability},code=compute_{compute_capability}' diff --git a/setup.py b/setup.py index b6be8b14e370..bf2ff9813537 100755 --- a/setup.py +++ b/setup.py @@ -21,7 +21,7 @@ raise ImportError('Unable to import torch, please visit https://pytorch.org/ ' 'to see how to properly install torch on your system.') -import op_builder +from op_builder import ALL_OPS, get_default_compute_capatabilities def fetch_requirements(path): @@ -64,12 +64,10 @@ def fetch_requirements(path): "you can ignore this message. Adding compute capability for Pascal, Volta, and Turing " "(compute capabilities 6.0, 6.1, 6.2)") if os.environ.get("TORCH_CUDA_ARCH_LIST", None) is None: - os.environ["TORCH_CUDA_ARCH_LIST"] = "6.0;6.1;6.2;7.0;7.5" + os.environ["TORCH_CUDA_ARCH_LIST"] = get_default_compute_capatabilities() ext_modules = [] -from op_builder import ALL_OPS - # Default to pre-install kernels to false so we rely on JIT BUILD_OP_DEFAULT = int(os.environ.get('DS_BUILD_OPS', 0)) print(f"DS_BUILD_OPS={BUILD_OP_DEFAULT}")