-
Notifications
You must be signed in to change notification settings - Fork 7k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add Quantizable MobilenetV3 architecture for Classification #3323
Conversation
1ab143e
to
a4ec036
Compare
Codecov Report
@@ Coverage Diff @@
## master #3323 +/- ##
==========================================
+ Coverage 73.90% 74.04% +0.13%
==========================================
Files 104 105 +1
Lines 9618 9692 +74
Branches 1544 1554 +10
==========================================
+ Hits 7108 7176 +68
- Misses 2028 2033 +5
- Partials 482 483 +1
Continue to review full report at Codecov.
|
44dd528
to
f3ddbf5
Compare
f3ddbf5
to
4e03a0b
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks great, thanks a lot!
I have a few questions to @raghuramank100 which are not blocking to get this PR merged, but I would love to get this thoughts on a couple of points.
model.qconfig = torch.quantization.get_default_qat_qconfig(backend) | ||
torch.quantization.prepare_qat(model, inplace=True) | ||
|
||
if pretrained: | ||
_load_weights(arch, model, quant_model_urls.get(arch + '_' + backend, None), progress) | ||
|
||
torch.quantization.convert(model, inplace=True) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@raghuramank100 the approach we had to follow here for loading the pre-trained weights for a quantized model is different from what we did for the other models. Would you happen to know why we can't use the previous approach
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you elaborate on why the other approach didnt work? For QAT, we should load the fp32 model (prior to prepare) and then start training with that.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For QAT, we should load the fp32 model (prior to prepare) and then start training with that.
We did not face any issues during the training process and what you describe is the approach we used for training the model.
Can you elaborate on why the other approach didnt work?
After the training was completed, we tried to load the weights of the quantized model (key "model_eval" of the checkpoint) to do inference. Unfortunately doing so leads to extremely low accuracy (less than 1%). We are certain that the weights are loaded on the model.
As a workaround, we opted for loading the weights of the QAT model (key "model"` of the checkpoint) and then convert it. This works fine and gives the same inference accuracy as observed during training.
We are trying to understand what could be the reason for this behaviour. I can provide a demo script if that helps.
Model Acc@1 Acc@5 | ||
================================ ============= ============= | ||
MobileNet V2 71.658 90.150 | ||
MobileNet V3 Large 73.004 90.858 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@raghuramank100 this is ~ 1 acc@1 point drop compared to the fp32 reference. Would you have any tips on how to make this gap smaller?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@fmassa The non-quantized version of MobileNet V3 Large uses averaging of checkpoints which I don't do here. That's possibly one of the reasons we get lower accuracy.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If you start with the averaged checkpoint to start quantization aware training, you should get better accuracy as the starting point is better.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Also, one additional hyper-parameter that helps is to turn on QAT in steps: We first turn observers on (i.e collect statistics) and then turn fake-quantization on, and after sometime we turn batch norm off. Currently, in train_quantization, steps 1 and 2 are combined. We have seen that separating them helps with QAT accuracy in some models. You could try something like:
# Initially only turn on observers, disable fake quant
model.apply(torch.quantization.enable_observer)
model.apply(torch.quantization.disable_fake_quant)
....
if epoch >= args.num_fake_quant_start_epochs:
model.apply(torch.quantization.enable_fake_quant)
if epoch >= args.num_observer_update_epochs:
print('Disabling observer for subseq epochs, epoch = ', epoch)
model.apply(torch.quantization.disable_observer)
if epoch >= args.num_batch_norm_update_epochs:
print('Freezing BN for subseq epochs, epoch = ', epoch)
model.apply(torch.nn.intrinsic.qat.freeze_bn_stats)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If you start with the averaged checkpoint to start quantization aware training, you should get better accuracy as the starting point is better.
We indeed start from an averaged checkpoint but that's not what I mean here. I'm referring to the post-training averaging step which is missing.
We first turn observers on (i.e collect statistics) and then turn fake-quantization on.
That's worth integrating on the new quant training script.
I believe key reason why the accuracy is lagging is because the quant training script does not currently support all the enhancements made on the classification training script. These enhancements (Multiple restarts, Optimizer tuning, Data augmentation, model averaging at the end etc) helped me push the accuracy by 2 points.
BTW, test failures seem to be related |
@fmassa Thanks for flagging. I fixed the issues, did 1 epoch of training and 1 validation to ensure everything still works fine. The rest of the failing tests are not related to this PR. So we should be OK. |
backend = 'qnnpack' | ||
|
||
model.fuse_model() | ||
model.qconfig = torch.quantization.get_default_qat_qconfig(backend) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
One suggestion for improvement is to change the qconfig to the following: This configuration uses per-channel quantization for weights with qnnpack, which is now supported.
qconfig = QConfig(activation=torch.quantization.FakeQuantize.with_args(observer=torch.quantization.MovingAverageMinMaxObserver,
quant_min=0,
quant_max=255,
reduce_range=False),
weight= default_per_channel_weight_fake_quant)
Summary: * Refactoring mobilenetv3 to make code reusable. * Adding quantizable MobileNetV3 architecture. * Fix bug on reference script. * Moving documentation of quantized models in the right place. * Update documentation. * Workaround for loading correct weights of quant model. * Update weight URL and readme. * Adding eval. Reviewed By: datumbox Differential Revision: D26226613 fbshipit-source-id: 050d53d91abf68975f2dc3ede8db633a08b33a25
The pre-trained model was trained:
Submitted batch job 35496554
Validated with:
Accuracy metrics (Epoch 89):
Acc@1 73.004 Acc@5 90.858
Speed Benchmark:
0.0162 sec per image on CPU