diff --git a/configs/neox_arguments.md b/configs/neox_arguments.md index 686a358ab..c8e1492ae 100644 --- a/configs/neox_arguments.md +++ b/configs/neox_arguments.md @@ -111,7 +111,7 @@ Logging Arguments - **git_hash**: str - Default = 47c93fb + Default = 6fb840e current git hash of repository @@ -432,7 +432,7 @@ Model Arguments The first item in the list specifies the attention type(s), and should be a list of strings. The second item specifies the number of times to repeat those attention types in the full list. - attention type choices: [global, local, sparse_fixed, sparse_variable, bslongformer, bigbird, "gmlp", "amlp", "flash", "mamba"] + attention type choices: [global, local, sparse_fixed, sparse_variable, bslongformer, bigbird, "gmlp", "amlp", "flash", "mamba", "rwkv"] So a 12 layer network with only global attention could be specified like: [[[`global`], 12]] @@ -1965,7 +1965,9 @@ Args for deepspeed config Default = None - Configuration for using bfloat16 floating-point format as an alternative to FP16. BFLOAT16 requires hardware support (e.g., NVIDIA A100). Dictionary options as described in Deepspeed documentation: https://www.deepspeed.ai/docs/config-json/#bfloat16-training-options + Configuration for using bfloat16 floating-point format as an alternative to FP16. BFLOAT16 requires hardware support (e.g., NVIDIA A100). + + Dictionary options as described in Deepspeed documentation: https://www.deepspeed.ai/docs/config-json/#bfloat16-training-options diff --git a/configs/rwkv/170M.yml b/configs/rwkv/170M.yml new file mode 100644 index 000000000..11311f441 --- /dev/null +++ b/configs/rwkv/170M.yml @@ -0,0 +1,102 @@ +{ + # Parallelism is not yet supported for rwkv + "pipe_parallel_size": 1, + "model_parallel_size": 1, + + "num_layers": 12, + "hidden_size": 768, + "num_attention_heads": 12, # head_size = dim_att / num_attention_heads. + # head_size is 64 for all rwkv models + "seq_length": 512, + "max_position_embeddings": 2048, + "output_layer_parallelism": "column", + "norm": "rmsnorm", + "rms_norm_epsilon": 1.0e-5, + "train_micro_batch_size_per_gpu": 32, + + "attention_config": [[["rwkv"], 12]], + + "activation": "silu", + + # model settings + + #"pos_emb": "rotary", + "rotary_pct": 0.25, + "no_weight_tying": true, + "gpt_j_residual": true, + + # these should provide some speedup but takes a while to build, set to true if desired + "scaled_upper_triang_masked_softmax_fusion": false, + "bias_gelu_fusion": false, + "rope_fusion": false, + "layernorm_fusion": false, + + + # init methods + "init_method": "small_init", + "output_layer_init_method": "wang_init", + + # optimizer settings + "optimizer": { + "type": "Adam", + "params": { + "lr": 0.0008, + "betas": [0.9, 0.95], + "eps": 1.0e-8, + } + }, + "min_lr": 0.00008, + + # for all zero_optimization options, see https://www.deepspeed.ai/docs/config-json/#zero-optimizations-for-fp16-training + "zero_optimization": { + "stage": 1, + "allgather_partitions": True, + "allgather_bucket_size": 500000000, + "overlap_comm": True, + "reduce_scatter": True, + "reduce_bucket_size": 500000000, + "contiguous_gradients": True, + }, + + # batch / data settings + "data_impl": "mmap", + "num_workers": 1, + + # activation checkpointing + "checkpoint_activations": true, + "checkpoint_num_layers": 1, + "partition_activations": true, + "synchronize_each_layer": true, + + # regularization + "gradient_clipping": 1.0, + "weight_decay": 0.1, + "hidden_dropout": 0, + "attention_dropout": 0, + + # precision settings + "bf16": { + "bf16": true, + "enabled": true, + "loss_scale": 0, + "loss_scale_window": 1000, + "initial_scale_power": 12, + "hysteresis": 2, + "min_loss_scale": 1, + }, + + # misc. training settings + "train_iters": 500, + "lr_decay_iters": 500, + "distributed_backend": "nccl", + "lr_decay_style": "constant", + "warmup": 0.01, + "checkpoint_factor": 100, + "eval_interval": 100000, + "eval_iters": 10, + + # logging + "log_interval": 10, + "steps_per_print": 10, + "wall_clock_breakdown": true, +} diff --git a/megatron/logging.py b/megatron/logging.py index 6c9b7915e..247aeb1b5 100644 --- a/megatron/logging.py +++ b/megatron/logging.py @@ -92,19 +92,34 @@ def get_flops(neox_args, iter_time_s) -> float: hidden_size = neox_args.hidden_size num_layers = neox_args.num_layers ckpt_activations_factor = 4 if neox_args.checkpoint_activations else 3 - flops_per_iteration = ( - 24 - * ckpt_activations_factor - * batch_size - * seq_len - * num_layers - * (hidden_size**2) - * ( - 1.0 - + (seq_len / (6.0 * hidden_size)) - + (vocab_size / (16.0 * num_layers * hidden_size)) + if "rwkv" in neox_args.attention_config: + num_heads = neox_args.num_attention_heads + + flops_per_iteration = ( + batch_size + * seq_len + * ( + 78 * hidden_size * hidden_size * num_layers + + 84 * hidden_size * num_layers + + 16 * hidden_size + + 12 * hidden_size * vocab_size + + 18 * hidden_size * hidden_size * num_layers / num_heads + ) + ) + else: + flops_per_iteration = ( + 24 + * ckpt_activations_factor + * batch_size + * seq_len + * num_layers + * (hidden_size**2) + * ( + 1.0 + + (seq_len / (6.0 * hidden_size)) + + (vocab_size / (16.0 * num_layers * hidden_size)) + ) ) - ) return flops_per_iteration / (iter_time_s * world_size) diff --git a/megatron/model/gpt2_model.py b/megatron/model/gpt2_model.py index 9c86d98d3..9e643874a 100644 --- a/megatron/model/gpt2_model.py +++ b/megatron/model/gpt2_model.py @@ -37,6 +37,7 @@ ParallelLinear, ) from megatron.model.gmlp import GMLPBlock +from megatron.model.rwkv.v6 import RWKVResidualLayerPipe from megatron.model.mamba import ParallelMambaResidualLayerPipe from megatron.model.word_embeddings import EmbeddingPipe, SoftEmbedding @@ -175,6 +176,7 @@ def insert_layers( "GMLPBlock", "ParallelTransformerLayerPipe", "ParallelMambaResidualLayerPipe", + "RWKVResidualLayerPipe", ], ) @@ -251,6 +253,14 @@ def init_specs(self): mask_fn=gpt2_attention_mask_func, ) ) + elif layer_type == "rwkv": + self.specs.append( + LayerSpec( + RWKVResidualLayerPipe, + neox_args=self.neox_args, + layer_number=i, + ) + ) elif layer_type in ["mamba"]: self.specs.append( LayerSpec( diff --git a/megatron/model/rwkv/__init__.py b/megatron/model/rwkv/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/megatron/model/rwkv/v6/__init__.py b/megatron/model/rwkv/v6/__init__.py new file mode 100644 index 000000000..c0d8d4ba1 --- /dev/null +++ b/megatron/model/rwkv/v6/__init__.py @@ -0,0 +1 @@ +from .rwkv import RWKVResidualLayerPipe, RWKVResidualLayer diff --git a/megatron/model/rwkv/v6/cuda/wkv6_cuda.cu b/megatron/model/rwkv/v6/cuda/wkv6_cuda.cu new file mode 100644 index 000000000..2b228e90f --- /dev/null +++ b/megatron/model/rwkv/v6/cuda/wkv6_cuda.cu @@ -0,0 +1,270 @@ +#include +#include +#include "ATen/ATen.h" +typedef at::BFloat16 bf16; + +template +__global__ void kernel_forward(const int B, + const int T, + const int C, + const int H, + const F* __restrict__ const _r, + const F* __restrict__ const _k, + const F* __restrict__ const _v, + const float* __restrict__ _w, + const F* __restrict__ _u, + F* __restrict__ const _y) +{ + const int b = blockIdx.x / H; + const int h = blockIdx.x % H; + const int i = threadIdx.x; + _u += h * _N_; + + __shared__ float r[_N_], k[_N_], u[_N_], w[_N_]; + float state[_N_] = {0}; + + __syncthreads(); + u[i] = float(_u[i]); + __syncthreads(); + + for (int t = b * T * C + h * _N_ + i; t < (b + 1) * T * C + h * _N_ + i; t += C) { + __syncthreads(); + w[i] = exp(_w[t]); + r[i] = float(_r[t]); + k[i] = float(_k[t]); + __syncthreads(); + + const float v = float(_v[t]); + float y = 0; + +#pragma unroll + for (int j = 0; j < _N_; j += 4) { + const float4& r_ = (float4&)(r[j]); + const float4& k_ = (float4&)(k[j]); + const float4& w_ = (float4&)(w[j]); + const float4& u_ = (float4&)(u[j]); + float4& s = (float4&)(state[j]); + float4 x; + + x.x = k_.x * v; + x.y = k_.y * v; + x.z = k_.z * v; + x.w = k_.w * v; + + y += r_.x * (u_.x * x.x + s.x); + y += r_.y * (u_.y * x.y + s.y); + y += r_.z * (u_.z * x.z + s.z); + y += r_.w * (u_.w * x.w + s.w); + + s.x = s.x * w_.x + x.x; + s.y = s.y * w_.y + x.y; + s.z = s.z * w_.z + x.z; + s.w = s.w * w_.w + x.w; + } + _y[t] = F(y); + } +} + +template +__global__ void kernel_backward_111(const int B, + const int T, + const int C, + const int H, + const F* __restrict__ const _r, + const F* __restrict__ const _k, + const F* __restrict__ const _v, + const float* __restrict__ _w, + const F* __restrict__ _u, + const F* __restrict__ const _gy, + F* __restrict__ const _gr, + F* __restrict__ const _gk, + F* __restrict__ const _gv, + F* __restrict__ const _gu) +{ + const int b = blockIdx.x / H; + const int h = blockIdx.x % H; + const int i = threadIdx.x; + _u += h * _N_; + + __shared__ float u_[_N_]; + __shared__ float r[_N_], k[_N_], v[_N_], w_[_N_], gy[_N_]; + __syncthreads(); + u_[i] = float(_u[i]); + __syncthreads(); + + const float u = u_[i]; + + float state[_N_] = {0}, scccc[_N_] = {0}, sdddd[_N_] = {0}; + + const int t_0 = b * T * C + h * _N_ + i; + const int t_T_1 = t_0 + (T - 1) * C; + const int t_T = t_0 + T * C; + + float gu = 0; + for (int t = t_0; t < t_T; t += C) { + __syncthreads(); + v[i] = float(_v[t]); + gy[i] = float(_gy[t]); + __syncthreads(); + + const float k = float(_k[t]); + const float w = exp(_w[t]); + float gr = 0, gu_ = 0; + +#pragma unroll + for (int j = 0; j < _N_; j++) { + float& s = state[j]; + float x = k * v[j]; + + gr += (u * x + s) * gy[j]; + gu_ += x * gy[j]; + s = s * w + x; + } + _gr[t] = F(gr); + gu += float(_r[t]) * gu_; + } + _gu[b * C + h * _N_ + i] = F(gu); + + for (int t = t_T_1; t >= t_0; t -= C) { + __syncthreads(); + v[i] = float(_v[t]); + gy[i] = float(_gy[t]); + __syncthreads(); + + const float rr = float(_r[t]); + const float w = exp(_w[t]); + float gk = 0; + +#pragma unroll + for (int j = 0; j < _N_; j++) { + float& s = scccc[j]; + float x = rr * gy[j]; + + gk += (u * x + s) * v[j]; + s = x + s * w; + } + _gk[t] = F(gk); + } + + for (int t = t_T_1; t >= t_0; t -= C) { + __syncthreads(); + r[i] = float(_r[t]); + k[i] = float(_k[t]); + w_[i] = exp(_w[t]); + __syncthreads(); + + const float gyy = float(_gy[t]); + float gv = 0; + +#pragma unroll + for (int j = 0; j < _N_; j++) { + float& s = sdddd[j]; + float x = gyy * r[j]; + + gv += (u_[j] * x + s) * k[j]; + s = x + s * w_[j]; + } + _gv[t] = F(gv); + } +} + +template +__global__ void kernel_backward_222(const int B, + const int T, + const int C, + const int H, + const F* __restrict__ const _r, + const F* __restrict__ const _k, + const F* __restrict__ const _v, + const float* __restrict__ _w, + const F* __restrict__ _u, + const F* __restrict__ const _gy, + F* __restrict__ const _gw) +{ + const int b = blockIdx.x / H; + const int h = blockIdx.x % H; + const int i = threadIdx.x; + + __shared__ float v[_N_], gy[_N_]; + float saaaa[_N_] = {0}, sbbbb[_T_ - 2] = {0}, scccc[_N_] = {0}; + + const int t_0 = b * T * C + h * _N_ + i; + const int t_1 = t_0 + C; + const int t_2 = t_0 + 2 * C; + const int t_T_1 = t_0 + (T - 1) * C; + + for (int t = t_T_1; t > t_1; t -= C) { + __syncthreads(); + gy[i] = float(_gy[t]); + v[i] = float(_v[t - 2 * C]); + __syncthreads(); + + const float r = float(_r[t]); + const float w = exp(_w[t - C]); + float sum = 0.0f; + +#pragma unroll + for (int j = 0; j < _N_; j++) { + float& s = saaaa[j]; + float x = r * gy[j]; + s = (s + x) * w; + sum += s * v[j]; + } + sbbbb[(t - t_2) / C] = sum * float(_k[t - 2 * C]); + } + + float sss = sbbbb[0]; + _gw[t_0] = 0; + _gw[t_1] = F(sss * _w[t_1]); + + for (int t = t_2; t < t_T_1; t += C) { + __syncthreads(); + gy[i] = float(_gy[t]); + v[i] = float(_v[t - 2 * C]); + __syncthreads(); + + const float w = exp(_w[t - C]); + const float k = float(_k[t - 2 * C]); + float sum = 0.0f; + +#pragma unroll + for (int j = 0; j < _N_; j++) { + float& s = scccc[j]; + float x = k * v[j]; + s = (s + x) * w; + sum += s * gy[j]; + } + sss += sbbbb[(t - t_1) / C] - (sum * float(_r[t])); + _gw[t] = F(sss * _w[t]); + } + _gw[t_T_1] = 0; +} + +void cuda_forward(int B, int T, int C, int H, bf16* r, bf16* k, bf16* v, float* w, bf16* u, bf16* y) +{ + assert(H * _N_ == C); + assert(_N_ % 4 == 0); + kernel_forward<<>>(B, T, C, H, r, k, v, w, u, y); +} + +void cuda_backward(int B, + int T, + int C, + int H, + bf16* r, + bf16* k, + bf16* v, + float* w, + bf16* u, + bf16* gy, + bf16* gr, + bf16* gk, + bf16* gv, + bf16* gw, + bf16* gu) +{ + assert(H * _N_ == C); + assert(_N_ % 4 == 0); + kernel_backward_111<<>>(B, T, C, H, r, k, v, w, u, gy, gr, gk, gv, gu); + kernel_backward_222<<>>(B, T, C, H, r, k, v, w, u, gy, gw); +} diff --git a/megatron/model/rwkv/v6/cuda/wkv6_op.cpp b/megatron/model/rwkv/v6/cuda/wkv6_op.cpp new file mode 100644 index 000000000..385b47487 --- /dev/null +++ b/megatron/model/rwkv/v6/cuda/wkv6_op.cpp @@ -0,0 +1,95 @@ +#include +#include "ATen/ATen.h" +typedef at::BFloat16 bf16; + +void cuda_forward(int B, + int T, + int C, + int H, + bf16* r, + bf16* k, + bf16* v, + float* w, + bf16* u, + bf16* y); +void cuda_backward(int B, + int T, + int C, + int H, + bf16* r, + bf16* k, + bf16* v, + float* w, + bf16* u, + bf16* gy, + bf16* gr, + bf16* gk, + bf16* gv, + bf16* gw, + bf16* gu); + +void forward(int64_t B, + int64_t T, + int64_t C, + int64_t H, + torch::Tensor& r, + torch::Tensor& k, + torch::Tensor& v, + torch::Tensor& w, + torch::Tensor& u, + torch::Tensor& y) +{ + cuda_forward(B, + T, + C, + H, + r.data_ptr(), + k.data_ptr(), + v.data_ptr(), + w.data_ptr(), + u.data_ptr(), + y.data_ptr()); +} +void backward(int64_t B, + int64_t T, + int64_t C, + int64_t H, + torch::Tensor& r, + torch::Tensor& k, + torch::Tensor& v, + torch::Tensor& w, + torch::Tensor& u, + torch::Tensor& gy, + torch::Tensor& gr, + torch::Tensor& gk, + torch::Tensor& gv, + torch::Tensor& gw, + torch::Tensor& gu) +{ + cuda_backward(B, + T, + C, + H, + r.data_ptr(), + k.data_ptr(), + v.data_ptr(), + w.data_ptr(), + u.data_ptr(), + gy.data_ptr(), + gr.data_ptr(), + gk.data_ptr(), + gv.data_ptr(), + gw.data_ptr(), + gu.data_ptr()); +} +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) +{ + m.def("forward", &forward, "wkv6 forward"); + m.def("backward", &backward, "wkv6 backward"); +} + +TORCH_LIBRARY(wkv6, m) +{ + m.def("forward", forward); + m.def("backward", backward); +} diff --git a/megatron/model/rwkv/v6/rwkv.py b/megatron/model/rwkv/v6/rwkv.py new file mode 100644 index 000000000..5d4e0d144 --- /dev/null +++ b/megatron/model/rwkv/v6/rwkv.py @@ -0,0 +1,357 @@ +######################################################################################################## +# The RWKV Language Model - https://github.com/BlinkDL/RWKV-LM +######################################################################################################## + +import os, math, gc, importlib +import torch +import torch.nn as nn +from torch.nn import functional as F +from torch.utils.cpp_extension import load + + +class WKV(torch.autograd.Function): + """ + WKV block, using cuda kernel. + """ + + @staticmethod + def forward(ctx, B, T, C, H, r, k, v, w, u): + with torch.no_grad(): + assert r.dtype == torch.bfloat16 + assert k.dtype == torch.bfloat16 + assert v.dtype == torch.bfloat16 + assert w.dtype == torch.bfloat16 + assert u.dtype == torch.bfloat16 + ctx.B = B + ctx.T = T + ctx.C = C + ctx.H = H + assert r.is_contiguous() + assert k.is_contiguous() + assert v.is_contiguous() + assert w.is_contiguous() + assert u.is_contiguous() + ew = (-torch.exp(w.float())).contiguous() + ctx.save_for_backward(r, k, v, ew, u) + y = torch.empty( + (B, T, C), + device=r.device, + dtype=torch.bfloat16, + memory_format=torch.contiguous_format, + ) # .uniform_(-100, 100) + wkv_cuda.forward(B, T, C, H, r, k, v, ew, u, y) + return y + + @staticmethod + def backward(ctx, gy): + with torch.no_grad(): + assert gy.dtype == torch.bfloat16 + B = ctx.B + T = ctx.T + C = ctx.C + H = ctx.H + assert gy.is_contiguous() + r, k, v, ew, u = ctx.saved_tensors + gr = torch.empty( + (B, T, C), + device=gy.device, + requires_grad=False, + dtype=torch.bfloat16, + memory_format=torch.contiguous_format, + ) # .uniform_(-100, 100) + gk = torch.empty( + (B, T, C), + device=gy.device, + requires_grad=False, + dtype=torch.bfloat16, + memory_format=torch.contiguous_format, + ) # .uniform_(-100, 100) + gv = torch.empty( + (B, T, C), + device=gy.device, + requires_grad=False, + dtype=torch.bfloat16, + memory_format=torch.contiguous_format, + ) # .uniform_(-100, 100) + gw = torch.empty( + (B, T, C), + device=gy.device, + requires_grad=False, + dtype=torch.bfloat16, + memory_format=torch.contiguous_format, + ) # .uniform_(-100, 100) + gu = torch.empty( + (B, C), + device=gy.device, + requires_grad=False, + dtype=torch.bfloat16, + memory_format=torch.contiguous_format, + ) # .uniform_(-100, 100) + wkv_cuda.backward(B, T, C, H, r, k, v, ew, u, gy, gr, gk, gv, gw, gu) + gu = torch.sum(gu, 0).view(H, C // H) + return (None, None, None, None, gr, gk, gv, gw, gu) + + +def RUN_CUDA_RWKV(B, T, C, H, r, k, v, w, u): + return WKV.apply(B, T, C, H, r, k, v, w, u) + + +# RWKV6 time mix +class RWKV_TimeMix(nn.Module): + """ + Time Mixing Layer + The RWKV substitute for attention. + TODO: fix jit compiling. + """ + + def __init__(self, neox_args, layer_number): + super().__init__() + self.neox_args = neox_args + self.layer_number = layer_number + + with torch.no_grad(): + ratio_0_to_1 = layer_number / (neox_args.num_layers - 1) # 0 to 1 + ratio_1_to_almost0 = 1.0 - (layer_number / neox_args.num_layers) # 1 to ~0 + ddd = torch.ones(1, 1, neox_args.hidden_size) + for i in range(neox_args.hidden_size): + ddd[0, 0, i] = i / neox_args.hidden_size + + # fancy time_mix + self.time_maa_x = nn.Parameter(1.0 - torch.pow(ddd, ratio_1_to_almost0)) + self.time_maa_w = nn.Parameter(1.0 - torch.pow(ddd, ratio_1_to_almost0)) + self.time_maa_k = nn.Parameter(1.0 - torch.pow(ddd, ratio_1_to_almost0)) + self.time_maa_v = nn.Parameter( + 1.0 - (torch.pow(ddd, ratio_1_to_almost0) + 0.3 * ratio_0_to_1) + ) + self.time_maa_r = nn.Parameter( + 1.0 - torch.pow(ddd, 0.5 * ratio_1_to_almost0) + ) + self.time_maa_g = nn.Parameter( + 1.0 - torch.pow(ddd, 0.5 * ratio_1_to_almost0) + ) + + TIME_MIX_EXTRA_DIM = 32 # generate TIME_MIX for w,k,v,r,g + self.time_maa_w1 = nn.Parameter( + torch.zeros(neox_args.hidden_size, TIME_MIX_EXTRA_DIM * 5).uniform_( + -1e-4, 1e-4 + ) + ) + self.time_maa_w2 = nn.Parameter( + torch.zeros(5, TIME_MIX_EXTRA_DIM, neox_args.hidden_size).uniform_( + -1e-4, 1e-4 + ) + ) + + # fancy time_decay + decay_speed = torch.ones(neox_args.dim_att) + for n in range(neox_args.dim_att): + decay_speed[n] = -6 + 5 * (n / (neox_args.dim_att - 1)) ** ( + 0.7 + 1.3 * ratio_0_to_1 + ) + self.time_decay = nn.Parameter(decay_speed.reshape(1, 1, neox_args.dim_att)) + + TIME_DECAY_EXTRA_DIM = 64 + self.time_decay_w1 = nn.Parameter( + torch.zeros(neox_args.hidden_size, TIME_DECAY_EXTRA_DIM).uniform_( + -1e-4, 1e-4 + ) + ) + self.time_decay_w2 = nn.Parameter( + torch.zeros(TIME_DECAY_EXTRA_DIM, neox_args.dim_att).uniform_( + -1e-4, 1e-4 + ) + ) + + tmp = torch.zeros(neox_args.dim_att) + for n in range(neox_args.dim_att): + zigzag = ((n + 1) % 3 - 1) * 0.1 + tmp[n] = ratio_0_to_1 * (1 - (n / (neox_args.dim_att - 1))) + zigzag + + self.time_faaaa = nn.Parameter( + tmp.reshape(neox_args.num_attention_heads, neox_args.head_size) + ) + + self.time_shift = nn.ZeroPad2d((0, 0, 1, -1)) + self.receptance = nn.Linear( + neox_args.hidden_size, neox_args.dim_att, bias=False + ) + self.key = nn.Linear(neox_args.hidden_size, neox_args.dim_att, bias=False) + + self.value = nn.Linear(neox_args.hidden_size, neox_args.dim_att, bias=False) + self.output = nn.Linear(neox_args.dim_att, neox_args.hidden_size, bias=False) + self.gate = nn.Linear(neox_args.hidden_size, neox_args.dim_att, bias=False) + self.ln_x = nn.GroupNorm( + neox_args.num_attention_heads, neox_args.dim_att, eps=(1e-5) * (8**2) + ) + + def jit_func(self, x): + B, T, C = x.size() + + xx = self.time_shift(x) - x + + xxx = x + xx * self.time_maa_x + xxx = torch.tanh(xxx @ self.time_maa_w1).view(B * T, 5, -1).transpose(0, 1) + xxx = torch.bmm(xxx, self.time_maa_w2).view(5, B, T, -1) + mw, mk, mv, mr, mg = xxx.unbind(dim=0) + + xw = x + xx * (self.time_maa_w + mw) + xk = x + xx * (self.time_maa_k + mk) + xv = x + xx * (self.time_maa_v + mv) + xr = x + xx * (self.time_maa_r + mr) + xg = x + xx * (self.time_maa_g + mg) + + r = self.receptance(xr) + k = self.key(xk) + v = self.value(xv) + g = F.silu(self.gate(xg)) + + ww = torch.tanh(xw @ self.time_decay_w1) @ self.time_decay_w2 + w = self.time_decay + ww + + return r, k, v, g, w + + def jit_func_2(self, x, g): + B, T, C = x.size() + x = x.view(B * T, C) + + x = self.ln_x(x).view(B, T, C) + x = self.output(x * g) + return x + + def forward(self, x): + B, T, C = x.size() + H = self.neox_args.num_attention_heads + + r, k, v, g, w = self.jit_func(x) + x = RUN_CUDA_RWKV(B, T, C, H, r, k, v, w, u=self.time_faaaa) + + return self.jit_func_2(x, g) + + +class RWKV_ChannelMix(nn.Module): + """ + Channel Mix layer. The ffn in RWKV + """ + + def __init__(self, neox_args, layer_number): + super().__init__() + self.neox_args = neox_args + self.layer_number = layer_number + self.time_shift = nn.ZeroPad2d((0, 0, 1, -1)) + + with torch.no_grad(): # fancy init of time_mix + ratio_1_to_almost0 = 1.0 - (layer_number / neox_args.num_layers) # 1 to ~0 + ddd = torch.ones(1, 1, neox_args.hidden_size) + for i in range(neox_args.hidden_size): + ddd[0, 0, i] = i / neox_args.hidden_size + self.time_maa_k = nn.Parameter(1.0 - torch.pow(ddd, ratio_1_to_almost0)) + self.time_maa_r = nn.Parameter(1.0 - torch.pow(ddd, ratio_1_to_almost0)) + + self.key = nn.Linear(neox_args.hidden_size, neox_args.dim_ffn, bias=False) + self.receptance = nn.Linear( + neox_args.hidden_size, neox_args.hidden_size, bias=False + ) + self.value = nn.Linear(neox_args.dim_ffn, neox_args.hidden_size, bias=False) + + def forward(self, x): + xx = self.time_shift(x) - x + xk = x + xx * self.time_maa_k + xr = x + xx * self.time_maa_r + + k = self.key(xk) + k = torch.relu(k) ** 2 + kv = self.value(k) + return torch.sigmoid(self.receptance(xr)) * kv + + +class RWKVResidualLayer(nn.Module): + """ + RWKV layer definition + """ + + def __init__(self, neox_args, layer_number): + super().__init__() + self.neox_args = neox_args + self.layer_number = layer_number + self.fp16 = neox_args.precision == "fp16" + self.bf16 = neox_args.precision == "bfloat16" + if not hasattr(neox_args, "dim_att"): + neox_args.dim_att = neox_args.hidden_size + if not hasattr(neox_args, "dim_ffn"): + # Make hidden size 3.5x. Round to nearest multiple of 32 until we add hdim rounding logic + neox_args.dim_ffn = int((neox_args.hidden_size * 3.5) // 32 * 32) + assert neox_args.hidden_size % 32 == 0 + assert neox_args.dim_att % 32 == 0 + assert neox_args.dim_ffn % 32 == 0 + self.neox_args.head_size = neox_args.dim_att // neox_args.num_attention_heads + self.head_size = self.neox_args.head_size + self.num_attention_heads = neox_args.num_attention_heads + assert neox_args.dim_att % self.num_attention_heads == 0 + + if neox_args.attention_dropout > 0: + self.drop0 = nn.Dropout(p=neox_args.attention_dropout) + + self.ln1 = nn.LayerNorm(neox_args.hidden_size) + self.ln2 = nn.LayerNorm(neox_args.hidden_size) + + self.att = RWKV_TimeMix(neox_args, layer_number) + + self.ffn = RWKV_ChannelMix(neox_args, layer_number) + + if neox_args.attention_dropout > 0: + self.drop0 = nn.Dropout(p=neox_args.attention_dropout) + if neox_args.hidden_dropout > 0: + self.drop1 = nn.Dropout(p=neox_args.hidden_dropout) + + if layer_number == 0: + global wkv_cuda + """ + Load cuda kernel at runtime. The kernel uses run time variables to build, ideally it should not. + """ + wkv_cuda = load( + name="wkv6", + sources=[ + "megatron/model/rwkv/v6/cuda/wkv6_op.cpp", + f"megatron/model/rwkv/v6/cuda/wkv6_cuda.cu", + ], + verbose=True, + extra_cuda_cflags=[ + "-res-usage", + "--use_fast_math", + "-O3", + "-Xptxas -O3", + "--extra-device-vectorization", + f"-D_N_={self.neox_args.head_size}", + f"-D_T_={self.neox_args.seq_length}", + ], + ) + + def forward(self, x): + neox_args = self.neox_args + B, T, C = x.size() + if self.layer_number == 0: + x = self.ln1(x) + + if self.neox_args.attention_dropout == 0: + x = x + self.att(self.ln1(x)) + else: + x = self.drop0(x + self.att(self.ln1(x))) + + if self.neox_args.hidden_dropout == 0: + x = x + self.ffn(self.ln2(x)) + else: + x = self.drop1(x + self.ffn(self.ln2(x))) + + return x + + +class RWKVResidualLayerPipe(RWKVResidualLayer): + """ + RWKV Pipeline Layer + """ + + def forward(self, args): + assert len(args) == 2 + hidden_states, mask = args + neox_args = self.neox_args + return super().forward(hidden_states), mask diff --git a/megatron/neox_arguments/arguments.py b/megatron/neox_arguments/arguments.py index d96c48af3..ff4f4bc21 100644 --- a/megatron/neox_arguments/arguments.py +++ b/megatron/neox_arguments/arguments.py @@ -1068,6 +1068,15 @@ def calculate_derived(self): assert ( self.hidden_dropout == 0.0, ), "Mamba does not yet have dropout implemented" + if "rwkv" in self.attention_config: + assert ( + not self.is_pipe_parallel and self.model_parallel_size == 1 + ), "RWKV not currently compatible with parallelism" + if isinstance(self.zero_stage, int): + assert self.zero_stage <= 2, "Zero stage 3 not compatible with RWKV" + assert ( + self.hidden_dropout == 0.0, + ), "RWKV does not yet have dropout implemented" # Sparsity config if self.sparsity_config is None: diff --git a/megatron/neox_arguments/deepspeed_args.py b/megatron/neox_arguments/deepspeed_args.py index 708a5f5b1..270e67f8c 100644 --- a/megatron/neox_arguments/deepspeed_args.py +++ b/megatron/neox_arguments/deepspeed_args.py @@ -100,7 +100,9 @@ class NeoXArgsDeepspeedConfig(NeoXArgsTemplate): bf16: dict = None """ - Configuration for using bfloat16 floating-point format as an alternative to FP16. BFLOAT16 requires hardware support (e.g., NVIDIA A100). Dictionary options as described in Deepspeed documentation: https://www.deepspeed.ai/docs/config-json/#bfloat16-training-options + Configuration for using bfloat16 floating-point format as an alternative to FP16. BFLOAT16 requires hardware support (e.g., NVIDIA A100). + + Dictionary options as described in Deepspeed documentation: https://www.deepspeed.ai/docs/config-json/#bfloat16-training-options """ # ---Automatic Mixed Precision (AMP) Training Options--- diff --git a/megatron/neox_arguments/neox_args.py b/megatron/neox_arguments/neox_args.py index 8a216a25b..febefb3c2 100644 --- a/megatron/neox_arguments/neox_args.py +++ b/megatron/neox_arguments/neox_args.py @@ -36,6 +36,7 @@ "gmlp", "amlp", "flash", + "rwkv", "mamba", ] @@ -216,7 +217,7 @@ class NeoXArgsModel(NeoXArgsTemplate): The first item in the list specifies the attention type(s), and should be a list of strings. The second item specifies the number of times to repeat those attention types in the full list. - attention type choices: [global, local, sparse_fixed, sparse_variable, bslongformer, bigbird, "gmlp", "amlp", "flash", "mamba"] + attention type choices: [global, local, sparse_fixed, sparse_variable, bslongformer, bigbird, "gmlp", "amlp", "flash", "mamba", "rwkv"] So a 12 layer network with only global attention could be specified like: [[[`global`], 12]] diff --git a/tests/cpu_tests/action.yml b/tests/cpu_tests/action.yml index a7847d1ec..f8180605f 100644 --- a/tests/cpu_tests/action.yml +++ b/tests/cpu_tests/action.yml @@ -5,7 +5,7 @@ inputs: required: true type: string runs: - using: composite + using: composite steps: - uses: actions/checkout@v4 with: