Skip to content

Commit

Permalink
[fbsync] add FCOS (#4961)
Browse files Browse the repository at this point in the history
Summary:
* add fcos

* update fcos

* add giou_loss

* add BoxLinearCoder for FCOS

* add full code for FCOS

* add giou loss

* add fcos

* add __all__

* Fixing lint

* Fixing lint in giou_loss.py

* Add typing annotation to fcos

* Add trained checkpoints

* Use partial to replace lambda

* Minor fixes to docstrings

* Apply ufmt format

* Fixing docstrings

* Fixing jit scripting

* Minor fixes to docstrings

* Fixing jit scripting

* Ignore mypy in fcos

* Fixing trained checkpoints

* Fixing unit-test of jit script

* Fixing docstrings

* Add test/expect/ModelTester.test_fcos_resnet50_fpn_expect.pkl

* Fixing test_detection_model_trainable_backbone_layers

* Update test_fcos_resnet50_fpn_expect.pkl

* rename stride to box size

* remove TODO and fix some typo

* merge some code for better

* impove the comments

* remove decode and encode of BoxLinearCoder

* remove some unnecessary hints

* use the default value in detectron2.

* update doc

* Add unittest for BoxLinearCoder

* Add types in FCOS

* Add docstring for BoxLinearCoder

* Minor fix for the docstring

* update doc

* Update fcos_resnet50_fpn_coco pretained weights url

* Update torchvision/models/detection/fcos.py

* Update torchvision/models/detection/fcos.py

* Update torchvision/models/detection/fcos.py

* Update torchvision/models/detection/fcos.py

* Add FCOS model documentation

* Fix typo in FCOS documentation

* Add fcos to the prototype builder

* Capitalize COCO_V1

* Fix params of fcos

* fix bug for partial

* Fixing docs indentation

* Fixing docs format in giou_loss

* Adopt Reference for GIoU Loss

* Rename giou_loss to generalized_box_iou_loss

* remove overwrite_eps

* Update AP test values

* Minor fixes for the docs

* Minor fixes for the docs

* Update torchvision/models/detection/fcos.py

* Update torchvision/prototype/models/detection/fcos.py

Reviewed By: jdsgomes, prabhat00155

Differential Revision: D33739385

fbshipit-source-id: 7dab616adfd0c34fe21f0153c1da51f97ef43b95

Co-authored-by: Vasilis Vryniotis <[email protected]>
Co-authored-by: Vasilis Vryniotis <[email protected]>
Co-authored-by: Vasilis Vryniotis <[email protected]>
Co-authored-by: Vasilis Vryniotis <[email protected]>
Co-authored-by: Zhiqiang Wang <[email protected]>
Co-authored-by: Zhiqiang Wang <[email protected]>
Co-authored-by: zhiqiang <[email protected]>
Co-authored-by: Joao Gomes <[email protected]>
Co-authored-by: Vasilis Vryniotis <[email protected]>
Co-authored-by: Joao Gomes <[email protected]>
  • Loading branch information
5 people authored and facebook-github-bot committed Jan 26, 2022
1 parent 165a270 commit 1a64a9a
Show file tree
Hide file tree
Showing 13 changed files with 979 additions and 0 deletions.
12 changes: 12 additions & 0 deletions docs/source/models.rst
Original file line number Diff line number Diff line change
Expand Up @@ -597,6 +597,7 @@ The models subpackage contains definitions for the following model
architectures for detection:

- `Faster R-CNN <https://arxiv.org/abs/1506.01497>`_
- `FCOS <https://arxiv.org/abs/1904.01355>`_
- `Mask R-CNN <https://arxiv.org/abs/1703.06870>`_
- `RetinaNet <https://arxiv.org/abs/1708.02002>`_
- `SSD <https://arxiv.org/abs/1512.02325>`_
Expand Down Expand Up @@ -642,6 +643,7 @@ Network box AP mask AP keypoint AP
Faster R-CNN ResNet-50 FPN 37.0 - -
Faster R-CNN MobileNetV3-Large FPN 32.8 - -
Faster R-CNN MobileNetV3-Large 320 FPN 22.8 - -
FCOS ResNet-50 FPN 39.2 - -
RetinaNet ResNet-50 FPN 36.4 - -
SSD300 VGG16 25.1 - -
SSDlite320 MobileNetV3-Large 21.3 - -
Expand Down Expand Up @@ -702,6 +704,7 @@ Network train time (s / it) test time (s / it)
Faster R-CNN ResNet-50 FPN 0.2288 0.0590 5.2
Faster R-CNN MobileNetV3-Large FPN 0.1020 0.0415 1.0
Faster R-CNN MobileNetV3-Large 320 FPN 0.0978 0.0376 0.6
FCOS ResNet-50 FPN 0.1450 0.0539 3.3
RetinaNet ResNet-50 FPN 0.2514 0.0939 4.1
SSD300 VGG16 0.2093 0.0744 1.5
SSDlite320 MobileNetV3-Large 0.1773 0.0906 1.5
Expand All @@ -721,6 +724,15 @@ Faster R-CNN
torchvision.models.detection.fasterrcnn_mobilenet_v3_large_fpn
torchvision.models.detection.fasterrcnn_mobilenet_v3_large_320_fpn

FCOS
----

.. autosummary::
:toctree: generated/
:template: function.rst

torchvision.models.detection.fcos_resnet50_fpn


RetinaNet
---------
Expand Down
4 changes: 4 additions & 0 deletions mypy.ini
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,10 @@ ignore_errors = True

ignore_errors = True

[mypy-torchvision.models.detection.fcos]

ignore_errors = True

[mypy-torchvision.ops.*]

ignore_errors = True
Expand Down
7 changes: 7 additions & 0 deletions references/detection/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,13 @@ torchrun --nproc_per_node=8 train.py\
--lr-steps 16 22 --aspect-ratio-group-factor 3
```

### FCOS ResNet-50 FPN
```
torchrun --nproc_per_node=8 train.py\
--dataset coco --model fcos_resnet50_fpn --epochs 26\
--lr-steps 16 22 --aspect-ratio-group-factor 3 --lr 0.01 --amp
```

### RetinaNet
```
torchrun --nproc_per_node=8 train.py\
Expand Down
Binary file not shown.
12 changes: 12 additions & 0 deletions test/test_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -218,6 +218,7 @@ def _check_input_backprop(model, inputs):
"retinanet_resnet50_fpn": lambda x: x[1],
"ssd300_vgg16": lambda x: x[1],
"ssdlite320_mobilenet_v3_large": lambda x: x[1],
"fcos_resnet50_fpn": lambda x: x[1],
}


