Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Windows Makefile review comments addressed. Remove extra commands and CI changes. #256

Closed
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
39 commits
Select commit Hold shift + click to select a range
09cd67e
code to load bf16 weights directly, and also re-wire the position of …
karpathy Apr 27, 2024
09d935c
i think i am making things cleaner, but i am not fixing the problem
karpathy Apr 27, 2024
d4a642b
i think github copilot betrayed me on this index here, i cant remember
karpathy Apr 27, 2024
e067a27
fix dumb bug. i'll blame github copilot but i can't remember
karpathy Apr 27, 2024
9d6fd30
tweak the tolerances until we pass lol
karpathy Apr 27, 2024
a58b8d5
print more in the comparison
karpathy Apr 27, 2024
2954d90
Enable multithreading in nvcc
ChrisDryden Apr 27, 2024
5ed4364
Addressed review comments. Remove extra commands and CI changes.
rosslwheeler Apr 26, 2024
25240ec
Resolving conflicts in Makefile
rosslwheeler Apr 27, 2024
0062707
fix a really bad bug in how i was checking the gradients, where i loa…
karpathy Apr 27, 2024
9a91b40
bring back original ordering. i also had to bump the thresholds by 3X…
karpathy Apr 28, 2024
82d7907
adjust comment
karpathy Apr 28, 2024
cfccd82
Include the float4 kernel
lancerts Apr 28, 2024
61c5c05
amend the float4 kernel
lancerts Apr 28, 2024
1c7d23a
amend the float4 kernel
lancerts Apr 28, 2024
4f7d8d9
allow user to make different precisions, add prints and error handlin…
karpathy Apr 28, 2024
a3f5ad9
reshuffle the ifdefs to make bf16 the default if no PRECISION is requ…
karpathy Apr 28, 2024
9d70d9a
profile and test only use bf16. but the train script can be run with …
karpathy Apr 28, 2024
d95b8d8
Merge pull request #265 from karpathy/feature/load_bf16
karpathy Apr 28, 2024
835060e
padded vocab change. touched a lot of code. very stressful and error …
karpathy Apr 28, 2024
b7972ff
make padded vocab fixes in the .c code as well, i missed it in the pr…
karpathy Apr 28, 2024
4b6a532
Merge branch 'encoder_forward-float4' of https://github.com/lancerts/…
karpathy Apr 28, 2024
327eef3
incorporate faster encoder_forward kernel to fp32 CUDA version
karpathy Apr 28, 2024
4b6f68a
Merge branch 'lancerts-encoder_forward-float4'
karpathy Apr 28, 2024
10aa24e
Merge pull request #269 from ChrisDryden/patch-3
karpathy Apr 28, 2024
b522333
Updating the CI to build different precisions
Ricardicus Apr 28, 2024
ca48791
as promised, cleanup enabled by padding :)
ngc92 Apr 28, 2024
4c295c7
and even more cleanup
ngc92 Apr 28, 2024
49228b0
add small comment on -t=0
karpathy Apr 28, 2024
5185656
Merge branch 'cleanup' of https://github.com/ngc92/llm.c into ngc92-c…
karpathy Apr 28, 2024
66a92c6
Merge branch 'ngc92-cleanup'
karpathy Apr 28, 2024
18b41b4
Merge pull request #279 from Ricardicus/prec-ci
karpathy Apr 28, 2024
4a3c278
moved checked helper functions into a separate file
ngc92 Apr 28, 2024
b1c80e9
Merge branch 'split-file' of https://github.com/ngc92/llm.c into ngc9…
karpathy Apr 28, 2024
c20497c
add comments pointing to the definition of the utils functions
karpathy Apr 28, 2024
50acc12
Merge branch 'ngc92-split-file'
karpathy Apr 28, 2024
dc0295c
Addressed review comments. Remove extra commands and CI changes.
rosslwheeler Apr 26, 2024
f265f60
Resolving conflicts in Makefile
rosslwheeler Apr 27, 2024
dd9bb3e
Merge branch 'makefile_windows_fixes' of https://github.com/rosslwhee…
rosslwheeler Apr 29, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 29 additions & 2 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -52,11 +52,38 @@ jobs:
build-with-cuda-fp32:
runs-on: ubuntu-latest # Host OS, Docker will run on top of this
container:
image: nvidia/cuda:11.2.2-devel-ubuntu20.04 # Example CUDA development image with nvcc
image: nvidia/cuda:12.4.1-devel-ubuntu22.04

