-
Notifications
You must be signed in to change notification settings - Fork 4
/
train_vision.py
110 lines (85 loc) · 3.21 KB
/
train_vision.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
import argparse
import logging
import os
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.tensorboard import SummaryWriter
from tqdm import tqdm
from configs import configs
from input_pipeline import get_loaders
from models import VisiongMLP
logger = logging.getLogger(__name__)
def train(args, model, train_loader):
tb_writer = SummaryWriter()
args.train_batch_size = args.per_gpu_train_batch_size * max(1, args.n_gpu)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=3e-4)
global_step = 1
tr_loss, logging_loss = 0.0, 0.0
model.zero_grad()
for epoch in range(args.num_train_epochs):
loop = tqdm(train_loader)
losses = []
for images, labels in loop:
images = images.to(args.device)
preds = model(images)
loss = criterion(preds, labels)
if args.n_gpu > 1:
loss = loss.mean()
loss.backward()
tr_loss += loss.item()
optimizer.step()
model.zero_grad()
global_step += 1
if args.logging_steps > 0 and global_step % args.logging_steps == 0:
tb_writer.add_scalar("loss", (tr_loss - logging_loss) / args.logging_steps, global_step)
logging_loss = tr_loss
losses.append(loss.item())
loop.set_description(
f"Epoch: {epoch+1}/{args.num_train_epochs} | Epoch loss: {sum(losses)/len(losses)}"
)
tb_writer.close()
return global_step
def main():
# TODO: logger, argparse config
parser = argparse.ArgumentParser()
# Required parameters
parser.add_argument(
"--model_type",
default=None,
type=str,
required=False,
help="Model type selected in the list: " + ", ",
)
parser.add_argument(
"--output_dir",
default=None,
type=str,
required=False,
help="The output directory where the model checkpoints will be saved",
)
# Other parameters
parser.add_argument("--lr", default=3e-4, type=float, help="Learning rate.")
parser.add_argument(
"--num_train_epochs", default=10, type=int, help="Total number of epochs for training."
)
parser.add_argument(
"--per_gpu_train_batch_size", default=4, type=int, help="Batch size per GPU/CPU when training."
)
parser.add_argument("--no_cuda", action="store_true", help="Whether not to use CUDA when available.")
parser.add_argument("--logging_steps", default=100, type=int, help="Log every X updates steps.")
args = parser.parse_args()
if args.no_cuda:
device = torch.device("cuda" if torch.cuda.is_available() and not args.no_cuda else "cpu")
args.n_gpu = 0 if args.no_cuda else torch.cuda.device_count()
else:
raise NotImplementedError
args.device = device
# TODO load model with user command line selected config
model = VisiongMLP(**configs["Ti"], prob_0_L=[1, 0.5]).to(args.device)
train_loader, eval_loader, test_loader = get_loaders(args.per_gpu_train_batch_size, eval_split=0.15)
print("*** Training ***")
train(args, model, train_loader)
if __name__ == "__main__":
main()