Skip to content

Commit

Permalink
Merge pull request #271 from OpenAccess-AI-Collective/quadratic-warmup
Browse files Browse the repository at this point in the history
Quadratic warmup
  • Loading branch information
winglian authored Jul 10, 2023
2 parents cbe3eb1 + 6b08fe5 commit fb48ad7
Show file tree
Hide file tree
Showing 2 changed files with 118 additions and 6 deletions.
60 changes: 59 additions & 1 deletion src/axolotl/utils/schedulers.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
"""Module for custom LRScheduler class"""
import math
from functools import partial

from torch.optim.lr_scheduler import LRScheduler
from torch.optim import Optimizer
from torch.optim.lr_scheduler import LambdaLR, LRScheduler


class InterpolatingLogScheduler(LRScheduler):
Expand Down Expand Up @@ -42,3 +45,58 @@ def get_lr(self):
lrs = [self.max_lr for base_lr in self.base_lrs]

return lrs


def _get_cosine_schedule_with_quadratic_warmup_lr_lambda(
current_step: int,
*,
num_warmup_steps: int,
num_training_steps: int,
num_cycles: float
):
if current_step < num_warmup_steps:
return (float(current_step) / float(max(1, num_warmup_steps))) ** 2
progress = float(current_step - num_warmup_steps) / float(
max(1, num_training_steps - num_warmup_steps)
)
return max(
0.0, 0.5 * (1.0 + math.cos(math.pi * float(num_cycles) * 2.0 * progress))
)


def get_cosine_schedule_with_quadratic_warmup(
optimizer: Optimizer,
num_warmup_steps: int,
num_training_steps: int,
num_cycles: float = 0.5,
last_epoch: int = -1,
):
"""
Create a schedule with a learning rate that decreases following the values of the cosine function between the
initial lr set in the optimizer to 0, after a warmup period during which it increases linearly between 0 and the
initial lr set in the optimizer.
Args:
optimizer ([`~torch.optim.Optimizer`]):
The optimizer for which to schedule the learning rate.
num_warmup_steps (`int`):
The number of steps for the warmup phase.
num_training_steps (`int`):
The total number of training steps.
num_cycles (`float`, *optional*, defaults to 0.5):
The number of waves in the cosine schedule (the defaults is to just decrease from the max value to 0
following a half-cosine).
last_epoch (`int`, *optional*, defaults to -1):
The index of the last epoch when resuming training.
Return:
`torch.optim.lr_scheduler.LambdaLR` with the appropriate schedule.
"""

lr_lambda = partial(
_get_cosine_schedule_with_quadratic_warmup_lr_lambda,
num_warmup_steps=num_warmup_steps,
num_training_steps=num_training_steps,
num_cycles=num_cycles,
)
return LambdaLR(optimizer, lr_lambda, last_epoch)
64 changes: 59 additions & 5 deletions src/axolotl/utils/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import math
import os
import sys
from dataclasses import field
from pathlib import Path
from typing import Optional

Expand All @@ -13,17 +14,67 @@
import transformers
from torch import nn
from torch.optim.lr_scheduler import OneCycleLR
from transformers import EarlyStoppingCallback, Trainer
from transformers import EarlyStoppingCallback, Trainer, TrainingArguments
from transformers.trainer_pt_utils import get_parameter_names

from axolotl.utils.callbacks import (
SaveBetterTransformerModelCallback,
SavePeftModelCallback,
)
from axolotl.utils.schedulers import InterpolatingLogScheduler
from axolotl.utils.schedulers import (
InterpolatingLogScheduler,
get_cosine_schedule_with_quadratic_warmup,
)


class OneCycleLRSchedulerTrainer(Trainer):
class AxolotlTrainingArguments(TrainingArguments):
"""
Extend the base TrainingArguments for axolotl helpers
"""

lr_quadratic_warmup: bool = field(
default=False,
metadata={"help": "Use quadratic warmup for cosine scheduling."},
)


class AxolotlTrainer(Trainer):
"""
Extend the base Trainer for axolotl helpers
"""

args = None # type: AxolotlTrainingArguments

def create_scheduler(
self, num_training_steps: int, optimizer: torch.optim.Optimizer = None
):
"""
Setup the scheduler. The optimizer of the trainer must have been set up either before this method is called or
passed as an argument.
Args:
num_training_steps (int): The number of training steps to do.
optimizer (torch.optim.Optimizer): The training optimizer
"""

# fmt: off
if self.lr_scheduler is None: # type: ignore # pylint: disable=access-member-before-definition
# fmt: on
if (
self.args.lr_scheduler_type == "cosine"
and self.args.lr_quadratic_warmup is True
):
self.lr_scheduler = get_cosine_schedule_with_quadratic_warmup( # pylint: disable=attribute-defined-outside-init
optimizer,
num_warmup_steps=self.args.get_warmup_steps(num_training_steps),
num_training_steps=num_training_steps,
)
else:
return super().create_scheduler(num_training_steps, optimizer)
return self.lr_scheduler


class OneCycleLRSchedulerTrainer(AxolotlTrainer):
"""
Trainer subclass that uses the OneCycleLR scheduler
"""
Expand Down Expand Up @@ -103,6 +154,9 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer):
if cfg.fsdp_config:
training_arguments_kwargs["fsdp_config"] = dict(cfg.fsdp_config)

if cfg.lr_quadratic_warmup is not None:
training_arguments_kwargs["lr_quadratic_warmup"] = cfg.lr_quadratic_warmup

# deepspeed
if (
os.environ.get("ACCELERATE_USE_DEEPSPEED") == "true"
Expand All @@ -128,7 +182,7 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer):
training_arguments_kwargs["hub_model_id"] = cfg.hub_model_id
training_arguments_kwargs["push_to_hub"] = True

training_args = transformers.TrainingArguments(
training_args = AxolotlTrainingArguments(
per_device_train_batch_size=cfg.micro_batch_size,
per_device_eval_batch_size=cfg.eval_batch_size
if cfg.eval_batch_size is not None
Expand Down Expand Up @@ -278,7 +332,7 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer):
trainer_cls = (
OneCycleLRSchedulerTrainer
if cfg.lr_scheduler == "one_cycle" and (cfg.fsdp or cfg.adapter == "qlora")
else transformers.Trainer
else AxolotlTrainer
)
trainer = trainer_cls(
model=model,
Expand Down

0 comments on commit fb48ad7

Please sign in to comment.