Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add rwkv support #1198

Merged
merged 26 commits into from
May 6, 2024
Merged
Show file tree
Hide file tree
Changes from 17 commits
Commits
Show all changes
26 commits
Select commit Hold shift + click to select a range
2741e89
add rwkv support
jahatef Mar 31, 2024
282f800
Update init_functions.py
jahatef Mar 31, 2024
09ba65e
rwkv model files
jahatef Mar 31, 2024
8892432
Merge branch 'rwkv' of github.com:EleutherAI/gpt-neox into rwkv
jahatef Mar 31, 2024
04b8fdb
configs
jahatef Mar 31, 2024
6e79cc2
kernels
jahatef Apr 10, 2024
cb49ff6
Cleanup
jahatef Apr 16, 2024
96ea6f5
Update 760M.yml
jahatef Apr 16, 2024
54f2775
remove preffn and mishglu
jahatef Apr 17, 2024
3606bdd
merge
jahatef Apr 17, 2024
8d60cef
Merge branch 'main' into rwkv
Quentin-Anthony Apr 19, 2024
276ffa9
Update NeoXArgs docs automatically
invalid-email-address Apr 19, 2024
e20138c
Add RWKV parallelism assertions
Quentin-Anthony Apr 19, 2024
428aad5
Update NeoXArgs docs automatically
invalid-email-address Apr 19, 2024
1b0bbab
pre-commit and config cleanup
Quentin-Anthony Apr 19, 2024
7550d64
Merge branch 'rwkv' of https://github.com/EleutherAI/gpt-neox into rwkv
Quentin-Anthony Apr 19, 2024
c0af563
Update NeoXArgs docs automatically
invalid-email-address Apr 19, 2024
1eb5f51
rwkv logging
jahatef May 3, 2024
330a802
Merge branch 'rwkv' of https://github.com/EleutherAI/gpt-neox into rwkv
jahatef May 3, 2024
1103663
Merge branch 'main' into rwkv
Quentin-Anthony May 4, 2024
a599ac7
Update NeoXArgs docs automatically
invalid-email-address May 4, 2024
682f7e5
Add rwkv version dirname, make hdim 3.5x
Quentin-Anthony May 4, 2024
921c41a
pre-commit
Quentin-Anthony May 4, 2024
8f60a43
Update NeoXArgs docs automatically
invalid-email-address May 4, 2024
6fb840e
fix bug and set batch size to 32
jahatef May 5, 2024
dd0138e
Update NeoXArgs docs automatically
invalid-email-address May 5, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 5 additions & 3 deletions configs/neox_arguments.md
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,7 @@ Logging Arguments

- **git_hash**: str

Default = 11a5537
Default = 7550d64

current git hash of repository

Expand Down Expand Up @@ -432,7 +432,7 @@ Model Arguments
The first item in the list specifies the attention type(s), and should be a list of strings. The second item
specifies the number of times to repeat those attention types in the full list.

attention type choices: [global, local, sparse_fixed, sparse_variable, bslongformer, bigbird, "gmlp", "amlp", "flash", "mamba"]
attention type choices: [global, local, sparse_fixed, sparse_variable, bslongformer, bigbird, "gmlp", "amlp", "flash", "mamba", "rwkv"]

So a 12 layer network with only global attention could be specified like:
[[[`global`], 12]]
Expand Down Expand Up @@ -1965,7 +1965,9 @@ Args for deepspeed config

Default = None

Configuration for using bfloat16 floating-point format as an alternative to FP16. BFLOAT16 requires hardware support (e.g., NVIDIA A100). Dictionary options as described in Deepspeed documentation: https://www.deepspeed.ai/docs/config-json/#bfloat16-training-options
Configuration for using bfloat16 floating-point format as an alternative to FP16. BFLOAT16 requires hardware support (e.g., NVIDIA A100).

Dictionary options as described in Deepspeed documentation: https://www.deepspeed.ai/docs/config-json/#bfloat16-training-options



