Skip to content
This repository has been archived by the owner on Aug 30, 2024. It is now read-only.

Commit

Permalink
[Model Enhence] Add Baichuan-7B architecutre and refactor Baichuan-13…
Browse files Browse the repository at this point in the history
…B. (#177)
  • Loading branch information
Zhenzhong1 authored Mar 15, 2024
1 parent eed9b30 commit 8d5fe2d
Show file tree
Hide file tree
Showing 5 changed files with 208 additions and 48 deletions.
124 changes: 116 additions & 8 deletions neural_speed/convert/convert_baichuan.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ def bytes_to_unicode():


class SentencePieceVocab:

def __init__(self, fname_tokenizer: Path, fname_added_tokens: Optional[Path]) -> None:
self.sentencepiece_tokenizer = SentencePieceProcessor(str(fname_tokenizer))
added_tokens: Dict[str, int]
Expand Down Expand Up @@ -116,8 +117,7 @@ def load_vocab_for_baichuan(path: Path) -> SentencePieceVocab:
else:
raise FileNotFoundError(
f"Could not find tokenizer.model in {path} or its parent; if it's in another directory, \
pass the directory as --vocab-dir"
)
pass the directory as --vocab-dir")
added_tokens_path = path.parent / "added_tokens.json"
print(f"Loading vocab file {path}")
return SentencePieceVocab(path, added_tokens_path if added_tokens_path.exists() else None)
Expand Down Expand Up @@ -161,9 +161,112 @@ def baichuan13B_convert(model, tokenizer, dir_model, fname_out, ftype, hparams):
fout.write(struct.pack("f", 10000.0)) # freq_base
fout.write(struct.pack("f", 1.0)) # rope_factor

fout.write(struct.pack("f", 0.0)) # config.json "rope_scaling.factor", not enabled
fout.write(struct.pack("i", 0)) # rope_scaling.original_max_position_embeddings
fout.write(struct.pack("i", 0)) # params["rope_scaling"]["type"] =="yarn" else 0))
fout.write(struct.pack("f", 0.0)) # config.json "rope_scaling.factor", not enabled
fout.write(struct.pack("i", 0)) # rope_scaling.original_max_position_embeddings
fout.write(struct.pack("i", 0)) # params["rope_scaling"]["type"] =="yarn" else 0))

fout.write(struct.pack("i", tokenizer.bos_token_id if tokenizer.bos_token_id is not None else 1))
fout.write(struct.pack("i", tokenizer.eos_token_id if tokenizer.eos_token_id is not None else 2))
fout.write(struct.pack("i", tokenizer.pad_token_id if tokenizer.pad_token_id is not None else -1))
fout.write(struct.pack("i", tokenizer.sep_token_id if tokenizer.sep_token_id is not None else -1))

tokenizer_path = Path(tokenizer.vocab_file).parent
vocab = load_vocab_for_baichuan(Path(tokenizer_path))
counter = 0
for text, score in vocab.all_tokens():
fout.write(struct.pack("i", len(text)))
fout.write(text)
fout.write(struct.pack("f", score))
counter += 1

while counter < hparams["vocab_size"]:
fout.write(struct.pack("i", len(text)))
fout.write(text)
fout.write(struct.pack("f", 0))
counter += 1

for name in list_vars.keys():
data = list_vars[name].squeeze().numpy()
print("Processing variable: " + name + " with shape: ", data.shape)
if 'inv_freq' in name:
continue

n_dims = len(data.shape)

# ftype == 0 -> float32, ftype == 1 -> float16
ftype_cur = 0
if ftype != 0:
if name[-7:] == ".weight" and n_dims == 2:
print(" Converting to float16")
data = data.astype(np.float16)
ftype_cur = 14
else:
print(" Converting to float32")
data = data.astype(np.float32)
ftype_cur = 0
else:
if data.dtype != np.float32:
print(" Converting to float32")
data = data.astype(np.float32)
ftype_cur = 0

# header
str = name.encode("utf-8")
fout.write(struct.pack("iii", n_dims, len(str), ftype_cur))
for i in range(n_dims):
fout.write(struct.pack("i", data.shape[n_dims - 1 - i]))
fout.write(str)

# data
data.tofile(fout)

fout.close()

print("Done. Output file: " + fname_out)
print("")


def baichuan7B_convert(model, tokenizer, dir_model, fname_out, ftype, hparams):
print("Baichuan-7B converting: ")
list_vars = model.state_dict()
for name in list_vars.keys():
print(name, list_vars[name].shape, list_vars[name].dtype)

fout = open(fname_out, "wb")

print(hparams)

fout.write(struct.pack("i", 0x67676d66))
fout.write(struct.pack("i", 1))

fout.write(struct.pack("i", hparams["vocab_size"]))
fout.write(struct.pack("i", hparams["hidden_size"]))
fout.write(struct.pack("i", 0))
fout.write(struct.pack("i", hparams["num_attention_heads"]))
fout.write(struct.pack("i", 0))
fout.write(struct.pack("i", hparams["num_hidden_layers"]))
fout.write(struct.pack("i", 128))
fout.write(struct.pack("i", ftype))
fout.write(struct.pack("i", hparams["model_max_length"]))
fout.write(struct.pack("f", 0))
fout.write(struct.pack("f", 0))
fout.write(struct.pack("i", 0))

fout.write(struct.pack("i", 0)) # word_embed_proj_dim (for opt)
fout.write(struct.pack("i", 0)) # do_layer_norm_before (for opt)

fout.write(struct.pack("i", 0))
fout.write(struct.pack("i", 0))
fout.write(struct.pack("i", hparams["intermediate_size"]))
fout.write(struct.pack("i", 0)) # n_experts
fout.write(struct.pack("i", 0)) # n_expert_used
fout.write(struct.pack("f", hparams.get("rms_norm_eps", 1e-6))) # rms_norm_eps or layer_norm_eps
fout.write(struct.pack("f", 10000.0)) # freq_base
fout.write(struct.pack("f", 1.0)) # rope_factor

fout.write(struct.pack("f", 0.0)) # config.json "rope_scaling.factor", not enabled
fout.write(struct.pack("i", 0)) # rope_scaling.original_max_position_embeddings
fout.write(struct.pack("i", 0)) # params["rope_scaling"]["type"] =="yarn" else 0))

fout.write(struct.pack("i", tokenizer.bos_token_id if tokenizer.bos_token_id is not None else 1))
fout.write(struct.pack("i", tokenizer.eos_token_id if tokenizer.eos_token_id is not None else 2))
Expand Down Expand Up @@ -230,8 +333,10 @@ def main(args_in: Optional[List[str]] = None) -> None:
parser = argparse.ArgumentParser(description="Convert a model to a NE compatible file")
parser.add_argument("--outtype", choices=["f32", "f16"], help="output format (default: based on input)")
parser.add_argument("--outfile", type=Path, help="path to write to; default: based on input")
parser.add_argument("--model_hub", choices=["huggingface","modelscope"],
default="huggingface", help="hub to load model")
parser.add_argument("--model_hub",
choices=["huggingface", "modelscope"],
default="huggingface",
help="hub to load model")
parser.add_argument("model", type=Path, help="directory containing model file")
args = parser.parse_args(args_in)

Expand All @@ -255,7 +360,10 @@ def main(args_in: Optional[List[str]] = None) -> None:

hparams = config.to_dict()

baichuan13B_convert(model, tokenizer, dir_model, fname_out, ftype, hparams)
if hparams['hidden_size'] == 4096:
baichuan7B_convert(model, tokenizer, dir_model, fname_out, ftype, hparams)
else:
baichuan13B_convert(model, tokenizer, dir_model, fname_out, ftype, hparams)


if __name__ == '__main__':
Expand Down
22 changes: 19 additions & 3 deletions neural_speed/models/baichuan/baichuan.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -74,8 +74,14 @@ static bool baichuan_model_eval_internal(model_context* ctx, const model_input*
int n_head = hparams.n_head;
const int n_vocab = hparams.n_vocab;
const int head_size = n_embd / n_head;
const int n_rot = n_embd / n_head / 2;
const float attn_scale = 1.f / std::sqrt(head_size);
const int n_rot = hparams.n_rot;
int baichuan_version = 0;
if (hparams.n_embd == 4096) {
baichuan_version = 7;
} else {
baichuan_version = 13;
}

bool enable_tp = false;
#ifdef NS_TP_MODEL
Expand Down Expand Up @@ -131,6 +137,7 @@ static bool baichuan_model_eval_internal(model_context* ctx, const model_input*
}

struct ne_tensor* embd = d_ne_new_tensor_1d(ctx0, NE_TYPE_I32, N);

for (int i = 0; i < batch_size; ++i) {
memcpy(static_cast<model_token*>(embd->data) + i * N, (inputs + i)->tokens, N * ne_element_size(embd));
}
Expand All @@ -152,7 +159,6 @@ static bool baichuan_model_eval_internal(model_context* ctx, const model_input*
{
// Linear::forward compute QKV
cur = ne_mul_mat(ctx0, model.layers[il].attn[0], cur);

ne_tensor* query_layer = ne_view_3d(ctx0, cur, head_size, n_head, N, head_size * ne_element_size(cur), cur->nb[1],
0); // [N, hidden]

Expand All @@ -162,6 +168,12 @@ static bool baichuan_model_eval_internal(model_context* ctx, const model_input*
ne_tensor* value_layer = ne_view_3d(ctx0, cur, head_size, n_head, N, head_size * ne_element_size(cur), cur->nb[1],
2 * hidden_size * ne_element_size(cur)); // [N, heads, head_size]

// using mode = 2 for neox mode
if (baichuan_version == 7) {
query_layer = ne_rope_inplace(ctx0, query_layer, n_past, n_rot, 2, 0, hparams.freq_base, hparams.freq_scale);
key_layer = ne_rope_inplace(ctx0, key_layer, n_past, n_rot, 2, 0, hparams.freq_base, hparams.freq_scale);
}

if (!run_mha_reordered) {
query_layer = ne_permute(ctx0, query_layer, 0, 2, 1, 3); // [heads, N, head_size]
key_layer = ne_permute(ctx0, key_layer, 0, 2, 1, 3); // [heads, N, head_size]
Expand Down Expand Up @@ -193,7 +205,11 @@ static bool baichuan_model_eval_internal(model_context* ctx, const model_input*
// attention
struct ne_tensor* attn_scores = ne_mul_mat(ctx0, key_layer, query_layer); // [heads, N, klen]
attn_scores = ne_scale_inplace(ctx0, attn_scores, ne_new_f32(ctx0, attn_scale));
attn_scores = ne_alibi(ctx0, attn_scores, n_past, n_head, 8);

if (baichuan_version == 13) {
attn_scores = ne_alibi(ctx0, attn_scores, n_past, n_head, 8);
}

if (n_past == 0) {
attn_scores = ne_diag_mask_inf_inplace(ctx0, attn_scores, n_past);
}
Expand Down
6 changes: 6 additions & 0 deletions neural_speed/models/baichuan/baichuan.h
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,12 @@ static const model_scratch baichuan_mem_req(int n_layers, float scratch_size_rat
static_cast<unsigned long long>(scratch_size_ratio * 2048) * MB,
static_cast<unsigned long long>(scratch_size_ratio * 4096) * MB,
};
case 32:
return {
static_cast<unsigned long long>(scratch_size_ratio * 4096) * MB,
static_cast<unsigned long long>(scratch_size_ratio * 2048) * MB,
static_cast<unsigned long long>(scratch_size_ratio * 4096) * MB,
};
default:
MODEL_ASSERT(false);
}
Expand Down
100 changes: 65 additions & 35 deletions neural_speed/models/baichuan/baichuan_utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -63,15 +63,15 @@ void BAICHUAN::init(const char* path_model, model_context* ctx, int n_gpu_layer_
model.hparams = ml->file_loaders.at(0)->hparams;
model_file_version file_version = ml->file_loaders.at(0)->file_version;
auto& hparams = model.hparams;
n_ff = 4 * hparams.n_embd;
fprintf(stderr, "%s: n_vocab = %u\n", __func__, hparams.n_vocab);
fprintf(stderr, "%s: n_embd = %u\n", __func__, hparams.n_embd);
fprintf(stderr, "%s: n_mult = %u\n", __func__, hparams.n_mult);
fprintf(stderr, "%s: n_head = %u\n", __func__, hparams.n_head);
fprintf(stderr, "%s: n_layer = %u\n", __func__, hparams.n_layer);
fprintf(stderr, "%s: n_rot = %u\n", __func__, hparams.n_rot);
fprintf(stderr, "%s: n_ff = %u\n", __func__, n_ff);
fprintf(stderr, "%s: n_ff = %u\n", __func__, hparams.ffn_hidden_size);
fprintf(stderr, "%s: n_parts = %zu\n", __func__, ml->file_loaders.size());
fprintf(stderr, "%s: inner_hidden_size = %u\n", __func__, hparams.inner_hidden_size);
n_embd = hparams.n_embd;
n_vocab = hparams.n_vocab;
n_layer = hparams.n_layer;
Expand All @@ -92,10 +92,6 @@ void BAICHUAN::load(model_context* ctx, model_progress_callback progress_callbac
fprintf(stderr, "%s: ctx size = %7.2f MB\n", __func__, ctx_size / 1024.0 / 1024.0);

const auto& hparams = model.hparams;
const int head_dim = n_embd / hparams.n_head;
const int kv_heads = hparams.n_head; // 1 if MQA else hparams.n_head
const int kv_dim = kv_heads * head_dim;
const int max_len = 4096;

// create the ne context
lctx.model.buf.resize(ctx_size);
Expand All @@ -116,37 +112,71 @@ void BAICHUAN::load(model_context* ctx, model_progress_callback progress_callbac
}

ml->ne_ctx = ne_ctx;

model.others[0] = ml->get_tensor("model.embed_tokens.weight", {n_embd, n_vocab}, NE_BACKEND_CPU);
model.others[1] = ml->get_tensor("model.norm.weight", {n_embd}, NE_BACKEND_CPU);
model.others[2] = ml->get_tensor("lm_head.weight", {n_embd, n_vocab}, NE_BACKEND_CPU);
const int i_gpu_start = n_layer - n_gpu_layer;

model.layers.resize(n_layer);
size_t vram_total = 0;
for (uint32_t i = 0; i < n_layer; ++i) {
const ne_backend backend = static_cast<int>(i) < i_gpu_start ? NE_BACKEND_CPU : MODEL_BACKEND_OFFLOAD;
auto& layer = model.layers[i];
std::string layers_i = "model.layers." + std::to_string(i);
layer.norm[0] = ml->get_tensor(layers_i + ".input_layernorm.weight", {n_embd}, backend);

// qkv GEMM
layer.attn[0] = ml->get_tensor(layers_i + ".self_attn.W_pack.weight", {n_embd, 3 * n_embd}, backend);
layer.attn[1] = ml->get_tensor(layers_i + ".self_attn.o_proj.weight", {n_embd, n_embd}, backend);

layer.norm[1] = ml->get_tensor(layers_i + ".post_attention_layernorm.weight", {n_embd}, backend);

// ffn GEMM
layer.ffn[0] = ml->get_tensor(layers_i + ".mlp.gate_proj.weight",
{n_embd, uint32_t(model.hparams.inner_hidden_size)}, backend);

layer.ffn[1] = ml->get_tensor(layers_i + ".mlp.down_proj.weight",
{uint32_t(model.hparams.inner_hidden_size), n_embd}, backend);
layer.ffn[2] =
ml->get_tensor(layers_i + ".mlp.up_proj.weight", {n_embd, uint32_t(model.hparams.inner_hidden_size)}, backend);

layer.v_cache = nullptr;
layer.k_cache = nullptr;
if (ml->verify_tensor("token_embd.weight")) { // for gguf
model.others[0] = ml->get_tensor("token_embd.weight", {n_embd, n_vocab}, NE_BACKEND_CPU);
model.others[1] = ml->get_tensor("output_norm.weight", {n_embd}, NE_BACKEND_CPU);
model.others[2] = ml->get_tensor("output.weight", {n_embd, n_vocab}, NE_BACKEND_CPU);
const int i_gpu_start = n_layer - n_gpu_layer;

model.layers.resize(n_layer);
for (uint32_t i = 0; i < n_layer; ++i) {
const ne_backend backend = static_cast<int>(i) < i_gpu_start ? NE_BACKEND_CPU : MODEL_BACKEND_OFFLOAD;
auto& layer = model.layers[i];
std::string layers_i = "blk." + std::to_string(i);
layer.norm[0] = ml->get_tensor(layers_i + ".attn_norm.weight", {n_embd}, backend);

// qkv GEMM
std::string w_pack = "model.layers." + std::to_string(i);
layer.attn[0] = ml->get_tensor(w_pack + ".self_attn.W_pack.weight", {n_embd, 3 * n_embd}, backend);
layer.attn[1] = ml->get_tensor(layers_i + ".attn_output.weight", {n_embd, n_embd}, backend);

layer.norm[1] = ml->get_tensor(layers_i + ".ffn_norm.weight", {n_embd}, backend);

// ffn GEMM
layer.ffn[0] =
ml->get_tensor(layers_i + ".ffn_gate.weight", {n_embd, uint32_t(model.hparams.ffn_hidden_size)}, backend);

layer.ffn[1] =
ml->get_tensor(layers_i + ".ffn_down.weight", {uint32_t(model.hparams.ffn_hidden_size), n_embd}, backend);
layer.ffn[2] =
ml->get_tensor(layers_i + ".ffn_up.weight", {n_embd, uint32_t(model.hparams.ffn_hidden_size)}, backend);

layer.v_cache = nullptr;
layer.k_cache = nullptr;
}
} else {
model.others[0] = ml->get_tensor("model.embed_tokens.weight", {n_embd, n_vocab}, NE_BACKEND_CPU);
model.others[1] = ml->get_tensor("model.norm.weight", {n_embd}, NE_BACKEND_CPU);
model.others[2] = ml->get_tensor("lm_head.weight", {n_embd, n_vocab}, NE_BACKEND_CPU);
const int i_gpu_start = n_layer - n_gpu_layer;

model.layers.resize(n_layer);
for (uint32_t i = 0; i < n_layer; ++i) {
const ne_backend backend = static_cast<int>(i) < i_gpu_start ? NE_BACKEND_CPU : MODEL_BACKEND_OFFLOAD;
auto& layer = model.layers[i];
std::string layers_i = "model.layers." + std::to_string(i);
layer.norm[0] = ml->get_tensor(layers_i + ".input_layernorm.weight", {n_embd}, backend);

// qkv GEMM
layer.attn[0] = ml->get_tensor(layers_i + ".self_attn.W_pack.weight", {n_embd, 3 * n_embd}, backend);
layer.attn[1] = ml->get_tensor(layers_i + ".self_attn.o_proj.weight", {n_embd, n_embd}, backend);

layer.norm[1] = ml->get_tensor(layers_i + ".post_attention_layernorm.weight", {n_embd}, backend);

// ffn GEMM
layer.ffn[0] = ml->get_tensor(layers_i + ".mlp.gate_proj.weight",
{n_embd, uint32_t(model.hparams.inner_hidden_size)}, backend);

layer.ffn[1] = ml->get_tensor(layers_i + ".mlp.down_proj.weight",
{uint32_t(model.hparams.inner_hidden_size), n_embd}, backend);
layer.ffn[2] = ml->get_tensor(layers_i + ".mlp.up_proj.weight",
{n_embd, uint32_t(model.hparams.inner_hidden_size)}, backend);

layer.v_cache = nullptr;
layer.k_cache = nullptr;
}
}

// print memory requirements
Expand Down
4 changes: 2 additions & 2 deletions neural_speed/models/model_utils/model_files.h
Original file line number Diff line number Diff line change
Expand Up @@ -1090,7 +1090,7 @@ struct model_file_loader {
printf("%-16s %d.hparams.n_head = %-30d\n", __func__, count++, hparams.n_head);
printf("%-16s %d.hparams.n_head_kv = %-30d\n", __func__, count++, hparams.n_head_kv);
printf("%-16s %d.hparams.n_layer = %-30d\n", __func__, count++, hparams.n_layer);
printf("%-16s %d.hparams.n_rot = %-30d\n", __func__, count++, hparams.n_vocab);
printf("%-16s %d.hparams.n_rot = %-30d\n", __func__, count++, hparams.n_rot);

hparams.ftype = (enum ne_ftype)file.read_u32();
hparams.max_seq_len = file.read_u32();
Expand Down Expand Up @@ -1122,7 +1122,7 @@ struct model_file_loader {
file.read_raw(&hparams.norm_eps, sizeof(float));
file.read_raw(&hparams.freq_base, sizeof(float));
file.read_raw(&hparams.freq_scale, sizeof(float));
printf("%-16s %d.hparams.inner_hidden_size = %-30d\n", __func__, count++, hparams.inner_hidden_size);
printf("%-16s %d.hparams.norm_eps = %-30f\n", __func__, count++, hparams.norm_eps);
printf("%-16s %d.hparams.freq_base = %-30.3f\n", __func__, count++, hparams.freq_base);
printf("%-16s %d.hparams.freq_scale = %-30.3f\n", __func__, count++, hparams.freq_scale);

Expand Down

0 comments on commit 8d5fe2d

Please sign in to comment.