Skip to content

Commit

Permalink
fixed copy_weights_phi for phi-4 series
Browse files Browse the repository at this point in the history
  • Loading branch information
ysjprojects committed Mar 1, 2025
1 parent 646e372 commit f6d2fd2
Showing 1 changed file with 5 additions and 3 deletions.
8 changes: 5 additions & 3 deletions litgpt/scripts/convert_lit_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -236,7 +236,7 @@ def copy_weights_phi(
"lm_head.weight": "lm_head.weight",
"lm_head.bias": "lm_head.bias",
}
if config.name.startswith(("Phi-3", "phi-4")):
if config.name.startswith(("Phi-3", "phi-4", "Phi-4")):
weight_map.update(
{
"transformer.h.{}.attn.qkv.weight": "model.layers.{}.self_attn.qkv_proj.weight",
Expand All @@ -249,10 +249,12 @@ def copy_weights_phi(
gate_up_proj_weights = defaultdict(dict)

for from_name, param in lit_weights.items():
if from_name == "lm_head.weight" and config.name.startswith("Phi-4"):
continue
name_template, layer_idx = layer_template(from_name)
param = load_param(param, from_name, None)
if from_name.endswith((".attn.qkv.weight", ".attn.qkv.bias")):
if config.name.startswith("Phi-3"):
if config.name.startswith(("Phi-3", "phi-4", "Phi-4")):
to_names = (weight_map[name_template].format(layer_idx),)
params = (param,)
else:
Expand Down Expand Up @@ -282,7 +284,7 @@ def copy_weights_phi(
param = saver.store_early(param)
state_dict[to_name] = param

if config.name.startswith("Phi-3"):
if config.name.startswith(("Phi-3", "phi-4", "Phi-4")):
for layer_idx in list(gate_up_proj_weights):
fc_1_weight = gate_up_proj_weights[layer_idx]["fc_1"]
fc_2_weight = gate_up_proj_weights[layer_idx]["fc_2"]
Expand Down

0 comments on commit f6d2fd2

Please sign in to comment.