-
Notifications
You must be signed in to change notification settings - Fork 20
/
llm_sft.py
89 lines (78 loc) · 2.44 KB
/
llm_sft.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
from swift.llm import sft_main
from typing import Any, Dict
from modelscope import AutoConfig, AutoModelForCausalLM, AutoTokenizer
from torch import dtype as Dtype
from transformers.utils.versions import require_version
from swift.llm import LoRATM, TemplateType, register_model
from swift.utils import get_logger
logger = get_logger()
class CustomModelType:
SOLAR_10_7B_v1 = "solar-10-7b-v1"
SOLAR_10_7B_v1_instruct = "solar-10-7b-instruct-v1"
SolarM_SakuraSolar_SLERP = "SolarM-SakuraSolar-SLERP"
CarbonVillain_en_10_7B_v1 = "CarbonVillain-en-10-7B-v1"
SOLAR_10B_OrcaDPO_Jawade = "SOLAR-10B-OrcaDPO-Jawade"
@register_model(
CustomModelType.SOLAR_10_7B_v1,
"upstage/SOLAR-10.7B-v1.0",
LoRATM.llama2,
TemplateType.llama,
support_vllm=True,
)
@register_model(
CustomModelType.SOLAR_10_7B_v1_instruct,
"upstage/SOLAR-10.7B-Instruct-v1.0",
LoRATM.llama2,
TemplateType.llama,
support_vllm=True,
)
@register_model(
CustomModelType.SolarM_SakuraSolar_SLERP,
"kodonho/SolarM-SakuraSolar-SLERP",
LoRATM.llama2,
TemplateType.llama,
support_vllm=True,
)
@register_model(
CustomModelType.CarbonVillain_en_10_7B_v1,
"jeonsworld/CarbonVillain-en-10.7B-v1",
LoRATM.llama2,
TemplateType.llama,
support_vllm=True,
)
@register_model(
CustomModelType.SOLAR_10B_OrcaDPO_Jawade,
"bhavinjawade/SOLAR-10B-OrcaDPO-Jawade",
LoRATM.llama2,
TemplateType.llama,
support_vllm=True,
)
def get_model_tokenizer(
model_dir: str,
torch_dtype: Dtype,
model_kwargs: Dict[str, Any],
load_model: bool = True,
**kwargs,
):
use_flash_attn = kwargs.pop("use_flash_attn", False)
if use_flash_attn:
require_version("transformers>=4.34")
logger.info("Setting use_flash_attention_2: True")
model_kwargs["use_flash_attention_2"] = True
model_config = AutoConfig.from_pretrained(model_dir, trust_remote_code=True)
model_config.pretraining_tp = 1
model_config.torch_dtype = torch_dtype
logger.info(f"model_config: {model_config}")
tokenizer = AutoTokenizer.from_pretrained(model_dir, trust_remote_code=True)
model = None
if load_model:
model = AutoModelForCausalLM.from_pretrained(
model_dir,
config=model_config,
torch_dtype=torch_dtype,
trust_remote_code=True,
**model_kwargs,
)
return model, tokenizer
if __name__ == "__main__":
output = sft_main()