Skip to content

Commit

Permalink
Update train_gpt2.py
Browse files Browse the repository at this point in the history
  • Loading branch information
KellerJordan authored Nov 28, 2024
1 parent b9e4d52 commit 9e35b93
Showing 1 changed file with 1 addition and 2 deletions.
3 changes: 1 addition & 2 deletions train_gpt2.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,8 +106,7 @@ def step(self):
state['momentum_buffer'] = torch.zeros_like(g)
buf = state['momentum_buffer']
buf.mul_(momentum).add_(g)
if group['nesterov']:
g = g.add(buf, alpha=momentum)
g = g.add(buf, alpha=momentum) if group['nesterov'] else buf
g = zeropower_backend(g, steps=group['backend_steps'])
g *= max(1, g.size(0)/g.size(1))**0.5
updates_flat[curr_idx:curr_idx+p.numel()] = g.flatten()
Expand Down

1 comment on commit 9e35b93

@Triang-jyed-driung
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is extremely obfuscated. So, after that, $g$ is

$$\left\{ \begin{aligned} & (1+m)g_t + \sum_{i=1}^t m^{i+1} g_{t-i}, & if \text{ nesterov }\\\ & g_t + \sum_{i=1}^t m^{i} g_{t-i}, & else \\\ \end{aligned} \right.$$

Please sign in to comment.