Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature] Add mmrazor support #220

Merged
merged 10 commits into from
Apr 19, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 20 additions & 1 deletion mmdeploy/codebase/base/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,8 @@
import torch
from torch.utils.data import DataLoader, Dataset

from mmdeploy.utils import get_backend_config, get_codebase, get_root_logger
from mmdeploy.utils import (get_backend_config, get_codebase,
get_codebase_config, get_root_logger)
from mmdeploy.utils.dataset import is_can_sort_dataset, sort_dataset


Expand Down Expand Up @@ -284,3 +285,21 @@ def get_model_name(self) -> str:
str: the name of the model.
"""
pass

@property
def from_mmrazor(self) -> bool:
"""Whether the codebase from mmrazor.

Returns:
bool: From mmrazor or not.

Raises:
TypeError: An error when type of `from_mmrazor` is not boolean.
"""
codebase_config = get_codebase_config(self.deploy_cfg)
from_mmrazor = codebase_config.get('from_mmrazor', False)
if not isinstance(from_mmrazor, bool):
raise TypeError('`from_mmrazor` attribute must be boolean type! '
f'but got: {from_mmrazor}')

return from_mmrazor
6 changes: 5 additions & 1 deletion mmdeploy/codebase/mmcls/deploy/classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,11 @@ def init_pytorch_model(self,
nn.Module: An initialized torch model generated by OpenMMLab
codebases.
"""
from mmcls.apis import init_model
if self.from_mmrazor:
from mmrazor.apis import init_mmcls_model as init_model
else:
from mmcls.apis import init_model

model = init_model(self.model_cfg, model_checkpoint, self.device,
cfg_options)

