Skip to content

Commit

Permalink
filter torch cuda arch < 6.0 (#955)
Browse files Browse the repository at this point in the history
* filter arch < 6.0

* remove unused codes
  • Loading branch information
CSY-ModelCloud authored Dec 23, 2024
1 parent 748a9c7 commit 6152c90
Showing 1 changed file with 5 additions and 0 deletions.
5 changes: 5 additions & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,10 @@

TORCH_CUDA_ARCH_LIST = os.environ.get("TORCH_CUDA_ARCH_LIST")

if TORCH_CUDA_ARCH_LIST:
arch_list = [arch for arch in TORCH_CUDA_ARCH_LIST.split() if float(arch.split('+')[0]) >= 6.0]
os.environ["TORCH_CUDA_ARCH_LIST"] = " ".join(arch_list)

version_vars = {}
exec("exec(open('gptqmodel/version.py').read()); version=__version__", {}, version_vars)
gptqmodel_version = version_vars['version']
Expand Down Expand Up @@ -109,6 +113,7 @@ def get_version_tag(is_cuda_release: bool = True) -> str:
if got_cuda_between_v6_and_v8:
FORCE_BUILD = True


if BUILD_CUDA_EXT:
if CUDA_RELEASE == "1":
common_setup_kwargs["version"] += f"+{get_version_tag(True)}"
Expand Down

0 comments on commit 6152c90

Please sign in to comment.