Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add llama 3 support to llm.c #754

Draft
wants to merge 48 commits into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
48 commits
Select commit Hold shift + click to select a range
09b47a7
llama3 starting point is at gpt-2 exact copy paste for both train/tes…
karpathy Sep 13, 2024
01bc4c6
first set of changes to match up the .py and the .cu version. default…
karpathy Sep 13, 2024
b883560
change the export code of Llama 3 to be very GPT-2 friendly, using a …
karpathy Sep 13, 2024
8866308
adapt the sizes of all the parameter tensors and load them from file.…
karpathy Sep 16, 2024
45026f6
make llama3cu phony
karpathy Sep 16, 2024
77e1d7a
add support for dataloader to serve uint32_t tokens, as necessary in …
karpathy Sep 16, 2024
72e6f1a
add new Encoder that does not use positional embeddings, like in llam…
karpathy Sep 16, 2024
234de31
introduce rmsnorm, unfused, forward
karpathy Sep 16, 2024
508c474
move debugging into fp32, so python has to write the fp32 version, an…
karpathy Sep 17, 2024
685617f
make fp32 path in .py code work correctly
karpathy Sep 17, 2024
56f956c
add repkv kernel to replicate K,V heads after the QKV projection
karpathy Sep 21, 2024
45401b4
DRAFT: Adding backward kernel for repkv
insop Sep 22, 2024
080e57f
CPU version tested
insop Sep 22, 2024
6c68657
Put cuda kernel caller placeholder
insop Sep 22, 2024
ad46043
WIP updating cuda kernel
insop Sep 22, 2024
42d09e8
minor clean up
insop Sep 22, 2024
fcc3466
Add minor change
insop Sep 22, 2024
de9c817
wip
insop Sep 24, 2024
76b40e4
integrate the repkv kernel with minor changes. use the bt4c buffer fo…
karpathy Sep 24, 2024
026e4ed
add RoPE PyTorch and C reference code
karpathy Sep 24, 2024
8336d2a
Merge remote-tracking branch 'upstream/llama3' into insop/llama3
insop Sep 25, 2024
2ebf8f6
Add rmsnorm fused kernel
gordicaleksa Sep 25, 2024
52c7254
add the finished RoPE forward pass
karpathy Sep 25, 2024
6538df6
Merge pull request #769 from gordicaleksa/fused_rmsnorm
karpathy Sep 25, 2024
bb3c92d
integrate the fused rmsnorm forward
karpathy Sep 25, 2024
1826752
add swigul yaygit add -u!
karpathy Sep 25, 2024
0731b39
forward pass matchesgit add train_llama3.cu train_llama3.py ! losses …
karpathy Sep 25, 2024
8874c2c
Merge remote-tracking branch 'upstream/llama3' into insop/llama3
insop Sep 25, 2024
3e5134d
Merge branch 'insop/llama3_wip' into insop/llama3
insop Sep 25, 2024
d1f2f64
Updated repkv_backward cuda kernel
insop Sep 26, 2024
31be5e7
add rmsnorm backward in dev/cuda, it seems to work surprisingly, and …
karpathy Sep 26, 2024
a2b66f1
Merge remote-tracking branch 'upstream/llama3' into insop/llama3
insop Sep 26, 2024
102067f
oops i think i accidentally forgot to include swiglu.cuh
karpathy Sep 26, 2024
2c4b3cc
integrate our rmsnorm backward and move the other rmsnorm functions i…
karpathy Sep 26, 2024
cbf53e3
Merge remote-tracking branch 'upstream/llama3' into insop/llama3
insop Sep 26, 2024
01c2895
Update RoPE naming
insop Sep 26, 2024
1b54612
i can backward through MLP block. Attention block is next
karpathy Sep 27, 2024
c8b348e
Merge pull request #764 from insop/insop/llama3
karpathy Sep 27, 2024
28e4a7f
small fixes, but still not too happy with this kernel, it wastes thre…
karpathy Sep 27, 2024
075e430
just pushing what i have. it's epsilon away from working sigh. basica…
karpathy Sep 27, 2024
8d49062
add backward kernel to dev/cuda for rope, to ensure correctness. but …
karpathy Sep 27, 2024
7d945e9
reshuffle repkv a bit, i wrote it from scratch. the kernel is still c…
karpathy Sep 27, 2024
e6481b6
fix bug with qkvr sizing, has to be 3*C. Credit to @ademeure for find…
karpathy Oct 1, 2024
9099a0a
ok the full backward now shows max abs diff of 3e-3, except for the e…
karpathy Oct 1, 2024
c746e06
take out debugging stuff. we can now run training loop for both model…
karpathy Oct 1, 2024
2602b46
BF16 opt state (m/v) with stochastic rounding, seems to work really w…
ademeure Oct 1, 2024
d808d78
Merge pull request #772 from ademeure/llama3_arun_new
karpathy Oct 1, 2024
2c5ced6
fix bug due to bf16 adamw mv
karpathy Oct 1, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 9 additions & 1 deletion Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -243,8 +243,13 @@ else
PFLAGS = -DENABLE_BF16
endif

