Skip to content

Latest commit

 

History

History
87 lines (71 loc) · 3.4 KB

README-WSAM.md

File metadata and controls

87 lines (71 loc) · 3.4 KB

WSAM Optimizer

Weighted Sharpness as a Regularization Term

KDD arXiv

We present PyTorch code for Sharpness-Aware Minimization Revisited: Weighted Sharpness as a Regularization Term, KDD'23. The code is based on https://github.com/davda54/sam.

Deep Neural Networks (DNNs) generalization is known to be closely related to the flatness of minima, leading to the development of Sharpness-Aware Minimization (SAM) for seeking flatter minima and better generalization. We propose a more general method, called WSAM, by incorporating sharpness as a regularization term. WSAM can achieve improved generalization, or is at least highly competitive, compared to the vanilla optimizer, SAM and its variants.

WSAM can achieve different minima by choosing
different 𝛾.

WSAM can achieve different (flatter) minima by choosing different 𝛾.

Usage

Similar to SAM, WSAM can be used in a two-step manner or with a single closure-based function.

from atorch.optimizers.wsam import WeightedSAM
from atorch.optimizers.utils import enable_running_stats, disable_running_stats

...

model = YourModel()
base_optimizer = torch.optim.SGD(model.parameters(), lr=0.001) # initialize the base optimizer
optimizer = WeightedSAM(model, base_optimizer, rho=0.05, gamma=0.9, adaptive=False, decouple=True, max_norm=None)
...
# 1. two-step method
for input, output in data:
  enable_running_stats(model)
  with model.no_sync():
    # first forward-backward pass
    loss = loss_function(output, model(input))  # use this loss for any training statistics
    loss.backward()
  optimizer.first_step(zero_grad=True)
  disable_running_stats(model)

  # second forward-backward pass
  loss_function(output, model(input)).backward()  # make sure to do a full forward pass
  optimizer.second_step(zero_grad=True)
...
# 2. closure-based method
for input, output in data:
  def closure():
    loss = loss_function(output, model(input))
    loss.backward()
    return loss

  loss = loss_function(output, model(input))
  loss.backward()
  optimizer.step(closure)
  optimizer.zero_grad()
...

Extra Notes

  • Regulatization mode: It is recommended to perform a decoupled update of the sharpness term, as used in our paper.
  • Gradient clipping: To ensure training stability, if max_norm is not None, WSAM will perform gradient clipping.
  • Gradient sync: This implementation synchronizes gradients correctly, corresponding to the m-sharpness used in the SAM paper.
  • Rho selection: If you try to reproduce ViT results from this paper, use a larger rho when having less GPUs. For more information, see this related link.