-
Notifications
You must be signed in to change notification settings - Fork 0
/
engine.py
291 lines (240 loc) · 10 KB
/
engine.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
import time
import numpy as np
import torch
import torch.nn.parallel
import torch.nn.functional as F
from utils import Bar, AverageMeter, accuracy
from torch.autograd import Variable
def train(args, labeled_trainloader, unlabeled_trainloader, model, optimizer, ema_optimizer, criterion, epoch, use_cuda):
batch_time = AverageMeter()
data_time = AverageMeter()
losses = AverageMeter()
losses_x = AverageMeter()
losses_u = AverageMeter()
ws = AverageMeter()
end = time.time()
bar = Bar('Training', max=args.train_iteration)
labeled_train_iter = iter(labeled_trainloader)
unlabeled_train_iter = iter(unlabeled_trainloader)
model.train()
for batch_idx in range(args.train_iteration):
# we need to reinitialize the iterator in case the number of steps exceeds the data,
# e.g. 1024 train iterations > 250 labeled data
try:
inputs_x, targets_x = labeled_train_iter.next()
except:
labeled_train_iter = iter(labeled_trainloader)
inputs_x, targets_x = labeled_train_iter.next()
try:
# augmented inputs K = 2
(inputs_u, inputs_u2), _ = unlabeled_train_iter.next()
except:
unlabeled_train_iter = iter(unlabeled_trainloader)
(inputs_u, inputs_u2), _ = unlabeled_train_iter.next()
# measure data loading time
data_time.update(time.time() - end)
batch_size = inputs_x.size(0)
number_of_classes = 7
# Transform label to one-hot
targets_x = torch.zeros(batch_size, number_of_classes).scatter_(1, targets_x.view(-1, 1).long(), 1)
if use_cuda:
inputs_x, targets_x = inputs_x.cuda(), targets_x.cuda(non_blocking=True)
inputs_u = inputs_u.cuda()
inputs_u2 = inputs_u2.cuda()
with torch.no_grad():
# compute guessed labels of unlabeled samples
# do entropy minimization
outputs_u = model(inputs_u)
outputs_u2 = model(inputs_u2)
p = (torch.softmax(outputs_u, dim=1) + torch.softmax(outputs_u2, dim=1)) / 2
# Sharpening with temperature T
pt = p ** (1 / args.T)
targets_u = pt / pt.sum(dim=1, keepdim=True)
targets_u = targets_u.detach()
# mixup
all_inputs = torch.cat([inputs_x, inputs_u, inputs_u2], dim=0)
all_targets = torch.cat([targets_x, targets_u, targets_u], dim=0)
# Choose lambda from Beta distribution
l = np.random.beta(args.alpha, args.alpha)
l = max(l, 1 - l)
idx = torch.randperm(all_inputs.size(0))
# Take random inputs both labeled and unlabeled from previous concatenation
# And also save the ordered label/unlabeled data
input_a, input_b = all_inputs, all_inputs[idx]
target_a, target_b = all_targets, all_targets[idx]
mixed_input = l * input_a + (1 - l) * input_b
mixed_target = l * target_a + (1 - l) * target_b
# interleave labeled and unlabed samples between batches to get correct batchnorm calculation
mixed_input = list(torch.split(mixed_input, batch_size))
mixed_input = interleave(mixed_input, batch_size)
logits = [model(mixed_input[0])]
for input in mixed_input[1:]:
logits.append(model(input))
# put interleaved samples back
logits = interleave(logits, batch_size)
logits_x = logits[0]
logits_u = torch.cat(logits[1:], dim=0)
Lx, Lu, w = criterion(logits_x, mixed_target[:batch_size], logits_u, mixed_target[batch_size:],
epoch + batch_idx / args.train_iteration)
loss = Lx + w * Lu
# record loss
losses.update(loss.item(), inputs_x.size(0))
losses_x.update(Lx.item(), inputs_x.size(0))
losses_u.update(Lu.item(), inputs_x.size(0))
ws.update(w, inputs_x.size(0))
# compute gradient and do SGD step
optimizer.zero_grad()
loss.backward()
optimizer.step()
ema_optimizer.step()
# measure elapsed time
batch_time.update(time.time() - end)
end = time.time()
# plot progress
bar.suffix = '({batch}/{size}) Data: {data:.3f}s | Batch: {bt:.3f}s | Total: {total:} | ETA: {eta:} | Loss: {loss:.4f} | Loss_x: {loss_x:.4f} | Loss_u: {loss_u:.4f} | W: {w:.4f}'.format(
batch=batch_idx + 1,
size=args.train_iteration,
data=data_time.avg,
bt=batch_time.avg,
total=bar.elapsed_td,
eta=bar.eta_td,
loss=losses.avg,
loss_x=losses_x.avg,
loss_u=losses_u.avg,
w=ws.avg,
)
bar.next()
if batch_idx % 100 == 0:
print(bar.suffix)
bar.finish()
return (losses.avg, losses_x.avg, losses_u.avg,)
def train_supervised(args, trainloader, model, optimizer, criterion, mixup, use_cuda):
batch_time = AverageMeter()
data_time = AverageMeter()
losses = AverageMeter()
end = time.time()
bar = Bar(f'Training', max=len(trainloader))
model.train()
for batch_idx, (inputs, targets) in enumerate(trainloader):
# measure data loading time
data_time.update(time.time() - end)
if mixup:
inputs, targets_a, targets_b, lam = mixup_data(inputs, targets,
args.alpha, use_cuda)
inputs, targets_a, targets_b = map(Variable, (inputs,
targets_a, targets_b))
inputs, targets_a, targets_b = inputs.cuda(), targets_a.cuda(), targets_b.cuda()
# compute output
outputs = model(inputs)
# # For focal loss
# targets_a = F.one_hot(targets_a, num_classes=7).float()
# targets_b = F.one_hot(targets_b, num_classes=7).float()
loss = mixup_criterion(criterion, outputs, targets_a, targets_b, lam)
else: # normal method
targets = targets.type(torch.LongTensor)
inputs, targets = inputs.cuda(), targets.cuda(non_blocking=True)
# compute output
outputs = model(inputs)
# # For focal loss
# targets = F.one_hot(targets, num_classes=7).float()
loss = criterion(outputs, targets)
losses.update(loss.item(), inputs.size(0))
# compute gradient and do SGD step
optimizer.zero_grad()
loss.backward()
optimizer.step()
# measure elapsed time
batch_time.update(time.time() - end)
end = time.time()
# plot progress
bar.suffix = '({batch}/{size}) Data: {data:.3f}s | Batch: {bt:.3f}s | Total: {total:} | ETA: {eta:} | Loss: {loss:.4f}'.format(
batch=batch_idx + 1,
size=len(trainloader),
data=data_time.avg,
bt=batch_time.avg,
total=bar.elapsed_td,
eta=bar.eta_td,
loss=losses.avg,
)
bar.next()
if batch_idx % 100 == 0:
print(bar.suffix)
bar.finish()
return losses.avg
def validate(valloader, model, criterion, epoch, use_cuda, mode):
batch_time = AverageMeter()
data_time = AverageMeter()
losses = AverageMeter()
top1 = AverageMeter()
top5 = AverageMeter()
# switch to evaluate mode
model.eval()
end = time.time()
bar = Bar(f'{mode}', max=len(valloader))
with torch.no_grad():
for batch_idx, (inputs, targets) in enumerate(valloader):
# measure data loading time
data_time.update(time.time() - end)
if use_cuda:
targets = targets.type(torch.LongTensor)
inputs, targets = inputs.cuda(), targets.cuda(non_blocking=True)
# compute output
outputs = model(inputs)
loss = criterion(outputs, targets)
# measure accuracy and record loss
prec1, prec5 = accuracy(outputs, targets, topk=(1, 5))
losses.update(loss.item(), inputs.size(0))
top1.update(prec1.item(), inputs.size(0))
top5.update(prec5.item(), inputs.size(0))
# measure elapsed time
batch_time.update(time.time() - end)
end = time.time()
# plot progress
bar.suffix = '({batch}/{size}) Data: {data:.3f}s | Batch: {bt:.3f}s | Total: {total:} | ETA: {eta:} | Loss: {loss:.4f} | top1: {top1: .4f} | top5: {top5: .4f}'.format(
batch=batch_idx + 1,
size=len(valloader),
data=data_time.avg,
bt=batch_time.avg,
total=bar.elapsed_td,
eta=bar.eta_td,
loss=losses.avg,
top1=top1.avg,
top5=top5.avg,
)
bar.next()
if batch_idx % 50 == 0 or batch_idx == len(valloader)-1:
print(bar.suffix)
bar.finish()
return losses.avg, top1.avg
def interleave_offsets(batch, nu):
groups = [batch // (nu + 1)] * (nu + 1)
for x in range(batch - sum(groups)):
groups[-x - 1] += 1
offsets = [0]
for g in groups:
offsets.append(offsets[-1] + g)
assert offsets[-1] == batch
return offsets
def interleave(xy, batch):
nu = len(xy) - 1
offsets = interleave_offsets(batch, nu)
xy = [[v[offsets[p]:offsets[p + 1]] for p in range(nu + 1)] for v in xy]
for i in range(1, nu + 1):
xy[0][i], xy[i][i] = xy[i][i], xy[0][i]
return [torch.cat(v, dim=0) for v in xy]
def mixup_data(x, y, alpha=1.0, use_cuda=True):
'''Returns mixed inputs, pairs of targets, and lambda'''
if alpha > 0:
lam = np.random.beta(alpha, alpha)
else:
lam = 1
batch_size = x.size()[0]
if use_cuda:
index = torch.randperm(batch_size).cuda()
else:
index = torch.randperm(batch_size)
mixed_x = lam * x + (1 - lam) * x[index, :]
y_a, y_b = y, y[index]
return mixed_x, y_a, y_b, lam
def mixup_criterion(criterion, pred, y_a, y_b, lam):
return lam * criterion(pred, y_a) + (1 - lam) * criterion(pred, y_b)