Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Replace MobileNetV3's SqueezeExcitation with EfficientNet's one #4487

Merged
merged 9 commits into from
Sep 29, 2021

Conversation

datumbox
Copy link
Contributor

@datumbox datumbox commented Sep 27, 2021

Fixes #4455

Partially resolves #4333

All validation stats of models remain the same:

mobilenet_v3_large
torchrun --nproc_per_node=2 train.py --model mobilenet_v3_large --test-only --pretrained

Main Branch:
Test:  Acc@1 74.042 Acc@5 91.340

PR:
Test:  Acc@1 74.042 Acc@5 91.340

mobilenet_v3_small
torchrun --nproc_per_node=2 train.py --model mobilenet_v3_small --test-only --pretrained

Main Branch:
Test:  Acc@1 67.668 Acc@5 87.402

PR:
Test:  Acc@1 67.668 Acc@5 87.402

quantized mobilenet_v3_large
python -u train_quantization.py --device cpu --model mobilenet_v3_large --test-only

Main Branch:
Test:  Acc@1 73.004 Acc@5 90.858

PR:
Test:  Acc@1 73.004 Acc@5 90.858

ssd300_vgg16
torchrun --nproc_per_node=2 train.py --dataset coco --model ssd300_vgg16 --pretrained --test-only

Main Branch:
IoU metric: bbox
 Average Precision  (AP) @[ IoU=0.50:0.95 | area=   all | maxDets=100 ] = 0.251
 Average Precision  (AP) @[ IoU=0.50      | area=   all | maxDets=100 ] = 0.415
 Average Precision  (AP) @[ IoU=0.75      | area=   all | maxDets=100 ] = 0.262
 Average Precision  (AP) @[ IoU=0.50:0.95 | area= small | maxDets=100 ] = 0.055
 Average Precision  (AP) @[ IoU=0.50:0.95 | area=medium | maxDets=100 ] = 0.268
 Average Precision  (AP) @[ IoU=0.50:0.95 | area= large | maxDets=100 ] = 0.435
 Average Recall     (AR) @[ IoU=0.50:0.95 | area=   all | maxDets=  1 ] = 0.239
 Average Recall     (AR) @[ IoU=0.50:0.95 | area=   all | maxDets= 10 ] = 0.344
 Average Recall     (AR) @[ IoU=0.50:0.95 | area=   all | maxDets=100 ] = 0.365
 Average Recall     (AR) @[ IoU=0.50:0.95 | area= small | maxDets=100 ] = 0.088
 Average Recall     (AR) @[ IoU=0.50:0.95 | area=medium | maxDets=100 ] = 0.406
 Average Recall     (AR) @[ IoU=0.50:0.95 | area= large | maxDets=100 ] = 0.602

PR:
IoU metric: bbox
 Average Precision  (AP) @[ IoU=0.50:0.95 | area=   all | maxDets=100 ] = 0.251
 Average Precision  (AP) @[ IoU=0.50      | area=   all | maxDets=100 ] = 0.415
 Average Precision  (AP) @[ IoU=0.75      | area=   all | maxDets=100 ] = 0.262
 Average Precision  (AP) @[ IoU=0.50:0.95 | area= small | maxDets=100 ] = 0.055
 Average Precision  (AP) @[ IoU=0.50:0.95 | area=medium | maxDets=100 ] = 0.268
 Average Precision  (AP) @[ IoU=0.50:0.95 | area= large | maxDets=100 ] = 0.435
 Average Recall     (AR) @[ IoU=0.50:0.95 | area=   all | maxDets=  1 ] = 0.239
 Average Recall     (AR) @[ IoU=0.50:0.95 | area=   all | maxDets= 10 ] = 0.344
 Average Recall     (AR) @[ IoU=0.50:0.95 | area=   all | maxDets=100 ] = 0.365
 Average Recall     (AR) @[ IoU=0.50:0.95 | area= small | maxDets=100 ] = 0.088
 Average Recall     (AR) @[ IoU=0.50:0.95 | area=medium | maxDets=100 ] = 0.406
 Average Recall     (AR) @[ IoU=0.50:0.95 | area= large | maxDets=100 ] = 0.602

lraspp_mobilenet_v3_large
torchrun --nproc_per_node=2 train.py --dataset coco --model lraspp_mobilenet_v3_large --pretrained --test-only

