Skip to content

Commit

Permalink
FDSP + QLoRA (#1378)
Browse files Browse the repository at this point in the history
* wip qlora + fsdp fixes

* more fixes

* make sure to load the lora 🤦

* only setup quantized meta on non-zero rank:

* only run setup_quantized_peft_meta_for_training for qlora+fsdp

* more fixes for qlora+fsdp

* chore: lint

* add example yml

* support mistral too

* fix for model_type and add mixtral support too

* set cpu_offload: false to reduce vram, constrain new accleerator logic to qlora + fsdp

* refactor for duplicate code
  • Loading branch information
winglian authored Mar 8, 2024
1 parent 33a9fc3 commit e7a9c81
Show file tree
Hide file tree
Showing 8 changed files with 502 additions and 9 deletions.
70 changes: 70 additions & 0 deletions examples/llama-2/qlora-fsdp.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
base_model: NousResearch/Llama-2-7b-hf
model_type: LlamaForCausalLM
tokenizer_type: LlamaTokenizer

load_in_8bit: false
load_in_4bit: true
strict: false

datasets:
- path: yahma/alpaca-cleaned
type: alpaca
dataset_prepared_path: last_run_prepared
val_set_size: 0.05
output_dir: ./qlora-out

adapter: qlora
lora_model_dir:

sequence_len: 512
sample_packing: false
pad_to_sequence_len: true

lora_r: 32
lora_alpha: 16
lora_dropout: 0.05
lora_target_modules:
lora_target_linear: true
lora_fan_in_fan_out:

wandb_project:
wandb_entity:
wandb_watch:
wandb_name:
wandb_log_model:

gradient_accumulation_steps: 4
micro_batch_size: 4
num_epochs: 4
optimizer: paged_adamw_8bit
lr_scheduler: cosine
learning_rate: 0.00001

train_on_inputs: false
group_by_length: false
bf16: auto
fp16:
tf32: false

gradient_checkpointing: true
gradient_checkpointing_kwargs:
use_reentrant: true
early_stopping_patience:
resume_from_checkpoint:
local_rank:
logging_steps: 1
xformers_attention:
flash_attention: true

warmup_steps: 10
evals_per_epoch: 4
eval_table_size:
saves_per_epoch: 1
debug:
deepspeed:
weight_decay: 0.0
fsdp:
- full_shard
fsdp_config:
fsdp_transformer_layer_cls_to_wrap: LlamaDecoderLayer
special_tokens:
74 changes: 74 additions & 0 deletions examples/mistral/mixtral-qlora-fsdp.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
base_model: mistralai/Mixtral-8x7B-v0.1
model_type: AutoModelForCausalLM
tokenizer_type: LlamaTokenizer
trust_remote_code: true

load_in_8bit: false
load_in_4bit: true
strict: false

datasets:
- path: tatsu-lab/alpaca
type: alpaca
dataset_prepared_path: last_run_prepared
val_set_size: 0.02
output_dir: ./qlora-out

model_config:
output_router_logits: true

adapter: qlora
lora_model_dir:

sequence_len: 1024
sample_packing: false
pad_to_sequence_len: false

lora_r: 32
lora_alpha: 16
lora_dropout: 0.05
lora_target_linear: true
lora_fan_in_fan_out:

wandb_project:
wandb_entity:
wandb_watch:
wandb_name:
wandb_log_model:

gradient_accumulation_steps: 4
micro_batch_size: 2
num_epochs: 1
optimizer: paged_adamw_8bit
lr_scheduler: cosine
learning_rate: 0.0002

train_on_inputs: false
group_by_length: false
bf16: auto
fp16:
tf32: false

gradient_checkpointing: true
early_stopping_patience:
resume_from_checkpoint:
local_rank:
logging_steps: 1
xformers_attention:
flash_attention: true

loss_watchdog_threshold: 5.0
loss_watchdog_patience: 3

warmup_steps: 10
evals_per_epoch: 4
eval_table_size:
eval_max_new_tokens: 128
saves_per_epoch: 1
debug:
weight_decay: 0.0
fsdp:
- full_shard
fsdp_config:
fsdp_transformer_layer_cls_to_wrap: MixtralSparseMoeBlock
special_tokens:
3 changes: 2 additions & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ packaging==23.2
peft==0.9.0
transformers==4.38.2
tokenizers==0.15.0
bitsandbytes>=0.41.1
bitsandbytes>=0.43.0
accelerate==0.26.1
deepspeed==0.13.1
pydantic==2.6.3
Expand Down Expand Up @@ -40,3 +40,4 @@ gcsfs
# adlfs

trl>=0.7.9
fastcore>=1.5.29
Empty file.
55 changes: 55 additions & 0 deletions src/axolotl/core/policies/auto_wrap.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
"""module for building the auto wrap policy for FSDP"""
import functools

from peft import PrefixEncoder, PromptEmbedding, PromptEncoder
from torch.distributed.fsdp.wrap import (
_or_policy,
lambda_auto_wrap_policy,
transformer_auto_wrap_policy,
)
from transformers.models.llama.modeling_llama import LlamaDecoderLayer
from transformers.models.mistral.modeling_mistral import MistralDecoderLayer
from transformers.models.mixtral.modeling_mixtral import MixtralDecoderLayer

SUPPORTED_AUTO_WRAP_MODEL_TYPES = [
"llama",
"mistral",
"mixtral",
]


def get_wrapping_policy_factory(model_type):
if model_type == "llama":
layer_to_wrap = LlamaDecoderLayer
elif model_type == "mistral":
layer_to_wrap = MistralDecoderLayer
elif model_type == "mixtral":
layer_to_wrap = MixtralDecoderLayer

def get_wrapping_policy():
"""This checks for lora layers (has weight and requires_grad)"""

def lambda_policy_fn(module):
return (
len(list(module.named_children())) == 0
and getattr(module, "weight", None) is not None
and module.weight.requires_grad
)

lambda_policy = functools.partial(
lambda_auto_wrap_policy, lambda_fn=lambda_policy_fn
)
transformer_layer_name = layer_to_wrap
transformer_wrap_policy = functools.partial(
transformer_auto_wrap_policy,
transformer_layer_cls=(
PrefixEncoder,
PromptEncoder,
PromptEmbedding,
transformer_layer_name,
),
)
policies = [lambda_policy, transformer_wrap_policy]
return functools.partial(_or_policy, policies=policies)

return get_wrapping_policy
62 changes: 62 additions & 0 deletions src/axolotl/core/trainer_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import importlib.util
import logging
import math
import os
import sys
from abc import abstractmethod
from dataclasses import dataclass, field
Expand All @@ -17,7 +18,10 @@

import torch
import transformers
from accelerate import FullyShardedDataParallelPlugin
from accelerate.utils import str_to_bool
from datasets import Dataset
from torch.distributed.fsdp import MixedPrecision
from torch.optim.lr_scheduler import OneCycleLR
from torch.utils.data import BatchSampler, DataLoader, RandomSampler, SequentialSampler
from transformers import (
Expand All @@ -30,6 +34,7 @@
from transformers.utils import is_sagemaker_mp_enabled
from trl import DPOTrainer

from axolotl.core.policies.auto_wrap import get_wrapping_policy_factory
from axolotl.loraplus import create_loraplus_optimizer
from axolotl.monkeypatch.multipack import SUPPORTED_MULTIPACK_MODEL_TYPES
from axolotl.monkeypatch.relora import ReLoRACallback, ReLoRAScheduler
Expand Down Expand Up @@ -191,6 +196,10 @@ class AxolotlTrainingArguments(TrainingArguments):
default=1e-6,
metadata={"help": "loraplus learning rate for lora embedding layers."},
)
qlora: bool = field(
default=False,
metadata={"help": "whether this is a qlora training"},
)


class AxolotlTrainer(Trainer):
Expand Down Expand Up @@ -468,6 +477,56 @@ def push_to_hub(self, *args, **kwargs) -> str:

return super().push_to_hub(*args, **kwargs)

@wraps(Trainer.create_accelerator_and_postprocess)
def create_accelerator_and_postprocess(self):
rank = int(os.environ.get("LOCAL_RANK", 0))
res = super().create_accelerator_and_postprocess()

if self.args.qlora is False:
return res

# the rest of this method override is specific to fsdp + qlora (for now)
sync_module_states = (
str_to_bool(os.environ.get("FSDP_SYNC_MODULE_STATES", "True")) == 1
)

mp_policy = None
amp = os.environ["ACCELERATE_MIXED_PRECISION"]
if amp == "fp16":
mp_policy = MixedPrecision(
param_dtype=torch.float32,
reduce_dtype=torch.float32,
buffer_dtype=torch.float32,
)
elif amp == "bf16":
mp_policy = MixedPrecision(
param_dtype=torch.float32,
reduce_dtype=torch.float32,
buffer_dtype=torch.float32,
)

# If somehow we figure out how we want to parameterize we want to autocast buffers...
# mp_policy = MixedPrecision(param_dtype=torch.bfloat16, reduce_dtype=torch.bfloat16, buffer_dtype=torch.float32)
# load_param_skip_names = ['inv_freq']

if self.is_fsdp_enabled:
wrapping_policy = get_wrapping_policy_factory(self.args.model_type)
fsdp_plugin = FullyShardedDataParallelPlugin(
auto_wrap_policy=wrapping_policy(),
cpu_offload=False,
use_orig_params=False,
limit_all_gathers=True,
param_init_fn=lambda module: module.to_empty(
device=torch.device("cuda"), recurse=False
)
if (rank != 0 and sync_module_states)
else None,
mixed_precision_policy=mp_policy,
)
self.accelerator.state.fsdp_plugin = fsdp_plugin

return res


class AxolotlMambaTrainer(AxolotlTrainer):
"""
Expand Down Expand Up @@ -787,6 +846,9 @@ def build(self, total_num_steps):
if self.cfg.fsdp_config:
training_arguments_kwargs["fsdp_config"] = dict(self.cfg.fsdp_config)

if self.cfg.adapter == "qlora":
training_arguments_kwargs["qlora"] = True

# deepspeed
if self.cfg.deepspeed:
training_arguments_kwargs["deepspeed"] = self.cfg.deepspeed
Expand Down
2 changes: 1 addition & 1 deletion src/axolotl/utils/bench.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,9 +24,9 @@ def wrapper(*args, **kwargs):
or not torch.cuda.is_available()
or device == "auto"
or torch.device(device).type == "cpu"
or torch.device(device).type == "meta"
):
return default_value

return func(*args, **kwargs)

return wrapper
Expand Down
Loading

0 comments on commit e7a9c81

Please sign in to comment.