Skip to content

Commit

Permalink
Add compute capability 8.0 if on cuda 11+ (microsoft#572)
Browse files Browse the repository at this point in the history
  • Loading branch information
jeffra authored Dec 3, 2020
1 parent 2d1f7c0 commit be33bea
Show file tree
Hide file tree
Showing 3 changed files with 27 additions and 9 deletions.
1 change: 1 addition & 0 deletions op_builder/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from .transformer import TransformerBuilder
from .stochastic_transformer import StochasticTransformerBuilder
from .utils import UtilsBuilder
from .builder import get_default_compute_capatabilities

# TODO: infer this list instead of hard coded
# List of all available ops
Expand Down
29 changes: 24 additions & 5 deletions op_builder/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,10 @@
WARNING = f"{YELLOW} [WARNING] {END}"

DEFAULT_TORCH_EXTENSION_PATH = "/tmp/torch_extensions"
DEFAULT_COMPUTE_CAPABILITIES = "6.0;6.1;7.0"


def assert_no_cuda_mismatch():
def installed_cuda_version():
import torch.utils.cpp_extension
cuda_home = torch.utils.cpp_extension.CUDA_HOME
assert cuda_home is not None, "CUDA_HOME does not exist, unable to compile CUDA op(s)"
Expand All @@ -25,12 +26,26 @@ def assert_no_cuda_mismatch():
release_idx = output_split.index("release")
release = output_split[release_idx + 1].replace(',', '').split(".")
# Ignore patch versions, only look at major + minor
cuda_major, cuda_minor = release[:2]
installed_cuda_version = ".".join(release[:2])
return int(cuda_major), int(cuda_minor)


def get_default_compute_capatabilities():
compute_caps = DEFAULT_COMPUTE_CAPABILITIES
if installed_cuda_version()[0] >= 11:
compute_caps += ";8.0"
return compute_caps


def assert_no_cuda_mismatch():
cuda_major, cuda_minor = installed_cuda_version()
sys_cuda_version = f'{cuda_major}.{cuda_minor}'
torch_cuda_version = ".".join(torch.version.cuda.split('.')[:2])
# This is a show-stopping error, should probably not proceed past this
if installed_cuda_version != torch_cuda_version:
if sys_cuda_version != torch_cuda_version:
raise Exception(
f"Installed CUDA version {installed_cuda_version} does not match the "
f"Installed CUDA version {sys_cuda_version} does not match the "
f"version torch was compiled with {torch.version.cuda}, unable to compile "
"cuda/cpp extensions without a matching cuda version.")

Expand Down Expand Up @@ -197,7 +212,10 @@ def jit_load(self, verbose=True):


class CUDAOpBuilder(OpBuilder):
def compute_capability_args(self, cross_compile_archs=['60', '61', '70']):
def compute_capability_args(self, cross_compile_archs=None):
if cross_compile_archs is None:
cross_compile_archs = get_default_compute_capatabilities()

args = []
if self.jit_mode:
# Compile for underlying architecture since we know it at runtime
Expand All @@ -208,7 +226,8 @@ def compute_capability_args(self, cross_compile_archs=['60', '61', '70']):
f'arch=compute_{compute_capability},code=compute_{compute_capability}')
else:
# Cross-compile mode, compile for various architectures
for compute_capability in cross_compile_archs:
for compute_capability in cross_compile_archs.split(';'):
compute_capability = compute_capability.replace('.', '')
args.append('-gencode')
args.append(
f'arch=compute_{compute_capability},code=compute_{compute_capability}'
Expand Down
6 changes: 2 additions & 4 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
raise ImportError('Unable to import torch, please visit https://pytorch.org/ '
'to see how to properly install torch on your system.')

import op_builder
from op_builder import ALL_OPS, get_default_compute_capatabilities


def fetch_requirements(path):
Expand Down Expand Up @@ -64,12 +64,10 @@ def fetch_requirements(path):
"you can ignore this message. Adding compute capability for Pascal, Volta, and Turing "
"(compute capabilities 6.0, 6.1, 6.2)")
if os.environ.get("TORCH_CUDA_ARCH_LIST", None) is None:
os.environ["TORCH_CUDA_ARCH_LIST"] = "6.0;6.1;6.2;7.0;7.5"
os.environ["TORCH_CUDA_ARCH_LIST"] = get_default_compute_capatabilities()

ext_modules = []

from op_builder import ALL_OPS

# Default to pre-install kernels to false so we rely on JIT
BUILD_OP_DEFAULT = int(os.environ.get('DS_BUILD_OPS', 0))
print(f"DS_BUILD_OPS={BUILD_OP_DEFAULT}")
Expand Down

0 comments on commit be33bea

Please sign in to comment.