-
Notifications
You must be signed in to change notification settings - Fork 124
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
feat: 🎸 support more lr schedulers and optimizers
Supported new lr schedulers: CosineWarmUp and CosineWarmUpScheduler. See basicts.runner.optim.lr_schedulers; Supported new optimizers: Muon and AdamW_nanoGPT. See basicts.runner.optim.optimizers.
- Loading branch information
Showing
8 changed files
with
348 additions
and
4 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,6 +1,6 @@ | ||
from .launcher import launch_training, launch_evaluation | ||
from .runners import BaseEpochRunner | ||
|
||
__version__ = '0.4.5.1' | ||
__version__ = '0.4.6' | ||
|
||
__all__ = ['__version__', 'launch_training', 'launch_evaluation', 'BaseEpochRunner'] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
from .builder import build_lr_scheduler, build_optim |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,105 @@ | ||
# Modified from easytorch: https://github.com/cnstark/easytorch/blob/96c903eb10deb0f96ca16a7b677ad0051b6f33f6/easytorch/core/optimizer_builder.py | ||
|
||
|
||
from typing import Dict | ||
|
||
from torch import nn, optim | ||
from torch.optim import lr_scheduler | ||
|
||
from . import lr_schedulers as basicts_lr_scheduler | ||
from . import optimizers as basicts_optim | ||
|
||
|
||
def build_optim(optim_cfg: Dict, model: nn.Module) -> optim.Optimizer: | ||
"""Build optimizer from `optim_cfg` | ||
`optim_cfg` is part of config which defines fields about optimizer | ||
structure of `optim_cfg` is | ||
{ | ||
'TYPE': (str or type) optimizer name or type, such as ``Adam``, ``SGD``, | ||
or custom optimizer type. | ||
'PARAM': (Dict) optimizer init params except first param `params` | ||
} | ||
Note: | ||
Optimizer is initialized by reflection, please ensure optim_cfg['TYPE'] is in `torch.optim` | ||
Examples: | ||
optim_cfg = { | ||
'TYPE': 'Adam', | ||
'PARAM': { | ||
'lr': 1e-3, | ||
'betas': (0.9, 0.99) | ||
'eps': 1e-8, | ||
'weight_decay': 0 | ||
} | ||
} | ||
An `Adam` optimizer will be built. | ||
Args: | ||
optim_cfg (Dict): optimizer config | ||
model (nn.Module): model defined by user | ||
Returns: | ||
optimizer (optim.Optimizer) | ||
""" | ||
|
||
if isinstance(optim_cfg['TYPE'], type): | ||
optim_type = optim_cfg['TYPE'] | ||
else: | ||
if hasattr(optim, optim_cfg['TYPE']): | ||
optim_type = getattr(optim, optim_cfg['TYPE']) | ||
else: | ||
optim_type = getattr(basicts_optim, optim_cfg['TYPE']) | ||
optim_param = optim_cfg['PARAM'].copy() | ||
optimizer = optim_type(model.parameters(), **optim_param) | ||
return optimizer | ||
|
||
|
||
def build_lr_scheduler(lr_scheduler_cfg: Dict, optimizer: optim.Optimizer) -> lr_scheduler._LRScheduler: | ||
"""Build lr_scheduler from `lr_scheduler_cfg` | ||
`lr_scheduler_cfg` is part of config which defines fields about lr_scheduler | ||
structure of `lr_scheduler_cfg` is | ||
{ | ||
'TYPE': (str or type) lr_scheduler name or type, such as ``MultiStepLR``, ``CosineAnnealingLR``, | ||
or custom lr_scheduler type | ||
'PARAM': (Dict) lr_scheduler init params except first param `optimizer` | ||
} | ||
Note: | ||
LRScheduler is initialized by reflection, please ensure | ||
lr_scheduler_cfg['TYPE'] is in `torch.optim.lr_scheduler` or `easytorch.easyoptim.easy_lr_scheduler`, | ||
if the `type` is not found in `torch.optim.lr_scheduler`, | ||
it will continue to be search in `easytorch.easyoptim.easy_lr_scheduler` | ||
Examples: | ||
lr_scheduler_cfg = { | ||
'TYPE': 'MultiStepLR', | ||
'PARAM': { | ||
'milestones': [100, 200, 300], | ||
'gamma': 0.1 | ||
} | ||
} | ||
An `MultiStepLR` lr_scheduler will be built. | ||
Args: | ||
lr_scheduler_cfg (Dict): lr_scheduler config | ||
optimizer (nn.Module): optimizer | ||
Returns: | ||
LRScheduler | ||
""" | ||
|
||
lr_scheduler_cfg['TYPE'] = lr_scheduler_cfg['TYPE'] | ||
if isinstance(lr_scheduler_cfg['TYPE'], type): | ||
scheduler_type = lr_scheduler_cfg['TYPE'] | ||
else: | ||
if hasattr(lr_scheduler, lr_scheduler_cfg['TYPE']): | ||
scheduler_type = getattr(lr_scheduler, lr_scheduler_cfg['TYPE']) | ||
else: | ||
scheduler_type = getattr(basicts_lr_scheduler, lr_scheduler_cfg['TYPE']) | ||
scheduler_param = lr_scheduler_cfg['PARAM'].copy() | ||
scheduler_param['optimizer'] = optimizer | ||
scheduler = scheduler_type(**scheduler_param) | ||
return scheduler |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,94 @@ | ||
# define more learning rate shedulers here | ||
|
||
import math | ||
from functools import partial | ||
|
||
from torch.optim import Optimizer | ||
from torch.optim.lr_scheduler import LambdaLR | ||
|
||
__all__ = ['CosineWarmup', 'CosineWarmupRestart'] | ||
|
||
|
||
class CosineWarmup(LambdaLR): | ||
""" | ||
Create a schedule with a learning rate that decreases following the values of the cosine function between the | ||
initial lr set in the optimizer to 0, after a warmup period during which it increases linearly between 0 and the | ||
initial lr set in the optimizer. | ||
Modified from https://github.com/huggingface/transformers/blob/v4.46.0/src/transformers/optimization.py#L144 | ||
Args: | ||
optimizer ([`~torch.optim.Optimizer`]): | ||
The optimizer for which to schedule the learning rate. | ||
num_warmup_steps (`int`): | ||
The number of steps for the warmup phase. | ||
num_training_steps (`int`): | ||
The total number of training steps. | ||
num_cycles (`float`, *optional*, defaults to 0.5): | ||
The number of waves in the cosine schedule (the defaults is to just decrease from the max value to 0 | ||
following a half-cosine). | ||
last_epoch (`int`, *optional*, defaults to -1): | ||
The index of the last epoch when resuming training. | ||
Return: | ||
`torch.optim.lr_scheduler.LambdaLR` with the appropriate schedule. | ||
""" | ||
def __init__(self, optimizer: Optimizer, num_warmup_steps: int, num_training_steps: int, num_cycles: float = 0.5, last_epoch: int = -1): | ||
lr_lambda = partial( | ||
self._get_cosine_schedule_with_warmup_lr_lambda, | ||
num_warmup_steps=num_warmup_steps, | ||
num_training_steps=num_training_steps, | ||
num_cycles=num_cycles, | ||
) | ||
super().__init__(optimizer, lr_lambda, last_epoch) | ||
|
||
@staticmethod | ||
def _get_cosine_schedule_with_warmup_lr_lambda(current_step: int, *, num_warmup_steps: int, num_training_steps: int, num_cycles: float): | ||
if current_step < num_warmup_steps: | ||
return float(current_step) / float(max(1, num_warmup_steps)) | ||
progress = float(current_step - num_warmup_steps) / float(max(1, num_training_steps - num_warmup_steps)) | ||
return max(0.0, 0.5 * (1.0 + math.cos(math.pi * float(num_cycles) * 2.0 * progress))) | ||
|
||
|
||
class CosineWarmupRestarts(LambdaLR): | ||
""" | ||
Create a schedule with a learning rate that decreases following the values of the cosine function between the | ||
initial lr set in the optimizer to 0, with several hard restarts, after a warmup period during which it increases | ||
linearly between 0 and the initial lr set in the optimizer. | ||
# Modified from https://github.com/huggingface/transformers/blob/c2820c94916e34baf4486accae74760972183a2f/src/transformers/optimization.py#L144 | ||
Args: | ||
optimizer ([`~torch.optim.Optimizer`]): | ||
The optimizer for which to schedule the learning rate. | ||
num_warmup_steps (`int`): | ||
The number of steps for the warmup phase. | ||
num_training_steps (`int`): | ||
The total number of training steps. | ||
num_cycles (`int`, *optional*, defaults to 1): | ||
The number of hard restarts to use. | ||
last_epoch (`int`, *optional*, defaults to -1): | ||
The index of the last epoch when resuming training. | ||
Return: | ||
`torch.optim.lr_scheduler.LambdaLR` with the appropriate schedule. | ||
""" | ||
def __init__(self, optimizer: Optimizer, num_warmup_steps: int, num_training_steps: int, num_cycles: int = 1, last_epoch: int = -1): | ||
lr_lambda = partial( | ||
self._get_cosine_with_hard_restarts_schedule_with_warmup_lr_lambda, | ||
num_warmup_steps=num_warmup_steps, | ||
num_training_steps=num_training_steps, | ||
num_cycles=num_cycles, | ||
) | ||
super().__init__(optimizer, lr_lambda, last_epoch) | ||
|
||
@staticmethod | ||
def _get_cosine_with_hard_restarts_schedule_with_warmup_lr_lambda( | ||
current_step: int, *, num_warmup_steps: int, num_training_steps: int, num_cycles: int | ||
): | ||
if current_step < num_warmup_steps: | ||
return float(current_step) / float(max(1, num_warmup_steps)) | ||
progress = float(current_step - num_warmup_steps) / float(max(1, num_training_steps - num_warmup_steps)) | ||
if progress >= 1.0: | ||
return 0.0 | ||
return max(0.0, 0.5 * (1.0 + math.cos(math.pi * ((float(num_cycles) * progress) % 1.0)))) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,145 @@ | ||
# define more optimizers here | ||
import os | ||
import inspect | ||
from typing import Union, Tuple, Optional | ||
|
||
import torch | ||
from torch import Tensor | ||
from torch.optim import AdamW | ||
from torch.optim.optimizer import ParamsT | ||
|
||
from easytorch.device import _DEVICE_TYPE | ||
|
||
|
||
__all__ = ['AdamW_nanoGPT', 'Muon'] | ||
|
||
class AdamW_nanoGPT(AdamW): | ||
|
||
def __init__(self, | ||
params: ParamsT, | ||
lr: Union[float, Tensor] = 1e-3, | ||
betas: Tuple[float, float] = (0.9, 0.999), | ||
eps: float = 1e-8, | ||
weight_decay: float = 1e-2, | ||
amsgrad: bool = False, | ||
*, | ||
maximize: bool = False, | ||
foreach: Optional[bool] = None, | ||
capturable: bool = False, | ||
differentiable: bool = False, | ||
): | ||
params = [p for p in params if p.requires_grad] | ||
# create optim groups. Any parameters that is 2D will be weight decayed, otherwise no. | ||
# i.e. all weight tensors in matmuls + embeddings decay, all biases and layernorms don't. | ||
decay_params = [p for p in params if p.dim() >= 2] | ||
nodecay_params = [p for p in params if p.dim() < 2] | ||
optim_groups = [ | ||
{'params': decay_params, 'weight_decay': weight_decay}, | ||
{'params': nodecay_params, 'weight_decay': 0.0} | ||
] | ||
fused_available = 'fused' in inspect.signature(torch.optim.AdamW).parameters | ||
use_fused = fused_available and _DEVICE_TYPE in ['gpu', 'mlu'] | ||
super().__init__(optim_groups, lr=lr, betas=betas, \ | ||
eps=eps, amsgrad=amsgrad, maximize=maximize, foreach=foreach, \ | ||
capturable=capturable, differentiable=differentiable, fused=use_fused) | ||
|
||
|
||
# ----------------------------------------------------------------------------- | ||
# Muon optimizer | ||
# Source: https://github.com/KellerJordan/modded-nanogpt/blob/master/train_gpt2.py | ||
class Muon(torch.optim.Optimizer): | ||
""" | ||
Muon - MomentUm Orthogonalized by Newton-schulz | ||
Muon internally runs standard SGD-momentum, and then performs an orthogonalization post- | ||
processing step, in which each 2D parameter's update is replaced with the nearest orthogonal | ||
matrix. To efficiently orthogonalize each update, we use a Newton-Schulz iteration, which has | ||
the advantage that it can be stably run in bfloat16 on the GPU. | ||
Some warnings: | ||
- This optimizer assumes that all parameters passed in are 2D. | ||
- It should not be used for the embedding layer, the final fully connected layer, or any {0,1}-D | ||
parameters; those should all be optimized by a standard method (e.g., AdamW). | ||
- To use it with 4D convolutional filters, it works well to just flatten their last 3 dimensions. | ||
- We believe it is unlikely to work well for training with small batch size. | ||
- We believe it may not work well for finetuning pretrained models, but we haven't tested this. | ||
- We have not yet tried this optimizer for training scenarios larger than NanoGPT (124M). | ||
Arguments: | ||
lr: The learning rate used by the internal SGD. | ||
momentum: The momentum used by the internal SGD. | ||
nesterov: Whether to use Nesterov-style momentum in the internal SGD. (recommended) | ||
backend: The chosen backend for the orthogonalization step. (recommended: 'newtonschulz5') | ||
backend_steps: The number of iteration steps to use in the backend, if it is iterative. | ||
""" | ||
def __init__(self, params, lr=0.02, momentum=0.95, nesterov=True, backend='newtonschulz5', backend_steps=5): | ||
defaults = dict(lr=lr, momentum=momentum, nesterov=nesterov, backend=backend, backend_steps=backend_steps) | ||
super().__init__(params, defaults) | ||
self.zeropower_backends = dict(svd=self._zeropower_via_svd, newtonschulz5=self._zeropower_via_newtonschulz5) | ||
|
||
def _zeropower_via_svd(self, G, steps=None): | ||
U, S, V = G.svd() | ||
return U @ V.T | ||
|
||
def _zeropower_via_newtonschulz5(self, G, steps=10, eps=1e-7): | ||
""" | ||
Newton-Schulz iteration to compute the zeroth power / orthogonalization of G. We opt to use a | ||
quintic iteration whose coefficients are selected to maximize the slope at zero. For the purpose | ||
of minimizing steps, it turns out to be empirically effective to keep increasing the slope at | ||
zero even beyond the point where the iteration no longer converges all the way to one everywhere | ||
on the interval. This iteration therefore does not produce UV^T but rather something like US'V^T | ||
where S' is diagonal with S_{ii}' \sim Uniform(0.5, 1.5), which turns out not to hurt model | ||
performance at all relative to UV^T, where USV^T = G is the SVD. | ||
""" | ||
assert len(G.shape) == 2 | ||
a, b, c = (3.4445, -4.7750, 2.0315) | ||
X = G.bfloat16() | ||
X /= (X.norm() + eps) # ensure top singular value <= 1 | ||
if G.size(0) > G.size(1): | ||
X = X.T | ||
for _ in range(steps): | ||
A = X @ X.T | ||
B = A @ X | ||
X = a * X + b * B + c * A @ B | ||
if G.size(0) > G.size(1): | ||
X = X.T | ||
return X | ||
|
||
def step(self): | ||
|
||
for group in self.param_groups: | ||
|
||
lr = group['lr'] | ||
momentum = group['momentum'] | ||
zeropower_backend = self.zeropower_backends[group['backend']] | ||
|
||
# generate weight updates in distributed fashion | ||
total_params = sum(p.numel() for p in group['params']) | ||
updates_flat = torch.zeros(total_params, device='cuda', dtype=torch.bfloat16) | ||
curr_idx = 0 | ||
for i, p in enumerate(group['params']): | ||
# luckily this will perfectly distribute a transformer with multiple of 4 layers to 8 GPUs | ||
if i % int(os.environ['WORLD_SIZE']) == int(os.environ['RANK']): | ||
g = p.grad | ||
assert g is not None | ||
state = self.state[p] | ||
if 'momentum_buffer' not in state: | ||
state['momentum_buffer'] = torch.zeros_like(g) | ||
buf = state['momentum_buffer'] | ||
buf.mul_(momentum).add_(g) | ||
if group['nesterov']: | ||
g = g.add(buf, alpha=momentum) | ||
g = zeropower_backend(g, steps=group['backend_steps']) | ||
g *= max(1, g.size(0)/g.size(1))**0.5 | ||
updates_flat[curr_idx:curr_idx+p.numel()] = g.flatten() | ||
curr_idx += p.numel() | ||
|
||
# sync updates across devices. we are not memory-constrained so can do this simple deserialization | ||
torch.distributed.all_reduce(updates_flat, op=torch.distributed.ReduceOp.SUM) | ||
|
||
# deserialize and apply updates | ||
curr_idx = 0 | ||
for p in group['params']: | ||
g = updates_flat[curr_idx:curr_idx+p.numel()].view_as(p.data).type_as(p.data) | ||
p.data.add_(g, alpha=-lr) | ||
curr_idx += p.numel() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters