Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add NonMaxSupression op to contribution ops #60

Merged
merged 8 commits into from
Dec 1, 2018
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
47 changes: 47 additions & 0 deletions onnxruntime/contrib_ops/contrib_ops.cc
Original file line number Diff line number Diff line change
Expand Up @@ -357,6 +357,51 @@ with the exception that numpy default keepdims to False instead of True.)DOC")
"keepdims",
"Keep the reduced dimension or not, default 1 mean keep reduced dimension.",
AttributeProto::INT);

ONNX_CONTRIB_OPERATOR_SCHEMA(NonMaxSuppression)
.SetDomain(kMSDomain)
.SinceVersion(1)
.SetDoc(R"DOC(
Pruning away boxes that have high intersection-over-union (IOU) overlap with previously selected boxes.
Bounding boxes with score less than score_threshold are removed. Bounding boxes are supplied as [y1, x1, y2, x2],
where (y1, x1) and (y2, x2) are the coordinates of any diagonal pair of box corners and the coordinates can be provided
as normalized (i.e., lying in the interval [0, 1]) or absolute.
Note that this algorithm is agnostic to where the origin is in the coordinate system and more generally is invariant to
orthogonal transformations and translations of the coordinate system;
thus translating or reflections of the coordinate system result in the same boxes being selected by the algorithm.
The output of this operation is a set of integers indexing into the input collection of bounding boxes representing the selected boxes.
The bounding box coordinates corresponding to the selected indices can then be obtained using the gather operation.)DOC")
.Input(0, "boxes", "An input tensor. 2D tensor with shape [num_boxes, 4]", "T1")
.Input(1, "scores", "An input tensor. 1D tensor with shape [num_boxes]", "T1")
.Output(0, "selected_indices", "selected indices from the boxes tensor.", "T2")
.Output(
1,
"valid_outputs",
"Optional. A 0-D integer tensor representing the number of valid elements in selected_indices, with the valid elements appearing first.",
"T2",
OpSchema::Optional)
.TypeConstraint("T1", {"tensor(float)"}, "Constrain input type to float tensor.")
.TypeConstraint("T2",
{"tensor(int32)"},
"Constrain output data type to 32-bit integer tensor.")
.Attr(
"max_output_size",
"Integer representing the maximum number of boxes to be selected by non max suppression.",
AttributeProto::INT)
.Attr(
"iou_threshold",
Copy link
Member

@wenbingl wenbingl Nov 29, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

any default value for it? #Resolved

"Float representing the threshold for deciding whether boxes overlap too much with respect to IOU. Value range [0, 1]. The default is 0.0",
AttributeProto::FLOAT,
static_cast<float>(0.0f))
.Attr(
"score_threshold",
"Float tensor representing the threshold for deciding when to remove boxes based on score.",
AttributeProto::FLOAT)
.Attr(
"pad_to_max_output_size",
"Optional. 1(true) - the output selected_indices is padded to be of length max_output_size. Defaults to 0(false).",
AttributeProto::INT,
OPTIONAL);
}
Copy link
Member

@duli2012 duli2012 Nov 29, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we want to add shape inference here since the shape inference is enabled by default now? #Resolved


class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, float, SampleOp);
Expand All @@ -366,6 +411,7 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1,
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, uint8_t, DequantizeLinear);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, int8_t, DequantizeLinear);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, float, QuantizeLinear);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, float, NonMaxSuppression);

void RegisterContribKernels(std::function<void(KernelCreateInfo&&)> fn) {
fn(BuildKernel<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, float, SampleOp)>());
Expand All @@ -378,6 +424,7 @@ void RegisterContribKernels(std::function<void(KernelCreateInfo&&)> fn) {
fn(BuildKernel<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, uint8_t, DequantizeLinear)>());
fn(BuildKernel<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, int8_t, DequantizeLinear)>());
fn(BuildKernel<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, float, QuantizeLinear)>());
fn(BuildKernel<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, float, NonMaxSuppression)>());
}
} // namespace contrib
} // namespace onnxruntime
149 changes: 149 additions & 0 deletions onnxruntime/contrib_ops/cpu/non_max_suppression.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,149 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

