From e136a26b26300ee5c5bbdcba7738ee7076f8d598 Mon Sep 17 00:00:00 2001 From: Yue Zhou <592267829@qq.com> Date: Mon, 18 Apr 2022 17:14:47 +0800 Subject: [PATCH] Add nms_rotated ort op (#312) * fix pose demo and windows build (#307) * init * Update nms_rotated.cpp * add postprocessing_masks gpu version (#276) * add postprocessing_masks gpu version * default device cpu * pre-commit fix Co-authored-by: hadoop-basecv * fixed a bug causes text-recognizer to fail when (non-NULL) empty bboxes list is passed (#310) * [Fix] include missing for formatter.h (#313) * fix formatter * relax GCC version requirement * fix * fix lint * fix lint * [Fix] MMEditing cannot save results when testing (#336) * fix show * lint * remove redundant codes * resolve comment * type hint * docs(build): fix typo (#352) * docs(build): add missing build option * docs(build): add onnx install * style(doc): trim whitespace * docs(build): revert install onnx * docs(build): add ncnn LD_LIBRARY_PATH * docs(build): fix path error * fix openvino export tmp model, add binary flag (#353) * init circleci (#348) * fix wrong input mat type (#362) * fix wrong input mat type * fix lint * fix(docs): remove redundant doc tree (#360) * fix missing ncnn_DIR & InferenceEngine_DIR (#364) * update doc Co-authored-by: Chen Xin Co-authored-by: Shengxi Li <982783556@qq.com> Co-authored-by: hadoop-basecv Co-authored-by: lzhangzz Co-authored-by: Yifan Zhou Co-authored-by: tpoisonooo Co-authored-by: lvhan028 --- .../onnxruntime/nms_rotated/nms_rotated.cpp | 352 ++++++++++++++++++ .../onnxruntime/nms_rotated/nms_rotated.h | 48 +++ docs/en/backends/onnxruntime.md | 1 + docs/en/ops/onnxruntime.md | 39 ++ mmdeploy/mmcv/ops/__init__.py | 1 + mmdeploy/mmcv/ops/nms_rotated.py | 51 +++ tests/test_ops/test_ops.py | 31 ++ 7 files changed, 523 insertions(+) create mode 100644 csrc/backend_ops/onnxruntime/nms_rotated/nms_rotated.cpp create mode 100644 csrc/backend_ops/onnxruntime/nms_rotated/nms_rotated.h create mode 100644 mmdeploy/mmcv/ops/nms_rotated.py diff --git a/csrc/backend_ops/onnxruntime/nms_rotated/nms_rotated.cpp b/csrc/backend_ops/onnxruntime/nms_rotated/nms_rotated.cpp new file mode 100644 index 0000000000..9858772abd --- /dev/null +++ b/csrc/backend_ops/onnxruntime/nms_rotated/nms_rotated.cpp @@ -0,0 +1,352 @@ +// Copyright (c) OpenMMLab. All rights reserved +#include "nms_rotated.h" + +#include + +#include +#include +#include +#include +#include +#include // std::iota +#include + +#include "ort_utils.h" + +namespace mmdeploy { + +namespace { +struct RotatedBox { + float x_ctr, y_ctr, w, h, a; +}; +struct Point { + float x, y; + Point(const float& px = 0, const float& py = 0) : x(px), y(py) {} + Point operator+(const Point& p) const { return Point(x + p.x, y + p.y); } + Point& operator+=(const Point& p) { + x += p.x; + y += p.y; + return *this; + } + Point operator-(const Point& p) const { return Point(x - p.x, y - p.y); } + Point operator*(const float coeff) const { return Point(x * coeff, y * coeff); } +}; + +float dot_2d(const Point& A, const Point& B) { return A.x * B.x + A.y * B.y; } + +float cross_2d(const Point& A, const Point& B) { return A.x * B.y - B.x * A.y; } +} // namespace + +void get_rotated_vertices(const RotatedBox& box, Point (&pts)[4]) { + // M_PI / 180. == 0.01745329251 + // double theta = box.a * 0.01745329251; + // MODIFIED + double theta = box.a; + float cosTheta2 = (float)cos(theta) * 0.5f; + float sinTheta2 = (float)sin(theta) * 0.5f; + + // y: top --> down; x: left --> right + pts[0].x = box.x_ctr - sinTheta2 * box.h - cosTheta2 * box.w; + pts[0].y = box.y_ctr + cosTheta2 * box.h - sinTheta2 * box.w; + pts[1].x = box.x_ctr + sinTheta2 * box.h - cosTheta2 * box.w; + pts[1].y = box.y_ctr - cosTheta2 * box.h - sinTheta2 * box.w; + pts[2].x = 2 * box.x_ctr - pts[0].x; + pts[2].y = 2 * box.y_ctr - pts[0].y; + pts[3].x = 2 * box.x_ctr - pts[1].x; + pts[3].y = 2 * box.y_ctr - pts[1].y; +} + +int get_intersection_points(const Point (&pts1)[4], const Point (&pts2)[4], + Point (&intersections)[24]) { + // Line vector + // A line from p1 to p2 is: p1 + (p2-p1)*t, t=[0,1] + Point vec1[4], vec2[4]; + for (int i = 0; i < 4; i++) { + vec1[i] = pts1[(i + 1) % 4] - pts1[i]; + vec2[i] = pts2[(i + 1) % 4] - pts2[i]; + } + + // Line test - test all line combos for intersection + int num = 0; // number of intersections + for (int i = 0; i < 4; i++) { + for (int j = 0; j < 4; j++) { + // Solve for 2x2 Ax=b + float det = cross_2d(vec2[j], vec1[i]); + + // This takes care of parallel lines + if (fabs(det) <= 1e-14) { + continue; + } + + auto vec12 = pts2[j] - pts1[i]; + + float t1 = cross_2d(vec2[j], vec12) / det; + float t2 = cross_2d(vec1[i], vec12) / det; + + if (t1 >= 0.0f && t1 <= 1.0f && t2 >= 0.0f && t2 <= 1.0f) { + intersections[num++] = pts1[i] + vec1[i] * t1; + } + } + } + + // Check for vertices of rect1 inside rect2 + { + const auto& AB = vec2[0]; + const auto& DA = vec2[3]; + auto ABdotAB = dot_2d(AB, AB); + auto ADdotAD = dot_2d(DA, DA); + for (int i = 0; i < 4; i++) { + // assume ABCD is the rectangle, and P is the point to be judged + // P is inside ABCD iff. P's projection on AB lies within AB + // and P's projection on AD lies within AD + + auto AP = pts1[i] - pts2[0]; + + auto APdotAB = dot_2d(AP, AB); + auto APdotAD = -dot_2d(AP, DA); + + if ((APdotAB >= 0) && (APdotAD >= 0) && (APdotAB <= ABdotAB) && (APdotAD <= ADdotAD)) { + intersections[num++] = pts1[i]; + } + } + } + + // Reverse the check - check for vertices of rect2 inside rect1 + { + const auto& AB = vec1[0]; + const auto& DA = vec1[3]; + auto ABdotAB = dot_2d(AB, AB); + auto ADdotAD = dot_2d(DA, DA); + for (int i = 0; i < 4; i++) { + auto AP = pts2[i] - pts1[0]; + + auto APdotAB = dot_2d(AP, AB); + auto APdotAD = -dot_2d(AP, DA); + + if ((APdotAB >= 0) && (APdotAD >= 0) && (APdotAB <= ABdotAB) && (APdotAD <= ADdotAD)) { + intersections[num++] = pts2[i]; + } + } + } + + return num; +} + +int convex_hull_graham(const Point (&p)[24], const int& num_in, Point (&q)[24], + bool shift_to_zero = false) { + assert(num_in >= 2); + + // Step 1: + // Find point with minimum y + // if more than 1 points have the same minimum y, + // pick the one with the minimum x. + int t = 0; + for (int i = 1; i < num_in; i++) { + if (p[i].y < p[t].y || (p[i].y == p[t].y && p[i].x < p[t].x)) { + t = i; + } + } + auto& start = p[t]; // starting point + + // Step 2: + // Subtract starting point from every points (for sorting in the next step) + for (int i = 0; i < num_in; i++) { + q[i] = p[i] - start; + } + + // Swap the starting point to position 0 + auto tmp = q[0]; + q[0] = q[t]; + q[t] = tmp; + + // Step 3: + // Sort point 1 ~ num_in according to their relative cross-product values + // (essentially sorting according to angles) + // If the angles are the same, sort according to their distance to origin + float dist[24]; + for (int i = 0; i < num_in; i++) { + dist[i] = dot_2d(q[i], q[i]); + } + + // CPU version + std::sort(q + 1, q + num_in, [](const Point& A, const Point& B) -> bool { + float temp = cross_2d(A, B); + if (fabs(temp) < 1e-6) { + return dot_2d(A, A) < dot_2d(B, B); + } else { + return temp > 0; + } + }); + // compute distance to origin after sort, since the points are now different. + for (int i = 0; i < num_in; i++) { + dist[i] = dot_2d(q[i], q[i]); + } + + // Step 4: + // Make sure there are at least 2 points (that don't overlap with each other) + // in the stack + int k; // index of the non-overlapped second point + for (k = 1; k < num_in; k++) { + if (dist[k] > 1e-8) { + break; + } + } + if (k == num_in) { + // We reach the end, which means the convex hull is just one point + q[0] = p[t]; + return 1; + } + q[1] = q[k]; + int m = 2; // 2 points in the stack + // Step 5: + // Finally we can start the scanning process. + // When a non-convex relationship between the 3 points is found + // (either concave shape or duplicated points), + // we pop the previous point from the stack + // until the 3-point relationship is convex again, or + // until the stack only contains two points + for (int i = k + 1; i < num_in; i++) { + while (m > 1 && cross_2d(q[i] - q[m - 2], q[m - 1] - q[m - 2]) >= 0) { + m--; + } + q[m++] = q[i]; + } + + // Step 6 (Optional): + // In general sense we need the original coordinates, so we + // need to shift the points back (reverting Step 2) + // But if we're only interested in getting the area/perimeter of the shape + // We can simply return. + if (!shift_to_zero) { + for (int i = 0; i < m; i++) { + q[i] += start; + } + } + + return m; +} + +float polygon_area(const Point (&q)[24], const int& m) { + if (m <= 2) { + return 0; + } + + float area = 0; + for (int i = 1; i < m - 1; i++) { + area += fabs(cross_2d(q[i] - q[0], q[i + 1] - q[0])); + } + + return area / 2.0; +} + +float rotated_boxes_intersection(const RotatedBox& box1, const RotatedBox& box2) { + // There are up to 4 x 4 + 4 + 4 = 24 intersections (including dups) returned + // from rotated_rect_intersection_pts + Point intersectPts[24], orderedPts[24]; + + Point pts1[4]; + Point pts2[4]; + get_rotated_vertices(box1, pts1); + get_rotated_vertices(box2, pts2); + + int num = get_intersection_points(pts1, pts2, intersectPts); + + if (num <= 2) { + return 0.0; + } + + // Convex Hull to order the intersection points in clockwise order and find + // the contour area. + int num_convex = convex_hull_graham(intersectPts, num, orderedPts, true); + return polygon_area(orderedPts, num_convex); +} + +NMSRotatedKernel::NMSRotatedKernel(OrtApi api, const OrtKernelInfo* info) + : api_(api), ort_(api_), info_(info) { + iou_threshold_ = ort_.KernelInfoGetAttribute(info, "iou_threshold"); + + // create allocator + allocator_ = Ort::AllocatorWithDefaultOptions(); +} + +void NMSRotatedKernel::Compute(OrtKernelContext* context) { + const float iou_threshold = iou_threshold_; + + const OrtValue* boxes = ort_.KernelContext_GetInput(context, 0); + const float* boxes_data = reinterpret_cast(ort_.GetTensorData(boxes)); + const OrtValue* scores = ort_.KernelContext_GetInput(context, 1); + const float* scores_data = reinterpret_cast(ort_.GetTensorData(scores)); + + OrtTensorDimensions boxes_dim(ort_, boxes); + OrtTensorDimensions scores_dim(ort_, scores); + + int64_t nboxes = boxes_dim[0]; + assert(boxes_dim[1] == 5); //(cx,cy,w,h,theta) + + // allocate tmp memory + float* tmp_boxes = (float*)allocator_.Alloc(sizeof(float) * nboxes * 5); + float* sc = (float*)allocator_.Alloc(sizeof(float) * nboxes); + bool* select = (bool*)allocator_.Alloc(sizeof(bool) * nboxes); + for (int64_t i = 0; i < nboxes; i++) { + select[i] = true; + } + + memcpy(tmp_boxes, boxes_data, sizeof(float) * nboxes * 5); + memcpy(sc, scores_data, sizeof(float) * nboxes); + + // sort scores + std::vector tmp_sc; + for (int i = 0; i < nboxes; i++) { + tmp_sc.push_back(sc[i]); + } + std::vector order(tmp_sc.size()); + std::iota(order.begin(), order.end(), 0); + std::sort(order.begin(), order.end(), + [&tmp_sc](int64_t id1, int64_t id2) { return tmp_sc[id1] > tmp_sc[id2]; }); + + for (int64_t _i = 0; _i < nboxes; _i++) { + if (select[_i] == false) continue; + auto i = order[_i]; + + for (int64_t _j = _i + 1; _j < nboxes; _j++) { + if (select[_j] == false) continue; + auto j = order[_j]; + RotatedBox box1, box2; + auto center_shift_x = (tmp_boxes[i * 5] + tmp_boxes[j * 5]) / 2.0; + auto center_shift_y = (tmp_boxes[i * 5 + 1] + tmp_boxes[j * 5 + 1]) / 2.0; + box1.x_ctr = tmp_boxes[i * 5] - center_shift_x; + box1.y_ctr = tmp_boxes[i * 5 + 1] - center_shift_y; + box1.w = tmp_boxes[i * 5 + 2]; + box1.h = tmp_boxes[i * 5 + 3]; + box1.a = tmp_boxes[i * 5 + 4]; + box2.x_ctr = tmp_boxes[j * 5] - center_shift_x; + box2.y_ctr = tmp_boxes[j * 5 + 1] - center_shift_y; + box2.w = tmp_boxes[j * 5 + 2]; + box2.h = tmp_boxes[j * 5 + 3]; + box2.a = tmp_boxes[j * 5 + 4]; + auto area1 = box1.w * box1.h; + auto area2 = box2.w * box2.h; + auto intersection = rotated_boxes_intersection(box1, box2); + float baseS = 1.0; + baseS = (area1 + area2 - intersection); + auto ovr = intersection / baseS; + if (ovr > iou_threshold) select[_j] = false; + } + } + std::vector res_order; + for (int i = 0; i < nboxes; i++) { + if (select[i]) { + res_order.push_back(order[i]); + } + } + + std::vector inds_dims({(int64_t)res_order.size()}); + + OrtValue* res = ort_.KernelContext_GetOutput(context, 0, inds_dims.data(), inds_dims.size()); + int64_t* res_data = ort_.GetTensorMutableData(res); + + memcpy(res_data, res_order.data(), sizeof(int64_t) * res_order.size()); +} + +REGISTER_ONNXRUNTIME_OPS(mmdeploy, NMSRotatedOp); +} // namespace mmdeploy diff --git a/csrc/backend_ops/onnxruntime/nms_rotated/nms_rotated.h b/csrc/backend_ops/onnxruntime/nms_rotated/nms_rotated.h new file mode 100644 index 0000000000..0c0f273a49 --- /dev/null +++ b/csrc/backend_ops/onnxruntime/nms_rotated/nms_rotated.h @@ -0,0 +1,48 @@ +// Copyright (c) OpenMMLab. All rights reserved. +#ifndef ONNXRUNTIME_NMS_ROTATED_H +#define ONNXRUNTIME_NMS_ROTATED_H + +#include +#include + +#include +#include +#include +#include + +namespace mmdeploy { +struct NMSRotatedKernel { + NMSRotatedKernel(OrtApi api, const OrtKernelInfo* info); + + void Compute(OrtKernelContext* context); + + private: + OrtApi api_; + Ort::CustomOpApi ort_; + const OrtKernelInfo* info_; + Ort::AllocatorWithDefaultOptions allocator_; + float iou_threshold_; +}; + +struct NMSRotatedOp : Ort::CustomOpBase { + void* CreateKernel(OrtApi api, const OrtKernelInfo* info) const { + return new NMSRotatedKernel(api, info); + } + const char* GetName() const { return "NMSRotated"; } + + size_t GetInputTypeCount() const { return 2; } + ONNXTensorElementDataType GetInputType(size_t) const { + return ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT; + } + + size_t GetOutputTypeCount() const { return 1; } + ONNXTensorElementDataType GetOutputType(size_t) const { + return ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64; + } + + // force cpu + const char* GetExecutionProviderType() const { return "CPUExecutionProvider"; } +}; +} // namespace mmdeploy + +#endif // ONNXRUNTIME_NMS_ROTATED_H diff --git a/docs/en/backends/onnxruntime.md b/docs/en/backends/onnxruntime.md index 181e655094..2b32fa6336 100644 --- a/docs/en/backends/onnxruntime.md +++ b/docs/en/backends/onnxruntime.md @@ -59,6 +59,7 @@ make -j$(nproc) | :--------------------------------------------------------------------------- | :---: | :---: | :---------------- | | [grid_sampler](../ops/onnxruntime.md#grid_sampler) | Y | N | master | | [MMCVModulatedDeformConv2d](../ops/onnxruntime.md#mmcvmodulateddeformconv2d) | Y | N | master | +| [NMSRotated](../ops/onnxruntime.md#nmsrotated) | Y | N | master | ### How to add a new custom op diff --git a/docs/en/ops/onnxruntime.md b/docs/en/ops/onnxruntime.md index 51791ebc9f..60c5f8bb52 100644 --- a/docs/en/ops/onnxruntime.md +++ b/docs/en/ops/onnxruntime.md @@ -15,6 +15,12 @@ - [Inputs](#inputs-1) - [Outputs](#outputs-1) - [Type Constraints](#type-constraints-1) +- [NMSRotated](#nmsrotated) + - [Description](#description-2) + - [Parameters](#parameters-2) + - [Inputs](#inputs-2) + - [Outputs](#outputs-2) + - [Type Constraints](#type-constraints-2) @@ -93,3 +99,36 @@ Perform Modulated Deformable Convolution on input feature, read [Deformable Conv #### Type Constraints - T:tensor(float32, Linear) + +### NMSRotated + +#### Description + +Non Max Suppression for rotated bboxes. + +#### Parameters + +| Type | Parameter | Description | +| -------------- | ------------------- | ------------------------------------------------------------------------------------- | +| `float` | `iou_threshold` | The IoU threshold for NMS. | + + +#### Inputs + +
+
inputs[0]: T
+
Input feature; 2-D tensor of shape (N, 5), where N is the number of rotated bboxes, .
+
inputs[1]: T
+
Input offset; 1-D tensor of shape (N, ), where N is the number of rotated bboxes.
+
+ +#### Outputs + +
+
outputs[0]: T
+
Output feature; 1-D tensor of shape (K, ), where K is the number of keep bboxes.
+
+ +#### Type Constraints + +- T:tensor(float32, Linear) diff --git a/mmdeploy/mmcv/ops/__init__.py b/mmdeploy/mmcv/ops/__init__.py index f839e64b99..e76ef7780d 100644 --- a/mmdeploy/mmcv/ops/__init__.py +++ b/mmdeploy/mmcv/ops/__init__.py @@ -2,6 +2,7 @@ from .deform_conv import deform_conv_openvino from .modulated_deform_conv import modulated_deform_conv_default from .nms import * # noqa: F401,F403 +from .nms_rotated import * # noqa: F401,F403 from .roi_align import roi_align_default __all__ = [ diff --git a/mmdeploy/mmcv/ops/nms_rotated.py b/mmdeploy/mmcv/ops/nms_rotated.py new file mode 100644 index 0000000000..3ba763cdad --- /dev/null +++ b/mmdeploy/mmcv/ops/nms_rotated.py @@ -0,0 +1,51 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import torch +from torch import Tensor + + +class ONNXNMSRotatedOp(torch.autograd.Function): + """Create onnx::NMSRotated op.""" + + @staticmethod + def forward(ctx, boxes: Tensor, scores: Tensor, + iou_threshold: float) -> Tensor: + """Get NMS rotated output indices. + + Args: + ctx (Context): The context with meta information. + boxes (Tensor): The bounding boxes of shape [N, num_boxes, 4]. + scores (Tensor): The detection scores of shape + [N, num_boxes, num_classes]. + iou_threshold (float): IOU threshold of nms. + + Returns: + Tensor: Selected indices of boxes. + """ + from mmcv.utils import ext_loader + ext_module = ext_loader.load_ext('_ext', ['nms_rotated']) + + _, order = scores.sort(0, descending=True) + dets_sorted = boxes.index_select(0, order) + keep_inds = ext_module.nms_rotated(boxes, scores, order, dets_sorted, + iou_threshold, 0) + return keep_inds + + @staticmethod + def symbolic(g, boxes: Tensor, scores: Tensor, iou_threshold: float): + """Symbolic function for onnx::NMSRotated. + + Args: + g (Graph): The traced onnx graph. + boxes (Tensor): The bounding boxes of shape [N, num_boxes, 4]. + scores (Tensor): The detection scores of shape + [N, num_boxes, num_classes]. + iou_threshold (float): IOU threshold of nms. + + Returns: + NMSRotated op for onnx. + """ + return g.op( + 'mmdeploy::NMSRotated', + boxes, + scores, + iou_threshold_f=float(iou_threshold)) diff --git a/tests/test_ops/test_ops.py b/tests/test_ops/test_ops.py index 54ab2d7b12..c6786833fd 100644 --- a/tests/test_ops/test_ops.py +++ b/tests/test_ops/test_ops.py @@ -774,3 +774,34 @@ def expand_function(input, target): input_names=['input', 'shape'], output_names=['output'], save_dir=save_dir) + + +@pytest.mark.parametrize('backend', [TEST_ONNXRT]) +@pytest.mark.parametrize('iou_threshold', [0.1, 0.3]) +def test_nms_rotated(backend, iou_threshold, save_dir=None): + backend.check_env() + + boxes = torch.tensor( + [[60, 75, 20, 50, 0], [65, 80, 10, 40, 0], [30, 30, 40, 40, 0]], + dtype=torch.float32) + scores = torch.tensor([0.5, 0.6, 0.7], dtype=torch.float32) + + from mmdeploy.mmcv.ops import ONNXNMSRotatedOp + + def wrapped_function(torch_boxes, torch_scores): + return ONNXNMSRotatedOp.apply(torch_boxes, torch_scores, iou_threshold) + + wrapped_model = WrapFunction(wrapped_function).eval() + + with RewriterContext( + Config({'backend_config': { + 'type': backend.backend_name + }}), + backend=backend.backend_name, + opset=11): + backend.run_and_validate( + wrapped_model, [boxes, scores], + 'nms_rotated', + input_names=['boxes', 'scores'], + output_names=['keep_inds'], + save_dir=save_dir)