diff --git a/docs/onnxruntime_custom_ops.md b/docs/onnxruntime_custom_ops.md
index e42032d23d..837d947184 100644
--- a/docs/onnxruntime_custom_ops.md
+++ b/docs/onnxruntime_custom_ops.md
@@ -27,6 +27,12 @@
- [Inputs](#inputs-3)
- [Outputs](#outputs-3)
- [Type Constraints](#type-constraints-3)
+ - [CornerPool](#cornerpool)
+ - [Description](#description-4)
+ - [Parameters](#parameters-4)
+ - [Inputs](#inputs-4)
+ - [Outputs](#outputs-4)
+ - [Type Constraints](#type-constraints-4)
@@ -171,3 +177,33 @@ Perform sample from `input` with pixel locations from `grid`.
### Type Constraints
- T:tensor(float32, Linear)
+## CornerPool
+### Description
+Perform CornerPool on `input` features. Read [CornerNet -- Detecting Objects as Paired Keypoints](https://arxiv.org/abs/1808.01244) for more details.
+### Parameters
+| Type | Parameter | Description |
+| ------- | --------------- | ---------------------------------------------------------------- |
+| `int` | `mode` | corner pool mode, (0: `top`, 1: `bottom`, 2: `left`, 3: `right`) |
+### Inputs
+- input: T
+- Input features. 4-D tensor of shape (N, C, H, W). N is the batch size.
+### Outputs
+- output: T
+- Output the pooled features. 4-D tensor of shape (N, C, H, W).
+### Type Constraints
+- T:tensor(float32)
diff --git a/docs/onnxruntime_op.md b/docs/onnxruntime_op.md
index 9324524e39..0e2f62adb4 100644
--- a/docs/onnxruntime_op.md
+++ b/docs/onnxruntime_op.md
@@ -21,6 +21,7 @@
| [RoIAlign](onnxruntime_custom_ops.md#roialign) | Y | N | 1.2.5 |
| [NMS](onnxruntime_custom_ops.md#nms) | Y | N | 1.2.7 |
| [grid_sampler](onnxruntime_custom_ops.md#grid_sampler) | Y | N | master |
+| [CornerPool](onnxruntime_custom_ops.md#cornerpool) | Y | N | master |
## How to build custom operators for ONNX Runtime
diff --git a/mmcv/ops/corner_pool.py b/mmcv/ops/corner_pool.py
index 6b0d871933..189506e6aa 100644
--- a/mmcv/ops/corner_pool.py
+++ b/mmcv/ops/corner_pool.py
@@ -10,9 +10,17 @@
'right_pool_forward', 'right_pool_backward'
+_mode_dict = {'top': 0, 'bottom': 1, 'left': 2, 'right': 3}
class TopPoolFunction(Function):
+ @staticmethod
+ def symbolic(g, input):
+ output = g.op(
+ 'mmcv::MMCVCornerPool', input, mode_i=int(_mode_dict['top']))
+ return output
def forward(ctx, input):
output = ext_module.top_pool_forward(input)
@@ -28,6 +36,12 @@ def backward(ctx, grad_output):
class BottomPoolFunction(Function):
+ @staticmethod
+ def symbolic(g, input):
+ output = g.op(
+ 'mmcv::MMCVCornerPool', input, mode_i=int(_mode_dict['bottom']))
+ return output
def forward(ctx, input):
output = ext_module.bottom_pool_forward(input)
@@ -43,6 +57,12 @@ def backward(ctx, grad_output):
class LeftPoolFunction(Function):
+ @staticmethod
+ def symbolic(g, input):
+ output = g.op(
+ 'mmcv::MMCVCornerPool', input, mode_i=int(_mode_dict['left']))
+ return output
def forward(ctx, input):
output = ext_module.left_pool_forward(input)
@@ -58,6 +78,12 @@ def backward(ctx, grad_output):
class RightPoolFunction(Function):
+ @staticmethod
+ def symbolic(g, input):
+ output = g.op(
+ 'mmcv::MMCVCornerPool', input, mode_i=int(_mode_dict['right']))
+ return output
def forward(ctx, input):
output = ext_module.right_pool_forward(input)
diff --git a/mmcv/ops/csrc/onnxruntime/corner_pool.h b/mmcv/ops/csrc/onnxruntime/corner_pool.h
new file mode 100644
index 0000000000..4edca2cb8f
--- /dev/null
+++ b/mmcv/ops/csrc/onnxruntime/corner_pool.h
@@ -0,0 +1,45 @@
+struct MMCVCornerPoolKernel {
+ public:
+ MMCVCornerPoolKernel(Ort::CustomOpApi ort, const OrtKernelInfo* info)
+ : ort_(ort) {
+ mode_ = ort_.KernelInfoGetAttribute(info, "mode");
+ }
+ void Compute(OrtKernelContext* context);
+ private:
+ Ort::CustomOpApi ort_;
+ int64_t mode_;
+struct MMCVCornerPoolCustomOp
+ : Ort::CustomOpBase {
+ void* CreateKernel(Ort::CustomOpApi api, const OrtKernelInfo* info) {
+ return new MMCVCornerPoolKernel(api, info);
+ }
+ const char* GetName() const { return "MMCVCornerPool"; }
+ size_t GetInputTypeCount() const { return 1; }
+ ONNXTensorElementDataType GetInputType(size_t) const {
+ }
+ size_t GetOutputTypeCount() const { return 1; }
+ ONNXTensorElementDataType GetOutputType(size_t) const {
+ }
+ // force cpu
+ const char* GetExecutionProviderType() const {
+ return "CPUExecutionProvider";
+ }
diff --git a/mmcv/ops/csrc/onnxruntime/cpu/corner_pool.cpp b/mmcv/ops/csrc/onnxruntime/cpu/corner_pool.cpp
new file mode 100644
index 0000000000..d9d4dc3aad
--- /dev/null
+++ b/mmcv/ops/csrc/onnxruntime/cpu/corner_pool.cpp
@@ -0,0 +1,122 @@
+#include "corner_pool.h"
+#include "../ort_mmcv_utils.h"
+void TopPoolForwardCPU(const float *input, float *output, const int batch_size,
+ const int channels, const int height, const int width) {
+ for (int n = 0; n < batch_size; n++) {
+ int index_n = n * channels * width * height;
+ for (int c = 0; c < channels; c++) {
+ int index_n_c = index_n + c * width * height;
+ for (int w = 0; w < width; w++) {
+ // directly copy the most bottom value from input to output
+ output[index_n_c + (height - 1) * width + w] =
+ input[index_n_c + (height - 1) * width + w];
+ // do top_pool
+ for (int h = height - 2; h >= 0; h--) {
+ output[index_n_c + h * width + w] =
+ std::max(output[index_n_c + (h + 1) * width + w],
+ input[index_n_c + h * width + w]);
+ } // for h
+ } // for w
+ } // for c
+ } // for n
+void BottomPoolForwardCPU(const float *input, float *output,
+ const int batch_size, const int channels,
+ const int height, const int width) {
+ for (int n = 0; n < batch_size; n++) {
+ int index_n = n * channels * width * height;
+ for (int c = 0; c < channels; c++) {
+ int index_n_c = index_n + c * width * height;
+ for (int w = 0; w < width; w++) {
+ // directly copy the most top value from input to output
+ output[index_n_c + w] = input[index_n_c + w];
+ // do top_pool
+ for (int h = 1; h < height; h++) {
+ output[index_n_c + h * width + w] =
+ std::max(output[index_n_c + (h - 1) * width + w],
+ input[index_n_c + h * width + w]);
+ } // for h
+ } // for w
+ } // for c
+ } // for n
+void LeftPoolForwardCPU(const float *input, float *output, const int batch_size,
+ const int channels, const int height, const int width) {
+ for (int n = 0; n < batch_size; n++) {
+ int index_n = n * channels * width * height;
+ for (int c = 0; c < channels; c++) {
+ int index_n_c = index_n + c * width * height;
+ for (int h = 0; h < height; h++) {
+ // directly copy the most right value from input to output
+ output[index_n_c + h * width + width - 1] =
+ input[index_n_c + h * width + width - 1];
+ // do left_pool
+ for (int w = width - 2; w >= 0; w--) {
+ output[index_n_c + h * width + w] =
+ std::max(output[index_n_c + h * width + w + 1],
+ input[index_n_c + h * width + w]);
+ } // for w
+ } // for h
+ } // for c
+ } // for n
+void RightPoolForwardCPU(const float *input, float *output,
+ const int batch_size, const int channels,
+ const int height, const int width) {
+ for (int n = 0; n < batch_size; n++) {
+ int index_n = n * channels * width * height;
+ for (int c = 0; c < channels; c++) {
+ int index_n_c = index_n + c * width * height;
+ for (int h = 0; h < height; h++) {
+ // directly copy the most left value from input to output
+ output[index_n_c + h * width] = input[index_n_c + h * width];
+ // do right_pool
+ for (int w = 1; w < width; w++) {
+ output[index_n_c + h * width + w] =
+ std::max(output[index_n_c + h * width + w - 1],
+ input[index_n_c + h * width + w]);
+ } // for w
+ } // for h
+ } // for c
+ } // for n
+void MMCVCornerPoolKernel::Compute(OrtKernelContext *context) {
+ const int mode = int(mode_);
+ typedef float T;
+ const OrtValue *input = ort_.KernelContext_GetInput(context, 0);
+ const T *input_data =
+ reinterpret_cast(ort_.GetTensorData(input));
+ // get output memory
+ OrtTensorDimensions out_dimensions(ort_, input);
+ OrtValue *output = ort_.KernelContext_GetOutput(
+ context, 0, out_dimensions.data(), out_dimensions.size());
+ T *output_data = ort_.GetTensorMutableData(output);
+ // 'top': 0, 'bottom': 1, 'left': 2, 'right':3
+ assert(mode == 0 || mode == 1 || mode == 2 || mode == 3);
+ // do corner_pool
+ int batch_size = out_dimensions.data()[0];
+ int input_channels = out_dimensions.data()[1];
+ int input_height = out_dimensions.data()[2];
+ int input_width = out_dimensions.data()[3];
+ if (mode == 0)
+ TopPoolForwardCPU(input_data, output_data, batch_size, input_channels,
+ input_height, input_width);
+ else if (mode == 1)
+ BottomPoolForwardCPU(input_data, output_data, batch_size, input_channels,
+ input_height, input_width);
+ else if (mode == 2)
+ LeftPoolForwardCPU(input_data, output_data, batch_size, input_channels,
+ input_height, input_width);
+ else
+ RightPoolForwardCPU(input_data, output_data, batch_size, input_channels,
+ input_height, input_width);
diff --git a/mmcv/ops/csrc/onnxruntime/cpu/onnxruntime_register.cpp b/mmcv/ops/csrc/onnxruntime/cpu/onnxruntime_register.cpp
index a46e5b6215..b55114b188 100644
--- a/mmcv/ops/csrc/onnxruntime/cpu/onnxruntime_register.cpp
+++ b/mmcv/ops/csrc/onnxruntime/cpu/onnxruntime_register.cpp
@@ -1,5 +1,6 @@
#include "onnxruntime_register.h"
+#include "corner_pool.h"
#include "grid_sample.h"
#include "nms.h"
#include "ort_mmcv_utils.h"
@@ -13,6 +14,7 @@ NmsOp c_NmsOp;
MMCVRoiAlignCustomOp c_MMCVRoiAlignCustomOp;
MMCVRoIAlignRotatedCustomOp c_MMCVRoIAlignRotatedCustomOp;
GridSampleOp c_GridSampleOp;
+MMCVCornerPoolCustomOp c_MMCVCornerPoolCustomOp;
OrtStatus *ORT_API_CALL RegisterCustomOps(OrtSessionOptions *options,
const OrtApiBase *api) {
@@ -45,5 +47,10 @@ OrtStatus *ORT_API_CALL RegisterCustomOps(OrtSessionOptions *options,
return status;
+ if (auto status =
+ ortApi->CustomOpDomain_Add(domain, &c_MMCVCornerPoolCustomOp)) {
+ return status;
+ }
return ortApi->AddCustomOpDomain(options, domain);
diff --git a/tests/test_ops/test_onnx.py b/tests/test_ops/test_onnx.py
index 0e50f5403f..62859c32f5 100644
--- a/tests/test_ops/test_onnx.py
+++ b/tests/test_ops/test_onnx.py
@@ -448,3 +448,49 @@ def func(feat, scale_factor=2):
if os.path.exists(onnx_file):
assert np.allclose(pytorch_result, onnx_result, atol=1e-3)
+@pytest.mark.parametrize('mode', ['top', 'bottom', 'left', 'right'])
+def test_corner_pool(mode, opset=11):
+ if torch.__version__ == 'parrots':
+ pytest.skip('onnx is not supported in parrots directly')
+ from mmcv.ops import get_onnxruntime_op_path
+ ort_custom_op_path = get_onnxruntime_op_path()
+ if not os.path.exists(ort_custom_op_path):
+ pytest.skip('custom ops for onnxruntime are not compiled.')
+ from mmcv.ops.corner_pool import CornerPool
+ def corner_pool_func(input):
+ corner_pool_module = CornerPool(mode)
+ return corner_pool_module.corner_pool.apply(input)
+ wrapped_model = WrapFunction(corner_pool_func).eval()
+ input = torch.rand((2, 3, 9, 12)) # (n,c,h,w)
+ with torch.no_grad():
+ torch.onnx.export(
+ wrapped_model,
+ input,
+ onnx_file,
+ export_params=True,
+ keep_initializers_as_inputs=True,
+ input_names=['input'],
+ output_names=['output'],
+ opset_version=opset)
+ onnx_model = onnx.load(onnx_file)
+ input_all = [node.name for node in onnx_model.graph.input]
+ input_initializer = [node.name for node in onnx_model.graph.initializer]
+ net_feed_input = list(set(input_all) - set(input_initializer))
+ assert (len(net_feed_input) == 1)
+ session_options = rt.SessionOptions()
+ session_options.register_custom_ops_library(ort_custom_op_path)
+ sess = rt.InferenceSession(onnx_file, session_options)
+ ort_result = sess.run(None, {'input': input.detach().numpy()})
+ pytorch_results = wrapped_model(input.clone())
+ os.remove(onnx_file)
+ assert np.allclose(pytorch_results, ort_result, atol=1e-5)