Skip to content

Commit

Permalink
Add High-res FasterRCNN MobileNetV3 and tune Low-res for speed (#3265)
Browse files Browse the repository at this point in the history
* Tag fasterrcnn mobilenetv3 model with 320, add new inference config that makes it 2x faster sacrificing a bit of mAP.

* Add a high resolution fasterrcnn mobilenetv3 model.

* Update tests and expected values.
  • Loading branch information
datumbox authored Jan 19, 2021
1 parent f7fae49 commit ae0d80b
Show file tree
Hide file tree
Showing 7 changed files with 109 additions and 45 deletions.
37 changes: 20 additions & 17 deletions docs/source/models.rst
Original file line number Diff line number Diff line change
Expand Up @@ -358,14 +358,15 @@ models return the predictions of the following classes:
Here are the summary of the accuracies for the models trained on
the instances set of COCO train2017 and evaluated on COCO val2017.

================================== ======= ======== ===========
Network box AP mask AP keypoint AP
================================== ======= ======== ===========
Faster R-CNN ResNet-50 FPN 37.0 - -
Faster R-CNN MobileNetV3-Large FPN 23.0 - -
RetinaNet ResNet-50 FPN 36.4 - -
Mask R-CNN ResNet-50 FPN 37.9 34.6 -
================================== ======= ======== ===========
====================================== ======= ======== ===========
Network box AP mask AP keypoint AP
====================================== ======= ======== ===========
Faster R-CNN ResNet-50 FPN 37.0 - -
Faster R-CNN MobileNetV3-Large FPN 32.8 - -
Faster R-CNN MobileNetV3-Large 320 FPN 22.8 - -
RetinaNet ResNet-50 FPN 36.4 - -
Mask R-CNN ResNet-50 FPN 37.9 34.6 -
====================================== ======= ======== ===========

For person keypoint detection, the accuracies for the pre-trained
models are as follows
Expand Down Expand Up @@ -415,22 +416,24 @@ For test time, we report the time for the model evaluation and postprocessing
(including mask pasting in image), but not the time for computing the
precision-recall.

================================== =================== ================== ===========
Network train time (s / it) test time (s / it) memory (GB)
================================== =================== ================== ===========
Faster R-CNN ResNet-50 FPN 0.2288 0.0590 5.2
Faster R-CNN MobileNetV3-Large FPN 0.0978 0.0376 0.6
RetinaNet ResNet-50 FPN 0.2514 0.0939 4.1
Mask R-CNN ResNet-50 FPN 0.2728 0.0903 5.4
Keypoint R-CNN ResNet-50 FPN 0.3789 0.1242 6.8
================================== =================== ================== ===========
====================================== =================== ================== ===========
Network train time (s / it) test time (s / it) memory (GB)
====================================== =================== ================== ===========
Faster R-CNN ResNet-50 FPN 0.2288 0.0590 5.2
Faster R-CNN MobileNetV3-Large FPN 0.1020 0.0415 1.0
Faster R-CNN MobileNetV3-Large 320 FPN 0.0978 0.0376 0.6
RetinaNet ResNet-50 FPN 0.2514 0.0939 4.1
Mask R-CNN ResNet-50 FPN 0.2728 0.0903 5.4
Keypoint R-CNN ResNet-50 FPN 0.3789 0.1242 6.8
====================================== =================== ================== ===========


Faster R-CNN
------------

.. autofunction:: torchvision.models.detection.fasterrcnn_resnet50_fpn
.. autofunction:: torchvision.models.detection.fasterrcnn_mobilenet_v3_large_fpn
.. autofunction:: torchvision.models.detection.fasterrcnn_mobilenet_v3_large_320_fpn


RetinaNet
Expand Down
7 changes: 7 additions & 0 deletions references/detection/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,13 @@ python -m torch.distributed.launch --nproc_per_node=8 --use_env train.py\
--lr-steps 16 22 --aspect-ratio-group-factor 3
```

### Faster R-CNN MobileNetV3-Large 320 FPN
```
python -m torch.distributed.launch --nproc_per_node=8 --use_env train.py\
--dataset coco --model fasterrcnn_mobilenet_v3_large_320_fpn --epochs 26\
--lr-steps 16 22 --aspect-ratio-group-factor 3
```

### RetinaNet
```
python -m torch.distributed.launch --nproc_per_node=8 --use_env train.py\
Expand Down
Binary file not shown.
Binary file not shown.
6 changes: 5 additions & 1 deletion test/test_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ def get_available_video_models():
'inception_v3': lambda x: x.logits,
"fasterrcnn_resnet50_fpn": lambda x: x[1],
"fasterrcnn_mobilenet_v3_large_fpn": lambda x: x[1],
"fasterrcnn_mobilenet_v3_large_320_fpn": lambda x: x[1],
"maskrcnn_resnet50_fpn": lambda x: x[1],
"keypointrcnn_resnet50_fpn": lambda x: x[1],
"retinanet_resnet50_fpn": lambda x: x[1],
Expand Down Expand Up @@ -106,8 +107,11 @@ def _test_detection_model(self, name, dev):
if "retinanet" in name:
# Reduce the default threshold to ensure the returned boxes are not empty.
kwargs["score_thresh"] = 0.01
elif "fasterrcnn_mobilenet" in name:
elif "fasterrcnn_mobilenet_v3_large" in name:
kwargs["box_score_thresh"] = 0.02076
if "fasterrcnn_mobilenet_v3_large_320_fpn" in name:
kwargs["rpn_pre_nms_top_n_test"] = 1000
kwargs["rpn_post_nms_top_n_test"] = 1000
model = models.detection.__dict__[name](num_classes=50, pretrained_backbone=False, **kwargs)
model.eval().to(device=dev)
input_shape = (3, 300, 300)
Expand Down
3 changes: 2 additions & 1 deletion test/test_models_detection_negative_samples.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,8 @@ def test_assign_targets_to_proposals(self):
self.assertEqual(labels[0].dtype, torch.int64)

def test_forward_negative_sample_frcnn(self):
for name in ["fasterrcnn_resnet50_fpn", "fasterrcnn_mobilenet_v3_large_fpn"]:
for name in ["fasterrcnn_resnet50_fpn", "fasterrcnn_mobilenet_v3_large_fpn",
"fasterrcnn_mobilenet_v3_large_320_fpn"]:
model = torchvision.models.detection.__dict__[name](
num_classes=2, min_size=100, max_size=100)

Expand Down
101 changes: 75 additions & 26 deletions torchvision/models/detection/faster_rcnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,8 @@


__all__ = [
"FasterRCNN", "fasterrcnn_resnet50_fpn", "fasterrcnn_mobilenet_v3_large_fpn"
"FasterRCNN", "fasterrcnn_resnet50_fpn", "fasterrcnn_mobilenet_v3_large_320_fpn",
"fasterrcnn_mobilenet_v3_large_fpn"
]


Expand Down Expand Up @@ -288,8 +289,10 @@ def forward(self, x):
model_urls = {
'fasterrcnn_resnet50_fpn_coco':
'https://download.pytorch.org/models/fasterrcnn_resnet50_fpn_coco-258fb6c6.pth',
'fasterrcnn_mobilenet_v3_large_320_fpn_coco':
'https://download.pytorch.org/models/fasterrcnn_mobilenet_v3_large_320_fpn-907ea3f9.pth',
'fasterrcnn_mobilenet_v3_large_fpn_coco':
'https://download.pytorch.org/models/fasterrcnn_mobilenet_v3_large_fpn-907ea3f9.pth',
'https://download.pytorch.org/models/fasterrcnn_mobilenet_v3_large_fpn-fb6a3cc7.pth'
}


Expand Down Expand Up @@ -368,16 +371,38 @@ def fasterrcnn_resnet50_fpn(pretrained=False, progress=True,
return model


def fasterrcnn_mobilenet_v3_large_fpn(pretrained=False, progress=True, num_classes=91, pretrained_backbone=True,
trainable_backbone_layers=None, min_size=320, max_size=640, rpn_score_thresh=0.05,
**kwargs):
def _fasterrcnn_mobilenet_v3_large_fpn(weights_name, pretrained=False, progress=True, num_classes=91,
pretrained_backbone=True, trainable_backbone_layers=None, **kwargs):
trainable_backbone_layers = _validate_trainable_layers(
pretrained or pretrained_backbone, trainable_backbone_layers, 6, 3)

if pretrained:
pretrained_backbone = False
backbone = mobilenet_backbone("mobilenet_v3_large", pretrained_backbone, True,
trainable_layers=trainable_backbone_layers)

anchor_sizes = ((32, 64, 128, 256, 512, ), ) * 3
aspect_ratios = ((0.5, 1.0, 2.0),) * len(anchor_sizes)

model = FasterRCNN(backbone, num_classes, rpn_anchor_generator=AnchorGenerator(anchor_sizes, aspect_ratios),
**kwargs)
if pretrained:
if model_urls.get(weights_name, None) is None:
raise ValueError("No checkpoint is available for model {}".format(weights_name))
state_dict = load_state_dict_from_url(model_urls[weights_name], progress=progress)
model.load_state_dict(state_dict)
return model


def fasterrcnn_mobilenet_v3_large_320_fpn(pretrained=False, progress=True, num_classes=91, pretrained_backbone=True,
trainable_backbone_layers=None, **kwargs):
"""
Constructs a Faster R-CNN model with a MobileNetV3-Large FPN backbone. It works similarly
to Faster R-CNN with ResNet-50 FPN backbone. See `fasterrcnn_resnet50_fpn` for more details.
Constructs a low resolution Faster R-CNN model with a MobileNetV3-Large FPN backbone tunned for mobile use-cases.
It works similarly to Faster R-CNN with ResNet-50 FPN backbone. See `fasterrcnn_resnet50_fpn` for more details.
Example::
>>> model = torchvision.models.detection.fasterrcnn_mobilenet_v3_large_fpn(pretrained=True)
>>> model = torchvision.models.detection.fasterrcnn_mobilenet_v3_large_320_fpn(pretrained=True)
>>> model.eval()
>>> x = [torch.rand(3, 300, 400), torch.rand(3, 500, 400)]
>>> predictions = model(x)
Expand All @@ -389,25 +414,49 @@ def fasterrcnn_mobilenet_v3_large_fpn(pretrained=False, progress=True, num_class
pretrained_backbone (bool): If True, returns a model with backbone pre-trained on Imagenet
trainable_backbone_layers (int): number of trainable (not frozen) resnet layers starting from final block.
Valid values are between 0 and 6, with 6 meaning all backbone layers are trainable.
min_size (int): minimum size of the image to be rescaled before feeding it to the backbone
max_size (int): maximum size of the image to be rescaled before feeding it to the backbone
rpn_score_thresh (float): during inference, only return proposals with a classification score
greater than rpn_score_thresh
"""
trainable_backbone_layers = _validate_trainable_layers(
pretrained or pretrained_backbone, trainable_backbone_layers, 6, 3)
weights_name = "fasterrcnn_mobilenet_v3_large_320_fpn_coco"
defaults = {
"min_size": 320,
"max_size": 640,
"rpn_pre_nms_top_n_test": 150,
"rpn_post_nms_top_n_test": 150,
"rpn_score_thresh": 0.05,
}

if pretrained:
pretrained_backbone = False
backbone = mobilenet_backbone("mobilenet_v3_large", pretrained_backbone, True,
trainable_layers=trainable_backbone_layers)
kwargs = {**defaults, **kwargs}
return _fasterrcnn_mobilenet_v3_large_fpn(weights_name, pretrained=pretrained, progress=progress,
num_classes=num_classes, pretrained_backbone=pretrained_backbone,
trainable_backbone_layers=trainable_backbone_layers, **kwargs)

anchor_sizes = ((32, 64, 128, 256, 512, ), ) * 3
aspect_ratios = ((0.5, 1.0, 2.0),) * len(anchor_sizes)

model = FasterRCNN(backbone, num_classes, rpn_anchor_generator=AnchorGenerator(anchor_sizes, aspect_ratios),
min_size=min_size, max_size=max_size, rpn_score_thresh=rpn_score_thresh, **kwargs)
if pretrained:
state_dict = load_state_dict_from_url(model_urls['fasterrcnn_mobilenet_v3_large_fpn_coco'], progress=progress)
model.load_state_dict(state_dict)
return model
def fasterrcnn_mobilenet_v3_large_fpn(pretrained=False, progress=True, num_classes=91, pretrained_backbone=True,
trainable_backbone_layers=None, **kwargs):
"""
Constructs a high resolution Faster R-CNN model with a MobileNetV3-Large FPN backbone.
It works similarly to Faster R-CNN with ResNet-50 FPN backbone. See `fasterrcnn_resnet50_fpn` for more details.
Example::
>>> model = torchvision.models.detection.fasterrcnn_mobilenet_v3_large_fpn(pretrained=True)
>>> model.eval()
>>> x = [torch.rand(3, 300, 400), torch.rand(3, 500, 400)]
>>> predictions = model(x)
Args:
pretrained (bool): If True, returns a model pre-trained on COCO train2017
progress (bool): If True, displays a progress bar of the download to stderr
num_classes (int): number of output classes of the model (including the background)
pretrained_backbone (bool): If True, returns a model with backbone pre-trained on Imagenet
trainable_backbone_layers (int): number of trainable (not frozen) resnet layers starting from final block.
Valid values are between 0 and 6, with 6 meaning all backbone layers are trainable.
"""
weights_name = "fasterrcnn_mobilenet_v3_large_fpn_coco"
defaults = {
"rpn_score_thresh": 0.05,
}

kwargs = {**defaults, **kwargs}
return _fasterrcnn_mobilenet_v3_large_fpn(weights_name, pretrained=pretrained, progress=progress,
num_classes=num_classes, pretrained_backbone=pretrained_backbone,
trainable_backbone_layers=trainable_backbone_layers, **kwargs)

0 comments on commit ae0d80b

Please sign in to comment.