Skip to content

Commit

Permalink
minor fix for llama ckpt conversion script (#7387)
Browse files Browse the repository at this point in the history
* minor fix for llama ckpt conversion script

Signed-off-by: Jason Wang <[email protected]>

* Update Jenkinsfile

Signed-off-by: Jason Wang <[email protected]>

* remove fast_swiglu configuration

Signed-off-by: Jason Wang <[email protected]>

---------

Signed-off-by: Jason Wang <[email protected]>
Co-authored-by: Eric Harper <[email protected]>
  • Loading branch information
2 people authored and yaoyu-33 committed Sep 18, 2023
1 parent 6021db8 commit 4582521
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 15 deletions.
1 change: 0 additions & 1 deletion Jenkinsfile
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,6 @@ pipeline {
sh 'CUDA_VISIBLE_DEVICES=0 python scripts/nlp_language_modeling/convert_hf_llama_to_nemo.py \
--in-file=/home/TestData/nlp/megatron_llama/llama-ci-hf \
--out-file=/home/TestData/nlp/megatron_llama/ci.nemo \
--fast-swiglu \
--precision=16'
sh 'rm -f /home/TestData/nlp/megatron_llama/ci.nemo'
}
Expand Down
25 changes: 11 additions & 14 deletions scripts/nlp_language_modeling/convert_hf_llama_to_nemo.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,6 @@ def get_args():
"--in-file", type=str, default=None, required=True, help="Path to Huggingface LLaMA checkpoints",
)
parser.add_argument("--out-file", type=str, default=None, required=True, help="Path to output .nemo file.")
parser.add_argument("--fast-swiglu", action="store_true", help="Enable fast swiglu by combining gate and up gemm")
parser.add_argument("--precision", type=str, default="32", help="Model precision")
args = parser.parse_args()
return args
Expand Down Expand Up @@ -109,7 +108,7 @@ def load_config(args, llama_config):
if 'num_key_value_heads' in llama_config:
nemo_config.num_query_groups = llama_config['num_key_value_heads']
nemo_config.use_cpu_initialization = True
nemo_config.activation = 'fast-swiglu' if args.fast_swiglu else 'swiglu'
nemo_config.activation = 'fast-swiglu'
nemo_config.tokenizer.model = llama_config['tokenizer_model']
if llama_config['rope_scaling'] is not None:
if llama_config['rope_scaling']['type'] == 'linear':
Expand Down Expand Up @@ -172,7 +171,7 @@ def convert(args):
dtype = torch.float32
elif precision in [16, "16", "16-mixed"]:
dtype = torch.float16
elif precision == ["bf16", "bf16-mixed"]:
elif precision in ["bf16", "bf16-mixed"]:
dtype = torch.bfloat16
else:
dtype = torch.float32 # fallback
Expand Down Expand Up @@ -254,19 +253,12 @@ def convert(args):
# MLP
mlp_down_weight = model.state_dict()[f'model.layers.{l}.mlp.gate_proj.weight']
mlp_gate_weight = model.state_dict()[f'model.layers.{l}.mlp.up_proj.weight']
if args.fast_swiglu:
if mcore_gpt:
mlp_down_base_name = f'model.decoder.layers.{l}.mlp.linear_fc1.weight'
else:
mlp_down_base_name = f'model.language_model.encoder.layers.{l}.mlp.dense_h_to_4h.weight'
mlp_down_weight = torch.cat((mlp_down_weight, mlp_gate_weight), axis=0)
checkpoint['state_dict'][mlp_down_base_name] = param_to_weights(mlp_down_weight)
if mcore_gpt:
mlp_down_base_name = f'model.decoder.layers.{l}.mlp.linear_fc1.weight'
else:
mlp_down_base_name = f'model.language_model.encoder.layers.{l}.mlp.dense_h_to_4h.weight'
checkpoint['state_dict'][mlp_down_base_name] = param_to_weights(mlp_down_weight)

mlp_gate_base_name = f'model.language_model.encoder.layers.{l}.mlp.dense_h_to_4h_2.weight'
checkpoint['state_dict'][mlp_gate_base_name] = param_to_weights(mlp_gate_weight)
mlp_down_weight = torch.cat((mlp_down_weight, mlp_gate_weight), axis=0)
checkpoint['state_dict'][mlp_down_base_name] = param_to_weights(mlp_down_weight)

mlp_up_weight = model.state_dict()[f'model.layers.{l}.mlp.down_proj.weight']
if mcore_gpt:
Expand Down Expand Up @@ -310,6 +302,11 @@ def convert(args):

del model

if nemo_config.get('megatron_amp_O2', False):
keys = list(checkpoint['state_dict'].keys())
for key in keys:
checkpoint['state_dict'][key.replace('model.', 'model.module.', 1)] = checkpoint['state_dict'].pop(key)

model = load_model(MegatronGPTModel, checkpoint, strict=False, trainer=trainer)

model._save_restore_connector = NLPSaveRestoreConnector()
Expand Down

0 comments on commit 4582521

Please sign in to comment.