Skip to content

Commit

Permalink
[Fix] Fix unittest in pt1.9 (#1146)
Browse files Browse the repository at this point in the history
* fix test.txt

* fix unittest in pt1.9

* fix checkpoint filename error

* add comment

* fix unittest

* fix onnxruntime version
  • Loading branch information
zhouzaida authored Jul 3, 2021
1 parent 6c63621 commit 4a9f834
Show file tree
Hide file tree
Showing 2 changed files with 23 additions and 5 deletions.
19 changes: 15 additions & 4 deletions tests/test_load_model_zoo.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
_load_checkpoint,
get_deprecated_model_names,
get_external_models)
from mmcv.utils import TORCH_VERSION


@patch('mmcv.__path__', [osp.join(osp.dirname(__file__), 'data/')])
Expand Down Expand Up @@ -77,13 +78,23 @@ def load(filepath, map_location=None):
def test_load_external_url():
# test modelzoo://
url = _load_checkpoint('modelzoo://resnet50')
assert url == 'url:https://download.pytorch.org/models/resnet50-19c8e357' \
'.pth'
if TORCH_VERSION < '1.9.0':
assert url == ('url:https://download.pytorch.org/models/resnet50-19c8e'
'357.pth')
else:
# filename of checkpoint is renamed in torch1.9.0
assert url == ('url:https://download.pytorch.org/models/resnet50-0676b'
'a61.pth')

# test torchvision://
url = _load_checkpoint('torchvision://resnet50')
assert url == 'url:https://download.pytorch.org/models/resnet50-19c8e357' \
'.pth'
if TORCH_VERSION < '1.9.0':
assert url == ('url:https://download.pytorch.org/models/resnet50-19c8e'
'357.pth')
else:
# filename of checkpoint is renamed in torch1.9.0
assert url == ('url:https://download.pytorch.org/models/resnet50-0676b'
'a61.pth')

# test open-mmlab:// with default MMCV_HOME
os.environ.pop(ENV_MMCV_HOME, None)
Expand Down
9 changes: 8 additions & 1 deletion tests/test_parallel.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from unittest.mock import MagicMock, patch

import torch
import torch.nn as nn
from torch.nn.parallel import DataParallel, DistributedDataParallel

Expand All @@ -15,7 +16,7 @@ def mock(*args, **kwargs):

@patch('torch.distributed._broadcast_coalesced', mock)
@patch('torch.distributed.broadcast', mock)
@patch('torch.nn.parallel.DistributedDataParallel._ddp_init_helper', MagicMock)
@patch('torch.nn.parallel.DistributedDataParallel._ddp_init_helper', mock)
def test_is_module_wrapper():

class Model(nn.Module):
Expand All @@ -27,6 +28,12 @@ def __init__(self):
def forward(self, x):
return self.conv(x)

# _verify_model_across_ranks is added in torch1.9.0 so we should check
# wether _verify_model_across_ranks is the member of torch.distributed
# before mocking
if hasattr(torch.distributed, '_verify_model_across_ranks'):
torch.distributed._verify_model_across_ranks = mock

model = Model()
assert not is_module_wrapper(model)

Expand Down

0 comments on commit 4a9f834

Please sign in to comment.