diff --git a/docs/source/models.rst b/docs/source/models.rst index 4c65eac8135..4daee5d5534 100644 --- a/docs/source/models.rst +++ b/docs/source/models.rst @@ -597,6 +597,7 @@ The models subpackage contains definitions for the following model architectures for detection: - `Faster R-CNN `_ +- `FCOS `_ - `Mask R-CNN `_ - `RetinaNet `_ - `SSD `_ @@ -642,6 +643,7 @@ Network box AP mask AP keypoint AP Faster R-CNN ResNet-50 FPN 37.0 - - Faster R-CNN MobileNetV3-Large FPN 32.8 - - Faster R-CNN MobileNetV3-Large 320 FPN 22.8 - - +FCOS ResNet-50 FPN 39.2 - - RetinaNet ResNet-50 FPN 36.4 - - SSD300 VGG16 25.1 - - SSDlite320 MobileNetV3-Large 21.3 - - @@ -702,6 +704,7 @@ Network train time (s / it) test time (s / it) Faster R-CNN ResNet-50 FPN 0.2288 0.0590 5.2 Faster R-CNN MobileNetV3-Large FPN 0.1020 0.0415 1.0 Faster R-CNN MobileNetV3-Large 320 FPN 0.0978 0.0376 0.6 +FCOS ResNet-50 FPN 0.1450 0.0539 3.3 RetinaNet ResNet-50 FPN 0.2514 0.0939 4.1 SSD300 VGG16 0.2093 0.0744 1.5 SSDlite320 MobileNetV3-Large 0.1773 0.0906 1.5 @@ -721,6 +724,15 @@ Faster R-CNN torchvision.models.detection.fasterrcnn_mobilenet_v3_large_fpn torchvision.models.detection.fasterrcnn_mobilenet_v3_large_320_fpn +FCOS +---- + +.. autosummary:: + :toctree: generated/ + :template: function.rst + + torchvision.models.detection.fcos_resnet50_fpn + RetinaNet --------- diff --git a/mypy.ini b/mypy.ini index 931665240f3..c2012102143 100644 --- a/mypy.ini +++ b/mypy.ini @@ -70,6 +70,10 @@ ignore_errors = True ignore_errors = True +[mypy-torchvision.models.detection.fcos] + +ignore_errors = True + [mypy-torchvision.ops.*] ignore_errors = True diff --git a/references/detection/README.md b/references/detection/README.md index 4d44f67b4c0..3695644138b 100644 --- a/references/detection/README.md +++ b/references/detection/README.md @@ -41,6 +41,13 @@ torchrun --nproc_per_node=8 train.py\ --lr-steps 16 22 --aspect-ratio-group-factor 3 ``` +### FCOS ResNet-50 FPN +``` +torchrun --nproc_per_node=8 train.py\ + --dataset coco --model fcos_resnet50_fpn --epochs 26\ + --lr-steps 16 22 --aspect-ratio-group-factor 3 --lr 0.01 --amp +``` + ### RetinaNet ``` torchrun --nproc_per_node=8 train.py\ diff --git a/test/expect/ModelTester.test_fcos_resnet50_fpn_expect.pkl b/test/expect/ModelTester.test_fcos_resnet50_fpn_expect.pkl new file mode 100644 index 00000000000..0657261d96c Binary files /dev/null and b/test/expect/ModelTester.test_fcos_resnet50_fpn_expect.pkl differ diff --git a/test/test_models.py b/test/test_models.py index f4f1828d8af..0c99d0dfe66 100644 --- a/test/test_models.py +++ b/test/test_models.py @@ -218,6 +218,7 @@ def _check_input_backprop(model, inputs): "retinanet_resnet50_fpn": lambda x: x[1], "ssd300_vgg16": lambda x: x[1], "ssdlite320_mobilenet_v3_large": lambda x: x[1], + "fcos_resnet50_fpn": lambda x: x[1], } @@ -274,6 +275,13 @@ def _check_input_backprop(model, inputs): "max_size": 224, "input_shape": (3, 224, 224), }, + "fcos_resnet50_fpn": { + "num_classes": 2, + "score_thresh": 0.05, + "min_size": 224, + "max_size": 224, + "input_shape": (3, 224, 224), + }, "maskrcnn_resnet50_fpn": { "num_classes": 10, "min_size": 224, @@ -325,6 +333,10 @@ def _check_input_backprop(model, inputs): "max_trainable": 6, "n_trn_params_per_layer": [96, 99, 138, 200, 239, 257, 266], }, + "fcos_resnet50_fpn": { + "max_trainable": 5, + "n_trn_params_per_layer": [54, 64, 83, 96, 106, 107], + }, } diff --git a/test/test_models_detection_utils.py b/test/test_models_detection_utils.py index 8d686023b1d..6551a1a759f 100644 --- a/test/test_models_detection_utils.py +++ b/test/test_models_detection_utils.py @@ -22,6 +22,19 @@ def test_balanced_positive_negative_sampler(self): assert neg[0].sum() == 3 assert neg[0][0:6].sum() == 3 + def test_box_linear_coder(self): + box_coder = _utils.BoxLinearCoder(normalize_by_size=True) + # Generate a random 10x4 boxes tensor, with coordinates < 50. + boxes = torch.rand(10, 4) * 50 + boxes.clamp_(min=1.0) # tiny boxes cause numerical instability in box regression + boxes[:, 2:] += boxes[:, :2] + + proposals = torch.tensor([0, 0, 101, 101] * 10).reshape(10, 4).float() + + rel_codes = box_coder.encode_single(boxes, proposals) + pred_boxes = box_coder.decode_single(rel_codes, boxes) + torch.allclose(proposals, pred_boxes) + @pytest.mark.parametrize("train_layers, exp_froz_params", [(0, 53), (1, 43), (2, 24), (3, 11), (4, 1), (5, 0)]) def test_resnet_fpn_backbone_frozen_layers(self, train_layers, exp_froz_params): # we know how many initial layers and parameters of the network should diff --git a/torchvision/models/detection/__init__.py b/torchvision/models/detection/__init__.py index 4772415b3b1..be46f950a61 100644 --- a/torchvision/models/detection/__init__.py +++ b/torchvision/models/detection/__init__.py @@ -4,3 +4,4 @@ from .retinanet import * from .ssd import * from .ssdlite import * +from .fcos import * diff --git a/torchvision/models/detection/_utils.py b/torchvision/models/detection/_utils.py index b870e6a2456..ef4f6550eef 100644 --- a/torchvision/models/detection/_utils.py +++ b/torchvision/models/detection/_utils.py @@ -217,6 +217,83 @@ def decode_single(self, rel_codes: Tensor, boxes: Tensor) -> Tensor: return pred_boxes +class BoxLinearCoder: + """ + The linear box-to-box transform defined in FCOS. The transformation is parameterized + by the distance from the center of (square) src box to 4 edges of the target box. + """ + + def __init__(self, normalize_by_size: bool = True) -> None: + """ + Args: + normalize_by_size (bool): normalize deltas by the size of src (anchor) boxes. + """ + self.normalize_by_size = normalize_by_size + + def encode_single(self, reference_boxes: Tensor, proposals: Tensor) -> Tensor: + """ + Encode a set of proposals with respect to some reference boxes + + Args: + reference_boxes (Tensor): reference boxes + proposals (Tensor): boxes to be encoded + + Returns: + Tensor: the encoded relative box offsets that can be used to + decode the boxes. + """ + # get the center of reference_boxes + reference_boxes_ctr_x = 0.5 * (reference_boxes[:, 0] + reference_boxes[:, 2]) + reference_boxes_ctr_y = 0.5 * (reference_boxes[:, 1] + reference_boxes[:, 3]) + + # get box regression transformation deltas + target_l = reference_boxes_ctr_x - proposals[:, 0] + target_t = reference_boxes_ctr_y - proposals[:, 1] + target_r = proposals[:, 2] - reference_boxes_ctr_x + target_b = proposals[:, 3] - reference_boxes_ctr_y + + targets = torch.stack((target_l, target_t, target_r, target_b), dim=1) + if self.normalize_by_size: + reference_boxes_w = reference_boxes[:, 2] - reference_boxes[:, 0] + reference_boxes_h = reference_boxes[:, 3] - reference_boxes[:, 1] + reference_boxes_size = torch.stack( + (reference_boxes_w, reference_boxes_h, reference_boxes_w, reference_boxes_h), dim=1 + ) + targets = targets / reference_boxes_size + + return targets + + def decode_single(self, rel_codes: Tensor, boxes: Tensor) -> Tensor: + """ + From a set of original boxes and encoded relative box offsets, + get the decoded boxes. + + Args: + rel_codes (Tensor): encoded boxes + boxes (Tensor): reference boxes. + + Returns: + Tensor: the predicted boxes with the encoded relative box offsets. + """ + + boxes = boxes.to(rel_codes.dtype) + + ctr_x = 0.5 * (boxes[:, 0] + boxes[:, 2]) + ctr_y = 0.5 * (boxes[:, 1] + boxes[:, 3]) + if self.normalize_by_size: + boxes_w = boxes[:, 2] - boxes[:, 0] + boxes_h = boxes[:, 3] - boxes[:, 1] + boxes_size = torch.stack((boxes_w, boxes_h, boxes_w, boxes_h), dim=1) + rel_codes = rel_codes * boxes_size + + pred_boxes1 = ctr_x - rel_codes[:, 0] + pred_boxes2 = ctr_y - rel_codes[:, 1] + pred_boxes3 = ctr_x + rel_codes[:, 2] + pred_boxes4 = ctr_y + rel_codes[:, 3] + pred_boxes = torch.stack((pred_boxes1, pred_boxes2, pred_boxes3, pred_boxes4), dim=1) + return pred_boxes + + class Matcher: """ This class assigns to each predicted "element" (e.g., a box) a ground-truth diff --git a/torchvision/models/detection/fcos.py b/torchvision/models/detection/fcos.py new file mode 100644 index 00000000000..71a6306e7e1 --- /dev/null +++ b/torchvision/models/detection/fcos.py @@ -0,0 +1,700 @@ +import math +import warnings +from collections import OrderedDict +from functools import partial +from typing import Callable, Dict, List, Tuple, Optional + +import torch +from torch import nn, Tensor + +from ..._internally_replaced_utils import load_state_dict_from_url +from ...ops import sigmoid_focal_loss, generalized_box_iou_loss +from ...ops import boxes as box_ops +from ...ops import misc as misc_nn_ops +from ...ops.feature_pyramid_network import LastLevelP6P7 +from ...utils import _log_api_usage_once +from ..resnet import resnet50 +from . import _utils as det_utils +from .anchor_utils import AnchorGenerator +from .backbone_utils import _resnet_fpn_extractor, _validate_trainable_layers +from .transform import GeneralizedRCNNTransform + + +__all__ = ["FCOS", "fcos_resnet50_fpn"] + + +class FCOSHead(nn.Module): + """ + A regression and classification head for use in FCOS. + + Args: + in_channels (int): number of channels of the input feature + num_anchors (int): number of anchors to be predicted + num_classes (int): number of classes to be predicted + num_convs (Optional[int]): number of conv layer of head. Default: 4. + """ + + __annotations__ = { + "box_coder": det_utils.BoxLinearCoder, + } + + def __init__(self, in_channels: int, num_anchors: int, num_classes: int, num_convs: Optional[int] = 4) -> None: + super().__init__() + self.box_coder = det_utils.BoxLinearCoder(normalize_by_size=True) + self.classification_head = FCOSClassificationHead(in_channels, num_anchors, num_classes, num_convs) + self.regression_head = FCOSRegressionHead(in_channels, num_anchors, num_convs) + + def compute_loss( + self, + targets: List[Dict[str, Tensor]], + head_outputs: Dict[str, Tensor], + anchors: List[Tensor], + matched_idxs: List[Tensor], + ) -> Dict[str, Tensor]: + + cls_logits = head_outputs["cls_logits"] # [N, HWA, C] + bbox_regression = head_outputs["bbox_regression"] # [N, HWA, 4] + bbox_ctrness = head_outputs["bbox_ctrness"] # [N, HWA, 1] + + all_gt_classes_targets = [] + all_gt_boxes_targets = [] + for targets_per_image, matched_idxs_per_image in zip(targets, matched_idxs): + gt_classes_targets = targets_per_image["labels"][matched_idxs_per_image.clip(min=0)] + gt_classes_targets[matched_idxs_per_image < 0] = -1 # backgroud + gt_boxes_targets = targets_per_image["boxes"][matched_idxs_per_image.clip(min=0)] + all_gt_classes_targets.append(gt_classes_targets) + all_gt_boxes_targets.append(gt_boxes_targets) + + all_gt_classes_targets = torch.stack(all_gt_classes_targets) + # compute foregroud + foregroud_mask = all_gt_classes_targets >= 0 + num_foreground = foregroud_mask.sum().item() + + # classification loss + gt_classes_targets = torch.zeros_like(cls_logits) + gt_classes_targets[foregroud_mask, all_gt_classes_targets[foregroud_mask]] = 1.0 + loss_cls = sigmoid_focal_loss(cls_logits, gt_classes_targets, reduction="sum") + + # regression loss: GIoU loss + # TODO: vectorize this instead of using a for loop + pred_boxes = [ + self.box_coder.decode_single(bbox_regression_per_image, anchors_per_image) + for anchors_per_image, bbox_regression_per_image in zip(anchors, bbox_regression) + ] + # amp issue: pred_boxes need to convert float + loss_bbox_reg = generalized_box_iou_loss( + torch.stack(pred_boxes)[foregroud_mask].float(), + torch.stack(all_gt_boxes_targets)[foregroud_mask], + reduction="sum", + ) + + # ctrness loss + bbox_reg_targets = [ + self.box_coder.encode_single(anchors_per_image, boxes_targets_per_image) + for anchors_per_image, boxes_targets_per_image in zip(anchors, all_gt_boxes_targets) + ] + bbox_reg_targets = torch.stack(bbox_reg_targets, dim=0) + if len(bbox_reg_targets) == 0: + bbox_reg_targets.new_zeros(len(bbox_reg_targets)) + left_right = bbox_reg_targets[:, :, [0, 2]] + top_bottom = bbox_reg_targets[:, :, [1, 3]] + gt_ctrness_targets = torch.sqrt( + (left_right.min(dim=-1)[0] / left_right.max(dim=-1)[0]) + * (top_bottom.min(dim=-1)[0] / top_bottom.max(dim=-1)[0]) + ) + pred_centerness = bbox_ctrness.squeeze(dim=2) + loss_bbox_ctrness = nn.functional.binary_cross_entropy_with_logits( + pred_centerness[foregroud_mask], gt_ctrness_targets[foregroud_mask], reduction="sum" + ) + + return { + "classification": loss_cls / max(1, num_foreground), + "bbox_regression": loss_bbox_reg / max(1, num_foreground), + "bbox_ctrness": loss_bbox_ctrness / max(1, num_foreground), + } + + def forward(self, x: List[Tensor]) -> Dict[str, Tensor]: + cls_logits = self.classification_head(x) + bbox_regression, bbox_ctrness = self.regression_head(x) + return { + "cls_logits": cls_logits, + "bbox_regression": bbox_regression, + "bbox_ctrness": bbox_ctrness, + } + + +class FCOSClassificationHead(nn.Module): + """ + A classification head for use in FCOS. + + Args: + in_channels (int): number of channels of the input feature. + num_anchors (int): number of anchors to be predicted. + num_classes (int): number of classes to be predicted. + num_convs (Optional[int]): number of conv layer. Default: 4. + prior_probability (Optional[float]): probability of prior. Default: 0.01. + norm_layer: Module specifying the normalization layer to use. + """ + + def __init__( + self, + in_channels: int, + num_anchors: int, + num_classes: int, + num_convs: int = 4, + prior_probability: float = 0.01, + norm_layer: Optional[Callable[..., nn.Module]] = None, + ) -> None: + super().__init__() + + self.num_classes = num_classes + self.num_anchors = num_anchors + + if norm_layer is None: + norm_layer = partial(nn.GroupNorm, 32) + + conv = [] + for _ in range(num_convs): + conv.append(nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=1, padding=1)) + conv.append(norm_layer(in_channels)) + conv.append(nn.ReLU()) + self.conv = nn.Sequential(*conv) + + for layer in self.conv.children(): + if isinstance(layer, nn.Conv2d): + torch.nn.init.normal_(layer.weight, std=0.01) + torch.nn.init.constant_(layer.bias, 0) + + self.cls_logits = nn.Conv2d(in_channels, num_anchors * num_classes, kernel_size=3, stride=1, padding=1) + torch.nn.init.normal_(self.cls_logits.weight, std=0.01) + torch.nn.init.constant_(self.cls_logits.bias, -math.log((1 - prior_probability) / prior_probability)) + + def forward(self, x: List[Tensor]) -> Tensor: + all_cls_logits = [] + + for features in x: + cls_logits = self.conv(features) + cls_logits = self.cls_logits(cls_logits) + + # Permute classification output from (N, A * K, H, W) to (N, HWA, K). + N, _, H, W = cls_logits.shape + cls_logits = cls_logits.view(N, -1, self.num_classes, H, W) + cls_logits = cls_logits.permute(0, 3, 4, 1, 2) + cls_logits = cls_logits.reshape(N, -1, self.num_classes) # Size=(N, HWA, 4) + + all_cls_logits.append(cls_logits) + + return torch.cat(all_cls_logits, dim=1) + + +class FCOSRegressionHead(nn.Module): + """ + A regression head for use in FCOS. + + Args: + in_channels (int): number of channels of the input feature + num_anchors (int): number of anchors to be predicted + num_convs (Optional[int]): number of conv layer. Default: 4. + norm_layer: Module specifying the normalization layer to use. + """ + + def __init__( + self, + in_channels: int, + num_anchors: int, + num_convs: int = 4, + norm_layer: Optional[Callable[..., nn.Module]] = None, + ): + super().__init__() + + if norm_layer is None: + norm_layer = partial(nn.GroupNorm, 32) + + conv = [] + for _ in range(num_convs): + conv.append(nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=1, padding=1)) + conv.append(norm_layer(in_channels)) + conv.append(nn.ReLU()) + self.conv = nn.Sequential(*conv) + + self.bbox_reg = nn.Conv2d(in_channels, num_anchors * 4, kernel_size=3, stride=1, padding=1) + self.bbox_ctrness = nn.Conv2d(in_channels, num_anchors * 1, kernel_size=3, stride=1, padding=1) + for layer in [self.bbox_reg, self.bbox_ctrness]: + torch.nn.init.normal_(layer.weight, std=0.01) + torch.nn.init.zeros_(layer.bias) + + for layer in self.conv.children(): + if isinstance(layer, nn.Conv2d): + torch.nn.init.normal_(layer.weight, std=0.01) + torch.nn.init.zeros_(layer.bias) + + def forward(self, x: List[Tensor]) -> Tuple[Tensor, Tensor]: + all_bbox_regression = [] + all_bbox_ctrness = [] + + for features in x: + bbox_feature = self.conv(features) + bbox_regression = nn.functional.relu(self.bbox_reg(bbox_feature)) + bbox_ctrness = self.bbox_ctrness(bbox_feature) + + # permute bbox regression output from (N, 4 * A, H, W) to (N, HWA, 4). + N, _, H, W = bbox_regression.shape + bbox_regression = bbox_regression.view(N, -1, 4, H, W) + bbox_regression = bbox_regression.permute(0, 3, 4, 1, 2) + bbox_regression = bbox_regression.reshape(N, -1, 4) # Size=(N, HWA, 4) + all_bbox_regression.append(bbox_regression) + + # permute bbox ctrness output from (N, 1 * A, H, W) to (N, HWA, 1). + bbox_ctrness = bbox_ctrness.view(N, -1, 1, H, W) + bbox_ctrness = bbox_ctrness.permute(0, 3, 4, 1, 2) + bbox_ctrness = bbox_ctrness.reshape(N, -1, 1) + all_bbox_ctrness.append(bbox_ctrness) + + return torch.cat(all_bbox_regression, dim=1), torch.cat(all_bbox_ctrness, dim=1) + + +class FCOS(nn.Module): + """ + Implements FCOS. + + The input to the model is expected to be a list of tensors, each of shape [C, H, W], one for each + image, and should be in 0-1 range. Different images can have different sizes. + + The behavior of the model changes depending if it is in training or evaluation mode. + + During training, the model expects both the input tensors, as well as a targets (list of dictionary), + containing: + - boxes (``FloatTensor[N, 4]``): the ground-truth boxes in ``[x1, y1, x2, y2]`` format, with + ``0 <= x1 < x2 <= W`` and ``0 <= y1 < y2 <= H``. + - labels (Int64Tensor[N]): the class label for each ground-truth box + + The model returns a Dict[Tensor] during training, containing the classification, regression + and centerness losses. + + During inference, the model requires only the input tensors, and returns the post-processed + predictions as a List[Dict[Tensor]], one for each input image. The fields of the Dict are as + follows: + - boxes (``FloatTensor[N, 4]``): the predicted boxes in ``[x1, y1, x2, y2]`` format, with + ``0 <= x1 < x2 <= W`` and ``0 <= y1 < y2 <= H``. + - labels (Int64Tensor[N]): the predicted labels for each image + - scores (Tensor[N]): the scores for each prediction + + Args: + backbone (nn.Module): the network used to compute the features for the model. + It should contain an out_channels attribute, which indicates the number of output + channels that each feature map has (and it should be the same for all feature maps). + The backbone should return a single Tensor or an OrderedDict[Tensor]. + num_classes (int): number of output classes of the model (including the background). + min_size (int): minimum size of the image to be rescaled before feeding it to the backbone + max_size (int): maximum size of the image to be rescaled before feeding it to the backbone + image_mean (Tuple[float, float, float]): mean values used for input normalization. + They are generally the mean values of the dataset on which the backbone has been trained + on + image_std (Tuple[float, float, float]): std values used for input normalization. + They are generally the std values of the dataset on which the backbone has been trained on + anchor_generator (AnchorGenerator): module that generates the anchors for a set of feature + maps. For FCOS, only set one anchor for per position of each level, the width and height equal to + the stride of feature map, and set aspect ratio = 1.0, so the center of anchor is equivalent to the point + in FCOS paper. + head (nn.Module): Module run on top of the feature pyramid. + Defaults to a module containing a classification and regression module. + center_sampling_radius (int): radius of the "center" of a groundtruth box, + within which all anchor points are labeled positive. + score_thresh (float): Score threshold used for postprocessing the detections. + nms_thresh (float): NMS threshold used for postprocessing the detections. + detections_per_img (int): Number of best detections to keep after NMS. + topk_candidates (int): Number of best detections to keep before NMS. + + Example: + + >>> import torch + >>> import torchvision + >>> from torchvision.models.detection import FCOS + >>> from torchvision.models.detection.anchor_utils import AnchorGenerator + >>> # load a pre-trained model for classification and return + >>> # only the features + >>> backbone = torchvision.models.mobilenet_v2(pretrained=True).features + >>> # FCOS needs to know the number of + >>> # output channels in a backbone. For mobilenet_v2, it's 1280 + >>> # so we need to add it here + >>> backbone.out_channels = 1280 + >>> + >>> # let's make the network generate 5 x 3 anchors per spatial + >>> # location, with 5 different sizes and 3 different aspect + >>> # ratios. We have a Tuple[Tuple[int]] because each feature + >>> # map could potentially have different sizes and + >>> # aspect ratios + >>> anchor_generator = AnchorGenerator( + >>> sizes=((8,), (16,), (32,), (64,), (128,)), + >>> aspect_ratios=((1.0,),) + >>> ) + >>> + >>> # put the pieces together inside a FCOS model + >>> model = FCOS( + >>> backbone, + >>> num_classes=80, + >>> anchor_generator=anchor_generator, + >>> ) + >>> model.eval() + >>> x = [torch.rand(3, 300, 400), torch.rand(3, 500, 400)] + >>> predictions = model(x) + """ + + __annotations__ = { + "box_coder": det_utils.BoxLinearCoder, + } + + def __init__( + self, + backbone: nn.Module, + num_classes: int, + # transform parameters + min_size: int = 800, + max_size: int = 1333, + image_mean: Optional[List[float]] = None, + image_std: Optional[List[float]] = None, + # Anchor parameters + anchor_generator: Optional[AnchorGenerator] = None, + head: Optional[nn.Module] = None, + center_sampling_radius: float = 1.5, + score_thresh: float = 0.2, + nms_thresh: float = 0.6, + detections_per_img: int = 100, + topk_candidates: int = 1000, + ): + super().__init__() + _log_api_usage_once(self) + + if not hasattr(backbone, "out_channels"): + raise ValueError( + "backbone should contain an attribute out_channels " + "specifying the number of output channels (assumed to be the " + "same for all the levels)" + ) + self.backbone = backbone + + assert isinstance(anchor_generator, (AnchorGenerator, type(None))) + + if anchor_generator is None: + anchor_sizes = ((8,), (16,), (32,), (64,), (128,)) # equal to strides of multi-level feature map + aspect_ratios = ((1.0,),) * len(anchor_sizes) # set only one anchor + anchor_generator = AnchorGenerator(anchor_sizes, aspect_ratios) + self.anchor_generator = anchor_generator + assert self.anchor_generator.num_anchors_per_location()[0] == 1 + + if head is None: + head = FCOSHead(backbone.out_channels, anchor_generator.num_anchors_per_location()[0], num_classes) + self.head = head + + self.box_coder = det_utils.BoxLinearCoder(normalize_by_size=True) + + if image_mean is None: + image_mean = [0.485, 0.456, 0.406] + if image_std is None: + image_std = [0.229, 0.224, 0.225] + self.transform = GeneralizedRCNNTransform(min_size, max_size, image_mean, image_std) + + self.center_sampling_radius = center_sampling_radius + self.score_thresh = score_thresh + self.nms_thresh = nms_thresh + self.detections_per_img = detections_per_img + self.topk_candidates = topk_candidates + + # used only on torchscript mode + self._has_warned = False + + @torch.jit.unused + def eager_outputs( + self, losses: Dict[str, Tensor], detections: List[Dict[str, Tensor]] + ) -> Tuple[Dict[str, Tensor], List[Dict[str, Tensor]]]: + if self.training: + return losses + + return detections + + def compute_loss( + self, + targets: List[Dict[str, Tensor]], + head_outputs: Dict[str, Tensor], + anchors: List[Tensor], + num_anchors_per_level: List[int], + ) -> Dict[str, Tensor]: + matched_idxs = [] + for anchors_per_image, targets_per_image in zip(anchors, targets): + if targets_per_image["boxes"].numel() == 0: + matched_idxs.append( + torch.full((anchors_per_image.size(0),), -1, dtype=torch.int64, device=anchors_per_image.device) + ) + continue + + gt_boxes = targets_per_image["boxes"] + gt_centers = (gt_boxes[:, :2] + gt_boxes[:, 2:]) / 2 # Nx2 + anchor_centers = (anchors_per_image[:, :2] + anchors_per_image[:, 2:]) / 2 # N + anchor_sizes = anchors_per_image[:, 2] - anchors_per_image[:, 0] + # center sampling: anchor point must be close enough to gt center. + pairwise_match = (anchor_centers[:, None, :] - gt_centers[None, :, :]).abs_().max( + dim=2 + ).values < self.center_sampling_radius * anchor_sizes[:, None] + # compute pairwise distance between N points and M boxes + x, y = anchor_centers.unsqueeze(dim=2).unbind(dim=1) # (N, 1) + x0, y0, x1, y1 = gt_boxes.unsqueeze(dim=0).unbind(dim=2) # (1, M) + pairwise_dist = torch.stack([x - x0, y - y0, x1 - x, y1 - y], dim=2) # (N, M) + + # anchor point must be inside gt + pairwise_match &= pairwise_dist.min(dim=2).values > 0 + + # each anchor is only responsible for certain scale range. + lower_bound = anchor_sizes * 4 + lower_bound[: num_anchors_per_level[0]] = 0 + upper_bound = anchor_sizes * 8 + upper_bound[-num_anchors_per_level[-1] :] = float("inf") + pairwise_dist = pairwise_dist.max(dim=2).values + pairwise_match &= (pairwise_dist > lower_bound[:, None]) & (pairwise_dist < upper_bound[:, None]) + + # match the GT box with minimum area, if there are multiple GT matches + gt_areas = (gt_boxes[:, 1] - gt_boxes[:, 0]) * (gt_boxes[:, 3] - gt_boxes[:, 1]) # N + pairwise_match = pairwise_match.to(torch.float32) * (1e8 - gt_areas[None, :]) + min_values, matched_idx = pairwise_match.max(dim=1) # R, per-anchor match + matched_idx[min_values < 1e-5] = -1 # unmatched anchors are assigned -1 + + matched_idxs.append(matched_idx) + + return self.head.compute_loss(targets, head_outputs, anchors, matched_idxs) + + def postprocess_detections( + self, head_outputs: Dict[str, List[Tensor]], anchors: List[List[Tensor]], image_shapes: List[Tuple[int, int]] + ) -> List[Dict[str, Tensor]]: + class_logits = head_outputs["cls_logits"] + box_regression = head_outputs["bbox_regression"] + box_ctrness = head_outputs["bbox_ctrness"] + + num_images = len(image_shapes) + + detections: List[Dict[str, Tensor]] = [] + + for index in range(num_images): + box_regression_per_image = [br[index] for br in box_regression] + logits_per_image = [cl[index] for cl in class_logits] + box_ctrness_per_image = [bc[index] for bc in box_ctrness] + anchors_per_image, image_shape = anchors[index], image_shapes[index] + + image_boxes = [] + image_scores = [] + image_labels = [] + + for box_regression_per_level, logits_per_level, box_ctrness_per_level, anchors_per_level in zip( + box_regression_per_image, logits_per_image, box_ctrness_per_image, anchors_per_image + ): + num_classes = logits_per_level.shape[-1] + + # remove low scoring boxes + scores_per_level = torch.sqrt( + torch.sigmoid(logits_per_level) * torch.sigmoid(box_ctrness_per_level) + ).flatten() + keep_idxs = scores_per_level > self.score_thresh + scores_per_level = scores_per_level[keep_idxs] + topk_idxs = torch.where(keep_idxs)[0] + + # keep only topk scoring predictions + num_topk = min(self.topk_candidates, topk_idxs.size(0)) + scores_per_level, idxs = scores_per_level.topk(num_topk) + topk_idxs = topk_idxs[idxs] + + anchor_idxs = torch.div(topk_idxs, num_classes, rounding_mode="floor") + labels_per_level = topk_idxs % num_classes + + boxes_per_level = self.box_coder.decode_single( + box_regression_per_level[anchor_idxs], anchors_per_level[anchor_idxs] + ) + boxes_per_level = box_ops.clip_boxes_to_image(boxes_per_level, image_shape) + + image_boxes.append(boxes_per_level) + image_scores.append(scores_per_level) + image_labels.append(labels_per_level) + + image_boxes = torch.cat(image_boxes, dim=0) + image_scores = torch.cat(image_scores, dim=0) + image_labels = torch.cat(image_labels, dim=0) + + # non-maximum suppression + keep = box_ops.batched_nms(image_boxes, image_scores, image_labels, self.nms_thresh) + keep = keep[: self.detections_per_img] + + detections.append( + { + "boxes": image_boxes[keep], + "scores": image_scores[keep], + "labels": image_labels[keep], + } + ) + + return detections + + def forward( + self, + images: List[Tensor], + targets: Optional[List[Dict[str, Tensor]]] = None, + ) -> Tuple[Dict[str, Tensor], List[Dict[str, Tensor]]]: + """ + Args: + images (list[Tensor]): images to be processed + targets (list[Dict[Tensor]]): ground-truth boxes present in the image (optional) + + Returns: + result (list[BoxList] or dict[Tensor]): the output from the model. + During training, it returns a dict[Tensor] which contains the losses. + During testing, it returns list[BoxList] contains additional fields + like `scores`, `labels` and `mask` (for Mask R-CNN models). + """ + if self.training: + if targets is None: + raise ValueError("In training mode, targets should be passed") + for target in targets: + boxes = target["boxes"] + if isinstance(boxes, torch.Tensor): + if len(boxes.shape) != 2 or boxes.shape[-1] != 4: + raise ValueError(f"Expected target boxes to be a tensor of shape [N, 4], got {boxes.shape}.") + else: + raise ValueError(f"Expected target boxes to be of type Tensor, got {type(boxes)}.") + + original_image_sizes: List[Tuple[int, int]] = [] + for img in images: + val = img.shape[-2:] + assert len(val) == 2 + original_image_sizes.append((val[0], val[1])) + + # transform the input + images, targets = self.transform(images, targets) + + # Check for degenerate boxes + if targets is not None: + for target_idx, target in enumerate(targets): + boxes = target["boxes"] + degenerate_boxes = boxes[:, 2:] <= boxes[:, :2] + if degenerate_boxes.any(): + # print the first degenerate box + bb_idx = torch.where(degenerate_boxes.any(dim=1))[0][0] + degen_bb: List[float] = boxes[bb_idx].tolist() + raise ValueError( + "All bounding boxes should have positive height and width." + f" Found invalid box {degen_bb} for target at index {target_idx}." + ) + + # get the features from the backbone + features = self.backbone(images.tensors) + if isinstance(features, torch.Tensor): + features = OrderedDict([("0", features)]) + + features = list(features.values()) + + # compute the fcos heads outputs using the features + head_outputs = self.head(features) + + # create the set of anchors + anchors = self.anchor_generator(images, features) + # recover level sizes + num_anchors_per_level = [x.size(2) * x.size(3) for x in features] + + losses = {} + detections: List[Dict[str, Tensor]] = [] + if self.training: + assert targets is not None + + # compute the losses + losses = self.compute_loss(targets, head_outputs, anchors, num_anchors_per_level) + else: + # split outputs per level + split_head_outputs: Dict[str, List[Tensor]] = {} + for k in head_outputs: + split_head_outputs[k] = list(head_outputs[k].split(num_anchors_per_level, dim=1)) + split_anchors = [list(a.split(num_anchors_per_level)) for a in anchors] + + # compute the detections + detections = self.postprocess_detections(split_head_outputs, split_anchors, images.image_sizes) + detections = self.transform.postprocess(detections, images.image_sizes, original_image_sizes) + + if torch.jit.is_scripting(): + if not self._has_warned: + warnings.warn("FCOS always returns a (Losses, Detections) tuple in scripting") + self._has_warned = True + return losses, detections + return self.eager_outputs(losses, detections) + + +model_urls = { + "fcos_resnet50_fpn_coco": "https://download.pytorch.org/models/fcos_resnet50_fpn_coco-99b0c9b7.pth", +} + + +def fcos_resnet50_fpn( + pretrained: bool = False, + progress: bool = True, + num_classes: int = 91, + pretrained_backbone: bool = True, + trainable_backbone_layers: Optional[int] = None, + **kwargs, +): + """ + Constructs a FCOS model with a ResNet-50-FPN backbone. + + Reference: `"FCOS: Fully Convolutional One-Stage Object Detection" `_. + + The input to the model is expected to be a list of tensors, each of shape ``[C, H, W]``, one for each + image, and should be in ``0-1`` range. Different images can have different sizes. + + The behavior of the model changes depending if it is in training or evaluation mode. + + During training, the model expects both the input tensors, as well as a targets (list of dictionary), + containing: + + - boxes (``FloatTensor[N, 4]``): the ground-truth boxes in ``[x1, y1, x2, y2]`` format, with + ``0 <= x1 < x2 <= W`` and ``0 <= y1 < y2 <= H``. + - labels (``Int64Tensor[N]``): the class label for each ground-truth box + + The model returns a ``Dict[Tensor]`` during training, containing the classification and regression + losses. + + During inference, the model requires only the input tensors, and returns the post-processed + predictions as a ``List[Dict[Tensor]]``, one for each input image. The fields of the ``Dict`` are as + follows, where ``N`` is the number of detections: + + - boxes (``FloatTensor[N, 4]``): the predicted boxes in ``[x1, y1, x2, y2]`` format, with + ``0 <= x1 < x2 <= W`` and ``0 <= y1 < y2 <= H``. + - labels (``Int64Tensor[N]``): the predicted labels for each detection + - scores (``Tensor[N]``): the scores of each detection + + For more details on the output, you may refer to :ref:`instance_seg_output`. + + Example: + + >>> model = torchvision.models.detection.fcos_resnet50_fpn(pretrained=True) + >>> model.eval() + >>> x = [torch.rand(3, 300, 400), torch.rand(3, 500, 400)] + >>> predictions = model(x) + + Args: + pretrained (bool): If True, returns a model pre-trained on COCO train2017 + progress (bool): If True, displays a progress bar of the download to stderr + num_classes (int): number of output classes of the model (including the background) + pretrained_backbone (bool): If True, returns a model with backbone pre-trained on Imagenet + trainable_backbone_layers (int, optional): number of trainable (not frozen) resnet layers starting + from final block. Valid values are between 0 and 5, with 5 meaning all backbone layers are + trainable. If ``None`` is passed (the default) this value is set to 3. Default: None + """ + trainable_backbone_layers = _validate_trainable_layers( + pretrained or pretrained_backbone, trainable_backbone_layers, 5, 3 + ) + + if pretrained: + # no need to download the backbone if pretrained is set + pretrained_backbone = False + + backbone = resnet50(pretrained=pretrained_backbone, progress=progress, norm_layer=misc_nn_ops.FrozenBatchNorm2d) + backbone = _resnet_fpn_extractor( + backbone, trainable_backbone_layers, returned_layers=[2, 3, 4], extra_blocks=LastLevelP6P7(256, 256) + ) + model = FCOS(backbone, num_classes, **kwargs) + if pretrained: + state_dict = load_state_dict_from_url(model_urls["fcos_resnet50_fpn_coco"], progress=progress) + model.load_state_dict(state_dict) + return model diff --git a/torchvision/ops/__init__.py b/torchvision/ops/__init__.py index d5cdf39d20f..33a48995869 100644 --- a/torchvision/ops/__init__.py +++ b/torchvision/ops/__init__.py @@ -13,6 +13,7 @@ from .deform_conv import deform_conv2d, DeformConv2d from .feature_pyramid_network import FeaturePyramidNetwork from .focal_loss import sigmoid_focal_loss +from .generalized_box_iou_loss import generalized_box_iou_loss from .misc import FrozenBatchNorm2d, ConvNormActivation, SqueezeExcitation from .poolers import MultiScaleRoIAlign from .ps_roi_align import ps_roi_align, PSRoIAlign @@ -52,4 +53,5 @@ "FrozenBatchNorm2d", "ConvNormActivation", "SqueezeExcitation", + "generalized_box_iou_loss", ] diff --git a/torchvision/ops/generalized_box_iou_loss.py b/torchvision/ops/generalized_box_iou_loss.py new file mode 100644 index 00000000000..1ac9433250d --- /dev/null +++ b/torchvision/ops/generalized_box_iou_loss.py @@ -0,0 +1,71 @@ +import torch + + +def generalized_box_iou_loss( + boxes1: torch.Tensor, + boxes2: torch.Tensor, + reduction: str = "none", + eps: float = 1e-7, +) -> torch.Tensor: + """ + Original implementation from + https://github.com/facebookresearch/fvcore/blob/bfff2ef/fvcore/nn/giou_loss.py + + Gradient-friendly IoU loss with an additional penalty that is non-zero when the + boxes do not overlap and scales with the size of their smallest enclosing box. + This loss is symmetric, so the boxes1 and boxes2 arguments are interchangeable. + + Both sets of boxes are expected to be in ``(x1, y1, x2, y2)`` format with + ``0 <= x1 < x2`` and ``0 <= y1 < y2``, and The two boxes should have the + same dimensions. + + Args: + boxes1 (Tensor[N, 4] or Tensor[4]): first set of boxes + boxes2 (Tensor[N, 4] or Tensor[4]): second set of boxes + reduction (string, optional): Specifies the reduction to apply to the output: + ``'none'`` | ``'mean'`` | ``'sum'``. ``'none'``: No reduction will be + applied to the output. ``'mean'``: The output will be averaged. + ``'sum'``: The output will be summed. Default: ``'none'`` + eps (float, optional): small number to prevent division by zero. Default: 1e-7 + + Reference: + Hamid Rezatofighi et. al: Generalized Intersection over Union: + A Metric and A Loss for Bounding Box Regression: + https://arxiv.org/abs/1902.09630 + """ + + x1, y1, x2, y2 = boxes1.unbind(dim=-1) + x1g, y1g, x2g, y2g = boxes2.unbind(dim=-1) + + assert (x2 >= x1).all(), "bad box: x1 larger than x2" + assert (y2 >= y1).all(), "bad box: y1 larger than y2" + + # Intersection keypoints + xkis1 = torch.max(x1, x1g) + ykis1 = torch.max(y1, y1g) + xkis2 = torch.min(x2, x2g) + ykis2 = torch.min(y2, y2g) + + intsctk = torch.zeros_like(x1) + mask = (ykis2 > ykis1) & (xkis2 > xkis1) + intsctk[mask] = (xkis2[mask] - xkis1[mask]) * (ykis2[mask] - ykis1[mask]) + unionk = (x2 - x1) * (y2 - y1) + (x2g - x1g) * (y2g - y1g) - intsctk + iouk = intsctk / (unionk + eps) + + # smallest enclosing box + xc1 = torch.min(x1, x1g) + yc1 = torch.min(y1, y1g) + xc2 = torch.max(x2, x2g) + yc2 = torch.max(y2, y2g) + + area_c = (xc2 - xc1) * (yc2 - yc1) + miouk = iouk - ((area_c - unionk) / (area_c + eps)) + + loss = 1 - miouk + + if reduction == "mean": + loss = loss.mean() if loss.numel() > 0 else 0.0 * loss.sum() + elif reduction == "sum": + loss = loss.sum() + + return loss diff --git a/torchvision/prototype/models/detection/__init__.py b/torchvision/prototype/models/detection/__init__.py index 13edbf75575..4146651c737 100644 --- a/torchvision/prototype/models/detection/__init__.py +++ b/torchvision/prototype/models/detection/__init__.py @@ -1,4 +1,5 @@ from .faster_rcnn import * +from .fcos import * from .keypoint_rcnn import * from .mask_rcnn import * from .retinanet import * diff --git a/torchvision/prototype/models/detection/fcos.py b/torchvision/prototype/models/detection/fcos.py new file mode 100644 index 00000000000..d1f7f9ba361 --- /dev/null +++ b/torchvision/prototype/models/detection/fcos.py @@ -0,0 +1,79 @@ +from typing import Any, Optional + +from torchvision.prototype.transforms import CocoEval +from torchvision.transforms.functional import InterpolationMode + +from ....models.detection.fcos import ( + _resnet_fpn_extractor, + _validate_trainable_layers, + FCOS, + LastLevelP6P7, + misc_nn_ops, +) +from .._api import WeightsEnum, Weights +from .._meta import _COCO_CATEGORIES +from .._utils import handle_legacy_interface, _ovewrite_value_param +from ..resnet import ResNet50_Weights, resnet50 + + +__all__ = [ + "FCOS", + "FCOS_ResNet50_FPN_Weights", + "fcos_resnet50_fpn", +] + + +class FCOS_ResNet50_FPN_Weights(WeightsEnum): + COCO_V1 = Weights( + url="https://download.pytorch.org/models/fcos_resnet50_fpn_coco-99b0c9b7.pth", + transforms=CocoEval, + meta={ + "task": "image_object_detection", + "architecture": "FCOS", + "publication_year": 2019, + "num_params": 32269600, + "categories": _COCO_CATEGORIES, + "interpolation": InterpolationMode.BILINEAR, + "recipe": "https://github.com/pytorch/vision/tree/main/references/detection#fcos-resnet-50-fpn", + "map": 39.2, + }, + ) + default = COCO_V1 + + +@handle_legacy_interface( + weights=("pretrained", FCOS_ResNet50_FPN_Weights.COCO_V1), + weights_backbone=("pretrained_backbone", ResNet50_Weights.ImageNet1K_V1), +) +def fcos_resnet50_fpn( + *, + weights: Optional[FCOS_ResNet50_FPN_Weights] = None, + progress: bool = True, + num_classes: Optional[int] = None, + weights_backbone: Optional[ResNet50_Weights] = None, + trainable_backbone_layers: Optional[int] = None, + **kwargs: Any, +) -> FCOS: + weights = FCOS_ResNet50_FPN_Weights.verify(weights) + weights_backbone = ResNet50_Weights.verify(weights_backbone) + + if weights is not None: + weights_backbone = None + num_classes = _ovewrite_value_param(num_classes, len(weights.meta["categories"])) + elif num_classes is None: + num_classes = 91 + + trainable_backbone_layers = _validate_trainable_layers( + weights is not None or weights_backbone is not None, trainable_backbone_layers, 5, 3 + ) + + backbone = resnet50(weights=weights_backbone, progress=progress, norm_layer=misc_nn_ops.FrozenBatchNorm2d) + backbone = _resnet_fpn_extractor( + backbone, trainable_backbone_layers, returned_layers=[2, 3, 4], extra_blocks=LastLevelP6P7(256, 256) + ) + model = FCOS(backbone, num_classes, **kwargs) + + if weights is not None: + model.load_state_dict(weights.get_state_dict(progress=progress)) + + return model