-
Notifications
You must be signed in to change notification settings - Fork 183
/
d6520673-0f5f-4c28-898b-f52d056b257d.txt
2165 lines (2092 loc) · 134 KB
/
d6520673-0f5f-4c28-898b-f52d056b257d.txt
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
import os
import sys
with open(sys.argv[0]) as f:
code = f.read() # read the code of this file ASAP, for logging
import uuid
import glob
import time
import contextlib
from dataclasses import dataclass
import numpy as np
import torch
from torch import nn
import torch.nn.functional as F
import torch.distributed as dist
import torch._inductor.config as config
from torch.nn.parallel import DistributedDataParallel as DDP
# Use of FlexAttention contributed by @KoszarskyB
from torch.nn.attention.flex_attention import flex_attention, create_block_mask
flex_attention = torch.compile(flex_attention, dynamic=False)
create_block_mask = torch.compile(create_block_mask, dynamic=False)
# -----------------------------------------------------------------------------
# Muon optimizer
def zeropower_via_svd(G, steps=None):
U, S, V = G.svd()
return U @ V.T
@torch.compile
def zeropower_via_newtonschulz5(G, steps=10, eps=1e-7):
"""
Newton-Schulz iteration to compute the zeroth power / orthogonalization of G. We opt to use a
quintic iteration whose coefficients are selected to maximize the slope at zero. For the purpose
of minimizing steps, it turns out to be empirically effective to keep increasing the slope at
zero even beyond the point where the iteration no longer converges all the way to one everywhere
on the interval. This iteration therefore does not produce UV^T but rather something like US'V^T
where S' is diagonal with S_{ii}' ~ Uniform(0.5, 1.5), which turns out not to hurt model
performance at all relative to UV^T, where USV^T = G is the SVD.
"""
assert len(G.shape) == 2
a, b, c = (3.4445, -4.7750, 2.0315)
X = G.bfloat16()
X /= (X.norm() + eps) # ensure top singular value <= 1
if G.size(0) > G.size(1):
X = X.T
for _ in range(steps):
A = X @ X.T
B = b * A + c * A @ A # adapted from suggestion by @jxbz, @leloykun, and @YouJiacheng
X = a * X + B @ X
if G.size(0) > G.size(1):
X = X.T
return X
zeropower_backends = dict(svd=zeropower_via_svd, newtonschulz5=zeropower_via_newtonschulz5)
class Muon(torch.optim.Optimizer):
"""
Muon - MomentUm Orthogonalized by Newton-schulz
Muon internally runs standard SGD-momentum, and then performs an orthogonalization post-
processing step, in which each 2D parameter's update is replaced with the nearest orthogonal
matrix. To efficiently orthogonalize each update, we use a Newton-Schulz iteration, which has
the advantage that it can be stably run in bfloat16 on the GPU.
Some warnings:
- This optimizer assumes that all parameters passed in are 2D.
- It should not be used for the embedding layer, the final fully connected layer, or any {0,1}-D
parameters; those should all be optimized by a standard method (e.g., AdamW).
- To use it with 4D convolutional filters, it works well to just flatten their last 3 dimensions.
- We believe it is unlikely to work well for training with small batch size.
- We believe it may not work well for finetuning pretrained models, but we haven't tested this.
- We have not yet tried this optimizer for training scenarios larger than NanoGPT (124M).
Arguments:
lr: The learning rate used by the internal SGD.
momentum: The momentum used by the internal SGD.
nesterov: Whether to use Nesterov-style momentum in the internal SGD. (recommended)
backend: The chosen backend for the orthogonalization step. (recommended: 'newtonschulz5')
backend_steps: The number of iteration steps to use in the backend, if it is iterative.
"""
def __init__(self, params, lr=0.02, momentum=0.95, nesterov=True,
backend='newtonschulz5', backend_steps=5):
defaults = dict(lr=lr, momentum=momentum, nesterov=nesterov, backend=backend, backend_steps=backend_steps)
super().__init__(params, defaults)
def step(self):
for group in self.param_groups:
lr = group['lr']
momentum = group['momentum']
zeropower_backend = zeropower_backends[group['backend']]
# generate weight updates in distributed fashion
total_params = sum(p.numel() for p in group['params'])
updates_flat = torch.zeros(total_params, device='cuda', dtype=torch.bfloat16)
curr_idx = 0
for i, p in enumerate(group['params']):
# luckily this will perfectly distribute a transformer with multiple of 4 layers to 8 GPUs
if i % int(os.environ['WORLD_SIZE']) == int(os.environ['RANK']):
g = p.grad
assert g is not None
state = self.state[p]
if 'momentum_buffer' not in state:
state['momentum_buffer'] = torch.zeros_like(g)
buf = state['momentum_buffer']
buf.mul_(momentum).add_(g)
g = g.add(buf, alpha=momentum) if group['nesterov'] else buf
g = zeropower_backend(g, steps=group['backend_steps'])
g *= max(1, g.size(0)/g.size(1))**0.5
updates_flat[curr_idx:curr_idx+p.numel()] = g.flatten()
curr_idx += p.numel()
# sync updates across devices. we are not memory-constrained so can do this simple deserialization
dist.all_reduce(updates_flat, op=dist.ReduceOp.SUM)
# deserialize and apply updates
curr_idx = 0
for p in group['params']:
g = updates_flat[curr_idx:curr_idx+p.numel()].view_as(p.data).type_as(p.data)
p.data.add_(g, alpha=-lr)
curr_idx += p.numel()
# -----------------------------------------------------------------------------
# PyTorch nn.Module definitions for the GPT-2 model
def norm(x):
return F.rms_norm(x, (x.size(-1),))
class CastedLinear(nn.Linear):
def __init__(self, in_features, out_features):
super().__init__(in_features, out_features, bias=False)
def forward(self, x):
return F.linear(x, self.weight.to(x.dtype))
class Rotary(torch.nn.Module):
def __init__(self, dim, base=10000):
super().__init__()
self.register_buffer('inv_freq', (1 / base) ** (torch.arange(0, dim, 2) / dim))
self.seq_len_cached = None
self.cos_cached = None
self.sin_cached = None
def forward(self, x):
seq_len = x.shape[1]
if seq_len != self.seq_len_cached:
t = torch.arange(seq_len, device=x.device)
freqs = torch.outer(t, self.inv_freq)
self.seq_len_cached = seq_len
self.cos_cached = freqs.cos()
self.sin_cached = freqs.sin()
cos, sin = self.cos_cached[None, :, None, :], self.sin_cached[None, :, None, :]
# apply_rotary_emb(x, cos, sin)
x1, x2 = x.chunk(2, dim=3)
y1 = x1 * cos + x2 * sin
y2 = x1 * (-sin) + x2 * cos
return torch.cat((y1, y2), 3).type_as(x)
class CausalSelfAttention(nn.Module):
def __init__(self, dim, n_head):
super().__init__()
assert dim % n_head == 0
self.n_head = n_head
self.c_q = CastedLinear(dim, dim)
self.c_k = CastedLinear(dim, dim)
self.c_v = CastedLinear(dim, dim)
# value residual lambda
self.lamb = nn.Parameter(torch.tensor(0.5)) # @Grad62304977
# rotary embeddings
self.rotary = Rotary(dim // n_head) # dim // n_head = head_dim
# output projection
self.c_proj = CastedLinear(dim, dim)
self.c_proj.weight.data.zero_() # zero init suggested by @Grad62304977
def forward(self, x, vi, block_mask):
B, T = x.size(0), x.size(1) # batch size, sequence length
assert B == 1, "Must use batch size = 1 for FlexAttention"
q = self.c_q(x).view(B, T, self.n_head, -1)
k = self.c_k(x).view(B, T, self.n_head, -1)
v = self.c_v(x).view(B, T, self.n_head, -1)
v = (1 - self.lamb) * v + self.lamb * vi.view_as(v) # @Grad62304977
q, k = norm(q), norm(k) # QK norm suggested by @Grad62304977
q, k = self.rotary(q), self.rotary(k)
y = flex_attention(q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2), block_mask=block_mask)
y = y.transpose(1, 2).contiguous().view_as(x) # re-assemble all head outputs side by side
y = self.c_proj(y)
return y
class MLP(nn.Module):
def __init__(self, dim):
super().__init__()
self.c_fc = CastedLinear(dim, 4 * dim)
self.c_proj = CastedLinear(4 * dim, dim)
self.c_proj.weight.data.zero_() # zero init suggested by @Grad62304977
def forward(self, x):
x = self.c_fc(x)
x = F.relu(x).square() # https://arxiv.org/abs/2109.08668v2; ~1-2% better than GELU; suggested by @SKYLINEZ007 and @Grad62304977
x = self.c_proj(x)
return x
class Block(nn.Module):
def __init__(self, config):
super().__init__()
self.attn = CausalSelfAttention(config.n_embd, config.n_head)
self.mlp = MLP(config.n_embd)
self.lambdas = nn.Parameter(torch.tensor([1., 0.]))
def forward(self, x, vi, x0, block_mask):
x = self.lambdas[0] * x + self.lambdas[1] * x0
x = x + self.attn(norm(x), vi, block_mask)
x = x + self.mlp(norm(x))
return x
# -----------------------------------------------------------------------------
# The main GPT-2 model
@dataclass
class GPTConfig:
vocab_size : int = 50304
n_layer : int = 12
n_head : int = 6 # head dim 128 suggested by @Grad62304977
n_embd : int = 768
class GPT(nn.Module):
def __init__(self, config):
super().__init__()
# U-net design by @brendanh0gan
self.num_encoder_layers = config.n_layer // 2 # Half of the layers for encoder
self.num_decoder_layers = config.n_layer - self.num_encoder_layers # Remaining for decoder
# Add learnable skip connection weights for decoder layers
self.skip_weights = nn.Parameter(torch.ones(self.num_decoder_layers))
self.transformer = nn.ModuleDict(dict(
wte = nn.Embedding(config.vocab_size, config.n_embd),
# token value embeddings by @KoszarskyB - inspired by @Grad62304977's value residual learning
vte = nn.Embedding(config.vocab_size, config.n_embd*12),
h = nn.ModuleList([Block(config) for _ in range(config.n_layer)]),
))
self.lm_head = CastedLinear(config.n_embd, config.vocab_size)
self.lm_head.weight.data.zero_() # @Grad62304977
def forward(self, idx, target, attn_blocksize):
docs = (idx == 50256).cumsum(0)
def document_causal_mask(b, h, q_idx, kv_idx):
causal_mask = q_idx >= kv_idx
document_mask = docs[q_idx] == docs[kv_idx]
window_mask = q_idx - kv_idx < attn_blocksize
return causal_mask & document_mask & window_mask
S = len(idx)
block_mask = create_block_mask(document_causal_mask, None, None, S, S, device="cuda", _compile=True)
# forward the GPT model itself
x = self.transformer.wte(idx[None]) # token embeddings of shape (b, t, n_embd)
x = norm(x) # @Grad62304977
x0 = x
vi = self.transformer.vte(idx[None]).chunk(12, dim=-1)
# Store outputs for U-Net skip connections
skip_connections = []
# Encoder pass - process only the first half of the blocks
for i in range(self.num_encoder_layers):
x = self.transformer.h[i](x, vi[i], x0, block_mask)
skip_connections.append(x)
# Decoder pass - process the remaining blocks with weighted skip connections
for i in range(self.num_decoder_layers):
x = x + self.skip_weights[i] * skip_connections.pop()
x = self.transformer.h[self.num_encoder_layers + i](x, vi[self.num_encoder_layers+i], x0, block_mask)
x = norm(x)
logits = self.lm_head(x)
logits = 30 * torch.tanh(logits / 30) # @Grad62304977
logits = logits.float()
loss = F.cross_entropy(logits.view(-1, logits.size(-1)), target.view(-1))
return loss
# -----------------------------------------------------------------------------
# Our own simple Distributed Data Loader
def _peek_data_shard(filename):
# only reads the header, returns header data
with open(filename, "rb") as f:
# first read the header, which is 256 int32 integers (4 bytes each)
header = np.frombuffer(f.read(256*4), dtype=np.int32)
if header[0] != 20240520:
print("ERROR: magic number mismatch in the data .bin file!")
print("---> HINT: Are you passing in a correct file with --input_bin?")
print("---> HINT: Dataset encoding changed recently, re-run data prepro or refer again to README")
print("---> HINT: For example re-run: `python dev/data/tinyshakespeare.py`, then re-try")
exit(1)
assert header[1] == 1, "unsupported version"
ntok = header[2] # number of tokens (claimed)
return ntok # for now just return the number of tokens
def _load_data_shard(filename):
with open(filename, "rb") as f:
# first read the header, which is 256 int32 integers (4 bytes each)
header = np.frombuffer(f.read(256*4), dtype=np.int32)
assert header[0] == 20240520, "magic number mismatch in the data .bin file"
assert header[1] == 1, "unsupported version"
ntok = header[2] # number of tokens (claimed)
# the rest of it are tokens, stored as uint16
tokens = np.frombuffer(f.read(), dtype=np.uint16)
assert len(tokens) == ntok, "number of tokens read does not match header?"
return tokens
class DistributedDataLoader:
def __init__(self, filename_pattern, T, process_rank, num_processes):
self.process_rank = process_rank
self.num_processes = num_processes
self.T = T
# glob files that match the pattern
self.files = sorted(glob.glob(filename_pattern))
assert len(self.files) > 0, f"did not find any files that match the pattern {filename_pattern}"
# load and validate all data shards, count number of tokens in total
ntok_total = 0
for fname in self.files:
shard_ntok = _peek_data_shard(fname)
assert shard_ntok >= num_processes * T + 1
ntok_total += int(shard_ntok)
self.ntok_total = ntok_total
self.reset()
def reset(self):
self.current_shard = -1
self.advance()
def advance(self): # advance to next data shard
self.current_shard = (self.current_shard + 1) % len(self.files)
self.current_position = self.process_rank * self.T
self.tokens = _load_data_shard(self.files[self.current_shard])
def next_batch(self):
batch_size = self.T * self.num_processes
buf = self.tokens[self.current_position:self.current_position+self.T+1]
buf = torch.tensor(buf.astype(np.int32), dtype=torch.long)
x = buf[:-1] # inputs
y = buf[1:] # targets
# advance current position and load next shard if necessary
self.current_position += batch_size
if self.current_position + batch_size >= len(self.tokens):
self.advance()
return x.cuda(), y.cuda()
# -----------------------------------------------------------------------------
# int main
@dataclass
class Hyperparameters:
# data hyperparams
input_bin : str = 'data/fineweb10B/fineweb_train_*.bin' # input .bin to train on
input_val_bin : str = 'data/fineweb10B/fineweb_val_*.bin' # input .bin to eval validation loss on
# optimization hyperparams
batch_size : int = 8 # batch size, in sequences, across all devices
sequence_length : int = 64*1024 # sequence length, in tokens
num_iterations : int = 1530 # number of iterations to run
warmup_iters : int = 0
cooldown_iters : int = 600 # number of iterations of linear warmup/cooldown for triangular or trapezoidal schedule
weight_decay : float = 0
# evaluation and logging hyperparams
val_loss_every : int = 125 # every how many steps to evaluate val loss? 0 for only at the end
val_tokens : int = 10485760 # how many tokens of validation data? it's important to keep this fixed for consistent comparisons
save_every : int = 0 # every how many steps to save the checkpoint? 0 for only at the end
args = Hyperparameters()
# set up DDP (distributed data parallel). torchrun sets this env variable
assert torch.cuda.is_available()
dist.init_process_group(backend='nccl')
ddp_rank = int(os.environ['RANK'])
ddp_local_rank = int(os.environ['LOCAL_RANK'])
ddp_world_size = int(os.environ['WORLD_SIZE'])
device = f'cuda:{ddp_local_rank}'
torch.cuda.set_device(device)
print(f"using device: {device}")
master_process = (ddp_rank == 0) # this process will do logging, checkpointing etc.
# begin logging
logfile = None
if master_process:
run_id = str(uuid.uuid4())
logdir = 'logs/%s/' % run_id
os.makedirs(logdir, exist_ok=True)
logfile = 'logs/%s.txt' % run_id
# create the log file
with open(logfile, "w") as f:
# begin the log by printing this file (the Python code)
f.write(code)
f.write('='*100 + '\n')
def print0(s, logonly=False):
if master_process:
with open(logfile, "a") as f:
if not logonly:
print(s)
f.write(s+'\n')
# log information about the hardware/software environment this is running on
# and print the full `nvidia-smi` to file
print0(f"Running pytorch {torch.version.__version__} compiled for CUDA {torch.version.cuda}\nnvidia-smi:")
import subprocess
result = subprocess.run(['nvidia-smi'], stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True)
print0(f'{result.stdout}', logonly=True)
print0('='*100, logonly=True)
# convenience variables
T = args.sequence_length
# calculate the number of steps to take in the val loop.
assert args.val_tokens % (T * ddp_world_size) == 0
val_steps = args.val_tokens // (T * ddp_world_size)
# calculate the steps of gradient accumulation required to attain the desired global batch size.
assert args.batch_size % (ddp_world_size) == 0
train_accumulation_steps = args.batch_size // ddp_world_size
# load tokens
train_loader = DistributedDataLoader(args.input_bin, T, ddp_rank, ddp_world_size)
val_loader = DistributedDataLoader(args.input_val_bin, T, ddp_rank, ddp_world_size)
print0(f"Training DataLoader: total number of tokens: {train_loader.ntok_total} across {len(train_loader.files)} files")
print0(f"Validation DataLoader: total number of tokens: {val_loader.ntok_total} across {len(val_loader.files)} files")
print0('='*100, logonly=True)
x, y = train_loader.next_batch()
# there are only 50257 unique GPT-2 tokens; we extend to nearest multiple of 128 for efficiency. suggested to me by @Grad62304977.
# this originates from Karpathy's experiments.
num_vocab = 50304
model = GPT(GPTConfig(vocab_size=num_vocab, n_layer=12, n_head=6, n_embd=768))
model = model.cuda().bfloat16()
for m in model.modules():
if isinstance(m, CastedLinear):
m.float()
if hasattr(config, "coordinate_descent_tuning"):
config.coordinate_descent_tuning = True # suggested by @Chillee
model = torch.compile(model)
# here we wrap model into DDP container
model = DDP(model, device_ids=[ddp_local_rank])
raw_model = model.module # always contains the "raw" unwrapped model
# init the optimizer(s)
optimizer1 = torch.optim.Adam([raw_model.transformer.wte.weight, raw_model.transformer.vte.weight], lr=0.6, betas=(0.8, 0.95), fused=True)
optimizer2 = torch.optim.Adam([raw_model.lm_head.weight], lr=0.008, betas=(0.8, 0.95), fused=True)
params = list(raw_model.transformer.h.parameters())
matrix_params = [p for p in params if p.ndim == 2]
scalar_params = [p for p in params if p.ndim < 2] + [raw_model.skip_weights]
optimizer3 = Muon(matrix_params, lr=0.05, momentum=0.95)
optimizer4 = torch.optim.Adam(scalar_params, lr=0.04, betas=(0.8, 0.95), fused=True) # note that this learning rate is neither sensitive nor tuned
optimizers = [optimizer1, optimizer2, optimizer3, optimizer4]
# learning rate decay scheduler (linear warmup and cooldown)
def get_lr(it):
assert it <= args.num_iterations
# 1) linear warmup for warmup_iters steps
if it < args.warmup_iters:
return (it+1) / args.warmup_iters
# 2) constant lr for a while
elif it < args.num_iterations - args.cooldown_iters:
return 1.0
# 3) linear cooldown
else:
decay_ratio = (args.num_iterations - it) / args.cooldown_iters
return decay_ratio
schedulers = [torch.optim.lr_scheduler.LambdaLR(opt, get_lr) for opt in optimizers]
# Start training loop
training_time_ms = 0
# start the clock
torch.cuda.synchronize()
t0 = time.time()
# begin training
for step in range(args.num_iterations + 1):
last_step = (step == args.num_iterations)
# This effectively ignores timing first 10 steps, which are slower for weird reasons.
# Alternately, and slightly more correctly in terms of benchmarking, we could do 10
# steps with dummy data first, and then re-initialize the model and reset the loader.
if step == 10:
training_time_ms = 0
t0 = time.time()
timed_steps = float('nan') if step <= 11 else (step - 10) + 1 # <= 11 to avoid bug in val
# Set the attention blocksize for the current step, in chunks of 64. By @fernbear.bsky.social
attn_blocksize = torch.tensor(64*((step/args.num_iterations * (1792 - 64) + 64)//64), dtype=torch.int, device='cuda')
# once in a while evaluate the validation dataset
if (last_step or (args.val_loss_every > 0 and step % args.val_loss_every == 0)):
# stop the clock
torch.cuda.synchronize()
training_time_ms += 1000 * (time.time() - t0)
# run validation batches
model.eval()
val_loader.reset()
val_loss = 0.0
for _ in range(val_steps):
with torch.no_grad():
x_val, y_val = val_loader.next_batch()
val_loss += model(x_val, y_val, attn_blocksize=attn_blocksize)
dist.all_reduce(val_loss, op=dist.ReduceOp.AVG)
val_loss /= val_steps
# log val loss to console and to logfile
print0(f'step:{step}/{args.num_iterations} val_loss:{val_loss:.4f} train_time:{training_time_ms:.0f}ms step_avg:{training_time_ms/(timed_steps-1):.2f}ms')
# start the clock again
torch.cuda.synchronize()
t0 = time.time()
if master_process and (last_step or (args.save_every > 0 and step % args.save_every == 0)):
# stop the clock
torch.cuda.synchronize()
training_time_ms += 1000 * (time.time() - t0)
# save the state of the training process
log = dict(step=step, code=code, model=raw_model.state_dict(), optimizers=[opt.state_dict() for opt in optimizers])
torch.save(log, 'logs/%s/state_step%06d.pt' % (run_id, step))
# start the clock again
torch.cuda.synchronize()
t0 = time.time()
# bit confusing: we want to make sure to eval on 0th iteration
# but also after the very last iteration. so we loop for step <= num_iterations
# instead of just < num_iterations (one extra due to <=), only to do
# the validation/sampling one last time, and then we break right here as we're done.
if last_step:
break
# --------------- TRAINING SECTION BEGIN -----------------
model.train()
for i in range(1, train_accumulation_steps+1):
ctx = model.no_sync() if i < train_accumulation_steps else contextlib.nullcontext()
with ctx: # there's no need to sync gradients every accumulation step
# forward pass
loss = model(x, y, attn_blocksize=attn_blocksize)
# advance the dataset for the next batch
x, y = train_loader.next_batch()
# backward pass
loss.backward()
train_loss = loss.detach()
for p in model.parameters():
p.grad /= train_accumulation_steps
# momentum warmup for Muon
frac = min(step/300, 1)
optimizer3.param_groups[0]['momentum'] = (1 - frac) * 0.85 + frac * 0.95
# step the optimizers and schedulers
for opt, sched in zip(optimizers, schedulers):
opt.step()
sched.step()
# null the gradients
model.zero_grad(set_to_none=True)
# --------------- TRAINING SECTION END -------------------
# everything that follows now is just diagnostics, prints, logging, etc.
#dist.all_reduce(train_loss, op=dist.ReduceOp.AVG) # all-reducing the training loss would be more correct in terms of logging, but slower
approx_time = training_time_ms + 1000 * (time.time() - t0)
print0(f"step:{step+1}/{args.num_iterations} train_loss:{train_loss.item():.4f} train_time:{approx_time:.0f}ms step_avg:{approx_time/timed_steps:.2f}ms")
if master_process:
print(f"peak memory consumption: {torch.cuda.max_memory_allocated() // 1024 // 1024} MiB")
# -------------------------------------------------------------------------
# clean up nice
dist.destroy_process_group()
====================================================================================================
Running pytorch 2.6.0.dev20241203+cu124 compiled for CUDA 12.4
nvidia-smi:
Thu Dec 5 02:19:36 2024
+---------------------------------------------------------------------------------------+
| NVIDIA-SMI 535.183.06 Driver Version: 535.183.06 CUDA Version: 12.2 |
|-----------------------------------------+----------------------+----------------------+
| GPU Name Persistence-M | Bus-Id Disp.A | Volatile Uncorr. ECC |
| Fan Temp Perf Pwr:Usage/Cap | Memory-Usage | GPU-Util Compute M. |
| | | MIG M. |
|=========================================+======================+======================|
| 0 NVIDIA H100 80GB HBM3 On | 00000000:19:00.0 Off | 0 |
| N/A 38C P0 75W / 700W | 3MiB / 81559MiB | 0% Default |
| | | Disabled |
+-----------------------------------------+----------------------+----------------------+
| 1 NVIDIA H100 80GB HBM3 On | 00000000:3B:00.0 Off | 0 |
| N/A 30C P0 115W / 700W | 529MiB / 81559MiB | 0% Default |
| | | Disabled |
+-----------------------------------------+----------------------+----------------------+
| 2 NVIDIA H100 80GB HBM3 On | 00000000:4C:00.0 Off | 0 |
| N/A 31C P0 118W / 700W | 529MiB / 81559MiB | 1% Default |
| | | Disabled |
+-----------------------------------------+----------------------+----------------------+
| 3 NVIDIA H100 80GB HBM3 On | 00000000:5D:00.0 Off | 0 |
| N/A 38C P0 118W / 700W | 529MiB / 81559MiB | 0% Default |
| | | Disabled |
+-----------------------------------------+----------------------+----------------------+
| 4 NVIDIA H100 80GB HBM3 On | 00000000:9B:00.0 Off | 0 |
| N/A 39C P0 123W / 700W | 529MiB / 81559MiB | 0% Default |
| | | Disabled |
+-----------------------------------------+----------------------+----------------------+
| 5 NVIDIA H100 80GB HBM3 On | 00000000:BB:00.0 Off | 0 |
| N/A 29C P0 110W / 700W | 529MiB / 81559MiB | 0% Default |
| | | Disabled |
+-----------------------------------------+----------------------+----------------------+
| 6 NVIDIA H100 80GB HBM3 On | 00000000:CB:00.0 Off | 0 |
| N/A 39C P0 127W / 700W | 529MiB / 81559MiB | 0% Default |
| | | Disabled |
+-----------------------------------------+----------------------+----------------------+
| 7 NVIDIA H100 80GB HBM3 On | 00000000:DB:00.0 Off | 0 |
| N/A 30C P0 118W / 700W | 529MiB / 81559MiB | 0% Default |
| | | Disabled |
+-----------------------------------------+----------------------+----------------------+
+---------------------------------------------------------------------------------------+
| Processes: |
| GPU GI CI PID Type Process name GPU Memory |
| ID ID Usage |
|=======================================================================================|
+---------------------------------------------------------------------------------------+
====================================================================================================
Training DataLoader: total number of tokens: 1100000000 across 11 files
Validation DataLoader: total number of tokens: 100000000 across 1 files
====================================================================================================
step:0/1530 val_loss:10.8258 train_time:0ms step_avg:nanms
step:1/1530 train_loss:10.8258 train_time:31485ms step_avg:nanms
step:2/1530 train_loss:10.0780 train_time:31596ms step_avg:nanms
step:3/1530 train_loss:8.3799 train_time:31755ms step_avg:nanms
step:4/1530 train_loss:7.5771 train_time:31916ms step_avg:nanms
step:5/1530 train_loss:7.4605 train_time:32077ms step_avg:nanms
step:6/1530 train_loss:6.9770 train_time:32237ms step_avg:nanms
step:7/1530 train_loss:7.1970 train_time:32399ms step_avg:nanms
step:8/1530 train_loss:6.7437 train_time:32559ms step_avg:nanms
step:9/1530 train_loss:6.6149 train_time:32719ms step_avg:nanms
step:10/1530 train_loss:6.4913 train_time:32879ms step_avg:nanms
step:11/1530 train_loss:6.4531 train_time:114ms step_avg:nanms
step:12/1530 train_loss:6.3716 train_time:276ms step_avg:nanms
step:13/1530 train_loss:6.2501 train_time:436ms step_avg:145.23ms
step:14/1530 train_loss:6.2664 train_time:596ms step_avg:148.92ms
step:15/1530 train_loss:6.1678 train_time:756ms step_avg:151.28ms
step:16/1530 train_loss:6.1338 train_time:917ms step_avg:152.77ms
step:17/1530 train_loss:6.1632 train_time:1075ms step_avg:153.64ms
step:18/1530 train_loss:5.9654 train_time:1237ms step_avg:154.60ms
step:19/1530 train_loss:5.9700 train_time:1397ms step_avg:155.20ms
step:20/1530 train_loss:5.7142 train_time:1556ms step_avg:155.61ms
step:21/1530 train_loss:5.9572 train_time:1717ms step_avg:156.13ms
step:22/1530 train_loss:6.1560 train_time:1878ms step_avg:156.51ms
step:23/1530 train_loss:5.8443 train_time:2038ms step_avg:156.80ms
step:24/1530 train_loss:6.0306 train_time:2198ms step_avg:156.99ms
step:25/1530 train_loss:5.6815 train_time:2359ms step_avg:157.25ms
step:26/1530 train_loss:5.5955 train_time:2520ms step_avg:157.48ms
step:27/1530 train_loss:5.7527 train_time:2679ms step_avg:157.61ms
step:28/1530 train_loss:5.4114 train_time:2841ms step_avg:157.82ms
step:29/1530 train_loss:5.6747 train_time:3001ms step_avg:157.93ms
step:30/1530 train_loss:5.4714 train_time:3162ms step_avg:158.09ms
step:31/1530 train_loss:5.4282 train_time:3322ms step_avg:158.20ms
step:32/1530 train_loss:5.2931 train_time:3484ms step_avg:158.38ms
step:33/1530 train_loss:5.5779 train_time:3643ms step_avg:158.41ms
step:34/1530 train_loss:5.4965 train_time:3803ms step_avg:158.47ms
step:35/1530 train_loss:5.6091 train_time:3964ms step_avg:158.57ms
step:36/1530 train_loss:5.5537 train_time:4125ms step_avg:158.64ms
step:37/1530 train_loss:5.4458 train_time:4286ms step_avg:158.74ms
step:38/1530 train_loss:5.3028 train_time:4447ms step_avg:158.82ms
step:39/1530 train_loss:5.3212 train_time:4608ms step_avg:158.88ms
step:40/1530 train_loss:5.2566 train_time:4768ms step_avg:158.93ms
step:41/1530 train_loss:5.2313 train_time:4928ms step_avg:158.98ms
step:42/1530 train_loss:5.1684 train_time:5088ms step_avg:159.01ms
step:43/1530 train_loss:5.2618 train_time:5248ms step_avg:159.04ms
step:44/1530 train_loss:5.2456 train_time:5408ms step_avg:159.07ms
step:45/1530 train_loss:5.3883 train_time:5569ms step_avg:159.12ms
step:46/1530 train_loss:5.1657 train_time:5729ms step_avg:159.13ms
step:47/1530 train_loss:5.0603 train_time:5889ms step_avg:159.17ms
step:48/1530 train_loss:5.2192 train_time:6049ms step_avg:159.19ms
step:49/1530 train_loss:5.1393 train_time:6209ms step_avg:159.21ms
step:50/1530 train_loss:5.2391 train_time:6370ms step_avg:159.25ms
step:51/1530 train_loss:5.1286 train_time:6530ms step_avg:159.26ms
step:52/1530 train_loss:5.0282 train_time:6689ms step_avg:159.27ms
step:53/1530 train_loss:5.1771 train_time:6849ms step_avg:159.29ms
step:54/1530 train_loss:5.0327 train_time:7010ms step_avg:159.31ms
step:55/1530 train_loss:5.4132 train_time:7170ms step_avg:159.34ms
step:56/1530 train_loss:5.0209 train_time:7330ms step_avg:159.34ms
step:57/1530 train_loss:4.8637 train_time:7490ms step_avg:159.36ms
step:58/1530 train_loss:5.0346 train_time:7650ms step_avg:159.37ms
step:59/1530 train_loss:5.0077 train_time:7811ms step_avg:159.40ms
step:60/1530 train_loss:5.1231 train_time:7971ms step_avg:159.41ms
step:61/1530 train_loss:4.8402 train_time:8130ms step_avg:159.42ms
step:62/1530 train_loss:4.9702 train_time:8290ms step_avg:159.42ms
step:63/1530 train_loss:4.9658 train_time:8450ms step_avg:159.44ms
step:64/1530 train_loss:4.9399 train_time:8610ms step_avg:159.44ms
step:65/1530 train_loss:4.7817 train_time:8770ms step_avg:159.45ms
step:66/1530 train_loss:4.9064 train_time:8930ms step_avg:159.46ms
step:67/1530 train_loss:4.8141 train_time:9090ms step_avg:159.47ms
step:68/1530 train_loss:5.1025 train_time:9250ms step_avg:159.48ms
step:69/1530 train_loss:4.7342 train_time:9411ms step_avg:159.51ms
step:70/1530 train_loss:4.8741 train_time:9571ms step_avg:159.52ms
step:71/1530 train_loss:4.9626 train_time:9731ms step_avg:159.52ms
step:72/1530 train_loss:4.8734 train_time:9890ms step_avg:159.52ms
step:73/1530 train_loss:4.7499 train_time:10051ms step_avg:159.54ms
step:74/1530 train_loss:4.8983 train_time:10211ms step_avg:159.54ms
step:75/1530 train_loss:4.8382 train_time:10371ms step_avg:159.56ms
step:76/1530 train_loss:4.7848 train_time:10532ms step_avg:159.57ms
step:77/1530 train_loss:4.9067 train_time:10692ms step_avg:159.58ms
step:78/1530 train_loss:5.1151 train_time:10852ms step_avg:159.58ms
step:79/1530 train_loss:4.7755 train_time:11011ms step_avg:159.58ms
step:80/1530 train_loss:4.8312 train_time:11174ms step_avg:159.63ms
step:81/1530 train_loss:4.6272 train_time:11335ms step_avg:159.65ms
step:82/1530 train_loss:4.8113 train_time:11495ms step_avg:159.66ms
step:83/1530 train_loss:4.7512 train_time:11657ms step_avg:159.68ms
step:84/1530 train_loss:4.7470 train_time:11817ms step_avg:159.69ms
step:85/1530 train_loss:4.6131 train_time:11978ms step_avg:159.71ms
step:86/1530 train_loss:4.8217 train_time:12137ms step_avg:159.70ms
step:87/1530 train_loss:4.7306 train_time:12298ms step_avg:159.72ms
step:88/1530 train_loss:4.7478 train_time:12459ms step_avg:159.73ms
step:89/1530 train_loss:4.6856 train_time:12620ms step_avg:159.74ms
step:90/1530 train_loss:4.6290 train_time:12781ms step_avg:159.76ms
step:91/1530 train_loss:4.6200 train_time:12941ms step_avg:159.77ms
step:92/1530 train_loss:4.7790 train_time:13102ms step_avg:159.78ms
step:93/1530 train_loss:4.5871 train_time:13262ms step_avg:159.79ms
step:94/1530 train_loss:4.6217 train_time:13422ms step_avg:159.79ms
step:95/1530 train_loss:4.6599 train_time:13585ms step_avg:159.82ms
step:96/1530 train_loss:4.5822 train_time:13744ms step_avg:159.81ms
step:97/1530 train_loss:4.6175 train_time:13904ms step_avg:159.82ms
step:98/1530 train_loss:4.5844 train_time:14065ms step_avg:159.83ms
step:99/1530 train_loss:4.6472 train_time:14224ms step_avg:159.82ms
step:100/1530 train_loss:4.6633 train_time:14386ms step_avg:159.84ms
step:101/1530 train_loss:4.5219 train_time:14546ms step_avg:159.85ms
step:102/1530 train_loss:4.6828 train_time:14707ms step_avg:159.86ms
step:103/1530 train_loss:4.5541 train_time:14868ms step_avg:159.87ms
step:104/1530 train_loss:4.5388 train_time:15028ms step_avg:159.87ms
step:105/1530 train_loss:4.5397 train_time:15188ms step_avg:159.87ms
step:106/1530 train_loss:4.5801 train_time:15348ms step_avg:159.88ms
step:107/1530 train_loss:4.4976 train_time:15509ms step_avg:159.88ms
step:108/1530 train_loss:4.3599 train_time:15669ms step_avg:159.89ms
step:109/1530 train_loss:4.4850 train_time:15829ms step_avg:159.89ms
step:110/1530 train_loss:4.4688 train_time:15988ms step_avg:159.88ms
step:111/1530 train_loss:4.4056 train_time:16149ms step_avg:159.89ms
step:112/1530 train_loss:4.5775 train_time:16310ms step_avg:159.90ms
step:113/1530 train_loss:4.4786 train_time:16470ms step_avg:159.90ms
step:114/1530 train_loss:4.3497 train_time:16630ms step_avg:159.90ms
step:115/1530 train_loss:4.4838 train_time:16792ms step_avg:159.93ms
step:116/1530 train_loss:4.4512 train_time:16956ms step_avg:159.96ms
step:117/1530 train_loss:4.3605 train_time:17121ms step_avg:160.01ms
step:118/1530 train_loss:4.5897 train_time:17286ms step_avg:160.06ms
step:119/1530 train_loss:4.4455 train_time:17448ms step_avg:160.08ms
step:120/1530 train_loss:4.3207 train_time:17612ms step_avg:160.11ms
step:121/1530 train_loss:4.2856 train_time:17777ms step_avg:160.15ms
step:122/1530 train_loss:4.4393 train_time:17942ms step_avg:160.19ms
step:123/1530 train_loss:4.2840 train_time:18105ms step_avg:160.22ms
step:124/1530 train_loss:4.5766 train_time:18269ms step_avg:160.25ms
step:125/1530 train_loss:4.4369 train_time:18432ms step_avg:160.28ms
step:125/1530 val_loss:4.3846 train_time:18479ms step_avg:160.69ms
step:126/1530 train_loss:4.3967 train_time:18599ms step_avg:160.34ms
step:127/1530 train_loss:4.4196 train_time:18767ms step_avg:160.40ms
step:128/1530 train_loss:4.3731 train_time:18930ms step_avg:160.42ms
step:129/1530 train_loss:4.6663 train_time:19095ms step_avg:160.46ms
step:130/1530 train_loss:4.3577 train_time:19261ms step_avg:160.51ms
step:131/1530 train_loss:4.4040 train_time:19425ms step_avg:160.54ms
step:132/1530 train_loss:4.3380 train_time:19588ms step_avg:160.56ms
step:133/1530 train_loss:4.4403 train_time:19753ms step_avg:160.59ms
step:134/1530 train_loss:4.2630 train_time:19916ms step_avg:160.61ms
step:135/1530 train_loss:4.4398 train_time:20080ms step_avg:160.64ms
step:136/1530 train_loss:4.2051 train_time:20244ms step_avg:160.67ms
step:137/1530 train_loss:4.3692 train_time:20408ms step_avg:160.69ms
step:138/1530 train_loss:4.2874 train_time:20571ms step_avg:160.71ms
step:139/1530 train_loss:4.3754 train_time:20736ms step_avg:160.74ms
step:140/1530 train_loss:4.4704 train_time:20901ms step_avg:160.78ms
step:141/1530 train_loss:4.3072 train_time:21066ms step_avg:160.81ms
step:142/1530 train_loss:4.3003 train_time:21229ms step_avg:160.83ms
step:143/1530 train_loss:4.2576 train_time:21393ms step_avg:160.85ms
step:144/1530 train_loss:4.3520 train_time:21556ms step_avg:160.86ms
step:145/1530 train_loss:4.2977 train_time:21720ms step_avg:160.89ms
step:146/1530 train_loss:4.1693 train_time:21884ms step_avg:160.91ms
step:147/1530 train_loss:4.3227 train_time:22047ms step_avg:160.93ms
step:148/1530 train_loss:4.3587 train_time:22210ms step_avg:160.94ms
step:149/1530 train_loss:4.2975 train_time:22373ms step_avg:160.96ms
step:150/1530 train_loss:4.4318 train_time:22536ms step_avg:160.97ms
step:151/1530 train_loss:4.2592 train_time:22702ms step_avg:161.00ms
step:152/1530 train_loss:4.2585 train_time:22866ms step_avg:161.03ms
step:153/1530 train_loss:4.3607 train_time:23029ms step_avg:161.04ms
step:154/1530 train_loss:4.3630 train_time:23194ms step_avg:161.07ms
step:155/1530 train_loss:4.2546 train_time:23358ms step_avg:161.09ms
step:156/1530 train_loss:4.3404 train_time:23522ms step_avg:161.11ms
step:157/1530 train_loss:4.4012 train_time:23685ms step_avg:161.12ms
step:158/1530 train_loss:4.2447 train_time:23849ms step_avg:161.14ms
step:159/1530 train_loss:4.2979 train_time:24012ms step_avg:161.15ms
step:160/1530 train_loss:4.1254 train_time:24175ms step_avg:161.17ms
step:161/1530 train_loss:4.3445 train_time:24340ms step_avg:161.19ms
step:162/1530 train_loss:4.3591 train_time:24504ms step_avg:161.21ms
step:163/1530 train_loss:4.3401 train_time:24669ms step_avg:161.24ms
step:164/1530 train_loss:4.1875 train_time:24833ms step_avg:161.26ms
step:165/1530 train_loss:4.2801 train_time:24998ms step_avg:161.28ms
step:166/1530 train_loss:4.3407 train_time:25163ms step_avg:161.30ms
step:167/1530 train_loss:4.1992 train_time:25326ms step_avg:161.31ms
step:168/1530 train_loss:4.2770 train_time:25490ms step_avg:161.33ms
step:169/1530 train_loss:4.1533 train_time:25655ms step_avg:161.35ms
step:170/1530 train_loss:4.0173 train_time:25819ms step_avg:161.37ms
step:171/1530 train_loss:4.1946 train_time:25981ms step_avg:161.37ms
step:172/1530 train_loss:4.2070 train_time:26144ms step_avg:161.39ms
step:173/1530 train_loss:4.2673 train_time:26307ms step_avg:161.39ms
step:174/1530 train_loss:4.4161 train_time:26469ms step_avg:161.40ms
step:175/1530 train_loss:4.2442 train_time:26633ms step_avg:161.41ms
step:176/1530 train_loss:4.0916 train_time:26796ms step_avg:161.42ms
step:177/1530 train_loss:4.0610 train_time:26960ms step_avg:161.44ms
step:178/1530 train_loss:4.1835 train_time:27123ms step_avg:161.45ms
step:179/1530 train_loss:4.1218 train_time:27286ms step_avg:161.45ms
step:180/1530 train_loss:4.1070 train_time:27449ms step_avg:161.46ms
step:181/1530 train_loss:4.2940 train_time:27611ms step_avg:161.47ms
step:182/1530 train_loss:4.1520 train_time:27773ms step_avg:161.47ms
step:183/1530 train_loss:4.1251 train_time:27938ms step_avg:161.49ms
step:184/1530 train_loss:4.1148 train_time:28101ms step_avg:161.50ms
step:185/1530 train_loss:4.2031 train_time:28264ms step_avg:161.51ms
step:186/1530 train_loss:4.1682 train_time:28426ms step_avg:161.51ms
step:187/1530 train_loss:4.2338 train_time:28590ms step_avg:161.53ms
step:188/1530 train_loss:4.1702 train_time:28885ms step_avg:162.27ms
step:189/1530 train_loss:4.1059 train_time:29209ms step_avg:163.18ms
step:190/1530 train_loss:4.1935 train_time:29372ms step_avg:163.18ms
step:191/1530 train_loss:4.0716 train_time:29534ms step_avg:163.17ms
step:192/1530 train_loss:4.0258 train_time:29698ms step_avg:163.17ms
step:193/1530 train_loss:4.2489 train_time:29862ms step_avg:163.18ms
step:194/1530 train_loss:4.1691 train_time:30025ms step_avg:163.18ms
step:195/1530 train_loss:4.3511 train_time:30188ms step_avg:163.18ms
step:196/1530 train_loss:4.1796 train_time:30350ms step_avg:163.17ms
step:197/1530 train_loss:4.0370 train_time:30512ms step_avg:163.17ms
step:198/1530 train_loss:4.1766 train_time:30675ms step_avg:163.16ms
step:199/1530 train_loss:4.0351 train_time:30839ms step_avg:163.17ms
step:200/1530 train_loss:4.1080 train_time:31002ms step_avg:163.17ms
step:201/1530 train_loss:4.0135 train_time:31164ms step_avg:163.16ms
step:202/1530 train_loss:4.2597 train_time:31327ms step_avg:163.16ms
step:203/1530 train_loss:4.0655 train_time:31489ms step_avg:163.16ms
step:204/1530 train_loss:4.1902 train_time:31652ms step_avg:163.15ms
step:205/1530 train_loss:4.2406 train_time:31814ms step_avg:163.15ms
step:206/1530 train_loss:3.9479 train_time:31977ms step_avg:163.15ms
step:207/1530 train_loss:4.0864 train_time:32141ms step_avg:163.15ms
step:208/1530 train_loss:4.1000 train_time:32303ms step_avg:163.15ms
step:209/1530 train_loss:4.2384 train_time:32466ms step_avg:163.15ms
step:210/1530 train_loss:4.1747 train_time:32629ms step_avg:163.15ms
step:211/1530 train_loss:4.0612 train_time:32792ms step_avg:163.14ms
step:212/1530 train_loss:4.1259 train_time:32956ms step_avg:163.15ms
step:213/1530 train_loss:4.0523 train_time:33120ms step_avg:163.15ms
step:214/1530 train_loss:4.1236 train_time:33282ms step_avg:163.15ms
step:215/1530 train_loss:3.9809 train_time:33445ms step_avg:163.15ms
step:216/1530 train_loss:4.0027 train_time:33607ms step_avg:163.14ms
step:217/1530 train_loss:4.0111 train_time:33770ms step_avg:163.14ms
step:218/1530 train_loss:4.0884 train_time:33932ms step_avg:163.14ms
step:219/1530 train_loss:4.0805 train_time:34096ms step_avg:163.14ms
step:220/1530 train_loss:4.0805 train_time:34260ms step_avg:163.14ms
step:221/1530 train_loss:4.0905 train_time:34423ms step_avg:163.14ms
step:222/1530 train_loss:3.9911 train_time:34585ms step_avg:163.14ms
step:223/1530 train_loss:3.9949 train_time:34749ms step_avg:163.14ms
step:224/1530 train_loss:4.3047 train_time:34911ms step_avg:163.13ms
step:225/1530 train_loss:3.9339 train_time:35074ms step_avg:163.14ms
step:226/1530 train_loss:3.9892 train_time:35237ms step_avg:163.13ms
step:227/1530 train_loss:3.9762 train_time:35400ms step_avg:163.13ms
step:228/1530 train_loss:4.1465 train_time:35565ms step_avg:163.14ms
step:229/1530 train_loss:3.9286 train_time:35730ms step_avg:163.15ms
step:230/1530 train_loss:4.0360 train_time:35895ms step_avg:163.16ms
step:231/1530 train_loss:3.9031 train_time:36064ms step_avg:163.18ms
step:232/1530 train_loss:3.9767 train_time:36229ms step_avg:163.19ms
step:233/1530 train_loss:4.0925 train_time:36394ms step_avg:163.20ms
step:234/1530 train_loss:4.0274 train_time:36562ms step_avg:163.22ms
step:235/1530 train_loss:3.8989 train_time:36729ms step_avg:163.24ms
step:236/1530 train_loss:4.0801 train_time:36894ms step_avg:163.25ms
step:237/1530 train_loss:4.0769 train_time:37060ms step_avg:163.26ms
step:238/1530 train_loss:3.9423 train_time:37225ms step_avg:163.27ms
step:239/1530 train_loss:4.0775 train_time:37391ms step_avg:163.28ms
step:240/1530 train_loss:4.1131 train_time:37558ms step_avg:163.29ms
step:241/1530 train_loss:3.9611 train_time:37723ms step_avg:163.30ms
step:242/1530 train_loss:4.1399 train_time:37889ms step_avg:163.31ms
step:243/1530 train_loss:4.0074 train_time:38055ms step_avg:163.33ms
step:244/1530 train_loss:4.0782 train_time:38223ms step_avg:163.35ms
step:245/1530 train_loss:4.1402 train_time:38389ms step_avg:163.36ms
step:246/1530 train_loss:4.0608 train_time:38554ms step_avg:163.37ms
step:247/1530 train_loss:4.0028 train_time:38723ms step_avg:163.39ms
step:248/1530 train_loss:4.0931 train_time:38888ms step_avg:163.40ms
step:249/1530 train_loss:3.9191 train_time:39053ms step_avg:163.40ms
step:250/1530 train_loss:3.9809 train_time:39219ms step_avg:163.41ms
step:250/1530 val_loss:4.0102 train_time:39267ms step_avg:163.61ms
step:251/1530 train_loss:4.0834 train_time:39388ms step_avg:163.44ms
step:252/1530 train_loss:4.1614 train_time:39554ms step_avg:163.45ms
step:253/1530 train_loss:3.9375 train_time:39721ms step_avg:163.46ms
step:254/1530 train_loss:3.8776 train_time:39888ms step_avg:163.48ms
step:255/1530 train_loss:4.0747 train_time:40054ms step_avg:163.48ms
step:256/1530 train_loss:3.9951 train_time:40220ms step_avg:163.50ms
step:257/1530 train_loss:3.9874 train_time:40386ms step_avg:163.50ms
step:258/1530 train_loss:3.9878 train_time:40551ms step_avg:163.51ms
step:259/1530 train_loss:4.0364 train_time:40717ms step_avg:163.52ms
step:260/1530 train_loss:4.0661 train_time:40886ms step_avg:163.54ms
step:261/1530 train_loss:4.0300 train_time:41053ms step_avg:163.56ms
step:262/1530 train_loss:3.9974 train_time:41219ms step_avg:163.57ms
step:263/1530 train_loss:3.8970 train_time:41385ms step_avg:163.58ms
step:264/1530 train_loss:3.9855 train_time:41551ms step_avg:163.59ms
step:265/1530 train_loss:3.8699 train_time:41718ms step_avg:163.60ms
step:266/1530 train_loss:3.9175 train_time:41883ms step_avg:163.61ms
step:267/1530 train_loss:3.9257 train_time:42049ms step_avg:163.62ms
step:268/1530 train_loss:3.9573 train_time:42215ms step_avg:163.62ms
step:269/1530 train_loss:3.8553 train_time:42380ms step_avg:163.63ms
step:270/1530 train_loss:4.0960 train_time:42546ms step_avg:163.64ms
step:271/1530 train_loss:3.9683 train_time:42713ms step_avg:163.65ms
step:272/1530 train_loss:3.9323 train_time:42878ms step_avg:163.66ms
step:273/1530 train_loss:3.9458 train_time:43043ms step_avg:163.66ms
step:274/1530 train_loss:4.0381 train_time:43210ms step_avg:163.68ms
step:275/1530 train_loss:4.0583 train_time:43376ms step_avg:163.68ms
step:276/1530 train_loss:4.2298 train_time:43542ms step_avg:163.69ms
step:277/1530 train_loss:4.0372 train_time:43709ms step_avg:163.70ms
step:278/1530 train_loss:4.0920 train_time:43873ms step_avg:163.71ms
step:279/1530 train_loss:4.0050 train_time:44039ms step_avg:163.71ms
step:280/1530 train_loss:4.1826 train_time:44208ms step_avg:163.73ms
step:281/1530 train_loss:3.9704 train_time:44374ms step_avg:163.74ms
step:282/1530 train_loss:3.9368 train_time:44540ms step_avg:163.75ms
step:283/1530 train_loss:3.9085 train_time:44706ms step_avg:163.76ms
step:284/1530 train_loss:4.0426 train_time:44872ms step_avg:163.77ms
step:285/1530 train_loss:4.0613 train_time:45038ms step_avg:163.77ms
step:286/1530 train_loss:4.0900 train_time:45203ms step_avg:163.78ms
step:287/1530 train_loss:3.9184 train_time:45368ms step_avg:163.78ms
step:288/1530 train_loss:4.0175 train_time:45533ms step_avg:163.79ms
step:289/1530 train_loss:3.8646 train_time:45698ms step_avg:163.79ms
step:290/1530 train_loss:3.8625 train_time:45863ms step_avg:163.80ms
step:291/1530 train_loss:3.9090 train_time:46029ms step_avg:163.80ms
step:292/1530 train_loss:3.8691 train_time:46193ms step_avg:163.81ms
step:293/1530 train_loss:3.9071 train_time:46358ms step_avg:163.81ms
step:294/1530 train_loss:3.9422 train_time:46524ms step_avg:163.82ms
step:295/1530 train_loss:3.8393 train_time:46688ms step_avg:163.82ms
step:296/1530 train_loss:3.8597 train_time:46853ms step_avg:163.82ms
step:297/1530 train_loss:3.8660 train_time:47018ms step_avg:163.83ms
step:298/1530 train_loss:3.9691 train_time:47184ms step_avg:163.83ms
step:299/1530 train_loss:3.8266 train_time:47349ms step_avg:163.84ms
step:300/1530 train_loss:3.9718 train_time:47514ms step_avg:163.84ms
step:301/1530 train_loss:3.9647 train_time:47678ms step_avg:163.84ms
step:302/1530 train_loss:3.9374 train_time:47843ms step_avg:163.85ms
step:303/1530 train_loss:3.9838 train_time:48009ms step_avg:163.85ms
step:304/1530 train_loss:3.9695 train_time:48173ms step_avg:163.85ms
step:305/1530 train_loss:4.4543 train_time:48338ms step_avg:163.86ms
step:306/1530 train_loss:3.9365 train_time:48504ms step_avg:163.86ms
step:307/1530 train_loss:3.8366 train_time:48668ms step_avg:163.87ms
step:308/1530 train_loss:3.9820 train_time:48833ms step_avg:163.87ms
step:309/1530 train_loss:3.8642 train_time:48997ms step_avg:163.87ms
step:310/1530 train_loss:4.0891 train_time:49161ms step_avg:163.87ms
step:311/1530 train_loss:3.9238 train_time:49328ms step_avg:163.88ms
step:312/1530 train_loss:3.8615 train_time:49493ms step_avg:163.88ms
step:313/1530 train_loss:3.9430 train_time:49658ms step_avg:163.89ms
step:314/1530 train_loss:4.0654 train_time:49824ms step_avg:163.89ms
step:315/1530 train_loss:3.9431 train_time:49988ms step_avg:163.90ms
step:316/1530 train_loss:3.7950 train_time:50153ms step_avg:163.90ms
step:317/1530 train_loss:3.8732 train_time:50319ms step_avg:163.90ms
step:318/1530 train_loss:3.9190 train_time:50483ms step_avg:163.91ms
step:319/1530 train_loss:3.8828 train_time:50649ms step_avg:163.91ms
step:320/1530 train_loss:4.0169 train_time:50814ms step_avg:163.92ms
step:321/1530 train_loss:3.9631 train_time:50978ms step_avg:163.92ms
step:322/1530 train_loss:3.9383 train_time:51143ms step_avg:163.92ms
step:323/1530 train_loss:4.0124 train_time:51309ms step_avg:163.93ms
step:324/1530 train_loss:3.9493 train_time:51473ms step_avg:163.93ms
step:325/1530 train_loss:4.0269 train_time:51638ms step_avg:163.93ms
step:326/1530 train_loss:3.8902 train_time:51806ms step_avg:163.94ms
step:327/1530 train_loss:4.3940 train_time:51970ms step_avg:163.94ms
step:328/1530 train_loss:4.0702 train_time:52136ms step_avg:163.95ms
step:329/1530 train_loss:3.7934 train_time:52301ms step_avg:163.95ms
step:330/1530 train_loss:3.7543 train_time:52467ms step_avg:163.96ms
step:331/1530 train_loss:3.9733 train_time:52632ms step_avg:163.96ms
step:332/1530 train_loss:3.9121 train_time:52796ms step_avg:163.96ms
step:333/1530 train_loss:3.8895 train_time:52961ms step_avg:163.97ms
step:334/1530 train_loss:3.8393 train_time:53128ms step_avg:163.98ms
step:335/1530 train_loss:4.0138 train_time:53293ms step_avg:163.98ms
step:336/1530 train_loss:3.9527 train_time:53458ms step_avg:163.98ms
step:337/1530 train_loss:4.4280 train_time:53623ms step_avg:163.99ms
step:338/1530 train_loss:3.9429 train_time:53789ms step_avg:163.99ms
step:339/1530 train_loss:3.8712 train_time:53953ms step_avg:163.99ms
step:340/1530 train_loss:3.9297 train_time:54118ms step_avg:164.00ms
step:341/1530 train_loss:3.8585 train_time:54285ms step_avg:164.00ms
step:342/1530 train_loss:3.8156 train_time:54453ms step_avg:164.01ms
step:343/1530 train_loss:3.8380 train_time:54620ms step_avg:164.02ms
step:344/1530 train_loss:3.9944 train_time:54789ms step_avg:164.04ms
step:345/1530 train_loss:3.8194 train_time:54958ms step_avg:164.05ms
step:346/1530 train_loss:3.7704 train_time:55127ms step_avg:164.07ms
step:347/1530 train_loss:3.8022 train_time:55295ms step_avg:164.08ms
step:348/1530 train_loss:3.8601 train_time:55463ms step_avg:164.09ms
step:349/1530 train_loss:3.8315 train_time:55632ms step_avg:164.11ms
step:350/1530 train_loss:3.5683 train_time:55799ms step_avg:164.11ms
step:351/1530 train_loss:3.8266 train_time:55967ms step_avg:164.13ms
step:352/1530 train_loss:4.1802 train_time:56134ms step_avg:164.14ms
step:353/1530 train_loss:3.6566 train_time:56301ms step_avg:164.14ms
step:354/1530 train_loss:3.9294 train_time:56469ms step_avg:164.15ms
step:355/1530 train_loss:3.7894 train_time:56637ms step_avg:164.17ms
step:356/1530 train_loss:3.8861 train_time:56806ms step_avg:164.18ms
step:357/1530 train_loss:3.7558 train_time:56973ms step_avg:164.19ms
step:358/1530 train_loss:3.8600 train_time:57142ms step_avg:164.20ms
step:359/1530 train_loss:3.7919 train_time:57313ms step_avg:164.22ms
step:360/1530 train_loss:3.4346 train_time:57482ms step_avg:164.23ms
step:361/1530 train_loss:4.0236 train_time:57650ms step_avg:164.25ms
step:362/1530 train_loss:3.9197 train_time:57818ms step_avg:164.26ms
step:363/1530 train_loss:3.8412 train_time:57986ms step_avg:164.27ms
step:364/1530 train_loss:3.7469 train_time:58154ms step_avg:164.28ms
step:365/1530 train_loss:3.9169 train_time:58324ms step_avg:164.29ms
step:366/1530 train_loss:3.8626 train_time:58491ms step_avg:164.30ms
step:367/1530 train_loss:3.8627 train_time:58658ms step_avg:164.31ms
step:368/1530 train_loss:3.8562 train_time:58826ms step_avg:164.32ms
step:369/1530 train_loss:3.7490 train_time:58994ms step_avg:164.33ms
step:370/1530 train_loss:3.8784 train_time:59162ms step_avg:164.34ms
step:371/1530 train_loss:3.7351 train_time:59331ms step_avg:164.35ms
step:372/1530 train_loss:3.7022 train_time:59498ms step_avg:164.36ms
step:373/1530 train_loss:3.9077 train_time:59665ms step_avg:164.37ms
step:374/1530 train_loss:3.8299 train_time:59833ms step_avg:164.38ms
step:375/1530 train_loss:3.8068 train_time:60001ms step_avg:164.39ms
step:375/1530 val_loss:3.8271 train_time:60049ms step_avg:164.52ms