Skip to content

Commit

Permalink
Add megablocks dropless MoE (#1192)
Browse files Browse the repository at this point in the history
* Add megablocks dropless MoE

* pre-commit

---------

Co-authored-by: Yang Zhang <[email protected]>
Co-authored-by: Quentin Anthony <[email protected]>
  • Loading branch information
3 people authored May 4, 2024
1 parent 06e5f0c commit 916c883
Show file tree
Hide file tree
Showing 10 changed files with 464 additions and 43 deletions.
75 changes: 75 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ Prior to 3/9/2023, GPT-NeoX relied on [DeeperSpeed](https://github.com/EleutherA
+ [Containerized Setup](#containerized-setup)
* [Usage](#usage)
- [Configuration](#configuration)
* [Mixture of Experts](#mixture-of-experts)
- [Datasets](#datasets)
* [Preconfigured Datasets](#preconfigured-datasets)
* [Using Custom Data](#using-custom-data)
Expand Down Expand Up @@ -322,6 +323,80 @@ These files are generally complete, but non-optimal. For example, depending on y

For a more detailed guide to the features available and how to configure them, see [the configuration README](configs/README.md), and for documentation of every possible argument, see [configs/neox_arguments.md](configs/neox_arguments.md).

## Mixture of Experts

GPT-NeoX includes multiple expert implementations for MoE. To select between them, specify `moe_type` of `megablocks` (default) or `deepspeed`.

Both are based on the DeepSpeed MoE parallelism framework, which supports tensor-expert-data parallelism.
Both allow you to toggle between token-dropping and dropless (default, and this is what Megablocks was designed for).
Sinkhorn routing to come soon!

For an example of a basic complete configuration, see configs/125M-dmoe.yml (for Megablocks dropless) or configs/125M-moe.yml.

Most MoE related configuration arguments are prefixed with `moe`. Some common configuration parameters and their defaults are as follows:

```
moe_type: megablocks
moe_num_experts: 1 # 1 disables MoE. 8 is a reasonable value.
moe_loss_coeff: 0.1
expert_interval: 2 # See details below
enable_expert_tensor_parallelism: false # See details below
moe_expert_parallel_size: 1 # See details below
moe_token_dropping: false
```

DeepSpeed can be further configured with the following:

```
moe_top_k: 1
moe_min_capacity: 4
moe_train_capacity_factor: 1.0 # Setting to 1.0
moe_eval_capacity_factor: 1.0 # Setting to 1.0
```

One MoE layer is present every `expert_interval` transformer layers including the first, so with 12 layers total:

```
0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11
```

Experts would be in these layers:

```
0, 2, 4, 6, 8, 10
```

By default, we use expert-data parallelism, so any available tensor parallelism (`model_parallel_size`) will be used for expert routing. For instance, given the following:

```
expert_parallel_size: 4
model_parallel_size: 2 # aka tensor parallelism
```

With 32 GPUs, the behavior will be look like:

- In non-expert layers:
- Tensor parallelism is 2. (There are 32 / 2 = 16 such tensor parallel groups, each of size 2.)
- Data parallelism implicitly becomes 32 / 2 = 16.
- In expert layers:
- There is no tensor parallelism.
- Expert parallelism is 4. (There are 32 / 4 = 8 expert parallel groups, each of size 4.)
- Data parallelism implicitly becomes 32 / 4 = 8. Some cross-node token routing happens as a result of this redivision of data parallelism between 16 and 8. To avoid it, ensure that `expert_parallel_size == model_parallel_size`.

Setting `enable_expert_tensor_parallelism` enables tensor-expert-data (TED) parallelism. The way to interpret the above would then be:

- In non-expert layers: same as before.
- In expert layers:
- Tensor parallelism is 2. (There are 32 / 2 = 16 tensor parallel groups, each of size 2.)
- Expert parallelism is 4. (There are 32 / 4 = 8 expert parallel groups, each of size 4.)
- Data parallelism implicitly becomes 32 / (2 * 4) = 4. Again, cross-node token routing happens. To avoid, ensure `expert_parallel_size == 1` or `model_parallel_size == 1`.

So note that DP must be divisible by (MP * EP). For more details, see the [TED paper].

Pipeline parallelism is not yet supported - coming soon!

[TED paper]: https://arxiv.org/abs/2303.06318

# Datasets

## Preconfigured Datasets
Expand Down
101 changes: 101 additions & 0 deletions configs/125M-dmoe.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,101 @@
# GPT-2 pretraining setup
{
# See README for MoE config docs!
"moe_type": "megablocks",
"moe_token_dropping": false,
# Have 4 experts per layer (every 2 layers by default)
"moe_num_experts": 4,
# parallelism settings
"enable_expert_tensor_parallelism": true,
"pipe_parallel_size": 1, # not yet supported for MoE
"model_parallel_size": 1,
"moe_expert_parallel_size": 1,

# model settings
"num_layers": 12,
"hidden_size": 768,
"num_attention_heads": 12,
"seq_length": 2048,
"max_position_embeddings": 2048,
"norm": "layernorm",
"pos_emb": "rotary",
"no_weight_tying": true,
"gpt_j_residual": false,
"output_layer_parallelism": "column",

# these should provide some speedup but takes a while to build, set to true if desired
"scaled_upper_triang_masked_softmax_fusion": false,
"bias_gelu_fusion": false,
"rope_fusion": false,

# init methods
"init_method": "small_init",
"output_layer_init_method": "wang_init",


# optimizer settings
"optimizer": {
"type": "Adam",
"params": {
"lr": 0.0006,
"betas": [0.9, 0.95],
"eps": 1.0e-8,
}
},
"min_lr": 0.00006,

# for all zero_optimization options, see https://www.deepspeed.ai/docs/config-json/#zero-optimizations-for-fp16-training
"zero_optimization": {
"stage": 0,
"allgather_partitions": True,
"allgather_bucket_size": 500000000,
"overlap_comm": True,
"reduce_scatter": True,
"reduce_bucket_size": 500000000,
"contiguous_gradients": True,
},

# batch / data settings
"train_micro_batch_size_per_gpu": 4,
"data_impl": "mmap",

# activation checkpointing
"checkpoint_activations": true,
"checkpoint_num_layers": 1,
"partition_activations": true,
"synchronize_each_layer": true,

# regularization
"gradient_clipping": 1.0,
"weight_decay": 0.1,
"hidden_dropout": 0.0,
"attention_dropout": 0.0,

# precision settings
"fp16": {
"enabled": true,
"loss_scale": 0,
"loss_scale_window": 1000,
"hysteresis": 2,
"min_loss_scale": 1
},

# misc. training settings
"train_iters": 320000,
"lr_decay_iters": 320000,
"distributed_backend": "nccl",
"lr_decay_style": "cosine",
"warmup": 0.01,
"checkpoint_factor": 10000,
"eval_interval": 1000,
"eval_iters": 10,

# logging
"log_interval": 10,
"steps_per_print": 10,
"keep_last_n_checkpoints": 4,
"wall_clock_breakdown": true,

# networking
"hostfile": "/mock_path"
}
16 changes: 7 additions & 9 deletions configs/125M-moe.yml
Original file line number Diff line number Diff line change
@@ -1,15 +1,13 @@
# GPT-2 pretraining setup
{
# See README for MoE config docs!
"moe_type": "deepspeed",
"moe_token_dropping": true,
# Have 4 experts per layer (every 2 layers by default)
# So with 12 layers total:
# 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11
# Experts would be in layers:
# 0, 2, 4, 6, 8, 10
"num_experts": 4,

# parallelism settings ( you will want to change these based on your cluster setup, ideally scheduling pipeline stages
# across the node boundaries )
"pipe_parallel_size": 1,
"moe_num_experts": 4,
# parallelism settings
"enable_expert_tensor_parallelism": true,
"pipe_parallel_size": 1, # not yet supported for MoE
"model_parallel_size": 1,
"moe_expert_parallel_size": 1,

Expand Down
12 changes: 6 additions & 6 deletions megatron/data/helpers.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -428,9 +428,9 @@ py::array build_mapping_impl(const py::array_t<int64_t>& docs_,
}

} // for (auto sent_index=sent_index_first; ...
} // if (num_remain_sent > 1) {
} // for (int doc=0; doc < num_docs; ++doc) {
} // for (int epoch=0; epoch < num_epochs; ++epoch) {
} // if (num_remain_sent > 1) {
} // for (int doc=0; doc < num_docs; ++doc) {
} // for (int epoch=0; epoch < num_epochs; ++epoch) {

if (!second) {
if (verbose) {
Expand Down Expand Up @@ -660,9 +660,9 @@ py::array build_blocks_mapping_impl(const py::array_t<int64_t>& docs_,
num_sent = 0;
}
} // for (auto sent_index=sent_index_first; ...
} // if (num_remain_sent > 1) {
} // for (int doc=0; doc < num_docs; ++doc) {
} // for (int epoch=0; epoch < num_epochs; ++epoch) {
} // if (num_remain_sent > 1) {
} // for (int doc=0; doc < num_docs; ++doc) {
} // for (int epoch=0; epoch < num_epochs; ++epoch) {

if (!second) {
if (verbose) {
Expand Down
34 changes: 34 additions & 0 deletions megatron/model/megablocks_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
"""Adapter to expose MegaBlocks package, if available."""

try:
import megablocks
except ImportError:
megablocks = None


def megablocks_is_available():
return megablocks is not None


def assert_megablocks_is_available():
assert (
megablocks_is_available()
), "MegaBlocks not available. Please run `pip install megablocks`."


moe = megablocks.layers.moe if megablocks_is_available() else None
dmoe = megablocks.layers.dmoe if megablocks_is_available() else None
arguments = megablocks.layers.arguments if megablocks_is_available() else None


def as_megablocks_args(neox_args):
import copy

tmp = copy.copy(neox_args)
delattr(tmp, "mlp_type")
tmp.mlp_type = "mlp"
args = arguments.from_megatron(tmp)
args.moe_lbl_in_fp32 = True
args.fp16 = neox_args.precision == "fp16"
args.moe_loss_weight = neox_args.moe_loss_coeff
return args
Loading

0 comments on commit 916c883

Please sign in to comment.