diff --git a/nemo/export/trt_llm/nemo_ckpt_loader/nemo_file.py b/nemo/export/trt_llm/nemo_ckpt_loader/nemo_file.py index 5a9cded38bab..08629483e006 100644 --- a/nemo/export/trt_llm/nemo_ckpt_loader/nemo_file.py +++ b/nemo/export/trt_llm/nemo_ckpt_loader/nemo_file.py @@ -117,7 +117,7 @@ def load_scaling_factors(state_dict: dict, basename: str, size: int) -> Optional def filter_experts_extra_states(state_dict: dict): - pattern = r'module\.decoder\.layers\.mlp\.experts\.experts\.linear_fc\d+\._extra_state/shard_\d+\.\d+_\d+\.\d+' + pattern = r'model\.decoder\.layers\.mlp\.experts\.experts\.linear_fc\d+\._extra_state/shard_\d+\.\d+_\d+\.\d+' return {k: v for k, v in state_dict.items() if not re.fullmatch(pattern, k)}