Skip to content

Commit

Permalink
[Feature]: Support corner_pool related custom operators for onnxrunti…
Browse files Browse the repository at this point in the history
…me in mmcv (#997)

* supports for onnxruntime custom op `mmcv::MMCVTopPool`

* supports for onnxruntime custom op `mmcv::MMCVCornerPool`, involving TopPool, BottomPool, LeftPool and RightPool

* add unittest for corner_pool

* supports mmcv::CornerPool without memcpy

* add docs for mmcv::CornerPool

* re-add docs for mmcv::CornerPool

* fix output dtype doc

* reformat

* format with pre-commit

* format

* fix lint error, by using google clang-format style for c/c++
  • Loading branch information
v-qjqs authored May 1, 2021
1 parent 3f8e985 commit db6b054
Show file tree
Hide file tree
Showing 7 changed files with 283 additions and 0 deletions.
36 changes: 36 additions & 0 deletions docs/onnxruntime_custom_ops.md
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,12 @@
- [Inputs](#inputs-3)
- [Outputs](#outputs-3)
- [Type Constraints](#type-constraints-3)
- [CornerPool](#cornerpool)
- [Description](#description-4)
- [Parameters](#parameters-4)
- [Inputs](#inputs-4)
- [Outputs](#outputs-4)
- [Type Constraints](#type-constraints-4)

<!-- TOC -->

Expand Down Expand Up @@ -171,3 +177,33 @@ Perform sample from `input` with pixel locations from `grid`.
### Type Constraints

- T:tensor(float32, Linear)

## CornerPool

### Description

Perform CornerPool on `input` features. Read [CornerNet -- Detecting Objects as Paired Keypoints](https://arxiv.org/abs/1808.01244) for more details.

### Parameters

| Type | Parameter | Description |
| ------- | --------------- | ---------------------------------------------------------------- |
| `int` | `mode` | corner pool mode, (0: `top`, 1: `bottom`, 2: `left`, 3: `right`) |

### Inputs

<dl>
<dt><tt>input</tt>: T</dt>
<dd>Input features. 4-D tensor of shape (N, C, H, W). N is the batch size.</dd>
</dl>

### Outputs

<dl>
<dt><tt>output</tt>: T</dt>
<dd>Output the pooled features. 4-D tensor of shape (N, C, H, W).</dd>
</dl>

### Type Constraints

- T:tensor(float32)
1 change: 1 addition & 0 deletions docs/onnxruntime_op.md
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
| [RoIAlign](onnxruntime_custom_ops.md#roialign) | Y | N | 1.2.5 |
| [NMS](onnxruntime_custom_ops.md#nms) | Y | N | 1.2.7 |
| [grid_sampler](onnxruntime_custom_ops.md#grid_sampler) | Y | N | master |
| [CornerPool](onnxruntime_custom_ops.md#cornerpool) | Y | N | master |

## How to build custom operators for ONNX Runtime

Expand Down
26 changes: 26 additions & 0 deletions mmcv/ops/corner_pool.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,17 @@
'right_pool_forward', 'right_pool_backward'
])

_mode_dict = {'top': 0, 'bottom': 1, 'left': 2, 'right': 3}


class TopPoolFunction(Function):

@staticmethod
def symbolic(g, input):
output = g.op(
'mmcv::MMCVCornerPool', input, mode_i=int(_mode_dict['top']))
return output

@staticmethod
def forward(ctx, input):
output = ext_module.top_pool_forward(input)
Expand All @@ -28,6 +36,12 @@ def backward(ctx, grad_output):

class BottomPoolFunction(Function):

@staticmethod
def symbolic(g, input):
output = g.op(
'mmcv::MMCVCornerPool', input, mode_i=int(_mode_dict['bottom']))
return output

@staticmethod
def forward(ctx, input):
output = ext_module.bottom_pool_forward(input)
Expand All @@ -43,6 +57,12 @@ def backward(ctx, grad_output):

class LeftPoolFunction(Function):

@staticmethod
def symbolic(g, input):
output = g.op(
'mmcv::MMCVCornerPool', input, mode_i=int(_mode_dict['left']))
return output

@staticmethod
def forward(ctx, input):
output = ext_module.left_pool_forward(input)
Expand All @@ -58,6 +78,12 @@ def backward(ctx, grad_output):

class RightPoolFunction(Function):

@staticmethod
def symbolic(g, input):
output = g.op(
'mmcv::MMCVCornerPool', input, mode_i=int(_mode_dict['right']))
return output

@staticmethod
def forward(ctx, input):
output = ext_module.right_pool_forward(input)
Expand Down
45 changes: 45 additions & 0 deletions mmcv/ops/csrc/onnxruntime/corner_pool.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
#ifndef ONNXRUNTIME_CORNER_POOL_H
#define ONNXRUNTIME_CORNER_POOL_H

#include <assert.h>
#include <onnxruntime_cxx_api.h>

struct MMCVCornerPoolKernel {
public:
MMCVCornerPoolKernel(Ort::CustomOpApi ort, const OrtKernelInfo* info)
: ort_(ort) {
mode_ = ort_.KernelInfoGetAttribute<int64_t>(info, "mode");
}

void Compute(OrtKernelContext* context);

private:
Ort::CustomOpApi ort_;

int64_t mode_;
};

struct MMCVCornerPoolCustomOp
: Ort::CustomOpBase<MMCVCornerPoolCustomOp, MMCVCornerPoolKernel> {
void* CreateKernel(Ort::CustomOpApi api, const OrtKernelInfo* info) {
return new MMCVCornerPoolKernel(api, info);
}

const char* GetName() const { return "MMCVCornerPool"; }

size_t GetInputTypeCount() const { return 1; }
ONNXTensorElementDataType GetInputType(size_t) const {
return ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT;
}

size_t GetOutputTypeCount() const { return 1; }
ONNXTensorElementDataType GetOutputType(size_t) const {
return ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT;
}

// force cpu
const char* GetExecutionProviderType() const {
return "CPUExecutionProvider";
}
};
#endif // ONNXRUNTIME_CORNER_POOL_H
122 changes: 122 additions & 0 deletions mmcv/ops/csrc/onnxruntime/cpu/corner_pool.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,122 @@
#include "corner_pool.h"

#include "../ort_mmcv_utils.h"

void TopPoolForwardCPU(const float *input, float *output, const int batch_size,
const int channels, const int height, const int width) {
for (int n = 0; n < batch_size; n++) {
int index_n = n * channels * width * height;
for (int c = 0; c < channels; c++) {
int index_n_c = index_n + c * width * height;
for (int w = 0; w < width; w++) {
// directly copy the most bottom value from input to output
output[index_n_c + (height - 1) * width + w] =
input[index_n_c + (height - 1) * width + w];
// do top_pool
for (int h = height - 2; h >= 0; h--) {
output[index_n_c + h * width + w] =
std::max(output[index_n_c + (h + 1) * width + w],
input[index_n_c + h * width + w]);
} // for h
} // for w
} // for c
} // for n
}

void BottomPoolForwardCPU(const float *input, float *output,
const int batch_size, const int channels,
const int height, const int width) {
for (int n = 0; n < batch_size; n++) {
int index_n = n * channels * width * height;
for (int c = 0; c < channels; c++) {
int index_n_c = index_n + c * width * height;
for (int w = 0; w < width; w++) {
// directly copy the most top value from input to output
output[index_n_c + w] = input[index_n_c + w];
// do top_pool
for (int h = 1; h < height; h++) {
output[index_n_c + h * width + w] =
std::max(output[index_n_c + (h - 1) * width + w],
input[index_n_c + h * width + w]);
} // for h
} // for w
} // for c
} // for n
}

void LeftPoolForwardCPU(const float *input, float *output, const int batch_size,
const int channels, const int height, const int width) {
for (int n = 0; n < batch_size; n++) {
int index_n = n * channels * width * height;
for (int c = 0; c < channels; c++) {
int index_n_c = index_n + c * width * height;
for (int h = 0; h < height; h++) {
// directly copy the most right value from input to output
output[index_n_c + h * width + width - 1] =
input[index_n_c + h * width + width - 1];
// do left_pool
for (int w = width - 2; w >= 0; w--) {
output[index_n_c + h * width + w] =
std::max(output[index_n_c + h * width + w + 1],
input[index_n_c + h * width + w]);
} // for w
} // for h
} // for c
} // for n
}

void RightPoolForwardCPU(const float *input, float *output,
const int batch_size, const int channels,
const int height, const int width) {
for (int n = 0; n < batch_size; n++) {
int index_n = n * channels * width * height;
for (int c = 0; c < channels; c++) {
int index_n_c = index_n + c * width * height;
for (int h = 0; h < height; h++) {
// directly copy the most left value from input to output
output[index_n_c + h * width] = input[index_n_c + h * width];
// do right_pool
for (int w = 1; w < width; w++) {
output[index_n_c + h * width + w] =
std::max(output[index_n_c + h * width + w - 1],
input[index_n_c + h * width + w]);
} // for w
} // for h
} // for c
} // for n
}

void MMCVCornerPoolKernel::Compute(OrtKernelContext *context) {
const int mode = int(mode_);
typedef float T;
const OrtValue *input = ort_.KernelContext_GetInput(context, 0);
const T *input_data =
reinterpret_cast<const float *>(ort_.GetTensorData<T>(input));

// get output memory
OrtTensorDimensions out_dimensions(ort_, input);
OrtValue *output = ort_.KernelContext_GetOutput(
context, 0, out_dimensions.data(), out_dimensions.size());
T *output_data = ort_.GetTensorMutableData<T>(output);

// 'top': 0, 'bottom': 1, 'left': 2, 'right':3
assert(mode == 0 || mode == 1 || mode == 2 || mode == 3);

// do corner_pool
int batch_size = out_dimensions.data()[0];
int input_channels = out_dimensions.data()[1];
int input_height = out_dimensions.data()[2];
int input_width = out_dimensions.data()[3];
if (mode == 0)
TopPoolForwardCPU(input_data, output_data, batch_size, input_channels,
input_height, input_width);
else if (mode == 1)
BottomPoolForwardCPU(input_data, output_data, batch_size, input_channels,
input_height, input_width);
else if (mode == 2)
LeftPoolForwardCPU(input_data, output_data, batch_size, input_channels,
input_height, input_width);
else
RightPoolForwardCPU(input_data, output_data, batch_size, input_channels,
input_height, input_width);
}
7 changes: 7 additions & 0 deletions mmcv/ops/csrc/onnxruntime/cpu/onnxruntime_register.cpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
#include "onnxruntime_register.h"

#include "corner_pool.h"
#include "grid_sample.h"
#include "nms.h"
#include "ort_mmcv_utils.h"
Expand All @@ -13,6 +14,7 @@ NmsOp c_NmsOp;
MMCVRoiAlignCustomOp c_MMCVRoiAlignCustomOp;
MMCVRoIAlignRotatedCustomOp c_MMCVRoIAlignRotatedCustomOp;
GridSampleOp c_GridSampleOp;
MMCVCornerPoolCustomOp c_MMCVCornerPoolCustomOp;

OrtStatus *ORT_API_CALL RegisterCustomOps(OrtSessionOptions *options,
const OrtApiBase *api) {
Expand Down Expand Up @@ -45,5 +47,10 @@ OrtStatus *ORT_API_CALL RegisterCustomOps(OrtSessionOptions *options,
return status;
}

if (auto status =
ortApi->CustomOpDomain_Add(domain, &c_MMCVCornerPoolCustomOp)) {
return status;
}

return ortApi->AddCustomOpDomain(options, domain);
}
46 changes: 46 additions & 0 deletions tests/test_ops/test_onnx.py
Original file line number Diff line number Diff line change
Expand Up @@ -448,3 +448,49 @@ def func(feat, scale_factor=2):
if os.path.exists(onnx_file):
os.remove(onnx_file)
assert np.allclose(pytorch_result, onnx_result, atol=1e-3)


@pytest.mark.parametrize('mode', ['top', 'bottom', 'left', 'right'])
def test_corner_pool(mode, opset=11):
if torch.__version__ == 'parrots':
pytest.skip('onnx is not supported in parrots directly')

from mmcv.ops import get_onnxruntime_op_path
ort_custom_op_path = get_onnxruntime_op_path()
if not os.path.exists(ort_custom_op_path):
pytest.skip('custom ops for onnxruntime are not compiled.')

from mmcv.ops.corner_pool import CornerPool

def corner_pool_func(input):
corner_pool_module = CornerPool(mode)
return corner_pool_module.corner_pool.apply(input)

wrapped_model = WrapFunction(corner_pool_func).eval()

input = torch.rand((2, 3, 9, 12)) # (n,c,h,w)

with torch.no_grad():
torch.onnx.export(
wrapped_model,
input,
onnx_file,
export_params=True,
keep_initializers_as_inputs=True,
input_names=['input'],
output_names=['output'],
opset_version=opset)

onnx_model = onnx.load(onnx_file)
input_all = [node.name for node in onnx_model.graph.input]
input_initializer = [node.name for node in onnx_model.graph.initializer]
net_feed_input = list(set(input_all) - set(input_initializer))
assert (len(net_feed_input) == 1)

session_options = rt.SessionOptions()
session_options.register_custom_ops_library(ort_custom_op_path)
sess = rt.InferenceSession(onnx_file, session_options)
ort_result = sess.run(None, {'input': input.detach().numpy()})
pytorch_results = wrapped_model(input.clone())
os.remove(onnx_file)
assert np.allclose(pytorch_results, ort_result, atol=1e-5)

0 comments on commit db6b054

Please sign in to comment.