Skip to content

Commit

Permalink
use LooseVersion for version checking (#1158)
Browse files Browse the repository at this point in the history
  • Loading branch information
kennymckormick authored Jun 29, 2021
1 parent 21845db commit 1d5ee6e
Show file tree
Hide file tree
Showing 10 changed files with 41 additions and 16 deletions.
5 changes: 4 additions & 1 deletion mmcv/cnn/bricks/activation.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from distutils.version import LooseVersion

import torch
import torch.nn as nn
import torch.nn.functional as F
Expand Down Expand Up @@ -70,7 +72,8 @@ def forward(self, input):
return F.gelu(input)


if TORCH_VERSION == 'parrots' or TORCH_VERSION < '1.4':
if (TORCH_VERSION == 'parrots'
or LooseVersion(TORCH_VERSION) < LooseVersion('1.4')):
ACTIVATION_LAYERS.register_module(module=GELU)
else:
ACTIVATION_LAYERS.register_module(module=nn.GELU)
Expand Down
8 changes: 6 additions & 2 deletions mmcv/ops/saconv.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from distutils.version import LooseVersion

import torch
import torch.nn as nn
import torch.nn.functional as F
Expand Down Expand Up @@ -103,7 +105,8 @@ def forward(self, x):
out_s = deform_conv2d(x, offset, weight, self.stride, self.padding,
self.dilation, self.groups, 1)
else:
if TORCH_VERSION < '1.5.0' or TORCH_VERSION == 'parrots':
if (LooseVersion(TORCH_VERSION) < LooseVersion('1.5.0')
or TORCH_VERSION == 'parrots'):
out_s = super().conv2d_forward(x, weight)
else:
out_s = super()._conv_forward(x, weight)
Expand All @@ -117,7 +120,8 @@ def forward(self, x):
out_l = deform_conv2d(x, offset, weight, self.stride, self.padding,
self.dilation, self.groups, 1)
else:
if TORCH_VERSION < '1.5.0' or TORCH_VERSION == 'parrots':
if (LooseVersion(TORCH_VERSION) < LooseVersion('1.5.0')
or TORCH_VERSION == 'parrots'):
out_l = super().conv2d_forward(x, weight)
else:
out_l = super()._conv_forward(x, weight)
Expand Down
10 changes: 6 additions & 4 deletions mmcv/parallel/distributed.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
# Copyright (c) Open-MMLab. All rights reserved.
from distutils.version import LooseVersion

import torch
from torch.nn.parallel.distributed import (DistributedDataParallel,
_find_tensors)
Expand Down Expand Up @@ -37,7 +39,7 @@ def train_step(self, *inputs, **kwargs):

# In PyTorch >= 1.7, ``reducer._rebuild_buckets()`` is moved from the
# end of backward to the beginning of forward.
if (TORCH_VERSION >= '1.7' and 'parrots'
if (LooseVersion(TORCH_VERSION) >= LooseVersion('1.7') and 'parrots'
not in TORCH_VERSION) and self.reducer._rebuild_buckets():
print_log(
'Reducer buckets have been rebuilt in this iteration.',
Expand All @@ -63,7 +65,7 @@ def train_step(self, *inputs, **kwargs):
else:
self.reducer.prepare_for_backward([])
else:
if TORCH_VERSION > '1.2':
if LooseVersion(TORCH_VERSION) > LooseVersion('1.2'):
self.require_forward_param_sync = False
return output

Expand All @@ -77,7 +79,7 @@ def val_step(self, *inputs, **kwargs):
"""
# In PyTorch >= 1.7, ``reducer._rebuild_buckets()`` is moved from the
# end of backward to the beginning of forward.
if (TORCH_VERSION >= '1.7' and 'parrots'
if (LooseVersion(TORCH_VERSION) >= LooseVersion('1.7') and 'parrots'
not in TORCH_VERSION) and self.reducer._rebuild_buckets():
print_log(
'Reducer buckets have been rebuilt in this iteration.',
Expand All @@ -103,6 +105,6 @@ def val_step(self, *inputs, **kwargs):
else:
self.reducer.prepare_for_backward([])
else:
if TORCH_VERSION > '1.2':
if LooseVersion(TORCH_VERSION) > LooseVersion('1.2'):
self.require_forward_param_sync = False
return output
4 changes: 3 additions & 1 deletion mmcv/parallel/distributed_deprecated.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
# Copyright (c) Open-MMLab. All rights reserved.
from distutils.version import LooseVersion

import torch
import torch.distributed as dist
import torch.nn as nn
Expand Down Expand Up @@ -40,7 +42,7 @@ def _sync_params(self):
self._dist_broadcast_coalesced(module_states,
self.broadcast_bucket_size)
if self.broadcast_buffers:
if TORCH_VERSION < '1.0':
if LooseVersion(TORCH_VERSION) < LooseVersion('1.0'):
buffers = [b.data for b in self.module._all_buffers()]
else:
buffers = [b.data for b in self.module.buffers()]
Expand Down
3 changes: 2 additions & 1 deletion mmcv/runner/dist_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import os
import subprocess
from collections import OrderedDict
from distutils.version import LooseVersion

import torch
import torch.multiprocessing as mp
Expand Down Expand Up @@ -78,7 +79,7 @@ def _init_dist_slurm(backend, port=None):


def get_dist_info():
if TORCH_VERSION < '1.0':
if LooseVersion(TORCH_VERSION) < LooseVersion('1.0'):
initialized = dist._initialized
else:
if dist.is_available():
Expand Down
10 changes: 7 additions & 3 deletions mmcv/runner/fp16_utils.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import functools
import warnings
from collections import abc
from distutils.version import LooseVersion
from inspect import getfullargspec

import numpy as np
Expand Down Expand Up @@ -121,7 +122,8 @@ def new_func(*args, **kwargs):
else:
new_kwargs[arg_name] = arg_value
# apply converted arguments to the decorated method
if TORCH_VERSION != 'parrots' and TORCH_VERSION >= '1.6.0':
if (TORCH_VERSION != 'parrots'
and LooseVersion(TORCH_VERSION) >= LooseVersion('1.6.0')):
with autocast(enabled=True):
output = old_func(*new_args, **new_kwargs)
else:
Expand Down Expand Up @@ -206,7 +208,8 @@ def new_func(*args, **kwargs):
else:
new_kwargs[arg_name] = arg_value
# apply converted arguments to the decorated method
if TORCH_VERSION != 'parrots' and TORCH_VERSION >= '1.6.0':
if (TORCH_VERSION != 'parrots'
and LooseVersion(TORCH_VERSION) >= LooseVersion('1.6.0')):
with autocast(enabled=False):
output = old_func(*new_args, **new_kwargs)
else:
Expand Down Expand Up @@ -245,7 +248,8 @@ def wrap_fp16_model(model):
Args:
model (nn.Module): Model in FP32.
"""
if TORCH_VERSION == 'parrots' or TORCH_VERSION < '1.6.0':
if (TORCH_VERSION == 'parrots'
or LooseVersion(TORCH_VERSION) < LooseVersion('1.6.0')):
# convert model to fp16
model.half()
# patch the normalization layers to make it work in fp32 mode
Expand Down
4 changes: 3 additions & 1 deletion mmcv/runner/hooks/logger/tensorboard.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
# Copyright (c) Open-MMLab. All rights reserved.
import os.path as osp
from distutils.version import LooseVersion

from mmcv.utils import TORCH_VERSION
from ...dist_utils import master_only
Expand All @@ -23,7 +24,8 @@ def __init__(self,
@master_only
def before_run(self, runner):
super(TensorboardLoggerHook, self).before_run(runner)
if TORCH_VERSION < '1.1' or TORCH_VERSION == 'parrots':
if (LooseVersion(TORCH_VERSION) < LooseVersion('1.1')
or TORCH_VERSION == 'parrots'):
try:
from tensorboardX import SummaryWriter
except ImportError:
Expand Down
4 changes: 3 additions & 1 deletion mmcv/runner/hooks/optimizer.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# Copyright (c) Open-MMLab. All rights reserved.
import copy
from collections import defaultdict
from distutils.version import LooseVersion
from itertools import chain

from torch.nn.utils import clip_grad
Expand Down Expand Up @@ -42,7 +43,8 @@ def after_train_iter(self, runner):
runner.optimizer.step()


if TORCH_VERSION != 'parrots' and TORCH_VERSION >= '1.6.0':
if (TORCH_VERSION != 'parrots'
and LooseVersion(TORCH_VERSION) >= LooseVersion('1.6.0')):

@HOOKS.register_module()
class Fp16OptimizerHook(OptimizerHook):
Expand Down
5 changes: 4 additions & 1 deletion tests/test_ops/test_deform_conv.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from distutils.version import LooseVersion

import numpy as np
import pytest
import torch
Expand Down Expand Up @@ -141,7 +143,8 @@ def test_deformconv(self):

# test amp when torch version >= '1.6.0', the type of
# input data for deformconv might be torch.float or torch.half
if TORCH_VERSION != 'parrots' and TORCH_VERSION >= '1.6.0':
if (TORCH_VERSION != 'parrots'
and LooseVersion(TORCH_VERSION) >= LooseVersion('1.6.0')):
with autocast(enabled=True):
self._test_amp_deformconv(torch.float, 1e-1)
self._test_amp_deformconv(torch.half, 1e-1)
4 changes: 3 additions & 1 deletion tests/test_ops/test_modulated_deform_conv.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import os
from distutils.version import LooseVersion

import numpy
import torch
Expand Down Expand Up @@ -112,7 +113,8 @@ def test_mdconv(self):

# test amp when torch version >= '1.6.0', the type of
# input data for mdconv might be torch.float or torch.half
if TORCH_VERSION != 'parrots' and TORCH_VERSION >= '1.6.0':
if (TORCH_VERSION != 'parrots'
and LooseVersion(TORCH_VERSION) >= LooseVersion('1.6.0')):
with autocast(enabled=True):
self._test_amp_mdconv(torch.float)
self._test_amp_mdconv(torch.half)

0 comments on commit 1d5ee6e

Please sign in to comment.