From e82c5e0625f28dd98cf4620ff0c68d19f4312ca7 Mon Sep 17 00:00:00 2001 From: XUANBABY Date: Fri, 9 Jul 2021 10:03:58 +0800 Subject: [PATCH] fix bug for how to get torch._version_ at setup.py --- mmcv/utils/parrots_wrapper.py | 4 ++-- setup.py | 8 +++----- tests/test_runner/test_utils.py | 4 ++-- 3 files changed, 7 insertions(+), 9 deletions(-) diff --git a/mmcv/utils/parrots_wrapper.py b/mmcv/utils/parrots_wrapper.py index 99e1b8aa41..4cb4be98f2 100644 --- a/mmcv/utils/parrots_wrapper.py +++ b/mmcv/utils/parrots_wrapper.py @@ -1,12 +1,12 @@ -from distutils.version import StrictVersion from functools import partial +from pkg_resources import parse_version import torch TORCH_VERSION = torch.__version__ is_rocm_pytorch = False -if StrictVersion(TORCH_VERSION) >= StrictVersion('1.5'): +if parse_version(TORCH_VERSION) >= parse_version('1.5'): from torch.utils.cpp_extension import ROCM_HOME is_rocm_pytorch = True if ((torch.version.hip is not None) and (ROCM_HOME is not None)) else False diff --git a/setup.py b/setup.py index a770686938..4122c814d2 100644 --- a/setup.py +++ b/setup.py @@ -1,14 +1,12 @@ import glob import os import re -from distutils.version import StrictVersion -from pkg_resources import DistributionNotFound, get_distribution +from pkg_resources import DistributionNotFound, get_distribution, parse_version from setuptools import find_packages, setup -import torch - EXT_TYPE = '' try: + import torch if torch.__version__ == 'parrots': from parrots.utils.build_extension import BuildExtension EXT_TYPE = 'parrots' @@ -223,7 +221,7 @@ def get_extensions(): extra_compile_args = {'cxx': []} is_rocm_pytorch = False - if StrictVersion(torch.__version__) >= StrictVersion('1.5'): + if parse_version(torch.__version__) >= parse_version('1.5'): from torch.utils.cpp_extension import ROCM_HOME is_rocm_pytorch = True if ((torch.version.hip is not None) and (ROCM_HOME is not None)) else False diff --git a/tests/test_runner/test_utils.py b/tests/test_runner/test_utils.py index 596f7aeb81..88e0629c28 100644 --- a/tests/test_runner/test_utils.py +++ b/tests/test_runner/test_utils.py @@ -1,6 +1,6 @@ import os import random -from distutils.version import StrictVersion +from pkg_resources import parse_version import numpy as np import torch @@ -9,7 +9,7 @@ from mmcv.utils import TORCH_VERSION is_rocm_pytorch = False -if StrictVersion(TORCH_VERSION) >= StrictVersion('1.5'): +if parse_version(TORCH_VERSION) >= parse_version('1.5'): from torch.utils.cpp_extension import ROCM_HOME is_rocm_pytorch = True if ((torch.version.hip is not None) and (ROCM_HOME is not None)) else False