#include "contrib_ops/cpu/non_max_suppression.h"
#include <queue>

namespace onnxruntime {
namespace contrib {

ONNX_CPU_OPERATOR_TYPED_MS_KERNEL(
NonMaxSuppression,
1,
float,
KernelDefBuilder()
.TypeConstraint("T1", DataTypeImpl::GetTensorType<float>())
.TypeConstraint("T2", DataTypeImpl::GetTensorType<int32_t>()),
NonMaxSuppression<float>);

template <typename T>
void NonMaxSuppression<T>::MaxMin(const T& lhs, const T& rhs, T& min, T& max) const {
if (lhs >= rhs) {
min = rhs;
max = lhs;
} else {
min = lhs;
max = rhs;
}
}

template <typename T>
bool NonMaxSuppression<T>::SuppressByIOU(const T* boxes_data, int32_t box_index1, int32_t box_index2) const {
T x1_min, y1_min, x1_max, y1_max, x2_min, y2_min, x2_max, y2_max;
// boxes data [y1, x1, y2, x2],
MaxMin(boxes_data[4 * box_index1 + 1], boxes_data[4 * box_index1 + 3], x1_min, x1_max);
MaxMin(boxes_data[4 * box_index1 + 0], boxes_data[4 * box_index1 + 2], y1_min, y1_max);
MaxMin(boxes_data[4 * box_index2 + 1], boxes_data[4 * box_index2 + 3], x2_min, x2_max);
MaxMin(boxes_data[4 * box_index2 + 0], boxes_data[4 * box_index2 + 2], y2_min, y2_max);

const T intersection_x_min = std::max(x1_min, x2_min);
const T intersection_y_min = std::max(y1_min, y2_min);
const T intersection_x_max = std::min(x1_max, x2_max);
const T intersection_y_max = std::min(y1_max, y2_max);

const T intersection_area = std::max(intersection_x_max - intersection_x_min, static_cast<T>(0.0)) *
std::max(intersection_y_max - intersection_y_min, static_cast<T>(0.0));

if (intersection_area <= static_cast<T>(0.0)) {
return false;
}

const T area1 = (x1_max - x1_min) * (y1_max - y1_min);
const T area2 = (x2_max - x2_min) * (y2_max - y2_min);
const T union_area = area1 + area2 - intersection_area;

if (area1 <= static_cast<T>(0.0) || area2 <= static_cast<T>(0.0) || union_area <= static_cast<T>(0.0)) {
return false;
}

const T intersection_over_union = intersection_area / union_area;

return intersection_over_union > iou_threshold_;
}

template <typename T>
Status NonMaxSuppression<T>::Compute(OpKernelContext* ctx) const {
const Tensor* boxes = ctx->Input<Tensor>(0);
ONNXRUNTIME_ENFORCE(boxes);
const Tensor* scores = ctx->Input<Tensor>(1);
ONNXRUNTIME_ENFORCE(scores);

const TensorShape& boxes_shape = boxes->Shape();
auto boxes_dims = boxes_shape.GetDims();
ONNXRUNTIME_RETURN_IF_NOT(boxes_shape.NumDimensions() == 2, "boxes must be a 2D tensor.");
int64_t num_boxes = boxes_dims[0];
ONNXRUNTIME_RETURN_IF_NOT(boxes_dims[1] == 4, "boxes shape must be a 2D tensor with shape [num_boxes, 4].");

const TensorShape& scores_shape = scores->Shape();
ONNXRUNTIME_RETURN_IF_NOT(scores_shape.NumDimensions() == 1, "boxes must be a 1D tensor.");
ONNXRUNTIME_RETURN_IF_NOT(scores_shape.GetDims()[0] == num_boxes, "scores and boxes should have same num_boxes.");

if (max_output_size_ <= 0 || boxes_dims[0] == 0) {
std::vector<int64_t> output_dims(1, 0);
TensorShape output_shape(output_dims);
Copy link
Contributor

@tracysh tracysh Nov 29, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

TensorShape can take an initializer list, so this can be output_shape({1, 0}) to save the overhead of a local vector alloc/destroy. Same thing below for {1, num_to_copy}. #Resolved

Copy link
Contributor

@tracysh tracysh Nov 30, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Well, not exactly what I said above, whatever the transform from what you have to an initializer list. #Resolved

ctx->Output(0, output_shape);
return Status::OK();
}

const T* boxes_data = boxes->Data<T>();
const T* scores_data = scores->Data<T>();

struct ScoreIndexPair {
T score;
int32_t index;
};

auto LessCompare = [](const ScoreIndexPair& lhs, const ScoreIndexPair& rhs) {
return lhs.score < rhs.score;
};

// Filter by score_threshold_
std::priority_queue<ScoreIndexPair, std::deque<ScoreIndexPair>, decltype(LessCompare)> sorted_scores_with_index(LessCompare);
for (int32_t i = 0; i < num_boxes; ++i) {
if (static_cast<float>(scores_data[i]) > score_threshold_) {
sorted_scores_with_index.emplace(ScoreIndexPair({scores_data[i], i}));
}
}

int num_of_selected = 0;
std::vector<int32_t> selected_index(max_output_size_, 0);
ScoreIndexPair next_top_score;

// Get the next box with top score, filter by iou_threshold_
while (num_of_selected < max_output_size_ && !sorted_scores_with_index.empty()) {
next_top_score = sorted_scores_with_index.top();
sorted_scores_with_index.pop();

bool selected = true;
// Check with existing boxes, suppress if exceed the IOU (Intersection Over Union) threadhold
Copy link
Contributor

@tracysh tracysh Nov 30, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

threadhold->threshold? #Resolved

Copy link
Contributor Author

@HectorSVC HectorSVC Nov 30, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

good catch. 👍 #Resolved

for (int i = num_of_selected - 1; i >= 0; --i) {
if (SuppressByIOU(boxes_data, selected_index[i], next_top_score.index)) {
selected = false;
break;
}
}

if (selected) {
selected_index[num_of_selected] = next_top_score.index;
++num_of_selected;
}
}

int64_t num_to_copy = pad_to_max_output_size_ == 1 ? max_output_size_ : num_of_selected;
std::vector<int64_t> output_dim(1, num_to_copy);
TensorShape output_shape(output_dim);
Tensor* selected_indices = ctx->Output(0, output_shape);
auto output_data = selected_indices->MutableData<int32_t>();
memcpy(output_data, selected_index.data(), num_to_copy * sizeof(int32_t));

TensorShape valid_outputs_shape(std::vector<int64_t>{1});
Tensor* valid_outputs = ctx->Output(1, valid_outputs_shape);
if (valid_outputs) {
valid_outputs->MutableData<int32_t>()[0] = num_of_selected;
}

return Status::OK();
}

} // namespace contrib
} // namespace onnxruntime
37 changes: 37 additions & 0 deletions onnxruntime/contrib_ops/cpu/non_max_suppression.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

