Skip to content

Commit

Permalink
[Refactor] Refactor unittest (#321)
Browse files Browse the repository at this point in the history
* Refactor unit tests folder structure.

* Remove label smooth and Vit test in `test_classifiers.py`

* Rename test_utils in dataset to test_dataset_utils

* Split test_models/test_utils/test_utils.py to multiple sub files.

* Add unit tests of classifiers and heads

* Use patch context manager.

* Add unit test of `is_tracing`, and add warning in `is_tracing` if torch
verison is smaller than 1.6.0
  • Loading branch information
mzr1996 authored Jul 8, 2021
1 parent 71621a5 commit 1a7cebe
Show file tree
Hide file tree
Showing 34 changed files with 502 additions and 379 deletions.
11 changes: 9 additions & 2 deletions mmcls/models/utils/helpers.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,26 @@
import collections.abc
import warnings
from distutils.version import LooseVersion
from itertools import repeat

import torch


def is_tracing() -> bool:
if hasattr(torch.jit, 'is_tracing'):
if LooseVersion(torch.__version__) >= LooseVersion('1.6.0'):
on_trace = torch.jit.is_tracing()
# In PyTorch 1.6, torch.jit.is_tracing has a bug.
# Refers to https://github.com/pytorch/pytorch/issues/42448
if isinstance(on_trace, bool):
return on_trace
else:
return torch._C._is_tracing()
return False
else:
warnings.warn(
'torch.jit.is_tracing is only supported after v1.6.0. '
'Therefore is_tracing returns False automatically. Please '
'set on_trace manually if you are using trace.', UserWarning)
return False


# From PyTorch internals
Expand Down
189 changes: 0 additions & 189 deletions tests/test_backbones/test_utils.py

This file was deleted.

Original file line number Diff line number Diff line change
@@ -1,18 +1,11 @@
import bisect
import math
import random
import string
import tempfile
from collections import defaultdict
from unittest.mock import MagicMock, patch

import numpy as np
import pytest
import torch

from mmcls.datasets import (DATASETS, BaseDataset, ClassBalancedDataset,
ConcatDataset, MultiLabelDataset, RepeatDataset)
from mmcls.datasets.utils import check_integrity, rm_suffix
from mmcls.datasets import DATASETS, BaseDataset, MultiLabelDataset


@pytest.mark.parametrize(
Expand Down Expand Up @@ -45,7 +38,6 @@ def test_datasets_override_default(dataset_name):
pipeline=[],
classes=('bus', 'car'),
test_mode=True)
assert dataset.CLASSES != original_classes
assert dataset.CLASSES == ('bus', 'car')

# Test setting classes as a list
Expand All @@ -54,7 +46,6 @@ def test_datasets_override_default(dataset_name):
pipeline=[],
classes=['bus', 'car'],
test_mode=True)
assert dataset.CLASSES != original_classes
assert dataset.CLASSES == ['bus', 'car']

# Test setting classes through a file
Expand All @@ -68,7 +59,6 @@ def test_datasets_override_default(dataset_name):
test_mode=True)
tmp_file.close()

assert dataset.CLASSES != original_classes
assert dataset.CLASSES == ['bus', 'car']

# Test overriding not a subset
Expand All @@ -77,7 +67,6 @@ def test_datasets_override_default(dataset_name):
pipeline=[],
classes=['foo'],
test_mode=True)
assert dataset.CLASSES != original_classes
assert dataset.CLASSES == ['foo']

# Test default behavior
Expand Down Expand Up @@ -258,92 +247,3 @@ def test_dataset_evaluation():
assert 'CR' in eval_results.keys()
assert 'OF1' in eval_results.keys()
assert 'CF1' not in eval_results.keys()


@patch.multiple(BaseDataset, __abstractmethods__=set())
def test_dataset_wrapper():
BaseDataset.CLASSES = ('foo', 'bar')
BaseDataset.__getitem__ = MagicMock(side_effect=lambda idx: idx)
dataset_a = BaseDataset(data_prefix='', pipeline=[], test_mode=True)
len_a = 10
cat_ids_list_a = [
np.random.randint(0, 80, num).tolist()
for num in np.random.randint(1, 20, len_a)
]
dataset_a.data_infos = MagicMock()
dataset_a.data_infos.__len__.return_value = len_a
dataset_a.get_cat_ids = MagicMock(
side_effect=lambda idx: cat_ids_list_a[idx])
dataset_b = BaseDataset(data_prefix='', pipeline=[], test_mode=True)
len_b = 20
cat_ids_list_b = [
np.random.randint(0, 80, num).tolist()
for num in np.random.randint(1, 20, len_b)
]
dataset_b.data_infos = MagicMock()
dataset_b.data_infos.__len__.return_value = len_b
dataset_b.get_cat_ids = MagicMock(
side_effect=lambda idx: cat_ids_list_b[idx])

