-
-
Notifications
You must be signed in to change notification settings - Fork 944
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* wip qlora + fsdp fixes * more fixes * make sure to load the lora 🤦 * only setup quantized meta on non-zero rank: * only run setup_quantized_peft_meta_for_training for qlora+fsdp * more fixes for qlora+fsdp * chore: lint * add example yml * support mistral too * fix for model_type and add mixtral support too * set cpu_offload: false to reduce vram, constrain new accleerator logic to qlora + fsdp * refactor for duplicate code
- Loading branch information
Showing
8 changed files
with
502 additions
and
9 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,70 @@ | ||
base_model: NousResearch/Llama-2-7b-hf | ||
model_type: LlamaForCausalLM | ||
tokenizer_type: LlamaTokenizer | ||
|
||
load_in_8bit: false | ||
load_in_4bit: true | ||
strict: false | ||
|
||
datasets: | ||
- path: yahma/alpaca-cleaned | ||
type: alpaca | ||
dataset_prepared_path: last_run_prepared | ||
val_set_size: 0.05 | ||
output_dir: ./qlora-out | ||
|
||
adapter: qlora | ||
lora_model_dir: | ||
|
||
sequence_len: 512 | ||
sample_packing: false | ||
pad_to_sequence_len: true | ||
|
||
lora_r: 32 | ||
lora_alpha: 16 | ||
lora_dropout: 0.05 | ||
lora_target_modules: | ||
lora_target_linear: true | ||
lora_fan_in_fan_out: | ||
|
||
wandb_project: | ||
wandb_entity: | ||
wandb_watch: | ||
wandb_name: | ||
wandb_log_model: | ||
|
||
gradient_accumulation_steps: 4 | ||
micro_batch_size: 4 | ||
num_epochs: 4 | ||
optimizer: paged_adamw_8bit | ||
lr_scheduler: cosine | ||
learning_rate: 0.00001 | ||
|
||
train_on_inputs: false | ||
group_by_length: false | ||
bf16: auto | ||
fp16: | ||
tf32: false | ||
|
||
gradient_checkpointing: true | ||
gradient_checkpointing_kwargs: | ||
use_reentrant: true | ||
early_stopping_patience: | ||
resume_from_checkpoint: | ||
local_rank: | ||
logging_steps: 1 | ||
xformers_attention: | ||
flash_attention: true | ||
|
||
warmup_steps: 10 | ||
evals_per_epoch: 4 | ||
eval_table_size: | ||
saves_per_epoch: 1 | ||
debug: | ||
deepspeed: | ||
weight_decay: 0.0 | ||
fsdp: | ||
- full_shard | ||
fsdp_config: | ||
fsdp_transformer_layer_cls_to_wrap: LlamaDecoderLayer | ||
special_tokens: |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,74 @@ | ||
base_model: mistralai/Mixtral-8x7B-v0.1 | ||
model_type: AutoModelForCausalLM | ||
tokenizer_type: LlamaTokenizer | ||
trust_remote_code: true | ||
|
||
load_in_8bit: false | ||
load_in_4bit: true | ||
strict: false | ||
|
||
datasets: | ||
- path: tatsu-lab/alpaca | ||
type: alpaca | ||
dataset_prepared_path: last_run_prepared | ||
val_set_size: 0.02 | ||
output_dir: ./qlora-out | ||
|
||
model_config: | ||
output_router_logits: true | ||
|
||
adapter: qlora | ||
lora_model_dir: | ||
|
||
sequence_len: 1024 | ||
sample_packing: false | ||
pad_to_sequence_len: false | ||
|
||
lora_r: 32 | ||
lora_alpha: 16 | ||
lora_dropout: 0.05 | ||
lora_target_linear: true | ||
lora_fan_in_fan_out: | ||
|
||
wandb_project: | ||
wandb_entity: | ||
wandb_watch: | ||
wandb_name: | ||
wandb_log_model: | ||
|
||
gradient_accumulation_steps: 4 | ||
micro_batch_size: 2 | ||
num_epochs: 1 | ||
optimizer: paged_adamw_8bit | ||
lr_scheduler: cosine | ||
learning_rate: 0.0002 | ||
|
||
train_on_inputs: false | ||
group_by_length: false | ||
bf16: auto | ||
fp16: | ||
tf32: false | ||
|
||
gradient_checkpointing: true | ||
early_stopping_patience: | ||
resume_from_checkpoint: | ||
local_rank: | ||
logging_steps: 1 | ||
xformers_attention: | ||
flash_attention: true | ||
|
||
loss_watchdog_threshold: 5.0 | ||
loss_watchdog_patience: 3 | ||
|
||
warmup_steps: 10 | ||
evals_per_epoch: 4 | ||
eval_table_size: | ||
eval_max_new_tokens: 128 | ||
saves_per_epoch: 1 | ||
debug: | ||
weight_decay: 0.0 | ||
fsdp: | ||
- full_shard | ||
fsdp_config: | ||
fsdp_transformer_layer_cls_to_wrap: MixtralSparseMoeBlock | ||
special_tokens: |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,55 @@ | ||
"""module for building the auto wrap policy for FSDP""" | ||
import functools | ||
|
||
from peft import PrefixEncoder, PromptEmbedding, PromptEncoder | ||
from torch.distributed.fsdp.wrap import ( | ||
_or_policy, | ||
lambda_auto_wrap_policy, | ||
transformer_auto_wrap_policy, | ||
) | ||
from transformers.models.llama.modeling_llama import LlamaDecoderLayer | ||
from transformers.models.mistral.modeling_mistral import MistralDecoderLayer | ||
from transformers.models.mixtral.modeling_mixtral import MixtralDecoderLayer | ||
|
||
SUPPORTED_AUTO_WRAP_MODEL_TYPES = [ | ||
"llama", | ||
"mistral", | ||
"mixtral", | ||
] | ||
|
||
|
||
def get_wrapping_policy_factory(model_type): | ||
if model_type == "llama": | ||
layer_to_wrap = LlamaDecoderLayer | ||
elif model_type == "mistral": | ||
layer_to_wrap = MistralDecoderLayer | ||
elif model_type == "mixtral": | ||
layer_to_wrap = MixtralDecoderLayer | ||
|
||
def get_wrapping_policy(): | ||
"""This checks for lora layers (has weight and requires_grad)""" | ||
|
||
def lambda_policy_fn(module): | ||
return ( | ||
len(list(module.named_children())) == 0 | ||
and getattr(module, "weight", None) is not None | ||
and module.weight.requires_grad | ||
) | ||
|
||
lambda_policy = functools.partial( | ||
lambda_auto_wrap_policy, lambda_fn=lambda_policy_fn | ||
) | ||
transformer_layer_name = layer_to_wrap | ||
transformer_wrap_policy = functools.partial( | ||
transformer_auto_wrap_policy, | ||
transformer_layer_cls=( | ||
PrefixEncoder, | ||
PromptEncoder, | ||
PromptEmbedding, | ||
transformer_layer_name, | ||
), | ||
) | ||
policies = [lambda_policy, transformer_wrap_policy] | ||
return functools.partial(_or_policy, policies=policies) | ||
|
||
return get_wrapping_policy |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.