-
Notifications
You must be signed in to change notification settings - Fork 183
/
74cba1d4-da56-4334-9622-e0aa960dfe3f.txt
2165 lines (2092 loc) · 134 KB
/
74cba1d4-da56-4334-9622-e0aa960dfe3f.txt
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
import os
import sys
with open(sys.argv[0]) as f:
code = f.read() # read the code of this file ASAP, for logging
import uuid
import glob
import time
import contextlib
from dataclasses import dataclass
import numpy as np
import torch
from torch import nn
import torch.nn.functional as F
import torch.distributed as dist
import torch._inductor.config as config
from torch.nn.parallel import DistributedDataParallel as DDP
# Use of FlexAttention contributed by @KoszarskyB
from torch.nn.attention.flex_attention import flex_attention, create_block_mask
flex_attention = torch.compile(flex_attention, dynamic=False)
create_block_mask = torch.compile(create_block_mask, dynamic=False)
# -----------------------------------------------------------------------------
# Muon optimizer
def zeropower_via_svd(G, steps=None):
U, S, V = G.svd()
return U @ V.T
@torch.compile
def zeropower_via_newtonschulz5(G, steps=10, eps=1e-7):
"""
Newton-Schulz iteration to compute the zeroth power / orthogonalization of G. We opt to use a
quintic iteration whose coefficients are selected to maximize the slope at zero. For the purpose
of minimizing steps, it turns out to be empirically effective to keep increasing the slope at
zero even beyond the point where the iteration no longer converges all the way to one everywhere
on the interval. This iteration therefore does not produce UV^T but rather something like US'V^T
where S' is diagonal with S_{ii}' ~ Uniform(0.5, 1.5), which turns out not to hurt model
performance at all relative to UV^T, where USV^T = G is the SVD.
"""
assert len(G.shape) == 2
a, b, c = (3.4445, -4.7750, 2.0315)
X = G.bfloat16()
X /= (X.norm() + eps) # ensure top singular value <= 1
if G.size(0) > G.size(1):
X = X.T
for _ in range(steps):
A = X @ X.T
B = b * A + c * A @ A # adapted from suggestion by @jxbz, @leloykun, and @YouJiacheng
X = a * X + B @ X
if G.size(0) > G.size(1):
X = X.T
return X
zeropower_backends = dict(svd=zeropower_via_svd, newtonschulz5=zeropower_via_newtonschulz5)
class Muon(torch.optim.Optimizer):
"""
Muon - MomentUm Orthogonalized by Newton-schulz
Muon internally runs standard SGD-momentum, and then performs an orthogonalization post-
processing step, in which each 2D parameter's update is replaced with the nearest orthogonal
matrix. To efficiently orthogonalize each update, we use a Newton-Schulz iteration, which has
the advantage that it can be stably run in bfloat16 on the GPU.
Some warnings:
- This optimizer assumes that all parameters passed in are 2D.
- It should not be used for the embedding layer, the final fully connected layer, or any {0,1}-D
parameters; those should all be optimized by a standard method (e.g., AdamW).
- To use it with 4D convolutional filters, it works well to just flatten their last 3 dimensions.
- We believe it is unlikely to work well for training with small batch size.
- We believe it may not work well for finetuning pretrained models, but we haven't tested this.
- We have not yet tried this optimizer for training scenarios larger than NanoGPT (124M).
Arguments:
lr: The learning rate used by the internal SGD.
momentum: The momentum used by the internal SGD.
nesterov: Whether to use Nesterov-style momentum in the internal SGD. (recommended)
backend: The chosen backend for the orthogonalization step. (recommended: 'newtonschulz5')
backend_steps: The number of iteration steps to use in the backend, if it is iterative.
"""
def __init__(self, params, lr=0.02, momentum=0.95, nesterov=True,
backend='newtonschulz5', backend_steps=5):
defaults = dict(lr=lr, momentum=momentum, nesterov=nesterov, backend=backend, backend_steps=backend_steps)
super().__init__(params, defaults)
def step(self):
for group in self.param_groups:
lr = group['lr']
momentum = group['momentum']
zeropower_backend = zeropower_backends[group['backend']]
# generate weight updates in distributed fashion
total_params = sum(p.numel() for p in group['params'])
updates_flat = torch.zeros(total_params, device='cuda', dtype=torch.bfloat16)
curr_idx = 0
for i, p in enumerate(group['params']):
# luckily this will perfectly distribute a transformer with multiple of 4 layers to 8 GPUs
if i % int(os.environ['WORLD_SIZE']) == int(os.environ['RANK']):
g = p.grad
assert g is not None
state = self.state[p]
if 'momentum_buffer' not in state:
state['momentum_buffer'] = torch.zeros_like(g)
buf = state['momentum_buffer']
buf.mul_(momentum).add_(g)
g = g.add(buf, alpha=momentum) if group['nesterov'] else buf
g = zeropower_backend(g, steps=group['backend_steps'])
g *= max(1, g.size(0)/g.size(1))**0.5
updates_flat[curr_idx:curr_idx+p.numel()] = g.flatten()
curr_idx += p.numel()
# sync updates across devices. we are not memory-constrained so can do this simple deserialization
dist.all_reduce(updates_flat, op=dist.ReduceOp.SUM)
# deserialize and apply updates
curr_idx = 0
for p in group['params']:
g = updates_flat[curr_idx:curr_idx+p.numel()].view_as(p.data).type_as(p.data)
p.data.add_(g, alpha=-lr)
curr_idx += p.numel()
# -----------------------------------------------------------------------------
# PyTorch nn.Module definitions for the GPT-2 model
def norm(x):
return F.rms_norm(x, (x.size(-1),))
class CastedLinear(nn.Linear):
def __init__(self, in_features, out_features):
super().__init__(in_features, out_features, bias=False)
def forward(self, x):
return F.linear(x, self.weight.to(x.dtype))
class Rotary(torch.nn.Module):
def __init__(self, dim, base=10000):
super().__init__()
self.register_buffer('inv_freq', (1 / base) ** (torch.arange(0, dim, 2) / dim))
self.seq_len_cached = None
self.cos_cached = None
self.sin_cached = None
def forward(self, x):
seq_len = x.shape[1]
if seq_len != self.seq_len_cached:
t = torch.arange(seq_len, device=x.device)
freqs = torch.outer(t, self.inv_freq)
self.seq_len_cached = seq_len
self.cos_cached = freqs.cos()
self.sin_cached = freqs.sin()
cos, sin = self.cos_cached[None, :, None, :], self.sin_cached[None, :, None, :]
# apply_rotary_emb(x, cos, sin)
x1, x2 = x.chunk(2, dim=3)
y1 = x1 * cos + x2 * sin
y2 = x1 * (-sin) + x2 * cos
return torch.cat((y1, y2), 3).type_as(x)
class CausalSelfAttention(nn.Module):
def __init__(self, dim, n_head):
super().__init__()
assert dim % n_head == 0
self.n_head = n_head
self.c_q = CastedLinear(dim, dim)
self.c_k = CastedLinear(dim, dim)
self.c_v = CastedLinear(dim, dim)
# value residual lambda
self.lamb = nn.Parameter(torch.tensor(0.5)) # @Grad62304977
# rotary embeddings
self.rotary = Rotary(dim // n_head) # dim // n_head = head_dim
# output projection
self.c_proj = CastedLinear(dim, dim)
self.c_proj.weight.data.zero_() # zero init suggested by @Grad62304977
def forward(self, x, vi, block_mask):
B, T = x.size(0), x.size(1) # batch size, sequence length
assert B == 1, "Must use batch size = 1 for FlexAttention"
q = self.c_q(x).view(B, T, self.n_head, -1)
k = self.c_k(x).view(B, T, self.n_head, -1)
v = self.c_v(x).view(B, T, self.n_head, -1)
v = (1 - self.lamb) * v + self.lamb * vi.view_as(v) # @Grad62304977
q, k = norm(q), norm(k) # QK norm suggested by @Grad62304977
q, k = self.rotary(q), self.rotary(k)
y = flex_attention(q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2), block_mask=block_mask)
y = y.transpose(1, 2).contiguous().view_as(x) # re-assemble all head outputs side by side
y = self.c_proj(y)
return y
class MLP(nn.Module):
def __init__(self, dim):
super().__init__()
self.c_fc = CastedLinear(dim, 4 * dim)
self.c_proj = CastedLinear(4 * dim, dim)
self.c_proj.weight.data.zero_() # zero init suggested by @Grad62304977
def forward(self, x):
x = self.c_fc(x)
x = F.relu(x).square() # https://arxiv.org/abs/2109.08668v2; ~1-2% better than GELU; suggested by @SKYLINEZ007 and @Grad62304977
x = self.c_proj(x)
return x
class Block(nn.Module):
def __init__(self, config):
super().__init__()
self.attn = CausalSelfAttention(config.n_embd, config.n_head)
self.mlp = MLP(config.n_embd)
self.lambdas = nn.Parameter(torch.tensor([1., 0.]))
def forward(self, x, vi, x0, block_mask):
x = self.lambdas[0] * x + self.lambdas[1] * x0
x = x + self.attn(norm(x), vi, block_mask)
x = x + self.mlp(norm(x))
return x
# -----------------------------------------------------------------------------
# The main GPT-2 model
@dataclass
class GPTConfig:
vocab_size : int = 50304
n_layer : int = 12
n_head : int = 6 # head dim 128 suggested by @Grad62304977
n_embd : int = 768
class GPT(nn.Module):
def __init__(self, config):
super().__init__()
# U-net design by @brendanh0gan
self.num_encoder_layers = config.n_layer // 2 # Half of the layers for encoder
self.num_decoder_layers = config.n_layer - self.num_encoder_layers # Remaining for decoder
# Add learnable skip connection weights for decoder layers
self.skip_weights = nn.Parameter(torch.ones(self.num_decoder_layers))
self.transformer = nn.ModuleDict(dict(
wte = nn.Embedding(config.vocab_size, config.n_embd),
# token value embeddings by @KoszarskyB - inspired by @Grad62304977's value residual learning
vte = nn.Embedding(config.vocab_size, config.n_embd*12),
h = nn.ModuleList([Block(config) for _ in range(config.n_layer)]),
))
self.lm_head = CastedLinear(config.n_embd, config.vocab_size)
self.lm_head.weight.data.zero_() # @Grad62304977
def forward(self, idx, target, attn_blocksize):
docs = (idx == 50256).cumsum(0)
def document_causal_mask(b, h, q_idx, kv_idx):
causal_mask = q_idx >= kv_idx
document_mask = docs[q_idx] == docs[kv_idx]
window_mask = q_idx - kv_idx < attn_blocksize
return causal_mask & document_mask & window_mask
S = len(idx)
block_mask = create_block_mask(document_causal_mask, None, None, S, S, device="cuda", _compile=True)
# forward the GPT model itself
x = self.transformer.wte(idx[None]) # token embeddings of shape (b, t, n_embd)
x = norm(x) # @Grad62304977
x0 = x
vi = self.transformer.vte(idx[None]).chunk(12, dim=-1)
# Store outputs for U-Net skip connections
skip_connections = []
# Encoder pass - process only the first half of the blocks
for i in range(self.num_encoder_layers):
x = self.transformer.h[i](x, vi[i], x0, block_mask)
skip_connections.append(x)
# Decoder pass - process the remaining blocks with weighted skip connections
for i in range(self.num_decoder_layers):
x = x + self.skip_weights[i] * skip_connections.pop()
x = self.transformer.h[self.num_encoder_layers + i](x, vi[self.num_encoder_layers+i], x0, block_mask)
x = norm(x)
logits = self.lm_head(x)
logits = 30 * torch.tanh(logits / 30) # @Grad62304977
logits = logits.float()
loss = F.cross_entropy(logits.view(-1, logits.size(-1)), target.view(-1))
return loss
# -----------------------------------------------------------------------------
# Our own simple Distributed Data Loader
def _peek_data_shard(filename):
# only reads the header, returns header data
with open(filename, "rb") as f:
# first read the header, which is 256 int32 integers (4 bytes each)
header = np.frombuffer(f.read(256*4), dtype=np.int32)
if header[0] != 20240520:
print("ERROR: magic number mismatch in the data .bin file!")
print("---> HINT: Are you passing in a correct file with --input_bin?")
print("---> HINT: Dataset encoding changed recently, re-run data prepro or refer again to README")
print("---> HINT: For example re-run: `python dev/data/tinyshakespeare.py`, then re-try")
exit(1)
assert header[1] == 1, "unsupported version"
ntok = header[2] # number of tokens (claimed)
return ntok # for now just return the number of tokens
def _load_data_shard(filename):
with open(filename, "rb") as f:
# first read the header, which is 256 int32 integers (4 bytes each)
header = np.frombuffer(f.read(256*4), dtype=np.int32)
assert header[0] == 20240520, "magic number mismatch in the data .bin file"
assert header[1] == 1, "unsupported version"
ntok = header[2] # number of tokens (claimed)
# the rest of it are tokens, stored as uint16
tokens = np.frombuffer(f.read(), dtype=np.uint16)
assert len(tokens) == ntok, "number of tokens read does not match header?"
return tokens
class DistributedDataLoader:
def __init__(self, filename_pattern, T, process_rank, num_processes):
self.process_rank = process_rank
self.num_processes = num_processes
self.T = T
# glob files that match the pattern
self.files = sorted(glob.glob(filename_pattern))
assert len(self.files) > 0, f"did not find any files that match the pattern {filename_pattern}"
# load and validate all data shards, count number of tokens in total
ntok_total = 0
for fname in self.files:
shard_ntok = _peek_data_shard(fname)
assert shard_ntok >= num_processes * T + 1
ntok_total += int(shard_ntok)
self.ntok_total = ntok_total
self.reset()
def reset(self):
self.current_shard = -1
self.advance()
def advance(self): # advance to next data shard
self.current_shard = (self.current_shard + 1) % len(self.files)
self.current_position = self.process_rank * self.T
self.tokens = _load_data_shard(self.files[self.current_shard])
def next_batch(self):
batch_size = self.T * self.num_processes
buf = self.tokens[self.current_position:self.current_position+self.T+1]
buf = torch.tensor(buf.astype(np.int32), dtype=torch.long)
x = buf[:-1] # inputs
y = buf[1:] # targets
# advance current position and load next shard if necessary
self.current_position += batch_size
if self.current_position + batch_size >= len(self.tokens):
self.advance()
return x.cuda(), y.cuda()
# -----------------------------------------------------------------------------
# int main
@dataclass
class Hyperparameters:
# data hyperparams
input_bin : str = 'data/fineweb10B/fineweb_train_*.bin' # input .bin to train on
input_val_bin : str = 'data/fineweb10B/fineweb_val_*.bin' # input .bin to eval validation loss on
# optimization hyperparams
batch_size : int = 8 # batch size, in sequences, across all devices
sequence_length : int = 64*1024 # sequence length, in tokens
num_iterations : int = 1530 # number of iterations to run
warmup_iters : int = 0
cooldown_iters : int = 600 # number of iterations of linear warmup/cooldown for triangular or trapezoidal schedule
weight_decay : float = 0
# evaluation and logging hyperparams
val_loss_every : int = 125 # every how many steps to evaluate val loss? 0 for only at the end
val_tokens : int = 10485760 # how many tokens of validation data? it's important to keep this fixed for consistent comparisons
save_every : int = 0 # every how many steps to save the checkpoint? 0 for only at the end
args = Hyperparameters()
# set up DDP (distributed data parallel). torchrun sets this env variable
assert torch.cuda.is_available()
dist.init_process_group(backend='nccl')
ddp_rank = int(os.environ['RANK'])
ddp_local_rank = int(os.environ['LOCAL_RANK'])
ddp_world_size = int(os.environ['WORLD_SIZE'])
device = f'cuda:{ddp_local_rank}'
torch.cuda.set_device(device)
print(f"using device: {device}")
master_process = (ddp_rank == 0) # this process will do logging, checkpointing etc.
# begin logging
logfile = None
if master_process:
run_id = str(uuid.uuid4())
logdir = 'logs/%s/' % run_id
os.makedirs(logdir, exist_ok=True)
logfile = 'logs/%s.txt' % run_id
# create the log file
with open(logfile, "w") as f:
# begin the log by printing this file (the Python code)
f.write(code)
f.write('='*100 + '\n')
def print0(s, logonly=False):
if master_process:
with open(logfile, "a") as f:
if not logonly:
print(s)
f.write(s+'\n')
# log information about the hardware/software environment this is running on
# and print the full `nvidia-smi` to file
print0(f"Running pytorch {torch.version.__version__} compiled for CUDA {torch.version.cuda}\nnvidia-smi:")
import subprocess
result = subprocess.run(['nvidia-smi'], stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True)
print0(f'{result.stdout}', logonly=True)
print0('='*100, logonly=True)
# convenience variables
T = args.sequence_length
# calculate the number of steps to take in the val loop.
assert args.val_tokens % (T * ddp_world_size) == 0
val_steps = args.val_tokens // (T * ddp_world_size)
# calculate the steps of gradient accumulation required to attain the desired global batch size.
assert args.batch_size % (ddp_world_size) == 0
train_accumulation_steps = args.batch_size // ddp_world_size
# load tokens
train_loader = DistributedDataLoader(args.input_bin, T, ddp_rank, ddp_world_size)
val_loader = DistributedDataLoader(args.input_val_bin, T, ddp_rank, ddp_world_size)
print0(f"Training DataLoader: total number of tokens: {train_loader.ntok_total} across {len(train_loader.files)} files")
print0(f"Validation DataLoader: total number of tokens: {val_loader.ntok_total} across {len(val_loader.files)} files")
print0('='*100, logonly=True)
x, y = train_loader.next_batch()
# there are only 50257 unique GPT-2 tokens; we extend to nearest multiple of 128 for efficiency. suggested to me by @Grad62304977.
# this originates from Karpathy's experiments.
num_vocab = 50304
model = GPT(GPTConfig(vocab_size=num_vocab, n_layer=12, n_head=6, n_embd=768))
model = model.cuda().bfloat16()
for m in model.modules():
if isinstance(m, CastedLinear):
m.float()
if hasattr(config, "coordinate_descent_tuning"):
config.coordinate_descent_tuning = True # suggested by @Chillee
model = torch.compile(model)
# here we wrap model into DDP container
model = DDP(model, device_ids=[ddp_local_rank])
raw_model = model.module # always contains the "raw" unwrapped model
# init the optimizer(s)
optimizer1 = torch.optim.Adam([raw_model.transformer.wte.weight, raw_model.transformer.vte.weight], lr=0.6, betas=(0.8, 0.95), fused=True)
optimizer2 = torch.optim.Adam([raw_model.lm_head.weight], lr=0.008, betas=(0.8, 0.95), fused=True)
params = list(raw_model.transformer.h.parameters())
matrix_params = [p for p in params if p.ndim == 2]
scalar_params = [p for p in params if p.ndim < 2] + [raw_model.skip_weights]
optimizer3 = Muon(matrix_params, lr=0.05, momentum=0.95)
optimizer4 = torch.optim.Adam(scalar_params, lr=0.04, betas=(0.8, 0.95), fused=True) # note that this learning rate is neither sensitive nor tuned
optimizers = [optimizer1, optimizer2, optimizer3, optimizer4]
# learning rate decay scheduler (linear warmup and cooldown)
def get_lr(it):
assert it <= args.num_iterations
# 1) linear warmup for warmup_iters steps
if it < args.warmup_iters:
return (it+1) / args.warmup_iters
# 2) constant lr for a while
elif it < args.num_iterations - args.cooldown_iters:
return 1.0
# 3) linear cooldown
else:
decay_ratio = (args.num_iterations - it) / args.cooldown_iters
return decay_ratio
schedulers = [torch.optim.lr_scheduler.LambdaLR(opt, get_lr) for opt in optimizers]
# Start training loop
training_time_ms = 0
# start the clock
torch.cuda.synchronize()
t0 = time.time()
# begin training
for step in range(args.num_iterations + 1):
last_step = (step == args.num_iterations)
# This effectively ignores timing first 10 steps, which are slower for weird reasons.
# Alternately, and slightly more correctly in terms of benchmarking, we could do 10
# steps with dummy data first, and then re-initialize the model and reset the loader.
if step == 10:
training_time_ms = 0
t0 = time.time()
timed_steps = float('nan') if step <= 11 else (step - 10) + 1 # <= 11 to avoid bug in val
# Set the attention blocksize for the current step, in chunks of 64. By @fernbear.bsky.social
attn_blocksize = torch.tensor(64*((step/args.num_iterations * (1792 - 64) + 64)//64), dtype=torch.int, device='cuda')
# once in a while evaluate the validation dataset
if (last_step or (args.val_loss_every > 0 and step % args.val_loss_every == 0)):
# stop the clock
torch.cuda.synchronize()
training_time_ms += 1000 * (time.time() - t0)
# run validation batches
model.eval()
val_loader.reset()
val_loss = 0.0
for _ in range(val_steps):
with torch.no_grad():
x_val, y_val = val_loader.next_batch()
val_loss += model(x_val, y_val, attn_blocksize=attn_blocksize)
dist.all_reduce(val_loss, op=dist.ReduceOp.AVG)
val_loss /= val_steps
# log val loss to console and to logfile
print0(f'step:{step}/{args.num_iterations} val_loss:{val_loss:.4f} train_time:{training_time_ms:.0f}ms step_avg:{training_time_ms/(timed_steps-1):.2f}ms')
# start the clock again
torch.cuda.synchronize()
t0 = time.time()
if master_process and (last_step or (args.save_every > 0 and step % args.save_every == 0)):
# stop the clock
torch.cuda.synchronize()
training_time_ms += 1000 * (time.time() - t0)
# save the state of the training process
log = dict(step=step, code=code, model=raw_model.state_dict(), optimizers=[opt.state_dict() for opt in optimizers])
torch.save(log, 'logs/%s/state_step%06d.pt' % (run_id, step))
# start the clock again
torch.cuda.synchronize()
t0 = time.time()
# bit confusing: we want to make sure to eval on 0th iteration
# but also after the very last iteration. so we loop for step <= num_iterations
# instead of just < num_iterations (one extra due to <=), only to do
# the validation/sampling one last time, and then we break right here as we're done.
if last_step:
break
# --------------- TRAINING SECTION BEGIN -----------------
model.train()
for i in range(1, train_accumulation_steps+1):
ctx = model.no_sync() if i < train_accumulation_steps else contextlib.nullcontext()
with ctx: # there's no need to sync gradients every accumulation step
# forward pass
loss = model(x, y, attn_blocksize=attn_blocksize)
# advance the dataset for the next batch
x, y = train_loader.next_batch()
# backward pass
loss.backward()
train_loss = loss.detach()
for p in model.parameters():
p.grad /= train_accumulation_steps
# momentum warmup for Muon
frac = min(step/300, 1)
optimizer3.param_groups[0]['momentum'] = (1 - frac) * 0.85 + frac * 0.95
# step the optimizers and schedulers
for opt, sched in zip(optimizers, schedulers):
opt.step()
sched.step()
# null the gradients
model.zero_grad(set_to_none=True)
# --------------- TRAINING SECTION END -------------------
# everything that follows now is just diagnostics, prints, logging, etc.
#dist.all_reduce(train_loss, op=dist.ReduceOp.AVG) # all-reducing the training loss would be more correct in terms of logging, but slower
approx_time = training_time_ms + 1000 * (time.time() - t0)
print0(f"step:{step+1}/{args.num_iterations} train_loss:{train_loss.item():.4f} train_time:{approx_time:.0f}ms step_avg:{approx_time/timed_steps:.2f}ms")
if master_process:
print(f"peak memory consumption: {torch.cuda.max_memory_allocated() // 1024 // 1024} MiB")
# -------------------------------------------------------------------------
# clean up nice
dist.destroy_process_group()
====================================================================================================
Running pytorch 2.6.0.dev20241203+cu124 compiled for CUDA 12.4
nvidia-smi:
Thu Dec 5 01:22:55 2024
+---------------------------------------------------------------------------------------+
| NVIDIA-SMI 535.183.06 Driver Version: 535.183.06 CUDA Version: 12.2 |
|-----------------------------------------+----------------------+----------------------+
| GPU Name Persistence-M | Bus-Id Disp.A | Volatile Uncorr. ECC |
| Fan Temp Perf Pwr:Usage/Cap | Memory-Usage | GPU-Util Compute M. |
| | | MIG M. |
|=========================================+======================+======================|
| 0 NVIDIA H100 80GB HBM3 On | 00000000:19:00.0 Off | 0 |
| N/A 39C P0 76W / 700W | 3MiB / 81559MiB | 0% Default |
| | | Disabled |
+-----------------------------------------+----------------------+----------------------+
| 1 NVIDIA H100 80GB HBM3 On | 00000000:3B:00.0 Off | 0 |
| N/A 31C P0 115W / 700W | 115MiB / 81559MiB | 1% Default |
| | | Disabled |
+-----------------------------------------+----------------------+----------------------+
| 2 NVIDIA H100 80GB HBM3 On | 00000000:4C:00.0 Off | 0 |
| N/A 32C P0 91W / 700W | 22MiB / 81559MiB | 0% Default |
| | | Disabled |
+-----------------------------------------+----------------------+----------------------+
| 3 NVIDIA H100 80GB HBM3 On | 00000000:5D:00.0 Off | 0 |
| N/A 39C P0 119W / 700W | 529MiB / 81559MiB | 1% Default |
| | | Disabled |
+-----------------------------------------+----------------------+----------------------+
| 4 NVIDIA H100 80GB HBM3 On | 00000000:9B:00.0 Off | 0 |
| N/A 40C P0 124W / 700W | 529MiB / 81559MiB | 1% Default |
| | | Disabled |
+-----------------------------------------+----------------------+----------------------+
| 5 NVIDIA H100 80GB HBM3 On | 00000000:BB:00.0 Off | 0 |
| N/A 30C P0 110W / 700W | 529MiB / 81559MiB | 0% Default |
| | | Disabled |
+-----------------------------------------+----------------------+----------------------+
| 6 NVIDIA H100 80GB HBM3 On | 00000000:CB:00.0 Off | 0 |
| N/A 39C P0 115W / 700W | 22MiB / 81559MiB | 0% Default |
| | | Disabled |
+-----------------------------------------+----------------------+----------------------+
| 7 NVIDIA H100 80GB HBM3 On | 00000000:DB:00.0 Off | 0 |
| N/A 31C P0 119W / 700W | 529MiB / 81559MiB | 0% Default |
| | | Disabled |
+-----------------------------------------+----------------------+----------------------+
+---------------------------------------------------------------------------------------+
| Processes: |
| GPU GI CI PID Type Process name GPU Memory |
| ID ID Usage |
|=======================================================================================|
+---------------------------------------------------------------------------------------+
====================================================================================================
Training DataLoader: total number of tokens: 1100000000 across 11 files
Validation DataLoader: total number of tokens: 100000000 across 1 files
====================================================================================================
step:0/1530 val_loss:10.8258 train_time:0ms step_avg:nanms
step:1/1530 train_loss:10.8258 train_time:31693ms step_avg:nanms
step:2/1530 train_loss:10.0716 train_time:31805ms step_avg:nanms
step:3/1530 train_loss:8.3640 train_time:31962ms step_avg:nanms
step:4/1530 train_loss:7.6072 train_time:32123ms step_avg:nanms
step:5/1530 train_loss:7.4646 train_time:32283ms step_avg:nanms
step:6/1530 train_loss:6.9752 train_time:32444ms step_avg:nanms
step:7/1530 train_loss:7.2159 train_time:32604ms step_avg:nanms
step:8/1530 train_loss:6.7513 train_time:32764ms step_avg:nanms
step:9/1530 train_loss:6.6330 train_time:32924ms step_avg:nanms
step:10/1530 train_loss:6.5159 train_time:33083ms step_avg:nanms
step:11/1530 train_loss:6.4850 train_time:114ms step_avg:nanms
step:12/1530 train_loss:6.3626 train_time:273ms step_avg:nanms
step:13/1530 train_loss:6.2707 train_time:432ms step_avg:144.12ms
step:14/1530 train_loss:6.2085 train_time:592ms step_avg:148.09ms
step:15/1530 train_loss:6.1819 train_time:752ms step_avg:150.47ms
step:16/1530 train_loss:6.1151 train_time:912ms step_avg:152.01ms
step:17/1530 train_loss:6.1863 train_time:1071ms step_avg:153.04ms
step:18/1530 train_loss:5.9722 train_time:1232ms step_avg:153.96ms
step:19/1530 train_loss:6.0243 train_time:1392ms step_avg:154.63ms
step:20/1530 train_loss:5.6759 train_time:1552ms step_avg:155.17ms
step:21/1530 train_loss:5.9576 train_time:1712ms step_avg:155.60ms
step:22/1530 train_loss:6.1905 train_time:1871ms step_avg:155.94ms
step:23/1530 train_loss:5.8890 train_time:2031ms step_avg:156.23ms
step:24/1530 train_loss:6.0436 train_time:2191ms step_avg:156.50ms
step:25/1530 train_loss:5.6864 train_time:2351ms step_avg:156.72ms
step:26/1530 train_loss:5.6168 train_time:2511ms step_avg:156.94ms
step:27/1530 train_loss:5.8335 train_time:2671ms step_avg:157.10ms
step:28/1530 train_loss:5.4210 train_time:2831ms step_avg:157.27ms
step:29/1530 train_loss:5.6753 train_time:2991ms step_avg:157.41ms
step:30/1530 train_loss:5.4833 train_time:3150ms step_avg:157.52ms
step:31/1530 train_loss:5.4632 train_time:3311ms step_avg:157.66ms
step:32/1530 train_loss:5.2912 train_time:3471ms step_avg:157.77ms
step:33/1530 train_loss:5.5987 train_time:3631ms step_avg:157.88ms
step:34/1530 train_loss:5.5084 train_time:3791ms step_avg:157.95ms
step:35/1530 train_loss:5.6359 train_time:3950ms step_avg:158.01ms
step:36/1530 train_loss:5.5582 train_time:4111ms step_avg:158.11ms
step:37/1530 train_loss:5.4651 train_time:4271ms step_avg:158.19ms
step:38/1530 train_loss:5.3144 train_time:4431ms step_avg:158.25ms
step:39/1530 train_loss:5.3393 train_time:4591ms step_avg:158.32ms
step:40/1530 train_loss:5.2474 train_time:4751ms step_avg:158.37ms
step:41/1530 train_loss:5.2377 train_time:4911ms step_avg:158.42ms
step:42/1530 train_loss:5.1697 train_time:5071ms step_avg:158.46ms
step:43/1530 train_loss:5.2533 train_time:5231ms step_avg:158.51ms
step:44/1530 train_loss:5.2163 train_time:5391ms step_avg:158.57ms
step:45/1530 train_loss:5.3696 train_time:5551ms step_avg:158.60ms
step:46/1530 train_loss:5.1700 train_time:5711ms step_avg:158.65ms
step:47/1530 train_loss:5.0684 train_time:5872ms step_avg:158.70ms
step:48/1530 train_loss:5.2080 train_time:6032ms step_avg:158.72ms
step:49/1530 train_loss:5.1580 train_time:6191ms step_avg:158.74ms
step:50/1530 train_loss:5.2668 train_time:6351ms step_avg:158.77ms
step:51/1530 train_loss:5.1550 train_time:6512ms step_avg:158.82ms
step:52/1530 train_loss:5.0508 train_time:6671ms step_avg:158.83ms
step:53/1530 train_loss:5.2036 train_time:6831ms step_avg:158.87ms
step:54/1530 train_loss:5.0361 train_time:6992ms step_avg:158.91ms
step:55/1530 train_loss:5.4162 train_time:7151ms step_avg:158.91ms
step:56/1530 train_loss:5.0331 train_time:7312ms step_avg:158.95ms
step:57/1530 train_loss:4.9016 train_time:7472ms step_avg:158.97ms
step:58/1530 train_loss:5.0534 train_time:7631ms step_avg:158.99ms
step:59/1530 train_loss:5.0364 train_time:7791ms step_avg:159.00ms
step:60/1530 train_loss:5.1545 train_time:7951ms step_avg:159.02ms
step:61/1530 train_loss:4.8509 train_time:8111ms step_avg:159.05ms
step:62/1530 train_loss:4.9762 train_time:8271ms step_avg:159.07ms
step:63/1530 train_loss:4.9827 train_time:8432ms step_avg:159.10ms
step:64/1530 train_loss:4.9842 train_time:8592ms step_avg:159.12ms
step:65/1530 train_loss:4.8165 train_time:8752ms step_avg:159.13ms
step:66/1530 train_loss:4.9254 train_time:8912ms step_avg:159.14ms
step:67/1530 train_loss:4.8172 train_time:9072ms step_avg:159.16ms
step:68/1530 train_loss:5.0865 train_time:9233ms step_avg:159.18ms
step:69/1530 train_loss:4.7258 train_time:9393ms step_avg:159.20ms
step:70/1530 train_loss:4.8372 train_time:9553ms step_avg:159.21ms
step:71/1530 train_loss:4.9604 train_time:9713ms step_avg:159.22ms
step:72/1530 train_loss:4.8785 train_time:9873ms step_avg:159.23ms
step:73/1530 train_loss:4.7724 train_time:10032ms step_avg:159.24ms
step:74/1530 train_loss:4.9081 train_time:10192ms step_avg:159.25ms
step:75/1530 train_loss:4.8682 train_time:10352ms step_avg:159.26ms
step:76/1530 train_loss:4.7937 train_time:10512ms step_avg:159.27ms
step:77/1530 train_loss:4.9216 train_time:10672ms step_avg:159.29ms
step:78/1530 train_loss:5.1221 train_time:10832ms step_avg:159.29ms
step:79/1530 train_loss:4.8564 train_time:10992ms step_avg:159.30ms
step:80/1530 train_loss:4.8712 train_time:11151ms step_avg:159.30ms
step:81/1530 train_loss:4.6424 train_time:11312ms step_avg:159.32ms
step:82/1530 train_loss:4.8056 train_time:11471ms step_avg:159.33ms
step:83/1530 train_loss:4.7831 train_time:11631ms step_avg:159.34ms
step:84/1530 train_loss:4.7706 train_time:11792ms step_avg:159.35ms
step:85/1530 train_loss:4.6206 train_time:11952ms step_avg:159.35ms
step:86/1530 train_loss:4.8249 train_time:12112ms step_avg:159.37ms
step:87/1530 train_loss:4.7597 train_time:12272ms step_avg:159.37ms
step:88/1530 train_loss:4.7753 train_time:12432ms step_avg:159.38ms
step:89/1530 train_loss:4.7057 train_time:12593ms step_avg:159.40ms
step:90/1530 train_loss:4.6340 train_time:12752ms step_avg:159.40ms
step:91/1530 train_loss:4.6276 train_time:12912ms step_avg:159.41ms
step:92/1530 train_loss:4.7967 train_time:13073ms step_avg:159.42ms
step:93/1530 train_loss:4.6290 train_time:13233ms step_avg:159.43ms
step:94/1530 train_loss:4.6368 train_time:13392ms step_avg:159.43ms
step:95/1530 train_loss:4.7006 train_time:13552ms step_avg:159.44ms
step:96/1530 train_loss:4.6010 train_time:13713ms step_avg:159.46ms
step:97/1530 train_loss:4.6431 train_time:13872ms step_avg:159.45ms
step:98/1530 train_loss:4.5800 train_time:14032ms step_avg:159.46ms
step:99/1530 train_loss:4.6785 train_time:14193ms step_avg:159.47ms
step:100/1530 train_loss:4.6969 train_time:14353ms step_avg:159.47ms
step:101/1530 train_loss:4.5520 train_time:14512ms step_avg:159.47ms
step:102/1530 train_loss:4.7105 train_time:14672ms step_avg:159.48ms
step:103/1530 train_loss:4.5959 train_time:14833ms step_avg:159.49ms
step:104/1530 train_loss:4.5335 train_time:14993ms step_avg:159.50ms
step:105/1530 train_loss:4.5574 train_time:15152ms step_avg:159.49ms
step:106/1530 train_loss:4.6214 train_time:15313ms step_avg:159.51ms
step:107/1530 train_loss:4.5140 train_time:15473ms step_avg:159.51ms
step:108/1530 train_loss:4.3939 train_time:15632ms step_avg:159.51ms
step:109/1530 train_loss:4.5198 train_time:15793ms step_avg:159.53ms
step:110/1530 train_loss:4.4962 train_time:15953ms step_avg:159.53ms
step:111/1530 train_loss:4.4384 train_time:16113ms step_avg:159.54ms
step:112/1530 train_loss:4.5991 train_time:16273ms step_avg:159.54ms
step:113/1530 train_loss:4.4899 train_time:16433ms step_avg:159.54ms
step:114/1530 train_loss:4.3604 train_time:16594ms step_avg:159.55ms
step:115/1530 train_loss:4.5019 train_time:16755ms step_avg:159.58ms
step:116/1530 train_loss:4.4721 train_time:16921ms step_avg:159.63ms
step:117/1530 train_loss:4.3718 train_time:17086ms step_avg:159.68ms
step:118/1530 train_loss:4.5885 train_time:17249ms step_avg:159.72ms
step:119/1530 train_loss:4.4511 train_time:17414ms step_avg:159.76ms
step:120/1530 train_loss:4.3327 train_time:17577ms step_avg:159.79ms
step:121/1530 train_loss:4.2884 train_time:17740ms step_avg:159.82ms
step:122/1530 train_loss:4.4532 train_time:17905ms step_avg:159.86ms
step:123/1530 train_loss:4.3069 train_time:18068ms step_avg:159.89ms
step:124/1530 train_loss:4.5917 train_time:18232ms step_avg:159.93ms
step:125/1530 train_loss:4.4524 train_time:18396ms step_avg:159.96ms
step:125/1530 val_loss:4.4079 train_time:18442ms step_avg:160.37ms
step:126/1530 train_loss:4.4161 train_time:18560ms step_avg:160.00ms
step:127/1530 train_loss:4.4308 train_time:18726ms step_avg:160.05ms
step:128/1530 train_loss:4.3783 train_time:18890ms step_avg:160.09ms
step:129/1530 train_loss:4.6765 train_time:19054ms step_avg:160.12ms
step:130/1530 train_loss:4.3682 train_time:19218ms step_avg:160.15ms
step:131/1530 train_loss:4.3884 train_time:19382ms step_avg:160.18ms
step:132/1530 train_loss:4.3347 train_time:19547ms step_avg:160.22ms
step:133/1530 train_loss:4.4381 train_time:19711ms step_avg:160.25ms
step:134/1530 train_loss:4.2677 train_time:19874ms step_avg:160.28ms
step:135/1530 train_loss:4.4457 train_time:20038ms step_avg:160.30ms
step:136/1530 train_loss:4.2155 train_time:20202ms step_avg:160.33ms
step:137/1530 train_loss:4.3767 train_time:20365ms step_avg:160.36ms
step:138/1530 train_loss:4.2963 train_time:20529ms step_avg:160.38ms
step:139/1530 train_loss:4.3863 train_time:20694ms step_avg:160.42ms
step:140/1530 train_loss:4.4768 train_time:20856ms step_avg:160.43ms
step:141/1530 train_loss:4.3081 train_time:21019ms step_avg:160.45ms
step:142/1530 train_loss:4.3061 train_time:21184ms step_avg:160.48ms
step:143/1530 train_loss:4.2598 train_time:21347ms step_avg:160.50ms
step:144/1530 train_loss:4.3511 train_time:21511ms step_avg:160.53ms
step:145/1530 train_loss:4.3062 train_time:21674ms step_avg:160.55ms
step:146/1530 train_loss:4.1704 train_time:21837ms step_avg:160.56ms
step:147/1530 train_loss:4.3176 train_time:22001ms step_avg:160.59ms
step:148/1530 train_loss:4.3627 train_time:22166ms step_avg:160.62ms
step:149/1530 train_loss:4.3044 train_time:22330ms step_avg:160.65ms
step:150/1530 train_loss:4.4486 train_time:22494ms step_avg:160.67ms
step:151/1530 train_loss:4.2752 train_time:22657ms step_avg:160.69ms
step:152/1530 train_loss:4.2786 train_time:22820ms step_avg:160.70ms
step:153/1530 train_loss:4.3546 train_time:22985ms step_avg:160.73ms
step:154/1530 train_loss:4.3672 train_time:23149ms step_avg:160.76ms
step:155/1530 train_loss:4.2634 train_time:23314ms step_avg:160.78ms
step:156/1530 train_loss:4.3444 train_time:23477ms step_avg:160.80ms
step:157/1530 train_loss:4.3973 train_time:23641ms step_avg:160.82ms
step:158/1530 train_loss:4.2389 train_time:23803ms step_avg:160.83ms
step:159/1530 train_loss:4.3096 train_time:23967ms step_avg:160.85ms
step:160/1530 train_loss:4.1350 train_time:24132ms step_avg:160.88ms
step:161/1530 train_loss:4.3477 train_time:24295ms step_avg:160.90ms
step:162/1530 train_loss:4.3435 train_time:24458ms step_avg:160.91ms
step:163/1530 train_loss:4.3335 train_time:24621ms step_avg:160.92ms
step:164/1530 train_loss:4.1900 train_time:24785ms step_avg:160.94ms
step:165/1530 train_loss:4.2863 train_time:24949ms step_avg:160.96ms
step:166/1530 train_loss:4.3292 train_time:25113ms step_avg:160.98ms
step:167/1530 train_loss:4.1965 train_time:25276ms step_avg:160.99ms
step:168/1530 train_loss:4.2826 train_time:25439ms step_avg:161.01ms
step:169/1530 train_loss:4.1588 train_time:25603ms step_avg:161.03ms
step:170/1530 train_loss:4.0147 train_time:25766ms step_avg:161.04ms
step:171/1530 train_loss:4.1926 train_time:25930ms step_avg:161.05ms
step:172/1530 train_loss:4.2028 train_time:26093ms step_avg:161.07ms
step:173/1530 train_loss:4.2730 train_time:26255ms step_avg:161.07ms
step:174/1530 train_loss:4.4148 train_time:26418ms step_avg:161.08ms
step:175/1530 train_loss:4.2281 train_time:26581ms step_avg:161.10ms
step:176/1530 train_loss:4.0900 train_time:26742ms step_avg:161.09ms
step:177/1530 train_loss:4.0697 train_time:26906ms step_avg:161.11ms
step:178/1530 train_loss:4.1810 train_time:27069ms step_avg:161.12ms
step:179/1530 train_loss:4.1124 train_time:27232ms step_avg:161.14ms
step:180/1530 train_loss:4.1076 train_time:27395ms step_avg:161.15ms
step:181/1530 train_loss:4.2882 train_time:27557ms step_avg:161.15ms
step:182/1530 train_loss:4.1546 train_time:27719ms step_avg:161.16ms
step:183/1530 train_loss:4.1144 train_time:27882ms step_avg:161.17ms
step:184/1530 train_loss:4.1210 train_time:28045ms step_avg:161.18ms
step:185/1530 train_loss:4.1971 train_time:28207ms step_avg:161.18ms
step:186/1530 train_loss:4.1669 train_time:28371ms step_avg:161.20ms
step:187/1530 train_loss:4.2269 train_time:28534ms step_avg:161.21ms
step:188/1530 train_loss:4.1578 train_time:28833ms step_avg:161.99ms
step:189/1530 train_loss:4.1083 train_time:29156ms step_avg:162.88ms
step:190/1530 train_loss:4.1989 train_time:29317ms step_avg:162.87ms
step:191/1530 train_loss:4.0734 train_time:29481ms step_avg:162.88ms
step:192/1530 train_loss:4.0245 train_time:29644ms step_avg:162.88ms
step:193/1530 train_loss:4.2496 train_time:29808ms step_avg:162.88ms
step:194/1530 train_loss:4.1719 train_time:29971ms step_avg:162.88ms
step:195/1530 train_loss:4.3403 train_time:30134ms step_avg:162.89ms
step:196/1530 train_loss:4.1600 train_time:30297ms step_avg:162.88ms
step:197/1530 train_loss:4.0429 train_time:30460ms step_avg:162.89ms
step:198/1530 train_loss:4.1756 train_time:30622ms step_avg:162.88ms
step:199/1530 train_loss:4.0249 train_time:30786ms step_avg:162.89ms
step:200/1530 train_loss:4.1021 train_time:30948ms step_avg:162.88ms
step:201/1530 train_loss:3.9988 train_time:31112ms step_avg:162.89ms
step:202/1530 train_loss:4.2392 train_time:31274ms step_avg:162.89ms
step:203/1530 train_loss:4.0533 train_time:31436ms step_avg:162.88ms
step:204/1530 train_loss:4.1802 train_time:31599ms step_avg:162.88ms
step:205/1530 train_loss:4.2389 train_time:31761ms step_avg:162.88ms
step:206/1530 train_loss:3.9357 train_time:31922ms step_avg:162.87ms
step:207/1530 train_loss:4.0711 train_time:32087ms step_avg:162.88ms
step:208/1530 train_loss:4.0972 train_time:32250ms step_avg:162.88ms
step:209/1530 train_loss:4.2231 train_time:32413ms step_avg:162.88ms
step:210/1530 train_loss:4.1745 train_time:32575ms step_avg:162.88ms
step:211/1530 train_loss:4.0506 train_time:32737ms step_avg:162.87ms
step:212/1530 train_loss:4.1086 train_time:32899ms step_avg:162.87ms
step:213/1530 train_loss:4.0432 train_time:33062ms step_avg:162.86ms
step:214/1530 train_loss:4.1023 train_time:33225ms step_avg:162.87ms
step:215/1530 train_loss:3.9453 train_time:33389ms step_avg:162.87ms
step:216/1530 train_loss:3.9916 train_time:33550ms step_avg:162.87ms
step:217/1530 train_loss:3.9972 train_time:33713ms step_avg:162.87ms
step:218/1530 train_loss:4.0750 train_time:33875ms step_avg:162.86ms
step:219/1530 train_loss:4.0617 train_time:34038ms step_avg:162.86ms
step:220/1530 train_loss:4.0759 train_time:34202ms step_avg:162.86ms
step:221/1530 train_loss:4.0832 train_time:34365ms step_avg:162.87ms
step:222/1530 train_loss:3.9901 train_time:34527ms step_avg:162.86ms
step:223/1530 train_loss:3.9881 train_time:34691ms step_avg:162.87ms
step:224/1530 train_loss:4.3029 train_time:34853ms step_avg:162.86ms
step:225/1530 train_loss:3.9249 train_time:35016ms step_avg:162.86ms
step:226/1530 train_loss:3.9741 train_time:35178ms step_avg:162.86ms
step:227/1530 train_loss:3.9717 train_time:35340ms step_avg:162.86ms
step:228/1530 train_loss:4.1305 train_time:35506ms step_avg:162.87ms
step:229/1530 train_loss:3.9229 train_time:35672ms step_avg:162.89ms
step:230/1530 train_loss:4.0389 train_time:35837ms step_avg:162.90ms
step:231/1530 train_loss:3.8957 train_time:36004ms step_avg:162.91ms
step:232/1530 train_loss:3.9514 train_time:36170ms step_avg:162.93ms
step:233/1530 train_loss:4.0729 train_time:36336ms step_avg:162.94ms
step:234/1530 train_loss:4.0229 train_time:36502ms step_avg:162.96ms
step:235/1530 train_loss:3.8927 train_time:36669ms step_avg:162.97ms
step:236/1530 train_loss:4.0773 train_time:36835ms step_avg:162.99ms
step:237/1530 train_loss:4.0790 train_time:37001ms step_avg:163.00ms
step:238/1530 train_loss:3.9373 train_time:37168ms step_avg:163.02ms
step:239/1530 train_loss:4.0670 train_time:37334ms step_avg:163.03ms
step:240/1530 train_loss:4.1039 train_time:37500ms step_avg:163.04ms
step:241/1530 train_loss:3.9594 train_time:37664ms step_avg:163.05ms
step:242/1530 train_loss:4.1303 train_time:37832ms step_avg:163.07ms
step:243/1530 train_loss:4.0034 train_time:37997ms step_avg:163.08ms
step:244/1530 train_loss:4.0735 train_time:38164ms step_avg:163.09ms
step:245/1530 train_loss:4.1341 train_time:38330ms step_avg:163.11ms
step:246/1530 train_loss:4.0452 train_time:38496ms step_avg:163.12ms
step:247/1530 train_loss:3.9924 train_time:38662ms step_avg:163.13ms
step:248/1530 train_loss:4.0897 train_time:38828ms step_avg:163.14ms
step:249/1530 train_loss:3.9111 train_time:38994ms step_avg:163.15ms
step:250/1530 train_loss:3.9650 train_time:39159ms step_avg:163.16ms
step:250/1530 val_loss:3.9935 train_time:39206ms step_avg:163.36ms
step:251/1530 train_loss:4.0702 train_time:39326ms step_avg:163.18ms
step:252/1530 train_loss:4.1490 train_time:39495ms step_avg:163.20ms
step:253/1530 train_loss:3.9214 train_time:39662ms step_avg:163.22ms
step:254/1530 train_loss:3.8744 train_time:39828ms step_avg:163.23ms
step:255/1530 train_loss:4.0621 train_time:39993ms step_avg:163.24ms
step:256/1530 train_loss:3.9781 train_time:40160ms step_avg:163.25ms
step:257/1530 train_loss:3.9858 train_time:40326ms step_avg:163.26ms
step:258/1530 train_loss:3.9811 train_time:40492ms step_avg:163.27ms
step:259/1530 train_loss:4.0285 train_time:40658ms step_avg:163.28ms
step:260/1530 train_loss:4.0518 train_time:40825ms step_avg:163.30ms
step:261/1530 train_loss:4.0129 train_time:40991ms step_avg:163.31ms
step:262/1530 train_loss:3.9803 train_time:41158ms step_avg:163.33ms
step:263/1530 train_loss:3.8802 train_time:41324ms step_avg:163.33ms
step:264/1530 train_loss:3.9745 train_time:41491ms step_avg:163.35ms
step:265/1530 train_loss:3.8607 train_time:41657ms step_avg:163.36ms
step:266/1530 train_loss:3.9084 train_time:41823ms step_avg:163.37ms
step:267/1530 train_loss:3.9173 train_time:41989ms step_avg:163.38ms
step:268/1530 train_loss:3.9563 train_time:42154ms step_avg:163.39ms
step:269/1530 train_loss:3.8436 train_time:42320ms step_avg:163.40ms
step:270/1530 train_loss:4.0859 train_time:42486ms step_avg:163.41ms
step:271/1530 train_loss:3.9648 train_time:42653ms step_avg:163.42ms
step:272/1530 train_loss:3.9224 train_time:42819ms step_avg:163.43ms
step:273/1530 train_loss:3.9327 train_time:42984ms step_avg:163.44ms
step:274/1530 train_loss:4.0263 train_time:43151ms step_avg:163.45ms
step:275/1530 train_loss:4.0563 train_time:43316ms step_avg:163.46ms
step:276/1530 train_loss:4.2169 train_time:43481ms step_avg:163.46ms
step:277/1530 train_loss:4.0343 train_time:43647ms step_avg:163.47ms
step:278/1530 train_loss:4.0843 train_time:43813ms step_avg:163.48ms
step:279/1530 train_loss:4.0018 train_time:43978ms step_avg:163.49ms
step:280/1530 train_loss:4.1768 train_time:44146ms step_avg:163.50ms
step:281/1530 train_loss:3.9688 train_time:44312ms step_avg:163.51ms
step:282/1530 train_loss:3.9345 train_time:44479ms step_avg:163.53ms
step:283/1530 train_loss:3.9044 train_time:44644ms step_avg:163.53ms
step:284/1530 train_loss:4.0406 train_time:44810ms step_avg:163.54ms
step:285/1530 train_loss:4.0518 train_time:44975ms step_avg:163.55ms
step:286/1530 train_loss:4.0710 train_time:45140ms step_avg:163.55ms
step:287/1530 train_loss:3.9003 train_time:45306ms step_avg:163.56ms
step:288/1530 train_loss:4.0044 train_time:45471ms step_avg:163.56ms
step:289/1530 train_loss:3.8646 train_time:45635ms step_avg:163.57ms
step:290/1530 train_loss:3.8510 train_time:45799ms step_avg:163.57ms
step:291/1530 train_loss:3.9038 train_time:45966ms step_avg:163.58ms
step:292/1530 train_loss:3.8552 train_time:46131ms step_avg:163.58ms
step:293/1530 train_loss:3.8967 train_time:46296ms step_avg:163.59ms
step:294/1530 train_loss:3.9291 train_time:46461ms step_avg:163.59ms
step:295/1530 train_loss:3.8329 train_time:46626ms step_avg:163.60ms
step:296/1530 train_loss:3.8492 train_time:46791ms step_avg:163.61ms
step:297/1530 train_loss:3.8608 train_time:46957ms step_avg:163.61ms
step:298/1530 train_loss:3.9719 train_time:47120ms step_avg:163.61ms
step:299/1530 train_loss:3.8185 train_time:47286ms step_avg:163.62ms
step:300/1530 train_loss:3.9591 train_time:47452ms step_avg:163.63ms
step:301/1530 train_loss:3.9473 train_time:47615ms step_avg:163.63ms
step:302/1530 train_loss:3.9221 train_time:47780ms step_avg:163.63ms
step:303/1530 train_loss:3.9698 train_time:47945ms step_avg:163.64ms
step:304/1530 train_loss:3.9640 train_time:48109ms step_avg:163.64ms
step:305/1530 train_loss:4.4427 train_time:48274ms step_avg:163.64ms
step:306/1530 train_loss:3.9296 train_time:48439ms step_avg:163.65ms
step:307/1530 train_loss:3.8316 train_time:48605ms step_avg:163.65ms
step:308/1530 train_loss:3.9696 train_time:48770ms step_avg:163.66ms
step:309/1530 train_loss:3.8685 train_time:48935ms step_avg:163.66ms
step:310/1530 train_loss:4.0751 train_time:49098ms step_avg:163.66ms
step:311/1530 train_loss:3.9110 train_time:49266ms step_avg:163.67ms
step:312/1530 train_loss:3.8545 train_time:49431ms step_avg:163.68ms
step:313/1530 train_loss:3.9254 train_time:49596ms step_avg:163.68ms
step:314/1530 train_loss:4.0545 train_time:49763ms step_avg:163.69ms
step:315/1530 train_loss:3.9325 train_time:49928ms step_avg:163.70ms
step:316/1530 train_loss:3.7946 train_time:50093ms step_avg:163.70ms
step:317/1530 train_loss:3.8661 train_time:50259ms step_avg:163.71ms
step:318/1530 train_loss:3.9089 train_time:50424ms step_avg:163.71ms
step:319/1530 train_loss:3.8760 train_time:50590ms step_avg:163.72ms
step:320/1530 train_loss:4.0017 train_time:50755ms step_avg:163.73ms
step:321/1530 train_loss:3.9500 train_time:50920ms step_avg:163.73ms
step:322/1530 train_loss:3.9217 train_time:51084ms step_avg:163.73ms
step:323/1530 train_loss:3.9936 train_time:51250ms step_avg:163.74ms
step:324/1530 train_loss:3.9304 train_time:51414ms step_avg:163.74ms
step:325/1530 train_loss:4.0030 train_time:51579ms step_avg:163.74ms
step:326/1530 train_loss:3.8894 train_time:51745ms step_avg:163.75ms
step:327/1530 train_loss:4.3859 train_time:51910ms step_avg:163.75ms
step:328/1530 train_loss:4.0590 train_time:52075ms step_avg:163.76ms
step:329/1530 train_loss:3.7858 train_time:52238ms step_avg:163.76ms
step:330/1530 train_loss:3.7308 train_time:52404ms step_avg:163.76ms
step:331/1530 train_loss:3.9703 train_time:52569ms step_avg:163.77ms
step:332/1530 train_loss:3.8985 train_time:52735ms step_avg:163.77ms
step:333/1530 train_loss:3.8763 train_time:52900ms step_avg:163.78ms
step:334/1530 train_loss:3.8348 train_time:53066ms step_avg:163.78ms
step:335/1530 train_loss:3.9978 train_time:53230ms step_avg:163.79ms
step:336/1530 train_loss:3.9434 train_time:53395ms step_avg:163.79ms
step:337/1530 train_loss:4.4323 train_time:53561ms step_avg:163.79ms
step:338/1530 train_loss:3.9343 train_time:53726ms step_avg:163.80ms
step:339/1530 train_loss:3.8598 train_time:53892ms step_avg:163.81ms
step:340/1530 train_loss:3.9263 train_time:54057ms step_avg:163.81ms
step:341/1530 train_loss:3.8470 train_time:54224ms step_avg:163.82ms
step:342/1530 train_loss:3.8051 train_time:54392ms step_avg:163.83ms
step:343/1530 train_loss:3.8293 train_time:54560ms step_avg:163.84ms
step:344/1530 train_loss:3.9842 train_time:54728ms step_avg:163.86ms
step:345/1530 train_loss:3.8134 train_time:54896ms step_avg:163.87ms
step:346/1530 train_loss:3.7619 train_time:55065ms step_avg:163.88ms
step:347/1530 train_loss:3.7804 train_time:55233ms step_avg:163.90ms
step:348/1530 train_loss:3.8479 train_time:55401ms step_avg:163.91ms
step:349/1530 train_loss:3.8232 train_time:55570ms step_avg:163.92ms
step:350/1530 train_loss:3.5644 train_time:55737ms step_avg:163.93ms
step:351/1530 train_loss:3.8199 train_time:55905ms step_avg:163.95ms
step:352/1530 train_loss:4.1850 train_time:56074ms step_avg:163.96ms
step:353/1530 train_loss:3.6461 train_time:56242ms step_avg:163.97ms
step:354/1530 train_loss:3.9173 train_time:56409ms step_avg:163.98ms
step:355/1530 train_loss:3.7740 train_time:56577ms step_avg:163.99ms
step:356/1530 train_loss:3.8705 train_time:56745ms step_avg:164.00ms
step:357/1530 train_loss:3.7512 train_time:56913ms step_avg:164.01ms
step:358/1530 train_loss:3.8622 train_time:57080ms step_avg:164.02ms
step:359/1530 train_loss:3.7631 train_time:57250ms step_avg:164.04ms
step:360/1530 train_loss:3.4240 train_time:57418ms step_avg:164.05ms
step:361/1530 train_loss:4.0075 train_time:57587ms step_avg:164.06ms
step:362/1530 train_loss:3.9071 train_time:57755ms step_avg:164.08ms
step:363/1530 train_loss:3.8250 train_time:57921ms step_avg:164.08ms
step:364/1530 train_loss:3.7357 train_time:58089ms step_avg:164.09ms
step:365/1530 train_loss:3.9091 train_time:58257ms step_avg:164.10ms
step:366/1530 train_loss:3.8521 train_time:58426ms step_avg:164.12ms
step:367/1530 train_loss:3.8548 train_time:58594ms step_avg:164.13ms
step:368/1530 train_loss:3.8476 train_time:58762ms step_avg:164.14ms
step:369/1530 train_loss:3.7372 train_time:58930ms step_avg:164.15ms
step:370/1530 train_loss:3.8735 train_time:59096ms step_avg:164.16ms
step:371/1530 train_loss:3.7266 train_time:59264ms step_avg:164.17ms
step:372/1530 train_loss:3.6844 train_time:59431ms step_avg:164.18ms
step:373/1530 train_loss:3.9020 train_time:59597ms step_avg:164.18ms
step:374/1530 train_loss:3.8188 train_time:59768ms step_avg:164.20ms
step:375/1530 train_loss:3.7880 train_time:59935ms step_avg:164.21ms
step:375/1530 val_loss:3.8158 train_time:59983ms step_avg:164.34ms