Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

RetinaNet object detection (take 2) #2784

Merged
merged 41 commits into from
Oct 13, 2020
Merged
Show file tree
Hide file tree
Changes from 39 commits
Commits
Show all changes
41 commits
Select commit Hold shift + click to select a range
50f822c
Add rough implementation of RetinaNet.
hgaiser Dec 20, 2019
022f8e1
Move AnchorGenerator to a seperate file.
May 22, 2020
8e0804d
Move box similarity to Matcher.
hgaiser Jan 24, 2020
ad53194
Expose extra blocks in FPN.
hgaiser Jan 24, 2020
2a5a5be
Expose retinanet in __init__.py.
hgaiser Jan 24, 2020
49e990c
Use P6 and P7 in FPN for retinanet.
hgaiser Jan 24, 2020
b5966eb
Use parameters from retinanet for anchor generation.
hgaiser Jan 24, 2020
aab1b28
General fixes for retinanet model.
hgaiser Jan 24, 2020
c078114
Implement loss for retinanet heads.
hgaiser Jan 31, 2020
eae4ee5
Output reshaped outputs from retinanet heads.
hgaiser Feb 2, 2020
3dac477
Add postprocessing of detections.
hgaiser Feb 7, 2020
9981a3c
Small fixes.
hgaiser Mar 20, 2020
5571dfe
Remove unused argument.
hgaiser Apr 3, 2020
fc7751b
Remove python2 invocation of super.
hgaiser Apr 4, 2020
b942648
Add postprocessing for additional outputs.
hgaiser Apr 4, 2020
b619936
Add missing import of ImageList.
hgaiser Apr 17, 2020
8c86588
Remove redundant import.
hgaiser Apr 17, 2020
2934f0d
Simplify class correction.
hgaiser Apr 17, 2020
32b8e77
Fix pylint warnings.
hgaiser Apr 17, 2020
437bfe9
Remove the label adjustment for background class.
Apr 17, 2020
9e810d6
Set default score threshold to 0.05.
Apr 17, 2020
f7d8c2e
Add weight initialization for regression layer.
Apr 24, 2020
d86c437
Allow training on images with no annotations.
Apr 27, 2020
72e46f2
Use smooth_l1_loss with beta value.
Apr 27, 2020
41c90fa
Add more typehints for TorchScript conversions.
hgaiser May 15, 2020
b9daa86
Fix linting issues.
hgaiser May 15, 2020
97d63b6
Fix type hints in postprocess_detections.
hgaiser May 15, 2020
eba7e16
Fix type annotations for TorchScript.
May 18, 2020
9545059
Fix inconsistency with matched_idxs.
May 22, 2020
4865952
Add retinanet model test.
May 26, 2020
6e065be
Add missing JIT annotations.
Sep 25, 2020
7dc4c6b
Remove redundant model construction
fmassa Oct 10, 2020
640e59b
Fix bugs during training on newer PyTorch and unused params in DDP
fmassa Oct 11, 2020
23cabe3
Cleanup resnet_fpn_backbone
fmassa Oct 11, 2020
214ead7
Use L1 loss for regression
fmassa Oct 13, 2020
44a1333
Disable support for images with no annotations
fmassa Oct 13, 2020
c2f6334
Merge branch 'master' of github.com:pytorch/vision into retinanet
fmassa Oct 13, 2020
e560039
Fix retinanet tests
fmassa Oct 13, 2020
0647732
Fix Lint
fmassa Oct 13, 2020
aa6364f
Add pretrained model
fmassa Oct 13, 2020
9b62169
Add training info for retinanet
fmassa Oct 13, 2020
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Binary file not shown.
1 change: 1 addition & 0 deletions torchvision/models/detection/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from .faster_rcnn import *
from .mask_rcnn import *
from .keypoint_rcnn import *
from .retinanet import *
159 changes: 159 additions & 0 deletions torchvision/models/detection/anchor_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,159 @@
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
import torch
from torch import nn

from torch.jit.annotations import List, Optional, Dict
from .image_list import ImageList


class AnchorGenerator(nn.Module):
"""
Module that generates anchors for a set of feature maps and
image sizes.

The module support computing anchors at multiple sizes and aspect ratios
per feature map. This module assumes aspect ratio = height / width for
each anchor.

sizes and aspect_ratios should have the same number of elements, and it should
correspond to the number of feature maps.

sizes[i] and aspect_ratios[i] can have an arbitrary number of elements,
and AnchorGenerator will output a set of sizes[i] * aspect_ratios[i] anchors
per spatial location for feature map i.

Arguments:
sizes (Tuple[Tuple[int]]):
aspect_ratios (Tuple[Tuple[float]]):
"""

