From db6b0542c7ca13edcbad0bb90687bf680a685844 Mon Sep 17 00:00:00 2001 From: v-qjqs Date: Sat, 1 May 2021 16:28:25 +0800 Subject: [PATCH] [Feature]: Support corner_pool related custom operators for onnxruntime in mmcv (#997) * supports for onnxruntime custom op `mmcv::MMCVTopPool` * supports for onnxruntime custom op `mmcv::MMCVCornerPool`, involving TopPool, BottomPool, LeftPool and RightPool * add unittest for corner_pool * supports mmcv::CornerPool without memcpy * add docs for mmcv::CornerPool * re-add docs for mmcv::CornerPool * fix output dtype doc * reformat * format with pre-commit * format * fix lint error, by using google clang-format style for c/c++ --- docs/onnxruntime_custom_ops.md | 36 ++++++ docs/onnxruntime_op.md | 1 + mmcv/ops/corner_pool.py | 26 ++++ mmcv/ops/csrc/onnxruntime/corner_pool.h | 45 +++++++ mmcv/ops/csrc/onnxruntime/cpu/corner_pool.cpp | 122 ++++++++++++++++++ .../onnxruntime/cpu/onnxruntime_register.cpp | 7 + tests/test_ops/test_onnx.py | 46 +++++++ 7 files changed, 283 insertions(+) create mode 100644 mmcv/ops/csrc/onnxruntime/corner_pool.h create mode 100644 mmcv/ops/csrc/onnxruntime/cpu/corner_pool.cpp diff --git a/docs/onnxruntime_custom_ops.md b/docs/onnxruntime_custom_ops.md index e42032d23d..837d947184 100644 --- a/docs/onnxruntime_custom_ops.md +++ b/docs/onnxruntime_custom_ops.md @@ -27,6 +27,12 @@ - [Inputs](#inputs-3) - [Outputs](#outputs-3) - [Type Constraints](#type-constraints-3) + - [CornerPool](#cornerpool) + - [Description](#description-4) + - [Parameters](#parameters-4) + - [Inputs](#inputs-4) + - [Outputs](#outputs-4) + - [Type Constraints](#type-constraints-4) @@ -171,3 +177,33 @@ Perform sample from `input` with pixel locations from `grid`. ### Type Constraints - T:tensor(float32, Linear) + +## CornerPool + +### Description + +Perform CornerPool on `input` features. Read [CornerNet -- Detecting Objects as Paired Keypoints](https://arxiv.org/abs/1808.01244) for more details. + +### Parameters + +| Type | Parameter | Description | +| ------- | --------------- | ---------------------------------------------------------------- | +| `int` | `mode` | corner pool mode, (0: `top`, 1: `bottom`, 2: `left`, 3: `right`) | + +### Inputs + +
+
input: T
+
Input features. 4-D tensor of shape (N, C, H, W). N is the batch size.
+
+ +### Outputs + +
+
output: T
+
Output the pooled features. 4-D tensor of shape (N, C, H, W).
+
+ +### Type Constraints + +- T:tensor(float32) diff --git a/docs/onnxruntime_op.md b/docs/onnxruntime_op.md index 9324524e39..0e2f62adb4 100644 --- a/docs/onnxruntime_op.md +++ b/docs/onnxruntime_op.md @@ -21,6 +21,7 @@ | [RoIAlign](onnxruntime_custom_ops.md#roialign) | Y | N | 1.2.5 | | [NMS](onnxruntime_custom_ops.md#nms) | Y | N | 1.2.7 | | [grid_sampler](onnxruntime_custom_ops.md#grid_sampler) | Y | N | master | +| [CornerPool](onnxruntime_custom_ops.md#cornerpool) | Y | N | master | ## How to build custom operators for ONNX Runtime diff --git a/mmcv/ops/corner_pool.py b/mmcv/ops/corner_pool.py index 6b0d871933..189506e6aa 100644 --- a/mmcv/ops/corner_pool.py +++ b/mmcv/ops/corner_pool.py @@ -10,9 +10,17 @@ 'right_pool_forward', 'right_pool_backward' ]) +_mode_dict = {'top': 0, 'bottom': 1, 'left': 2, 'right': 3} + class TopPoolFunction(Function): + @staticmethod + def symbolic(g, input): + output = g.op( + 'mmcv::MMCVCornerPool', input, mode_i=int(_mode_dict['top'])) + return output + @staticmethod def forward(ctx, input): output = ext_module.top_pool_forward(input) @@ -28,6 +36,12 @@ def backward(ctx, grad_output): class BottomPoolFunction(Function): + @staticmethod + def symbolic(g, input): + output = g.op( + 'mmcv::MMCVCornerPool', input, mode_i=int(_mode_dict['bottom'])) + return output + @staticmethod def forward(ctx, input): output = ext_module.bottom_pool_forward(input) @@ -43,6 +57,12 @@ def backward(ctx, grad_output): class LeftPoolFunction(Function): + @staticmethod + def symbolic(g, input): + output = g.op( + 'mmcv::MMCVCornerPool', input, mode_i=int(_mode_dict['left'])) + return output + @staticmethod def forward(ctx, input): output = ext_module.left_pool_forward(input) @@ -58,6 +78,12 @@ def backward(ctx, grad_output): class RightPoolFunction(Function): + @staticmethod + def symbolic(g, input): + output = g.op( + 'mmcv::MMCVCornerPool', input, mode_i=int(_mode_dict['right'])) + return output + @staticmethod def forward(ctx, input): output = ext_module.right_pool_forward(input) diff --git a/mmcv/ops/csrc/onnxruntime/corner_pool.h b/mmcv/ops/csrc/onnxruntime/corner_pool.h new file mode 100644 index 0000000000..4edca2cb8f --- /dev/null +++ b/mmcv/ops/csrc/onnxruntime/corner_pool.h @@ -0,0 +1,45 @@ +#ifndef ONNXRUNTIME_CORNER_POOL_H +#define ONNXRUNTIME_CORNER_POOL_H + +#include +#include + +struct MMCVCornerPoolKernel { + public: + MMCVCornerPoolKernel(Ort::CustomOpApi ort, const OrtKernelInfo* info) + : ort_(ort) { + mode_ = ort_.KernelInfoGetAttribute(info, "mode"); + } + + void Compute(OrtKernelContext* context); + + private: + Ort::CustomOpApi ort_; + + int64_t mode_; +}; + +struct MMCVCornerPoolCustomOp + : Ort::CustomOpBase { + void* CreateKernel(Ort::CustomOpApi api, const OrtKernelInfo* info) { + return new MMCVCornerPoolKernel(api, info); + } + + const char* GetName() const { return "MMCVCornerPool"; } + + size_t GetInputTypeCount() const { return 1; } + ONNXTensorElementDataType GetInputType(size_t) const { + return ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT; + } + + size_t GetOutputTypeCount() const { return 1; } + ONNXTensorElementDataType GetOutputType(size_t) const { + return ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT; + } + + // force cpu + const char* GetExecutionProviderType() const { + return "CPUExecutionProvider"; + } +}; +#endif // ONNXRUNTIME_CORNER_POOL_H diff --git a/mmcv/ops/csrc/onnxruntime/cpu/corner_pool.cpp b/mmcv/ops/csrc/onnxruntime/cpu/corner_pool.cpp new file mode 100644 index 0000000000..d9d4dc3aad --- /dev/null +++ b/mmcv/ops/csrc/onnxruntime/cpu/corner_pool.cpp @@ -0,0 +1,122 @@ +#include "corner_pool.h" + +#include "../ort_mmcv_utils.h" + +void TopPoolForwardCPU(const float *input, float *output, const int batch_size, + const int channels, const int height, const int width) { + for (int n = 0; n < batch_size; n++) { + int index_n = n * channels * width * height; + for (int c = 0; c < channels; c++) { + int index_n_c = index_n + c * width * height; + for (int w = 0; w < width; w++) { + // directly copy the most bottom value from input to output + output[index_n_c + (height - 1) * width + w] = + input[index_n_c + (height - 1) * width + w]; + // do top_pool + for (int h = height - 2; h >= 0; h--) { + output[index_n_c + h * width + w] = + std::max(output[index_n_c + (h + 1) * width + w], + input[index_n_c + h * width + w]); + } // for h + } // for w + } // for c + } // for n +} + +void BottomPoolForwardCPU(const float *input, float *output, + const int batch_size, const int channels, + const int height, const int width) { + for (int n = 0; n < batch_size; n++) { + int index_n = n * channels * width * height; + for (int c = 0; c < channels; c++) { + int index_n_c = index_n + c * width * height; + for (int w = 0; w < width; w++) { + // directly copy the most top value from input to output + output[index_n_c + w] = input[index_n_c + w]; + // do top_pool + for (int h = 1; h < height; h++) { + output[index_n_c + h * width + w] = + std::max(output[index_n_c + (h - 1) * width + w], + input[index_n_c + h * width + w]); + } // for h + } // for w + } // for c + } // for n +} + +void LeftPoolForwardCPU(const float *input, float *output, const int batch_size, + const int channels, const int height, const int width) { + for (int n = 0; n < batch_size; n++) { + int index_n = n * channels * width * height; + for (int c = 0; c < channels; c++) { + int index_n_c = index_n + c * width * height; + for (int h = 0; h < height; h++) { + // directly copy the most right value from input to output + output[index_n_c + h * width + width - 1] = + input[index_n_c + h * width + width - 1]; + // do left_pool + for (int w = width - 2; w >= 0; w--) { + output[index_n_c + h * width + w] = + std::max(output[index_n_c + h * width + w + 1], + input[index_n_c + h * width + w]); + } // for w + } // for h + } // for c + } // for n +} + +void RightPoolForwardCPU(const float *input, float *output, + const int batch_size, const int channels, + const int height, const int width) { + for (int n = 0; n < batch_size; n++) { + int index_n = n * channels * width * height; + for (int c = 0; c < channels; c++) { + int index_n_c = index_n + c * width * height; + for (int h = 0; h < height; h++) { + // directly copy the most left value from input to output + output[index_n_c + h * width] = input[index_n_c + h * width]; + // do right_pool + for (int w = 1; w < width; w++) { + output[index_n_c + h * width + w] = + std::max(output[index_n_c + h * width + w - 1], + input[index_n_c + h * width + w]); + } // for w + } // for h + } // for c + } // for n +} + +void MMCVCornerPoolKernel::Compute(OrtKernelContext *context) { + const int mode = int(mode_); + typedef float T; + const OrtValue *input = ort_.KernelContext_GetInput(context, 0); + const T *input_data = + reinterpret_cast(ort_.GetTensorData(input)); + + // get output memory + OrtTensorDimensions out_dimensions(ort_, input); + OrtValue *output = ort_.KernelContext_GetOutput( + context, 0, out_dimensions.data(), out_dimensions.size()); + T *output_data = ort_.GetTensorMutableData(output); + + // 'top': 0, 'bottom': 1, 'left': 2, 'right':3 + assert(mode == 0 || mode == 1 || mode == 2 || mode == 3); + + // do corner_pool + int batch_size = out_dimensions.data()[0]; + int input_channels = out_dimensions.data()[1]; + int input_height = out_dimensions.data()[2]; + int input_width = out_dimensions.data()[3]; + if (mode == 0) + TopPoolForwardCPU(input_data, output_data, batch_size, input_channels, + input_height, input_width); + else if (mode == 1) + BottomPoolForwardCPU(input_data, output_data, batch_size, input_channels, + input_height, input_width); + else if (mode == 2) + LeftPoolForwardCPU(input_data, output_data, batch_size, input_channels, + input_height, input_width); + else + RightPoolForwardCPU(input_data, output_data, batch_size, input_channels, + input_height, input_width); +} diff --git a/mmcv/ops/csrc/onnxruntime/cpu/onnxruntime_register.cpp b/mmcv/ops/csrc/onnxruntime/cpu/onnxruntime_register.cpp index a46e5b6215..b55114b188 100644 --- a/mmcv/ops/csrc/onnxruntime/cpu/onnxruntime_register.cpp +++ b/mmcv/ops/csrc/onnxruntime/cpu/onnxruntime_register.cpp @@ -1,5 +1,6 @@ #include "onnxruntime_register.h" +#include "corner_pool.h" #include "grid_sample.h" #include "nms.h" #include "ort_mmcv_utils.h" @@ -13,6 +14,7 @@ NmsOp c_NmsOp; MMCVRoiAlignCustomOp c_MMCVRoiAlignCustomOp; MMCVRoIAlignRotatedCustomOp c_MMCVRoIAlignRotatedCustomOp; GridSampleOp c_GridSampleOp; +MMCVCornerPoolCustomOp c_MMCVCornerPoolCustomOp; OrtStatus *ORT_API_CALL RegisterCustomOps(OrtSessionOptions *options, const OrtApiBase *api) { @@ -45,5 +47,10 @@ OrtStatus *ORT_API_CALL RegisterCustomOps(OrtSessionOptions *options, return status; } + if (auto status = + ortApi->CustomOpDomain_Add(domain, &c_MMCVCornerPoolCustomOp)) { + return status; + } + return ortApi->AddCustomOpDomain(options, domain); } diff --git a/tests/test_ops/test_onnx.py b/tests/test_ops/test_onnx.py index 0e50f5403f..62859c32f5 100644 --- a/tests/test_ops/test_onnx.py +++ b/tests/test_ops/test_onnx.py @@ -448,3 +448,49 @@ def func(feat, scale_factor=2): if os.path.exists(onnx_file): os.remove(onnx_file) assert np.allclose(pytorch_result, onnx_result, atol=1e-3) + + +@pytest.mark.parametrize('mode', ['top', 'bottom', 'left', 'right']) +def test_corner_pool(mode, opset=11): + if torch.__version__ == 'parrots': + pytest.skip('onnx is not supported in parrots directly') + + from mmcv.ops import get_onnxruntime_op_path + ort_custom_op_path = get_onnxruntime_op_path() + if not os.path.exists(ort_custom_op_path): + pytest.skip('custom ops for onnxruntime are not compiled.') + + from mmcv.ops.corner_pool import CornerPool + + def corner_pool_func(input): + corner_pool_module = CornerPool(mode) + return corner_pool_module.corner_pool.apply(input) + + wrapped_model = WrapFunction(corner_pool_func).eval() + + input = torch.rand((2, 3, 9, 12)) # (n,c,h,w) + + with torch.no_grad(): + torch.onnx.export( + wrapped_model, + input, + onnx_file, + export_params=True, + keep_initializers_as_inputs=True, + input_names=['input'], + output_names=['output'], + opset_version=opset) + + onnx_model = onnx.load(onnx_file) + input_all = [node.name for node in onnx_model.graph.input] + input_initializer = [node.name for node in onnx_model.graph.initializer] + net_feed_input = list(set(input_all) - set(input_initializer)) + assert (len(net_feed_input) == 1) + + session_options = rt.SessionOptions() + session_options.register_custom_ops_library(ort_custom_op_path) + sess = rt.InferenceSession(onnx_file, session_options) + ort_result = sess.run(None, {'input': input.detach().numpy()}) + pytorch_results = wrapped_model(input.clone()) + os.remove(onnx_file) + assert np.allclose(pytorch_results, ort_result, atol=1e-5)