steps:
- name: Checkout code
uses: actions/checkout@v3

- name: Build project
- name: Build FP32 checkpoint
run: make train_gpt2fp32cu test_gpt2fp32cu

- name: Build FP32 precision
run: PRECISION=FP32 make train_gpt2cu test_gpt2cu profile_gpt2cu

build-with-cuda-bf16:
runs-on: ubuntu-latest # Host OS, Docker will run on top of this
container:
image: nvidia/cuda:12.4.1-devel-ubuntu22.04

steps:
- name: Checkout code
uses: actions/checkout@v3

- name: Build project
run: PRECISION=BF16 make test_gpt2cu train_gpt2cu profile_gpt2cu

build-with-cuda-fp16:
runs-on: ubuntu-latest # Host OS, Docker will run on top of this
container:
image: nvidia/cuda:12.4.1-devel-ubuntu22.04

steps:
- name: Checkout code
uses: actions/checkout@v3

- name: Build project
run: PRECISION=FP16 make test_gpt2cu train_gpt2cu profile_gpt2cu
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ test_gpt2fp32cu
train_gpt2
train_gpt2cu
train_gpt2fp32cu
profile_gpt2cu
dev/cuda/*_forward
dev/cuda/*_backward
dev/cuda/classifier_fused
Expand Down
96 changes: 73 additions & 23 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -7,23 +7,57 @@ CFLAGS_COND = -march=native

# Find nvcc
NVCC := $(shell which nvcc 2>/dev/null)
SHELL_UNAME = $(shell uname)
REMOVE_FILES = rm -f
OUTPUT_FILE = -o $@
CUDA_OUTPUT_FILE = -o $@

# NVCC flags
NVCC_FLAGS = -O3 --use_fast_math
# -t=0 is short for --threads, 0 = number of CPUs on the machine
NVCC_FLAGS = -O3 -t=0 --use_fast_math
NVCC_LDFLAGS = -lcublas -lcublasLt
NCLL_INCLUDES =
NVCC_LDLIBS =

# Function to test if the compiler accepts a given flag.
define check_and_add_flag
ifneq ($(OS), Windows_NT)
NVCC := $(shell which nvcc 2>/dev/null)

# Function to test if the compiler accepts a given flag.
define check_and_add_flag
$(eval FLAG_SUPPORTED := $(shell printf "int main() { return 0; }\n" | $(CC) $(1) -x c - -o /dev/null 2>/dev/null && echo 'yes'))
ifeq ($(FLAG_SUPPORTED),yes)
CFLAGS += $(1)
endif
endef
endef

# Check each flag and add it if supported
$(foreach flag,$(CFLAGS_COND),$(eval $(call check_and_add_flag,$(flag))))
# Check each flag and add it if supported
$(foreach flag,$(CFLAGS_COND),$(eval $(call check_and_add_flag,$(flag))))
else
CFLAGS :=
REMOVE_FILES = del *.exe,*.obj,*.lib,*.exp,*.pdb && del
SHELL_UNAME := Windows
ifneq ($(shell where nvcc 2> nul),"")
NVCC := nvcc
else
NVCC :=
endif
CC := cl
CFLAGS = /Idev /Zi /nologo /Wall /WX- /diagnostics:column /sdl /O2 /Oi /Ot /GL /D _DEBUG /D _CONSOLE /D _UNICODE /D UNICODE /Gm- /EHsc /MD /GS /Gy /fp:fast /Zc:wchar_t /Zc:forScope /Zc:inline /permissive- \
/external:W3 /Gd /TP /wd4996 /[email protected] /FC /openmp:llvm
LDFLAGS :=
LDLIBS :=
INCLUDES :=
NVCC_FLAGS += -I"dev"
ifeq ($(WIN_CI_BUILD),1)
$(info Windows CI build)
OUTPUT_FILE = /link /OUT:$@
CUDA_OUTPUT_FILE = -o $@
else
$(info Windows local build)
OUTPUT_FILE = /link /OUT:$@ && copy /Y $@ [email protected]
CUDA_OUTPUT_FILE = -o $@ && copy /Y [email protected] $@
endif
endif

# Check if OpenMP is available
# This is done by attempting to compile an empty file with OpenMP flags
Expand All @@ -36,7 +70,7 @@ ifeq ($(NO_OMP), 1)
$(info OpenMP is manually disabled)
else
# Detect if running on macOS or Linux
ifeq ($(shell uname), Darwin)
ifeq ($(SHELL_UNAME), Darwin)
# Check for Homebrew's libomp installation in different common directories
ifeq ($(shell [ -d /opt/homebrew/opt/libomp/lib ] && echo "exists"), exists)
# macOS with Homebrew on ARM (Apple Silicon)
Expand All @@ -56,13 +90,15 @@ else
$(warning OpenMP not found, skipping OpenMP support)
endif
else
# Check for OpenMP support in GCC or Clang on Linux
ifeq ($(shell echo | $(CC) -fopenmp -x c -E - > /dev/null 2>&1; echo $$?), 0)
CFLAGS += -fopenmp -DOMP
LDLIBS += -lgomp
$(info OpenMP found, compiling with OpenMP support)
else
$(warning OpenMP not found, skipping OpenMP support)
ifneq ($(OS), Windows_NT)
# Check for OpenMP support in GCC or Clang on Linux
ifeq ($(shell echo | $(CC) -fopenmp -x c -E - > /dev/null 2>&1; echo $$?), 0)
CFLAGS += -fopenmp -DOMP
LDLIBS += -lgomp
$(info OpenMP found, compiling with OpenMP support)
else
$(warning OpenMP not found, skipping OpenMP support)
endif
endif
endif
endif
Expand All @@ -85,8 +121,22 @@ else
endif
endif

# Precision settings, default to bf16 but ability to override
PRECISION ?= BF16
VALID_PRECISIONS := FP32 FP16 BF16
ifeq ($(filter $(PRECISION),$(VALID_PRECISIONS)),)
$(error Invalid precision $(PRECISION), valid precisions are $(VALID_PRECISIONS))
endif
ifeq ($(PRECISION), FP32)
PFLAGS = -DENABLE_FP32
else ifeq ($(PRECISION), FP16)
PFLAGS = -DENABLE_FP16
else
PFLAGS = -DENABLE_BF16
endif

# PHONY means these targets will always be executed
.PHONY: all train_gpt2 test_gpt2 train_gpt2cu test_gpt2cu train_gpt2fp32cu test_gpt2fp32cu
.PHONY: all train_gpt2 test_gpt2 train_gpt2cu test_gpt2cu train_gpt2fp32cu test_gpt2fp32cu profile_gpt2cu

# Add targets
TARGETS = train_gpt2 test_gpt2
Expand All @@ -102,25 +152,25 @@ endif
all: $(TARGETS)

train_gpt2: train_gpt2.c
$(CC) $(CFLAGS) $(INCLUDES) $(LDFLAGS) $< $(LDLIBS) -o $@
$(CC) $(CFLAGS) $(INCLUDES) $(LDFLAGS) $< $(LDLIBS) $(OUTPUT_FILE)

test_gpt2: test_gpt2.c
$(CC) $(CFLAGS) $(INCLUDES) $(LDFLAGS) $< $(LDLIBS) -o $@
$(CC) $(CFLAGS) $(INCLUDES) $(LDFLAGS) $< $(LDLIBS) $(OUTPUT_FILE)

train_gpt2cu: train_gpt2.cu
$(NVCC) $(NVCC_FLAGS) $< $(NVCC_LDFLAGS) $(NVCC_INCLUDES) $(NVCC_LDLIBS) $(NVCC_LDFLAGS) -o $@
$(NVCC) $(NVCC_FLAGS) $< $(NVCC_LDFLAGS) $(NVCC_INCLUDES) $(NVCC_LDLIBS) $(NVCC_LDFLAGS) $(CUDA_OUTPUT_FILE)

train_gpt2fp32cu: train_gpt2_fp32.cu
$(NVCC) $(NVCC_FLAGS) $< $(NVCC_LDFLAGS) $(NVCC_INCLUDES) $(NVCC_LDLIBS) $(NVCC_LDFLAGS) -o $@
$(NVCC) $(NVCC_FLAGS) $< $(NVCC_LDFLAGS) $(NVCC_INCLUDES) $(NVCC_LDLIBS) $(NVCC_LDFLAGS) $(CUDA_OUTPUT_FILE)

test_gpt2cu: test_gpt2.cu
$(NVCC) $(NVCC_FLAGS) $< $(NVCC_LDFLAGS) $(NVCC_INCLUDES) $(NVCC_LDLIBS) $(NVCC_LDFLAGS) -o $@
$(NVCC) $(NVCC_FLAGS) $< $(NVCC_LDFLAGS) $(NVCC_INCLUDES) $(NVCC_LDLIBS) $(NVCC_LDFLAGS) $(CUDA_OUTPUT_FILE)

test_gpt2fp32cu: test_gpt2_fp32.cu
$(NVCC) $(NVCC_FLAGS) $< $(NVCC_LDFLAGS) $(NVCC_INCLUDES) $(NVCC_LDLIBS) $(NVCC_LDFLAGS) -o $@
$(NVCC) $(NVCC_FLAGS) $< $(NVCC_LDFLAGS) $(NVCC_INCLUDES) $(NVCC_LDLIBS) $(NVCC_LDFLAGS) $(CUDA_OUTPUT_FILE)

profile_gpt2cu: profile_gpt2.cu
$(NVCC) $(NVCC_FLAGS) -lineinfo $< $(NVCC_LDFLAGS) -o $@
$(NVCC) $(NVCC_FLAGS) -lineinfo $< $(NVCC_LDFLAGS) $(CUDA_OUTPUT_FILE)

clean:
rm -f train_gpt2 test_gpt2 train_gpt2cu train_gpt2fp32cu test_gpt2cu test_gpt2fp32cu
$(REMOVE_FILES) $(TARGETS)
40 changes: 40 additions & 0 deletions dev/cuda/encoder_forward.cu
Original file line number Diff line number Diff line change
Expand Up @@ -9,12 +9,16 @@ version 1 is naive port from CPU code to kernel: parallelizes over B,T, loops ov

version 2 is more optimized, parallelizes over all of B,T,C
./encoder_forward 2

version 3 is like version 2 but uses float4 reads/writes
./encoder_forward 3
*/

