Skip to content

Commit

Permalink
Merge branch 'karpathy:master' into makefile_windows_update
Browse files Browse the repository at this point in the history
  • Loading branch information
rosslwheeler authored Apr 25, 2024
2 parents d8fb206 + 7a52a21 commit 1d1ea11
Show file tree
Hide file tree
Showing 4 changed files with 41 additions and 11 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ jobs:
run: python prepro_tinyshakespeare.py

- name: Train model
run: python train_gpt2.py
run: python train_gpt2.py --device=cpu

- name: Compile training and testing program
run: make test_gpt2 train_gpt2
Expand Down
11 changes: 10 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -135,7 +135,9 @@ I attached a very small tutorial here, in [doc/layernorm/layernorm.md](doc/layer

The full training loop is also implemented in pure CUDA in one file, but optimizations of the kernels are ongoing. Currently, we roughly match the speed of PyTorch. The way we organize code is that we have a growing collection of kernels of increasing complexity in the `dev/cuda` folder, see [dev/cuda/README.md](dev/cuda/README.md). We then copy paste the best kernels into the main training loop in the single training file `train_gpt2cu.cu`.

**WIP alert, April 23**. We merged the first version of mixed precision training code. I checkpointed the fp32 version to separate files that include `_fp32` in their filename, and would like to preserve this version in the root of the repo because it 1) doesn't require the most up to date CUDA and will a lot more likely compile and is more portable, 2) it is a lot simpler and acts as reference. The "mainline" development of the CUDA version will from here on move mostly to the [train_gpt2.cu](train_gpt2.cu) file, which includes mixed precision. In the descriptions below I will default to using the fp32 version for now because it is currently more portable and stable, then at the end I will cover to the new mixed precision version.
**WIP alert, April 23**. We merged the first version of mixed precision training code. I checkpointed the fp32 version to separate files that include `_fp32` in their filename, and would like to preserve this version in the root of the repo because it 1) doesn't require the most up to date CUDA and will a lot more likely compile and is more portable, 2) it is a lot simpler and acts as reference. In fact, we'd like to diverge the fp32 version in the direction of being pure CUDA (e.g. do not even call cuBLAS by default), to be used as an educational reference, maybe even a kernel of a course on CUDA. The "mainline" development concerned with speed will from there on move to the [train_gpt2.cu](train_gpt2.cu) file, which includes mixed precision training.

In the descriptions below I will default to using the fp32 version for now because it is currently more portable and stable, then at the end I will cover to the new mixed precision version.

**Correctness**. First, we can do 10 iterations of training and verify that our code exactly matches and preproduces the numbers from PyTorch:

Expand Down Expand Up @@ -269,6 +271,13 @@ Lastly, I will be a lot more sensitive to complexity in the root folder of the p

