Skip to content

Commit

Permalink
Vectorize RetinaNet's postprocessing (pytorch#2828)
Browse files Browse the repository at this point in the history
* Vectorize operations, across all feaure levels.

* Remove unnecessary other_outputs variable.

* Split per feature level.

* Perform batched_nms across feature levels.

* Add extra parameter for limiting detections before and after nms.

* Restoring default threshold.

* Apply suggestions from code review

Co-authored-by: Francisco Massa <[email protected]>

* Renaming variable.

Co-authored-by: Francisco Massa <[email protected]>
  • Loading branch information
2 people authored and vfdev-5 committed Dec 4, 2020
1 parent e13c1ea commit a35c130
Show file tree
Hide file tree
Showing 3 changed files with 59 additions and 54 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -22,5 +22,6 @@ htmlcov
gen.yml
.mypy_cache
.vscode/
.idea/
*.orig
*-checkpoint.ipynb
Binary file modified test/expect/ModelTester.test_retinanet_resnet50_fpn_expect.pkl
Binary file not shown.
112 changes: 58 additions & 54 deletions torchvision/models/detection/retinanet.py
Original file line number Diff line number Diff line change
Expand Up @@ -291,6 +291,7 @@ class RetinaNet(nn.Module):
considered as positive during training.
bg_iou_thresh (float): maximum IoU between the anchor and the GT box so that they can be
considered as negative during training.
topk_candidates (int): Number of best detections to keep before NMS.
Example:
Expand Down Expand Up @@ -339,7 +340,8 @@ def __init__(self, backbone, num_classes,
score_thresh=0.05,
nms_thresh=0.5,
detections_per_img=300,
fg_iou_thresh=0.5, bg_iou_thresh=0.4):
fg_iou_thresh=0.5, bg_iou_thresh=0.4,
topk_candidates=1000):
super().__init__()