# Optimizer precision settings, enable to allow BF16 for AdamW m/v state (also affects state file)
ifeq ($(OPTIMIZER_LOW_PRECISION), 1)
PFLAGS += -DOPTIMIZER_LOW_PRECISION
endif

# PHONY means these targets will always be executed
.PHONY: all train_gpt2 test_gpt2 train_gpt2cu test_gpt2cu train_gpt2fp32cu test_gpt2fp32cu profile_gpt2cu
.PHONY: all train_llama3cu train_gpt2 test_gpt2 train_gpt2cu test_gpt2cu train_gpt2fp32cu test_gpt2fp32cu profile_gpt2cu

# Add targets
TARGETS = train_gpt2 test_gpt2
Expand Down Expand Up @@ -285,6 +290,9 @@ test_gpt2fp32cu: test_gpt2_fp32.cu
profile_gpt2cu: profile_gpt2.cu $(NVCC_CUDNN)
$(NVCC) $(NVCC_FLAGS) $(PFLAGS) -lineinfo $^ $(NVCC_LDFLAGS) $(NVCC_INCLUDES) $(NVCC_LDLIBS) $(CUDA_OUTPUT_FILE)

train_llama3cu: train_llama3.cu $(NVCC_CUDNN)
$(NVCC) $(NVCC_FLAGS) $(PFLAGS) $^ $(NVCC_LDFLAGS) $(NVCC_INCLUDES) $(NVCC_LDLIBS) $(CUDA_OUTPUT_FILE)

clean:
$(REMOVE_FILES) $(TARGETS)
$(REMOVE_BUILD_OBJECT_FILES)
13 changes: 13 additions & 0 deletions dev/cbridge/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
# cbridge

We'll use this directory for the PyTorch -> C bridge. So we have some PyTorch code and we'd like to have the equivalent C implementation (usually that one in turn serves as reference for the CUDA kernels later).

For starters we have RoPE. E.g. generate the reference with PyTorch and then match it in C:

```bash
python rope.py
gcc -o rope rope.c -lm
./rope
```

The .py file writes a `robe.bin` file with the intermediate tensors.
101 changes: 101 additions & 0 deletions dev/cbridge/rmsnorm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,101 @@
"""
An RMSNorm PyTorch reference implementation.
This script then does forward/back and writes everything to file so we can
develop the CPU version, and eventually the GPU kernel as well.
"""

import math
import torch
import numpy as np
import torch.nn as nn
from torch.nn import functional as F

# -----------------------------------------------------------------------------

class RMSNorm(torch.nn.Module):
def __init__(self, dim: int, eps: float):
super().__init__()
self.eps = eps
self.weight = nn.Parameter(torch.ones(dim))

def _norm(self, x):
mean_sq = x.pow(2).mean(dim=-1, keepdim=True) + self.eps
rstd = torch.rsqrt(mean_sq)
norm = x * rstd
return norm

def forward(self, x):
output = self._norm(x.float()).type_as(x)
return output * self.weight

def rmsnorm_backward(x, w, dout, eps):
# recompute the rstd, norm (or we could cache it in the forward pass)
mean_sq = x.pow(2).mean(dim=-1, keepdim=True) + eps # (B, T, 1)
rstd = torch.rsqrt(mean_sq) # (B, T, 1)
norm = x * rstd # (B, T, C)
# gradients for weights
dw = (dout * norm).sum((0, 1)) # (C)
# gradients for input
dnorm = dout * w # (B, T, C)
dx = dnorm - norm * (dnorm * norm).mean(dim=-1, keepdim=True)
dx *= rstd
return dx, dw

