From 660af62b945342276a6c813250a6fa4fcf5ff330 Mon Sep 17 00:00:00 2001 From: Danil Date: Thu, 30 Nov 2023 11:27:58 +0400 Subject: [PATCH] Added the ability to build a project with PyTorch 2.0. (#2553) * Added the ability to build a project with PyTorch 2.0. Namely, I added the flag -std=c++17 to extra_compile_args depending on the version of Torch. * Lost the condition for the presence of nvcc * Lost the condition for the presence of nvcc * Add parse_version * fix lint --------- Co-authored-by: Xin Chen --- setup.py | 12 ++++++++++-- 1 file changed, 10 insertions(+), 2 deletions(-) diff --git a/setup.py b/setup.py index ddd853648d..bd6cf1dc42 100644 --- a/setup.py +++ b/setup.py @@ -1,9 +1,11 @@ import os +from pkg_resources import parse_version from setuptools import find_packages, setup EXT_TYPE = '' try: + import torch from torch.utils.cpp_extension import BuildExtension cmd_class = {'build_ext': BuildExtension} EXT_TYPE = 'torch' @@ -139,7 +141,10 @@ def get_extensions(): # to compile those cpp files, so there is no need to add the # argument if platform.system() != 'Windows': - extra_compile_args['cxx'] = ['-std=c++14'] + if parse_version(torch.__version__) <= parse_version('1.12.1'): + extra_compile_args['cxx'] = ['-std=c++14'] + else: + extra_compile_args['cxx'] = ['-std=c++17'] include_dirs = [] @@ -159,7 +164,10 @@ def get_extensions(): # to compile those cpp files, so there is no need to add the # argument if 'nvcc' in extra_compile_args and platform.system() != 'Windows': - extra_compile_args['nvcc'] += ['-std=c++14'] + if parse_version(torch.__version__) <= parse_version('1.12.1'): + extra_compile_args['nvcc'] += ['-std=c++14'] + else: + extra_compile_args['nvcc'] += ['-std=c++17'] ext_ops = extension( name=ext_name,