Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

RetinaNet with MobileNetV3 FPN backbone #3223

Merged
merged 14 commits into from
Jan 12, 2021

Conversation

datumbox
Copy link
Contributor

@datumbox datumbox commented Jan 6, 2021

Partially fixes #1999


RetinaNet + MobileNetV3 large + FPN

Trained using the code committed at 7af35c3.

The current temporary pre-trained model was trained:

python -m torch.distributed.launch --nproc_per_node=8 --use_env train.py\
 --dataset coco --model retinanet_mobilenet_v3_large_fpn --epochs 26 --lr-steps 16 22\
 --aspect-ratio-group-factor 3 --lr 0.01

Submitted batch job 34643976

Then we took the 2 last checkpoints (epochs 22, 18) that improved the AP and averaged their parameters using the following script:

# from https://github.com/pytorch/fairseq/blob/master/scripts/average_checkpoints.py
import collections
import torch


def average_checkpoints(inputs):
    params_dict = collections.OrderedDict()
    params_keys = None
    new_state = None
    num_models = len(inputs)
    for fpath in inputs:
        with open(fpath, "rb") as f:
            state = torch.load(
                f,
                map_location=(
                    lambda s, _: torch.serialization.default_restore_location(s, "cpu")
                ),
            )
        # Copies over the settings from the first checkpoint
        if new_state is None:
            new_state = state
        model_params = state["model"]
        model_params_keys = list(model_params.keys())
        if params_keys is None:
            params_keys = model_params_keys
        elif params_keys != model_params_keys:
            raise KeyError(
                "For checkpoint {}, expected list of params: {}, "
                "but found: {}".format(f, params_keys, model_params_keys)
            )
        for k in params_keys:
            p = model_params[k]
            if isinstance(p, torch.HalfTensor):
                p = p.float()
            if k not in params_dict:
                params_dict[k] = p.clone()
                # NOTE: clone() is needed in case of p is a shared parameter
            else:
                params_dict[k] += p
    averaged_params = collections.OrderedDict()
    for k, v in params_dict.items():
        averaged_params[k] = v
        if averaged_params[k].is_floating_point():
            averaged_params[k].div_(num_models)
        else:
            averaged_params[k] //= num_models
    new_state["model"] = averaged_params
    return new_state


def avg(epochs, filename):
    paths = ["model_{}.pth".format(i) for i in epochs]
    weights = average_checkpoints(paths)
    torch.save(weights, filename.format(len(epochs)))

avg([22, 18], "model_best{}avg.pth")

Accuracy metrics:

0: Test: Total time: 0:00:33 (0.0528 s / it)
0: Averaged stats: model_time: 0.0228 (0.0248)  evaluator_time: 0.0159 (0.0221)
0: Accumulating evaluation results...
0: DONE (t=17.23s).
0: IoU metric: bbox
0:  Average Precision  (AP) @[ IoU=0.50:0.95 | area=   all | maxDets=100 ] = 0.256
0:  Average Precision  (AP) @[ IoU=0.50      | area=   all | maxDets=100 ] = 0.423
0:  Average Precision  (AP) @[ IoU=0.75      | area=   all | maxDets=100 ] = 0.262
0:  Average Precision  (AP) @[ IoU=0.50:0.95 | area= small | maxDets=100 ] = 0.061
0:  Average Precision  (AP) @[ IoU=0.50:0.95 | area=medium | maxDets=100 ] = 0.275
0:  Average Precision  (AP) @[ IoU=0.50:0.95 | area= large | maxDets=100 ] = 0.428
0:  Average Recall     (AR) @[ IoU=0.50:0.95 | area=   all | maxDets=  1 ] = 0.248
0:  Average Recall     (AR) @[ IoU=0.50:0.95 | area=   all | maxDets= 10 ] = 0.380
0:  Average Recall     (AR) @[ IoU=0.50:0.95 | area=   all | maxDets=100 ] = 0.410
0:  Average Recall     (AR) @[ IoU=0.50:0.95 | area= small | maxDets=100 ] = 0.145
0:  Average Recall     (AR) @[ IoU=0.50:0.95 | area=medium | maxDets=100 ] = 0.468
0:  Average Recall     (AR) @[ IoU=0.50:0.95 | area= large | maxDets=100 ] = 0.645 

Validated with:

python -m torch.distributed.launch --nproc_per_node=8 --use_env train.py\
 --model mobilenet_v3_large --test-only --pretrained

Submitted batch job 34643680

Speed benchmark:
0.74 sec per image on CPU

Copy link
Contributor Author

@datumbox datumbox left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please review comments that highlight important changes.

@@ -104,7 +104,7 @@ def _test_detection_model(self, name, dev):
kwargs = {}
if "retinanet" in name:
# Reduce the default threshold to ensure the returned boxes are not empty.
kwargs["score_thresh"] = 0.01
kwargs["score_thresh"] = 0.0099999
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Slight adjustment necessary for getting non-zero results on MobileNetV3.

