-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathextract_features.py
180 lines (149 loc) · 7.95 KB
/
extract_features.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
import argparse
import os
import time
import pickle
import pdb
import numpy as np
import torch
from torch.utils.model_zoo import load_url
from torchvision import transforms
from cirtorch.networks.imageretrievalnet import init_network, extract_vectors
from cirtorch.datasets.datahelpers import cid2filename
from cirtorch.datasets.testdataset import configdataset
from cirtorch.utils.download import download_train, download_test
from cirtorch.utils.whiten import whitenlearn, whitenapply
from cirtorch.utils.evaluate import compute_map_and_print
from cirtorch.utils.general import get_data_root, htime
PRETRAINED = {
'retrievalSfM120k-vgg16-gem' : 'http://cmp.felk.cvut.cz/cnnimageretrieval/data/networks/retrieval-SfM-120k/retrievalSfM120k-vgg16-gem-b4dcdc6.pth',
'retrievalSfM120k-resnet101-gem' : 'http://cmp.felk.cvut.cz/cnnimageretrieval/data/networks/retrieval-SfM-120k/retrievalSfM120k-resnet101-gem-b80fb85.pth',
'rSfM120k-tl-resnet50-gem-w' : 'http://cmp.felk.cvut.cz/cnnimageretrieval/data/networks/retrieval-SfM-120k/rSfM120k-tl-resnet50-gem-w-97bf910.pth',
'rSfM120k-tl-resnet101-gem-w' : 'http://cmp.felk.cvut.cz/cnnimageretrieval/data/networks/retrieval-SfM-120k/rSfM120k-tl-resnet101-gem-w-a155e54.pth',
'rSfM120k-tl-resnet152-gem-w' : 'http://cmp.felk.cvut.cz/cnnimageretrieval/data/networks/retrieval-SfM-120k/rSfM120k-tl-resnet152-gem-w-f39cada.pth',
'gl18-tl-resnet50-gem-w' : 'http://cmp.felk.cvut.cz/cnnimageretrieval/data/networks/gl18/gl18-tl-resnet50-gem-w-83fdc30.pth',
'gl18-tl-resnet101-gem-w' : 'http://cmp.felk.cvut.cz/cnnimageretrieval/data/networks/gl18/gl18-tl-resnet101-gem-w-a4d43db.pth',
'gl18-tl-resnet152-gem-w' : 'http://cmp.felk.cvut.cz/cnnimageretrieval/data/networks/gl18/gl18-tl-resnet152-gem-w-21278d5.pth',
}
datasets_names = ['oxford5k', 'paris6k', 'roxford5k', 'rparis6k', 'retrieval-sfm-120k', 'retrieval-SfM-30k']
parser = argparse.ArgumentParser(description='Feature extractor for a given model and dataset.')
# network
group = parser.add_mutually_exclusive_group(required=True)
group.add_argument('--network-path', '-npath', metavar='NETWORK',
help="pretrained network or network path (destination where network is saved)")
group.add_argument('--network-offtheshelf', '-noff', metavar='NETWORK',
help="off-the-shelf network, in the format 'ARCHITECTURE-POOLING' or 'ARCHITECTURE-POOLING-{reg-lwhiten-whiten}'," +
" examples: 'resnet101-gem' | 'resnet101-gem-reg' | 'resnet101-gem-whiten' | 'resnet101-gem-lwhiten' | 'resnet101-gem-reg-whiten'")
parser.add_argument('--datasets', '-d', metavar='DATASETS', default='roxford5k,rparis6k',
help="comma separated list of test datasets: " +
" | ".join(datasets_names) +
" (default: 'oxford5k,paris6k')")
parser.add_argument('--image-size', '-imsize', default=1024, type=int, metavar='N',
help="maximum size of longer image side used for testing (default: 1024)")
parser.add_argument('--multiscale', '-ms', metavar='MULTISCALE', default='[1]',
help="use multiscale vectors for testing, " +
" examples: '[1]' | '[1, 1/2**(1/2), 1/2]' | '[1, 2**(1/2), 1/2**(1/2)]' (default: '[1]')")
parser.add_argument('--gpu-id', '-g', default='0', metavar='N',
help="gpu id used for testing (default: '0')")
def pil_loader(path):
# to avoid crashing for truncated (corrupted images)
ImageFile.LOAD_TRUNCATED_IMAGES = True
# open path as file to avoid ResourceWarning
#(https://github.com/python-pillow/Pillow/issues/835)
with open(path, 'rb') as f:
img = Image.open(f)
return img.convert('RGB')
def main():
args = parser.parse_args()
# check if there are unknown datasets
for dataset in args.datasets.split(','):
if dataset not in datasets_names:
raise ValueError('Unsupported or unknown dataset: {}!'.format(dataset))
# check if test dataset are downloaded
# and download if they are not
download_train(get_data_root())
download_test(get_data_root())
# setting up the visible GPU
os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu_id
# loading network from path
if args.network_path is not None:
print(">> Loading network:\n>>>> '{}'".format(args.network_path))
if args.network_path in PRETRAINED:
# pretrained networks (downloaded automatically)
state = load_url(PRETRAINED[args.network_path], model_dir=os.path.join(get_data_root(), 'networks'))
else:
# fine-tuned network from path
state = torch.load(args.network_path)
# parsing net params from meta
# architecture, pooling, mean, std required
# the rest has default values, in case that is doesnt exist
net_params = {}
net_params['architecture'] = state['meta']['architecture']
net_params['pooling'] = state['meta']['pooling']
net_params['local_whitening'] = state['meta'].get('local_whitening', False)
net_params['regional'] = state['meta'].get('regional', False)
net_params['whitening'] = state['meta'].get('whitening', False)
net_params['mean'] = state['meta']['mean']
net_params['std'] = state['meta']['std']
net_params['pretrained'] = False
# load network
net = init_network(net_params)
net.load_state_dict(state['state_dict'])
# if whitening is precomputed
print(">>>> loaded network: ")
print(net.meta_repr())
# loading offtheshelf network
elif args.network_offtheshelf is not None:
# parse off-the-shelf parameters
offtheshelf = args.network_offtheshelf.split('-')
net_params = {}
net_params['architecture'] = offtheshelf[0]
net_params['pooling'] = offtheshelf[1]
net_params['local_whitening'] = 'lwhiten' in offtheshelf[2:]
net_params['regional'] = 'reg' in offtheshelf[2:]
net_params['whitening'] = 'whiten' in offtheshelf[2:]
net_params['pretrained'] = True
# load off-the-shelf network
print(">> Loading off-the-shelf network:\n>>>> '{}'".format(args.network_offtheshelf))
net = init_network(net_params)
print(">>>> loaded network: ")
print(net.meta_repr())
# setting up the multi-scale parameters
ms = list(eval(args.multiscale))
if len(ms)>1 and net.meta['pooling'] == 'gem' and not net.meta['regional'] and not net.meta['whitening']:
msp = net.pool.p.item()
print(">> Set-up multiscale:")
print(">>>> ms: {}".format(ms))
print(">>>> msp: {}".format(msp))
else:
msp = 1
# moving network to gpu and eval mode
net.cuda()
net.eval()
# set up the transform
normalize = transforms.Normalize(
mean=net.meta['mean'],
std=net.meta['std']
)
transform = transforms.Compose([
transforms.ToTensor(),
normalize
])
# evaluate on test datasets
datasets = args.datasets.split(',')
for dataset in datasets:
start = time.time()
print('>> {}: Extracting...'.format(dataset))
data_root = get_data_root()
cfg = configdataset(dataset, os.path.join(data_root, 'datasets'))
images = [cfg['im_fname'](cfg,i) for i in range(cfg['n'])]
vecs = extract_vectors(net, images, args.image_size, transform, ms=ms, msp=msp)
feat_dir = os.path.join(data_root, 'features')
out_path = os.path.join(feat_dir, '%s_%s' % (dataset, args.network_path))
np.save(out_path, vecs)
out_path_list = os.path.join(feat_dir, '%s_%s_img_list.txt' % (dataset, args.network_path))
with open(out_path_list, 'w') as opl:
for x in images:
opl.write(x+'\n')
print('>> {}: elapsed time: {}'.format(dataset, htime(time.time()-start)))
if __name__ == '__main__':
main()