# -----------------------------------------------------------------------------

# seed the rng
torch.manual_seed(42)

B = 4
T = 64
C = 256
eps = 1e-5

inp = torch.randn(B, T, C, dtype=torch.float32)
inp.requires_grad = True

# rmsnorm
m = RMSNorm(C, eps=eps)
out = m(inp)

# loss can just be a weighted sum, with some fixed weights
wei = torch.randn_like(out, dtype=torch.float32)
loss = (out * wei).sum()
loss.backward()

# let's now do the backward pass manually
# backprop starts with the output gradient, which is exactly wei because of the loss functions
dx, dw = rmsnorm_backward(inp, m.weight, wei, eps)
# let's assert that the gradients match
assert torch.allclose(dx, inp.grad, atol=1e-6)
assert torch.allclose(dw, m.weight.grad, atol=1e-6)
print("RMSNorm gradients match")
print("first 5 elements of dx comparison:")
print(dx.view(-1)[:5].tolist())
print(inp.grad.view(-1)[:5].tolist())
print("first 5 elements of dw comparison:")
print(dw.view(-1)[:5].tolist())
print(m.weight.grad.view(-1)[:5].tolist())
print("dx error:", (inp.grad.view(-1) - dx.view(-1)).abs().max().item())
print("dw error:", (m.weight.grad.view(-1) - dw.view(-1)).abs().max().item())

# save to .bin file so we can check correctness in C land
int_header = np.zeros(16, dtype=np.int32) # for ints
float_header = np.zeros(16, dtype=np.float32) # for floats
int_header[0] = 20240925 # magic number
int_header[1] = B
int_header[2] = T
int_header[3] = C
float_header[0] = eps

# write the hyperparameters, inputs, output, and input gradients to file
results_file = "rmsnorm.bin"
with open(results_file, "wb") as f:
f.write(int_header.tobytes()) # 16 int32
f.write(float_header.tobytes()) # 16 float32
f.write(inp.detach().cpu().numpy().tobytes()) # B * T * C
f.write(out.detach().cpu().numpy().tobytes()) # B * T * C
f.write(wei.detach().cpu().numpy().tobytes()) # B * T * C
f.write(inp.grad.detach().cpu().numpy().tobytes()) # B * T * C
f.write(m.weight.grad.detach().cpu().numpy().tobytes()) # C
print("Saved results to %s" % results_file)
246 changes: 246 additions & 0 deletions dev/cbridge/rope.c
Original file line number Diff line number Diff line change
@@ -0,0 +1,246 @@
/*
Our goal here is to load the .bin files generated by rope.py and match
the implementation in C and get the same results as in rope.py.

Compile and run simply with:

gcc -o rope rope.c -lm
./rope
*/

#include <stdio.h>
#include <stdlib.h>
#include <math.h>
#include <assert.h>

// ----------------------------------------------------------------------------
// a few utils for safety
extern inline void fread_check(void *ptr, size_t size, size_t nmemb, FILE *stream, const char *file, int line) {
size_t result = fread(ptr, size, nmemb, stream);
if (result != nmemb) {
if (feof(stream)) {
fprintf(stderr, "Error: Unexpected end of file at %s:%d\n", file, line);
} else if (ferror(stream)) {
fprintf(stderr, "Error: File read error at %s:%d\n", file, line);
} else {
fprintf(stderr, "Error: Partial read at %s:%d. Expected %zu elements, read %zu\n",
file, line, nmemb, result);
}
fprintf(stderr, "Error details:\n");
fprintf(stderr, " File: %s\n", file);
fprintf(stderr, " Line: %d\n", line);
fprintf(stderr, " Expected elements: %zu\n", nmemb);
fprintf(stderr, " Read elements: %zu\n", result);
exit(EXIT_FAILURE);
}
}
#define freadCheck(ptr, size, nmemb, stream) fread_check(ptr, size, nmemb, stream, __FILE__, __LINE__)

extern inline void *malloc_check(size_t size, const char *file, int line) {
void *ptr = malloc(size);
if (ptr == NULL) {
fprintf(stderr, "Error: Memory allocation failed at %s:%d\n", file, line);
fprintf(stderr, "Error details:\n");
fprintf(stderr, " File: %s\n", file);
fprintf(stderr, " Line: %d\n", line);
fprintf(stderr, " Size: %zu bytes\n", size);
exit(EXIT_FAILURE);
}
return ptr;
}

