Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Quantizable MobilenetV3 architecture for Classification #3323

Merged
merged 10 commits into from
Feb 2, 2021
47 changes: 47 additions & 0 deletions docs/source/models.rst
Original file line number Diff line number Diff line change
Expand Up @@ -263,6 +263,53 @@ MNASNet
.. autofunction:: mnasnet1_0
.. autofunction:: mnasnet1_3

Quantized Models
----------------

The following architectures provide support for INT8 quantized models. You can get
a model with random weights by calling its constructor:

.. code:: python

import torchvision.models as models
googlenet = models.quantization.googlenet()
inception_v3 = models.quantization.inception_v3()
mobilenet_v2 = models.quantization.mobilenet_v2()
mobilenet_v3_large = models.quantization.mobilenet_v3_large()
mobilenet_v3_small = models.quantization.mobilenet_v3_small()
resnet18 = models.quantization.resnet18()
resnet50 = models.quantization.resnet50()
resnext101_32x8d = models.quantization.resnext101_32x8d()
shufflenet_v2_x0_5 = models.quantization.shufflenet_v2_x0_5()
shufflenet_v2_x1_0 = models.quantization.shufflenet_v2_x1_0()
shufflenet_v2_x1_5 = models.quantization.shufflenet_v2_x1_5()
shufflenet_v2_x2_0 = models.quantization.shufflenet_v2_x2_0()

Obtaining a pre-trained quantized model can be done with a few lines of code:

.. code:: python

import torchvision.models as models
model = models.quantization.mobilenet_v2(pretrained=True, quantize=True)
model.eval()
# run the model with quantized inputs and weights
out = model(torch.rand(1, 3, 224, 224))

We provide pre-trained quantized weights for the following models:

================================ ============= =============
Model Acc@1 Acc@5
================================ ============= =============
MobileNet V2 71.658 90.150
MobileNet V3 Large 73.004 90.858
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@raghuramank100 this is ~ 1 acc@1 point drop compared to the fp32 reference. Would you have any tips on how to make this gap smaller?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@fmassa The non-quantized version of MobileNet V3 Large uses averaging of checkpoints which I don't do here. That's possibly one of the reasons we get lower accuracy.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If you start with the averaged checkpoint to start quantization aware training, you should get better accuracy as the starting point is better.

Copy link
Contributor

@raghuramank100 raghuramank100 Feb 3, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also, one additional hyper-parameter that helps is to turn on QAT in steps: We first turn observers on (i.e collect statistics) and then turn fake-quantization on, and after sometime we turn batch norm off. Currently, in train_quantization, steps 1 and 2 are combined. We have seen that separating them helps with QAT accuracy in some models. You could try something like:

# Initially only turn on observers, disable fake quant
model.apply(torch.quantization.enable_observer)
model.apply(torch.quantization.disable_fake_quant)
....

 if epoch >= args.num_fake_quant_start_epochs:
   model.apply(torch.quantization.enable_fake_quant)
 if epoch >= args.num_observer_update_epochs:
                print('Disabling observer for subseq epochs, epoch = ', epoch)
                model.apply(torch.quantization.disable_observer)
            if epoch >= args.num_batch_norm_update_epochs:
                print('Freezing BN for subseq epochs, epoch = ', epoch)
                model.apply(torch.nn.intrinsic.qat.freeze_bn_stats)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If you start with the averaged checkpoint to start quantization aware training, you should get better accuracy as the starting point is better.

We indeed start from an averaged checkpoint but that's not what I mean here. I'm referring to the post-training averaging step which is missing.

We first turn observers on (i.e collect statistics) and then turn fake-quantization on.

That's worth integrating on the new quant training script.

I believe key reason why the accuracy is lagging is because the quant training script does not currently support all the enhancements made on the classification training script. These enhancements (Multiple restarts, Optimizer tuning, Data augmentation, model averaging at the end etc) helped me push the accuracy by 2 points.

