forked from charent/ChatLM-mini-Chinese
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathsft_train.py
134 lines (110 loc) · 4.96 KB
/
sft_train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
# coding=utf-8
from typing import Dict
import time
import os
import pandas as pd
import numpy as np
import torch
from datasets import Dataset, load_dataset
from peft import LoraConfig
from tqdm import tqdm
from transformers import PreTrainedTokenizerFast, Seq2SeqTrainer, DataCollatorForSeq2Seq,Seq2SeqTrainingArguments
from transformers.generation.configuration_utils import GenerationConfig
from model.chat_model import TextToTextModel
from config import SFTconfig, T5ModelConfig
from utils.functions import get_T5_config, MyTrainerCallback
tqdm.pandas()
os.environ['TF_ENABLE_ONEDNN_OPTS'] = '0'
def get_dataset(file: str, split: str, tokenizer: PreTrainedTokenizerFast, cache_dir: str='.cache') -> Dataset:
"""
加载数据集
"""
# 加载json数据集,如果要加载parquet,更改为'parquet'即可
dataset = load_dataset('json', data_files=file, split=split, cache_dir=cache_dir)
def tokens_to_ids(samples: dict) -> Dict[str, str]:
eos_token_id = tokenizer.eos_token_id
batch_prompt = samples['prompt']
batch_response = samples['response']
encoded_prompt = tokenizer(batch_prompt, truncation=False, padding=False, return_attention_mask=False)
encoded_response = tokenizer(batch_response, truncation=False, padding=False, return_attention_mask=False)
# vocab size 小于65535 可以用 uint16, 每个样本都要添加eos_token_id
input_ids = [np.array(item + [eos_token_id], dtype=np.uint16) for item in encoded_prompt["input_ids"]]
labels = [np.array(item + [eos_token_id], dtype=np.uint16) for item in encoded_response["input_ids"]]
return {
'input_ids': input_ids,
'labels': labels,
}
dataset = dataset.map(tokens_to_ids, batched=True, batch_size=8192, remove_columns=dataset.column_names)
return dataset
def sft_train(config: SFTconfig) -> None:
# step 1. 加载tokenizer
tokenizer = PreTrainedTokenizerFast.from_pretrained(config.tokenizer_dir)
# step 2. 加载预训练模型
model = None
if os.path.isdir(config.finetune_from_ckp_file):
# 传入文件夹则 from_pretrained
model = TextToTextModel.from_pretrained(config.finetune_from_ckp_file)
else:
# load_state_dict
t5_config = get_T5_config(T5ModelConfig(), vocab_size=len(tokenizer), decoder_start_token_id=tokenizer.pad_token_id, eos_token_id=tokenizer.eos_token_id)
model = TextToTextModel(t5_config)
model.load_state_dict(torch.load(config.finetune_from_ckp_file, map_location='cpu')) # set cpu for no exception
# Step 4: Load the dataset
dataset = get_dataset(file=config.sft_train_file, split="train", tokenizer=tokenizer)
# Step 5: Define the training arguments
# T5属于sequence to sequence模型,故要使用Seq2SeqTrainingArguments、DataCollatorForSeq2Seq、Seq2SeqTrainer
# huggingface官网的sft工具适用于language model/LM模型
generation_config = GenerationConfig()
generation_config.remove_invalid_values = True
generation_config.eos_token_id = tokenizer.eos_token_id
generation_config.pad_token_id = tokenizer.pad_token_id
generation_config.decoder_start_token_id = tokenizer.pad_token_id
generation_config.max_new_tokens = 320
generation_config.repetition_penalty = 1.5
generation_config.num_beams = 1 # greedy search
generation_config.do_sample = False # greedy search
training_args = Seq2SeqTrainingArguments(
output_dir=config.output_dir,
per_device_train_batch_size=config.batch_size,
auto_find_batch_size=True, # 防止OOM
gradient_accumulation_steps=config.gradient_accumulation_steps,
learning_rate=config.learning_rate,
logging_steps=config.logging_steps,
num_train_epochs=config.num_train_epochs,
optim="adafactor",
report_to='tensorboard',
log_level='info',
save_steps=config.save_steps,
save_total_limit=3,
fp16=config.fp16,
logging_first_step=config.logging_first_step,
warmup_steps=config.warmup_steps,
seed=config.seed,
generation_config=generation_config,
)
# step 6: init a collator
collator = DataCollatorForSeq2Seq(tokenizer, max_length=config.max_seq_len)
empty_cuda_cahce = MyTrainerCallback()
# Step 7: Define the Trainer
trainer = Seq2SeqTrainer(
model=model,
args=training_args,
train_dataset=dataset,
tokenizer=tokenizer,
data_collator=collator,
callbacks=[empty_cuda_cahce]
)
# step 8: train
trainer.train(
# resume_from_checkpoint=True
)
loss_log = pd.DataFrame(trainer.state.log_history)
log_dir = './logs'
if not os.path.exists(log_dir):
os.mkdir(log_dir)
loss_log.to_csv(f"{log_dir}/sft_train_log_{time.strftime('%Y%m%d-%H%M')}.csv")
# Step 9: Save the model
trainer.save_model(config.output_dir)
if __name__ == '__main__':
config = SFTconfig()
sft_train(config)