Skip to content

Commit

Permalink
Nemo 2.0 ckpt support in TRT-LLM export (NVIDIA#10891)
Browse files Browse the repository at this point in the history
* fix minor import bug

Signed-off-by: Onur Yilmaz <[email protected]>

* Add registry to register all needed classes with artifacts in nemo.lightning.io

Signed-off-by: Hemil Desai <[email protected]>

* Apply isort and black reformatting

Signed-off-by: hemildesai <[email protected]>

* Fixes

Signed-off-by: Hemil Desai <[email protected]>

* Apply isort and black reformatting

Signed-off-by: hemildesai <[email protected]>

* Fix

Signed-off-by: Hemil Desai <[email protected]>

* nemo 2.0 support in export to trt-llm

Signed-off-by: Onur Yilmaz <[email protected]>

* get mixing from main

Signed-off-by: Onur Yilmaz <[email protected]>

* Apply isort and black reformatting

Signed-off-by: oyilmaz-nvidia <[email protected]>

* fix style

Signed-off-by: Onur Yilmaz <[email protected]>

---------

Signed-off-by: Onur Yilmaz <[email protected]>
Signed-off-by: Hemil Desai <[email protected]>
Signed-off-by: hemildesai <[email protected]>
Signed-off-by: oyilmaz-nvidia <[email protected]>
Co-authored-by: Hemil Desai <[email protected]>
Co-authored-by: hemildesai <[email protected]>
Co-authored-by: oyilmaz-nvidia <[email protected]>
  • Loading branch information
4 people authored and XuesongYang committed Jan 18, 2025
1 parent b5797dd commit ac16e11
Show file tree
Hide file tree
Showing 3 changed files with 71 additions and 9 deletions.
3 changes: 3 additions & 0 deletions nemo/export/tensorrt_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -374,8 +374,11 @@ def export(
)

tokenizer_path = os.path.join(nemo_export_dir, "tokenizer.model")
tokenizer_path_nemo2 = os.path.join(nemo_export_dir, "nemo_context")
if os.path.exists(tokenizer_path):
shutil.copy(tokenizer_path, self.model_dir)
elif os.path.exists(tokenizer_path_nemo2):
shutil.copytree(tokenizer_path_nemo2, Path(self.model_dir) / "nemo_context")
else:
self.tokenizer.save_pretrained(os.path.join(self.model_dir, 'huggingface_tokenizer'))

Expand Down
5 changes: 3 additions & 2 deletions nemo/export/trt_llm/converter/model_to_trt_llm_ckpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,7 +157,8 @@ def convert_model_to_trt_llm_ckpt(
num_kv_heads = num_attention_heads

export_config = {
"apply_layernorm_1p": nemo_model_config.get("normalization", "") == "layernorm1p",
"apply_layernorm_1p": nemo_model_config.get("normalization", "") == "layernorm1p"
or nemo_model_config.get("layernorm_zero_centered_gamma", False),
"tp_size": training_tp_size,
"split_gated_activation": nemo_model_config.get("activation", "gelu")
in ["swiglu", "geglu", "fast-swiglu", "fast-geglu"]
Expand Down Expand Up @@ -195,7 +196,7 @@ def handle_model_level_weights(model, tp_idx: int, pp_idx: int):

val = val.to(storage_type).cpu()
model_level_weights["transformer.vocab_embedding.weight"].append(val)
if has_lm_head and pp_idx == training_pp_size - 1:
if has_lm_head and pp_idx == training_pp_size - 1 and decoder_type != "gemma":
val = model.get("state_dict", model)[get_layer_name("output_layer", prefix)]
val = val.to(storage_type).cpu()
model_level_weights["lm_head.weight"].append(val)
Expand Down
72 changes: 65 additions & 7 deletions nemo/export/trt_llm/nemo_ckpt_loader/nemo_file.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@
import json
import logging
import os
import re
import shutil
from io import BytesIO
from pathlib import Path
from typing import Dict, List, Optional, Tuple, Union
Expand Down Expand Up @@ -114,6 +116,11 @@ def load_scaling_factors(state_dict: dict, basename: str, size: int) -> Optional
return load_scales_from_bytes(bytes_list)


def filter_experts_extra_states(state_dict: dict):
pattern = r'module\.decoder\.layers\.mlp\.experts\.experts\.linear_fc\d+\._extra_state/shard_\d+\.\d+_\d+\.\d+'
return {k: v for k, v in state_dict.items() if not re.fullmatch(pattern, k)}


def standarize_distributed_scaling_factors(state_dict: dict) -> dict:
while key := get_extra_state_key(state_dict):
basename, size = unpack_extra_state_key(key)
Expand Down Expand Up @@ -144,6 +151,7 @@ def load_sharded_metadata_torch_dist(checkpoint_dir: Union[Path, TarPath], torch
storage_reader=fs_reader,
no_dist=True,
)
state_dict = filter_experts_extra_states(state_dict)
state_dict = standarize_distributed_scaling_factors(state_dict)

if not torch_tensor:
Expand Down Expand Up @@ -277,12 +285,20 @@ def copy_tokenizer_files(config, out_dir):

def get_tokenzier(tokenizer_dir_or_path: Path) -> PreTrainedTokenizer:
"""Loads the tokenizer from the decoded NEMO weights dir."""
if os.path.isdir(os.path.join(tokenizer_dir_or_path, "huggingface_tokenizer")):
return AutoTokenizer.from_pretrained(os.path.join(tokenizer_dir_or_path, "huggingface_tokenizer"))
if (tokenizer_dir_or_path / "nemo_context").exists():
from nemo.lightning import io

tokenizer_spec = io.load_context((tokenizer_dir_or_path / "nemo_context"), subpath="model.tokenizer")
return build_tokenizer(tokenizer_spec)
else:
if os.path.isdir(os.path.join(tokenizer_dir_or_path, "huggingface_tokenizer")):
return AutoTokenizer.from_pretrained(os.path.join(tokenizer_dir_or_path, "huggingface_tokenizer"))

model_path = tokenizer_dir_or_path / "tokenizer.model" if tokenizer_dir_or_path.is_dir() else tokenizer_dir_or_path
tokenizer_config = {"library": "sentencepiece", "model": str(model_path)}
return build_tokenizer(tokenizer_config)
model_path = (
tokenizer_dir_or_path / "tokenizer.model" if tokenizer_dir_or_path.is_dir() else tokenizer_dir_or_path
)
tokenizer_config = {"library": "sentencepiece", "model": str(model_path)}
return build_tokenizer(tokenizer_config)


def build_tokenizer(tokenizer):
Expand All @@ -309,6 +325,7 @@ def build_tokenizer(tokenizer):
def batch_encode_patch(self, ids):
if torch.is_tensor(ids):
ids = ids.cpu().numpy()
ids = ids[0] if len(ids.shape) > 1 else ids
return self.ids_to_text(ids)

tokenizer.bos_token_id = tokenizer.bos_id
Expand All @@ -331,11 +348,13 @@ def load_nemo_model(nemo_ckpt: Union[str, Path], nemo_export_dir: Union[str, Pat
else:
nemo_dir = TarPath(nemo_ckpt)

tokenizer = None
try:
unpacked_checkpoint_dir = UnpackedNemoCheckpointDir(nemo_dir, load_checkpoints_to_cpu=True)

dist_ckpt_folder = nemo_dir / "model_weights"
if dist_ckpt_folder.exists():
if (nemo_dir / "model_weights").exists():
dist_ckpt_folder = nemo_dir / "model_weights"

model = load_sharded_metadata(dist_ckpt_folder)
nemo_model_config = unpacked_checkpoint_dir.model_config

Expand All @@ -350,6 +369,45 @@ def load_nemo_model(nemo_ckpt: Union[str, Path], nemo_export_dir: Union[str, Pat

tokenizer_config["model"] = os.path.join(nemo_export_dir, "tokenizer.model")
tokenizer = build_tokenizer(tokenizer_config)
elif (nemo_dir / "weights").exists():
dist_ckpt_folder = nemo_dir / "weights"
model = load_sharded_metadata(dist_ckpt_folder)
io_folder = nemo_dir / "context"

if (io_folder / "model.yaml").exists():
with open(io_folder / "model.yaml", 'r') as stream:
config = yaml.safe_load(stream)

nemo_model_config = {}
for k, v in config["config"].items():
if isinstance(v, (float, int, str, bool)):
nemo_model_config[k] = v
elif k == "activation_func":
nemo_model_config["activation"] = v["_target_"].rsplit('.', 1)[-1]
else:
from nemo.lightning import io

config = io.load_context(io_folder, subpath="model.config")

nemo_model_config = {}
for k, v in config.__dict__.items():
if isinstance(v, (float, int, str, bool)):
nemo_model_config[k] = v
elif k == "activation_func":
nemo_model_config["activation"] = v.__name__

if nemo_model_config.get("num_moe_experts") is None:
nemo_model_config["num_moe_experts"] = 0
nemo_model_config["moe_router_topk"] = 0
if nemo_model_config["activation"] == "silu":
nemo_model_config["activation"] = "fast-swiglu"
elif nemo_model_config["activation"] == "openai_gelu":
nemo_model_config["activation"] = "geglu"

nemo_model_config["mcore_gpt"] = True
nemo_model_config["max_position_embeddings"] = nemo_model_config.get("seq_length", 4096)

shutil.copytree(io_folder, nemo_export_dir / "nemo_context")
else:
raise Exception("Not a supported NeMo file format: only distributed MCore NeMo checkpoints are supported.")
finally:
Expand Down

0 comments on commit ac16e11

Please sign in to comment.