Skip to content

Commit

Permalink
fix bug for how to get torch._version_ at setup.py
Browse files Browse the repository at this point in the history
  • Loading branch information
XuanBaby committed Jul 9, 2021
1 parent 8c18ffc commit e82c5e0
Show file tree
Hide file tree
Showing 3 changed files with 7 additions and 9 deletions.
4 changes: 2 additions & 2 deletions mmcv/utils/parrots_wrapper.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
from distutils.version import StrictVersion
from functools import partial
from pkg_resources import parse_version

import torch

TORCH_VERSION = torch.__version__

is_rocm_pytorch = False
if StrictVersion(TORCH_VERSION) >= StrictVersion('1.5'):
if parse_version(TORCH_VERSION) >= parse_version('1.5'):
from torch.utils.cpp_extension import ROCM_HOME
is_rocm_pytorch = True if ((torch.version.hip is not None) and
(ROCM_HOME is not None)) else False
Expand Down
8 changes: 3 additions & 5 deletions setup.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,12 @@
import glob
import os
import re
from distutils.version import StrictVersion
from pkg_resources import DistributionNotFound, get_distribution
from pkg_resources import DistributionNotFound, get_distribution, parse_version
from setuptools import find_packages, setup

import torch

EXT_TYPE = ''
try:
import torch
if torch.__version__ == 'parrots':
from parrots.utils.build_extension import BuildExtension
EXT_TYPE = 'parrots'
Expand Down Expand Up @@ -223,7 +221,7 @@ def get_extensions():
extra_compile_args = {'cxx': []}

is_rocm_pytorch = False
if StrictVersion(torch.__version__) >= StrictVersion('1.5'):
if parse_version(torch.__version__) >= parse_version('1.5'):
from torch.utils.cpp_extension import ROCM_HOME
is_rocm_pytorch = True if ((torch.version.hip is not None) and
(ROCM_HOME is not None)) else False
Expand Down
4 changes: 2 additions & 2 deletions tests/test_runner/test_utils.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import os
import random
from distutils.version import StrictVersion
from pkg_resources import parse_version

import numpy as np
import torch
Expand All @@ -9,7 +9,7 @@
from mmcv.utils import TORCH_VERSION

is_rocm_pytorch = False
if StrictVersion(TORCH_VERSION) >= StrictVersion('1.5'):
if parse_version(TORCH_VERSION) >= parse_version('1.5'):
from torch.utils.cpp_extension import ROCM_HOME
is_rocm_pytorch = True if ((torch.version.hip is not None) and
(ROCM_HOME is not None)) else False
Expand Down

0 comments on commit e82c5e0

Please sign in to comment.