Skip to content

Commit

Permalink
llama: Support MiniCPM-1B (with & w/o longrope) (#10559)
Browse files Browse the repository at this point in the history
  • Loading branch information
JFLFY2255 authored Dec 4, 2024
1 parent 2759916 commit 8d0cfd5
Show file tree
Hide file tree
Showing 4 changed files with 60 additions and 182 deletions.
55 changes: 33 additions & 22 deletions convert_hf_to_gguf.py
Original file line number Diff line number Diff line change
Expand Up @@ -1831,29 +1831,40 @@ class MiniCPMModel(Model):
model_arch = gguf.MODEL_ARCH.MINICPM

def set_gguf_parameters(self):
block_count = self.hparams["num_hidden_layers"]
self.gguf_writer.add_context_length(self.hparams["max_position_embeddings"])
self.gguf_writer.add_embedding_length(self.hparams["hidden_size"])
self.gguf_writer.add_block_count(block_count)
self.gguf_writer.add_feed_forward_length(self.hparams["intermediate_size"])
self.gguf_writer.add_rope_dimension_count(self.hparams["hidden_size"] // self.hparams["num_attention_heads"])
self.gguf_writer.add_head_count(self.hparams["num_attention_heads"])
self.gguf_writer.add_head_count_kv(self.hparams["num_key_value_heads"])
self.gguf_writer.add_layer_norm_rms_eps(self.hparams["rms_norm_eps"])
self.gguf_writer.add_file_type(self.ftype)
super().set_gguf_parameters()
embedding_scale = float(self.hparams["scale_emb"])
self.gguf_writer.add_embedding_scale(embedding_scale)
logger.info(f"gguf: (minicpm) embedding_scale = {embedding_scale}")
residual_scale = self.hparams["scale_depth"] / self.hparams["num_hidden_layers"] ** 0.5
self.gguf_writer.add_residual_scale(residual_scale)
logger.info(f"gguf: (minicpm) residual_scale = {residual_scale}")
logit_scale = self.hparams["hidden_size"] / self.hparams["dim_model_base"]
self.gguf_writer.add_logit_scale(logit_scale)
logger.info(f"gguf: (minicpm) logit_scale = {logit_scale}")
if self.hparams.get("rope_scaling") is not None:
if self.hparams["rope_scaling"].get("type") == "longrope":
self.gguf_writer.add_rope_scaling_type(gguf.RopeScalingType.LONGROPE)
logger.info(f"gguf: (minicpm) rope_scaling_type = {gguf.RopeScalingType.LONGROPE}")

def set_vocab(self):
self._set_vocab_llama_hf()
def generate_extra_tensors(self) -> Iterable[tuple[str, Tensor]]:
rope_dims = self.hparams["hidden_size"] // self.hparams["num_attention_heads"]

def _reverse_hf_permute(self, weights: Tensor, n_head: int, n_kv_head: int | None = None) -> Tensor:
if n_kv_head is not None and n_head != n_kv_head:
n_head //= n_kv_head
rope_scaling = self.find_hparam(['rope_scaling'], True)
if rope_scaling is not None:
long_factors = rope_scaling.get('long_factor', None)
short_factors = rope_scaling.get('short_factor', None)

return (
weights.reshape(n_head, 2, weights.shape[0] // n_head // 2, *weights.shape[1:])
.swapaxes(1, 2)
.reshape(weights.shape)
)
if long_factors is None or short_factors is None:
raise KeyError('Missing the required key rope_scaling.long_factor or rope_scaling_short_factor')

if len(long_factors) != len(short_factors) or len(long_factors) != rope_dims / 2:
raise ValueError(f'The length of rope long and short factors must be {rope_dims / 2}')

yield (self.format_tensor_name(gguf.MODEL_TENSOR.ROPE_FACTORS_LONG), torch.tensor(long_factors, dtype=torch.float32))
yield (self.format_tensor_name(gguf.MODEL_TENSOR.ROPE_FACTORS_SHORT), torch.tensor(short_factors, dtype=torch.float32))

def set_vocab(self):
self._set_vocab_sentencepiece()

def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
del bid # unused
Expand All @@ -1863,9 +1874,9 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iter

# HF models permute some of the tensors, so we need to undo that
if name.endswith(("q_proj.weight")):
data_torch = self._reverse_hf_permute(data_torch, n_head, n_head)
data_torch = LlamaModel.permute(data_torch, n_head, n_head)
if name.endswith(("k_proj.weight")):
data_torch = self._reverse_hf_permute(data_torch, n_head, n_kv_head)
data_torch = LlamaModel.permute(data_torch, n_head, n_kv_head)

return [(self.map_tensor_name(name), data_torch)]

Expand Down
9 changes: 6 additions & 3 deletions gguf-py/gguf/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -896,6 +896,8 @@ class MODEL_TENSOR(IntEnum):
MODEL_TENSOR.OUTPUT,
MODEL_TENSOR.OUTPUT_NORM,
MODEL_TENSOR.ROPE_FREQS,
MODEL_TENSOR.ROPE_FACTORS_LONG,
MODEL_TENSOR.ROPE_FACTORS_SHORT,
MODEL_TENSOR.ATTN_NORM,
MODEL_TENSOR.ATTN_Q,
MODEL_TENSOR.ATTN_K,
Expand Down Expand Up @@ -1388,9 +1390,10 @@ class TokenType(IntEnum):


class RopeScalingType(Enum):
NONE = 'none'
LINEAR = 'linear'
YARN = 'yarn'
NONE = 'none'
LINEAR = 'linear'
YARN = 'yarn'
LONGROPE = 'longrope'


class PoolingType(IntEnum):
Expand Down
3 changes: 2 additions & 1 deletion include/llama.h
Original file line number Diff line number Diff line change
Expand Up @@ -185,7 +185,8 @@ extern "C" {
LLAMA_ROPE_SCALING_TYPE_NONE = 0,
LLAMA_ROPE_SCALING_TYPE_LINEAR = 1,
LLAMA_ROPE_SCALING_TYPE_YARN = 2,
LLAMA_ROPE_SCALING_TYPE_MAX_VALUE = LLAMA_ROPE_SCALING_TYPE_YARN,
LLAMA_ROPE_SCALING_TYPE_LONGROPE = 3,
LLAMA_ROPE_SCALING_TYPE_MAX_VALUE = LLAMA_ROPE_SCALING_TYPE_LONGROPE,
};

enum llama_pooling_type {
Expand Down
175 changes: 19 additions & 156 deletions src/llama.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1036,6 +1036,8 @@ static const std::map<llm_arch, std::map<llm_tensor, const char *>> LLM_TENSOR_N
{ LLM_TENSOR_OUTPUT_NORM, "output_norm" },
{ LLM_TENSOR_OUTPUT, "output" },
{ LLM_TENSOR_ROPE_FREQS, "rope_freqs" },
{ LLM_TENSOR_ROPE_FACTORS_LONG, "rope_factors_long" },
{ LLM_TENSOR_ROPE_FACTORS_SHORT, "rope_factors_short" },
{ LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" },
{ LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" },
{ LLM_TENSOR_ATTN_K, "blk.%d.attn_k" },
Expand Down Expand Up @@ -1683,9 +1685,10 @@ struct LLM_TN {
//

static const std::map<llama_rope_scaling_type, const char *> LLAMA_ROPE_SCALING_TYPES = {
{ LLAMA_ROPE_SCALING_TYPE_NONE, "none" },
{ LLAMA_ROPE_SCALING_TYPE_LINEAR, "linear" },
{ LLAMA_ROPE_SCALING_TYPE_YARN, "yarn" },
{ LLAMA_ROPE_SCALING_TYPE_NONE, "none" },
{ LLAMA_ROPE_SCALING_TYPE_LINEAR, "linear" },
{ LLAMA_ROPE_SCALING_TYPE_YARN, "yarn" },
{ LLAMA_ROPE_SCALING_TYPE_LONGROPE, "longrope" },
};

static llama_rope_scaling_type llama_rope_scaling_type_from_string(const std::string & name) {
Expand Down Expand Up @@ -5580,8 +5583,12 @@ static void llm_load_hparams(
case LLM_ARCH_MINICPM:
{
ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps);
ml.get_key(LLM_KV_EMBEDDING_SCALE, hparams.f_embedding_scale);
ml.get_key(LLM_KV_RESIDUAL_SCALE, hparams.f_residual_scale);
ml.get_key(LLM_KV_LOGIT_SCALE, hparams.f_logit_scale);

switch (hparams.n_layer) {
case 52: model.type = e_model::MODEL_1B; break;
case 40: model.type = e_model::MODEL_2B; break;
default: model.type = e_model::MODEL_UNKNOWN;
}
Expand Down Expand Up @@ -7065,7 +7072,7 @@ static void llm_load_print_meta(llama_model_loader & ml, llama_model & model) {
LLAMA_LOG_INFO("%s: n_ff_shexp = %d\n", __func__, hparams.n_ff_shexp);
}

if (model.arch == LLM_ARCH_GRANITE || model.arch == LLM_ARCH_GRANITE_MOE) {
if (model.arch == LLM_ARCH_MINICPM || model.arch == LLM_ARCH_GRANITE || model.arch == LLM_ARCH_GRANITE_MOE) {
LLAMA_LOG_INFO("%s: f_embedding_scale = %f\n", __func__, hparams.f_embedding_scale);
LLAMA_LOG_INFO("%s: f_residual_scale = %f\n", __func__, hparams.f_residual_scale);
LLAMA_LOG_INFO("%s: f_attention_scale = %f\n", __func__, hparams.f_attention_scale);
Expand Down Expand Up @@ -7690,7 +7697,13 @@ static bool llm_load_tensors(

layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0);

layer.rope_freqs = create_tensor(tn(LLM_TENSOR_ROPE_FREQS, "weight", i), {n_rot/2}, llama_model_loader::TENSOR_NOT_REQUIRED | (i != 0 ? llama_model_loader::TENSOR_DUPLICATED : 0));
if (hparams.rope_scaling_type_train == LLAMA_ROPE_SCALING_TYPE_LONGROPE) {
layer.rope_long = create_tensor(tn(LLM_TENSOR_ROPE_FACTORS_LONG, "weight", i), {n_rot/2}, llama_model_loader::TENSOR_NOT_REQUIRED | (i != 0 ? llama_model_loader::TENSOR_DUPLICATED : 0));
layer.rope_short = create_tensor(tn(LLM_TENSOR_ROPE_FACTORS_SHORT, "weight", i), {n_rot/2}, llama_model_loader::TENSOR_NOT_REQUIRED | (i != 0 ? llama_model_loader::TENSOR_DUPLICATED : 0));
}
else {
layer.rope_freqs = create_tensor(tn(LLM_TENSOR_ROPE_FREQS, "weight", i), {n_rot/2}, llama_model_loader::TENSOR_NOT_REQUIRED | (i != 0 ? llama_model_loader::TENSOR_DUPLICATED : 0));
}

if (n_expert == 0) {
layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, 0);
Expand Down Expand Up @@ -13497,153 +13510,6 @@ struct llm_build_context {
return gf;
}

// ref: https://arxiv.org/abs/2203.03466
// https://github.com/ggerganov/llama.cpp/issues/5276#issuecomment-1925774738
// based on the original build_llama() function
struct ggml_cgraph * build_minicpm() {
struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, llama_model_max_nodes(model), false);

const int64_t n_embd_head = hparams.n_embd_head_v;
GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
GGML_ASSERT(n_embd_head == hparams.n_rot);

const int64_t n_embd = hparams.n_embd;
//TODO: if the model varies, these parameters need to be read from the model
const int64_t n_embd_base = 256;
const float scale_embd = 12.0f;
const float scale_depth = 1.4f;

struct ggml_tensor * cur;
struct ggml_tensor * inpL;

inpL = llm_build_inp_embd(ctx0, lctx, hparams, ubatch, model.tok_embd, cb);

// scale the input embeddings
inpL = ggml_scale(ctx0, inpL, scale_embd);
cb(inpL, "inp_scaled", -1);

// inp_pos - contains the positions
struct ggml_tensor * inp_pos = build_inp_pos();

// KQ_mask (mask for 1 head, it will be broadcasted to all heads)
struct ggml_tensor * KQ_mask = build_inp_KQ_mask();

for (int il = 0; il < n_layer; ++il) {
struct ggml_tensor * inpSA = inpL;

// norm
cur = llm_build_norm(ctx0, inpL, hparams,
model.layers[il].attn_norm, NULL,
LLM_NORM_RMS, cb, il);
cb(cur, "attn_norm", il);

// self-attention
{
// compute Q and K and RoPE them
struct ggml_tensor * Qcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wq, cur);
cb(Qcur, "Qcur", il);
if (model.layers[il].bq) {
Qcur = ggml_add(ctx0, Qcur, model.layers[il].bq);
cb(Qcur, "Qcur", il);
}

struct ggml_tensor * Kcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wk, cur);
cb(Kcur, "Kcur", il);
if (model.layers[il].bk) {
Kcur = ggml_add(ctx0, Kcur, model.layers[il].bk);
cb(Kcur, "Kcur", il);
}

struct ggml_tensor * Vcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wv, cur);
cb(Vcur, "Vcur", il);
if (model.layers[il].bv) {
Vcur = ggml_add(ctx0, Vcur, model.layers[il].bv);
cb(Vcur, "Vcur", il);
}

Qcur = ggml_rope_ext(
ctx0, ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens), inp_pos, nullptr,
n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
ext_factor, attn_factor, beta_fast, beta_slow
);
cb(Qcur, "Qcur", il);

Kcur = ggml_rope_ext(
ctx0, ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens), inp_pos, nullptr,
n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
ext_factor, attn_factor, beta_fast, beta_slow
);
cb(Kcur, "Kcur", il);

cur = llm_build_kv(ctx0, lctx, kv_self, gf,
model.layers[il].wo, model.layers[il].bo,
Kcur, Vcur, Qcur, KQ_mask, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il);
}

if (il == n_layer - 1) {
// skip computing output for unused tokens
struct ggml_tensor * inp_out_ids = build_inp_out_ids();
cur = ggml_get_rows(ctx0, cur, inp_out_ids);
inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids);
}

// scale_res - scale the hidden states for residual connection
const float scale_res = scale_depth/sqrtf(float(n_layer));
cur = ggml_scale(ctx0, cur, scale_res);
cb(cur, "hidden_scaled", -1);

struct ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA);
cb(ffn_inp, "ffn_inp", il);

// feed-forward network
{
cur = llm_build_norm(ctx0, ffn_inp, hparams,
model.layers[il].ffn_norm, NULL,
LLM_NORM_RMS, cb, il);
cb(cur, "ffn_norm", il);

cur = llm_build_ffn(ctx0, lctx, cur,
model.layers[il].ffn_up, NULL, NULL,
model.layers[il].ffn_gate, NULL, NULL,
model.layers[il].ffn_down, NULL, NULL,
NULL,
LLM_FFN_SILU, LLM_FFN_PAR, cb, il);
cb(cur, "ffn_out", il);
}

// scale the hidden states for residual connection
cur = ggml_scale(ctx0, cur, scale_res);
cb(cur, "hidden_scaled_ffn", -1);

cur = ggml_add(ctx0, cur, ffn_inp);
cur = lctx.cvec.apply_to(ctx0, cur, il);
cb(cur, "l_out", il);

// input for next layer
inpL = cur;
}

cur = inpL;

cur = llm_build_norm(ctx0, cur, hparams,
model.output_norm, NULL,
LLM_NORM_RMS, cb, -1);
cb(cur, "result_norm", -1);

// lm_head scaling
const float scale_lmhead = float(n_embd_base)/float(n_embd);
cur = ggml_scale(ctx0, cur, scale_lmhead);
cb(cur, "lmhead_scaling", -1);

// lm_head
cur = llm_build_lora_mm(lctx, ctx0, model.output, cur);
cb(cur, "result_output", -1);

ggml_build_forward_expand(gf, cur);

return gf;
}

struct ggml_cgraph * build_minicpm3() {
struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, llama_model_max_nodes(model), false);

Expand Down Expand Up @@ -16742,6 +16608,7 @@ static struct ggml_cgraph * llama_build_graph(

switch (model.arch) {
case LLM_ARCH_LLAMA:
case LLM_ARCH_MINICPM:
case LLM_ARCH_GRANITE:
case LLM_ARCH_GRANITE_MOE:
{
Expand Down Expand Up @@ -16825,10 +16692,6 @@ static struct ggml_cgraph * llama_build_graph(
{
result = llm.build_internlm2();
} break;
case LLM_ARCH_MINICPM:
{
result = llm.build_minicpm();
} break;
case LLM_ARCH_MINICPM3:
{
result = llm.build_minicpm3();
Expand Down

0 comments on commit 8d0cfd5

Please sign in to comment.