Skip to content

Commit

Permalink
add role play data.
Browse files Browse the repository at this point in the history
  • Loading branch information
shibing624 committed Aug 2, 2024
1 parent 5751d23 commit d475416
Show file tree
Hide file tree
Showing 7 changed files with 416 additions and 0 deletions.
2 changes: 2 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,8 @@ Supervised Finetuning, RLHF(Reward Modeling and Reinforcement Learning) and DPO(
- DPO方法来自论文[Direct Preference Optimization:Your Language Model is Secretly a Reward Model](https://arxiv.org/pdf/2305.18290.pdf)
- ORPO方法来自论文[ORPO: Monolithic Preference Optimization without Reference Model](https://arxiv.org/abs/2403.07691)
## 🔥 News
[2024/08/02] v2.2版本:支持了角色扮演模型训练,新增了医患对话SFT数据生成代码,详见[Release-v2.2](https://github.com/shibing624/MedicalGPT/releases/tag/2.2.0)

[2024/06/11] v2.1版本:支持了 **[Qwen-2](https://qwenlm.github.io/blog/qwen2/)** 系列模型,详见[Release-v2.1](https://github.com/shibing624/MedicalGPT/releases/tag/2.1.0)

[2024/04/24] v2.0版本:支持了 **[Llama-3](https://huggingface.co/meta-llama)** 系列模型,详见[Release-v2.0](https://github.com/shibing624/MedicalGPT/releases/tag/2.0.0)
Expand Down
35 changes: 35 additions & 0 deletions role_play_data/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@

## 造训练数据

### 数据生成框架
本数据集使用OpenAI API接口生成,流程:

- **种子特征集和基础设定**
- 手工编写的种子集包含基本角色特征。
- LLM从这个种子集生成角色的基础设定。
- **角色设定的进化**
- 第二个种子集包含指导角色设定进化的指令Prompt。
- 这些进化角色的指令Prompt被放到一个指令池中。基于这些进化Prompt,LLM对基础设定实施进化。
- **反馈循环**
- 由人类评估者和GPT-4组成的混合评价系统。此系统对进化后的设定给出反馈。
- 反馈用于迭代更新种子集。如此迭代,我们最终得到一个细致的角色设定数据集。
- **角色扮演和对话生成**
- 使用self-instruction框架基于角色设定生成角色的对话数据。


1. 生成角色设定,分别生成护士角色和患者角色
```bash
cd role_play_data

python role_generate.py
```


2. 生成医患之间的多轮对话
LLM选择:分别用gpt-4o的api和豆包的doubao-character-pro-32k的api生成对话
```bash
python roleplay_data_generate_gpt4.py

python roleplay_data_generate_doubao.py
```

72 changes: 72 additions & 0 deletions role_play_data/role_generate.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
import json
import random

from openai import OpenAI
from tqdm import tqdm

client = OpenAI()
print(client)


def generate(prompt):
print(prompt)
messages = [
{"role": "user", "content": prompt}
]
r = client.chat.completions.create(
model='gpt-4o',
temperature=1,
messages=messages, )
response = r.choices[0].message.content
print("回答:", response)
return response


def generate_role(input_file, save_file, total_lines):
with open(input_file, "r", encoding="utf-8") as file:
lines = file.readlines()
with tqdm(total=total_lines, desc="指令进度") as pbar:
while pbar.n < total_lines:
random.shuffle(lines)
i = 0
sum_str = ""
for line in lines:
i += 1
try:
data = json.loads(line.strip())
except:
print("error:", line.strip())
continue
question = data["system_prompt"]

sum_str += f"{i}.{question}\n\n"

if i == 5:
res = generate(f'请续写下面内容,不少于10条,增加些多样性。\n\n{sum_str}')
res = res.split("\n\n")
for result in res:
result = result.strip()
prefix_length = len(result.split(".", 1)[0]) + 1 # 获取前缀数字的长度,包括后面的点号
result = result[prefix_length:]
if result == "":
continue
json_data = {'system_prompt': result}
# 将数据写入文件
with open(save_file, 'a', encoding='utf-8') as f:
f.write(json.dumps(json_data, ensure_ascii=False) + '\n')

pbar.update(1)
if pbar.n >= total_lines:
break


if __name__ == '__main__':
total_lines = 50
input_file = "seed_nurse_role.jsonl"
save_file = "seed_nurse_role_output.jsonl"
generate_role(input_file, save_file, total_lines)

total_lines = 50
input_file = "seed_patient_role.jsonl"
save_file = "seed_patient_role_output.jsonl"
generate_role(input_file, save_file, total_lines)
81 changes: 81 additions & 0 deletions role_play_data/roleplay_data_generate_doubao.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
import json
import random

from openai import OpenAI
from tqdm import tqdm

client = OpenAI(
api_key="xxx",
base_url="https://ark.cn-beijing.volces.com/api/v3",
)
print(client)


def generate(prompt, system_prompt=''):
print('提示:', prompt)
messages = [
{"role": "system", "content": system_prompt},
{"role": "user", "content": prompt}
]
completion = client.chat.completions.create(
#pro-32k: ep-20240623141021-r77gl
#lite-4k:ep-20240623140948-92n2g
model="ep-20240623141021-r77gl", # your model endpoint ID
messages=messages,
max_tokens=3048,
)
response = completion.choices[0].message.content
print("生成的对话:", response)
return response


file_role1 = "seed_nurse_role.jsonl"
file_role2 = "seed_patient_role.jsonl"
with open(file_role1, "r", encoding="utf-8") as file:
role1s = file.readlines()
with open(file_role2, "r", encoding="utf-8") as file:
role2s = file.readlines()

save_file = "roleplay_train_data_v2.jsonl"
total_lines = 1000 # 10000
max_history_len = 10

with tqdm(total=total_lines, desc="生成对话") as pbar:
while pbar.n < total_lines:
role1 = random.choice(role1s)
role2 = random.choice(role2s)
data1 = json.loads(role1.strip())['system_prompt']
data2 = json.loads(role2.strip())['system_prompt']
p = "你是护士,跟患者对话。\n\n护士角色:" + data1 + '\n患者角色:' + data2
conversation = {"id": str(pbar.n), "system_prompt": p, "conversations": []}

system_prompt = f"护士角色:{data1}\n患者角色:{data2}\n"
print('------' * 10)
print('system_prompt:', system_prompt)
history = []

for i in range(6):
patient_prompt = f"要求你扮演患者,并且根据角色的设定内容模仿 角色相应的对话口吻和风格。你说一句话,完成本轮对话即可。"
for history_turn in history[-max_history_len:]:
patient_prompt += history_turn + '\n'
patient_prompt += "患者:"

patient_response = generate(patient_prompt, system_prompt)
conversation["conversations"].append({"from": "human", "value": patient_response.strip()})
history.append("患者:" + patient_response.strip())

nurse_prompt = f"要求你扮演护士,并且根据角色的设定内容模仿 角色相应的对话口吻和风格。你说一句话,完成本轮对话即可。\n"
for history_turn in history[-max_history_len:]:
nurse_prompt += history_turn + '\n'
nurse_prompt += "护士:"

nurse_response = generate(nurse_prompt, system_prompt)
conversation["conversations"].append({"from": "gpt", "value": nurse_response.strip()})
history.append("护士: " + nurse_response.strip())

with open(save_file, 'a', encoding='utf-8') as f:
f.write(json.dumps(conversation, ensure_ascii=False) + '\n')

pbar.update(1)
if pbar.n >= total_lines:
break
65 changes: 65 additions & 0 deletions role_play_data/roleplay_data_generate_gpt4.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
import json
import random

from openai import OpenAI
from tqdm import tqdm

client = OpenAI()
print(client)


def generate(prompt):
print('提示:', prompt)
messages = [
{"role": "user", "content": prompt}
]
r = client.chat.completions.create(
model='gpt-4o',
messages=messages,
temperature=1,
max_tokens=3048, # 增加max_tokens以生成更长的对话
)
response = r.choices[0].message.content
print("生成的对话:", response)
return response


file_role1 = "seed_nurse_role.jsonl"
file_role2 = "seed_patient_role.jsonl"
with open(file_role1, "r", encoding="utf-8") as file:
role1s = file.readlines()
with open(file_role2, "r", encoding="utf-8") as file:
role2s = file.readlines()

save_file = "roleplay_train_data_v1.jsonl"
total_lines = 1000

with tqdm(total=total_lines, desc="生成对话") as pbar:
while pbar.n < total_lines:
role1 = random.choice(role1s)
role2 = random.choice(role2s)
data1 = json.loads(role1.strip())['system_prompt']
data2 = json.loads(role2.strip())['system_prompt']
p = "你是护士,跟患者对话。\n\n护士角色:" + data1 + '\n患者角色:' + data2
conversation = {"id": str(pbar.n), "system_prompt": p, "conversations": []}

combined_prompt = f"你扮演一个护士,以下对话是你和患者之间的对话。\n护士角色:{data1}\n患者角色:{data2}\n"
combined_prompt += "进行多轮问答(6轮以上)。患者说话以`患者:`开头,护士说话以`护士:`开头。患者先提问。\n"

prompt = combined_prompt + "\n对话开始:\n "
response = generate(prompt)

# 解析生成的多轮对话
lines = response.strip().split('\n')
for line in lines:
if line.startswith("患者"):
conversation["conversations"].append({"from": "human", "value": line.split("患者")[1].strip()[1:]})
elif line.startswith("护士"):
conversation["conversations"].append({"from": "gpt", "value": line.split("护士")[1].strip()[1:]})

with open(save_file, 'a', encoding='utf-8') as f:
f.write(json.dumps(conversation, ensure_ascii=False) + '\n')

pbar.update(1)
if pbar.n >= total_lines:
break
Loading

0 comments on commit d475416

Please sign in to comment.