Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix lion8b error correction with torch 2.1 #656

Merged
merged 3 commits into from
Oct 9, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 19 additions & 11 deletions llmfoundry/optim/lion8b.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
from typing import Any, Callable, Dict, Iterable, Optional, Tuple

import torch
from packaging import version


class DecoupledLionW_8bit(torch.optim.Optimizer):
Expand Down Expand Up @@ -68,11 +67,6 @@ def __init__(self,
compress_state_dict: bool = False,
error_correction: bool = False,
_fused: bool = True): # XXX this flag is mostly for testing...
if version.parse(torch.__version__) >= version.parse(
'2.1.0') and error_correction:
raise RuntimeError(
'DecoupledLionW_8bit with error correction requires PyTorch <2.1.0'
)

if lr < 0.0:
raise ValueError('Invalid learning rate: {}'.format(lr))
Expand Down Expand Up @@ -138,11 +132,19 @@ def step_param(self, p: torch.Tensor, hparams: Dict[str, Any]) -> None:
mom, try_quantize=self._quantize)
need_errs = (p.dtype != torch.float32) and self._error_correction
if state.get('errors') is None and need_errs:
state['errors'] = torch.zeros(p.shape,
dtype=torch.uint8,
device=p.device)
numel = p.numel()
numel += numel % 2 # ensure even number of bytes
errors = torch.zeros(numel, dtype=torch.uint8, device=p.device)
# as of torch 2.1, FSDP can't shard ints for no reason
state['errors'] = errors.view(torch.bfloat16)
decay_factor = hparams['weight_decay']
decay_factor *= hparams['lr'] / hparams['initial_lr']
errors: Optional[torch.Tensor] = None
if 'errors' in state:
errors = state['errors']
assert errors is not None # pyright
errors = errors.view(dtype=torch.uint8)
errors = errors[:p.numel()].view(p.shape) # strip padding + reshape
_lion8b_step(momentums=state['exp_avg'],
weights=p,
grads=p.grad,
Expand All @@ -151,7 +153,7 @@ def step_param(self, p: torch.Tensor, hparams: Dict[str, Any]) -> None:
lr=hparams['lr'],
weight_decay=decay_factor,
fused=hparams['fused'],
errors=state.get('errors'))
errors=errors)

def __setstate__(self, state: Dict[str, Dict[str, Any]]) -> None:
# we override this function to quantize optimizer states when
Expand All @@ -173,7 +175,8 @@ def __setstate__(self, state: Dict[str, Dict[str, Any]]) -> None:
# we need to cast back to the correct dtype since optimizer
# load_state_dict casts to param dtype for fp params; see
# https://github.com/pytorch/pytorch/blob/a25eee1d77d93079614fab3ea4ac66e64fb2343b/torch/optim/optimizer.py#L626C7-L626C7 # noqa
errs = param_state['errors'].to(dtype=torch.uint8)
errs = param_state['errors'].to(dtype=torch.uint8).view(
torch.bfloat16)
new_state['errors'] = errs
opt_state[param_id] = new_state
super().__setstate__(state)
Expand All @@ -199,6 +202,11 @@ def state_dict(self):
qtensor.state_dict(
name='exp_avg',
allow_quantized=self._compress_state_dict))
if 'errors' in param_state:
# fsdp apparently needs the states to be the same shape
# as the params
param_state['errors'] = param_state['errors'].view(
torch.uint8).to(dtype=torch.bfloat16)
opt_state[param_id] = param_state
return d

Expand Down
Loading
Loading