-
Notifications
You must be signed in to change notification settings - Fork 43
/
test.py
112 lines (87 loc) · 3.75 KB
/
test.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
import argparse
import os
import time
import torch
import torch.backends.cudnn as cudnn
import torchvision.datasets as datasets
from utils import accuracy, ProgressMeter, AverageMeter, val_preprocess
from convnet_utils import switch_deploy_flag, switch_conv_bn_impl, build_model
parser = argparse.ArgumentParser(description='PyTorch ImageNet Test')
parser.add_argument('data', metavar='DIR', help='path to dataset')
parser.add_argument('mode', metavar='MODE', default='train', choices=['train', 'deploy'], help='train or deploy')
parser.add_argument('weights', metavar='WEIGHTS', help='path to the weights file')
parser.add_argument('-a', '--arch', metavar='ARCH', default='ResNet-18')
parser.add_argument('-t', '--blocktype', metavar='BLK', default='DBB', choices=['DBB', 'ACB', 'base'])
parser.add_argument('-j', '--workers', default=4, type=int, metavar='N',
help='number of data loading workers (default: 4)')
parser.add_argument('-b', '--batch-size', default=100, type=int,
metavar='N',
help='mini-batch size (default: 100) for test')
def test():
args = parser.parse_args()
switch_deploy_flag(args.mode == 'deploy')
switch_conv_bn_impl(args.blocktype)
model = build_model(args.arch)
if not torch.cuda.is_available():
print('using CPU, this will be slow')
use_gpu = False
else:
model = model.cuda()
use_gpu = True
# define loss function (criterion) and optimizer
criterion = torch.nn.CrossEntropyLoss().cuda()
if 'hdf5' in args.weights:
from utils import model_load_hdf5
model_load_hdf5(model, args.weights)
elif os.path.isfile(args.weights):
print("=> loading checkpoint '{}'".format(args.weights))
checkpoint = torch.load(args.weights)
if 'state_dict' in checkpoint:
checkpoint = checkpoint['state_dict']
ckpt = {k.replace('module.', ''):v for k,v in checkpoint.items()} # strip the names
model.load_state_dict(ckpt)
else:
print("=> no checkpoint found at '{}'".format(args.weights))
cudnn.benchmark = True
# Data loading code
valdir = os.path.join(args.data, 'val')
val_loader = torch.utils.data.DataLoader(
datasets.ImageFolder(valdir, val_preprocess(224)),
batch_size=args.batch_size, shuffle=False,
num_workers=args.workers, pin_memory=True)
validate(val_loader, model, criterion, use_gpu)
def validate(val_loader, model, criterion, use_gpu):
batch_time = AverageMeter('Time', ':6.3f')
losses = AverageMeter('Loss', ':.4e')
top1 = AverageMeter('Acc@1', ':6.2f')
top5 = AverageMeter('Acc@5', ':6.2f')
progress = ProgressMeter(
len(val_loader),
[batch_time, losses, top1, top5],
prefix='Test: ')
# switch to evaluate mode
model.eval()
with torch.no_grad():
end = time.time()
for i, (images, target) in enumerate(val_loader):
if use_gpu:
images = images.cuda(non_blocking=True)
target = target.cuda(non_blocking=True)
# compute output
output = model(images)
loss = criterion(output, target)
# measure accuracy and record loss
acc1, acc5 = accuracy(output, target, topk=(1, 5))
losses.update(loss.item(), images.size(0))
top1.update(acc1[0], images.size(0))
top5.update(acc5[0], images.size(0))
# measure elapsed time
batch_time.update(time.time() - end)
end = time.time()
if i % 10 == 0:
progress.display(i)
print(' * Acc@1 {top1.avg:.3f} Acc@5 {top5.avg:.3f}'
.format(top1=top1, top5=top5))
return top1.avg
if __name__ == '__main__':
test()