Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Quantizable MobilenetV3 architecture for Classification #3323

Merged
merged 10 commits into from
Feb 2, 2021

Conversation

datumbox
Copy link
Contributor

@datumbox datumbox commented Jan 29, 2021

The pre-trained model was trained:

python -m torch.distributed.launch --nproc_per_node=8 --use_env train_quantization.py\
--model='mobilenet_v3_large' --wd 0.00001 --lr 0.001

Submitted batch job 35496554

Validated with:

python train_quantization.py --device='cpu' --model='mobilenet_v3_large' --test-only

Accuracy metrics (Epoch 89):
Acc@1 73.004 Acc@5 90.858

Speed Benchmark: 0.0162 sec per image on CPU

@datumbox datumbox force-pushed the mobilenetv3_quantized branch from 1ab143e to a4ec036 Compare January 29, 2021 11:26
@codecov
Copy link

codecov bot commented Jan 29, 2021

Codecov Report

Merging #3323 (bc27744) into master (859a535) will increase coverage by 0.13%.
The diff coverage is 92.85%.

Impacted file tree graph

@@            Coverage Diff             @@
##           master    #3323      +/-   ##
==========================================
+ Coverage   73.90%   74.04%   +0.13%     
==========================================
  Files         104      105       +1     
  Lines        9618     9692      +74     
  Branches     1544     1554      +10     
==========================================
+ Hits         7108     7176      +68     
- Misses       2028     2033       +5     
- Partials      482      483       +1     
Impacted Files Coverage Δ
torchvision/models/quantization/mobilenetv3.py 92.42% <92.42%> (ø)
torchvision/models/mobilenetv3.py 92.42% <93.33%> (-0.38%) ⬇️
torchvision/models/quantization/mobilenet.py 100.00% <100.00%> (ø)

Continue to review full report at Codecov.

Legend - Click here to learn more
Δ = absolute <relative> (impact), ø = not affected, ? = missing data
Powered by Codecov. Last update 859a535...274c6a1. Read the comment docs.

@datumbox datumbox force-pushed the mobilenetv3_quantized branch from 44dd528 to f3ddbf5 Compare January 29, 2021 12:29
@datumbox datumbox force-pushed the mobilenetv3_quantized branch from f3ddbf5 to 4e03a0b Compare January 29, 2021 12:32
@datumbox datumbox mentioned this pull request Jan 29, 2021
13 tasks
@datumbox datumbox changed the title [WIP] Add Quantizable MobilenetV3 architecture for Classification Add Quantizable MobilenetV3 architecture for Classification Feb 2, 2021
@datumbox datumbox requested a review from fmassa February 2, 2021 15:31
Copy link
Member

@fmassa fmassa left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks great, thanks a lot!

I have a few questions to @raghuramank100 which are not blocking to get this PR merged, but I would love to get this thoughts on a couple of points.

Comment on lines +101 to +107
model.qconfig = torch.quantization.get_default_qat_qconfig(backend)
torch.quantization.prepare_qat(model, inplace=True)

if pretrained:
_load_weights(arch, model, quant_model_urls.get(arch + '_' + backend, None), progress)

torch.quantization.convert(model, inplace=True)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@raghuramank100 the approach we had to follow here for loading the pre-trained weights for a quantized model is different from what we did for the other models. Would you happen to know why we can't use the previous approach

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you elaborate on why the other approach didnt work? For QAT, we should load the fp32 model (prior to prepare) and then start training with that.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For QAT, we should load the fp32 model (prior to prepare) and then start training with that.

We did not face any issues during the training process and what you describe is the approach we used for training the model.

Can you elaborate on why the other approach didnt work?

After the training was completed, we tried to load the weights of the quantized model (key "model_eval" of the checkpoint) to do inference. Unfortunately doing so leads to extremely low accuracy (less than 1%). We are certain that the weights are loaded on the model.

