Skip to content

Commit

Permalink
Merge pull request #362 from ker2xu/main
Browse files Browse the repository at this point in the history
Typo
  • Loading branch information
shibing624 authored Apr 25, 2024
2 parents 70b9653 + 62aeb39 commit 9bd86ea
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 11 deletions.
10 changes: 5 additions & 5 deletions docs/training_params.md
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

- 第一阶段:PT(Continue PreTraining)增量预训练 `run_pt.sh`
- 第二阶段:SFT(Supervised Fine-tuning)有监督微调 `run_sft.sh`
- 第三阶段
- 第三阶段
- RLHF(Reinforcement Learning from Human Feedback)分为两步:
- RM(Reward Model)奖励模型建模 `run_rm.sh`
- RL(Reinforcement Learning)基于人类反馈的强化学习 `run_ppo.sh`
Expand All @@ -23,7 +23,7 @@
9. PT和SFT支持qlora训练,如果使用的是 RTX4090、A100 或 H100 GPU,支持nf4,使用`--qlora True --load_in_4bit True`参数启用qlora训练,开启qlora训练,会减少显存占用,训练加速,同时建议设置`--torch_dtype bfloat16 --optim paged_adamw_32bit`保证训练精度
10. 扩词表后的增量预训练,PT阶段加上`--modules_to_save embed_tokens,lm_head`参数,后续SFT等阶段不用加
11. 新增了RoPE插值来扩展GPT模型的上下文长度,通过[位置插值方法](https://arxiv.org/abs/2306.15595),在增量数据上进行训练,使模型获得长文本处理能力,使用 `--rope_scaling linear` 参数训练模型,使用`--rope_scaling dynamic` 参数预测模型
12. 针对LLaMA模型支持了[FlashAttention-2](https://github.com/Dao-AILab/flash-attention),如果您使用的是 RTX4090、A100 或 H100 GPU,SFT中请使用 `--flash_attn` 参数以启用 FlashAttention-2
12. 针对LLaMA模型支持了[FlashAttention-2](https://github.com/Dao-AILab/flash-attention),如果您使用的是 RTX3090、RTX4090、A100 或 H100 GPU,SFT中请使用 `--flash_attn` 参数以启用 FlashAttention-2
13. 新增了[LongLoRA](https://github.com/dvlab-research/LongLoRA) 提出的 **$S^2$-Attn**,使模型获得长文本处理能力,SFT中使用 `--shift_attn` 参数以启用该功能
14. 支持了[NEFTune](https://github.com/neelsjain/NEFTune)给embedding加噪SFT训练方法,[NEFTune paper](https://arxiv.org/abs/2310.05914), SFT中使用 `--neft_alpha` 参数启用 NEFTune,例如 `--neft_alpha 5`
15. 支持微调Mixtral混合专家MoE模型 **[Mixtral 8x7B](https://huggingface.co/mistralai/Mixtral-8x7B-v0.1)**,SFT中如果用lora微调模型,可以开启4bit量化和QLoRA`--load_in_4bit True --qlora True`以节省显存,建议设置`--target_modules q_proj,k_proj,v_proj,o_proj`,这样可以避免对MoE专家网络的MLP层量化,因为它们很稀疏且量化后会导致性能效果下降。
Expand All @@ -33,8 +33,8 @@

默认使用LoRA训练,每个stage的LoRA模型权重都需要合并到base model中,使用以下命令合并,下一个stage的`model_name_or_path`指定为合并后的模型文件夹。

LoRA layers were using at all stages to reduce memory requirements.
At each stage the peft adapter layers were merged with the base model, using:
LoRA layers were using at all stages to reduce memory requirements.
At each stage the peft adapter layers were merged with the base model, using:
```shell
python merge_peft_adapter.py \
--base_model base_model_dir \
Expand Down Expand Up @@ -98,7 +98,7 @@ node_rank=$1
echo ${node_rank}
master_addr="10.111.112.223"

torchrun --nproc_per_node 8 --nnodes 2 --master_addr ${master_addr} --master_port 14545 --node_rank ${node_rank} run_supervised_finetuning.py ...
torchrun --nproc_per_node 8 --nnodes 2 --master_addr ${master_addr} --master_port 14545 --node_rank ${node_rank} run_supervised_finetuning.py ...
```


Expand Down
12 changes: 6 additions & 6 deletions supervised_finetuning.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,7 +148,7 @@ class ModelArguments:
)
shift_attn: Optional[bool] = field(
default=False,
metadata={"help": "Enable shift short attention (S^2-Attn) proposed by LongLoRA."}
metadata={"help": "Enable shifted sparse attention (S^2-Attn) proposed by LongLoRA."}
)
neft_alpha: Optional[float] = field(
default=0,
Expand Down Expand Up @@ -1270,16 +1270,16 @@ def filter_empty_labels(example):
logger.warning("FlashAttention-2 is not installed.")
elif model_args.shift_attn and getattr(config, "model_type", None) == "llama":
logger.warning("Using `--flash_attn` for faster training in large context length, enable if your GPU"
" is RTX4090, A100 or H100.")
" is RTX3090, RTX4090, A100 or H100.")

# Set shift short attention (S^2-Attn)
# Set shifted sparse attention (S^2-Attn)
if model_args.shift_attn:
if getattr(config, "model_type", None) == "llama":
setattr(config, "group_size_ratio", 0.25)
apply_llama_patch()
logger.info("Using shift short attention with group_size_ratio=1/4.")
logger.info("Using shifted sparse attention with group_size_ratio=1/4.")
else:
logger.warning("Current model does not support shift short attention.")
logger.warning("Current model does not support shifted sparse attention.")

load_in_4bit = model_args.load_in_4bit
load_in_8bit = model_args.load_in_8bit
Expand Down Expand Up @@ -1407,7 +1407,7 @@ def fp32_forward_post_hook(module: torch.nn.Module, args: Tuple[torch.Tensor], o
tokenizer=tokenizer,
model=model,
label_pad_token_id=IGNORE_INDEX,
pad_to_multiple_of=4 if tokenizer.padding_side == "right" else None, # for shift short attention
pad_to_multiple_of=4 if tokenizer.padding_side == "right" else None, # for shifted sparse attention
)
# Initialize our Trainer
trainer = SavePeftModelTrainer(
Expand Down

0 comments on commit 9bd86ea

Please sign in to comment.