Skip to content

Commit

Permalink
update dpo fp16.
Browse files Browse the repository at this point in the history
  • Loading branch information
shibing624 committed Jan 11, 2024
1 parent 9c9eaf1 commit 9aec1e1
Show file tree
Hide file tree
Showing 8 changed files with 22 additions and 11 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,7 @@ Training Stage:
| Supervised Fine-tuning | 有监督微调 | [supervised_finetuning.py](https://github.com/shibing624/MedicalGPT/blob/main/supervised_finetuning.py) | [run_sft.sh](https://github.com/shibing624/MedicalGPT/blob/main/run_sft.sh) |
| Direct Preference Optimization | 直接偏好优化 | [dpo_training.py](https://github.com/shibing624/MedicalGPT/blob/main/dpo_training.py) | [run_dpo.sh](https://github.com/shibing624/MedicalGPT/blob/main/run_dpo.sh) |
| Reward Modeling | 奖励模型建模 | [reward_modeling.py](https://github.com/shibing624/MedicalGPT/blob/main/reward_modeling.py) | [run_rm.sh](https://github.com/shibing624/MedicalGPT/blob/main/run_rm.sh) |
| Reinforcement Learning | 强化学习 | [rl_training.py](https://github.com/shibing624/MedicalGPT/blob/main/rl_training.py) | [run_rl.sh](https://github.com/shibing624/MedicalGPT/blob/main/run_rl.sh) |
| Reinforcement Learning | 强化学习 | [ppo_training.py](https://github.com/shibing624/MedicalGPT/blob/main/ppo_training.py) | [run_ppo.sh](https://github.com/shibing624/MedicalGPT/blob/main/run_ppo.sh) |

- 提供完整PT+SFT+DPO全阶段串起来训练的pipeline:[run_training_dpo_pipeline.ipynb](https://github.com/shibing624/MedicalGPT/blob/main/run_training_dpo_pipeline.ipynb) ,其对应的colab: [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/shibing624/MedicalGPT/blob/main/run_training_dpo_pipeline.ipynb),运行完大概需要15分钟,我运行成功后的副本colab:[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1kMIe3pTec2snQvLBA00Br8ND1_zwy3Gr?usp=sharing)
- 提供完整PT+SFT+RLHF全阶段串起来训练的pipeline:[run_training_ppo_pipeline.ipynb](https://github.com/shibing624/MedicalGPT/blob/main/run_training_ppo_pipeline.ipynb) ,其对应的colab: [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/shibing624/MedicalGPT/blob/main/run_training_ppo_pipeline.ipynb) ,运行完大概需要20分钟,我运行成功后的副本colab:[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1RGkbev8D85gR33HJYxqNdnEThODvGUsS?usp=sharing)
Expand Down
2 changes: 1 addition & 1 deletion docs/training_details.md
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ Reinforcement Learning fine-tuning of llama-7b-sft with the llama-7b-reward rewa
```shell
pip install git+https://github.com/lvwerra/trl
cd scripts
sh run_rl.sh
sh run_ppo.sh
```

### Stage 3: DPO(Direct Preference Optimization)
Expand Down
4 changes: 2 additions & 2 deletions docs/training_params.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
- 第三阶段
- RLHF(Reinforcement Learning from Human Feedback)分为两步:
- RM(Reward Model)奖励模型建模 `run_rm.sh`
- RL(Reinforcement Learning)基于人类反馈的强化学习 `run_rl.sh`
- RL(Reinforcement Learning)基于人类反馈的强化学习 `run_ppo.sh`
- DPO(Direct Preference Optimization)直接偏好优化 `run_dpo.sh`


Expand All @@ -17,7 +17,7 @@
3. 指定训练集,`--train_file_dir`指定训练数据目录,`--validation_file_dir`指定验证数据目录,如果不指定,默认使用`--dataset_name`指定的HF datasets数据集,训练集字段格式见[数据集格式](https://github.com/shibing624/MedicalGPT/wiki/%E6%95%B0%E6%8D%AE%E9%9B%86),建议领域训练集中加入一些通用对话数据,数据集链接见[📚 Dataset](https://github.com/shibing624/MedicalGPT#-dataset),当前默认多轮对话格式,兼容单轮对话,微调训练集如果是alpaca格式,可以用[convert_dataset.py](https://github.com/shibing624/MedicalGPT/blob/main/convert_dataset.py)转为shareGPT格式,即可传入训练
4. 如果运行环境支持deepspeed,加上`--deepspeed deepspeed_zero_stage2_config.json`参数启动zero2模式;显存不足,加上`--deepspeed deepspeed_zero_stage3_config.json --fp16`参数启动zero3混合精度模式
5. 如果gpu支持int8/int4量化,加上`--load_in_4bit True`代表采用4bit量化训练,或者`--load_in_8bit True`代表采用8bit量化训练,均可显著减少显存占用
6. 调试模型`--max_train_samples``--max_eval_samples`指定训练和验证数据集的最大样本数,用于快速验证代码是否可用,训练时请删除这两个参数或者设置为-1
6. 训练集条数控制`--max_train_samples``--max_eval_samples`指定训练和验证数据集的最大样本数,用于快速验证代码是否可用,训练时建议设置为`--max_train_samples -1`表示用全部训练集,`--max_eval_samples 50`表示用50条验证数据
7. 训练方式,指定`--use_peft False`为全参训练(要移除`--fp16`),`--use_peft True`是LoRA训练;注意:全参训练LLaMA-7B模型需要120GB显存,LoRA训练需要13GB显存
8. 支持恢复训练,LoRA训练时指定`--peft_path`为旧的adapter_model.bin所在文件夹路径;全参训练时指定`--resume_from_checkpoint`为旧模型权重的文件夹路径
9. PT和SFT支持qlora训练,如果使用的是 RTX4090、A100 或 H100 GPU,支持nf4,使用`--qlora True --load_in_4bit True`参数启用qlora训练,开启qlora训练,会减少显存占用,训练加速,同时建议设置`--torch_dtype bfloat16 --optim paged_adamw_32bit`保证训练精度
Expand Down
3 changes: 3 additions & 0 deletions dpo_training.py
Original file line number Diff line number Diff line change
Expand Up @@ -403,6 +403,9 @@ def main():
bnb_4bit_compute_dtype=torch_dtype,
) if args.qlora else None,
)
# fixed FP16 ValueError
for param in filter(lambda p: p.requires_grad, model.parameters()):
param.data = param.data.to(torch.float32)

# Initialize our Trainer
if args.gradient_checkpointing:
Expand Down
3 changes: 3 additions & 0 deletions rl_training.py → ppo_training.py
Original file line number Diff line number Diff line change
Expand Up @@ -277,6 +277,9 @@ def main():
trust_remote_code=args.trust_remote_code,
peft_config=peft_config if args.use_peft else None,
)
for param in filter(lambda p: p.requires_grad, model.parameters()):
param.data = param.data.to(torch.float32)

print_trainable_parameters(model)
# Load reward model
default_device = "cuda" if torch.cuda.is_available() else "cpu"
Expand Down
2 changes: 1 addition & 1 deletion run_rl.sh → run_ppo.sh
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
CUDA_VISIBLE_DEVICES=0,1 torchrun --nproc_per_node 2 rl_training.py \
CUDA_VISIBLE_DEVICES=0,1 torchrun --nproc_per_node 2 ppo_training.py \
--model_type bloom \
--model_name_or_path bigscience/bloomz-560m \
--reward_model_name_or_path OpenAssistant/reward-model-deberta-v3-large-v2 \
Expand Down
12 changes: 6 additions & 6 deletions run_training_ppo_pipeline.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -702,7 +702,7 @@
"execution_count": null,
"outputs": [],
"source": [
"!python rl_training.py \\\n",
"!python ppo_training.py \\\n",
" --model_type bloom \\\n",
" --model_name_or_path ./merged-sft \\\n",
" --reward_model_name_or_path ./merged-rm \\\n",
Expand Down Expand Up @@ -736,7 +736,7 @@
"execution_count": null,
"outputs": [],
"source": [
"%ls -lh outputs-rl-v1"
"%ls -lh outputs-ppo-v1"
],
"metadata": {
"collapsed": false
Expand Down Expand Up @@ -768,7 +768,7 @@
"outputs": [],
"source": [
"!python merge_peft_adapter.py --model_type bloom \\\n",
" --base_model merged-sft --lora_model outputs-rl-v1 --output_dir merged-rl/"
" --base_model merged-sft --lora_model outputs-ppo-v1 --output_dir merged-ppo/"
],
"metadata": {
"collapsed": false
Expand All @@ -779,7 +779,7 @@
"execution_count": null,
"outputs": [],
"source": [
"%ls -lh merged-rl/"
"%ls -lh merged-ppo/"
],
"metadata": {
"collapsed": false
Expand All @@ -790,7 +790,7 @@
"execution_count": null,
"outputs": [],
"source": [
"%cat merged-rl/config.json"
"%cat merged-ppo/config.json"
],
"metadata": {
"collapsed": false
Expand Down Expand Up @@ -861,7 +861,7 @@
"execution_count": null,
"outputs": [],
"source": [
"!python inference.py --model_type bloom --base_model merged-rl --interactive"
"!python inference.py --model_type bloom --base_model merged-ppo --interactive"
],
"metadata": {
"collapsed": false,
Expand Down
5 changes: 5 additions & 0 deletions supervised_finetuning.py
Original file line number Diff line number Diff line change
Expand Up @@ -1122,6 +1122,11 @@ def filter_empty_labels(example):
if data_args.max_eval_samples is not None and data_args.max_eval_samples > 0:
max_eval_samples = min(len(eval_dataset), data_args.max_eval_samples)
eval_dataset = eval_dataset.select(range(max_eval_samples))
eval_size = len(eval_dataset)
logger.debug(f"Num eval_samples: {eval_size}")
if eval_size > 500:
logger.warning(f"Num eval_samples is large: {eval_size}, "
f"training slow, consider reduce it by `--max_eval_samples=50`")
logger.debug(f"Example eval_dataset[0]: {eval_dataset[0]}")
eval_dataset = eval_dataset.map(
preprocess_function,
Expand Down

0 comments on commit 9aec1e1

Please sign in to comment.