From f607f1965b3e16258cf21cbb77bc1d6e4315b2e5 Mon Sep 17 00:00:00 2001 From: AllentDan <41138331+AllentDan@users.noreply.github.com> Date: Tue, 3 Aug 2021 17:12:44 +0800 Subject: [PATCH] [Feature] Align datasets (#29) * add test tool and re-orgnize apis.utils * handle topk and refine codes * add cls export and test support * fix lint * move ort into wrapper * resolve conflicts * resolve comments * resolve conflicts * resolve comments and padding mrcnn * resolve comments --- configs/mmcls/mmcls_tensorrt.py | 2 +- mmdeploy/apis/__init__.py | 10 +- mmdeploy/apis/test.py | 142 +++++++++ mmdeploy/apis/utils.py | 278 +++++++++--------- mmdeploy/mmcls/export/model_wrappers.py | 2 +- mmdeploy/mmcls/models/__init__.py | 1 + mmdeploy/mmcls/models/heads/__init__.py | 13 + mmdeploy/mmcls/models/heads/cls_head.py | 13 + mmdeploy/mmcls/models/heads/linear_head.py | 14 + .../mmcls/models/heads/multi_label_head.py | 12 + .../models/heads/multi_label_linear_head.py | 14 + mmdeploy/mmcls/models/heads/stacked_head.py | 16 + .../models/heads/vision_transformer_head.py | 14 + mmdeploy/mmdet/export/__init__.py | 4 +- mmdeploy/mmdet/export/model_wrappers.py | 2 +- mmdeploy/mmdet/export/tensorrt_helper.py | 18 ++ .../mmdet/models/dense_heads/anchor_head.py | 11 +- .../mmdet/models/dense_heads/fcos_head.py | 13 +- mmdeploy/mmdet/models/dense_heads/rpn_head.py | 27 +- tools/deploy.py | 7 +- tools/test.py | 101 +++++++ 21 files changed, 535 insertions(+), 179 deletions(-) create mode 100644 mmdeploy/apis/test.py create mode 100644 mmdeploy/mmcls/models/heads/__init__.py create mode 100644 mmdeploy/mmcls/models/heads/cls_head.py create mode 100644 mmdeploy/mmcls/models/heads/linear_head.py create mode 100644 mmdeploy/mmcls/models/heads/multi_label_head.py create mode 100644 mmdeploy/mmcls/models/heads/multi_label_linear_head.py create mode 100644 mmdeploy/mmcls/models/heads/stacked_head.py create mode 100644 mmdeploy/mmcls/models/heads/vision_transformer_head.py create mode 100644 mmdeploy/mmdet/export/tensorrt_helper.py create mode 100644 tools/test.py diff --git a/configs/mmcls/mmcls_tensorrt.py b/configs/mmcls/mmcls_tensorrt.py index b2dfd6c814..b2cb5ca4f0 100644 --- a/configs/mmcls/mmcls_tensorrt.py +++ b/configs/mmcls/mmcls_tensorrt.py @@ -3,6 +3,6 @@ dict( save_file='end2end.engine', opt_shape_dict=dict( - input=[[1, 3, 224, 224], [4, 3, 224, 224], [32, 3, 224, 224]]), + input=[[1, 3, 224, 224], [4, 3, 224, 224], [64, 3, 224, 224]]), max_workspace_size=1 << 30) ]) diff --git a/mmdeploy/apis/__init__.py b/mmdeploy/apis/__init__.py index 699703c40f..586a7b22fe 100644 --- a/mmdeploy/apis/__init__.py +++ b/mmdeploy/apis/__init__.py @@ -1,5 +1,13 @@ from .extract_model import extract_model from .inference import inference_model from .pytorch2onnx import torch2onnx, torch2onnx_impl +from .test import post_process_outputs, prepare_data_loader, single_gpu_test +from .utils import (assert_cfg_valid, assert_module_exist, + get_classes_from_config, init_backend_model) -__all__ = ['torch2onnx_impl', 'torch2onnx', 'extract_model', 'inference_model'] +__all__ = [ + 'torch2onnx_impl', 'torch2onnx', 'extract_model', 'inference_model', + 'prepare_data_loader', 'assert_module_exist', 'assert_cfg_valid', + 'init_backend_model', 'get_classes_from_config', 'single_gpu_test', + 'post_process_outputs' +] diff --git a/mmdeploy/apis/test.py b/mmdeploy/apis/test.py new file mode 100644 index 0000000000..8350900d8e --- /dev/null +++ b/mmdeploy/apis/test.py @@ -0,0 +1,142 @@ +import warnings +from typing import Any, Union + +import mmcv +import numpy as np +from torch import nn +from torch.utils.data import DataLoader + +from mmdeploy.apis.utils import assert_module_exist + + +def prepare_data_loader(codebase: str, model_cfg: Union[str, mmcv.Config]): + # load model_cfg if necessary + if isinstance(model_cfg, str): + model_cfg = mmcv.Config.fromfile(model_cfg) + + if codebase == 'mmcls': + from mmcls.datasets import (build_dataloader, build_dataset) + assert_module_exist(codebase) + # build dataset and dataloader + dataset = build_dataset(model_cfg.data.test) + data_loader = build_dataloader( + dataset, + samples_per_gpu=model_cfg.data.samples_per_gpu, + workers_per_gpu=model_cfg.data.workers_per_gpu, + shuffle=False, + round_up=False) + + elif codebase == 'mmdet': + assert_module_exist(codebase) + from mmdet.datasets import (build_dataloader, build_dataset, + replace_ImageToTensor) + # in case the test dataset is concatenated + samples_per_gpu = 1 + if isinstance(model_cfg.data.test, dict): + model_cfg.data.test.test_mode = True + samples_per_gpu = model_cfg.data.test.pop('samples_per_gpu', 1) + if samples_per_gpu > 1: + # Replace 'ImageToTensor' to 'DefaultFormatBundle' + model_cfg.data.test.pipeline = replace_ImageToTensor( + model_cfg.data.test.pipeline) + elif isinstance(model_cfg.data.test, list): + for ds_cfg in model_cfg.data.test: + ds_cfg.test_mode = True + samples_per_gpu = max([ + ds_cfg.pop('samples_per_gpu', 1) + for ds_cfg in model_cfg.data.test + ]) + if samples_per_gpu > 1: + for ds_cfg in model_cfg.data.test: + ds_cfg.pipeline = replace_ImageToTensor(ds_cfg.pipeline) + + # build the dataloader + dataset = build_dataset(model_cfg.data.test) + data_loader = build_dataloader( + dataset, + samples_per_gpu=samples_per_gpu, + workers_per_gpu=model_cfg.data.workers_per_gpu, + dist=False, + shuffle=False) + + else: + raise NotImplementedError(f'Unknown codebase type: {codebase}') + + return dataset, data_loader + + +def single_gpu_test(codebase: str, + model: nn.Module, + data_loader: DataLoader, + show: bool = False, + out_dir: Any = None, + show_score_thr: float = 0.3): + if codebase == 'mmcls': + assert_module_exist(codebase) + from mmcls.apis import single_gpu_test + outputs = single_gpu_test(model, data_loader, show, out_dir) + elif codebase == 'mmdet': + assert_module_exist(codebase) + from mmdet.apis import single_gpu_test + outputs = single_gpu_test(model, data_loader, show, out_dir, + show_score_thr) + + else: + raise NotImplementedError(f'Unknown codebase type: {codebase}') + return outputs + + +def post_process_outputs(outputs, + dataset, + model_cfg: mmcv.Config, + codebase: str, + metrics: str = None, + out: str = None, + metric_options: dict = None, + format_only: bool = False): + if codebase == 'mmcls': + if metrics: + results = dataset.evaluate(outputs, metrics, metric_options) + for k, v in results.items(): + print(f'\n{k} : {v:.2f}') + else: + warnings.warn('Evaluation metrics are not specified.') + scores = np.vstack(outputs) + pred_score = np.max(scores, axis=1) + pred_label = np.argmax(scores, axis=1) + pred_class = [dataset.CLASSES[lb] for lb in pred_label] + results = { + 'pred_score': pred_score, + 'pred_label': pred_label, + 'pred_class': pred_class + } + if not out: + print('\nthe predicted result for the first element is ' + f'pred_score = {pred_score[0]:.2f}, ' + f'pred_label = {pred_label[0]} ' + f'and pred_class = {pred_class[0]}. ' + 'Specify --out to save all results to files.') + if out: + print(f'\nwriting results to {out}') + mmcv.dump(results, out) + + elif codebase == 'mmdet': + if out: + print(f'\nwriting results to {out}') + mmcv.dump(outputs, out) + kwargs = {} if metric_options is None else metric_options + if format_only: + dataset.format_results(outputs, **kwargs) + if metrics: + eval_kwargs = model_cfg.get('evaluation', {}).copy() + # hard-code way to remove EvalHook args + for key in [ + 'interval', 'tmpdir', 'start', 'gpu_collect', 'save_best', + 'rule' + ]: + eval_kwargs.pop(key, None) + eval_kwargs.update(dict(metric=metrics, **kwargs)) + print(dataset.evaluate(outputs, **eval_kwargs)) + + else: + raise NotImplementedError(f'Unknown codebase type: {codebase}') diff --git a/mmdeploy/apis/utils.py b/mmdeploy/apis/utils.py index 87860290d3..22f7ad0572 100644 --- a/mmdeploy/apis/utils.py +++ b/mmdeploy/apis/utils.py @@ -6,8 +6,24 @@ import torch -def module_exist(module_name: str): - return importlib.util.find_spec(module_name) is not None +def assert_cfg_valid(cfg: Union[str, mmcv.Config, mmcv.ConfigDict], *args): + """Check config validation.""" + + def _assert_cfg_valid_(cfg): + if isinstance(cfg, str): + cfg = mmcv.Config.fromfile(cfg) + if not isinstance(cfg, (mmcv.Config, mmcv.ConfigDict)): + raise TypeError('deploy_cfg must be a filename or Config object, ' + f'but got {type(cfg)}') + + _assert_cfg_valid_(cfg) + for cfg in args: + _assert_cfg_valid_(cfg) + + +def assert_module_exist(module_name: str): + if importlib.util.find_spec(module_name) is None: + raise ImportError(f'Can not import module: {module_name}') def init_model(codebase: str, @@ -17,25 +33,20 @@ def init_model(codebase: str, cfg_options: Optional[Dict] = None): # mmcls if codebase == 'mmcls': - if module_exist(codebase): - from mmcls.apis import init_model - model = init_model(model_cfg, model_checkpoint, device, - cfg_options) - else: - raise ImportError(f'Can not import module: {codebase}') + assert_module_exist(codebase) + from mmcls.apis import init_model + model = init_model(model_cfg, model_checkpoint, device, cfg_options) + elif codebase == 'mmdet': - if module_exist(codebase): - from mmdet.apis import init_detector - model = init_detector(model_cfg, model_checkpoint, device, - cfg_options) - else: - raise ImportError(f'Can not import module: {codebase}') + assert_module_exist(codebase) + from mmdet.apis import init_detector + model = init_detector(model_cfg, model_checkpoint, device, cfg_options) + elif codebase == 'mmseg': - if module_exist(codebase): - from mmseg.apis import init_segmentor - model = init_segmentor(model_cfg, model_checkpoint, device) - else: - raise ImportError(f'Can not import module: {codebase}') + assert_module_exist(codebase) + from mmseg.apis import init_segmentor + model = init_segmentor(model_cfg, model_checkpoint, device) + else: raise NotImplementedError(f'Unknown codebase type: {codebase}') @@ -54,17 +65,15 @@ def create_input(codebase: str, cfg = model_cfg.copy() if codebase == 'mmcls': - if module_exist(codebase): - from mmdeploy.mmcls.export import create_input - return create_input(cfg, imgs, device) - else: - raise ImportError(f'Can not import module: {codebase}') + assert_module_exist(codebase) + from mmdeploy.mmcls.export import create_input + return create_input(cfg, imgs, device) + elif codebase == 'mmdet': - if module_exist(codebase): - from mmdeploy.mmdet.export import create_input - return create_input(cfg, imgs, device) - else: - raise ImportError(f'Can not import module: {codebase}') + assert_module_exist(codebase) + from mmdeploy.mmdet.export import create_input + return create_input(cfg, imgs, device) + else: raise NotImplementedError(f'Unknown codebase type: {codebase}') @@ -86,45 +95,33 @@ def init_backend_model(model_files: Sequence[str], class_names: Sequence[str], device_id: int = 0): if codebase == 'mmcls': - if module_exist(codebase): - if backend == 'onnxruntime': - from mmdeploy.mmcls.export import ONNXRuntimeClassifier - backend_model = ONNXRuntimeClassifier( - model_files[0], - class_names=class_names, - device_id=device_id) - elif backend == 'tensorrt': - from mmdeploy.mmcls.export import TensorRTClassifier - backend_model = TensorRTClassifier( - model_files[0], - class_names=class_names, - device_id=device_id) - else: - raise NotImplementedError( - f'Unsupported backend type: {backend}') - return backend_model + assert_module_exist(codebase) + if backend == 'onnxruntime': + from mmdeploy.mmcls.export import ONNXRuntimeClassifier + backend_model = ONNXRuntimeClassifier( + model_files[0], class_names=class_names, device_id=device_id) + elif backend == 'tensorrt': + from mmdeploy.mmcls.export import TensorRTClassifier + backend_model = TensorRTClassifier( + model_files[0], class_names=class_names, device_id=device_id) else: - raise ImportError(f'Can not import module: {codebase}') + raise NotImplementedError(f'Unsupported backend type: {backend}') + return backend_model + elif codebase == 'mmdet': - if module_exist(codebase): - if backend == 'onnxruntime': - from mmdeploy.mmdet.export import ONNXRuntimeDetector - backend_model = ONNXRuntimeDetector( - model_files[0], - class_names=class_names, - device_id=device_id) - elif backend == 'tensorrt': - from mmdeploy.mmdet.export import TensorRTDetector - backend_model = TensorRTDetector( - model_files[0], - class_names=class_names, - device_id=device_id) - else: - raise NotImplementedError( - f'Unsupported backend type: {backend}') - return backend_model + assert_module_exist(codebase) + if backend == 'onnxruntime': + from mmdeploy.mmdet.export import ONNXRuntimeDetector + backend_model = ONNXRuntimeDetector( + model_files[0], class_names=class_names, device_id=device_id) + elif backend == 'tensorrt': + from mmdeploy.mmdet.export import TensorRTDetector + backend_model = TensorRTDetector( + model_files[0], class_names=class_names, device_id=device_id) else: - raise ImportError(f'Can not import module: {codebase}') + raise NotImplementedError(f'Unsupported backend type: {backend}') + return backend_model + else: raise NotImplementedError(f'Unknown codebase type: {codebase}') @@ -132,56 +129,51 @@ def init_backend_model(model_files: Sequence[str], def get_classes_from_config(codebase: str, model_cfg: Union[str, mmcv.Config]): model_cfg_str = model_cfg if codebase == 'mmcls': - if module_exist(codebase): - if isinstance(model_cfg, str): - model_cfg = mmcv.Config.fromfile(model_cfg) - elif not isinstance(model_cfg, (mmcv.Config, mmcv.ConfigDict)): - raise TypeError('config must be a filename or Config object, ' - f'but got {type(model_cfg)}') - - from mmcls.datasets import DATASETS - module_dict = DATASETS.module_dict - data_cfg = model_cfg.data - - if 'train' in data_cfg: - module = module_dict[data_cfg.train.type] - elif 'val' in data_cfg: - module = module_dict[data_cfg.val.type] - elif 'test' in data_cfg: - module = module_dict[data_cfg.test.type] - else: - raise RuntimeError( - f'No dataset config found in: {model_cfg_str}') - - return module.CLASSES + assert_module_exist(codebase) + if isinstance(model_cfg, str): + model_cfg = mmcv.Config.fromfile(model_cfg) + elif not isinstance(model_cfg, (mmcv.Config, mmcv.ConfigDict)): + raise TypeError('config must be a filename or Config object, ' + f'but got {type(model_cfg)}') + + from mmcls.datasets import DATASETS + module_dict = DATASETS.module_dict + data_cfg = model_cfg.data + + if 'train' in data_cfg: + module = module_dict[data_cfg.train.type] + elif 'val' in data_cfg: + module = module_dict[data_cfg.val.type] + elif 'test' in data_cfg: + module = module_dict[data_cfg.test.type] else: - raise ImportError(f'Can not import module: {codebase}') + raise RuntimeError(f'No dataset config found in: {model_cfg_str}') + + return module.CLASSES if codebase == 'mmdet': - if module_exist(codebase): - if isinstance(model_cfg, str): - model_cfg = mmcv.Config.fromfile(model_cfg) - elif not isinstance(model_cfg, (mmcv.Config, mmcv.ConfigDict)): - raise TypeError('config must be a filename or Config object, ' - f'but got {type(model_cfg)}') - - from mmdet.datasets import DATASETS - module_dict = DATASETS.module_dict - data_cfg = model_cfg.data - - if 'train' in data_cfg: - module = module_dict[data_cfg.train.type] - elif 'val' in data_cfg: - module = module_dict[data_cfg.val.type] - elif 'test' in data_cfg: - module = module_dict[data_cfg.test.type] - else: - raise RuntimeError( - f'No dataset config found in: {model_cfg_str}') - - return module.CLASSES + assert_module_exist(codebase) + if isinstance(model_cfg, str): + model_cfg = mmcv.Config.fromfile(model_cfg) + elif not isinstance(model_cfg, (mmcv.Config, mmcv.ConfigDict)): + raise TypeError('config must be a filename or Config object, ' + f'but got {type(model_cfg)}') + + from mmdet.datasets import DATASETS + module_dict = DATASETS.module_dict + data_cfg = model_cfg.data + + if 'train' in data_cfg: + module = module_dict[data_cfg.train.type] + elif 'val' in data_cfg: + module = module_dict[data_cfg.val.type] + elif 'test' in data_cfg: + module = module_dict[data_cfg.test.type] else: - raise ImportError(f'Can not import module: {codebase}') + raise RuntimeError(f'No dataset config found in: {model_cfg_str}') + + return module.CLASSES + else: raise NotImplementedError(f'Unknown codebase type: {codebase}') @@ -195,41 +187,37 @@ def check_model_outputs(codebase: str, show_result=False): show_img = mmcv.imread(image) if isinstance(image, str) else image if codebase == 'mmcls': - if module_exist(codebase): - output_file = None if show_result else output_file - with torch.no_grad(): - scores = model(**model_inputs, return_loss=False)[0] - pred_score = np.max(scores, axis=0) - pred_label = np.argmax(scores, axis=0) - result = { - 'pred_label': pred_label, - 'pred_score': float(pred_score) - } - result['pred_class'] = model.CLASSES[result['pred_label']] - model.show_result( - show_img, - result, - show=True, - win_name=backend, - out_file=output_file) - else: - raise ImportError(f'Can not import module: {codebase}') + assert_module_exist(codebase) + output_file = None if show_result else output_file + with torch.no_grad(): + scores = model(**model_inputs, return_loss=False)[0] + pred_score = np.max(scores, axis=0) + pred_label = np.argmax(scores, axis=0) + result = { + 'pred_label': pred_label, + 'pred_score': float(pred_score) + } + result['pred_class'] = model.CLASSES[result['pred_label']] + model.show_result( + show_img, + result, + show=True, + win_name=backend, + out_file=output_file) + elif codebase == 'mmdet': - if module_exist(codebase): - output_file = None if show_result else output_file - score_thr = 0.3 - with torch.no_grad(): - results = model( - **model_inputs, return_loss=False, rescale=True)[0] - model.show_result( - show_img, - results, - score_thr=score_thr, - show=True, - win_name=backend, - out_file=output_file) + assert_module_exist(codebase) + output_file = None if show_result else output_file + score_thr = 0.3 + with torch.no_grad(): + results = model(**model_inputs, return_loss=False, rescale=True)[0] + model.show_result( + show_img, + results, + score_thr=score_thr, + show=True, + win_name=backend, + out_file=output_file) - else: - raise ImportError(f'Can not import module: {codebase}') else: raise NotImplementedError(f'Unknown codebase type: {codebase}') diff --git a/mmdeploy/mmcls/export/model_wrappers.py b/mmdeploy/mmcls/export/model_wrappers.py index 30b8f74542..8ea589d47a 100644 --- a/mmdeploy/mmcls/export/model_wrappers.py +++ b/mmdeploy/mmcls/export/model_wrappers.py @@ -1,7 +1,6 @@ import warnings import numpy as np -import onnxruntime as ort import torch from mmcls.models import BaseClassifier @@ -11,6 +10,7 @@ class ONNXRuntimeClassifier(BaseClassifier): def __init__(self, onnx_file, class_names, device_id): super(ONNXRuntimeClassifier, self).__init__() + import onnxruntime as ort sess = ort.InferenceSession(onnx_file) providers = ['CPUExecutionProvider'] diff --git a/mmdeploy/mmcls/models/__init__.py b/mmdeploy/mmcls/models/__init__.py index e4a4d3b9f7..fcfdc09f08 100644 --- a/mmdeploy/mmcls/models/__init__.py +++ b/mmdeploy/mmcls/models/__init__.py @@ -1 +1,2 @@ from .classifiers import * # noqa: F401,F403 +from .heads import * # noqa: F401,F403 diff --git a/mmdeploy/mmcls/models/heads/__init__.py b/mmdeploy/mmcls/models/heads/__init__.py new file mode 100644 index 0000000000..3aa125738e --- /dev/null +++ b/mmdeploy/mmcls/models/heads/__init__.py @@ -0,0 +1,13 @@ +from .cls_head import simple_test_of_cls_head +from .linear_head import simple_test_of_linear_head +from .multi_label_head import simple_test_of_multi_label_head +from .multi_label_linear_head import simple_test_of_multi_label_linear_head +from .stacked_head import simple_test_of_stacked_head +from .vision_transformer_head import simple_test_of_vision_transformer_head + +__all__ = [ + 'simple_test_of_multi_label_linear_head', + 'simple_test_of_multi_label_head', 'simple_test_of_cls_head', + 'simple_test_of_linear_head', 'simple_test_of_stacked_head', + 'simple_test_of_vision_transformer_head' +] diff --git a/mmdeploy/mmcls/models/heads/cls_head.py b/mmdeploy/mmcls/models/heads/cls_head.py new file mode 100644 index 0000000000..164bfb206e --- /dev/null +++ b/mmdeploy/mmcls/models/heads/cls_head.py @@ -0,0 +1,13 @@ +import torch.nn.functional as F + +from mmdeploy.core import FUNCTION_REWRITER + + +@FUNCTION_REWRITER.register_rewriter( + func_name='mmcls.models.heads.ClsHead.simple_test') +def simple_test_of_cls_head(ctx, self, cls_score, **kwargs): + """Test without augmentation.""" + if isinstance(cls_score, list): + cls_score = sum(cls_score) / float(len(cls_score)) + pred = F.softmax(cls_score, dim=1) if cls_score is not None else None + return pred diff --git a/mmdeploy/mmcls/models/heads/linear_head.py b/mmdeploy/mmcls/models/heads/linear_head.py new file mode 100644 index 0000000000..b52d7f9e96 --- /dev/null +++ b/mmdeploy/mmcls/models/heads/linear_head.py @@ -0,0 +1,14 @@ +import torch.nn.functional as F + +from mmdeploy.core import FUNCTION_REWRITER + + +@FUNCTION_REWRITER.register_rewriter( + func_name='mmcls.models.heads.LinearClsHead.simple_test') +def simple_test_of_linear_head(ctx, self, img, **kwargs): + """Test without augmentation.""" + cls_score = self.fc(img) + if isinstance(cls_score, list): + cls_score = sum(cls_score) / float(len(cls_score)) + pred = F.softmax(cls_score, dim=1) if cls_score is not None else None + return pred diff --git a/mmdeploy/mmcls/models/heads/multi_label_head.py b/mmdeploy/mmcls/models/heads/multi_label_head.py new file mode 100644 index 0000000000..b58160abcd --- /dev/null +++ b/mmdeploy/mmcls/models/heads/multi_label_head.py @@ -0,0 +1,12 @@ +import torch.nn.functional as F + +from mmdeploy.core import FUNCTION_REWRITER + + +@FUNCTION_REWRITER.register_rewriter( + func_name='mmcls.models.heads.MultiLabelClsHead.simple_test') +def simple_test_of_multi_label_head(ctx, self, cls_score, **kwargs): + if isinstance(cls_score, list): + cls_score = sum(cls_score) / float(len(cls_score)) + pred = F.sigmoid(cls_score) if cls_score is not None else None + return pred diff --git a/mmdeploy/mmcls/models/heads/multi_label_linear_head.py b/mmdeploy/mmcls/models/heads/multi_label_linear_head.py new file mode 100644 index 0000000000..1d5e210e3a --- /dev/null +++ b/mmdeploy/mmcls/models/heads/multi_label_linear_head.py @@ -0,0 +1,14 @@ +import torch.nn.functional as F + +from mmdeploy.core import FUNCTION_REWRITER + + +@FUNCTION_REWRITER.register_rewriter( + func_name='mmcls.models.heads.MultiLabelLinearClsHead.simple_test') +def simple_test_of_multi_label_linear_head(ctx, self, img, **kwargs): + """Test without augmentation.""" + cls_score = self.fc(img) + if isinstance(cls_score, list): + cls_score = sum(cls_score) / float(len(cls_score)) + pred = F.sigmoid(cls_score) if cls_score is not None else None + return pred diff --git a/mmdeploy/mmcls/models/heads/stacked_head.py b/mmdeploy/mmcls/models/heads/stacked_head.py new file mode 100644 index 0000000000..993cffdf77 --- /dev/null +++ b/mmdeploy/mmcls/models/heads/stacked_head.py @@ -0,0 +1,16 @@ +import torch.nn.functional as F + +from mmdeploy.core import FUNCTION_REWRITER + + +@FUNCTION_REWRITER.register_rewriter( + func_name='mmcls.models.heads.StackedLinearClsHead.simple_test') +def simple_test_of_stacked_head(ctx, self, img, **kwargs): + """Test without augmentation.""" + cls_score = img + for layer in self.layers: + cls_score = layer(cls_score) + if isinstance(cls_score, list): + cls_score = sum(cls_score) / float(len(cls_score)) + pred = F.softmax(cls_score, dim=1) if cls_score is not None else None + return pred diff --git a/mmdeploy/mmcls/models/heads/vision_transformer_head.py b/mmdeploy/mmcls/models/heads/vision_transformer_head.py new file mode 100644 index 0000000000..1fd66868fc --- /dev/null +++ b/mmdeploy/mmcls/models/heads/vision_transformer_head.py @@ -0,0 +1,14 @@ +import torch.nn.functional as F + +from mmdeploy.core import FUNCTION_REWRITER + + +@FUNCTION_REWRITER.register_rewriter( + func_name='mmcls.models.heads.VisionTransformerClsHead.simple_test') +def simple_test_of_vision_transformer_head(ctx, self, img, **kwargs): + """Test without augmentation.""" + cls_score = self.layers(img) + if isinstance(cls_score, list): + cls_score = sum(cls_score) / float(len(cls_score)) + pred = F.softmax(cls_score, dim=1) if cls_score is not None else None + return pred diff --git a/mmdeploy/mmdet/export/__init__.py b/mmdeploy/mmdet/export/__init__.py index 5eed5c7634..70cf93525a 100644 --- a/mmdeploy/mmdet/export/__init__.py +++ b/mmdeploy/mmdet/export/__init__.py @@ -1,7 +1,9 @@ from .model_wrappers import ONNXRuntimeDetector, TensorRTDetector from .onnx_helper import clip_bboxes from .prepare_input import create_input +from .tensorrt_helper import pad_with_value __all__ = [ - 'clip_bboxes', 'TensorRTDetector', 'create_input', 'ONNXRuntimeDetector' + 'clip_bboxes', 'TensorRTDetector', 'create_input', 'ONNXRuntimeDetector', + 'pad_with_value' ] diff --git a/mmdeploy/mmdet/export/model_wrappers.py b/mmdeploy/mmdet/export/model_wrappers.py index 943ccb663e..9723e604ed 100644 --- a/mmdeploy/mmdet/export/model_wrappers.py +++ b/mmdeploy/mmdet/export/model_wrappers.py @@ -74,7 +74,7 @@ def forward(self, img, img_metas, *args, **kwargs): img_h, img_w = img_metas[i]['img_shape'][:2] ori_h, ori_w = img_metas[i]['ori_shape'][:2] masks = masks[:, :img_h, :img_w] - if rescale: + if rescale and batch_masks.shape[1] > 0: masks = masks.astype(np.float32) masks = torch.from_numpy(masks) masks = torch.nn.functional.interpolate( diff --git a/mmdeploy/mmdet/export/tensorrt_helper.py b/mmdeploy/mmdet/export/tensorrt_helper.py new file mode 100644 index 0000000000..63bd4149c1 --- /dev/null +++ b/mmdeploy/mmdet/export/tensorrt_helper.py @@ -0,0 +1,18 @@ +import torch + + +def pad_with_value(x, pad_dim, pad_size, pad_value=None): + num_dims = len(x.shape) + pad_slice = (slice(None, None, None), ) * num_dims + pad_slice = pad_slice[:pad_dim] + (slice(0, 1, + 1), ) + pad_slice[pad_dim + 1:] + repeat_size = [1] * num_dims + repeat_size[pad_dim] = pad_size + + x_pad = x.__getitem__(pad_slice) + if pad_value is not None: + x_pad = x_pad * 0 + pad_value + + x_pad = x_pad.repeat(*repeat_size) + x = torch.cat([x, x_pad], dim=pad_dim) + return x diff --git a/mmdeploy/mmdet/models/dense_heads/anchor_head.py b/mmdeploy/mmdet/models/dense_heads/anchor_head.py index 886501077c..155bf57e2c 100644 --- a/mmdeploy/mmdet/models/dense_heads/anchor_head.py +++ b/mmdeploy/mmdet/models/dense_heads/anchor_head.py @@ -2,6 +2,7 @@ from mmdeploy.core import FUNCTION_REWRITER from mmdeploy.mmdet.core import multiclass_nms +from mmdeploy.mmdet.export import pad_with_value from mmdeploy.utils import is_dynamic_shape @@ -55,15 +56,15 @@ def get_bboxes_of_anchor_head(ctx, anchors = anchors.expand_as(bbox_pred) - enable_nms_pre = True backend = deploy_cfg['backend'] # topk in tensorrt does not support shape 0 and enable_nms_pre: + if pre_topk > 0: # Get maximum scores for foreground classes. if self.use_sigmoid_cls: max_scores, _ = scores.max(-1) diff --git a/mmdeploy/mmdet/models/dense_heads/fcos_head.py b/mmdeploy/mmdet/models/dense_heads/fcos_head.py index 6c56e6b657..c4825f5436 100644 --- a/mmdeploy/mmdet/models/dense_heads/fcos_head.py +++ b/mmdeploy/mmdet/models/dense_heads/fcos_head.py @@ -2,6 +2,7 @@ from mmdeploy.core import FUNCTION_REWRITER from mmdeploy.mmdet.core import distance2bbox, multiclass_nms +from mmdeploy.mmdet.export import pad_with_value from mmdeploy.utils import is_dynamic_shape @@ -59,14 +60,16 @@ def get_bboxes_of_fcos_head(ctx, points = points.expand(batch_size, -1, 2) - enable_nms_pre = True backend = deploy_cfg['backend'] # topk in tensorrt does not support shape 0 and enable_nms_pre: + if pre_topk > 0: max_scores, _ = (scores * centerness).max(-1) _, topk_inds = max_scores.topk(pre_topk) batch_inds = torch.arange(batch_size).view(-1, @@ -92,7 +95,7 @@ def get_bboxes_of_fcos_head(ctx, if not with_nms: return batch_mlvl_bboxes, batch_mlvl_scores, batch_mlvl_centerness - batch_mlvl_scores = batch_mlvl_scores * (batch_mlvl_centerness) + batch_mlvl_scores = batch_mlvl_scores * batch_mlvl_centerness post_params = deploy_cfg.post_processing max_output_boxes_per_class = post_params.max_output_boxes_per_class iou_threshold = cfg.nms.get('iou_threshold', post_params.iou_threshold) diff --git a/mmdeploy/mmdet/models/dense_heads/rpn_head.py b/mmdeploy/mmdet/models/dense_heads/rpn_head.py index e438bf270b..9ea4a58601 100644 --- a/mmdeploy/mmdet/models/dense_heads/rpn_head.py +++ b/mmdeploy/mmdet/models/dense_heads/rpn_head.py @@ -2,6 +2,7 @@ from mmdeploy.core import FUNCTION_REWRITER from mmdeploy.mmdet.core import multiclass_nms +from mmdeploy.mmdet.export import pad_with_value from mmdeploy.utils import is_dynamic_shape @@ -50,6 +51,7 @@ def get_bboxes_of_rpn_head(ctx, # be consistent with other head since mmdet v2.0. In mmdet v2.0 # to v2.4 we keep BG label as 0 and FG label as 1 in rpn head. scores = cls_score.softmax(-1)[..., 0] + scores = scores.reshape(batch_size, -1, 1) bbox_pred = bbox_pred.permute(0, 2, 3, 1).reshape(batch_size, -1, 4) # use static anchor if input shape is static @@ -58,32 +60,27 @@ def get_bboxes_of_rpn_head(ctx, anchors = anchors.expand_as(bbox_pred) - enable_nms_pre = True backend = deploy_cfg['backend'] # topk in tensorrt does not support shape 0 and enable_nms_pre: - _, topk_inds = scores.topk(pre_topk) + if pre_topk > 0: + _, topk_inds = scores.squeeze(2).topk(pre_topk) batch_inds = torch.arange( batch_size, device=device).view(-1, 1).expand_as(topk_inds) - # Avoid onnx2tensorrt issue in https://github.com/NVIDIA/TensorRT/issues/1134 # noqa: E501 - transformed_inds = scores.shape[1] * batch_inds + topk_inds - scores = scores.reshape(-1, 1)[transformed_inds].reshape( - batch_size, -1) - bbox_pred = bbox_pred.reshape(-1, 4)[transformed_inds, :].reshape( - batch_size, -1, 4) - anchors = anchors.reshape(-1, 4)[transformed_inds, :].reshape( - batch_size, -1, 4) + anchors = anchors[batch_inds, topk_inds, :] + bbox_pred = bbox_pred[batch_inds, topk_inds, :] + scores = scores[batch_inds, topk_inds, :] mlvl_valid_bboxes.append(bbox_pred) mlvl_scores.append(scores) mlvl_valid_anchors.append(anchors) batch_mlvl_bboxes = torch.cat(mlvl_valid_bboxes, dim=1) - batch_mlvl_scores = torch.cat(mlvl_scores, dim=1).unsqueeze(2) + batch_mlvl_scores = torch.cat(mlvl_scores, dim=1) batch_mlvl_anchors = torch.cat(mlvl_valid_anchors, dim=1) batch_mlvl_bboxes = self.bbox_coder.decode( batch_mlvl_anchors, diff --git a/tools/deploy.py b/tools/deploy.py index 176c1dc1ac..03a3403ae4 100644 --- a/tools/deploy.py +++ b/tools/deploy.py @@ -7,7 +7,8 @@ import torch.multiprocessing as mp from torch.multiprocessing import Process, set_start_method -from mmdeploy.apis import extract_model, inference_model, torch2onnx +from mmdeploy.apis import (assert_cfg_valid, extract_model, inference_model, + torch2onnx) def parse_args(): @@ -70,9 +71,7 @@ def main(): # load deploy_cfg deploy_cfg = mmcv.Config.fromfile(deploy_cfg_path) - if not isinstance(deploy_cfg, (mmcv.Config, mmcv.ConfigDict)): - raise TypeError('deploy_cfg must be a filename or Config object, ' - f'but got {type(deploy_cfg)}') + assert_cfg_valid(deploy_cfg, model_cfg_path) # create work_dir if not mmcv.mkdir_or_exist(osp.abspath(args.work_dir)) diff --git a/tools/test.py b/tools/test.py new file mode 100644 index 0000000000..afb69f07e5 --- /dev/null +++ b/tools/test.py @@ -0,0 +1,101 @@ +import argparse + +import mmcv +from mmcv import DictAction +from mmcv.parallel import MMDataParallel + +from mmdeploy.apis import (init_backend_model, post_process_outputs, + prepare_data_loader, single_gpu_test) +from mmdeploy.apis.utils import assert_cfg_valid, get_classes_from_config + + +def parse_args(): + parser = argparse.ArgumentParser( + description='MMDeploy test (and eval) a backend.') + parser.add_argument('deploy_cfg', help='Deploy config path') + parser.add_argument('model_cfg', help='Model config path') + parser.add_argument('model', help='Input model file.') + parser.add_argument('--out', help='output result file in pickle format') + parser.add_argument( + '--format-only', + action='store_true', + help='Format the output results without perform evaluation. It is' + 'useful when you want to format the result to a specific format and ' + 'submit it to the test server') + parser.add_argument( + '--metrics', + type=str, + nargs='+', + help='evaluation metrics, which depends on the codebase and the ' + 'dataset, e.g., "bbox", "segm", "proposal" for COCO, and "mAP", ' + '"recall" for PASCAL VOC in mmdet; "accuracy", "precision", "recall", ' + '"f1_score", "support" for single label dataset, and "mAP", "CP", "CR"' + ', "CF1", "OP", "OR", "OF1" for multi-label dataset in mmcls') + parser.add_argument('--show', action='store_true', help='show results') + parser.add_argument( + '--show-dir', help='directory where painted images will be saved') + parser.add_argument( + '--show-score-thr', + type=float, + default=0.3, + help='score threshold (default: 0.3)') + parser.add_argument( + '--device', help='device used for conversion', default='cpu') + parser.add_argument( + '--cfg-options', + nargs='+', + action=DictAction, + help='override some settings in the used config, the key-value pair ' + 'in xxx=yyy format will be merged into config file. If the value to ' + 'be overwritten is a list, it should be like key="[a,b]" or key=a,b ' + 'It also allows nested list/tuple values, e.g. key="[(a,b),(c,d)]" ' + 'Note that the quotation marks are necessary and that no white space ' + 'is allowed.') + parser.add_argument( + '--metric-options', + nargs='+', + action=DictAction, + help='custom options for evaluation, the key-value pair in xxx=yyy ' + 'format will be kwargs for dataset.evaluate() function') + args = parser.parse_args() + + return args + + +def main(): + args = parse_args() + if args.out is not None and not args.out.endswith(('.pkl', '.pickle')): + raise ValueError('The output file must be a pkl file.') + + deploy_cfg_path = args.deploy_cfg + model_cfg_path = args.model_cfg + + # load deploy_cfg + deploy_cfg = mmcv.Config.fromfile(deploy_cfg_path) + model_cfg = mmcv.Config.fromfile(model_cfg_path) + assert_cfg_valid(deploy_cfg, model_cfg) + + # prepare the dataset loader + codebase = deploy_cfg['codebase'] + dataset, data_loader = prepare_data_loader(codebase, model_cfg) + + # load the model of the backend + device_id = -1 if args.device == 'cpu' else 0 + backend = deploy_cfg.get('backend', 'default') + model = init_backend_model([args.model], + codebase=codebase, + backend=backend, + class_names=get_classes_from_config( + codebase, model_cfg), + device_id=device_id) + + model = MMDataParallel(model, device_ids=[0]) + outputs = single_gpu_test(codebase, model, data_loader, args.show, + args.show_dir, args.show_score_thr) + + post_process_outputs(outputs, dataset, model_cfg, codebase, args.metrics, + args.out, args.metric_options, args.format_only) + + +if __name__ == '__main__': + main()