#pragma once

#include "core/common/common.h"
#include "core/framework/op_kernel.h"
//#include "core/util/math_cpuonly.h"

namespace onnxruntime {
namespace contrib {

template <typename T>
class NonMaxSuppression final : public OpKernel {
public:
NonMaxSuppression(const OpKernelInfo& info) : OpKernel(info),
pad_to_max_output_size_(info.GetAttrOrDefault<int64_t>("pad_to_max_output_size", 0)) {
ONNXRUNTIME_ENFORCE(info.GetAttr("max_output_size", &max_output_size_).IsOK());
ONNXRUNTIME_ENFORCE(info.GetAttr("iou_threshold", &iou_threshold_).IsOK());
ONNXRUNTIME_ENFORCE(iou_threshold_ >= 0 && iou_threshold_ <= 1, "iou_threshold must be in range [0, 1]");
ONNXRUNTIME_ENFORCE(info.GetAttr("score_threshold", &score_threshold_).IsOK());
}

Status Compute(OpKernelContext* context) const override;

private:
bool SuppressByIOU(const T* boxes_data, int32_t box_index1, int32_t box_index2) const;
void MaxMin(const T& lhs, const T& rhs, T& min, T& max) const;

private :
int64_t max_output_size_;
float iou_threshold_;
float score_threshold_;
int64_t pad_to_max_output_size_;
};
} // namespace contrib
} // namespace onnxruntime
Loading