Skip to content

Commit

Permalink
Add CLI args
Browse files Browse the repository at this point in the history
  • Loading branch information
AAnoosheh committed Jan 14, 2025
1 parent 80cbfd2 commit 0ce6e7d
Showing 1 changed file with 260 additions and 73 deletions.
333 changes: 260 additions & 73 deletions scripts/llm/gpt_distillation.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,9 @@

import types
from abc import ABCMeta
from argparse import ArgumentParser
from contextlib import contextmanager
from dataclasses import dataclass
from dataclasses import asdict, dataclass
from typing import Any, Callable, Dict, Optional, Tuple

import modelopt.torch.distill as mtd
Expand All @@ -34,15 +35,204 @@
from nemo import lightning as nl
from nemo.collections import llm
from nemo.collections.llm.gpt.model.base import get_batch_on_this_context_parallel_rank
from nemo.collections.nlp.models.language_modeling.megatron.gpt_layer_modelopt_spec import get_gpt_layer_modelopt_spec

# from nemo.collections.nlp.models.language_modeling.megatron.gpt_layer_modelopt_spec import get_gpt_layer_modelopt_spec
from nemo.collections.nlp.modules.common.megatron.module import Float16Module
from nemo.collections.nlp.modules.common.megatron.utils import average_losses_across_data_parallel_group

# from nemo.collections.nlp.modules.common.tokenizer_utils import get_nmt_tokenizer
from nemo.lightning import io
from nemo.lightning.ckpt_utils import ckpt_to_context_subdir
from nemo.lightning.megatron_parallel import DDP, MaskedTokenLossReduction
from nemo.lightning.pytorch.callbacks import ModelCheckpoint
from nemo.lightning.pytorch.optim import CosineAnnealingScheduler
from nemo.utils import logging
from nemo.utils.model_utils import unwrap_model


def get_args():
"""
Parse the command line arguments.
"""
parser = ArgumentParser(description="""Run Knowledge Distillation from a teacher model to a student.""")
parser.add_argument(
"--name",
type=str,
required=True,
help="""Experiment name""",
)
parser.add_argument(
"--teacher_path",
type=str,
required=True,
help="""Path to NeMo 2 checkpoint""",
)
parser.add_argument(
"--student_path",
type=str,
required=True,
help="""Path to NeMo 2 checkpoint""",
)
parser.add_argument(
"--tp",
type=int,
default=1,
help="""Tensor parallel size""",
)
parser.add_argument(
"--cp",
type=int,
default=1,
help="""Context parallel size""",
)
parser.add_argument(
"--pp",
type=int,
default=1,
help="""Pipeline parallel size""",
)
parser.add_argument(
"--enable_sp",
action="store_true",
help="""Enable Sequence parallelism""",
)
parser.add_argument(
"--precision",
type=str,
default="bf16-mixed",
help="""Datatype for models and optimizer""",
)
parser.add_argument(
"--devices",
type=int,
default=1,
help="""Number of GPUs to use per node""",
)
parser.add_argument(
"--nodes",
type=int,
default=1,
help="""Number of nodes to use""",
)
parser.add_argument(
"--log_dir",
type=str,
required=True,
help="""Folder for logging and checkpoint saving""",
)
parser.add_argument(
"--steps",
type=int,
required=True,

Check notice

Code scanning / CodeQL

Unused local variable Note

Variable forward_out is not used.
help="""Number of global batches to process""",
)
parser.add_argument(
"--global_batch_size",
type=int,
required=True,
help="""Total number of data samples per batch""",
)
parser.add_argument(
"--micro_batch_size",
type=int,
required=True,
help="""Number of samples to process through model simultaneously""",
)
parser.add_argument(
"--data_paths",
nargs='+',
required=True,
help="""List of pre-tokenized data paths to load from""",
)
parser.add_argument(
"--split",
type=str,
default="98,1,1",
help="""""",
)
parser.add_argument(
"--index_mapping_dir",
type=str,
default=None,
help="""""",
)
parser.add_argument(
"--sequence_length",
type=int,
required=True,
help="""Number of tokens per input sample""",
)
# parser.add_argument(
# "--tokenizer_library",
# type=str,
# default="sentencepiece",
# help="""Library to load tokenizer from""",
# )
# parser.add_argument(
# "--tokenizer_name_or_path",
# type=str,
# required=True,
# help="""Either the name of the tokenizer model in the library, or a path to a file""",
# )
parser.add_argument(
"--optimzer",
type=str,
default="adam",
help="""Optimizer name""",
)
parser.add_argument(
"--lr",
type=float,
default=3e-5,
help="""""",
)
parser.add_argument(
"--min_lr",
type=float,
default=2e-7,
help="""""",
)
parser.add_argument(
"--warmup_steps",
type=int,
default=50,
help="""""",
)
parser.add_argument(
"--val_check_interval",
type=int,
default=100,
help="""""",
)
parser.add_argument(
"--limit_val_batches",
type=int,
default=32,
help="""""",
)
parser.add_argument(
"--limit_test_batches",
type=int,
default=32,
help="""""",
)
parser.add_argument(
"--log_interval",
type=int,
default=10,
help="""""",
)
parser.add_argument(
"--save_top_k",
type=int,
default=1,
help="""""",
)

args = parser.parse_args()
return args


def gpt_distillation_data_step(dataloader_iter, attn_mask_cpu=False) -> Dict[str, torch.Tensor]:
batch = next(dataloader_iter)

Expand Down Expand Up @@ -390,6 +580,7 @@ def _teacher_provider(cfg: DistillationGPTConfig) -> MCoreGPTModel:
plugins=nl.MegatronMixedPrecision(precision="bf16-mixed"),
)

