Skip to content

Commit

Permalink
remove references to 'ckpt'
Browse files Browse the repository at this point in the history
Signed-off-by: Chen Cui <[email protected]>
  • Loading branch information
cuichenx committed Oct 20, 2023
1 parent c6b0f53 commit 68ca6ad
Showing 1 changed file with 19 additions and 19 deletions.
38 changes: 19 additions & 19 deletions scripts/nlp_language_modeling/convert_nemo_gpt_to_mcore.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,8 +34,8 @@
<follow the readme in that script> \
--target_tensor_model_parallel_size=1 \
--target_pipeline_model_parallel_size=1
2) extract your checkpoint to a folder with
tar -xvf your_ckpt.nemo
2) extract your nemo file to a folder with
tar -xvf filename.nemo
Then, run this conversion script:
python convert_nemo_gpt_to_mcore.py \
Expand All @@ -56,7 +56,7 @@ def get_args():
return args


def get_mcore_model_from_nemo_ckpt(nemo_restore_from_path):
def get_mcore_model_from_nemo_file(nemo_restore_from_path):
model_cfg = MegatronGPTModel.restore_from(nemo_restore_from_path, return_config=True)
model_cfg.tokenizer.vocab_file = None
model_cfg.tokenizer.merge_file = None
Expand All @@ -78,7 +78,7 @@ def get_mcore_model_from_nemo_ckpt(nemo_restore_from_path):


def print_mcore_parameter_names(restore_from_path):
mcore_model = get_mcore_model_from_nemo_ckpt(restore_from_path)
mcore_model = get_mcore_model_from_nemo_file(restore_from_path)

print("*********")
print('\n'.join(sorted([k + '###' + str(v.shape) for k, v in mcore_model.named_parameters()])))
Expand Down Expand Up @@ -150,31 +150,31 @@ def load_model(model, state_dict):
return model


def convert(input_ckpt_file, output_ckpt_file, skip_if_output_exists=True):
if skip_if_output_exists and os.path.exists(output_ckpt_file):
logging.info(f"Output file already exists ({output_ckpt_file}), skipping conversion...")
def convert(input_nemo_file, output_nemo_file, skip_if_output_exists=True):
if skip_if_output_exists and os.path.exists(output_nemo_file):
logging.info(f"Output file already exists ({output_nemo_file}), skipping conversion...")
return
dummy_trainer = Trainer(devices=1, accelerator='cpu')

nemo_model = MegatronGPTModel.restore_from(input_ckpt_file, trainer=dummy_trainer)
nemo_model = MegatronGPTModel.restore_from(input_nemo_file, trainer=dummy_trainer)
nemo_tokenizer_model = nemo_model.cfg.tokenizer.model
nemo_state_dict = nemo_model.state_dict()
mcore_state_dict = OrderedDict()
for mcore_param, nemo_param in build_key_mapping(nemo_model.cfg).items():
mcore_state_dict[mcore_param] = nemo_state_dict[nemo_param]

mcore_model = get_mcore_model_from_nemo_ckpt(input_ckpt_file)
mcore_model = get_mcore_model_from_nemo_file(input_nemo_file)
mcore_model = load_model(mcore_model, mcore_state_dict)

if nemo_model.cfg.tokenizer.model is not None:
logging.info("registering artifact: tokenizer.model = " + nemo_tokenizer_model)
mcore_model.register_artifact("tokenizer.model", nemo_tokenizer_model)

mcore_model.save_to(output_ckpt_file)
logging.info(f"Done. Model saved to {output_ckpt_file}")
mcore_model.save_to(output_nemo_file)
logging.info(f"Done. Model saved to {output_nemo_file}")


def run_sanity_checks(nemo_ckpt_file, mcore_ckpt_file):
def run_sanity_checks(nemo_file, mcore_file):
cfg = OmegaConf.load(
os.path.join(
os.path.dirname(__file__),
Expand All @@ -185,8 +185,8 @@ def run_sanity_checks(nemo_ckpt_file, mcore_ckpt_file):
cfg.trainer.precision = 'bf16' # change me
dtype = torch.bfloat16
trainer = MegatronTrainerBuilder(cfg).create_trainer()
nemo_model = MegatronGPTModel.restore_from(nemo_ckpt_file, trainer=trainer).eval().to(dtype)
mcore_model = MegatronGPTModel.restore_from(mcore_ckpt_file, trainer=trainer).eval().to(dtype)
nemo_model = MegatronGPTModel.restore_from(nemo_file, trainer=trainer).eval().to(dtype)
mcore_model = MegatronGPTModel.restore_from(mcore_file, trainer=trainer).eval().to(dtype)

logging.debug("*** Mcore model restored config")
logging.debug(OmegaConf.to_yaml(mcore_model.cfg))
Expand Down Expand Up @@ -220,9 +220,9 @@ def run_sanity_checks(nemo_ckpt_file, mcore_ckpt_file):
if __name__ == '__main__':
args = get_args()

input_ckpt = args.in_file
output_ckpt = args.out_file
os.makedirs(os.path.dirname(output_ckpt), exist_ok=True)
convert(input_ckpt, output_ckpt, skip_if_output_exists=True)
input_nemo_file = args.in_file
output_nemo_file = args.out_file
os.makedirs(os.path.dirname(output_nemo_file), exist_ok=True)
convert(input_nemo_file, output_nemo_file, skip_if_output_exists=True)
torch.cuda.empty_cache()
run_sanity_checks(input_ckpt, output_ckpt)
run_sanity_checks(input_nemo_file, output_nemo_file)

0 comments on commit 68ca6ad

Please sign in to comment.