Skip to content

Commit

Permalink
fix bug with qkvr sizing, has to be 3*C. Credit to @ademeure for find…
Browse files Browse the repository at this point in the history
…ing this bug and bringing light to darkness and order to chaos. A true warrior in the fight against entropy.
  • Loading branch information
karpathy committed Oct 1, 2024
1 parent 7d945e9 commit e6481b6
Showing 1 changed file with 3 additions and 3 deletions.
6 changes: 3 additions & 3 deletions train_llama3.cu
Original file line number Diff line number Diff line change
Expand Up @@ -289,7 +289,7 @@ void fill_in_activation_sizes(const ActivationTensors* data, TensorSpec (&tensor
tensors[14] = TENSOR_SPEC(data->lnf_mean, B * T);
tensors[15] = TENSOR_SPEC(data->lnf_rstd, B * T);
tensors[16] = TENSOR_SPEC(data->losses, B * T);
tensors[17] = TENSOR_SPEC(data->qkvr, L * B * T * qkv_channels);
tensors[17] = TENSOR_SPEC(data->qkvr, L * B * T * 3*C); // 3*C is correct - this is QKV after replication of KV
tensors[18] = TENSOR_SPEC(data->output, B * T * max(qkv_channels, max(ffn_channels, max(NH*T, Vp))));
tensors[19] = TENSOR_SPEC(data->scratch_bt4c, B * T * ffn_channels);
tensors[20] = TENSOR_SPEC(data->scratch_btc, B * T * C);
Expand Down Expand Up @@ -678,7 +678,7 @@ void gpt2_forward(GPT2 *model, const int* inputs, size_t B, size_t T) {

// get the pointers of the activations for this layer
floatX* l_ln1 = (model->recompute < 2) ? acts.ln1 + l * B * T * C : acts.lnf;
floatX* l_qkvr = acts.qkvr + l * B * T * qkv_channels;
floatX* l_qkvr = acts.qkvr + l * B * T * 3*C;
floatX* l_atty = acts.atty + l * B * T * C;
floatX* l_residual2 = acts.residual2 + l * B * T * C;
floatX* l_ln2 = (model->recompute < 2) ? acts.ln2 + l * B * T * C : acts.lnf;
Expand Down Expand Up @@ -862,7 +862,7 @@ void gpt2_backward_and_reduce(GPT2 *model, int* inputs, const int* targets, int
// get the pointers of the activations for this layer
floatX* l_ln1 = (model->recompute < 2) ? acts.ln1 + l * B * T * C : acts.lnf;
float* l_ln1_rstd = acts.ln1_rstd + l * B * T;
floatX* l_qkvr = acts.qkvr + l * B * T * qkv_channels;
floatX* l_qkvr = acts.qkvr + l * B * T * 3*C;
floatX* l_atty = acts.atty + l * B * T * C;
floatX* l_residual2 = acts.residual2 + l * B * T * C;
floatX* l_ln2 = (model->recompute < 2) ? acts.ln2 + l * B * T * C : acts.lnf;
Expand Down

0 comments on commit e6481b6

Please sign in to comment.