diff --git a/configs/_base_/backends/rknn.py b/configs/_base_/backends/rknn.py index 640a982655..39b62bd8f6 100644 --- a/configs/_base_/backends/rknn.py +++ b/configs/_base_/backends/rknn.py @@ -1,8 +1,6 @@ backend_config = dict( type='rknn', common_config=dict( - mean_values=None, # [[103.53, 116.28, 123.675]], - std_values=None, # [[57.375, 57.12, 58.395]], target_platform='rv1126', # 'rk3588' optimization_level=1), - quantization_config=dict(do_quantization=False, dataset=None)) + quantization_config=dict(do_quantization=True, dataset=None)) diff --git a/configs/mmcls/classification_rknn-fp16_static-224x224.py b/configs/mmcls/classification_rknn-fp16_static-224x224.py new file mode 100644 index 0000000000..a6a599e51e --- /dev/null +++ b/configs/mmcls/classification_rknn-fp16_static-224x224.py @@ -0,0 +1,7 @@ +_base_ = ['./classification_static.py', '../_base_/backends/rknn.py'] + +onnx_config = dict(input_shape=[224, 224]) +codebase_config = dict(model_type='end2end') +backend_config = dict( + input_size_list=[[3, 224, 224]], + quantization_config=dict(do_quantization=False)) diff --git a/configs/mmcls/classification_rknn_static-224x224.py b/configs/mmcls/classification_rknn-int8_static-224x224.py similarity index 78% rename from configs/mmcls/classification_rknn_static-224x224.py rename to configs/mmcls/classification_rknn-int8_static-224x224.py index 74f9b4e76a..02f772e496 100644 --- a/configs/mmcls/classification_rknn_static-224x224.py +++ b/configs/mmcls/classification_rknn-int8_static-224x224.py @@ -1,5 +1,5 @@ _base_ = ['./classification_static.py', '../_base_/backends/rknn.py'] onnx_config = dict(input_shape=[224, 224]) -codebase_config = dict(model_type='rknn') +codebase_config = dict(model_type='end2end') backend_config = dict(input_size_list=[[3, 224, 224]]) diff --git a/configs/mmdet/detection/detection_rknn-fp16_static-320x320.py b/configs/mmdet/detection/detection_rknn-fp16_static-320x320.py new file mode 100644 index 0000000000..f5b0134640 --- /dev/null +++ b/configs/mmdet/detection/detection_rknn-fp16_static-320x320.py @@ -0,0 +1,34 @@ +_base_ = ['../_base_/base_static.py', '../../_base_/backends/rknn.py'] + +onnx_config = dict(input_shape=[320, 320]) + +codebase_config = dict(model_type='rknn') + +backend_config = dict( + input_size_list=[[3, 320, 320]], + quantization_config=dict(do_quantization=False)) + +# # yolov3, yolox for rknn-toolkit and rknn-toolkit2 +# partition_config = dict( +# type='rknn', # the partition policy name +# apply_marks=True, # should always be set to True +# partition_cfg=[ +# dict( +# save_file='model.onnx', # name to save the partitioned onnx +# start=['detector_forward:input'], # [mark_name:input, ...] +# end=['yolo_head:input'], # [mark_name:output, ...] +# output_names=[f'pred_maps.{i}' for i in range(3)]) # out names +# ]) + +# # retinanet, ssd, fsaf for rknn-toolkit2 +# partition_config = dict( +# type='rknn', # the partition policy name +# apply_marks=True, +# partition_cfg=[ +# dict( +# save_file='model.onnx', +# start='detector_forward:input', +# end=['BaseDenseHead:output'], +# output_names=[f'BaseDenseHead.cls.{i}' for i in range(5)] + +# [f'BaseDenseHead.loc.{i}' for i in range(5)]) +# ]) diff --git a/configs/mmdet/detection/detection_rknn_static-320x320.py b/configs/mmdet/detection/detection_rknn-int8_static-320x320.py similarity index 65% rename from configs/mmdet/detection/detection_rknn_static-320x320.py rename to configs/mmdet/detection/detection_rknn-int8_static-320x320.py index 5507d5856c..eafb6351f0 100644 --- a/configs/mmdet/detection/detection_rknn_static-320x320.py +++ b/configs/mmdet/detection/detection_rknn-int8_static-320x320.py @@ -6,7 +6,7 @@ backend_config = dict(input_size_list=[[3, 320, 320]]) -# # yolov3, yolox +# # yolov3, yolox for rknn-toolkit and rknn-toolkit2 # partition_config = dict( # type='rknn', # the partition policy name # apply_marks=True, # should always be set to True @@ -14,10 +14,11 @@ # dict( # save_file='model.onnx', # name to save the partitioned onnx # start=['detector_forward:input'], # [mark_name:input, ...] -# end=['yolo_head:input']) # [mark_name:output, ...] +# end=['yolo_head:input'], # [mark_name:output, ...] +# output_names=[f'pred_maps.{i}' for i in range(3)]) # out names # ]) -# # retinanet, ssd, fsaf +# # retinanet, ssd, fsaf for rknn-toolkit2 # partition_config = dict( # type='rknn', # the partition policy name # apply_marks=True, @@ -25,5 +26,7 @@ # dict( # save_file='model.onnx', # start='detector_forward:input', -# end=['BaseDenseHead:output']) +# end=['BaseDenseHead:output'], +# output_names=[f'BaseDenseHead.cls.{i}' for i in range(5)] + +# [f'BaseDenseHead.loc.{i}' for i in range(5)]) # ]) diff --git a/configs/mmdet/detection/yolov3_partition_onnxruntime_static.py b/configs/mmdet/detection/yolov3_partition_onnxruntime_static.py index 20e10a2562..09791236ea 100644 --- a/configs/mmdet/detection/yolov3_partition_onnxruntime_static.py +++ b/configs/mmdet/detection/yolov3_partition_onnxruntime_static.py @@ -8,5 +8,6 @@ dict( save_file='yolov3.onnx', start=['detector_forward:input'], - end=['yolo_head:input']) + end=['yolo_head:input'], + output_names=[f'pred_maps.{i}' for i in range(3)]) ]) diff --git a/configs/mmseg/segmentation_rknn-fp16_static-320x320.py b/configs/mmseg/segmentation_rknn-fp16_static-320x320.py new file mode 100644 index 0000000000..31a8edf718 --- /dev/null +++ b/configs/mmseg/segmentation_rknn-fp16_static-320x320.py @@ -0,0 +1,9 @@ +_base_ = ['./segmentation_static.py', '../_base_/backends/rknn.py'] + +onnx_config = dict(input_shape=[320, 320]) + +codebase_config = dict(model_type='rknn') + +backend_config = dict( + input_size_list=[[3, 320, 320]], + quantization_config=dict(do_quantization=False)) diff --git a/configs/mmseg/segmentation_rknn_static-320x320.py b/configs/mmseg/segmentation_rknn-int8_static-320x320.py similarity index 100% rename from configs/mmseg/segmentation_rknn_static-320x320.py rename to configs/mmseg/segmentation_rknn-int8_static-320x320.py diff --git a/csrc/mmdeploy/codebase/mmdet/yolo_head.cpp b/csrc/mmdeploy/codebase/mmdet/yolo_head.cpp new file mode 100644 index 0000000000..69aab7ebcb --- /dev/null +++ b/csrc/mmdeploy/codebase/mmdet/yolo_head.cpp @@ -0,0 +1,228 @@ +// Copyright (c) OpenMMLab. All rights reserved. +#include "yolo_head.h" + +#include + +#include +#include + +#include "mmdeploy/core/model.h" +#include "mmdeploy/core/utils/device_utils.h" +#include "mmdeploy/core/utils/formatter.h" +#include "utils.h" + +namespace mmdeploy::mmdet { + +YOLOHead::YOLOHead(const Value& cfg) : MMDetection(cfg) { + auto init = [&]() -> Result { + auto model = cfg["context"]["model"].get(); + if (cfg.contains("params")) { + nms_pre_ = cfg["params"].value("nms_pre", -1); + score_thr_ = cfg["params"].value("score_thr", 0.02f); + min_bbox_size_ = cfg["params"].value("min_bbox_size", 0); + iou_threshold_ = cfg["params"].contains("nms") + ? cfg["params"]["nms"].value("iou_threshold", 0.45f) + : 0.45f; + if (cfg["params"].contains("anchor_generator")) { + from_value(cfg["params"]["anchor_generator"]["base_sizes"], anchors_); + from_value(cfg["params"]["anchor_generator"]["strides"], strides_); + } + } + return success(); + }; + init().value(); +} + +Result YOLOHead::operator()(const Value& prep_res, const Value& infer_res) { + MMDEPLOY_DEBUG("prep_res: {}\ninfer_res: {}", prep_res, infer_res); + try { + const Device kHost{0, 0}; + std::vector pred_maps; + for (auto iter = infer_res.begin(); iter != infer_res.end(); iter++) { + auto pred_map = iter->get(); + OUTCOME_TRY(auto _pred_map, MakeAvailableOnDevice(pred_map, kHost, stream())); + pred_maps.push_back(_pred_map); + } + OUTCOME_TRY(stream().Wait()); + // reorder pred_maps according to strides and anchors, mainly for rknpu yolov3 + if ((pred_maps.size() > 1) && + !((strides_[0] < strides_[1]) ^ (pred_maps[0].shape(3) < pred_maps[1].shape(3)))) { + std::reverse(pred_maps.begin(), pred_maps.end()); + } + OUTCOME_TRY(auto result, GetBBoxes(prep_res["img_metas"], pred_maps)); + return to_value(result); + } catch (...) { + return Status(eFail); + } +} + +inline static int clamp(float val, int min, int max) { + return val > min ? (val < max ? val : max) : min; +} + +static float sigmoid(float x) { return 1.0 / (1.0 + expf(-x)); } + +static float unsigmoid(float y) { return -1.0 * logf((1.0 / y) - 1.0); } + +int YOLOHead::YOLOFeatDecode(const Tensor& feat_map, const std::vector>& anchor, + int grid_h, int grid_w, int height, int width, int stride, + std::vector& boxes, std::vector& obj_probs, + std::vector& class_id, float threshold) const { + auto input = const_cast(feat_map.data()); + auto prop_box_size = feat_map.shape(1) / anchor.size(); + const int kClasses = prop_box_size - 5; + int valid_count = 0; + int grid_len = grid_h * grid_w; + float thres = unsigmoid(threshold); + for (int a = 0; a < anchor.size(); a++) { + for (int i = 0; i < grid_h; i++) { + for (int j = 0; j < grid_w; j++) { + float box_confidence = input[(prop_box_size * a + 4) * grid_len + i * grid_w + j]; + if (box_confidence >= thres) { + int offset = (prop_box_size * a) * grid_len + i * grid_w + j; + float* in_ptr = input + offset; + + float box_x = sigmoid(*in_ptr); + float box_y = sigmoid(in_ptr[grid_len]); + float box_w = in_ptr[2 * grid_len]; + float box_h = in_ptr[3 * grid_len]; + auto box = yolo_decode(box_x, box_y, box_w, box_h, stride, anchor, j, i, a); + + box_x = box[0]; + box_y = box[1]; + box_w = box[2]; + box_h = box[3]; + + box_x -= (box_w / 2.0); + box_y -= (box_h / 2.0); + boxes.push_back(box_x); + boxes.push_back(box_y); + boxes.push_back(box_x + box_w); + boxes.push_back(box_y + box_h); + + float max_class_probs = in_ptr[5 * grid_len]; + int max_class_id = 0; + for (int k = 1; k < kClasses; ++k) { + float prob = in_ptr[(5 + k) * grid_len]; + if (prob > max_class_probs) { + max_class_id = k; + max_class_probs = prob; + } + } + obj_probs.push_back(sigmoid(max_class_probs) * sigmoid(box_confidence)); + class_id.push_back(max_class_id); + valid_count++; + } + } + } + } + return valid_count; +} + +Result YOLOHead::GetBBoxes(const Value& prep_res, + const std::vector& pred_maps) const { + std::vector filter_boxes; + std::vector obj_probs; + std::vector class_id; + + int model_in_h = prep_res["img_shape"][1].get(); + int model_in_w = prep_res["img_shape"][2].get(); + + for (int i = 0; i < pred_maps.size(); i++) { + int stride = strides_[i]; + int grid_h = model_in_h / stride; + int grid_w = model_in_w / stride; + YOLOFeatDecode(pred_maps[i], anchors_[i], grid_h, grid_w, model_in_h, model_in_w, stride, + filter_boxes, obj_probs, class_id, score_thr_); + } + + std::vector indexArray; + for (int i = 0; i < obj_probs.size(); ++i) { + indexArray.push_back(i); + } + Sort(obj_probs, class_id, indexArray); + + Tensor dets(TensorDesc{Device{0, 0}, DataType::kFLOAT, + TensorShape{int(filter_boxes.size() / 4), 4}, "dets"}); + std::copy(filter_boxes.begin(), filter_boxes.end(), dets.data()); + NMS(dets, iou_threshold_, indexArray); + + Detections objs; + std::vector scale_factor; + if (prep_res.contains("scale_factor")) { + from_value(prep_res["scale_factor"], scale_factor); + } else { + scale_factor = {1.f, 1.f, 1.f, 1.f}; + } + int ori_width = prep_res["ori_shape"][2].get(); + int ori_height = prep_res["ori_shape"][1].get(); + auto det_ptr = dets.data(); + for (int i = 0; i < indexArray.size(); ++i) { + if (indexArray[i] == -1) { + continue; + } + int j = indexArray[i]; + auto x1 = clamp(det_ptr[j * 4 + 0], 0, model_in_w); + auto y1 = clamp(det_ptr[j * 4 + 1], 0, model_in_h); + auto x2 = clamp(det_ptr[j * 4 + 2], 0, model_in_w); + auto y2 = clamp(det_ptr[j * 4 + 3], 0, model_in_h); + int label_id = class_id[i]; + float score = obj_probs[i]; + + MMDEPLOY_DEBUG("{}-th box: ({}, {}, {}, {}), {}, {}", i, x1, y1, x2, y2, label_id, score); + + auto rect = MapToOriginImage(x1, y1, x2, y2, scale_factor.data(), 0, 0, ori_width, ori_height); + if (rect[2] - rect[0] < min_bbox_size_ || rect[3] - rect[1] < min_bbox_size_) { + MMDEPLOY_DEBUG("ignore small bbox with width '{}' and height '{}", rect[2] - rect[0], + rect[3] - rect[1]); + continue; + } + Detection det{}; + det.index = i; + det.label_id = label_id; + det.score = score; + det.bbox = rect; + objs.push_back(std::move(det)); + } + + return objs; +} + +Result YOLOV3Head::operator()(const Value& prep_res, const Value& infer_res) { + return YOLOHead::operator()(prep_res, infer_res); +} + +std::array YOLOV3Head::yolo_decode(float box_x, float box_y, float box_w, float box_h, + float stride, + const std::vector>& anchor, int j, + int i, int a) const { + box_x = (box_x + j) * stride; + box_y = (box_y + i) * stride; + box_w = expf(box_w) * anchor[a][0]; + box_h = expf(box_h) * anchor[a][1]; + return std::array{box_x, box_y, box_w, box_h}; +} + +Result YOLOV5Head::operator()(const Value& prep_res, const Value& infer_res) { + return YOLOHead::operator()(prep_res, infer_res); +} + +std::array YOLOV5Head::yolo_decode(float box_x, float box_y, float box_w, float box_h, + float stride, + const std::vector>& anchor, int j, + int i, int a) const { + box_x = box_x * 2 - 0.5; + box_y = box_y * 2 - 0.5; + box_w = box_w * 2 - 0.5; + box_h = box_h * 2 - 0.5; + box_x = (box_x + j) * stride; + box_y = (box_y + i) * stride; + box_w = box_w * box_w * anchor[a][0]; + box_h = box_h * box_h * anchor[a][1]; + return std::array{box_x, box_y, box_w, box_h}; +} + +MMDEPLOY_REGISTER_CODEBASE_COMPONENT(MMDetection, YOLOV3Head); +MMDEPLOY_REGISTER_CODEBASE_COMPONENT(MMDetection, YOLOV5Head); + +} // namespace mmdeploy::mmdet diff --git a/csrc/mmdeploy/codebase/mmdet/yolo_head.h b/csrc/mmdeploy/codebase/mmdet/yolo_head.h new file mode 100644 index 0000000000..08421b3f68 --- /dev/null +++ b/csrc/mmdeploy/codebase/mmdet/yolo_head.h @@ -0,0 +1,53 @@ +// Copyright (c) OpenMMLab. All rights reserved. +#ifndef MMDEPLOY_CODEBASE_MMDET_YOLO_HEAD_H_ +#define MMDEPLOY_CODEBASE_MMDET_YOLO_HEAD_H_ + +#include "mmdeploy/codebase/mmdet/mmdet.h" +#include "mmdeploy/core/tensor.h" + +namespace mmdeploy::mmdet { + +class YOLOHead : public MMDetection { + public: + explicit YOLOHead(const Value& cfg); + Result operator()(const Value& prep_res, const Value& infer_res); + int YOLOFeatDecode(const Tensor& feat_map, const std::vector>& anchor, + int grid_h, int grid_w, int height, int width, int stride, + std::vector& boxes, std::vector& obj_probs, + std::vector& class_id, float threshold) const; + Result GetBBoxes(const Value& prep_res, const std::vector& pred_maps) const; + virtual std::array yolo_decode(float box_x, float box_y, float box_w, float box_h, + float stride, + const std::vector>& anchor, int j, + int i, int a) const = 0; + + private: + float score_thr_{0.4f}; + int nms_pre_{1000}; + float iou_threshold_{0.45f}; + int min_bbox_size_{0}; + std::vector>> anchors_; + std::vector strides_; +}; + +class YOLOV3Head : public YOLOHead { + public: + using YOLOHead::YOLOHead; + Result operator()(const Value& prep_res, const Value& infer_res); + std::array yolo_decode(float box_x, float box_y, float box_w, float box_h, float stride, + const std::vector>& anchor, int j, int i, + int a) const override; +}; + +class YOLOV5Head : public YOLOHead { + public: + using YOLOHead::YOLOHead; + Result operator()(const Value& prep_res, const Value& infer_res); + std::array yolo_decode(float box_x, float box_y, float box_w, float box_h, float stride, + const std::vector>& anchor, int j, int i, + int a) const override; +}; + +} // namespace mmdeploy::mmdet + +#endif // MMDEPLOY_CODEBASE_MMDET_YOLO_HEAD_H_ diff --git a/docs/en/01-how-to-build/rockchip.md b/docs/en/01-how-to-build/rockchip.md index 6c22a2fbdb..c31102391e 100644 --- a/docs/en/01-how-to-build/rockchip.md +++ b/docs/en/01-how-to-build/rockchip.md @@ -138,30 +138,12 @@ label: 65, score: 0.95 ## Troubleshooting -- Quantization fails. - - Empirically, RKNN require the inputs not normalized if `do_quantization` is set to `True`. Please modify the settings of `Normalize` in the `model_cfg` from - - ```python - img_norm_cfg = dict( - mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True) - ``` - - to - - ```python - img_norm_cfg = dict( - mean=[0, 0, 0], std=[1, 1, 1], to_rgb=True) - ``` - - Besides, the `mean_values` and `std_values` of deploy_cfg should be replaced with original normalization settings of `model_cfg`. Let `mean_values=[[103.53, 116.28, 123.675]]` and `std_values=[[57.375, 57.12, 58.395]]`. - - MMDet models. YOLOV3 & YOLOX: you may paste the following partition configuration into [detection_rknn_static.py](https://github.com/open-mmlab/mmdeploy/blob/master/configs/mmdet/detection/detection_rknn_static.py): ```python - # yolov3, yolox + # yolov3, yolox for rknn-toolkit and rknn-toolkit2 partition_config = dict( type='rknn', # the partition policy name apply_marks=True, # should always be set to True @@ -169,7 +151,8 @@ label: 65, score: 0.95 dict( save_file='model.onnx', # name to save the partitioned onnx start=['detector_forward:input'], # [mark_name:input, ...] - end=['yolo_head:input']) # [mark_name:output, ...] + end=['yolo_head:input'], # [mark_name:output, ...] + output_names=[f'pred_maps.{i}' for i in range(3)]) # output names ]) ``` @@ -184,7 +167,9 @@ label: 65, score: 0.95 dict( save_file='model.onnx', start='detector_forward:input', - end=['BaseDenseHead:output']) + end=['BaseDenseHead:output'], + output_names=[f'BaseDenseHead.cls.{i}' for i in range(5)] + + [f'BaseDenseHead.loc.{i}' for i in range(5)]) ]) ``` diff --git a/docs/en/07-developer-guide/partition_model.md b/docs/en/07-developer-guide/partition_model.md index 7c981af364..1e482288ca 100644 --- a/docs/en/07-developer-guide/partition_model.md +++ b/docs/en/07-developer-guide/partition_model.md @@ -66,7 +66,8 @@ partition_config = dict( dict( save_file='yolov3.onnx', # filename to save the partitioned onnx model start=['detector_forward:input'], # [mark_name:input/output, ...] - end=['yolo_head:input']) # [mark_name:input/output, ...] + end=['yolo_head:input'], # [mark_name:input/output, ...] + output_names=[f'pred_maps.{i}' for i in range(3)]) # output names ]) ``` diff --git a/docs/zh_cn/01-how-to-build/rockchip.md b/docs/zh_cn/01-how-to-build/rockchip.md index 111745c31a..2b7722c504 100644 --- a/docs/zh_cn/01-how-to-build/rockchip.md +++ b/docs/zh_cn/01-how-to-build/rockchip.md @@ -105,7 +105,7 @@ python tools/deploy.py \ 将下面的模型拆分配置写入到 [detection_rknn_static.py](https://github.com/open-mmlab/mmdeploy/blob/master/configs/mmdet/detection/detection_rknn_static.py) ```python -# yolov3, yolox +# yolov3, yolox for rknn-toolkit and rknn-toolkit2 partition_config = dict( type='rknn', # the partition policy name apply_marks=True, # should always be set to True @@ -113,7 +113,8 @@ partition_config = dict( dict( save_file='model.onnx', # name to save the partitioned onnx start=['detector_forward:input'], # [mark_name:input, ...] - end=['yolo_head:input']) # [mark_name:output, ...] + end=['yolo_head:input'], # [mark_name:output, ...] + output_names=[f'pred_maps.{i}' for i in range(3)]) # output names ]) ``` @@ -143,7 +144,9 @@ partition_config = dict( dict( save_file='model.onnx', start='detector_forward:input', - end=['BaseDenseHead:output']) + end=['BaseDenseHead:output'], + output_names=[f'BaseDenseHead.cls.{i}' for i in range(5)] + + [f'BaseDenseHead.loc.{i}' for i in range(5)]) ]) ``` @@ -168,24 +171,6 @@ backend_config = dict( ### 问题说明 -- 量化失败. - - 经验来说, 如果 `do_quantization` 被设置为 `True`,RKNN 需要的输入没有被归一化过。请修改 `Normalize` 在 `model_cfg` 的设置,如将 - - ```python - img_norm_cfg = dict( - mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True) - ``` - - 改为 - - ```python - img_norm_cfg = dict( - mean=[0, 0, 0], std=[1, 1, 1], to_rgb=True) - ``` - - 此外, deploy_cfg 的 `mean_values` 和 `std_values` 应该被设置为 `model_cfg` 中归一化的设置. 使 `mean_values=[[103.53, 116.28, 123.675]]`, `std_values=[[57.375, 57.12, 58.395]]`。 - - SDK 只支持 int8 的 rknn 模型,这需要在转换模型时设置 `do_quantization=True`。 ## 模型推理 diff --git a/docs/zh_cn/07-developer-guide/partition_model.md b/docs/zh_cn/07-developer-guide/partition_model.md index a9104173ad..81e6ac7f1e 100644 --- a/docs/zh_cn/07-developer-guide/partition_model.md +++ b/docs/zh_cn/07-developer-guide/partition_model.md @@ -64,7 +64,8 @@ partition_config = dict( dict( save_file='yolov3.onnx', # filename to save the partitioned onnx model start=['detector_forward:input'], # [mark_name:input/output, ...] - end=['yolo_head:input']) # [mark_name:input/output, ...] + end=['yolo_head:input'], # [mark_name:input/output, ...] + output_names=[f'pred_maps.{i}' for i in range(3)]) # output names ]) ``` diff --git a/mmdeploy/apis/visualize.py b/mmdeploy/apis/visualize.py index 4602fc13b4..3f8f30ec72 100644 --- a/mmdeploy/apis/visualize.py +++ b/mmdeploy/apis/visualize.py @@ -64,10 +64,12 @@ def visualize_model(model_cfg: Union[str, mmcv.Config], if backend == Backend.PYTORCH: model = task_processor.init_pytorch_model(model[0]) + model_inputs, _ = task_processor.create_input(img, input_shape) else: model = task_processor.init_backend_model(model, **kwargs) + model_inputs, _ = task_processor.create_input( + img, input_shape, task_processor.update_test_pipeline) - model_inputs, _ = task_processor.create_input(img, input_shape) with torch.no_grad(): result = task_processor.run_inference(model, model_inputs)[0] diff --git a/mmdeploy/backend/rknn/onnx2rknn.py b/mmdeploy/backend/rknn/onnx2rknn.py index e7ede81a0c..88d62152b2 100644 --- a/mmdeploy/backend/rknn/onnx2rknn.py +++ b/mmdeploy/backend/rknn/onnx2rknn.py @@ -4,15 +4,17 @@ import mmcv from rknn.api import RKNN -from mmdeploy.utils import (get_common_config, get_onnx_config, +from mmdeploy.utils import (get_backend_config, get_common_config, + get_normalization, get_onnx_config, get_partition_config, get_quantization_config, - get_root_logger, load_config) -from mmdeploy.utils.config_utils import get_backend_config + get_rknn_quantization, get_root_logger, + load_config) def onnx2rknn(onnx_model: str, output_path: str, deploy_cfg: Union[str, mmcv.Config], + model_cfg: Union[str, mmcv.Config], dataset_file: Optional[str] = None, **kwargs): """Convert ONNX to RKNN. @@ -40,6 +42,14 @@ def onnx2rknn(onnx_model: str, output_names = onnx_params.get('output_names', None) input_size_list = get_backend_config(deploy_cfg).get( 'input_size_list', None) + # update norm value + if get_rknn_quantization(deploy_cfg) is True: + transform = get_normalization(model_cfg) + common_params.update( + dict( + mean_values=[transform['mean']], + std_values=[transform['std']])) + # update output_names for partition models if get_partition_config(deploy_cfg) is not None: import onnx @@ -62,7 +72,7 @@ def onnx2rknn(onnx_model: str, if dataset_cfg is None and dataset_file is None: do_quantization = False logger.warning('no dataset passed in, quantization is skipped') - if dataset_file is None: + if dataset_cfg is not None: dataset_file = dataset_cfg ret = rknn.build(do_quantization=do_quantization, dataset=dataset_file) if ret != 0: diff --git a/mmdeploy/backend/rknn/wrapper.py b/mmdeploy/backend/rknn/wrapper.py index 3502c55304..d84e3d3a9d 100644 --- a/mmdeploy/backend/rknn/wrapper.py +++ b/mmdeploy/backend/rknn/wrapper.py @@ -48,7 +48,7 @@ def __init__(self, super().__init__(output_names) def forward(self, inputs: Dict[str, - torch.Tensor]) -> Sequence[torch.Tensor]: + torch.Tensor]) -> Dict[str, torch.Tensor]: """Run forward inference. Note that the shape of the input tensor is NxCxHxW while RKNN only accepts the numpy inputs of NxHxWxC. There is a permute operation outside RKNN inference. @@ -57,11 +57,14 @@ def forward(self, inputs: Dict[str, inputs (Dict[str, torch.Tensor]): Input name and tensor pairs. Return: - Sequence[torch.Tensor]: The output tensors. + Dict[str, torch.Tensor]: The output tensors. """ rknn_out = self.__rknnnn_execute( [i.permute(0, 2, 3, 1).cpu().numpy() for i in inputs.values()]) - return [torch.from_numpy(out) for out in rknn_out] + rknn_out = [torch.from_numpy(out) for out in rknn_out] + if self.output_names is not None: + return dict(zip(self.output_names, rknn_out)) + return {'#' + str(i): x for i, x in enumerate(rknn_out)} @TimeCounter.count_time(Backend.RKNN.value) def __rknnnn_execute(self, inputs: Sequence[np.array]): diff --git a/mmdeploy/backend/sdk/export_info.py b/mmdeploy/backend/sdk/export_info.py index 70e54afab3..fa1329b1a8 100644 --- a/mmdeploy/backend/sdk/export_info.py +++ b/mmdeploy/backend/sdk/export_info.py @@ -7,7 +7,8 @@ from mmdeploy.apis import build_task_processor from mmdeploy.utils import (Backend, Task, get_backend, get_codebase, - get_common_config, get_ir_config, get_root_logger, + get_common_config, get_ir_config, + get_partition_config, get_root_logger, get_task_type, is_dynamic_batch, load_config) from mmdeploy.utils.constants import SDK_TASK_MAP as task_map from .tracer import add_transform_tag, get_transform_static @@ -94,6 +95,9 @@ def get_models(deploy_cfg: Union[str, mmcv.Config], name, _ = get_model_name_customs(deploy_cfg, model_cfg, work_dir, device) precision = 'FP32' ir_name = get_ir_config(deploy_cfg)['save_file'] + if get_partition_config(deploy_cfg) is not None: + ir_name = get_partition_config( + deploy_cfg)['partition_cfg'][0]['save_file'] net = ir_name weights = '' backend = get_backend(deploy_cfg=deploy_cfg) @@ -185,6 +189,9 @@ def get_inference_info(deploy_cfg: mmcv.Config, model_cfg: mmcv.Config, backend = get_backend(deploy_cfg=deploy_cfg) if backend in (Backend.TORCHSCRIPT, Backend.RKNN): output_names = ir_config.get('output_names', None) + if get_partition_config(deploy_cfg) is not None: + output_names = get_partition_config( + deploy_cfg)['partition_cfg'][0]['output_names'] input_map = dict(img='#0') output_map = {name: f'#{i}' for i, name in enumerate(output_names)} else: @@ -258,6 +265,8 @@ def get_preprocess(deploy_cfg: mmcv.Config, model_cfg: mmcv.Config, for transform in transforms: if transform['type'] == 'Normalize': transform['to_float'] = False + transform['mean'] = [0, 0, 0] + transform['std'] = [1, 1, 1] if transforms[0]['type'] != 'Lift': assert transforms[0]['type'] == 'LoadImageFromFile', \ @@ -299,14 +308,15 @@ def get_postprocess(deploy_cfg: mmcv.Config, model_cfg: mmcv.Config, task = Task.INSTANCE_SEGMENTATION component = task_map[task]['component'] - if get_backend(deploy_cfg) == Backend.RKNN: - if 'YOLO' in task_processor.model_cfg.model.type: - bbox_head = task_processor.model_cfg.model.bbox_head - component = bbox_head.type - params['anchor_generator'] = bbox_head.get('anchor_generator', - None) - else: # default using base_dense_head - component = 'BaseDenseHead' + if task == Task.OBJECT_DETECTION: + if get_backend(deploy_cfg) == Backend.RKNN: + if 'YOLO' in task_processor.model_cfg.model.type: + bbox_head = task_processor.model_cfg.model.bbox_head + component = bbox_head.type + params['anchor_generator'] = bbox_head.get( + 'anchor_generator', None) + else: # default using base_dense_head + component = 'BaseDenseHead' if task != Task.SUPER_RESOLUTION and task != Task.SEGMENTATION: if 'type' in params: diff --git a/mmdeploy/codebase/base/task.py b/mmdeploy/codebase/base/task.py index 1e27c35f50..c2cd8e8a7e 100644 --- a/mmdeploy/codebase/base/task.py +++ b/mmdeploy/codebase/base/task.py @@ -1,6 +1,6 @@ # Copyright (c) OpenMMLab. All rights reserved. from abc import ABCMeta, abstractmethod -from typing import Any, Dict, Optional, Sequence, Tuple, Union +from typing import Any, Callable, Dict, Optional, Sequence, Tuple, Union import mmcv import numpy as np @@ -8,7 +8,8 @@ from torch.utils.data import DataLoader, Dataset from mmdeploy.utils import (get_backend_config, get_codebase, - get_codebase_config, get_root_logger) + get_codebase_config, get_rknn_quantization, + get_root_logger) from mmdeploy.utils.dataset import is_can_sort_dataset, sort_dataset @@ -139,18 +140,49 @@ def single_gpu_test(self, return self.codebase_class.single_gpu_test(model, data_loader, show, out_dir, **kwargs) + @staticmethod + def update_test_pipeline(deploy_cfg: mmcv.Config, model_cfg: mmcv.Config): + """Update preprocess pipeline. + + Args: + model_cfg (str | mmcv.Config): Model config file. + deploy_cfg (str | mmcv.Config): Deployment config file. + + Returns: + cfg (mmcv.Config): Updated model_cfg. + """ + cfg = model_cfg.deepcopy() + if get_rknn_quantization(deploy_cfg): + pipelines = cfg.data.test.pipeline + for i, pipeline in enumerate(pipelines): + if pipeline['type'] == 'MultiScaleFlipAug': + assert 'transforms' in pipeline + for trans in pipeline['transforms']: + if trans['type'] == 'Normalize': + trans['mean'] = [0, 0, 0] + trans['std'] = [1, 1, 1] + else: + if pipeline['type'] == 'Normalize': + pipeline['mean'] = [0, 0, 0] + pipeline['std'] = [1, 1, 1] + cfg.data.test.pipeline = pipelines + return cfg + @abstractmethod def create_input(self, imgs: Union[str, np.ndarray, Sequence], - input_shape: Sequence[int] = None, + input_shape: Optional[Sequence[int]] = None, + pipeline_updater: Optional[Callable] = None, **kwargs) -> Tuple[Dict, torch.Tensor]: """Create input for model. Args: imgs (str | np.ndarray | Sequence): Input image(s), accepted data types are `str`, `np.ndarray`. - input_shape (list[int]): Input shape of image in (width, height) - format, defaults to `None`. + input_shape (Sequence[int] | None): Input shape of image in + (width, height) format, defaults to `None`. + pipeline_updater (function | None): A function to get a new + pipeline. Returns: tuple: (data, img), meta information for the input image and input diff --git a/mmdeploy/codebase/mmaction/deploy/video_recognition.py b/mmdeploy/codebase/mmaction/deploy/video_recognition.py index fb1a4ed7b7..ab0ec7739e 100644 --- a/mmdeploy/codebase/mmaction/deploy/video_recognition.py +++ b/mmdeploy/codebase/mmaction/deploy/video_recognition.py @@ -1,5 +1,5 @@ # Copyright (c) OpenMMLab. All rights reserved. -from typing import Any, Dict, Optional, Sequence, Tuple, Union +from typing import Any, Callable, Dict, Optional, Sequence, Tuple, Union import mmcv import numpy as np @@ -147,15 +147,18 @@ def init_pytorch_model(self, def create_input(self, imgs: Union[str, np.ndarray, Sequence], - input_shape: Sequence[int] = None) \ - -> Tuple[Dict, torch.Tensor]: + input_shape: Optional[Sequence[int]] = None, + pipeline_updater: Optional[Callable] = None, + **kwargs) -> Tuple[Dict, torch.Tensor]: """Create input for recognizer. Args: imgs (Any): Input image(s), accepted data type are `str`, `np.ndarray`, `torch.Tensor`. - input_shape (list[int]): A list of two integer in (width, height) - format specifying input shape. Defaults to `None`. + input_shape (Sequence[int] | None): Input shape of image in + (width, height) format, defaults to `None`. + pipeline_updater (function | None): A function to get a new + pipeline. Returns: tuple: (data, img), meta information for the input image and input. diff --git a/mmdeploy/codebase/mmcls/deploy/classification.py b/mmdeploy/codebase/mmcls/deploy/classification.py index f887ffac66..698f17a362 100644 --- a/mmdeploy/codebase/mmcls/deploy/classification.py +++ b/mmdeploy/codebase/mmcls/deploy/classification.py @@ -1,6 +1,6 @@ # Copyright (c) OpenMMLab. All rights reserved. import logging -from typing import Any, Dict, Optional, Sequence, Tuple, Union +from typing import Any, Callable, Dict, Optional, Sequence, Tuple, Union import mmcv import numpy as np @@ -8,8 +8,7 @@ from torch.utils.data import Dataset from mmdeploy.codebase.base import BaseTask -from mmdeploy.utils import Task, get_root_logger -from mmdeploy.utils.config_utils import get_input_shape +from mmdeploy.utils import Task, get_input_shape, get_root_logger from .mmclassification import MMCLS_TASK @@ -35,6 +34,7 @@ def process_model_config(model_cfg: mmcv.Config, else: if cfg.data.test.pipeline[0]['type'] == 'LoadImageFromFile': cfg.data.test.pipeline.pop(0) + # check whether input_shape is valid if input_shape is not None: if 'crop_size' in cfg.data.test.pipeline[2]: @@ -111,15 +111,18 @@ def init_pytorch_model(self, def create_input(self, imgs: Union[str, np.ndarray, Sequence], - input_shape: Optional[Sequence[int]] = None) \ + input_shape: Optional[Sequence[int]] = None, + pipeline_updater: Optional[Callable] = None, **kwargs) \ -> Tuple[Dict, torch.Tensor]: """Create input for classifier. Args: imgs (Union[str, np.ndarray, Sequence]): Input image(s), accepted data type are `str`, `np.ndarray`, Sequence. - input_shape (list[int]): A list of two integer in (width, height) - format specifying input shape. Default: None. + input_shape (Sequence[int] | None): Input shape of image in + (width, height) format, defaults to `None`. + pipeline_updater (function | None): A function to get a new + pipeline. Returns: tuple: (data, img), meta information for the input image and input. @@ -128,7 +131,10 @@ def create_input(self, from mmcv.parallel import collate, scatter if isinstance(imgs, (str, np.ndarray)): imgs = [imgs] - cfg = process_model_config(self.model_cfg, imgs, input_shape) + model_cfg = self.model_cfg + if pipeline_updater is not None: + model_cfg = pipeline_updater(self.deploy_cfg, model_cfg) + cfg = process_model_config(model_cfg, imgs, input_shape) data_list = [] test_pipeline = Compose(cfg.data.test.pipeline) for img in imgs: @@ -276,7 +282,8 @@ def get_preprocess(self) -> Dict: dict: Composed of the preprocess information. """ input_shape = get_input_shape(self.deploy_cfg) - cfg = process_model_config(self.model_cfg, [''], input_shape) + cfg = self.update_test_pipeline(self.deploy_cfg, self.model_cfg) + cfg = process_model_config(cfg, [''], input_shape) preprocess = cfg.data.test.pipeline return preprocess diff --git a/mmdeploy/codebase/mmcls/deploy/classification_model.py b/mmdeploy/codebase/mmcls/deploy/classification_model.py index eec7a25c00..6ebd8d5200 100644 --- a/mmdeploy/codebase/mmcls/deploy/classification_model.py +++ b/mmdeploy/codebase/mmcls/deploy/classification_model.py @@ -144,25 +144,6 @@ def forward(self, img: List[torch.Tensor], *args, **kwargs) -> list: return pred[np.argsort(pred[:, 0])][np.newaxis, :, 1] -@__BACKEND_MODEL.register_module('rknn') -class RKNNEnd2EndModel(End2EndModel): - """RKNN inference class, converts RKNN output to mmcls format.""" - - def forward_test(self, imgs: torch.Tensor, *args, **kwargs) -> \ - List[np.ndarray]: - """The interface for forward test. - - Args: - imgs (torch.Tensor): Input image(s) in [N x C x H x W] format. - - Returns: - List[np.ndarray]: A list of classification prediction. - """ - outputs = self.wrapper({self.input_name: imgs}) - outputs = [out.numpy() for out in outputs] - return outputs - - def get_classes_from_config(model_cfg: Union[str, mmcv.Config]): """Get class name from config. diff --git a/mmdeploy/codebase/mmdet/deploy/object_detection.py b/mmdeploy/codebase/mmdet/deploy/object_detection.py index acdaf8fe68..7da5c9c3f8 100644 --- a/mmdeploy/codebase/mmdet/deploy/object_detection.py +++ b/mmdeploy/codebase/mmdet/deploy/object_detection.py @@ -1,5 +1,5 @@ # Copyright (c) OpenMMLab. All rights reserved. -from typing import Any, Dict, Optional, Sequence, Tuple, Union +from typing import Any, Callable, Dict, Optional, Sequence, Tuple, Union import mmcv import numpy as np @@ -7,8 +7,7 @@ from mmcv.parallel import DataContainer from torch.utils.data import Dataset -from mmdeploy.utils import Task -from mmdeploy.utils.config_utils import get_input_shape, is_dynamic_shape +from mmdeploy.utils import Task, get_input_shape, is_dynamic_shape from ...base import BaseTask from .mmdetection import MMDET_TASK @@ -103,15 +102,18 @@ def init_pytorch_model(self, def create_input(self, imgs: Union[str, np.ndarray, Sequence], - input_shape: Sequence[int] = None) \ + input_shape: Optional[Sequence[int]] = None, + pipeline_updater: Optional[Callable] = None, **kwargs) \ -> Tuple[Dict, torch.Tensor]: """Create input for detector. Args: imgs (str|np.ndarray): Input image(s), accpeted data type are `str`, `np.ndarray`. - input_shape (list[int]): A list of two integer in (width, height) - format specifying input shape. Defaults to `None`. + input_shape (Sequence[int] | None): Input shape of image in + (width, height) format, defaults to `None`. + pipeline_updater (function | None): A function to get a new + pipeline. Returns: tuple: (data, img), meta information for the input image and input. @@ -121,7 +123,10 @@ def create_input(self, if isinstance(imgs, (str, np.ndarray)): imgs = [imgs] dynamic_flag = is_dynamic_shape(self.deploy_cfg) - cfg = process_model_config(self.model_cfg, imgs, input_shape) + model_cfg = self.model_cfg + if pipeline_updater is not None: + model_cfg = pipeline_updater(self.deploy_cfg, model_cfg) + cfg = process_model_config(model_cfg, imgs, input_shape) # Drop pad_to_square when static shape. Because static shape should # ensure the shape before input image. if not dynamic_flag: @@ -291,7 +296,8 @@ def get_preprocess(self) -> Dict: dict: Composed of the preprocess information. """ input_shape = get_input_shape(self.deploy_cfg) - model_cfg = process_model_config(self.model_cfg, [''], input_shape) + cfg = self.update_test_pipeline(self.deploy_cfg, self.model_cfg) + model_cfg = process_model_config(cfg, [''], input_shape) preprocess = model_cfg.data.test.pipeline return preprocess diff --git a/mmdeploy/codebase/mmdet/deploy/object_detection_model.py b/mmdeploy/codebase/mmdet/deploy/object_detection_model.py index baf0c31fcd..8d00375729 100644 --- a/mmdeploy/codebase/mmdet/deploy/object_detection_model.py +++ b/mmdeploy/codebase/mmdet/deploy/object_detection_model.py @@ -722,6 +722,7 @@ def forward_test(self, imgs: torch.Tensor, img_metas: Sequence[dict], class labels of shape [N, num_det]. """ outputs = self.wrapper({self.input_name: imgs}) + outputs = [i for i in outputs.values()] ret = self._get_bboxes(outputs, img_metas) return ret diff --git a/mmdeploy/codebase/mmdet3d/deploy/monocular_detection.py b/mmdeploy/codebase/mmdet3d/deploy/monocular_detection.py index 7f238f9cb0..4ad5bb98cd 100644 --- a/mmdeploy/codebase/mmdet3d/deploy/monocular_detection.py +++ b/mmdeploy/codebase/mmdet3d/deploy/monocular_detection.py @@ -1,7 +1,7 @@ # Copyright (c) OpenMMLab. All rights reserved. from copy import deepcopy from os import path as osp -from typing import Any, Dict, List, Optional, Sequence, Tuple, Union +from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, Union import mmcv import numpy as np @@ -65,12 +65,19 @@ def init_pytorch_model(self, def create_input(self, imgs: Union[str, np.ndarray], - input_shape: Sequence[int] = None) \ + input_shape: Optional[Sequence[int]] = None, + pipeline_updater: Optional[Callable] = None, **kwargs) \ -> Tuple[Dict, torch.Tensor]: """Create input for detector. Args: - pcd (str): Input pcd file path. + input_shape (Sequence[int] | None): Input shape of image in + (width, height) format, defaults to `None`. + input_shape (Sequence[int] | None): Input shape of image in + (width, height) format, defaults to `None`. + pipeline_updater (function | None): A function to get a new + pipeline. + Returns: tuple: (data, img), meta information for the input image and input. """ diff --git a/mmdeploy/codebase/mmedit/deploy/super_resolution.py b/mmdeploy/codebase/mmedit/deploy/super_resolution.py index dcc58f542e..58cc509154 100644 --- a/mmdeploy/codebase/mmedit/deploy/super_resolution.py +++ b/mmdeploy/codebase/mmedit/deploy/super_resolution.py @@ -1,6 +1,6 @@ # Copyright (c) OpenMMLab. All rights reserved. import warnings -from typing import Any, Dict, Optional, Sequence, Tuple, Union +from typing import Any, Callable, Dict, Optional, Sequence, Tuple, Union import mmcv import numpy as np @@ -117,13 +117,16 @@ def init_pytorch_model(self, def create_input(self, imgs: Union[str, np.ndarray], input_shape: Optional[Sequence[int]] = None, + pipeline_updater: Optional[Callable] = None, **kwargs) -> Tuple[Dict, torch.Tensor]: """Create input for editing processor. Args: imgs (str | np.ndarray): Input image(s). - input_shape (Sequence[int] | None): A list of two integer in - (width, height) format specifying input shape. Defaults to `None`. + input_shape (Sequence[int] | None): Input shape of image in + (width, height) format, defaults to `None`. + pipeline_updater (function | None): A function to get a new + pipeline. Returns: tuple: (data, img), meta information for the input image and input. diff --git a/mmdeploy/codebase/mmocr/deploy/text_detection.py b/mmdeploy/codebase/mmocr/deploy/text_detection.py index 2f775153e0..d05db78f37 100644 --- a/mmdeploy/codebase/mmocr/deploy/text_detection.py +++ b/mmdeploy/codebase/mmocr/deploy/text_detection.py @@ -1,5 +1,5 @@ # Copyright (c) OpenMMLab. All rights reserved. -from typing import Any, Dict, Optional, Sequence, Tuple, Union +from typing import Any, Callable, Dict, Optional, Sequence, Tuple, Union import mmcv import numpy as np @@ -108,15 +108,18 @@ def init_pytorch_model(self, def create_input(self, imgs: Union[str, np.ndarray, Sequence], - input_shape: Sequence[int] = None) \ + input_shape: Optional[Sequence[int]] = None, + pipeline_updater: Optional[Callable] = None, **kwargs) \ -> Tuple[Dict, torch.Tensor]: """Create input for segmentor. Args: imgs (str | np.ndarray): Input image(s), accepted data type are `str`, `np.ndarray`. - input_shape (list[int]): A list of two integer in (width, height) - format specifying input shape. Defaults to `None`. + input_shape (Sequence[int] | None): Input shape of image in + (width, height) format, defaults to `None`. + pipeline_updater (function | None): A function to get a new + pipeline. Returns: tuple: (data, img), meta information for the input image and input. diff --git a/mmdeploy/codebase/mmocr/deploy/text_recognition.py b/mmdeploy/codebase/mmocr/deploy/text_recognition.py index af8287c719..c8470bca62 100644 --- a/mmdeploy/codebase/mmocr/deploy/text_recognition.py +++ b/mmdeploy/codebase/mmocr/deploy/text_recognition.py @@ -1,5 +1,5 @@ # Copyright (c) OpenMMLab. All rights reserved. -from typing import Any, Dict, Optional, Sequence, Tuple, Union +from typing import Any, Callable, Dict, Optional, Sequence, Tuple, Union import mmcv import numpy as np @@ -114,15 +114,18 @@ def init_pytorch_model(self, def create_input(self, imgs: Union[str, np.ndarray], - input_shape: Sequence[int] = None) \ + input_shape: Optional[Sequence[int]] = None, + pipeline_updater: Optional[Callable] = None, **kwargs) \ -> Tuple[Dict, torch.Tensor]: """Create input for segmentor. Args: imgs (str | np.ndarray): Input image(s), accepted data type are `str`, `np.ndarray`. - input_shape (list[int]): A list of two integer in (width, height) - format specifying input shape. Defaults to `None`. + input_shape (Sequence[int] | None): Input shape of image in + (width, height) format, defaults to `None`. + pipeline_updater (function | None): A function to get a new + pipeline. Returns: tuple: (data, img), meta information for the input image and input. diff --git a/mmdeploy/codebase/mmpose/deploy/pose_detection.py b/mmdeploy/codebase/mmpose/deploy/pose_detection.py index c22ffa9e7e..96be22a5dc 100644 --- a/mmdeploy/codebase/mmpose/deploy/pose_detection.py +++ b/mmdeploy/codebase/mmpose/deploy/pose_detection.py @@ -3,7 +3,7 @@ import copy import logging import os -from typing import Any, Dict, Optional, Sequence, Tuple, Union +from typing import Any, Callable, Dict, Optional, Sequence, Tuple, Union import mmcv import numpy as np @@ -130,15 +130,18 @@ def init_pytorch_model(self, def create_input(self, imgs: Union[str, np.ndarray, Sequence], - input_shape: Sequence[int] = None, + input_shape: Optional[Sequence[int]] = None, + pipeline_updater: Optional[Callable] = None, **kwargs) -> Tuple[Dict, torch.Tensor]: """Create input for pose detection. Args: imgs (Any): Input image(s), accepted data type are ``str``, ``np.ndarray``. - input_shape (list[int]): A list of two integer in (width, height) - format specifying input shape. Defaults to ``None``. + input_shape (Sequence[int] | None): Input shape of image in + (width, height) format, defaults to `None`. + pipeline_updater (function | None): A function to get a new + pipeline. Returns: tuple: (data, img), meta information for the input image and input. diff --git a/mmdeploy/codebase/mmrotate/deploy/rotated_detection.py b/mmdeploy/codebase/mmrotate/deploy/rotated_detection.py index 92b7f73ac2..186c2dcc17 100644 --- a/mmdeploy/codebase/mmrotate/deploy/rotated_detection.py +++ b/mmdeploy/codebase/mmrotate/deploy/rotated_detection.py @@ -1,6 +1,6 @@ # Copyright (c) OpenMMLab. All rights reserved. import copy -from typing import Any, Dict, Optional, Sequence, Tuple, Union +from typing import Any, Callable, Dict, Optional, Sequence, Tuple, Union import mmcv import numpy as np @@ -150,15 +150,18 @@ def init_pytorch_model(self, def create_input(self, imgs: Union[str, np.ndarray], - input_shape: Sequence[int] = None) \ + input_shape: Optional[Sequence[int]] = None, + pipeline_updater: Optional[Callable] = None, **kwargs) \ -> Tuple[Dict, torch.Tensor]: """Create input for rotated object detection. Args: imgs (str | np.ndarray): Input image(s), accepted data type are `str`, `np.ndarray`. - input_shape (list[int]): A list of two integer in (width, height) - format specifying input shape. Defaults to `None`. + input_shape (Sequence[int] | None): Input shape of image in + (width, height) format, defaults to `None`. + pipeline_updater (function | None): A function to get a new + pipeline. Returns: tuple: (data, img), meta information for the input image and input. diff --git a/mmdeploy/codebase/mmseg/deploy/segmentation.py b/mmdeploy/codebase/mmseg/deploy/segmentation.py index 2651d8fe2d..6d800172e2 100644 --- a/mmdeploy/codebase/mmseg/deploy/segmentation.py +++ b/mmdeploy/codebase/mmseg/deploy/segmentation.py @@ -1,5 +1,5 @@ # Copyright (c) OpenMMLab. All rights reserved. -from typing import Any, Dict, Optional, Sequence, Tuple, Union +from typing import Any, Callable, Dict, Optional, Sequence, Tuple, Union import mmcv import numpy as np @@ -33,6 +33,7 @@ def process_model_config(model_cfg: mmcv.Config, if isinstance(imgs[0], np.ndarray): # set loading pipeline type cfg.data.test.pipeline[0] = LoadImage() + # for static exporting if input_shape is not None: for pipeline in cfg.data.test.pipeline[1:]: @@ -107,15 +108,18 @@ def init_pytorch_model(self, def create_input(self, imgs: Union[str, np.ndarray, Sequence], - input_shape: Sequence[int] = None) \ + input_shape: Optional[Sequence[int]] = None, + pipeline_updater: Optional[Callable] = None, **kwargs) \ -> Tuple[Dict, torch.Tensor]: """Create input for segmentor. Args: imgs (Any): Input image(s), accepted data type are `str`, `np.ndarray`, `torch.Tensor`. - input_shape (list[int]): A list of two integer in (width, height) - format specifying input shape. Defaults to `None`. + input_shape (Sequence[int] | None): Input shape of image in + (width, height) format, defaults to `None`. + pipeline_updater (function | None): A function to get a new + pipeline. Returns: tuple: (data, img), meta information for the input image and input. @@ -125,7 +129,10 @@ def create_input(self, if isinstance(imgs, (str, np.ndarray)): imgs = [imgs] imgs = [mmcv.imread(_) for _ in imgs] - cfg = process_model_config(self.model_cfg, imgs, input_shape) + model_cfg = self.model_cfg + if pipeline_updater is not None: + model_cfg = pipeline_updater(self.deploy_cfg, model_cfg) + cfg = process_model_config(model_cfg, imgs, input_shape) test_pipeline = Compose(cfg.data.test.pipeline) data_list = [] for img in imgs: @@ -260,7 +267,8 @@ def get_preprocess(self) -> Dict: """ input_shape = get_input_shape(self.deploy_cfg) load_from_file = self.model_cfg.data.test.pipeline[0] - model_cfg = process_model_config(self.model_cfg, [''], input_shape) + cfg = self.update_test_pipeline(self.deploy_cfg, self.model_cfg) + model_cfg = process_model_config(cfg, [''], input_shape) preprocess = model_cfg.data.test.pipeline preprocess[0] = load_from_file return preprocess diff --git a/mmdeploy/codebase/mmseg/deploy/segmentation_model.py b/mmdeploy/codebase/mmseg/deploy/segmentation_model.py index a135e1aa05..de7c486b3b 100644 --- a/mmdeploy/codebase/mmseg/deploy/segmentation_model.py +++ b/mmdeploy/codebase/mmseg/deploy/segmentation_model.py @@ -162,7 +162,9 @@ def forward_test(self, imgs: torch.Tensor, *args, **kwargs) -> \ List[np.ndarray]: A list of segmentation map. """ outputs = self.wrapper({self.input_name: imgs}) - outputs = [output.argmax(dim=1, keepdim=True) for output in outputs] + outputs = [ + output.argmax(dim=1, keepdim=True) for output in outputs.values() + ] outputs = [out.detach().cpu().numpy() for out in outputs] return outputs diff --git a/mmdeploy/utils/__init__.py b/mmdeploy/utils/__init__.py index b3c44a805f..d52dfb0e24 100644 --- a/mmdeploy/utils/__init__.py +++ b/mmdeploy/utils/__init__.py @@ -21,8 +21,9 @@ get_codebase_config, get_common_config, get_dynamic_axes, get_input_shape, get_ir_config, get_model_inputs, - get_onnx_config, get_partition_config, - get_quantization_config, get_task_type, + get_normalization, get_onnx_config, + get_partition_config, get_quantization_config, + get_rknn_quantization, get_task_type, is_dynamic_batch, is_dynamic_shape, load_config) # yapf: enable @@ -33,5 +34,6 @@ 'get_codebase_config', 'get_common_config', 'get_dynamic_axes', 'get_input_shape', 'get_ir_config', 'get_model_inputs', 'get_onnx_config', 'get_partition_config', 'get_quantization_config', - 'get_task_type', 'is_dynamic_batch', 'is_dynamic_shape', 'load_config' + 'get_task_type', 'is_dynamic_batch', 'is_dynamic_shape', 'load_config', + 'get_rknn_quantization', 'get_normalization' ] diff --git a/mmdeploy/utils/config_utils.py b/mmdeploy/utils/config_utils.py index 212df1ef57..e8737996aa 100644 --- a/mmdeploy/utils/config_utils.py +++ b/mmdeploy/utils/config_utils.py @@ -393,3 +393,40 @@ def get_dynamic_axes( raise KeyError('No names were found to define dynamic axes.') dynamic_axes = dict(zip(axes_names, dynamic_axes)) return dynamic_axes + + +def get_normalization(model_cfg: Union[str, mmcv.Config]): + """Get the Normalize transform from model config. + + Args: + model_cfg (mmcv.Config): The content of config. + + Returns: + dict: The Normalize transform. + """ + model_cfg = load_config(model_cfg)[0] + pipelines = model_cfg.data.test.pipeline + for i, pipeline in enumerate(pipelines): + if pipeline['type'] == 'MultiScaleFlipAug': + assert 'transforms' in pipeline + for trans in pipeline['transforms']: + if trans['type'] == 'Normalize': + return trans + else: + if pipeline['type'] == 'Normalize': + return pipeline + + +def get_rknn_quantization(deploy_cfg: mmcv.Config): + """Get the flag of `do_quantization` for rknn backend. + + Args: + deploy_cfg (mmcv.Config): The content of config. + + Returns: + bool: Do quantization or not. + """ + if get_backend(deploy_cfg) == Backend.RKNN: + return get_backend_config( + deploy_cfg)['quantization_config']['do_quantization'] + return False diff --git a/tools/deploy.py b/tools/deploy.py index ca9df6014e..a73ce1947c 100644 --- a/tools/deploy.py +++ b/tools/deploy.py @@ -372,6 +372,7 @@ def main(): onnx_path, output_path, deploy_cfg_path, + model_cfg_path, dataset_file=dataset_file) backend_files.append(output_path)