From 585374c9efebcb07b6da81e2501bd4a1848e7fef Mon Sep 17 00:00:00 2001 From: Vasilis Vryniotis Date: Sun, 20 Dec 2020 20:26:05 +0000 Subject: [PATCH] Adding rmsprop support on the train.py --- references/classification/train.py | 12 ++++++++++-- 1 file changed, 10 insertions(+), 2 deletions(-) diff --git a/references/classification/train.py b/references/classification/train.py index 789bb8134ff..5f4a6a24544 100644 --- a/references/classification/train.py +++ b/references/classification/train.py @@ -173,8 +173,15 @@ def main(args): criterion = nn.CrossEntropyLoss() - optimizer = torch.optim.SGD( - model.parameters(), lr=args.lr, momentum=args.momentum, weight_decay=args.weight_decay) + opt_name = args.opt.lower() + if opt_name == 'sgd': + optimizer = torch.optim.SGD( + model.parameters(), lr=args.lr, momentum=args.momentum, weight_decay=args.weight_decay) + elif opt_name == 'rmsprop': + optimizer = torch.optim.RMSprop( + model.parameters(), lr=args.lr, momentum=args.momentum, weight_decay=args.weight_decay) + else: + raise RuntimeError("Invalid optimizer {}. Only SGD and RMSprop are supported.".format(args.opt)) if args.apex: model, optimizer = amp.initialize(model, optimizer, @@ -238,6 +245,7 @@ def parse_args(): help='number of total epochs to run') parser.add_argument('-j', '--workers', default=16, type=int, metavar='N', help='number of data loading workers (default: 16)') + parser.add_argument('--opt', default='sgd', type=str, help='optimizer') parser.add_argument('--lr', default=0.1, type=float, help='initial learning rate') parser.add_argument('--momentum', default=0.9, type=float, metavar='M', help='momentum')