Skip to content

Commit

Permalink
tried setting pos weight
Browse files Browse the repository at this point in the history
  • Loading branch information
bryce13950 committed Jul 5, 2024
1 parent fad4f34 commit 2f44d5b
Showing 1 changed file with 1 addition and 0 deletions.
1 change: 1 addition & 0 deletions transformer_lens/pretrained/weight_conversions/baichuan.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@ def convert_baichuan_weights(baichuan, cfg: HookedTransformerConfig):
state_dict[f"blocks.{l}.mlp.b_out"] = torch.zeros(cfg.d_model, dtype=W_O.dtype)

state_dict["ln_final.w"] = baichuan.model.norm.weight
state_dict["pos_embed.W_pos"] = baichuan.model.transformer.wpe.weight
state_dict["unembed.W_U"] = baichuan.lm_head.weight.T
state_dict["unembed.b_U"] = torch.zeros(cfg.d_vocab, dtype=W_O.dtype)

Expand Down

0 comments on commit 2f44d5b

Please sign in to comment.