concat_dataset = ConcatDataset([dataset_a, dataset_b])
assert concat_dataset[5] == 5
assert concat_dataset[25] == 15
assert concat_dataset.get_cat_ids(5) == cat_ids_list_a[5]
assert concat_dataset.get_cat_ids(25) == cat_ids_list_b[15]
assert len(concat_dataset) == len(dataset_a) + len(dataset_b)
assert concat_dataset.CLASSES == BaseDataset.CLASSES

repeat_dataset = RepeatDataset(dataset_a, 10)
assert repeat_dataset[5] == 5
assert repeat_dataset[15] == 5
assert repeat_dataset[27] == 7
assert repeat_dataset.get_cat_ids(5) == cat_ids_list_a[5]
assert repeat_dataset.get_cat_ids(15) == cat_ids_list_a[5]
assert repeat_dataset.get_cat_ids(27) == cat_ids_list_a[7]
assert len(repeat_dataset) == 10 * len(dataset_a)
assert repeat_dataset.CLASSES == BaseDataset.CLASSES

category_freq = defaultdict(int)
for cat_ids in cat_ids_list_a:
cat_ids = set(cat_ids)
for cat_id in cat_ids:
category_freq[cat_id] += 1
for k, v in category_freq.items():
category_freq[k] = v / len(cat_ids_list_a)

mean_freq = np.mean(list(category_freq.values()))
repeat_thr = mean_freq

category_repeat = {
cat_id: max(1.0, math.sqrt(repeat_thr / cat_freq))
for cat_id, cat_freq in category_freq.items()
}

repeat_factors = []
for cat_ids in cat_ids_list_a:
cat_ids = set(cat_ids)
repeat_factor = max({category_repeat[cat_id] for cat_id in cat_ids})
repeat_factors.append(math.ceil(repeat_factor))
repeat_factors_cumsum = np.cumsum(repeat_factors)
repeat_factor_dataset = ClassBalancedDataset(dataset_a, repeat_thr)
assert repeat_factor_dataset.CLASSES == BaseDataset.CLASSES
assert len(repeat_factor_dataset) == repeat_factors_cumsum[-1]
for idx in np.random.randint(0, len(repeat_factor_dataset), 3):
assert repeat_factor_dataset[idx] == bisect.bisect_right(
repeat_factors_cumsum, idx)


def test_dataset_utils():
# test rm_suffix
assert rm_suffix('a.jpg') == 'a'
assert rm_suffix('a.bak.jpg') == 'a.bak'
assert rm_suffix('a.bak.jpg', suffix='.jpg') == 'a.bak'
assert rm_suffix('a.bak.jpg', suffix='.bak.jpg') == 'a'

# test check_integrity
rand_file = ''.join(random.sample(string.ascii_letters, 10))
assert not check_integrity(rand_file, md5=None)
assert not check_integrity(rand_file, md5=2333)
tmp_file = tempfile.NamedTemporaryFile()
assert check_integrity(tmp_file.name, md5=None)
assert not check_integrity(tmp_file.name, md5=2333)
21 changes: 21 additions & 0 deletions tests/test_data/test_datasets/test_dataset_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
import random
import string
import tempfile

from mmcls.datasets.utils import check_integrity, rm_suffix


def test_dataset_utils():
# test rm_suffix
assert rm_suffix('a.jpg') == 'a'
assert rm_suffix('a.bak.jpg') == 'a.bak'
assert rm_suffix('a.bak.jpg', suffix='.jpg') == 'a.bak'
assert rm_suffix('a.bak.jpg', suffix='.bak.jpg') == 'a'

# test check_integrity
rand_file = ''.join(random.sample(string.ascii_letters, 10))
assert not check_integrity(rand_file, md5=None)
assert not check_integrity(rand_file, md5=2333)
tmp_file = tempfile.NamedTemporaryFile()
assert check_integrity(tmp_file.name, md5=None)
assert not check_integrity(tmp_file.name, md5=2333)
Loading

0 comments on commit 1a7cebe

Please sign in to comment.