From a0716d3d15ca784e3ff6b60433f9f5f914353f9b Mon Sep 17 00:00:00 2001 From: ancestor-mithril Date: Mon, 29 Apr 2024 14:30:11 +0300 Subject: [PATCH] Using fused step for optimizers --- utils/optimizer.py | 14 +++++++++++++- 1 file changed, 13 insertions(+), 1 deletion(-) diff --git a/utils/optimizer.py b/utils/optimizer.py index bf65b94..b4a3e4a 100644 --- a/utils/optimizer.py +++ b/utils/optimizer.py @@ -1,3 +1,4 @@ +import torch from torch import optim from torch.optim import Optimizer @@ -7,4 +8,15 @@ def init_optimizer(args, parameters) -> Optimizer: momentum = 0.9 if not hasattr(args, 'momentum') else args.momentum weight_decay = 5e-4 if not hasattr(args, 'weight_decay') else args.weight_decay - return optim.SGD(parameters, lr=args.lr, momentum=momentum, weight_decay=weight_decay) + kwargs = { + "lr": args.lr, + "momentum": momentum, + "weight_decay": weight_decay, + } + if torch.torch_version.TorchVersion(torch.__version__) >= '2.3.0': + kwargs['fused'] = True + else: + import warnings + warnings.warn("Upgrade torch to support fused optimizers") + + return optim.SGD(parameters, **kwargs)