ShuffleNet V2 68.360 87.582
ResNet 18 69.494 88.882
ResNet 50 75.920 92.814
ResNext 101 32x8d 78.986 94.480
Inception V3 77.176 93.354
GoogleNet 69.826 89.404
================================ ============= =============


Semantic Segmentation
=====================
Expand Down
56 changes: 31 additions & 25 deletions references/classification/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -74,27 +74,6 @@ python -m torch.distributed.launch --nproc_per_node=8 --use_env train.py\
```

## Quantized
### INT8 models
We add INT8 quantized models to follow the quantization support added in PyTorch 1.3.

Obtaining a pre-trained quantized model can be obtained with a few lines of code:
```
model = torchvision.models.quantization.mobilenet_v2(pretrained=True, quantize=True)
model.eval()
# run the model with quantized inputs and weights
out = model(torch.rand(1, 3, 224, 224))
```
We provide pre-trained quantized weights for the following models:

| Model | Acc@1 | Acc@5 |
|:-----------------:|:------:|:------:|
| MobileNet V2 | 71.658 | 90.150 |
| ShuffleNet V2: | 68.360 | 87.582 |
| ResNet 18 | 69.494 | 88.882 |
| ResNet 50 | 75.920 | 92.814 |
| ResNext 101 32x8d | 78.986 | 94.480 |
| Inception V3 | 77.176 | 93.354 |
| GoogleNet | 69.826 | 89.404 |

### Parameters used for generating quantized models:

Expand All @@ -106,6 +85,10 @@ For all post training quantized models (All quantized models except mobilenet-v2
4. eval_batch_size: 128
5. backend: 'fbgemm'

```
python train_quantization.py --device='cpu' --post-training-quantize --backend='fbgemm' --model='<model_name>'
```

For Mobilenet-v2, the model was trained with quantization aware training, the settings used are:
1. num_workers: 16
2. batch_size: 32
Expand All @@ -118,15 +101,38 @@ For Mobilenet-v2, the model was trained with quantization aware training, the se
9. momentum: 0.9
10. lr_step_size:30
11. lr_gamma: 0.1
12. weight-decay: 0.0001

```
python -m torch.distributed.launch --nproc_per_node=8 --use_env train_quantization.py --model='mobilenet_v2'
```

Training converges at about 10 epochs.

For post training quant, device is set to CPU. For training, the device is set to CUDA
For Mobilenet-v3 Large, the model was trained with quantization aware training, the settings used are:
1. num_workers: 16
2. batch_size: 32
3. eval_batch_size: 128
4. backend: 'qnnpack'
5. learning-rate: 0.001
6. num_epochs: 90
7. num_observer_update_epochs:4
8. num_batch_norm_update_epochs:3
9. momentum: 0.9
10. lr_step_size:30
11. lr_gamma: 0.1
12. weight-decay: 0.00001

```
python -m torch.distributed.launch --nproc_per_node=8 --use_env train_quantization.py --model='mobilenet_v3_large' \
--wd 0.00001 --lr 0.001
```

For post training quant, device is set to CPU. For training, the device is set to CUDA.

### Command to evaluate quantized models using the pre-trained weights:
For all quantized models:

```
python references/classification/train_quantization.py --data-path='imagenet_full_size/' \
--device='cpu' --test-only --backend='fbgemm' --model='<model_name>'
python train_quantization.py --device='cpu' --test-only --backend='<backend>' --model='<model_name>'
```

6 changes: 4 additions & 2 deletions references/classification/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,10 +92,12 @@ def load_data(traindir, valdir, args):
print("Loading dataset_train from {}".format(cache_path))
dataset, _ = torch.load(cache_path)
else:
auto_augment_policy = getattr(args, "auto_augment", None)
random_erase_prob = getattr(args, "random_erase", 0.0)
dataset = torchvision.datasets.ImageFolder(
traindir,
presets.ClassificationPresetTrain(crop_size=crop_size, auto_augment_policy=args.auto_augment,
random_erase_prob=args.random_erase))
presets.ClassificationPresetTrain(crop_size=crop_size, auto_augment_policy=auto_augment_policy,
random_erase_prob=random_erase_prob))
if args.cache_dataset:
print("Saving dataset_train to {}".format(cache_path))
utils.mkdir(os.path.dirname(cache_path))
Expand Down
3 changes: 1 addition & 2 deletions references/classification/train_quantization.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,8 +37,7 @@ def main(args):
train_dir = os.path.join(args.data_path, 'train')
val_dir = os.path.join(args.data_path, 'val')

dataset, dataset_test, train_sampler, test_sampler = load_data(train_dir, val_dir,
args.cache_dataset, args.distributed)
dataset, dataset_test, train_sampler, test_sampler = load_data(train_dir, val_dir, args)
data_loader = torch.utils.data.DataLoader(
dataset, batch_size=args.batch_size,
sampler=train_sampler, num_workers=args.workers, pin_memory=True)
Expand Down
126 changes: 67 additions & 59 deletions torchvision/models/mobilenetv3.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from functools import partial
from torch import nn, Tensor
from torch.nn import functional as F
from typing import Any, Callable, List, Optional, Sequence
from typing import Any, Callable, Dict, List, Optional, Sequence

from torchvision.models.utils import load_state_dict_from_url
from torchvision.models.mobilenetv2 import _make_divisible, ConvBNActivation
Expand All @@ -24,14 +24,18 @@ def __init__(self, input_channels: int, squeeze_factor: int = 4):
super().__init__()
squeeze_channels = _make_divisible(input_channels // squeeze_factor, 8)
self.fc1 = nn.Conv2d(input_channels, squeeze_channels, 1)
self.relu = nn.ReLU(inplace=True)
self.fc2 = nn.Conv2d(squeeze_channels, input_channels, 1)

def forward(self, input: Tensor) -> Tensor:
def _scale(self, input: Tensor, inplace: bool) -> Tensor:
scale = F.adaptive_avg_pool2d(input, 1)
scale = self.fc1(scale)
scale = F.relu(scale, inplace=True)
scale = self.relu(scale)
scale = self.fc2(scale)
scale = F.hardsigmoid(scale, inplace=True)
return F.hardsigmoid(scale, inplace=inplace)

def forward(self, input: Tensor) -> Tensor:
scale = self._scale(input, True)
return scale * input


Expand All @@ -55,7 +59,8 @@ def adjust_channels(channels: int, width_mult: float):

class InvertedResidual(nn.Module):

def __init__(self, cnf: InvertedResidualConfig, norm_layer: Callable[..., nn.Module]):
def __init__(self, cnf: InvertedResidualConfig, norm_layer: Callable[..., nn.Module],
se_layer: Callable[..., nn.Module] = SqueezeExcitation):
super().__init__()
if not (1 <= cnf.stride <= 2):
raise ValueError('illegal stride value')
Expand All @@ -76,7 +81,7 @@ def __init__(self, cnf: InvertedResidualConfig, norm_layer: Callable[..., nn.Mod
stride=stride, dilation=cnf.dilation, groups=cnf.expanded_channels,
norm_layer=norm_layer, activation_layer=activation_layer))
if cnf.use_se:
layers.append(SqueezeExcitation(cnf.expanded_channels))
layers.append(se_layer(cnf.expanded_channels))

# project
layers.append(ConvBNActivation(cnf.expanded_channels, cnf.out_channels, kernel_size=1, norm_layer=norm_layer,
Expand Down Expand Up @@ -179,7 +184,56 @@ def forward(self, x: Tensor) -> Tensor:
return self._forward_impl(x)


def _mobilenet_v3(
def _mobilenet_v3_conf(arch: str, params: Dict[str, Any]):
# non-public config parameters
reduce_divider = 2 if params.pop('_reduced_tail', False) else 1
dilation = 2 if params.pop('_dilated', False) else 1
width_mult = params.pop('_width_mult', 1.0)

bneck_conf = partial(InvertedResidualConfig, width_mult=width_mult)
adjust_channels = partial(InvertedResidualConfig.adjust_channels, width_mult=width_mult)

if arch == "mobilenet_v3_large":
inverted_residual_setting = [
bneck_conf(16, 3, 16, 16, False, "RE", 1, 1),
bneck_conf(16, 3, 64, 24, False, "RE", 2, 1), # C1
bneck_conf(24, 3, 72, 24, False, "RE", 1, 1),
bneck_conf(24, 5, 72, 40, True, "RE", 2, 1), # C2
bneck_conf(40, 5, 120, 40, True, "RE", 1, 1),
bneck_conf(40, 5, 120, 40, True, "RE", 1, 1),
bneck_conf(40, 3, 240, 80, False, "HS", 2, 1), # C3
bneck_conf(80, 3, 200, 80, False, "HS", 1, 1),
bneck_conf(80, 3, 184, 80, False, "HS", 1, 1),
bneck_conf(80, 3, 184, 80, False, "HS", 1, 1),
bneck_conf(80, 3, 480, 112, True, "HS", 1, 1),
bneck_conf(112, 3, 672, 112, True, "HS", 1, 1),
bneck_conf(112, 5, 672, 160 // reduce_divider, True, "HS", 2, dilation), # C4
bneck_conf(160 // reduce_divider, 5, 960 // reduce_divider, 160 // reduce_divider, True, "HS", 1, dilation),
bneck_conf(160 // reduce_divider, 5, 960 // reduce_divider, 160 // reduce_divider, True, "HS", 1, dilation),
]
last_channel = adjust_channels(1280 // reduce_divider) # C5
elif arch == "mobilenet_v3_small":
inverted_residual_setting = [
bneck_conf(16, 3, 16, 16, True, "RE", 2, 1), # C1
bneck_conf(16, 3, 72, 24, False, "RE", 2, 1), # C2
bneck_conf(24, 3, 88, 24, False, "RE", 1, 1),
bneck_conf(24, 5, 96, 40, True, "HS", 2, 1), # C3
bneck_conf(40, 5, 240, 40, True, "HS", 1, 1),
bneck_conf(40, 5, 240, 40, True, "HS", 1, 1),
bneck_conf(40, 5, 120, 48, True, "HS", 1, 1),
bneck_conf(48, 5, 144, 48, True, "HS", 1, 1),
bneck_conf(48, 5, 288, 96 // reduce_divider, True, "HS", 2, dilation), # C4
bneck_conf(96 // reduce_divider, 5, 576 // reduce_divider, 96 // reduce_divider, True, "HS", 1, dilation),
bneck_conf(96 // reduce_divider, 5, 576 // reduce_divider, 96 // reduce_divider, True, "HS", 1, dilation),
]
last_channel = adjust_channels(1024 // reduce_divider) # C5
else:
raise ValueError("Unsupported model type {}".format(arch))

return inverted_residual_setting, last_channel


def _mobilenet_v3_model(
arch: str,
inverted_residual_setting: List[InvertedResidualConfig],
last_channel: int,
Expand All @@ -205,34 +259,9 @@ def mobilenet_v3_large(pretrained: bool = False, progress: bool = True, **kwargs
pretrained (bool): If True, returns a model pre-trained on ImageNet
progress (bool): If True, displays a progress bar of the download to stderr
"""
# non-public config parameters
reduce_divider = 2 if kwargs.pop('_reduced_tail', False) else 1
dilation = 2 if kwargs.pop('_dilated', False) else 1
width_mult = 1.0

bneck_conf = partial(InvertedResidualConfig, width_mult=width_mult)
adjust_channels = partial(InvertedResidualConfig.adjust_channels, width_mult=width_mult)

inverted_residual_setting = [
bneck_conf(16, 3, 16, 16, False, "RE", 1, 1),
bneck_conf(16, 3, 64, 24, False, "RE", 2, 1), # C1
bneck_conf(24, 3, 72, 24, False, "RE", 1, 1),
bneck_conf(24, 5, 72, 40, True, "RE", 2, 1), # C2
bneck_conf(40, 5, 120, 40, True, "RE", 1, 1),
bneck_conf(40, 5, 120, 40, True, "RE", 1, 1),
bneck_conf(40, 3, 240, 80, False, "HS", 2, 1), # C3
bneck_conf(80, 3, 200, 80, False, "HS", 1, 1),
bneck_conf(80, 3, 184, 80, False, "HS", 1, 1),
bneck_conf(80, 3, 184, 80, False, "HS", 1, 1),
bneck_conf(80, 3, 480, 112, True, "HS", 1, 1),
bneck_conf(112, 3, 672, 112, True, "HS", 1, 1),
bneck_conf(112, 5, 672, 160 // reduce_divider, True, "HS", 2, dilation), # C4
bneck_conf(160 // reduce_divider, 5, 960 // reduce_divider, 160 // reduce_divider, True, "HS", 1, dilation),
bneck_conf(160 // reduce_divider, 5, 960 // reduce_divider, 160 // reduce_divider, True, "HS", 1, dilation),
]
last_channel = adjust_channels(1280 // reduce_divider) # C5

return _mobilenet_v3("mobilenet_v3_large", inverted_residual_setting, last_channel, pretrained, progress, **kwargs)
arch = "mobilenet_v3_large"
inverted_residual_setting, last_channel = _mobilenet_v3_conf(arch, kwargs)
return _mobilenet_v3_model(arch, inverted_residual_setting, last_channel, pretrained, progress, **kwargs)


def mobilenet_v3_small(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> MobileNetV3:
Expand All @@ -244,27 +273,6 @@ def mobilenet_v3_small(pretrained: bool = False, progress: bool = True, **kwargs
pretrained (bool): If True, returns a model pre-trained on ImageNet
progress (bool): If True, displays a progress bar of the download to stderr
"""
# non-public config parameters
reduce_divider = 2 if kwargs.pop('_reduced_tail', False) else 1
dilation = 2 if kwargs.pop('_dilated', False) else 1
width_mult = 1.0

bneck_conf = partial(InvertedResidualConfig, width_mult=width_mult)
adjust_channels = partial(InvertedResidualConfig.adjust_channels, width_mult=width_mult)

inverted_residual_setting = [
bneck_conf(16, 3, 16, 16, True, "RE", 2, 1), # C1
bneck_conf(16, 3, 72, 24, False, "RE", 2, 1), # C2
bneck_conf(24, 3, 88, 24, False, "RE", 1, 1),
bneck_conf(24, 5, 96, 40, True, "HS", 2, 1), # C3
bneck_conf(40, 5, 240, 40, True, "HS", 1, 1),
bneck_conf(40, 5, 240, 40, True, "HS", 1, 1),
bneck_conf(40, 5, 120, 48, True, "HS", 1, 1),
bneck_conf(48, 5, 144, 48, True, "HS", 1, 1),
bneck_conf(48, 5, 288, 96 // reduce_divider, True, "HS", 2, dilation), # C4
bneck_conf(96 // reduce_divider, 5, 576 // reduce_divider, 96 // reduce_divider, True, "HS", 1, dilation),
bneck_conf(96 // reduce_divider, 5, 576 // reduce_divider, 96 // reduce_divider, True, "HS", 1, dilation),
]
last_channel = adjust_channels(1024 // reduce_divider) # C5

return _mobilenet_v3("mobilenet_v3_small", inverted_residual_setting, last_channel, pretrained, progress, **kwargs)
arch = "mobilenet_v3_small"
inverted_residual_setting, last_channel = _mobilenet_v3_conf(arch, kwargs)
return _mobilenet_v3_model(arch, inverted_residual_setting, last_channel, pretrained, progress, **kwargs)
3 changes: 2 additions & 1 deletion torchvision/models/quantization/mobilenet.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from .mobilenetv2 import QuantizableMobileNetV2, mobilenet_v2, __all__ as mv2_all
from .mobilenetv3 import QuantizableMobileNetV3, mobilenet_v3_large, mobilenet_v3_small, __all__ as mv3_all

__all__ = mv2_all
__all__ = mv2_all + mv3_all
Loading