-
Notifications
You must be signed in to change notification settings - Fork 16
/
engine.py
188 lines (145 loc) · 7.51 KB
/
engine.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
# Copyright (c) 2015-present, Facebook, Inc.
# All rights reserved.
"""
Train and eval functions used in main.py
"""
import math
import sys
from typing import Iterable, Optional
import torch
from timm.data import Mixup
from timm.utils import accuracy, ModelEma
import kornia as K
from losses import DistillationLoss
import utils
import torch.nn as nn
import torch.nn.functional as F
def clamp(X, lower_limit, upper_limit):
return torch.max(torch.min(X, upper_limit), lower_limit)
def PGDAttack(x, y, model, attack_epsilon, attack_alpha, lower_limit, loss_fn, upper_limit, max_iters, random_init):
model.eval()
delta = torch.zeros_like(x).cuda()
if random_init:
for iiiii in range(len(attack_epsilon)):
delta[:, iiiii, :, :].uniform_(-attack_epsilon[iiiii][0][0].item(), attack_epsilon[iiiii][0][0].item())
adv_imgs = clamp(x+delta, lower_limit, upper_limit)
max_iters = int(max_iters)
adv_imgs.requires_grad = True
with torch.enable_grad():
for _iter in range(max_iters):
outputs = model(adv_imgs)
loss = loss_fn(outputs, y)
grads = torch.autograd.grad(loss, adv_imgs, grad_outputs=None,
only_inputs=True)[0]
adv_imgs.data += attack_alpha * torch.sign(grads.data)
adv_imgs = clamp(adv_imgs, x-attack_epsilon, x+attack_epsilon)
adv_imgs = clamp(adv_imgs, lower_limit, upper_limit)
return adv_imgs.detach()
def patch_level_aug(input1, patch_transform, upper_limit, lower_limit):
bs, channle_size, H, W = input1.shape
patches = input1.unfold(2, 16, 16).unfold(3, 16, 16).permute(0,2,3,1,4,5).contiguous().reshape(-1, channle_size,16,16)
patches = patch_transform(patches)
patches = patches.reshape(bs, -1, channle_size,16,16).permute(0,2,3,4,1).contiguous().reshape(bs, channle_size*16*16, -1)
output_images = F.fold(patches, (H,W), 16, stride=16)
output_images = clamp(output_images, lower_limit, upper_limit)
return output_images
def train_one_epoch(args, model: torch.nn.Module, criterion: DistillationLoss,
data_loader: Iterable, optimizer: torch.optim.Optimizer,
device: torch.device, epoch: int, loss_scaler, max_norm: float = 0,
model_ema: Optional[ModelEma] = None, mixup_fn: Optional[Mixup] = None,
set_training_mode=True):
model.train(set_training_mode)
metric_logger = utils.MetricLogger(delimiter=" ")
metric_logger.add_meter('lr', utils.SmoothedValue(window_size=1, fmt='{value:.6f}'))
header = 'Epoch: [{}]'.format(epoch)
print_freq = 10
std_imagenet = torch.tensor((0.229, 0.224, 0.225)).view(3,1,1).to(device)
mu_imagenet = torch.tensor((0.485, 0.456, 0.406)).view(3,1,1).to(device)
upper_limit = ((1 - mu_imagenet)/ std_imagenet)
lower_limit = ((0 - mu_imagenet)/ std_imagenet)
for samples, targets in metric_logger.log_every(data_loader, print_freq, header):
samples = samples.to(device, non_blocking=True)
targets = targets.to(device, non_blocking=True)
if mixup_fn is not None:
samples, targets = mixup_fn(samples, targets)
if args.use_patch_aug:
patch_transform = nn.Sequential(
K.augmentation.RandomResizedCrop(size=(16,16), scale=(0.85,1.0), ratio=(1.0,1.0), p=0.1),
K.augmentation.RandomGaussianNoise(mean=0., std=0.01, p=0.1),
K.augmentation.RandomHorizontalFlip(p=0.1)
)
aug_samples = patch_level_aug(samples, patch_transform, upper_limit, lower_limit)
is_second_order = hasattr(optimizer, 'is_second_order') and optimizer.is_second_order
with torch.cuda.amp.autocast():
if args.use_patch_aug:
outputs2 = model(aug_samples)
loss = criterion(aug_samples, outputs2, targets)
loss_scaler._scaler.scale(loss).backward(create_graph=is_second_order)
outputs = model(samples)
loss = criterion(samples, outputs, targets)
else:
outputs = model(samples)
loss = criterion(samples, outputs, targets)
loss_value = loss.item()
if not math.isfinite(loss_value):
print("Loss is {}, stopping training".format(loss_value))
sys.exit(1)
optimizer.zero_grad()
# this attribute is added by timm on one optimizer (adahessian)
loss_scaler(loss, optimizer, clip_grad=max_norm,
parameters=model.parameters(), create_graph=is_second_order)
torch.cuda.synchronize()
if model_ema is not None:
model_ema.update(model)
metric_logger.update(loss=loss_value)
metric_logger.update(lr=optimizer.param_groups[0]["lr"])
# gather the stats from all processes
metric_logger.synchronize_between_processes()
print("Averaged stats:", metric_logger)
return {k: meter.global_avg for k, meter in metric_logger.meters.items()}
@torch.no_grad()
def evaluate(data_loader, model, device, mask=None, adv=None):
criterion = torch.nn.CrossEntropyLoss()
metric_logger = utils.MetricLogger(delimiter=" ")
header = 'Test:'
# switch to evaluation mode
model.eval()
for images, target in metric_logger.log_every(data_loader, 10, header):
images = images.to(device, non_blocking=True)
target = target.to(device, non_blocking=True)
if adv == 'FGSM':
std_imagenet = torch.tensor((0.229, 0.224, 0.225)).view(3,1,1).cuda()
mu_imagenet = torch.tensor((0.485, 0.456, 0.406)).view(3,1,1).cuda()
attack_epsilon = (1 / 255.) / std_imagenet
attack_alpha = (1 / 255.) / std_imagenet
upper_limit = ((1 - mu_imagenet)/ std_imagenet)
lower_limit = ((0 - mu_imagenet)/ std_imagenet)
adv_input = PGDAttack(images, target, model, attack_epsilon, attack_alpha, lower_limit, criterion, upper_limit, max_iters=1, random_init=False)
elif adv == "PGD":
std_imagenet = torch.tensor((0.229, 0.224, 0.225)).view(3,1,1).cuda()
mu_imagenet = torch.tensor((0.485, 0.456, 0.406)).view(3,1,1).cuda()
attack_epsilon = (1 / 255.) / std_imagenet
attack_alpha = (0.5 / 255.) / std_imagenet
upper_limit = ((1 - mu_imagenet)/ std_imagenet)
lower_limit = ((0 - mu_imagenet)/ std_imagenet)
adv_input = PGDAttack(images, target, model, attack_epsilon, attack_alpha, lower_limit, criterion, upper_limit, max_iters=5, random_init=True)
# compute output
with torch.cuda.amp.autocast():
if adv:
output = model(adv_input)
else:
output = model(images)
loss = criterion(output, target)
if mask is None:
acc1, acc5 = accuracy(output, target, topk=(1, 5))
else:
acc1, acc5 = accuracy(output[:,mask], target, topk=(1, 5))
batch_size = images.shape[0]
metric_logger.update(loss=loss.item())
metric_logger.meters['acc1'].update(acc1.item(), n=batch_size)
metric_logger.meters['acc5'].update(acc5.item(), n=batch_size)
# gather the stats from all processes
metric_logger.synchronize_between_processes()
print('* Acc@1 {top1.global_avg:.3f} Acc@5 {top5.global_avg:.3f} loss {losses.global_avg:.3f}'
.format(top1=metric_logger.acc1, top5=metric_logger.acc5, losses=metric_logger.loss))
return {k: meter.global_avg for k, meter in metric_logger.meters.items()}