From c7c5ae6a6bf5b7d1ef2c8e14c53a52d5994aeb89 Mon Sep 17 00:00:00 2001 From: root Date: Thu, 10 Aug 2023 06:59:36 +0000 Subject: [PATCH 01/23] add decoupled lion8b optimizer + tests + builder option + deps --- llmfoundry/optim/__init__.py | 1 + llmfoundry/optim/lion8b.py | 334 +++++++++++++++++++++++++++ llmfoundry/utils/builders.py | 5 +- setup.py | 1 + tests/test_lion8b.py | 422 +++++++++++++++++++++++++++++++++++ 5 files changed, 762 insertions(+), 1 deletion(-) create mode 100644 llmfoundry/optim/lion8b.py create mode 100644 tests/test_lion8b.py diff --git a/llmfoundry/optim/__init__.py b/llmfoundry/optim/__init__.py index ea996b3305..7936f9a8c7 100644 --- a/llmfoundry/optim/__init__.py +++ b/llmfoundry/optim/__init__.py @@ -3,5 +3,6 @@ from llmfoundry.optim.adaptive_lion import DecoupledAdaLRLion, DecoupledClipLion from llmfoundry.optim.lion import DecoupledLionW +from llmfoundry.optim.lion8b import DecoupledLionW_8bit __all__ = ['DecoupledLionW', 'DecoupledClipLion', 'DecoupledAdaLRLion'] diff --git a/llmfoundry/optim/lion8b.py b/llmfoundry/optim/lion8b.py new file mode 100644 index 0000000000..3c11e4e282 --- /dev/null +++ b/llmfoundry/optim/lion8b.py @@ -0,0 +1,334 @@ +from typing import Any, Callable, Dict, Iterable, Optional, Tuple + +import torch + +_KEY_MOMENTUM = 'exp_avg' +_KEY_ERRORS = 'errors' + + +class DecoupledLionW_8bit(torch.optim.Optimizer): + """LION optimizer with ~8 bits of state per parameter. + + This optimizer is a drop-in replacement for our regular LION optimizer + with decoupled weight decay, but uses less memory, writes smaller + checkpoints, and offers almost-numerically-identical convergence. + + Its state saved per parameter is just an int8, though there are auxiliary + scaling factors that bring the total memory per parameter to ~8.5 bits. + The exact quantization scheme is considered an implementation detail + and may change. + + When training on CPUs, however, no quantization will actually take place. + + See the LION paper (https://arxiv.org/abs/2302.06675) for details about + the algorithm itself. + + Args: + params: iterable of parameters to optimize or dicts defining + parameter groups + lr: learning rate (Default: 1e-3) + betas: two coefficients between 0 and 1 used to combine the current + gradients and the momentum. The first coefficient is the weight + of the gradient when computing the update. The second is the + weight of the gradient when computing the new momentum. + (Default: .9, .99) + weight decay: Weights are multiplied by 1 - `weight_decay` after + each optimizer step. Note that we use decoupled weight decay, + meaning that this decay does not contribute to the momentum. + (Default: 0.) + l2_penalty: adds `l2_penalty * param` to the gradient at the + start of the optimizer step. This term *is* added to the momentum. + compress_state_dict: if True, this optimizer's `state_dict` will + include quantized optimizer states. Otherwise, the optimizer + states are converted to bfloat16 Tensors matching the shapes of + their corresponding parameters. The former uses ~8.5 bits per + parameter while the latter uses 16 bits per parameter. However, + the former is less thoroughly tested and will not work with + FSDP or other weight sharding approaches. + quantize: If False, optimizer states will not actually be quantized. + This option is available so that one can easily debug whether + the quantization is causing any convergence issues. Quantization + is always disabled when training without a CUDA device. + error_correction: If True, float16 and bfloat16 parameters will be + given an extra state variable, "errors." This tensor will be + of the same shape as the parameter but of dtype uint8. This + auxiliary variable is used to better approximate float32 updates + by retaining information across optimizer steps. + """ + + def __init__( + self, + params: Iterable[torch.Tensor], + lr: float = 1e-3, + betas: Tuple[float, float] = (0.9, 0.99), + weight_decay: float = 0, + compress_state_dict: bool = False, + quantize: bool = True, + _fused: bool = True, # XXX this flag is mostly for testing... + error_correction: bool = False): + if not 0.0 <= lr: + raise ValueError('Invalid learning rate: {}'.format(lr)) + # if not 0.0 < betas[0] < 1.0: + if not 0.0 <= betas[0] <= 1.0: + raise ValueError('Invalid beta parameter at index 0: {}'.format( + betas[0])) + if not 0.0 <= betas[1] <= 1.0: + raise ValueError('Invalid beta parameter at index 1: {}'.format( + betas[1])) + if not 0.0 <= weight_decay: + raise ValueError( + 'Invalid weight_decay value: {}'.format(weight_decay)) + + self._quantize = quantize and torch.cuda.is_available() + self._compress_state_dict = compress_state_dict + self._error_correction = error_correction + if error_correction and not _fused: + raise NotImplementedError( + "Error correction requires fused kernels.") + defaults = dict(lr=lr, + initial_lr=lr, + betas=betas, + weight_decay=weight_decay, + fused=_fused) + super().__init__(params, defaults) + + @torch.no_grad() + def step(self, closure: Optional[Callable] = None): + loss = None + if closure is not None: + with torch.enable_grad(): + loss = closure() + + for group in self.param_groups: + for p in group['params']: + self.step_param(p, group) + + return loss + + def step_param(self, p: torch.Tensor, hparams: Dict[str, Any]) -> None: + if not p.requires_grad or p.grad is None: + return + if self._quantize and not p.is_cuda: + raise NotImplementedError( + f"Can't use quantization with param on {p.device} " + + f"({p.shape}, {p.dtype}). If you need " + + "to use DecoupledLionW_8bit without a CUDA device, try " + + "creating this optimizer with quantize=False.") + state = self.state[p] # type:ignore using tensor as key + if _KEY_MOMENTUM not in state: + mom = torch.zeros_like(p) + state[_KEY_MOMENTUM] = _MaybeQuantizedTensor( + mom, try_quantize=self._quantize) + need_errs = (p.dtype != torch.float32) and self._error_correction + if state.get(_KEY_ERRORS) is None and need_errs: + state[_KEY_ERRORS] = torch.zeros(p.shape, + dtype=torch.uint8, + device=p.device) + decay_factor = hparams['weight_decay'] + decay_factor *= hparams['lr'] / hparams['initial_lr'] + _lion8b_step(momentums=state[_KEY_MOMENTUM], + weights=p, + grads=p.grad, + beta1=hparams['betas'][0], + beta2=hparams['betas'][1], + lr=hparams['lr'], + weight_decay=decay_factor, + fused=hparams['fused'], + errors=state.get(_KEY_ERRORS)) + + def __setstate__(self, state: Dict[str, Dict[str, Any]]) -> None: + # we override this function to quantize optimizer states when + # loading a state dict + opt_state, _ = state.values() # other val is param_groups + for param_id in opt_state: + param_state = opt_state[param_id] + new_state = {} + if _KEY_MOMENTUM in param_state: + qtensor = _MaybeQuantizedTensor(None, + try_quantize=self._quantize) + qtensor.load_state_dict(param_state[_KEY_MOMENTUM]) + new_state[_KEY_MOMENTUM] = qtensor + if self._error_correction and _KEY_ERRORS in param_state: + # we need to cast back to the correct dtype since optimizer + # load_state_dict casts to param dtype for fp params; see + # https://github.com/pytorch/pytorch/blob/a25eee1d77d93079614fab3ea4ac66e64fb2343b/torch/optim/optimizer.py#L626C7-L626C7 # noqa + errs = param_state[_KEY_ERRORS].to(dtype=torch.uint8) + new_state[_KEY_ERRORS] = errs + opt_state[param_id] = new_state + super().__setstate__(state) + + def state_dict(self): + # If the user hasn't opted into storing compressed state dicts + # we have to make sure our states are regular torch.Tensors. This + # is mostly needed to make FSDP happy in the case that we want to + # resume training with a number of devices where + # (param numel / device count) % quantization group size != 0 + # for any param. + d = super().state_dict() + opt_state, _ = d.values() # other val is param_groups + for param_id in opt_state: + # make a copy so that we don't mutate our self.state; opt_state + # isn't the same as self.state, but its consituent dicts are + # the same as those in self.state + param_state = {k: v for k, v in opt_state[param_id].items()} + if _KEY_MOMENTUM in param_state: + qtensor = param_state[_KEY_MOMENTUM] + assert isinstance(qtensor, _MaybeQuantizedTensor) # pyright + param_state[_KEY_MOMENTUM] = qtensor.state_dict( + allow_quantized=self._compress_state_dict) + opt_state[param_id] = param_state + return d + + +class _MaybeQuantizedTensor: + """Helper class so 8b LION doesn't have to know quantization details. + + Important points about this class: + * It handles CPU tensors not being quantized + * It knows how to save + load state dicts, handling both the quantized + and not quantized cases + * It implements some parts of the torch.Tensor interface that we need, + but is not intended to be a full torch.Tensor replacement + """ + + def __init__(self, data: Optional[torch.Tensor], try_quantize: bool = True): + super().__init__() + self.data: Optional[torch.Tensor] = None + self.quantized: Optional[torch.Tensor] = None + self.scales: Optional[torch.Tensor] = None + self._try_quantize = try_quantize and torch.cuda.is_available() + + # conditionally import CUDA kernels + self._f_encode = None + self._f_decode = None + if self._try_quantize: + from turbo import dequantize8b, quantize8b + self._f_encode = quantize8b + self._f_decode = dequantize8b + + if data is not None: + self.set_data(data) + + def state_dict(self, + allow_quantized: bool = False) -> Dict[str, torch.Tensor]: + if self.is_quantized() and allow_quantized: + assert self.quantized is not None # pyright + assert self.scales is not None # pyright + return {'quantized': self.quantized, 'scales': self.scales} + return {'data': self.materialize().to(dtype=torch.bfloat16)} + + def load_state_dict(self, d: Dict[str, torch.Tensor]) -> None: + if 'data' in d: + if len(d) != 1: + raise ValueError('If state dict specifies "data", it must not' + + f'specify other keys. Got {list(d.keys())}') + self.set_data(d['data']) + return + + self.quantized = d['quantized'].to(dtype=torch.int8) + self.scales = d['scales'].to(dtype=torch.float16) + + def set_data(self, data: torch.Tensor) -> None: + if not (self._try_quantize and data.is_cuda): + self.data = data.to(dtype=torch.float32) + self.quantized = None + self.scales = None + else: + self.data = None + assert self._f_encode is not None # pyright + self.quantized, self.scales = self._f_encode(data) + + def is_quantized(self) -> bool: + return self.data is None + + def materialize(self) -> torch.Tensor: + if not self.is_quantized(): + assert self.data is not None # pyright + return self.data + assert self._f_decode is not None # pyright + assert self.quantized is not None # pyright + assert self.scales is not None # pyright + return self._f_decode(self.quantized, self.scales) + + @property # property to mirror Tensor interface + def is_cuda(self) -> bool: + if self.is_quantized(): + assert self.quantized is not None # pyright + return self.quantized.is_cuda + assert self.data is not None # pyright + return self.data.is_cuda + + @property # property to mirror Tensor interface + def shape(self) -> Tuple[int]: + if self.is_quantized(): + assert self.quantized is not None # pyright + return self.quantized.shape + assert self.data is not None # pyright + return self.data.shape + + def numel(self) -> int: + if self.is_quantized(): + assert self.quantized is not None # pyright + return self.quantized.numel() + assert self.data is not None # pyright + return self.data.numel() + + def __repr__(self): + return (f'{self.__class__.__name__} quantized={self.is_quantized()} ' + + f'shape={self.shape}') + + +def lion_step_unfused(grads: torch.Tensor, + weights: torch.Tensor, + momentums: torch.Tensor, + lr: float, + beta1: float, + beta2: float, + weight_decay: float = 0) -> torch.Tensor: + # f32 cast to match fused impl + for compatibility with f32 grads or weights + momentums = momentums.to(torch.float32) + grads = grads.to(dtype=torch.float32) + + update = momentums.lerp(grads, 1 - beta1).sign_() + if weight_decay > 0: + weights.mul_(1. - weight_decay) + + weights.add_(update, alpha=-lr) + momentums.lerp_(grads, 1. - beta2) + return momentums # f32 upcast means not necessarily modified in place + + +def _lion8b_step(grads: torch.Tensor, + weights: torch.Tensor, + momentums: _MaybeQuantizedTensor, + lr: float, + beta1: float, + beta2: float, + weight_decay: float = 0, + errors: Optional[torch.Tensor] = None, + fused: bool = True) -> None: + + if momentums.is_quantized() and fused: + from turbo import lion8b_step as lion8b_step_fused + + assert momentums.quantized is not None # pyright + assert momentums.scales is not None # pyright + return lion8b_step_fused(grads=grads, + weights=weights, + momentums=momentums.quantized, + scales=momentums.scales, + lr=lr, + beta1=beta1, + beta2=beta2, + weight_decay=weight_decay, + errors=errors) + + momentums_float = momentums.materialize() + new_momentums = lion_step_unfused(grads=grads, + weights=weights, + momentums=momentums_float, + lr=lr, + beta1=beta1, + beta2=beta2, + weight_decay=weight_decay) + momentums.set_data(new_momentums) diff --git a/llmfoundry/utils/builders.py b/llmfoundry/utils/builders.py index bf16ddc663..a00ee72d8a 100644 --- a/llmfoundry/utils/builders.py +++ b/llmfoundry/utils/builders.py @@ -26,7 +26,7 @@ LayerFreezing, MonolithicCheckpointSaver, ScheduledGarbageCollector) from llmfoundry.optim import (DecoupledAdaLRLion, DecoupledClipLion, - DecoupledLionW) + DecoupledLionW, DecoupledLionW_8bit) Tokenizer = Union[PreTrainedTokenizer, PreTrainedTokenizerFast] @@ -115,6 +115,9 @@ def build_optimizer(cfg, model): timeout=cfg.timeout, lr_penalty=cfg.lr_penalty, min_scale=cfg.min_scale) + elif cfg.name.lower() == 'decoupled_lionw_8b': + kwargs = {k: v for k, v in cfg.items() if k != 'name'} + return DecoupledLionW_8bit(model.parameters(), **kwargs) else: raise ValueError(f'Not sure how to build optimizer: {cfg.name}') diff --git a/setup.py b/setup.py index 43b7939fba..141b245891 100644 --- a/setup.py +++ b/setup.py @@ -50,6 +50,7 @@ 'mosaicml[libcloud,nlp,wandb]>=0.15.0,<0.16', 'accelerate>=0.20,<0.21', # for HF inference `device_map` 'mosaicml-streaming>=0.5.1,<0.6', + 'mosaicml-turbo>=0.0.2,<0.1', 'torch>=1.13.1,<=2.0.1', 'datasets==2.10.1', 'sentencepiece==0.1.97', diff --git a/tests/test_lion8b.py b/tests/test_lion8b.py new file mode 100644 index 0000000000..24887b2b0b --- /dev/null +++ b/tests/test_lion8b.py @@ -0,0 +1,422 @@ +# Copyright 2022 MosaicML LLM Foundry authors +# SPDX-License-Identifier: Apache-2.0 + +import time +import warnings + +import numpy as np +import pytest +import torch + +from llmfoundry.optim import DecoupledLionW_8bit as Lion8bit + +warnings.filterwarnings('ignore') + +_MANY_PARAM_SHAPES = [(1, 1), (1, 2), (17, 23), (64, 32)] +_FLOAT_DTYPES = [torch.bfloat16, torch.float16, torch.float32] + +np.set_printoptions(linewidth=160, formatter={'float': lambda f: f'{f:5.3f}'}) + + +@pytest.mark.gpu +@pytest.mark.parametrize('N,D', _MANY_PARAM_SHAPES) +@pytest.mark.parametrize('dtype', _FLOAT_DTYPES) +@pytest.mark.parametrize('fused,use_errors', [(False, False), (True, False), + (True, True)]) +def test_modifies_weights_and_momentums(N: int, D: int, dtype: torch.dtype, + fused: bool, use_errors: bool) -> None: + device = 'cuda' + torch.manual_seed(123) + X = torch.randn((N, D), device=device, requires_grad=False, dtype=dtype) + W = torch.randn((D, D), device=device, requires_grad=True, dtype=dtype) + W_orig = W.detach().clone() + + opt = Lion8bit([W], + lr=1.0, + _fused=fused, + betas=(.75, .75), + weight_decay=.2, + error_correction=use_errors) + + Y = X @ W + loss = Y.sum() + loss.backward() + torch.testing.assert_close(W_orig, W) # no weight modification yet + opt.step() + opt.zero_grad() + + with pytest.raises(AssertionError): # opt step modified the weights + torch.testing.assert_close(W_orig, W) + + # every momentum should be nonzero with infinite precision, but + # might be zero after quantization + param_state = opt.state[W] # type:ignore using tensor as key + momentum = param_state['exp_avg'].materialize() + assert momentum.shape == (D, D) + momentum = momentum.ravel() + assert momentum is not None + if momentum.numel() == 1: + assert momentum.item() != 0 + else: + assert torch.std(momentum).item() > 0 + + +@pytest.mark.gpu +@pytest.mark.parametrize('N,D', _MANY_PARAM_SHAPES) +@pytest.mark.parametrize('device', ['cpu', 'cuda']) +@pytest.mark.parametrize('dtype', _FLOAT_DTYPES) +@pytest.mark.parametrize('weight_decay', [0, .1]) +@pytest.mark.parametrize('fused,use_errors', [(False, False), (True, False), + (True, True)]) +def test_changes_with_zero_grads(N: int, D: int, device: str, + dtype: torch.dtype, weight_decay: float, + fused: bool, use_errors: bool) -> None: + if (dtype != torch.float32) and device == 'cpu': + return + torch.manual_seed(123) + W = torch.rand((D, D), device=device, requires_grad=True) + with torch.no_grad(): + W += torch.sign(W) # bound away from zero so decay won't change sign + W_orig = W.detach().clone() + + opt = Lion8bit([W], + _fused=fused, + betas=(.5, .5), + quantize=(device != 'cpu'), + weight_decay=weight_decay, + error_correction=use_errors) + + zeros_grad = torch.zeros_like(W) + for _ in range(5): + W.grad = zeros_grad + opt.step() + opt.zero_grad() + + mom = opt.state[W]['exp_avg'] # type:ignore using tensor as key + assert torch.all(mom.materialize() == 0) + if mom.is_quantized(): + assert torch.all(mom.quantized == 0) + + if weight_decay: + assert torch.all(W_orig.abs() > W.abs()) + else: + torch.testing.assert_close(W_orig, W) # no weight modification + + +@pytest.mark.gpu +@pytest.mark.parametrize('N,D', [(1, 8), (17, 23), (32, 32)]) +@pytest.mark.parametrize('device', ['cpu', 'cuda']) +@pytest.mark.parametrize('dtype', _FLOAT_DTYPES) +@pytest.mark.parametrize('fused,use_errors', [(False, False), (True, False), + (True, True)]) +def test_descends(N: int, D: int, device: str, dtype: torch.dtype, fused: bool, + use_errors: bool) -> None: + if (dtype != torch.float32) and device == 'cpu': + return + torch.manual_seed(123) + X = torch.randn((N, D), device=device, requires_grad=False, dtype=dtype) + W = torch.randn((D, D), device=device, requires_grad=True, dtype=dtype) + + # we use tiny beta1 so we move almost entirely in the gradient direction + opt = Lion8bit([W], + lr=1e-2, + betas=(.5, .5), + quantize=(device != 'cpu'), + _fused=fused, + error_correction=use_errors) + + prev_loss = np.inf + prev_momentum = None + num_iters = 10 if device == 'cuda' else 2 # keep test fast + for _ in range(num_iters): + Y = X @ W + loss = (Y * Y).mean() + loss.backward() + opt.step() + opt.zero_grad() + + loss_val = loss.item() + assert loss_val < prev_loss + prev_loss = loss_val + + # since we're getting the same batch every time and have a small + # learning rate, our gradients should point in the same direction + # at each step. Consequently, our momentum should grow each step. + state_for_param = opt.state[W] # type:ignore using tensor as key + momentum = state_for_param['exp_avg'].materialize() + assert momentum is not None and momentum.shape == W.shape + if prev_momentum is not None: + momentum_changes = momentum - prev_momentum + assert torch.all(momentum_changes >= 0) + assert momentum_changes.max() > 0 + prev_momentum = momentum + + +def _nmse(vals_true: torch.Tensor, + vals_hat: torch.Tensor, + norm_how: str = 'l2_sq'): + diffs = vals_true - vals_hat + mse = (diffs * diffs).mean() + if norm_how == 'var': + return mse / vals_true.var() + return mse / (vals_true * vals_true).mean() + + +@pytest.mark.gpu +@pytest.mark.parametrize('w_init', ['cyclic', 'rand']) +@pytest.mark.parametrize('grad_strategy', ['zero', 'ones', 'const', 'rand']) +@pytest.mark.parametrize('D', [4, 12]) # vectorized and unvectorized impls +@pytest.mark.parametrize('dtype', _FLOAT_DTYPES) +def test_lion8b_fused_unfused_unquantized_same(w_init: str, grad_strategy: str, + D: int, + dtype: torch.dtype) -> None: + torch.manual_seed(123) + device = 'cuda' + + # each optimizer gets a different weight matrix to optimize + if w_init == 'cyclic': + W0 = torch.arange(D * D, + device=device, + requires_grad=False, + dtype=dtype).reshape(D, D) + W0 = ((W0 // 2 % 3) - 1).to(dtype=dtype) + elif w_init == 'rand': + W0 = torch.rand( + size=(D, D), device=device, requires_grad=False, + dtype=dtype) * 2 - 1 + W0 += .01 * torch.sign(W0) # bound away from 0 to cap rel errors + W0 = W0.to(dtype=dtype) + else: # here for pyright + raise ValueError("Unrecognized w_init: ", w_init) + W0.add_(W0.sign()) # bound away from zero so decay won't flip sign + W_true = torch.empty_like(W0, requires_grad=True, + dtype=torch.float32) # ground truth + W_uq = torch.empty_like(W0, requires_grad=True) # unquantized + W_uf = torch.empty_like(W0, requires_grad=True) # unfused + W_fq = torch.empty_like(W0, requires_grad=True) # fused and quantized + W_fqe = torch.empty_like(W0, requires_grad=True) # fused, quantized, ecc + W_sgd = torch.empty_like(W0, requires_grad=True) + with torch.no_grad(): + W_true.copy_(W0.to(W_true.dtype)) + W_uq.copy_(W0) + W_uf.copy_(W0) + W_fq.copy_(W0) + W_fqe.copy_(W0) + W_sgd.copy_(W0) + + # we use a high LR, low betas, and regularization so that there will + # hopefully be differences if *any* of the logic is wrong + lr = .1 + # weight_decay = .25 + weight_decay = .01 + # weight_decay = .0 + kwargs = {'lr': lr, 'weight_decay': weight_decay, 'betas': (.5, .75)} + # kwargs = {'lr': lr, 'weight_decay': weight_decay, 'betas': (0, 0)} # f16 fq works + # kwargs = {'lr': lr, 'weight_decay': weight_decay, 'betas': (.5, 0)} # f16 fq works + # kwargs = {'lr': lr, 'weight_decay': weight_decay, 'betas': (0, .5)} # f16 fq works + opt_true = Lion8bit([W_true], quantize=False, **kwargs) + opt_uq = Lion8bit([W_uq], quantize=False, **kwargs) + opt_uf = Lion8bit([W_uf], _fused=False, **kwargs) + opt_fq = Lion8bit([W_fq], _fused=True, **kwargs) + opt_fqe = Lion8bit([W_fqe], _fused=True, error_correction=True, **kwargs) + opt_sgd = torch.optim.SGD([W_sgd], lr=lr) + + W_list = [W_true, W_uq, W_uf, W_fq, W_fqe, W_sgd] + opt_list = [opt_true, opt_uq, opt_uf, opt_fq, opt_fqe, opt_sgd] + + if grad_strategy == 'zero': + grads = torch.zeros_like(W0) + elif grad_strategy == 'ones': + grads = ((torch.arange(W0.numel()) % 2) * 2 - 1).reshape(W0.shape) + elif grad_strategy == 'const': + # arange makes blocks have different distros, so we can't + # get away with bugs like always using the first scale_scale + grads = torch.arange(W0.numel(), + device=device, + requires_grad=False, + dtype=W0.dtype).view(W0.shape) + # next two conditions are just here for pyright + elif grad_strategy == 'rand': + grads = torch.tensor([-1]) + else: + raise ValueError("bad grad_strategy: ", grad_strategy) + + # for _ in range(3): + # for _ in range(1): + # for _ in range(10): + for _ in range(4): + if grad_strategy == 'rand': # type:ignore (reportUnnecessaryComparison) + grads = torch.rand(W0.shape, + device=device, + requires_grad=False, + dtype=W0.dtype) + for W, opt in zip(W_list, opt_list): + W.grad = grads.clone().to(dtype=W.dtype, device=W.device) + opt.step() + opt.zero_grad() + + W0_f = W0.float() + diffs_true = (W_true.detach().float() - W0_f).ravel() + diffs_uq = (W_uq.detach().float() - W0_f).ravel() + diffs_uf = (W_uf.detach().float() - W0_f).ravel() + diffs_fq = (W_fq.detach().float() - W0_f).ravel() + diffs_fqe = (W_fqe.detach().float() - W0_f).ravel() + diffs_sgd = (W_sgd.detach().float() - W0_f).ravel() + + # a bunch of made-up numbers; should be tight enough to detect + # regressions, but aren't enough to 100% guarantee correct numerics + if dtype != torch.bfloat16: + min_cossim = .99 + max_nmse = .01 + else: + min_cossim = .98 + max_nmse = .04 + + cossim = torch.cosine_similarity # avoid ugly linewraps + + assert cossim(diffs_true, diffs_uq, dim=-1) > min_cossim + assert _nmse(diffs_true, diffs_uq) < max_nmse + + assert cossim(diffs_true, diffs_uf, dim=-1) > min_cossim + assert _nmse(diffs_true, diffs_uf) < max_nmse + + # fused and unfused should be almost identical; the only differences + # are intermediate upcasting in the fused impl + assert cossim(diffs_uf, diffs_fq, dim=-1) > min_cossim + assert _nmse(diffs_uf, diffs_fq) < max_nmse + + # fused impl should be close to unfused version with no quantization + # at all; latter is "ground truth" + assert cossim(diffs_true, diffs_fq, dim=-1) > min_cossim + assert _nmse(diffs_true, diffs_fq) < max_nmse + + # fused impl with errors should also be close to "true" updates; + assert cossim(diffs_true, diffs_fqe, dim=-1) > min_cossim + assert _nmse(diffs_true, diffs_fqe) < max_nmse + + # error correction should reduce error, or at least do no worse + assert _nmse(diffs_true, diffs_fqe) <= _nmse(diffs_true, diffs_fq) + + # if sgd weights aren't different than LION weights, we haven't + # changed them enough to meaningfully test the LION logic + if grad_strategy not in ('zero', 'ones'): + assert torch.cosine_similarity( + diffs_true, # type:ignore (reportUnboundVariable) + diffs_sgd, # type:ignore (reportUnboundVariable) + dim=-1) < .99 + assert _nmse( + diffs_true, # type:ignore (reportUnboundVariable) + diffs_sgd # type:ignore (reportUnboundVariable) + ) > .01 + + +@pytest.mark.gpu +@pytest.mark.parametrize('device', ['cpu', 'cuda']) +@pytest.mark.parametrize('quantized_state', [False, True]) +@pytest.mark.parametrize('dtype', _FLOAT_DTYPES) +@pytest.mark.parametrize('use_errors', [False, True]) +def test_state_dict_save_load(device: str, quantized_state: bool, + dtype: torch.dtype, use_errors: bool): + torch.manual_seed(123) + params = [] + for shape in _MANY_PARAM_SHAPES: + p = torch.rand(shape, device=device, dtype=dtype, requires_grad=True) + p.grad = torch.rand_like(p) + params.append(p) + + # create optimizer and have it step so that state gets populated + opt = Lion8bit(params, + compress_state_dict=quantized_state, + error_correction=use_errors) + if device == 'cpu': + with pytest.raises(NotImplementedError): + opt.step() + return + else: + opt.step() + opt.zero_grad() + + # copy state dict into a new instance + state_dict = opt.state_dict() + opt_new = Lion8bit(params, + compress_state_dict=quantized_state, + error_correction=use_errors) + opt_new.load_state_dict(state_dict) + + for p in params: + d_orig = opt.state[p] + d_new = opt_new.state[p] + assert list(d_orig.keys()) == list(d_new.keys()) + mom_orig = d_orig['exp_avg'] + mom_new = d_new['exp_avg'] + if quantized_state: + # Optimizer load_state_dict insists on converting scales to + # dtype of param, which is lossy for bf16 params. + # Ideally we'd require == for everything but it's less complexity + # to just relax the bf16 test + assert torch.all(mom_orig.quantized == mom_new.quantized) + if dtype == torch.bfloat16: + torch.testing.assert_close(mom_orig.scales, + mom_new.scales, + atol=1e-3, + rtol=1e-2) + else: + assert torch.all(mom_orig.scales == mom_new.scales) + + torch.testing.assert_close(mom_orig.materialize(), + mom_new.materialize(), + atol=1. / (2 * 127), + rtol=np.inf) + if use_errors and (dtype != torch.float32): + torch.testing.assert_close(d_orig['errors'], d_new['errors']) + + +@pytest.mark.gpu +@pytest.mark.parametrize('N,D', [(32, 32), (256, 256), (1024, 1024), + (4096, 4096), [16384, 16384]]) +def test_fused_as_fast_as_unfused(N: int, + D: int, + min_elems_traversed: int = int(1e6)): + W = torch.randn((N, D), device='cuda', requires_grad=True) + W.grad = torch.randn((N, D), device='cuda', requires_grad=False) + + num_iters = int(np.ceil(min_elems_traversed / W.grad.numel())) + num_iters = min(100, num_iters) # don't take all day when overhead-bound + + times = {} + kwargs = {'weight_decay': .01} + combos = [(True, False), (True, True), (False, False), ('NA', False)] + for fused, use_errors in combos: + if fused == 'NA': + opt = Lion8bit([W], quantize=False, + **kwargs) # type:ignore (reportGeneralTypeIssues) + else: + opt = Lion8bit([W], + _fused=fused, + error_correction=use_errors, + **kwargs) # type:ignore (reportGeneralTypeIssues) + for _ in range(3): + opt.step() # warmup iters + torch.cuda.synchronize() + t_start = time.time() + for _ in range(num_iters): + opt.step() + torch.cuda.synchronize() + t_end = time.time() + dur = (t_end - t_start) / num_iters + if use_errors: + times['ecc'] = dur + else: + times[fused] = dur + + atol = 20e-6 # should always be faster, but avoids rare flakiness + assert times[True] < times[False] + atol + assert times[True] < times['NA'] + atol + assert times['ecc'] < times['NA'] + atol + + if False: # change to True to check on thruput + print("") + print("time fused (ms): ", times[True] * 1e3) + print("time fused+ecc (ms): ", times['ecc'] * 1e3) + print("time unfused (ms): ", times[False] * 1e3) + print("time unquantized (ms): ", times['NA'] * 1e3) From 1a14857264dca03dc62f244edb7afb0180bbdd88 Mon Sep 17 00:00:00 2001 From: root Date: Fri, 11 Aug 2023 17:15:00 +0000 Subject: [PATCH 02/23] pre-commit fixes --- llmfoundry/optim/lion8b.py | 25 ++++++++++++++----------- tests/test_lion8b.py | 14 +++++++------- 2 files changed, 21 insertions(+), 18 deletions(-) diff --git a/llmfoundry/optim/lion8b.py b/llmfoundry/optim/lion8b.py index 3c11e4e282..35d2addc6b 100644 --- a/llmfoundry/optim/lion8b.py +++ b/llmfoundry/optim/lion8b.py @@ -1,3 +1,6 @@ +# Copyright 2022 MosaicML LLM Foundry authors +# SPDX-License-Identifier: Apache-2.0 + from typing import Any, Callable, Dict, Iterable, Optional, Tuple import torch @@ -36,8 +39,6 @@ class DecoupledLionW_8bit(torch.optim.Optimizer): each optimizer step. Note that we use decoupled weight decay, meaning that this decay does not contribute to the momentum. (Default: 0.) - l2_penalty: adds `l2_penalty * param` to the gradient at the - start of the optimizer step. This term *is* added to the momentum. compress_state_dict: if True, this optimizer's `state_dict` will include quantized optimizer states. Otherwise, the optimizer states are converted to bfloat16 Tensors matching the shapes of @@ -84,12 +85,14 @@ def __init__( self._error_correction = error_correction if error_correction and not _fused: raise NotImplementedError( - "Error correction requires fused kernels.") - defaults = dict(lr=lr, - initial_lr=lr, - betas=betas, - weight_decay=weight_decay, - fused=_fused) + 'Error correction requires fused kernels.') + defaults = { + 'lr': lr, + 'initial_lr': lr, + 'betas': betas, + 'weight_decay': weight_decay, + 'fused': _fused + } super().__init__(params, defaults) @torch.no_grad() @@ -111,9 +114,9 @@ def step_param(self, p: torch.Tensor, hparams: Dict[str, Any]) -> None: if self._quantize and not p.is_cuda: raise NotImplementedError( f"Can't use quantization with param on {p.device} " + - f"({p.shape}, {p.dtype}). If you need " + - "to use DecoupledLionW_8bit without a CUDA device, try " + - "creating this optimizer with quantize=False.") + f'({p.shape}, {p.dtype}). If you need ' + + 'to use DecoupledLionW_8bit without a CUDA device, try ' + + 'creating this optimizer with quantize=False.') state = self.state[p] # type:ignore using tensor as key if _KEY_MOMENTUM not in state: mom = torch.zeros_like(p) diff --git a/tests/test_lion8b.py b/tests/test_lion8b.py index 24887b2b0b..019866bf6f 100644 --- a/tests/test_lion8b.py +++ b/tests/test_lion8b.py @@ -187,7 +187,7 @@ def test_lion8b_fused_unfused_unquantized_same(w_init: str, grad_strategy: str, W0 += .01 * torch.sign(W0) # bound away from 0 to cap rel errors W0 = W0.to(dtype=dtype) else: # here for pyright - raise ValueError("Unrecognized w_init: ", w_init) + raise ValueError('Unrecognized w_init: ', w_init) W0.add_(W0.sign()) # bound away from zero so decay won't flip sign W_true = torch.empty_like(W0, requires_grad=True, dtype=torch.float32) # ground truth @@ -239,7 +239,7 @@ def test_lion8b_fused_unfused_unquantized_same(w_init: str, grad_strategy: str, elif grad_strategy == 'rand': grads = torch.tensor([-1]) else: - raise ValueError("bad grad_strategy: ", grad_strategy) + raise ValueError('bad grad_strategy: ', grad_strategy) # for _ in range(3): # for _ in range(1): @@ -415,8 +415,8 @@ def test_fused_as_fast_as_unfused(N: int, assert times['ecc'] < times['NA'] + atol if False: # change to True to check on thruput - print("") - print("time fused (ms): ", times[True] * 1e3) - print("time fused+ecc (ms): ", times['ecc'] * 1e3) - print("time unfused (ms): ", times[False] * 1e3) - print("time unquantized (ms): ", times['NA'] * 1e3) + print('') + print('time fused (ms): ', times[True] * 1e3) + print('time fused+ecc (ms): ', times['ecc'] * 1e3) + print('time unfused (ms): ', times[False] * 1e3) + print('time unquantized (ms): ', times['NA'] * 1e3) From 16ca215e5702913e92a0d342d9614851bf8318cd Mon Sep 17 00:00:00 2001 From: root Date: Fri, 11 Aug 2023 20:38:09 +0000 Subject: [PATCH 03/23] move lion8b kernels dep to "gpu" extra_deps --- setup.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/setup.py b/setup.py index e0d1087ac6..9ad0f5ac4a 100644 --- a/setup.py +++ b/setup.py @@ -50,7 +50,6 @@ 'mosaicml[libcloud,nlp,wandb,mlflow]>=0.15.0,<0.16', 'accelerate>=0.20,<0.21', # for HF inference `device_map` 'mosaicml-streaming>=0.5.1,<0.6', - 'mosaicml-turbo>=0.0.2,<0.1', 'torch>=1.13.1,<=2.0.1', 'datasets==2.10.1', 'sentencepiece==0.1.97', @@ -83,6 +82,7 @@ extra_deps['gpu'] = [ 'flash-attn==v1.0.3.post0', + 'mosaicml-turbo>=0.0.2,<0.1', # PyPI does not support direct dependencies, so we remove this line before uploading from PyPI 'xentropy-cuda-lib@git+https://github.com/HazyResearch/flash-attention.git@v1.0.3#subdirectory=csrc/xentropy', ] From f391b7fe2d1c641bdd2f943e78394a147330e85d Mon Sep 17 00:00:00 2001 From: root Date: Fri, 11 Aug 2023 20:59:08 +0000 Subject: [PATCH 04/23] move fused error checks to llmfoundry --- llmfoundry/optim/lion8b.py | 67 ++++++++++++++++++++++++++++++++++++-- 1 file changed, 65 insertions(+), 2 deletions(-) diff --git a/llmfoundry/optim/lion8b.py b/llmfoundry/optim/lion8b.py index 35d2addc6b..a47ea476dd 100644 --- a/llmfoundry/optim/lion8b.py +++ b/llmfoundry/optim/lion8b.py @@ -301,6 +301,71 @@ def lion_step_unfused(grads: torch.Tensor, return momentums # f32 upcast means not necessarily modified in place +def lion8b_step_fused(grads: torch.Tensor, + weights: torch.Tensor, + momentums: torch.Tensor, + scales: torch.Tensor, + lr: float, + beta1: float, + beta2: float, + weight_decay: float, + errors: Optional[torch.Tensor] = None) -> None: + # just to save space in lists of allowed dtypes + f16, bf16, f32 = torch.float16, torch.bfloat16, torch.float32 + + use_errors = (errors is not None) and (weights.dtype in (f16, bf16)) + orig_shape = weights.shape + + # ------------------------------------------------ wall of error checking + quantize_group_size = 32 + num_groups = (weights.numel() + quantize_group_size - + 1) // quantize_group_size + if (num_groups != scales.numel()): + raise ValueError(f'Expected {num_groups} quantization scales but ' + + f' received {scales.numel()}') + + for name, tensor, allowed_dtypes in [('grad', grads, (f16, bf16, f32)), + ('param', weights, (f16, bf16, f32)), + ('momentum', momentums, [torch.int8]), + ('scales', scales, [f16]), + ('errors', errors, [torch.uint8])]: + if name == 'errors' and not use_errors: + continue + if not tensor.is_cuda: + raise ValueError( + f'{name} must be on a CUDA device, not {tensor.device}') + if not tensor.is_contiguous(): + raise ValueError(f'{name} is not contiguous!') + strides_unequal = tensor.stride() != weights.stride() + if name not in ('scales', 'errors') and strides_unequal: + raise ValueError(f'{name} stride {tensor.stride()} != ' + + f'param stride {weights.stride()}') + if tensor.dtype not in allowed_dtypes: + raise ValueError(f'{name} must have dtype {allowed_dtypes}, not ' + + f'{tensor.dtype}') + if (name != 'scales') and (orig_shape != tensor.shape): + raise ValueError(f'Param shape {orig_shape} != ' + + f'{name} shape {tensor.shape}') + + if grads.dtype in (torch.float16, torch.bfloat16): + allowed_dtypes = (grads.dtype, torch.float32) + if weights.dtype not in allowed_dtypes: + raise ValueError( + f'Weights must be f32 or match grad dtype {grads.dtype}') + + # ------------------------------------------------ actual function call + from turbo import lion8b_step_cuda + return lion8b_step_cuda(grads=grads, + weights=weights, + momentums=momentums, + scales=scales, + lr=lr, + beta1=beta1, + beta2=beta2, + weight_decay=weight_decay, + errors=errors) + + def _lion8b_step(grads: torch.Tensor, weights: torch.Tensor, momentums: _MaybeQuantizedTensor, @@ -312,8 +377,6 @@ def _lion8b_step(grads: torch.Tensor, fused: bool = True) -> None: if momentums.is_quantized() and fused: - from turbo import lion8b_step as lion8b_step_fused - assert momentums.quantized is not None # pyright assert momentums.scales is not None # pyright return lion8b_step_fused(grads=grads, From bcf55bfa02a6d23a35ba9966c017d8b25ae633d1 Mon Sep 17 00:00:00 2001 From: root Date: Fri, 11 Aug 2023 20:59:37 +0000 Subject: [PATCH 05/23] make precommit + CodeQL happy? --- llmfoundry/optim/__init__.py | 5 ++++- llmfoundry/utils/builders.py | 3 ++- tests/test_lion8b.py | 2 +- 3 files changed, 7 insertions(+), 3 deletions(-) diff --git a/llmfoundry/optim/__init__.py b/llmfoundry/optim/__init__.py index 7936f9a8c7..1d0e5caced 100644 --- a/llmfoundry/optim/__init__.py +++ b/llmfoundry/optim/__init__.py @@ -5,4 +5,7 @@ from llmfoundry.optim.lion import DecoupledLionW from llmfoundry.optim.lion8b import DecoupledLionW_8bit -__all__ = ['DecoupledLionW', 'DecoupledClipLion', 'DecoupledAdaLRLion'] +__all__ = [ + 'DecoupledLionW', 'DecoupledLionW_8bit', 'DecoupledClipLion', + 'DecoupledAdaLRLion' +] diff --git a/llmfoundry/utils/builders.py b/llmfoundry/utils/builders.py index 3ce4a9a6c9..4e47021755 100644 --- a/llmfoundry/utils/builders.py +++ b/llmfoundry/utils/builders.py @@ -116,7 +116,8 @@ def build_optimizer(cfg: DictConfig, model: torch.nn.Module): lr_penalty=cfg.lr_penalty, min_scale=cfg.min_scale) elif cfg.name.lower() == 'decoupled_lionw_8b': - kwargs = {k: v for k, v in cfg.items() if k != 'name'} + # str() cast is just for pyright + kwargs = {str(k): v for k, v in cfg.items() if k != 'name'} return DecoupledLionW_8bit(model.parameters(), **kwargs) else: raise ValueError(f'Not sure how to build optimizer: {cfg.name}') diff --git a/tests/test_lion8b.py b/tests/test_lion8b.py index 019866bf6f..85927d2184 100644 --- a/tests/test_lion8b.py +++ b/tests/test_lion8b.py @@ -376,7 +376,7 @@ def test_state_dict_save_load(device: str, quantized_state: bool, (4096, 4096), [16384, 16384]]) def test_fused_as_fast_as_unfused(N: int, D: int, - min_elems_traversed: int = int(1e6)): + min_elems_traversed: int = 1000000): W = torch.randn((N, D), device='cuda', requires_grad=True) W.grad = torch.randn((N, D), device='cuda', requires_grad=False) From 6fc178274f2b5a6bcc4bab49e51e6a748378886e Mon Sep 17 00:00:00 2001 From: root Date: Sat, 12 Aug 2023 03:31:42 +0000 Subject: [PATCH 06/23] disable fsdp param_dtype for low-bit master weights --- llmfoundry/utils/config_utils.py | 17 +++++++++++++++++ 1 file changed, 17 insertions(+) diff --git a/llmfoundry/utils/config_utils.py b/llmfoundry/utils/config_utils.py index 6c12775bfc..6e6fd0c89e 100644 --- a/llmfoundry/utils/config_utils.py +++ b/llmfoundry/utils/config_utils.py @@ -86,6 +86,23 @@ def process_init_device(model_cfg: DictConfig, fsdp_config: Optional[Dict]): # Set defaults for mixed initialization fsdp_config.setdefault('use_orig_params', False) fsdp_config.setdefault('load_monolith_rank0_only', True) + + # no mixed precision needed for weights when they're already 16 bits + master_dtype = model_cfg.get('master_weights_dtype') + if fsdp_config and master_dtype in ('bf16', 'f16', 'float16', 'bfloat16'): + reduce_dtype = None + buffer_dtype = None + mixed_precision = fsdp_config.get('mixed_precision') + if isinstance(mixed_precision, Mapping): + reduce_dtype = mixed_precision.get('reduce_dtype') + buffer_dtype = mixed_precision.get('buffer_dtype') + fsdp_config['mixed_precision'] = { + 'param_dtype': None, + 'reduce_dtype': reduce_dtype, + 'buffer_dtype': buffer_dtype, + 'keep_low_precision_grads': True, + } + return init_context From ba0e317a48b8f5e8ea42aa54a60163180c5e0437 Mon Sep 17 00:00:00 2001 From: root Date: Sat, 12 Aug 2023 03:32:40 +0000 Subject: [PATCH 07/23] add low-precision master weights option + rm needles .get(..., None) --- scripts/train/train.py | 26 +++++++++++++++----------- 1 file changed, 15 insertions(+), 11 deletions(-) diff --git a/scripts/train/train.py b/scripts/train/train.py index c6a86503e3..611eabb155 100644 --- a/scripts/train/train.py +++ b/scripts/train/train.py @@ -194,15 +194,15 @@ def main(cfg: DictConfig): cfg = update_batch_size_info(cfg) # Read FSDP Config as a dict - fsdp_config = cfg.get('fsdp_config', None) - fsdp_config = om.to_container(fsdp_config, - resolve=True) if fsdp_config else None - assert isinstance(fsdp_config, Dict) or fsdp_config is None - if dist.get_world_size() == 1 and fsdp_config is not None: - warnings.warn( - 'FSDP is not applicable for single-GPU training. Reverting to DDP.') - cfg.pop('fsdp_config') - fsdp_config = None + fsdp_config = cfg.get('fsdp_config') + if fsdp_config is not None: + fsdp_config = om.to_container(fsdp_config, resolve=True) + assert isinstance(fsdp_config, Dict) + if dist.get_world_size() == 1: + warnings.warn( + 'FSDP is not applicable for single-GPU training. Reverting to DDP.') + cfg.pop('fsdp_config') + fsdp_config = None init_context = process_init_device(cfg.model, fsdp_config) @@ -212,13 +212,16 @@ def main(cfg: DictConfig): # Build Model print('Initializing model...') with init_context: - if cfg.get('lora', - None) is not None: # frozen model + trainable lora modules + if cfg.get('lora') is not None: # frozen model + trainable lora modules model: ComposerHFCausalLM = build_composer_peft_model( cfg.model, cfg.lora, tokenizer) print_trainable_parameters(model) # should not be 100% else: # standard model model = build_composer_model(cfg.model, tokenizer) + if cfg.model.get('master_weights_dtype') in ('bf16', 'bfloat16'): + model = model.to(dtype=torch.bfloat16) + elif cfg.model.get('master_weights_dtype') in ('f16', 'float16'): + model = model.to(dtype=torch.float16) cfg.n_params = sum(p.numel() for p in model.parameters()) print(f'{cfg.n_params=:.2e}') @@ -342,5 +345,6 @@ def main(cfg: DictConfig): yaml_cfg = om.load(f) cli_cfg = om.from_cli(args_list) cfg = om.merge(yaml_cfg, cli_cfg) + om.resolve(cfg) assert isinstance(cfg, DictConfig) main(cfg) From 7a55e0741d7328d423d6a91f568b61652ddc887e Mon Sep 17 00:00:00 2001 From: root Date: Mon, 14 Aug 2023 00:49:53 +0000 Subject: [PATCH 08/23] fix missing import in config_utils --- llmfoundry/utils/config_utils.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/llmfoundry/utils/config_utils.py b/llmfoundry/utils/config_utils.py index 6e6fd0c89e..e8d130230c 100644 --- a/llmfoundry/utils/config_utils.py +++ b/llmfoundry/utils/config_utils.py @@ -4,7 +4,7 @@ import contextlib import math import warnings -from typing import Dict, Optional, Union +from typing import Dict, Optional, Union, Mapping from composer.utils import dist from omegaconf import DictConfig @@ -89,7 +89,8 @@ def process_init_device(model_cfg: DictConfig, fsdp_config: Optional[Dict]): # no mixed precision needed for weights when they're already 16 bits master_dtype = model_cfg.get('master_weights_dtype') - if fsdp_config and master_dtype in ('bf16', 'f16', 'float16', 'bfloat16'): + small_dtypes = ('bf16', 'f16', 'float16', 'bfloat16', 'amp_fp16', 'amp_bf16') + if fsdp_config and master_dtype in small_dtypes: reduce_dtype = None buffer_dtype = None mixed_precision = fsdp_config.get('mixed_precision') From 225ceac0ee57ed8c6a537341b92502b08f28c985 Mon Sep 17 00:00:00 2001 From: root Date: Mon, 14 Aug 2023 01:25:44 +0000 Subject: [PATCH 09/23] hopefully fix lion8b fsdp checkpointing --- llmfoundry/optim/lion8b.py | 100 +++++++++++++++++++++---------------- 1 file changed, 56 insertions(+), 44 deletions(-) diff --git a/llmfoundry/optim/lion8b.py b/llmfoundry/optim/lion8b.py index a47ea476dd..16cebf77b3 100644 --- a/llmfoundry/optim/lion8b.py +++ b/llmfoundry/optim/lion8b.py @@ -29,16 +29,14 @@ class DecoupledLionW_8bit(torch.optim.Optimizer): Args: params: iterable of parameters to optimize or dicts defining parameter groups - lr: learning rate (Default: 1e-3) + lr: learning rate betas: two coefficients between 0 and 1 used to combine the current gradients and the momentum. The first coefficient is the weight of the gradient when computing the update. The second is the weight of the gradient when computing the new momentum. - (Default: .9, .99) weight decay: Weights are multiplied by 1 - `weight_decay` after each optimizer step. Note that we use decoupled weight decay, meaning that this decay does not contribute to the momentum. - (Default: 0.) compress_state_dict: if True, this optimizer's `state_dict` will include quantized optimizer states. Otherwise, the optimizer states are converted to bfloat16 Tensors matching the shapes of @@ -57,19 +55,17 @@ class DecoupledLionW_8bit(torch.optim.Optimizer): by retaining information across optimizer steps. """ - def __init__( - self, - params: Iterable[torch.Tensor], - lr: float = 1e-3, - betas: Tuple[float, float] = (0.9, 0.99), - weight_decay: float = 0, - compress_state_dict: bool = False, - quantize: bool = True, - _fused: bool = True, # XXX this flag is mostly for testing... - error_correction: bool = False): - if not 0.0 <= lr: + def __init__(self, + params: Iterable[torch.Tensor], + lr: float = 1e-3, + betas: Tuple[float, float] = (0.9, 0.99), + weight_decay: float = 0, + quantize: bool = True, + compress_state_dict: bool = False, + error_correction: bool = False, + _fused: bool = True): # XXX this flag is mostly for testing... + if lr < 0.0: raise ValueError('Invalid learning rate: {}'.format(lr)) - # if not 0.0 < betas[0] < 1.0: if not 0.0 <= betas[0] <= 1.0: raise ValueError('Invalid beta parameter at index 0: {}'.format( betas[0])) @@ -80,19 +76,24 @@ def __init__( raise ValueError( 'Invalid weight_decay value: {}'.format(weight_decay)) - self._quantize = quantize and torch.cuda.is_available() - self._compress_state_dict = compress_state_dict + if not torch.cuda.is_available(): + needs_cuda = ' requires a CUDA device.' + if quantize: + raise NotImplementedError('Quantization' + needs_cuda) + if error_correction: + raise NotImplementedError('Error correction' + needs_cuda) + if compress_state_dict: + raise NotImplementedError('Quantized state dict' + needs_cuda) + + self._quantize = quantize self._error_correction = error_correction - if error_correction and not _fused: - raise NotImplementedError( - 'Error correction requires fused kernels.') - defaults = { - 'lr': lr, - 'initial_lr': lr, - 'betas': betas, - 'weight_decay': weight_decay, - 'fused': _fused - } + self._compress_state_dict = compress_state_dict + + defaults = {'lr': lr, + 'initial_lr': lr, + 'betas': betas, + 'weight_decay': weight_decay, + 'fused': _fused} super().__init__(params, defaults) @torch.no_grad() @@ -146,12 +147,16 @@ def __setstate__(self, state: Dict[str, Dict[str, Any]]) -> None: for param_id in opt_state: param_state = opt_state[param_id] new_state = {} - if _KEY_MOMENTUM in param_state: + if any(k.startswith(_KEY_MOMENTUM) for k in param_state): + # the keys can either be just "exp_avg" or + # "exp_avg::quantized" and "exp_avg::scales", depending on + # whether we saved it as quantized or not. The former case + # gives us interop with regular LION. qtensor = _MaybeQuantizedTensor(None, try_quantize=self._quantize) - qtensor.load_state_dict(param_state[_KEY_MOMENTUM]) + qtensor.load_state_dict(param_state, name=_KEY_MOMENTUM) new_state[_KEY_MOMENTUM] = qtensor - if self._error_correction and _KEY_ERRORS in param_state: + if _KEY_ERRORS in param_state: # we need to cast back to the correct dtype since optimizer # load_state_dict casts to param dtype for fp params; see # https://github.com/pytorch/pytorch/blob/a25eee1d77d93079614fab3ea4ac66e64fb2343b/torch/optim/optimizer.py#L626C7-L626C7 # noqa @@ -174,11 +179,12 @@ def state_dict(self): # isn't the same as self.state, but its consituent dicts are # the same as those in self.state param_state = {k: v for k, v in opt_state[param_id].items()} - if _KEY_MOMENTUM in param_state: - qtensor = param_state[_KEY_MOMENTUM] + if _KEY_MOMENTUM in param_state: # true if we've taken any steps + qtensor = param_state.pop(_KEY_MOMENTUM) assert isinstance(qtensor, _MaybeQuantizedTensor) # pyright - param_state[_KEY_MOMENTUM] = qtensor.state_dict( - allow_quantized=self._compress_state_dict) + param_state.update(qtensor.state_dict( + name=_KEY_MOMENTUM, + allow_quantized=self._compress_state_dict)) opt_state[param_id] = param_state return d @@ -213,23 +219,29 @@ def __init__(self, data: Optional[torch.Tensor], try_quantize: bool = True): self.set_data(data) def state_dict(self, + name: str, allow_quantized: bool = False) -> Dict[str, torch.Tensor]: if self.is_quantized() and allow_quantized: assert self.quantized is not None # pyright assert self.scales is not None # pyright - return {'quantized': self.quantized, 'scales': self.scales} - return {'data': self.materialize().to(dtype=torch.bfloat16)} - - def load_state_dict(self, d: Dict[str, torch.Tensor]) -> None: - if 'data' in d: + return {f'{name}::quantized': self.quantized, + f'{name}::scales': self.scales} + return {name: self.materialize().to(dtype=torch.bfloat16)} + + def load_state_dict(self, d: Dict[str, torch.Tensor], name: str) -> None: + # we allow other keys in the state dict for convenience, so you can + # just pass this the whole opt state for a parameters + d = {k: v for k, v in d.items() if k.startswith(name)} + if name in d: if len(d) != 1: - raise ValueError('If state dict specifies "data", it must not' + - f'specify other keys. Got {list(d.keys())}') - self.set_data(d['data']) + raise ValueError( + f'If state dict specifies {name}, it must not ' + + f'specify other keys. Got {list(d.keys())}') + self.set_data(d[name]) return - self.quantized = d['quantized'].to(dtype=torch.int8) - self.scales = d['scales'].to(dtype=torch.float16) + self.quantized = d[f'{name}::quantized'].to(dtype=torch.int8) + self.scales = d[f'{name}::scales'].to(dtype=torch.float16) def set_data(self, data: torch.Tensor) -> None: if not (self._try_quantize and data.is_cuda): From d53f0e58c2b69aeead7314962bd7e61a80d09aed Mon Sep 17 00:00:00 2001 From: root Date: Mon, 14 Aug 2023 05:49:10 +0000 Subject: [PATCH 10/23] pre-commit fixes --- llmfoundry/optim/lion8b.py | 25 +++++++++++++++---------- llmfoundry/utils/config_utils.py | 5 +++-- scripts/train/train.py | 3 ++- 3 files changed, 20 insertions(+), 13 deletions(-) diff --git a/llmfoundry/optim/lion8b.py b/llmfoundry/optim/lion8b.py index 16cebf77b3..c22689a212 100644 --- a/llmfoundry/optim/lion8b.py +++ b/llmfoundry/optim/lion8b.py @@ -89,11 +89,13 @@ def __init__(self, self._error_correction = error_correction self._compress_state_dict = compress_state_dict - defaults = {'lr': lr, - 'initial_lr': lr, - 'betas': betas, - 'weight_decay': weight_decay, - 'fused': _fused} + defaults = { + 'lr': lr, + 'initial_lr': lr, + 'betas': betas, + 'weight_decay': weight_decay, + 'fused': _fused + } super().__init__(params, defaults) @torch.no_grad() @@ -182,9 +184,10 @@ def state_dict(self): if _KEY_MOMENTUM in param_state: # true if we've taken any steps qtensor = param_state.pop(_KEY_MOMENTUM) assert isinstance(qtensor, _MaybeQuantizedTensor) # pyright - param_state.update(qtensor.state_dict( - name=_KEY_MOMENTUM, - allow_quantized=self._compress_state_dict)) + param_state.update( + qtensor.state_dict( + name=_KEY_MOMENTUM, + allow_quantized=self._compress_state_dict)) opt_state[param_id] = param_state return d @@ -224,8 +227,10 @@ def state_dict(self, if self.is_quantized() and allow_quantized: assert self.quantized is not None # pyright assert self.scales is not None # pyright - return {f'{name}::quantized': self.quantized, - f'{name}::scales': self.scales} + return { + f'{name}::quantized': self.quantized, + f'{name}::scales': self.scales + } return {name: self.materialize().to(dtype=torch.bfloat16)} def load_state_dict(self, d: Dict[str, torch.Tensor], name: str) -> None: diff --git a/llmfoundry/utils/config_utils.py b/llmfoundry/utils/config_utils.py index e8d130230c..79c9fe8011 100644 --- a/llmfoundry/utils/config_utils.py +++ b/llmfoundry/utils/config_utils.py @@ -4,7 +4,7 @@ import contextlib import math import warnings -from typing import Dict, Optional, Union, Mapping +from typing import Dict, Mapping, Optional, Union from composer.utils import dist from omegaconf import DictConfig @@ -89,7 +89,8 @@ def process_init_device(model_cfg: DictConfig, fsdp_config: Optional[Dict]): # no mixed precision needed for weights when they're already 16 bits master_dtype = model_cfg.get('master_weights_dtype') - small_dtypes = ('bf16', 'f16', 'float16', 'bfloat16', 'amp_fp16', 'amp_bf16') + small_dtypes = ('bf16', 'f16', 'float16', 'bfloat16', 'amp_fp16', + 'amp_bf16') if fsdp_config and master_dtype in small_dtypes: reduce_dtype = None buffer_dtype = None diff --git a/scripts/train/train.py b/scripts/train/train.py index 611eabb155..2602bf8b43 100644 --- a/scripts/train/train.py +++ b/scripts/train/train.py @@ -200,7 +200,8 @@ def main(cfg: DictConfig): assert isinstance(fsdp_config, Dict) if dist.get_world_size() == 1: warnings.warn( - 'FSDP is not applicable for single-GPU training. Reverting to DDP.') + 'FSDP is not applicable for single-GPU training. Reverting to DDP.' + ) cfg.pop('fsdp_config') fsdp_config = None From b1125aae0785f3ee5ebbee67b699bfadef9206fe Mon Sep 17 00:00:00 2001 From: root Date: Mon, 14 Aug 2023 23:40:13 +0000 Subject: [PATCH 11/23] address pr comments --- tests/test_lion8b.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/tests/test_lion8b.py b/tests/test_lion8b.py index 85927d2184..fc67ded4a4 100644 --- a/tests/test_lion8b.py +++ b/tests/test_lion8b.py @@ -48,13 +48,13 @@ def test_modifies_weights_and_momentums(N: int, D: int, dtype: torch.dtype, with pytest.raises(AssertionError): # opt step modified the weights torch.testing.assert_close(W_orig, W) - # every momentum should be nonzero with infinite precision, but - # might be zero after quantization + # Every momentum should be nonzero with infinite precision, but + # might be zero after quantization. We turn the _MaybeQuantizedTensor + # instance into a regular torch Tensor to simplify this check. param_state = opt.state[W] # type:ignore using tensor as key momentum = param_state['exp_avg'].materialize() assert momentum.shape == (D, D) momentum = momentum.ravel() - assert momentum is not None if momentum.numel() == 1: assert momentum.item() != 0 else: From 476a9ecb735d2b2ec9c31d4a9b445928a26d3119 Mon Sep 17 00:00:00 2001 From: root Date: Tue, 15 Aug 2023 02:04:08 +0000 Subject: [PATCH 12/23] fix descent + zero grad tests not being as stringent as intended --- tests/test_lion8b.py | 20 +++++--------------- 1 file changed, 5 insertions(+), 15 deletions(-) diff --git a/tests/test_lion8b.py b/tests/test_lion8b.py index fc67ded4a4..dad2bbb61f 100644 --- a/tests/test_lion8b.py +++ b/tests/test_lion8b.py @@ -75,8 +75,6 @@ def test_changes_with_zero_grads(N: int, D: int, device: str, return torch.manual_seed(123) W = torch.rand((D, D), device=device, requires_grad=True) - with torch.no_grad(): - W += torch.sign(W) # bound away from zero so decay won't change sign W_orig = W.detach().clone() opt = Lion8bit([W], @@ -146,10 +144,10 @@ def test_descends(N: int, D: int, device: str, dtype: torch.dtype, fused: bool, momentum = state_for_param['exp_avg'].materialize() assert momentum is not None and momentum.shape == W.shape if prev_momentum is not None: - momentum_changes = momentum - prev_momentum - assert torch.all(momentum_changes >= 0) - assert momentum_changes.max() > 0 - prev_momentum = momentum + momentum_abs_changes = (momentum - prev_momentum).abs() + assert torch.all(momentum_abs_changes >= 0) + assert momentum_abs_changes.max() > 0 + prev_momentum = momentum.clone() # gpu, f32 on cpu write in place def _nmse(vals_true: torch.Tensor, @@ -173,7 +171,7 @@ def test_lion8b_fused_unfused_unquantized_same(w_init: str, grad_strategy: str, torch.manual_seed(123) device = 'cuda' - # each optimizer gets a different weight matrix to optimize + # each optimizer gets a different copy of the weight matrix to optimize if w_init == 'cyclic': W0 = torch.arange(D * D, device=device, @@ -207,13 +205,8 @@ def test_lion8b_fused_unfused_unquantized_same(w_init: str, grad_strategy: str, # we use a high LR, low betas, and regularization so that there will # hopefully be differences if *any* of the logic is wrong lr = .1 - # weight_decay = .25 weight_decay = .01 - # weight_decay = .0 kwargs = {'lr': lr, 'weight_decay': weight_decay, 'betas': (.5, .75)} - # kwargs = {'lr': lr, 'weight_decay': weight_decay, 'betas': (0, 0)} # f16 fq works - # kwargs = {'lr': lr, 'weight_decay': weight_decay, 'betas': (.5, 0)} # f16 fq works - # kwargs = {'lr': lr, 'weight_decay': weight_decay, 'betas': (0, .5)} # f16 fq works opt_true = Lion8bit([W_true], quantize=False, **kwargs) opt_uq = Lion8bit([W_uq], quantize=False, **kwargs) opt_uf = Lion8bit([W_uf], _fused=False, **kwargs) @@ -241,9 +234,6 @@ def test_lion8b_fused_unfused_unquantized_same(w_init: str, grad_strategy: str, else: raise ValueError('bad grad_strategy: ', grad_strategy) - # for _ in range(3): - # for _ in range(1): - # for _ in range(10): for _ in range(4): if grad_strategy == 'rand': # type:ignore (reportUnnecessaryComparison) grads = torch.rand(W0.shape, From b87ca31a90c5cb005f15ad3913cc6178bcab3963 Mon Sep 17 00:00:00 2001 From: root Date: Tue, 15 Aug 2023 02:04:36 +0000 Subject: [PATCH 13/23] tiny style change --- llmfoundry/optim/lion8b.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/llmfoundry/optim/lion8b.py b/llmfoundry/optim/lion8b.py index c22689a212..db6629e73c 100644 --- a/llmfoundry/optim/lion8b.py +++ b/llmfoundry/optim/lion8b.py @@ -306,7 +306,7 @@ def lion_step_unfused(grads: torch.Tensor, beta2: float, weight_decay: float = 0) -> torch.Tensor: # f32 cast to match fused impl + for compatibility with f32 grads or weights - momentums = momentums.to(torch.float32) + momentums = momentums.to(dtype=torch.float32) grads = grads.to(dtype=torch.float32) update = momentums.lerp(grads, 1 - beta1).sign_() From f90a71c33944558e6748777991278cb3a4569c52 Mon Sep 17 00:00:00 2001 From: root Date: Tue, 15 Aug 2023 06:23:43 +0000 Subject: [PATCH 14/23] address more pr comments + WIP draft of FSDP checkpointing test --- llmfoundry/optim/lion8b.py | 29 +++++++++++++---- tests/test_lion8b.py | 67 ++++++++++++++++++++++++++++++++++---- 2 files changed, 83 insertions(+), 13 deletions(-) diff --git a/llmfoundry/optim/lion8b.py b/llmfoundry/optim/lion8b.py index db6629e73c..b9d4ec0874 100644 --- a/llmfoundry/optim/lion8b.py +++ b/llmfoundry/optim/lion8b.py @@ -46,13 +46,19 @@ class DecoupledLionW_8bit(torch.optim.Optimizer): FSDP or other weight sharding approaches. quantize: If False, optimizer states will not actually be quantized. This option is available so that one can easily debug whether - the quantization is causing any convergence issues. Quantization - is always disabled when training without a CUDA device. + the quantization is causing any convergence issues. Because + quantization is only supported for CUDA parameters, attempting to + update a non-CUDA tensor will raise an error. error_correction: If True, float16 and bfloat16 parameters will be given an extra state variable, "errors." This tensor will be of the same shape as the parameter but of dtype uint8. This auxiliary variable is used to better approximate float32 updates by retaining information across optimizer steps. + + Raises: + NotImplemenetedError - If any of `quantize`, `compress_state_dict`, + or `error_correction` are `True` and either a) there is no CUDA + device, or b) step() is executed on a non-CUDA parameter. """ def __init__(self, @@ -85,6 +91,7 @@ def __init__(self, if compress_state_dict: raise NotImplementedError('Quantized state dict' + needs_cuda) + _fused = _fused and quantize self._quantize = quantize self._error_correction = error_correction self._compress_state_dict = compress_state_dict @@ -249,14 +256,18 @@ def load_state_dict(self, d: Dict[str, torch.Tensor], name: str) -> None: self.scales = d[f'{name}::scales'].to(dtype=torch.float16) def set_data(self, data: torch.Tensor) -> None: - if not (self._try_quantize and data.is_cuda): - self.data = data.to(dtype=torch.float32) - self.quantized = None - self.scales = None - else: + if self._try_quantize: + if not data.is_cuda: + raise NotImplementedError( + f'Attempting to quantize a non-CUDA {data.dtype} tensor ' + + f'on device {data.device} with shape {data.shape}.') self.data = None assert self._f_encode is not None # pyright self.quantized, self.scales = self._f_encode(data) + else: + self.data = data.to(dtype=torch.float32) + self.quantized = None + self.scales = None def is_quantized(self) -> bool: return self.data is None @@ -393,6 +404,10 @@ def _lion8b_step(grads: torch.Tensor, errors: Optional[torch.Tensor] = None, fused: bool = True) -> None: + if fused and not momentums.is_quantized(): + raise NotImplementedError( + 'Fused LION step only implemented with quantization.') + if momentums.is_quantized() and fused: assert momentums.quantized is not None # pyright assert momentums.scales is not None # pyright diff --git a/tests/test_lion8b.py b/tests/test_lion8b.py index dad2bbb61f..ef21132750 100644 --- a/tests/test_lion8b.py +++ b/tests/test_lion8b.py @@ -7,6 +7,8 @@ import numpy as np import pytest import torch +import torch.nn as nn +from torch.distributed.fsdp import FullyShardedDataParallel as FSDP from llmfoundry.optim import DecoupledLionW_8bit as Lion8bit @@ -63,16 +65,17 @@ def test_modifies_weights_and_momentums(N: int, D: int, dtype: torch.dtype, @pytest.mark.gpu @pytest.mark.parametrize('N,D', _MANY_PARAM_SHAPES) -@pytest.mark.parametrize('device', ['cpu', 'cuda']) -@pytest.mark.parametrize('dtype', _FLOAT_DTYPES) +@pytest.mark.parametrize('device,dtype', [('cpu', torch.float32), + ('cuda', torch.bfloat16), ('cuda', torch.float16), ('cuda', torch.float32)]) @pytest.mark.parametrize('weight_decay', [0, .1]) @pytest.mark.parametrize('fused,use_errors', [(False, False), (True, False), (True, True)]) def test_changes_with_zero_grads(N: int, D: int, device: str, dtype: torch.dtype, weight_decay: float, fused: bool, use_errors: bool) -> None: - if (dtype != torch.float32) and device == 'cpu': + if (device == 'cpu') and (fused or use_errors): return + torch.manual_seed(123) W = torch.rand((D, D), device=device, requires_grad=True) W_orig = W.detach().clone() @@ -103,13 +106,13 @@ def test_changes_with_zero_grads(N: int, D: int, device: str, @pytest.mark.gpu @pytest.mark.parametrize('N,D', [(1, 8), (17, 23), (32, 32)]) -@pytest.mark.parametrize('device', ['cpu', 'cuda']) -@pytest.mark.parametrize('dtype', _FLOAT_DTYPES) +@pytest.mark.parametrize('device,dtype', [('cpu', torch.float32), + ('cuda', torch.bfloat16), ('cuda', torch.float16), ('cuda', torch.float32)]) @pytest.mark.parametrize('fused,use_errors', [(False, False), (True, False), (True, True)]) def test_descends(N: int, D: int, device: str, dtype: torch.dtype, fused: bool, use_errors: bool) -> None: - if (dtype != torch.float32) and device == 'cpu': + if (device == 'cpu') and (fused or use_errors): return torch.manual_seed(123) X = torch.randn((N, D), device=device, requires_grad=False, dtype=dtype) @@ -361,6 +364,58 @@ def test_state_dict_save_load(device: str, quantized_state: bool, torch.testing.assert_close(d_orig['errors'], d_new['errors']) +class _DummyModule(nn.Module): + def __init__(self, device: str, dtype: torch.dtype): + super().__init__() + self.linear0 = nn.Linear(4, 3, device=device, dtype=dtype) + self.linear1 = nn.Linear(3, 4, device=device, dtype=dtype) + + def forward(self, x: torch.Tensor) -> torch.Tensor: # type:ignore + return self.linear1(self.linear0(x)) + + +# run just this test with: +# python3 -m composer.cli.launcher -n 2 --master_port 26000 -m pytest -m gpu tests/test_lion8b.py::test_fsdp_save_load # noqa +@pytest.mark.gpu +@pytest.mark.parametrize('dtype', _FLOAT_DTYPES) +@pytest.mark.parametrize('use_errors', [False, True]) +@pytest.mark.parametrize('world_size', [2]) +def test_fsdp_save_load(dtype: torch.dtype, use_errors: bool, world_size: int): + device = 'cuda' + if torch.cuda.device_count() < 2: + pytest.skip(f'This test requires 2+ GPUs.') + assert torch.distributed.get_world_size() >= 2, 'Misconfigured test run!' + + # assert False + + # # # torch.cuda.set_device(0) + # # # os.environ['RANK'] = 0 + if not torch.distributed.is_initialized(): + torch.distributed.init_process_group() + + mod = FSDP(_DummyModule(device=device, dtype=dtype)) + + # actual forward pass instead of setting p.grad to avoid FSDP issues + X = torch.rand(size=(5, 4), device=device, dtype=dtype) + Y = mod(X) + Y.sum().backward() + for p in mod.parameters(): + p.grad = torch.rand_like(p) + + # create optimizer and have it step so that state gets populated + opt = Lion8bit(mod.parameters(), error_correction=use_errors) + opt.step() + opt.zero_grad() + + # copy state dict into a new instance + state_dict = opt.state_dict() + # del opt + mod_new = FSDP(_DummyModule(device=device, dtype=dtype)) + opt_new = Lion8bit(mod_new.parameters(), error_correction=use_errors) + opt_new.load_state_dict(state_dict) + # assert False + + @pytest.mark.gpu @pytest.mark.parametrize('N,D', [(32, 32), (256, 256), (1024, 1024), (4096, 4096), [16384, 16384]]) From 0eb34208b283e0cf25eb14d0957058c5a39067d7 Mon Sep 17 00:00:00 2001 From: root Date: Thu, 17 Aug 2023 07:06:32 +0000 Subject: [PATCH 15/23] partial fix of fsdp state dict test --- tests/test_lion8b.py | 35 ++++++++++++++++++++++++----------- 1 file changed, 24 insertions(+), 11 deletions(-) diff --git a/tests/test_lion8b.py b/tests/test_lion8b.py index ef21132750..595d95dd97 100644 --- a/tests/test_lion8b.py +++ b/tests/test_lion8b.py @@ -1,6 +1,7 @@ # Copyright 2022 MosaicML LLM Foundry authors # SPDX-License-Identifier: Apache-2.0 +import os import time import warnings @@ -377,21 +378,22 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: # type:ignore # run just this test with: # python3 -m composer.cli.launcher -n 2 --master_port 26000 -m pytest -m gpu tests/test_lion8b.py::test_fsdp_save_load # noqa @pytest.mark.gpu -@pytest.mark.parametrize('dtype', _FLOAT_DTYPES) -@pytest.mark.parametrize('use_errors', [False, True]) -@pytest.mark.parametrize('world_size', [2]) -def test_fsdp_save_load(dtype: torch.dtype, use_errors: bool, world_size: int): +@pytest.mark.world_size(2) +# @pytest.mark.parametrize('dtype', _FLOAT_DTYPES) +# @pytest.mark.parametrize('use_errors', [False, True]) +@pytest.mark.parametrize('dtype', [torch.float32]) +@pytest.mark.parametrize('use_errors', [False]) +# @pytest.mark.parametrize('world_size', [2]) +# @pytest.world_size(2) +def test_fsdp_save_load(dtype: torch.dtype, use_errors: bool): device = 'cuda' if torch.cuda.device_count() < 2: pytest.skip(f'This test requires 2+ GPUs.') - assert torch.distributed.get_world_size() >= 2, 'Misconfigured test run!' - # assert False - - # # # torch.cuda.set_device(0) - # # # os.environ['RANK'] = 0 + torch.cuda.set_device(f'cuda:{os.environ["RANK"]}') # needed for fsdp if not torch.distributed.is_initialized(): torch.distributed.init_process_group() + assert torch.distributed.get_world_size() >= 2, 'Misconfigured test run!' mod = FSDP(_DummyModule(device=device, dtype=dtype)) @@ -409,11 +411,22 @@ def test_fsdp_save_load(dtype: torch.dtype, use_errors: bool, world_size: int): # copy state dict into a new instance state_dict = opt.state_dict() - # del opt mod_new = FSDP(_DummyModule(device=device, dtype=dtype)) opt_new = Lion8bit(mod_new.parameters(), error_correction=use_errors) opt_new.load_state_dict(state_dict) - # assert False + + for p in mod.parameters(): + d_orig = opt.state[p] + d_new = opt_new.state[p] + assert list(d_orig.keys()) == list(d_new.keys()) + mom_orig = d_orig['exp_avg'] + mom_new = d_new['exp_avg'] + torch.testing.assert_close(mom_orig.materialize(), + mom_new.materialize(), + atol=1. / (2 * 127), + rtol=np.inf) + if use_errors and (dtype != torch.float32): + torch.testing.assert_close(d_orig['errors'], d_new['errors']) @pytest.mark.gpu From a2a010467ac560f739b38e49bfb33cf49f9cae4e Mon Sep 17 00:00:00 2001 From: root Date: Fri, 18 Aug 2023 06:50:46 +0000 Subject: [PATCH 16/23] fsdp state dict test passing --- tests/test_lion8b.py | 55 +++++++++++++++++++++++++++++++++----------- 1 file changed, 41 insertions(+), 14 deletions(-) diff --git a/tests/test_lion8b.py b/tests/test_lion8b.py index 595d95dd97..c7e4600971 100644 --- a/tests/test_lion8b.py +++ b/tests/test_lion8b.py @@ -10,6 +10,7 @@ import torch import torch.nn as nn from torch.distributed.fsdp import FullyShardedDataParallel as FSDP +from torch.distributed import fsdp from llmfoundry.optim import DecoupledLionW_8bit as Lion8bit @@ -379,10 +380,10 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: # type:ignore # python3 -m composer.cli.launcher -n 2 --master_port 26000 -m pytest -m gpu tests/test_lion8b.py::test_fsdp_save_load # noqa @pytest.mark.gpu @pytest.mark.world_size(2) -# @pytest.mark.parametrize('dtype', _FLOAT_DTYPES) -# @pytest.mark.parametrize('use_errors', [False, True]) -@pytest.mark.parametrize('dtype', [torch.float32]) -@pytest.mark.parametrize('use_errors', [False]) +@pytest.mark.parametrize('dtype', _FLOAT_DTYPES) +@pytest.mark.parametrize('use_errors', [False, True]) +# @pytest.mark.parametrize('dtype', [torch.float32]) +# @pytest.mark.parametrize('use_errors', [False]) # @pytest.mark.parametrize('world_size', [2]) # @pytest.world_size(2) def test_fsdp_save_load(dtype: torch.dtype, use_errors: bool): @@ -409,22 +410,48 @@ def test_fsdp_save_load(dtype: torch.dtype, use_errors: bool): opt.step() opt.zero_grad() - # copy state dict into a new instance - state_dict = opt.state_dict() + def _set_state_dict_type(model: nn.Module): + FSDP.set_state_dict_type(model, + fsdp.StateDictType.FULL_STATE_DICT, + fsdp.FullStateDictConfig(rank0_only=False), + fsdp.api.FullOptimStateDictConfig(rank0_only=False)) + + # load FSDP state dict + _set_state_dict_type(mod) + opt_state_dict = FSDP.optim_state_dict(mod, opt) + + # make a new model and optimizer mod_new = FSDP(_DummyModule(device=device, dtype=dtype)) opt_new = Lion8bit(mod_new.parameters(), error_correction=use_errors) - opt_new.load_state_dict(state_dict) + _set_state_dict_type(mod_new) - for p in mod.parameters(): - d_orig = opt.state[p] - d_new = opt_new.state[p] + print("initial opt state dict: ", opt_state_dict) + + # load state dict into the new optimizer + opt_state_dict_slice = FSDP.optim_state_dict_to_load( + opt_state_dict, mod_new, opt_new) + opt_new.load_state_dict(opt_state_dict_slice) + + new_opt_state_dict = FSDP.optim_state_dict(mod_new, opt_new) + print("new opt state dict: ", new_opt_state_dict) + + orig_state = opt_state_dict['state'] + orig_param_groups = opt_state_dict['param_groups'] + new_state = new_opt_state_dict['state'] + new_param_groups = new_opt_state_dict['param_groups'] + + all_keys = set(orig_state.keys()) | set(new_state.keys()) + assert orig_param_groups == new_param_groups # works since strs, not ptrs + for k in all_keys: # keys are param paths in module as strings + d_orig = orig_state[k] + d_new = new_state[k] assert list(d_orig.keys()) == list(d_new.keys()) mom_orig = d_orig['exp_avg'] mom_new = d_new['exp_avg'] - torch.testing.assert_close(mom_orig.materialize(), - mom_new.materialize(), - atol=1. / (2 * 127), - rtol=np.inf) + # momentums may not be bit-for-bit identical because Optimizer upcasts + # to f32 and we convert back to bf16, possibly with different rounding + torch.testing.assert_close(mom_orig, mom_new) + # errors not bit-for-bit identical because scales get upcast too if use_errors and (dtype != torch.float32): torch.testing.assert_close(d_orig['errors'], d_new['errors']) From afd969967d4b612bc44f710a5d3e467fcc189241 Mon Sep 17 00:00:00 2001 From: root Date: Fri, 18 Aug 2023 07:13:53 +0000 Subject: [PATCH 17/23] get fsdp state dict test passing with different sharding strategies --- tests/test_lion8b.py | 43 +++++++++++++++++++++++++++++++------------ 1 file changed, 31 insertions(+), 12 deletions(-) diff --git a/tests/test_lion8b.py b/tests/test_lion8b.py index c7e4600971..ddcd6ba023 100644 --- a/tests/test_lion8b.py +++ b/tests/test_lion8b.py @@ -376,17 +376,18 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: # type:ignore return self.linear1(self.linear0(x)) +_FULL_STATE = fsdp.StateDictType.FULL_STATE_DICT +_SHARDED_STATE = fsdp.StateDictType.SHARDED_STATE_DICT +_LOCAL_STATE = fsdp.StateDictType.LOCAL_STATE_DICT + # run just this test with: # python3 -m composer.cli.launcher -n 2 --master_port 26000 -m pytest -m gpu tests/test_lion8b.py::test_fsdp_save_load # noqa @pytest.mark.gpu @pytest.mark.world_size(2) @pytest.mark.parametrize('dtype', _FLOAT_DTYPES) @pytest.mark.parametrize('use_errors', [False, True]) -# @pytest.mark.parametrize('dtype', [torch.float32]) -# @pytest.mark.parametrize('use_errors', [False]) -# @pytest.mark.parametrize('world_size', [2]) -# @pytest.world_size(2) -def test_fsdp_save_load(dtype: torch.dtype, use_errors: bool): +@pytest.mark.parametrize('state_sharding', [_FULL_STATE, _SHARDED_STATE, _LOCAL_STATE]) +def test_fsdp_save_load(dtype: torch.dtype, use_errors: bool, state_sharding: fsdp.StateDictType): device = 'cuda' if torch.cuda.device_count() < 2: pytest.skip(f'This test requires 2+ GPUs.') @@ -411,10 +412,19 @@ def test_fsdp_save_load(dtype: torch.dtype, use_errors: bool): opt.zero_grad() def _set_state_dict_type(model: nn.Module): - FSDP.set_state_dict_type(model, - fsdp.StateDictType.FULL_STATE_DICT, - fsdp.FullStateDictConfig(rank0_only=False), - fsdp.api.FullOptimStateDictConfig(rank0_only=False)) + # for mapping between state dict types and optim state dict types, see: + # https://github.com/pytorch/pytorch/blob/a815e719e85899d4229616617e7827d4de191c2d/torch/distributed/fsdp/fully_sharded_data_parallel.py#L664 # noqa + state_dict_cfg = { + _FULL_STATE: fsdp.FullStateDictConfig(rank0_only=False), + _SHARDED_STATE: fsdp.api.ShardedStateDictConfig(), + _LOCAL_STATE: fsdp.api.LocalStateDictConfig(), + }[state_sharding] + optim_cfg = { + _FULL_STATE: fsdp.api.FullOptimStateDictConfig(rank0_only=False), + _SHARDED_STATE: fsdp.api.ShardedOptimStateDictConfig(), + _LOCAL_STATE: fsdp.api.LocalOptimStateDictConfig(), + }[state_sharding] + FSDP.set_state_dict_type(model, state_sharding, state_dict_cfg, optim_cfg) # load FSDP state dict _set_state_dict_type(mod) @@ -425,15 +435,12 @@ def _set_state_dict_type(model: nn.Module): opt_new = Lion8bit(mod_new.parameters(), error_correction=use_errors) _set_state_dict_type(mod_new) - print("initial opt state dict: ", opt_state_dict) - # load state dict into the new optimizer opt_state_dict_slice = FSDP.optim_state_dict_to_load( opt_state_dict, mod_new, opt_new) opt_new.load_state_dict(opt_state_dict_slice) new_opt_state_dict = FSDP.optim_state_dict(mod_new, opt_new) - print("new opt state dict: ", new_opt_state_dict) orig_state = opt_state_dict['state'] orig_param_groups = opt_state_dict['param_groups'] @@ -448,6 +455,18 @@ def _set_state_dict_type(model: nn.Module): assert list(d_orig.keys()) == list(d_new.keys()) mom_orig = d_orig['exp_avg'] mom_new = d_new['exp_avg'] + + assert mom_orig.shape == mom_new.shape + assert mom_orig.dtype == mom_new.dtype + if use_errors: + errs_orig = d_orig['errors'] + errs_new = d_new['errors'] + assert errs_orig.shape == errs_new.shape + assert errs_orig.dtype == errs_new.dtype + + if state_sharding != _FULL_STATE: + continue # more detailed checks lean on FSDP impl details + # momentums may not be bit-for-bit identical because Optimizer upcasts # to f32 and we convert back to bf16, possibly with different rounding torch.testing.assert_close(mom_orig, mom_new) From dd4ccb31ded9c763770230a76fcf4772fcc32c71 Mon Sep 17 00:00:00 2001 From: root Date: Fri, 18 Aug 2023 07:26:36 +0000 Subject: [PATCH 18/23] remove state key name indirection as per pr comments --- llmfoundry/optim/lion8b.py | 37 +++++++++++++++++-------------------- 1 file changed, 17 insertions(+), 20 deletions(-) diff --git a/llmfoundry/optim/lion8b.py b/llmfoundry/optim/lion8b.py index b9d4ec0874..806dbdbd14 100644 --- a/llmfoundry/optim/lion8b.py +++ b/llmfoundry/optim/lion8b.py @@ -5,9 +5,6 @@ import torch -_KEY_MOMENTUM = 'exp_avg' -_KEY_ERRORS = 'errors' - class DecoupledLionW_8bit(torch.optim.Optimizer): """LION optimizer with ~8 bits of state per parameter. @@ -128,18 +125,18 @@ def step_param(self, p: torch.Tensor, hparams: Dict[str, Any]) -> None: 'to use DecoupledLionW_8bit without a CUDA device, try ' + 'creating this optimizer with quantize=False.') state = self.state[p] # type:ignore using tensor as key - if _KEY_MOMENTUM not in state: + if 'exp_avg' not in state: mom = torch.zeros_like(p) - state[_KEY_MOMENTUM] = _MaybeQuantizedTensor( + state['exp_avg'] = _MaybeQuantizedTensor( mom, try_quantize=self._quantize) need_errs = (p.dtype != torch.float32) and self._error_correction - if state.get(_KEY_ERRORS) is None and need_errs: - state[_KEY_ERRORS] = torch.zeros(p.shape, - dtype=torch.uint8, - device=p.device) + if state.get('errors') is None and need_errs: + state['errors'] = torch.zeros(p.shape, + dtype=torch.uint8, + device=p.device) decay_factor = hparams['weight_decay'] decay_factor *= hparams['lr'] / hparams['initial_lr'] - _lion8b_step(momentums=state[_KEY_MOMENTUM], + _lion8b_step(momentums=state['exp_avg'], weights=p, grads=p.grad, beta1=hparams['betas'][0], @@ -147,7 +144,7 @@ def step_param(self, p: torch.Tensor, hparams: Dict[str, Any]) -> None: lr=hparams['lr'], weight_decay=decay_factor, fused=hparams['fused'], - errors=state.get(_KEY_ERRORS)) + errors=state.get('errors')) def __setstate__(self, state: Dict[str, Dict[str, Any]]) -> None: # we override this function to quantize optimizer states when @@ -156,21 +153,21 @@ def __setstate__(self, state: Dict[str, Dict[str, Any]]) -> None: for param_id in opt_state: param_state = opt_state[param_id] new_state = {} - if any(k.startswith(_KEY_MOMENTUM) for k in param_state): + if any(k.startswith('exp_avg') for k in param_state): # the keys can either be just "exp_avg" or # "exp_avg::quantized" and "exp_avg::scales", depending on # whether we saved it as quantized or not. The former case # gives us interop with regular LION. qtensor = _MaybeQuantizedTensor(None, try_quantize=self._quantize) - qtensor.load_state_dict(param_state, name=_KEY_MOMENTUM) - new_state[_KEY_MOMENTUM] = qtensor - if _KEY_ERRORS in param_state: + qtensor.load_state_dict(param_state, name='exp_avg') + new_state['exp_avg'] = qtensor + if 'errors' in param_state: # we need to cast back to the correct dtype since optimizer # load_state_dict casts to param dtype for fp params; see # https://github.com/pytorch/pytorch/blob/a25eee1d77d93079614fab3ea4ac66e64fb2343b/torch/optim/optimizer.py#L626C7-L626C7 # noqa - errs = param_state[_KEY_ERRORS].to(dtype=torch.uint8) - new_state[_KEY_ERRORS] = errs + errs = param_state['errors'].to(dtype=torch.uint8) + new_state['errors'] = errs opt_state[param_id] = new_state super().__setstate__(state) @@ -188,12 +185,12 @@ def state_dict(self): # isn't the same as self.state, but its consituent dicts are # the same as those in self.state param_state = {k: v for k, v in opt_state[param_id].items()} - if _KEY_MOMENTUM in param_state: # true if we've taken any steps - qtensor = param_state.pop(_KEY_MOMENTUM) + if 'exp_avg' in param_state: # true if we've taken any steps + qtensor = param_state.pop('exp_avg') assert isinstance(qtensor, _MaybeQuantizedTensor) # pyright param_state.update( qtensor.state_dict( - name=_KEY_MOMENTUM, + name='exp_avg', allow_quantized=self._compress_state_dict)) opt_state[param_id] = param_state return d From abdc6a6d883e795a63c63c9e11a6e445c04a94fa Mon Sep 17 00:00:00 2001 From: root Date: Fri, 18 Aug 2023 07:27:44 +0000 Subject: [PATCH 19/23] make precommit + pyright happy --- tests/test_lion8b.py | 47 ++++++++++++++++++++++++++++---------------- 1 file changed, 30 insertions(+), 17 deletions(-) diff --git a/tests/test_lion8b.py b/tests/test_lion8b.py index ddcd6ba023..999bf0fa67 100644 --- a/tests/test_lion8b.py +++ b/tests/test_lion8b.py @@ -8,9 +8,13 @@ import numpy as np import pytest import torch +import torch.distributed as dist import torch.nn as nn -from torch.distributed.fsdp import FullyShardedDataParallel as FSDP from torch.distributed import fsdp +from torch.distributed.fsdp import FullyShardedDataParallel as FSDP +from torch.distributed.fsdp.api import ( # type:ignore .api not in public API + FullOptimStateDictConfig, LocalOptimStateDictConfig, + ShardedOptimStateDictConfig) from llmfoundry.optim import DecoupledLionW_8bit as Lion8bit @@ -68,7 +72,9 @@ def test_modifies_weights_and_momentums(N: int, D: int, dtype: torch.dtype, @pytest.mark.gpu @pytest.mark.parametrize('N,D', _MANY_PARAM_SHAPES) @pytest.mark.parametrize('device,dtype', [('cpu', torch.float32), - ('cuda', torch.bfloat16), ('cuda', torch.float16), ('cuda', torch.float32)]) + ('cuda', torch.bfloat16), + ('cuda', torch.float16), + ('cuda', torch.float32)]) @pytest.mark.parametrize('weight_decay', [0, .1]) @pytest.mark.parametrize('fused,use_errors', [(False, False), (True, False), (True, True)]) @@ -109,7 +115,9 @@ def test_changes_with_zero_grads(N: int, D: int, device: str, @pytest.mark.gpu @pytest.mark.parametrize('N,D', [(1, 8), (17, 23), (32, 32)]) @pytest.mark.parametrize('device,dtype', [('cpu', torch.float32), - ('cuda', torch.bfloat16), ('cuda', torch.float16), ('cuda', torch.float32)]) + ('cuda', torch.bfloat16), + ('cuda', torch.float16), + ('cuda', torch.float32)]) @pytest.mark.parametrize('fused,use_errors', [(False, False), (True, False), (True, True)]) def test_descends(N: int, D: int, device: str, dtype: torch.dtype, fused: bool, @@ -152,7 +160,7 @@ def test_descends(N: int, D: int, device: str, dtype: torch.dtype, fused: bool, momentum_abs_changes = (momentum - prev_momentum).abs() assert torch.all(momentum_abs_changes >= 0) assert momentum_abs_changes.max() > 0 - prev_momentum = momentum.clone() # gpu, f32 on cpu write in place + prev_momentum = momentum.clone() # {gpu, f32 on cpu} write in place def _nmse(vals_true: torch.Tensor, @@ -367,12 +375,13 @@ def test_state_dict_save_load(device: str, quantized_state: bool, class _DummyModule(nn.Module): + def __init__(self, device: str, dtype: torch.dtype): super().__init__() self.linear0 = nn.Linear(4, 3, device=device, dtype=dtype) self.linear1 = nn.Linear(3, 4, device=device, dtype=dtype) - def forward(self, x: torch.Tensor) -> torch.Tensor: # type:ignore + def forward(self, x: torch.Tensor) -> torch.Tensor: # type:ignore return self.linear1(self.linear0(x)) @@ -380,22 +389,25 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: # type:ignore _SHARDED_STATE = fsdp.StateDictType.SHARDED_STATE_DICT _LOCAL_STATE = fsdp.StateDictType.LOCAL_STATE_DICT + # run just this test with: # python3 -m composer.cli.launcher -n 2 --master_port 26000 -m pytest -m gpu tests/test_lion8b.py::test_fsdp_save_load # noqa @pytest.mark.gpu @pytest.mark.world_size(2) @pytest.mark.parametrize('dtype', _FLOAT_DTYPES) @pytest.mark.parametrize('use_errors', [False, True]) -@pytest.mark.parametrize('state_sharding', [_FULL_STATE, _SHARDED_STATE, _LOCAL_STATE]) -def test_fsdp_save_load(dtype: torch.dtype, use_errors: bool, state_sharding: fsdp.StateDictType): +@pytest.mark.parametrize('state_sharding', + [_FULL_STATE, _SHARDED_STATE, _LOCAL_STATE]) +def test_fsdp_save_load(dtype: torch.dtype, use_errors: bool, + state_sharding: fsdp.StateDictType): device = 'cuda' if torch.cuda.device_count() < 2: pytest.skip(f'This test requires 2+ GPUs.') - torch.cuda.set_device(f'cuda:{os.environ["RANK"]}') # needed for fsdp - if not torch.distributed.is_initialized(): - torch.distributed.init_process_group() - assert torch.distributed.get_world_size() >= 2, 'Misconfigured test run!' + torch.cuda.set_device(f'cuda:{os.environ["RANK"]}') # needed for fsdp + if not dist.is_initialized(): + dist.init_process_group() + assert dist.get_world_size() >= 2, 'Misconfigured test run!' mod = FSDP(_DummyModule(device=device, dtype=dtype)) @@ -416,15 +428,16 @@ def _set_state_dict_type(model: nn.Module): # https://github.com/pytorch/pytorch/blob/a815e719e85899d4229616617e7827d4de191c2d/torch/distributed/fsdp/fully_sharded_data_parallel.py#L664 # noqa state_dict_cfg = { _FULL_STATE: fsdp.FullStateDictConfig(rank0_only=False), - _SHARDED_STATE: fsdp.api.ShardedStateDictConfig(), - _LOCAL_STATE: fsdp.api.LocalStateDictConfig(), + _SHARDED_STATE: fsdp.ShardedStateDictConfig(), + _LOCAL_STATE: fsdp.LocalStateDictConfig(), }[state_sharding] optim_cfg = { - _FULL_STATE: fsdp.api.FullOptimStateDictConfig(rank0_only=False), - _SHARDED_STATE: fsdp.api.ShardedOptimStateDictConfig(), - _LOCAL_STATE: fsdp.api.LocalOptimStateDictConfig(), + _FULL_STATE: FullOptimStateDictConfig(rank0_only=False), + _SHARDED_STATE: ShardedOptimStateDictConfig(), + _LOCAL_STATE: LocalOptimStateDictConfig(), }[state_sharding] - FSDP.set_state_dict_type(model, state_sharding, state_dict_cfg, optim_cfg) + FSDP.set_state_dict_type(model, state_sharding, state_dict_cfg, + optim_cfg) # load FSDP state dict _set_state_dict_type(mod) From 50829660368f12b35bf456aa8e8c190e262bda8c Mon Sep 17 00:00:00 2001 From: root Date: Fri, 18 Aug 2023 20:28:42 +0000 Subject: [PATCH 20/23] fix broken merge --- llmfoundry/utils/builders.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/llmfoundry/utils/builders.py b/llmfoundry/utils/builders.py index fd1b7b2a10..8bc6316edf 100644 --- a/llmfoundry/utils/builders.py +++ b/llmfoundry/utils/builders.py @@ -98,7 +98,7 @@ def build_optimizer(model: torch.nn.Module, name: str, return DecoupledClipLion(model.parameters(), **optimizer_config) elif name == 'adalr_lion': return DecoupledAdaLRLion(model.parameters(), **optimizer_config) - elif cfg.name.lower() == 'decoupled_lionw_8b': + elif name == 'decoupled_lionw_8b': return DecoupledLionW_8bit(model.parameters(), **optimizer_config) else: raise ValueError(f'Not sure how to build optimizer: {name}') From fbde16b5779de790c79a8bbda2cf0ea5b40eaf6e Mon Sep 17 00:00:00 2001 From: root Date: Sat, 19 Aug 2023 00:27:02 +0000 Subject: [PATCH 21/23] skip fsdp checkpoint test for torch 1.13.1 since...config classes missing? --- tests/test_lion8b.py | 16 +++++++++++++--- 1 file changed, 13 insertions(+), 3 deletions(-) diff --git a/tests/test_lion8b.py b/tests/test_lion8b.py index 999bf0fa67..7a4cd25982 100644 --- a/tests/test_lion8b.py +++ b/tests/test_lion8b.py @@ -6,15 +6,23 @@ import warnings import numpy as np +import packaging.version as version import pytest import torch import torch.distributed as dist import torch.nn as nn from torch.distributed import fsdp from torch.distributed.fsdp import FullyShardedDataParallel as FSDP -from torch.distributed.fsdp.api import ( # type:ignore .api not in public API - FullOptimStateDictConfig, LocalOptimStateDictConfig, - ShardedOptimStateDictConfig) + +if version.parse(torch.__version__) >= version.parse('2.0.1'): + from torch.distributed.fsdp.api import ( # type:ignore .api not in public API + FullOptimStateDictConfig, LocalOptimStateDictConfig, + ShardedOptimStateDictConfig) +else: + from unittest.mock import MagicMock # for pyright so vars aren't None + FullOptimStateDictConfig = MagicMock() + LocalOptimStateDictConfig = MagicMock() + ShardedOptimStateDictConfig = MagicMock() from llmfoundry.optim import DecoupledLionW_8bit as Lion8bit @@ -403,6 +411,8 @@ def test_fsdp_save_load(dtype: torch.dtype, use_errors: bool, device = 'cuda' if torch.cuda.device_count() < 2: pytest.skip(f'This test requires 2+ GPUs.') + if version.parse(torch.__version__) < version.parse('2.0.1'): + pytest.skip(f'This test requires torch 2.0.1 or greater.') torch.cuda.set_device(f'cuda:{os.environ["RANK"]}') # needed for fsdp if not dist.is_initialized(): From 7adfd57b21d81555a98759caba5027f15524fe4b Mon Sep 17 00:00:00 2001 From: root Date: Sat, 19 Aug 2023 00:32:20 +0000 Subject: [PATCH 22/23] fix wrong var for model config (manual merge fail) --- scripts/train/train.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/scripts/train/train.py b/scripts/train/train.py index 2d101af78d..0d9e4e9d10 100644 --- a/scripts/train/train.py +++ b/scripts/train/train.py @@ -400,9 +400,9 @@ def main(cfg: DictConfig): print_trainable_parameters(model) # should not be 100% else: # standard model model = build_composer_model(model_config, tokenizer) - if cfg.model.get('master_weights_dtype') in ('bf16', 'bfloat16'): + if model_config.get('master_weights_dtype') in ('bf16', 'bfloat16'): model = model.to(dtype=torch.bfloat16) - elif cfg.model.get('master_weights_dtype') in ('f16', 'float16'): + elif model_config.get('master_weights_dtype') in ('f16', 'float16'): model = model.to(dtype=torch.float16) # Log number of parameters From 284a855a0643e4eb4d3875e53461f288d3936fb5 Mon Sep 17 00:00:00 2001 From: root Date: Thu, 24 Aug 2023 21:41:21 +0000 Subject: [PATCH 23/23] print thruputs in thruput test as per pr comments --- tests/test_lion8b.py | 11 +++++------ 1 file changed, 5 insertions(+), 6 deletions(-) diff --git a/tests/test_lion8b.py b/tests/test_lion8b.py index 7a4cd25982..2852d99b8b 100644 --- a/tests/test_lion8b.py +++ b/tests/test_lion8b.py @@ -541,9 +541,8 @@ def test_fused_as_fast_as_unfused(N: int, assert times[True] < times['NA'] + atol assert times['ecc'] < times['NA'] + atol - if False: # change to True to check on thruput - print('') - print('time fused (ms): ', times[True] * 1e3) - print('time fused+ecc (ms): ', times['ecc'] * 1e3) - print('time unfused (ms): ', times[False] * 1e3) - print('time unquantized (ms): ', times['NA'] * 1e3) + print('') + print('time fused (ms): ', times[True] * 1e3) + print('time fused+ecc (ms): ', times['ecc'] * 1e3) + print('time unfused (ms): ', times[False] * 1e3) + print('time unquantized (ms): ', times['NA'] * 1e3)