if not hasattr(backbone, "out_channels"):
Expand Down Expand Up @@ -382,6 +384,7 @@ def __init__(self, backbone, num_classes,
self.score_thresh = score_thresh
self.nms_thresh = nms_thresh
self.detections_per_img = detections_per_img
self.topk_candidates = topk_candidates

# used only on torchscript mode
self._has_warned = False
Expand All @@ -408,77 +411,63 @@ def compute_loss(self, targets, head_outputs, anchors):
return self.head.compute_loss(targets, head_outputs, anchors, matched_idxs)

def postprocess_detections(self, head_outputs, anchors, image_shapes):
# type: (Dict[str, Tensor], List[Tensor], List[Tuple[int, int]]) -> List[Dict[str, Tensor]]
# TODO: Merge this with roi_heads.RoIHeads.postprocess_detections ?
# type: (Dict[str, List[Tensor]], List[List[Tensor]], List[Tuple[int, int]]) -> List[Dict[str, Tensor]]
class_logits = head_outputs['cls_logits']
box_regression = head_outputs['bbox_regression']

class_logits = head_outputs.pop('cls_logits')
box_regression = head_outputs.pop('bbox_regression')
other_outputs = head_outputs

device = class_logits.device
num_classes = class_logits.shape[-1]

scores = torch.sigmoid(class_logits)

# create labels for each score
labels = torch.arange(num_classes, device=device)
labels = labels.view(1, -1).expand_as(scores)
num_images = len(image_shapes)

detections = torch.jit.annotate(List[Dict[str, Tensor]], [])

for index, (box_regression_per_image, scores_per_image, labels_per_image, anchors_per_image, image_shape) in \
enumerate(zip(box_regression, scores, labels, anchors, image_shapes)):

boxes_per_image = self.box_coder.decode_single(box_regression_per_image, anchors_per_image)
boxes_per_image = box_ops.clip_boxes_to_image(boxes_per_image, image_shape)

other_outputs_per_image = [(k, v[index]) for k, v in other_outputs.items()]
for index in range(num_images):
box_regression_per_image = [br[index] for br in box_regression]
logits_per_image = [cl[index] for cl in class_logits]
anchors_per_image, image_shape = anchors[index], image_shapes[index]

image_boxes = []
image_scores = []
image_labels = []
image_other_outputs = torch.jit.annotate(Dict[str, List[Tensor]], {})

for class_index in range(num_classes):
for box_regression_per_level, logits_per_level, anchors_per_level in \
zip(box_regression_per_image, logits_per_image, anchors_per_image):
num_classes = logits_per_level.shape[-1]

# remove low scoring boxes
inds = torch.gt(scores_per_image[:, class_index], self.score_thresh)
boxes_per_class, scores_per_class, labels_per_class = \
boxes_per_image[inds], scores_per_image[inds, class_index], labels_per_image[inds, class_index]
other_outputs_per_class = [(k, v[inds]) for k, v in other_outputs_per_image]
scores_per_level = torch.sigmoid(logits_per_level).flatten()
keep_idxs = scores_per_level > self.score_thresh
scores_per_level = scores_per_level[keep_idxs]
topk_idxs = torch.where(keep_idxs)[0]

# remove empty boxes
keep = box_ops.remove_small_boxes(boxes_per_class, min_size=1e-2)
boxes_per_class, scores_per_class, labels_per_class = \
boxes_per_class[keep], scores_per_class[keep], labels_per_class[keep]
other_outputs_per_class = [(k, v[keep]) for k, v in other_outputs_per_class]
# keep only topk scoring predictions
num_topk = min(self.topk_candidates, topk_idxs.size(0))
scores_per_level, idxs = scores_per_level.topk(num_topk)
topk_idxs = topk_idxs[idxs]

# non-maximum suppression, independently done per class
keep = box_ops.nms(boxes_per_class, scores_per_class, self.nms_thresh)
anchor_idxs = topk_idxs // num_classes
labels_per_level = topk_idxs % num_classes

# keep only topk scoring predictions
keep = keep[:self.detections_per_img]
boxes_per_class, scores_per_class, labels_per_class = \
boxes_per_class[keep], scores_per_class[keep], labels_per_class[keep]
other_outputs_per_class = [(k, v[keep]) for k, v in other_outputs_per_class]
boxes_per_level = self.box_coder.decode_single(box_regression_per_level[anchor_idxs],
anchors_per_level[anchor_idxs])
boxes_per_level = box_ops.clip_boxes_to_image(boxes_per_level, image_shape)

image_boxes.append(boxes_per_level)
image_scores.append(scores_per_level)
image_labels.append(labels_per_level)

image_boxes.append(boxes_per_class)
image_scores.append(scores_per_class)
image_labels.append(labels_per_class)
image_boxes = torch.cat(image_boxes, dim=0)
image_scores = torch.cat(image_scores, dim=0)
image_labels = torch.cat(image_labels, dim=0)

for k, v in other_outputs_per_class:
if k not in image_other_outputs:
image_other_outputs[k] = []
image_other_outputs[k].append(v)
# non-maximum suppression
keep = box_ops.batched_nms(image_boxes, image_scores, image_labels, self.nms_thresh)
keep = keep[:self.detections_per_img]

detections.append({
'boxes': torch.cat(image_boxes, dim=0),
'scores': torch.cat(image_scores, dim=0),
'labels': torch.cat(image_labels, dim=0),
'boxes': image_boxes[keep],
'scores': image_scores[keep],
'labels': image_labels[keep],
})

for k, v in image_other_outputs.items():
detections[-1].update({k: torch.cat(v, dim=0)})

return detections

def forward(self, images, targets=None):
Expand Down Expand Up @@ -557,8 +546,23 @@ def forward(self, images, targets=None):
# compute the losses
losses = self.compute_loss(targets, head_outputs, anchors)
else:
# recover level sizes
num_anchors_per_level = [x.size(2) * x.size(3) for x in features]
HW = 0
for v in num_anchors_per_level:
HW += v
HWA = head_outputs['cls_logits'].size(1)
A = HWA // HW
num_anchors_per_level = [hw * A for hw in num_anchors_per_level]

# split outputs per level
split_head_outputs: Dict[str, List[Tensor]] = {}
for k in head_outputs:
split_head_outputs[k] = list(head_outputs[k].split(num_anchors_per_level, dim=1))
split_anchors = [list(a.split(num_anchors_per_level)) for a in anchors]

# compute the detections
detections = self.postprocess_detections(head_outputs, anchors, images.image_sizes)
detections = self.postprocess_detections(split_head_outputs, split_anchors, images.image_sizes)
detections = self.transform.postprocess(detections, images.image_sizes, original_image_sizes)

if torch.jit.is_scripting():
Expand Down

0 comments on commit a35c130

Please sign in to comment.