From 8d5fe2df1e83d24b575ebc9904a64c140be7fccf Mon Sep 17 00:00:00 2001 From: Zhenzhong1 <109137058+Zhenzhong1@users.noreply.github.com> Date: Fri, 15 Mar 2024 11:00:51 +0800 Subject: [PATCH] [Model Enhence] Add Baichuan-7B architecutre and refactor Baichuan-13B. (#177) --- neural_speed/convert/convert_baichuan.py | 124 ++++++++++++++++-- neural_speed/models/baichuan/baichuan.cpp | 22 +++- neural_speed/models/baichuan/baichuan.h | 6 + .../models/baichuan/baichuan_utils.cpp | 100 +++++++++----- neural_speed/models/model_utils/model_files.h | 4 +- 5 files changed, 208 insertions(+), 48 deletions(-) diff --git a/neural_speed/convert/convert_baichuan.py b/neural_speed/convert/convert_baichuan.py index 85bcd9dc9..b6893e310 100644 --- a/neural_speed/convert/convert_baichuan.py +++ b/neural_speed/convert/convert_baichuan.py @@ -48,6 +48,7 @@ def bytes_to_unicode(): class SentencePieceVocab: + def __init__(self, fname_tokenizer: Path, fname_added_tokens: Optional[Path]) -> None: self.sentencepiece_tokenizer = SentencePieceProcessor(str(fname_tokenizer)) added_tokens: Dict[str, int] @@ -116,8 +117,7 @@ def load_vocab_for_baichuan(path: Path) -> SentencePieceVocab: else: raise FileNotFoundError( f"Could not find tokenizer.model in {path} or its parent; if it's in another directory, \ - pass the directory as --vocab-dir" - ) + pass the directory as --vocab-dir") added_tokens_path = path.parent / "added_tokens.json" print(f"Loading vocab file {path}") return SentencePieceVocab(path, added_tokens_path if added_tokens_path.exists() else None) @@ -161,9 +161,112 @@ def baichuan13B_convert(model, tokenizer, dir_model, fname_out, ftype, hparams): fout.write(struct.pack("f", 10000.0)) # freq_base fout.write(struct.pack("f", 1.0)) # rope_factor - fout.write(struct.pack("f", 0.0)) # config.json "rope_scaling.factor", not enabled - fout.write(struct.pack("i", 0)) # rope_scaling.original_max_position_embeddings - fout.write(struct.pack("i", 0)) # params["rope_scaling"]["type"] =="yarn" else 0)) + fout.write(struct.pack("f", 0.0)) # config.json "rope_scaling.factor", not enabled + fout.write(struct.pack("i", 0)) # rope_scaling.original_max_position_embeddings + fout.write(struct.pack("i", 0)) # params["rope_scaling"]["type"] =="yarn" else 0)) + + fout.write(struct.pack("i", tokenizer.bos_token_id if tokenizer.bos_token_id is not None else 1)) + fout.write(struct.pack("i", tokenizer.eos_token_id if tokenizer.eos_token_id is not None else 2)) + fout.write(struct.pack("i", tokenizer.pad_token_id if tokenizer.pad_token_id is not None else -1)) + fout.write(struct.pack("i", tokenizer.sep_token_id if tokenizer.sep_token_id is not None else -1)) + + tokenizer_path = Path(tokenizer.vocab_file).parent + vocab = load_vocab_for_baichuan(Path(tokenizer_path)) + counter = 0 + for text, score in vocab.all_tokens(): + fout.write(struct.pack("i", len(text))) + fout.write(text) + fout.write(struct.pack("f", score)) + counter += 1 + + while counter < hparams["vocab_size"]: + fout.write(struct.pack("i", len(text))) + fout.write(text) + fout.write(struct.pack("f", 0)) + counter += 1 + + for name in list_vars.keys(): + data = list_vars[name].squeeze().numpy() + print("Processing variable: " + name + " with shape: ", data.shape) + if 'inv_freq' in name: + continue + + n_dims = len(data.shape) + + # ftype == 0 -> float32, ftype == 1 -> float16 + ftype_cur = 0 + if ftype != 0: + if name[-7:] == ".weight" and n_dims == 2: + print(" Converting to float16") + data = data.astype(np.float16) + ftype_cur = 14 + else: + print(" Converting to float32") + data = data.astype(np.float32) + ftype_cur = 0 + else: + if data.dtype != np.float32: + print(" Converting to float32") + data = data.astype(np.float32) + ftype_cur = 0 + + # header + str = name.encode("utf-8") + fout.write(struct.pack("iii", n_dims, len(str), ftype_cur)) + for i in range(n_dims): + fout.write(struct.pack("i", data.shape[n_dims - 1 - i])) + fout.write(str) + + # data + data.tofile(fout) + + fout.close() + + print("Done. Output file: " + fname_out) + print("") + + +def baichuan7B_convert(model, tokenizer, dir_model, fname_out, ftype, hparams): + print("Baichuan-7B converting: ") + list_vars = model.state_dict() + for name in list_vars.keys(): + print(name, list_vars[name].shape, list_vars[name].dtype) + + fout = open(fname_out, "wb") + + print(hparams) + + fout.write(struct.pack("i", 0x67676d66)) + fout.write(struct.pack("i", 1)) + + fout.write(struct.pack("i", hparams["vocab_size"])) + fout.write(struct.pack("i", hparams["hidden_size"])) + fout.write(struct.pack("i", 0)) + fout.write(struct.pack("i", hparams["num_attention_heads"])) + fout.write(struct.pack("i", 0)) + fout.write(struct.pack("i", hparams["num_hidden_layers"])) + fout.write(struct.pack("i", 128)) + fout.write(struct.pack("i", ftype)) + fout.write(struct.pack("i", hparams["model_max_length"])) + fout.write(struct.pack("f", 0)) + fout.write(struct.pack("f", 0)) + fout.write(struct.pack("i", 0)) + + fout.write(struct.pack("i", 0)) # word_embed_proj_dim (for opt) + fout.write(struct.pack("i", 0)) # do_layer_norm_before (for opt) + + fout.write(struct.pack("i", 0)) + fout.write(struct.pack("i", 0)) + fout.write(struct.pack("i", hparams["intermediate_size"])) + fout.write(struct.pack("i", 0)) # n_experts + fout.write(struct.pack("i", 0)) # n_expert_used + fout.write(struct.pack("f", hparams.get("rms_norm_eps", 1e-6))) # rms_norm_eps or layer_norm_eps + fout.write(struct.pack("f", 10000.0)) # freq_base + fout.write(struct.pack("f", 1.0)) # rope_factor + + fout.write(struct.pack("f", 0.0)) # config.json "rope_scaling.factor", not enabled + fout.write(struct.pack("i", 0)) # rope_scaling.original_max_position_embeddings + fout.write(struct.pack("i", 0)) # params["rope_scaling"]["type"] =="yarn" else 0)) fout.write(struct.pack("i", tokenizer.bos_token_id if tokenizer.bos_token_id is not None else 1)) fout.write(struct.pack("i", tokenizer.eos_token_id if tokenizer.eos_token_id is not None else 2)) @@ -230,8 +333,10 @@ def main(args_in: Optional[List[str]] = None) -> None: parser = argparse.ArgumentParser(description="Convert a model to a NE compatible file") parser.add_argument("--outtype", choices=["f32", "f16"], help="output format (default: based on input)") parser.add_argument("--outfile", type=Path, help="path to write to; default: based on input") - parser.add_argument("--model_hub", choices=["huggingface","modelscope"], - default="huggingface", help="hub to load model") + parser.add_argument("--model_hub", + choices=["huggingface", "modelscope"], + default="huggingface", + help="hub to load model") parser.add_argument("model", type=Path, help="directory containing model file") args = parser.parse_args(args_in) @@ -255,7 +360,10 @@ def main(args_in: Optional[List[str]] = None) -> None: hparams = config.to_dict() - baichuan13B_convert(model, tokenizer, dir_model, fname_out, ftype, hparams) + if hparams['hidden_size'] == 4096: + baichuan7B_convert(model, tokenizer, dir_model, fname_out, ftype, hparams) + else: + baichuan13B_convert(model, tokenizer, dir_model, fname_out, ftype, hparams) if __name__ == '__main__': diff --git a/neural_speed/models/baichuan/baichuan.cpp b/neural_speed/models/baichuan/baichuan.cpp index 00ae82d29..38bb2ae05 100644 --- a/neural_speed/models/baichuan/baichuan.cpp +++ b/neural_speed/models/baichuan/baichuan.cpp @@ -74,8 +74,14 @@ static bool baichuan_model_eval_internal(model_context* ctx, const model_input* int n_head = hparams.n_head; const int n_vocab = hparams.n_vocab; const int head_size = n_embd / n_head; - const int n_rot = n_embd / n_head / 2; const float attn_scale = 1.f / std::sqrt(head_size); + const int n_rot = hparams.n_rot; + int baichuan_version = 0; + if (hparams.n_embd == 4096) { + baichuan_version = 7; + } else { + baichuan_version = 13; + } bool enable_tp = false; #ifdef NS_TP_MODEL @@ -131,6 +137,7 @@ static bool baichuan_model_eval_internal(model_context* ctx, const model_input* } struct ne_tensor* embd = d_ne_new_tensor_1d(ctx0, NE_TYPE_I32, N); + for (int i = 0; i < batch_size; ++i) { memcpy(static_cast(embd->data) + i * N, (inputs + i)->tokens, N * ne_element_size(embd)); } @@ -152,7 +159,6 @@ static bool baichuan_model_eval_internal(model_context* ctx, const model_input* { // Linear::forward compute QKV cur = ne_mul_mat(ctx0, model.layers[il].attn[0], cur); - ne_tensor* query_layer = ne_view_3d(ctx0, cur, head_size, n_head, N, head_size * ne_element_size(cur), cur->nb[1], 0); // [N, hidden] @@ -162,6 +168,12 @@ static bool baichuan_model_eval_internal(model_context* ctx, const model_input* ne_tensor* value_layer = ne_view_3d(ctx0, cur, head_size, n_head, N, head_size * ne_element_size(cur), cur->nb[1], 2 * hidden_size * ne_element_size(cur)); // [N, heads, head_size] + // using mode = 2 for neox mode + if (baichuan_version == 7) { + query_layer = ne_rope_inplace(ctx0, query_layer, n_past, n_rot, 2, 0, hparams.freq_base, hparams.freq_scale); + key_layer = ne_rope_inplace(ctx0, key_layer, n_past, n_rot, 2, 0, hparams.freq_base, hparams.freq_scale); + } + if (!run_mha_reordered) { query_layer = ne_permute(ctx0, query_layer, 0, 2, 1, 3); // [heads, N, head_size] key_layer = ne_permute(ctx0, key_layer, 0, 2, 1, 3); // [heads, N, head_size] @@ -193,7 +205,11 @@ static bool baichuan_model_eval_internal(model_context* ctx, const model_input* // attention struct ne_tensor* attn_scores = ne_mul_mat(ctx0, key_layer, query_layer); // [heads, N, klen] attn_scores = ne_scale_inplace(ctx0, attn_scores, ne_new_f32(ctx0, attn_scale)); - attn_scores = ne_alibi(ctx0, attn_scores, n_past, n_head, 8); + + if (baichuan_version == 13) { + attn_scores = ne_alibi(ctx0, attn_scores, n_past, n_head, 8); + } + if (n_past == 0) { attn_scores = ne_diag_mask_inf_inplace(ctx0, attn_scores, n_past); } diff --git a/neural_speed/models/baichuan/baichuan.h b/neural_speed/models/baichuan/baichuan.h index 2803d9617..df12fc7bc 100644 --- a/neural_speed/models/baichuan/baichuan.h +++ b/neural_speed/models/baichuan/baichuan.h @@ -31,6 +31,12 @@ static const model_scratch baichuan_mem_req(int n_layers, float scratch_size_rat static_cast(scratch_size_ratio * 2048) * MB, static_cast(scratch_size_ratio * 4096) * MB, }; + case 32: + return { + static_cast(scratch_size_ratio * 4096) * MB, + static_cast(scratch_size_ratio * 2048) * MB, + static_cast(scratch_size_ratio * 4096) * MB, + }; default: MODEL_ASSERT(false); } diff --git a/neural_speed/models/baichuan/baichuan_utils.cpp b/neural_speed/models/baichuan/baichuan_utils.cpp index 155bc1e61..9318e2580 100644 --- a/neural_speed/models/baichuan/baichuan_utils.cpp +++ b/neural_speed/models/baichuan/baichuan_utils.cpp @@ -63,15 +63,15 @@ void BAICHUAN::init(const char* path_model, model_context* ctx, int n_gpu_layer_ model.hparams = ml->file_loaders.at(0)->hparams; model_file_version file_version = ml->file_loaders.at(0)->file_version; auto& hparams = model.hparams; - n_ff = 4 * hparams.n_embd; fprintf(stderr, "%s: n_vocab = %u\n", __func__, hparams.n_vocab); fprintf(stderr, "%s: n_embd = %u\n", __func__, hparams.n_embd); fprintf(stderr, "%s: n_mult = %u\n", __func__, hparams.n_mult); fprintf(stderr, "%s: n_head = %u\n", __func__, hparams.n_head); fprintf(stderr, "%s: n_layer = %u\n", __func__, hparams.n_layer); fprintf(stderr, "%s: n_rot = %u\n", __func__, hparams.n_rot); - fprintf(stderr, "%s: n_ff = %u\n", __func__, n_ff); + fprintf(stderr, "%s: n_ff = %u\n", __func__, hparams.ffn_hidden_size); fprintf(stderr, "%s: n_parts = %zu\n", __func__, ml->file_loaders.size()); + fprintf(stderr, "%s: inner_hidden_size = %u\n", __func__, hparams.inner_hidden_size); n_embd = hparams.n_embd; n_vocab = hparams.n_vocab; n_layer = hparams.n_layer; @@ -92,10 +92,6 @@ void BAICHUAN::load(model_context* ctx, model_progress_callback progress_callbac fprintf(stderr, "%s: ctx size = %7.2f MB\n", __func__, ctx_size / 1024.0 / 1024.0); const auto& hparams = model.hparams; - const int head_dim = n_embd / hparams.n_head; - const int kv_heads = hparams.n_head; // 1 if MQA else hparams.n_head - const int kv_dim = kv_heads * head_dim; - const int max_len = 4096; // create the ne context lctx.model.buf.resize(ctx_size); @@ -116,37 +112,71 @@ void BAICHUAN::load(model_context* ctx, model_progress_callback progress_callbac } ml->ne_ctx = ne_ctx; - - model.others[0] = ml->get_tensor("model.embed_tokens.weight", {n_embd, n_vocab}, NE_BACKEND_CPU); - model.others[1] = ml->get_tensor("model.norm.weight", {n_embd}, NE_BACKEND_CPU); - model.others[2] = ml->get_tensor("lm_head.weight", {n_embd, n_vocab}, NE_BACKEND_CPU); - const int i_gpu_start = n_layer - n_gpu_layer; - - model.layers.resize(n_layer); size_t vram_total = 0; - for (uint32_t i = 0; i < n_layer; ++i) { - const ne_backend backend = static_cast(i) < i_gpu_start ? NE_BACKEND_CPU : MODEL_BACKEND_OFFLOAD; - auto& layer = model.layers[i]; - std::string layers_i = "model.layers." + std::to_string(i); - layer.norm[0] = ml->get_tensor(layers_i + ".input_layernorm.weight", {n_embd}, backend); - - // qkv GEMM - layer.attn[0] = ml->get_tensor(layers_i + ".self_attn.W_pack.weight", {n_embd, 3 * n_embd}, backend); - layer.attn[1] = ml->get_tensor(layers_i + ".self_attn.o_proj.weight", {n_embd, n_embd}, backend); - - layer.norm[1] = ml->get_tensor(layers_i + ".post_attention_layernorm.weight", {n_embd}, backend); - // ffn GEMM - layer.ffn[0] = ml->get_tensor(layers_i + ".mlp.gate_proj.weight", - {n_embd, uint32_t(model.hparams.inner_hidden_size)}, backend); - - layer.ffn[1] = ml->get_tensor(layers_i + ".mlp.down_proj.weight", - {uint32_t(model.hparams.inner_hidden_size), n_embd}, backend); - layer.ffn[2] = - ml->get_tensor(layers_i + ".mlp.up_proj.weight", {n_embd, uint32_t(model.hparams.inner_hidden_size)}, backend); - - layer.v_cache = nullptr; - layer.k_cache = nullptr; + if (ml->verify_tensor("token_embd.weight")) { // for gguf + model.others[0] = ml->get_tensor("token_embd.weight", {n_embd, n_vocab}, NE_BACKEND_CPU); + model.others[1] = ml->get_tensor("output_norm.weight", {n_embd}, NE_BACKEND_CPU); + model.others[2] = ml->get_tensor("output.weight", {n_embd, n_vocab}, NE_BACKEND_CPU); + const int i_gpu_start = n_layer - n_gpu_layer; + + model.layers.resize(n_layer); + for (uint32_t i = 0; i < n_layer; ++i) { + const ne_backend backend = static_cast(i) < i_gpu_start ? NE_BACKEND_CPU : MODEL_BACKEND_OFFLOAD; + auto& layer = model.layers[i]; + std::string layers_i = "blk." + std::to_string(i); + layer.norm[0] = ml->get_tensor(layers_i + ".attn_norm.weight", {n_embd}, backend); + + // qkv GEMM + std::string w_pack = "model.layers." + std::to_string(i); + layer.attn[0] = ml->get_tensor(w_pack + ".self_attn.W_pack.weight", {n_embd, 3 * n_embd}, backend); + layer.attn[1] = ml->get_tensor(layers_i + ".attn_output.weight", {n_embd, n_embd}, backend); + + layer.norm[1] = ml->get_tensor(layers_i + ".ffn_norm.weight", {n_embd}, backend); + + // ffn GEMM + layer.ffn[0] = + ml->get_tensor(layers_i + ".ffn_gate.weight", {n_embd, uint32_t(model.hparams.ffn_hidden_size)}, backend); + + layer.ffn[1] = + ml->get_tensor(layers_i + ".ffn_down.weight", {uint32_t(model.hparams.ffn_hidden_size), n_embd}, backend); + layer.ffn[2] = + ml->get_tensor(layers_i + ".ffn_up.weight", {n_embd, uint32_t(model.hparams.ffn_hidden_size)}, backend); + + layer.v_cache = nullptr; + layer.k_cache = nullptr; + } + } else { + model.others[0] = ml->get_tensor("model.embed_tokens.weight", {n_embd, n_vocab}, NE_BACKEND_CPU); + model.others[1] = ml->get_tensor("model.norm.weight", {n_embd}, NE_BACKEND_CPU); + model.others[2] = ml->get_tensor("lm_head.weight", {n_embd, n_vocab}, NE_BACKEND_CPU); + const int i_gpu_start = n_layer - n_gpu_layer; + + model.layers.resize(n_layer); + for (uint32_t i = 0; i < n_layer; ++i) { + const ne_backend backend = static_cast(i) < i_gpu_start ? NE_BACKEND_CPU : MODEL_BACKEND_OFFLOAD; + auto& layer = model.layers[i]; + std::string layers_i = "model.layers." + std::to_string(i); + layer.norm[0] = ml->get_tensor(layers_i + ".input_layernorm.weight", {n_embd}, backend); + + // qkv GEMM + layer.attn[0] = ml->get_tensor(layers_i + ".self_attn.W_pack.weight", {n_embd, 3 * n_embd}, backend); + layer.attn[1] = ml->get_tensor(layers_i + ".self_attn.o_proj.weight", {n_embd, n_embd}, backend); + + layer.norm[1] = ml->get_tensor(layers_i + ".post_attention_layernorm.weight", {n_embd}, backend); + + // ffn GEMM + layer.ffn[0] = ml->get_tensor(layers_i + ".mlp.gate_proj.weight", + {n_embd, uint32_t(model.hparams.inner_hidden_size)}, backend); + + layer.ffn[1] = ml->get_tensor(layers_i + ".mlp.down_proj.weight", + {uint32_t(model.hparams.inner_hidden_size), n_embd}, backend); + layer.ffn[2] = ml->get_tensor(layers_i + ".mlp.up_proj.weight", + {n_embd, uint32_t(model.hparams.inner_hidden_size)}, backend); + + layer.v_cache = nullptr; + layer.k_cache = nullptr; + } } // print memory requirements diff --git a/neural_speed/models/model_utils/model_files.h b/neural_speed/models/model_utils/model_files.h index 85d65891a..b702afc52 100644 --- a/neural_speed/models/model_utils/model_files.h +++ b/neural_speed/models/model_utils/model_files.h @@ -1090,7 +1090,7 @@ struct model_file_loader { printf("%-16s %d.hparams.n_head = %-30d\n", __func__, count++, hparams.n_head); printf("%-16s %d.hparams.n_head_kv = %-30d\n", __func__, count++, hparams.n_head_kv); printf("%-16s %d.hparams.n_layer = %-30d\n", __func__, count++, hparams.n_layer); - printf("%-16s %d.hparams.n_rot = %-30d\n", __func__, count++, hparams.n_vocab); + printf("%-16s %d.hparams.n_rot = %-30d\n", __func__, count++, hparams.n_rot); hparams.ftype = (enum ne_ftype)file.read_u32(); hparams.max_seq_len = file.read_u32(); @@ -1122,7 +1122,7 @@ struct model_file_loader { file.read_raw(&hparams.norm_eps, sizeof(float)); file.read_raw(&hparams.freq_base, sizeof(float)); file.read_raw(&hparams.freq_scale, sizeof(float)); - printf("%-16s %d.hparams.inner_hidden_size = %-30d\n", __func__, count++, hparams.inner_hidden_size); + printf("%-16s %d.hparams.norm_eps = %-30f\n", __func__, count++, hparams.norm_eps); printf("%-16s %d.hparams.freq_base = %-30.3f\n", __func__, count++, hparams.freq_base); printf("%-16s %d.hparams.freq_scale = %-30.3f\n", __func__, count++, hparams.freq_scale);