#include <stdio.h>
#include <stdlib.h>
#include <cuda_runtime.h>
#include "common.h"
#include <cassert>

// ----------------------------------------------------------------------------
// CPU code reference
Expand Down Expand Up @@ -81,6 +85,28 @@ __global__ void encoder_forward_kernel2(float* out,
}
}

__device__ inline float4 add_float4(const float4& a, const float4& b) {
return make_float4(a.x + b.x, a.y + b.y, a.z + b.z, a.w + b.w);
}

// use of float4 leads to using 128-bit LDG / STG instructions in SASS,
// very helpful in memory-bound kernels like encoder_forward
__global__ void encoder_forward_kernel3(float4* out,
const int* inp, const float4* wte, const float4* wpe,
int B, int T, int C) {
int C4 = C / 4;
int idx = blockIdx.x * blockDim.x + threadIdx.x;
int N = B * T * C4;
if (idx < N) {
int bt = idx / C4;
int b = bt / T;
int t = bt % T;
int c4 = idx % C4;
int ix = inp[b * T + t];
out[b * T * C4 + t * C4 + c4] = add_float4(wte[ix * C4 + c4], wpe[t * C4 + c4]);
}
}

// ----------------------------------------------------------------------------
// kernel launcher

Expand All @@ -104,6 +130,17 @@ void encoder_forward2(float* out,
cudaCheck(cudaGetLastError());
}

void encoder_forward3(float* out,
const int* inp, const float* wte, const float* wpe,
int B, int T, int C,
const int block_size) {
assert(C % 4 == 0);
const int N = B * T * C;
const int grid_size = ceil_div(N / 4, block_size);
encoder_forward_kernel3<<<grid_size, block_size>>>((float4*) out, inp, (float4*) wte, (float4*) wpe, B, T, C);
cudaCheck(cudaGetLastError());
}

// kernel version dispatch
void encoder_forward(int kernel_num,
float* out,
Expand All @@ -117,6 +154,9 @@ void encoder_forward(int kernel_num,
case 2:
encoder_forward2(out, inp, wte, wpe, B, T, C, block_size);
break;
case 3:
encoder_forward3(out, inp, wte, wpe, B, T, C, block_size);
break;
default:
printf("Invalid kernel number\n");
exit(1);
Expand Down
4 changes: 2 additions & 2 deletions profile_gpt2.cu
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ int main() {

// build the GPT-2 model from a checkpoint
GPT2 model;
gpt2_build_from_checkpoint(&model, "gpt2_124M.bin");
gpt2_build_from_checkpoint(&model, "gpt2_124M_bf16.bin");

int B = 4;
int T = 1024;
Expand Down Expand Up @@ -80,4 +80,4 @@ int main() {
cublasCheck(cublasLtDestroy(cublaslt_handle));

return 0;
}
}
38 changes: 25 additions & 13 deletions test_gpt2.c
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ int main(int argc, char *argv[]) {

int C = model.config.channels;
int V = model.config.vocab_size;
int Vp = model.config.padded_vocab_size;
int maxT = model.config.max_seq_len;
int L = model.config.num_layers;

Expand All @@ -52,8 +53,12 @@ int main(int argc, char *argv[]) {
if (state_file == NULL) { printf("Error opening state file\n"); return 1; }
int state_header[256];
fread(state_header, sizeof(int), 256, state_file);
if (state_header[0] != 20240327) { printf("Bad magic state file"); return 1; }
if (state_header[1] != 1) { printf("Bad version in state file"); return 1; }
if (state_header[0] != 20240327) { printf("Bad magic state file\n"); return 1; }
if (state_header[1] != 2) {
printf("Bad version in state file\n");
printf("---> HINT: try to re-run `python train_gpt2.py`\n");
return 1;
}
int B = state_header[2]; // batch size, e.g. 4
int T = state_header[3]; // time / sequence length (e.g. 64, up to maxT)
printf("[State]\n");
Expand Down Expand Up @@ -107,22 +112,29 @@ int main(int argc, char *argv[]) {

if (step == 0) {
// error checking at step 0 for reference activations/gradients

// at this point, target should be equal to expected_logits, let's compare
int logits_ok = 1;
for (int i=0; i<B*T*V; i++) {
if(i < 3) {
printf("%f %f\n", expected_logits[i], model.acts.logits[i]);
}
if (fabsf(expected_logits[i] - model.acts.logits[i]) >= 1e-2) {
printf("MISMATCH AT INDEX %d: ", i);
printf("%f %f\n", expected_logits[i],model.acts.logits[i]);
logits_ok = 0;
break;
float* calculated_logits = model.acts.logits;
float max_diff = 0.0f;
for (int bt = 0; bt < B*T; bt++) {
for (int v = 0; v < V; v++) { // note we only loop to V (ignoring padding)
int i = bt * Vp + v; // linearized index, using Vp
if (i < 10) {
printf("%f, %f\n", expected_logits[i], calculated_logits[i]);
}
float diff = fabsf(expected_logits[bt*V + v] - calculated_logits[i]);
max_diff = fmaxf(max_diff, diff);
if (diff >= 1e-2f) {
printf("MISMATCH AT INDEX %d,%d: ", bt, v);
printf("%f %f\n", expected_logits[bt*V + v], calculated_logits[i]);
logits_ok = 0;
bt = B*T; // to break out of both loops
break;
}
}
}
if(!logits_ok) { printf("NOT "); }
printf("OK (LOGITS)\n");
printf("OK (LOGITS), max_diff = %e\n", max_diff);
allok = allok && logits_ok;

// compare the achieved loss
Expand Down
Loading
Loading