Skip to content

Commit

Permalink
Alit/mamba recipe (#10935)
Browse files Browse the repository at this point in the history
* add some mamba recipe

* add 130m

* add the rest of the recipes

* add tokenizer

* add tokenizer

* minor fix

* minor fix

* minor fix

* minor fix

* minor fix

* minor fix

* minor fix

* minor fix

* minor fix

* minor fix

* minor fix

* add fixes to ssm for nemorun recipes

* add hybrid tokenizer

* updating some recipes

* Apply isort and black reformatting

Signed-off-by: JRD971000 <[email protected]>

* remove comments

* update gbs

* fix ckpt resume

* fix ckpt resume

* fix ckpt resume

* update recipes final

* Apply isort and black reformatting

Signed-off-by: JRD971000 <[email protected]>

* remove redundant imports

* ckpt convertor dtype fix

* Apply isort and black reformatting

Signed-off-by: JRD971000 <[email protected]>

---------

Signed-off-by: JRD971000 <[email protected]>
Signed-off-by: Ali Taghibakhshi <[email protected]>
Co-authored-by: JRD971000 <[email protected]>
  • Loading branch information
2 people authored and titu1994 committed Oct 28, 2024
1 parent 8141251 commit d0be75c
Show file tree
Hide file tree
Showing 11 changed files with 2,281 additions and 3 deletions.
14 changes: 13 additions & 1 deletion nemo/collections/llm/gpt/model/ssm.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,9 @@ class SSMConfig(TransformerConfig, io.IOMixin):
fp16_lm_cross_entropy: bool = False
parallel_output: bool = True
share_embeddings_and_output_weights: bool = False
params_dtype: torch.dtype = torch.bfloat16
fp16: bool = False
bf16: bool = True
num_layers: int = 2
mamba_ssm_ngroups: int = 8
num_attention_heads: int = 1
Expand Down Expand Up @@ -81,6 +84,7 @@ class SSMConfig(TransformerConfig, io.IOMixin):

forward_step_fn: Callable = ssm_forward_step
data_step_fn: Callable = gpt_data_step
tokenizer_model_path: str = None

def configure_model(self, tokenizer) -> "MCoreMambaModel":

Expand Down Expand Up @@ -127,9 +131,17 @@ def __init__(self, state_dict):
def state_dict(self):
return self._state_dict

def to(self, dtype):
for k, v in self._state_dict.items():
if v.dtype != dtype:
logging.warning(f"Converting {k} from {v.dtype} (source model) to {dtype} (target model)")
self._state_dict[k] = v.to(dtype)

source = ModelState(source)
target = self.init()
trainer = self.nemo_setup(target)
trainer = self.nemo_setup(target, ckpt_async_save=False)
source.to(self.config.params_dtype)
target.to(self.config.params_dtype)
self.convert_state(source, target)
self.nemo_save(output_path, trainer)

Expand Down
14 changes: 14 additions & 0 deletions nemo/collections/llm/recipes/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,13 @@
llama3_70b_16k,
llama3_70b_64k,
llama31_405b,
mamba2_1_3b,
mamba2_2_7b,
mamba2_8b,
mamba2_130m,
mamba2_370m,
mamba2_780m,
mamba2_hybrid_8b,
mistral_7b,
mistral_nemo_12b,
mixtral_8x7b,
Expand Down Expand Up @@ -49,6 +56,13 @@
"llama3_70b_16k",
"llama3_70b_64k",
"llama31_405b",
"mamba2_130m",
"mamba2_370m",
"mamba2_780m",
"mamba2_1_3b",
"mamba2_2_7b",
"mamba2_8b",
"mamba2_hybrid_8b",
"mistral_7b",
"mistral_nemo_12b",
"mixtral_8x7b",
Expand Down
321 changes: 321 additions & 0 deletions nemo/collections/llm/recipes/mamba2_130m.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,321 @@
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


from typing import Optional

import nemo_run as run
import pytorch_lightning as pl
import torch
from megatron.core.distributed import DistributedDataParallelConfig
from pytorch_lightning.callbacks.callback import Callback

from nemo import lightning as nl
from nemo.collections import llm
from nemo.collections.llm.api import finetune, pretrain
from nemo.collections.llm.gpt.data.mock import MockDataModule
from nemo.collections.llm.recipes.log.default import default_log, default_resume, tensorboard_logger
from nemo.collections.llm.recipes.optim.adam import distributed_fused_adam_with_cosine_annealing
from nemo.collections.llm.recipes.precision.mixed_precision import bf16_mixed
from nemo.collections.nlp.modules.common.tokenizer_utils import get_nmt_tokenizer
from nemo.utils.exp_manager import TimingCallback

NAME = "mamba2_130m"


@run.cli.factory(name=NAME)
def tokenizer(tokenizer_model: str = None) -> run.Config[pl.LightningModule]:

return run.Config(
get_nmt_tokenizer,
library='huggingface',
model_name="EleutherAI/gpt-neox-20b",
tokenizer_model=tokenizer_model,
use_fast=True,
)


@run.cli.factory(name=NAME)
def model(tokenizer_model: str = None) -> run.Config[pl.LightningModule]:
"""
Factory function to create a Mamba2 130M model configuration.
Returns:
run.Config[pl.LightningModule]: Configuration for the Mamba2 130M model.
Examples:
CLI usage:
$ nemo llm pretrain model=mamba2_130m ...
Python API usage:
>>> model_config = model()
>>> print(model_config)
"""
return run.Config(
llm.GPTModel, config=run.Config(llm.BaseMambaConfig130M), tokenizer=tokenizer(tokenizer_model=tokenizer_model)
)


def trainer(
tensor_parallelism: int = 1,
pipeline_parallelism: int = 1,
pipeline_parallelism_type: Optional[torch.dtype] = None,
virtual_pipeline_parallelism: Optional[int] = None,
context_parallelism: int = 1,
sequence_parallelism: bool = False,
num_nodes: int = 1,
num_gpus_per_node: int = 8,
max_steps: int = 1168251,
callbacks: Optional[list[run.Config[Callback]]] = None,
) -> run.Config[nl.Trainer]:
"""
Configure the NeMo Lightning Trainer for Mamba2 130M model.
This function sets up the distributed training strategy and other training parameters.
Args:
tensor_parallelism (int): Degree of tensor model parallelism.
pipeline_parallelism (int): Degree of pipeline model parallelism.
pipeline_parallelism_type (Optional[torch.dtype]): Data type for pipeline parallelism.
virtual_pipeline_parallelism (Optional[int]): Size of virtual pipeline parallelism.
context_parallelism (int): Degree of context parallelism.
sequence_parallelism (bool): Whether to use sequence parallelism.
num_nodes (int): Number of compute nodes to use.
num_gpus_per_node (int): Number of GPUs per node.
max_steps (int): Maximum number of training steps.
callbacks (Optional[list[run.Config[Callback]]]): List of callback configurations.
Returns:
run.Config[nl.Trainer]: Configuration for the NeMo Lightning Trainer.
Examples:
CLI usage:
$ nemo llm pretrain trainer=mamba2_130m ...
Python API usage:
>>> trainer_config = trainer(num_nodes=1, num_gpus_per_node=1)
>>> print(trainer_config)
Note:
For more information on distributed training strategies, refer to the
NeMo documentation on multi-GPU and multi-node training.
"""
strategy = run.Config(
nl.MegatronStrategy,
tensor_model_parallel_size=tensor_parallelism,
pipeline_model_parallel_size=pipeline_parallelism,
pipeline_dtype=pipeline_parallelism_type,
virtual_pipeline_model_parallel_size=virtual_pipeline_parallelism,
context_parallel_size=context_parallelism,
sequence_parallel=sequence_parallelism,
gradient_as_bucket_view=True,
ckpt_async_save=False,
ckpt_parallel_load=True,
ddp=run.Config(
DistributedDataParallelConfig,
check_for_nan_in_grad=True,
grad_reduce_in_fp32=True,
overlap_grad_reduce=True,
overlap_param_gather=True,
),
)

trainer = run.Config(
nl.Trainer,
accelerator="gpu",
accumulate_grad_batches=1,
callbacks=callbacks,
devices=num_gpus_per_node,
limit_test_batches=50,
limit_val_batches=32,
log_every_n_steps=10,
max_steps=max_steps,
num_nodes=num_nodes,
plugins=bf16_mixed(),
strategy=strategy,
use_distributed_sampler=False,
val_check_interval=2000,
)

return trainer


@run.cli.factory(target=pretrain, name=NAME)
def pretrain_recipe(
dir: Optional[str] = None,
name: str = "default",
tokenizer_model: str = None,
num_nodes: int = 1,
num_gpus_per_node: int = 8,
fn=pretrain,
) -> run.Partial:
"""
Create a pre-training recipe for Mamba2 130M model.
This function sets up a complete configuration for pre-training, including
model, trainer, data, logging, optimization, and resumption settings.
Args:
dir (Optional[str]): Directory for saving logs and checkpoints.
name (str): Name of the pre-training run.
num_nodes (int): Number of compute nodes to use.
num_gpus_per_node (int): Number of GPUs per node.
fn (Callable): The pre-training function to use.
Returns:
run.Partial: Partial configuration for pre-training.
Examples:
CLI usage:
$ nemo llm pretrain --factory mamba2_130M
$ nemo llm pretrain --factory "mamba2_130M(num_nodes=1, name='my_pretrain')"
Python API usage:
>>> recipe = pretrain_recipe(name="mamba2_130M_pretrain", num_nodes=1)
>>> print(recipe)
Note:
For more details on pre-training LLMs with NeMo, see the pre-training
guide in the `examples/llm/pretrain/` directory.
"""
return run.Partial(
fn,
model=model(),
trainer=trainer(
num_nodes=num_nodes,
num_gpus_per_node=num_gpus_per_node,
callbacks=[run.Config(TimingCallback)],
),
data=run.Config(
MockDataModule,
seq_length=4096,
global_batch_size=8,
micro_batch_size=1,
tokenizer=tokenizer(tokenizer_model=tokenizer_model),
),
log=default_log(dir=dir, name=name, tensorboard_logger=tensorboard_logger(name=name)),
optim=distributed_fused_adam_with_cosine_annealing(max_lr=3e-4),
resume=default_resume(),
)


@run.cli.factory(target=finetune, name=NAME)
def finetune_recipe(
dir: Optional[str] = None,
name: str = "default",
resume_path: str = None,
tokenizer_model: str = None,
num_nodes: int = 1,
num_gpus_per_node: int = 8,
gbs: int = 8,
mbs: int = 1,
peft_scheme: Optional[str] = 'none',
) -> run.Partial:
"""
Create a fine-tuning recipe for Mamba2 130M model.
This function sets up a complete configuration for fine-tuning, including
model, trainer, data, logging, optimization, and resumption settings.
Args:
dir (Optional[str]): Directory for saving logs and checkpoints.
name (str): Name of the fine-tuning run.
resume_path (str): Path to the NeMo checkpoint (refer to notes below
on how to convert a pytorch checkpoint to NeMo)
tokenizer_model (str): Path to tokenizer model (defaults to None)
num_nodes (int): Number of compute nodes to use.
num_gpus_per_node (int): Number of GPUs per node.
Returns:
run.Partial: Partial configuration for fine-tuning.
Examples:
CLI usage:
$ nemo llm finetune --factory mamba2_130m
Python API usage:
>>> recipe = finetune_recipe(name="mamba2_130m_finetune", num_nodes=1)
>>> print(recipe)
Note:
This recipe uses the SQuAD dataset for fine-tuning. For more information
on fine-tuning LLMs with NeMo, see the fine-tuning guide in the
`examples/llm/finetune/` directory.
For converting an SSM pytorch checkpoint, use the following line of python code:
llm.GPTModel(llm.BaseMambaConfig130M(), tokenizer=tokenizer()).import_ckpt(
path="pytorch://ABSOLUTE_PATH_TO_CKPT/your_pytorch_state_dict_file",
model_config=llm.BaseMambaConfig130M())
This line will cache the nemo checkpoint to following directory:
/root/.cache/nemo/models/your_pytorch_state_dict_file
"""
nemo_resume = run.Config(
nl.AutoResume,
restore_config=run.Config(nl.RestoreConfig, path=resume_path),
)
strategy = run.Config(
nl.MegatronStrategy,
tensor_model_parallel_size=1,
pipeline_model_parallel_size=1,
gradient_as_bucket_view=True,
ckpt_load_optimizer=False,
ckpt_save_optimizer=False,
ckpt_async_save=False,
)
checkpoint_callback = run.Config(
nl.ModelCheckpoint,
every_n_train_steps=10,
dirpath=dir,
)
trainer = run.Config(
nl.Trainer,
accelerator="gpu",
accumulate_grad_batches=1,
devices=num_gpus_per_node,
limit_test_batches=10,
limit_val_batches=10,
log_every_n_steps=20,
max_steps=100,
num_nodes=num_nodes,
plugins=run.Config(
nl.MegatronMixedPrecision,
precision="bf16-mixed",
params_dtype=torch.bfloat16,
),
callbacks=[checkpoint_callback],
strategy=strategy,
use_distributed_sampler=False,
val_check_interval=20,
)
recipe = run.Partial(
llm.finetune,
model=model(tokenizer_model=tokenizer_model),
trainer=trainer,
data=run.Config(
llm.SquadDataModule,
seq_length=2048,
global_batch_size=gbs,
micro_batch_size=mbs,
tokenizer=tokenizer(tokenizer_model=tokenizer_model),
),
log=llm.default_log(dir=dir, name=name, tensorboard_logger=tensorboard_logger(name=name)),
optim=distributed_fused_adam_with_cosine_annealing(max_lr=1e-4, min_lr=0, warmup_steps=50),
resume=nemo_resume,
)
if peft_scheme is None or peft_scheme.lower() == 'none':
recipe.trainer.strategy.tensor_model_parallel_size = 1
recipe.optim.config.lr = 5e-6
else:
raise ValueError(f"Unrecognized peft scheme: {peft_scheme}")
return recipe
Loading

0 comments on commit d0be75c

Please sign in to comment.