Skip to content

Commit

Permalink
Add MobileNetV3 Architecture in TorchVision (#3182)
Browse files Browse the repository at this point in the history
* Adding implementation of network architecture

* Adding rmsprop support on the train.py

* Adding auto-augment and random-erase in the training scripts.

* Adding support for reduced tail on MobileNetV3.

* Tagging blocks with comments.
  • Loading branch information
datumbox authored Jan 5, 2021
1 parent 6315358 commit aea1191
Show file tree
Hide file tree
Showing 7 changed files with 325 additions and 26 deletions.
3 changes: 2 additions & 1 deletion hubconf.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,8 @@
from torchvision.models.vgg import vgg11, vgg13, vgg16, vgg19, vgg11_bn, vgg13_bn, vgg16_bn, vgg19_bn
from torchvision.models.googlenet import googlenet
from torchvision.models.shufflenetv2 import shufflenet_v2_x0_5, shufflenet_v2_x1_0
from torchvision.models.mobilenet import mobilenet_v2
from torchvision.models.mobilenetv2 import mobilenet_v2
from torchvision.models.mobilenetv3 import mobilenet_v3_large, mobilenet_v3_small
from torchvision.models.mnasnet import mnasnet0_5, mnasnet0_75, mnasnet1_0, \
mnasnet1_3

Expand Down
49 changes: 33 additions & 16 deletions references/classification/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ def _get_cache_path(filepath):
return cache_path


def load_data(traindir, valdir, cache_dataset, distributed):
def load_data(traindir, valdir, args):
# Data loading code
print("Loading data")
normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
Expand All @@ -88,28 +88,36 @@ def load_data(traindir, valdir, cache_dataset, distributed):
print("Loading training data")
st = time.time()
cache_path = _get_cache_path(traindir)
if cache_dataset and os.path.exists(cache_path):
if args.cache_dataset and os.path.exists(cache_path):
# Attention, as the transforms are also cached!
print("Loading dataset_train from {}".format(cache_path))
dataset, _ = torch.load(cache_path)
else:
trans = [
transforms.RandomResizedCrop(224),
transforms.RandomHorizontalFlip(),
]
if args.auto_augment is not None:
aa_policy = transforms.AutoAugmentPolicy(args.auto_augment)
trans.append(transforms.AutoAugment(policy=aa_policy))
trans.extend([
transforms.ToTensor(),
normalize,
])
if args.random_erase > 0:
trans.append(transforms.RandomErasing(p=args.random_erase))
dataset = torchvision.datasets.ImageFolder(
traindir,
transforms.Compose([
transforms.RandomResizedCrop(224),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
normalize,
]))
if cache_dataset:
transforms.Compose(trans))
if args.cache_dataset:
print("Saving dataset_train to {}".format(cache_path))
utils.mkdir(os.path.dirname(cache_path))
utils.save_on_master((dataset, traindir), cache_path)
print("Took", time.time() - st)

print("Loading validation data")
cache_path = _get_cache_path(valdir)
if cache_dataset and os.path.exists(cache_path):
if args.cache_dataset and os.path.exists(cache_path):
# Attention, as the transforms are also cached!
print("Loading dataset_test from {}".format(cache_path))
dataset_test, _ = torch.load(cache_path)
Expand All @@ -122,13 +130,13 @@ def load_data(traindir, valdir, cache_dataset, distributed):
transforms.ToTensor(),
normalize,
]))
if cache_dataset:
if args.cache_dataset:
print("Saving dataset_test to {}".format(cache_path))
utils.mkdir(os.path.dirname(cache_path))
utils.save_on_master((dataset_test, valdir), cache_path)

print("Creating data loaders")
if distributed:
if args.distributed:
train_sampler = torch.utils.data.distributed.DistributedSampler(dataset)
test_sampler = torch.utils.data.distributed.DistributedSampler(dataset_test)
else:
Expand All @@ -155,8 +163,7 @@ def main(args):