As a workaround, we opted for loading the weights of the QAT model (key "model"` of the checkpoint) and then convert it. This works fine and gives the same inference accuracy as observed during training.

We are trying to understand what could be the reason for this behaviour. I can provide a demo script if that helps.

Model Acc@1 Acc@5
================================ ============= =============
MobileNet V2 71.658 90.150
MobileNet V3 Large 73.004 90.858
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@raghuramank100 this is ~ 1 acc@1 point drop compared to the fp32 reference. Would you have any tips on how to make this gap smaller?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@fmassa The non-quantized version of MobileNet V3 Large uses averaging of checkpoints which I don't do here. That's possibly one of the reasons we get lower accuracy.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If you start with the averaged checkpoint to start quantization aware training, you should get better accuracy as the starting point is better.

Copy link
Contributor

@raghuramank100 raghuramank100 Feb 3, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also, one additional hyper-parameter that helps is to turn on QAT in steps: We first turn observers on (i.e collect statistics) and then turn fake-quantization on, and after sometime we turn batch norm off. Currently, in train_quantization, steps 1 and 2 are combined. We have seen that separating them helps with QAT accuracy in some models. You could try something like:

# Initially only turn on observers, disable fake quant
model.apply(torch.quantization.enable_observer)
model.apply(torch.quantization.disable_fake_quant)
....

 if epoch >= args.num_fake_quant_start_epochs:
   model.apply(torch.quantization.enable_fake_quant)
 if epoch >= args.num_observer_update_epochs:
                print('Disabling observer for subseq epochs, epoch = ', epoch)
                model.apply(torch.quantization.disable_observer)
            if epoch >= args.num_batch_norm_update_epochs:
                print('Freezing BN for subseq epochs, epoch = ', epoch)
                model.apply(torch.nn.intrinsic.qat.freeze_bn_stats)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If you start with the averaged checkpoint to start quantization aware training, you should get better accuracy as the starting point is better.

We indeed start from an averaged checkpoint but that's not what I mean here. I'm referring to the post-training averaging step which is missing.

We first turn observers on (i.e collect statistics) and then turn fake-quantization on.

That's worth integrating on the new quant training script.

I believe key reason why the accuracy is lagging is because the quant training script does not currently support all the enhancements made on the classification training script. These enhancements (Multiple restarts, Optimizer tuning, Data augmentation, model averaging at the end etc) helped me push the accuracy by 2 points.

@fmassa
Copy link
Member

fmassa commented Feb 2, 2021

BTW, test failures seem to be related

@datumbox
Copy link
Contributor Author

datumbox commented Feb 2, 2021

@fmassa Thanks for flagging. I fixed the issues, did 1 epoch of training and 1 validation to ensure everything still works fine. The rest of the failing tests are not related to this PR. So we should be OK.

@datumbox datumbox merged commit 8317295 into pytorch:master Feb 2, 2021
@datumbox datumbox deleted the mobilenetv3_quantized branch February 2, 2021 17:36
backend = 'qnnpack'

model.fuse_model()
model.qconfig = torch.quantization.get_default_qat_qconfig(backend)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

One suggestion for improvement is to change the qconfig to the following: This configuration uses per-channel quantization for weights with qnnpack, which is now supported.
qconfig = QConfig(activation=torch.quantization.FakeQuantize.with_args(observer=torch.quantization.MovingAverageMinMaxObserver,
quant_min=0,
quant_max=255,
reduce_range=False),
weight= default_per_channel_weight_fake_quant)

facebook-github-bot pushed a commit that referenced this pull request Feb 4, 2021
Summary:
* Refactoring mobilenetv3 to make code reusable.

* Adding quantizable MobileNetV3 architecture.

* Fix bug on reference script.

* Moving documentation of quantized models in the right place.

* Update documentation.

* Workaround for loading correct weights of quant model.

* Update weight URL and readme.

* Adding eval.

Reviewed By: datumbox

Differential Revision: D26226613

fbshipit-source-id: 050d53d91abf68975f2dc3ede8db633a08b33a25
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants