-
Notifications
You must be signed in to change notification settings - Fork 23
/
convert_params.py
54 lines (44 loc) · 2.87 KB
/
convert_params.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
from jax import Array
import torch
import torch.nn as tnn
from transformers import LlamaForCausalLM, LlamaModel as LlamaModelPt
from transformers.models.llama.modeling_llama import LlamaAttention, LlamaDecoderLayer
from ..array_utils import pt2jax
from ..llama import Llama, LlamaModel, ModelConfig
from ..llama.attention import Attention
from ..llama.decoder_block import DecoderBlock
from ..tree_utils import stack_leaves
def convert_proj(x: tnn.Linear) -> Array:
return pt2jax(x.weight.T)
def convert_q_proj(x: tnn.Linear, *, model_config: ModelConfig) -> Array:
return pt2jax(x.weight.T.reshape(model_config.d_model, model_config.n_rep_kv, model_config.n_heads_kv, model_config.d_k))
def convert_k_proj(x: tnn.Linear, *, model_config: ModelConfig) -> Array:
return pt2jax(x.weight.T.reshape(model_config.d_model, model_config.n_heads_kv, model_config.d_k))
def convert_v_proj(x: tnn.Linear, *, model_config: ModelConfig) -> Array:
return pt2jax(x.weight.T.reshape(model_config.d_model, model_config.n_heads_kv, model_config.d_v))
def convert_out_proj(x: tnn.Linear, *, model_config: ModelConfig) -> Array:
return pt2jax(x.weight.T.reshape(model_config.n_rep_kv, model_config.n_heads_kv, model_config.d_v, model_config.d_model))
def convert_attention(x: LlamaAttention, *, model_config: ModelConfig) -> Attention:
q_proj = convert_q_proj(x.q_proj, model_config=model_config)
k_proj = convert_k_proj(x.k_proj, model_config=model_config)
v_proj = convert_v_proj(x.v_proj, model_config=model_config)
out_proj = convert_out_proj(x.o_proj, model_config=model_config)
return Attention(q_proj=q_proj, k_proj=k_proj, v_proj=v_proj, out_proj=out_proj)
def convert_decoder_block(x: LlamaDecoderLayer, *, model_config: ModelConfig) -> DecoderBlock:
input_norm = pt2jax(x.input_layernorm.weight)
attention = convert_attention(x.self_attn, model_config=model_config)
post_attn_norm = pt2jax(x.post_attention_layernorm.weight)
gate_proj = convert_proj(x.mlp.gate_proj)
up_proj = convert_proj(x.mlp.up_proj)
down_proj = convert_proj(x.mlp.down_proj)
return DecoderBlock(input_norm=input_norm, attention=attention, post_attn_norm=post_attn_norm, gate_proj=gate_proj, up_proj=up_proj, down_proj=down_proj)
def convert_llama_model(model: LlamaModelPt, *, model_config: ModelConfig) -> LlamaModel:
embedding = pt2jax(model.embed_tokens.weight)
decoder = stack_leaves([convert_decoder_block(model.layers[i], model_config=model_config) for i in range(model_config.n_layers)])
norm = pt2jax(model.norm.weight)
return LlamaModel(embedding=embedding, decoder=decoder, norm=norm)
def convert_llama(model_pt: LlamaForCausalLM, *, model_config: ModelConfig) -> Llama:
with torch.no_grad():
model = convert_llama_model(model_pt.model, model_config=model_config)
lm_head = convert_proj(model_pt.lm_head)
return Llama(model=model, lm_head=lm_head)