Skip to content

Commit

Permalink
[Nemo-UX] Expose transformer_layer_spec inside GPTConfig (NVIDIA#9592)
Browse files Browse the repository at this point in the history
* Expose transformer_layer_spec inside GPTConfig

* Apply isort and black reformatting

Signed-off-by: marcromeyn <[email protected]>

* Expose layer-specs

* Apply isort and black reformatting

Signed-off-by: marcromeyn <[email protected]>

---------

Signed-off-by: marcromeyn <[email protected]>
Co-authored-by: marcromeyn <[email protected]>
Signed-off-by: tonyjie <[email protected]>
  • Loading branch information
2 people authored and tonyjie committed Aug 6, 2024
1 parent eb4684d commit 6d097a9
Show file tree
Hide file tree
Showing 2 changed files with 33 additions and 4 deletions.
4 changes: 4 additions & 0 deletions nemo/collections/llm/gpt/model/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
MaskedTokenLossReduction,
gpt_data_step,
gpt_forward_step,
local_layer_spec,
transformer_engine_layer_spec,
)
from nemo.collections.llm.gpt.model.gemma import (
CodeGemmaConfig2B,
Expand Down Expand Up @@ -56,4 +58,6 @@
"MaskedTokenLossReduction",
"gpt_data_step",
"gpt_forward_step",
"transformer_engine_layer_spec",
"local_layer_spec",
]
33 changes: 29 additions & 4 deletions nemo/collections/llm/gpt/model/base.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
from dataclasses import dataclass
from typing import TYPE_CHECKING, Callable, Dict, Literal, Optional
from typing import TYPE_CHECKING, Callable, Dict, Literal, Optional, Union

import pytorch_lightning as L
import torch
import torch.distributed
from megatron.core.models.gpt import gpt_layer_specs
from megatron.core.optimizer import OptimizerConfig
from megatron.core.transformer.spec_utils import ModuleSpec
from megatron.core.transformer.transformer_config import TransformerConfig
from torch import nn

Expand Down Expand Up @@ -63,6 +65,18 @@ def gpt_forward_step(model, batch) -> torch.Tensor:
return model(**forward_args)


def transformer_engine_layer_spec(config: "GPTConfig") -> ModuleSpec:
return gpt_layer_specs.get_gpt_layer_with_transformer_engine_spec(
num_experts=config.num_moe_experts, moe_grouped_gemm=config.moe_grouped_gemm, qk_layernorm=config.qk_layernorm
)


def local_layer_spec(config: "GPTConfig") -> ModuleSpec:
return gpt_layer_specs.get_gpt_layer_local_spec(
num_experts=config.num_moe_experts, moe_grouped_gemm=config.moe_grouped_gemm, qk_layernorm=config.qk_layernorm
)


@dataclass
class GPTConfig(TransformerConfig, io.IOMixin):
# From megatron.core.models.gpt.gpt_model.GPTModel
Expand All @@ -79,6 +93,7 @@ class GPTConfig(TransformerConfig, io.IOMixin):
# TODO: Move this to better places?
get_attention_mask_from_fusion: bool = False

transformer_layer_spec: Union[ModuleSpec, Callable[["GPTConfig"], ModuleSpec]] = transformer_engine_layer_spec
forward_step_fn: Callable = gpt_forward_step
data_step_fn: Callable = gpt_data_step

Expand All @@ -91,12 +106,15 @@ def configure_model(self, tokenizer) -> "MCoreGPTModel":
) % vp_size == 0, "Make sure the number of model chunks is the same across all pipeline stages."

from megatron.core import parallel_state
from megatron.core.models.gpt.gpt_layer_specs import get_gpt_layer_with_transformer_engine_spec
from megatron.core.models.gpt.gpt_model import GPTModel as MCoreGPTModel

transformer_layer_spec = self.transformer_layer_spec
if not isinstance(transformer_layer_spec, ModuleSpec):
transformer_layer_spec = transformer_layer_spec(self)

return MCoreGPTModel(
self,
transformer_layer_spec=get_gpt_layer_with_transformer_engine_spec(self.num_moe_experts),
transformer_layer_spec=transformer_layer_spec,
vocab_size=get_vocab_size(self, tokenizer.vocab_size, self.make_vocab_size_divisible_by),
max_sequence_length=self.seq_length,
fp16_lm_cross_entropy=self.fp16_lm_cross_entropy,
Expand Down Expand Up @@ -225,4 +243,11 @@ def get_packed_seq_params(batch):
)


__all__ = ["GPTModel", "GPTConfig", "gpt_data_step", "gpt_forward_step"]
__all__ = [
"GPTModel",
"GPTConfig",
"gpt_data_step",
"gpt_forward_step",
"transformer_engine_layer_spec",
"local_layer_spec",
]

0 comments on commit 6d097a9

Please sign in to comment.