From 71d8bd6480fb672fba2bf8cd45225d5658db1167 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Stanis=C5=82aw=20Szymczyk?= Date: Wed, 1 May 2024 09:43:19 +0200 Subject: [PATCH 01/15] Added support for the snowflake-arctic model. --- convert-hf-to-gguf.py | 113 +++++++++++++++ gguf-py/gguf/constants.py | 25 ++++ gguf-py/gguf/tensor_mapping.py | 66 ++++++++- llama.cpp | 246 ++++++++++++++++++++++++++++++++- 4 files changed, 447 insertions(+), 3 deletions(-) diff --git a/convert-hf-to-gguf.py b/convert-hf-to-gguf.py index 2f146d7302a78..6f013a1d00217 100755 --- a/convert-hf-to-gguf.py +++ b/convert-hf-to-gguf.py @@ -1516,6 +1516,119 @@ def write_tensors(self): if len(experts) > 0: raise ValueError(f"Unprocessed experts: {experts.keys()}") +@Model.register("ArcticForCausalLM") +class ArcticModel(Model): + model_arch = gguf.MODEL_ARCH.ARCTIC + + def set_vocab(self): + self._set_vocab_llama_hf() + + def set_gguf_parameters(self): + super().set_gguf_parameters() + hparams = self.hparams + self.gguf_writer.add_vocab_size(hparams["vocab_size"]) + self.gguf_writer.add_rope_dimension_count(hparams["hidden_size"] // hparams["num_attention_heads"]) + + # Same as super class, but permuting q_proj, k_proj + def write_tensors(self): + block_count = self.hparams.get("n_layers", self.hparams.get("num_hidden_layers", self.hparams.get("n_layer"))) + tensor_map = gguf.get_tensor_name_map(self.model_arch, block_count) + n_head = self.hparams.get("num_attention_heads") + n_kv_head = self.hparams.get("num_key_value_heads") + n_experts = self.hparams.get("num_local_experts") + experts = dict() + for name, data_torch in self.get_tensors(): + # we don't need these + if name.endswith((".attention.masked_bias", ".attention.bias", ".attention.rotary_emb.inv_freq")): + continue + + old_dtype = data_torch.dtype + + # convert any unsupported data types to float32 + if data_torch.dtype not in (torch.float16, torch.float32): + data_torch = data_torch.to(torch.float32) + + data = data_torch.numpy() + + if name.endswith("q_proj.weight"): + data = permute(data, n_head, n_head) + if name.endswith("k_proj.weight"): + data = permute(data, n_head, n_kv_head) + + data = data.squeeze() + + # process the experts separately + if name.find("block_sparse_moe.experts") != -1: + experts[name] = data + if len(experts) >= n_experts: + # merge the experts into a single 3d tensor + for bid in range(block_count): + for wid in range(1, 4): + full = True + for xid in range(n_experts): + ename = f"model.layers.{bid}.block_sparse_moe.experts.{xid}.w{wid}.weight" + if ename not in experts: + full = False + break + if not full: + continue + + datas = [] + for xid in range(n_experts): + ename = f"model.layers.{bid}.block_sparse_moe.experts.{xid}.w{wid}.weight" + datas.append(experts[ename]) + del experts[ename] + + data = np.stack(datas, axis=0) + data_dtype = data.dtype + + if self.ftype == 0 and data_dtype == np.float16: + data = data.astype(np.float32) + + if self.ftype == 1 and data_dtype == np.float32: + data = data.astype(np.float16) + + merged_name = f"layers.{bid}.feed_forward.experts.w{wid}.weight" + + new_name = tensor_map.get_name(merged_name, try_suffixes=(".weight", ".bias")) + if new_name is None: + print(f"Can not map tensor {name!r}") + sys.exit() + + print(f"{new_name}, n_dims = {len(data.shape)}, shape = {data.shape} --> {data.dtype}") + + self.gguf_writer.add_tensor(new_name, data) + continue + + # map tensor names + new_name = tensor_map.get_name(name, try_suffixes=(".weight", ".bias")) + if new_name is None: + print(f"Can not map tensor {name!r}") + sys.exit() + + n_dims = len(data.shape) + data_dtype = data.dtype + + # if f32 desired, convert any float16 to float32 + if self.ftype == 0 and data_dtype == np.float16: + data = data.astype(np.float32) + + # 1d tensors need to be converted to float32 + if self.ftype == 1 and data_dtype == np.float16 and n_dims == 1: + data = data.astype(np.float32) + + # if f16 desired, convert any float32 2-dim weight tensors to float16 + if self.ftype == 1 and data_dtype == np.float32 and name.endswith(".weight") and n_dims == 2: + data = data.astype(np.float16) + + print(f"{new_name}, n_dims = {n_dims}, {old_dtype} --> {data.dtype}") + + self.gguf_writer.add_tensor(new_name, data) + + if len(experts) > 0: + raise ValueError(f"Unprocessed experts: {experts.keys()}") + + @Model.register("GrokForCausalLM") class GrokModel(Model): diff --git a/gguf-py/gguf/constants.py b/gguf-py/gguf/constants.py index 6d597bfd9d621..0478828182e8f 100644 --- a/gguf-py/gguf/constants.py +++ b/gguf-py/gguf/constants.py @@ -138,6 +138,7 @@ class MODEL_ARCH(IntEnum): COMMAND_R = auto() DBRX = auto() OLMO = auto() + ARCTIC = auto() class MODEL_TENSOR(IntEnum): @@ -180,6 +181,7 @@ class MODEL_TENSOR(IntEnum): SSM_A = auto() SSM_D = auto() SSM_OUT = auto() + FFN_NORM_EXP = auto() MODEL_ARCH_NAMES: dict[MODEL_ARCH, str] = { @@ -215,6 +217,7 @@ class MODEL_TENSOR(IntEnum): MODEL_ARCH.COMMAND_R: "command-r", MODEL_ARCH.DBRX: "dbrx", MODEL_ARCH.OLMO: "olmo", + MODEL_ARCH.ARCTIC: "arctic", } TENSOR_NAMES: dict[MODEL_TENSOR, str] = { @@ -257,6 +260,7 @@ class MODEL_TENSOR(IntEnum): MODEL_TENSOR.SSM_A: "blk.{bid}.ssm_a", MODEL_TENSOR.SSM_D: "blk.{bid}.ssm_d", MODEL_TENSOR.SSM_OUT: "blk.{bid}.ssm_out", + MODEL_TENSOR.FFN_NORM_EXP: "blk.{bid}.ffn_norm_exps", } MODEL_TENSORS: dict[MODEL_ARCH, list[MODEL_TENSOR]] = { @@ -725,6 +729,27 @@ class MODEL_TENSOR(IntEnum): MODEL_TENSOR.FFN_DOWN, MODEL_TENSOR.FFN_UP, ], + MODEL_ARCH.ARCTIC: [ + MODEL_TENSOR.TOKEN_EMBD, + MODEL_TENSOR.OUTPUT_NORM, + MODEL_TENSOR.OUTPUT, + MODEL_TENSOR.ROPE_FREQS, + MODEL_TENSOR.ATTN_NORM, + MODEL_TENSOR.ATTN_Q, + MODEL_TENSOR.ATTN_K, + MODEL_TENSOR.ATTN_V, + MODEL_TENSOR.ATTN_OUT, + MODEL_TENSOR.ATTN_ROT_EMBD, + MODEL_TENSOR.FFN_GATE_INP, + MODEL_TENSOR.FFN_NORM, + MODEL_TENSOR.FFN_GATE, + MODEL_TENSOR.FFN_DOWN, + MODEL_TENSOR.FFN_UP, + MODEL_TENSOR.FFN_GATE_EXP, + MODEL_TENSOR.FFN_DOWN_EXP, + MODEL_TENSOR.FFN_UP_EXP, + MODEL_TENSOR.FFN_NORM_EXP, + ], # TODO } diff --git a/gguf-py/gguf/tensor_mapping.py b/gguf-py/gguf/tensor_mapping.py index e5750d4191f6b..6ecce589c9e29 100644 --- a/gguf-py/gguf/tensor_mapping.py +++ b/gguf-py/gguf/tensor_mapping.py @@ -370,6 +370,64 @@ class TensorNameMap: "model.layers.{bid}.out_proj", "backbone.layers.{bid}.mixer.out_proj", ), + + } + + # architecture-specific block mappings + arch_block_mappings_cfg: dict[MODEL_ARCH, dict[MODEL_TENSOR, tuple[str, ...]]] = { + MODEL_ARCH.ARCTIC: { + MODEL_TENSOR.TOKEN_EMBD: ( + "model.embed_tokens", + ), + MODEL_TENSOR.OUTPUT_NORM: ( + "model.norm", + ), + MODEL_TENSOR.OUTPUT: ( + "lm_head", + ), + MODEL_TENSOR.ATTN_NORM: ( + "model.layers.{bid}.input_layernorm", + ), + MODEL_TENSOR.ATTN_Q: ( + "model.layers.{bid}.self_attn.q_proj", + ), + MODEL_TENSOR.ATTN_K: ( + "model.layers.{bid}.self_attn.k_proj", + ), + MODEL_TENSOR.ATTN_V: ( + "model.layers.{bid}.self_attn.v_proj", + ), + MODEL_TENSOR.ATTN_OUT: ( + "model.layers.{bid}.self_attn.o_proj", + ), + MODEL_TENSOR.FFN_GATE_INP: ( + "model.layers.{bid}.block_sparse_moe.gate", + ), + MODEL_TENSOR.FFN_NORM: ( + "model.layers.{bid}.residual_layernorm", + ), + MODEL_TENSOR.FFN_GATE: ( + "model.layers.{bid}.residual_mlp.w1", + ), + MODEL_TENSOR.FFN_DOWN: ( + "model.layers.{bid}.residual_mlp.w2", + ), + MODEL_TENSOR.FFN_UP: ( + "model.layers.{bid}.residual_mlp.w3", + ), + MODEL_TENSOR.FFN_GATE_EXP: ( + "layers.{bid}.feed_forward.experts.w1", + ), + MODEL_TENSOR.FFN_DOWN_EXP: ( + "layers.{bid}.feed_forward.experts.w2", + ), + MODEL_TENSOR.FFN_UP_EXP: ( + "layers.{bid}.feed_forward.experts.w3", + ), + MODEL_TENSOR.FFN_NORM_EXP: ( + "model.layers.{bid}.post_attention_layernorm", + ), + }, } mapping: dict[str, tuple[MODEL_TENSOR, str]] @@ -383,12 +441,16 @@ def __init__(self, arch: MODEL_ARCH, n_blocks: int): self.mapping[tensor_name] = (tensor, tensor_name) for key in keys: self.mapping[key] = (tensor, tensor_name) + if arch in self.arch_block_mappings_cfg: + block_mappings = self.arch_block_mappings_cfg[arch] + else: + block_mappings = self.block_mappings_cfg for bid in range(n_blocks): - for tensor, keys in self.block_mappings_cfg.items(): + for tensor, keys in block_mappings.items(): if tensor not in MODEL_TENSORS[arch]: continue # TODO: make this configurable - n_experts = 60 + n_experts = 128 for xid in range(n_experts): tensor_name = TENSOR_NAMES[tensor].format(bid = bid, xid = xid) self.mapping[tensor_name] = (tensor, tensor_name) diff --git a/llama.cpp b/llama.cpp index 18d6297ce1dfd..ee5a3226e753a 100644 --- a/llama.cpp +++ b/llama.cpp @@ -106,7 +106,7 @@ #endif #define LLAMA_MAX_NODES 8192 -#define LLAMA_MAX_EXPERTS 60 +#define LLAMA_MAX_EXPERTS 128 // // logging @@ -224,6 +224,7 @@ enum llm_arch { LLM_ARCH_COMMAND_R, LLM_ARCH_DBRX, LLM_ARCH_OLMO, + LLM_ARCH_ARCTIC, LLM_ARCH_UNKNOWN, }; @@ -260,6 +261,7 @@ static const std::map LLM_ARCH_NAMES = { { LLM_ARCH_COMMAND_R, "command-r" }, { LLM_ARCH_DBRX, "dbrx" }, { LLM_ARCH_OLMO, "olmo" }, + { LLM_ARCH_ARCTIC, "arctic" }, { LLM_ARCH_UNKNOWN, "(unknown)" }, }; @@ -457,6 +459,7 @@ enum llm_tensor { LLM_TENSOR_FFN_DOWN_EXPS, // merged experts LLM_TENSOR_FFN_GATE_EXPS, LLM_TENSOR_FFN_UP_EXPS, + LLM_TENSOR_FFN_NORM_EXPS, LLM_TENSOR_FFN_DOWN_SHEXP, LLM_TENSOR_FFN_GATE_SHEXP, LLM_TENSOR_FFN_UP_SHEXP, @@ -1027,6 +1030,28 @@ static const std::map> LLM_TENSOR_NA { LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" }, }, }, + { + LLM_ARCH_ARCTIC, + { + { LLM_TENSOR_TOKEN_EMBD, "token_embd" }, + { LLM_TENSOR_OUTPUT_NORM, "output_norm" }, + { LLM_TENSOR_OUTPUT, "output" }, + { LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" }, + { LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" }, + { LLM_TENSOR_ATTN_K, "blk.%d.attn_k" }, + { LLM_TENSOR_ATTN_V, "blk.%d.attn_v" }, + { LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" }, + { LLM_TENSOR_FFN_GATE_INP, "blk.%d.ffn_gate_inp" }, + { LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" }, + { LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" }, + { LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" }, + { LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" }, + { LLM_TENSOR_FFN_GATE_EXPS, "blk.%d.ffn_gate_exps" }, + { LLM_TENSOR_FFN_DOWN_EXPS, "blk.%d.ffn_down_exps" }, + { LLM_TENSOR_FFN_UP_EXPS, "blk.%d.ffn_up_exps" }, + { LLM_TENSOR_FFN_NORM_EXPS, "blk.%d.ffn_norm_exps" }, + }, + }, { LLM_ARCH_UNKNOWN, { @@ -1803,6 +1828,7 @@ enum e_model { MODEL_8x7B, MODEL_8x22B, MODEL_16x12B, + MODEL_10B_128x3_66B, }; static const size_t kiB = 1024; @@ -1975,6 +2001,7 @@ struct llama_layer { struct ggml_tensor * ffn_norm_b; struct ggml_tensor * layer_out_norm; struct ggml_tensor * layer_out_norm_b; + struct ggml_tensor * ffn_norm_exps; // ff struct ggml_tensor * ffn_gate; // w1 @@ -3734,6 +3761,7 @@ static const char * llama_model_type_name(e_model type) { case MODEL_8x7B: return "8x7B"; case MODEL_8x22B: return "8x22B"; case MODEL_16x12B: return "16x12B"; + case MODEL_10B_128x3_66B: return "10B+128x3.66B"; default: return "?B"; } } @@ -4196,6 +4224,20 @@ static void llm_load_hparams( default: model.type = e_model::MODEL_UNKNOWN; } } break; + case LLM_ARCH_ARCTIC: + { + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); + + if (hparams.n_expert == 128) { + switch (hparams.n_layer) { + case 35: model.type = e_model::MODEL_10B_128x3_66B; break; + default: model.type = e_model::MODEL_UNKNOWN; + } + } else { + model.type = e_model::MODEL_UNKNOWN; + } + } break; + default: (void)0; } @@ -5932,6 +5974,55 @@ static bool llm_load_tensors( layer.ffn_up = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}); } } break; + case LLM_ARCH_ARCTIC: + { + model.tok_embd = ml.create_tensor(ctx_input, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}); + + // output + { + model.output_norm = ml.create_tensor(ctx_output, tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}); + model.output = ml.create_tensor(ctx_output_split, tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, false); + // if output is NULL, init from the input tok embed + if (model.output == NULL) { + model.output = ml.create_tensor(ctx_output, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}); + ml.n_created--; // artificial tensor + ml.size_data += ggml_nbytes(model.output); + } + } + + for (int i = 0; i < n_layer; ++i) { + ggml_context * ctx_layer = ctx_for_layer(i); + ggml_context * ctx_split = ctx_for_layer_split(i); + + auto & layer = model.layers[i]; + + layer.attn_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}); + + layer.wq = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd}); + layer.wk = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_gqa}); + layer.wv = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_gqa}); + layer.wo = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}); + + // optional bias tensors + layer.bq = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_Q, "bias", i), {n_embd}, false); + layer.bk = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_K, "bias", i), {n_embd_gqa}, false); + layer.bv = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_V, "bias", i), {n_embd_gqa}, false); + layer.bo = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_OUT, "bias", i), {n_embd}, false); + + layer.ffn_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}); + + layer.ffn_gate = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_embd}); + layer.ffn_down = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_DOWN, "weight", i), {n_embd, n_embd}); + layer.ffn_up = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_embd}); + + layer.ffn_gate_inp = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_GATE_INP, "weight", i), {n_embd, n_expert}); + layer.ffn_norm_exps = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_NORM_EXPS, "weight", i), {n_embd}); + layer.ffn_gate_exps = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_GATE_EXPS, "weight", i), {n_embd, n_ff, n_expert}, false); + layer.ffn_down_exps = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_DOWN_EXPS, "weight", i), { n_ff, n_embd, n_expert}); + layer.ffn_up_exps = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_UP_EXPS, "weight", i), {n_embd, n_ff, n_expert}); + } + } break; + default: throw std::runtime_error("unknown architecture"); } @@ -10682,6 +10773,154 @@ struct llm_build_context { return gf; } + + struct ggml_cgraph * build_arctic() { + struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, LLAMA_MAX_NODES, false); + + // mutable variable, needed during the last layer of the computation to skip unused tokens + int32_t n_tokens = this->n_tokens; + + const int64_t n_embd_head = hparams.n_embd_head_v; + GGML_ASSERT(n_embd_head == hparams.n_embd_head_k); + GGML_ASSERT(n_embd_head == hparams.n_rot); + + struct ggml_tensor * cur; + struct ggml_tensor * inpL; + + inpL = llm_build_inp_embd(ctx0, lctx, hparams, batch, model.tok_embd, cb); + + // inp_pos - contains the positions + struct ggml_tensor * inp_pos = build_inp_pos(); + + // KQ_mask (mask for 1 head, it will be broadcasted to all heads) + struct ggml_tensor * KQ_mask = build_inp_KQ_mask(); + + for (int il = 0; il < n_layer; ++il) { + struct ggml_tensor * inpSA = inpL; + + // norm + cur = llm_build_norm(ctx0, inpL, hparams, + model.layers[il].attn_norm, NULL, + LLM_NORM_RMS, cb, il); + cb(cur, "attn_norm", il); + + // self-attention + { + // compute Q and K and RoPE them + struct ggml_tensor * Qcur = ggml_mul_mat(ctx0, model.layers[il].wq, cur); + cb(Qcur, "Qcur", il); + if (model.layers[il].bq) { + Qcur = ggml_add(ctx0, Qcur, model.layers[il].bq); + cb(Qcur, "Qcur", il); + } + + struct ggml_tensor * Kcur = ggml_mul_mat(ctx0, model.layers[il].wk, cur); + cb(Kcur, "Kcur", il); + if (model.layers[il].bk) { + Kcur = ggml_add(ctx0, Kcur, model.layers[il].bk); + cb(Kcur, "Kcur", il); + } + + struct ggml_tensor * Vcur = ggml_mul_mat(ctx0, model.layers[il].wv, cur); + cb(Vcur, "Vcur", il); + if (model.layers[il].bv) { + Vcur = ggml_add(ctx0, Vcur, model.layers[il].bv); + cb(Vcur, "Vcur", il); + } + + Qcur = ggml_rope_custom( + ctx0, ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens), inp_pos, + n_rot, rope_type, 0, n_orig_ctx, freq_base, freq_scale, + ext_factor, attn_factor, beta_fast, beta_slow + ); + cb(Qcur, "Qcur", il); + + Kcur = ggml_rope_custom( + ctx0, ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens), inp_pos, + n_rot, rope_type, 0, n_orig_ctx, freq_base, freq_scale, + ext_factor, attn_factor, beta_fast, beta_slow + ); + cb(Kcur, "Kcur", il); + + cur = llm_build_kv(ctx0, model, hparams, cparams, kv_self, gf, + model.layers[il].wo, model.layers[il].bo, + Kcur, Vcur, Qcur, KQ_mask, nullptr, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il); + } + + if (il == n_layer - 1) { + // skip computing output for unused tokens + struct ggml_tensor * inp_out_ids = build_inp_out_ids(); + n_tokens = n_outputs; + cur = ggml_get_rows(ctx0, cur, inp_out_ids); + inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids); + } + + struct ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA); + cb(ffn_inp, "ffn_inp", il); + + // feed-forward network + cur = llm_build_norm(ctx0, ffn_inp, hparams, + model.layers[il].ffn_norm, NULL, + LLM_NORM_RMS, cb, il); + cb(cur, "ffn_norm", il); + + cur = llm_build_ffn(ctx0, cur, + model.layers[il].ffn_up, NULL, + model.layers[il].ffn_gate, NULL, + model.layers[il].ffn_down, NULL, + NULL, + LLM_FFN_SILU, LLM_FFN_PAR, cb, il); + cb(cur, "ffn_out", il); + + struct ggml_tensor * ffn_out = ggml_add(ctx0, cur, ffn_inp); + cb(ffn_out, "ffn_out", il); + + // MoE + cur = llm_build_norm(ctx0, inpSA, hparams, + model.layers[il].ffn_norm_exps, NULL, + LLM_NORM_RMS, cb, il); + cb(cur, "ffn_norm_exps", il); + + cur = llm_build_moe_ffn(ctx0, cur, + model.layers[il].ffn_gate_inp, + model.layers[il].ffn_up_exps, + model.layers[il].ffn_gate_exps, + model.layers[il].ffn_down_exps, + n_expert, n_expert_used, + LLM_FFN_SILU, true, + cb, il); + cb(cur, "ffn_moe_out", il); + + cur = ggml_add(ctx0, cur, ffn_out); + cb(cur, "ffn_out", il); + + ggml_tensor * layer_dir = lctx.cvec.tensor_for(il); + if (layer_dir != nullptr) { + cur = ggml_add(ctx0, cur, layer_dir); + } + cb(cur, "l_out", il); + + // input for next layer + inpL = cur; + } + + cur = inpL; + + cur = llm_build_norm(ctx0, cur, hparams, + model.output_norm, NULL, + LLM_NORM_RMS, cb, -1); + cb(cur, "result_norm", -1); + + // lm_head + cur = ggml_mul_mat(ctx0, model.output, cur); + cb(cur, "result_output", -1); + + ggml_build_forward_expand(gf, cur); + + return gf; + } + + }; static struct ggml_cgraph * llama_build_graph_defrag(llama_context & lctx, const std::vector & ids) { @@ -10895,6 +11134,10 @@ static struct ggml_cgraph * llama_build_graph( { result = llm.build_olmo(); } break; + case LLM_ARCH_ARCTIC: + { + result = llm.build_arctic(); + } break; default: GGML_ASSERT(false); } @@ -15783,6 +16026,7 @@ enum llama_rope_type llama_rope_type(const struct llama_model * model) { case LLM_ARCH_XVERSE: case LLM_ARCH_COMMAND_R: case LLM_ARCH_OLMO: + case LLM_ARCH_ARCTIC: return LLAMA_ROPE_TYPE_NORM; // the pairs of head values are offset by n_rot/2 From c95013d1b5a2eb247e212889c57ca36e738b08fe Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Stanis=C5=82aw=20Szymczyk?= Date: Thu, 2 May 2024 09:53:59 +0200 Subject: [PATCH 02/15] Whitespace formatting fixes. --- convert-hf-to-gguf.py | 2 +- gguf-py/gguf/tensor_mapping.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/convert-hf-to-gguf.py b/convert-hf-to-gguf.py index 6f013a1d00217..17cdb0fc577f0 100755 --- a/convert-hf-to-gguf.py +++ b/convert-hf-to-gguf.py @@ -1516,6 +1516,7 @@ def write_tensors(self): if len(experts) > 0: raise ValueError(f"Unprocessed experts: {experts.keys()}") + @Model.register("ArcticForCausalLM") class ArcticModel(Model): model_arch = gguf.MODEL_ARCH.ARCTIC @@ -1629,7 +1630,6 @@ def write_tensors(self): raise ValueError(f"Unprocessed experts: {experts.keys()}") - @Model.register("GrokForCausalLM") class GrokModel(Model): model_arch = gguf.MODEL_ARCH.GROK diff --git a/gguf-py/gguf/tensor_mapping.py b/gguf-py/gguf/tensor_mapping.py index 6ecce589c9e29..844a270f4db9c 100644 --- a/gguf-py/gguf/tensor_mapping.py +++ b/gguf-py/gguf/tensor_mapping.py @@ -427,7 +427,7 @@ class TensorNameMap: MODEL_TENSOR.FFN_NORM_EXP: ( "model.layers.{bid}.post_attention_layernorm", ), - }, + }, } mapping: dict[str, tuple[MODEL_TENSOR, str]] From c6f15a752ae7733f7e5bcd84caae1a4bc6a3b8f4 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Stanis=C5=82aw=20Szymczyk?= Date: Tue, 7 May 2024 20:57:28 +0200 Subject: [PATCH 03/15] Read vocabulary for ArcticForCausalLM from sentencepiece model instead of HF tokenizer. Add/redefine tokens accordingly to added_tokens_decoder from tokenizer_config.json --- convert-hf-to-gguf.py | 82 ++++++++++++++++++++++++++++++++++++++++++- 1 file changed, 81 insertions(+), 1 deletion(-) diff --git a/convert-hf-to-gguf.py b/convert-hf-to-gguf.py index 17cdb0fc577f0..dd7b7c6fce64c 100755 --- a/convert-hf-to-gguf.py +++ b/convert-hf-to-gguf.py @@ -1522,7 +1522,87 @@ class ArcticModel(Model): model_arch = gguf.MODEL_ARCH.ARCTIC def set_vocab(self): - self._set_vocab_llama_hf() + # The reason for using a custom implementation here is that the + # snowflake-arctic-instruct model redefined tokens 31998 and 31999 from + # tokenizer.model and used them as BOS and EOS instead of adding new tokens. + from sentencepiece import SentencePieceProcessor + + tokenizer_path = self.dir_model / 'tokenizer.model' + + if not tokenizer_path.is_file(): + print(f'Error: Missing {tokenizer_path}', file=sys.stderr) + sys.exit(1) + + # Read the whole vocabulary from the tokenizer.model file + tokenizer = SentencePieceProcessor(str(tokenizer_path)) + + vocab_size = self.hparams.get('vocab_size', tokenizer.vocab_size()) + + tokens: list[bytes] = [f"[PAD{i}]".encode("utf-8") for i in range(vocab_size)] + scores: list[float] = [-10000.0] * vocab_size + toktypes: list[int] = [SentencePieceTokenTypes.UNKNOWN] * vocab_size + + for token_id in range(tokenizer.vocab_size()): + + piece = tokenizer.id_to_piece(token_id) + text = piece.encode("utf-8") + score = tokenizer.get_score(token_id) + + toktype = SentencePieceTokenTypes.NORMAL + if tokenizer.is_unknown(token_id): + toktype = SentencePieceTokenTypes.UNKNOWN + elif tokenizer.is_control(token_id): + toktype = SentencePieceTokenTypes.CONTROL + elif tokenizer.is_unused(token_id): + toktype = SentencePieceTokenTypes.UNUSED + elif tokenizer.is_byte(token_id): + toktype = SentencePieceTokenTypes.BYTE + + tokens[token_id] = text + scores[token_id] = score + toktypes[token_id] = toktype + + # Use the added_tokens_decoder field from tokeniser_config.json as the source + # of information about added/redefined tokens and modify them accordingly. + tokenizer_config_file = self.dir_model / 'tokenizer_config.json' + if tokenizer_config_file.is_file(): + with open(tokenizer_config_file, "r", encoding="utf-8") as f: + tokenizer_config_json = json.load(f) + + if "added_tokens_decoder" in tokenizer_config_json: + added_tokens_decoder = tokenizer_config_json["added_tokens_decoder"] + for token_id, token_json in added_tokens_decoder.items(): + token_id = int(token_id) + if (token_id >= vocab_size): + print(f'ignore token {token_id}: id is out of range, max={vocab_size - 1}') + continue + + token_content = token_json["content"] + token_type = SentencePieceTokenTypes.USER_DEFINED + token_score = -10000.0 + + # Map unk_token to UNKNOWN, other special tokens to CONTROL + # Set the score to 0.0 as in the original tokenizer.model + if ("special" in token_json) and token_json["special"]: + if token_content == tokenizer_config_json["unk_token"]: + token_type = SentencePieceTokenTypes.UNKNOWN + else: + token_type = SentencePieceTokenTypes.CONTROL + token_score = 0.0 + + print(f"Setting token {token_id} to '{token_content}' (type: {token_type}, score: {token_score:.2f})") + tokens[token_id] = token_content.encode("utf-8") + toktypes[token_id] = token_type + scores[token_id] = token_score + + self.gguf_writer.add_tokenizer_model("llama") + self.gguf_writer.add_tokenizer_pre("default") + self.gguf_writer.add_token_list(tokens) + self.gguf_writer.add_token_scores(scores) + self.gguf_writer.add_token_types(toktypes) + + special_vocab = gguf.SpecialVocab(self.dir_model, n_vocab=len(tokens)) + special_vocab.add_to_gguf(self.gguf_writer) def set_gguf_parameters(self): super().set_gguf_parameters() From 0cffda89b3af4f21bf967ac3383655c5636c3515 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Stanis=C5=82aw=20Szymczyk?= Date: Thu, 9 May 2024 15:17:37 +0200 Subject: [PATCH 04/15] Moved ArcticModel to the end of the file. --- convert-hf-to-gguf.py | 386 +++++++++++++++++++++--------------------- 1 file changed, 193 insertions(+), 193 deletions(-) diff --git a/convert-hf-to-gguf.py b/convert-hf-to-gguf.py index dd7b7c6fce64c..c6b2bfe3e172a 100755 --- a/convert-hf-to-gguf.py +++ b/convert-hf-to-gguf.py @@ -1517,199 +1517,6 @@ def write_tensors(self): raise ValueError(f"Unprocessed experts: {experts.keys()}") -@Model.register("ArcticForCausalLM") -class ArcticModel(Model): - model_arch = gguf.MODEL_ARCH.ARCTIC - - def set_vocab(self): - # The reason for using a custom implementation here is that the - # snowflake-arctic-instruct model redefined tokens 31998 and 31999 from - # tokenizer.model and used them as BOS and EOS instead of adding new tokens. - from sentencepiece import SentencePieceProcessor - - tokenizer_path = self.dir_model / 'tokenizer.model' - - if not tokenizer_path.is_file(): - print(f'Error: Missing {tokenizer_path}', file=sys.stderr) - sys.exit(1) - - # Read the whole vocabulary from the tokenizer.model file - tokenizer = SentencePieceProcessor(str(tokenizer_path)) - - vocab_size = self.hparams.get('vocab_size', tokenizer.vocab_size()) - - tokens: list[bytes] = [f"[PAD{i}]".encode("utf-8") for i in range(vocab_size)] - scores: list[float] = [-10000.0] * vocab_size - toktypes: list[int] = [SentencePieceTokenTypes.UNKNOWN] * vocab_size - - for token_id in range(tokenizer.vocab_size()): - - piece = tokenizer.id_to_piece(token_id) - text = piece.encode("utf-8") - score = tokenizer.get_score(token_id) - - toktype = SentencePieceTokenTypes.NORMAL - if tokenizer.is_unknown(token_id): - toktype = SentencePieceTokenTypes.UNKNOWN - elif tokenizer.is_control(token_id): - toktype = SentencePieceTokenTypes.CONTROL - elif tokenizer.is_unused(token_id): - toktype = SentencePieceTokenTypes.UNUSED - elif tokenizer.is_byte(token_id): - toktype = SentencePieceTokenTypes.BYTE - - tokens[token_id] = text - scores[token_id] = score - toktypes[token_id] = toktype - - # Use the added_tokens_decoder field from tokeniser_config.json as the source - # of information about added/redefined tokens and modify them accordingly. - tokenizer_config_file = self.dir_model / 'tokenizer_config.json' - if tokenizer_config_file.is_file(): - with open(tokenizer_config_file, "r", encoding="utf-8") as f: - tokenizer_config_json = json.load(f) - - if "added_tokens_decoder" in tokenizer_config_json: - added_tokens_decoder = tokenizer_config_json["added_tokens_decoder"] - for token_id, token_json in added_tokens_decoder.items(): - token_id = int(token_id) - if (token_id >= vocab_size): - print(f'ignore token {token_id}: id is out of range, max={vocab_size - 1}') - continue - - token_content = token_json["content"] - token_type = SentencePieceTokenTypes.USER_DEFINED - token_score = -10000.0 - - # Map unk_token to UNKNOWN, other special tokens to CONTROL - # Set the score to 0.0 as in the original tokenizer.model - if ("special" in token_json) and token_json["special"]: - if token_content == tokenizer_config_json["unk_token"]: - token_type = SentencePieceTokenTypes.UNKNOWN - else: - token_type = SentencePieceTokenTypes.CONTROL - token_score = 0.0 - - print(f"Setting token {token_id} to '{token_content}' (type: {token_type}, score: {token_score:.2f})") - tokens[token_id] = token_content.encode("utf-8") - toktypes[token_id] = token_type - scores[token_id] = token_score - - self.gguf_writer.add_tokenizer_model("llama") - self.gguf_writer.add_tokenizer_pre("default") - self.gguf_writer.add_token_list(tokens) - self.gguf_writer.add_token_scores(scores) - self.gguf_writer.add_token_types(toktypes) - - special_vocab = gguf.SpecialVocab(self.dir_model, n_vocab=len(tokens)) - special_vocab.add_to_gguf(self.gguf_writer) - - def set_gguf_parameters(self): - super().set_gguf_parameters() - hparams = self.hparams - self.gguf_writer.add_vocab_size(hparams["vocab_size"]) - self.gguf_writer.add_rope_dimension_count(hparams["hidden_size"] // hparams["num_attention_heads"]) - - # Same as super class, but permuting q_proj, k_proj - def write_tensors(self): - block_count = self.hparams.get("n_layers", self.hparams.get("num_hidden_layers", self.hparams.get("n_layer"))) - tensor_map = gguf.get_tensor_name_map(self.model_arch, block_count) - n_head = self.hparams.get("num_attention_heads") - n_kv_head = self.hparams.get("num_key_value_heads") - n_experts = self.hparams.get("num_local_experts") - experts = dict() - for name, data_torch in self.get_tensors(): - # we don't need these - if name.endswith((".attention.masked_bias", ".attention.bias", ".attention.rotary_emb.inv_freq")): - continue - - old_dtype = data_torch.dtype - - # convert any unsupported data types to float32 - if data_torch.dtype not in (torch.float16, torch.float32): - data_torch = data_torch.to(torch.float32) - - data = data_torch.numpy() - - if name.endswith("q_proj.weight"): - data = permute(data, n_head, n_head) - if name.endswith("k_proj.weight"): - data = permute(data, n_head, n_kv_head) - - data = data.squeeze() - - # process the experts separately - if name.find("block_sparse_moe.experts") != -1: - experts[name] = data - if len(experts) >= n_experts: - # merge the experts into a single 3d tensor - for bid in range(block_count): - for wid in range(1, 4): - full = True - for xid in range(n_experts): - ename = f"model.layers.{bid}.block_sparse_moe.experts.{xid}.w{wid}.weight" - if ename not in experts: - full = False - break - if not full: - continue - - datas = [] - for xid in range(n_experts): - ename = f"model.layers.{bid}.block_sparse_moe.experts.{xid}.w{wid}.weight" - datas.append(experts[ename]) - del experts[ename] - - data = np.stack(datas, axis=0) - data_dtype = data.dtype - - if self.ftype == 0 and data_dtype == np.float16: - data = data.astype(np.float32) - - if self.ftype == 1 and data_dtype == np.float32: - data = data.astype(np.float16) - - merged_name = f"layers.{bid}.feed_forward.experts.w{wid}.weight" - - new_name = tensor_map.get_name(merged_name, try_suffixes=(".weight", ".bias")) - if new_name is None: - print(f"Can not map tensor {name!r}") - sys.exit() - - print(f"{new_name}, n_dims = {len(data.shape)}, shape = {data.shape} --> {data.dtype}") - - self.gguf_writer.add_tensor(new_name, data) - continue - - # map tensor names - new_name = tensor_map.get_name(name, try_suffixes=(".weight", ".bias")) - if new_name is None: - print(f"Can not map tensor {name!r}") - sys.exit() - - n_dims = len(data.shape) - data_dtype = data.dtype - - # if f32 desired, convert any float16 to float32 - if self.ftype == 0 and data_dtype == np.float16: - data = data.astype(np.float32) - - # 1d tensors need to be converted to float32 - if self.ftype == 1 and data_dtype == np.float16 and n_dims == 1: - data = data.astype(np.float32) - - # if f16 desired, convert any float32 2-dim weight tensors to float16 - if self.ftype == 1 and data_dtype == np.float32 and name.endswith(".weight") and n_dims == 2: - data = data.astype(np.float16) - - print(f"{new_name}, n_dims = {n_dims}, {old_dtype} --> {data.dtype}") - - self.gguf_writer.add_tensor(new_name, data) - - if len(experts) > 0: - raise ValueError(f"Unprocessed experts: {experts.keys()}") - - @Model.register("GrokForCausalLM") class GrokModel(Model): model_arch = gguf.MODEL_ARCH.GROK @@ -3101,6 +2908,199 @@ def write_tensors(self): self.gguf_writer.add_tensor(new_name, data) +@Model.register("ArcticForCausalLM") +class ArcticModel(Model): + model_arch = gguf.MODEL_ARCH.ARCTIC + + def set_vocab(self): + # The reason for using a custom implementation here is that the + # snowflake-arctic-instruct model redefined tokens 31998 and 31999 from + # tokenizer.model and used them as BOS and EOS instead of adding new tokens. + from sentencepiece import SentencePieceProcessor + + tokenizer_path = self.dir_model / 'tokenizer.model' + + if not tokenizer_path.is_file(): + print(f'Error: Missing {tokenizer_path}', file=sys.stderr) + sys.exit(1) + + # Read the whole vocabulary from the tokenizer.model file + tokenizer = SentencePieceProcessor(str(tokenizer_path)) + + vocab_size = self.hparams.get('vocab_size', tokenizer.vocab_size()) + + tokens: list[bytes] = [f"[PAD{i}]".encode("utf-8") for i in range(vocab_size)] + scores: list[float] = [-10000.0] * vocab_size + toktypes: list[int] = [SentencePieceTokenTypes.UNKNOWN] * vocab_size + + for token_id in range(tokenizer.vocab_size()): + + piece = tokenizer.id_to_piece(token_id) + text = piece.encode("utf-8") + score = tokenizer.get_score(token_id) + + toktype = SentencePieceTokenTypes.NORMAL + if tokenizer.is_unknown(token_id): + toktype = SentencePieceTokenTypes.UNKNOWN + elif tokenizer.is_control(token_id): + toktype = SentencePieceTokenTypes.CONTROL + elif tokenizer.is_unused(token_id): + toktype = SentencePieceTokenTypes.UNUSED + elif tokenizer.is_byte(token_id): + toktype = SentencePieceTokenTypes.BYTE + + tokens[token_id] = text + scores[token_id] = score + toktypes[token_id] = toktype + + # Use the added_tokens_decoder field from tokeniser_config.json as the source + # of information about added/redefined tokens and modify them accordingly. + tokenizer_config_file = self.dir_model / 'tokenizer_config.json' + if tokenizer_config_file.is_file(): + with open(tokenizer_config_file, "r", encoding="utf-8") as f: + tokenizer_config_json = json.load(f) + + if "added_tokens_decoder" in tokenizer_config_json: + added_tokens_decoder = tokenizer_config_json["added_tokens_decoder"] + for token_id, token_json in added_tokens_decoder.items(): + token_id = int(token_id) + if (token_id >= vocab_size): + print(f'ignore token {token_id}: id is out of range, max={vocab_size - 1}') + continue + + token_content = token_json["content"] + token_type = SentencePieceTokenTypes.USER_DEFINED + token_score = -10000.0 + + # Map unk_token to UNKNOWN, other special tokens to CONTROL + # Set the score to 0.0 as in the original tokenizer.model + if ("special" in token_json) and token_json["special"]: + if token_content == tokenizer_config_json["unk_token"]: + token_type = SentencePieceTokenTypes.UNKNOWN + else: + token_type = SentencePieceTokenTypes.CONTROL + token_score = 0.0 + + print(f"Setting token {token_id} to '{token_content}' (type: {token_type}, score: {token_score:.2f})") + tokens[token_id] = token_content.encode("utf-8") + toktypes[token_id] = token_type + scores[token_id] = token_score + + self.gguf_writer.add_tokenizer_model("llama") + self.gguf_writer.add_tokenizer_pre("default") + self.gguf_writer.add_token_list(tokens) + self.gguf_writer.add_token_scores(scores) + self.gguf_writer.add_token_types(toktypes) + + special_vocab = gguf.SpecialVocab(self.dir_model, n_vocab=len(tokens)) + special_vocab.add_to_gguf(self.gguf_writer) + + def set_gguf_parameters(self): + super().set_gguf_parameters() + hparams = self.hparams + self.gguf_writer.add_vocab_size(hparams["vocab_size"]) + self.gguf_writer.add_rope_dimension_count(hparams["hidden_size"] // hparams["num_attention_heads"]) + + # Same as super class, but permuting q_proj, k_proj + def write_tensors(self): + block_count = self.hparams.get("n_layers", self.hparams.get("num_hidden_layers", self.hparams.get("n_layer"))) + tensor_map = gguf.get_tensor_name_map(self.model_arch, block_count) + n_head = self.hparams.get("num_attention_heads") + n_kv_head = self.hparams.get("num_key_value_heads") + n_experts = self.hparams.get("num_local_experts") + experts = dict() + for name, data_torch in self.get_tensors(): + # we don't need these + if name.endswith((".attention.masked_bias", ".attention.bias", ".attention.rotary_emb.inv_freq")): + continue + + old_dtype = data_torch.dtype + + # convert any unsupported data types to float32 + if data_torch.dtype not in (torch.float16, torch.float32): + data_torch = data_torch.to(torch.float32) + + data = data_torch.numpy() + + if name.endswith("q_proj.weight"): + data = permute(data, n_head, n_head) + if name.endswith("k_proj.weight"): + data = permute(data, n_head, n_kv_head) + + data = data.squeeze() + + # process the experts separately + if name.find("block_sparse_moe.experts") != -1: + experts[name] = data + if len(experts) >= n_experts: + # merge the experts into a single 3d tensor + for bid in range(block_count): + for wid in range(1, 4): + full = True + for xid in range(n_experts): + ename = f"model.layers.{bid}.block_sparse_moe.experts.{xid}.w{wid}.weight" + if ename not in experts: + full = False + break + if not full: + continue + + datas = [] + for xid in range(n_experts): + ename = f"model.layers.{bid}.block_sparse_moe.experts.{xid}.w{wid}.weight" + datas.append(experts[ename]) + del experts[ename] + + data = np.stack(datas, axis=0) + data_dtype = data.dtype + + if self.ftype == 0 and data_dtype == np.float16: + data = data.astype(np.float32) + + if self.ftype == 1 and data_dtype == np.float32: + data = data.astype(np.float16) + + merged_name = f"layers.{bid}.feed_forward.experts.w{wid}.weight" + + new_name = tensor_map.get_name(merged_name, try_suffixes=(".weight", ".bias")) + if new_name is None: + print(f"Can not map tensor {name!r}") + sys.exit() + + print(f"{new_name}, n_dims = {len(data.shape)}, shape = {data.shape} --> {data.dtype}") + + self.gguf_writer.add_tensor(new_name, data) + continue + + # map tensor names + new_name = tensor_map.get_name(name, try_suffixes=(".weight", ".bias")) + if new_name is None: + print(f"Can not map tensor {name!r}") + sys.exit() + + n_dims = len(data.shape) + data_dtype = data.dtype + + # if f32 desired, convert any float16 to float32 + if self.ftype == 0 and data_dtype == np.float16: + data = data.astype(np.float32) + + # 1d tensors need to be converted to float32 + if self.ftype == 1 and data_dtype == np.float16 and n_dims == 1: + data = data.astype(np.float32) + + # if f16 desired, convert any float32 2-dim weight tensors to float16 + if self.ftype == 1 and data_dtype == np.float32 and name.endswith(".weight") and n_dims == 2: + data = data.astype(np.float16) + + print(f"{new_name}, n_dims = {n_dims}, {old_dtype} --> {data.dtype}") + + self.gguf_writer.add_tensor(new_name, data) + + if len(experts) > 0: + raise ValueError(f"Unprocessed experts: {experts.keys()}") + + ###### CONVERSION LOGIC ###### From a89257151f3f5edc13702704d0c11c1b1ef6318f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Stanis=C5=82aw=20Szymczyk?= Date: Thu, 9 May 2024 17:04:21 +0200 Subject: [PATCH 05/15] Applied changes from upstream PR: save memory with lazy evaluation #7075 (shameless copy from LlamaModel). --- convert-hf-to-gguf.py | 127 +++++++++++++++--------------------------- 1 file changed, 46 insertions(+), 81 deletions(-) diff --git a/convert-hf-to-gguf.py b/convert-hf-to-gguf.py index 7ea9cbc446b23..c8967a79f68e8 100755 --- a/convert-hf-to-gguf.py +++ b/convert-hf-to-gguf.py @@ -2370,104 +2370,69 @@ def set_gguf_parameters(self): self.gguf_writer.add_vocab_size(hparams["vocab_size"]) self.gguf_writer.add_rope_dimension_count(hparams["hidden_size"] // hparams["num_attention_heads"]) - # Same as super class, but permuting q_proj, k_proj - def write_tensors(self): - block_count = self.hparams.get("n_layers", self.hparams.get("num_hidden_layers", self.hparams.get("n_layer"))) - tensor_map = gguf.get_tensor_name_map(self.model_arch, block_count) - n_head = self.hparams.get("num_attention_heads") - n_kv_head = self.hparams.get("num_key_value_heads") - n_experts = self.hparams.get("num_local_experts") - experts = dict() - for name, data_torch in self.get_tensors(): - # we don't need these - if name.endswith((".attention.masked_bias", ".attention.bias", ".attention.rotary_emb.inv_freq")): - continue - - old_dtype = data_torch.dtype - - # convert any unsupported data types to float32 - if data_torch.dtype not in (torch.float16, torch.float32): - data_torch = data_torch.to(torch.float32) - - data = data_torch.numpy() - - if name.endswith("q_proj.weight"): - data = permute(data, n_head, n_head) - if name.endswith("k_proj.weight"): - data = permute(data, n_head, n_kv_head) - - data = data.squeeze() + @staticmethod + def permute(weights: Tensor, n_head: int, n_head_kv: int | None): + if n_head_kv is not None and n_head != n_head_kv: + n_head = n_head_kv + return (weights.reshape(n_head, 2, weights.shape[0] // n_head // 2, *weights.shape[1:]) + .swapaxes(1, 2) + .reshape(weights.shape)) - # process the experts separately - if name.find("block_sparse_moe.experts") != -1: - experts[name] = data - if len(experts) >= n_experts: - # merge the experts into a single 3d tensor - for bid in range(block_count): - for wid in range(1, 4): - full = True - for xid in range(n_experts): - ename = f"model.layers.{bid}.block_sparse_moe.experts.{xid}.w{wid}.weight" - if ename not in experts: - full = False - break - if not full: - continue + _experts: list[dict[str, Tensor]] | None = None - datas = [] - for xid in range(n_experts): - ename = f"model.layers.{bid}.block_sparse_moe.experts.{xid}.w{wid}.weight" - datas.append(experts[ename]) - del experts[ename] + def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]: + n_head = self.hparams["num_attention_heads"] + n_kv_head = self.hparams.get("num_key_value_heads") - data = np.stack(datas, axis=0) - data_dtype = data.dtype + if name.endswith("q_proj.weight"): + data_torch = LlamaModel.permute(data_torch, n_head, n_head) + if name.endswith("k_proj.weight"): + data_torch = LlamaModel.permute(data_torch, n_head, n_kv_head) - if self.ftype == 0 and data_dtype == np.float16: - data = data.astype(np.float32) + # process the experts separately + if name.find("block_sparse_moe.experts") != -1: + n_experts = self.hparams["num_local_experts"] - if self.ftype == 1 and data_dtype == np.float32: - data = data.astype(np.float16) + assert bid is not None - merged_name = f"layers.{bid}.feed_forward.experts.w{wid}.weight" + if self._experts is None: + self._experts = [{} for _ in range(self.block_count)] - new_name = tensor_map.get_name(merged_name, try_suffixes=(".weight", ".bias")) - if new_name is None: - print(f"Can not map tensor {name!r}") - sys.exit() + self._experts[bid][name] = data_torch - print(f"{new_name}, n_dims = {len(data.shape)}, shape = {data.shape} --> {data.dtype}") + if len(self._experts[bid]) >= n_experts * 3: + tensors: list[tuple[str, Tensor]] = [] - self.gguf_writer.add_tensor(new_name, data) - continue + # merge the experts into a single 3d tensor + for wid in ["w1", "w2", "w3"]: + datas: list[Tensor] = [] - # map tensor names - new_name = tensor_map.get_name(name, try_suffixes=(".weight", ".bias")) - if new_name is None: - print(f"Can not map tensor {name!r}") - sys.exit() + for xid in range(n_experts): + ename = f"model.layers.{bid}.block_sparse_moe.experts.{xid}.{wid}.weight" + datas.append(self._experts[bid][ename]) + del self._experts[bid][ename] - n_dims = len(data.shape) - data_dtype = data.dtype + data_torch = torch.stack(datas, dim=0) - # if f32 desired, convert any float16 to float32 - if self.ftype == 0 and data_dtype == np.float16: - data = data.astype(np.float32) + merged_name = f"layers.{bid}.feed_forward.experts.{wid}.weight" - # 1d tensors need to be converted to float32 - if self.ftype == 1 and data_dtype == np.float16 and n_dims == 1: - data = data.astype(np.float32) + new_name = self.map_tensor_name(merged_name) - # if f16 desired, convert any float32 2-dim weight tensors to float16 - if self.ftype == 1 and data_dtype == np.float32 and name.endswith(".weight") and n_dims == 2: - data = data.astype(np.float16) + tensors.append((new_name, data_torch)) + return tensors + else: + return [] - print(f"{new_name}, n_dims = {n_dims}, {old_dtype} --> {data.dtype}") + return [(self.map_tensor_name(name), data_torch)] - self.gguf_writer.add_tensor(new_name, data) + def write_tensors(self): + super().write_tensors() - if len(experts) > 0: - raise ValueError(f"Unprocessed experts: {experts.keys()}") + if self._experts is not None: + # flatten `list[dict[str, Tensor]]` into `list[str]` + experts = [k for d in self._experts for k in d.keys()] + if len(experts) > 0: + raise ValueError(f"Unprocessed experts: {experts}") ###### CONVERSION LOGIC ###### From 4ebb52cfc23e9944058333ab1375b15335179f0b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Stanis=C5=82aw=20Szymczyk?= Date: Thu, 9 May 2024 17:13:12 +0200 Subject: [PATCH 06/15] Replaced prints with logger calls. --- convert-hf-to-gguf.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/convert-hf-to-gguf.py b/convert-hf-to-gguf.py index c8967a79f68e8..3c415a60442ac 100755 --- a/convert-hf-to-gguf.py +++ b/convert-hf-to-gguf.py @@ -2290,7 +2290,7 @@ def set_vocab(self): tokenizer_path = self.dir_model / 'tokenizer.model' if not tokenizer_path.is_file(): - print(f'Error: Missing {tokenizer_path}', file=sys.stderr) + logger.error(f'Error: Missing {tokenizer_path}') sys.exit(1) # Read the whole vocabulary from the tokenizer.model file @@ -2334,7 +2334,7 @@ def set_vocab(self): for token_id, token_json in added_tokens_decoder.items(): token_id = int(token_id) if (token_id >= vocab_size): - print(f'ignore token {token_id}: id is out of range, max={vocab_size - 1}') + logger.debug(f'ignore token {token_id}: id is out of range, max={vocab_size - 1}') continue token_content = token_json["content"] @@ -2350,7 +2350,7 @@ def set_vocab(self): token_type = SentencePieceTokenTypes.CONTROL token_score = 0.0 - print(f"Setting token {token_id} to '{token_content}' (type: {token_type}, score: {token_score:.2f})") + logger.info(f"Setting added token {token_id} to '{token_content}' (type: {token_type}, score: {token_score:.2f})") tokens[token_id] = token_content.encode("utf-8") toktypes[token_id] = token_type scores[token_id] = token_score From 9acc3ecf34ca7ce579965876e57ed62201c2b95a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Stanis=C5=82aw=20Szymczyk?= Date: Thu, 9 May 2024 19:41:39 +0200 Subject: [PATCH 07/15] Removed unnecessary method - LlamaModel.permute is used instead. --- convert-hf-to-gguf.py | 8 -------- 1 file changed, 8 deletions(-) diff --git a/convert-hf-to-gguf.py b/convert-hf-to-gguf.py index 3c415a60442ac..170dea060056b 100755 --- a/convert-hf-to-gguf.py +++ b/convert-hf-to-gguf.py @@ -2370,14 +2370,6 @@ def set_gguf_parameters(self): self.gguf_writer.add_vocab_size(hparams["vocab_size"]) self.gguf_writer.add_rope_dimension_count(hparams["hidden_size"] // hparams["num_attention_heads"]) - @staticmethod - def permute(weights: Tensor, n_head: int, n_head_kv: int | None): - if n_head_kv is not None and n_head != n_head_kv: - n_head = n_head_kv - return (weights.reshape(n_head, 2, weights.shape[0] // n_head // 2, *weights.shape[1:]) - .swapaxes(1, 2) - .reshape(weights.shape)) - _experts: list[dict[str, Tensor]] | None = None def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]: From f4421f7cd8994e138d74cd814d2bfa6ca2e1a6e5 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Stanis=C5=82aw=20Szymczyk?= Date: Tue, 14 May 2024 20:52:51 +0200 Subject: [PATCH 08/15] convert-hf : Corrected sentencepiece API calls. --- convert-hf-to-gguf.py | 15 ++++++++------- 1 file changed, 8 insertions(+), 7 deletions(-) diff --git a/convert-hf-to-gguf.py b/convert-hf-to-gguf.py index 170dea060056b..5569ed031bb4c 100755 --- a/convert-hf-to-gguf.py +++ b/convert-hf-to-gguf.py @@ -2294,7 +2294,8 @@ def set_vocab(self): sys.exit(1) # Read the whole vocabulary from the tokenizer.model file - tokenizer = SentencePieceProcessor(str(tokenizer_path)) + tokenizer = SentencePieceProcessor() + tokenizer.LoadFromFile(str(tokenizer_path)) vocab_size = self.hparams.get('vocab_size', tokenizer.vocab_size()) @@ -2304,18 +2305,18 @@ def set_vocab(self): for token_id in range(tokenizer.vocab_size()): - piece = tokenizer.id_to_piece(token_id) + piece = tokenizer.IdToPiece(token_id) text = piece.encode("utf-8") - score = tokenizer.get_score(token_id) + score = tokenizer.GetScore(token_id) toktype = SentencePieceTokenTypes.NORMAL - if tokenizer.is_unknown(token_id): + if tokenizer.IsUnknown(token_id): toktype = SentencePieceTokenTypes.UNKNOWN - elif tokenizer.is_control(token_id): + elif tokenizer.IsControl(token_id): toktype = SentencePieceTokenTypes.CONTROL - elif tokenizer.is_unused(token_id): + elif tokenizer.IsUnused(token_id): toktype = SentencePieceTokenTypes.UNUSED - elif tokenizer.is_byte(token_id): + elif tokenizer.IsByte(token_id): toktype = SentencePieceTokenTypes.BYTE tokens[token_id] = text From 85263f0568147f83829ebf367554fa2781f368cc Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Stanis=C5=82aw=20Szymczyk?= Date: Wed, 15 May 2024 09:30:58 +0200 Subject: [PATCH 09/15] Minor fixes after merging. --- llama.cpp | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/llama.cpp b/llama.cpp index 7b8bfe2ad7c29..8cd0c18de2de9 100644 --- a/llama.cpp +++ b/llama.cpp @@ -267,7 +267,7 @@ static const std::map LLM_ARCH_NAMES = { { LLM_ARCH_COMMAND_R, "command-r" }, { LLM_ARCH_DBRX, "dbrx" }, { LLM_ARCH_OLMO, "olmo" }, - { LLM_ARCH_ARCTIC, "arctic" }, + { LLM_ARCH_ARCTIC, "arctic" }, { LLM_ARCH_UNKNOWN, "(unknown)" }, }; @@ -10941,7 +10941,7 @@ struct llm_build_context { cur = llm_build_kv(ctx0, model, hparams, cparams, kv_self, gf, model.layers[il].wo, model.layers[il].bo, - Kcur, Vcur, Qcur, KQ_mask, nullptr, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il); + Kcur, Vcur, Qcur, KQ_mask, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il); } if (il == n_layer - 1) { From 5b2be25d9be413d953960a99cbb64f3af49b3865 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Stanis=C5=82aw=20Szymczyk?= Date: Thu, 16 May 2024 20:35:04 +0200 Subject: [PATCH 10/15] gguf-py : Moved non-conflicting block mappings from architecture-specific ARCTIC mappigs to general mappings. --- gguf-py/gguf/tensor_mapping.py | 55 ++++------------------------------ 1 file changed, 5 insertions(+), 50 deletions(-) diff --git a/gguf-py/gguf/tensor_mapping.py b/gguf-py/gguf/tensor_mapping.py index 126e6d3802ed6..8b1b21d78bb09 100644 --- a/gguf-py/gguf/tensor_mapping.py +++ b/gguf-py/gguf/tensor_mapping.py @@ -244,6 +244,7 @@ class TensorNameMap: "encoder.layers.{bid}.mlp.fc11", # nomic-bert "model.layers.{bid}.mlp.c_fc", # starcoder2 "encoder.layer.{bid}.mlp.gated_layers_v", # jina-bert-v2 + "model.layers.{bid}.residual_mlp.w3", # arctic ), MODEL_TENSOR.FFN_UP_EXP: ( @@ -272,6 +273,7 @@ class TensorNameMap: "encoder.layers.{bid}.mlp.fc12", # nomic-bert "encoder.layer.{bid}.mlp.gated_layers_w", # jina-bert-v2 "transformer.h.{bid}.mlp.linear_1", # refact + "model.layers.{bid}.residual_mlp.w1", # arctic ), MODEL_TENSOR.FFN_GATE_EXP: ( @@ -306,6 +308,7 @@ class TensorNameMap: "encoder.layers.{bid}.mlp.fc2", # nomic-bert "model.layers.{bid}.mlp.c_proj", # starcoder2 "encoder.layer.{bid}.mlp.wo", # jina-bert-v2 + "model.layers.{bid}.residual_mlp.w2", # arctic ), MODEL_TENSOR.FFN_DOWN_EXP: ( @@ -380,60 +383,14 @@ class TensorNameMap: "model.layers.{bid}.out_proj", "backbone.layers.{bid}.mixer.out_proj", ), - } # architecture-specific block mappings arch_block_mappings_cfg: dict[MODEL_ARCH, dict[MODEL_TENSOR, tuple[str, ...]]] = { MODEL_ARCH.ARCTIC: { - MODEL_TENSOR.TOKEN_EMBD: ( - "model.embed_tokens", - ), - MODEL_TENSOR.OUTPUT_NORM: ( - "model.norm", - ), - MODEL_TENSOR.OUTPUT: ( - "lm_head", - ), - MODEL_TENSOR.ATTN_NORM: ( - "model.layers.{bid}.input_layernorm", - ), - MODEL_TENSOR.ATTN_Q: ( - "model.layers.{bid}.self_attn.q_proj", - ), - MODEL_TENSOR.ATTN_K: ( - "model.layers.{bid}.self_attn.k_proj", - ), - MODEL_TENSOR.ATTN_V: ( - "model.layers.{bid}.self_attn.v_proj", - ), - MODEL_TENSOR.ATTN_OUT: ( - "model.layers.{bid}.self_attn.o_proj", - ), - MODEL_TENSOR.FFN_GATE_INP: ( - "model.layers.{bid}.block_sparse_moe.gate", - ), MODEL_TENSOR.FFN_NORM: ( "model.layers.{bid}.residual_layernorm", ), - MODEL_TENSOR.FFN_GATE: ( - "model.layers.{bid}.residual_mlp.w1", - ), - MODEL_TENSOR.FFN_DOWN: ( - "model.layers.{bid}.residual_mlp.w2", - ), - MODEL_TENSOR.FFN_UP: ( - "model.layers.{bid}.residual_mlp.w3", - ), - MODEL_TENSOR.FFN_GATE_EXP: ( - "layers.{bid}.feed_forward.experts.w1", - ), - MODEL_TENSOR.FFN_DOWN_EXP: ( - "layers.{bid}.feed_forward.experts.w2", - ), - MODEL_TENSOR.FFN_UP_EXP: ( - "layers.{bid}.feed_forward.experts.w3", - ), MODEL_TENSOR.FFN_NORM_EXP: ( "model.layers.{bid}.post_attention_layernorm", ), @@ -452,11 +409,9 @@ def __init__(self, arch: MODEL_ARCH, n_blocks: int): for key in keys: self.mapping[key] = (tensor, tensor_name) if arch in self.arch_block_mappings_cfg: - block_mappings = self.arch_block_mappings_cfg[arch] - else: - block_mappings = self.block_mappings_cfg + self.block_mappings_cfg.update(self.arch_block_mappings_cfg[arch]) for bid in range(n_blocks): - for tensor, keys in block_mappings.items(): + for tensor, keys in self.block_mappings_cfg.items(): if tensor not in MODEL_TENSORS[arch]: continue # TODO: make this configurable From 5553226f41599937be1546bbdcdcfeda8b0d5d3f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Stanis=C5=82aw=20Szymczyk?= Date: Fri, 17 May 2024 09:34:15 +0200 Subject: [PATCH 11/15] Reordered tensors for visual consistency. --- gguf-py/gguf/constants.py | 6 +++--- llama.cpp | 2 +- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/gguf-py/gguf/constants.py b/gguf-py/gguf/constants.py index 1ae568a37595f..cb932eeaa1f62 100644 --- a/gguf-py/gguf/constants.py +++ b/gguf-py/gguf/constants.py @@ -166,6 +166,7 @@ class MODEL_TENSOR(IntEnum): FFN_DOWN = auto() FFN_UP = auto() FFN_ACT = auto() + FFN_NORM_EXP = auto() FFN_GATE_EXP = auto() FFN_DOWN_EXP = auto() FFN_UP_EXP = auto() @@ -182,7 +183,6 @@ class MODEL_TENSOR(IntEnum): SSM_A = auto() SSM_D = auto() SSM_OUT = auto() - FFN_NORM_EXP = auto() MODEL_ARCH_NAMES: dict[MODEL_ARCH, str] = { @@ -251,6 +251,7 @@ class MODEL_TENSOR(IntEnum): MODEL_TENSOR.FFN_DOWN_SHEXP: "blk.{bid}.ffn_down_shexp", MODEL_TENSOR.FFN_UP_SHEXP: "blk.{bid}.ffn_up_shexp", MODEL_TENSOR.FFN_ACT: "blk.{bid}.ffn", + MODEL_TENSOR.FFN_NORM_EXP: "blk.{bid}.ffn_norm_exps", MODEL_TENSOR.FFN_GATE_EXP: "blk.{bid}.ffn_gate_exps", MODEL_TENSOR.FFN_DOWN_EXP: "blk.{bid}.ffn_down_exps", MODEL_TENSOR.FFN_UP_EXP: "blk.{bid}.ffn_up_exps", @@ -262,7 +263,6 @@ class MODEL_TENSOR(IntEnum): MODEL_TENSOR.SSM_A: "blk.{bid}.ssm_a", MODEL_TENSOR.SSM_D: "blk.{bid}.ssm_d", MODEL_TENSOR.SSM_OUT: "blk.{bid}.ssm_out", - MODEL_TENSOR.FFN_NORM_EXP: "blk.{bid}.ffn_norm_exps", } MODEL_TENSORS: dict[MODEL_ARCH, list[MODEL_TENSOR]] = { @@ -763,10 +763,10 @@ class MODEL_TENSOR(IntEnum): MODEL_TENSOR.FFN_GATE, MODEL_TENSOR.FFN_DOWN, MODEL_TENSOR.FFN_UP, + MODEL_TENSOR.FFN_NORM_EXP, MODEL_TENSOR.FFN_GATE_EXP, MODEL_TENSOR.FFN_DOWN_EXP, MODEL_TENSOR.FFN_UP_EXP, - MODEL_TENSOR.FFN_NORM_EXP, ], # TODO } diff --git a/llama.cpp b/llama.cpp index 8cd0c18de2de9..5030c33ee81f7 100644 --- a/llama.cpp +++ b/llama.cpp @@ -1071,10 +1071,10 @@ static const std::map> LLM_TENSOR_NA { LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" }, { LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" }, { LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" }, + { LLM_TENSOR_FFN_NORM_EXPS, "blk.%d.ffn_norm_exps" }, { LLM_TENSOR_FFN_GATE_EXPS, "blk.%d.ffn_gate_exps" }, { LLM_TENSOR_FFN_DOWN_EXPS, "blk.%d.ffn_down_exps" }, { LLM_TENSOR_FFN_UP_EXPS, "blk.%d.ffn_up_exps" }, - { LLM_TENSOR_FFN_NORM_EXPS, "blk.%d.ffn_norm_exps" }, }, }, { From f93acb5e1f68b5b7761e07c9a7833d8902af3b62 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Stanis=C5=82aw=20Szymczyk?= Date: Fri, 17 May 2024 12:13:33 +0200 Subject: [PATCH 12/15] llama : Removed usage of bias tensors in LLM_ARCH_ARCTIC, as they are not present in released models. --- llama.cpp | 20 +------------------- 1 file changed, 1 insertion(+), 19 deletions(-) diff --git a/llama.cpp b/llama.cpp index 5030c33ee81f7..4c94304028fdd 100644 --- a/llama.cpp +++ b/llama.cpp @@ -6131,12 +6131,6 @@ static bool llm_load_tensors( layer.wv = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_gqa}); layer.wo = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}); - // optional bias tensors - layer.bq = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_Q, "bias", i), {n_embd}, false); - layer.bk = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_K, "bias", i), {n_embd_gqa}, false); - layer.bv = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_V, "bias", i), {n_embd_gqa}, false); - layer.bo = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_OUT, "bias", i), {n_embd}, false); - layer.ffn_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}); layer.ffn_gate = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_embd}); @@ -10906,24 +10900,12 @@ struct llm_build_context { // compute Q and K and RoPE them struct ggml_tensor * Qcur = ggml_mul_mat(ctx0, model.layers[il].wq, cur); cb(Qcur, "Qcur", il); - if (model.layers[il].bq) { - Qcur = ggml_add(ctx0, Qcur, model.layers[il].bq); - cb(Qcur, "Qcur", il); - } struct ggml_tensor * Kcur = ggml_mul_mat(ctx0, model.layers[il].wk, cur); cb(Kcur, "Kcur", il); - if (model.layers[il].bk) { - Kcur = ggml_add(ctx0, Kcur, model.layers[il].bk); - cb(Kcur, "Kcur", il); - } struct ggml_tensor * Vcur = ggml_mul_mat(ctx0, model.layers[il].wv, cur); cb(Vcur, "Vcur", il); - if (model.layers[il].bv) { - Vcur = ggml_add(ctx0, Vcur, model.layers[il].bv); - cb(Vcur, "Vcur", il); - } Qcur = ggml_rope_custom( ctx0, ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens), inp_pos, @@ -10940,7 +10922,7 @@ struct llm_build_context { cb(Kcur, "Kcur", il); cur = llm_build_kv(ctx0, model, hparams, cparams, kv_self, gf, - model.layers[il].wo, model.layers[il].bo, + model.layers[il].wo, NULL, Kcur, Vcur, Qcur, KQ_mask, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il); } From b53fd2956c6187fbcb03a23865c5c2a53f782883 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Stanis=C5=82aw=20Szymczyk?= Date: Fri, 17 May 2024 12:45:30 +0200 Subject: [PATCH 13/15] Reordered tensors for visual consistency. --- llama.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/llama.cpp b/llama.cpp index 4c94304028fdd..e3ed625017dae 100644 --- a/llama.cpp +++ b/llama.cpp @@ -462,10 +462,10 @@ enum llm_tensor { LLM_TENSOR_FFN_DOWN_EXP, // split experts for backward compatibility LLM_TENSOR_FFN_GATE_EXP, LLM_TENSOR_FFN_UP_EXP, + LLM_TENSOR_FFN_NORM_EXPS, LLM_TENSOR_FFN_DOWN_EXPS, // merged experts LLM_TENSOR_FFN_GATE_EXPS, LLM_TENSOR_FFN_UP_EXPS, - LLM_TENSOR_FFN_NORM_EXPS, LLM_TENSOR_FFN_DOWN_SHEXP, LLM_TENSOR_FFN_GATE_SHEXP, LLM_TENSOR_FFN_UP_SHEXP, From a1a5508d67a7c38905e230ee0ffcf088a8e7b477 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Stanis=C5=82aw=20Szymczyk?= Date: Wed, 22 May 2024 15:52:10 +0200 Subject: [PATCH 14/15] llama : Replaced obsolete ggml_rope_custom() calls with ggml_rope_ext(). --- llama.cpp | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/llama.cpp b/llama.cpp index d5f3460432cb3..452cf4647cb98 100644 --- a/llama.cpp +++ b/llama.cpp @@ -10707,15 +10707,15 @@ struct llm_build_context { struct ggml_tensor * Vcur = ggml_mul_mat(ctx0, model.layers[il].wv, cur); cb(Vcur, "Vcur", il); - Qcur = ggml_rope_custom( - ctx0, ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens), inp_pos, + Qcur = ggml_rope_ext( + ctx0, ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens), inp_pos, nullptr, n_rot, rope_type, 0, n_orig_ctx, freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow ); cb(Qcur, "Qcur", il); - Kcur = ggml_rope_custom( - ctx0, ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens), inp_pos, + Kcur = ggml_rope_ext( + ctx0, ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens), inp_pos, nullptr, n_rot, rope_type, 0, n_orig_ctx, freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow ); From 602c80d918e609f8bd5120fcd346242ed2da5f74 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Stanis=C5=82aw=20Szymczyk?= Date: Fri, 24 May 2024 12:44:13 +0200 Subject: [PATCH 15/15] llama : fix whitespace formatting --- llama.cpp | 82 +++++++++++++++++++++++++++---------------------------- 1 file changed, 41 insertions(+), 41 deletions(-) diff --git a/llama.cpp b/llama.cpp index 7009e939dbadf..3c9fe15bb4596 100644 --- a/llama.cpp +++ b/llama.cpp @@ -3808,48 +3808,48 @@ static std::string llama_model_ftype_name(llama_ftype ftype) { static const char * llama_model_type_name(e_model type) { switch (type) { - case MODEL_14M: return "14M"; - case MODEL_17M: return "17M"; - case MODEL_22M: return "22M"; - case MODEL_33M: return "33M"; - case MODEL_70M: return "70M"; - case MODEL_109M: return "109M"; - case MODEL_137M: return "137M"; - case MODEL_160M: return "160M"; - case MODEL_335M: return "335M"; - case MODEL_410M: return "410M"; - case MODEL_0_5B: return "0.5B"; - case MODEL_1B: return "1B"; - case MODEL_1_4B: return "1.4B"; - case MODEL_2B: return "2B"; - case MODEL_2_8B: return "2.8B"; - case MODEL_3B: return "3B"; - case MODEL_4B: return "4B"; - case MODEL_6_9B: return "6.9B"; - case MODEL_7B: return "7B"; - case MODEL_8B: return "8B"; - case MODEL_12B: return "12B"; - case MODEL_13B: return "13B"; - case MODEL_14B: return "14B"; - case MODEL_15B: return "15B"; - case MODEL_20B: return "20B"; - case MODEL_30B: return "30B"; - case MODEL_34B: return "34B"; - case MODEL_35B: return "35B"; - case MODEL_40B: return "40B"; - case MODEL_65B: return "65B"; - case MODEL_70B: return "70B"; - case MODEL_314B: return "314B"; - case MODEL_SMALL: return "0.1B"; - case MODEL_MEDIUM: return "0.4B"; - case MODEL_LARGE: return "0.8B"; - case MODEL_XL: return "1.5B"; - case MODEL_A2_7B: return "A2.7B"; - case MODEL_8x7B: return "8x7B"; - case MODEL_8x22B: return "8x22B"; - case MODEL_16x12B: return "16x12B"; + case MODEL_14M: return "14M"; + case MODEL_17M: return "17M"; + case MODEL_22M: return "22M"; + case MODEL_33M: return "33M"; + case MODEL_70M: return "70M"; + case MODEL_109M: return "109M"; + case MODEL_137M: return "137M"; + case MODEL_160M: return "160M"; + case MODEL_335M: return "335M"; + case MODEL_410M: return "410M"; + case MODEL_0_5B: return "0.5B"; + case MODEL_1B: return "1B"; + case MODEL_1_4B: return "1.4B"; + case MODEL_2B: return "2B"; + case MODEL_2_8B: return "2.8B"; + case MODEL_3B: return "3B"; + case MODEL_4B: return "4B"; + case MODEL_6_9B: return "6.9B"; + case MODEL_7B: return "7B"; + case MODEL_8B: return "8B"; + case MODEL_12B: return "12B"; + case MODEL_13B: return "13B"; + case MODEL_14B: return "14B"; + case MODEL_15B: return "15B"; + case MODEL_20B: return "20B"; + case MODEL_30B: return "30B"; + case MODEL_34B: return "34B"; + case MODEL_35B: return "35B"; + case MODEL_40B: return "40B"; + case MODEL_65B: return "65B"; + case MODEL_70B: return "70B"; + case MODEL_314B: return "314B"; + case MODEL_SMALL: return "0.1B"; + case MODEL_MEDIUM: return "0.4B"; + case MODEL_LARGE: return "0.8B"; + case MODEL_XL: return "1.5B"; + case MODEL_A2_7B: return "A2.7B"; + case MODEL_8x7B: return "8x7B"; + case MODEL_8x22B: return "8x22B"; + case MODEL_16x12B: return "16x12B"; case MODEL_10B_128x3_66B: return "10B+128x3.66B"; - default: return "?B"; + default: return "?B"; } }