diff --git a/scripts/llm/gpt_distillation.py b/scripts/llm/gpt_distillation.py index fd6909e89d1f4..f832809774e46 100644 --- a/scripts/llm/gpt_distillation.py +++ b/scripts/llm/gpt_distillation.py @@ -14,8 +14,9 @@ import types from abc import ABCMeta +from argparse import ArgumentParser from contextlib import contextmanager -from dataclasses import dataclass +from dataclasses import asdict, dataclass from typing import Any, Callable, Dict, Optional, Tuple import modelopt.torch.distill as mtd @@ -34,15 +35,204 @@ from nemo import lightning as nl from nemo.collections import llm from nemo.collections.llm.gpt.model.base import get_batch_on_this_context_parallel_rank -from nemo.collections.nlp.models.language_modeling.megatron.gpt_layer_modelopt_spec import get_gpt_layer_modelopt_spec + +# from nemo.collections.nlp.models.language_modeling.megatron.gpt_layer_modelopt_spec import get_gpt_layer_modelopt_spec from nemo.collections.nlp.modules.common.megatron.module import Float16Module from nemo.collections.nlp.modules.common.megatron.utils import average_losses_across_data_parallel_group + +# from nemo.collections.nlp.modules.common.tokenizer_utils import get_nmt_tokenizer from nemo.lightning import io +from nemo.lightning.ckpt_utils import ckpt_to_context_subdir from nemo.lightning.megatron_parallel import DDP, MaskedTokenLossReduction +from nemo.lightning.pytorch.callbacks import ModelCheckpoint +from nemo.lightning.pytorch.optim import CosineAnnealingScheduler from nemo.utils import logging from nemo.utils.model_utils import unwrap_model +def get_args(): + """ + Parse the command line arguments. + """ + parser = ArgumentParser(description="""Run Knowledge Distillation from a teacher model to a student.""") + parser.add_argument( + "--name", + type=str, + required=True, + help="""Experiment name""", + ) + parser.add_argument( + "--teacher_path", + type=str, + required=True, + help="""Path to NeMo 2 checkpoint""", + ) + parser.add_argument( + "--student_path", + type=str, + required=True, + help="""Path to NeMo 2 checkpoint""", + ) + parser.add_argument( + "--tp", + type=int, + default=1, + help="""Tensor parallel size""", + ) + parser.add_argument( + "--cp", + type=int, + default=1, + help="""Context parallel size""", + ) + parser.add_argument( + "--pp", + type=int, + default=1, + help="""Pipeline parallel size""", + ) + parser.add_argument( + "--enable_sp", + action="store_true", + help="""Enable Sequence parallelism""", + ) + parser.add_argument( + "--precision", + type=str, + default="bf16-mixed", + help="""Datatype for models and optimizer""", + ) + parser.add_argument( + "--devices", + type=int, + default=1, + help="""Number of GPUs to use per node""", + ) + parser.add_argument( + "--nodes", + type=int, + default=1, + help="""Number of nodes to use""", + ) + parser.add_argument( + "--log_dir", + type=str, + required=True, + help="""Folder for logging and checkpoint saving""", + ) + parser.add_argument( + "--steps", + type=int, + required=True, + help="""Number of global batches to process""", + ) + parser.add_argument( + "--global_batch_size", + type=int, + required=True, + help="""Total number of data samples per batch""", + ) + parser.add_argument( + "--micro_batch_size", + type=int, + required=True, + help="""Number of samples to process through model simultaneously""", + ) + parser.add_argument( + "--data_paths", + nargs='+', + required=True, + help="""List of pre-tokenized data paths to load from""", + ) + parser.add_argument( + "--split", + type=str, + default="98,1,1", + help="""""", + ) + parser.add_argument( + "--index_mapping_dir", + type=str, + default=None, + help="""""", + ) + parser.add_argument( + "--sequence_length", + type=int, + required=True, + help="""Number of tokens per input sample""", + ) + # parser.add_argument( + # "--tokenizer_library", + # type=str, + # default="sentencepiece", + # help="""Library to load tokenizer from""", + # ) + # parser.add_argument( + # "--tokenizer_name_or_path", + # type=str, + # required=True, + # help="""Either the name of the tokenizer model in the library, or a path to a file""", + # ) + parser.add_argument( + "--optimzer", + type=str, + default="adam", + help="""Optimizer name""", + ) + parser.add_argument( + "--lr", + type=float, + default=3e-5, + help="""""", + ) + parser.add_argument( + "--min_lr", + type=float, + default=2e-7, + help="""""", + ) + parser.add_argument( + "--warmup_steps", + type=int, + default=50, + help="""""", + ) + parser.add_argument( + "--val_check_interval", + type=int, + default=100, + help="""""", + ) + parser.add_argument( + "--limit_val_batches", + type=int, + default=32, + help="""""", + ) + parser.add_argument( + "--limit_test_batches", + type=int, + default=32, + help="""""", + ) + parser.add_argument( + "--log_interval", + type=int, + default=10, + help="""""", + ) + parser.add_argument( + "--save_top_k", + type=int, + default=1, + help="""""", + ) + + args = parser.parse_args() + return args + + def gpt_distillation_data_step(dataloader_iter, attn_mask_cpu=False) -> Dict[str, torch.Tensor]: batch = next(dataloader_iter) @@ -390,6 +580,7 @@ def _teacher_provider(cfg: DistillationGPTConfig) -> MCoreGPTModel: plugins=nl.MegatronMixedPrecision(precision="bf16-mixed"), ) + # TODO(aanoosheh): Replace spec with modelopt one model, _ = io.ModelConnector().nemo_load(cfg.kd_teacher_restore_from_path, trainer, cpu=False) model = unwrap_model(model.module, (DDP, Float16Module, MCoreFloat16Module)) @@ -479,95 +670,91 @@ def _swap_teacher_config(self, model_wrapper): if __name__ == "__main__": logging.info("Distillation enabled.") - TEACHER_PATH = "./test_teacher/" + args = get_args() - seq_length = 2048 - global_batch_size = 16 - tp = 1 - pp = 1 - - # TODO: setup the dummy dataset - data = llm.MockDataModule(seq_length=seq_length, global_batch_size=global_batch_size) - - ## initialize the strategy + ## initialize the strategy and trainer strategy = nl.MegatronStrategy( - tensor_model_parallel_size=tp, - pipeline_model_parallel_size=pp, + tensor_model_parallel_size=args.tp, + pipeline_model_parallel_size=args.pp, + context_parallel_size=args.cp, + sequence_parallel=args.enable_sp, ) trainer = nl.Trainer( - devices=1, ## you can change the number of devices to suit your setup - max_steps=50, - accelerator="gpu", + devices=args.devices, + num_nodes=args.nodes, + max_steps=args.steps, + log_every_n_steps=args.log_interval, + val_check_interval=args.val_check_interval, + limit_val_batches=args.limit_val_batches, + limit_test_batches=args.limit_test_batches, strategy=strategy, - plugins=nl.MegatronMixedPrecision(precision="bf16-mixed"), - ) - - common_model_kwargs = dict( - seq_length=seq_length, - init_method_std=0.023, - hidden_dropout=0.0, - attention_dropout=0.0, - layernorm_epsilon=1e-5, - make_vocab_size_divisible_by=128, - transformer_layer_spec=get_gpt_layer_modelopt_spec(), + accelerator="gpu", + plugins=nl.MegatronMixedPrecision(precision=args.precision), ) - ############# TEACHER HACK ############# - import os - import sys - - if not os.path.exists(TEACHER_PATH): - from lightning.pytorch.trainer.states import TrainerFn - - gpt_config = llm.GPTConfig( - num_layers=9, - hidden_size=384, - ffn_hidden_size=1536, - num_attention_heads=6, - **common_model_kwargs, - ) - model = llm.GPTModel(gpt_config, tokenizer=data.tokenizer) - - strategy.ckpt_save_optimizer = False # otherwise need to do `model._trainer = trainer` - trainer.state.fn = TrainerFn.FITTING # needed for proper save. - trainer.strategy.connect(model) - trainer.strategy.setup_environment() - with trainer.init_module(): - model.configure_model() - - io.ModelConnector().nemo_save(TEACHER_PATH, trainer) - - sys.exit(0) - ########################################## + ## load student model + model = nl.io.load_context(path=ckpt_to_context_subdir(args.student_path), subpath="model") + assert hasattr(model, "tokenizer"), "Please provide a model checkpoint with tokenizer included." + model_config, tokenizer = model.config, model.tokenizer - ## initialize a small GPT model - gpt_config = DistillationGPTConfig( - num_layers=6, - hidden_size=384, - ffn_hidden_size=1536, - num_attention_heads=6, - **common_model_kwargs, - kd_teacher_restore_from_path=TEACHER_PATH, + model_config = DistillationGPTConfig( + **asdict(model_config), + # transformer_layer_spec=get_gpt_layer_modelopt_spec(), # TODO(aanoosheh) + kd_teacher_restore_from_path=args.teacher_path, ) - model = DistillationGPTModel(gpt_config, tokenizer=data.tokenizer) + model = DistillationGPTModel(model_config, tokenizer=tokenizer) + + # setup the dataset + data = llm.PreTrainingDataModule( + paths=args.data_paths, + seq_length=args.sequence_length, + micro_batch_size=args.micro_batch_size, + global_batch_size=args.global_batch_size, + split=args.split, + index_mapping_dir=args.index_mapping_dir, + tokenizer=tokenizer, + ) + + # auto-resume setup + # resume = nl.AutoResume( + # resume_if_exists=True, + # resume_ignore_no_checkpoint=True, + # resume_from_directory=LOG_DIR, + # restore_config=nl.RestoreConfig(path=STUDENT_PATH) if STUDENT_PATH else None, + # ) ## setup the optimizer opt_config = OptimizerConfig( - optimizer='adam', - lr=3e-5, - bf16=True, + optimizer=args.optimizer, + lr=args.lr, + bf16=("bf16" in args.precision), + ) + sched = CosineAnnealingScheduler( + max_steps=args.steps, + warmup_steps=args.warmup_steps, + constant_steps=0, + min_lr=args.min_lr, ) - opt = nl.MegatronOptimizerModule(config=opt_config) + opt = nl.MegatronOptimizerModule(opt_config, sched) - nemo_logger = nl.NeMoLogger( - log_dir="test_logdir", ## logs and checkpoints will be written here + # checkpointing and logging + checkpoint_callback = ModelCheckpoint( + monitor="val_loss", + save_top_k=args.save_top_k, + every_n_train_steps=args.val_check_interval, + ) + logger = nl.NeMoLogger( + name=args.name, + log_dir=args.log_dir, + ckpt=checkpoint_callback, ) + # run llm.train( model=model, data=data, - trainer=trainer, - log=nemo_logger, - tokenizer='data', optim=opt, + tokenizer='model', + trainer=trainer, + log=logger, )