From 8ebfd2f5d5f1792ce2cf5a2329320f604530a68e Mon Sep 17 00:00:00 2001 From: Vasilis Vryniotis Date: Thu, 14 Jan 2021 10:45:54 +0000 Subject: [PATCH] Improve speed/accuracy of FasterRCNN by introducing a score threshold on RPN (#3205) * Introduce small score threshold on rpn * Adding docs and fixing keypoint and mask. * Making value 0.0 by default for BC. * Fixing for onnx. * Update threshold. * Removing non-default threshold from reference scripts. Co-authored-by: Francisco Massa --- references/detection/train.py | 7 +++++-- test/test_models_detection_negative_samples.py | 2 +- test/test_onnx.py | 7 ++++++- torchvision/models/detection/_utils.py | 8 ++++++-- torchvision/models/detection/faster_rcnn.py | 6 +++++- torchvision/models/detection/keypoint_rcnn.py | 4 ++++ torchvision/models/detection/mask_rcnn.py | 4 ++++ torchvision/models/detection/rpn.py | 17 +++++++++++++++-- 8 files changed, 46 insertions(+), 9 deletions(-) diff --git a/references/detection/train.py b/references/detection/train.py index d46e832a2ff..f3fe9bc9fff 100644 --- a/references/detection/train.py +++ b/references/detection/train.py @@ -92,8 +92,11 @@ def main(args): collate_fn=utils.collate_fn) print("Creating model") - model = torchvision.models.detection.__dict__[args.model](num_classes=num_classes, - pretrained=args.pretrained) + kwargs = {} + if "rcnn" in args.model: + kwargs["rpn_score_thresh"] = 0.0 + model = torchvision.models.detection.__dict__[args.model](num_classes=num_classes, pretrained=args.pretrained, + **kwargs) model.to(device) model_without_ddp = model diff --git a/test/test_models_detection_negative_samples.py b/test/test_models_detection_negative_samples.py index 6d767971f72..650a565cdea 100644 --- a/test/test_models_detection_negative_samples.py +++ b/test/test_models_detection_negative_samples.py @@ -44,7 +44,7 @@ def test_targets_to_anchors(self): rpn_anchor_generator, rpn_head, 0.5, 0.3, 256, 0.5, - 2000, 2000, 0.7) + 2000, 2000, 0.7, 0.05) labels, matched_gt_boxes = head.assign_targets_to_anchors(anchors, targets) diff --git a/test/test_onnx.py b/test/test_onnx.py index 975cea7a58f..b2a7624fc61 100644 --- a/test/test_onnx.py +++ b/test/test_onnx.py @@ -1,3 +1,4 @@ +from common_utils import set_rng_seed import io import torch from torchvision import ops @@ -197,12 +198,14 @@ def _init_test_rpn(self): rpn_pre_nms_top_n = dict(training=2000, testing=1000) rpn_post_nms_top_n = dict(training=2000, testing=1000) rpn_nms_thresh = 0.7 + rpn_score_thresh = 0.0 rpn = RegionProposalNetwork( rpn_anchor_generator, rpn_head, rpn_fg_iou_thresh, rpn_bg_iou_thresh, rpn_batch_size_per_image, rpn_positive_fraction, - rpn_pre_nms_top_n, rpn_post_nms_top_n, rpn_nms_thresh) + rpn_pre_nms_top_n, rpn_post_nms_top_n, rpn_nms_thresh, + score_thresh=rpn_score_thresh) return rpn def _init_test_roi_heads_faster_rcnn(self): @@ -255,6 +258,8 @@ def get_features(self, images): return features def test_rpn(self): + set_rng_seed(0) + class RPNModule(torch.nn.Module): def __init__(self_module): super(RPNModule, self_module).__init__() diff --git a/torchvision/models/detection/_utils.py b/torchvision/models/detection/_utils.py index b8948d95f82..a3299bcf301 100644 --- a/torchvision/models/detection/_utils.py +++ b/torchvision/models/detection/_utils.py @@ -173,10 +173,14 @@ def decode(self, rel_codes, boxes): box_sum = 0 for val in boxes_per_image: box_sum += val + if box_sum > 0: + rel_codes = rel_codes.reshape(box_sum, -1) pred_boxes = self.decode_single( - rel_codes.reshape(box_sum, -1), concat_boxes + rel_codes, concat_boxes ) - return pred_boxes.reshape(box_sum, -1, 4) + if box_sum > 0: + pred_boxes = pred_boxes.reshape(box_sum, -1, 4) + return pred_boxes def decode_single(self, rel_codes, boxes): """ diff --git a/torchvision/models/detection/faster_rcnn.py b/torchvision/models/detection/faster_rcnn.py index 7d896d5ec95..e42680d682d 100644 --- a/torchvision/models/detection/faster_rcnn.py +++ b/torchvision/models/detection/faster_rcnn.py @@ -79,6 +79,8 @@ class FasterRCNN(GeneralizedRCNN): for computing the loss rpn_positive_fraction (float): proportion of positive anchors in a mini-batch during training of the RPN + rpn_score_thresh (float): during inference, only return proposals with a classification score + greater than rpn_score_thresh box_roi_pool (MultiScaleRoIAlign): the module which crops and resizes the feature maps in the locations indicated by the bounding boxes box_head (nn.Module): module that takes the cropped feature maps as input @@ -153,6 +155,7 @@ def __init__(self, backbone, num_classes=None, rpn_nms_thresh=0.7, rpn_fg_iou_thresh=0.7, rpn_bg_iou_thresh=0.3, rpn_batch_size_per_image=256, rpn_positive_fraction=0.5, + rpn_score_thresh=0.0, # Box parameters box_roi_pool=None, box_head=None, box_predictor=None, box_score_thresh=0.05, box_nms_thresh=0.5, box_detections_per_img=100, @@ -197,7 +200,8 @@ def __init__(self, backbone, num_classes=None, rpn_anchor_generator, rpn_head, rpn_fg_iou_thresh, rpn_bg_iou_thresh, rpn_batch_size_per_image, rpn_positive_fraction, - rpn_pre_nms_top_n, rpn_post_nms_top_n, rpn_nms_thresh) + rpn_pre_nms_top_n, rpn_post_nms_top_n, rpn_nms_thresh, + score_thresh=rpn_score_thresh) if box_roi_pool is None: box_roi_pool = MultiScaleRoIAlign( diff --git a/torchvision/models/detection/keypoint_rcnn.py b/torchvision/models/detection/keypoint_rcnn.py index 44df04819ff..0475994a5a0 100644 --- a/torchvision/models/detection/keypoint_rcnn.py +++ b/torchvision/models/detection/keypoint_rcnn.py @@ -74,6 +74,8 @@ class KeypointRCNN(FasterRCNN): for computing the loss rpn_positive_fraction (float): proportion of positive anchors in a mini-batch during training of the RPN + rpn_score_thresh (float): during inference, only return proposals with a classification score + greater than rpn_score_thresh box_roi_pool (MultiScaleRoIAlign): the module which crops and resizes the feature maps in the locations indicated by the bounding boxes box_head (nn.Module): module that takes the cropped feature maps as input @@ -158,6 +160,7 @@ def __init__(self, backbone, num_classes=None, rpn_nms_thresh=0.7, rpn_fg_iou_thresh=0.7, rpn_bg_iou_thresh=0.3, rpn_batch_size_per_image=256, rpn_positive_fraction=0.5, + rpn_score_thresh=0.0, # Box parameters box_roi_pool=None, box_head=None, box_predictor=None, box_score_thresh=0.05, box_nms_thresh=0.5, box_detections_per_img=100, @@ -204,6 +207,7 @@ def __init__(self, backbone, num_classes=None, rpn_nms_thresh, rpn_fg_iou_thresh, rpn_bg_iou_thresh, rpn_batch_size_per_image, rpn_positive_fraction, + rpn_score_thresh, # Box parameters box_roi_pool, box_head, box_predictor, box_score_thresh, box_nms_thresh, box_detections_per_img, diff --git a/torchvision/models/detection/mask_rcnn.py b/torchvision/models/detection/mask_rcnn.py index 565ef05f4cc..4f065f3f917 100644 --- a/torchvision/models/detection/mask_rcnn.py +++ b/torchvision/models/detection/mask_rcnn.py @@ -75,6 +75,8 @@ class MaskRCNN(FasterRCNN): for computing the loss rpn_positive_fraction (float): proportion of positive anchors in a mini-batch during training of the RPN + rpn_score_thresh (float): during inference, only return proposals with a classification score + greater than rpn_score_thresh box_roi_pool (MultiScaleRoIAlign): the module which crops and resizes the feature maps in the locations indicated by the bounding boxes box_head (nn.Module): module that takes the cropped feature maps as input @@ -158,6 +160,7 @@ def __init__(self, backbone, num_classes=None, rpn_nms_thresh=0.7, rpn_fg_iou_thresh=0.7, rpn_bg_iou_thresh=0.3, rpn_batch_size_per_image=256, rpn_positive_fraction=0.5, + rpn_score_thresh=0.0, # Box parameters box_roi_pool=None, box_head=None, box_predictor=None, box_score_thresh=0.05, box_nms_thresh=0.5, box_detections_per_img=100, @@ -204,6 +207,7 @@ def __init__(self, backbone, num_classes=None, rpn_nms_thresh, rpn_fg_iou_thresh, rpn_bg_iou_thresh, rpn_batch_size_per_image, rpn_positive_fraction, + rpn_score_thresh, # Box parameters box_roi_pool, box_head, box_predictor, box_score_thresh, box_nms_thresh, box_detections_per_img, diff --git a/torchvision/models/detection/rpn.py b/torchvision/models/detection/rpn.py index b50dc08839b..9ea05c94136 100644 --- a/torchvision/models/detection/rpn.py +++ b/torchvision/models/detection/rpn.py @@ -141,7 +141,7 @@ def __init__(self, fg_iou_thresh, bg_iou_thresh, batch_size_per_image, positive_fraction, # - pre_nms_top_n, post_nms_top_n, nms_thresh): + pre_nms_top_n, post_nms_top_n, nms_thresh, score_thresh=0.0): super(RegionProposalNetwork, self).__init__() self.anchor_generator = anchor_generator self.head = head @@ -163,6 +163,7 @@ def __init__(self, self._pre_nms_top_n = pre_nms_top_n self._post_nms_top_n = post_nms_top_n self.nms_thresh = nms_thresh + self.score_thresh = score_thresh self.min_size = 1e-3 def pre_nms_top_n(self): @@ -251,17 +252,29 @@ def filter_proposals(self, proposals, objectness, image_shapes, num_anchors_per_ levels = levels[batch_idx, top_n_idx] proposals = proposals[batch_idx, top_n_idx] + objectness_prob = F.sigmoid(objectness) + final_boxes = [] final_scores = [] - for boxes, scores, lvl, img_shape in zip(proposals, objectness, levels, image_shapes): + for boxes, scores, lvl, img_shape in zip(proposals, objectness_prob, levels, image_shapes): boxes = box_ops.clip_boxes_to_image(boxes, img_shape) + + # remove small boxes keep = box_ops.remove_small_boxes(boxes, self.min_size) boxes, scores, lvl = boxes[keep], scores[keep], lvl[keep] + + # remove low scoring boxes + # use >= for Backwards compatibility + keep = torch.where(scores >= self.score_thresh)[0] + boxes, scores, lvl = boxes[keep], scores[keep], lvl[keep] + # non-maximum suppression, independently done per level keep = box_ops.batched_nms(boxes, scores, lvl, self.nms_thresh) + # keep only topk scoring predictions keep = keep[:self.post_nms_top_n()] boxes, scores = boxes[keep], scores[keep] + final_boxes.append(boxes) final_scores.append(scores) return final_boxes, final_scores