Skip to content

Commit

Permalink
Improve speed/accuracy of FasterRCNN by introducing a score threshold…
Browse files Browse the repository at this point in the history
… on RPN (#3205)

* Introduce small score threshold on rpn

* Adding docs and fixing keypoint and mask.

* Making value 0.0 by default for BC.

* Fixing for onnx.

* Update threshold.

* Removing non-default threshold from reference scripts.

Co-authored-by: Francisco Massa <[email protected]>
  • Loading branch information
datumbox and fmassa authored Jan 14, 2021
1 parent d0063f3 commit 8ebfd2f
Show file tree
Hide file tree
Showing 8 changed files with 46 additions and 9 deletions.
7 changes: 5 additions & 2 deletions references/detection/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,8 +92,11 @@ def main(args):
collate_fn=utils.collate_fn)

print("Creating model")
model = torchvision.models.detection.__dict__[args.model](num_classes=num_classes,
pretrained=args.pretrained)
kwargs = {}
if "rcnn" in args.model:
kwargs["rpn_score_thresh"] = 0.0
model = torchvision.models.detection.__dict__[args.model](num_classes=num_classes, pretrained=args.pretrained,
**kwargs)
model.to(device)

model_without_ddp = model
Expand Down
2 changes: 1 addition & 1 deletion test/test_models_detection_negative_samples.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ def test_targets_to_anchors(self):
rpn_anchor_generator, rpn_head,
0.5, 0.3,
256, 0.5,
2000, 2000, 0.7)
2000, 2000, 0.7, 0.05)

labels, matched_gt_boxes = head.assign_targets_to_anchors(anchors, targets)

Expand Down
7 changes: 6 additions & 1 deletion test/test_onnx.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from common_utils import set_rng_seed
import io
import torch
from torchvision import ops
Expand Down Expand Up @@ -197,12 +198,14 @@ def _init_test_rpn(self):
rpn_pre_nms_top_n = dict(training=2000, testing=1000)
rpn_post_nms_top_n = dict(training=2000, testing=1000)
rpn_nms_thresh = 0.7
rpn_score_thresh = 0.0

rpn = RegionProposalNetwork(
rpn_anchor_generator, rpn_head,
rpn_fg_iou_thresh, rpn_bg_iou_thresh,
rpn_batch_size_per_image, rpn_positive_fraction,
rpn_pre_nms_top_n, rpn_post_nms_top_n, rpn_nms_thresh)
rpn_pre_nms_top_n, rpn_post_nms_top_n, rpn_nms_thresh,
score_thresh=rpn_score_thresh)
return rpn

def _init_test_roi_heads_faster_rcnn(self):
Expand Down Expand Up @@ -255,6 +258,8 @@ def get_features(self, images):
return features

def test_rpn(self):
set_rng_seed(0)

class RPNModule(torch.nn.Module):
def __init__(self_module):
super(RPNModule, self_module).__init__()
Expand Down
8 changes: 6 additions & 2 deletions torchvision/models/detection/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,10 +173,14 @@ def decode(self, rel_codes, boxes):
box_sum = 0
for val in boxes_per_image:
box_sum += val
if box_sum > 0:
rel_codes = rel_codes.reshape(box_sum, -1)
pred_boxes = self.decode_single(
rel_codes.reshape(box_sum, -1), concat_boxes
rel_codes, concat_boxes
)
return pred_boxes.reshape(box_sum, -1, 4)
if box_sum > 0:
pred_boxes = pred_boxes.reshape(box_sum, -1, 4)
return pred_boxes

