Skip to content

Commit

Permalink
[build] fix computer capability arch flags, add PTX, handle PTX (micr…
Browse files Browse the repository at this point in the history
…osoft#591)

* fix arch flags, add PTX

* bug fix

Co-authored-by: Jeff Rasley <[email protected]>
  • Loading branch information
stas00 and jeffra authored Dec 11, 2020
1 parent 0518252 commit 8a184b6
Showing 1 changed file with 6 additions and 3 deletions.
9 changes: 6 additions & 3 deletions op_builder/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -221,7 +221,7 @@ def compute_capability_args(self, cross_compile_archs=None):
1. `TORCH_CUDA_ARCH_LIST` takes priority over `cross_compile_archs`.
2. If neither is set default compute capabilities will be used
3. Under `jit_mode` compute capabilities of all visible cards will be used.
3. Under `jit_mode` compute capabilities of all visible cards will be used plus PTX
Format:
Expand All @@ -243,6 +243,7 @@ def compute_capability_args(self, cross_compile_archs=None):
if cc not in ccs:
ccs.append(cc)
ccs = sorted(ccs)
ccs[-1] += '+PTX'
else:
# Cross-compile mode, compile for various architectures
# env override takes priority
Expand All @@ -260,8 +261,10 @@ def compute_capability_args(self, cross_compile_archs=None):

args = []
for cc in ccs:
cc = cc.replace('.', '')
args.append(f'-gencode=arch=compute_{cc},code=compute_{cc}')
num = cc[0] + cc[2]
args.append(f'-gencode=arch=compute_{num},code=sm_{num}')
if cc.endswith('+PTX'):
args.append(f'-gencode=arch=compute_{num},code=compute_{num}')

return args

Expand Down

0 comments on commit 8a184b6

Please sign in to comment.