Skip to content

Commit

Permalink
add sparsercnn (PaddlePaddle#3623)
Browse files Browse the repository at this point in the history
* add sparsercnn

* update sparsercnn
  • Loading branch information
FL77N authored Jul 7, 2021
1 parent bb84609 commit 841f2f4
Show file tree
Hide file tree
Showing 8 changed files with 1,010 additions and 2 deletions.
27 changes: 26 additions & 1 deletion ppdet/data/transform/batch_operators.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@

__all__ = [
'PadBatch', 'BatchRandomResize', 'Gt2YoloTarget', 'Gt2FCOSTarget',
'Gt2TTFTarget', 'Gt2Solov2Target'
'Gt2TTFTarget', 'Gt2Solov2Target', 'Gt2SparseRCNNTarget'
]


Expand Down Expand Up @@ -746,3 +746,28 @@ def __call__(self, samples, context=None):
data['grid_order{}'.format(idx)] = gt_grid_order

return samples


@register_op
class Gt2SparseRCNNTarget(BaseOperator):
'''
Generate SparseRCNN targets by groud truth data
'''

def __init__(self):
super(Gt2SparseRCNNTarget, self).__init__()

def __call__(self, samples, context=None):
for sample in samples:
im = sample["image"]
h, w = im.shape[1:3]
img_whwh = np.array([w, h, w, h], dtype=np.int32)
sample["img_whwh"] = img_whwh
if "scale_factor" in sample:
sample["scale_factor_wh"] = np.array([sample["scale_factor"][1], sample["scale_factor"][0]],
dtype=np.float32)
sample.pop("scale_factor")
else:
sample["scale_factor_wh"] = np.array([1.0, 1.0], dtype=np.float32)

return samples
2 changes: 2 additions & 0 deletions ppdet/modeling/architectures/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
from . import fairmot
from . import centernet
from . import detr
from . import sparse_rcnn

from .meta_arch import *
from .faster_rcnn import *
Expand All @@ -41,3 +42,4 @@
from .centernet import *
from .blazeface import *
from .detr import *
from .sparse_rcnn import *
99 changes: 99 additions & 0 deletions ppdet/modeling/architectures/sparse_rcnn.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,99 @@
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

from ppdet.core.workspace import register, create
from .meta_arch import BaseArch

__all__ = ["SparseRCNN"]


@register
class SparseRCNN(BaseArch):
__category__ = 'architecture'
__inject__ = ["postprocess"]

def __init__(self,
backbone,
neck,
head="SparsercnnHead",
postprocess="SparsePostProcess"):
super(SparseRCNN, self).__init__()
self.backbone = backbone
self.neck = neck
self.head = head
self.postprocess = postprocess

@classmethod
def from_config(cls, cfg, *args, **kwargs):
backbone = create(cfg['backbone'])

kwargs = {'input_shape': backbone.out_shape}
neck = create(cfg['neck'], **kwargs)

kwargs = {'roi_input_shape': neck.out_shape}
head = create(cfg['head'], **kwargs)

return {
'backbone': backbone,
'neck': neck,
"head": head,
}

def _forward(self):
body_feats = self.backbone(self.inputs)
fpn_feats = self.neck(body_feats)
head_outs = self.head(fpn_feats, self.inputs["img_whwh"])

if not self.training:
bboxes = self.postprocess(
head_outs["pred_logits"], head_outs["pred_boxes"],
self.inputs["scale_factor_wh"], self.inputs["img_whwh"])
return bboxes
else:
return head_outs

def get_loss(self):
batch_gt_class = self.inputs["gt_class"]
batch_gt_box = self.inputs["gt_bbox"]
batch_whwh = self.inputs["img_whwh"]
targets = []

for i in range(len(batch_gt_class)):
boxes = batch_gt_box[i]
labels = batch_gt_class[i].squeeze(-1)
img_whwh = batch_whwh[i]
img_whwh_tgt = img_whwh.unsqueeze(0).tile([int(boxes.shape[0]), 1])
targets.append({
"boxes": boxes,
"labels": labels,
"img_whwh": img_whwh,
"img_whwh_tgt": img_whwh_tgt
})

outputs = self._forward()
loss_dict = self.head.get_loss(outputs, targets)
acc = loss_dict["acc"]
loss_dict.pop("acc")
total_loss = sum(loss_dict.values())
loss_dict.update({"loss": total_loss, "acc": acc})
return loss_dict

def get_pred(self):
bbox_pred, bbox_num = self._forward()
output = {'bbox': bbox_pred, 'bbox_num': bbox_num}
return output
2 changes: 2 additions & 0 deletions ppdet/modeling/heads/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
from . import keypoint_hrhrnet_head
from . import centernet_head
from . import detr_head
from . import sparsercnn_head

from .bbox_head import *
from .mask_head import *
Expand All @@ -41,3 +42,4 @@
from .keypoint_hrhrnet_head import *
from .centernet_head import *
from .detr_head import *
from .sparsercnn_head import *
Loading

0 comments on commit 841f2f4

Please sign in to comment.