def decode_single(self, rel_codes, boxes):
"""
Expand Down
6 changes: 5 additions & 1 deletion torchvision/models/detection/faster_rcnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,8 @@ class FasterRCNN(GeneralizedRCNN):
for computing the loss
rpn_positive_fraction (float): proportion of positive anchors in a mini-batch during training
of the RPN
rpn_score_thresh (float): during inference, only return proposals with a classification score
greater than rpn_score_thresh
box_roi_pool (MultiScaleRoIAlign): the module which crops and resizes the feature maps in
the locations indicated by the bounding boxes
box_head (nn.Module): module that takes the cropped feature maps as input
Expand Down Expand Up @@ -153,6 +155,7 @@ def __init__(self, backbone, num_classes=None,
rpn_nms_thresh=0.7,
rpn_fg_iou_thresh=0.7, rpn_bg_iou_thresh=0.3,
rpn_batch_size_per_image=256, rpn_positive_fraction=0.5,
rpn_score_thresh=0.0,
# Box parameters
box_roi_pool=None, box_head=None, box_predictor=None,
box_score_thresh=0.05, box_nms_thresh=0.5, box_detections_per_img=100,
Expand Down Expand Up @@ -197,7 +200,8 @@ def __init__(self, backbone, num_classes=None,
rpn_anchor_generator, rpn_head,
rpn_fg_iou_thresh, rpn_bg_iou_thresh,
rpn_batch_size_per_image, rpn_positive_fraction,
rpn_pre_nms_top_n, rpn_post_nms_top_n, rpn_nms_thresh)
rpn_pre_nms_top_n, rpn_post_nms_top_n, rpn_nms_thresh,
score_thresh=rpn_score_thresh)

if box_roi_pool is None:
box_roi_pool = MultiScaleRoIAlign(
Expand Down
4 changes: 4 additions & 0 deletions torchvision/models/detection/keypoint_rcnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,8 @@ class KeypointRCNN(FasterRCNN):
for computing the loss
rpn_positive_fraction (float): proportion of positive anchors in a mini-batch during training
of the RPN
rpn_score_thresh (float): during inference, only return proposals with a classification score
greater than rpn_score_thresh
box_roi_pool (MultiScaleRoIAlign): the module which crops and resizes the feature maps in
the locations indicated by the bounding boxes
box_head (nn.Module): module that takes the cropped feature maps as input
Expand Down Expand Up @@ -158,6 +160,7 @@ def __init__(self, backbone, num_classes=None,
rpn_nms_thresh=0.7,
rpn_fg_iou_thresh=0.7, rpn_bg_iou_thresh=0.3,
rpn_batch_size_per_image=256, rpn_positive_fraction=0.5,
rpn_score_thresh=0.0,
# Box parameters
box_roi_pool=None, box_head=None, box_predictor=None,
box_score_thresh=0.05, box_nms_thresh=0.5, box_detections_per_img=100,
Expand Down Expand Up @@ -204,6 +207,7 @@ def __init__(self, backbone, num_classes=None,
rpn_nms_thresh,
rpn_fg_iou_thresh, rpn_bg_iou_thresh,
rpn_batch_size_per_image, rpn_positive_fraction,
rpn_score_thresh,
# Box parameters
box_roi_pool, box_head, box_predictor,
box_score_thresh, box_nms_thresh, box_detections_per_img,
Expand Down
4 changes: 4 additions & 0 deletions torchvision/models/detection/mask_rcnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,8 @@ class MaskRCNN(FasterRCNN):
for computing the loss
rpn_positive_fraction (float): proportion of positive anchors in a mini-batch during training
of the RPN
rpn_score_thresh (float): during inference, only return proposals with a classification score
greater than rpn_score_thresh
box_roi_pool (MultiScaleRoIAlign): the module which crops and resizes the feature maps in
the locations indicated by the bounding boxes
box_head (nn.Module): module that takes the cropped feature maps as input
Expand Down Expand Up @@ -158,6 +160,7 @@ def __init__(self, backbone, num_classes=None,
rpn_nms_thresh=0.7,
rpn_fg_iou_thresh=0.7, rpn_bg_iou_thresh=0.3,
rpn_batch_size_per_image=256, rpn_positive_fraction=0.5,
rpn_score_thresh=0.0,
# Box parameters
box_roi_pool=None, box_head=None, box_predictor=None,
box_score_thresh=0.05, box_nms_thresh=0.5, box_detections_per_img=100,
Expand Down Expand Up @@ -204,6 +207,7 @@ def __init__(self, backbone, num_classes=None,
rpn_nms_thresh,
rpn_fg_iou_thresh, rpn_bg_iou_thresh,
rpn_batch_size_per_image, rpn_positive_fraction,
rpn_score_thresh,
# Box parameters
box_roi_pool, box_head, box_predictor,
box_score_thresh, box_nms_thresh, box_detections_per_img,
Expand Down
17 changes: 15 additions & 2 deletions torchvision/models/detection/rpn.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,7 +141,7 @@ def __init__(self,
fg_iou_thresh, bg_iou_thresh,
batch_size_per_image, positive_fraction,
#
pre_nms_top_n, post_nms_top_n, nms_thresh):
pre_nms_top_n, post_nms_top_n, nms_thresh, score_thresh=0.0):
super(RegionProposalNetwork, self).__init__()
self.anchor_generator = anchor_generator
self.head = head
Expand All @@ -163,6 +163,7 @@ def __init__(self,
self._pre_nms_top_n = pre_nms_top_n
self._post_nms_top_n = post_nms_top_n
self.nms_thresh = nms_thresh
self.score_thresh = score_thresh
self.min_size = 1e-3

def pre_nms_top_n(self):
Expand Down Expand Up @@ -251,17 +252,29 @@ def filter_proposals(self, proposals, objectness, image_shapes, num_anchors_per_
levels = levels[batch_idx, top_n_idx]
proposals = proposals[batch_idx, top_n_idx]

objectness_prob = F.sigmoid(objectness)

final_boxes = []
final_scores = []
for boxes, scores, lvl, img_shape in zip(proposals, objectness, levels, image_shapes):
for boxes, scores, lvl, img_shape in zip(proposals, objectness_prob, levels, image_shapes):
boxes = box_ops.clip_boxes_to_image(boxes, img_shape)

# remove small boxes
keep = box_ops.remove_small_boxes(boxes, self.min_size)
boxes, scores, lvl = boxes[keep], scores[keep], lvl[keep]

# remove low scoring boxes
# use >= for Backwards compatibility
keep = torch.where(scores >= self.score_thresh)[0]
boxes, scores, lvl = boxes[keep], scores[keep], lvl[keep]

# non-maximum suppression, independently done per level
keep = box_ops.batched_nms(boxes, scores, lvl, self.nms_thresh)

# keep only topk scoring predictions
keep = keep[:self.post_nms_top_n()]
boxes, scores = boxes[keep], scores[keep]

final_boxes.append(boxes)
final_scores.append(scores)
return final_boxes, final_scores
Expand Down

0 comments on commit 8ebfd2f

Please sign in to comment.