#define mallocCheck(size) malloc_check(size, __FILE__, __LINE__)

int compare_arrays(const float *arr1, const float *arr2, size_t size, const char *name, float epsilon) {
for (size_t i = 0; i < size; i++) {
// print 10 elements that are equally spaced out, for qualitative check
if (i % (size / 10) == 0) {
printf("arr1[%zu] = %f, arr2[%zu] = %f\n", i, arr1[i], i, arr2[i]);
}
if (fabsf(arr1[i] - arr2[i]) > epsilon) {
printf("Error: %s[%zu] = %f, expected %f (diff: %f)\n",
name, i, arr1[i], arr2[i], fabsf(arr1[i] - arr2[i]));
return 0;
}
}
printf("OK: %s\n", name);
return 1;
}

// ----------------------------------------------------------------------------
// all the functions we need

void precompute_freqs_cis(float *freqs_cis, int dim, int end, float theta, int use_scaled) {
// same as precompute_freqs_cis_real in rope.py
for (int i = 0; i < dim / 2; i++) {

// calculate the frequency for the (i, i+1)th dimension
float freq = 1.0f / powf(theta, (float)(2 * i) / dim);
if (use_scaled) {
const int scale_factor = 8;
const int low_freq_factor = 1;
const int high_freq_factor = 4;
const int old_context_len = 8192; // original llama3 length
const float low_freq_wavelen = (float)old_context_len / low_freq_factor;
const float high_freq_wavelen = (float)old_context_len / high_freq_factor;
float wavelen = 2.0f * M_PI / freq;
if (wavelen < high_freq_wavelen) {
// skip; keep freq as is
} else if (wavelen > low_freq_wavelen) {
// scale down by scale_factor
freq /= scale_factor;
} else {
// smooth transition between scaled and unscaled
float smooth = ((float)old_context_len / wavelen - low_freq_factor) / (high_freq_factor - low_freq_factor);
freq = (1.0f - smooth) * freq / scale_factor + smooth * freq;
}
}

// iterate over all time steps, calculate the angle, and store the cos/sin
for (int t = 0; t < end; t++) {
float angle = (float)t * freq;
freqs_cis[t * dim + 2 * i] = cosf(angle); // real part
freqs_cis[t * dim + 2 * i + 1] = sinf(angle); // imaginary part
}
}
}

void apply_rotary_emb_forward(float *out, const float *inp, const float *freqs_cis, int B, int T, int n_head, int head_dim) {
// same as apply_rotary_emb_real in rope.py
for (int b = 0; b < B; b++) {
for (int t = 0; t < T; t++) {
int idx_bt = b * (T * n_head * head_dim) + t * (n_head * head_dim);
for (int h = 0; h < n_head; h++) {
int idx_bth = idx_bt + h * head_dim;
for (int d = 0; d < head_dim / 2; d++) {
// fetch a tuple of activations, which we imagine as a complex number
int idx = idx_bth + 2 * d;
float x_real = inp[idx];
float x_imag = inp[idx + 1];
// fetch the angle from freqs_cis
int freqs_idx = t * head_dim + 2 * d;
float freqs_cos = freqs_cis[freqs_idx];
float freqs_sin = freqs_cis[freqs_idx + 1];
// apply the rotation
out[idx] = x_real * freqs_cos - x_imag * freqs_sin;
out[idx + 1] = x_real * freqs_sin + x_imag * freqs_cos;
}
}
}
}
}

void apply_rotary_emb_backward(float *dinp, const float *dout, const float *inp, const float *freqs_cis, int B, int T, int n_head, int head_dim) {
// backward pass of the RoPE embedding
for (int b = 0; b < B; b++) {
for (int t = 0; t < T; t++) {
int idx_bt = b * (T * n_head * head_dim) + t * (n_head * head_dim);
for (int h = 0; h < n_head; h++) {
int idx_bth = idx_bt + h * head_dim;
for (int d = 0; d < head_dim / 2; d++) {
// fetch the angle from freqs_cis
int freqs_idx = t * head_dim + 2 * d;
float freqs_cos = freqs_cis[freqs_idx];
float freqs_sin = freqs_cis[freqs_idx + 1];
// and the input index we'll be updating
int idx = idx_bth + 2 * d;
// backward pass is simple because freqs_cis is just scaling by a constant
dinp[idx] += dout[idx] * freqs_cos + dout[idx + 1] * freqs_sin;
dinp[idx + 1] += -dout[idx] * freqs_sin + dout[idx + 1] * freqs_cos;
}
}
}
}
}

