Skip to content

Commit

Permalink
fix shufflenetv2 with trt (#645)
Browse files Browse the repository at this point in the history
* fix shufflenetv2 and pspnet

* fix ci

* remove print
  • Loading branch information
RunningLeon authored Jun 27, 2022
1 parent ae47e9d commit fa034e0
Show file tree
Hide file tree
Showing 6 changed files with 22 additions and 23 deletions.
8 changes: 6 additions & 2 deletions mmdeploy/apis/calibration.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,8 @@
from mmcv.parallel import MMDataParallel

from mmdeploy.core import patch_model
from mmdeploy.utils import cfg_apply_marks, load_config
from mmdeploy.utils import (IR, cfg_apply_marks, get_backend, get_ir_config,
load_config)
from .core import PIPELINE_MANAGER, no_mp
from .utils import create_calib_input_data as create_calib_input_data_impl

Expand Down Expand Up @@ -61,7 +62,10 @@ def create_calib_input_data(calib_file: str,
dataset = task_processor.build_dataset(dataset_cfg, dataset_type)

# patch model
patched_model = patch_model(model, cfg=deploy_cfg)
backend = get_backend(deploy_cfg)
ir = IR.get(get_ir_config(deploy_cfg)['type'])
patched_model = patch_model(
model, cfg=deploy_cfg, backend=backend, ir=ir)

dataloader = task_processor.build_dataloader(
dataset, 1, 1, dist=False, shuffle=False)
Expand Down
7 changes: 4 additions & 3 deletions mmdeploy/apis/onnx/export.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

from mmdeploy.apis.core import PIPELINE_MANAGER
from mmdeploy.core import RewriterContext, patch_model
from mmdeploy.utils import Backend, get_root_logger
from mmdeploy.utils import IR, Backend, get_ir_config, get_root_logger
from .optimizer import * # noqa
from .passes import optimize_onnx

Expand Down Expand Up @@ -91,20 +91,21 @@ def _add_or_update(cfg: dict, key: str, val: Any):
verbose=verbose,
keep_initializers_as_inputs=keep_initializers_as_inputs)
_add_or_update(deploy_cfg, 'ir_config', ir_config)

ir = IR.get(get_ir_config(deploy_cfg)['type'])
if isinstance(backend, Backend):
backend = backend.value
backend_config = dict(type=backend)
_add_or_update(deploy_cfg, 'backend_config', backend_config)

context_info['cfg'] = deploy_cfg
context_info['ir'] = ir
if 'backend' not in context_info:
context_info['backend'] = backend
if 'opset' not in context_info:
context_info['opset'] = opset_version

# patch model
patched_model = patch_model(model, cfg=deploy_cfg, backend=backend)
patched_model = patch_model(model, cfg=deploy_cfg, backend=backend, ir=ir)

if 'onnx_custom_passes' not in context_info:
onnx_custom_passes = optimize_onnx if optimize else None
Expand Down
5 changes: 3 additions & 2 deletions mmdeploy/apis/torch_jit/trace.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from packaging.version import parse as version_parse

from mmdeploy.core import RewriterContext, patch_model
from mmdeploy.utils import IR, Backend, get_root_logger
from mmdeploy.utils import IR, Backend, get_ir_config, get_root_logger
from ..core import PIPELINE_MANAGER


Expand Down Expand Up @@ -87,7 +87,8 @@ def _add_or_update(cfg: dict, key: str, val: Any):

# patch model
if isinstance(func, torch.nn.Module):
func = patch_model(func, cfg=deploy_cfg, backend=backend)
ir = IR.get(get_ir_config(deploy_cfg)['type'])
func = patch_model(func, cfg=deploy_cfg, backend=backend, ir=ir)

with RewriterContext(**context_info), torch.no_grad():
# for exporting models with weight that depends on inputs
Expand Down
4 changes: 2 additions & 2 deletions mmdeploy/codebase/mmcls/models/backbones/__init__.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
# Copyright (c) OpenMMLab. All rights reserved.
from .shufflenet_v2 import shufflenetv2_backbone__forward__ncnn
from .shufflenet_v2 import shufflenetv2_backbone__forward__default
from .vision_transformer import visiontransformer__forward__ncnn

__all__ = [
'shufflenetv2_backbone__forward__ncnn',
'shufflenetv2_backbone__forward__default',
'visiontransformer__forward__ncnn',
]
17 changes: 5 additions & 12 deletions mmdeploy/codebase/mmcls/models/backbones/shufflenet_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,23 +2,16 @@
import torch

from mmdeploy.core import FUNCTION_REWRITER
from mmdeploy.utils import Backend


# torch.chunk will export dynamic shape slice, which will lead integer input
# on ncnn backend. So the model needs to rewrite.
@FUNCTION_REWRITER.register_rewriter(
func_name='mmcls.models.backbones.shufflenet_v2.InvertedResidual.forward',
backend=Backend.NCNN.value)
@FUNCTION_REWRITER.register_rewriter(
func_name='mmcls.models.backbones.shufflenet_v2.InvertedResidual.forward',
backend=Backend.TORCHSCRIPT.value)
def shufflenetv2_backbone__forward__ncnn(ctx, self, x):
"""Rewrite `forward` of InvertedResidual used in shufflenet_v2 for ncnn
backend.
func_name='mmcls.models.backbones.shufflenet_v2.InvertedResidual.forward')
def shufflenetv2_backbone__forward__default(ctx, self, x):
"""Rewrite `forward` of InvertedResidual used in shufflenet_v2.
The chunk in original InvertedResidual.forward will convert to dynamic
`Slice` operator in ONNX, which will raise error in ncnn.
`Slice` operator in ONNX, which will raise error in ncnn, torchscript
and tensorrt. Here we replace `chunk` with `split`.
Args:
ctx (ContextCaller): The context with additional information.
Expand Down
4 changes: 2 additions & 2 deletions mmdeploy/utils/config_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,13 +14,13 @@ def load_config(*args) -> List[mmcv.Config]:
args (str | Sequence[str]): The path to the config file(s).
Returns:
List[mmcv.Config]: The content of config.
List[mmcv.Config | dict]: The content of config.
"""

def _load_config(cfg):
if isinstance(cfg, str):
cfg = mmcv.Config.fromfile(cfg)
if not isinstance(cfg, (mmcv.Config, mmcv.ConfigDict)):
if not isinstance(cfg, (mmcv.Config, mmcv.ConfigDict, dict)):
raise TypeError('deploy_cfg must be a filename or Config object, '
f'but got {type(cfg)}')
return cfg
Expand Down

0 comments on commit fa034e0

Please sign in to comment.