Skip to content

Commit

Permalink
Add YOLOv5 support for RV1126 device. (#1321)
Browse files Browse the repository at this point in the history
* support pose simcc

* fix yolov5 create_input

* add yolov5 post process and update mmcls.yml

* add letter resize

* add yolov5 export info

* fix

* add pose face config

* pick 4dd4d48

* fix yolov5 head

* fix ut

* refactor mmpose config

* pass output_names outside

* rknn batch size

* lint

* add input names to wrapper

* update according to open-mmlab/mmyolo#305

* add pre_compile option

* update doc and fix typo

* fix padding

* fix typo

* use throw_exception
  • Loading branch information
AllentDan authored and lvhan028 committed Mar 1, 2023
1 parent e2c397e commit 83a2e85
Show file tree
Hide file tree
Showing 44 changed files with 853 additions and 133 deletions.
8 changes: 5 additions & 3 deletions configs/_base_/backends/rknn.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
backend_config = dict(
type='rknn',
common_config=dict(
mean_values=None, # [[103.53, 116.28, 123.675]],
std_values=None, # [[57.375, 57.12, 58.395]],
target_platform='rv1126', # 'rk3588'
optimization_level=1),
quantization_config=dict(do_quantization=False, dataset=None))
quantization_config=dict(
do_quantization=True,
dataset=None,
pre_compile=False,
rknn_batch_size=-1))
7 changes: 7 additions & 0 deletions configs/mmcls/classification_rknn-fp16_static-224x224.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
_base_ = ['./classification_static.py', '../_base_/backends/rknn.py']

onnx_config = dict(input_shape=[224, 224])
codebase_config = dict(model_type='end2end')
backend_config = dict(
input_size_list=[[3, 224, 224]],
quantization_config=dict(do_quantization=False))
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
_base_ = ['./classification_static.py', '../_base_/backends/rknn.py']

onnx_config = dict(input_shape=[224, 224])
codebase_config = dict(model_type='rknn')
codebase_config = dict(model_type='end2end')
backend_config = dict(input_size_list=[[3, 224, 224]])
34 changes: 34 additions & 0 deletions configs/mmdet/detection/detection_rknn-fp16_static-320x320.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
_base_ = ['../_base_/base_static.py', '../../_base_/backends/rknn.py']

onnx_config = dict(input_shape=[320, 320])

codebase_config = dict(model_type='rknn')

backend_config = dict(
input_size_list=[[3, 320, 320]],
quantization_config=dict(do_quantization=False))

# # yolov3, yolox for rknn-toolkit and rknn-toolkit2
# partition_config = dict(
# type='rknn', # the partition policy name
# apply_marks=True, # should always be set to True
# partition_cfg=[
# dict(
# save_file='model.onnx', # name to save the partitioned onnx
# start=['detector_forward:input'], # [mark_name:input, ...]
# end=['yolo_head:input'], # [mark_name:output, ...]
# output_names=[f'pred_maps.{i}' for i in range(3)]) # out names
# ])

# # retinanet, ssd, fsaf for rknn-toolkit2
# partition_config = dict(
# type='rknn', # the partition policy name
# apply_marks=True,
# partition_cfg=[
# dict(
# save_file='model.onnx',
# start='detector_forward:input',
# end=['BaseDenseHead:output'],
# output_names=[f'BaseDenseHead.cls.{i}' for i in range(5)] +
# [f'BaseDenseHead.loc.{i}' for i in range(5)])
# ])
Original file line number Diff line number Diff line change
Expand Up @@ -6,24 +6,27 @@

backend_config = dict(input_size_list=[[3, 320, 320]])

# yolov3, yolox
# # yolov3, yolox for rknn-toolkit and rknn-toolkit2
# partition_config = dict(
# type='rknn', # the partition policy name
# apply_marks=True, # should always be set to True
# partition_cfg=[
# dict(
# save_file='model.onnx', # name to save the partitioned onnx
# start=['detector_forward:input'], # [mark_name:input, ...]
# end=['yolo_head:input']) # [mark_name:output, ...]
# end=['yolo_head:input'], # [mark_name:output, ...]
# output_names=[f'pred_maps.{i}' for i in range(3)]) # out names
# ])

# # retinanet, ssd, fsaf
# # retinanet, ssd, fsaf for rknn-toolkit2
# partition_config = dict(
# type='rknn', # the partition policy name
# apply_marks=True,
# partition_cfg=[
# dict(
# save_file='model.onnx',
# start='detector_forward:input',
# end=['BaseDenseHead:output'])
# end=['BaseDenseHead:output'],
# output_names=[f'BaseDenseHead.cls.{i}' for i in range(5)] +
# [f'BaseDenseHead.loc.{i}' for i in range(5)])
# ])
Original file line number Diff line number Diff line change
Expand Up @@ -8,5 +8,6 @@
dict(
save_file='yolov3.onnx',
start=['detector_forward:input'],
end=['yolo_head:input'])
end=['yolo_head:input'],
output_names=[f'pred_maps.{i}' for i in range(3)])
])
10 changes: 10 additions & 0 deletions configs/mmpose/pose-detection_rknn-fp16_static-256x192.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
_base_ = ['./pose-detection_static.py', '../_base_/backends/rknn.py']

onnx_config = dict(input_shape=[192, 256])

codebase_config = dict(model_type='end2end')

backend_config = dict(
input_size_list=[[3, 256, 192]],
quantization_config=dict(do_quantization=False),
common_config=dict(target_platform='rv1126'))
10 changes: 10 additions & 0 deletions configs/mmpose/pose-detection_rknn-fp16_static-256x256.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
_base_ = ['./pose-detection_static.py', '../_base_/backends/rknn.py']

onnx_config = dict(input_shape=[256, 256])

codebase_config = dict(model_type='end2end')

backend_config = dict(
input_size_list=[[3, 256, 256]],
quantization_config=dict(do_quantization=False),
common_config=dict(target_platform='rv1126'))
9 changes: 9 additions & 0 deletions configs/mmpose/pose-detection_rknn-int8_static-256x192.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
_base_ = ['./pose-detection_static.py', '../_base_/backends/rknn.py']

onnx_config = dict(input_shape=[192, 256])

codebase_config = dict(model_type='end2end')

backend_config = dict(
input_size_list=[[3, 256, 192]],
common_config=dict(target_platform='rv1126'))
9 changes: 9 additions & 0 deletions configs/mmpose/pose-detection_rknn-int8_static-256x256.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
_base_ = ['./pose-detection_static.py', '../_base_/backends/rknn.py']

onnx_config = dict(input_shape=[256, 256])

codebase_config = dict(model_type='end2end')

backend_config = dict(
input_size_list=[[3, 256, 256]],
common_config=dict(target_platform='rv1126'))
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
_base_ = ['./pose-detection_static.py', '../_base_/backends/rknn.py']

onnx_config = dict(input_shape=[192, 256], output_names=['simcc_x', 'simcc_y'])

backend_config = dict(
input_size_list=[[3, 256, 192]],
quantization_config=dict(do_quantization=False))
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
_base_ = ['./pose-detection_static.py', '../_base_/backends/rknn.py']

onnx_config = dict(input_shape=[192, 256], output_names=['simcc_x', 'simcc_y'])

backend_config = dict(input_size_list=[[3, 256, 192]])
9 changes: 9 additions & 0 deletions configs/mmseg/segmentation_rknn-fp16_static-320x320.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
_base_ = ['./segmentation_static.py', '../_base_/backends/rknn.py']

onnx_config = dict(input_shape=[320, 320])

codebase_config = dict(model_type='rknn')

backend_config = dict(
input_size_list=[[3, 320, 320]],
quantization_config=dict(do_quantization=False))
3 changes: 2 additions & 1 deletion csrc/mmdeploy/codebase/mmdet/base_dense_head.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,8 @@ Result<Detections> BaseDenseHead::GetBBoxes(const Value& prep_res, const Tensor&

MMDEPLOY_DEBUG("{}-th box: ({}, {}, {}, {}), {}, {}", i, x1, y1, x2, y2, label_id, score);

auto rect = MapToOriginImage(x1, y1, x2, y2, scale_factor.data(), 0, 0, ori_width, ori_height);
auto rect =
MapToOriginImage(x1, y1, x2, y2, scale_factor.data(), 0, 0, ori_width, ori_height, 0, 0);
if (rect[2] - rect[0] < min_bbox_size_ || rect[3] - rect[1] < min_bbox_size_) {
MMDEPLOY_DEBUG("ignore small bbox with width '{}' and height '{}", rect[2] - rect[0],
rect[3] - rect[1]);
Expand Down
19 changes: 9 additions & 10 deletions csrc/mmdeploy/codebase/mmdet/object_detection.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
#include "mmdeploy/core/utils/device_utils.h"
#include "mmdeploy/core/utils/formatter.h"
#include "mmdeploy/experimental/module_adapter.h"
#include "utils.h"

using namespace std;

Expand Down Expand Up @@ -127,6 +128,13 @@ Result<Detections> ResizeBBox::GetBBoxes(const Value& prep_res, const Tensor& de
scale_factor = {1.f, 1.f, 1.f, 1.f};
}

int top_padding = 0;
int left_padding = 0;
if (prep_res.contains("pad_param")) {
top_padding = prep_res["pad_param"][0].get<int>();
left_padding = prep_res["pad_param"][1].get<int>();
}

float w_offset = 0.f;
float h_offset = 0.f;
if (prep_res.contains("border")) {
Expand All @@ -153,7 +161,7 @@ Result<Detections> ResizeBBox::GetBBoxes(const Value& prep_res, const Tensor& de
MMDEPLOY_DEBUG("ori left {}, top {}, right {}, bottom {}, label {}", left, top, right, bottom,
*labels_ptr);
auto rect = MapToOriginImage(left, top, right, bottom, scale_factor.data(), w_offset, h_offset,
ori_width, ori_height);
ori_width, ori_height, top_padding, left_padding);
if (rect[2] - rect[0] < min_bbox_size_ || rect[3] - rect[1] < min_bbox_size_) {
MMDEPLOY_DEBUG("ignore small bbox with width '{}' and height '{}", rect[2] - rect[0],
rect[3] - rect[1]);
Expand All @@ -170,15 +178,6 @@ Result<Detections> ResizeBBox::GetBBoxes(const Value& prep_res, const Tensor& de
}
return objs;
}
std::array<float, 4> ResizeBBox::MapToOriginImage(float left, float top, float right, float bottom,
const float* scale_factor, float x_offset,
float y_offset, int ori_width, int ori_height) {
left = std::max(left / scale_factor[0] + x_offset, 0.f);
top = std::max(top / scale_factor[1] + y_offset, 0.f);
right = std::min(right / scale_factor[2] + x_offset, (float)ori_width - 1.f);
bottom = std::min(bottom / scale_factor[3] + y_offset, (float)ori_height - 1.f);
return {left, top, right, bottom};
}

MMDEPLOY_REGISTER_CODEBASE_COMPONENT(MMDetection, ResizeBBox);

Expand Down
4 changes: 0 additions & 4 deletions csrc/mmdeploy/codebase/mmdet/object_detection.h
Original file line number Diff line number Diff line change
Expand Up @@ -22,10 +22,6 @@ class ResizeBBox : public MMDetection {
template <typename T>
Result<Detections> GetBBoxes(const Value& prep_res, const Tensor& dets, const Tensor& labels);

std::array<float, 4> MapToOriginImage(float left, float top, float right, float bottom,
const float* scale_factor, float x_offset, float y_offset,
int ori_width, int ori_height);

std::vector<Tensor> GetDetsLabels(const Value& prep_res, const Value& infer_res);

protected:
Expand Down
11 changes: 6 additions & 5 deletions csrc/mmdeploy/codebase/mmdet/utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,12 @@ namespace mmdeploy::mmdet {

std::array<float, 4> MapToOriginImage(float left, float top, float right, float bottom,
const float* scale_factor, float x_offset, float y_offset,
int ori_width, int ori_height) {
left = std::max(left / scale_factor[0] + x_offset, 0.f);
top = std::max(top / scale_factor[1] + y_offset, 0.f);
right = std::min(right / scale_factor[2] + x_offset, (float)ori_width - 1.f);
bottom = std::min(bottom / scale_factor[3] + y_offset, (float)ori_height - 1.f);
int ori_width, int ori_height, int top_padding,
int left_padding) {
left = std::max((left - left_padding) / scale_factor[0] + x_offset, 0.f);
top = std::max((top - top_padding) / scale_factor[1] + y_offset, 0.f);
right = std::min((right - left_padding) / scale_factor[2] + x_offset, (float)ori_width - 1.f);
bottom = std::min((bottom - top_padding) / scale_factor[3] + y_offset, (float)ori_height - 1.f);
return {left, top, right, bottom};
}

Expand Down
3 changes: 2 additions & 1 deletion csrc/mmdeploy/codebase/mmdet/utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,8 @@
namespace mmdeploy::mmdet {
std::array<float, 4> MapToOriginImage(float left, float top, float right, float bottom,
const float* scale_factor, float x_offset, float y_offset,
int ori_width, int ori_height);
int ori_width, int ori_height, int top_padding,
int left_padding);
// @brief Filter results using score threshold and topk candidates.
// scores (Tensor): The scores, shape (num_bboxes, K).
// probs: The scores after being filtered
Expand Down
Loading

0 comments on commit 83a2e85

Please sign in to comment.