-
Notifications
You must be signed in to change notification settings - Fork 248
/
train.py
115 lines (93 loc) · 3.76 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
from __future__ import print_function
import os
import argparse
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import torch.backends.cudnn as cudnn
import torchvision
import torchvision.transforms as transforms
from loss import FocalLoss
from retinanet import RetinaNet
from datagen import ListDataset
from torch.autograd import Variable
parser = argparse.ArgumentParser(description='PyTorch RetinaNet Training')
parser.add_argument('--lr', default=1e-3, type=float, help='learning rate')
parser.add_argument('--resume', '-r', action='store_true', help='resume from checkpoint')
args = parser.parse_args()
assert torch.cuda.is_available(), 'Error: CUDA not found!'
best_loss = float('inf') # best test loss
start_epoch = 0 # start from epoch 0 or last epoch
# Data
print('==> Preparing data..')
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.485,0.456,0.406), (0.229,0.224,0.225))
])
trainset = ListDataset(root='/search/odin/liukuang/data/voc_all_images',
list_file='./data/voc12_train.txt', train=True, transform=transform, input_size=600)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=16, shuffle=True, num_workers=8, collate_fn=trainset.collate_fn)
testset = ListDataset(root='/search/odin/liukuang/data/voc_all_images',
list_file='./data/voc12_val.txt', train=False, transform=transform, input_size=600)
testloader = torch.utils.data.DataLoader(testset, batch_size=16, shuffle=False, num_workers=8, collate_fn=testset.collate_fn)
# Model
net = RetinaNet()
net.load_state_dict(torch.load('./model/net.pth'))
if args.resume:
print('==> Resuming from checkpoint..')
checkpoint = torch.load('./checkpoint/ckpt.pth')
net.load_state_dict(checkpoint['net'])
best_loss = checkpoint['loss']
start_epoch = checkpoint['epoch']
net = torch.nn.DataParallel(net, device_ids=range(torch.cuda.device_count()))
net.cuda()
criterion = FocalLoss()
optimizer = optim.SGD(net.parameters(), lr=args.lr, momentum=0.9, weight_decay=1e-4)
# Training
def train(epoch):
print('\nEpoch: %d' % epoch)
net.train()
net.module.freeze_bn()
train_loss = 0
for batch_idx, (inputs, loc_targets, cls_targets) in enumerate(trainloader):
inputs = Variable(inputs.cuda())
loc_targets = Variable(loc_targets.cuda())
cls_targets = Variable(cls_targets.cuda())
optimizer.zero_grad()
loc_preds, cls_preds = net(inputs)
loss = criterion(loc_preds, loc_targets, cls_preds, cls_targets)
loss.backward()
optimizer.step()
train_loss += loss.data[0]
print('train_loss: %.3f | avg_loss: %.3f' % (loss.data[0], train_loss/(batch_idx+1)))
# Test
def test(epoch):
print('\nTest')
net.eval()
test_loss = 0
for batch_idx, (inputs, loc_targets, cls_targets) in enumerate(testloader):
inputs = Variable(inputs.cuda(), volatile=True)
loc_targets = Variable(loc_targets.cuda())
cls_targets = Variable(cls_targets.cuda())
loc_preds, cls_preds = net(inputs)
loss = criterion(loc_preds, loc_targets, cls_preds, cls_targets)
test_loss += loss.data[0]
print('test_loss: %.3f | avg_loss: %.3f' % (loss.data[0], test_loss/(batch_idx+1)))
# Save checkpoint
global best_loss
test_loss /= len(testloader)
if test_loss < best_loss:
print('Saving..')
state = {
'net': net.module.state_dict(),
'loss': test_loss,
'epoch': epoch,
}
if not os.path.isdir('checkpoint'):
os.mkdir('checkpoint')
torch.save(state, './checkpoint/ckpt.pth')
best_loss = test_loss
for epoch in range(start_epoch, start_epoch+200):
train(epoch)
test(epoch)