Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Dpo #180

Merged
merged 14 commits into from
Aug 27, 2023
Merged

Dpo #180

Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
45 changes: 29 additions & 16 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -20,15 +20,18 @@
## 📖 Introduction

**MedicalGPT** training medical GPT model with ChatGPT training pipeline, implemantation of Pretraining,
Supervised Finetuning, Reward Modeling and Reinforcement Learning.
Supervised Finetuning, RLHF(Reward Modeling and Reinforcement Learning) and DPO(Direct Preference Optimization).

**MedicalGPT** 训练医疗大模型,实现包括二次预训练、有监督微调、奖励建模、强化学习训练。
**MedicalGPT** 训练医疗大模型,实现了包括增量预训练、有监督微调、RLHF(奖励建模、强化学习训练)和DPO(直接偏好优化)

<img src="https://github.com/shibing624/MedicalGPT/blob/main/docs/GPT_Training.jpg" width="860" />
<img src="https://github.com/shibing624/MedicalGPT/blob/main/docs/dpo.jpg" width="860" />

分四阶段训练GPT模型,来自Andrej Karpathy的演讲PDF [State of GPT](https://karpathy.ai/stateofgpt.pdf),视频 [Video](https://build.microsoft.com/en-US/sessions/db3f4859-cd30-4445-a0cd-553c3304f8e2)
- RLHF training pipeline来自Andrej Karpathy的演讲PDF [State of GPT](https://karpathy.ai/stateofgpt.pdf),视频 [Video](https://build.microsoft.com/en-US/sessions/db3f4859-cd30-4445-a0cd-553c3304f8e2)
- DPO方法来自论文[Direct Preference Optimization:Your Language Model is Secretly a Reward Model](https://arxiv.org/pdf/2305.18290.pdf)

## 🔥 News
[2023/08/25] v1.5版本: 新增[DPO(直接偏好优化)](https://arxiv.org/pdf/2305.18290.pdf)方法,DPO通过直接优化语言模型来实现对其行为的精确控制,而无需使用复杂的强化学习,也可以有效学习到人类偏好,DPO相较于RLHF更容易实现且易于训练,效果更好。详见[Release-v1.5](https://github.com/shibing624/MedicalGPT/releases/tag/1.5.0)

[2023/08/08] v1.4版本: 发布基于ShareGPT4数据集微调的中英文Vicuna-13B模型[shibing624/vicuna-baichuan-13b-chat](https://huggingface.co/shibing624/vicuna-baichuan-13b-chat),和对应的LoRA模型[shibing624/vicuna-baichuan-13b-chat-lora](https://huggingface.co/shibing624/vicuna-baichuan-13b-chat-lora),详见[Release-v1.4](https://github.com/shibing624/MedicalGPT/releases/tag/1.4.0)

[2023/08/02] v1.3版本: 新增LLaMA, LLaMA2, Bloom, ChatGLM, ChatGLM2, Baichuan模型的多轮对话微调训练;新增领域词表扩充功能;新增中文预训练数据集和中文ShareGPT微调训练集,详见[Release-v1.3](https://github.com/shibing624/MedicalGPT/releases/tag/1.3.0)
Expand All @@ -40,13 +43,18 @@ Supervised Finetuning, Reward Modeling and Reinforcement Learning.
[2023/06/05] v0.2版本: 以医疗为例,训练领域大模型,实现了四阶段训练:包括二次预训练、有监督微调、奖励建模、强化学习训练。详见[Release-v0.2](https://github.com/shibing624/MedicalGPT/releases/tag/0.2.0)


## 😊 Feature
基于ChatGPT Training Pipeline,本项目实现了领域模型--医疗模型的四阶段训练:
## 😊 Features


基于ChatGPT Training Pipeline,本项目实现了领域模型--医疗行业语言大模型的训练:

- 第一阶段:PT(Continue PreTraining)增量预训练,在海量领域文档数据上二次预训练GPT模型,以注入领域知识

- 第一阶段:PT(Continue PreTraining)增量预训练,在海量领域文档数据上二次预训练GPT模型,以注入领域知识(可选)
- 第二阶段:SFT(Supervised Fine-tuning)有监督微调,构造指令微调数据集,在预训练模型基础上做指令精调,以对齐指令意图
- 第三阶段:RM(Reward Model)奖励模型建模,构造人类偏好排序数据集,训练奖励模型,用来对齐人类偏好,主要是"HHH"原则,具体是"helpful, honest, harmless"
- 第四阶段:RL(Reinforcement Learning)基于人类反馈的强化学习(RLHF),用奖励模型来训练SFT模型,生成模型使用奖励或惩罚来更新其策略,以便生成更高质量、更符合人类偏好的文本
- 第三阶段
- RLHF(Reinforcement Learning from Human Feedback)基于人类反馈对语言模型进行强化学习,分为两步:1)RM(Reward Model)奖励模型建模,构造人类偏好排序数据集,训练奖励模型,用来建模人类偏好,主要是"HHH"原则,具体是"helpful, honest, harmless";2)RL(Reinforcement Learning)强化学习,用奖励模型来训练SFT模型,生成模型使用奖励或惩罚来更新其策略,以便生成更高质量、更符合人类偏好的文本
- [DPO(Direct Preference Optimization)](https://arxiv.org/pdf/2305.18290.pdf)直接偏好优化方法,DPO通过直接优化语言模型来实现对其行为的精确控制,而无需使用复杂的强化学习,也可以有效学习到人类偏好,DPO相较于RLHF更容易实现且易于训练,效果更好



### Release Models
Expand Down Expand Up @@ -99,18 +107,21 @@ pip install -r requirements.txt --upgrade

Training Stage:

| Stage | Introduction | Python script | Shell script |
|:--------------------------------|:-------------|:------------------------------------------------------------------------------------------------------------------------|:------------------------------------------------------------------------------------|
| Stage 1: Continue Pretraining | 增量预训练 | [pretraining.py](https://github.com/shibing624/MedicalGPT/blob/main/pretraining.py) | [run_pt.sh](https://github.com/shibing624/MedicalGPT/blob/main/run_pt.sh) |
| Stage 2: Supervised Fine-tuning | 有监督微调 | [supervised_finetuning.py](https://github.com/shibing624/MedicalGPT/blob/main/supervised_finetuning.py) | [run_sft.sh](https://github.com/shibing624/MedicalGPT/blob/main/run_sft.sh) |
| Stage 3: Reward Modeling | 奖励模型建模 | [reward_modeling.py](https://github.com/shibing624/MedicalGPT/blob/main/reward_modeling.py) | [run_rm.sh](https://github.com/shibing624/MedicalGPT/blob/main/run_rm.sh) |
| Stage 4: Reinforcement Learning | 强化学习 | [rl_training.py](https://github.com/shibing624/MedicalGPT/blob/main/rl_training.py) | [run_rl.sh](https://github.com/shibing624/MedicalGPT/blob/main/run_rl.sh) |
| Stage | Introduction | Python script | Shell script |
|:--------------------------------|:-------------|:--------------------------------------------------------------------------------------------------------|:----------------------------------------------------------------------------|
| Continue Pretraining | 增量预训练 | [pretraining.py](https://github.com/shibing624/MedicalGPT/blob/main/pretraining.py) | [run_pt.sh](https://github.com/shibing624/MedicalGPT/blob/main/run_pt.sh) |
| Supervised Fine-tuning | 有监督微调 | [supervised_finetuning.py](https://github.com/shibing624/MedicalGPT/blob/main/supervised_finetuning.py) | [run_sft.sh](https://github.com/shibing624/MedicalGPT/blob/main/run_sft.sh) |
| Direct Preference Optimization | 直接偏好优化 | [dpo_training.py](https://github.com/shibing624/MedicalGPT/blob/main/dpo_training.py) | [run_dpo.sh](https://github.com/shibing624/MedicalGPT/blob/main/run_dpo.sh) |
| Reward Modeling | 奖励模型建模 | [reward_modeling.py](https://github.com/shibing624/MedicalGPT/blob/main/reward_modeling.py) | [run_rm.sh](https://github.com/shibing624/MedicalGPT/blob/main/run_rm.sh) |
| Reinforcement Learning | 强化学习 | [rl_training.py](https://github.com/shibing624/MedicalGPT/blob/main/rl_training.py) | [run_rl.sh](https://github.com/shibing624/MedicalGPT/blob/main/run_rl.sh) |

- 提供完整四阶段串起来训练的pipeline:[run_training_pipeline.ipynb](https://github.com/shibing624/MedicalGPT/blob/main/run_training_pipeline.ipynb) ,其对应的colab: [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/shibing624/MedicalGPT/blob/main/run_training_pipeline.ipynb) ,运行完大概需要15分钟,我运行成功后的副本colab:[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1RGkbev8D85gR33HJYxqNdnEThODvGUsS?usp=sharing)
- 提供完整PT+SFT+DPO全阶段串起来训练的pipeline:[run_training_pipeline.ipynb](https://github.com/shibing624/MedicalGPT/blob/main/run_training_pipeline.ipynb) ,其对应的colab: [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/shibing624/MedicalGPT/blob/main/run_training_pipeline.ipynb)
- 提供完整PT+SFT+RLHF全阶段串起来训练的pipeline:[run_training_pipeline.ipynb](https://github.com/shibing624/MedicalGPT/blob/main/run_training_pipeline.ipynb) ,其对应的colab: [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/shibing624/MedicalGPT/blob/main/run_training_pipeline.ipynb) ,运行完大概需要15分钟,我运行成功后的副本colab:[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1RGkbev8D85gR33HJYxqNdnEThODvGUsS?usp=sharing)
- [训练参数说明wiki](https://github.com/shibing624/MedicalGPT/wiki/%E8%AE%AD%E7%BB%83%E5%8F%82%E6%95%B0%E8%AF%B4%E6%98%8E)
- [数据集wiki](https://github.com/shibing624/MedicalGPT/wiki/%E6%95%B0%E6%8D%AE%E9%9B%86)
- [扩充词表wiki](https://github.com/shibing624/MedicalGPT/wiki/%E6%89%A9%E5%85%85%E4%B8%AD%E6%96%87%E8%AF%8D%E8%A1%A8)
- [FAQ](https://github.com/shibing624/MedicalGPT/wiki/FAQ)

#### Supported Models
The following models are tested:

Expand Down Expand Up @@ -233,6 +244,7 @@ CUDA_VISIBLE_DEVICES=0 python inference.py \
4. [x] add medical reward dataset
5. [x] add llama in8/int4 training
6. [x] add all training and predict demo in colab
7. [x] add dpo training

## ☎️ Contact

Expand Down Expand Up @@ -284,6 +296,7 @@ CUDA_VISIBLE_DEVICES=0 python inference.py \

## 💕 Acknowledgements

- [Direct Preference Optimization:Your Language Model is Secretly a Reward Model](https://arxiv.org/pdf/2305.18290.pdf)
- [tloen/alpaca-lora](https://github.com/tloen/alpaca-lora/blob/main/finetune.py)
- [ymcui/Chinese-LLaMA-Alpaca](https://github.com/ymcui/Chinese-LLaMA-Alpaca)

Expand Down
Binary file added docs/dpo.jpg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
28 changes: 22 additions & 6 deletions docs/training_details.md
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# Training Detail


### Stage 1: Continue Pretraining
### Stage 1: PT(Continue PreTraining)
第一阶段:PT(Continue PreTraining)增量预训练

使用百科类文档类数据集,用来在领域数据集上增量预训练或二次预训练,期望能把领域知识注入给模型,以医疗领域为例,希望增量预训练,能让模型理解感冒的症状、病因、治疗药品、治疗方法、药品疗效等知识,便于后续的SFT监督微调能激活这些内在知识。
Expand All @@ -25,7 +25,7 @@ sh run_pt.sh
- 如果你的显存不足,可以改小batch_size=1, block_size=512(影响训练的上下文最大长度);
- 如果你的显存更大,可以改大block_size=2048, 此为llama原始预训练长度,不能更大啦;调大batch_size。

### Stage 2: Supervised FineTuning
### Stage 2: SFT(Supervised Fine-tuning)
第二阶段:SFT(Supervised Fine-tuning)有监督微调

基于llama-7b-pt模型,使用医疗问答类数据进行有监督微调,得到llama-7b-sft模型
Expand All @@ -39,8 +39,9 @@ sh run_sft.sh

[训练参数说明wiki](https://github.com/shibing624/MedicalGPT/wiki/%E8%AE%AD%E7%BB%83%E7%BB%86%E8%8A%82%E8%AF%B4%E6%98%8E)

### Stage 3: Reward Modeling
第三阶段:RM(Reward Model)奖励模型建模
### Stage 3: RLHF(Reinforcement Learning from Human Feedback)
#### Reward Modeling
RM(Reward Model)奖励模型建模

RM(Reward Model)奖励模型,原则上,我们可以直接用人类标注来对模型做 RLHF 微调。

Expand All @@ -61,8 +62,8 @@ sh run_rm.sh
```
[训练参数说明wiki](https://github.com/shibing624/MedicalGPT/wiki/%E8%AE%AD%E7%BB%83%E7%BB%86%E8%8A%82%E8%AF%B4%E6%98%8E)

### Stage 4: Reinforcement Learning
第四阶段:RL(Reinforcement Learning)基于人类反馈的强化学习(RLHF)
#### Reinforcement Learning
RL(Reinforcement Learning)强化学习

RL(Reinforcement Learning)模型的目的是最大化奖励模型的输出,基于上面步骤,我们有了微调的语言模型(llama-7b-sft)和奖励模型(llama-7b-reward),
可以开始执行 RL 循环了。
Expand All @@ -86,3 +87,18 @@ cd scripts
sh run_rl.sh
```

### Stage 3: DPO(Direct Preference Optimization)
DPO(Direct Preference Optimization)直接偏好优化

DPO方法可以通过直接优化语言模型来实现对其行为的精确控制,而无需使用复杂的强化学习。

DPO 将奖励函数和最优策略之间的映射联系起来,从而把约束奖励最大化问题转化为一个单阶段的策略训练问题。
这种算法不仅不用拟合奖励模型,还避免了在微调过程中从语言模型中采样或调整超参数的需要。

实验结果表明,DPO 算法可以与现有RLHF方法一样有效地从人类偏好中学习,甚至在某些任务中表现更好,比如情感调节、摘要和单轮对话。

PS: 使用DPO训练LLaMA2-7B在fp16,batch_size为2时,需要70GB显存。

```shell
sh run_dpo.sh
```
Loading