forked from KellerJordan/modded-nanogpt
-
Notifications
You must be signed in to change notification settings - Fork 1
/
train_gpt2.py
executable file
·499 lines (427 loc) · 20.7 KB
/
train_gpt2.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
import os
import sys
import uuid
import glob
from dataclasses import dataclass
import numpy as np
import torch
from torch import nn
import torch.nn.functional as F
import torch.distributed as dist
import torch._inductor.config as config
from torch.nn.parallel import DistributedDataParallel as DDP
from soap import SOAP
with open(sys.argv[0]) as f:
code = f.read()
# -----------------------------------------------------------------------------
# OrthgonalNesterov optimizer
class OrthogonalNesterov(torch.optim.Optimizer):
def __init__(self, params, lr=0.02, momentum=0.9, nesterov=True, zeropower_iters=5):
defaults = dict(lr=lr, momentum=momentum, nesterov=nesterov, zeropower_iters=zeropower_iters)
super().__init__(params, defaults)
def step(self):
for group in self.param_groups:
lr = group['lr']
momentum = group['momentum']
for p in group['params']:
g = p.grad
if g is None:
continue
state = self.state[p]
state['steps'] = state.get('steps', 0) + 1
if 'momentum_buffer' not in state:
state['momentum_buffer'] = torch.zeros_like(g)
buf = state['momentum_buffer']
buf.mul_(momentum).add_(g)
g = g.add(buf, alpha=momentum) if group['nesterov'] else buf
update = zeroth_power_via_newtonschulz5(g, steps=group['zeropower_iters'])
p.data.add_(update, alpha=-lr)
@torch.compile
def zeroth_power_via_newtonschulz5(G, steps=5, eps=1e-7):
"""
Newton-Schulz iteration to compute the zeroth power / orthogonalization of G. We opt to use a
quintic iteration whose coefficients are selected to maximize the slope at zero. It turns out
to be empirically effective to keep increasing the slope of the quintic at zero even beyond the
point where it no longer converges to one everywhere after repeated application (so long as it
stays relatively close to 1 across the interval). Our usage of a Newton-Schulz iteration as the
orthogonalization method traces to Bernstein & Newhouse (2024) https://arxiv.org/abs/2409.20325
who suggested its use for computing the preconditioners of Shampoo.
"""
assert len(G.shape) == 2
a, b, c = (3.4445, -4.7750, 2.0315)
X = G.bfloat16() / (G.norm() + eps) # ensure top singular value <= 1
if G.size(0) > G.size(1):
X = X.T
for _ in range(steps):
A = X @ X.T
B = A @ X
X = a * X + b * B + c * A @ B
if G.size(0) > G.size(1):
X = X.T
return X.to(G.dtype)
class CombinedOptimizer:
def __init__(self, optimizers):
assert all(len(opt.param_groups) == 1 for opt in optimizers)
self.optimizers = optimizers
self.param_groups = [pg for opt in self.optimizers for pg in opt.param_groups]
self.base_lrs = [opt.param_groups[0]['lr'] for opt in self.optimizers]
def step(self):
for opt in self.optimizers:
opt.step()
def zero_grad(self, **kwargs):
for opt in self.optimizers:
opt.zero_grad(**kwargs)
def scale_lrs(self, lr_scale):
for base_lr, opt in zip(self.base_lrs, self.optimizers):
opt.param_groups[0]['lr'] = base_lr * lr_scale
def state_dict(self):
return [opt.state_dict() for opt in self.optimizers]
# -----------------------------------------------------------------------------
# PyTorch nn.Module definitions for the GPT-2 model
class Rotary(torch.nn.Module):
def __init__(self, dim, base=10000):
super().__init__()
inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float() / dim))
self.register_buffer("inv_freq", inv_freq)
self.seq_len_cached = None
self.cos_cached = None
self.sin_cached = None
def forward(self, x):
seq_len = x.shape[1]
if seq_len != self.seq_len_cached:
self.seq_len_cached = seq_len
t = torch.arange(seq_len, device=x.device).type_as(self.inv_freq)
freqs = torch.outer(t, self.inv_freq).to(x.device)
self.cos_cached = freqs.cos()
self.sin_cached = freqs.sin()
return self.cos_cached[None, :, None, :], self.sin_cached[None, :, None, :]
def apply_rotary_emb(x, cos, sin):
assert x.ndim == 4 # multihead attention
d = x.shape[3]//2
x1 = x[..., :d]
x2 = x[..., d:]
y1 = x1 * cos + x2 * sin
y2 = x1 * (-sin) + x2 * cos
return torch.cat([y1, y2], 3)
def rmsnorm(x0, eps=1e-6):
x = x0.float()
x = x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + eps)
return x.type_as(x0)
class CausalSelfAttention(nn.Module):
def __init__(self, config):
super().__init__()
self.n_head = config.n_head
self.n_embd = config.n_embd
self.head_dim = self.n_embd // self.n_head
assert self.n_embd % self.n_head == 0
# key, query, value projections for all heads, but in a batch
self.c_attn = nn.Linear(self.n_embd, 3 * self.n_embd, bias=False)
# output projection
self.c_proj = nn.Linear(self.n_embd, self.n_embd, bias=False)
self.rotary = Rotary(self.head_dim)
def forward(self, x):
B, T, C = x.size() # batch size, sequence length, embedding dimensionality (n_embd)
# calculate query, key, values for all heads in batch and move head forward to be the batch dim
qkv = self.c_attn(x)
q, k, v = qkv.split(self.n_embd, dim=2)
k = k.view(B, T, self.n_head, self.head_dim)
q = q.view(B, T, self.n_head, self.head_dim)
v = v.view(B, T, self.n_head, self.head_dim)
cos, sin = self.rotary(q)
q = apply_rotary_emb(q, cos, sin)
k = apply_rotary_emb(k, cos, sin)
y = F.scaled_dot_product_attention(q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2), is_causal=True)
y = y.transpose(1, 2).contiguous().view(B, T, C) # re-assemble all head outputs side by side
# output projection
y = self.c_proj(y)
return y
class MLP(nn.Module):
def __init__(self, config):
super().__init__()
self.c_fc = nn.Linear(config.n_embd, 4 * config.n_embd, bias=False)
self.c_proj = nn.Linear(4 * config.n_embd, config.n_embd, bias=False)
def forward(self, x):
x = self.c_fc(x)
x = F.gelu(x)
x = self.c_proj(x)
return x
class Block(nn.Module):
def __init__(self, config):
super().__init__()
self.attn = CausalSelfAttention(config)
self.mlp = MLP(config)
self.attn_scale = (1 / (2 * config.n_layer)**0.5)
def forward(self, x):
x = x + self.attn_scale * self.attn(rmsnorm(x))
x = x + self.mlp(rmsnorm(x))
return x
# -----------------------------------------------------------------------------
# The main GPT-2 model
@dataclass
class GPTConfig:
vocab_size: int = 50257
n_layer: int = 12
n_head: int = 12
n_embd: int = 768
class GPT(nn.Module):
def __init__(self, config):
super().__init__()
self.config = config
self.transformer = nn.ModuleDict(dict(
wte = nn.Embedding(config.vocab_size, config.n_embd),
h = nn.ModuleList([Block(config) for _ in range(config.n_layer)]),
))
self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
self.transformer.wte.weight = self.lm_head.weight # https://paperswithcode.com/method/weight-tying
def forward(self, idx, targets=None, return_logits=True):
b, t = idx.size()
pos = torch.arange(0, t, dtype=torch.long, device=idx.device) # shape (t)
# forward the GPT model itself
x = self.transformer.wte(idx) # token embeddings of shape (b, t, n_embd)
for block in self.transformer.h:
x = block(x)
x = rmsnorm(x)
if targets is not None:
# if we are given some desired targets also calculate the loss
logits = self.lm_head(x)
logits = logits.float() # use tf32/fp32 for logits
loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1), ignore_index=-1)
else:
# inference-time mini-optimization: only forward the lm_head on the very last position
logits = self.lm_head(x[:, [-1], :]) # note: using list [-1] to preserve the time dim
logits = logits.float() # use tf32/fp32 for logits
loss = None
# there are performance reasons why not returning logits is prudent, if not needed
if not return_logits:
logits = None
return logits, loss
def configure_optimizers(self, weight_decay, learning_rate, betas):
optimizer = CombinedOptimizer([
torch.optim.AdamW(self.lm_head.parameters(), lr=0.0018, betas=betas, weight_decay=0),
SOAP(self.transformer.h.parameters(), lr=learning_rate, betas=(.95, .95), weight_decay=0, precondition_frequency=10)
#OrthogonalNesterov(self.transformer.h.parameters(), lr=10 * learning_rate, momentum=0.95)
])
return optimizer
# -----------------------------------------------------------------------------
# Our own simple Distributed Data Loader
def _peek_data_shard(filename):
# only reads the header, returns header data
with open(filename, "rb") as f:
# first read the header, which is 256 int32 integers (4 bytes each)
header = np.frombuffer(f.read(256*4), dtype=np.int32)
if header[0] != 20240520:
print("ERROR: magic number mismatch in the data .bin file!")
print("---> HINT: Are you passing in a correct file with --input_bin?")
print("---> HINT: Dataset encoding changed recently, re-run data prepro or refer again to README")
print("---> HINT: For example re-run: `python dev/data/tinyshakespeare.py`, then re-try")
exit(1)
assert header[1] == 1, "unsupported version"
ntok = header[2] # number of tokens (claimed)
return ntok # for now just return the number of tokens
def _load_data_shard(filename):
with open(filename, "rb") as f:
# first read the header, which is 256 int32 integers (4 bytes each)
header = np.frombuffer(f.read(256*4), dtype=np.int32)
assert header[0] == 20240520, "magic number mismatch in the data .bin file"
assert header[1] == 1, "unsupported version"
ntok = header[2] # number of tokens (claimed)
# the rest of it are tokens, stored as uint16
tokens = np.frombuffer(f.read(), dtype=np.uint16)
assert len(tokens) == ntok, "number of tokens read does not match header?"
return tokens
class DistributedDataLoader:
def __init__(self, filename_pattern, B, T, process_rank, num_processes):
self.process_rank = process_rank
self.num_processes = num_processes
self.B = B
self.T = T
# glob files that match the pattern
self.files = sorted(glob.glob(filename_pattern))
assert len(self.files) > 0, f"did not find any files that match the pattern {filename_pattern}"
# load and validate all data shards, count number of tokens in total
ntok_total = 0
for fname in self.files:
shard_ntok = _peek_data_shard(fname)
assert shard_ntok >= num_processes * B * T + 1
ntok_total += int(shard_ntok)
self.ntok_total = ntok_total
# kick things off
self.reset()
def reset(self):
self.current_shard = 0
self.current_position = self.process_rank * self.B * self.T
self.tokens = _load_data_shard(self.files[self.current_shard])
def advance(self): # advance to next data shard
self.current_shard = (self.current_shard + 1) % len(self.files)
self.current_position = self.process_rank * self.B * self.T
self.tokens = _load_data_shard(self.files[self.current_shard])
def next_batch(self):
B = self.B
T = self.T
buf = self.tokens[self.current_position : self.current_position+B*T+1]
buf = torch.tensor(buf.astype(np.int32), dtype=torch.long)
x = (buf[:-1]).view(B, T) # inputs
y = (buf[1:]).view(B, T) # targets
# advance current position and load next shard if necessary
self.current_position += B * T * self.num_processes
if self.current_position + (B * T * self.num_processes + 1) > len(self.tokens):
self.advance()
return x.cuda(), y.cuda()
# -----------------------------------------------------------------------------
# int main
def print0(*args, **kwargs):
# modified print that only prints from the master process
# if this is not a distributed run, it's just a print
if int(os.environ.get("RANK", 0)) == 0:
print(*args, **kwargs)
if __name__ == "__main__":
import time
import argparse
print0(f"Running pytorch {torch.version.__version__}")
parser = argparse.ArgumentParser()
# file system input / output
parser.add_argument("--input_bin", type=str, help="input .bin to train on")
parser.add_argument("--input_val_bin", type=str, help="input .bin to eval validation loss on")
parser.add_argument("--model", type=str, default="d12", help="d12|d24|d36|d48")
# token layout for each step of the optimization
parser.add_argument("--batch_size", type=int, default=4, help="batch size, in units of #batch dimensions")
parser.add_argument("--accumulation", type=int, default=1)
parser.add_argument("--sequence_length", type=int, default=64, help="sequence length")
# workload (number of steps)
parser.add_argument("--num_iterations", type=int, default=10, help="number of iterations to run")
# optimization
parser.add_argument("--learning_rate", type=float, default=1e-4, help="learning rate warmup iterations")
parser.add_argument("--warmup_iters", type=int, default=0, help="learning rate warmup iterations")
parser.add_argument("--warmdown_iters", type=int, default=0, help="learning rate warmdown iterations")
parser.add_argument("--weight_decay", type=float, default=0.0, help="weight decay")
# evaluation
parser.add_argument("--val_loss_every", type=int, default=0, help="every how many steps to evaluate val loss?")
parser.add_argument("--val_max_steps", type=int, default=20, help="how many batches of val to average?")
parser.add_argument("--save_every", type=int, default=0, help="every how many steps to save the checkpoint")
args = parser.parse_args()
# args error checking and convenience variables
B, T = args.batch_size, args.sequence_length
assert args.model in {"d12", "d24", "d36", "d48"}
# set up DDP (distributed data parallel). torchrun sets this env variable
assert torch.cuda.is_available()
dist.init_process_group(backend='nccl')
ddp_rank = int(os.environ['RANK'])
ddp_local_rank = int(os.environ['LOCAL_RANK'])
ddp_world_size = int(os.environ['WORLD_SIZE'])
device = f'cuda:{ddp_local_rank}'
torch.cuda.set_device(device)
print(f"using device: {device}")
master_process = (ddp_rank == 0) # this process will do logging, checkpointing etc.
# load tokens
train_loader = DistributedDataLoader(args.input_bin, B, T, ddp_rank, ddp_world_size)
print0(f"Training DataLoader: total number of tokens: {train_loader.ntok_total} across {len(train_loader.files)} files")
val_loader = DistributedDataLoader(args.input_val_bin, B, T, ddp_rank, ddp_world_size)
print0(f"Validation DataLoader: total number of tokens: {val_loader.ntok_total} across {len(val_loader.files)} files")
x, y = train_loader.next_batch()
# init the model from scratch
num_vocab = 50257
model_config = {
"d12": GPTConfig(vocab_size=num_vocab, n_layer=12, n_head=12, n_embd=768),
"d24": GPTConfig(vocab_size=num_vocab, n_layer=24, n_head=16, n_embd=1024),
"d36": GPTConfig(vocab_size=num_vocab, n_layer=36, n_head=20, n_embd=1280),
"d48": GPTConfig(vocab_size=num_vocab, n_layer=48, n_head=25, n_embd=1600),
}[args.model]
model = GPT(model_config)
model = model.cuda()
if hasattr(config, "coordinate_descent_tuning"):
config.coordinate_descent_tuning = True # suggested by @Chillee
print0("compiling the model...")
model = torch.compile(model)
# here we wrap model into DDP container
model = DDP(model, device_ids=[ddp_local_rank])
raw_model = model.module # always contains the "raw" unwrapped model
# set up a context manager following the desired dtype and device
ctx = torch.amp.autocast(device_type='cuda', dtype=torch.bfloat16)
# init the optimizer
optimizer = raw_model.configure_optimizers(weight_decay=args.weight_decay,
learning_rate=args.learning_rate, betas=(0.9, 0.95))
# learning rate decay scheduler (linear warmup and warmdown)
def get_lr(it):
assert it <= args.num_iterations
# 1) linear warmup for warmup_iters steps
if it < args.warmup_iters:
return (it+1) / args.warmup_iters
# 2) constant lr for a while
elif it < args.num_iterations - args.warmdown_iters:
return 1.0
# 3) linear warmdown
else:
decay_ratio = (args.num_iterations - it) / args.warmdown_iters
return decay_ratio
run_id = str(uuid.uuid4())
if master_process:
os.makedirs('logs/%s' % run_id, exist_ok=True)
logfile = 'logs/%s/log.txt' % run_id
# create the empty log file
with open(logfile, "w") as f:
pass
for step in range(args.num_iterations + 1):
last_step = (step == args.num_iterations)
# once in a while evaluate the validation dataset
if (last_step or (args.val_loss_every > 0 and step % args.val_loss_every == 0)):
model.eval()
val_loader.reset()
val_loss = 0.0
for _ in range(args.val_max_steps):
with torch.no_grad(): # I want to use ctx here but it causes a torch.compile error
x_val, y_val = val_loader.next_batch()
_, loss = model(x_val, y_val, return_logits=False)
val_loss += loss
dist.all_reduce(val_loss, op=dist.ReduceOp.AVG)
val_loss /= args.val_max_steps
# log val loss to console and to logfile
print0(f"val loss {val_loss}")
if master_process and logfile is not None:
with open(logfile, "a") as f:
f.write("s:%d tel:%f\n" % (step, val_loss))
# save the state of the training process
if master_process and (last_step or (args.save_every > 0 and step % args.save_every == 0)):
log = dict(step=step, args=args.__dict__, code=code, model=raw_model.state_dict(), optimizer=optimizer.state_dict())
torch.save(log, 'logs/%s/state_step%06d.pt' % (run_id, step))
# bit confusing: we want to make sure to eval on 0th iteration
# but also after the very last iteration. so we loop for step <= num_iterations
# instead of just < num_iterations (one extra due to <=), only to do
# the validation/sampling one last time, and then we break right here as we're done.
if last_step:
break
torch.cuda.synchronize()
t0 = time.time()
# --------------- TRAINING SECTION BEGIN -----------------
model.train()
for _ in range(args.accumulation):
# forward pass
with ctx:
_, loss = model(x, y, return_logits=False)
train_loss = loss.detach()
# advance the dataset for the next batch
x, y = train_loader.next_batch()
# backward pass
loss.backward()
for p in model.parameters():
p.grad /= args.accumulation
# determine and set the learning rate for this iteration
lr_scale = get_lr(step)
optimizer.scale_lrs(lr_scale)
# step the optimizer
optimizer.step()
optimizer.zero_grad(set_to_none=True)
# --------------- TRAINING SECTION END -------------------
# everything that follows now is just diagnostics, prints, logging, etc.
torch.cuda.synchronize()
t1 = time.time()
dist.all_reduce(train_loss, op=dist.ReduceOp.AVG)
tokens_per_second = ddp_world_size * B * T / (t1 - t0)
print0(f"step {step+1:4d}/{args.num_iterations} | train loss {train_loss.item():.4f} | lr_scale {lr_scale:.2e} | ({(t1-t0)*1000:.2f} ms | {tokens_per_second:.0f} tok/s)")
# log training loss to logfile
if master_process:
with open(logfile, "a") as f:
f.write("s:%d trl:%f\n" % (step, train_loss.item()))
print0(f"peak memory consumption: {torch.cuda.max_memory_allocated() // 1024 // 1024} MiB")
# -------------------------------------------------------------------------
# clean up nice
dist.destroy_process_group()