Skip to content

Commit

Permalink
Fix lion8b error correction with torch 2.1 (#656)
Browse files Browse the repository at this point in the history
  • Loading branch information
dblalock authored Oct 9, 2023
1 parent df945fa commit aa2ba9f
Show file tree
Hide file tree
Showing 2 changed files with 258 additions and 303 deletions.
30 changes: 19 additions & 11 deletions llmfoundry/optim/lion8b.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
from typing import Any, Callable, Dict, Iterable, Optional, Tuple

import torch
from packaging import version


class DecoupledLionW_8bit(torch.optim.Optimizer):
Expand Down Expand Up @@ -68,11 +67,6 @@ def __init__(self,
compress_state_dict: bool = False,
error_correction: bool = False,
_fused: bool = True): # XXX this flag is mostly for testing...
if version.parse(torch.__version__) >= version.parse(
'2.1.0') and error_correction:
raise RuntimeError(
'DecoupledLionW_8bit with error correction requires PyTorch <2.1.0'
)

if lr < 0.0:
raise ValueError('Invalid learning rate: {}'.format(lr))
Expand Down Expand Up @@ -138,11 +132,19 @@ def step_param(self, p: torch.Tensor, hparams: Dict[str, Any]) -> None:
mom, try_quantize=self._quantize)
need_errs = (p.dtype != torch.float32) and self._error_correction
if state.get('errors') is None and need_errs:
state['errors'] = torch.zeros(p.shape,
dtype=torch.uint8,
device=p.device)
numel = p.numel()
numel += numel % 2 # ensure even number of bytes
errors = torch.zeros(numel, dtype=torch.uint8, device=p.device)
# as of torch 2.1, FSDP can't shard ints for no reason
state['errors'] = errors.view(torch.bfloat16)
decay_factor = hparams['weight_decay']
decay_factor *= hparams['lr'] / hparams['initial_lr']
errors: Optional[torch.Tensor] = None
if 'errors' in state:
errors = state['errors']
assert errors is not None # pyright
errors = errors.view(dtype=torch.uint8)
errors = errors[:p.numel()].view(p.shape) # strip padding + reshape
_lion8b_step(momentums=state['exp_avg'],
weights=p,
grads=p.grad,
Expand All @@ -151,7 +153,7 @@ def step_param(self, p: torch.Tensor, hparams: Dict[str, Any]) -> None:
lr=hparams['lr'],
weight_decay=decay_factor,
fused=hparams['fused'],
errors=state.get('errors'))
errors=errors)

def __setstate__(self, state: Dict[str, Dict[str, Any]]) -> None:
# we override this function to quantize optimizer states when
Expand All @@ -173,7 +175,8 @@ def __setstate__(self, state: Dict[str, Dict[str, Any]]) -> None:
# we need to cast back to the correct dtype since optimizer
# load_state_dict casts to param dtype for fp params; see
# https://github.com/pytorch/pytorch/blob/a25eee1d77d93079614fab3ea4ac66e64fb2343b/torch/optim/optimizer.py#L626C7-L626C7 # noqa
errs = param_state['errors'].to(dtype=torch.uint8)
errs = param_state['errors'].to(dtype=torch.uint8).view(
torch.bfloat16)
new_state['errors'] = errs
opt_state[param_id] = new_state
super().__setstate__(state)
Expand All @@ -199,6 +202,11 @@ def state_dict(self):
qtensor.state_dict(
name='exp_avg',
allow_quantized=self._compress_state_dict))
if 'errors' in param_state:
# fsdp apparently needs the states to be the same shape
# as the params
param_state['errors'] = param_state['errors'].view(
torch.uint8).to(dtype=torch.bfloat16)
opt_state[param_id] = param_state
return d

Expand Down
Loading

0 comments on commit aa2ba9f

Please sign in to comment.