diff --git a/docs/supported_models.md b/docs/supported_models.md
index 05895a9f1..c4a658256 100644
--- a/docs/supported_models.md
+++ b/docs/supported_models.md
@@ -215,7 +215,9 @@ Neural Speed supports the following models:
Qwen-7B,
- Qwen-14B |
+ Qwen-14B,
+ Qwen1.5-7B,
+ Qwen1.5-0.5B
✅ |
|
|
@@ -358,6 +360,14 @@ Neural Speed supports the following models:
✅ |
|
+
+ TheBloke/Mixtral-8x7B-Instruct-v0.1-GGUFF,
+ | ✅ |
+ ✅ |
+ ✅ |
+ ✅ |
+ |
+
TheBloke/SOLAR-10.7B-Instruct-v1.0-GGUF |
✅ |
@@ -410,7 +420,8 @@ Neural Speed supports the following models:
✅ |
- Qwen-7B-Chat |
+ Qwen-7B-Chat,
+ Qwen1.5-7B-Chat-GGUF |
✅ |
✅ |
✅ |
diff --git a/neural_speed/convert/convert_qwen.py b/neural_speed/convert/convert_qwen.py
index 704aa9ee6..7966717b2 100644
--- a/neural_speed/convert/convert_qwen.py
+++ b/neural_speed/convert/convert_qwen.py
@@ -100,9 +100,11 @@ def main(args_in: Optional[List[str]] = None) -> None:
fout.write(struct.pack("i", hparams["num_attention_heads"]))
fout.write(struct.pack("i", 0)) # multi-query attention
fout.write(struct.pack("i", hparams["num_hidden_layers"]))
- fout.write(struct.pack("i", hparams["kv_channels"]))
+ fout.write(struct.pack("i", hparams["kv_channels"] if "kv_channels" in hparams
+ else int(hparams["hidden_size"]/hparams["num_attention_heads"])))
fout.write(struct.pack("i", ftype))
- fout.write(struct.pack("i", hparams["seq_length"]))
+ fout.write(struct.pack("i", hparams["seq_length"] if "seq_length" in hparams
+ else hparams["max_position_embeddings"]))
fout.write(struct.pack("f", 0.0))
fout.write(struct.pack("f", 0.0))
fout.write(struct.pack("i", 0))
@@ -121,9 +123,10 @@ def main(args_in: Optional[List[str]] = None) -> None:
fout.write(struct.pack("f", 0.0)) # config.json "rope_scaling.factor", not enabled
fout.write(struct.pack("i", 0)) # rope_scaling.original_max_position_embeddings
fout.write(struct.pack("i", 0)) # params["rope_scaling"]["type"] =="yarn" else 0))
-
- fout.write(struct.pack("i", tokenizer.special_tokens['<|endoftext|>']))
- fout.write(struct.pack("i", tokenizer.special_tokens['<|endoftext|>']))
+ fout.write(struct.pack("i", hparams["bos_token_id"] if hparams["bos_token_id"]
+ else tokenizer.special_tokens['<|endoftext|>']))
+ fout.write(struct.pack("i", hparams["eos_token_id"] if hparams["eos_token_id"]
+ else tokenizer.special_tokens['<|endoftext|>']))
fout.write(struct.pack("i", tokenizer.pad_token_id if tokenizer.pad_token_id is not None else -1))
fout.write(struct.pack("i", tokenizer.sep_token_id if tokenizer.sep_token_id is not None else -1))
diff --git a/neural_speed/models/model_utils/gguf.h b/neural_speed/models/model_utils/gguf.h
index 251280fa3..2abdfc7f7 100644
--- a/neural_speed/models/model_utils/gguf.h
+++ b/neural_speed/models/model_utils/gguf.h
@@ -231,18 +231,17 @@ enum llm_arch {
LLM_ARCH_CHATGLM,
LLM_ARCH_CHATGLM2,
LLM_ARCH_PHI,
+ LLM_ARCH_QWEN2,
LLM_ARCH_UNKNOWN,
};
static std::map LLM_ARCH_NAMES = {
- {LLM_ARCH_LLAMA, "llama"}, {LLM_ARCH_FALCON, "falcon"},
- {LLM_ARCH_GPT2, "gpt2"}, {LLM_ARCH_GPTJ, "gptj"},
- {LLM_ARCH_GPTNEOX, "gptneox"}, {LLM_ARCH_MPT, "mpt"},
- {LLM_ARCH_BAICHUAN, "baichuan"}, {LLM_ARCH_STARCODER, "starcoder"},
- {LLM_ARCH_PERSIMMON, "persimmon"}, {LLM_ARCH_REFACT, "refact"},
- {LLM_ARCH_BLOOM, "bloom"}, {LLM_ARCH_STABLELM, "stablelm"},
- {LLM_ARCH_QWEN, "qwen"}, {LLM_ARCH_CHATGLM, "chatglm"},
- {LLM_ARCH_CHATGLM2, "chatglm2"}, {LLM_ARCH_PHI, "phi"}};
+ {LLM_ARCH_LLAMA, "llama"}, {LLM_ARCH_FALCON, "falcon"}, {LLM_ARCH_GPT2, "gpt2"},
+ {LLM_ARCH_GPTJ, "gptj"}, {LLM_ARCH_GPTNEOX, "gptneox"}, {LLM_ARCH_MPT, "mpt"},
+ {LLM_ARCH_BAICHUAN, "baichuan"}, {LLM_ARCH_STARCODER, "starcoder"}, {LLM_ARCH_PERSIMMON, "persimmon"},
+ {LLM_ARCH_REFACT, "refact"}, {LLM_ARCH_BLOOM, "bloom"}, {LLM_ARCH_STABLELM, "stablelm"},
+ {LLM_ARCH_QWEN, "qwen"}, {LLM_ARCH_CHATGLM, "chatglm"}, {LLM_ARCH_CHATGLM2, "chatglm2"},
+ {LLM_ARCH_PHI, "phi"}, {LLM_ARCH_QWEN2, "qwen2"}};
struct gguf_tensor_info {
struct gguf_str name;
diff --git a/neural_speed/models/qwen/qwen.cpp b/neural_speed/models/qwen/qwen.cpp
index 837c09021..2aa65eb2d 100644
--- a/neural_speed/models/qwen/qwen.cpp
+++ b/neural_speed/models/qwen/qwen.cpp
@@ -102,6 +102,12 @@ static bool qwen_model_eval_internal(model_context* ctx, const model_input* inpu
const int n_vocab = hparams.n_vocab;
const int n_rot = hparams.n_rot;
const int head_dim = n_embd / n_head;
+ int qwen_version = 0;
+ if (hparams.max_seq_len == 8192) {
+ qwen_version = 1;
+ } else {
+ qwen_version = 2;
+ }
auto& mem_per_token = lctx.mem_per_token;
auto& buf_compute = lctx.buf_compute;
@@ -164,20 +170,36 @@ static bool qwen_model_eval_internal(model_context* ctx, const model_input* inpu
}
// compute QKV
- {
+ struct ne_tensor* Qcur;
+ struct ne_tensor* Kcur;
+ struct ne_tensor* Vcur;
+
+ if (qwen_version == 1) {
cur = ne_mul_mat(ctx0, model.layers[il].attn[0], cur);
cur = ne_add(ctx0, ne_repeat(ctx0, model.layers[il].attn[1], cur), cur);
+ size_t fused_qkv_row_nb = (3 * n_embd) * sizeof(float);
+ Qcur = ne_cont(ctx0, ne_view_3d(ctx0, cur, head_dim, n_head, N, head_dim * sizeof(float), fused_qkv_row_nb,
+ 0 * sizeof(float) * n_embd));
+ // head_dim, n_head, N --> head_dim, N, n_head
+ Kcur = ne_cont(ctx0, ne_view_3d(ctx0, cur, head_dim, n_head, N, head_dim * sizeof(float), fused_qkv_row_nb,
+ 1 * sizeof(float) * n_embd));
+ // head_dim, n_head, N --> N, head_dim, n_head
+ Vcur = ne_cont(ctx0, ne_view_3d(ctx0, cur, head_dim, n_head, N, head_dim * sizeof(float), fused_qkv_row_nb,
+ 2 * sizeof(float) * n_embd));
+ } else {
+ Qcur = ne_mul_mat(ctx0, model.layers[il].attn[0], cur);
+ Qcur = ne_add(ctx0, ne_repeat(ctx0, model.layers[il].attn[1], Qcur), Qcur);
+ Qcur = ne_reshape_3d(ctx0, Qcur, head_dim, n_head, N);
+
+ Kcur = ne_mul_mat(ctx0, model.layers[il].attn[2], cur);
+ Kcur = ne_add(ctx0, ne_repeat(ctx0, model.layers[il].attn[3], Kcur), Kcur);
+ Kcur = ne_reshape_3d(ctx0, Kcur, head_dim, n_head, N);
+
+ Vcur = ne_mul_mat(ctx0, model.layers[il].attn[4], cur);
+ Vcur = ne_add(ctx0, ne_repeat(ctx0, model.layers[il].attn[5], Vcur), Vcur);
+ Vcur = ne_reshape_3d(ctx0, Vcur, head_dim, n_head, N);
}
- size_t fused_qkv_row_nb = (3 * n_embd) * sizeof(float);
- struct ne_tensor* Qcur = ne_cont(ctx0, ne_view_3d(ctx0, cur, head_dim, n_head, N, head_dim * sizeof(float),
- fused_qkv_row_nb, 0 * sizeof(float) * n_embd));
- // head_dim, n_head, N --> head_dim, N, n_head
- struct ne_tensor* Kcur = ne_cont(ctx0, ne_view_3d(ctx0, cur, head_dim, n_head, N, head_dim * sizeof(float),
- fused_qkv_row_nb, 1 * sizeof(float) * n_embd));
- // head_dim, n_head, N --> N, head_dim, n_head
- struct ne_tensor* Vcur = ne_cont(ctx0, ne_view_3d(ctx0, cur, head_dim, n_head, N, head_dim * sizeof(float),
- fused_qkv_row_nb, 2 * sizeof(float) * n_embd));
// using mode = 2 for GPT-NeoX mode
Qcur = ne_rope_inplace(ctx0, Qcur, n_past, n_rot, 2, 0, hparams.freq_base, hparams.freq_scale);
@@ -300,7 +322,11 @@ static bool qwen_model_eval_internal(model_context* ctx, const model_input* inpu
cur = ne_view_2d(ctx0, KQV_Out, n_embd, N, n_embd * ne_element_size(KQV_Out), 0);
}
// projection
- { cur = ne_mul_mat(ctx0, model.layers[il].attn[2], cur); }
+ if (qwen_version == 1) {
+ cur = ne_mul_mat(ctx0, model.layers[il].attn[2], cur);
+ } else {
+ cur = ne_mul_mat(ctx0, model.layers[il].attn[6], cur);
+ }
}
lctx.use_buf(ctx0, 1);
diff --git a/neural_speed/models/qwen/qwen_utils.cpp b/neural_speed/models/qwen/qwen_utils.cpp
index ae618a04b..2ea38492e 100644
--- a/neural_speed/models/qwen/qwen_utils.cpp
+++ b/neural_speed/models/qwen/qwen_utils.cpp
@@ -52,13 +52,16 @@ void QWEN::init(const char* path_model, model_context* ctx, int n_gpu_layer_, bo
model.hparams = ml->file_loaders.at(0)->hparams;
model_file_version file_version = ml->file_loaders.at(0)->file_version;
auto& hparams = model.hparams;
- n_ff = hparams.ffn_hidden_size / 2;
+ n_ff = hparams.ffn_hidden_size;
+ if (hparams.max_seq_len == 8192) {
+ n_ff = n_ff / 2;
+ }
fprintf(stderr, "%s: n_vocab = %u\n", __func__, hparams.n_vocab);
fprintf(stderr, "%s: n_embd = %u\n", __func__, hparams.n_embd);
fprintf(stderr, "%s: n_head = %u\n", __func__, hparams.n_head);
fprintf(stderr, "%s: n_layer = %u\n", __func__, hparams.n_layer);
fprintf(stderr, "%s: n_rot = %u\n", __func__, hparams.n_rot);
- fprintf(stderr, "%s: n_ff = %u\n", __func__, hparams.ffn_hidden_size / 2);
+ fprintf(stderr, "%s: n_ff = %u\n", __func__, hparams.ffn_hidden_size);
fprintf(stderr, "%s: n_parts = %zu\n", __func__, ml->file_loaders.size());
n_embd = hparams.n_embd;
n_vocab = hparams.n_vocab;
@@ -102,7 +105,7 @@ void QWEN::load(model_context* ctx, model_progress_callback progress_callback, v
model.layers.resize(n_layer);
size_t vram_total = 0;
- if (ml->verify_tensor("token_embd.weight")) {
+ if (ml->verify_tensor("token_embd.weight")) { // gguf
model.others[0] = ml->get_tensor("token_embd.weight", {n_embd, n_vocab}, NE_BACKEND_CPU);
model.others[1] = ml->get_tensor("output_norm.weight", {n_embd}, NE_BACKEND_CPU);
model.others[2] = ml->get_tensor("output.weight", {n_embd, n_vocab}, NE_BACKEND_CPU);
@@ -117,16 +120,26 @@ void QWEN::load(model_context* ctx, model_progress_callback progress_callback, v
layer.norm[1] = ml->get_tensor(layers_i + ".ffn_norm.weight", {n_embd}, backend);
// qkv GEMM
- layer.attn[0] = ml->get_tensor(layers_i + ".attn_qkv.weight", {n_embd, 3 * n_embd}, backend);
- layer.attn[1] = ml->get_tensor(layers_i + ".attn_qkv.bias", {3 * n_embd}, backend);
- layer.attn[2] = ml->get_tensor(layers_i + ".attn_output.weight", {n_embd, n_embd}, backend);
+ if (ml->verify_tensor(layers_i + ".attn_qkv.weight")) {
+ layer.attn[0] = ml->get_tensor(layers_i + ".attn_qkv.weight", {n_embd, 3 * n_embd}, backend);
+ layer.attn[1] = ml->get_tensor(layers_i + ".attn_qkv.bias", {3 * n_embd}, backend);
+ layer.attn[2] = ml->get_tensor(layers_i + ".attn_output.weight", {n_embd, n_embd}, backend);
+ } else { // qwen2 gguf
+ layer.attn[0] = ml->get_tensor(layers_i + ".attn_q.weight", {n_embd, n_embd}, backend);
+ layer.attn[1] = ml->get_tensor(layers_i + ".attn_q.bias", {n_embd}, backend);
+ layer.attn[2] = ml->get_tensor(layers_i + ".attn_k.weight", {n_embd, n_embd}, backend);
+ layer.attn[3] = ml->get_tensor(layers_i + ".attn_k.bias", {n_embd}, backend);
+ layer.attn[4] = ml->get_tensor(layers_i + ".attn_v.weight", {n_embd, n_embd}, backend);
+ layer.attn[5] = ml->get_tensor(layers_i + ".attn_v.bias", {n_embd}, backend);
+ layer.attn[6] = ml->get_tensor(layers_i + ".attn_output.weight", {n_embd, n_embd}, backend);
+ }
// ffn GEMM
layer.ffn[0] = ml->get_tensor(layers_i + ".ffn_up.weight", {n_embd, n_ff}, backend);
layer.ffn[1] = ml->get_tensor(layers_i + ".ffn_gate.weight", {n_embd, n_ff}, backend);
layer.ffn[2] = ml->get_tensor(layers_i + ".ffn_down.weight", {n_ff, n_embd}, backend);
}
- } else {
+ } else if (ml->verify_tensor("transformer.wte.weight")) { // qwen1 bin
model.others[0] = ml->get_tensor("transformer.wte.weight", {n_embd, n_vocab}, NE_BACKEND_CPU);
model.others[1] = ml->get_tensor("transformer.ln_f.weight", {n_embd}, NE_BACKEND_CPU);
model.others[2] = ml->get_tensor("lm_head.weight", {n_embd, n_vocab}, NE_BACKEND_CPU);
@@ -150,6 +163,34 @@ void QWEN::load(model_context* ctx, model_progress_callback progress_callback, v
layer.ffn[1] = ml->get_tensor(layers_i + ".mlp.w2.weight", {n_embd, n_ff}, backend);
layer.ffn[2] = ml->get_tensor(layers_i + ".mlp.c_proj.weight", {n_ff, n_embd}, backend);
}
+ } else { // qwen2 bin
+ model.others[0] = ml->get_tensor("model.embed_tokens.weight", {n_embd, n_vocab}, NE_BACKEND_CPU);
+ model.others[1] = ml->get_tensor("model.norm.weight", {n_embd}, NE_BACKEND_CPU);
+ model.others[2] = ml->get_tensor("lm_head.weight", {n_embd, n_vocab}, NE_BACKEND_CPU);
+
+ for (uint32_t i = 0; i < n_layer; ++i) {
+ const ne_backend backend = static_cast(i) < i_gpu_start ? NE_BACKEND_CPU : MODEL_BACKEND_OFFLOAD;
+ auto& layer = model.layers[i];
+ std::string layers_i = "model.layers." + std::to_string(i);
+
+ // norm: cur = ln_1_g*cur + ln_1_b
+ layer.norm[0] = ml->get_tensor(layers_i + ".input_layernorm.weight", {n_embd}, backend);
+ layer.norm[1] = ml->get_tensor(layers_i + ".post_attention_layernorm.weight", {n_embd}, backend);
+
+ // qkv GEMM + out proj GEMM
+ layer.attn[0] = ml->get_tensor(layers_i + ".self_attn.q_proj.weight", {n_embd, n_embd}, backend);
+ layer.attn[1] = ml->get_tensor(layers_i + ".self_attn.q_proj.bias", {n_embd}, backend);
+ layer.attn[2] = ml->get_tensor(layers_i + ".self_attn.k_proj.weight", {n_embd, n_embd}, backend);
+ layer.attn[3] = ml->get_tensor(layers_i + ".self_attn.k_proj.bias", {n_embd}, backend);
+ layer.attn[4] = ml->get_tensor(layers_i + ".self_attn.v_proj.weight", {n_embd, n_embd}, backend);
+ layer.attn[5] = ml->get_tensor(layers_i + ".self_attn.v_proj.bias", {n_embd}, backend);
+ layer.attn[6] = ml->get_tensor(layers_i + ".self_attn.o_proj.weight", {n_embd, n_embd}, backend);
+
+ // ffn GEMM
+ layer.ffn[0] = ml->get_tensor(layers_i + ".mlp.up_proj.weight", {n_embd, n_ff}, backend);
+ layer.ffn[1] = ml->get_tensor(layers_i + ".mlp.gate_proj.weight", {n_embd, n_ff}, backend);
+ layer.ffn[2] = ml->get_tensor(layers_i + ".mlp.down_proj.weight", {n_ff, n_embd}, backend);
+ }
}
// print memory requirements
@@ -180,7 +221,7 @@ class qwen_quant_layer : public quant_layer_base {
public:
quant_params_internal get_layer_config(std::string layername, std::vector ne, ne_type type) override {
bool quantize = layername.rfind("weight") == layername.size() - 6; // ends with 'weight'?
- if (layername == "transformer.wte.weight") {
+ if (layername == "transformer.wte.weight" || layername == "model.embed_tokens.weight") {
// special layer process, can be loaded by config file
return quant_params_internal(); // return q4_0 to cover the usage of getrow
}
diff --git a/tests/model-test/cpp_graph_inference.sh b/tests/model-test/cpp_graph_inference.sh
index 500fb37fd..3abccb082 100644
--- a/tests/model-test/cpp_graph_inference.sh
+++ b/tests/model-test/cpp_graph_inference.sh
@@ -155,8 +155,10 @@ model_name_map["qwen-7b"]="Qwen/Qwen-7B-Chat"
model_name_map["magicoder"]="ise-uiuc/Magicoder-S-DS-6.7B"
model_name_map["whisper"]="openai/whisper-tiny"
model_name_map["phi2"]="microsoft/phi-2"
+model_name_map["qwen-1_5"]="Qwen/Qwen1.5-7B-Chat"
model_name_map["mixtral"]="mistralai/Mixtral-8x7B-Instruct-v0.1"
+
function main() {
conda_env="$1"
model="$2"
@@ -251,6 +253,10 @@ function main() {
quant_script="./build/bin/quant_qwen"
convert_script="${convert_script}/convert_qwen.py"
infer_cmd="./build/bin/run_qwen"
+ elif [[ "${model}" == "qwen-1_5" ]]; then
+ quant_script="./build/bin/quant_qwen"
+ convert_script="${convert_script}/convert_qwen.py"
+ infer_cmd="./build/bin/run_qwen"
elif [[ "${model}" == "magicoder" ]]; then
quant_script="./build/bin/quant_llama"
convert_script="${convert_script}/convert_llama.py"