diff --git a/train_gpt2.py b/train_gpt2.py index 055f457e..a6f7d50a 100644 --- a/train_gpt2.py +++ b/train_gpt2.py @@ -106,8 +106,7 @@ def step(self): state['momentum_buffer'] = torch.zeros_like(g) buf = state['momentum_buffer'] buf.mul_(momentum).add_(g) - if group['nesterov']: - g = g.add(buf, alpha=momentum) + g = g.add(buf, alpha=momentum) if group['nesterov'] else buf g = zeropower_backend(g, steps=group['backend_steps']) g *= max(1, g.size(0)/g.size(1))**0.5 updates_flat[curr_idx:curr_idx+p.numel()] = g.flatten()