- Metal
- [llm.metal](https://github.com/regrettable-username/llm.metal) by @[regrettable-username](https://github.com/regrettable-username): LLM training in simple, raw C/Metal Shading Language

- Zig
- [llm.zig]() by @[saimirbaci](https://github.com/Saimirbaci/llm.zig/): a Zig port of this project

- Go
- [llm.go](https://github.com/joshcarp/llm.go) by @[joshcarp](https://github.com/joshcarp): a Go port of this project

## discussions

Ways of organizing development:
Expand Down
8 changes: 7 additions & 1 deletion train_gpt2.cu
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,7 @@ void cublasCheck(cublasStatus_t status, const char *file, int line)
#define cublasCheck(status) { cublasCheck((status), __FILE__, __LINE__); }

// GPU helper functions for atomicAdd on smaller than 32-bit types
#ifdef ENABLE_BF16
__device__ void atomicAddX(__nv_bfloat16* addr, __nv_bfloat16 val) {
uintptr_t ptr_val = reinterpret_cast<uintptr_t>(addr);
__nv_bfloat162* ptr_bf16 = reinterpret_cast<__nv_bfloat162*>(ptr_val & ~uintptr_t(0x3));
Expand All @@ -105,6 +106,9 @@ __device__ void atomicAddX(__nv_bfloat16* addr, __nv_bfloat16 val) {
: __halves2bfloat162(val, __ushort_as_bfloat16(0));
atomicAdd(ptr_bf16, add_val);
}
#endif

#ifdef ENABLE_FP16
__device__ void atomicAddX(half* addr, half val) {
uintptr_t ptr_val = reinterpret_cast<uintptr_t>(addr);
half2* ptr_fp16 = reinterpret_cast<half2*>(ptr_val & ~uintptr_t(0x3));
Expand All @@ -114,6 +118,8 @@ __device__ void atomicAddX(half* addr, half val) {
: __halves2half2(val, __ushort_as_half(0));
atomicAdd(ptr_fp16, add_val);
}
#endif

__device__ void atomicAddX(float* addr, float val) {
atomicAdd(addr, val);
}
Expand Down Expand Up @@ -1666,7 +1672,7 @@ void gpt2_forward(GPT2 *model, int* inputs, int* targets, int B, int T) {
fill_in_activation_sizes(model->act_sizes, B, T, model->config);
size_t num_activations = 0;
for (size_t i = 0; i < NUM_ACTIVATION_TENSORS; i++) {
num_activations += model->act_sizes[i] * sizeof(floatX);
num_activations += model->act_sizes[i];
}
model->num_activations = num_activations;
model->acts_memory = malloc_and_point_activations(&model->acts, model->act_sizes);
Expand Down
31 changes: 23 additions & 8 deletions train_gpt2.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
import os
import math
import struct
from contextlib import nullcontext
from dataclasses import dataclass

import numpy as np
Expand Down Expand Up @@ -321,6 +322,8 @@ def write_tokenizer(enc, filename):
parser = argparse.ArgumentParser()
parser.add_argument("--write_tensors", type=int, default=1, help="write tensors to disk")
parser.add_argument("--inference_only", type=int, default=0, help="only run inference")
parser.add_argument("--dtype", type=str, default="float32", help="float32|float16|bfloat16")
parser.add_argument("--device", type=str, default="", help="by default we autodetect, or set it here")
parser.add_argument("--compile", type=int, default=0, help="torch.compile the model")
parser.add_argument("--tensorcores", type=int, default=0, help="use tensorcores")
parser.add_argument("--num_iterations", type=int, default=10, help="number of iterations to run")
Expand All @@ -329,15 +332,24 @@ def write_tokenizer(enc, filename):
args = parser.parse_args()
B, T = args.batch_size, args.sequence_length
assert 1 <= T <= 1024
assert args.dtype in {"float32", "float16", "bfloat16"}

# select a reasonable device to run on
device = "cpu"
if torch.cuda.is_available():
device = "cuda"
elif hasattr(torch.backends, "mps") and torch.backends.mps.is_available():
device = "mps"
# select the device
if args.device:
device = args.device
else:
# attempt to autodetect the device
device = "cpu"
if torch.cuda.is_available():
device = "cuda"
elif hasattr(torch.backends, "mps") and torch.backends.mps.is_available():
device = "mps"
print(f"using device: {device}")

# create a context manager following the desired dtype and device
ptdtype = {'float32': torch.float32, 'bfloat16': torch.bfloat16, 'float16': torch.float16}[args.dtype]
ctx = torch.amp.autocast(device_type="cuda", dtype=ptdtype) if device == "cuda" else nullcontext()

# seed the random number generators
torch.manual_seed(42)
if torch.cuda.is_available():
Expand All @@ -358,7 +370,8 @@ def write_tokenizer(enc, filename):
model.train()
model.to(device)
if args.compile:
config.coordinate_descent_tuning = True # suggested by @Chillee
if hasattr(config, "coordinate_descent_tuning"):
config.coordinate_descent_tuning = True # suggested by @Chillee
print("compiling the model...")
model = torch.compile(model)

Expand Down Expand Up @@ -398,6 +411,7 @@ def get_batch():

# do one forward pass to generate ground truth for our C tests
if not args.inference_only and args.write_tensors:
assert args.dtype == "float32", "right now can only write tensors in float32"
logits, loss = model(x, y)
loss.backward()
write_model(model, "gpt2_124M.bin")
Expand All @@ -410,7 +424,8 @@ def get_batch():
torch.cuda.reset_peak_memory_stats()
for i in range(args.num_iterations):
t0 = time.time()
logits, loss = model(x, y)
with ctx:
logits, loss = model(x, y)
if not args.inference_only:
optimizer.zero_grad()
del logits
Expand Down

0 comments on commit 1d1ea11

Please sign in to comment.