Skip to content

Commit

Permalink
Split per feature level.
Browse files Browse the repository at this point in the history
  • Loading branch information
datumbox committed Oct 18, 2020
1 parent 1001612 commit f131bbe
Show file tree
Hide file tree
Showing 2 changed files with 55 additions and 27 deletions.
Binary file modified test/expect/ModelTester.test_retinanet_resnet50_fpn_expect.pkl
Binary file not shown.
82 changes: 55 additions & 27 deletions torchvision/models/detection/retinanet.py
Original file line number Diff line number Diff line change
Expand Up @@ -408,44 +408,56 @@ def compute_loss(self, targets, head_outputs, anchors):
return self.head.compute_loss(targets, head_outputs, anchors, matched_idxs)

def postprocess_detections(self, head_outputs, anchors, image_shapes):
# type: (Dict[str, Tensor], List[Tensor], List[Tuple[int, int]]) -> List[Dict[str, Tensor]]
# TODO: confirm that RetinaNet can't have other outputs like masks
# type: (Dict[str, List[Tensor]], List[List[Tensor]], List[Tuple[int, int]]) -> List[Dict[str, Tensor]]
class_logits = head_outputs['cls_logits']
box_regression = head_outputs['bbox_regression']

num_classes = class_logits.shape[-1]

scores = torch.sigmoid(class_logits)
num_images = len(image_shapes)

detections = torch.jit.annotate(List[Dict[str, Tensor]], [])

for index, (box_regression_per_image, scores_per_image, anchors_per_image, image_shape) in \
enumerate(zip(box_regression, scores, anchors, image_shapes)):
# remove low scoring boxes
scores_per_image = scores_per_image.flatten()
keep_idxs = scores_per_image > self.score_thresh
scores_per_image = scores_per_image[keep_idxs]
topk_idxs = torch.where(keep_idxs)[0]
for index in range(num_images):
box_regression_per_image = [br[index] for br in box_regression]
logits_per_image = [cl[index] for cl in class_logits]
anchors_per_image, image_shape = anchors[index], image_shapes[index]

image_boxes = []
image_scores = []
image_labels = []

for box_regression_per_level, logits_per_level, anchors_per_level in \
zip(box_regression_per_image, logits_per_image, anchors_per_image):
num_classes = logits_per_level.shape[-1]

# keep only topk scoring predictions
num_topk = min(self.detections_per_img, topk_idxs.size(0))
scores_per_image, idxs = scores_per_image.topk(num_topk)
topk_idxs = topk_idxs[idxs]
# remove low scoring boxes
scores_per_level = torch.sigmoid(logits_per_level).flatten()
keep_idxs = scores_per_level > self.score_thresh
scores_per_level = scores_per_level[keep_idxs]
topk_idxs = torch.where(keep_idxs)[0]

anchor_idxs = topk_idxs // num_classes
labels_per_image = topk_idxs % num_classes
# keep only topk scoring predictions
num_topk = min(self.detections_per_img, topk_idxs.size(0))
scores_per_level, idxs = scores_per_level.topk(num_topk)
topk_idxs = topk_idxs[idxs]

boxes_per_image = self.box_coder.decode_single(box_regression_per_image[anchor_idxs],
anchors_per_image[anchor_idxs])
boxes_per_image = box_ops.clip_boxes_to_image(boxes_per_image, image_shape)
anchor_idxs = topk_idxs // num_classes
labels_per_level = topk_idxs % num_classes

# non-maximum suppression
keep = box_ops.batched_nms(boxes_per_image, scores_per_image, labels_per_image, self.nms_thresh)
boxes_per_level = self.box_coder.decode_single(box_regression_per_level[anchor_idxs],
anchors_per_level[anchor_idxs])
boxes_per_level = box_ops.clip_boxes_to_image(boxes_per_level, image_shape)

# non-maximum suppression
keep = box_ops.batched_nms(boxes_per_level, scores_per_level, labels_per_level, self.nms_thresh)

image_boxes.append(boxes_per_level[keep])
image_scores.append(scores_per_level[keep])
image_labels.append(labels_per_level[keep])

detections.append({
'boxes': boxes_per_image[keep],
'scores': scores_per_image[keep],
'labels': labels_per_image[keep],
'boxes': torch.cat(image_boxes, dim=0),
'scores': torch.cat(image_scores, dim=0),
'labels': torch.cat(image_labels, dim=0),
})

return detections
Expand Down Expand Up @@ -526,8 +538,24 @@ def forward(self, images, targets=None):
# compute the losses
losses = self.compute_loss(targets, head_outputs, anchors)
else:
# recover level sizes
feature_sizes_per_level = [x.size(2) * x.size(3) for x in features]
HW = 0
for v in feature_sizes_per_level:
HW += v
HWA = head_outputs['cls_logits'].size(1)
A = HWA // HW
feature_sizes_per_level = [hw * A for hw in feature_sizes_per_level]

# split outputs per level
split_head_outputs: Dict[str, List[Tensor]] = {}
for k in head_outputs:
split_head_outputs[k] = [x.permute(1, 0, 2) for x in
head_outputs[k].permute(1, 0, 2).split_with_sizes(feature_sizes_per_level)]
split_anchors = [list(a.split_with_sizes(feature_sizes_per_level)) for a in anchors]

# compute the detections
detections = self.postprocess_detections(head_outputs, anchors, images.image_sizes)
detections = self.postprocess_detections(split_head_outputs, split_anchors, images.image_sizes)
detections = self.transform.postprocess(detections, images.image_sizes, original_image_sizes)

if torch.jit.is_scripting():
Expand Down

0 comments on commit f131bbe

Please sign in to comment.