Skip to content

Commit

Permalink
CodeCamp #101: Support MMDetection 3.x RTMDet model deployment on RV1…
Browse files Browse the repository at this point in the history
…126 (#1551)

* * partition rtmdet

* * add rtmdet deploy config

* * add rtmdet deploy config

* * modify rtmdet pipline anchor_generator's info dump
* support rtmdet infer in sdk

* fix a bug

* * fix a bug in csrc/mmdeploy/preprocess/transform/normalize.cpp

* * fix a bug

* * update docs

* * fix lint

* * update several urls in docs
  • Loading branch information
Qingrenn authored Jan 6, 2023
1 parent db8de7e commit 71fc8e3
Show file tree
Hide file tree
Showing 13 changed files with 304 additions and 7 deletions.
19 changes: 19 additions & 0 deletions configs/mmdet/detection/detection_rknn-int8_static-640x640.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
_base_ = ['../_base_/base_static.py', '../../_base_/backends/rknn.py']

onnx_config = dict(input_shape=[640, 640])

codebase_config = dict(model_type='rknn')

backend_config = dict(input_size_list=[[3, 640, 640]])

# rtmdet for rknn-toolkit and rknn-toolkit2
# partition_config = dict(
# type='rknn', # the partition policy name
# apply_marks=True, # should always be set to True
# partition_cfg=[
# dict(
# save_file='model.onnx', # name to save the partitioned onnx
# start=['detector_forward:input'], # [mark_name:input, ...]
# end=['rtmdet_head:output'], # [mark_name:output, ...]
# output_names=[f'pred_maps.{i}' for i in range(6)]) # output names
# ])
194 changes: 194 additions & 0 deletions csrc/mmdeploy/codebase/mmdet/rtmdet_head.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,194 @@
// Copyright (c) OpenMMLab. All rights reserved.
#include "rtmdet_head.h"

#include <math.h>

#include <algorithm>
#include <numeric>

#include "mmdeploy/core/model.h"
#include "mmdeploy/core/utils/device_utils.h"
#include "mmdeploy/core/utils/formatter.h"
#include "utils.h"

namespace mmdeploy::mmdet {

RTMDetSepBNHead::RTMDetSepBNHead(const Value& cfg) : MMDetection(cfg) {
auto init = [&]() -> Result<void> {
auto model = cfg["context"]["model"].get<Model>();
if (cfg.contains("params")) {
nms_pre_ = cfg["params"].value("nms_pre", -1);
score_thr_ = cfg["params"].value("score_thr", 0.02f);
min_bbox_size_ = cfg["params"].value("min_bbox_size", 0);
max_per_img_ = cfg["params"].value("max_per_img", 100);
iou_threshold_ = cfg["params"].contains("nms")
? cfg["params"]["nms"].value("iou_threshold", 0.45f)
: 0.45f;
if (cfg["params"].contains("anchor_generator")) {
offset_ = cfg["params"]["anchor_generator"].value("offset", 0);
from_value(cfg["params"]["anchor_generator"]["strides"], strides_);
}
}
return success();
};
init().value();
}

Result<Value> RTMDetSepBNHead::operator()(const Value& prep_res, const Value& infer_res) {
MMDEPLOY_DEBUG("prep_res: {}\ninfer_res: {}", prep_res, infer_res);
try {
std::vector<Tensor> cls_scores;
std::vector<Tensor> bbox_preds;
const Device kHost{0, 0};
int i = 0;
int divisor = infer_res.size() / 2;
for (auto iter = infer_res.begin(); iter != infer_res.end(); iter++) {
auto pred_map = iter->get<Tensor>();
OUTCOME_TRY(auto _pred_map, MakeAvailableOnDevice(pred_map, kHost, stream()));
if (i < divisor)
cls_scores.push_back(_pred_map);
else
bbox_preds.push_back(_pred_map);
i++;
}
OUTCOME_TRY(stream().Wait());
OUTCOME_TRY(auto result, GetBBoxes(prep_res["img_metas"], bbox_preds, cls_scores));
return to_value(result);
} catch (...) {
return Status(eFail);
}
}

static float sigmoid(float x) { return 1.0 / (1.0 + expf(-x)); }

Result<Detections> RTMDetSepBNHead::GetBBoxes(const Value& prep_res,
const std::vector<Tensor>& bbox_preds,
const std::vector<Tensor>& cls_scores) const {
MMDEPLOY_DEBUG("bbox_pred: {}, {}", bbox_preds[0].shape(), dets[0].data_type());

This comment has been minimized.

Copy link
@jev001

jev001 Jul 3, 2024

开启DEBUG日志的情况下, 编译提示参数未定义 经排查
此处参数名错误.
MMDEPLOY_DEBUG("bbox_pred: {}, {}", bbox_preds[0].shape(), dets[0].data_type());

MMDEPLOY_DEBUG("cls_score: {}, {}", scores[0].shape(), scores[0].data_type());

std::vector<float> filter_boxes;
std::vector<float> obj_probs;
std::vector<int> class_ids;

for (int i = 0; i < bbox_preds.size(); i++) {
RTMDetFeatDeocde(bbox_preds[i], cls_scores[i], strides_[i], offset_, filter_boxes, obj_probs,
class_ids);
}

std::vector<int> indexArray;
for (int i = 0; i < obj_probs.size(); ++i) {
indexArray.push_back(i);
}
Sort(obj_probs, class_ids, indexArray);

Tensor dets(TensorDesc{Device{0, 0}, DataType::kFLOAT,
TensorShape{int(filter_boxes.size() / 4), 4}, "dets"});
std::copy(filter_boxes.begin(), filter_boxes.end(), dets.data<float>());
NMS(dets, iou_threshold_, indexArray);

Detections objs;
std::vector<float> scale_factor;
if (prep_res.contains("scale_factor")) {
from_value(prep_res["scale_factor"], scale_factor);
} else {
scale_factor = {1.f, 1.f, 1.f, 1.f};
}
int ori_width = prep_res["ori_shape"][2].get<int>();
int ori_height = prep_res["ori_shape"][1].get<int>();
auto det_ptr = dets.data<float>();
for (int i = 0; i < indexArray.size(); ++i) {
if (indexArray[i] == -1) {
continue;
}
int j = indexArray[i];
auto x1 = det_ptr[j * 4 + 0];
auto y1 = det_ptr[j * 4 + 1];
auto x2 = det_ptr[j * 4 + 2];
auto y2 = det_ptr[j * 4 + 3];
int label_id = class_ids[i];
float score = obj_probs[i];

MMDEPLOY_DEBUG("{}-th box: ({}, {}, {}, {}), {}, {}", i, x1, y1, x2, y2, label_id, score);

auto rect =
MapToOriginImage(x1, y1, x2, y2, scale_factor.data(), 0, 0, ori_width, ori_height, 0, 0);
if (rect[2] - rect[0] < min_bbox_size_ || rect[3] - rect[1] < min_bbox_size_) {
MMDEPLOY_DEBUG("ignore small bbox with width '{}' and height '{}", rect[2] - rect[0],
rect[3] - rect[1]);
continue;
}
Detection det{};
det.index = i;
det.label_id = label_id;
det.score = score;
det.bbox = rect;
objs.push_back(std::move(det));
}

return objs;
}

int RTMDetSepBNHead::RTMDetFeatDeocde(const Tensor& bbox_pred, const Tensor& cls_score,
const float stride, const float offset,
std::vector<float>& filter_boxes,
std::vector<float>& obj_probs,
std::vector<int>& class_ids) const {
int cls_param_num = cls_score.shape(1);
int feat_h = bbox_pred.shape(2);
int feat_w = bbox_pred.shape(3);
int feat_size = feat_h * feat_w;
auto bbox_ptr = bbox_pred.data<float>();
auto score_ptr = cls_score.data<float>(); // (b, c, h, w)
int valid_count = 0;
for (int i = 0; i < feat_h; i++) {
for (int j = 0; j < feat_w; j++) {
float max_score = score_ptr[i * feat_w + j];
int class_id = 0;
for (int k = 0; k < cls_param_num; k++) {
float score = score_ptr[k * feat_size + i * feat_w + j];
if (score > max_score) {
max_score = score;
class_id = k;
}
}
max_score = sigmoid(max_score);
if (max_score < score_thr_) continue;

obj_probs.push_back(max_score);
class_ids.push_back(class_id);

float tl_x = bbox_ptr[0 * feat_size + i * feat_w + j];
float tl_y = bbox_ptr[1 * feat_size + i * feat_w + j];
float br_x = bbox_ptr[2 * feat_size + i * feat_w + j];
float br_y = bbox_ptr[3 * feat_size + i * feat_w + j];

auto box = RTMDetdecode(tl_x, tl_y, br_x, br_y, stride, offset, j, i);

tl_x = box[0];
tl_y = box[1];
br_x = box[2];
br_y = box[3];

filter_boxes.push_back(tl_x);
filter_boxes.push_back(tl_y);
filter_boxes.push_back(br_x);
filter_boxes.push_back(br_y);
valid_count++;
}
}
return valid_count;
}

std::array<float, 4> RTMDetSepBNHead::RTMDetdecode(float tl_x, float tl_y, float br_x, float br_y,
float stride, float offset, int j, int i) const {
tl_x = (offset + j) * stride - tl_x;
tl_y = (offset + i) * stride - tl_y;
br_x = (offset + j) * stride + br_x;
br_y = (offset + i) * stride + br_y;
return std::array<float, 4>{tl_x, tl_y, br_x, br_y};
}

MMDEPLOY_REGISTER_CODEBASE_COMPONENT(MMDetection, RTMDetSepBNHead);

} // namespace mmdeploy::mmdet
34 changes: 34 additions & 0 deletions csrc/mmdeploy/codebase/mmdet/rtmdet_head.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
// Copyright (c) OpenMMLab. All rights reserved.
#ifndef MMDEPLOY_CODEBASE_MMDET_RTMDET_HEAD_H_
#define MMDEPLOY_CODEBASE_MMDET_RTMDET_HEAD_H_

#include "mmdeploy/codebase/mmdet/mmdet.h"
#include "mmdeploy/core/tensor.h"

namespace mmdeploy::mmdet {

class RTMDetSepBNHead : public MMDetection {
public:
explicit RTMDetSepBNHead(const Value& cfg);
Result<Value> operator()(const Value& prep_res, const Value& infer_res);
Result<Detections> GetBBoxes(const Value& prep_res, const std::vector<Tensor>& bbox_preds,
const std::vector<Tensor>& cls_scores) const;
int RTMDetFeatDeocde(const Tensor& bbox_pred, const Tensor& cls_score, const float stride,
const float offset, std::vector<float>& filter_boxes,
std::vector<float>& obj_probs, std::vector<int>& class_ids) const;
std::array<float, 4> RTMDetdecode(float tl_x, float tl_y, float br_x, float br_y, float stride,
float offset, int j, int i) const;

private:
float score_thr_{0.4f};
int nms_pre_{1000};
float iou_threshold_{0.45f};
int min_bbox_size_{0};
int max_per_img_{100};
float offset_{0.0f};
std::vector<float> strides_;
};

} // namespace mmdeploy::mmdet

#endif // MMDEPLOY_CODEBASE_MMDET_RTMDET_HEAD_H_
3 changes: 2 additions & 1 deletion csrc/mmdeploy/preprocess/transform/normalize.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -99,13 +99,14 @@ class Normalize : public Transform {
Tensor dst;
if (to_float_) {
OUTCOME_TRY(normalize_.Apply(tensor, dst));
data[key] = std::move(dst);
} else if (to_rgb_) {
auto src_mat = to_mat(tensor, PixelFormat::kBGR);
Mat dst_mat;
OUTCOME_TRY(cvt_color_.Apply(src_mat, dst_mat, PixelFormat::kBGR));
dst = to_tensor(src_mat);
data[key] = std::move(dst);
}
data[key] = std::move(dst);

for (auto& v : mean_) {
data["img_norm_cfg"]["mean"].push_back(v);
Expand Down
16 changes: 16 additions & 0 deletions docs/en/01-how-to-build/rockchip.md
Original file line number Diff line number Diff line change
Expand Up @@ -156,6 +156,22 @@ label: 65, score: 0.95
])
```

RTMDet: you may paste the following partition configuration into [detection_rknn-int8_static-640x640.py](https://github.com/open-mmlab/mmdeploy/blob/dev-1.x/configs/mmdet/detection/detection_rknn-int8_static-640x640.py):

```python
# rtmdet for rknn-toolkit and rknn-toolkit2
partition_config = dict(
type='rknn', # the partition policy name
apply_marks=True, # should always be set to True
partition_cfg=[
dict(
save_file='model.onnx', # name to save the partitioned onnx
start=['detector_forward:input'], # [mark_name:input, ...]
end=['rtmdet_head:output'], # [mark_name:output, ...]
output_names=[f'pred_maps.{i}' for i in range(6)]) # output names
])
```

RetinaNet & SSD & FSAF with rknn-toolkit2, you may paste the following partition configuration into [detection_rknn_static-320x320.py](https://github.com/open-mmlab/mmdeploy/tree/1.x/configs/mmdet/detection/detection_rknn_static-320x320.py). Users with rknn-toolkit can directly use default config.

```python
Expand Down
2 changes: 1 addition & 1 deletion docs/en/02-how-to-run/convert_model.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ Notes:
### Prerequisite

1. Install and build your target backend. You could refer to [ONNXRuntime-install](../05-supported-backends/onnxruntime.md), [TensorRT-install](../05-supported-backends/tensorrt.md), [ncnn-install](../05-supported-backends/ncnn.md), [PPLNN-install](../05-supported-backends/pplnn.md), [OpenVINO-install](../05-supported-backends/openvino.md) for more information.
2. Install and build your target codebase. You could refer to [MMClassification-install](https://github.com/open-mmlab/mmclassification/blob/1.x/docs/en/get_started.md#installation), [MMDetection-install](https://github.com/open-mmlab/mmdetection/blob/3.x/docs/en/get_started.md), [MMSegmentation-install](https://github.com/open-mmlab/mmsegmentation/blob/1.x/docs/en/get_started.md#installation), [MMOCR-install](https://github.com/open-mmlab/mmocr/blob/1.x/docs/en/get_started/install.md), [MMEditing-install](https://github.com/open-mmlab/mmediting/blob/1.x/docs/en/2_get_started.md#installation).
2. Install and build your target codebase. You could refer to [MMClassification-install](https://github.com/open-mmlab/mmclassification/blob/1.x/docs/en/get_started.md#installation), [MMDetection-install](https://github.com/open-mmlab/mmdetection/blob/3.x/docs/en/get_started.md), [MMSegmentation-install](https://github.com/open-mmlab/mmsegmentation/blob/1.x/docs/en/get_started.md#installation), [MMOCR-install](https://github.com/open-mmlab/mmocr/blob/1.x/docs/en/get_started/install.md), [MMEditing-install](https://github.com/open-mmlab/mmediting/blob/1.x/docs/en/get_started/install.md).

### Usage

Expand Down
2 changes: 1 addition & 1 deletion docs/en/02-how-to-run/write_config.md
Original file line number Diff line number Diff line change
Expand Up @@ -177,4 +177,4 @@ detection_tensorrt-int8_dynamic-320x320-1344x1344.py

## 6. How to write model config

According to model's codebase, write the model config file. Model's config file is used to initialize the model, referring to [MMClassification](https://github.com/open-mmlab/mmclassification/blob/1.x/docs/en/user_guides/config.md), [MMDetection](https://github.com/open-mmlab/mmdetection/blob/3.x/docs/en/user_guides/config.md), [MMSegmentation](https://github.com/open-mmlab/mmsegmentation/blob/1.x/docs/en/user_guides/1_config.md), [MMOCR](https://github.com/open-mmlab/mmocr/blob/1.x/docs/en/user_guides/config.md), [MMEditing](https://github.com/open-mmlab/mmediting/blob/1.x/docs/en/user_guides/1_config.md).
According to model's codebase, write the model config file. Model's config file is used to initialize the model, referring to [MMClassification](https://github.com/open-mmlab/mmclassification/blob/1.x/docs/en/user_guides/config.md), [MMDetection](https://github.com/open-mmlab/mmdetection/blob/3.x/docs/en/user_guides/config.md), [MMSegmentation](https://github.com/open-mmlab/mmsegmentation/blob/1.x/docs/en/user_guides/1_config.md), [MMOCR](https://github.com/open-mmlab/mmocr/blob/1.x/docs/en/user_guides/config.md), [MMEditing](https://github.com/open-mmlab/mmediting/blob/1.x/docs/en/user_guides/config.md).
18 changes: 18 additions & 0 deletions docs/zh_cn/01-how-to-build/rockchip.md
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,24 @@ python tools/deploy.py \

```

- RTMDet

将下面的模型拆分配置写入到 [detection_rknn-int8_static-640x640.py](https://github.com/open-mmlab/mmdeploy/blob/dev-1.x/configs/mmdet/detection/detection_rknn-int8_static-640x640.py)

```python
# rtmdet for rknn-toolkit and rknn-toolkit2
partition_config = dict(
type='rknn', # the partition policy name
apply_marks=True, # should always be set to True
partition_cfg=[
dict(
save_file='model.onnx', # name to save the partitioned onnx
start=['detector_forward:input'], # [mark_name:input, ...]
end=['rtmdet_head:output'], # [mark_name:output, ...]
output_names=[f'pred_maps.{i}' for i in range(6)]) # output names
])
```

- RetinaNet & SSD & FSAF with rknn-toolkit2

将下面的模型拆分配置写入到 [detection_rknn_static.py](https://github.com/open-mmlab/mmdeploy/blob/1.x/configs/mmdet/detection/detection_rknn_static-320x320.py)。使用 rknn-toolkit 的用户则不用。
Expand Down
2 changes: 1 addition & 1 deletion docs/zh_cn/02-how-to-run/convert_model.md
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
### 准备工作

1. 安装您的目标后端。 您可以参考 [ONNXRuntime-install](../05-supported-backends/onnxruntime.md)[TensorRT-install](../05-supported-backends/tensorrt.md)[ncnn-install](../05-supported-backends/ncnn.md)[PPLNN-install](../05-supported-backends/pplnn.md), [OpenVINO-install](../05-supported-backends/openvino.md)
2. 安装您的目标代码库。 您可以参考 [MMClassification-install](https://github.com/open-mmlab/mmclassification/blob/1.x/docs/zh_CN/get_started.md#%E5%AE%89%E8%A3%85)[MMDetection-install](https://github.com/open-mmlab/mmdetection/blob/3.x/docs/zh_cn/get_started.md)[MMSegmentation-install](https://github.com/open-mmlab/mmsegmentation/blob/1.x/docs/zh_cn/get_started.md#installation)[MMOCR-install](https://github.com/open-mmlab/mmocr/blob/1.x/docs/zh_cn/get_started/install.md)[MMEditing-install](https://github.com/open-mmlab/mmediting/blob/1.x/docs/en/2_get_started.md#installation)
2. 安装您的目标代码库。 您可以参考 [MMClassification-install](https://github.com/open-mmlab/mmclassification/blob/1.x/docs/zh_CN/get_started.md#%E5%AE%89%E8%A3%85)[MMDetection-install](https://github.com/open-mmlab/mmdetection/blob/3.x/docs/zh_cn/get_started.md)[MMSegmentation-install](https://github.com/open-mmlab/mmsegmentation/blob/1.x/docs/zh_cn/get_started.md#installation)[MMOCR-install](https://github.com/open-mmlab/mmocr/blob/1.x/docs/zh_cn/get_started/install.md)[MMEditing-install](https://github.com/open-mmlab/mmediting/blob/1.x/docs/en/get_started/install.md)

### 使用方法

Expand Down
2 changes: 1 addition & 1 deletion docs/zh_cn/02-how-to-run/write_config.md
Original file line number Diff line number Diff line change
Expand Up @@ -187,4 +187,4 @@ detection_tensorrt-int8_dynamic-320x320-1344x1344.py

## 6. 如何编写模型配置文件

请根据模型具体任务的代码库,编写模型配置文件。 模型配置文件用于初始化模型,详情请参考[MMClassification](https://github.com/open-mmlab/mmclassification/blob/1.x/docs/zh_CN/user_guides/config.md)[MMDetection](https://github.com/open-mmlab/mmdetection/blob/3.x/docs/zh_cn/user_guides/config.md)[MMSegmentation](https://github.com/open-mmlab/mmsegmentation/blob/1.x/docs/zh_cn/user_guides/1_config.md)[MMOCR](https://github.com/open-mmlab/mmocr/blob/1.x/docs/en/user_guides/config.md)[MMEditing](https://github.com/open-mmlab/mmediting/blob/1.x/docs/en/user_guides/1_config.md)
请根据模型具体任务的代码库,编写模型配置文件。 模型配置文件用于初始化模型,详情请参考[MMClassification](https://github.com/open-mmlab/mmclassification/blob/1.x/docs/zh_CN/user_guides/config.md)[MMDetection](https://github.com/open-mmlab/mmdetection/blob/3.x/docs/zh_cn/user_guides/config.md)[MMSegmentation](https://github.com/open-mmlab/mmsegmentation/blob/1.x/docs/zh_cn/user_guides/1_config.md)[MMOCR](https://github.com/open-mmlab/mmocr/blob/1.x/docs/en/user_guides/config.md)[MMEditing](https://github.com/open-mmlab/mmediting/blob/1.x/docs/en/user_guides/config.md)
3 changes: 2 additions & 1 deletion mmdeploy/codebase/mmdet/deploy/object_detection.py
Original file line number Diff line number Diff line change
Expand Up @@ -313,7 +313,8 @@ def get_postprocess(self, *args, **kwargs) -> Dict:
params['mask_thr_binary'] = params['rcnn']['mask_thr_binary']
type = 'ResizeInstanceMask' # for instance-seg
if get_backend(self.deploy_cfg) == Backend.RKNN:
if 'YOLO' in self.model_cfg.model.type:
if 'YOLO' in self.model_cfg.model.type or \
'RTMDet' in self.model_cfg.model.type:
bbox_head = self.model_cfg.model.bbox_head
type = bbox_head.type
params['anchor_generator'] = bbox_head.get(
Expand Down
8 changes: 8 additions & 0 deletions mmdeploy/codebase/mmdet/deploy/object_detection_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -755,6 +755,14 @@ def _get_bboxes(self, outputs: List[Tensor], metainfos: Any):
batch_img_metas=metainfos,
cfg=self.model_cfg._cfg_dict.model.test_cfg,
rescale=True)
elif head_cfg.type == 'RTMDetSepBNHead':
divisor = round(len(outputs) / 2)
ret = head.predict_by_feat(
outputs[:divisor],
outputs[divisor:],
batch_img_metas=metainfos,
cfg=self.model_cfg._cfg_dict.model.test_cfg,
rescale=True)
elif head_cfg.type in ('RetinaHead', 'SSDHead', 'FSAFHead'):
partition_cfgs = get_partition_config(self.deploy_cfg)
if partition_cfgs is None: # bbox decoding done in rknn model
Expand Down
8 changes: 7 additions & 1 deletion mmdeploy/codebase/mmdet/models/dense_heads/rtmdet_head.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from torch import Tensor

from mmdeploy.codebase.mmdet import get_post_processing_params
from mmdeploy.core import FUNCTION_REWRITER
from mmdeploy.core import FUNCTION_REWRITER, mark
from mmdeploy.mmcv.ops import multiclass_nms


Expand Down Expand Up @@ -51,6 +51,12 @@ def rtmdet_head__predict_by_feat(self,
tensor in the tuple is (N, num_box), and each element
represents the class label of the corresponding box.
"""

@mark('rtmdet_head', inputs=['cls_scores', 'bbox_preds'])
def __mark_pred_maps(cls_scores, bbox_preds):
return cls_scores, bbox_preds

cls_scores, bbox_preds = __mark_pred_maps(cls_scores, bbox_preds)
ctx = FUNCTION_REWRITER.get_context()
assert len(cls_scores) == len(bbox_preds)
device = cls_scores[0].device
Expand Down

0 comments on commit 71fc8e3

Please sign in to comment.