Skip to content

Commit

Permalink
RoIHeads: remove low-scoring boxes statically
Browse files Browse the repository at this point in the history
One of the postprocessing step of RoIHeads is to filter low-scoring (in
terms of confidence level) boxes from the output. This commit makes that
operation static for TPUs.

For now, it's needed to make it static manually as the experimental
support for dynamic nonzero/masked_select would raise the following
issue in this case:

RuntimeError: Internal: From /job:tpu_worker/replica:0/task:0:
RET_CHECK failure (learning/brain/google/xla/tpu_execute.cc:1066) size >= 0
	 [[{{node XRTExecute}}]]
  • Loading branch information
sprt committed Dec 11, 2019
1 parent 0c464f6 commit 9f7946a
Showing 1 changed file with 7 additions and 2 deletions.
9 changes: 7 additions & 2 deletions torchvision/models/detection/roi_heads.py
Original file line number Diff line number Diff line change
Expand Up @@ -595,8 +595,13 @@ def postprocess_detections(self, class_logits, box_regression, proposals, image_
labels = labels.reshape(-1)

# remove low scoring boxes
inds = torch.nonzero(scores > self.score_thresh).squeeze(1)
boxes, scores, labels = boxes[inds], scores[inds], labels[inds]
score_mask = (scores > self.score_thresh)
score_mask_boxes = score_mask.reshape(boxes.shape[0], -1)
nan_boxes = torch.full((1, boxes.shape[1],), float('nan'), device=device)
boxes = torch.where(score_mask_boxes, boxes, nan_boxes)
# pad with zeros so that nms will consider those boxes last
scores = torch.where(score_mask, scores, torch.tensor(0.0, device=device))
labels = torch.where(score_mask, labels, torch.tensor(-1, device=device))

# remove empty boxes
keep = box_ops.remove_small_boxes(boxes, min_size=1e-2)
Expand Down

0 comments on commit 9f7946a

Please sign in to comment.