From eb620b07cb8760321dfdf9c30efe882ae1d5db2a Mon Sep 17 00:00:00 2001 From: wanghaox Date: Sun, 28 Jan 2018 13:54:56 +0800 Subject: [PATCH] add target_confidence_assign_op --- .../operators/target_confidence_assign_op.cc | 135 ++++++++++++++ .../operators/target_confidence_assign_op.h | 100 +++++++++++ .../tests/test_target_confidence_assign_op.py | 168 ++++++++++++++++++ 3 files changed, 403 insertions(+) create mode 100755 paddle/operators/target_confidence_assign_op.cc create mode 100644 paddle/operators/target_confidence_assign_op.h create mode 100755 python/paddle/v2/fluid/tests/test_target_confidence_assign_op.py diff --git a/paddle/operators/target_confidence_assign_op.cc b/paddle/operators/target_confidence_assign_op.cc new file mode 100755 index 0000000000000..a21b623ca6e34 --- /dev/null +++ b/paddle/operators/target_confidence_assign_op.cc @@ -0,0 +1,135 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/operators/target_confidence_assign_op.h" + +namespace paddle { +namespace operators { + +class TargetConfidenceAssignOp : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + void InferShape(framework::InferShapeContext* ctx) const override { + PADDLE_ENFORCE( + ctx->HasInput("Conf"), + "Input(Conf) of TargetConfidenceAssignOp should not be null"); + PADDLE_ENFORCE( + ctx->HasInput("GTLabels"), + "Input(GTLabels) of TargetConfidenceAssignOp should not be null"); + PADDLE_ENFORCE(ctx->HasInput("MatchIndices"), + "Input(MatchIndices) of TargetConfidenceAssignOp should " + "not be null"); + PADDLE_ENFORCE( + ctx->HasInput("NegIndices"), + "Input(NegIndices) of TargetConfidenceAssignOp should not be null"); + + PADDLE_ENFORCE( + ctx->HasOutput("ConfGT"), + "Output(ConfGT) of TargetConfidenceAssignOp should not be null."); + PADDLE_ENFORCE( + ctx->HasOutput("ConfPred"), + "Output(ConfPred) of TargetConfidenceAssignOp should not be null."); + + auto conf_dims = ctx->GetInputDim("Conf"); + auto gt_dims = ctx->GetInputDim("GTLabels"); + auto mi_dims = ctx->GetInputDim("MatchIndices"); + auto neg_dims = ctx->GetInputDim("NegIndices"); + PADDLE_ENFORCE_EQ(conf_dims.size(), 3UL, + "The rank of Input(Conf) must be 3, the shape is " + "[batch_size, prior_box_num, class_num]."); + PADDLE_ENFORCE_EQ(gt_dims.size(), 2UL, + "The rank of Input(GTLabels) must be 2, the shape is " + "[N, 1]."); + PADDLE_ENFORCE_EQ(mi_dims.size(), 2UL, + "The rank of Input(MatchIndices) must be 2, the shape is " + "[batch_size, prior_box_num]."); + PADDLE_ENFORCE_EQ(neg_dims.size(), 2UL, + "The rank of Input(NegIndices) must be 2, the shape is " + "[N, 1]."); + + PADDLE_ENFORCE_EQ(conf_dims[0], mi_dims[0], + "The batch_size of Input(Conf) and " + "Input(MatchIndices) must be the same."); + + PADDLE_ENFORCE_EQ(conf_dims[1], mi_dims[1], + "The prior_box_num of Input(Loc) and " + "Input(MatchIndices) must be the same."); + PADDLE_ENFORCE_EQ(gt_dims[1], 1UL, + "The shape of Input(GTLabels) is [N, 1]."); + PADDLE_ENFORCE_EQ(neg_dims[1], 1UL, + "The shape of Input(NegIndices) is [Nneg, 1]."); + } + + protected: + framework::OpKernelType GetExpectedKernelType( + const framework::ExecutionContext& ctx) const override { + return framework::OpKernelType( + framework::ToDataType(ctx.Input("Conf")->type()), + ctx.device_context()); + } +}; + +class TargetConfidenceAssignOpMaker : public framework::OpProtoAndCheckerMaker { + public: + TargetConfidenceAssignOpMaker(OpProto* proto, OpAttrChecker* op_checker) + : OpProtoAndCheckerMaker(proto, op_checker) { + AddInput("Conf", + "(Tensor, default Tensor), The input confidence " + "predictions."); + AddInput( + "GTLabels", + "(LoDTensor, default LoDTensor), The input ground-truth labels."); + AddInput("MatchIndices", + "(LoDTensor, default LoDTensor), The input matched indices, " + "When it's equal to -1, it doesn't match any entity."); + AddInput("NegIndices", + "(LoDTensor, default LoDTensor), The input negative example " + "indics."); + AddOutput("ConfGT", + "(LoDTensor), The output ground-truth labels filtered by " + "MatchIndices and append NegIndices examples."); + AddOutput("ConfPred", + "(LoDTensor), The output confidence predictions filtered by " + "MatchIndices and append NegIndices examples."); + AddAttr("background_label_id", + "(int, default 0), Label id for background class.") + .SetDefault(0); + AddComment(R"DOC( +TargetConfidenceAssign operator + +Filter ground-truth labels when the corresponding MatchIndices is not -1, + and append negative examples with label background_label_id, + it produces the output ConfGT. + Filter confidence predictions when the corresponding MatchIndices is not -1, + and append negative examples' confidence prediction. + it produces the output ConfPred. + + )DOC"); + } +}; + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +REGISTER_OP_WITHOUT_GRADIENT(target_confidence_assign, + ops::TargetConfidenceAssignOp, + ops::TargetConfidenceAssignOpMaker); +REGISTER_OP_CPU_KERNEL( + target_confidence_assign, + ops::TargetConfidenceAssignOpKernel, + ops::TargetConfidenceAssignOpKernel); diff --git a/paddle/operators/target_confidence_assign_op.h b/paddle/operators/target_confidence_assign_op.h new file mode 100644 index 0000000000000..c1f6d85d74159 --- /dev/null +++ b/paddle/operators/target_confidence_assign_op.h @@ -0,0 +1,100 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once +#include "paddle/framework/eigen.h" +#include "paddle/framework/op_registry.h" + +namespace paddle { +namespace operators { + +template +class TargetConfidenceAssignOpKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { + auto* in_conf = ctx.Input("Conf"); + auto* in_gt_labels = ctx.Input("GTLabels"); + auto* in_match_indices = ctx.Input("MatchIndices"); + auto* in_neg_indices = ctx.Input("NegIndices"); + + auto* out_conf_gt = ctx.Output("ConfGT"); + auto* out_conf_pred = ctx.Output("ConfPred"); + int background_label_id = ctx.Attr("background_label_id"); + + auto in_conf_dim = in_conf->dims(); + auto gt_lod = in_gt_labels->lod(); + auto neg_indices_lod = in_neg_indices->lod(); + int batch_size = in_conf_dim[0]; + int prior_num = in_conf_dim[1]; + int class_num = in_conf_dim[2]; + + auto conf = framework::EigenTensor::From(*in_conf); + auto gt_labels = framework::EigenTensor::From(*in_gt_labels); + auto match_indices = + framework::EigenTensor::From(*in_match_indices); + auto neg_indices = framework::EigenTensor::From(*in_neg_indices); + + int match_num = 0; + int neg_num = in_neg_indices->dims()[0]; + for (int n = 0; n < batch_size; ++n) { + for (int p = 0; p < prior_num; ++p) { + if (match_indices(n, p) != -1) match_num++; + } + } + + framework::LoD out_lod; + out_lod.resize(1); + out_lod[0].push_back(0); + out_conf_gt->mutable_data( + framework::make_ddim({match_num + neg_num, 1}), ctx.GetPlace()); + out_conf_pred->mutable_data( + framework::make_ddim({match_num + neg_num, class_num}), ctx.GetPlace()); + + auto conf_gt = framework::EigenTensor::From(*out_conf_gt); + auto conf_pred = framework::EigenTensor::From(*out_conf_pred); + + int count = 0; + for (int n = 0; n < batch_size; ++n) { + for (int p = 0; p < prior_num; ++p) { + int idx = match_indices(n, p); + if (idx == -1) continue; + int gt_start = gt_lod[0][n]; + int gt_offset = gt_start + idx; + int label = gt_labels(gt_offset); + conf_gt(count) = label; + for (int c = 0; c < class_num; ++c) { + conf_pred(count, c) = conf(n, p, c); + } + count += 1; + } + + int neg_start = neg_indices_lod[0][n]; + int neg_end = neg_indices_lod[0][n + 1]; + for (int ne = neg_start; ne < neg_end; ++ne) { + int idx = neg_indices(ne); + conf_gt(count) = background_label_id; + for (int c = 0; c < class_num; ++c) { + conf_pred(count, c) = conf(n, idx, c); + } + count += 1; + } + out_lod[0].push_back(count); + } + out_conf_gt->set_lod(out_lod); + out_conf_pred->set_lod(out_lod); + } +}; + +} // namespace operators +} // namespace paddle diff --git a/python/paddle/v2/fluid/tests/test_target_confidence_assign_op.py b/python/paddle/v2/fluid/tests/test_target_confidence_assign_op.py new file mode 100755 index 0000000000000..23937c3042b4e --- /dev/null +++ b/python/paddle/v2/fluid/tests/test_target_confidence_assign_op.py @@ -0,0 +1,168 @@ +# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserve. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest +import numpy as np +import math +import sys +import random +from op_test import OpTest + + +class TestTargetConfidenceAssginOp(OpTest): + def set_data(self): + self.init_test_case() + + self.inputs = { + 'Conf': self.conf_data, + 'GTLabels': (self.gt_labels, self.gt_labels_lod), + 'MatchIndices': self.match_indices, + 'NegIndices': (self.neg_indices, self.neg_indices_lod) + } + + self.attrs = {'background_label_id': self.background_label_id} + + self.outputs = { + 'ConfGT': self.conf_gt_data, + 'ConfPred': self.conf_pred_data + } + + def init_test_case(self): + self.batch_size = 5 + self.prior_num = 32 + self.cls_num = 1 + self.gt_labels_num = 10 + self.gt_labels_lod = [[0, 2, 5, 7, 9, 10]] + self.neg_indices_num = 10 + self.neg_indices_lod = [[0, 2, 5, 7, 9, 10]] + + # Only support do_neg_mining=True and loss_type=softmax + self.do_neg_mining = True + self.loss_type = 'softmax' # softmax or logistic + self.background_label_id = 0 + + self.init_input_data() + self.conf_gt_data, self.conf_pred_data = self.calc_confidence_assign() + + def init_input_data(self): + # [batch_size, prior_num, cls_num] + self.conf_data = np.random.random( + (self.batch_size, self.prior_num, self.cls_num)).astype('float32') + + # [gt_labels_num, 1] + self.gt_labels = np.random.random_integers( + 0, high=self.cls_num - 1, + size=(self.gt_labels_num, 1)).astype('int32') + + # match_indices[n, p] = gt_box_index + self.match_indices = np.zeros( + (self.batch_size, self.prior_num)).astype('int32') + + self.neg_indices = np.zeros((self.neg_indices_num, 1)).astype('int32') + + for n in range(self.batch_size): + gt_start = self.gt_labels_lod[0][n] + gt_end = self.gt_labels_lod[0][n + 1] + gt_num = gt_end - gt_start + for p in range(self.prior_num): + self.match_indices[n, p] = random.randint(-1, gt_num - 1) + + neg_start = self.neg_indices_lod[0][n] + neg_end = self.neg_indices_lod[0][n + 1] + for i in range(neg_start, neg_end): + self.neg_indices[i] = random.randint(0, self.prior_num - 1) + + def calc_confidence_assign(self, + do_neg_mining=True, + conf_loss_type='softmax'): + background_label_id = self.background_label_id + target_lod = [0] + count = 0 + + num_matches = 0 + num_negs = self.neg_indices_num + for i in range(self.batch_size): + for j in range(self.prior_num): + if self.match_indices[i, j] != -1: + num_matches += 1 + neg_start = self.neg_indices_lod[0][i] + neg_end = self.neg_indices_lod[0][i + 1] + + if do_neg_mining: + num_conf = num_matches + num_negs + else: + num_conf = self.batch_size * self.prior_num + + if conf_loss_type == 'softmax': + conf_gt_data = np.zeros((num_conf, 1)).astype('int32') + conf_pred_data = np.zeros( + (num_conf, self.cls_num)).astype('float32') + elif conf_loss_type == 'logistic': + conf_gt_data = np.zeros((num_conf, self.cls_num)).astype('int32') + conf_pred_data = np.zeros( + (num_conf, self.cls_num)).astype('float32') + + for i in range(self.batch_size): + for j in range(self.prior_num): + gt_idx = self.match_indices[i, j] + if gt_idx == -1: continue + gt_start = self.gt_labels_lod[0][i] + gt_idx = gt_idx + gt_start + gt_label = self.gt_labels[gt_idx] + + if do_neg_mining: + idx = count + else: + idx = j + + if conf_loss_type == 'softmax': + conf_gt_data[idx] = gt_label + elif conf_loss_type == 'logistic': + conf_gt_data[idx, gt_label] = 1 + + if do_neg_mining: + conf_pred_data[idx, :] = self.conf_data[i, j, :] + count += 1 + + # Go to next image. + if do_neg_mining: + neg_start = self.neg_indices_lod[0][i] + neg_end = self.neg_indices_lod[0][i + 1] + for ne in range(neg_start, neg_end): + idx = self.neg_indices[ne] + conf_pred_data[count, :] = self.conf_data[i, idx, :] + if conf_loss_type == 'softmax': + conf_gt_data[count, 0] = background_label_id + elif conf_loss_type == 'logistic': + conf_gt_data[count, background_label_id] = 1 + count += 1 + + if do_neg_mining: + target_lod.append(count) + + if do_neg_mining: + return (conf_gt_data, [target_lod]), (conf_pred_data, [target_lod]) + else: + return conf_gt_data, conf_pred_data + + def setUp(self): + self.op_type = "target_confidence_assign" + self.set_data() + + def test_check_output(self): + self.check_output() + + +if __name__ == '__main__': + unittest.main()