Expand Down
102 changes: 102 additions & 0 deletions configs/rwkv/170M.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,102 @@
{
# Parallelism is not yet supported for rwkv
"pipe_parallel_size": 0,
"model_parallel_size": 1,

"num_layers": 12,
"hidden_size": 768,
"num_attention_heads": 12, # head_size = dim_att / num_attention_heads.
# head_size is 64 for all rwkv models
"seq_length": 512,
"max_position_embeddings": 2048,
"output_layer_parallelism": "column",
"norm": "rmsnorm",
"rms_norm_epsilon": 1.0e-5,
"train_micro_batch_size_per_gpu": 1,

"attention_config": [[["rwkv"], 12]],

"activation": "silu",

# model settings

#"pos_emb": "rotary",
"rotary_pct": 0.25,
"no_weight_tying": true,
"gpt_j_residual": true,

# these should provide some speedup but takes a while to build, set to true if desired
"scaled_upper_triang_masked_softmax_fusion": false,
"bias_gelu_fusion": false,
"rope_fusion": false,
"layernorm_fusion": false,


# init methods
"init_method": "small_init",
"output_layer_init_method": "wang_init",

# optimizer settings
"optimizer": {
"type": "Adam",
"params": {
"lr": 0.0008,
"betas": [0.9, 0.95],
"eps": 1.0e-8,
}
},
"min_lr": 0.00008,

# for all zero_optimization options, see https://www.deepspeed.ai/docs/config-json/#zero-optimizations-for-fp16-training
"zero_optimization": {
"stage": 1,
"allgather_partitions": True,
"allgather_bucket_size": 500000000,
"overlap_comm": True,
"reduce_scatter": True,
"reduce_bucket_size": 500000000,
"contiguous_gradients": True,
},

# batch / data settings
"data_impl": "mmap",
"num_workers": 1,

# activation checkpointing
"checkpoint_activations": true,
"checkpoint_num_layers": 1,
"partition_activations": true,
"synchronize_each_layer": true,

# regularization
"gradient_clipping": 1.0,
"weight_decay": 0.1,
"hidden_dropout": 0,
"attention_dropout": 0,

# precision settings
"bf16": {
"bf16": true,
"enabled": true,
"loss_scale": 0,
"loss_scale_window": 1000,
"initial_scale_power": 12,
"hysteresis": 2,
"min_loss_scale": 1,
},

# misc. training settings
"train_iters": 500,
"lr_decay_iters": 500,
"distributed_backend": "nccl",
"lr_decay_style": "constant",
"warmup": 0.01,
"checkpoint_factor": 100,
"eval_interval": 100000,
"eval_iters": 10,

# logging
"log_interval": 10,
"steps_per_print": 10,
"wall_clock_breakdown": true,
}
10 changes: 10 additions & 0 deletions megatron/model/gpt2_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
ParallelLinear,
)
from megatron.model.gmlp import GMLPBlock
from megatron.model.rwkv import RWKVResidualLayerPipe
from megatron.model.mamba import ParallelMambaResidualLayerPipe
from megatron.model.word_embeddings import EmbeddingPipe, SoftEmbedding

Expand Down Expand Up @@ -175,6 +176,7 @@ def insert_layers(
"GMLPBlock",
"ParallelTransformerLayerPipe",
"ParallelMambaResidualLayerPipe",
"RWKVResidualLayerPipe",
],
)

Expand Down Expand Up @@ -251,6 +253,14 @@ def init_specs(self):
mask_fn=gpt2_attention_mask_func,
)
)
elif layer_type == "rwkv":
self.specs.append(
LayerSpec(
RWKVResidualLayerPipe,
neox_args=self.neox_args,
layer_number=i,
)
)
elif layer_type in ["mamba"]:
self.specs.append(
LayerSpec(
Expand Down
1 change: 1 addition & 0 deletions megatron/model/rwkv/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from .rwkv import RWKVResidualLayerPipe, RWKVResidualLayer
Loading