diff --git a/torchvision/models/detection/retinanet.py b/torchvision/models/detection/retinanet.py new file mode 100644 index 00000000000..09b9885db4f --- /dev/null +++ b/torchvision/models/detection/retinanet.py @@ -0,0 +1,371 @@ +from collections import OrderedDict + +import torch +from torch import nn +import torch.nn.functional as F + +from ..utils import load_state_dict_from_url + +from .rpn import AnchorGenerator +from .transform import GeneralizedRCNNTransform +from .backbone_utils import resnet_fpn_backbone + + +__all__ = [ + "RetinaNet", "retinanet_resnet50_fpn", +] + + +class RetinaNetHead(nn.Module): + """ + A regression and classification head for use in RetinaNet. + + Arguments: + in_channels (int): number of channels of the input feature + num_anchors (int): number of anchors to be predicted + num_classes (int): number of classes to be predicted + """ + + def __init__(self, in_channels, num_anchors, num_classes): + super(RPNHead, self).__init__() + self.classification_head = RetinaNetClassificationHead(in_channels, num_anchors, num_classes) + self.regression_head = RetinaNetRegressionHead(in_channels, num_anchors) + + def compute_loss(self, outputs, labels, matched_gt_boxes): + return { + 'classification': self.classification_head.compute_loss(outputs, targets, anchor_state), + 'regression': self.regression_head.compute_loss(outputs, targets, anchor_state), + } + + def forward(self, x): + logits = [self.classification_head(feature, targets) for feature in x] + bbox_reg = [self.regression_head(feature, targets) for feature in x] + return dict(logits=logits, bbox_reg=bbox_reg) + + +class RetinaNetClassificationHead(nn.Module): + """ + A classification head for use in RetinaNet. + + Arguments: + in_channels (int): number of channels of the input feature + num_anchors (int): number of anchors to be predicted + num_classes (int): number of classes to be predicted + """ + + def __init__(self, in_channels, num_anchors, num_classes): + super(RPNHead, self).__init__() + self.conv1 = nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=1, padding=1) + self.conv2 = nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=1, padding=1) + self.conv3 = nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=1, padding=1) + self.conv4 = nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=1, padding=1) + self.cls_logits = nn.Conv2d(in_channels, num_anchors * num_classes, kernel_size=3, stride=1) + + for l in self.children(): + torch.nn.init.normal_(l.weight, std=0.01) + torch.nn.init.constant_(l.bias, 0) + + def compute_loss(self, outputs, labels, matched_gt_boxes): + # TODO Implement focal loss, is there an existing function for this? + return 0 + + def forward(self, x): + logits = [] + for feature in x: + t = F.relu(self.conv1(feature)) + t = F.relu(self.conv2(t)) + t = F.relu(self.conv3(t)) + t = F.relu(self.conv4(t)) + logits.append(self.cls_logits(t)) + return logits + + +class RetinaNetRegressionHead(nn.Module): + """ + A regression head for use in RetinaNet. + + Arguments: + in_channels (int): number of channels of the input feature + num_anchors (int): number of anchors to be predicted + """ + + def __init__(self, in_channels, num_anchors): + super(RPNHead, self).__init__() + self.conv1 = nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=1, padding=1) + self.conv2 = nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=1, padding=1) + self.conv3 = nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=1, padding=1) + self.conv4 = nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=1, padding=1) + self.bbox_reg = nn.Conv2d(in_channels, num_anchors * 4, kernel_size=3, stride=1) + + for l in self.children(): + torch.nn.init.normal_(l.weight, std=0.01) + torch.nn.init.constant_(l.bias, 0) + + def compute_loss(self, outputs, labels, matched_gt_boxes): + # TODO Use SmoothL1 loss for regression, or just L1 like in rpn.py ? + return 0 + + def forward(self, x): + bbox_reg = [] + for feature in x: + t = F.relu(self.conv1(feature)) + t = F.relu(self.conv2(t)) + t = F.relu(self.conv3(t)) + t = F.relu(self.conv4(t)) + bbox_reg.append(self.bbox_reg(t)) + return bbox_reg + + +class RetinaNet(nn.Module): + """ + Implements RetinaNet. + + The input to the model is expected to be a list of tensors, each of shape [C, H, W], one for each + image, and should be in 0-1 range. Different images can have different sizes. + + The behavior of the model changes depending if it is in training or evaluation mode. + + During training, the model expects both the input tensors, as well as a targets (list of dictionary), + containing: + - boxes (FloatTensor[N, 4]): the ground-truth boxes in [x1, y1, x2, y2] format, with values + between 0 and H and 0 and W + - labels (Int64Tensor[N]): the class label for each ground-truth box + + The model returns a Dict[Tensor] during training, containing the classification and regression + losses. + + During inference, the model requires only the input tensors, and returns the post-processed + predictions as a List[Dict[Tensor]], one for each input image. The fields of the Dict are as + follows: + - boxes (FloatTensor[N, 4]): the predicted boxes in [x1, y1, x2, y2] format, with values between + 0 and H and 0 and W + - labels (Int64Tensor[N]): the predicted labels for each image + - scores (Tensor[N]): the scores for each prediction + + Arguments: + backbone (nn.Module): the network used to compute the features for the model. + It should contain an out_channels attribute, which indicates the number of output + channels that each feature map has (and it should be the same for all feature maps). + The backbone should return a single Tensor or an OrderedDict[Tensor]. + num_classes (int): number of output classes of the model (excluding the background). + min_size (int): minimum size of the image to be rescaled before feeding it to the backbone + max_size (int): maximum size of the image to be rescaled before feeding it to the backbone + image_mean (Tuple[float, float, float]): mean values used for input normalization. + They are generally the mean values of the dataset on which the backbone has been trained + on + image_std (Tuple[float, float, float]): std values used for input normalization. + They are generally the std values of the dataset on which the backbone has been trained on + anchor_generator (AnchorGenerator): module that generates the anchors for a set of feature + maps. + head (nn.Module): Module run on top of the feature pyramid. + Defaults to a module containing a classification and regression module. + pre_nms_top_n (int): number of proposals to keep before applying NMS during testing. + post_nms_top_n (int): number of proposals to keep after applying NMS during testing. + nms_thresh (float): NMS threshold used for postprocessing the detections. + fg_iou_thresh (float): minimum IoU between the anchor and the GT box so that they can be + considered as positive during training. + bg_iou_thresh (float): maximum IoU between the anchor and the GT box so that they can be + considered as negative during training. + + Example:: + + >>> import torch + >>> import torchvision + >>> from torchvision.models.detection import RetinaNet + >>> from torchvision.models.detection.rpn import AnchorGenerator + >>> # load a pre-trained model for classification and return + >>> # only the features + >>> backbone = torchvision.models.mobilenet_v2(pretrained=True).features + >>> # RetinaNet needs to know the number of + >>> # output channels in a backbone. For mobilenet_v2, it's 1280 + >>> # so we need to add it here + >>> backbone.out_channels = 1280 + >>> + >>> # let's make the network generate 5 x 3 anchors per spatial + >>> # location, with 5 different sizes and 3 different aspect + >>> # ratios. We have a Tuple[Tuple[int]] because each feature + >>> # map could potentially have different sizes and + >>> # aspect ratios + >>> anchor_generator = AnchorGenerator(sizes=((32, 64, 128, 256, 512),), + >>> aspect_ratios=((0.5, 1.0, 2.0),)) + >>> + >>> # put the pieces together inside a RetinaNet model + >>> model = RetinaNet(backbone, + >>> num_classes=2, + >>> anchor_generator=anchor_generator) + >>> model.eval() + >>> x = [torch.rand(3, 300, 400), torch.rand(3, 500, 400)] + >>> predictions = model(x) + """ + + def __init__(self, backbone, num_classes, + # transform parameters + min_size=800, max_size=1333, + image_mean=None, image_std=None, + # Anchor parameters + anchor_generator=None, head=None, + pre_nms_top_n=1000, post_nms_top_n=1000, + nms_thresh=0.5, + fg_iou_thresh=0.5, bg_iou_thresh=0.4): + + if not hasattr(backbone, "out_channels"): + raise ValueError( + "backbone should contain an attribute out_channels " + "specifying the number of output channels (assumed to be the " + "same for all the levels)") + + assert isinstance(anchor_generator, (AnchorGenerator, type(None))) + + if anchor_generator is None: + anchor_sizes = ((32,), (64,), (128,), (256,), (512,)) + aspect_ratios = ((0.5, 1.0, 2.0),) * len(anchor_sizes) + self.anchor_generator = AnchorGenerator( + anchor_sizes, aspect_ratios + ) + + if head is None: + head = RetinaNetHead(backbone.out_channels, num_classes, anchor_generator.num_anchors_per_location()) + self.head = head + + if image_mean is None: + image_mean = [0.485, 0.456, 0.406] + if image_std is None: + image_std = [0.229, 0.224, 0.225] + self.transform = GeneralizedRCNNTransform(min_size, max_size, image_mean, image_std) + + @torch.jit.unused + def eager_outputs(self, losses, detections): + # type: (Dict[str, Tensor], List[Dict[str, Tensor]]) -> Tuple[Dict[str, Tensor], List[Dict[str, Tensor]]] + if self.training: + return losses + + return detections + + def forward(self, images, targets=None): + # type: (List[Tensor], Optional[List[Dict[str, Tensor]]]) + """ + Arguments: + images (list[Tensor]): images to be processed + targets (list[Dict[Tensor]]): ground-truth boxes present in the image (optional) + + Returns: + result (list[BoxList] or dict[Tensor]): the output from the model. + During training, it returns a dict[Tensor] which contains the losses. + During testing, it returns list[BoxList] contains additional fields + like `scores`, `labels` and `mask` (for Mask R-CNN models). + + """ + if self.training and targets is None: + raise ValueError("In training mode, targets should be passed") + + # get the original image sizes + original_image_sizes = torch.jit.annotate(List[Tuple[int, int]], []) + for img in images: + val = img.shape[-2:] + assert len(val) == 2 + original_image_sizes.append((val[0], val[1])) + + # transform the input + images, targets = self.transform(images, targets) + + # get the features from the backbone + features = self.backbone(images.tensors) + if isinstance(features, torch.Tensor): + features = OrderedDict([('0', features)]) + + # compute the retinanet heads outputs using the features + head_outputs = self.head(images, features, targets) + + # create the set of anchors + anchors = self.anchor_generator(images, features) + + losses = {} + detections = torch.jit.annotate(List[Dict[str, torch.Tensor]], []) + if self.training: + assert targets is not None + + # compute the losses + # TODO: Move necessary functions out of rpn.RegionProposalNetwork to a class or function + # so that we can use it here and in rpn.RegionProposalNetwork + labels, matched_gt_boxes = self.assign_targets_to_anchors(anchors, targets) + regression_targets = self.box_coder.encode(matched_gt_boxes, anchors) + losses = self.head.compute_loss(head_outputs, labels, matched_gt_boxes) + else: + # compute the detections + # TODO: Implement postprocess_detections + boxes, scores, labels = self.postprocess_detections(class_logits, box_regression, anchors) + num_images = len(images) + for i in range(num_images): + detections.append( + { + "boxes": boxes[i], + "labels": labels[i], + "scores": scores[i], + } + ) + + detections = self.transform.postprocess(detections, images.image_sizes, original_image_sizes) + + if torch.jit.is_scripting(): + if not self._has_warned: + warnings.warn("RetinaNet always returns a (Losses, Detections) tuple in scripting") + self._has_warned = True + return (losses, detections) + else: + return self.eager_outputs(losses, detections) + + +model_urls = { + 'retinanet_resnet50_fpn_coco': + '#TODO', +} + + +def retinanet_resnet50_fpn(pretrained=False, progress=True, + num_classes=91, pretrained_backbone=True, **kwargs): + """ + Constructs a RetinaNet model with a ResNet-50-FPN backbone. + + The input to the model is expected to be a list of tensors, each of shape ``[C, H, W]``, one for each + image, and should be in ``0-1`` range. Different images can have different sizes. + + The behavior of the model changes depending if it is in training or evaluation mode. + + During training, the model expects both the input tensors, as well as a targets (list of dictionary), + containing: + - boxes (``FloatTensor[N, 4]``): the ground-truth boxes in ``[x1, y1, x2, y2]`` format, with values + between ``0`` and ``H`` and ``0`` and ``W`` + - labels (``Int64Tensor[N]``): the class label for each ground-truth box + + The model returns a ``Dict[Tensor]`` during training, containing the classification and regression + losses. + + During inference, the model requires only the input tensors, and returns the post-processed + predictions as a ``List[Dict[Tensor]]``, one for each input image. The fields of the ``Dict`` are as + follows: + - boxes (``FloatTensor[N, 4]``): the predicted boxes in ``[x1, y1, x2, y2]`` format, with values between + ``0`` and ``H`` and ``0`` and ``W`` + - labels (``Int64Tensor[N]``): the predicted labels for each image + - scores (``Tensor[N]``): the scores or each prediction + + Example:: + + >>> model = torchvision.models.detection.retinanet_resnet50_fpn(pretrained=True) + >>> model.eval() + >>> x = [torch.rand(3, 300, 400), torch.rand(3, 500, 400)] + >>> predictions = model(x) + + Arguments: + pretrained (bool): If True, returns a model pre-trained on COCO train2017 + progress (bool): If True, displays a progress bar of the download to stderr + """ + if pretrained: + # no need to download the backbone if pretrained is set + pretrained_backbone = False + backbone = resnet_fpn_backbone('resnet50', pretrained_backbone) + model = RetinaNet(backbone, num_classes, **kwargs) + if pretrained: + state_dict = load_state_dict_from_url(model_urls['retinanet_resnet50_fpn_coco'], + progress=progress) + model.load_state_dict(state_dict) + return model