Skip to content

Commit

Permalink
Add CLI args
Browse files Browse the repository at this point in the history
  • Loading branch information
AAnoosheh committed Jan 14, 2025
1 parent 80cbfd2 commit b748ae8
Showing 1 changed file with 114 additions and 74 deletions.
188 changes: 114 additions & 74 deletions scripts/llm/gpt_distillation.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,9 @@

import types
from abc import ABCMeta
from argparse import ArgumentParser
from contextlib import contextmanager
from dataclasses import dataclass
from dataclasses import asdict, dataclass
from typing import Any, Callable, Dict, Optional, Tuple

import modelopt.torch.distill as mtd
Expand All @@ -34,15 +35,55 @@
from nemo import lightning as nl
from nemo.collections import llm
from nemo.collections.llm.gpt.model.base import get_batch_on_this_context_parallel_rank
from nemo.collections.nlp.models.language_modeling.megatron.gpt_layer_modelopt_spec import get_gpt_layer_modelopt_spec

# from nemo.collections.nlp.models.language_modeling.megatron.gpt_layer_modelopt_spec import get_gpt_layer_modelopt_spec
from nemo.collections.nlp.modules.common.megatron.module import Float16Module
from nemo.collections.nlp.modules.common.megatron.utils import average_losses_across_data_parallel_group
from nemo.lightning import io
from nemo.lightning.ckpt_utils import ckpt_to_context_subdir
from nemo.lightning.megatron_parallel import DDP, MaskedTokenLossReduction
from nemo.lightning.pytorch.callbacks import ModelCheckpoint
from nemo.lightning.pytorch.optim import CosineAnnealingScheduler
from nemo.utils import logging
from nemo.utils.model_utils import unwrap_model


def get_args():
"""
Parse the command line arguments.
"""
parser = ArgumentParser(description="""Run Knowledge Distillation from a teacher model to a student.""")

parser.add_argument("--name", type=str, required=True, help="""Experiment name""")
parser.add_argument("--teacher_path", type=str, required=True, help="""Path to NeMo 2 checkpoint""")
parser.add_argument("--student_path", type=str, required=True, help="""Path to NeMo 2 checkpoint""")
parser.add_argument("--tp", type=int, default=1, help="""Tensor parallel size""")
parser.add_argument("--cp", type=int, default=1, help="""Context parallel size""")
parser.add_argument("--pp", type=int, default=1, help="""Pipeline parallel size""")
parser.add_argument("--enable_sp", action="store_true", help="""Enable Sequence parallelism""")
parser.add_argument("--precision", type=str, default="bf16-mixed", help="""Datatype for models and optimizer""")
parser.add_argument("--devices", type=int, default=1, help="""Number of GPUs to use per node""")
parser.add_argument("--nodes", type=int, default=1, help="""Number of nodes to use""")
parser.add_argument("--log_dir", type=str, required=True, help="""Folder for logging and checkpoint saving""")
parser.add_argument("--steps", type=int, required=True, help="""Number of global batches to process""")
parser.add_argument("--global_batch_size", type=int, required=True, help="""Data samples per optimizer step""")
parser.add_argument("--micro_batch_size", type=int, required=True, help="""Data samples per forward pass""")
parser.add_argument("--data_paths", nargs='+', required=True, help="""List of tokenized data paths to load from""")
parser.add_argument("--split", type=str, default="99,1,0", help="""""")
parser.add_argument("--index_mapping_dir", type=str, default=None, help="""""")
parser.add_argument("--sequence_length", type=int, required=True, help="""Number of tokens per input sample""")
parser.add_argument("--lr", type=float, default=3e-5, help="""""")
parser.add_argument("--min_lr", type=float, default=2e-7, help="""""")
parser.add_argument("--warmup_steps", type=int, default=50, help="""""")
parser.add_argument("--val_check_interval", type=int, default=100, help="""""")
parser.add_argument("--limit_val_batches", type=int, default=32, help="""""")
parser.add_argument("--limit_test_batches", type=int, default=32, help="""""")
parser.add_argument("--log_interval", type=int, default=10, help="""""")

args = parser.parse_args()
return args


def gpt_distillation_data_step(dataloader_iter, attn_mask_cpu=False) -> Dict[str, torch.Tensor]:
batch = next(dataloader_iter)

Expand Down Expand Up @@ -173,7 +214,8 @@ def data_step(self, dataloader_iter, cache_num_batches: Optional[int] = None) ->
return _LoopingCachedDataIterator(batches)
elif isinstance(dataloader_iter, _LoopingCachedDataIterator):
batch = next(dataloader_iter)
batch["attention_mask"] = batch["attention_mask"].cuda(non_blocking=True) # move back to GPU
if "attention_mask" in batch:
batch["attention_mask"] = batch["attention_mask"].cuda(non_blocking=True) # move back to GPU
return batch
else:
return self.config.data_step_fn(dataloader_iter)
Expand Down Expand Up @@ -390,6 +432,7 @@ def _teacher_provider(cfg: DistillationGPTConfig) -> MCoreGPTModel:
plugins=nl.MegatronMixedPrecision(precision="bf16-mixed"),
)

# TODO(aanoosheh): Replace spec with modelopt one
model, _ = io.ModelConnector().nemo_load(cfg.kd_teacher_restore_from_path, trainer, cpu=False)
model = unwrap_model(model.module, (DDP, Float16Module, MCoreFloat16Module))