# TODO(aanoosheh): Replace spec with modelopt one
model, _ = io.ModelConnector().nemo_load(cfg.kd_teacher_restore_from_path, trainer, cpu=False)
model = unwrap_model(model.module, (DDP, Float16Module, MCoreFloat16Module))

Expand Down Expand Up @@ -479,95 +670,91 @@ def _swap_teacher_config(self, model_wrapper):
if __name__ == "__main__":
logging.info("Distillation enabled.")

TEACHER_PATH = "./test_teacher/"
args = get_args()

seq_length = 2048
global_batch_size = 16
tp = 1
pp = 1

# TODO: setup the dummy dataset
data = llm.MockDataModule(seq_length=seq_length, global_batch_size=global_batch_size)

## initialize the strategy
## initialize the strategy and trainer
strategy = nl.MegatronStrategy(
tensor_model_parallel_size=tp,
pipeline_model_parallel_size=pp,
tensor_model_parallel_size=args.tp,
pipeline_model_parallel_size=args.pp,
context_parallel_size=args.cp,
sequence_parallel=args.enable_sp,
)
trainer = nl.Trainer(
devices=1, ## you can change the number of devices to suit your setup
max_steps=50,
accelerator="gpu",
devices=args.devices,
num_nodes=args.nodes,
max_steps=args.steps,
log_every_n_steps=args.log_interval,
val_check_interval=args.val_check_interval,
limit_val_batches=args.limit_val_batches,
limit_test_batches=args.limit_test_batches,
strategy=strategy,
plugins=nl.MegatronMixedPrecision(precision="bf16-mixed"),
)

common_model_kwargs = dict(
seq_length=seq_length,
init_method_std=0.023,
hidden_dropout=0.0,
attention_dropout=0.0,
layernorm_epsilon=1e-5,
make_vocab_size_divisible_by=128,
transformer_layer_spec=get_gpt_layer_modelopt_spec(),
accelerator="gpu",
plugins=nl.MegatronMixedPrecision(precision=args.precision),
)

############# TEACHER HACK #############
import os
import sys

if not os.path.exists(TEACHER_PATH):
from lightning.pytorch.trainer.states import TrainerFn

gpt_config = llm.GPTConfig(
num_layers=9,
hidden_size=384,
ffn_hidden_size=1536,
num_attention_heads=6,
**common_model_kwargs,
)
model = llm.GPTModel(gpt_config, tokenizer=data.tokenizer)

strategy.ckpt_save_optimizer = False # otherwise need to do `model._trainer = trainer`
trainer.state.fn = TrainerFn.FITTING # needed for proper save.
trainer.strategy.connect(model)
trainer.strategy.setup_environment()
with trainer.init_module():
model.configure_model()

io.ModelConnector().nemo_save(TEACHER_PATH, trainer)

sys.exit(0)
##########################################
## load student model
model = nl.io.load_context(path=ckpt_to_context_subdir(args.student_path), subpath="model")
assert hasattr(model, "tokenizer"), "Please provide a model checkpoint with tokenizer included."
model_config, tokenizer = model.config, model.tokenizer

## initialize a small GPT model
gpt_config = DistillationGPTConfig(
num_layers=6,
hidden_size=384,
ffn_hidden_size=1536,
num_attention_heads=6,
**common_model_kwargs,
kd_teacher_restore_from_path=TEACHER_PATH,
model_config = DistillationGPTConfig(
**asdict(model_config),
# transformer_layer_spec=get_gpt_layer_modelopt_spec(), # TODO(aanoosheh)
kd_teacher_restore_from_path=args.teacher_path,
)
model = DistillationGPTModel(gpt_config, tokenizer=data.tokenizer)
model = DistillationGPTModel(model_config, tokenizer=tokenizer)

# setup the dataset
data = llm.PreTrainingDataModule(
paths=args.data_paths,
seq_length=args.sequence_length,
micro_batch_size=args.micro_batch_size,
global_batch_size=args.global_batch_size,
split=args.split,
index_mapping_dir=args.index_mapping_dir,
tokenizer=tokenizer,
)

# auto-resume setup
# resume = nl.AutoResume(
# resume_if_exists=True,
# resume_ignore_no_checkpoint=True,
# resume_from_directory=LOG_DIR,
# restore_config=nl.RestoreConfig(path=STUDENT_PATH) if STUDENT_PATH else None,
# )

## setup the optimizer
opt_config = OptimizerConfig(
optimizer='adam',
lr=3e-5,
bf16=True,
optimizer=args.optimizer,
lr=args.lr,
bf16=("bf16" in args.precision),
)
sched = CosineAnnealingScheduler(
max_steps=args.steps,
warmup_steps=args.warmup_steps,
constant_steps=0,
min_lr=args.min_lr,
)
opt = nl.MegatronOptimizerModule(config=opt_config)
opt = nl.MegatronOptimizerModule(opt_config, sched)

nemo_logger = nl.NeMoLogger(
log_dir="test_logdir", ## logs and checkpoints will be written here
# checkpointing and logging
checkpoint_callback = ModelCheckpoint(
monitor="val_loss",
save_top_k=args.save_top_k,
every_n_train_steps=args.val_check_interval,
)
logger = nl.NeMoLogger(
name=args.name,
log_dir=args.log_dir,
ckpt=checkpoint_callback,
)

# run
llm.train(
model=model,
data=data,
trainer=trainer,
log=nemo_logger,
tokenizer='data',
optim=opt,
tokenizer='model',
trainer=trainer,
log=logger,
)

0 comments on commit 0ce6e7d

Please sign in to comment.