Skip to content

Commit

Permalink
add rwkv support (#1198)
Browse files Browse the repository at this point in the history
* add rwkv support

* Update init_functions.py

* rwkv model files

* configs

* kernels

* Cleanup

* Update 760M.yml

* remove preffn and mishglu

* Update NeoXArgs docs automatically

* Add RWKV parallelism assertions

* Update NeoXArgs docs automatically

* pre-commit and config cleanup

* Update NeoXArgs docs automatically

* rwkv logging

* Update NeoXArgs docs automatically

* Add rwkv version dirname, make hdim 3.5x

* pre-commit

* Update NeoXArgs docs automatically

* fix bug and set batch size to 32

* Update NeoXArgs docs automatically

---------

Co-authored-by: Quentin Anthony <[email protected]>
Co-authored-by: github-actions <[email protected]>
  • Loading branch information
3 people authored May 6, 2024
1 parent c814959 commit 4bc6670
Show file tree
Hide file tree
Showing 13 changed files with 882 additions and 18 deletions.
8 changes: 5 additions & 3 deletions configs/neox_arguments.md
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,7 @@ Logging Arguments

- **git_hash**: str

Default = 47c93fb
Default = 6fb840e

current git hash of repository

Expand Down Expand Up @@ -432,7 +432,7 @@ Model Arguments
The first item in the list specifies the attention type(s), and should be a list of strings. The second item
specifies the number of times to repeat those attention types in the full list.
attention type choices: [global, local, sparse_fixed, sparse_variable, bslongformer, bigbird, "gmlp", "amlp", "flash", "mamba"]
attention type choices: [global, local, sparse_fixed, sparse_variable, bslongformer, bigbird, "gmlp", "amlp", "flash", "mamba", "rwkv"]
So a 12 layer network with only global attention could be specified like:
[[[`global`], 12]]
Expand Down Expand Up @@ -1965,7 +1965,9 @@ Args for deepspeed config
Default = None
Configuration for using bfloat16 floating-point format as an alternative to FP16. BFLOAT16 requires hardware support (e.g., NVIDIA A100). Dictionary options as described in Deepspeed documentation: https://www.deepspeed.ai/docs/config-json/#bfloat16-training-options
Configuration for using bfloat16 floating-point format as an alternative to FP16. BFLOAT16 requires hardware support (e.g., NVIDIA A100).
Dictionary options as described in Deepspeed documentation: https://www.deepspeed.ai/docs/config-json/#bfloat16-training-options
Expand Down
102 changes: 102 additions & 0 deletions configs/rwkv/170M.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,102 @@
{
# Parallelism is not yet supported for rwkv
"pipe_parallel_size": 1,
"model_parallel_size": 1,

"num_layers": 12,
"hidden_size": 768,
"num_attention_heads": 12, # head_size = dim_att / num_attention_heads.
# head_size is 64 for all rwkv models
"seq_length": 512,
"max_position_embeddings": 2048,
"output_layer_parallelism": "column",
"norm": "rmsnorm",
"rms_norm_epsilon": 1.0e-5,
"train_micro_batch_size_per_gpu": 32,

"attention_config": [[["rwkv"], 12]],

"activation": "silu",

# model settings

#"pos_emb": "rotary",
"rotary_pct": 0.25,
"no_weight_tying": true,
"gpt_j_residual": true,

# these should provide some speedup but takes a while to build, set to true if desired
"scaled_upper_triang_masked_softmax_fusion": false,
"bias_gelu_fusion": false,
"rope_fusion": false,
"layernorm_fusion": false,


# init methods
"init_method": "small_init",
"output_layer_init_method": "wang_init",

# optimizer settings
"optimizer": {
"type": "Adam",
"params": {
"lr": 0.0008,
"betas": [0.9, 0.95],
"eps": 1.0e-8,
}
},
"min_lr": 0.00008,

# for all zero_optimization options, see https://www.deepspeed.ai/docs/config-json/#zero-optimizations-for-fp16-training
"zero_optimization": {
"stage": 1,
"allgather_partitions": True,
"allgather_bucket_size": 500000000,
"overlap_comm": True,
"reduce_scatter": True,
"reduce_bucket_size": 500000000,
"contiguous_gradients": True,
},

# batch / data settings
"data_impl": "mmap",
"num_workers": 1,

# activation checkpointing
"checkpoint_activations": true,
"checkpoint_num_layers": 1,
"partition_activations": true,
"synchronize_each_layer": true,

# regularization
"gradient_clipping": 1.0,
"weight_decay": 0.1,
"hidden_dropout": 0,
"attention_dropout": 0,

# precision settings
"bf16": {
"bf16": true,
"enabled": true,
"loss_scale": 0,
"loss_scale_window": 1000,
"initial_scale_power": 12,
"hysteresis": 2,
"min_loss_scale": 1,
},

# misc. training settings
"train_iters": 500,
"lr_decay_iters": 500,
"distributed_backend": "nccl",
"lr_decay_style": "constant",
"warmup": 0.01,
"checkpoint_factor": 100,
"eval_interval": 100000,
"eval_iters": 10,

# logging
"log_interval": 10,
"steps_per_print": 10,
"wall_clock_breakdown": true,
}
39 changes: 27 additions & 12 deletions megatron/logging.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,19 +92,34 @@ def get_flops(neox_args, iter_time_s) -> float:
hidden_size = neox_args.hidden_size
num_layers = neox_args.num_layers
ckpt_activations_factor = 4 if neox_args.checkpoint_activations else 3
flops_per_iteration = (
24
* ckpt_activations_factor
* batch_size
* seq_len
* num_layers
* (hidden_size**2)
* (
1.0
+ (seq_len / (6.0 * hidden_size))
+ (vocab_size / (16.0 * num_layers * hidden_size))
if "rwkv" in neox_args.attention_config:
num_heads = neox_args.num_attention_heads

flops_per_iteration = (
batch_size
* seq_len
* (
78 * hidden_size * hidden_size * num_layers
+ 84 * hidden_size * num_layers
+ 16 * hidden_size
+ 12 * hidden_size * vocab_size
+ 18 * hidden_size * hidden_size * num_layers / num_heads
)
)
else:
flops_per_iteration = (
24
* ckpt_activations_factor
* batch_size
* seq_len
* num_layers
* (hidden_size**2)
* (
1.0
+ (seq_len / (6.0 * hidden_size))
+ (vocab_size / (16.0 * num_layers * hidden_size))
)
)
)
return flops_per_iteration / (iter_time_s * world_size)


Expand Down
10 changes: 10 additions & 0 deletions megatron/model/gpt2_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
ParallelLinear,
)
from megatron.model.gmlp import GMLPBlock
from megatron.model.rwkv.v6 import RWKVResidualLayerPipe
from megatron.model.mamba import ParallelMambaResidualLayerPipe
from megatron.model.word_embeddings import EmbeddingPipe, SoftEmbedding

Expand Down Expand Up @@ -175,6 +176,7 @@ def insert_layers(
"GMLPBlock",
"ParallelTransformerLayerPipe",
"ParallelMambaResidualLayerPipe",
"RWKVResidualLayerPipe",
],
)

Expand Down Expand Up @@ -251,6 +253,14 @@ def init_specs(self):
mask_fn=gpt2_attention_mask_func,
)
)
elif layer_type == "rwkv":
self.specs.append(
LayerSpec(
RWKVResidualLayerPipe,
neox_args=self.neox_args,
layer_number=i,
)
)
elif layer_type in ["mamba"]:
self.specs.append(
LayerSpec(
Expand Down
Empty file added megatron/model/rwkv/__init__.py
Empty file.
1 change: 1 addition & 0 deletions megatron/model/rwkv/v6/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from .rwkv import RWKVResidualLayerPipe, RWKVResidualLayer
Loading

0 comments on commit 4bc6670

Please sign in to comment.