Skip to content

Commit

Permalink
Merge pull request #7953 from qingqing01/multiclass_nms_op
Browse files Browse the repository at this point in the history
 Add multi-class non-maximum suppression operator.
  • Loading branch information
qingqing01 authored Feb 2, 2018
2 parents 624d22d + a6f3846 commit c9ef69b
Show file tree
Hide file tree
Showing 4 changed files with 625 additions and 9 deletions.
18 changes: 12 additions & 6 deletions paddle/operators/bipartite_match_op.cc
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -28,12 +28,18 @@ class BipartiteMatchOp : public framework::OperatorWithKernel {
void InferShape(framework::InferShapeContext* ctx) const override {
PADDLE_ENFORCE(ctx->HasInput("DistMat"),
"Input(DistMat) of BipartiteMatch should not be null.");
PADDLE_ENFORCE(
ctx->HasOutput("ColToRowMatchIndices"),
"Output(ColToRowMatchIndices) of BipartiteMatch should not be null.");
PADDLE_ENFORCE(
ctx->HasOutput("ColToRowMatchDist"),
"Output(ColToRowMatchDist) of BipartiteMatch should not be null.");

auto dims = ctx->GetInputDim("DistMat");
PADDLE_ENFORCE_EQ(dims.size(), 2, "The rank of Input(DistMat) must be 2.");

ctx->SetOutputDim("ColToRowMatchIndices", dims);
ctx->SetOutputDim("ColToRowMatchDis", dims);
ctx->SetOutputDim("ColToRowMatchDist", dims);
}
};

Expand Down Expand Up @@ -91,7 +97,7 @@ class BipartiteMatchKernel : public framework::OpKernel<T> {
void Compute(const framework::ExecutionContext& context) const override {
auto* dist_mat = context.Input<LoDTensor>("DistMat");
auto* match_indices = context.Output<Tensor>("ColToRowMatchIndices");
auto* match_dist = context.Output<Tensor>("ColToRowMatchDis");
auto* match_dist = context.Output<Tensor>("ColToRowMatchDist");

auto& dev_ctx = context.device_context<platform::CPUDeviceContext>();

Expand Down Expand Up @@ -148,13 +154,13 @@ class BipartiteMatchOpMaker : public framework::OpProtoAndCheckerMaker {
"Otherwise, it means B[j] is matched to row "
"ColToRowMatchIndices[i][j] in i-th instance. The row number of "
"i-th instance is saved in ColToRowMatchIndices[i][j].");
AddOutput("ColToRowMatchDis",
AddOutput("ColToRowMatchDist",
"(Tensor) A 2-D Tensor with shape [N, M] in float type. "
"N is batch size. If ColToRowMatchIndices[i][j] is -1, "
"ColToRowMatchDis[i][j] is also -1.0. Otherwise, assumed "
"ColToRowMatchDist[i][j] is also -1.0. Otherwise, assumed "
"ColToRowMatchIndices[i][j] = d, and the row offsets of each "
"instance are called LoD. Then "
"ColToRowMatchDis[i][j] = DistMat[d+LoD[i]][j]");
"ColToRowMatchDist[i][j] = DistMat[d+LoD[i]][j]");
AddComment(R"DOC(
This operator is a greedy bipartite matching algorithm, which is used to
obtain the matching with the maximum distance based on the input
Expand Down
Loading

0 comments on commit c9ef69b

Please sign in to comment.