__annotations__ = {
"cell_anchors": Optional[List[torch.Tensor]],
"_cache": Dict[str, List[torch.Tensor]]
}

def __init__(
self,
sizes=((128, 256, 512),),
aspect_ratios=((0.5, 1.0, 2.0),),
):
super(AnchorGenerator, self).__init__()

if not isinstance(sizes[0], (list, tuple)):
# TODO change this
sizes = tuple((s,) for s in sizes)
if not isinstance(aspect_ratios[0], (list, tuple)):
aspect_ratios = (aspect_ratios,) * len(sizes)

assert len(sizes) == len(aspect_ratios)

self.sizes = sizes
self.aspect_ratios = aspect_ratios
self.cell_anchors = None
self._cache = {}

# TODO: https://github.com/pytorch/pytorch/issues/26792
# For every (aspect_ratios, scales) combination, output a zero-centered anchor with those values.
# (scales, aspect_ratios) are usually an element of zip(self.scales, self.aspect_ratios)
# This method assumes aspect ratio = height / width for an anchor.
def generate_anchors(self, scales, aspect_ratios, dtype=torch.float32, device="cpu"):
# type: (List[int], List[float], int, Device) -> Tensor # noqa: F821
scales = torch.as_tensor(scales, dtype=dtype, device=device)
aspect_ratios = torch.as_tensor(aspect_ratios, dtype=dtype, device=device)
h_ratios = torch.sqrt(aspect_ratios)
w_ratios = 1 / h_ratios

ws = (w_ratios[:, None] * scales[None, :]).view(-1)
hs = (h_ratios[:, None] * scales[None, :]).view(-1)

base_anchors = torch.stack([-ws, -hs, ws, hs], dim=1) / 2
return base_anchors.round()

def set_cell_anchors(self, dtype, device):
# type: (int, Device) -> None # noqa: F821
if self.cell_anchors is not None:
cell_anchors = self.cell_anchors
assert cell_anchors is not None
# suppose that all anchors have the same device
# which is a valid assumption in the current state of the codebase
if cell_anchors[0].device == device:
return

cell_anchors = [
self.generate_anchors(
sizes,
aspect_ratios,
dtype,
device
)
for sizes, aspect_ratios in zip(self.sizes, self.aspect_ratios)
]
self.cell_anchors = cell_anchors

def num_anchors_per_location(self):
return [len(s) * len(a) for s, a in zip(self.sizes, self.aspect_ratios)]

# For every combination of (a, (g, s), i) in (self.cell_anchors, zip(grid_sizes, strides), 0:2),
# output g[i] anchors that are s[i] distance apart in direction i, with the same dimensions as a.
def grid_anchors(self, grid_sizes, strides):
# type: (List[List[int]], List[List[Tensor]]) -> List[Tensor]
anchors = []
cell_anchors = self.cell_anchors
assert cell_anchors is not None
assert len(grid_sizes) == len(strides) == len(cell_anchors)

for size, stride, base_anchors in zip(
grid_sizes, strides, cell_anchors
):
grid_height, grid_width = size
stride_height, stride_width = stride
device = base_anchors.device

# For output anchor, compute [x_center, y_center, x_center, y_center]
shifts_x = torch.arange(
0, grid_width, dtype=torch.float32, device=device
) * stride_width
shifts_y = torch.arange(
0, grid_height, dtype=torch.float32, device=device
) * stride_height
shift_y, shift_x = torch.meshgrid(shifts_y, shifts_x)
shift_x = shift_x.reshape(-1)
shift_y = shift_y.reshape(-1)
shifts = torch.stack((shift_x, shift_y, shift_x, shift_y), dim=1)

# For every (base anchor, output anchor) pair,
# offset each zero-centered base anchor by the center of the output anchor.
anchors.append(
(shifts.view(-1, 1, 4) + base_anchors.view(1, -1, 4)).reshape(-1, 4)
)

return anchors

def cached_grid_anchors(self, grid_sizes, strides):
# type: (List[List[int]], List[List[Tensor]]) -> List[Tensor]
key = str(grid_sizes) + str(strides)
if key in self._cache:
return self._cache[key]
anchors = self.grid_anchors(grid_sizes, strides)
self._cache[key] = anchors
return anchors