Main Branch:
global correct: 91.2
average row correct: ['94.5', '84.3', '69.5', '72.8', '57.7', '42.0', '77.0', '57.0', '90.4', '36.1', '76.0', '60.8', '81.4', '78.9', '81.0', '87.6', '51.3', '83.9', '62.2', '84.2', '56.1']
IoU: ['90.2', '69.2', '57.7', '58.5', '47.8', '35.7', '69.5', '47.1', '79.1', '29.6', '62.6', '34.2', '65.5', '63.4', '70.0', '76.8', '30.1', '61.9', '46.8', '70.6', '49.1']
mean IoU: 57.9

PR:
global correct: 91.2
average row correct: ['94.5', '84.3', '69.5', '72.8', '57.7', '42.0', '77.0', '57.0', '90.4', '36.1', '76.0', '60.8', '81.4', '78.9', '81.0', '87.6', '51.3', '83.9', '62.2', '84.2', '56.1']
IoU: ['90.2', '69.2', '57.7', '58.5', '47.8', '35.7', '69.5', '47.1', '79.1', '29.6', '62.6', '34.2', '65.5', '63.4', '70.0', '76.8', '30.1', '61.9', '46.8', '70.6', '49.1']
mean IoU: 57.9

@datumbox datumbox force-pushed the models/replace_se branch 2 times, most recently from 69d462e to e269817 Compare September 27, 2021 20:36
Copy link
Contributor Author

@datumbox datumbox left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Highlighting some interesting bits of the implementation.

@@ -107,13 +110,13 @@ def _mobilenet_v3_model(
torch.quantization.prepare_qat(model, inplace=True)

if pretrained:
_load_weights(arch, model, quant_model_urls.get(arch + '_' + backend, None), progress)
_load_weights(arch, model, quant_model_urls.get(arch + '_' + backend, None), progress, False)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Earlier versions of the SqueezeExcite class used F.adaptive_avg_pool2d() and F.hardsigmoid() instead of their nn.Module equivalents. Using the latter are advised as because QAT can further optimize them.

Loading the old weights, is still possible but the QAT bits of the above two layers will be missing. Passing strict=false allows us to use the previous weights and achieve the same accuracy.

torchvision/models/mobilenetv3.py Show resolved Hide resolved
torchvision/models/mobilenetv3.py Show resolved Hide resolved
@datumbox datumbox requested a review from kazhang September 27, 2021 20:45
@datumbox datumbox marked this pull request as ready for review September 27, 2021 20:46
@datumbox datumbox changed the title [WIP] Replace MobileNetV3's SqueezeExcitation with EfficientNet's one Replace MobileNetV3's SqueezeExcitation with EfficientNet's one Sep 27, 2021
Copy link
Contributor

@kazhang kazhang left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM overall. Thanks for working on this!
I only have a question on quantizable module BC.

torchvision/models/quantization/mobilenetv3.py Outdated Show resolved Hide resolved
Copy link
Contributor

@kazhang kazhang left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for consolidating the SE layers!

Copy link
Member

@fmassa fmassa left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the PR!

I've left one comment which I think would be a better way of handling the BC in the quantized model.

I'm approving the PR now, as I would be ok merging the PR as it currently stands.

torchvision/models/quantization/mobilenetv3.py Outdated Show resolved Hide resolved
@datumbox datumbox merged commit ff126ae into pytorch:main Sep 29, 2021
@datumbox datumbox deleted the models/replace_se branch September 29, 2021 14:34
facebook-github-bot pushed a commit that referenced this pull request Sep 30, 2021
…one (#4487)

Summary:
* Reuse EfficientNet SE layer.

* Deprecating the mobilenetv3.SqueezeExcitation layer.

* Passing the right activation on quantization.

* Making strict named param.

* Set default params if missing.

* Fixing typos.

Reviewed By: datumbox

Differential Revision: D31270916

fbshipit-source-id: bd10285771f12f61f9b0d0a5487e8ae7aae0a2fc
cyyever pushed a commit to cyyever/vision that referenced this pull request Nov 16, 2021
…rch#4487)

* Reuse EfficientNet SE layer.

* Deprecating the mobilenetv3.SqueezeExcitation layer.

* Passing the right activation on quantization.

* Making strict named param.

* Set default params if missing.

* Fixing typos.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
4 participants