forked from open-mmlab/mmdeploy
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add rv1126 yolov3 support to sdk (open-mmlab#1280)
* add yolov3 head to SDK * add yolov5 head to SDK * fix export-info and lint, add reverse check * fix lint * fix export info for yolo heads * add output_names to partition_config * fix typo * config * normalize config * fix * refactor config * fix lint and doc * c++ form * resolve comments * fix CI * fix CI * fix CI * float strides anchors * refine pipeline of rknn-int8 * config * rename func * refactor * rknn wrapper dict and fix typo * rknn wrapper output update, mmcls use end2end type * fix typo
- Loading branch information
Showing
35 changed files
with
579 additions
and
147 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,8 +1,6 @@ | ||
backend_config = dict( | ||
type='rknn', | ||
common_config=dict( | ||
mean_values=None, # [[103.53, 116.28, 123.675]], | ||
std_values=None, # [[57.375, 57.12, 58.395]], | ||
target_platform='rv1126', # 'rk3588' | ||
optimization_level=1), | ||
quantization_config=dict(do_quantization=False, dataset=None)) | ||
quantization_config=dict(do_quantization=True, dataset=None)) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,7 @@ | ||
_base_ = ['./classification_static.py', '../_base_/backends/rknn.py'] | ||
|
||
onnx_config = dict(input_shape=[224, 224]) | ||
codebase_config = dict(model_type='end2end') | ||
backend_config = dict( | ||
input_size_list=[[3, 224, 224]], | ||
quantization_config=dict(do_quantization=False)) |
2 changes: 1 addition & 1 deletion
2
...cls/classification_rknn_static-224x224.py → ...lassification_rknn-int8_static-224x224.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,5 +1,5 @@ | ||
_base_ = ['./classification_static.py', '../_base_/backends/rknn.py'] | ||
|
||
onnx_config = dict(input_shape=[224, 224]) | ||
codebase_config = dict(model_type='rknn') | ||
codebase_config = dict(model_type='end2end') | ||
backend_config = dict(input_size_list=[[3, 224, 224]]) |
34 changes: 34 additions & 0 deletions
34
configs/mmdet/detection/detection_rknn-fp16_static-320x320.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,34 @@ | ||
_base_ = ['../_base_/base_static.py', '../../_base_/backends/rknn.py'] | ||
|
||
onnx_config = dict(input_shape=[320, 320]) | ||
|
||
codebase_config = dict(model_type='rknn') | ||
|
||
backend_config = dict( | ||
input_size_list=[[3, 320, 320]], | ||
quantization_config=dict(do_quantization=False)) | ||
|
||
# # yolov3, yolox for rknn-toolkit and rknn-toolkit2 | ||
# partition_config = dict( | ||
# type='rknn', # the partition policy name | ||
# apply_marks=True, # should always be set to True | ||
# partition_cfg=[ | ||
# dict( | ||
# save_file='model.onnx', # name to save the partitioned onnx | ||
# start=['detector_forward:input'], # [mark_name:input, ...] | ||
# end=['yolo_head:input'], # [mark_name:output, ...] | ||
# output_names=[f'pred_maps.{i}' for i in range(3)]) # out names | ||
# ]) | ||
|
||
# # retinanet, ssd, fsaf for rknn-toolkit2 | ||
# partition_config = dict( | ||
# type='rknn', # the partition policy name | ||
# apply_marks=True, | ||
# partition_cfg=[ | ||
# dict( | ||
# save_file='model.onnx', | ||
# start='detector_forward:input', | ||
# end=['BaseDenseHead:output'], | ||
# output_names=[f'BaseDenseHead.cls.{i}' for i in range(5)] + | ||
# [f'BaseDenseHead.loc.{i}' for i in range(5)]) | ||
# ]) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,9 @@ | ||
_base_ = ['./segmentation_static.py', '../_base_/backends/rknn.py'] | ||
|
||
onnx_config = dict(input_shape=[320, 320]) | ||
|
||
codebase_config = dict(model_type='rknn') | ||
|
||
backend_config = dict( | ||
input_size_list=[[3, 320, 320]], | ||
quantization_config=dict(do_quantization=False)) |
File renamed without changes.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,228 @@ | ||
// Copyright (c) OpenMMLab. All rights reserved. | ||
#include "yolo_head.h" | ||
|
||
#include <math.h> | ||
|
||
#include <algorithm> | ||
#include <numeric> | ||
|
||
#include "mmdeploy/core/model.h" | ||
#include "mmdeploy/core/utils/device_utils.h" | ||
#include "mmdeploy/core/utils/formatter.h" | ||
#include "utils.h" | ||
|
||
namespace mmdeploy::mmdet { | ||
|
||
YOLOHead::YOLOHead(const Value& cfg) : MMDetection(cfg) { | ||
auto init = [&]() -> Result<void> { | ||
auto model = cfg["context"]["model"].get<Model>(); | ||
if (cfg.contains("params")) { | ||
nms_pre_ = cfg["params"].value("nms_pre", -1); | ||
score_thr_ = cfg["params"].value("score_thr", 0.02f); | ||
min_bbox_size_ = cfg["params"].value("min_bbox_size", 0); | ||
iou_threshold_ = cfg["params"].contains("nms") | ||
? cfg["params"]["nms"].value("iou_threshold", 0.45f) | ||
: 0.45f; | ||
if (cfg["params"].contains("anchor_generator")) { | ||
from_value(cfg["params"]["anchor_generator"]["base_sizes"], anchors_); | ||
from_value(cfg["params"]["anchor_generator"]["strides"], strides_); | ||
} | ||
} | ||
return success(); | ||
}; | ||
init().value(); | ||
} | ||
|
||
Result<Value> YOLOHead::operator()(const Value& prep_res, const Value& infer_res) { | ||
MMDEPLOY_DEBUG("prep_res: {}\ninfer_res: {}", prep_res, infer_res); | ||
try { | ||
const Device kHost{0, 0}; | ||
std::vector<Tensor> pred_maps; | ||
for (auto iter = infer_res.begin(); iter != infer_res.end(); iter++) { | ||
auto pred_map = iter->get<Tensor>(); | ||
OUTCOME_TRY(auto _pred_map, MakeAvailableOnDevice(pred_map, kHost, stream())); | ||
pred_maps.push_back(_pred_map); | ||
} | ||
OUTCOME_TRY(stream().Wait()); | ||
// reorder pred_maps according to strides and anchors, mainly for rknpu yolov3 | ||
if ((pred_maps.size() > 1) && | ||
!((strides_[0] < strides_[1]) ^ (pred_maps[0].shape(3) < pred_maps[1].shape(3)))) { | ||
std::reverse(pred_maps.begin(), pred_maps.end()); | ||
} | ||
OUTCOME_TRY(auto result, GetBBoxes(prep_res["img_metas"], pred_maps)); | ||
return to_value(result); | ||
} catch (...) { | ||
return Status(eFail); | ||
} | ||
} | ||
|
||
inline static int clamp(float val, int min, int max) { | ||
return val > min ? (val < max ? val : max) : min; | ||
} | ||
|
||
static float sigmoid(float x) { return 1.0 / (1.0 + expf(-x)); } | ||
|
||
static float unsigmoid(float y) { return -1.0 * logf((1.0 / y) - 1.0); } | ||
|
||
int YOLOHead::YOLOFeatDecode(const Tensor& feat_map, const std::vector<std::vector<float>>& anchor, | ||
int grid_h, int grid_w, int height, int width, int stride, | ||
std::vector<float>& boxes, std::vector<float>& obj_probs, | ||
std::vector<int>& class_id, float threshold) const { | ||
auto input = const_cast<float*>(feat_map.data<float>()); | ||
auto prop_box_size = feat_map.shape(1) / anchor.size(); | ||
const int kClasses = prop_box_size - 5; | ||
int valid_count = 0; | ||
int grid_len = grid_h * grid_w; | ||
float thres = unsigmoid(threshold); | ||
for (int a = 0; a < anchor.size(); a++) { | ||
for (int i = 0; i < grid_h; i++) { | ||
for (int j = 0; j < grid_w; j++) { | ||
float box_confidence = input[(prop_box_size * a + 4) * grid_len + i * grid_w + j]; | ||
if (box_confidence >= thres) { | ||
int offset = (prop_box_size * a) * grid_len + i * grid_w + j; | ||
float* in_ptr = input + offset; | ||
|
||
float box_x = sigmoid(*in_ptr); | ||
float box_y = sigmoid(in_ptr[grid_len]); | ||
float box_w = in_ptr[2 * grid_len]; | ||
float box_h = in_ptr[3 * grid_len]; | ||
auto box = yolo_decode(box_x, box_y, box_w, box_h, stride, anchor, j, i, a); | ||
|
||
box_x = box[0]; | ||
box_y = box[1]; | ||
box_w = box[2]; | ||
box_h = box[3]; | ||
|
||
box_x -= (box_w / 2.0); | ||
box_y -= (box_h / 2.0); | ||
boxes.push_back(box_x); | ||
boxes.push_back(box_y); | ||
boxes.push_back(box_x + box_w); | ||
boxes.push_back(box_y + box_h); | ||
|
||
float max_class_probs = in_ptr[5 * grid_len]; | ||
int max_class_id = 0; | ||
for (int k = 1; k < kClasses; ++k) { | ||
float prob = in_ptr[(5 + k) * grid_len]; | ||
if (prob > max_class_probs) { | ||
max_class_id = k; | ||
max_class_probs = prob; | ||
} | ||
} | ||
obj_probs.push_back(sigmoid(max_class_probs) * sigmoid(box_confidence)); | ||
class_id.push_back(max_class_id); | ||
valid_count++; | ||
} | ||
} | ||
} | ||
} | ||
return valid_count; | ||
} | ||
|
||
Result<Detections> YOLOHead::GetBBoxes(const Value& prep_res, | ||
const std::vector<Tensor>& pred_maps) const { | ||
std::vector<float> filter_boxes; | ||
std::vector<float> obj_probs; | ||
std::vector<int> class_id; | ||
|
||
int model_in_h = prep_res["img_shape"][1].get<int>(); | ||
int model_in_w = prep_res["img_shape"][2].get<int>(); | ||
|
||
for (int i = 0; i < pred_maps.size(); i++) { | ||
int stride = strides_[i]; | ||
int grid_h = model_in_h / stride; | ||
int grid_w = model_in_w / stride; | ||
YOLOFeatDecode(pred_maps[i], anchors_[i], grid_h, grid_w, model_in_h, model_in_w, stride, | ||
filter_boxes, obj_probs, class_id, score_thr_); | ||
} | ||
|
||
std::vector<int> indexArray; | ||
for (int i = 0; i < obj_probs.size(); ++i) { | ||
indexArray.push_back(i); | ||
} | ||
Sort(obj_probs, class_id, indexArray); | ||
|
||
Tensor dets(TensorDesc{Device{0, 0}, DataType::kFLOAT, | ||
TensorShape{int(filter_boxes.size() / 4), 4}, "dets"}); | ||
std::copy(filter_boxes.begin(), filter_boxes.end(), dets.data<float>()); | ||
NMS(dets, iou_threshold_, indexArray); | ||
|
||
Detections objs; | ||
std::vector<float> scale_factor; | ||
if (prep_res.contains("scale_factor")) { | ||
from_value(prep_res["scale_factor"], scale_factor); | ||
} else { | ||
scale_factor = {1.f, 1.f, 1.f, 1.f}; | ||
} | ||
int ori_width = prep_res["ori_shape"][2].get<int>(); | ||
int ori_height = prep_res["ori_shape"][1].get<int>(); | ||
auto det_ptr = dets.data<float>(); | ||
for (int i = 0; i < indexArray.size(); ++i) { | ||
if (indexArray[i] == -1) { | ||
continue; | ||
} | ||
int j = indexArray[i]; | ||
auto x1 = clamp(det_ptr[j * 4 + 0], 0, model_in_w); | ||
auto y1 = clamp(det_ptr[j * 4 + 1], 0, model_in_h); | ||
auto x2 = clamp(det_ptr[j * 4 + 2], 0, model_in_w); | ||
auto y2 = clamp(det_ptr[j * 4 + 3], 0, model_in_h); | ||
int label_id = class_id[i]; | ||
float score = obj_probs[i]; | ||
|
||
MMDEPLOY_DEBUG("{}-th box: ({}, {}, {}, {}), {}, {}", i, x1, y1, x2, y2, label_id, score); | ||
|
||
auto rect = MapToOriginImage(x1, y1, x2, y2, scale_factor.data(), 0, 0, ori_width, ori_height); | ||
if (rect[2] - rect[0] < min_bbox_size_ || rect[3] - rect[1] < min_bbox_size_) { | ||
MMDEPLOY_DEBUG("ignore small bbox with width '{}' and height '{}", rect[2] - rect[0], | ||
rect[3] - rect[1]); | ||
continue; | ||
} | ||
Detection det{}; | ||
det.index = i; | ||
det.label_id = label_id; | ||
det.score = score; | ||
det.bbox = rect; | ||
objs.push_back(std::move(det)); | ||
} | ||
|
||
return objs; | ||
} | ||
|
||
Result<Value> YOLOV3Head::operator()(const Value& prep_res, const Value& infer_res) { | ||
return YOLOHead::operator()(prep_res, infer_res); | ||
} | ||
|
||
std::array<float, 4> YOLOV3Head::yolo_decode(float box_x, float box_y, float box_w, float box_h, | ||
float stride, | ||
const std::vector<std::vector<float>>& anchor, int j, | ||
int i, int a) const { | ||
box_x = (box_x + j) * stride; | ||
box_y = (box_y + i) * stride; | ||
box_w = expf(box_w) * anchor[a][0]; | ||
box_h = expf(box_h) * anchor[a][1]; | ||
return std::array<float, 4>{box_x, box_y, box_w, box_h}; | ||
} | ||
|
||
Result<Value> YOLOV5Head::operator()(const Value& prep_res, const Value& infer_res) { | ||
return YOLOHead::operator()(prep_res, infer_res); | ||
} | ||
|
||
std::array<float, 4> YOLOV5Head::yolo_decode(float box_x, float box_y, float box_w, float box_h, | ||
float stride, | ||
const std::vector<std::vector<float>>& anchor, int j, | ||
int i, int a) const { | ||
box_x = box_x * 2 - 0.5; | ||
box_y = box_y * 2 - 0.5; | ||
box_w = box_w * 2 - 0.5; | ||
box_h = box_h * 2 - 0.5; | ||
box_x = (box_x + j) * stride; | ||
box_y = (box_y + i) * stride; | ||
box_w = box_w * box_w * anchor[a][0]; | ||
box_h = box_h * box_h * anchor[a][1]; | ||
return std::array<float, 4>{box_x, box_y, box_w, box_h}; | ||
} | ||
|
||
MMDEPLOY_REGISTER_CODEBASE_COMPONENT(MMDetection, YOLOV3Head); | ||
MMDEPLOY_REGISTER_CODEBASE_COMPONENT(MMDetection, YOLOV5Head); | ||
|
||
} // namespace mmdeploy::mmdet |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,53 @@ | ||
// Copyright (c) OpenMMLab. All rights reserved. | ||
#ifndef MMDEPLOY_CODEBASE_MMDET_YOLO_HEAD_H_ | ||
#define MMDEPLOY_CODEBASE_MMDET_YOLO_HEAD_H_ | ||
|
||
#include "mmdeploy/codebase/mmdet/mmdet.h" | ||
#include "mmdeploy/core/tensor.h" | ||
|
||
namespace mmdeploy::mmdet { | ||
|
||
class YOLOHead : public MMDetection { | ||
public: | ||
explicit YOLOHead(const Value& cfg); | ||
Result<Value> operator()(const Value& prep_res, const Value& infer_res); | ||
int YOLOFeatDecode(const Tensor& feat_map, const std::vector<std::vector<float>>& anchor, | ||
int grid_h, int grid_w, int height, int width, int stride, | ||
std::vector<float>& boxes, std::vector<float>& obj_probs, | ||
std::vector<int>& class_id, float threshold) const; | ||
Result<Detections> GetBBoxes(const Value& prep_res, const std::vector<Tensor>& pred_maps) const; | ||
virtual std::array<float, 4> yolo_decode(float box_x, float box_y, float box_w, float box_h, | ||
float stride, | ||
const std::vector<std::vector<float>>& anchor, int j, | ||
int i, int a) const = 0; | ||
|
||
private: | ||
float score_thr_{0.4f}; | ||
int nms_pre_{1000}; | ||
float iou_threshold_{0.45f}; | ||
int min_bbox_size_{0}; | ||
std::vector<std::vector<std::vector<float>>> anchors_; | ||
std::vector<float> strides_; | ||
}; | ||
|
||
class YOLOV3Head : public YOLOHead { | ||
public: | ||
using YOLOHead::YOLOHead; | ||
Result<Value> operator()(const Value& prep_res, const Value& infer_res); | ||
std::array<float, 4> yolo_decode(float box_x, float box_y, float box_w, float box_h, float stride, | ||
const std::vector<std::vector<float>>& anchor, int j, int i, | ||
int a) const override; | ||
}; | ||
|
||
class YOLOV5Head : public YOLOHead { | ||
public: | ||
using YOLOHead::YOLOHead; | ||
Result<Value> operator()(const Value& prep_res, const Value& infer_res); | ||
std::array<float, 4> yolo_decode(float box_x, float box_y, float box_w, float box_h, float stride, | ||
const std::vector<std::vector<float>>& anchor, int j, int i, | ||
int a) const override; | ||
}; | ||
|
||
} // namespace mmdeploy::mmdet | ||
|
||
#endif // MMDEPLOY_CODEBASE_MMDET_YOLO_HEAD_H_ |
Oops, something went wrong.