-
Notifications
You must be signed in to change notification settings - Fork 7k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add rough implementation of RetinaNet.
- Loading branch information
Showing
1 changed file
with
371 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,371 @@ | ||
from collections import OrderedDict | ||
|
||
import torch | ||
from torch import nn | ||
import torch.nn.functional as F | ||
|
||
from ..utils import load_state_dict_from_url | ||
|
||
from .rpn import AnchorGenerator | ||
from .transform import GeneralizedRCNNTransform | ||
from .backbone_utils import resnet_fpn_backbone | ||
|
||
|
||
__all__ = [ | ||
"RetinaNet", "retinanet_resnet50_fpn", | ||
] | ||
|
||
|
||
class RetinaNetHead(nn.Module): | ||
""" | ||
A regression and classification head for use in RetinaNet. | ||
Arguments: | ||
in_channels (int): number of channels of the input feature | ||
num_anchors (int): number of anchors to be predicted | ||
num_classes (int): number of classes to be predicted | ||
""" | ||
|
||
def __init__(self, in_channels, num_anchors, num_classes): | ||
super(RPNHead, self).__init__() | ||
self.classification_head = RetinaNetClassificationHead(in_channels, num_anchors, num_classes) | ||
self.regression_head = RetinaNetRegressionHead(in_channels, num_anchors) | ||
|
||
def compute_loss(self, outputs, labels, matched_gt_boxes): | ||
return { | ||
'classification': self.classification_head.compute_loss(outputs, targets, anchor_state), | ||
'regression': self.regression_head.compute_loss(outputs, targets, anchor_state), | ||
} | ||
|
||
def forward(self, x): | ||
logits = [self.classification_head(feature, targets) for feature in x] | ||
bbox_reg = [self.regression_head(feature, targets) for feature in x] | ||
return dict(logits=logits, bbox_reg=bbox_reg) | ||
|
||
|
||
class RetinaNetClassificationHead(nn.Module): | ||
""" | ||
A classification head for use in RetinaNet. | ||
Arguments: | ||
in_channels (int): number of channels of the input feature | ||
num_anchors (int): number of anchors to be predicted | ||
num_classes (int): number of classes to be predicted | ||
""" | ||
|
||
def __init__(self, in_channels, num_anchors, num_classes): | ||
super(RPNHead, self).__init__() | ||
self.conv1 = nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=1, padding=1) | ||
self.conv2 = nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=1, padding=1) | ||
self.conv3 = nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=1, padding=1) | ||
self.conv4 = nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=1, padding=1) | ||
self.cls_logits = nn.Conv2d(in_channels, num_anchors * num_classes, kernel_size=3, stride=1) | ||
|
||
for l in self.children(): | ||
torch.nn.init.normal_(l.weight, std=0.01) | ||
torch.nn.init.constant_(l.bias, 0) | ||
|
||
def compute_loss(self, outputs, labels, matched_gt_boxes): | ||
# TODO Implement focal loss, is there an existing function for this? | ||
return 0 | ||
|
||
def forward(self, x): | ||
logits = [] | ||
for feature in x: | ||
t = F.relu(self.conv1(feature)) | ||
t = F.relu(self.conv2(t)) | ||
t = F.relu(self.conv3(t)) | ||
t = F.relu(self.conv4(t)) | ||
logits.append(self.cls_logits(t)) | ||
return logits | ||
|
||
|
||
class RetinaNetRegressionHead(nn.Module): | ||
""" | ||
A regression head for use in RetinaNet. | ||
Arguments: | ||
in_channels (int): number of channels of the input feature | ||
num_anchors (int): number of anchors to be predicted | ||
""" | ||
|
||
def __init__(self, in_channels, num_anchors): | ||
super(RPNHead, self).__init__() | ||
self.conv1 = nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=1, padding=1) | ||
self.conv2 = nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=1, padding=1) | ||
self.conv3 = nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=1, padding=1) | ||
self.conv4 = nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=1, padding=1) | ||
self.bbox_reg = nn.Conv2d(in_channels, num_anchors * 4, kernel_size=3, stride=1) | ||
|
||
for l in self.children(): | ||
torch.nn.init.normal_(l.weight, std=0.01) | ||
torch.nn.init.constant_(l.bias, 0) | ||
|
||
def compute_loss(self, outputs, labels, matched_gt_boxes): | ||
# TODO Use SmoothL1 loss for regression, or just L1 like in rpn.py ? | ||
return 0 | ||
|
||
def forward(self, x): | ||
bbox_reg = [] | ||
for feature in x: | ||
t = F.relu(self.conv1(feature)) | ||
t = F.relu(self.conv2(t)) | ||
t = F.relu(self.conv3(t)) | ||
t = F.relu(self.conv4(t)) | ||
bbox_reg.append(self.bbox_reg(t)) | ||
return bbox_reg | ||
|
||
|
||
class RetinaNet(nn.Module): | ||
""" | ||
Implements RetinaNet. | ||
The input to the model is expected to be a list of tensors, each of shape [C, H, W], one for each | ||
image, and should be in 0-1 range. Different images can have different sizes. | ||
The behavior of the model changes depending if it is in training or evaluation mode. | ||
During training, the model expects both the input tensors, as well as a targets (list of dictionary), | ||
containing: | ||
- boxes (FloatTensor[N, 4]): the ground-truth boxes in [x1, y1, x2, y2] format, with values | ||
between 0 and H and 0 and W | ||
- labels (Int64Tensor[N]): the class label for each ground-truth box | ||
The model returns a Dict[Tensor] during training, containing the classification and regression | ||
losses. | ||
During inference, the model requires only the input tensors, and returns the post-processed | ||
predictions as a List[Dict[Tensor]], one for each input image. The fields of the Dict are as | ||
follows: | ||
- boxes (FloatTensor[N, 4]): the predicted boxes in [x1, y1, x2, y2] format, with values between | ||
0 and H and 0 and W | ||
- labels (Int64Tensor[N]): the predicted labels for each image | ||
- scores (Tensor[N]): the scores for each prediction | ||
Arguments: | ||
backbone (nn.Module): the network used to compute the features for the model. | ||
It should contain an out_channels attribute, which indicates the number of output | ||
channels that each feature map has (and it should be the same for all feature maps). | ||
The backbone should return a single Tensor or an OrderedDict[Tensor]. | ||
num_classes (int): number of output classes of the model (excluding the background). | ||
min_size (int): minimum size of the image to be rescaled before feeding it to the backbone | ||
max_size (int): maximum size of the image to be rescaled before feeding it to the backbone | ||
image_mean (Tuple[float, float, float]): mean values used for input normalization. | ||
They are generally the mean values of the dataset on which the backbone has been trained | ||
on | ||
image_std (Tuple[float, float, float]): std values used for input normalization. | ||
They are generally the std values of the dataset on which the backbone has been trained on | ||
anchor_generator (AnchorGenerator): module that generates the anchors for a set of feature | ||
maps. | ||
head (nn.Module): Module run on top of the feature pyramid. | ||
Defaults to a module containing a classification and regression module. | ||
pre_nms_top_n (int): number of proposals to keep before applying NMS during testing. | ||
post_nms_top_n (int): number of proposals to keep after applying NMS during testing. | ||
nms_thresh (float): NMS threshold used for postprocessing the detections. | ||
fg_iou_thresh (float): minimum IoU between the anchor and the GT box so that they can be | ||
considered as positive during training. | ||
bg_iou_thresh (float): maximum IoU between the anchor and the GT box so that they can be | ||
considered as negative during training. | ||
Example:: | ||
>>> import torch | ||
>>> import torchvision | ||
>>> from torchvision.models.detection import RetinaNet | ||
>>> from torchvision.models.detection.rpn import AnchorGenerator | ||
>>> # load a pre-trained model for classification and return | ||
>>> # only the features | ||
>>> backbone = torchvision.models.mobilenet_v2(pretrained=True).features | ||
>>> # RetinaNet needs to know the number of | ||
>>> # output channels in a backbone. For mobilenet_v2, it's 1280 | ||
>>> # so we need to add it here | ||
>>> backbone.out_channels = 1280 | ||
>>> | ||
>>> # let's make the network generate 5 x 3 anchors per spatial | ||
>>> # location, with 5 different sizes and 3 different aspect | ||
>>> # ratios. We have a Tuple[Tuple[int]] because each feature | ||
>>> # map could potentially have different sizes and | ||
>>> # aspect ratios | ||
>>> anchor_generator = AnchorGenerator(sizes=((32, 64, 128, 256, 512),), | ||
>>> aspect_ratios=((0.5, 1.0, 2.0),)) | ||
>>> | ||
>>> # put the pieces together inside a RetinaNet model | ||
>>> model = RetinaNet(backbone, | ||
>>> num_classes=2, | ||
>>> anchor_generator=anchor_generator) | ||
>>> model.eval() | ||
>>> x = [torch.rand(3, 300, 400), torch.rand(3, 500, 400)] | ||
>>> predictions = model(x) | ||
""" | ||
|
||
def __init__(self, backbone, num_classes, | ||
# transform parameters | ||
min_size=800, max_size=1333, | ||
image_mean=None, image_std=None, | ||
# Anchor parameters | ||
anchor_generator=None, head=None, | ||
pre_nms_top_n=1000, post_nms_top_n=1000, | ||
nms_thresh=0.5, | ||
fg_iou_thresh=0.5, bg_iou_thresh=0.4): | ||
|
||
if not hasattr(backbone, "out_channels"): | ||
raise ValueError( | ||
"backbone should contain an attribute out_channels " | ||
"specifying the number of output channels (assumed to be the " | ||
"same for all the levels)") | ||
|
||
assert isinstance(anchor_generator, (AnchorGenerator, type(None))) | ||
|
||
if anchor_generator is None: | ||
anchor_sizes = ((32,), (64,), (128,), (256,), (512,)) | ||
aspect_ratios = ((0.5, 1.0, 2.0),) * len(anchor_sizes) | ||
self.anchor_generator = AnchorGenerator( | ||
anchor_sizes, aspect_ratios | ||
) | ||
|
||
if head is None: | ||
head = RetinaNetHead(backbone.out_channels, num_classes, anchor_generator.num_anchors_per_location()) | ||
self.head = head | ||
|
||
if image_mean is None: | ||
image_mean = [0.485, 0.456, 0.406] | ||
if image_std is None: | ||
image_std = [0.229, 0.224, 0.225] | ||
self.transform = GeneralizedRCNNTransform(min_size, max_size, image_mean, image_std) | ||
|
||
@torch.jit.unused | ||
def eager_outputs(self, losses, detections): | ||
# type: (Dict[str, Tensor], List[Dict[str, Tensor]]) -> Tuple[Dict[str, Tensor], List[Dict[str, Tensor]]] | ||
if self.training: | ||
return losses | ||
|
||
return detections | ||
|
||
def forward(self, images, targets=None): | ||
# type: (List[Tensor], Optional[List[Dict[str, Tensor]]]) | ||
""" | ||
Arguments: | ||
images (list[Tensor]): images to be processed | ||
targets (list[Dict[Tensor]]): ground-truth boxes present in the image (optional) | ||
Returns: | ||
result (list[BoxList] or dict[Tensor]): the output from the model. | ||
During training, it returns a dict[Tensor] which contains the losses. | ||
During testing, it returns list[BoxList] contains additional fields | ||
like `scores`, `labels` and `mask` (for Mask R-CNN models). | ||
""" | ||
if self.training and targets is None: | ||
raise ValueError("In training mode, targets should be passed") | ||
|
||
# get the original image sizes | ||
original_image_sizes = torch.jit.annotate(List[Tuple[int, int]], []) | ||
for img in images: | ||
val = img.shape[-2:] | ||
assert len(val) == 2 | ||
original_image_sizes.append((val[0], val[1])) | ||
|
||
# transform the input | ||
images, targets = self.transform(images, targets) | ||
|
||
# get the features from the backbone | ||
features = self.backbone(images.tensors) | ||
if isinstance(features, torch.Tensor): | ||
features = OrderedDict([('0', features)]) | ||
|
||
# compute the retinanet heads outputs using the features | ||
head_outputs = self.head(images, features, targets) | ||
|
||
# create the set of anchors | ||
anchors = self.anchor_generator(images, features) | ||
|
||
losses = {} | ||
detections = torch.jit.annotate(List[Dict[str, torch.Tensor]], []) | ||
if self.training: | ||
assert targets is not None | ||
|
||
# compute the losses | ||
# TODO: Move necessary functions out of rpn.RegionProposalNetwork to a class or function | ||
# so that we can use it here and in rpn.RegionProposalNetwork | ||
labels, matched_gt_boxes = self.assign_targets_to_anchors(anchors, targets) | ||
regression_targets = self.box_coder.encode(matched_gt_boxes, anchors) | ||
losses = self.head.compute_loss(head_outputs, labels, matched_gt_boxes) | ||
else: | ||
# compute the detections | ||
# TODO: Implement postprocess_detections | ||
boxes, scores, labels = self.postprocess_detections(class_logits, box_regression, anchors) | ||
num_images = len(images) | ||
for i in range(num_images): | ||
detections.append( | ||
{ | ||
"boxes": boxes[i], | ||
"labels": labels[i], | ||
"scores": scores[i], | ||
} | ||
) | ||
|
||
detections = self.transform.postprocess(detections, images.image_sizes, original_image_sizes) | ||
|
||
if torch.jit.is_scripting(): | ||
if not self._has_warned: | ||
warnings.warn("RetinaNet always returns a (Losses, Detections) tuple in scripting") | ||
self._has_warned = True | ||
return (losses, detections) | ||
else: | ||
return self.eager_outputs(losses, detections) | ||
|
||
|
||
model_urls = { | ||
'retinanet_resnet50_fpn_coco': | ||
'#TODO', | ||
} | ||
|
||
|
||
def retinanet_resnet50_fpn(pretrained=False, progress=True, | ||
num_classes=91, pretrained_backbone=True, **kwargs): | ||
""" | ||
Constructs a RetinaNet model with a ResNet-50-FPN backbone. | ||
The input to the model is expected to be a list of tensors, each of shape ``[C, H, W]``, one for each | ||
image, and should be in ``0-1`` range. Different images can have different sizes. | ||
The behavior of the model changes depending if it is in training or evaluation mode. | ||
During training, the model expects both the input tensors, as well as a targets (list of dictionary), | ||
containing: | ||
- boxes (``FloatTensor[N, 4]``): the ground-truth boxes in ``[x1, y1, x2, y2]`` format, with values | ||
between ``0`` and ``H`` and ``0`` and ``W`` | ||
- labels (``Int64Tensor[N]``): the class label for each ground-truth box | ||
The model returns a ``Dict[Tensor]`` during training, containing the classification and regression | ||
losses. | ||
During inference, the model requires only the input tensors, and returns the post-processed | ||
predictions as a ``List[Dict[Tensor]]``, one for each input image. The fields of the ``Dict`` are as | ||
follows: | ||
- boxes (``FloatTensor[N, 4]``): the predicted boxes in ``[x1, y1, x2, y2]`` format, with values between | ||
``0`` and ``H`` and ``0`` and ``W`` | ||
- labels (``Int64Tensor[N]``): the predicted labels for each image | ||
- scores (``Tensor[N]``): the scores or each prediction | ||
Example:: | ||
>>> model = torchvision.models.detection.retinanet_resnet50_fpn(pretrained=True) | ||
>>> model.eval() | ||
>>> x = [torch.rand(3, 300, 400), torch.rand(3, 500, 400)] | ||
>>> predictions = model(x) | ||
Arguments: | ||
pretrained (bool): If True, returns a model pre-trained on COCO train2017 | ||
progress (bool): If True, displays a progress bar of the download to stderr | ||
""" | ||
if pretrained: | ||
# no need to download the backbone if pretrained is set | ||
pretrained_backbone = False | ||
backbone = resnet_fpn_backbone('resnet50', pretrained_backbone) | ||
model = RetinaNet(backbone, num_classes, **kwargs) | ||
if pretrained: | ||
state_dict = load_state_dict_from_url(model_urls['retinanet_resnet50_fpn_coco'], | ||
progress=progress) | ||
model.load_state_dict(state_dict) | ||
return model |