Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

perf recipes and Mcore DistOpt params #10883

Merged
merged 18 commits into from
Oct 18, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions nemo/collections/llm/gpt/model/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -204,6 +204,9 @@ class GPTConfig5B(GPTConfig):
ffn_hidden_size: int = 16384
num_attention_heads: int = 32

bias_activation_fusion: bool = True
bias_dropout_add_fusion: bool = True


@dataclass
class GPTConfig7B(GPTConfig):
Expand All @@ -222,6 +225,9 @@ class GPTConfig20B(GPTConfig):
ffn_hidden_size: int = 24576
num_attention_heads: int = 48

bias_activation_fusion: bool = True
bias_dropout_add_fusion: bool = True


@dataclass
class GPTConfig40B(GPTConfig):
Expand All @@ -240,6 +246,9 @@ class GPTConfig175B(GPTConfig):
ffn_hidden_size: int = 49152
num_attention_heads: int = 96

bias_activation_fusion: bool = True
bias_dropout_add_fusion: bool = True


class GPTModel(L.LightningModule, io.IOMixin, io.ConnectorMixin, fn.FNMixin):
def __init__(
Expand Down
245 changes: 245 additions & 0 deletions nemo/collections/llm/recipes/gpt3_175b.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,245 @@
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


from typing import Callable, Optional

import nemo_run as run
import pytorch_lightning as pl
import torch
from megatron.core.distributed import DistributedDataParallelConfig
from pytorch_lightning.callbacks.callback import Callback

from nemo import lightning as nl
from nemo.collections.llm.api import pretrain
from nemo.collections.llm.gpt.data.mock import MockDataModule
from nemo.collections.llm.gpt.model import GPTConfig175B, GPTModel
from nemo.collections.llm.recipes.log.default import default_log, default_resume, tensorboard_logger
from nemo.collections.llm.recipes.optim.adam import distributed_fused_adam_with_cosine_annealing
from nemo.collections.llm.recipes.precision.mixed_precision import bf16_mixed
from nemo.collections.llm.recipes.tp_overlap_configs.userbuffers import (
userbuffers_bf16_h100_h12288_tp4_mbs1_seqlen2048,
)
from nemo.lightning.pytorch.callbacks.megatron_comm_overlap import MegatronCommOverlapCallback
from nemo.utils.exp_manager import TimingCallback

NAME = "gpt3_175b"


@run.cli.factory(name=NAME)
def model() -> run.Config[pl.LightningModule]:
"""
Factory function to create a GPT3 175B model configuration.

Returns:
run.Config[pl.LightningModule]: Configuration for the GPT3 175B model.

Examples:
CLI usage:
$ nemo llm pretrain model=gpt3_175b ...

Python API usage:
>>> model_config = model()
>>> print(model_config)
"""
return run.Config(GPTModel, config=run.Config(GPTConfig175B))


def trainer(
tensor_parallelism: int = 4,
pipeline_parallelism: int = 8,
pipeline_parallelism_type: Optional[torch.dtype] = torch.bfloat16,
virtual_pipeline_parallelism: Optional[int] = 6,
context_parallelism: int = 1,
sequence_parallelism: bool = True,
num_nodes: int = 64,
num_gpus_per_node: int = 8,
max_steps: int = 1168251,
callbacks: Optional[list[run.Config[Callback]]] = None,
) -> run.Config[nl.Trainer]:
"""
Configure the NeMo Lightning Trainer for GPT3 175B model.

This function sets up the distributed training strategy optimized for the large 175B model.

Args:
tensor_parallelism (int): Degree of tensor model parallelism.
pipeline_parallelism (int): Degree of pipeline model parallelism.
pipeline_parallelism_type (Optional[torch.dtype]): Data type for pipeline parallelism.
virtual_pipeline_parallelism (Optional[int]): Size of virtual pipeline parallelism.
context_parallelism (int): Degree of context parallelism.
sequence_parallelism (bool): Whether to use sequence parallelism.
num_nodes (int): Number of compute nodes to use.
num_gpus_per_node (int): Number of GPUs per node.
max_steps (int): Maximum number of training steps.
callbacks (Optional[list[run.Config[Callback]]]): List of callback configurations.

Returns:
run.Config[nl.Trainer]: Configuration for the NeMo Lightning Trainer.

Examples:
CLI usage:
$ nemo llm pretrain trainer=gpt3_175b ...

Python API usage:
>>> trainer_config = trainer(num_nodes=64, num_gpus_per_node=8)
>>> print(trainer_config)

Note:
This configuration uses extensive parallelism to handle the large model size efficiently.
"""
strategy = run.Config(
nl.MegatronStrategy,
tensor_model_parallel_size=tensor_parallelism,
pipeline_model_parallel_size=pipeline_parallelism,
pipeline_dtype=pipeline_parallelism_type,
virtual_pipeline_model_parallel_size=virtual_pipeline_parallelism,
context_parallel_size=context_parallelism,
sequence_parallel=sequence_parallelism,
gradient_as_bucket_view=True,
ckpt_async_save=True,
ckpt_parallel_load=True,
ddp=run.Config(
DistributedDataParallelConfig,
check_for_nan_in_grad=True,
grad_reduce_in_fp32=True,
overlap_grad_reduce=True,
overlap_param_gather=True,
average_in_collective=True,
),
)

trainer = run.Config(
nl.Trainer,
accelerator="gpu",
accumulate_grad_batches=1,
callbacks=callbacks,
devices=num_gpus_per_node,
limit_test_batches=50,
limit_val_batches=32,
log_every_n_steps=10,
max_steps=max_steps,
num_nodes=num_nodes,
plugins=bf16_mixed(),
strategy=strategy,
use_distributed_sampler=False,
val_check_interval=2000,
)

return trainer


@run.cli.factory(target=pretrain, name=NAME)
def pretrain_recipe(
dir: Optional[str] = None, name: str = "default", num_nodes: int = 1, num_gpus_per_node: int = 8, fn=pretrain
) -> run.Partial:
"""
Create a pre-training recipe for GPT3 175B model.

This function sets up a complete configuration for pre-training, including
model, trainer, data, logging, optimization, and resumption settings.

Args:
dir (Optional[str]): Directory for saving logs and checkpoints.
name (str): Name of the pre-training run.
num_nodes (int): Number of compute nodes to use.
num_gpus_per_node (int): Number of GPUs per node.
fn (Callable): The pre-training function to use.

Returns:
run.Partial: Partial configuration for pre-training.

Examples:
CLI usage:
$ nemo llm pretrain --factory gpt3_175b
$ nemo llm pretrain --factory "gpt3_175b(num_nodes=64, name='my_175b_pretrain')"

Python API usage:
>>> recipe = pretrain_recipe(name="gpt3_175b_pretrain", num_nodes=64)
>>> print(recipe)

Note:
This recipe is optimized for the large 175B model and requires significant computational resources.
"""
return run.Partial(
fn,
model=model(),
trainer=trainer(
num_nodes=num_nodes,
num_gpus_per_node=num_gpus_per_node,
callbacks=[run.Config(TimingCallback)],
),
data=run.Config(MockDataModule, seq_length=2048, global_batch_size=2048, micro_batch_size=2),
log=default_log(dir=dir, name=name, tensorboard_logger=tensorboard_logger(name=name)),
optim=distributed_fused_adam_with_cosine_annealing(max_lr=0.9e-4),
resume=default_resume(),
)


@run.cli.factory(target=pretrain, name=NAME + "_performance")
def pretrain_recipe_performance(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would be a bit afraid that this will blow up the number of recipes that we show in the CLI. Since we allow the user to call the recipe with arguments, what about adding a new arg to the pretrain recipes that will dictate if we use the performance one or not?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree @marcromeyn . But, should we do this as a follow-up PR as this would involve performance recipes for all models (I assume) and the interface for perf recipe/perf mode would also change. Reviewing would be difficult/erroneous?

Also, cc: @erhoo82 @JimmyZhang12

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@marcromeyn We need this performance recipe for every key model training because the baseline recipe is missing many performance features and it is hard to communicate with users without having a recipe that has all the important perf features. There is a high risk that a user can miss out many of them without asking us for the actual CLI command line every time.

Copy link
Collaborator

@erhoo82 erhoo82 Oct 16, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ericharper As we discussed, the baseline recipe is for debugging and we need a performance recipe for all model training.

BTW, are we supposed to have different recipes for different sequence lengths?
Also, one question on the global batch size here.
#10905

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Collaborator

@erhoo82 erhoo82 Oct 16, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ericharper
Yes, I see that it uses different a CP size for the target seqlen. So, I think it is only TP and CP mapping that needs to be changed. Do we need a full recipe only for this change? Also, I think the global batch size there seems to be wrong.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We decided previously to include the full recipe for this but we could revisit it. If there's a bug please send a fix.

@malay-nagda we can merge this as is, but please send a follow up PR with @marcromeyn 's suggestion

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have created another draft PR with @marcromeyn's suggestion- #10926
waiting for this PR to be merged so that I can update #10926

dir: Optional[str] = None,
name: str = "default",
num_nodes: int = 1,
num_gpus_per_node: int = 8,
fn: Callable = pretrain,
) -> run.Partial:
"""
Create a performance-optimized pre-training recipe for GPT3 175B model.

This recipe enables performance optimizations that may not be suitable for all use cases.
It builds upon the standard pre-training recipe and adds additional performance enhancements.

Args:
dir (Optional[str]): Directory for saving logs and checkpoints.
name (str): Name of the pre-training run.
num_nodes (int): Number of compute nodes to use.
num_gpus_per_node (int): Number of GPUs per node.
fn (Callable): The pre-training function to use.

Returns:
run.Partial: Partial configuration for performance-optimized pre-training.

Examples:
CLI usage:
$ nemo llm pretrain --factory "gpt3_175b.pretrain_recipe_performance(num_nodes=64, name='perf_pretrain')"

Python API usage:
>>> recipe = pretrain_recipe_performance(name="gpt3_175b_perf", num_nodes=64)
>>> print(recipe)

Note:
Use this recipe with caution and only when you need maximum performance.
It may not be suitable for all hardware configurations or use cases.
"""
recipe = pretrain_recipe(name=name, dir=dir, num_nodes=num_nodes, num_gpus_per_node=num_gpus_per_node, fn=fn)

# 'overlap_param_gather_with_optimizer_step' and 'align_param_gather' params are set automatically by MegatronCommOverlapCallback
# They are added here for user's knowledge
# overlap_param_gather_with_optimizer_step- If true, overlap param all-gather of first bucket with optimizer step.
# align_param_gather- If true, all PP stages launch param all-gathers simultaneously, else each PP stage launches independently as needed

recipe.trainer.callbacks.append(
run.Config(
MegatronCommOverlapCallback,
tp_comm_overlap=True,
tp_comm_overlap_cfg=userbuffers_bf16_h100_h12288_tp4_mbs1_seqlen2048,
defer_embedding_wgrad_compute=True,
wgrad_deferral_limit=50,
overlap_param_gather_with_optimizer_step=True,
align_param_gather=True,
)
)

return recipe
73 changes: 72 additions & 1 deletion nemo/collections/llm/recipes/llama31_405b.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,12 @@
# limitations under the License.


from typing import Optional
from typing import Callable, Optional

import nemo_run as run
import pytorch_lightning as pl
import torch
from megatron.core.distributed import DistributedDataParallelConfig
from pytorch_lightning.callbacks.callback import Callback

from nemo import lightning as nl
Expand All @@ -27,6 +28,10 @@
from nemo.collections.llm.recipes.log.default import default_log, default_resume, tensorboard_logger
from nemo.collections.llm.recipes.optim.adam import distributed_fused_adam_with_cosine_annealing
from nemo.collections.llm.recipes.precision.mixed_precision import bf16_mixed
from nemo.collections.llm.recipes.tp_overlap_configs.userbuffers import (
userbuffers_bf16_h100_h16384_tp8_cp2_mbs1_seqlen8192,
)
from nemo.lightning.pytorch.callbacks.megatron_comm_overlap import MegatronCommOverlapCallback
from nemo.utils.exp_manager import TimingCallback

NAME = "llama31_405b"
Expand Down Expand Up @@ -107,6 +112,14 @@ def trainer(
gradient_as_bucket_view=True,
ckpt_async_save=True,
ckpt_parallel_load=True,
ddp=run.Config(
DistributedDataParallelConfig,
check_for_nan_in_grad=True,
grad_reduce_in_fp32=True,
overlap_grad_reduce=True,
overlap_param_gather=True,
average_in_collective=True,
),
)

trainer = run.Config(
Expand Down Expand Up @@ -174,3 +187,61 @@ def pretrain_recipe(
optim=distributed_fused_adam_with_cosine_annealing(max_lr=3e-4),
resume=default_resume(),
)


@run.cli.factory(target=pretrain, name=NAME + "_performance")
def pretrain_recipe_performance(
dir: Optional[str] = None,
name: str = "default",
num_nodes: int = 1,
num_gpus_per_node: int = 8,
fn: Callable = pretrain,
) -> run.Partial:
"""
Create a performance-optimized pre-training recipe for Llama3.1 405B model.

This recipe enables performance optimizations that may not be suitable for all use cases.
It builds upon the standard pre-training recipe and adds additional performance enhancements.

Args:
dir (Optional[str]): Directory for saving logs and checkpoints.
name (str): Name of the pre-training run.
num_nodes (int): Number of compute nodes to use.
num_gpus_per_node (int): Number of GPUs per node.
fn (Callable): The pre-training function to use.

Returns:
run.Partial: Partial configuration for performance-optimized pre-training.

Examples:
CLI usage:
$ nemo llm pretrain --factory "llama31_405b.pretrain_recipe_performance(num_nodes=4, name='perf_pretrain')"

Python API usage:
>>> recipe = pretrain_recipe_performance(name="llama31_405b_perf", num_nodes=4)
>>> print(recipe)

Note:
Use this recipe with caution and only when you need maximum performance.
It may not be suitable for all hardware configurations or use cases.
"""
recipe = pretrain_recipe(name=name, dir=dir, num_nodes=num_nodes, num_gpus_per_node=num_gpus_per_node, fn=fn)

# 'overlap_param_gather_with_optimizer_step' and 'align_param_gather' params are set automatically by MegatronCommOverlapCallback
# They are added here for user's knowledge
# overlap_param_gather_with_optimizer_step- If true, overlap param all-gather of first bucket with optimizer step.
# align_param_gather- If true, all PP stages launch param all-gathers simultaneously, else each PP stage launches independently as needed

recipe.trainer.callbacks.append(
run.Config(
MegatronCommOverlapCallback,
tp_comm_overlap=True,
tp_comm_overlap_cfg=userbuffers_bf16_h100_h16384_tp8_cp2_mbs1_seqlen8192,
defer_embedding_wgrad_compute=True,
wgrad_deferral_limit=50,
overlap_param_gather_with_optimizer_step=True,
align_param_gather=True,
)
)

return recipe
Loading
Loading