forked from karpathy/nanoGPT
-
Notifications
You must be signed in to change notification settings - Fork 25
/
train.py
45 lines (37 loc) · 1.38 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
"""
This training script can be run both on a single gpu in debug mode,
and also in a larger training run with distributed data parallel (ddp).
To run on a single GPU, example:
$ python train.py --batch_size=32 --compile=False
To run with DDP on 4 gpus on 1 node, example:
$ torchrun --standalone --nproc_per_node=4 train.py
To run with DDP on 4 gpus across 2 nodes, example:
- Run on the first (master) node with example IP 123.456.123.456:
$ torchrun --nproc_per_node=8 --nnodes=2 --node_rank=0 --master_addr=123.456.123.456 --master_port=1234 train.py
- Run on the worker node:
$ torchrun --nproc_per_node=8 --nnodes=2 --node_rank=1 --master_addr=123.456.123.456 --master_port=1234 train.py
(If your cluster does not have Infiniband interconnect prepend NCCL_IB_DISABLE=1)
"""
import os
import time
import math
import pickle
from contextlib import nullcontext
import numpy as np
import torch
from model import GPTConfig, GPT
import yaml
from torch.nn.parallel import DistributedDataParallel as DDP
from trainers.trainer import Trainer
# load config.yaml from current directory
with open('config/config.yaml') as f:
conf = yaml.load(f, Loader=yaml.FullLoader)
# nested dictionary structure
config = {}
for k, v in conf.items():
for k2, v2 in v.items():
config[k2] = v2
# convert to dotdict
print(config)
trainer = Trainer(config)
trainer.train()