diff --git a/docs/prompt_template.md b/docs/prompt_template.md
new file mode 100644
index 000000000..81694f51a
--- /dev/null
+++ b/docs/prompt_template.md
@@ -0,0 +1,81 @@
+# Prompt template
+
+This document will show some examples to introduce how to correctly use prompt templates in Neural Speed and [ITREX](https://github.com/intel/intel-extension-for-transformers).
+
+For the base model (without SFT for pre-training), prompt can be directly encoded into token ids without adding any special prefix or suffix token. But for the chat model, we need some prompt templates to generate correct and human understandable words. The reason is that these models are usually trained with specific prompt templates.
+
+## Chat with ChatGLM3:
+```python
+from transformers import AutoTokenizer, TextStreamer
+from neural_speed import Model
+
+prompt = "你好"
+tokenizer = AutoTokenizer.from_pretrained(args.model_path, trust_remote_code=True)
+inputs = tokenizer.build_chat_input(prompt)['input_ids']
+model = Model()
+model.init_from_bin(args.model_name, gguf_path)
+outputs = model.generate(inputs, max_new_tokens=300, do_sample=True)
+words = tokenizer.decode(outputs[0])
+```
+
+## Chat with LLaMA2:
+
+```python
+from transformers import AutoTokenizer, TextStreamer
+from intel_extension_for_transformers.transformers import AutoModelForCausalLM, WeightOnlyQuantConfig
+
+# Please change to local path to model, llama2 does not support online conversion, currently.
+model_name = "meta-llama/Llama-2-7b-chat-hf"
+woq_config = WeightOnlyQuantConfig(compute_dtype="int8", weight_dtype="int4")
+tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
+streamer = TextStreamer(tokenizer)
+model = AutoModelForCausalLM.from_pretrained(model_name, quantization_config=woq_config, trust_remote_code=True)
+
+while True:
+ prompt = input("> ").strip()
+ if prompt == "quit":
+ break
+ b_prompt = "[INST]{}[/INST]".format(prompt) # prompt template for llama2
+ inputs = tokenizer(b_prompt, return_tensors="pt").input_ids
+ outputs = model.generate(inputs, streamer=streamer, interactive=True, ignore_prompt=True, do_sample=True)
+```
+
+## Chat with ChatGLM2:
+```python
+from transformers import AutoTokenizer, TextStreamer
+from intel_extension_for_transformers.transformers import AutoModelForCausalLM, WeightOnlyQuantConfig
+
+model_name = "THUDM/chatglm2-6b" # or local path to model
+woq_config = WeightOnlyQuantConfig(compute_dtype="int8", weight_dtype="int4")
+tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
+streamer = TextStreamer(tokenizer)
+model = AutoModelForCausalLM.from_pretrained(model_name, quantization_config=woq_config, trust_remote_code=True)
+
+while True:
+ prompt = input("> ").strip()
+ if prompt == "quit":
+ break
+ prompt = tokenizer.build_prompt(prompt) # prompt template for chatglm2
+ inputs = tokenizer([prompt], return_tensors="pt").input_ids
+ outputs = model.generate(inputs, streamer=streamer, interactive=True, ignore_prompt=True, do_sample=True, n_keep=2)
+```
+
+## Chat with Qwen:
+```python
+from transformers import AutoTokenizer, TextStreamer
+from intel_extension_for_transformers.transformers import AutoModelForCausalLM, WeightOnlyQuantConfig
+
+model_name = "Qwen/Qwen-7B-Chat" # or local path to model
+woq_config = WeightOnlyQuantConfig(compute_dtype="int8", weight_dtype="int4")
+tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
+streamer = TextStreamer(tokenizer)
+model = AutoModelForCausalLM.from_pretrained(model_name, quantization_config=woq_config, trust_remote_code=True)
+
+while True:
+ prompt = input("> ").strip()
+ if prompt == "quit":
+ break
+ prompt = "\n<|im_start|>user\n{}<|im_end|>\n<|im_start|>assistant\n".format(prompt) # prompt template for qwen
+ inputs = tokenizer([prompt], return_tensors="pt").input_ids
+ outputs = model.generate(inputs, streamer=streamer, interactive=True, ignore_prompt=True, do_sample=True)
+```
diff --git a/docs/supported_models.md b/docs/supported_models.md
index fac7d1b0b..03c63e9ed 100644
--- a/docs/supported_models.md
+++ b/docs/supported_models.md
@@ -219,7 +219,8 @@ Neural Speed supports the following models:
ChatGLM-6B,
- ChatGLM2-6B |
+ ChatGLM2-6B,
+ ChatGLM3-6B
✅ |
|
|
diff --git a/neural_speed/__init__.py b/neural_speed/__init__.py
index 322a69dcd..22566b2e1 100644
--- a/neural_speed/__init__.py
+++ b/neural_speed/__init__.py
@@ -24,6 +24,7 @@
class Model:
+
def __init__(self):
self.module = None
self.model = None
@@ -55,7 +56,7 @@ def __import_package(self, model_type):
import neural_speed.bloom_cpp as cpp_model
elif model_type == "chatglm":
import neural_speed.chatglm_cpp as cpp_model
- elif model_type == "chatglm2":
+ elif model_type == "chatglm2" or model_type == "chatglm3":
import neural_speed.chatglm2_cpp as cpp_model
elif model_type == "baichuan":
import neural_speed.baichuan_cpp as cpp_model
@@ -85,6 +86,11 @@ def get_model_type(model_config):
if model_type == "chatglm" and "chatglm2" in model_config._name_or_path:
model_type = "chatglm2"
+ # For ChatGLM3
+ if model_type == "chatglm" and "chatglm3" in model_config._name_or_path:
+ # due to the same model architecture.
+ model_type = "chatglm2"
+
# for TheBloke/falcon-40b-instruct-GPTQ & TheBloke/Falcon-7B-Instruct-GPTQ
if model_type == "RefinedWebModel" or model_type == "RefinedWeb":
model_type = "falcon"
@@ -200,7 +206,7 @@ def init_from_bin(self, model_type, model_path, **generate_kwargs):
def get_max_seq_length():
config = self.config.to_dict()
- # chatglm2, bloom
+ # chatglm2, bloom, chatglm3
if 'seq_length' in config:
return config['seq_length']
# qwen2, llama-2, llama, dolly, gptneox, qwen, qwen1.5, opt, phi
diff --git a/neural_speed/application/CMakeLists.txt b/neural_speed/application/CMakeLists.txt
index cde77862e..2a34f5db3 100644
--- a/neural_speed/application/CMakeLists.txt
+++ b/neural_speed/application/CMakeLists.txt
@@ -65,6 +65,7 @@ compile_quant(quant_bloom quant_model.cpp bloom bloom)
compile_quant(quant_chatglm quant_model.cpp chatglm chatglm)
compile_quant(quant_chatglm2 quant_model.cpp chatglm2 chatglm2)
+compile_quant(quant_chatglm3 quant_model.cpp chatglm2 chatglm2)
compile_quant(quant_baichuan quant_model.cpp baichuan baichuan)
compile_quant(quant_mistral quant_model.cpp mistral llama)
compile_quant(quant_mixtral quant_model.cpp mixtral llama)
@@ -97,7 +98,7 @@ set(mymap_phi 16)
set(mymap_stablelm 17)
set(mymap_whisper 18)
set(mymap_mixtral 19)
-
+set(mymap_chatglm3 20)
function(compile_run TARGET MAIN_CPP MAIN_PY MODEL_NAME MODEL_LIB)
@@ -128,6 +129,7 @@ compile_run(run_starcoder main_run.cpp main_pybind.cpp starcoder starcoder)
compile_run(run_opt main_run.cpp main_pybind.cpp opt opt)
compile_run(run_bloom main_run.cpp main_pybind.cpp bloom bloom)
compile_run(run_chatglm2 main_run.cpp main_pybind.cpp chatglm2 chatglm2)
+compile_run(run_chatglm3 main_run.cpp main_pybind.cpp chatglm3 chatglm3)
compile_run(run_chatglm main_run.cpp main_pybind.cpp chatglm chatglm)
compile_run(run_baichuan main_run.cpp main_pybind.cpp baichuan baichuan)
compile_run(run_mistral main_run.cpp main_pybind.cpp mistral llama)
diff --git a/neural_speed/application/main_pybind.cpp b/neural_speed/application/main_pybind.cpp
index 532df4b04..eb3fb8777 100644
--- a/neural_speed/application/main_pybind.cpp
+++ b/neural_speed/application/main_pybind.cpp
@@ -921,6 +921,10 @@ PYBIND11_MODULE(whisper_cpp, m)
PYBIND11_MODULE(mixtral_cpp, m)
+#elif MODEL_NAME_ID == 20
+
+PYBIND11_MODULE(chatglm3_cpp, m)
+
#endif
{
m.doc() = "cpp model python binding";
diff --git a/neural_speed/application/main_run.cpp b/neural_speed/application/main_run.cpp
index 55cd4fec8..c784348eb 100644
--- a/neural_speed/application/main_run.cpp
+++ b/neural_speed/application/main_run.cpp
@@ -240,7 +240,8 @@ int main(int argc, char** argv) { // NOLINT
std::string prompt = build_prompt_glm2(prompts);
embd_inp = ::model_tokenize(ctx, prompt, false);
embd_inp.insert(embd_inp.begin(), {64790, 64792}); // special prefix
- } else if (params.model_arch == MODEL_CHATGLM || params.model_arch == MODEL_BAICHUAN) {
+ } else if (params.model_arch == MODEL_CHATGLM || params.model_arch == MODEL_BAICHUAN ||
+ params.model_arch == MODEL_CHATGLM3) {
for (auto& i : params.ids) {
embd_inp.emplace_back(i);
}
@@ -646,7 +647,7 @@ int main(int argc, char** argv) { // NOLINT
// display text
if (params.model_arch == MODEL_CHATGLM || params.model_arch == MODEL_CHATGLM2 ||
- params.model_arch == MODEL_BAICHUAN) {
+ params.model_arch == MODEL_BAICHUAN || params.model_arch == MODEL_CHATGLM3) {
static bool is_prompt = true;
if (input_echo) {
if (is_prompt == true) {
diff --git a/neural_speed/convert/common.py b/neural_speed/convert/common.py
index d4e5f49cc..9a9d3a320 100644
--- a/neural_speed/convert/common.py
+++ b/neural_speed/convert/common.py
@@ -90,7 +90,56 @@ def quantize_q4_1(tensor: torch.Tensor) -> torch.CharTensor:
return tensor
+def quantize_q8_0(tensor: torch.Tensor) -> torch.Tensor:
+ # equivalent to ggml_quantize_q8_0 in ggml.c
+ assert tensor.shape[1] % GGML_QK8_0 == 0
+ tensor = tensor.view(-1, GGML_QK8_0)
+ scale = tensor.abs().max(dim=-1, keepdim=True).values / ((1 << 7) - 1)
+ tensor = (tensor / scale).round().clamp(min=-128, max=127).char()
+ # add scale into each block
+ tensor = torch.cat((scale.half().view(torch.int8), tensor), dim=-1)
+ return tensor
+
+
+def quantize_q5_0(tensor: torch.Tensor) -> torch.Tensor:
+ # equivalent to ggml_quantize_q5_0 in ggml.c
+ assert tensor.shape[1] % GGML_QK5_0 == 0
+ tensor = tensor.view(-1, GGML_QK5_0)
+ abs_max_indices = tensor.abs().max(dim=-1, keepdim=True).indices
+ max_values = torch.take_along_dim(tensor, abs_max_indices, dim=-1)
+ scale = max_values / -16
+ tensor = (tensor / scale + 16).round().clamp(min=0, max=31).char()
+ qs = (tensor[:, :16] & 0x0F) | (tensor[:, 16:] << 4)
+ qh = torch.zeros(tensor.shape[:-1], dtype=torch.int32)
+ for i in range(32):
+ qh |= ((tensor[:, i] & 0x10) >> 4).int() << i
+
+ # add scale into each block
+ tensor = torch.cat((scale.half().view(torch.int8), qh[..., None].view(torch.int8), qs), dim=-1)
+ return tensor
+
+
+def quantize_q5_1(tensor: torch.Tensor) -> torch.Tensor:
+ # equivalent to ggml_quantize_q5_1 in ggml.c
+ assert tensor.shape[1] % GGML_QK5_1 == 0
+ tensor = tensor.view(-1, GGML_QK5_1)
+ min_vals = tensor.min(dim=-1, keepdim=True).values
+ max_vals = tensor.max(dim=-1, keepdim=True).values
+ scale = (max_vals - min_vals) / ((1 << 5) - 1)
+ tensor = ((tensor - min_vals) / scale).round().clamp(min=0, max=31).char()
+ qs = (tensor[:, :16] & 0x0F) | (tensor[:, 16:] << 4)
+ qh = torch.zeros(tensor.shape[:-1], dtype=torch.int32)
+ for i in range(32):
+ qh |= ((tensor[:, i] & 0x10) >> 4).int() << i
+
+ # add scale & min into each block
+ tensor = torch.cat(
+ (scale.half().view(torch.int8), min_vals.half().view(torch.int8), qh[..., None].view(torch.int8), qs), dim=-1)
+ return tensor
+
+
class SentencePieceVocab:
+
def __init__(self, fname_tokenizer: Path, fname_added_tokens: Optional[Path]) -> None:
self.sentencepiece_tokenizer = SentencePieceProcessor(str(fname_tokenizer))
added_tokens: Dict[str, int]
diff --git a/neural_speed/convert/convert_chatglm.py b/neural_speed/convert/convert_chatglm.py
index 5cbec2e3a..be84facd6 100644
--- a/neural_speed/convert/convert_chatglm.py
+++ b/neural_speed/convert/convert_chatglm.py
@@ -145,6 +145,173 @@ def load_vocab_for_glm2(path: Path) -> SentencePieceVocab:
return SentencePieceVocab(path, added_tokens_path if added_tokens_path.exists() else None)
+def chatglm3_convert_gguf(model, tokenizer, dir_model, fname_out, ftype, hparams):
+ print("ChatGLM-3.gguf converting: ")
+ list_vars = model.state_dict()
+ for name in list_vars.keys():
+ print("%-80s" % name, list_vars[name].shape, list_vars[name].dtype)
+
+ print(hparams)
+
+ gguf_file = fname_out
+ gguf_writer = gguf.GGUFWriter(gguf_file, "chatglm3")
+ gguf_writer.add_uint32('magic', 0x67676d66)
+ import pdb
+ pdb.set_trace()
+ gguf_writer.add_uint32('version', 1)
+ gguf_writer.add_uint32('n_vocab', hparams["padded_vocab_size"])
+ gguf_writer.add_embedding_length(hparams["hidden_size"])
+
+ gguf_writer.add_uint32('n_mult', 0)
+ gguf_writer.add_head_count(hparams["num_attention_heads"])
+ gguf_writer.add_head_count_kv(0)
+ gguf_writer.add_block_count(hparams["num_layers"])
+
+ gguf_writer.add_rope_dimension_count(0)
+ gguf_writer.add_uint32('ftype', ftype)
+
+ gguf_writer.add_context_length(hparams["seq_length"])
+
+ gguf_writer.add_max_alibi_bias(0)
+
+ gguf_writer.add_uint32('clip_qkv', 0)
+ gguf_writer.add_uint32('par_res', 0)
+
+ gguf_writer.add_uint32('word_embed_proj_dim', 0)
+ gguf_writer.add_uint32('do_layer_norm_before', 0)
+
+ gguf_writer.add_uint32('multi_query_group_num', hparams["multi_query_group_num"])
+
+ gguf_writer.add_feed_forward_length(hparams["ffn_hidden_size"])
+
+ gguf_writer.add_uint32('inner_hidden_size', 0)
+
+ gguf_writer.add_bos_token_id(tokenizer.bos_token_id if tokenizer.bos_token_id is not None else 0)
+ gguf_writer.add_eos_token_id(tokenizer.eos_token_id if tokenizer.eos_token_id is not None else 0)
+ gguf_writer.add_pad_token_id(tokenizer.pad_token_id if tokenizer.pad_token_id is not None else 0)
+ gguf_writer.add_sep_token_id(tokenizer.sep_token_id if tokenizer.sep_token_id is not None else 0)
+
+ def write_vocab_gguf(dir_model):
+ print("gguf: get tokenizer metadata")
+
+ tokens: List[bytes] = []
+ scores: List[float] = []
+ toktypes: List[int] = []
+
+ if Path(dir_model + "/tokenizer.model").is_file():
+ # vocab type sentencepiece
+ print("gguf: get sentencepiece tokenizer vocab, scores and token types")
+
+ vocab = load_vocab_for_glm2(Path(dir_model))
+
+ # NOTE: `all_tokens` returns the base vocabulary and added tokens
+ for text, score in vocab.all_tokens():
+ tokens.append(text)
+ scores.append(score)
+
+ gguf_writer.add_tokenizer_model("chatglm3")
+ gguf_writer.add_token_list(tokens)
+ gguf_writer.add_token_scores(scores)
+
+ print("gguf: get special token ids")
+
+ # If no tokenizer.json: Look for special tokens in config.json
+ if "bos_token_id" in hparams and hparams["bos_token_id"] != None:
+ gguf_writer.add_bos_token_id(hparams["bos_token_id"])
+
+ if "eos_token_id" in hparams and hparams["eos_token_id"] != None:
+ gguf_writer.add_eos_token_id(hparams["eos_token_id"])
+
+ if "unk_token_id" in hparams and hparams["unk_token_id"] != None:
+ gguf_writer.add_unk_token_id(hparams["unk_token_id"])
+
+ if "sep_token_id" in hparams and hparams["sep_token_id"] != None:
+ gguf_writer.add_sep_token_id(hparams["sep_token_id"])
+
+ if "pad_token_id" in hparams and hparams["pad_token_id"] != None:
+ gguf_writer.add_pad_token_id(hparams["pad_token_id"])
+
+ write_vocab_gguf(dir_model)
+
+ # tensor info
+ print("gguf: get tensor metadata")
+ for name in list_vars.keys():
+ data = list_vars[name].squeeze().numpy()
+ if 'inv_freq' in name:
+ print("Converting: %-75s" % name, " shape: %-15s" % str(data.shape))
+ continue
+
+ print("Converting: %-75s" % name, " shape: %-15s" % str(data.shape), end=" ")
+ n_dims = len(data.shape)
+
+ # ftype == 0 -> float32, ftype == 1 -> float16
+ ftype_cur = 0
+ if ftype != 0:
+ if name[-7:] == ".weight" and n_dims == 2:
+ print(" to float16".rjust(15))
+ data = data.astype(np.float16)
+ ftype_cur = 1
+ else:
+ print(" to float32".rjust(15))
+ data = data.astype(np.float32)
+ ftype_cur = 0
+ else:
+ if data.dtype != np.float32:
+ print(" to float32".rjust(15))
+ data = data.astype(np.float32)
+ ftype_cur = 0
+
+ gguf_writer.add_tensor(name, data)
+
+ if "mlp.dense_h_to_4h" in name:
+ name_0 = name.replace("dense_h_to_4h", "dense_h_to_4h_0")
+ name_1 = name.replace("dense_h_to_4h", "dense_h_to_4h_1")
+ shape_0 = data.shape[0]
+ half_shape_0 = int(shape_0 / 2)
+ data_0 = data[0:half_shape_0, :]
+ data_1 = data[half_shape_0:shape_0, :]
+
+ print("Converting: %-75s" % name_0, " shape: %-15s" % str(data_0.shape))
+ print("Converting: %-75s" % name_1, " shape: %-15s" % str(data_1.shape))
+
+ n_dims = len(data_0.shape)
+ assert (len(data_0.shape) == len(data_1.shape))
+ # ftype == 0 -> float32, ftype == 1 -> float16
+ ftype_cur = 0
+ if ftype != 0:
+ if name_0[-7:] == ".weight" and n_dims == 2:
+ print(" to float16".rjust(15))
+ data_0 = data_0.astype(np.float16)
+ data_1 = data_1.astype(np.float32)
+ ftype_cur = 1
+ else:
+ print(" to float32".rjust(15))
+ data_0 = data_0.astype(np.float32)
+ data_1 = data_1.astype(np.float32)
+ ftype_cur = 0
+ else:
+ if data_0.dtype != np.float32:
+ print(" to float32".rjust(15))
+ data_0 = data_0.astype(np.float32)
+ data_1 = data_1.astype(np.float32)
+ ftype_cur = 0
+
+ gguf_writer.add_tensor(name_0, data_0)
+ gguf_writer.add_tensor(name_1, data_1)
+
+ print("gguf: write header")
+ gguf_writer.write_header_to_file()
+ print("gguf: write metadata")
+ gguf_writer.write_kv_data_to_file()
+ print("gguf: write tensors")
+ gguf_writer.write_tensors_to_file()
+
+ gguf_writer.close()
+
+ print("Done. Output file: " + fname_out)
+ print("")
+
+
def chatglm2_convert_gguf(model, tokenizer, dir_model, fname_out, ftype, hparams):
print("ChatGLM-2.gguf converting: ")
list_vars = model.state_dict()
@@ -360,6 +527,156 @@ def write_vocab_gguf(dir_model):
print("")
+def chatglm3_convert(model, tokenizer, dir_model, fname_out, ftype, hparams):
+ print("ChatGLM-3 converting: ")
+ list_vars = model.state_dict()
+ for name in list_vars.keys():
+ print(name, list_vars[name].shape, list_vars[name].dtype)
+
+ fout = open(fname_out, "wb")
+
+ print(hparams)
+
+ fout.write(struct.pack("i", 0x67676d66))
+ fout.write(struct.pack("i", 1))
+
+ fout.write(struct.pack("i", hparams["padded_vocab_size"]))
+ fout.write(struct.pack("i", hparams["hidden_size"]))
+ fout.write(struct.pack("i", 0))
+ fout.write(struct.pack("i", hparams["num_attention_heads"]))
+ fout.write(struct.pack("i", 0))
+ fout.write(struct.pack("i", hparams["num_layers"]))
+ fout.write(struct.pack("i", 0))
+ fout.write(struct.pack("i", ftype))
+ fout.write(struct.pack("i", hparams["seq_length"]))
+ fout.write(struct.pack("f", 0))
+ fout.write(struct.pack("f", 0))
+ fout.write(struct.pack("i", 0))
+
+ fout.write(struct.pack("i", 0)) # word_embed_proj_dim (for opt)
+ fout.write(struct.pack("i", 0)) # do_layer_norm_before (for opt)
+
+ fout.write(struct.pack("i", hparams["multi_query_group_num"]))
+ fout.write(struct.pack("i", hparams["ffn_hidden_size"]))
+ fout.write(struct.pack("i", 0))
+ fout.write(struct.pack("i", 0)) # n_experts
+ fout.write(struct.pack("i", 0)) # n_expert_used
+ fout.write(struct.pack("f", hparams.get("layernorm_epsilon", 1e-5))) # rms_norm_eps or layer_norm_eps
+ fout.write(struct.pack("f", 10000.0)) # freq_base
+ fout.write(struct.pack("f", 1.0)) # rope_factor
+
+ fout.write(struct.pack("f", 0.0)) # config.json "rope_scaling.factor", not enabled
+ fout.write(struct.pack("i", 0)) # rope_scaling.original_max_position_embeddings
+ fout.write(struct.pack("i", 0)) # params["rope_scaling"]["type"] =="yarn" else 0))
+
+ fout.write(struct.pack("i", tokenizer.bos_token_id if tokenizer.bos_token_id is not None else 1))
+ fout.write(struct.pack("i", tokenizer.eos_token_id if tokenizer.eos_token_id is not None else 2))
+ fout.write(struct.pack("i", tokenizer.pad_token_id if tokenizer.pad_token_id is not None else -1))
+ fout.write(struct.pack("i", tokenizer.sep_token_id if tokenizer.sep_token_id is not None else -1))
+
+ tokenizer_path = Path(tokenizer.vocab_file).parent
+ vocab = load_vocab_for_glm2(Path(tokenizer_path))
+
+ counter = 0
+ for text, score in vocab.all_tokens():
+ fout.write(struct.pack("i", len(text)))
+ fout.write(text)
+ fout.write(struct.pack("f", score))
+ counter += 1
+
+ while counter < hparams["padded_vocab_size"]:
+ fout.write(struct.pack("i", len(text)))
+ fout.write(text)
+ fout.write(struct.pack("f", 0))
+ counter += 1
+
+ for name in list_vars.keys():
+ data = list_vars[name].squeeze().numpy()
+ print("Processing variable: " + name + " with shape: ", data.shape)
+ if 'inv_freq' in name:
+ continue
+
+ n_dims = len(data.shape)
+
+ # ftype == 0 -> float32, ftype == 1 -> float16
+ ftype_cur = 0
+ if ftype != 0:
+ if name[-7:] == ".weight" and n_dims == 2:
+ print(" Converting to float16")
+ data = data.astype(np.float16)
+ ftype_cur = 1
+ else:
+ print(" Converting to float32")
+ data = data.astype(np.float32)
+ ftype_cur = 0
+ else:
+ if data.dtype != np.float32:
+ print(" Converting to float32")
+ data = data.astype(np.float32)
+ ftype_cur = 0
+
+ # header
+ str = name.encode("utf-8")
+ fout.write(struct.pack("iii", n_dims, len(str), ftype_cur))
+ for i in range(n_dims):
+ fout.write(struct.pack("i", data.shape[n_dims - 1 - i]))
+ fout.write(str)
+ # data
+ data.tofile(fout)
+
+ if "mlp.dense_h_to_4h" in name:
+ name_0 = name.replace("dense_h_to_4h", "dense_h_to_4h_0")
+ name_1 = name.replace("dense_h_to_4h", "dense_h_to_4h_1")
+ shape_0 = data.shape[0]
+ half_shape_0 = int(shape_0 / 2)
+ data_0 = data[0:half_shape_0, :]
+ data_1 = data[half_shape_0:shape_0, :]
+
+ print("Converting: %-75s" % name_0, " shape: ", data_0.shape)
+ print("Converting: %-75s" % name_1, " shape: ", data_1.shape)
+
+ n_dims = len(data_0.shape)
+ assert (len(data_0.shape) == len(data_1.shape))
+ # ftype == 0 -> float32, ftype == 1 -> float16
+ ftype_cur = 0
+ if ftype != 0:
+ if name_0[-7:] == ".weight" and n_dims == 2:
+ print(" to float16".rjust(15))
+ data_0 = data_0.astype(np.float16)
+ data_1 = data_1.astype(np.float32)
+ ftype_cur = 1
+ else:
+ print(" to float32".rjust(15))
+ data_0 = data_0.astype(np.float32)
+ data_1 = data_1.astype(np.float32)
+ ftype_cur = 0
+ else:
+ if data_0.dtype != np.float32:
+ print(" to float32".rjust(15))
+ data_0 = data_0.astype(np.float32)
+ data_1 = data_1.astype(np.float32)
+ ftype_cur = 0
+
+ str_0 = name_0.encode("utf-8")
+ fout.write(struct.pack("iii", n_dims, len(str_0), ftype_cur))
+ for i in range(n_dims):
+ fout.write(struct.pack("i", data_0.shape[n_dims - 1 - i]))
+ fout.write(str_0)
+ data_0.tofile(fout)
+
+ str_1 = name_1.encode("utf-8")
+ fout.write(struct.pack("iii", n_dims, len(str_1), ftype_cur))
+ for i in range(n_dims):
+ fout.write(struct.pack("i", data_1.shape[n_dims - 1 - i]))
+ fout.write(str_1)
+ data_1.tofile(fout)
+
+ fout.close()
+
+ print("Done. Output file: " + fname_out)
+ print("")
+
+
def chatglm2_convert(model, tokenizer, dir_model, fname_out, ftype, hparams):
print("ChatGLM-2 converting: ")
list_vars = model.state_dict()
@@ -616,10 +933,15 @@ def chatglm1_convert(model, tokenizer, dir_model, fname_out, ftype, hparams):
def main(args_in: Optional[List[str]] = None) -> None:
parser = argparse.ArgumentParser(description="Convert a model to a NE compatible file")
- parser.add_argument("--outtype", choices=["f32", "f16"], help="output format (default: based on input)")
+ parser.add_argument("--outtype",
+ choices=["f32", "f16"],
+ default="f32",
+ help="output format (default: based on input)")
parser.add_argument("--outfile", type=Path, help="path to write to; default: based on input")
- parser.add_argument("--model_hub", choices=["huggingface","modelscope"],
- default="huggingface", help="hub to load model")
+ parser.add_argument("--model_hub",
+ choices=["huggingface", "modelscope"],
+ default="huggingface",
+ help="hub to load model")
parser.add_argument("model", type=Path, help="directory containing model file")
parser.add_argument("--format",
type=str,
@@ -647,7 +969,15 @@ def main(args_in: Optional[List[str]] = None) -> None:
hparams = config.to_dict()
- if hasattr(model.config, "multi_query_attention"):
+ # ChatGLM3 shares the same architecture and model config with ChatGLM2
+ # but its tokenizer further supports system prompts,
+ # so we can check system token to discriminate ChatGLM3 from ChatGLM2.
+ if "<|system|>" in tokenizer.tokenizer.special_tokens:
+ if args.format == "GGUF":
+ chatglm3_convert_gguf(model, tokenizer, dir_model, fname_out, ftype, hparams)
+ else:
+ chatglm3_convert(model, tokenizer, dir_model, fname_out, ftype, hparams)
+ elif hasattr(model.config, "multi_query_attention"):
if args.format == "GGUF":
chatglm2_convert_gguf(model, tokenizer, dir_model, fname_out, ftype, hparams)
else:
diff --git a/neural_speed/models/CMakeLists.txt b/neural_speed/models/CMakeLists.txt
index a0bcf1f1a..a1edeca6f 100644
--- a/neural_speed/models/CMakeLists.txt
+++ b/neural_speed/models/CMakeLists.txt
@@ -36,3 +36,4 @@ add_model(chatglm chatglm/chatglm.cpp chatglm/chatglm_utils.cpp ${MODEL_UTILS_SO
add_model(chatglm2 chatglm/chatglm2.cpp chatglm/chatglm2_utils.cpp ${MODEL_UTILS_SOURCE})
add_model(phi phi/phi.cpp phi/phi_utils.cpp ${MODEL_UTILS_SOURCE})
add_model(stablelm stablelm/stablelm.cpp stablelm/stablelm_utils.cpp ${MODEL_UTILS_SOURCE})
+add_model(chatglm3 chatglm/chatglm2.cpp chatglm/chatglm2_utils.cpp ${MODEL_UTILS_SOURCE})
diff --git a/neural_speed/models/model_utils/gguf.h b/neural_speed/models/model_utils/gguf.h
index 2abdfc7f7..22d0435ca 100644
--- a/neural_speed/models/model_utils/gguf.h
+++ b/neural_speed/models/model_utils/gguf.h
@@ -230,18 +230,30 @@ enum llm_arch {
LLM_ARCH_QWEN,
LLM_ARCH_CHATGLM,
LLM_ARCH_CHATGLM2,
+ LLM_ARCH_CHATGLM3,
LLM_ARCH_PHI,
LLM_ARCH_QWEN2,
LLM_ARCH_UNKNOWN,
};
-static std::map LLM_ARCH_NAMES = {
- {LLM_ARCH_LLAMA, "llama"}, {LLM_ARCH_FALCON, "falcon"}, {LLM_ARCH_GPT2, "gpt2"},
- {LLM_ARCH_GPTJ, "gptj"}, {LLM_ARCH_GPTNEOX, "gptneox"}, {LLM_ARCH_MPT, "mpt"},
- {LLM_ARCH_BAICHUAN, "baichuan"}, {LLM_ARCH_STARCODER, "starcoder"}, {LLM_ARCH_PERSIMMON, "persimmon"},
- {LLM_ARCH_REFACT, "refact"}, {LLM_ARCH_BLOOM, "bloom"}, {LLM_ARCH_STABLELM, "stablelm"},
- {LLM_ARCH_QWEN, "qwen"}, {LLM_ARCH_CHATGLM, "chatglm"}, {LLM_ARCH_CHATGLM2, "chatglm2"},
- {LLM_ARCH_PHI, "phi"}, {LLM_ARCH_QWEN2, "qwen2"}};
+static std::map LLM_ARCH_NAMES = {{LLM_ARCH_LLAMA, "llama"},
+ {LLM_ARCH_FALCON, "falcon"},
+ {LLM_ARCH_GPT2, "gpt2"},
+ {LLM_ARCH_GPTJ, "gptj"},
+ {LLM_ARCH_GPTNEOX, "gptneox"},
+ {LLM_ARCH_MPT, "mpt"},
+ {LLM_ARCH_BAICHUAN, "baichuan"},
+ {LLM_ARCH_STARCODER, "starcoder"},
+ {LLM_ARCH_PERSIMMON, "persimmon"},
+ {LLM_ARCH_REFACT, "refact"},
+ {LLM_ARCH_BLOOM, "bloom"},
+ {LLM_ARCH_STABLELM, "stablelm"},
+ {LLM_ARCH_QWEN, "qwen"},
+ {LLM_ARCH_CHATGLM, "chatglm"},
+ {LLM_ARCH_CHATGLM2, "chatglm2"},
+ {LLM_ARCH_CHATGLM3, "chatglm3"},
+ {LLM_ARCH_PHI, "phi"},
+ {LLM_ARCH_QWEN2, "qwen2"}};
struct gguf_tensor_info {
struct gguf_str name;
diff --git a/neural_speed/models/model_utils/model_types.h b/neural_speed/models/model_utils/model_types.h
index 619130004..3575a7e2b 100644
--- a/neural_speed/models/model_utils/model_types.h
+++ b/neural_speed/models/model_utils/model_types.h
@@ -79,6 +79,7 @@ enum model_archs {
MODEL_OPT,
MODEL_BLOOM,
MODEL_BAICHUAN,
+ MODEL_CHATGLM3,
MODEL_CHATGLM2,
MODEL_CHATGLM,
MODEL_QWEN,
@@ -485,7 +486,7 @@ class model_name_to_arch {
{"falcon", MODEL_FALCON}, {"bloom", MODEL_BLOOM}, {"chatglm2", MODEL_CHATGLM2},
{"chatglm", MODEL_CHATGLM}, {"baichuan", MODEL_BAICHUAN}, {"mistral", MODEL_LLAMA},
{"qwen", MODEL_QWEN}, {"phi", MODEL_PHI}, {"stablelm", MODEL_STABLELM},
- {"whisper", MODEL_WHISPER}, {"mixtral", MODEL_LLAMA}};
+ {"whisper", MODEL_WHISPER}, {"chatglm3", MODEL_CHATGLM3}, {"mixtral", MODEL_LLAMA}};
};
#ifdef __cplusplus
diff --git a/neural_speed/models/model_utils/model_utils.cpp b/neural_speed/models/model_utils/model_utils.cpp
index eb6c0b4f0..bb5b65fa7 100644
--- a/neural_speed/models/model_utils/model_utils.cpp
+++ b/neural_speed/models/model_utils/model_utils.cpp
@@ -143,7 +143,7 @@ static bool kv_cache_init(const struct model_hparams& hparams, struct model_kv_c
ne_set_name(cache.cossin, "cossin(-1)");
float freq_base = hparams.freq_base;
float theta = -1 * hparams.freq_scale;
- float theta_scale = (model != nullptr && model->arch == MODEL_CHATGLM2)
+ float theta_scale = (model != nullptr && (model->arch == MODEL_CHATGLM2 || model->arch == MODEL_CHATGLM3))
? std::pow(freq_base, -2.0f / (head_size / 2)) // chatglm2 has their DIM_SCALE of 2
: hparams.n_rot > 0 ? std::pow(freq_base, -2.0f / hparams.n_rot)
: std::pow(freq_base, -2.0f / head_size);
@@ -929,7 +929,7 @@ struct model_context* model_init_from_file(const char* path_model, struct model_
const auto& hparams = ctx->model.hparams;
if (params.shift_roped_k) {
- const std::array supported{MODEL_LLAMA, MODEL_GPTJ, MODEL_CHATGLM2};
+ const std::array supported{MODEL_LLAMA, MODEL_GPTJ, MODEL_CHATGLM2, MODEL_CHATGLM3};
NE_ASSERT(("Current model does not support shifting RoPE-ed K cache",
std::any_of(supported.cbegin(), supported.cend(), [arch](auto m) { return arch == m; })));
}
@@ -951,7 +951,8 @@ struct model_context* model_init_from_file(const char* path_model, struct model_
: NE_TYPE_COUNT;
NE_ASSERT(memory_type != NE_TYPE_COUNT);
- const bool kv_in_layers = (arch == MODEL_CHATGLM2 || arch == MODEL_CHATGLM || arch == MODEL_BAICHUAN);
+ const bool kv_in_layers =
+ (arch == MODEL_CHATGLM3 || arch == MODEL_CHATGLM2 || arch == MODEL_CHATGLM || arch == MODEL_BAICHUAN);
if (!kv_cache_init(ctx->model.hparams, ctx->model.kv_self, memory_type, ctx->n_ctx, ctx->max_request_num,
ctx->beam_size, params.shift_roped_k, (kv_in_layers ? &ctx->model : nullptr))) {
fprintf(stderr, "%s: kv_cache_init() failed for self-attention cache\n", __func__);
diff --git a/scripts/inference.py b/scripts/inference.py
index c78940777..27dc8164c 100644
--- a/scripts/inference.py
+++ b/scripts/inference.py
@@ -168,6 +168,15 @@ def main(args_in: Optional[List[str]] = None) -> None:
token_ids_list = map(str, token_ids_list)
token_ids_str = ', '.join(token_ids_list)
cmd.extend(["--ids", token_ids_str])
+
+ elif (args.model_name == "chatglm3"):
+ tokenizer = AutoTokenizer.from_pretrained(args.tokenizer, trust_remote_code=True)
+ token_ids_tensor = tokenizer.build_chat_input(prompt_text)['input_ids']
+ token_ids_list = token_ids_tensor.tolist()[0]
+ token_ids_list = map(str, token_ids_list)
+ token_ids_str = ', '.join(token_ids_list)
+ cmd.extend(["--ids", token_ids_str])
+
elif (args.model_name == "baichuan" or args.model_name == "qwen"):
tokenizer = AutoTokenizer.from_pretrained(args.tokenizer, trust_remote_code=True)
diff --git a/scripts/run.py b/scripts/run.py
index 4552136e8..b615f1f26 100644
--- a/scripts/run.py
+++ b/scripts/run.py
@@ -41,11 +41,16 @@ def main(args_in: Optional[List[str]] = None) -> None:
# quantization related arguments.
parser.add_argument(
"--weight_dtype",
- choices=["int4", "int8", "fp8", "fp8_e5m2", "fp8_e4m3",
- "fp4", "fp4_e2m1", "nf4"],
+ choices=["int4", "int8", "fp8", "fp8_e5m2", "fp8_e4m3", "fp4", "fp4_e2m1", "nf4"],
help="Data type of quantized weight: int4/int8 (default int4)",
default="int4",
)
+ parser.add_argument(
+ "--model_type",
+ type=str,
+ help="Setting the model_type manually)",
+ default=None,
+ )
parser.add_argument(
"--alg",
type=str,
@@ -74,13 +79,11 @@ def main(args_in: Optional[List[str]] = None) -> None:
help="Data type of Gemm computation: int8/bf16/fp32 (default: int8)",
default="int8",
)
- parser.add_argument(
- "--format",
- type=str,
- default="NE",
- choices=["NE", "GGUF"],
- help="Convert to the GGUF or NE format"
- )
+ parser.add_argument("--format",
+ type=str,
+ default="NE",
+ choices=["NE", "GGUF"],
+ help="Convert to the GGUF or NE format")
parser.add_argument(
"--use_ggml",
action="store_true",
@@ -175,12 +178,19 @@ def main(args_in: Optional[List[str]] = None) -> None:
# Handles Missing token ID for gated models
except Exception as e:
if e.response.status_code == 401:
- print("You are required to input an access token ID for {}, please add it in option --token or download model weights locally".format(args.model))
+ print(
+ "You are required to input an access token ID for {}, please add it in option --token or download model weights locally"
+ .format(args.model))
sys.exit(f"{e}")
parent_path = Path(__file__).parent.absolute()
- config = AutoConfig.from_pretrained(dir_model)
- model_type = model_maps.get(config.model_type, config.model_type)
+ config = AutoConfig.from_pretrained(dir_model, trust_remote_code=True)
+
+ if args.model_type == None:
+ model_type = model_maps.get(config.model_type, config.model_type)
+ else:
+ model_type = args.model_type
+
work_path = Path(model_type + "_files")
if not work_path.exists():
Path.mkdir(work_path)
@@ -198,7 +208,8 @@ def main(args_in: Optional[List[str]] = None) -> None:
# 2. quantize
path = Path(parent_path, "quantize.py")
- quant_file = f"gguf_{model_type}_{args.weight_dtype}.gguf" if str(args.format) == "GGUF" else f"ne_{model_type}_{args.weight_dtype}.bin"
+ quant_file = f"gguf_{model_type}_{args.weight_dtype}.gguf" if str(
+ args.format) == "GGUF" else f"ne_{model_type}_{args.weight_dtype}.bin"
quant_cmd = ["python", path]
quant_cmd.extend(["--model_name", model_type])
quant_cmd.extend(["--model_file", Path(work_path, outfile + ".gguf" if str(args.format) == "GGUF" else outfile)])
@@ -220,7 +231,8 @@ def main(args_in: Optional[List[str]] = None) -> None:
infer_cmd.extend(["--model_name", model_type])
infer_cmd.extend(["-m", Path(work_path, quant_file)])
infer_cmd.extend(["--prompt", args.prompt])
- infer_cmd.extend(["--file", args.file])
+ if args.file != None:
+ infer_cmd.extend(["--file", args.file])
infer_cmd.extend(["--n_predict", str(args.n_predict)])
infer_cmd.extend(["--threads", str(args.threads)])
infer_cmd.extend(["--batch_size_truncate", str(args.batch_size_truncate)])
@@ -232,7 +244,7 @@ def main(args_in: Optional[List[str]] = None) -> None:
infer_cmd.extend(["--one_click_run", "True"])
if args.shift_roped_k:
infer_cmd.extend(["--shift-roped-k"])
- if (model_type == "baichuan" or model_type == "qwen"):
+ if (model_type == "baichuan" or model_type == "qwen" or model_type == "chatglm3"):
infer_cmd.extend(["--tokenizer", dir_model])
print("Inference model ...")
subprocess.run(infer_cmd)
diff --git a/tests/model-test/cpp_graph_inference.sh b/tests/model-test/cpp_graph_inference.sh
index 288cbe0cf..7702f3db6 100644
--- a/tests/model-test/cpp_graph_inference.sh
+++ b/tests/model-test/cpp_graph_inference.sh
@@ -146,6 +146,7 @@ model_name_map["starcoder-3b"]="bigcode/starcoder"
model_name_map["bloom-7b"]="bigscience/bloom-7b1"
model_name_map["opt-1.3b"]="facebook/opt-1.3b"
model_name_map["dolly-v2-3b"]="databricks/dolly-v2-3b"
+model_name_map["chatglm3"]="THUDM/chatglm3-6b"
model_name_map["chatglm2"]="THUDM/chatglm2-6b"
model_name_map["chatglm-6b"]="THUDM/chatglm-6b"
model_name_map["baichuan2-13b"]="baichuan-inc/Baichuan2-13B-Chat"
@@ -234,6 +235,13 @@ function main() {
convert_script="${convert_script}/convert_chatglm.py --format=GGUF"
infer_cmd="./build/bin/run_chatglm2"
input_list=(32 1024)
+ elif [[ "${model}" == "chatglm3-6b" ]]; then
+ quant_script="./build/bin/quant_chatglm3"
+ convert_script="${convert_script}/convert_chatglm.py"
+ infer_cmd="python $working_dir/scripts/inference.py"
+ extension=" --model_name chatglm3 --tokenizer $model_path"
+ requirements_file="$working_dir/neural_speed/models/requirements/chatglm-6b.sh"
+ input_list=(32 1024)
elif [[ "${model}" == "chatglm-6b" ]]; then
quant_script="./build/bin/quant_chatglm"
convert_script="${convert_script}/convert_chatglm.py"