From fc728c1c18fe2a81771608b14fbb3b1d74d14414 Mon Sep 17 00:00:00 2001 From: Vasilis Vryniotis Date: Fri, 29 Jan 2021 11:16:13 +0000 Subject: [PATCH 1/8] Refactoring mobilenetv3 to make code reusable. --- torchvision/models/mobilenetv3.py | 126 ++++++++++++++++-------------- 1 file changed, 67 insertions(+), 59 deletions(-) diff --git a/torchvision/models/mobilenetv3.py b/torchvision/models/mobilenetv3.py index a7d45264dc5..f6117dcb989 100644 --- a/torchvision/models/mobilenetv3.py +++ b/torchvision/models/mobilenetv3.py @@ -3,7 +3,7 @@ from functools import partial from torch import nn, Tensor from torch.nn import functional as F -from typing import Any, Callable, List, Optional, Sequence +from typing import Any, Callable, Dict, List, Optional, Sequence from torchvision.models.utils import load_state_dict_from_url from torchvision.models.mobilenetv2 import _make_divisible, ConvBNActivation @@ -24,14 +24,18 @@ def __init__(self, input_channels: int, squeeze_factor: int = 4): super().__init__() squeeze_channels = _make_divisible(input_channels // squeeze_factor, 8) self.fc1 = nn.Conv2d(input_channels, squeeze_channels, 1) + self.relu = nn.ReLU(inplace=True) self.fc2 = nn.Conv2d(squeeze_channels, input_channels, 1) - def forward(self, input: Tensor) -> Tensor: + def _scale(self, input: Tensor, inplace: bool) -> Tensor: scale = F.adaptive_avg_pool2d(input, 1) scale = self.fc1(scale) - scale = F.relu(scale, inplace=True) + scale = self.relu(scale) scale = self.fc2(scale) - scale = F.hardsigmoid(scale, inplace=True) + return F.hardsigmoid(scale, inplace=inplace) + + def forward(self, input: Tensor) -> Tensor: + scale = self._scale(input, True) return scale * input @@ -55,7 +59,8 @@ def adjust_channels(channels: int, width_mult: float): class InvertedResidual(nn.Module): - def __init__(self, cnf: InvertedResidualConfig, norm_layer: Callable[..., nn.Module]): + def __init__(self, cnf: InvertedResidualConfig, norm_layer: Callable[..., nn.Module], + se_layer: Callable[..., nn.Module] = SqueezeExcitation): super().__init__() if not (1 <= cnf.stride <= 2): raise ValueError('illegal stride value') @@ -76,7 +81,7 @@ def __init__(self, cnf: InvertedResidualConfig, norm_layer: Callable[..., nn.Mod stride=stride, dilation=cnf.dilation, groups=cnf.expanded_channels, norm_layer=norm_layer, activation_layer=activation_layer)) if cnf.use_se: - layers.append(SqueezeExcitation(cnf.expanded_channels)) + layers.append(se_layer(cnf.expanded_channels)) # project layers.append(ConvBNActivation(cnf.expanded_channels, cnf.out_channels, kernel_size=1, norm_layer=norm_layer, @@ -179,7 +184,56 @@ def forward(self, x: Tensor) -> Tensor: return self._forward_impl(x) -def _mobilenet_v3( +def _mobilenet_v3_conf(arch: str, params: Dict[str, Any]): + # non-public config parameters + reduce_divider = 2 if params.pop('_reduced_tail', False) else 1 + dilation = 2 if params.pop('_dilated', False) else 1 + width_mult = params.pop('_width_mult', 1.0) + + bneck_conf = partial(InvertedResidualConfig, width_mult=width_mult) + adjust_channels = partial(InvertedResidualConfig.adjust_channels, width_mult=width_mult) + + if arch == "mobilenet_v3_large": + inverted_residual_setting = [ + bneck_conf(16, 3, 16, 16, False, "RE", 1, 1), + bneck_conf(16, 3, 64, 24, False, "RE", 2, 1), # C1 + bneck_conf(24, 3, 72, 24, False, "RE", 1, 1), + bneck_conf(24, 5, 72, 40, True, "RE", 2, 1), # C2 + bneck_conf(40, 5, 120, 40, True, "RE", 1, 1), + bneck_conf(40, 5, 120, 40, True, "RE", 1, 1), + bneck_conf(40, 3, 240, 80, False, "HS", 2, 1), # C3 + bneck_conf(80, 3, 200, 80, False, "HS", 1, 1), + bneck_conf(80, 3, 184, 80, False, "HS", 1, 1), + bneck_conf(80, 3, 184, 80, False, "HS", 1, 1), + bneck_conf(80, 3, 480, 112, True, "HS", 1, 1), + bneck_conf(112, 3, 672, 112, True, "HS", 1, 1), + bneck_conf(112, 5, 672, 160 // reduce_divider, True, "HS", 2, dilation), # C4 + bneck_conf(160 // reduce_divider, 5, 960 // reduce_divider, 160 // reduce_divider, True, "HS", 1, dilation), + bneck_conf(160 // reduce_divider, 5, 960 // reduce_divider, 160 // reduce_divider, True, "HS", 1, dilation), + ] + last_channel = adjust_channels(1280 // reduce_divider) # C5 + elif arch == "mobilenet_v3_small": + inverted_residual_setting = [ + bneck_conf(16, 3, 16, 16, True, "RE", 2, 1), # C1 + bneck_conf(16, 3, 72, 24, False, "RE", 2, 1), # C2 + bneck_conf(24, 3, 88, 24, False, "RE", 1, 1), + bneck_conf(24, 5, 96, 40, True, "HS", 2, 1), # C3 + bneck_conf(40, 5, 240, 40, True, "HS", 1, 1), + bneck_conf(40, 5, 240, 40, True, "HS", 1, 1), + bneck_conf(40, 5, 120, 48, True, "HS", 1, 1), + bneck_conf(48, 5, 144, 48, True, "HS", 1, 1), + bneck_conf(48, 5, 288, 96 // reduce_divider, True, "HS", 2, dilation), # C4 + bneck_conf(96 // reduce_divider, 5, 576 // reduce_divider, 96 // reduce_divider, True, "HS", 1, dilation), + bneck_conf(96 // reduce_divider, 5, 576 // reduce_divider, 96 // reduce_divider, True, "HS", 1, dilation), + ] + last_channel = adjust_channels(1024 // reduce_divider) # C5 + else: + raise ValueError("Unsupported model type {}".format(arch)) + + return inverted_residual_setting, last_channel + + +def _mobilenet_v3_model( arch: str, inverted_residual_setting: List[InvertedResidualConfig], last_channel: int, @@ -205,34 +259,9 @@ def mobilenet_v3_large(pretrained: bool = False, progress: bool = True, **kwargs pretrained (bool): If True, returns a model pre-trained on ImageNet progress (bool): If True, displays a progress bar of the download to stderr """ - # non-public config parameters - reduce_divider = 2 if kwargs.pop('_reduced_tail', False) else 1 - dilation = 2 if kwargs.pop('_dilated', False) else 1 - width_mult = 1.0 - - bneck_conf = partial(InvertedResidualConfig, width_mult=width_mult) - adjust_channels = partial(InvertedResidualConfig.adjust_channels, width_mult=width_mult) - - inverted_residual_setting = [ - bneck_conf(16, 3, 16, 16, False, "RE", 1, 1), - bneck_conf(16, 3, 64, 24, False, "RE", 2, 1), # C1 - bneck_conf(24, 3, 72, 24, False, "RE", 1, 1), - bneck_conf(24, 5, 72, 40, True, "RE", 2, 1), # C2 - bneck_conf(40, 5, 120, 40, True, "RE", 1, 1), - bneck_conf(40, 5, 120, 40, True, "RE", 1, 1), - bneck_conf(40, 3, 240, 80, False, "HS", 2, 1), # C3 - bneck_conf(80, 3, 200, 80, False, "HS", 1, 1), - bneck_conf(80, 3, 184, 80, False, "HS", 1, 1), - bneck_conf(80, 3, 184, 80, False, "HS", 1, 1), - bneck_conf(80, 3, 480, 112, True, "HS", 1, 1), - bneck_conf(112, 3, 672, 112, True, "HS", 1, 1), - bneck_conf(112, 5, 672, 160 // reduce_divider, True, "HS", 2, dilation), # C4 - bneck_conf(160 // reduce_divider, 5, 960 // reduce_divider, 160 // reduce_divider, True, "HS", 1, dilation), - bneck_conf(160 // reduce_divider, 5, 960 // reduce_divider, 160 // reduce_divider, True, "HS", 1, dilation), - ] - last_channel = adjust_channels(1280 // reduce_divider) # C5 - - return _mobilenet_v3("mobilenet_v3_large", inverted_residual_setting, last_channel, pretrained, progress, **kwargs) + arch = "mobilenet_v3_large" + inverted_residual_setting, last_channel = _mobilenet_v3_conf(arch, kwargs) + return _mobilenet_v3_model(arch, inverted_residual_setting, last_channel, pretrained, progress, **kwargs) def mobilenet_v3_small(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> MobileNetV3: @@ -244,27 +273,6 @@ def mobilenet_v3_small(pretrained: bool = False, progress: bool = True, **kwargs pretrained (bool): If True, returns a model pre-trained on ImageNet progress (bool): If True, displays a progress bar of the download to stderr """ - # non-public config parameters - reduce_divider = 2 if kwargs.pop('_reduced_tail', False) else 1 - dilation = 2 if kwargs.pop('_dilated', False) else 1 - width_mult = 1.0 - - bneck_conf = partial(InvertedResidualConfig, width_mult=width_mult) - adjust_channels = partial(InvertedResidualConfig.adjust_channels, width_mult=width_mult) - - inverted_residual_setting = [ - bneck_conf(16, 3, 16, 16, True, "RE", 2, 1), # C1 - bneck_conf(16, 3, 72, 24, False, "RE", 2, 1), # C2 - bneck_conf(24, 3, 88, 24, False, "RE", 1, 1), - bneck_conf(24, 5, 96, 40, True, "HS", 2, 1), # C3 - bneck_conf(40, 5, 240, 40, True, "HS", 1, 1), - bneck_conf(40, 5, 240, 40, True, "HS", 1, 1), - bneck_conf(40, 5, 120, 48, True, "HS", 1, 1), - bneck_conf(48, 5, 144, 48, True, "HS", 1, 1), - bneck_conf(48, 5, 288, 96 // reduce_divider, True, "HS", 2, dilation), # C4 - bneck_conf(96 // reduce_divider, 5, 576 // reduce_divider, 96 // reduce_divider, True, "HS", 1, dilation), - bneck_conf(96 // reduce_divider, 5, 576 // reduce_divider, 96 // reduce_divider, True, "HS", 1, dilation), - ] - last_channel = adjust_channels(1024 // reduce_divider) # C5 - - return _mobilenet_v3("mobilenet_v3_small", inverted_residual_setting, last_channel, pretrained, progress, **kwargs) + arch = "mobilenet_v3_small" + inverted_residual_setting, last_channel = _mobilenet_v3_conf(arch, kwargs) + return _mobilenet_v3_model(arch, inverted_residual_setting, last_channel, pretrained, progress, **kwargs) From a4ec03654ce5a6340d42028dd47d72175d36fb32 Mon Sep 17 00:00:00 2001 From: Vasilis Vryniotis Date: Fri, 29 Jan 2021 11:26:42 +0000 Subject: [PATCH 2/8] Adding quantizable MobileNetV3 architecture. --- torchvision/models/quantization/mobilenet.py | 3 +- .../models/quantization/mobilenetv3.py | 136 ++++++++++++++++++ 2 files changed, 138 insertions(+), 1 deletion(-) create mode 100644 torchvision/models/quantization/mobilenetv3.py diff --git a/torchvision/models/quantization/mobilenet.py b/torchvision/models/quantization/mobilenet.py index 8b4e61f8ae6..4ffd14e2793 100644 --- a/torchvision/models/quantization/mobilenet.py +++ b/torchvision/models/quantization/mobilenet.py @@ -1,3 +1,4 @@ from .mobilenetv2 import QuantizableMobileNetV2, mobilenet_v2, __all__ as mv2_all +from .mobilenetv3 import QuantizableMobileNetV3, mobilenet_v3_large, mobilenet_v3_small, __all__ as mv3_all -__all__ = mv2_all +__all__ = mv2_all + mv3_all diff --git a/torchvision/models/quantization/mobilenetv3.py b/torchvision/models/quantization/mobilenetv3.py new file mode 100644 index 00000000000..85e6b8a1297 --- /dev/null +++ b/torchvision/models/quantization/mobilenetv3.py @@ -0,0 +1,136 @@ +from torch import nn, Tensor +from torchvision.models.utils import load_state_dict_from_url +from torchvision.models.mobilenetv3 import InvertedResidual, InvertedResidualConfig, ConvBNActivation, MobileNetV3,\ + SqueezeExcitation, model_urls, _mobilenet_v3_conf +from torch.quantization import QuantStub, DeQuantStub, fuse_modules +from typing import Any, List +from .utils import _replace_relu, quantize_model + + +__all__ = ['QuantizableMobileNetV3', 'mobilenet_v3_large', 'mobilenet_v3_small'] + +# TODO: Add URLs +quant_model_urls = { + 'mobilenet_v3_large_qnnpack': None, + 'mobilenet_v3_small_qnnpack': None, +} + + +class QuantizableSqueezeExcitation(SqueezeExcitation): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.skip_mul = nn.quantized.FloatFunctional() + + def forward(self, input: Tensor) -> Tensor: + return self.skip_mul.mul(self._scale(input, False), input) + + def fuse_model(self): + fuse_modules(self, ['fc1', 'relu'], inplace=True) + + +class QuantizableInvertedResidual(InvertedResidual): + def __init__(self, *args, **kwargs): + super().__init__(*args, se_layer=QuantizableSqueezeExcitation, **kwargs) + self.skip_add = nn.quantized.FloatFunctional() + + def forward(self, x): + if self.use_res_connect: + return self.skip_add.add(x, self.block(x)) + else: + return self.block(x) + + +class QuantizableMobileNetV3(MobileNetV3): + def __init__(self, *args, **kwargs): + """ + MobileNet V3 main class + + Args: + Inherits args from floating point MobileNetV3 + """ + super().__init__(*args, **kwargs) + self.quant = QuantStub() + self.dequant = DeQuantStub() + + def forward(self, x): + x = self.quant(x) + x = self._forward_impl(x) + x = self.dequant(x) + return x + + def fuse_model(self): + for m in self.modules(): + if type(m) == ConvBNActivation: + modules_to_fuse = ['0', '1'] + if type(m[2]) == nn.ReLU: + modules_to_fuse.append('2') + fuse_modules(m, modules_to_fuse, inplace=True) + elif type(m) == QuantizableSqueezeExcitation: + m.fuse_model() + + +def _mobilenet_v3_model( + arch: str, + inverted_residual_setting: List[InvertedResidualConfig], + last_channel: int, + pretrained: bool, + progress: bool, + quantize: bool, + **kwargs: Any +): + model = QuantizableMobileNetV3(inverted_residual_setting, last_channel, block=QuantizableInvertedResidual, **kwargs) + _replace_relu(model) + + if quantize: + backend = 'qnnpack' + quantize_model(model, backend) + model_url = quant_model_urls.get(arch + '_' + backend, None) + else: + assert pretrained in [True, False] + model_url = model_urls.get(arch, None) + + if pretrained: + if model_url is None: + raise ValueError("No checkpoint is available for {}".format(arch)) + state_dict = load_state_dict_from_url(model_url, progress=progress) + model.load_state_dict(state_dict) + + return model + + +def mobilenet_v3_large(pretrained=False, progress=True, quantize=False, **kwargs): + """ + Constructs a MobileNetV3 Large architecture from + `"Searching for MobileNetV3" `_. + + Note that quantize = True returns a quantized model with 8 bit + weights. Quantized models only support inference and run on CPUs. + GPU inference is not yet supported + + Args: + pretrained (bool): If True, returns a model pre-trained on ImageNet. + progress (bool): If True, displays a progress bar of the download to stderr + quantize (bool): If True, returns a quantized model, else returns a float model + """ + arch = "mobilenet_v3_large" + inverted_residual_setting, last_channel = _mobilenet_v3_conf(arch, kwargs) + return _mobilenet_v3_model(arch, inverted_residual_setting, last_channel, pretrained, progress, quantize, **kwargs) + + +def mobilenet_v3_small(pretrained=False, progress=True, quantize=False, **kwargs): + """ + Constructs a MobileNetV3 Small architecture from + `"Searching for MobileNetV3" `_. + + Note that quantize = True returns a quantized model with 8 bit + weights. Quantized models only support inference and run on CPUs. + GPU inference is not yet supported + + Args: + pretrained (bool): If True, returns a model pre-trained on ImageNet. + progress (bool): If True, displays a progress bar of the download to stderr + quantize (bool): If True, returns a quantized model, else returns a float model + """ + arch = "mobilenet_v3_small" + inverted_residual_setting, last_channel = _mobilenet_v3_conf(arch, kwargs) + return _mobilenet_v3_model(arch, inverted_residual_setting, last_channel, pretrained, progress, quantize, **kwargs) From 4e03a0b089fd93a256f3d07290488cab5e55d31e Mon Sep 17 00:00:00 2001 From: Vasilis Vryniotis Date: Fri, 29 Jan 2021 12:31:52 +0000 Subject: [PATCH 3/8] Fix bug on reference script. --- references/classification/train.py | 6 ++++-- references/classification/train_quantization.py | 3 +-- 2 files changed, 5 insertions(+), 4 deletions(-) diff --git a/references/classification/train.py b/references/classification/train.py index 522ceaf3daa..232c3b5556b 100644 --- a/references/classification/train.py +++ b/references/classification/train.py @@ -92,10 +92,12 @@ def load_data(traindir, valdir, args): print("Loading dataset_train from {}".format(cache_path)) dataset, _ = torch.load(cache_path) else: + auto_augment_policy = getattr(args, "auto_augment", None) + random_erase_prob = getattr(args, "random_erase", 0.0) dataset = torchvision.datasets.ImageFolder( traindir, - presets.ClassificationPresetTrain(crop_size=crop_size, auto_augment_policy=args.auto_augment, - random_erase_prob=args.random_erase)) + presets.ClassificationPresetTrain(crop_size=crop_size, auto_augment_policy=auto_augment_policy, + random_erase_prob=random_erase_prob)) if args.cache_dataset: print("Saving dataset_train to {}".format(cache_path)) utils.mkdir(os.path.dirname(cache_path)) diff --git a/references/classification/train_quantization.py b/references/classification/train_quantization.py index b0452a9426e..dd41d0b3d1f 100644 --- a/references/classification/train_quantization.py +++ b/references/classification/train_quantization.py @@ -37,8 +37,7 @@ def main(args): train_dir = os.path.join(args.data_path, 'train') val_dir = os.path.join(args.data_path, 'val') - dataset, dataset_test, train_sampler, test_sampler = load_data(train_dir, val_dir, - args.cache_dataset, args.distributed) + dataset, dataset_test, train_sampler, test_sampler = load_data(train_dir, val_dir, args) data_loader = torch.utils.data.DataLoader( dataset, batch_size=args.batch_size, sampler=train_sampler, num_workers=args.workers, pin_memory=True) From 5baebb462626eedb7ebb8dc169b53a0168aa9a14 Mon Sep 17 00:00:00 2001 From: Vasilis Vryniotis Date: Fri, 29 Jan 2021 15:04:15 +0000 Subject: [PATCH 4/8] Moving documentation of quantized models in the right place. --- docs/source/models.rst | 47 +++++++++++++++++++++++++++++ references/classification/README.md | 34 ++++++--------------- 2 files changed, 57 insertions(+), 24 deletions(-) diff --git a/docs/source/models.rst b/docs/source/models.rst index f4188a5ad1f..4d11c2d6205 100644 --- a/docs/source/models.rst +++ b/docs/source/models.rst @@ -263,6 +263,53 @@ MNASNet .. autofunction:: mnasnet1_0 .. autofunction:: mnasnet1_3 +Quantized Models +---------------- + +The following architectures provide support for INT8 quantized models. You can get +a model with random weights by calling its constructor: + +.. code:: python + + import torchvision.models as models + googlenet = models.quantization.googlenet() + inception_v3 = models.quantization.inception_v3() + mobilenet_v2 = models.quantization.mobilenet_v2() + mobilenet_v3_large = models.quantization.mobilenet_v3_large() + mobilenet_v3_small = models.quantization.mobilenet_v3_small() + resnet18 = models.quantization.resnet18() + resnet50 = models.quantization.resnet50() + resnext101_32x8d = models.quantization.resnext101_32x8d() + shufflenet_v2_x0_5 = models.quantization.shufflenet_v2_x0_5() + shufflenet_v2_x1_0 = models.quantization.shufflenet_v2_x1_0() + shufflenet_v2_x1_5 = models.quantization.shufflenet_v2_x1_5() + shufflenet_v2_x2_0 = models.quantization.shufflenet_v2_x2_0() + +Obtaining a pre-trained quantized model can be done with a few lines of code: + +.. code:: python + + import torchvision.models as models + model = models.quantization.mobilenet_v2(pretrained=True, quantize=True) + model.eval() + # run the model with quantized inputs and weights + out = model(torch.rand(1, 3, 224, 224)) + +We provide pre-trained quantized weights for the following models: + +================================ ============= ============= +Model Acc@1 Acc@5 +================================ ============= ============= +MobileNet V2 71.658 90.150 +MobileNet V3 Large TODO TODO +ShuffleNet V2 68.360 87.582 +ResNet 18 69.494 88.882 +ResNet 50 75.920 92.814 +ResNext 101 32x8d 78.986 94.480 +Inception V3 77.176 93.354 +GoogleNet 69.826 89.404 +================================ ============= ============= + Semantic Segmentation ===================== diff --git a/references/classification/README.md b/references/classification/README.md index d8c5eff8c05..8e6ecdfd589 100644 --- a/references/classification/README.md +++ b/references/classification/README.md @@ -74,27 +74,6 @@ python -m torch.distributed.launch --nproc_per_node=8 --use_env train.py\ ``` ## Quantized -### INT8 models -We add INT8 quantized models to follow the quantization support added in PyTorch 1.3. - -Obtaining a pre-trained quantized model can be obtained with a few lines of code: -``` -model = torchvision.models.quantization.mobilenet_v2(pretrained=True, quantize=True) -model.eval() -# run the model with quantized inputs and weights -out = model(torch.rand(1, 3, 224, 224)) -``` -We provide pre-trained quantized weights for the following models: - -| Model | Acc@1 | Acc@5 | -|:-----------------:|:------:|:------:| -| MobileNet V2 | 71.658 | 90.150 | -| ShuffleNet V2: | 68.360 | 87.582 | -| ResNet 18 | 69.494 | 88.882 | -| ResNet 50 | 75.920 | 92.814 | -| ResNext 101 32x8d | 78.986 | 94.480 | -| Inception V3 | 77.176 | 93.354 | -| GoogleNet | 69.826 | 89.404 | ### Parameters used for generating quantized models: @@ -106,6 +85,10 @@ For all post training quantized models (All quantized models except mobilenet-v2 4. eval_batch_size: 128 5. backend: 'fbgemm' +``` +python train_quantization.py --device='cpu' --post-training-quantize --backend='fbgemm' --model='' +``` + For Mobilenet-v2, the model was trained with quantization aware training, the settings used are: 1. num_workers: 16 2. batch_size: 32 @@ -119,14 +102,17 @@ For Mobilenet-v2, the model was trained with quantization aware training, the se 10. lr_step_size:30 11. lr_gamma: 0.1 +``` +python -m torch.distributed.launch --nproc_per_node=8 --use_env train_quantization.py --model='mobilenetv2' +``` + Training converges at about 10 epochs. For post training quant, device is set to CPU. For training, the device is set to CUDA ### Command to evaluate quantized models using the pre-trained weights: -For all quantized models: + ``` -python references/classification/train_quantization.py --data-path='imagenet_full_size/' \ - --device='cpu' --test-only --backend='fbgemm' --model='' +python train_quantization.py --device='cpu' --test-only --backend='fbgemm' --model='' ``` From 3d69b2acb62214bfa8aa8ad16811534382587dba Mon Sep 17 00:00:00 2001 From: Vasilis Vryniotis Date: Tue, 2 Feb 2021 14:41:25 +0000 Subject: [PATCH 5/8] Update documentation. --- docs/source/models.rst | 2 +- references/classification/README.md | 24 ++++++++++++++++++++++-- 2 files changed, 23 insertions(+), 3 deletions(-) diff --git a/docs/source/models.rst b/docs/source/models.rst index 4d11c2d6205..8d4d7260746 100644 --- a/docs/source/models.rst +++ b/docs/source/models.rst @@ -301,7 +301,7 @@ We provide pre-trained quantized weights for the following models: Model Acc@1 Acc@5 ================================ ============= ============= MobileNet V2 71.658 90.150 -MobileNet V3 Large TODO TODO +MobileNet V3 Large 73.004 90.858 ShuffleNet V2 68.360 87.582 ResNet 18 69.494 88.882 ResNet 50 75.920 92.814 diff --git a/references/classification/README.md b/references/classification/README.md index 8e6ecdfd589..e2fcf2ae96b 100644 --- a/references/classification/README.md +++ b/references/classification/README.md @@ -101,14 +101,34 @@ For Mobilenet-v2, the model was trained with quantization aware training, the se 9. momentum: 0.9 10. lr_step_size:30 11. lr_gamma: 0.1 +12. weight-decay: 0.0001 ``` -python -m torch.distributed.launch --nproc_per_node=8 --use_env train_quantization.py --model='mobilenetv2' +python -m torch.distributed.launch --nproc_per_node=8 --use_env train_quantization.py --model='mobilenet_v2' ``` Training converges at about 10 epochs. -For post training quant, device is set to CPU. For training, the device is set to CUDA +For Mobilenet-v3 Large, the model was trained with quantization aware training, the settings used are: +1. num_workers: 16 +2. batch_size: 32 +3. eval_batch_size: 128 +4. backend: 'qnnpack' +5. learning-rate: 0.001 +6. num_epochs: 90 +7. num_observer_update_epochs:4 +8. num_batch_norm_update_epochs:3 +9. momentum: 0.9 +10. lr_step_size:30 +11. lr_gamma: 0.1 +12. weight-decay: 0.00001 + +``` +python -m torch.distributed.launch --nproc_per_node=8 --use_env train_quantization.py --model='mobilenet_v3_large' \ + --wd 0.00001 --lr 0.001 +``` + +For post training quant, device is set to CPU. For training, the device is set to CUDA. ### Command to evaluate quantized models using the pre-trained weights: From 274c6a1393384054876d701ffa1b54eb6750f1d8 Mon Sep 17 00:00:00 2001 From: Vasilis Vryniotis Date: Tue, 2 Feb 2021 15:03:39 +0000 Subject: [PATCH 6/8] Workaround for loading correct weights of quant model. --- .../models/quantization/mobilenetv3.py | 42 ++++++++++++------- 1 file changed, 28 insertions(+), 14 deletions(-) diff --git a/torchvision/models/quantization/mobilenetv3.py b/torchvision/models/quantization/mobilenetv3.py index 85e6b8a1297..aa78b476506 100644 --- a/torchvision/models/quantization/mobilenetv3.py +++ b/torchvision/models/quantization/mobilenetv3.py @@ -1,17 +1,18 @@ +import torch from torch import nn, Tensor from torchvision.models.utils import load_state_dict_from_url from torchvision.models.mobilenetv3 import InvertedResidual, InvertedResidualConfig, ConvBNActivation, MobileNetV3,\ SqueezeExcitation, model_urls, _mobilenet_v3_conf from torch.quantization import QuantStub, DeQuantStub, fuse_modules -from typing import Any, List -from .utils import _replace_relu, quantize_model +from typing import Any, List, Optional +from .utils import _replace_relu __all__ = ['QuantizableMobileNetV3', 'mobilenet_v3_large', 'mobilenet_v3_small'] -# TODO: Add URLs quant_model_urls = { - 'mobilenet_v3_large_qnnpack': None, + 'mobilenet_v3_large_qnnpack': + "https://github.com/datumbox/torchvision-models/raw/main/quantized/mobilenet_v3_large_qnnpack-5bcacf28.pth", 'mobilenet_v3_small_qnnpack': None, } @@ -69,6 +70,18 @@ def fuse_model(self): m.fuse_model() +def _load_weights( + arch: str, + model: QuantizableMobileNetV3, + model_url: Optional[str], + progress: bool, +): + if model_url is None: + raise ValueError("No checkpoint is available for {}".format(arch)) + state_dict = load_state_dict_from_url(model_url, progress=progress) + model.load_state_dict(state_dict) + + def _mobilenet_v3_model( arch: str, inverted_residual_setting: List[InvertedResidualConfig], @@ -83,17 +96,18 @@ def _mobilenet_v3_model( if quantize: backend = 'qnnpack' - quantize_model(model, backend) - model_url = quant_model_urls.get(arch + '_' + backend, None) + + model.fuse_model() + model.qconfig = torch.quantization.get_default_qat_qconfig(backend) + torch.quantization.prepare_qat(model, inplace=True) + + if pretrained: + _load_weights(arch, model, quant_model_urls.get(arch + '_' + backend, None), progress) + + torch.quantization.convert(model, inplace=True) else: - assert pretrained in [True, False] - model_url = model_urls.get(arch, None) - - if pretrained: - if model_url is None: - raise ValueError("No checkpoint is available for {}".format(arch)) - state_dict = load_state_dict_from_url(model_url, progress=progress) - model.load_state_dict(state_dict) + if pretrained: + _load_weights(arch, model, model_urls.get(arch, None), progress) return model From aa448560f20b9e066d466a481d51858c9c0a49bc Mon Sep 17 00:00:00 2001 From: Vasilis Vryniotis Date: Tue, 2 Feb 2021 15:30:35 +0000 Subject: [PATCH 7/8] Update weight URL and readme. --- references/classification/README.md | 2 +- torchvision/models/quantization/mobilenetv3.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/references/classification/README.md b/references/classification/README.md index e2fcf2ae96b..1694b25c7a8 100644 --- a/references/classification/README.md +++ b/references/classification/README.md @@ -133,6 +133,6 @@ For post training quant, device is set to CPU. For training, the device is set t ### Command to evaluate quantized models using the pre-trained weights: ``` -python train_quantization.py --device='cpu' --test-only --backend='fbgemm' --model='' +python train_quantization.py --device='cpu' --test-only --backend='' --model='' ``` diff --git a/torchvision/models/quantization/mobilenetv3.py b/torchvision/models/quantization/mobilenetv3.py index aa78b476506..db6737267d1 100644 --- a/torchvision/models/quantization/mobilenetv3.py +++ b/torchvision/models/quantization/mobilenetv3.py @@ -12,7 +12,7 @@ quant_model_urls = { 'mobilenet_v3_large_qnnpack': - "https://github.com/datumbox/torchvision-models/raw/main/quantized/mobilenet_v3_large_qnnpack-5bcacf28.pth", + "https://download.pytorch.org/models/quantized/mobilenet_v3_large_qnnpack-5bcacf28.pth", 'mobilenet_v3_small_qnnpack': None, } From 6bd42ffbde174465a2af04d420a223e947fe4a8b Mon Sep 17 00:00:00 2001 From: Vasilis Vryniotis Date: Tue, 2 Feb 2021 16:35:55 +0000 Subject: [PATCH 8/8] Adding eval. --- torchvision/models/quantization/mobilenetv3.py | 1 + 1 file changed, 1 insertion(+) diff --git a/torchvision/models/quantization/mobilenetv3.py b/torchvision/models/quantization/mobilenetv3.py index db6737267d1..eafe39bb041 100644 --- a/torchvision/models/quantization/mobilenetv3.py +++ b/torchvision/models/quantization/mobilenetv3.py @@ -105,6 +105,7 @@ def _mobilenet_v3_model( _load_weights(arch, model, quant_model_urls.get(arch + '_' + backend, None), progress) torch.quantization.convert(model, inplace=True) + model.eval() else: if pretrained: _load_weights(arch, model, model_urls.get(arch, None), progress)