Skip to content

Commit

Permalink
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Load CICD model specially
Browse files Browse the repository at this point in the history
Signed-off-by: Asha Anoosheh <[email protected]>
AAnoosheh committed Jan 27, 2025
1 parent 0f967f8 commit 2bd6e40
Showing 3 changed files with 23 additions and 7 deletions.
4 changes: 2 additions & 2 deletions .github/workflows/cicd-main.yml
Original file line number Diff line number Diff line change
@@ -4894,8 +4894,8 @@ jobs:
SCRIPT: |
python scripts/llm/gpt_distillation.py \
--name nemo2_llama_distill \
--teacher_path /home/TestData/nemo2_ckpt/llama_68M_v2 \
--student_path /home/TestData/nemo2_ckpt/llama_68M_v2 \
--teacher_path /home/TestData/nemo2_ckpt/llama_68M_v3 \
--student_path /home/TestData/nemo2_ckpt/llama_68M_v3 \
--tp_size 1 \
--cp_size 1 \
--pp_size 2 \
14 changes: 14 additions & 0 deletions nemo/collections/llm/distillation/utils.py
Original file line number Diff line number Diff line change
@@ -152,3 +152,17 @@ def _swap_teacher_config(self, model_wrapper):
# HACK: Pipeline-parallel forward function relies on the config in the model to know what
# hidden size of tensor to communicate to next stage.
model.swap_teacher_config = MethodType(_swap_teacher_config, model)


def load_cicd_models(student_path: str):
# pylint: disable=C0116
import os.path

from nemo.collections.nlp.modules.common.tokenizer_utils import get_nmt_tokenizer
from tests.collections.llm.common import Llama3ConfigCI

tokenizer = get_nmt_tokenizer(tokenizer_model=os.path.join(student_path, "dummy_tokenizer.model"))
student_model = llm.LlamaModel(Llama3ConfigCI(), tokenizer=tokenizer)
teacher_model = llm.LlamaModel(Llama3ConfigCI(), tokenizer=tokenizer)

return student_model, teacher_model, tokenizer
12 changes: 7 additions & 5 deletions scripts/llm/gpt_distillation.py
Original file line number Diff line number Diff line change
@@ -88,13 +88,15 @@ def get_args():

## Load both models and combine into an aggregate module
if args.cicd_run:
from tests.collections.llm.common import Llama3ConfigCI # pylint: disable=W0611
from nemo.collections.llm.distillation.utils import load_cicd_models

_student_model = nl.io.load_context(path=ckpt_to_context_subdir(args.student_path), subpath="model")
_teacher_model = nl.io.load_context(path=ckpt_to_context_subdir(args.teacher_path), subpath="model")
_student_model, _teacher_model, tokenizer = load_cicd_models(args.student_path)
else:
_student_model = nl.io.load_context(path=ckpt_to_context_subdir(args.student_path), subpath="model")
_teacher_model = nl.io.load_context(path=ckpt_to_context_subdir(args.teacher_path), subpath="model")

tokenizer = getattr(_student_model, "tokenizer", None) or getattr(_teacher_model, "tokenizer", None)
assert tokenizer is not None, "Please provide a model checkpoint with tokenizer included."
tokenizer = getattr(_student_model, "tokenizer", None) or getattr(_teacher_model, "tokenizer", None)
assert tokenizer is not None, "Please provide a model checkpoint with tokenizer included."

model = distill.DistillationGPTModel(
_student_model.config,

0 comments on commit 2bd6e40

Please sign in to comment.