-
Notifications
You must be signed in to change notification settings - Fork 7k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
RetinaNet object detection (take 2) (#2784)
* Add rough implementation of RetinaNet. * Move AnchorGenerator to a seperate file. * Move box similarity to Matcher. * Expose extra blocks in FPN. * Expose retinanet in __init__.py. * Use P6 and P7 in FPN for retinanet. * Use parameters from retinanet for anchor generation. * General fixes for retinanet model. * Implement loss for retinanet heads. * Output reshaped outputs from retinanet heads. * Add postprocessing of detections. * Small fixes. * Remove unused argument. * Remove python2 invocation of super. * Add postprocessing for additional outputs. * Add missing import of ImageList. * Remove redundant import. * Simplify class correction. * Fix pylint warnings. * Remove the label adjustment for background class. * Set default score threshold to 0.05. * Add weight initialization for regression layer. * Allow training on images with no annotations. * Use smooth_l1_loss with beta value. * Add more typehints for TorchScript conversions. * Fix linting issues. * Fix type hints in postprocess_detections. * Fix type annotations for TorchScript. * Fix inconsistency with matched_idxs. * Add retinanet model test. * Add missing JIT annotations. * Remove redundant model construction Make tests pass * Fix bugs during training on newer PyTorch and unused params in DDP Needs cleanup and to add back support for images with no annotations * Cleanup resnet_fpn_backbone * Use L1 loss for regression Gives 1mAP improvement over smooth l1 * Disable support for images with no annotations Need to fix distributed first * Fix retinanet tests Need to deduplicate those box checks * Fix Lint * Add pretrained model * Add training info for retinanet Co-authored-by: Hans Gaiser <[email protected]> Co-authored-by: Hans Gaiser <[email protected]> Co-authored-by: Hans Gaiser <[email protected]>
- Loading branch information
1 parent
42e7f1f
commit 5bb81c8
Showing
13 changed files
with
884 additions
and
169 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Binary file not shown.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,3 +1,4 @@ | ||
from .faster_rcnn import * | ||
from .mask_rcnn import * | ||
from .keypoint_rcnn import * | ||
from .retinanet import * |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,159 @@ | ||
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. | ||
import torch | ||
from torch import nn | ||
|
||
from torch.jit.annotations import List, Optional, Dict | ||
from .image_list import ImageList | ||
|
||
|
||
class AnchorGenerator(nn.Module): | ||
""" | ||
Module that generates anchors for a set of feature maps and | ||
image sizes. | ||
The module support computing anchors at multiple sizes and aspect ratios | ||
per feature map. This module assumes aspect ratio = height / width for | ||
each anchor. | ||
sizes and aspect_ratios should have the same number of elements, and it should | ||
correspond to the number of feature maps. | ||
sizes[i] and aspect_ratios[i] can have an arbitrary number of elements, | ||
and AnchorGenerator will output a set of sizes[i] * aspect_ratios[i] anchors | ||
per spatial location for feature map i. | ||
Arguments: | ||
sizes (Tuple[Tuple[int]]): | ||
aspect_ratios (Tuple[Tuple[float]]): | ||
""" | ||
|
||
__annotations__ = { | ||
"cell_anchors": Optional[List[torch.Tensor]], | ||
"_cache": Dict[str, List[torch.Tensor]] | ||
} | ||
|
||
def __init__( | ||
self, | ||
sizes=((128, 256, 512),), | ||
aspect_ratios=((0.5, 1.0, 2.0),), | ||
): | ||
super(AnchorGenerator, self).__init__() | ||
|
||
if not isinstance(sizes[0], (list, tuple)): | ||
# TODO change this | ||
sizes = tuple((s,) for s in sizes) | ||
if not isinstance(aspect_ratios[0], (list, tuple)): | ||
aspect_ratios = (aspect_ratios,) * len(sizes) | ||
|
||
assert len(sizes) == len(aspect_ratios) | ||
|
||
self.sizes = sizes | ||
self.aspect_ratios = aspect_ratios | ||
self.cell_anchors = None | ||
self._cache = {} | ||
|
||
# TODO: https://github.com/pytorch/pytorch/issues/26792 | ||
# For every (aspect_ratios, scales) combination, output a zero-centered anchor with those values. | ||
# (scales, aspect_ratios) are usually an element of zip(self.scales, self.aspect_ratios) | ||
# This method assumes aspect ratio = height / width for an anchor. | ||
def generate_anchors(self, scales, aspect_ratios, dtype=torch.float32, device="cpu"): | ||
# type: (List[int], List[float], int, Device) -> Tensor # noqa: F821 | ||
scales = torch.as_tensor(scales, dtype=dtype, device=device) | ||
aspect_ratios = torch.as_tensor(aspect_ratios, dtype=dtype, device=device) | ||
h_ratios = torch.sqrt(aspect_ratios) | ||
w_ratios = 1 / h_ratios | ||
|
||
ws = (w_ratios[:, None] * scales[None, :]).view(-1) | ||
hs = (h_ratios[:, None] * scales[None, :]).view(-1) | ||
|
||
base_anchors = torch.stack([-ws, -hs, ws, hs], dim=1) / 2 | ||
return base_anchors.round() | ||
|
||
def set_cell_anchors(self, dtype, device): | ||
# type: (int, Device) -> None # noqa: F821 | ||
if self.cell_anchors is not None: | ||
cell_anchors = self.cell_anchors | ||
assert cell_anchors is not None | ||
# suppose that all anchors have the same device | ||
# which is a valid assumption in the current state of the codebase | ||
if cell_anchors[0].device == device: | ||
return | ||
|
||
cell_anchors = [ | ||
self.generate_anchors( | ||
sizes, | ||
aspect_ratios, | ||
dtype, | ||
device | ||
) | ||
for sizes, aspect_ratios in zip(self.sizes, self.aspect_ratios) | ||
] | ||
self.cell_anchors = cell_anchors | ||
|
||
def num_anchors_per_location(self): | ||
return [len(s) * len(a) for s, a in zip(self.sizes, self.aspect_ratios)] | ||
|
||
# For every combination of (a, (g, s), i) in (self.cell_anchors, zip(grid_sizes, strides), 0:2), | ||
# output g[i] anchors that are s[i] distance apart in direction i, with the same dimensions as a. | ||
def grid_anchors(self, grid_sizes, strides): | ||
# type: (List[List[int]], List[List[Tensor]]) -> List[Tensor] | ||
anchors = [] | ||
cell_anchors = self.cell_anchors | ||
assert cell_anchors is not None | ||
assert len(grid_sizes) == len(strides) == len(cell_anchors) | ||
|
||
for size, stride, base_anchors in zip( | ||
grid_sizes, strides, cell_anchors | ||
): | ||
grid_height, grid_width = size | ||
stride_height, stride_width = stride | ||
device = base_anchors.device | ||
|
||
# For output anchor, compute [x_center, y_center, x_center, y_center] | ||
shifts_x = torch.arange( | ||
0, grid_width, dtype=torch.float32, device=device | ||
) * stride_width | ||
shifts_y = torch.arange( | ||
0, grid_height, dtype=torch.float32, device=device | ||
) * stride_height | ||
shift_y, shift_x = torch.meshgrid(shifts_y, shifts_x) | ||
shift_x = shift_x.reshape(-1) | ||
shift_y = shift_y.reshape(-1) | ||
shifts = torch.stack((shift_x, shift_y, shift_x, shift_y), dim=1) | ||
|
||
# For every (base anchor, output anchor) pair, | ||
# offset each zero-centered base anchor by the center of the output anchor. | ||
anchors.append( | ||
(shifts.view(-1, 1, 4) + base_anchors.view(1, -1, 4)).reshape(-1, 4) | ||
) | ||
|
||
return anchors | ||
|
||
def cached_grid_anchors(self, grid_sizes, strides): | ||
# type: (List[List[int]], List[List[Tensor]]) -> List[Tensor] | ||
key = str(grid_sizes) + str(strides) | ||
if key in self._cache: | ||
return self._cache[key] | ||
anchors = self.grid_anchors(grid_sizes, strides) | ||
self._cache[key] = anchors | ||
return anchors | ||
|
||
def forward(self, image_list, feature_maps): | ||
# type: (ImageList, List[Tensor]) -> List[Tensor] | ||
grid_sizes = list([feature_map.shape[-2:] for feature_map in feature_maps]) | ||
image_size = image_list.tensors.shape[-2:] | ||
dtype, device = feature_maps[0].dtype, feature_maps[0].device | ||
strides = [[torch.tensor(image_size[0] // g[0], dtype=torch.int64, device=device), | ||
torch.tensor(image_size[1] // g[1], dtype=torch.int64, device=device)] for g in grid_sizes] | ||
self.set_cell_anchors(dtype, device) | ||
anchors_over_all_feature_maps = self.cached_grid_anchors(grid_sizes, strides) | ||
anchors = torch.jit.annotate(List[List[torch.Tensor]], []) | ||
for i, (image_height, image_width) in enumerate(image_list.image_sizes): | ||
anchors_in_image = [] | ||
for anchors_per_feature_map in anchors_over_all_feature_maps: | ||
anchors_in_image.append(anchors_per_feature_map) | ||
anchors.append(anchors_in_image) | ||
anchors = [torch.cat(anchors_per_image) for anchors_per_image in anchors] | ||
# Clear the cache in case that memory leaks. | ||
self._cache.clear() | ||
return anchors |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.