Skip to content

Commit

Permalink
MegaBlocks release (#1102)
Browse files Browse the repository at this point in the history
* [Stage] Megablocks release (#241)

* V1 of MegaBlocks
---------

* fix hf ckptr

* rename

* lint

* lint

---------

Co-authored-by: Abhinav Venigalla <[email protected]>
Co-authored-by: Sasha Doubov <[email protected]>
Co-authored-by: Cheng Li <[email protected]>
Co-authored-by: Ning Wang <[email protected]>
Co-authored-by: Irene Dea <[email protected]>
Co-authored-by: Shashank Rajput <[email protected]>
Co-authored-by: Chuck Tang <[email protected]>
Co-authored-by: Jose Javier <[email protected]>
Co-authored-by: Angel Ruiz <[email protected]>
Co-authored-by: Denny Lee <[email protected]>
Co-authored-by: Jane Zhang <[email protected]>
Co-authored-by: Daniel King <[email protected]>
Co-authored-by: Chuck Tang <[email protected]>
Co-authored-by: Vitaliy Chiley <[email protected]>
  • Loading branch information
15 people authored Apr 9, 2024
1 parent 2939cc9 commit 53160f4
Show file tree
Hide file tree
Showing 26 changed files with 2,392 additions and 146 deletions.
4 changes: 4 additions & 0 deletions llmfoundry/callbacks/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@
from llmfoundry.callbacks.eval_gauntlet_callback import EvalGauntlet
from llmfoundry.callbacks.fdiff_callback import FDiffMetrics
from llmfoundry.callbacks.hf_checkpointer import HuggingFaceCheckpointer
from llmfoundry.callbacks.log_mbmoe_tok_per_expert_callback import \
MegaBlocksMoE_TokPerExpert
from llmfoundry.callbacks.monolithic_ckpt_callback import \
MonolithicCheckpointSaver
from llmfoundry.callbacks.resumption_callbacks import (GlobalLRScaling,
Expand All @@ -34,6 +36,7 @@
callbacks.register('scheduled_gc', func=ScheduledGarbageCollector)
callbacks.register('oom_observer', func=OOMObserver)
callbacks.register('eval_output_logging', func=EvalOutputLogging)
callbacks.register('mbmoe_tok_per_expert', func=MegaBlocksMoE_TokPerExpert)

callbacks_with_config.register('async_eval', func=AsyncEval)
callbacks_with_config.register('curriculum_learning', func=CurriculumLearning)
Expand All @@ -46,6 +49,7 @@
'ScheduledGarbageCollector',
'EvalGauntlet',
'HuggingFaceCheckpointer',
'MegaBlocksMoE_TokPerExpert',
'AsyncEval',
'CurriculumLearning',
]
70 changes: 58 additions & 12 deletions llmfoundry/callbacks/hf_checkpointer.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from typing import Any, Dict, List, Optional, Sequence, Union

import torch
import torch.nn as nn
from composer.core import Callback, Event, State, Time, TimeUnit
from composer.core.state import fsdp_state_dict_type_context
from composer.loggers import Logger, MLFlowLogger
Expand All @@ -24,6 +25,7 @@
parse_uri)
from composer.utils.misc import create_interval_scheduler
from mlflow.transformers import _fetch_model_card, _write_license_information
from packaging import version
from transformers import PreTrainedModel, PreTrainedTokenizerBase

from llmfoundry.models.mpt import MPTConfig, MPTForCausalLM
Expand Down Expand Up @@ -312,28 +314,72 @@ def _save_checkpoint(self, state: State, logger: Logger):
state_dict_model = state.model.model
original_tokenizer = state.model.tokenizer

state_dict_context = fsdp_state_dict_type_context(
original_model,
state_dict_type='full') if ((not state.is_model_ddp) and isinstance(
state_dict_model, FSDP)) else contextlib.nullcontext()

with state_dict_context:
state_dict = state_dict_model.state_dict()

# convert the state dict to the requested precision
for k, v in state_dict.items():
if isinstance(v, torch.Tensor):
state_dict[k] = v.to(dtype=self.dtype)
if version.parse(torch.__version__) > version.parse('2.2.9'):
from torch.distributed._tensor import DTensor
from torch.distributed.checkpoint.state_dict import (
StateDictOptions, get_model_state_dict)
cpu_offload = True

# Add a dtensor->cpu tensor hook to avoid CUDA OOM
def dtensor_to_tensor_hook(
module: nn.Module,
state_dict: Dict[str, Any],
prefix: str,
*args: Any,
) -> Dict[str, Any]:
dtensor_fqns = []
for fqn in state_dict.keys():
tensor = state_dict[fqn]
if isinstance(tensor, DTensor):
dtensor_fqns.append(fqn)
tensor = tensor.full_tensor() # type: ignore
if dist.get_global_rank() == 0:
if cpu_offload:
tensor = tensor.cpu()
state_dict[fqn] = tensor
if dist.get_global_rank() != 0:
for fqn in dtensor_fqns:
del state_dict[fqn]
return state_dict

hooks = []
for _, module in state_dict_model.named_modules():
if isinstance(module, FSDP):
hooks.append(
module._register_state_dict_hook(
dtensor_to_tensor_hook))

state_dict = get_model_state_dict(state_dict_model,
options=StateDictOptions(
full_state_dict=True,
cpu_offload=cpu_offload))
for hook in hooks:
hook.remove()
else:
state_dict_context = fsdp_state_dict_type_context(
original_model, state_dict_type='full') if (
(not state.is_model_ddp) and isinstance(
state_dict_model, FSDP)) else contextlib.nullcontext()
with state_dict_context:
state_dict = state_dict_model.state_dict()

# Convert the state dict to the requested precis
for k, v in state_dict.items():
if isinstance(v, torch.Tensor):
state_dict[k] = v.to(dtype=self.dtype)

new_model_instance = None # Need this for pyright because variable could be unbound

if dist.get_global_rank() == 0:
log.debug('Saving Hugging Face checkpoint in global rank 0')

# Edit HF config before building 2nd model copy
copied_config = copy.deepcopy(original_model.config)
if copied_config.model_type == 'mpt':
copied_config.attn_config['attn_impl'] = 'torch'
copied_config.init_device = 'cpu'
if 'moe_world_size' in getattr(copied_config, 'ffn_config', {}):
copied_config.ffn_config['moe_world_size'] = 1

log.debug(f'Creating new model instance')

Expand Down
140 changes: 140 additions & 0 deletions llmfoundry/callbacks/log_mbmoe_tok_per_expert_callback.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,140 @@
# Copyright 2022 MosaicML LLM Foundry authors
# SPDX-License-Identifier: Apache-2.0

"""Log tokens per expert for MegaBlocks MoE."""
from __future__ import annotations

import torch
from composer.core import Callback, State
from composer.loggers import Logger
from composer.utils import dist


class MegaBlocksMoE_TokPerExpert(Callback):
"""Log tokens per expert for MegaBlocks MoE.
To compute the load balancing loss, MegaBlocks caches information including `tokens_per_expert`
(tpe). At the :attr:`.Event.BATCH_END` event this callback gets load_balancing_loss from
MegaBlocks to get `tokens_per_expert` then logs statistics (<STAT>) of the number of tokens
assigned to experts for each layer index (l_idx) under ``mb_moe/layer<l_idx>_<STAT>_tpe``.
The tokens_per_expert statistics are logged by the :class:`.Logger` to the following keys as
described below.
+----------------------------------+-----------------------------------------------------------+
| Key | Logged data |
+==================================+===========================================================+
| `mb_moe/alllayer_min_tpe` | Minimum tokens per expert across all layers |
+----------------------------------+-----------------------------------------------------------+
| `mb_moe/alllayer_max_tpe` | Maximum tokens per expert across all layers |
+----------------------------------+-----------------------------------------------------------+
| `mb_moe/alllayer_median_tpe` | Median tokens per expert across all layers |
+----------------------------------+-----------------------------------------------------------+
| `mb_moe/alllayer_std_tpe` | Standard deviation of tokens per expert across all layers |
+----------------------------------+-----------------------------------------------------------+
| `mb_moe/layer<l_idx>_min_tpe` | Minimum tokens per expert at l_idx layer |
+----------------------------------+-----------------------------------------------------------+
| `mb_moe/layer<l_idx>_max_tpe` | Maximum tokens per expert at l_idx layer |
+----------------------------------+-----------------------------------------------------------+
| `mb_moe/layer<l_idx>_median_tpe` | Median tokens per expert at l_idx layer |
+----------------------------------+-----------------------------------------------------------+
| `mb_moe/layer<l_idx>_std_tpe` | Standard deviation of tokens per expert at l_idx layer |
+----------------------------------+-----------------------------------------------------------+
Args:
log_interval (int, optional): The interval on which to log (Default: 10).
log_every_layer (bool, optional): Enable logging ever layer's statisictics (True) or log
only aggregate statistics (Default: False).
all_reduce_stats (bool, optional): Enable aggregating statistics across gpus (True) or log
statistics for GPU 0 (Default: False).
normalize (bool, optional): Normalize token counts by total tokens (Default: True) or output
raw token count (False). When normalize is True, the callback displays the fraction of
unique tokens routed to each expert. When normalize is False, the callback displays the
total number of tokens routed to each expert.
"""

def __init__(
self,
log_interval: int = 10,
log_every_layer: bool = False,
all_reduce_stats: bool = False,
normalize: bool = True,
):
self.log_interval = log_interval
self.log_every_layer = log_every_layer
self.all_reduce_stats = all_reduce_stats
self.normalize = normalize

self.topk = None

def fit_start(self, state: State, logger: Logger) -> None:
if self.topk is None and self.normalize:
try:
from megablocks.layers.dmoe import dMoE
from megablocks.layers.moe import MoE
except:
raise RuntimeError(
'Requirements for MegaBlocks not installed; see install instructions in `README.md`.'
)
for module in state.model.modules():
if isinstance(module, (MoE, dMoE)):
self.topk = module.experts.args.moe_top_k
return

raise RuntimeError(
f'Callback not initialized correctly; self.topk not instantiated.'
)

def batch_end(self, state: State, logger: Logger) -> None:
if state.timestamp.batch.value % self.log_interval == 0:
try:
from megablocks.layers.moe import get_load_balancing_loss
except:
raise RuntimeError(
'Requirements for MegaBlocks not installed; see install instructions in `README.md`.'
)
tokens_per_expert, _ = zip(*get_load_balancing_loss())

tokens_per_expert = [
tpe.clone().detach() for tpe in tokens_per_expert
]
if self.all_reduce_stats:
for tpe in tokens_per_expert:
dist.all_reduce(tpe)

if self.normalize:
tokens_per_expert = [
tpe / (tpe.sum() / self.topk) for tpe in tokens_per_expert
]

all_tokens_per_expert = torch.concat(tokens_per_expert)

min_tpe = all_tokens_per_expert.min().item()
max_tpe = all_tokens_per_expert.max().item()
median_tpe = all_tokens_per_expert.median().item()
std_tpe = all_tokens_per_expert.float().std().item()

log_info = {
f'mb_moe/all_layers_min_tpe': min_tpe,
f'mb_moe/all_layers_max_tpe': max_tpe,
f'mb_moe/all_layers_median_tpe': median_tpe,
f'mb_moe/all_layers_std_tpe': std_tpe,
}

if self.log_every_layer:
for l_idx, tpe_layer in enumerate(tokens_per_expert):

min_tpe = tpe_layer.min().item()
max_tpe = tpe_layer.max().item()
median_tpe = tpe_layer.median().item()
std_tpe = tpe_layer.float().std().item()

log_info.update({
f'mb_moe/layer{l_idx}_min_tpe': min_tpe,
f'mb_moe/layer{l_idx}_max_tpe': max_tpe,
f'mb_moe/layer{l_idx}_median_tpe': median_tpe,
f'mb_moe/layer{l_idx}_std_tpe': std_tpe,
})

logger.log_metrics(log_info)
Loading

0 comments on commit 53160f4

Please sign in to comment.