-
Notifications
You must be signed in to change notification settings - Fork 3
/
sft_script.py
59 lines (55 loc) · 1.87 KB
/
sft_script.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
from train.sft import sft_train, sft_train_v2
from argparse import ArgumentParser
from utils.config import llama3_path_a800
import yaml
import os
argumentParser = ArgumentParser()
argumentParser.add_argument(
"--train_config_path",
type=str,
default="train/sft_recipes/hotpot_qa.yaml",
)
argumentParser.add_argument(
"--skipping",
type=int,
default=0,
)
argumentParser.add_argument("--vllm_env", type=str, required=True)
argumentParser.add_argument("--alignment_env", type=str, required=True)
argumentParser.add_argument("--skip_iteration", type=int, default=0)
args = argumentParser.parse_args()
if __name__ == "__main__":
with open(args.train_config_path, "r") as f:
config = yaml.safe_load(f)
os.makedirs(config["mid_yaml_root_path"], exist_ok=True)
os.makedirs(config["mid_dataset_root_path"], exist_ok=True)
os.makedirs(config["mid_jsonl_root_path"], exist_ok=True)
sft_train_v2(
config["origin_yaml_path"],
config["initial_model_path"],
config["initial_dataset_path"],
config["dataset_type"],
config["mid_yaml_root_path"],
config["mid_jsonl_root_path"],
config["mid_dataset_root_path"],
config["check_point_root_path"],
config["initial_episilon"],
config["iteration_times"],
config["port"],
config["devices"],
config["tokenizer_first_path"],
config["tokenizer_second_path"],
config["sample_count"],
config["explore_count"],
config["thread_count"],
config["prompt_pool_path"],
skipping=args.skipping,
cal_ppl=config["cal_ppl"],
skip_iteration=args.skip_iteration,
from_initial=config["from_initial"],
lambda1=config["lambda1"],
lambda2=config["lambda2"],
mix_dataset=[],
vllm_env=args.vllm_env,
alignment_env=args.alignment_env,
)