Expand Down Expand Up @@ -274,6 +275,13 @@ def _check_input_backprop(model, inputs):
"max_size": 224,
"input_shape": (3, 224, 224),
},
"fcos_resnet50_fpn": {
"num_classes": 2,
"score_thresh": 0.05,
"min_size": 224,
"max_size": 224,
"input_shape": (3, 224, 224),
},
"maskrcnn_resnet50_fpn": {
"num_classes": 10,
"min_size": 224,
Expand Down Expand Up @@ -325,6 +333,10 @@ def _check_input_backprop(model, inputs):
"max_trainable": 6,
"n_trn_params_per_layer": [96, 99, 138, 200, 239, 257, 266],
},
"fcos_resnet50_fpn": {
"max_trainable": 5,
"n_trn_params_per_layer": [54, 64, 83, 96, 106, 107],
},
}


Expand Down
13 changes: 13 additions & 0 deletions test/test_models_detection_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,19 @@ def test_balanced_positive_negative_sampler(self):
assert neg[0].sum() == 3
assert neg[0][0:6].sum() == 3

def test_box_linear_coder(self):
box_coder = _utils.BoxLinearCoder(normalize_by_size=True)
# Generate a random 10x4 boxes tensor, with coordinates < 50.
boxes = torch.rand(10, 4) * 50
boxes.clamp_(min=1.0) # tiny boxes cause numerical instability in box regression
boxes[:, 2:] += boxes[:, :2]

proposals = torch.tensor([0, 0, 101, 101] * 10).reshape(10, 4).float()

