Skip to content

Commit

Permalink
Revert pretrained argument, and remove it's deprecation sign.
Browse files Browse the repository at this point in the history
  • Loading branch information
mzr1996 committed Jan 28, 2022
1 parent d65301a commit bd9bbce
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 0 deletions.
3 changes: 3 additions & 0 deletions mmcls/models/classifiers/image.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,13 @@ def __init__(self,
backbone,
neck=None,
head=None,
pretrained=None,
train_cfg=None,
init_cfg=None):
super(ImageClassifier, self).__init__(init_cfg)

if pretrained is not None:
self.init_cfg = dict(type='Pretrained', checkpoint=pretrained)
self.backbone = build_backbone(backbone)

if neck is not None:
Expand Down
9 changes: 9 additions & 0 deletions tests/test_models/test_classifiers.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from copy import deepcopy

import numpy as np
import pytest
import torch
from mmcv import ConfigDict

Expand Down Expand Up @@ -84,6 +85,14 @@ def test_image_classifier():
assert soft_pred.shape == (16, 10)
torch.testing.assert_allclose(soft_pred, torch.softmax(pred, dim=1))

# test pretrained
with pytest.warns(UserWarning):
model_cfg_ = deepcopy(model_cfg)
model_cfg_['pretrained'] = 'checkpoint'
model = CLASSIFIERS.build(model_cfg_)
assert model.init_cfg == dict(
type='Pretrained', checkpoint='checkpoint')

# test show_result
img = np.random.randint(0, 256, (224, 224, 3)).astype(np.uint8)
result = dict(pred_class='cat', pred_label=0, pred_score=0.9)
Expand Down

0 comments on commit bd9bbce

Please sign in to comment.