Skip to content

Commit

Permalink
RetinaNet object detection (take 2) (#2784)
Browse files Browse the repository at this point in the history
* Add rough implementation of RetinaNet.

* Move AnchorGenerator to a seperate file.

* Move box similarity to Matcher.

* Expose extra blocks in FPN.

* Expose retinanet in __init__.py.

* Use P6 and P7 in FPN for retinanet.

* Use parameters from retinanet for anchor generation.

* General fixes for retinanet model.

* Implement loss for retinanet heads.

* Output reshaped outputs from retinanet heads.

* Add postprocessing of detections.

* Small fixes.

* Remove unused argument.

* Remove python2 invocation of super.

* Add postprocessing for additional outputs.

* Add missing import of ImageList.

* Remove redundant import.

* Simplify class correction.

* Fix pylint warnings.

* Remove the label adjustment for background class.

* Set default score threshold to 0.05.

* Add weight initialization for regression layer.

* Allow training on images with no annotations.

* Use smooth_l1_loss with beta value.

* Add more typehints for TorchScript conversions.

* Fix linting issues.

* Fix type hints in postprocess_detections.

* Fix type annotations for TorchScript.

* Fix inconsistency with matched_idxs.

* Add retinanet model test.

* Add missing JIT annotations.

* Remove redundant model construction

Make tests pass

* Fix bugs during training on newer PyTorch and unused params in DDP

Needs cleanup and to add back support for images with no annotations

* Cleanup resnet_fpn_backbone

* Use L1 loss for regression

Gives 1mAP improvement over smooth l1

* Disable support for images with no annotations

Need to fix distributed first

* Fix retinanet tests

Need to deduplicate those box checks

* Fix Lint

* Add pretrained model

* Add training info for retinanet

Co-authored-by: Hans Gaiser <[email protected]>
Co-authored-by: Hans Gaiser <[email protected]>
Co-authored-by: Hans Gaiser <[email protected]>
  • Loading branch information
4 people authored Oct 13, 2020
1 parent 42e7f1f commit 5bb81c8
Show file tree
Hide file tree
Showing 13 changed files with 884 additions and 169 deletions.
8 changes: 8 additions & 0 deletions docs/source/models.rst
Original file line number Diff line number Diff line change
Expand Up @@ -350,6 +350,7 @@ the instances set of COCO train2017 and evaluated on COCO val2017.
Network box AP mask AP keypoint AP
================================ ======= ======== ===========
Faster R-CNN ResNet-50 FPN 37.0 - -
RetinaNet ResNet-50 FPN 36.4 - -
Mask R-CNN ResNet-50 FPN 37.9 34.6 -
================================ ======= ======== ===========

Expand Down Expand Up @@ -405,6 +406,7 @@ precision-recall.
Network train time (s / it) test time (s / it) memory (GB)
============================== =================== ================== ===========
Faster R-CNN ResNet-50 FPN 0.2288 0.0590 5.2
RetinaNet ResNet-50 FPN 0.2514 0.0939 4.1
Mask R-CNN ResNet-50 FPN 0.2728 0.0903 5.4
Keypoint R-CNN ResNet-50 FPN 0.3789 0.1242 6.8
============================== =================== ================== ===========
Expand All @@ -416,6 +418,12 @@ Faster R-CNN
.. autofunction:: torchvision.models.detection.fasterrcnn_resnet50_fpn


RetinaNet
------------

.. autofunction:: torchvision.models.detection.retinanet_resnet50_fpn


Mask R-CNN
----------

Expand Down
7 changes: 7 additions & 0 deletions references/detection/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,13 @@ python -m torch.distributed.launch --nproc_per_node=8 --use_env train.py\
--lr-steps 16 22 --aspect-ratio-group-factor 3
```

### RetinaNet
```
python -m torch.distributed.launch --nproc_per_node=8 --use_env train.py\
--dataset coco --model retinanet_resnet50_fpn --epochs 26\
--lr-steps 16 22 --aspect-ratio-group-factor 3 --lr 0.01
```


### Mask R-CNN
```
Expand Down
Binary file not shown.
1 change: 1 addition & 0 deletions torchvision/models/detection/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from .faster_rcnn import *
from .mask_rcnn import *
from .keypoint_rcnn import *
from .retinanet import *
159 changes: 159 additions & 0 deletions torchvision/models/detection/anchor_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,159 @@
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
import torch
from torch import nn

from torch.jit.annotations import List, Optional, Dict
from .image_list import ImageList


class AnchorGenerator(nn.Module):
"""
Module that generates anchors for a set of feature maps and
image sizes.
The module support computing anchors at multiple sizes and aspect ratios
per feature map. This module assumes aspect ratio = height / width for
each anchor.
sizes and aspect_ratios should have the same number of elements, and it should
correspond to the number of feature maps.
sizes[i] and aspect_ratios[i] can have an arbitrary number of elements,
and AnchorGenerator will output a set of sizes[i] * aspect_ratios[i] anchors
per spatial location for feature map i.
Arguments:
sizes (Tuple[Tuple[int]]):
aspect_ratios (Tuple[Tuple[float]]):
"""

__annotations__ = {
"cell_anchors": Optional[List[torch.Tensor]],
"_cache": Dict[str, List[torch.Tensor]]
}

def __init__(
self,
sizes=((128, 256, 512),),
aspect_ratios=((0.5, 1.0, 2.0),),
):
super(AnchorGenerator, self).__init__()

if not isinstance(sizes[0], (list, tuple)):
# TODO change this
sizes = tuple((s,) for s in sizes)
if not isinstance(aspect_ratios[0], (list, tuple)):
aspect_ratios = (aspect_ratios,) * len(sizes)

assert len(sizes) == len(aspect_ratios)

self.sizes = sizes
self.aspect_ratios = aspect_ratios
self.cell_anchors = None
self._cache = {}

# TODO: https://github.com/pytorch/pytorch/issues/26792
# For every (aspect_ratios, scales) combination, output a zero-centered anchor with those values.
# (scales, aspect_ratios) are usually an element of zip(self.scales, self.aspect_ratios)
# This method assumes aspect ratio = height / width for an anchor.
def generate_anchors(self, scales, aspect_ratios, dtype=torch.float32, device="cpu"):
# type: (List[int], List[float], int, Device) -> Tensor # noqa: F821
scales = torch.as_tensor(scales, dtype=dtype, device=device)
aspect_ratios = torch.as_tensor(aspect_ratios, dtype=dtype, device=device)
h_ratios = torch.sqrt(aspect_ratios)
w_ratios = 1 / h_ratios

ws = (w_ratios[:, None] * scales[None, :]).view(-1)
hs = (h_ratios[:, None] * scales[None, :]).view(-1)

base_anchors = torch.stack([-ws, -hs, ws, hs], dim=1) / 2
return base_anchors.round()

def set_cell_anchors(self, dtype, device):
# type: (int, Device) -> None # noqa: F821
if self.cell_anchors is not None:
cell_anchors = self.cell_anchors
assert cell_anchors is not None
# suppose that all anchors have the same device
# which is a valid assumption in the current state of the codebase
if cell_anchors[0].device == device:
return

cell_anchors = [
self.generate_anchors(
sizes,
aspect_ratios,
dtype,
device
)
for sizes, aspect_ratios in zip(self.sizes, self.aspect_ratios)
]
self.cell_anchors = cell_anchors

def num_anchors_per_location(self):
return [len(s) * len(a) for s, a in zip(self.sizes, self.aspect_ratios)]

# For every combination of (a, (g, s), i) in (self.cell_anchors, zip(grid_sizes, strides), 0:2),
# output g[i] anchors that are s[i] distance apart in direction i, with the same dimensions as a.
def grid_anchors(self, grid_sizes, strides):
# type: (List[List[int]], List[List[Tensor]]) -> List[Tensor]
anchors = []
cell_anchors = self.cell_anchors
assert cell_anchors is not None
assert len(grid_sizes) == len(strides) == len(cell_anchors)

for size, stride, base_anchors in zip(
grid_sizes, strides, cell_anchors
):
grid_height, grid_width = size
stride_height, stride_width = stride
device = base_anchors.device

# For output anchor, compute [x_center, y_center, x_center, y_center]
shifts_x = torch.arange(
0, grid_width, dtype=torch.float32, device=device
) * stride_width
shifts_y = torch.arange(
0, grid_height, dtype=torch.float32, device=device
) * stride_height
shift_y, shift_x = torch.meshgrid(shifts_y, shifts_x)
shift_x = shift_x.reshape(-1)
shift_y = shift_y.reshape(-1)
shifts = torch.stack((shift_x, shift_y, shift_x, shift_y), dim=1)

# For every (base anchor, output anchor) pair,
# offset each zero-centered base anchor by the center of the output anchor.
anchors.append(
(shifts.view(-1, 1, 4) + base_anchors.view(1, -1, 4)).reshape(-1, 4)
)

return anchors

def cached_grid_anchors(self, grid_sizes, strides):
# type: (List[List[int]], List[List[Tensor]]) -> List[Tensor]
key = str(grid_sizes) + str(strides)
if key in self._cache:
return self._cache[key]
anchors = self.grid_anchors(grid_sizes, strides)
self._cache[key] = anchors
return anchors

def forward(self, image_list, feature_maps):
# type: (ImageList, List[Tensor]) -> List[Tensor]
grid_sizes = list([feature_map.shape[-2:] for feature_map in feature_maps])
image_size = image_list.tensors.shape[-2:]
dtype, device = feature_maps[0].dtype, feature_maps[0].device
strides = [[torch.tensor(image_size[0] // g[0], dtype=torch.int64, device=device),
torch.tensor(image_size[1] // g[1], dtype=torch.int64, device=device)] for g in grid_sizes]
self.set_cell_anchors(dtype, device)
anchors_over_all_feature_maps = self.cached_grid_anchors(grid_sizes, strides)
anchors = torch.jit.annotate(List[List[torch.Tensor]], [])
for i, (image_height, image_width) in enumerate(image_list.image_sizes):
anchors_in_image = []
for anchors_per_feature_map in anchors_over_all_feature_maps:
anchors_in_image.append(anchors_per_feature_map)
anchors.append(anchors_in_image)
anchors = [torch.cat(anchors_per_image) for anchors_per_image in anchors]
# Clear the cache in case that memory leaks.
self._cache.clear()
return anchors
34 changes: 23 additions & 11 deletions torchvision/models/detection/backbone_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,13 +25,17 @@ class BackboneWithFPN(nn.Module):
Attributes:
out_channels (int): the number of channels in the FPN
"""
def __init__(self, backbone, return_layers, in_channels_list, out_channels):
def __init__(self, backbone, return_layers, in_channels_list, out_channels, extra_blocks=None):
super(BackboneWithFPN, self).__init__()

if extra_blocks is None:
extra_blocks = LastLevelMaxPool()

self.body = IntermediateLayerGetter(backbone, return_layers=return_layers)
self.fpn = FeaturePyramidNetwork(
in_channels_list=in_channels_list,
out_channels=out_channels,
extra_blocks=LastLevelMaxPool(),
extra_blocks=extra_blocks,
)
self.out_channels = out_channels

Expand All @@ -41,7 +45,14 @@ def forward(self, x):
return x


def resnet_fpn_backbone(backbone_name, pretrained, norm_layer=misc_nn_ops.FrozenBatchNorm2d, trainable_layers=3):
def resnet_fpn_backbone(
backbone_name,
pretrained,
norm_layer=misc_nn_ops.FrozenBatchNorm2d,
trainable_layers=3,
returned_layers=None,
extra_blocks=None
):
"""
Constructs a specified ResNet backbone with FPN on top. Freezes the specified number of layers in the backbone.
Expand Down Expand Up @@ -82,14 +93,15 @@ def resnet_fpn_backbone(backbone_name, pretrained, norm_layer=misc_nn_ops.Frozen
if all([not name.startswith(layer) for layer in layers_to_train]):
parameter.requires_grad_(False)

return_layers = {'layer1': '0', 'layer2': '1', 'layer3': '2', 'layer4': '3'}
if extra_blocks is None:
extra_blocks = LastLevelMaxPool()

if returned_layers is None:
returned_layers = [1, 2, 3, 4]
assert min(returned_layers) > 0 and max(returned_layers) < 5
return_layers = {f'layer{k}': str(v) for v, k in enumerate(returned_layers)}

in_channels_stage2 = backbone.inplanes // 8
in_channels_list = [
in_channels_stage2,
in_channels_stage2 * 2,
in_channels_stage2 * 4,
in_channels_stage2 * 8,
]
in_channels_list = [in_channels_stage2 * 2 ** (i - 1) for i in returned_layers]
out_channels = 256
return BackboneWithFPN(backbone, return_layers, in_channels_list, out_channels)
return BackboneWithFPN(backbone, return_layers, in_channels_list, out_channels, extra_blocks=extra_blocks)
3 changes: 2 additions & 1 deletion torchvision/models/detection/faster_rcnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,9 @@

from ..utils import load_state_dict_from_url

from .anchor_utils import AnchorGenerator
from .generalized_rcnn import GeneralizedRCNN
from .rpn import AnchorGenerator, RPNHead, RegionProposalNetwork
from .rpn import RPNHead, RegionProposalNetwork
from .roi_heads import RoIHeads
from .transform import GeneralizedRCNNTransform
from .backbone_utils import resnet_fpn_backbone
Expand Down
2 changes: 1 addition & 1 deletion torchvision/models/detection/keypoint_rcnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,7 @@ class KeypointRCNN(FasterRCNN):
>>> import torch
>>> import torchvision
>>> from torchvision.models.detection import KeypointRCNN
>>> from torchvision.models.detection.rpn import AnchorGenerator
>>> from torchvision.models.detection.anchor_utils import AnchorGenerator
>>>
>>> # load a pre-trained model for classification and return
>>> # only the features
Expand Down
2 changes: 1 addition & 1 deletion torchvision/models/detection/mask_rcnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,7 @@ class MaskRCNN(FasterRCNN):
>>> import torch
>>> import torchvision
>>> from torchvision.models.detection import MaskRCNN
>>> from torchvision.models.detection.rpn import AnchorGenerator
>>> from torchvision.models.detection.anchor_utils import AnchorGenerator
>>>
>>> # load a pre-trained model for classification and return
>>> # only the features
Expand Down
Loading

0 comments on commit 5bb81c8

Please sign in to comment.