num_classes=2, min_size=100, max_size=100)
for name in ["retinanet_resnet50_fpn", "retinanet_mobilenet_v3_large_fpn"]:
model = torchvision.models.detection.__dict__[name](
num_classes=2, min_size=100, max_size=100, pretrained_backbone=False)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The previous version seemed to download the weights of the backbone unnecessarily. I fix this inplace by adding
pretrained_backbone=False.


# Gather the indeces of blocks which are strided. These are the locations of C1, ..., Cn-1 blocks.
# The first and last blocks are always included because they are the C0 (conv1) and Cn.
stage_indeces = [0] + [i for i, b in enumerate(backbone) if getattr(b, "is_strided", False)] + [len(backbone) - 1]
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Instead of adding meta-data with the location on the blocks that downsample, I get it by checking a new attribute called is_strided. This attribute is added in both the MobileNetV2 and V3 residual blocks and indicates if the specific block downsamples. This is typically the location of C1...Cn-1 blocks.

Note that blocks at first and last position of the features block are always included.

freeze_before = num_stages if trainable_layers == 0 else stage_indeces[num_stages - trainable_layers]

# freeze layers only if pretrained backbone is used
for b in backbone[:freeze_before]:
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Unlike the resnet implementation, here we need to find the location of the first block that we finetune and mark everything before that as frozen.

"""
if pretrained:
pretrained_backbone = False
backbone = mobilenet_fpn_backbone("mobilenet_v3_large", pretrained_backbone, returned_layers=[4, 5])
Copy link
Contributor Author

@datumbox datumbox Jan 6, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We use the outputs of blocks C4 and C5 in our feature pyramid.

On C5, we do the same as the paper and we use the layer just before pooling.

On C4, we deviate from the original paper that suggests using the "the expansion layer of the 13th bottleneck block". In our case we use the output of the 13th bottleneck because it's very hard to get the output of the expansion without completely refactoring the entire mobilenetV2 and V3 architectures. As a result our C4 feature output is 160x7x7 instead of 672x14x14.

This could lead to a faster model but might reduce the accuracy metrics. We'll do experiments to assess the difference. Perhaps instead of refactoring completely the implementation, we could have a workaround using hooks?

pretrained_backbone = False
backbone = mobilenet_fpn_backbone("mobilenet_v3_large", pretrained_backbone, returned_layers=[4, 5])

anchor_sizes = ((128,), (256,), (512,))
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Anchor sizes for C4, C5 and pool. It's important to note that C4 and C5 have the same output stride of 32.

@@ -90,6 +91,8 @@ def __init__(
norm_layer(oup),
])
self.conv = nn.Sequential(*layers)
self.output_channels = oup
self.is_strided = stride > 1
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Meta data are added in the blocks to make it easier to detect the C1...Cn blocks and the out_channels in detection models. We do this both on mobilenetv2 and mobilenetv3.

@datumbox datumbox mentioned this pull request Jan 6, 2021
13 tasks
@datumbox datumbox force-pushed the mobilenetv3/object_detection branch from 217350a to a56fe27 Compare January 6, 2021 14:35
@datumbox datumbox requested a review from fmassa January 6, 2021 14:49
@datumbox datumbox force-pushed the mobilenetv3/object_detection branch from 7d4dd3d to 06e3e72 Compare January 7, 2021 21:30
@datumbox datumbox changed the title [WIP] RetinaNet with MobileNetV3 FPN backbone [WIP] RetinaNet with MobileNetV3 backbone Jan 8, 2021
@datumbox datumbox force-pushed the mobilenetv3/object_detection branch from e211494 to 0419dbc Compare January 8, 2021 14:07
@datumbox datumbox force-pushed the mobilenetv3/object_detection branch from 4a0c7a4 to 6c53bfc Compare January 8, 2021 14:58
@datumbox datumbox force-pushed the mobilenetv3/object_detection branch 2 times, most recently from 5c15a2c to 75933da Compare January 9, 2021 10:59
@datumbox datumbox changed the title [WIP] RetinaNet with MobileNetV3 backbone [WIP] RetinaNet with MobileNetV3 FPN backbone Jan 9, 2021
@datumbox datumbox force-pushed the mobilenetv3/object_detection branch from 75fa2a6 to 81800cd Compare January 9, 2021 18:38
@datumbox datumbox changed the title [WIP] RetinaNet with MobileNetV3 FPN backbone RetinaNet with MobileNetV3 FPN backbone Jan 12, 2021
@datumbox datumbox merged commit f883796 into pytorch:mobilenetv3 Jan 12, 2021
@datumbox datumbox deleted the mobilenetv3/object_detection branch January 12, 2021 10:42
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants