-
Notifications
You must be signed in to change notification settings - Fork 5.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add multi-class non-maximum suppression operator. #7953
Conversation
9ad769d
to
2731fd9
Compare
auto score_dims = ctx->GetInputDim("Scores"); | ||
|
||
PADDLE_ENFORCE_EQ(box_dims.size(), 2, | ||
"The rank of Input(Bboxes) must be 3."); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The rank of Input(Bboxes) must be 2.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done.
"The rank of Input(Bboxes) must be 3."); | ||
PADDLE_ENFORCE_EQ(score_dims.size(), 3, | ||
"The rank of Input(Scores) must be 3."); | ||
PADDLE_ENFORCE_EQ(box_dims[1], 4); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The shape of Input(Bboxes) must be [N, 4]
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done.
PADDLE_ENFORCE_EQ(score_dims.size(), 3, | ||
"The rank of Input(Scores) must be 3."); | ||
PADDLE_ENFORCE_EQ(box_dims[1], 4); | ||
PADDLE_ENFORCE_EQ(box_dims[0], score_dims[2]); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The predictions number of Input(Bboxes) and Input(Scores) must be the same.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done. Thanks!
T BBoxArea(const T* box, const bool normalized) { | ||
if (box[2] < box[0] || box[3] < box[1]) { | ||
// If bbox is invalid (e.g. xmax < xmin or ymax < ymin), return 0. | ||
return T(0.); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
use static_cast(0.) or brace initialization?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done. Thanks!
for (const auto& it : selected_indices) { | ||
int label = it.first; | ||
const T* sdata = scores_data + label * predict_dim; | ||
std::vector<int> indices = it.second; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
std::vector& indices = it.second;
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done.
@@ -0,0 +1,375 @@ | |||
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please fix the year.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done.
|
||
void InferShape(framework::InferShapeContext* ctx) const override { | ||
PADDLE_ENFORCE(ctx->HasInput("Bboxes"), | ||
"Input(Bboxes) of MulticlassNMS should not be null."); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Bboxes or BBoxes?
MulticlassNMS --> MultiClassNMSOp
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done. Change to BBoxes.
PADDLE_ENFORCE(ctx->HasInput("Bboxes"), | ||
"Input(Bboxes) of MulticlassNMS should not be null."); | ||
PADDLE_ENFORCE(ctx->HasInput("Scores"), | ||
"Input(Scores) of MulticlassNMS should not be null."); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
MulticlassNMS --> MulticlassNMSOp
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done.
constexpr int64_t kOutputDim = 6; | ||
constexpr int64_t kBBoxSize = 4; | ||
|
||
class MulticlassNMSOp : public framework::OperatorWithKernel { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
MulticlassNMSOp --> MultiClassNMSOp
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done.
auto score_dims = ctx->GetInputDim("Scores"); | ||
|
||
PADDLE_ENFORCE_EQ(box_dims.size(), 2, | ||
"The rank of Input(Bboxes) must be 3."); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
must be 3 --> must be 2
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done.
AddInput("Bboxes", | ||
"(Tensor) A 2-D Tensor with shape [M, 4] represents the location " | ||
"predictions with M bboxes. 4 is the number of " | ||
"each location coordinates."); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
4 is the number of each location coordinates --> Each bounding box has four coordinate values and the layout is [xmin, ymin, xmax, ymax]
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done. Thanks!
"each location coordinates."); | ||
AddInput("Scores", | ||
"(Tensor) A 3-D Tensor with shape [N, C, M] represents the " | ||
"confidence predictions. N is the batch size, C is the class " |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
confidence predictions --> predicted confidence scores.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done. Thanks!
AddInput("Scores", | ||
"(Tensor) A 3-D Tensor with shape [N, C, M] represents the " | ||
"confidence predictions. N is the batch size, C is the class " | ||
"number, M is number of predictions for each class, which is " |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
, M is number of predictions for each class --> and M is the number of bounding boxes. For each category there are total M scores which corresponding to M bounding boxes. Please note, M is equal to the 1st dimension of Bboxes.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done. Thanks!
AddAttr<int>( | ||
"background_label", | ||
"(int64_t, defalut: 0) " | ||
"The index of background label, the background label will be ignored.") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please add If set to -1, then all categories will be considered.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done.
.SetDefault(0); | ||
AddAttr<float>("score_threshold", | ||
"(float) " | ||
"Only consider detections whose confidences are larger than " |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Threshold to filter out bounding boxes with low confidence score.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
: OpProtoAndCheckerMaker(proto, op_checker) { | ||
AddInput("Bboxes", | ||
"(Tensor) A 2-D Tensor with shape [M, 4] represents the location " | ||
"predictions with M bboxes. 4 is the number of " |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done. Thank!
AddInput("Bboxes", | ||
"(Tensor) A 2-D Tensor with shape [M, 4] represents the location " | ||
"predictions with M bboxes. 4 is the number of " | ||
"each location coordinates."); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done. Thanks!
"each location coordinates."); | ||
AddInput("Scores", | ||
"(Tensor) A 3-D Tensor with shape [N, C, M] represents the " | ||
"confidence predictions. N is the batch size, C is the class " |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done. Thanks!
AddInput("Scores", | ||
"(Tensor) A 3-D Tensor with shape [N, C, M] represents the " | ||
"confidence predictions. N is the batch size, C is the class " | ||
"number, M is number of predictions for each class, which is " |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done. Thanks!
@@ -0,0 +1,375 @@ | |||
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done.
const T* bdata = bboxes_data + idx * kBBoxSize; | ||
odata[count * kOutputDim] = label; // label | ||
odata[count * kOutputDim + 1] = sdata[idx]; // score | ||
odata[count * kOutputDim + 2] = bdata[0]; // xmin |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done, use std::memcpy
std::vector<size_t> batch_starts = {0}; | ||
for (int64_t i = 0; i < batch_size; ++i) { | ||
Tensor ins_score = scores->Slice(i, i + 1); | ||
ins_score.Resize({class_num, predict_dim}); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The Slice
in Tensor doesn't reduce dimension. The shape of ins_score is [1, C, M], resize to [C, M] here.
if (normalized) { | ||
return w * h; | ||
} else { | ||
// If bbox is not within range [0, 1]. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done.
AddAttr<int>( | ||
"background_label", | ||
"(int64_t, defalut: 0) " | ||
"The index of background label, the background label will be ignored.") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done.
.SetDefault(0); | ||
AddAttr<float>("score_threshold", | ||
"(float) " | ||
"Only consider detections whose confidences are larger than " |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done.
3b56ccd
to
1d9a7e1
Compare
limitations under the License. */ | ||
|
||
#include "paddle/framework/op_registry.h" | ||
#include "paddle/operators/math/math_function.h" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is this header necessary?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done.
} | ||
|
||
template <class T> | ||
T BBoxArea(const T* box, const bool normalized) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
add inline?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
std::vector<size_t> batch_starts = {0}; | ||
for (int64_t i = 0; i < batch_size; ++i) { | ||
Tensor ins_score = scores->Slice(i, i + 1); | ||
ins_score.Resize({class_num, predict_dim}); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I see, thx.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@wanghaox Thanks!
limitations under the License. */ | ||
|
||
#include "paddle/framework/op_registry.h" | ||
#include "paddle/operators/math/math_function.h" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done.
} | ||
|
||
template <class T> | ||
T BBoxArea(const T* box, const bool normalized) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done.
Fix #7773