Skip to content

Commit

Permalink
Adding label smoothing on classification reference (#4335)
Browse files Browse the repository at this point in the history
* Adding label smoothing on classification reference.

* Replace underscore with dash.
  • Loading branch information
datumbox authored Sep 2, 2021
1 parent 388b19c commit f52ddb0
Showing 1 changed file with 4 additions and 1 deletion.
5 changes: 4 additions & 1 deletion references/classification/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,7 +175,7 @@ def main(args):
if args.distributed and args.sync_bn:
model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model)

criterion = nn.CrossEntropyLoss()
criterion = nn.CrossEntropyLoss(label_smoothing=args.label_smoothing)

opt_name = args.opt.lower()
if opt_name == 'sgd':
Expand Down Expand Up @@ -256,6 +256,9 @@ def get_args_parser(add_help=True):
parser.add_argument('--wd', '--weight-decay', default=1e-4, type=float,
metavar='W', help='weight decay (default: 1e-4)',
dest='weight_decay')
parser.add_argument('--label-smoothing', default=0.0, type=float,
help='label smoothing (default: 0.0)',
dest='label_smoothing')
parser.add_argument('--lr-step-size', default=30, type=int, help='decrease lr every step-size epochs')
parser.add_argument('--lr-gamma', default=0.1, type=float, help='decrease lr by a factor of lr-gamma')
parser.add_argument('--print-freq', default=10, type=int, help='print frequency')
Expand Down

0 comments on commit f52ddb0

Please sign in to comment.