Expand Down
6 changes: 5 additions & 1 deletion mmdeploy/codebase/mmdet/deploy/object_detection.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,11 @@ def init_pytorch_model(self,
nn.Module: An initialized torch model generated by other OpenMMLab
codebases.
"""
from mmdet.apis import init_detector
if self.from_mmrazor:
from mmrazor.apis import init_mmdet_model as init_detector
else:
from mmdet.apis import init_detector

model = init_detector(self.model_cfg, model_checkpoint, self.device,
cfg_options)
return model.eval()
Expand Down
6 changes: 5 additions & 1 deletion mmdeploy/codebase/mmseg/deploy/segmentation.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,11 @@ def init_pytorch_model(self,
codebases.
"""
from mmcv.cnn.utils import revert_sync_batchnorm
from mmseg.apis import init_segmentor
if self.from_mmrazor:
from mmrazor.apis import init_mmseg_model as init_segmentor
else:
from mmseg.apis import init_segmentor

model = init_segmentor(self.model_cfg, model_checkpoint, self.device)
model = revert_sync_batchnorm(model)
return model.eval()
Expand Down
1 change: 1 addition & 0 deletions requirements/optional.txt
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ mmdet>=2.19.0,<=2.20.0
mmedit
mmocr>=0.3.0,<=0.4.1
mmpose>=0.24.0
mmrazor>=0.3.0
mmsegmentation
onnxruntime>=1.8.0
openvino-dev
31 changes: 31 additions & 0 deletions tests/test_codebase/test_mmcls/data/mmrazor_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
# Copyright (c) OpenMMLab. All rights reserved.
_base_ = 'model.py'

norm_cfg = dict(type='BN')

mutator = dict(
type='OneShotMutator',
placeholder_mapping=dict(
all_blocks=dict(
type='OneShotOP',
choices=dict(
shuffle_3x3=dict(
type='ShuffleBlock', kernel_size=3, norm_cfg=norm_cfg),
shuffle_5x5=dict(
type='ShuffleBlock', kernel_size=5, norm_cfg=norm_cfg),
shuffle_7x7=dict(
type='ShuffleBlock', kernel_size=7, norm_cfg=norm_cfg),
shuffle_xception=dict(
type='ShuffleXception', norm_cfg=norm_cfg),
))))

algorithm = dict(
type='SPOS',
architecture=dict(
type='MMClsArchitecture',
model={{_base_.model}},
),
mutator=mutator,
distiller=None,
mutable_cfg='tests/test_codebase/test_mmcls/data/mmrazor_mutable_cfg.yaml',
retraining=True)
60 changes: 60 additions & 0 deletions tests/test_codebase/test_mmcls/data/mmrazor_mutable_cfg.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
stage_0_block_0:
chosen:
- shuffle_7x7
stage_0_block_1:
chosen:
- shuffle_5x5
stage_0_block_2:
chosen:
- shuffle_3x3
stage_0_block_3:
chosen:
- shuffle_5x5
stage_1_block_0:
chosen:
- shuffle_7x7
stage_1_block_1:
chosen:
- shuffle_3x3
stage_1_block_2:
chosen:
- shuffle_7x7
stage_1_block_3:
chosen:
- shuffle_3x3
stage_2_block_0:
chosen:
- shuffle_7x7
stage_2_block_1:
chosen:
- shuffle_3x3
stage_2_block_2:
chosen:
- shuffle_7x7
stage_2_block_3:
chosen:
- shuffle_xception
stage_2_block_4:
chosen:
- shuffle_3x3
stage_2_block_5:
chosen:
- shuffle_3x3
stage_2_block_6:
chosen:
- shuffle_3x3
stage_2_block_7:
chosen:
- shuffle_3x3
stage_3_block_0:
chosen:
- shuffle_xception
stage_3_block_1:
chosen:
- shuffle_7x7
stage_3_block_2:
chosen:
- shuffle_xception
stage_3_block_3:
chosen:
- shuffle_xception
30 changes: 28 additions & 2 deletions tests/test_codebase/test_mmcls/test_classification.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
# Copyright (c) OpenMMLab. All rights reserved.
import copy
import os
from tempfile import NamedTemporaryFile, TemporaryDirectory
from typing import Any

import mmcv
import numpy as np
Expand Down Expand Up @@ -37,9 +39,33 @@
img = np.random.rand(*img_shape, 3)


def test_init_pytorch_model():
@pytest.mark.parametrize('from_mmrazor', [True, False, '123', 0])
def test_init_pytorch_model(from_mmrazor: Any):
from mmcls.models.classifiers.base import BaseClassifier
model = task_processor.init_pytorch_model(None)
if from_mmrazor is False:
_task_processor = task_processor
else:
_model_cfg_path = 'tests/test_codebase/test_mmcls/data/' \
'mmrazor_model.py'
_model_cfg = load_config(_model_cfg_path)[0]
_model_cfg.algorithm.architecture.model.type = 'mmcls.ImageClassifier'
_model_cfg.algorithm.architecture.model.backbone = dict(
type='SearchableShuffleNetV2', widen_factor=1.0)
_deploy_cfg = copy.deepcopy(deploy_cfg)
_deploy_cfg.codebase_config['from_mmrazor'] = from_mmrazor
_task_processor = build_task_processor(_model_cfg, _deploy_cfg, 'cpu')

if not isinstance(from_mmrazor, bool):
with pytest.raises(
TypeError,
match='`from_mmrazor` attribute must be '
'boolean type! '
f'but got: {from_mmrazor}'):
_ = _task_processor.from_mmrazor
return
assert from_mmrazor == _task_processor.from_mmrazor

model = _task_processor.init_pytorch_model(None)
assert isinstance(model, BaseClassifier)


Expand Down
34 changes: 34 additions & 0 deletions tests/test_codebase/test_mmdet/data/mmrazor_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
# Copyright (c) OpenMMLab. All rights reserved.
_base_ = 'model.py'

norm_cfg = dict(type='BN', requires_grad=True)
mutator = dict(
type='OneShotMutator',
placeholder_mapping=dict(
all_blocks=dict(
type='OneShotOP',
choices=dict(
shuffle_3x3=dict(
type='ShuffleBlock', norm_cfg=norm_cfg, kernel_size=3),
shuffle_5x5=dict(
type='ShuffleBlock', norm_cfg=norm_cfg, kernel_size=5),
shuffle_7x7=dict(
type='ShuffleBlock', norm_cfg=norm_cfg, kernel_size=7),
shuffle_xception=dict(
type='ShuffleXception',
norm_cfg=norm_cfg,
),
))))

algorithm = dict(
type='DetNAS',
architecture=dict(
type='MMDetArchitecture',
model={{_base_.model}},
),
mutator=mutator,
pruner=None,
distiller=None,
retraining=True,
mutable_cfg='tests/test_codebase/test_mmdet/data/mmrazor_mutable_cfg.yaml',
)
60 changes: 60 additions & 0 deletions tests/test_codebase/test_mmdet/data/mmrazor_mutable_cfg.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
stage_0_block_0:
chosen:
- shuffle_7x7
stage_0_block_1:
chosen:
- shuffle_5x5
stage_0_block_2:
chosen:
- shuffle_7x7
stage_0_block_3:
chosen:
- shuffle_3x3
stage_1_block_0:
chosen:
- shuffle_7x7
stage_1_block_1:
chosen:
- shuffle_5x5
stage_1_block_2:
chosen:
- shuffle_5x5
stage_1_block_3:
chosen:
- shuffle_7x7
stage_2_block_0:
chosen:
- shuffle_xception
stage_2_block_1:
chosen:
- shuffle_xception
stage_2_block_2:
chosen:
- shuffle_5x5
stage_2_block_3:
chosen:
- shuffle_xception
stage_2_block_4:
chosen:
- shuffle_3x3
stage_2_block_5:
chosen:
- shuffle_3x3
stage_2_block_6:
chosen:
- shuffle_xception
stage_2_block_7:
chosen:
- shuffle_5x5
stage_3_block_0:
chosen:
- shuffle_xception
stage_3_block_1:
chosen:
- shuffle_5x5
stage_3_block_2:
chosen:
- shuffle_xception
stage_3_block_3:
chosen:
- shuffle_7x7
30 changes: 28 additions & 2 deletions tests/test_codebase/test_mmdet/test_object_detection.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
# Copyright (c) OpenMMLab. All rights reserved.
import copy
import os
from tempfile import NamedTemporaryFile, TemporaryDirectory
from typing import Any

import mmcv
import numpy as np
Expand Down Expand Up @@ -48,9 +50,33 @@
img = np.random.rand(*img_shape, 3)


def test_init_pytorch_model():
@pytest.mark.parametrize('from_mmrazor', [True, False, '123', 0])
def test_init_pytorch_model(from_mmrazor: Any):
from mmdet.models import BaseDetector
model = task_processor.init_pytorch_model(None)
if from_mmrazor is False:
_task_processor = task_processor
else:
_model_cfg_path = 'tests/test_codebase/test_mmdet/data/' \
'mmrazor_model.py'
_model_cfg = load_config(_model_cfg_path)[0]
_model_cfg.algorithm.architecture.model.type = 'mmdet.YOLOV3'
_model_cfg.algorithm.architecture.model.backbone.type = \
'mmcls.SearchableShuffleNetV2'
_deploy_cfg = copy.deepcopy(deploy_cfg)
_deploy_cfg.codebase_config['from_mmrazor'] = from_mmrazor
_task_processor = build_task_processor(_model_cfg, _deploy_cfg, 'cpu')

if not isinstance(from_mmrazor, bool):
with pytest.raises(
TypeError,
match='`from_mmrazor` attribute must be '
'boolean type! '
f'but got: {from_mmrazor}'):
_ = _task_processor.from_mmrazor
return
assert from_mmrazor == _task_processor.from_mmrazor

model = _task_processor.init_pytorch_model(None)
assert isinstance(model, BaseDetector)


Expand Down
28 changes: 28 additions & 0 deletions tests/test_codebase/test_mmseg/data/mmrazor_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
# Copyright (c) OpenMMLab. All rights reserved.
_base_ = 'model.py'

# algorithm setting
algorithm = dict(
type='GeneralDistill',
architecture=dict(
type='MMSegArchitecture',
model={{_base_.model}},
),
distiller=dict(
type='SingleTeacherDistiller',
teacher={{_base_.model}},
teacher_trainable=False,
components=[
dict(
student_module='decode_head.conv_seg',
teacher_module='decode_head.conv_seg',
losses=[
dict(
type='ChannelWiseDivergence',
name='loss_cwd_logits',
tau=1,
loss_weight=5,
)
])
]),
)
Loading