train_dir = os.path.join(args.data_path, 'train')
val_dir = os.path.join(args.data_path, 'val')
dataset, dataset_test, train_sampler, test_sampler = load_data(train_dir, val_dir,
args.cache_dataset, args.distributed)
dataset, dataset_test, train_sampler, test_sampler = load_data(train_dir, val_dir, args)
data_loader = torch.utils.data.DataLoader(
dataset, batch_size=args.batch_size,
sampler=train_sampler, num_workers=args.workers, pin_memory=True)
Expand All @@ -173,8 +180,15 @@ def main(args):

criterion = nn.CrossEntropyLoss()

optimizer = torch.optim.SGD(
model.parameters(), lr=args.lr, momentum=args.momentum, weight_decay=args.weight_decay)
opt_name = args.opt.lower()
if opt_name == 'sgd':
optimizer = torch.optim.SGD(
model.parameters(), lr=args.lr, momentum=args.momentum, weight_decay=args.weight_decay)
elif opt_name == 'rmsprop':
optimizer = torch.optim.RMSprop(model.parameters(), lr=args.lr, momentum=args.momentum,
weight_decay=args.weight_decay, eps=0.0316, alpha=0.9)
else:
raise RuntimeError("Invalid optimizer {}. Only SGD and RMSprop are supported.".format(args.opt))

if args.apex:
model, optimizer = amp.initialize(model, optimizer,
Expand Down Expand Up @@ -238,6 +252,7 @@ def parse_args():
help='number of total epochs to run')
parser.add_argument('-j', '--workers', default=16, type=int, metavar='N',
help='number of data loading workers (default: 16)')
parser.add_argument('--opt', default='sgd', type=str, help='optimizer')
parser.add_argument('--lr', default=0.1, type=float, help='initial learning rate')
parser.add_argument('--momentum', default=0.9, type=float, metavar='M',
help='momentum')
Expand Down Expand Up @@ -275,6 +290,8 @@ def parse_args():
help="Use pre-trained models from the modelzoo",
action="store_true",
)
parser.add_argument('--auto-augment', default=None, help='auto augment policy (default: None)')
parser.add_argument('--random-erase', default=0.0, type=float, help='random erasing probability (default: 0.0)')

# Mixed precision training parameters
parser.add_argument('--apex', action='store_true',
Expand Down
Binary file not shown.
Binary file not shown.
17 changes: 9 additions & 8 deletions test/test_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -275,16 +275,17 @@ def test_mobilenet_v2_residual_setting(self):
out = model(x)
self.assertEqual(out.shape[-1], 1000)

def test_mobilenetv2_norm_layer(self):
model = models.__dict__["mobilenet_v2"]()
self.assertTrue(any(isinstance(x, nn.BatchNorm2d) for x in model.modules()))
def test_mobilenet_norm_layer(self):
for name in ["mobilenet_v2", "mobilenet_v3_large", "mobilenet_v3_small"]:
model = models.__dict__[name]()
self.assertTrue(any(isinstance(x, nn.BatchNorm2d) for x in model.modules()))

def get_gn(num_channels):
return nn.GroupNorm(32, num_channels)
def get_gn(num_channels):
return nn.GroupNorm(32, num_channels)

model = models.__dict__["mobilenet_v2"](norm_layer=get_gn)
self.assertFalse(any(isinstance(x, nn.BatchNorm2d) for x in model.modules()))
self.assertTrue(any(isinstance(x, nn.GroupNorm) for x in model.modules()))
model = models.__dict__[name](norm_layer=get_gn)
self.assertFalse(any(isinstance(x, nn.BatchNorm2d) for x in model.modules()))
self.assertTrue(any(isinstance(x, nn.GroupNorm) for x in model.modules()))

def test_inception_v3_eval(self):
# replacement for models.inception_v3(pretrained=True) that does not download weights
Expand Down
3 changes: 2 additions & 1 deletion torchvision/models/mobilenet.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from .mobilenetv2 import MobileNetV2, mobilenet_v2, __all__ as mv2_all
from .mobilenetv3 import MobileNetV3, mobilenet_v3_large, mobilenet_v3_small, __all__ as mv3_all

__all__ = mv2_all
__all__ = mv2_all + mv3_all
Loading

0 comments on commit aea1191

Please sign in to comment.