-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathsft3.py
104 lines (92 loc) · 3.59 KB
/
sft3.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
#!/usr/bin/env python
# _*_ coding: utf-8 _*_
# @Time : 2023/11/18 09:52
# @Author : Roger
# @Version : V 0.1
# @Email : [email protected]
# @File : main.py
import re
from tqdm import tqdm
import json
import hashlib
from datetime import datetime
def get_hashid(content):
return hashlib.md5((content).encode("utf-8")).hexdigest()
class QAParser(object):
def __init__(self, source, model):
self.source = source
self.model = model
self.now_time = datetime.now().strftime('%Y%m%d %H:%M:%S')
self.file_name = ''
def parser(self, input_path, output_path, mode='w'):
f = open(input_path, 'r')
# 样本换成列表
sample_buff_list = []
# 按行读文件
for line in tqdm(f.readlines()):
# 解析样本
sample_buff_list += self.line_parser(json.loads(line))
# 写入样本
if len(sample_buff_list) > 20000:
self.output(sample_buff_list, output_path, mode)
sample_buff_list = []
# 剩余样本写入
self.output(sample_buff_list, output_path, mode)
def line_parser(self, json_text):
# 获取问答部分内容
meta_instruction = json_text.get('meta_instruction')
num_turns = json_text.get('num_turns')
chat = json_text.get('chat')
qa_list = []
for chat_index in range(num_turns):
plain_text = chat['turn_{}'.format(chat_index + 1)]
question = plain_text['Human'].replace('<eoh>\n', '').replace('<|Human|>: ', '')
answer = plain_text['MOSS'].replace('<eom>\n', '').replace('<|MOSS|>: ', '')
# 样本列表
qa_list.append({
'id': get_hashid('{}_{}_{}'.format(self.source, question, answer)),
'问': question,
'答': answer,
'来源': self.source,
'元数据': {
'create_time': self.now_time,
'问题明细': "\"from\": \"human\"", # 当前硬编码,可能需要改为提取方式
'回答明细': "\"from\": \"moss\"",
'扩展字段': json.dumps({
"会话": chat_index + 1, # qa_id,
"多轮序号": chat_index + 1,
"解析模型": self.model,
"meta_instruction": meta_instruction,
"原始文件名": self.file_name,
}, ensure_ascii=False)
}
})
return qa_list
def proc(self, content):
content = content.replace('<eoh>\n', '').replace('<|Human|>: ', '')
content = content.replace('<eom>\n', '').replace('<|MOSS|>: ', '')
# content = re.sub(r'\[MOSS\]:|\[Human\]:|<eoa>|<eoh>', '', content)
content = content.strip()
return content
def output(self, qa_list, output_path, mode='w'):
f = open(output_path, mode)
for qa_sample in qa_list:
f.write(json.dumps(qa_sample, ensure_ascii=False) + '\n')
f.close()
if __name__ == '__main__':
# 模型
model = 'MOSS'
source = 'moss-002-sft-data'
# 加载解析工具
qa_parser = QAParser(source, model)
# 待解析的文件列表
files = [
'moss-003-sft-no-tools.jsonl',
]
# 输出文件路径
output_path = 'data/sample/moss-003-sft-data.json'
for file_name in files:
qa_parser.file_name = file_name
input_path = 'data/moss-003-sft-data/{}'.format(file_name)
# 批量解析
qa_parser.parser(input_path, output_path, mode='a')