From 5198385f6f9ecef2b1f532df94aaa7d0c6de1610 Mon Sep 17 00:00:00 2001 From: Vasilis Vryniotis Date: Fri, 1 Jan 2021 12:08:22 +0000 Subject: [PATCH] Adding rmsprop support on the train.py --- references/classification/train.py | 12 ++++++++++-- 1 file changed, 10 insertions(+), 2 deletions(-) diff --git a/references/classification/train.py b/references/classification/train.py index 789bb8134ff..77f3127782e 100644 --- a/references/classification/train.py +++ b/references/classification/train.py @@ -173,8 +173,15 @@ def main(args): criterion = nn.CrossEntropyLoss() - optimizer = torch.optim.SGD( - model.parameters(), lr=args.lr, momentum=args.momentum, weight_decay=args.weight_decay) + opt_name = args.opt.lower() + if opt_name == 'sgd': + optimizer = torch.optim.SGD( + model.parameters(), lr=args.lr, momentum=args.momentum, weight_decay=args.weight_decay) + elif opt_name == 'rmsprop': + optimizer = torch.optim.RMSprop(model.parameters(), lr=args.lr, momentum=args.momentum, + weight_decay=args.weight_decay, eps=0.0316, alpha=0.9) + else: + raise RuntimeError("Invalid optimizer {}. Only SGD and RMSprop are supported.".format(args.opt)) if args.apex: model, optimizer = amp.initialize(model, optimizer, @@ -238,6 +245,7 @@ def parse_args(): help='number of total epochs to run') parser.add_argument('-j', '--workers', default=16, type=int, metavar='N', help='number of data loading workers (default: 16)') + parser.add_argument('--opt', default='sgd', type=str, help='optimizer') parser.add_argument('--lr', default=0.1, type=float, help='initial learning rate') parser.add_argument('--momentum', default=0.9, type=float, metavar='M', help='momentum')