From 9b2f102d572b273669e47b703f840b58513ffeab Mon Sep 17 00:00:00 2001 From: Kenny Date: Tue, 29 Jun 2021 16:36:27 +0800 Subject: [PATCH] use LooseVersion for version checking --- mmcv/cnn/bricks/activation.py | 5 ++++- mmcv/ops/saconv.py | 8 ++++++-- mmcv/parallel/distributed.py | 10 ++++++---- mmcv/parallel/distributed_deprecated.py | 4 +++- mmcv/runner/dist_utils.py | 3 ++- mmcv/runner/fp16_utils.py | 10 +++++++--- mmcv/runner/hooks/logger/tensorboard.py | 4 +++- mmcv/runner/hooks/optimizer.py | 4 +++- tests/test_ops/test_deform_conv.py | 5 ++++- tests/test_ops/test_modulated_deform_conv.py | 4 +++- 10 files changed, 41 insertions(+), 16 deletions(-) diff --git a/mmcv/cnn/bricks/activation.py b/mmcv/cnn/bricks/activation.py index f50241b192..89d54980e8 100644 --- a/mmcv/cnn/bricks/activation.py +++ b/mmcv/cnn/bricks/activation.py @@ -1,3 +1,5 @@ +from distutils.version import LooseVersion + import torch import torch.nn as nn import torch.nn.functional as F @@ -70,7 +72,8 @@ def forward(self, input): return F.gelu(input) -if TORCH_VERSION == 'parrots' or TORCH_VERSION < '1.4': +if (TORCH_VERSION == 'parrots' + or LooseVersion(TORCH_VERSION) < LooseVersion('1.4')): ACTIVATION_LAYERS.register_module(module=GELU) else: ACTIVATION_LAYERS.register_module(module=nn.GELU) diff --git a/mmcv/ops/saconv.py b/mmcv/ops/saconv.py index cd7eea122f..5694cf344e 100644 --- a/mmcv/ops/saconv.py +++ b/mmcv/ops/saconv.py @@ -1,3 +1,5 @@ +from distutils.version import LooseVersion + import torch import torch.nn as nn import torch.nn.functional as F @@ -103,7 +105,8 @@ def forward(self, x): out_s = deform_conv2d(x, offset, weight, self.stride, self.padding, self.dilation, self.groups, 1) else: - if TORCH_VERSION < '1.5.0' or TORCH_VERSION == 'parrots': + if (LooseVersion(TORCH_VERSION) < LooseVersion('1.5.0') + or TORCH_VERSION == 'parrots'): out_s = super().conv2d_forward(x, weight) else: out_s = super()._conv_forward(x, weight) @@ -117,7 +120,8 @@ def forward(self, x): out_l = deform_conv2d(x, offset, weight, self.stride, self.padding, self.dilation, self.groups, 1) else: - if TORCH_VERSION < '1.5.0' or TORCH_VERSION == 'parrots': + if (LooseVersion(TORCH_VERSION) < LooseVersion('1.5.0') + or TORCH_VERSION == 'parrots'): out_l = super().conv2d_forward(x, weight) else: out_l = super()._conv_forward(x, weight) diff --git a/mmcv/parallel/distributed.py b/mmcv/parallel/distributed.py index 767c4f9dd2..2882cf35d4 100644 --- a/mmcv/parallel/distributed.py +++ b/mmcv/parallel/distributed.py @@ -1,4 +1,6 @@ # Copyright (c) Open-MMLab. All rights reserved. +from distutils.version import LooseVersion + import torch from torch.nn.parallel.distributed import (DistributedDataParallel, _find_tensors) @@ -37,7 +39,7 @@ def train_step(self, *inputs, **kwargs): # In PyTorch >= 1.7, ``reducer._rebuild_buckets()`` is moved from the # end of backward to the beginning of forward. - if (TORCH_VERSION >= '1.7' and 'parrots' + if (LooseVersion(TORCH_VERSION) >= LooseVersion('1.7') and 'parrots' not in TORCH_VERSION) and self.reducer._rebuild_buckets(): print_log( 'Reducer buckets have been rebuilt in this iteration.', @@ -63,7 +65,7 @@ def train_step(self, *inputs, **kwargs): else: self.reducer.prepare_for_backward([]) else: - if TORCH_VERSION > '1.2': + if LooseVersion(TORCH_VERSION) > LooseVersion('1.2'): self.require_forward_param_sync = False return output @@ -77,7 +79,7 @@ def val_step(self, *inputs, **kwargs): """ # In PyTorch >= 1.7, ``reducer._rebuild_buckets()`` is moved from the # end of backward to the beginning of forward. - if (TORCH_VERSION >= '1.7' and 'parrots' + if (LooseVersion(TORCH_VERSION) >= LooseVersion('1.7') and 'parrots' not in TORCH_VERSION) and self.reducer._rebuild_buckets(): print_log( 'Reducer buckets have been rebuilt in this iteration.', @@ -103,6 +105,6 @@ def val_step(self, *inputs, **kwargs): else: self.reducer.prepare_for_backward([]) else: - if TORCH_VERSION > '1.2': + if LooseVersion(TORCH_VERSION) > LooseVersion('1.2'): self.require_forward_param_sync = False return output diff --git a/mmcv/parallel/distributed_deprecated.py b/mmcv/parallel/distributed_deprecated.py index 2a49fa9e3f..45443db995 100644 --- a/mmcv/parallel/distributed_deprecated.py +++ b/mmcv/parallel/distributed_deprecated.py @@ -1,4 +1,6 @@ # Copyright (c) Open-MMLab. All rights reserved. +from distutils.version import LooseVersion + import torch import torch.distributed as dist import torch.nn as nn @@ -40,7 +42,7 @@ def _sync_params(self): self._dist_broadcast_coalesced(module_states, self.broadcast_bucket_size) if self.broadcast_buffers: - if TORCH_VERSION < '1.0': + if LooseVersion(TORCH_VERSION) < LooseVersion('1.0'): buffers = [b.data for b in self.module._all_buffers()] else: buffers = [b.data for b in self.module.buffers()] diff --git a/mmcv/runner/dist_utils.py b/mmcv/runner/dist_utils.py index 0a9ccf35af..6221554b62 100644 --- a/mmcv/runner/dist_utils.py +++ b/mmcv/runner/dist_utils.py @@ -3,6 +3,7 @@ import os import subprocess from collections import OrderedDict +from distutils.version import LooseVersion import torch import torch.multiprocessing as mp @@ -78,7 +79,7 @@ def _init_dist_slurm(backend, port=None): def get_dist_info(): - if TORCH_VERSION < '1.0': + if LooseVersion(TORCH_VERSION) < LooseVersion('1.0'): initialized = dist._initialized else: if dist.is_available(): diff --git a/mmcv/runner/fp16_utils.py b/mmcv/runner/fp16_utils.py index f2f0ac0ee1..c5d562512e 100644 --- a/mmcv/runner/fp16_utils.py +++ b/mmcv/runner/fp16_utils.py @@ -1,6 +1,7 @@ import functools import warnings from collections import abc +from distutils.version import LooseVersion from inspect import getfullargspec import numpy as np @@ -121,7 +122,8 @@ def new_func(*args, **kwargs): else: new_kwargs[arg_name] = arg_value # apply converted arguments to the decorated method - if TORCH_VERSION != 'parrots' and TORCH_VERSION >= '1.6.0': + if (TORCH_VERSION != 'parrots' + and LooseVersion(TORCH_VERSION) >= LooseVersion('1.6.0')): with autocast(enabled=True): output = old_func(*new_args, **new_kwargs) else: @@ -206,7 +208,8 @@ def new_func(*args, **kwargs): else: new_kwargs[arg_name] = arg_value # apply converted arguments to the decorated method - if TORCH_VERSION != 'parrots' and TORCH_VERSION >= '1.6.0': + if (TORCH_VERSION != 'parrots' + and LooseVersion(TORCH_VERSION) >= LooseVersion('1.6.0')): with autocast(enabled=False): output = old_func(*new_args, **new_kwargs) else: @@ -245,7 +248,8 @@ def wrap_fp16_model(model): Args: model (nn.Module): Model in FP32. """ - if TORCH_VERSION == 'parrots' or TORCH_VERSION < '1.6.0': + if (TORCH_VERSION == 'parrots' + or LooseVersion(TORCH_VERSION) < LooseVersion('1.6.0')): # convert model to fp16 model.half() # patch the normalization layers to make it work in fp32 mode diff --git a/mmcv/runner/hooks/logger/tensorboard.py b/mmcv/runner/hooks/logger/tensorboard.py index f973047976..475d4b5408 100644 --- a/mmcv/runner/hooks/logger/tensorboard.py +++ b/mmcv/runner/hooks/logger/tensorboard.py @@ -1,5 +1,6 @@ # Copyright (c) Open-MMLab. All rights reserved. import os.path as osp +from distutils.version import LooseVersion from mmcv.utils import TORCH_VERSION from ...dist_utils import master_only @@ -23,7 +24,8 @@ def __init__(self, @master_only def before_run(self, runner): super(TensorboardLoggerHook, self).before_run(runner) - if TORCH_VERSION < '1.1' or TORCH_VERSION == 'parrots': + if (LooseVersion(TORCH_VERSION) < LooseVersion('1.1') + or TORCH_VERSION == 'parrots'): try: from tensorboardX import SummaryWriter except ImportError: diff --git a/mmcv/runner/hooks/optimizer.py b/mmcv/runner/hooks/optimizer.py index bb97504667..a2f8114a7b 100644 --- a/mmcv/runner/hooks/optimizer.py +++ b/mmcv/runner/hooks/optimizer.py @@ -1,6 +1,7 @@ # Copyright (c) Open-MMLab. All rights reserved. import copy from collections import defaultdict +from distutils.version import LooseVersion from itertools import chain from torch.nn.utils import clip_grad @@ -42,7 +43,8 @@ def after_train_iter(self, runner): runner.optimizer.step() -if TORCH_VERSION != 'parrots' and TORCH_VERSION >= '1.6.0': +if (TORCH_VERSION != 'parrots' + and LooseVersion(TORCH_VERSION) >= LooseVersion('1.6.0')): @HOOKS.register_module() class Fp16OptimizerHook(OptimizerHook): diff --git a/tests/test_ops/test_deform_conv.py b/tests/test_ops/test_deform_conv.py index c49d47980a..ea6e429d2e 100644 --- a/tests/test_ops/test_deform_conv.py +++ b/tests/test_ops/test_deform_conv.py @@ -1,3 +1,5 @@ +from distutils.version import LooseVersion + import numpy as np import pytest import torch @@ -141,7 +143,8 @@ def test_deformconv(self): # test amp when torch version >= '1.6.0', the type of # input data for deformconv might be torch.float or torch.half - if TORCH_VERSION != 'parrots' and TORCH_VERSION >= '1.6.0': + if (TORCH_VERSION != 'parrots' + and LooseVersion(TORCH_VERSION) >= LooseVersion('1.6.0')): with autocast(enabled=True): self._test_amp_deformconv(torch.float, 1e-1) self._test_amp_deformconv(torch.half, 1e-1) diff --git a/tests/test_ops/test_modulated_deform_conv.py b/tests/test_ops/test_modulated_deform_conv.py index 83c6f8a405..73032f0a45 100644 --- a/tests/test_ops/test_modulated_deform_conv.py +++ b/tests/test_ops/test_modulated_deform_conv.py @@ -1,4 +1,5 @@ import os +from distutils.version import LooseVersion import numpy import torch @@ -112,7 +113,8 @@ def test_mdconv(self): # test amp when torch version >= '1.6.0', the type of # input data for mdconv might be torch.float or torch.half - if TORCH_VERSION != 'parrots' and TORCH_VERSION >= '1.6.0': + if (TORCH_VERSION != 'parrots' + and LooseVersion(TORCH_VERSION) >= LooseVersion('1.6.0')): with autocast(enabled=True): self._test_amp_mdconv(torch.float) self._test_amp_mdconv(torch.half)