From aea11915e8641bdc584095df0827204215f469a4 Mon Sep 17 00:00:00 2001 From: Vasilis Vryniotis Date: Tue, 5 Jan 2021 19:14:37 +0000 Subject: [PATCH] Add MobileNetV3 Architecture in TorchVision (#3182) * Adding implementation of network architecture * Adding rmsprop support on the train.py * Adding auto-augment and random-erase in the training scripts. * Adding support for reduced tail on MobileNetV3. * Tagging blocks with comments. --- hubconf.py | 3 +- references/classification/train.py | 49 ++- ...lTester.test_mobilenet_v3_large_expect.pkl | Bin 0 -> 953 bytes ...lTester.test_mobilenet_v3_small_expect.pkl | Bin 0 -> 953 bytes test/test_models.py | 17 +- torchvision/models/mobilenet.py | 3 +- torchvision/models/mobilenetv3.py | 279 ++++++++++++++++++ 7 files changed, 325 insertions(+), 26 deletions(-) create mode 100644 test/expect/ModelTester.test_mobilenet_v3_large_expect.pkl create mode 100644 test/expect/ModelTester.test_mobilenet_v3_small_expect.pkl create mode 100644 torchvision/models/mobilenetv3.py diff --git a/hubconf.py b/hubconf.py index 79c22bd938b..dec4a7fb196 100644 --- a/hubconf.py +++ b/hubconf.py @@ -11,7 +11,8 @@ from torchvision.models.vgg import vgg11, vgg13, vgg16, vgg19, vgg11_bn, vgg13_bn, vgg16_bn, vgg19_bn from torchvision.models.googlenet import googlenet from torchvision.models.shufflenetv2 import shufflenet_v2_x0_5, shufflenet_v2_x1_0 -from torchvision.models.mobilenet import mobilenet_v2 +from torchvision.models.mobilenetv2 import mobilenet_v2 +from torchvision.models.mobilenetv3 import mobilenet_v3_large, mobilenet_v3_small from torchvision.models.mnasnet import mnasnet0_5, mnasnet0_75, mnasnet1_0, \ mnasnet1_3 diff --git a/references/classification/train.py b/references/classification/train.py index 789bb8134ff..47a7e5955e6 100644 --- a/references/classification/train.py +++ b/references/classification/train.py @@ -79,7 +79,7 @@ def _get_cache_path(filepath): return cache_path -def load_data(traindir, valdir, cache_dataset, distributed): +def load_data(traindir, valdir, args): # Data loading code print("Loading data") normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], @@ -88,20 +88,28 @@ def load_data(traindir, valdir, cache_dataset, distributed): print("Loading training data") st = time.time() cache_path = _get_cache_path(traindir) - if cache_dataset and os.path.exists(cache_path): + if args.cache_dataset and os.path.exists(cache_path): # Attention, as the transforms are also cached! print("Loading dataset_train from {}".format(cache_path)) dataset, _ = torch.load(cache_path) else: + trans = [ + transforms.RandomResizedCrop(224), + transforms.RandomHorizontalFlip(), + ] + if args.auto_augment is not None: + aa_policy = transforms.AutoAugmentPolicy(args.auto_augment) + trans.append(transforms.AutoAugment(policy=aa_policy)) + trans.extend([ + transforms.ToTensor(), + normalize, + ]) + if args.random_erase > 0: + trans.append(transforms.RandomErasing(p=args.random_erase)) dataset = torchvision.datasets.ImageFolder( traindir, - transforms.Compose([ - transforms.RandomResizedCrop(224), - transforms.RandomHorizontalFlip(), - transforms.ToTensor(), - normalize, - ])) - if cache_dataset: + transforms.Compose(trans)) + if args.cache_dataset: print("Saving dataset_train to {}".format(cache_path)) utils.mkdir(os.path.dirname(cache_path)) utils.save_on_master((dataset, traindir), cache_path) @@ -109,7 +117,7 @@ def load_data(traindir, valdir, cache_dataset, distributed): print("Loading validation data") cache_path = _get_cache_path(valdir) - if cache_dataset and os.path.exists(cache_path): + if args.cache_dataset and os.path.exists(cache_path): # Attention, as the transforms are also cached! print("Loading dataset_test from {}".format(cache_path)) dataset_test, _ = torch.load(cache_path) @@ -122,13 +130,13 @@ def load_data(traindir, valdir, cache_dataset, distributed): transforms.ToTensor(), normalize, ])) - if cache_dataset: + if args.cache_dataset: print("Saving dataset_test to {}".format(cache_path)) utils.mkdir(os.path.dirname(cache_path)) utils.save_on_master((dataset_test, valdir), cache_path) print("Creating data loaders") - if distributed: + if args.distributed: train_sampler = torch.utils.data.distributed.DistributedSampler(dataset) test_sampler = torch.utils.data.distributed.DistributedSampler(dataset_test) else: @@ -155,8 +163,7 @@ def main(args): train_dir = os.path.join(args.data_path, 'train') val_dir = os.path.join(args.data_path, 'val') - dataset, dataset_test, train_sampler, test_sampler = load_data(train_dir, val_dir, - args.cache_dataset, args.distributed) + dataset, dataset_test, train_sampler, test_sampler = load_data(train_dir, val_dir, args) data_loader = torch.utils.data.DataLoader( dataset, batch_size=args.batch_size, sampler=train_sampler, num_workers=args.workers, pin_memory=True) @@ -173,8 +180,15 @@ def main(args): criterion = nn.CrossEntropyLoss() - optimizer = torch.optim.SGD( - model.parameters(), lr=args.lr, momentum=args.momentum, weight_decay=args.weight_decay) + opt_name = args.opt.lower() + if opt_name == 'sgd': + optimizer = torch.optim.SGD( + model.parameters(), lr=args.lr, momentum=args.momentum, weight_decay=args.weight_decay) + elif opt_name == 'rmsprop': + optimizer = torch.optim.RMSprop(model.parameters(), lr=args.lr, momentum=args.momentum, + weight_decay=args.weight_decay, eps=0.0316, alpha=0.9) + else: + raise RuntimeError("Invalid optimizer {}. Only SGD and RMSprop are supported.".format(args.opt)) if args.apex: model, optimizer = amp.initialize(model, optimizer, @@ -238,6 +252,7 @@ def parse_args(): help='number of total epochs to run') parser.add_argument('-j', '--workers', default=16, type=int, metavar='N', help='number of data loading workers (default: 16)') + parser.add_argument('--opt', default='sgd', type=str, help='optimizer') parser.add_argument('--lr', default=0.1, type=float, help='initial learning rate') parser.add_argument('--momentum', default=0.9, type=float, metavar='M', help='momentum') @@ -275,6 +290,8 @@ def parse_args(): help="Use pre-trained models from the modelzoo", action="store_true", ) + parser.add_argument('--auto-augment', default=None, help='auto augment policy (default: None)') + parser.add_argument('--random-erase', default=0.0, type=float, help='random erasing probability (default: 0.0)') # Mixed precision training parameters parser.add_argument('--apex', action='store_true', diff --git a/test/expect/ModelTester.test_mobilenet_v3_large_expect.pkl b/test/expect/ModelTester.test_mobilenet_v3_large_expect.pkl new file mode 100644 index 0000000000000000000000000000000000000000..9691daf18c7c691a970ac8204305a28061a4b394 GIT binary patch literal 953 zcmWIWW@cev;NW1u00Im`42ea_8JT6N`YDMeiFyUuIc`pT3{fbcfhoBpAE-(%zO*DW zr zf*+{H(8R#h(7+4?&CM;13YjCAfuhL;rG+fsMkR%;f!++>jNV3Vg>2qN-i&R9>>bI; z`8heM$t9WjdBt4*MJcI8sVOd*$t7Hc9GXFeoF#=^L519U0p9E!7Z**lSphT+gadH; zTZVxT#ozj9PI7bNLJETmd>U(d(^K@U%C6{f|C3)=?4PmDCVIiT?EhYRGyE*{td8DT zSF(Kj+RNWn*83bWT^BZY=DI%a-)kFE^tB8cGS{xlY+ak9`9N>?TeEe)zDTU!*LG@M zmf_!Z8CzZTc0Cr>^JqA;POYj^Z(pvH-u~$ox=IS?^-etWS^I@^u};7o9=$H3rgf9O z4(R3Fc3!`BDeL-^wwHA$-CnwGu3EUhqvnNm8C=bJ@lO}6i)i_vWAIFNosB-9-XV<@ z>!xhJqx;FSa$SfJtFGGbW#I6+bM3*i6Tq+mVcg-v&tMG?pR&}VVqmztIhi8`2}=4P z#4+Xq$IQI+P$r;ieH6lFm<0AR$n$KVoWc+eRlo>j2Y9ox Nfy9`B5TqWW767`o`& zf*+{H(8R#h(7?jn%+$=nz|5qOIf5A|np{v?$l`5OQpg(U&EU=GZPZrC=56H7*jC8i zk(`{LlarcUl9``Z%;jH{l3J9S;*yzM!d1wj8C1wwQpgol$gLOP&CZeF9)5cT&@>PZ z!0B%p20j#j>!Ufz&4~*s4Dz={uT#GJYkk|{b>EJ)vW8wHa%JQ_FIGg zp_+;7nHROMe;0d9-|M;CdM`Wo4SD)T}KtvjysOMlr9(e)jga_cLew(2`JMCf<_ zU%xKw+-m)xtvu^jz2B;TApEzEQRF4PGn@GgI6kaffB0+RI}yLBzv~;#UZB5hQ`Gw9-k0^|<_oQ#yX5w|uP0XN&-i&+cbcQ@`g^C1 z*01GStv`qR&bp$42kVUUOZ7Qlz66KQ^fIO+CxBrC!nnhSpTQa)K4qyz#lUcPb23K? z5|s2oh-1tJikXY^(nFbmwt{egHzSCGr%B{k697q|0Q3}!t{d4;GAKHK0C~u|(c>B2 zBxE;&QYr#S0No2S2^t~+-fV0-P!)2_x^T6i#0&y3`Y43UFbV8skmuPzIfWq{s(=y5 R4)A7W1Bo#MAxJ$$EdbN+|G@wN literal 0 HcmV?d00001 diff --git a/test/test_models.py b/test/test_models.py index d40649ffb65..dfbaf88be6c 100644 --- a/test/test_models.py +++ b/test/test_models.py @@ -275,16 +275,17 @@ def test_mobilenet_v2_residual_setting(self): out = model(x) self.assertEqual(out.shape[-1], 1000) - def test_mobilenetv2_norm_layer(self): - model = models.__dict__["mobilenet_v2"]() - self.assertTrue(any(isinstance(x, nn.BatchNorm2d) for x in model.modules())) + def test_mobilenet_norm_layer(self): + for name in ["mobilenet_v2", "mobilenet_v3_large", "mobilenet_v3_small"]: + model = models.__dict__[name]() + self.assertTrue(any(isinstance(x, nn.BatchNorm2d) for x in model.modules())) - def get_gn(num_channels): - return nn.GroupNorm(32, num_channels) + def get_gn(num_channels): + return nn.GroupNorm(32, num_channels) - model = models.__dict__["mobilenet_v2"](norm_layer=get_gn) - self.assertFalse(any(isinstance(x, nn.BatchNorm2d) for x in model.modules())) - self.assertTrue(any(isinstance(x, nn.GroupNorm) for x in model.modules())) + model = models.__dict__[name](norm_layer=get_gn) + self.assertFalse(any(isinstance(x, nn.BatchNorm2d) for x in model.modules())) + self.assertTrue(any(isinstance(x, nn.GroupNorm) for x in model.modules())) def test_inception_v3_eval(self): # replacement for models.inception_v3(pretrained=True) that does not download weights diff --git a/torchvision/models/mobilenet.py b/torchvision/models/mobilenet.py index 0c7cc10df5b..4108305d3f5 100644 --- a/torchvision/models/mobilenet.py +++ b/torchvision/models/mobilenet.py @@ -1,3 +1,4 @@ from .mobilenetv2 import MobileNetV2, mobilenet_v2, __all__ as mv2_all +from .mobilenetv3 import MobileNetV3, mobilenet_v3_large, mobilenet_v3_small, __all__ as mv3_all -__all__ = mv2_all +__all__ = mv2_all + mv3_all diff --git a/torchvision/models/mobilenetv3.py b/torchvision/models/mobilenetv3.py new file mode 100644 index 00000000000..6282cd45434 --- /dev/null +++ b/torchvision/models/mobilenetv3.py @@ -0,0 +1,279 @@ +import torch + +from functools import partial +from torch import nn, Tensor +from torch.nn import functional as F +from typing import Any, Callable, List, Optional, Sequence + +from torchvision.models.utils import load_state_dict_from_url +from torchvision.models.mobilenetv2 import _make_divisible, ConvBNActivation + + +__all__ = ["MobileNetV3", "mobilenet_v3_large", "mobilenet_v3_small"] + + +# TODO: add pretrained +model_urls = { + "mobilenet_v3_large": None, + "mobilenet_v3_small": None, +} + + +class Identity(nn.Module): + + def __init__(self, inplace: bool = False): + super().__init__() + self.inplace = inplace + + def forward(self, input: Tensor) -> Tensor: + return input + + +class SqueezeExcitation(nn.Module): + + def __init__(self, input_channels: int, squeeze_factor: int = 4): + super().__init__() + squeeze_channels = _make_divisible(input_channels // squeeze_factor, 8) + self.fc1 = nn.Conv2d(input_channels, squeeze_channels, 1) + self.fc2 = nn.Conv2d(squeeze_channels, input_channels, 1) + + def forward(self, input: Tensor) -> Tensor: + scale = F.adaptive_avg_pool2d(input, 1) + scale = self.fc1(scale) + scale = F.relu(scale, inplace=True) + scale = self.fc2(scale) + scale = F.hardsigmoid(scale, inplace=True) + return scale * input + + +class InvertedResidualConfig: + + def __init__(self, input_channels: int, kernel: int, expanded_channels: int, output_channels: int, use_se: bool, + activation: str, stride: int, width_mult: float): + self.input_channels = self.adjust_channels(input_channels, width_mult) + self.kernel = kernel + self.expanded_channels = self.adjust_channels(expanded_channels, width_mult) + self.output_channels = self.adjust_channels(output_channels, width_mult) + self.use_se = use_se + self.use_hs = activation == "HS" + self.stride = stride + + @staticmethod + def adjust_channels(channels: int, width_mult: float): + return _make_divisible(channels * width_mult, 8) + + +class InvertedResidual(nn.Module): + + def __init__(self, cnf: InvertedResidualConfig, norm_layer: Callable[..., nn.Module]): + super().__init__() + if not (1 <= cnf.stride <= 2): + raise ValueError('illegal stride value') + + self.use_res_connect = cnf.stride == 1 and cnf.input_channels == cnf.output_channels + + layers: List[nn.Module] = [] + activation_layer = nn.Hardswish if cnf.use_hs else nn.ReLU + + # expand + if cnf.expanded_channels != cnf.input_channels: + layers.append(ConvBNActivation(cnf.input_channels, cnf.expanded_channels, kernel_size=1, + norm_layer=norm_layer, activation_layer=activation_layer)) + + # depthwise + layers.append(ConvBNActivation(cnf.expanded_channels, cnf.expanded_channels, kernel_size=cnf.kernel, + stride=cnf.stride, groups=cnf.expanded_channels, norm_layer=norm_layer, + activation_layer=activation_layer)) + if cnf.use_se: + layers.append(SqueezeExcitation(cnf.expanded_channels)) + + # project + layers.append(ConvBNActivation(cnf.expanded_channels, cnf.output_channels, kernel_size=1, norm_layer=norm_layer, + activation_layer=Identity)) + + self.block = nn.Sequential(*layers) + + def forward(self, input: Tensor) -> Tensor: + result = self.block(input) + if self.use_res_connect: + result += input + return result + + +class MobileNetV3(nn.Module): + + def __init__( + self, + inverted_residual_setting: List[InvertedResidualConfig], + last_channel: int, + num_classes: int = 1000, + block: Optional[Callable[..., nn.Module]] = None, + norm_layer: Optional[Callable[..., nn.Module]] = None + ) -> None: + """ + MobileNet V3 main class + + Args: + inverted_residual_setting (List[InvertedResidualConfig]): Network structure + last_channel (int): The number of channels on the penultimate layer + num_classes (int): Number of classes + block (Optional[Callable[..., nn.Module]]): Module specifying inverted residual building block for mobilenet + norm_layer (Optional[Callable[..., nn.Module]]): Module specifying the normalization layer to use + """ + super().__init__() + + if not inverted_residual_setting: + raise ValueError("The inverted_residual_setting should not be empty") + elif not (isinstance(inverted_residual_setting, Sequence) and + all([isinstance(s, InvertedResidualConfig) for s in inverted_residual_setting])): + raise TypeError("The inverted_residual_setting should be List[InvertedResidualConfig]") + + if block is None: + block = InvertedResidual + + if norm_layer is None: + norm_layer = partial(nn.BatchNorm2d, eps=0.001, momentum=0.01) + + layers: List[nn.Module] = [] + + # building first layer + firstconv_output_channels = inverted_residual_setting[0].input_channels + layers.append(ConvBNActivation(3, firstconv_output_channels, kernel_size=3, stride=2, norm_layer=norm_layer, + activation_layer=nn.Hardswish)) + + # building inverted residual blocks + for cnf in inverted_residual_setting: + layers.append(block(cnf, norm_layer)) + + # building last several layers + lastconv_input_channels = inverted_residual_setting[-1].output_channels + lastconv_output_channels = 6 * lastconv_input_channels + layers.append(ConvBNActivation(lastconv_input_channels, lastconv_output_channels, kernel_size=1, + norm_layer=norm_layer, activation_layer=nn.Hardswish)) + + self.features = nn.Sequential(*layers) + self.avgpool = nn.AdaptiveAvgPool2d(1) + self.classifier = nn.Sequential( + nn.Linear(lastconv_output_channels, last_channel), + nn.Hardswish(inplace=True), + nn.Dropout(p=0.2, inplace=True), + nn.Linear(last_channel, num_classes), + ) + + for m in self.modules(): + if isinstance(m, nn.Conv2d): + nn.init.kaiming_normal_(m.weight, mode='fan_out') + if m.bias is not None: + nn.init.zeros_(m.bias) + elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)): + nn.init.ones_(m.weight) + nn.init.zeros_(m.bias) + elif isinstance(m, nn.Linear): + nn.init.normal_(m.weight, 0, 0.01) + nn.init.zeros_(m.bias) + + def _forward_impl(self, x: Tensor) -> Tensor: + x = self.features(x) + + x = self.avgpool(x) + x = torch.flatten(x, 1) + + x = self.classifier(x) + + return x + + def forward(self, x: Tensor) -> Tensor: + return self._forward_impl(x) + + +def _mobilenet_v3( + arch: str, + inverted_residual_setting: List[InvertedResidualConfig], + last_channel: int, + pretrained: bool, + progress: bool, + **kwargs: Any +): + model = MobileNetV3(inverted_residual_setting, last_channel, **kwargs) + if pretrained: + state_dict = load_state_dict_from_url(model_urls[arch], progress=progress) + model.load_state_dict(state_dict) + return model + + +def mobilenet_v3_large(pretrained: bool = False, progress: bool = True, reduced_tail: bool = False, + **kwargs: Any) -> MobileNetV3: + """ + Constructs a large MobileNetV3 architecture from + `"Searching for MobileNetV3" `_. + + Args: + pretrained (bool): If True, returns a model pre-trained on ImageNet + progress (bool): If True, displays a progress bar of the download to stderr + reduced_tail (bool): If True, reduces the channel counts of all feature layers + between C4 and C5 by 2. It is used to reduce the channel redundancy in the + backbone for Detection and Segmentation. + """ + width_mult = 1.0 + bneck_conf = partial(InvertedResidualConfig, width_mult=width_mult) + adjust_channels = partial(InvertedResidualConfig.adjust_channels, width_mult=width_mult) + + reduce_divider = 2 if reduced_tail else 1 + + inverted_residual_setting = [ + bneck_conf(16, 3, 16, 16, False, "RE", 1), + bneck_conf(16, 3, 64, 24, False, "RE", 2), # C1 + bneck_conf(24, 3, 72, 24, False, "RE", 1), + bneck_conf(24, 5, 72, 40, True, "RE", 2), # C2 + bneck_conf(40, 5, 120, 40, True, "RE", 1), + bneck_conf(40, 5, 120, 40, True, "RE", 1), + bneck_conf(40, 3, 240, 80, False, "HS", 2), # C3 + bneck_conf(80, 3, 200, 80, False, "HS", 1), + bneck_conf(80, 3, 184, 80, False, "HS", 1), + bneck_conf(80, 3, 184, 80, False, "HS", 1), + bneck_conf(80, 3, 480, 112, True, "HS", 1), + bneck_conf(112, 3, 672, 112, True, "HS", 1), + bneck_conf(112, 5, 672, 160 // reduce_divider, True, "HS", 2), # C4 + bneck_conf(160 // reduce_divider, 5, 960 // reduce_divider, 160 // reduce_divider, True, "HS", 1), + bneck_conf(160 // reduce_divider, 5, 960 // reduce_divider, 160 // reduce_divider, True, "HS", 1), + ] + last_channel = adjust_channels(1280 // reduce_divider) # C5 + + return _mobilenet_v3("mobilenet_v3_large", inverted_residual_setting, last_channel, pretrained, progress, **kwargs) + + +def mobilenet_v3_small(pretrained: bool = False, progress: bool = True, reduced_tail: bool = False, + **kwargs: Any) -> MobileNetV3: + """ + Constructs a small MobileNetV3 architecture from + `"Searching for MobileNetV3" `_. + + Args: + pretrained (bool): If True, returns a model pre-trained on ImageNet + progress (bool): If True, displays a progress bar of the download to stderr + reduced_tail (bool): If True, reduces the channel counts of all feature layers + between C4 and C5 by 2. It is used to reduce the channel redundancy in the + backbone for Detection and Segmentation. + """ + width_mult = 1.0 + bneck_conf = partial(InvertedResidualConfig, width_mult=width_mult) + adjust_channels = partial(InvertedResidualConfig.adjust_channels, width_mult=width_mult) + + reduce_divider = 2 if reduced_tail else 1 + + inverted_residual_setting = [ + bneck_conf(16, 3, 16, 16, True, "RE", 2), # C1 + bneck_conf(16, 3, 72, 24, False, "RE", 2), # C2 + bneck_conf(24, 3, 88, 24, False, "RE", 1), + bneck_conf(24, 5, 96, 40, True, "HS", 2), # C3 + bneck_conf(40, 5, 240, 40, True, "HS", 1), + bneck_conf(40, 5, 240, 40, True, "HS", 1), + bneck_conf(40, 5, 120, 48, True, "HS", 1), + bneck_conf(48, 5, 144, 48, True, "HS", 1), + bneck_conf(48, 5, 288, 96 // reduce_divider, True, "HS", 2), # C4 + bneck_conf(96 // reduce_divider, 5, 576 // reduce_divider, 96 // reduce_divider, True, "HS", 1), + bneck_conf(96 // reduce_divider, 5, 576 // reduce_divider, 96 // reduce_divider, True, "HS", 1), + ] + last_channel = adjust_channels(1024 // reduce_divider) # C5 + + return _mobilenet_v3("mobilenet_v3_small", inverted_residual_setting, last_channel, pretrained, progress, **kwargs)