diff --git a/mmcv/utils/parrots_wrapper.py b/mmcv/utils/parrots_wrapper.py index 99e1b8aa41..4cb4be98f2 100644 --- a/mmcv/utils/parrots_wrapper.py +++ b/mmcv/utils/parrots_wrapper.py @@ -1,12 +1,12 @@ -from distutils.version import StrictVersion from functools import partial +from pkg_resources import parse_version import torch TORCH_VERSION = torch.__version__ is_rocm_pytorch = False -if StrictVersion(TORCH_VERSION) >= StrictVersion('1.5'): +if parse_version(TORCH_VERSION) >= parse_version('1.5'): from torch.utils.cpp_extension import ROCM_HOME is_rocm_pytorch = True if ((torch.version.hip is not None) and (ROCM_HOME is not None)) else False diff --git a/setup.py b/setup.py index a770686938..4122c814d2 100644 --- a/setup.py +++ b/setup.py @@ -1,14 +1,12 @@ import glob import os import re -from distutils.version import StrictVersion -from pkg_resources import DistributionNotFound, get_distribution +from pkg_resources import DistributionNotFound, get_distribution, parse_version from setuptools import find_packages, setup -import torch - EXT_TYPE = '' try: + import torch if torch.__version__ == 'parrots': from parrots.utils.build_extension import BuildExtension EXT_TYPE = 'parrots' @@ -223,7 +221,7 @@ def get_extensions(): extra_compile_args = {'cxx': []} is_rocm_pytorch = False - if StrictVersion(torch.__version__) >= StrictVersion('1.5'): + if parse_version(torch.__version__) >= parse_version('1.5'): from torch.utils.cpp_extension import ROCM_HOME is_rocm_pytorch = True if ((torch.version.hip is not None) and (ROCM_HOME is not None)) else False diff --git a/tests/test_runner/test_utils.py b/tests/test_runner/test_utils.py index 596f7aeb81..88e0629c28 100644 --- a/tests/test_runner/test_utils.py +++ b/tests/test_runner/test_utils.py @@ -1,6 +1,6 @@ import os import random -from distutils.version import StrictVersion +from pkg_resources import parse_version import numpy as np import torch @@ -9,7 +9,7 @@ from mmcv.utils import TORCH_VERSION is_rocm_pytorch = False -if StrictVersion(TORCH_VERSION) >= StrictVersion('1.5'): +if parse_version(TORCH_VERSION) >= parse_version('1.5'): from torch.utils.cpp_extension import ROCM_HOME is_rocm_pytorch = True if ((torch.version.hip is not None) and (ROCM_HOME is not None)) else False