Skip to content

Commit

Permalink
[Feature] support mmyolo deployment (open-mmlab#79)
Browse files Browse the repository at this point in the history
* support mmyolo deployment

* mv deploy place

* remove unused configs

* add deploy code

* fix new register

* fix comments

* fix dependent codebase register

* remove unused initialize

* refact deploy config

* credit return to triplemu

* Add yolov5 head rewrite

* refactor deploy

* refactor deploy

* Add yolov5 head rewrite

* fix configs

* refact config

* fix comment

* sync name after mmdeploy 1088

* fix mmyolo

* fix yapf

* fix deploy config

* try to fix flake8 importlib-metadata

* add mmyolo models ut

* add deploy uts

* add deploy uts

* fix trt dynamic error

* fix multi-batch for dynamic batch value

* fix mode

* fix lint

* sync model.py

* add ci for deploy test

* fix ci

* fix ci

* fix ci

* extract script to command for fixing CI

* fix cmake for CI

* sudo ln

* move ort position

* remove unused sdk compile

* cd mmdeploy

* simplify build

* add missing make

* change order

* add -v

* add setuptools

* get locate

* get locate

* upgrade torch

* change torchvision  version

* fix config

* fix ci

* fix ci

* fix lint

Co-authored-by: tripleMu <[email protected]>
Co-authored-by: RunningLeon <[email protected]>
  • Loading branch information
3 people authored and hhaAndroid committed Nov 3, 2022
1 parent ce73b03 commit 5874e41
Show file tree
Hide file tree
Showing 20 changed files with 901 additions and 2 deletions.
22 changes: 20 additions & 2 deletions .circleci/test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,12 @@ jobs:
command: |
python -V
pip install torch==<< parameters.torch >>+cpu torchvision==<< parameters.torchvision >>+cpu -f https://download.pytorch.org/whl/torch_stable.html
- run:
name: Install ONNXRuntime
command: |
pip install onnxruntime==1.8.1
wget https://github.com/microsoft/onnxruntime/releases/download/v1.8.1/onnxruntime-linux-x64-1.8.1.tgz
tar xvf onnxruntime-linux-x64-1.8.1.tgz
- run:
name: Install mmyolo dependencies
command: |
Expand All @@ -65,13 +71,25 @@ jobs:
pip install git+https://github.com/open-mmlab/[email protected]
pip install -r requirements/albu.txt
pip install -r requirements/tests.txt
- run:
name: Install mmdeploy
command: |
pip install setuptools
git clone -b dev-1.x --depth 1 https://github.com/open-mmlab/mmdeploy.git mmdeploy --recurse-submodules
wget https://github.com/Kitware/CMake/releases/download/v3.20.0/cmake-3.20.0-linux-x86_64.tar.gz
tar -xzvf cmake-3.20.0-linux-x86_64.tar.gz
sudo ln -sf $(pwd)/cmake-3.20.0-linux-x86_64/bin/* /usr/bin/
cd mmdeploy && mkdir build && cd build && cmake .. -DMMDEPLOY_TARGET_BACKENDS=ort -DONNXRUNTIME_DIR=/home/circleci/project/onnxruntime-linux-x64-1.8.1 && make -j8 && make install
export LD_LIBRARY_PATH=/home/circleci/project/onnxruntime-linux-x64-1.8.1/lib:${LD_LIBRARY_PATH}
cd /home/circleci/project/mmdeploy && python -m pip install -v -e .
- run:
name: Build and install
command: |
pip install -e .
- run:
name: Run unittests
command: |
export LD_LIBRARY_PATH=/home/circleci/project/onnxruntime-linux-x64-1.8.1/lib:${LD_LIBRARY_PATH}
coverage run --branch --source mmyolo -m pytest tests/
coverage xml
coverage report -m
Expand Down Expand Up @@ -144,8 +162,8 @@ workflows:
- main
- build_cpu:
name: minimum_version_cpu
torch: 1.7.0
torchvision: 0.8.1
torch: 1.8.0
torchvision: 0.9.0
python: 3.8.0 # The lowest python 3.6.x version available on CircleCI images
requires:
- lint
Expand Down
17 changes: 17 additions & 0 deletions configs/deploy/base_dynamic.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
_base_ = ['./base_static.py']
onnx_config = dict(
dynamic_axes={
'input': {
0: 'batch',
2: 'height',
3: 'width'
},
'dets': {
0: 'batch',
1: 'num_dets'
},
'labels': {
0: 'batch',
1: 'num_dets'
}
})
23 changes: 23 additions & 0 deletions configs/deploy/base_static.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
onnx_config = dict(
type='onnx',
export_params=True,
keep_initializers_as_inputs=False,
opset_version=11,
save_file='end2end.onnx',
input_names=['input'],
output_names=['dets', 'labels'],
input_shape=None,
optimize=True)
codebase_config = dict(
type='mmyolo',
task='ObjectDetection',
model_type='end2end',
post_processing=dict(
score_threshold=0.05,
confidence_threshold=0.005,
iou_threshold=0.5,
max_output_boxes_per_class=200,
pre_top_k=5000,
keep_top_k=100,
background_label_id=-1),
module=['mmyolo.deploy'])
15 changes: 15 additions & 0 deletions configs/deploy/detection_onnxruntime_dynamic.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
_base_ = ['./base_dynamic.py']
codebase_config = dict(
type='mmyolo',
task='ObjectDetection',
model_type='end2end',
post_processing=dict(
score_threshold=0.05,
confidence_threshold=0.005,
iou_threshold=0.5,
max_output_boxes_per_class=200,
pre_top_k=5000,
keep_top_k=100,
background_label_id=-1),
module=['mmyolo.deploy'])
backend_config = dict(type='onnxruntime')
15 changes: 15 additions & 0 deletions configs/deploy/detection_onnxruntime_static.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
_base_ = ['./base_static.py']
codebase_config = dict(
type='mmyolo',
task='ObjectDetection',
model_type='end2end',
post_processing=dict(
score_threshold=0.05,
confidence_threshold=0.005,
iou_threshold=0.5,
max_output_boxes_per_class=200,
pre_top_k=5000,
keep_top_k=100,
background_label_id=-1),
module=['mmyolo.deploy'])
backend_config = dict(type='onnxruntime')
12 changes: 12 additions & 0 deletions configs/deploy/detection_tensorrt-fp16_dynamic-320x320-640x640.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
_base_ = ['./base_dynamic.py']
backend_config = dict(
type='tensorrt',
common_config=dict(fp16_mode=True, max_workspace_size=1 << 30),
model_inputs=[
dict(
input_shapes=dict(
input=dict(
min_shape=[1, 3, 320, 320],
opt_shape=[1, 3, 640, 640],
max_shape=[1, 3, 640, 640])))
])
13 changes: 13 additions & 0 deletions configs/deploy/detection_tensorrt-fp16_static-640x640.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
_base_ = ['./base_static.py']
onnx_config = dict(input_shape=(640, 640))
backend_config = dict(
type='tensorrt',
common_config=dict(fp16_mode=True, max_workspace_size=1 << 30),
model_inputs=[
dict(
input_shapes=dict(
input=dict(
min_shape=[1, 3, 640, 640],
opt_shape=[1, 3, 640, 640],
max_shape=[1, 3, 640, 640])))
])
14 changes: 14 additions & 0 deletions configs/deploy/detection_tensorrt-int8_dynamic-320x320-640x640.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
_base_ = ['./base_dynamic.py']
backend_config = dict(
type='tensorrt',
common_config=dict(
fp16_mode=True, max_workspace_size=1 << 30, int8_mode=True),
model_inputs=[
dict(
input_shapes=dict(
input=dict(
min_shape=[1, 3, 320, 320],
opt_shape=[1, 3, 640, 640],
max_shape=[1, 3, 640, 640])))
],
calib_config=dict(create_calib=True, calib_file='calib_data.h5'))
15 changes: 15 additions & 0 deletions configs/deploy/detection_tensorrt-int8_static-640x640.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
_base_ = ['./base_static.py']
onnx_config = dict(input_shape=(640, 640))
backend_config = dict(
type='tensorrt',
common_config=dict(
fp16_mode=True, max_workspace_size=1 << 30, int8_mode=True),
model_inputs=[
dict(
input_shapes=dict(
input=dict(
min_shape=[1, 3, 640, 640],
opt_shape=[1, 3, 640, 640],
max_shape=[1, 3, 640, 640])))
],
calib_config=dict(create_calib=True, calib_file='calib_data.h5'))
12 changes: 12 additions & 0 deletions configs/deploy/detection_tensorrt_dynamic-320x320-640x640.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
_base_ = ['./base_dynamic.py']
backend_config = dict(
type='tensorrt',
common_config=dict(fp16_mode=False, max_workspace_size=1 << 30),
model_inputs=[
dict(
input_shapes=dict(
input=dict(
min_shape=[1, 3, 320, 320],
opt_shape=[1, 3, 640, 640],
max_shape=[1, 3, 640, 640])))
])
13 changes: 13 additions & 0 deletions configs/deploy/detection_tensorrt_static-640x640.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
_base_ = ['./base_static.py']
onnx_config = dict(input_shape=(640, 640))
backend_config = dict(
type='tensorrt',
common_config=dict(fp16_mode=False, max_workspace_size=1 << 30),
model_inputs=[
dict(
input_shapes=dict(
input=dict(
min_shape=[1, 3, 640, 640],
opt_shape=[1, 3, 640, 640],
max_shape=[1, 3, 640, 640])))
])
7 changes: 7 additions & 0 deletions mmyolo/deploy/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
# Copyright (c) OpenMMLab. All rights reserved.
from mmdeploy.codebase.base import MMCodebase

from .models import * # noqa: F401,F403
from .object_detection import MMYOLO, YOLOObjectDetection

__all__ = ['MMCodebase', 'MMYOLO', 'YOLOObjectDetection']
2 changes: 2 additions & 0 deletions mmyolo/deploy/models/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
# Copyright (c) OpenMMLab. All rights reserved.
from . import dense_heads # noqa: F401,F403
4 changes: 4 additions & 0 deletions mmyolo/deploy/models/dense_heads/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
# Copyright (c) OpenMMLab. All rights reserved.
from . import yolov5_head # noqa: F401,F403

__all__ = ['yolov5_head']
108 changes: 108 additions & 0 deletions mmyolo/deploy/models/dense_heads/yolov5_head.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,108 @@
# Copyright (c) OpenMMLab. All rights reserved.
import copy
from typing import List, Optional, Tuple

import torch
from mmdeploy.codebase.mmdet import get_post_processing_params
from mmdeploy.codebase.mmdet.models.layers import multiclass_nms
from mmdeploy.core import FUNCTION_REWRITER
from mmengine.config import ConfigDict
from mmengine.structures import InstanceData
from torch import Tensor


@FUNCTION_REWRITER.register_rewriter(
func_name='mmyolo.models.dense_heads.yolov5_head.'
'YOLOv5Head.predict_by_feat')
def yolov5_head__predict_by_feat(ctx,
self,
cls_scores: List[Tensor],
bbox_preds: List[Tensor],
objectnesses: Optional[List[Tensor]],
batch_img_metas: Optional[List[dict]] = None,
cfg: Optional[ConfigDict] = None,
rescale: bool = False,
with_nms: bool = True) -> Tuple[InstanceData]:
"""Transform a batch of output features extracted by the head into
bbox results.
Args:
cls_scores (list[Tensor]): Classification scores for all
scale levels, each is a 4D-tensor, has shape
(batch_size, num_priors * num_classes, H, W).
bbox_preds (list[Tensor]): Box energies / deltas for all
scale levels, each is a 4D-tensor, has shape
(batch_size, num_priors * 4, H, W).
objectnesses (list[Tensor], Optional): Score factor for
all scale level, each is a 4D-tensor, has shape
(batch_size, 1, H, W).
batch_img_metas (list[dict], Optional): Batch image meta info.
Defaults to None.
cfg (ConfigDict, optional): Test / postprocessing
configuration, if None, test_cfg would be used.
Defaults to None.
rescale (bool): If True, return boxes in original image space.
Defaults to False.
with_nms (bool): If True, do nms before return boxes.
Defaults to True.
Returns:
tuple[Tensor, Tensor]: The first item is an (N, num_box, 5) tensor,
where 5 represent (tl_x, tl_y, br_x, br_y, score), N is batch
size and the score between 0 and 1. The shape of the second
tensor in the tuple is (N, num_box), and each element
represents the class label of the corresponding box.
"""
assert len(cls_scores) == len(bbox_preds)
cfg = self.test_cfg if cfg is None else cfg
cfg = copy.deepcopy(cfg)

num_imgs = cls_scores[0].shape[0]
featmap_sizes = [cls_score.shape[2:] for cls_score in cls_scores]

mlvl_priors = self.prior_generator.grid_priors(
featmap_sizes, dtype=cls_scores[0].dtype, device=cls_scores[0].device)
flatten_priors = torch.cat(mlvl_priors)

mlvl_strides = [
flatten_priors.new_full(
(featmap_size[0] * featmap_size[1] * self.num_base_priors, ),
stride)
for featmap_size, stride in zip(featmap_sizes, self.featmap_strides)
]
flatten_stride = torch.cat(mlvl_strides)
# flatten cls_scores, bbox_preds and objectness
flatten_cls_scores = [
cls_score.permute(0, 2, 3, 1).reshape(num_imgs, -1, self.num_classes)
for cls_score in cls_scores
]
flatten_bbox_preds = [
bbox_pred.permute(0, 2, 3, 1).reshape(num_imgs, -1, 4)
for bbox_pred in bbox_preds
]

flatten_objectness = [
objectness.permute(0, 2, 3, 1).reshape(num_imgs, -1)
for objectness in objectnesses
]

cls_scores = torch.cat(flatten_cls_scores, dim=1).sigmoid()
flatten_bbox_preds = torch.cat(flatten_bbox_preds, dim=1)
flatten_objectness = torch.cat(flatten_objectness, dim=1).sigmoid()
bboxes = self.bbox_coder.decode(flatten_priors[None], flatten_bbox_preds,
flatten_stride)

# directly multiply score factor and feed to nms
scores = cls_scores * (flatten_objectness.unsqueeze(-1))

if not with_nms:
return bboxes, scores
deploy_cfg = ctx.cfg
post_params = get_post_processing_params(deploy_cfg)
max_output_boxes_per_class = post_params.max_output_boxes_per_class
iou_threshold = cfg.nms.get('iou_threshold', post_params.iou_threshold)
score_threshold = cfg.get('score_thr', post_params.score_threshold)
pre_top_k = post_params.pre_top_k
keep_top_k = cfg.get('max_per_img', post_params.keep_top_k)

return multiclass_nms(bboxes, scores, max_output_boxes_per_class,
iou_threshold, score_threshold, pre_top_k,
keep_top_k)
Loading

0 comments on commit 5874e41

Please sign in to comment.