// ----------------------------------------------------------------------------

int main() {

// load the .bin file
FILE *file = fopen("rope.bin", "rb");
if (file == NULL) {
printf("Error: Could not open file.\n");
return 1;
}
// read the header
int int_header[16];
float float_header[16];
freadCheck(int_header, sizeof(int), 16, file);
freadCheck(float_header, sizeof(float), 16, file);
// check the magic number
if (int_header[0] != 20240924) {
printf("Error: Invalid magic number.\n");
fclose(file);
return 1;
}
// extract the hyperparameters
int B = int_header[1];
int T = int_header[2];
int n_embd = int_header[3];
int n_head = int_header[4];
int use_scaled_rope = int_header[5];
float rope_theta = float_header[0];
int head_dim = n_embd / n_head;
// read the inputs
float *inp = (float *)mallocCheck(B * T * n_head * head_dim * sizeof(float));
freadCheck(inp, sizeof(float), B * T * n_head * head_dim, file);
// read the freqs_cis
float *freqs_cis_target = (float *)mallocCheck(T * head_dim * sizeof(float));
freadCheck(freqs_cis_target, sizeof(float), T * head_dim, file);
// read the output
float *out_target = (float *)mallocCheck(B * T * n_head * head_dim * sizeof(float));
freadCheck(out_target, sizeof(float), B * T * n_head * head_dim, file);
// read the weights for the loss function
float *wei = (float *)mallocCheck(B * T * n_head * head_dim * sizeof(float));
freadCheck(wei, sizeof(float), B * T * n_head * head_dim, file);
// read the input gradients
float *inp_grad_target = (float *)mallocCheck(B * T * n_head * head_dim * sizeof(float));
freadCheck(inp_grad_target, sizeof(float), B * T * n_head * head_dim, file);
// ensure we exactly exhausted the file
long current_position = ftell(file);
// Get the file size
fseek(file, 0, SEEK_END);
long file_size = ftell(file);
// check if we read the whole file
if (current_position != file_size) {
printf("Error: File was not read properly; %ld bytes left unread.\n", file_size - current_position);
fclose(file);
return 1;
}
fclose(file);

// print the hyperparameters
printf("B: %d, T: %d, n_embd: %d, n_head: %d, use_scaled_rope: %d, rope_theta: %f\n",
B, T, n_embd, n_head, use_scaled_rope, rope_theta);

// Step 1) Calculate freqs_cis in C and compare with the Python one
float *freqs_cis = (float *)mallocCheck(T * head_dim * sizeof(float));
precompute_freqs_cis(freqs_cis, head_dim, T, rope_theta, use_scaled_rope);
if (!compare_arrays(freqs_cis, freqs_cis_target, T * head_dim, "freqs_cis", 1e-6f)) { return 1; }

// Step 2) Apply the RoPE embedding in C and compare with the Python one
float *out = (float *)mallocCheck(B * T * n_head * head_dim * sizeof(float));
apply_rotary_emb_forward(out, inp, freqs_cis, B, T, n_head, head_dim);
if (!compare_arrays(out, out_target, B * T * n_head * head_dim, "out", 1e-6f)) { return 1; }

// Step 3) Calculate the loss and gradients in C and compare with the Python one
float *dout = wei; // wei is dout because the loss is just a dot product of out and wei
float *dinp = (float *)mallocCheck(B * T * n_head * head_dim * sizeof(float));
apply_rotary_emb_backward(dinp, dout, inp, freqs_cis, B, T, n_head, head_dim);
if (!compare_arrays(dinp, inp_grad_target, B * T * n_head * head_dim, "dinp", 1e-6f)) { return 1; }

printf("✅ ALL OK\n");

// clean up
free(inp);
free(freqs_cis_target);
free(out_target);
free(wei);
free(inp_grad_target);
free(freqs_cis);
free(out);
free(dinp);

return 0;
}
Loading