def forward(self, image_list, feature_maps):
# type: (ImageList, List[Tensor]) -> List[Tensor]
grid_sizes = list([feature_map.shape[-2:] for feature_map in feature_maps])
image_size = image_list.tensors.shape[-2:]
dtype, device = feature_maps[0].dtype, feature_maps[0].device
strides = [[torch.tensor(image_size[0] // g[0], dtype=torch.int64, device=device),
torch.tensor(image_size[1] // g[1], dtype=torch.int64, device=device)] for g in grid_sizes]
self.set_cell_anchors(dtype, device)
anchors_over_all_feature_maps = self.cached_grid_anchors(grid_sizes, strides)
anchors = torch.jit.annotate(List[List[torch.Tensor]], [])
for i, (image_height, image_width) in enumerate(image_list.image_sizes):
anchors_in_image = []
for anchors_per_feature_map in anchors_over_all_feature_maps:
anchors_in_image.append(anchors_per_feature_map)
anchors.append(anchors_in_image)
anchors = [torch.cat(anchors_per_image) for anchors_per_image in anchors]
# Clear the cache in case that memory leaks.
self._cache.clear()
return anchors
34 changes: 23 additions & 11 deletions torchvision/models/detection/backbone_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,13 +25,17 @@ class BackboneWithFPN(nn.Module):
Attributes:
out_channels (int): the number of channels in the FPN
"""
def __init__(self, backbone, return_layers, in_channels_list, out_channels):
def __init__(self, backbone, return_layers, in_channels_list, out_channels, extra_blocks=None):
super(BackboneWithFPN, self).__init__()

if extra_blocks is None:
extra_blocks = LastLevelMaxPool()

self.body = IntermediateLayerGetter(backbone, return_layers=return_layers)
self.fpn = FeaturePyramidNetwork(
in_channels_list=in_channels_list,
out_channels=out_channels,
extra_blocks=LastLevelMaxPool(),
extra_blocks=extra_blocks,
)
self.out_channels = out_channels

Expand All @@ -41,7 +45,14 @@ def forward(self, x):
return x


def resnet_fpn_backbone(backbone_name, pretrained, norm_layer=misc_nn_ops.FrozenBatchNorm2d, trainable_layers=3):
def resnet_fpn_backbone(
backbone_name,
pretrained,
norm_layer=misc_nn_ops.FrozenBatchNorm2d,
trainable_layers=3,
returned_layers=None,
extra_blocks=None
):
"""
Constructs a specified ResNet backbone with FPN on top. Freezes the specified number of layers in the backbone.

Expand Down Expand Up @@ -82,14 +93,15 @@ def resnet_fpn_backbone(backbone_name, pretrained, norm_layer=misc_nn_ops.Frozen
if all([not name.startswith(layer) for layer in layers_to_train]):
parameter.requires_grad_(False)

return_layers = {'layer1': '0', 'layer2': '1', 'layer3': '2', 'layer4': '3'}
if extra_blocks is None:
extra_blocks = LastLevelMaxPool()

if returned_layers is None:
returned_layers = [1, 2, 3, 4]
assert min(returned_layers) > 0 and max(returned_layers) < 5
return_layers = {f'layer{k}': str(v) for v, k in enumerate(returned_layers)}

in_channels_stage2 = backbone.inplanes // 8
in_channels_list = [
in_channels_stage2,
in_channels_stage2 * 2,
in_channels_stage2 * 4,
in_channels_stage2 * 8,
]
in_channels_list = [in_channels_stage2 * 2 ** (i - 1) for i in returned_layers]
out_channels = 256
return BackboneWithFPN(backbone, return_layers, in_channels_list, out_channels)
return BackboneWithFPN(backbone, return_layers, in_channels_list, out_channels, extra_blocks=extra_blocks)
3 changes: 2 additions & 1 deletion torchvision/models/detection/faster_rcnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,9 @@

from ..utils import load_state_dict_from_url

from .anchor_utils import AnchorGenerator
from .generalized_rcnn import GeneralizedRCNN
from .rpn import AnchorGenerator, RPNHead, RegionProposalNetwork
from .rpn import RPNHead, RegionProposalNetwork
from .roi_heads import RoIHeads
from .transform import GeneralizedRCNNTransform
from .backbone_utils import resnet_fpn_backbone
Expand Down
2 changes: 1 addition & 1 deletion torchvision/models/detection/keypoint_rcnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,7 @@ class KeypointRCNN(FasterRCNN):
>>> import torch
>>> import torchvision
>>> from torchvision.models.detection import KeypointRCNN
>>> from torchvision.models.detection.rpn import AnchorGenerator
>>> from torchvision.models.detection.anchor_utils import AnchorGenerator
>>>
>>> # load a pre-trained model for classification and return
>>> # only the features
Expand Down
2 changes: 1 addition & 1 deletion torchvision/models/detection/mask_rcnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,7 @@ class MaskRCNN(FasterRCNN):
>>> import torch
>>> import torchvision
>>> from torchvision.models.detection import MaskRCNN
>>> from torchvision.models.detection.rpn import AnchorGenerator
>>> from torchvision.models.detection.anchor_utils import AnchorGenerator
>>>
>>> # load a pre-trained model for classification and return
>>> # only the features
Expand Down
Loading