Skip to content

Commit

Permalink
Using fused step for optimizers
Browse files Browse the repository at this point in the history
  • Loading branch information
ancestor-mithril committed Apr 29, 2024
1 parent 9d2fcbe commit a0716d3
Showing 1 changed file with 13 additions and 1 deletion.
14 changes: 13 additions & 1 deletion utils/optimizer.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import torch
from torch import optim
from torch.optim import Optimizer

Expand All @@ -7,4 +8,15 @@ def init_optimizer(args, parameters) -> Optimizer:
momentum = 0.9 if not hasattr(args, 'momentum') else args.momentum
weight_decay = 5e-4 if not hasattr(args, 'weight_decay') else args.weight_decay

return optim.SGD(parameters, lr=args.lr, momentum=momentum, weight_decay=weight_decay)
kwargs = {
"lr": args.lr,
"momentum": momentum,
"weight_decay": weight_decay,
}
if torch.torch_version.TorchVersion(torch.__version__) >= '2.3.0':
kwargs['fused'] = True
else:
import warnings
warnings.warn("Upgrade torch to support fused optimizers")

return optim.SGD(parameters, **kwargs)

0 comments on commit a0716d3

Please sign in to comment.