This repository has been archived by the owner on Sep 1, 2021. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 49
/
Copy pathget_flickr30k.py
101 lines (79 loc) · 3.48 KB
/
get_flickr30k.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
# Copyright (c) Facebook, Inc. and its affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
#
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from __future__ import unicode_literals
import argparse
from six.moves import cPickle as pickle
import json
import os
import sys
from tqdm import tqdm
from PIL import Image
import torch
import torchvision
import spacy
parser = argparse.ArgumentParser()
parser.add_argument('--flickr30k_root', type=str)
parser.add_argument('--batch_size', type=int, default=16)
cfg = parser.parse_args()
mean, std = [0.485, 0.456, 0.406], [0.229, 0.224, 0.225]
preprocess_1c = torchvision.transforms.Compose([
torchvision.transforms.Resize(size=256),
torchvision.transforms.CenterCrop(224),
torchvision.transforms.ToTensor(),
torchvision.transforms.Normalize(mean, std)
])
class MyImageFolder(torchvision.datasets.ImageFolder):
def __getitem__(self, index):
return super(MyImageFolder, self).__getitem__(index), self.imgs[index] # return image path
def normf(t, p=2, d=1):
return t / t.norm(p, d, keepdim=True).expand_as(t)
def get_pil_img(root, img_name):
img_pil = Image.open(os.path.join(root, img_name))
if img_pil.mode != 'RGB':
img_pil = img_pil.convert('RGB')
return img_pil
def main():
model = torchvision.models.resnet152(pretrained=True)
model = torch.nn.Sequential(*list(model.children())[:-1])
model = torch.nn.DataParallel(model).cuda()
dataset = MyImageFolder(cfg.flickr30k_root, preprocess_1c)
loader = torch.utils.data.DataLoader(dataset, batch_size=cfg.batch_size, shuffle=False, num_workers=16,
pin_memory=True)
print('Extracting image features...')
img_features = {}
progress = tqdm(loader, mininterval=1, leave=False, file=sys.stdout)
for (inputs, target), files in progress:
outputs = normf(model(torch.autograd.Variable(inputs)))
for ii, img_path in enumerate(files[0]):
img_features[img_path.split('/')[-1]] = outputs[ii, :].detach().squeeze().cpu().numpy()
print('Generating dataset pkls...')
all_splits = ['train', 'val', 'test']
captions = {s: [] for s in all_splits}
features = {s: {} for s in all_splits}
spacy_nlp = spacy.load('en')
with open(os.path.join(cfg.flickr30k_root, 'dataset_flickr30k.json')) as f:
raw_data = json.load(f)
data = {split: [x for x in raw_data['images'] if x['split'] == split] for split in all_splits}
for split in all_splits:
for img_id, image_with_caption in enumerate(data[split]):
img_name = image_with_caption['filename']
features[split][img_id] = img_features[img_name]
for cap in image_with_caption['sentences']:
d = {'image_id': img_id, 'image_path': img_name, 'caption': [t.text for t in spacy_nlp(cap['raw'])]}
captions[split].append(d)
output_root = os.path.join('data', 'datasets', 'flickr30k')
if not os.path.exists(output_root):
os.makedirs(output_root)
for split in captions:
print('Saving to disk (%s)...' % split)
with open(os.path.join(output_root, '%s.pkl' % split), 'wb') as f:
pickle.dump({'features': features[split], 'captions': captions[split]}, f, protocol=pickle.HIGHEST_PROTOCOL)
if __name__ == '__main__':
main()