Expand Down Expand Up @@ -479,95 +522,92 @@ def _swap_teacher_config(self, model_wrapper):
if __name__ == "__main__":
logging.info("Distillation enabled.")

TEACHER_PATH = "./test_teacher/"
args = get_args()

seq_length = 2048
global_batch_size = 16
tp = 1
pp = 1

# TODO: setup the dummy dataset
data = llm.MockDataModule(seq_length=seq_length, global_batch_size=global_batch_size)

## initialize the strategy
## initialize the strategy and trainer
strategy = nl.MegatronStrategy(
tensor_model_parallel_size=tp,
pipeline_model_parallel_size=pp,
tensor_model_parallel_size=args.tp,
pipeline_model_parallel_size=args.pp,
context_parallel_size=args.cp,
sequence_parallel=args.enable_sp,
)
trainer = nl.Trainer(
devices=1, ## you can change the number of devices to suit your setup
max_steps=50,
accelerator="gpu",
devices=args.devices,
num_nodes=args.nodes,
max_steps=args.steps,
log_every_n_steps=args.log_interval,
val_check_interval=args.val_check_interval,
limit_val_batches=args.limit_val_batches,
limit_test_batches=args.limit_test_batches,
strategy=strategy,
plugins=nl.MegatronMixedPrecision(precision="bf16-mixed"),
)

common_model_kwargs = dict(
seq_length=seq_length,
init_method_std=0.023,
hidden_dropout=0.0,
attention_dropout=0.0,
layernorm_epsilon=1e-5,
make_vocab_size_divisible_by=128,
transformer_layer_spec=get_gpt_layer_modelopt_spec(),
accelerator="gpu",
plugins=nl.MegatronMixedPrecision(precision=args.precision),
)

############# TEACHER HACK #############
import os
import sys

if not os.path.exists(TEACHER_PATH):
from lightning.pytorch.trainer.states import TrainerFn
## load student model
model = nl.io.load_context(path=ckpt_to_context_subdir(args.student_path), subpath="model")
assert hasattr(model, "tokenizer"), "Please provide a model checkpoint with tokenizer included."
model_config, tokenizer = model.config, model.tokenizer

gpt_config = llm.GPTConfig(
num_layers=9,
hidden_size=384,
ffn_hidden_size=1536,
num_attention_heads=6,
**common_model_kwargs,
)
model = llm.GPTModel(gpt_config, tokenizer=data.tokenizer)

strategy.ckpt_save_optimizer = False # otherwise need to do `model._trainer = trainer`
trainer.state.fn = TrainerFn.FITTING # needed for proper save.
trainer.strategy.connect(model)
trainer.strategy.setup_environment()
with trainer.init_module():
model.configure_model()

io.ModelConnector().nemo_save(TEACHER_PATH, trainer)

sys.exit(0)
##########################################

## initialize a small GPT model
gpt_config = DistillationGPTConfig(
num_layers=6,
hidden_size=384,
ffn_hidden_size=1536,
num_attention_heads=6,
**common_model_kwargs,
kd_teacher_restore_from_path=TEACHER_PATH,
model_config = DistillationGPTConfig(
**asdict(model_config),
# transformer_layer_spec=get_gpt_layer_modelopt_spec(), # TODO(aanoosheh)
kd_teacher_restore_from_path=args.teacher_path,
)
model = DistillationGPTModel(model_config, tokenizer=tokenizer)

# setup the dataset
data = llm.PreTrainingDataModule(
paths=args.data_paths,
seq_length=args.sequence_length,
micro_batch_size=args.micro_batch_size,
global_batch_size=args.global_batch_size,
split=args.split,
index_mapping_dir=args.index_mapping_dir,
tokenizer=tokenizer,
)
model = DistillationGPTModel(gpt_config, tokenizer=data.tokenizer)

# auto-resume setup
# resume = nl.AutoResume(
# resume_if_exists=True,
# resume_ignore_no_checkpoint=True,
# resume_from_directory=LOG_DIR,
# restore_config=nl.RestoreConfig(path=STUDENT_PATH) if STUDENT_PATH else None,
# )

## setup the optimizer
opt_config = OptimizerConfig(
optimizer='adam',
lr=3e-5,
bf16=True,
optimizer="adam",
lr=args.lr,
bf16=("bf16" in args.precision),
use_distributed_optimizer=True,
)
sched = CosineAnnealingScheduler(
max_steps=args.steps,
warmup_steps=args.warmup_steps,
constant_steps=0,
min_lr=args.min_lr,
)
opt = nl.MegatronOptimizerModule(config=opt_config)
opt = nl.MegatronOptimizerModule(opt_config, sched)

nemo_logger = nl.NeMoLogger(
log_dir="test_logdir", ## logs and checkpoints will be written here
# checkpointing and logging
checkpoint_callback = ModelCheckpoint(
monitor="val_loss",
save_top_k=1,
every_n_train_steps=args.val_check_interval,
)
logger = nl.NeMoLogger(
name=args.name,
log_dir=args.log_dir,
ckpt=checkpoint_callback,
)

# run
llm.train(
model=model,
data=data,
trainer=trainer,
log=nemo_logger,
tokenizer='data',
optim=opt,
tokenizer='model',
trainer=trainer,
log=logger,
)

0 comments on commit b748ae8

Please sign in to comment.