Skip to content

Commit

Permalink
[Fix] Update digit_version (#778)
Browse files Browse the repository at this point in the history
* update digit_version

* add unittest

* fix import
  • Loading branch information
Junjun2016 authored Aug 12, 2021
1 parent 58f5dbc commit bfc3cdb
Show file tree
Hide file tree
Showing 5 changed files with 66 additions and 14 deletions.
53 changes: 42 additions & 11 deletions mmseg/__init__.py
Original file line number Diff line number Diff line change
@@ -1,21 +1,52 @@
import warnings

import mmcv
from packaging.version import parse

from .version import __version__, version_info

MMCV_MIN = '1.3.7'
MMCV_MAX = '1.4.0'


def digit_version(version_str):
digit_version = []
for x in version_str.split('.'):
if x.isdigit():
digit_version.append(int(x))
elif x.find('rc') != -1:
patch_version = x.split('rc')
digit_version.append(int(patch_version[0]) - 1)
digit_version.append(int(patch_version[1]))
return digit_version
def digit_version(version_str: str, length: int = 4):
"""Convert a version string into a tuple of integers.
This method is usually used for comparing two versions. For pre-release
versions: alpha < beta < rc.
Args:
version_str (str): The version string.
length (int): The maximum number of version levels. Default: 4.
Returns:
tuple[int]: The version info in digits (integers).
"""
version = parse(version_str)
assert version.release, f'failed to parse version {version_str}'
release = list(version.release)
release = release[:length]
if len(release) < length:
release = release + [0] * (length - len(release))
if version.is_prerelease:
mapping = {'a': -3, 'b': -2, 'rc': -1}
val = -4
# version.pre can be None
if version.pre:
if version.pre[0] not in mapping:
warnings.warn(f'unknown prerelease version {version.pre[0]}, '
'version checking may go wrong')
else:
val = mapping[version.pre[0]]
release.extend([val, version.pre[-1]])
else:
release.extend([val, 0])

elif version.is_postrelease:
release.extend([1, version.post])
else:
release.extend([0, 0])
return tuple(release)


mmcv_min_version = digit_version(MMCV_MIN)
Expand All @@ -27,4 +58,4 @@ def digit_version(version_str):
f'MMCV=={mmcv.__version__} is used but incompatible. ' \
f'Please install mmcv>={mmcv_min_version}, <={mmcv_max_version}.'

__all__ = ['__version__', 'version_info']
__all__ = ['__version__', 'version_info', 'digit_version']
4 changes: 2 additions & 2 deletions mmseg/datasets/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import torch
from mmcv.parallel import collate
from mmcv.runner import get_dist_info
from mmcv.utils import Registry, build_from_cfg
from mmcv.utils import Registry, build_from_cfg, digit_version
from torch.utils.data import DataLoader, DistributedSampler

if platform.system() != 'Windows':
Expand Down Expand Up @@ -133,7 +133,7 @@ def build_dataloader(dataset,
worker_init_fn, num_workers=num_workers, rank=rank,
seed=seed) if seed is not None else None

if torch.__version__ >= '1.8.0':
if digit_version(torch.__version__) >= digit_version('1.8.0'):
data_loader = DataLoader(
dataset,
batch_size=batch_size,
Expand Down
1 change: 1 addition & 0 deletions requirements/runtime.txt
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
matplotlib
numpy
packaging
prettytable
2 changes: 1 addition & 1 deletion setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,6 @@ line_length = 79
multi_line_output = 0
known_standard_library = setuptools
known_first_party = mmseg
known_third_party = PIL,cityscapesscripts,cv2,detail,matplotlib,mmcv,numpy,onnxruntime,oss2,prettytable,pytest,scipy,seaborn,torch,ts
known_third_party = PIL,cityscapesscripts,cv2,detail,matplotlib,mmcv,numpy,onnxruntime,oss2,packaging,prettytable,pytest,scipy,seaborn,torch,ts
no_lines_before = STDLIB,LOCALFOLDER
default_section = THIRDPARTY
20 changes: 20 additions & 0 deletions tests/test_digit_version.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
from mmseg import digit_version


def test_digit_version():
assert digit_version('0.2.16') == (0, 2, 16, 0, 0, 0)
assert digit_version('1.2.3') == (1, 2, 3, 0, 0, 0)
assert digit_version('1.2.3rc0') == (1, 2, 3, 0, -1, 0)
assert digit_version('1.2.3rc1') == (1, 2, 3, 0, -1, 1)
assert digit_version('1.0rc0') == (1, 0, 0, 0, -1, 0)
assert digit_version('1.0') == digit_version('1.0.0')
assert digit_version('1.5.0+cuda90_cudnn7.6.3_lms') == digit_version('1.5')
assert digit_version('1.0.0dev') < digit_version('1.0.0a')
assert digit_version('1.0.0a') < digit_version('1.0.0a1')
assert digit_version('1.0.0a') < digit_version('1.0.0b')
assert digit_version('1.0.0b') < digit_version('1.0.0rc')
assert digit_version('1.0.0rc1') < digit_version('1.0.0')
assert digit_version('1.0.0') < digit_version('1.0.0post')
assert digit_version('1.0.0post') < digit_version('1.0.0post1')
assert digit_version('v1') == (1, 0, 0, 0, 0, 0)
assert digit_version('v1.1.5') == (1, 1, 5, 0, 0, 0)

0 comments on commit bfc3cdb

Please sign in to comment.