rel_codes = box_coder.encode_single(boxes, proposals)
pred_boxes = box_coder.decode_single(rel_codes, boxes)
torch.allclose(proposals, pred_boxes)

@pytest.mark.parametrize("train_layers, exp_froz_params", [(0, 53), (1, 43), (2, 24), (3, 11), (4, 1), (5, 0)])
def test_resnet_fpn_backbone_frozen_layers(self, train_layers, exp_froz_params):
# we know how many initial layers and parameters of the network should
Expand Down
1 change: 1 addition & 0 deletions torchvision/models/detection/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,3 +4,4 @@
from .retinanet import *
from .ssd import *
from .ssdlite import *
from .fcos import *
77 changes: 77 additions & 0 deletions torchvision/models/detection/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -217,6 +217,83 @@ def decode_single(self, rel_codes: Tensor, boxes: Tensor) -> Tensor:
return pred_boxes


class BoxLinearCoder:
"""
The linear box-to-box transform defined in FCOS. The transformation is parameterized
by the distance from the center of (square) src box to 4 edges of the target box.
"""

def __init__(self, normalize_by_size: bool = True) -> None:
"""
Args:
normalize_by_size (bool): normalize deltas by the size of src (anchor) boxes.
"""
self.normalize_by_size = normalize_by_size

def encode_single(self, reference_boxes: Tensor, proposals: Tensor) -> Tensor:
"""
Encode a set of proposals with respect to some reference boxes
Args:
reference_boxes (Tensor): reference boxes
proposals (Tensor): boxes to be encoded
Returns:
Tensor: the encoded relative box offsets that can be used to
decode the boxes.
"""
# get the center of reference_boxes
reference_boxes_ctr_x = 0.5 * (reference_boxes[:, 0] + reference_boxes[:, 2])
reference_boxes_ctr_y = 0.5 * (reference_boxes[:, 1] + reference_boxes[:, 3])

# get box regression transformation deltas
target_l = reference_boxes_ctr_x - proposals[:, 0]
target_t = reference_boxes_ctr_y - proposals[:, 1]
target_r = proposals[:, 2] - reference_boxes_ctr_x
target_b = proposals[:, 3] - reference_boxes_ctr_y

targets = torch.stack((target_l, target_t, target_r, target_b), dim=1)
if self.normalize_by_size:
reference_boxes_w = reference_boxes[:, 2] - reference_boxes[:, 0]
reference_boxes_h = reference_boxes[:, 3] - reference_boxes[:, 1]
reference_boxes_size = torch.stack(
(reference_boxes_w, reference_boxes_h, reference_boxes_w, reference_boxes_h), dim=1
)
targets = targets / reference_boxes_size

return targets

def decode_single(self, rel_codes: Tensor, boxes: Tensor) -> Tensor:
"""
From a set of original boxes and encoded relative box offsets,
get the decoded boxes.
Args:
rel_codes (Tensor): encoded boxes
boxes (Tensor): reference boxes.
Returns:
Tensor: the predicted boxes with the encoded relative box offsets.
"""

boxes = boxes.to(rel_codes.dtype)

ctr_x = 0.5 * (boxes[:, 0] + boxes[:, 2])
ctr_y = 0.5 * (boxes[:, 1] + boxes[:, 3])
if self.normalize_by_size:
boxes_w = boxes[:, 2] - boxes[:, 0]
boxes_h = boxes[:, 3] - boxes[:, 1]
boxes_size = torch.stack((boxes_w, boxes_h, boxes_w, boxes_h), dim=1)
rel_codes = rel_codes * boxes_size

pred_boxes1 = ctr_x - rel_codes[:, 0]
pred_boxes2 = ctr_y - rel_codes[:, 1]
pred_boxes3 = ctr_x + rel_codes[:, 2]
pred_boxes4 = ctr_y + rel_codes[:, 3]
pred_boxes = torch.stack((pred_boxes1, pred_boxes2, pred_boxes3, pred_boxes4), dim=1)
return pred_boxes


class Matcher:
"""
This class assigns to each predicted "element" (e.g., a box) a ground-truth
Expand Down
Loading

0 comments on commit 1a64a9a

Please sign in to comment.