From 7d945e994cc105182a3c4d62f0cc8990a62cb5ec Mon Sep 17 00:00:00 2001 From: Andrej Karpathy Date: Fri, 27 Sep 2024 19:25:09 +0000 Subject: [PATCH] reshuffle repkv a bit, i wrote it from scratch. the kernel is still correct. repkv backward looks correct. rope backward is trivial so i don't see how it's not correct, and i also checked it. basically i'm really confused right now --- dev/cuda/repkv_backward.cu | 52 +++++++++++++++++++++----------------- llmc/repkv.cuh | 4 +-- train_llama3.cu | 43 ++++++++++++++++--------------- train_llama3.py | 12 ++++----- 4 files changed, 60 insertions(+), 51 deletions(-) diff --git a/dev/cuda/repkv_backward.cu b/dev/cuda/repkv_backward.cu index 9a00205d9..84064c530 100644 --- a/dev/cuda/repkv_backward.cu +++ b/dev/cuda/repkv_backward.cu @@ -11,39 +11,46 @@ Block size 128 seems fastest on H100 // cpu reference code void repkv_backward_cpu(float* dinp, const float* dout, - const int B, const int T, const int Cout, - const int hd, const int qh, const int kh, const int vh) { - - assert(Cout == (hd * (3 * qh))); + int B, int T, int C, + int hd, int qh, int kh, int vh) { + // inp is (B, T, C) + // out is (B, T, 3, NH, HD) + // hd = head dimension + // qh, kh, vh = number of query, key, value heads + assert(C == hd * (qh + kh + vh)); assert(kh == vh); int nrep = qh / kh; // number of times to replicate key/value vectors - int Cin = hd * (qh + kh + vh); // output channels + int Cout = hd * (qh * 3); // output channels for (int b = 0; b < B; b++) { for (int t = 0; t < T; t++) { - // seek to the input position dout[b,t,:] - const float* x = dout + b * T * Cout + t * Cout; + // seek to the input position inp[b,t,:] + float* dx = dinp + b * T * C + t * C; // seek to the output position out[b,t,:] - float* y = dinp + b * T * Cin + t * Cin; + const float* dy = dout + b * T * Cout + t * Cout; // copy all the query vectors, no changes - for (int i = 0; i < hd * qh; i++) { y[i] = x[i]; } - x += hd * qh; // advance input pointer - y += hd * qh; // advance output pointer - // copy key vectors, and replicate them nrep times + for (int i = 0; i < hd * qh; i++) { dx[i] = dy[i]; } + dx += hd * qh; // advance input pointer + dy += hd * qh; // advance output pointer + // gather gradients from the key vectors for (int h = 0; h < kh; h++) { + // init the gradient to 0 + for (int i = 0; i < hd; i++) { dx[i] = 0.0f; } for (int n = 0; n < nrep; n++) { - for (int i = 0; i < hd; i++) { y[i] += x[i]; } - x += hd; // advance input pointer + for (int i = 0; i < hd; i++) { dx[i] += dy[i]; } + dy += hd; // advance output pointer } - y += hd; // advance output pointer + dx += hd; // advance input pointer } - // copy value vectors, and replicate them nrep times + // gather gradients from the value vectors for (int h = 0; h < vh; h++) { + // init the gradient to 0 + for (int i = 0; i < hd; i++) { dx[i] = 0.0f; } for (int n = 0; n < nrep; n++) { - for (int i = 0; i < hd; i++) { y[i] += x[i]; } - x += hd; // advance input pointer + for (int i = 0; i < hd; i++) { dx[i] += dy[i]; } + dy += hd; // advance output pointer } - y += hd; // advance output pointer + dx += hd; // advance input pointer } } } @@ -76,7 +83,7 @@ __global__ void repkv_backward_kernel1(floatX* dinp, const floatX* dout, dinp[dinp_idx] = __ldcs(&dout[dout_idx]); } else if (c == 1) { if (nh % replicate_factor == 0) { - float reduced_sum = 0; + float reduced_sum = 0.0f; for (int i = 0; i < replicate_factor; i++) { reduced_sum += __ldcs(&dout[dout_idx+HD*i]); } @@ -87,7 +94,7 @@ __global__ void repkv_backward_kernel1(floatX* dinp, const floatX* dout, } else { if (nh % replicate_factor == 0) { - float reduced_sum = 0; + float reduced_sum = 0.0f; for (int i = 0; i < replicate_factor; i++) { reduced_sum += __ldcs(&dout[dout_idx+HD*i]); } @@ -141,7 +148,6 @@ int main(int argc, char **argv) { // allocate (and fill) CPU memory float* dinp = (float*)malloc(B * T * Cin * sizeof(float)); - memset(dinp, 0, B * T * Cin * sizeof(float)); float* dout = make_random_float(B * T * Cout * sizeof(float)); // allocate GPU memory @@ -160,7 +166,7 @@ int main(int argc, char **argv) { printf("Using kernel %d\n", kernel_num); // CPU reference calculate - repkv_backward_cpu(dinp, dout, B, T, Cout, hd, qh, kh, vh); + repkv_backward_cpu(dinp, dout, B, T, Cin, hd, qh, kh, vh); // check the correctness of the kernel at all block sizes int block_sizes[] = {32, 64, 128, 256, 512, 1024}; diff --git a/llmc/repkv.cuh b/llmc/repkv.cuh index f4c517eaa..a70881402 100644 --- a/llmc/repkv.cuh +++ b/llmc/repkv.cuh @@ -74,7 +74,7 @@ __global__ void repkv_backward_kernel1(floatX* dinp, const floatX* dout, dinp[dinp_idx] = __ldcs(&dout[dout_idx]); } else if (c == 1) { if (nh % replicate_factor == 0) { - float reduced_sum = 0; + float reduced_sum = 0.0f; for (int i = 0; i < replicate_factor; i++) { reduced_sum += (float) __ldcs(&dout[dout_idx+HD*i]); } @@ -85,7 +85,7 @@ __global__ void repkv_backward_kernel1(floatX* dinp, const floatX* dout, } else { if (nh % replicate_factor == 0) { - float reduced_sum = 0; + float reduced_sum = 0.0f; for (int i = 0; i < replicate_factor; i++) { reduced_sum += (float) __ldcs(&dout[dout_idx+HD*i]); } diff --git a/train_llama3.cu b/train_llama3.cu index 2cc554a0b..65739d1c6 100644 --- a/train_llama3.cu +++ b/train_llama3.cu @@ -897,10 +897,31 @@ void gpt2_backward_and_reduce(GPT2 *model, int* inputs, const int* targets, int rmsnorm_backward(dresidual, dl_ln2w, scratchF, dl_btc, l_residual2, l_ln2w, l_ln2_rstd, B, T, C, main_stream); matmul_backward(dl_btc, dl_attprojw, dl_attprojb, dresidual, l_atty, l_attprojw, scratchF, B, T, C, C, main_stream); + // <--- gradient here matches OK + + #ifdef ENABLE_CUDNN + printf("cuDNN path TODO\n"); exit(0); + float* l_att = (float*)acts.att + l * B * NH * T; // cuDNN needs a smaller FP32 tensor + attention_backward_cudnn(dl_bt4c, dl_btc, l_qkvr, l_atty, (float*)l_att, B, T, NH, C, main_stream); + #else + floatX* l_att = acts.att + l * B * NH * T * T; + // we need B x T x (4)C buffers. l_atty and l_fch aren't needed anymore at this point, so reuse their memory + floatX* buffer_a = l_atty; + floatX* buffer_b = l_fch_pre_gelu; // this is B x T x 4C, so even larger than what we need + attention_backward(dl_bt4c, buffer_b, scratchX, buffer_a, dl_btc, l_qkvr, l_att, B, T, C, NH, main_stream); + #endif + // backward rope (this can be done in-place) + rope_backward_inplace(dl_bt4c, dl_bt4c, model->freqs_cis, B, T, NH, hd, main_stream); + // backward repkv (use scratchX as gradient buffer here) + repkv_backward(dl_bt4c2, dl_bt4c, B, T, NH, n_kv_head, hd); + + // <--- here the gradients don't match + // so there is an issue with one of attention, rope, or repkv, or how they are called + // ------------------------------------------------------------------------ // DEBUGGING: we only work until this point right now, so exit here // transfer the first 32 elements to CPU and print them - float* output = (float*)dl_btc; + float* output = (float*)dl_bt4c2; floatX* cpu = (floatX*)mallocCheck(32 * sizeof(floatX)); cudaCheck(cudaMemcpy(cpu, output, 32 * sizeof(floatX), cudaMemcpyDeviceToHost)); for (int i = 0; i < 32; i++) { @@ -909,7 +930,7 @@ void gpt2_backward_and_reduce(GPT2 *model, int* inputs, const int* targets, int // write to .bin file // move output to cpu // int sz = B*T*qkv_channels; //B*T*C; - int sz = B*T*C; + int sz = B*T*qkv_channels; floatX* cpu_output = (floatX*)mallocCheck(sz * sizeof(floatX)); cudaCheck(cudaMemcpy(cpu_output, output, sz * sizeof(floatX), cudaMemcpyDeviceToHost)); FILE* f = fopen("out.bin", "wb"); @@ -918,24 +939,6 @@ void gpt2_backward_and_reduce(GPT2 *model, int* inputs, const int* targets, int exit(0); // ------------------------------------------------------------------------ - #ifdef ENABLE_CUDNN - printf("cuDNN path TODO\n"); exit(0); - float* l_att = (float*)acts.att + l * B * NH * T; // cuDNN needs a smaller FP32 tensor - attention_backward_cudnn(dl_bt4c, dl_btc, l_qkvr, l_atty, (float*)l_att, B, T, NH, C, main_stream); - #else - floatX* l_att = acts.att + l * B * NH * T * T; - // we need B x T x (4)C buffers. l_atty and l_fch aren't needed anymore at this point, so reuse their memory - floatX* buffer_a = l_atty; - floatX* buffer_b = l_fch_pre_gelu; // this is B x T x 4C, so even larger than what we need - attention_backward(dl_bt4c, buffer_b, scratchX, buffer_a, dl_btc, l_qkvr, l_att, B, T, C, NH, main_stream); - #endif - // backward rope (this can be done in-place) - rope_backward_inplace(dl_bt4c, dl_bt4c, model->freqs_cis, B, T, NH, hd, main_stream); - // backward repkv (use scratchX as gradient buffer here) - repkv_backward(dl_bt4c2, dl_bt4c, B, T, NH, n_kv_head, hd); - - // <--- here the gradients don't match, so there is an issue in between - // backward QKV projection if(model->recompute >= 2) { rmsnorm_forward(l_ln1, l_ln1_rstd, residual, l_ln1w, B, T, C, main_stream); diff --git a/train_llama3.py b/train_llama3.py index cd1549b42..b654d8d83 100644 --- a/train_llama3.py +++ b/train_llama3.py @@ -168,6 +168,12 @@ def forward(self, x, freqs_cis=None, start_pos=None, mask=None): B, T, C = x.size() # batch size, sequence length, embedding dimensionality (n_embd) # calculate query, key, values for all heads in batch and move head forward to be the batch dim qkv = self.c_attn(x) + + DEBUG_POINT = qkv.detach() + DEBUG_POINT = DEBUG_POINT.requires_grad_(True) + self.DEBUG_POINT = DEBUG_POINT + qkv = DEBUG_POINT + q, k, v = qkv.split([self.n_head * self.hd, self.n_kv_head * self.hd, self.n_kv_head * self.hd], dim=-1) q, k, v = map(lambda t: t.view(B, T, -1, self.hd), (q, k, v)) # (B, T, NH, HD) q, k = apply_rotary_emb(q, k, freqs_cis=freqs_cis) # rotate QK (rope) <-- 1. difference compared to GPT-2 @@ -197,12 +203,6 @@ def forward(self, x, freqs_cis=None, start_pos=None, mask=None): att = F.softmax(scores.float(), dim=-1).type_as(q) y = att @ v # (B, NH, T, T) x (B, NH, T, HD) -> (B, NH, T, HD) y = y.transpose(1, 2).contiguous().view(B, T, C) - - DEBUG_POINT = y.detach() - DEBUG_POINT = DEBUG_POINT.requires_grad_(True) - self.DEBUG_POINT = DEBUG_POINT - y = DEBUG_POINT - y = self.c_proj(y) return y