Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Fix] use LooseVersion for version checking #1158

Merged
merged 1 commit into from
Jun 29, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion mmcv/cnn/bricks/activation.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from distutils.version import LooseVersion

import torch
import torch.nn as nn
import torch.nn.functional as F
Expand Down Expand Up @@ -70,7 +72,8 @@ def forward(self, input):
return F.gelu(input)


if TORCH_VERSION == 'parrots' or TORCH_VERSION < '1.4':
if (TORCH_VERSION == 'parrots'
or LooseVersion(TORCH_VERSION) < LooseVersion('1.4')):
ACTIVATION_LAYERS.register_module(module=GELU)
else:
ACTIVATION_LAYERS.register_module(module=nn.GELU)
Expand Down
8 changes: 6 additions & 2 deletions mmcv/ops/saconv.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from distutils.version import LooseVersion

import torch
import torch.nn as nn
import torch.nn.functional as F
Expand Down Expand Up @@ -103,7 +105,8 @@ def forward(self, x):
out_s = deform_conv2d(x, offset, weight, self.stride, self.padding,
self.dilation, self.groups, 1)
else:
if TORCH_VERSION < '1.5.0' or TORCH_VERSION == 'parrots':
if (LooseVersion(TORCH_VERSION) < LooseVersion('1.5.0')
or TORCH_VERSION == 'parrots'):
out_s = super().conv2d_forward(x, weight)
else:
out_s = super()._conv_forward(x, weight)
Expand All @@ -117,7 +120,8 @@ def forward(self, x):
out_l = deform_conv2d(x, offset, weight, self.stride, self.padding,
self.dilation, self.groups, 1)
else:
if TORCH_VERSION < '1.5.0' or TORCH_VERSION == 'parrots':
if (LooseVersion(TORCH_VERSION) < LooseVersion('1.5.0')
or TORCH_VERSION == 'parrots'):
out_l = super().conv2d_forward(x, weight)
else:
out_l = super()._conv_forward(x, weight)
Expand Down
10 changes: 6 additions & 4 deletions mmcv/parallel/distributed.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
# Copyright (c) Open-MMLab. All rights reserved.
from distutils.version import LooseVersion

import torch
from torch.nn.parallel.distributed import (DistributedDataParallel,
_find_tensors)
Expand Down Expand Up @@ -37,7 +39,7 @@ def train_step(self, *inputs, **kwargs):

# In PyTorch >= 1.7, ``reducer._rebuild_buckets()`` is moved from the
# end of backward to the beginning of forward.
if (TORCH_VERSION >= '1.7' and 'parrots'
if (LooseVersion(TORCH_VERSION) >= LooseVersion('1.7') and 'parrots'
not in TORCH_VERSION) and self.reducer._rebuild_buckets():
print_log(
'Reducer buckets have been rebuilt in this iteration.',
Expand All @@ -63,7 +65,7 @@ def train_step(self, *inputs, **kwargs):
else:
self.reducer.prepare_for_backward([])
else:
if TORCH_VERSION > '1.2':
if LooseVersion(TORCH_VERSION) > LooseVersion('1.2'):
self.require_forward_param_sync = False
return output

Expand All @@ -77,7 +79,7 @@ def val_step(self, *inputs, **kwargs):
"""
# In PyTorch >= 1.7, ``reducer._rebuild_buckets()`` is moved from the
# end of backward to the beginning of forward.
if (TORCH_VERSION >= '1.7' and 'parrots'
if (LooseVersion(TORCH_VERSION) >= LooseVersion('1.7') and 'parrots'
not in TORCH_VERSION) and self.reducer._rebuild_buckets():
print_log(
'Reducer buckets have been rebuilt in this iteration.',
Expand All @@ -103,6 +105,6 @@ def val_step(self, *inputs, **kwargs):
else:
self.reducer.prepare_for_backward([])
else:
if TORCH_VERSION > '1.2':
if LooseVersion(TORCH_VERSION) > LooseVersion('1.2'):
self.require_forward_param_sync = False
return output
4 changes: 3 additions & 1 deletion mmcv/parallel/distributed_deprecated.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
# Copyright (c) Open-MMLab. All rights reserved.
from distutils.version import LooseVersion

import torch
import torch.distributed as dist
import torch.nn as nn
Expand Down Expand Up @@ -40,7 +42,7 @@ def _sync_params(self):
self._dist_broadcast_coalesced(module_states,
self.broadcast_bucket_size)
if self.broadcast_buffers:
if TORCH_VERSION < '1.0':
if LooseVersion(TORCH_VERSION) < LooseVersion('1.0'):
buffers = [b.data for b in self.module._all_buffers()]
else:
buffers = [b.data for b in self.module.buffers()]
Expand Down
3 changes: 2 additions & 1 deletion mmcv/runner/dist_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import os
import subprocess
from collections import OrderedDict
from distutils.version import LooseVersion

import torch
import torch.multiprocessing as mp
Expand Down Expand Up @@ -78,7 +79,7 @@ def _init_dist_slurm(backend, port=None):


def get_dist_info():
if TORCH_VERSION < '1.0':
if LooseVersion(TORCH_VERSION) < LooseVersion('1.0'):
initialized = dist._initialized
else:
if dist.is_available():
Expand Down
10 changes: 7 additions & 3 deletions mmcv/runner/fp16_utils.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import functools
import warnings
from collections import abc
from distutils.version import LooseVersion
from inspect import getfullargspec

import numpy as np
Expand Down Expand Up @@ -121,7 +122,8 @@ def new_func(*args, **kwargs):
else:
new_kwargs[arg_name] = arg_value
# apply converted arguments to the decorated method
if TORCH_VERSION != 'parrots' and TORCH_VERSION >= '1.6.0':
if (TORCH_VERSION != 'parrots'
and LooseVersion(TORCH_VERSION) >= LooseVersion('1.6.0')):
with autocast(enabled=True):
output = old_func(*new_args, **new_kwargs)
else:
Expand Down Expand Up @@ -206,7 +208,8 @@ def new_func(*args, **kwargs):
else:
new_kwargs[arg_name] = arg_value
# apply converted arguments to the decorated method
if TORCH_VERSION != 'parrots' and TORCH_VERSION >= '1.6.0':
if (TORCH_VERSION != 'parrots'
and LooseVersion(TORCH_VERSION) >= LooseVersion('1.6.0')):
with autocast(enabled=False):
output = old_func(*new_args, **new_kwargs)
else:
Expand Down Expand Up @@ -245,7 +248,8 @@ def wrap_fp16_model(model):
Args:
model (nn.Module): Model in FP32.
"""
if TORCH_VERSION == 'parrots' or TORCH_VERSION < '1.6.0':
if (TORCH_VERSION == 'parrots'
or LooseVersion(TORCH_VERSION) < LooseVersion('1.6.0')):
# convert model to fp16
model.half()
# patch the normalization layers to make it work in fp32 mode
Expand Down
4 changes: 3 additions & 1 deletion mmcv/runner/hooks/logger/tensorboard.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
# Copyright (c) Open-MMLab. All rights reserved.
import os.path as osp
from distutils.version import LooseVersion

from mmcv.utils import TORCH_VERSION
from ...dist_utils import master_only
Expand All @@ -23,7 +24,8 @@ def __init__(self,
@master_only
def before_run(self, runner):
super(TensorboardLoggerHook, self).before_run(runner)
if TORCH_VERSION < '1.1' or TORCH_VERSION == 'parrots':
if (LooseVersion(TORCH_VERSION) < LooseVersion('1.1')
or TORCH_VERSION == 'parrots'):
try:
from tensorboardX import SummaryWriter
except ImportError:
Expand Down
4 changes: 3 additions & 1 deletion mmcv/runner/hooks/optimizer.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# Copyright (c) Open-MMLab. All rights reserved.
import copy
from collections import defaultdict
from distutils.version import LooseVersion
from itertools import chain

from torch.nn.utils import clip_grad
Expand Down Expand Up @@ -42,7 +43,8 @@ def after_train_iter(self, runner):
runner.optimizer.step()


if TORCH_VERSION != 'parrots' and TORCH_VERSION >= '1.6.0':
if (TORCH_VERSION != 'parrots'
and LooseVersion(TORCH_VERSION) >= LooseVersion('1.6.0')):

@HOOKS.register_module()
class Fp16OptimizerHook(OptimizerHook):
Expand Down
5 changes: 4 additions & 1 deletion tests/test_ops/test_deform_conv.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from distutils.version import LooseVersion

import numpy as np
import pytest
import torch
Expand Down Expand Up @@ -141,7 +143,8 @@ def test_deformconv(self):

# test amp when torch version >= '1.6.0', the type of
# input data for deformconv might be torch.float or torch.half
if TORCH_VERSION != 'parrots' and TORCH_VERSION >= '1.6.0':
if (TORCH_VERSION != 'parrots'
and LooseVersion(TORCH_VERSION) >= LooseVersion('1.6.0')):
with autocast(enabled=True):
self._test_amp_deformconv(torch.float, 1e-1)
self._test_amp_deformconv(torch.half, 1e-1)
4 changes: 3 additions & 1 deletion tests/test_ops/test_modulated_deform_conv.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import os
from distutils.version import LooseVersion

import numpy
import torch
Expand Down Expand Up @@ -112,7 +113,8 @@ def test_mdconv(self):

# test amp when torch version >= '1.6.0', the type of
# input data for mdconv might be torch.float or torch.half
if TORCH_VERSION != 'parrots' and TORCH_VERSION >= '1.6.0':
if (TORCH_VERSION != 'parrots'
and LooseVersion(TORCH_VERSION) >= LooseVersion('1.6.0')):
with autocast(enabled=True):
self._test_amp_mdconv(torch.float)
self._test_amp_mdconv(torch.half)