Skip to content

Commit

Permalink
Add MobileNetV3 architecture for Classification (#3252)
Browse files Browse the repository at this point in the history
* Add MobileNetV3 Architecture in TorchVision (#3182)

* Adding implementation of network architecture

* Adding rmsprop support on the train.py

* Adding auto-augment and random-erase in the training scripts.

* Adding support for reduced tail on MobileNetV3.

* Tagging blocks with comments.

* Adding documentation, pre-trained model URL and a minor refactoring.

* Handling better untrained supported models.
  • Loading branch information
datumbox authored Jan 14, 2021
1 parent 8ebfd2f commit 7bf6e7b
Show file tree
Hide file tree
Showing 10 changed files with 357 additions and 30 deletions.
20 changes: 16 additions & 4 deletions docs/source/models.rst
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,8 @@ architectures for image classification:
- `Inception`_ v3
- `GoogLeNet`_
- `ShuffleNet`_ v2
- `MobileNet`_ v2
- `MobileNetV2`_
- `MobileNetV3`_
- `ResNeXt`_
- `Wide ResNet`_
- `MNASNet`_
Expand All @@ -40,7 +41,9 @@ You can construct a model with random weights by calling its constructor:
inception = models.inception_v3()
googlenet = models.googlenet()
shufflenet = models.shufflenet_v2_x1_0()
mobilenet = models.mobilenet_v2()
mobilenet_v2 = models.mobilenet_v2()
mobilenet_v3_large = models.mobilenet_v3_large()
mobilenet_v3_small = models.mobilenet_v3_small()
resnext50_32x4d = models.resnext50_32x4d()
wide_resnet50_2 = models.wide_resnet50_2()
mnasnet = models.mnasnet1_0()
Expand All @@ -59,7 +62,8 @@ These can be constructed by passing ``pretrained=True``:
inception = models.inception_v3(pretrained=True)
googlenet = models.googlenet(pretrained=True)
shufflenet = models.shufflenet_v2_x1_0(pretrained=True)
mobilenet = models.mobilenet_v2(pretrained=True)
mobilenet_v2 = models.mobilenet_v2(pretrained=True)
mobilenet_v3_large = models.mobilenet_v3_large(pretrained=True)
resnext50_32x4d = models.resnext50_32x4d(pretrained=True)
wide_resnet50_2 = models.wide_resnet50_2(pretrained=True)
mnasnet = models.mnasnet1_0(pretrained=True)
Expand Down Expand Up @@ -137,6 +141,7 @@ Inception v3 22.55 6.44
GoogleNet 30.22 10.47
ShuffleNet V2 30.64 11.68
MobileNet V2 28.12 9.71
MobileNet V3 Large 25.96 8.66
ResNeXt-50-32x4d 22.38 6.30
ResNeXt-101-32x8d 20.69 5.47
Wide ResNet-50-2 21.49 5.91
Expand All @@ -153,7 +158,8 @@ MNASNet 1.0 26.49 8.456
.. _Inception: https://arxiv.org/abs/1512.00567
.. _GoogLeNet: https://arxiv.org/abs/1409.4842
.. _ShuffleNet: https://arxiv.org/abs/1807.11164
.. _MobileNet: https://arxiv.org/abs/1801.04381
.. _MobileNetV2: https://arxiv.org/abs/1801.04381
.. _MobileNetV3: https://arxiv.org/abs/1905.02244
.. _ResNeXt: https://arxiv.org/abs/1611.05431
.. _MNASNet: https://arxiv.org/abs/1807.11626

Expand Down Expand Up @@ -231,6 +237,12 @@ MobileNet v2

.. autofunction:: mobilenet_v2

MobileNet v3
-------------

.. autofunction:: mobilenet_v3_large
.. autofunction:: mobilenet_v3_small

ResNext
-------

Expand Down
3 changes: 2 additions & 1 deletion hubconf.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,8 @@
from torchvision.models.vgg import vgg11, vgg13, vgg16, vgg19, vgg11_bn, vgg13_bn, vgg16_bn, vgg19_bn
from torchvision.models.googlenet import googlenet
from torchvision.models.shufflenetv2 import shufflenet_v2_x0_5, shufflenet_v2_x1_0
from torchvision.models.mobilenet import mobilenet_v2
from torchvision.models.mobilenetv2 import mobilenet_v2
from torchvision.models.mobilenetv3 import mobilenet_v3_large, mobilenet_v3_small
from torchvision.models.mnasnet import mnasnet0_5, mnasnet0_75, mnasnet1_0, \
mnasnet1_3

Expand Down
10 changes: 10 additions & 0 deletions references/classification/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,16 @@ python -m torch.distributed.launch --nproc_per_node=8 --use_env train.py\
--lr-step-size 1 --lr-gamma 0.98
```


### MobileNetV3 Large
```
python -m torch.distributed.launch --nproc_per_node=8 --use_env train.py\
--model mobilenet_v3_large --epochs 600 --opt rmsprop --batch-size 128 --lr 0.064\
--wd 0.00001 --lr-step-size 2 --lr-gamma 0.973 --auto-augment imagenet --random-erase 0.2
```

Then we averaged the parameters of the last 3 checkpoints that improved the Acc@1. See [#3182](https://github.com/pytorch/vision/pull/3182) for details.

## Mixed precision training
Automatic Mixed Precision (AMP) training on GPU for Pytorch can be enabled with the [NVIDIA Apex extension](https://github.com/NVIDIA/apex).

Expand Down
49 changes: 33 additions & 16 deletions references/classification/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ def _get_cache_path(filepath):
return cache_path


def load_data(traindir, valdir, cache_dataset, distributed):
def load_data(traindir, valdir, args):
# Data loading code
print("Loading data")
normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
Expand All @@ -88,28 +88,36 @@ def load_data(traindir, valdir, cache_dataset, distributed):
print("Loading training data")
st = time.time()
cache_path = _get_cache_path(traindir)
if cache_dataset and os.path.exists(cache_path):
if args.cache_dataset and os.path.exists(cache_path):
# Attention, as the transforms are also cached!
print("Loading dataset_train from {}".format(cache_path))
dataset, _ = torch.load(cache_path)
else:
trans = [
transforms.RandomResizedCrop(224),
transforms.RandomHorizontalFlip(),
]
if args.auto_augment is not None:
aa_policy = transforms.AutoAugmentPolicy(args.auto_augment)
trans.append(transforms.AutoAugment(policy=aa_policy))
trans.extend([
transforms.ToTensor(),
normalize,
])
if args.random_erase > 0:
trans.append(transforms.RandomErasing(p=args.random_erase))
dataset = torchvision.datasets.ImageFolder(
traindir,
transforms.Compose([
transforms.RandomResizedCrop(224),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
normalize,
]))
if cache_dataset:
transforms.Compose(trans))
if args.cache_dataset:
print("Saving dataset_train to {}".format(cache_path))
utils.mkdir(os.path.dirname(cache_path))
utils.save_on_master((dataset, traindir), cache_path)
print("Took", time.time() - st)

print("Loading validation data")
cache_path = _get_cache_path(valdir)
if cache_dataset and os.path.exists(cache_path):
if args.cache_dataset and os.path.exists(cache_path):
# Attention, as the transforms are also cached!
print("Loading dataset_test from {}".format(cache_path))
dataset_test, _ = torch.load(cache_path)
Expand All @@ -122,13 +130,13 @@ def load_data(traindir, valdir, cache_dataset, distributed):
transforms.ToTensor(),
normalize,
]))
if cache_dataset:
if args.cache_dataset:
print("Saving dataset_test to {}".format(cache_path))
utils.mkdir(os.path.dirname(cache_path))
utils.save_on_master((dataset_test, valdir), cache_path)

print("Creating data loaders")
if distributed:
if args.distributed:
train_sampler = torch.utils.data.distributed.DistributedSampler(dataset)
test_sampler = torch.utils.data.distributed.DistributedSampler(dataset_test)
else:
Expand All @@ -155,8 +163,7 @@ def main(args):

train_dir = os.path.join(args.data_path, 'train')
val_dir = os.path.join(args.data_path, 'val')
dataset, dataset_test, train_sampler, test_sampler = load_data(train_dir, val_dir,
args.cache_dataset, args.distributed)
dataset, dataset_test, train_sampler, test_sampler = load_data(train_dir, val_dir, args)
data_loader = torch.utils.data.DataLoader(
dataset, batch_size=args.batch_size,
sampler=train_sampler, num_workers=args.workers, pin_memory=True)
Expand All @@ -173,8 +180,15 @@ def main(args):

criterion = nn.CrossEntropyLoss()

optimizer = torch.optim.SGD(
model.parameters(), lr=args.lr, momentum=args.momentum, weight_decay=args.weight_decay)
opt_name = args.opt.lower()
if opt_name == 'sgd':
optimizer = torch.optim.SGD(
model.parameters(), lr=args.lr, momentum=args.momentum, weight_decay=args.weight_decay)
elif opt_name == 'rmsprop':
optimizer = torch.optim.RMSprop(model.parameters(), lr=args.lr, momentum=args.momentum,
weight_decay=args.weight_decay, eps=0.0316, alpha=0.9)
else:
raise RuntimeError("Invalid optimizer {}. Only SGD and RMSprop are supported.".format(args.opt))

if args.apex:
model, optimizer = amp.initialize(model, optimizer,
Expand Down Expand Up @@ -238,6 +252,7 @@ def parse_args():
help='number of total epochs to run')
parser.add_argument('-j', '--workers', default=16, type=int, metavar='N',
help='number of data loading workers (default: 16)')
parser.add_argument('--opt', default='sgd', type=str, help='optimizer')
parser.add_argument('--lr', default=0.1, type=float, help='initial learning rate')
parser.add_argument('--momentum', default=0.9, type=float, metavar='M',
help='momentum')
Expand Down Expand Up @@ -275,6 +290,8 @@ def parse_args():
help="Use pre-trained models from the modelzoo",
action="store_true",
)
parser.add_argument('--auto-augment', default=None, help='auto augment policy (default: None)')
parser.add_argument('--random-erase', default=0.0, type=float, help='random erasing probability (default: 0.0)')

# Mixed precision training parameters
parser.add_argument('--apex', action='store_true',
Expand Down
Binary file not shown.
Binary file not shown.
17 changes: 9 additions & 8 deletions test/test_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -275,16 +275,17 @@ def test_mobilenet_v2_residual_setting(self):
out = model(x)
self.assertEqual(out.shape[-1], 1000)

def test_mobilenetv2_norm_layer(self):
model = models.__dict__["mobilenet_v2"]()
self.assertTrue(any(isinstance(x, nn.BatchNorm2d) for x in model.modules()))
def test_mobilenet_norm_layer(self):
for name in ["mobilenet_v2", "mobilenet_v3_large", "mobilenet_v3_small"]:
model = models.__dict__[name]()
self.assertTrue(any(isinstance(x, nn.BatchNorm2d) for x in model.modules()))

def get_gn(num_channels):
return nn.GroupNorm(32, num_channels)
def get_gn(num_channels):
return nn.GroupNorm(32, num_channels)

model = models.__dict__["mobilenet_v2"](norm_layer=get_gn)
self.assertFalse(any(isinstance(x, nn.BatchNorm2d) for x in model.modules()))
self.assertTrue(any(isinstance(x, nn.GroupNorm) for x in model.modules()))
model = models.__dict__[name](norm_layer=get_gn)
self.assertFalse(any(isinstance(x, nn.BatchNorm2d) for x in model.modules()))
self.assertTrue(any(isinstance(x, nn.GroupNorm) for x in model.modules()))

def test_inception_v3_eval(self):
# replacement for models.inception_v3(pretrained=True) that does not download weights
Expand Down
3 changes: 2 additions & 1 deletion torchvision/models/mobilenet.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from .mobilenetv2 import MobileNetV2, mobilenet_v2, __all__ as mv2_all
from .mobilenetv3 import MobileNetV3, mobilenet_v3_large, mobilenet_v3_small, __all__ as mv3_all

__all__ = mv2_all
__all__ = mv2_all + mv3_all
3 changes: 3 additions & 0 deletions torchvision/models/mobilenetv2.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ def __init__(
norm_layer(out_planes),
activation_layer(inplace=True)
)
self.out_channels = out_planes


# necessary for backwards compatibility
Expand Down Expand Up @@ -90,6 +91,8 @@ def __init__(
norm_layer(oup),
])
self.conv = nn.Sequential(*layers)
self.out_channels = oup
self.is_strided = stride > 1

def forward(self, x: Tensor) -> Tensor:
if self.use_res_connect:
Expand Down
Loading

0 comments on commit 7bf6e7b

Please sign in to comment.