-
Notifications
You must be signed in to change notification settings - Fork 5
/
train_fourgram.py
152 lines (135 loc) · 5.77 KB
/
train_fourgram.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
from caption_model.att import *
from caption_model.fc import *
from mycider import *
from multiprocessing import Pool
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "0"
import ngram_opts
from tools import *
from dataloader import *
opts = ngram_opts.parse_opt()
if opts.caption_model == 'fc':
opts.use_att = False
else:
opts.use_att = True
batch_size = opts.batch_size
loader = KKDataLoader(opts)
vocabs = loader.get_vocab()
vocab = ['#END#']
for i in range(len(vocabs)):
ids = str(i+1)
vocab.append(vocabs[ids])
if not os.path.exists('fourgram_cider_model'):
os.mkdir('fourgram_cider_model')
if opts.use_att:
save_dir = 'fourgram_cider_model/' + 'att_model'
else:
save_dir = 'fourgram_cider_model/' + 'fc_model'
if not os.path.exists(save_dir):
os.mkdir(save_dir)
print(save_dir + ' has been built')
image_dim = 2048
vocab_size = loader.vocab_size + 1
cell_size = 512
lr = 0.00005
if opts.use_att:
model = AttModel(batch_size=batch_size, image_dim=image_dim, vocab_size=vocab_size, cell_size=cell_size, lr=lr, ngram=4,on_gpu=True)
model.load('warm_model/att_warm/model.init')
else:
model = FCModel(batch_size=batch_size, image_dim=image_dim, vocab_size=vocab_size, cell_size=cell_size, lr=lr, ngram=4,on_gpu=True)
model.load('warm_model/fc_warm/model.init')
gts = transfer_json_to_cider_gts(osp.join('data/features', 'captions_train.json'))
cider_scorer = CiderScorer(refs=gts, n=4, sigma=6.0)
def cider_temp(res):
cider_scorer.cook_append_test(test={res['image_id']: [res['caption']]})
score, _ = cider_scorer.compute_score()
return score
pool = Pool(processes=5)
best_score = -1
logger = Logger(save_dir)
iter = 0
finish_iter = 100000
timer = Timer()
timer.tic()
while iter < finish_iter:
iter += 1
data = loader.get_batch('train')
tmp = [data['fc_feats'][np.arange(loader.batch_size) * loader.seq_per_img],
data['att_feats'][np.arange(loader.batch_size) * loader.seq_per_img]]
fc_feats, att_feats = tmp
image_id = [data['infos'][i]['id'] for i in range(opts.batch_size)]
if opts.use_att:
att_feats = att_feats.reshape(att_feats.shape[0], att_feats.shape[1] * att_feats.shape[2], att_feats.shape[3])
feature = att_feats
else:
feature = fc_feats
greedy_cap, greedy_res = model.inference(vocab, image_id, feature, manner='greedy', max_length=16)
greedy_scores = np.array(pool.map(cider_temp, greedy_res))
all_caps, all_results, all_scores = [], [], []
for _ in xrange(20):
# Generate captions by sampling
sample_caps, sample_results = model.fourgram_inference(vocab, image_id, feature,
manner='sample',
max_length=16)
# Compute cider scores for sampled captions
sample_scores = np.array(pool.map(cider_temp, sample_results))
all_caps.append(sample_caps)
all_results.append(sample_results)
all_scores.append(sample_scores)
all_scores = np.array(all_scores)
sample_caps, sample_results, sample_scores = [], [], []
for n in xrange(opts.batch_size):
best_i = all_scores[:, n].argmax()
sample_caps.append(all_caps[best_i][n])
sample_results.append(all_results[best_i][n])
sample_scores.append(all_scores[best_i, n])
sample_scores = np.array(sample_scores)
max_length = max([cap.shape[0] for cap in sample_caps])
caption = np.zeros([max_length + 2, opts.batch_size], dtype=np.int32)
for n in xrange(opts.batch_size):
L = sample_caps[n].shape[0]
caption[1:L + 1, n] = sample_caps[n]
caption[L + 1:, n] = 0
mask = np.zeros([max_length + 1, opts.batch_size], dtype=np.float32)
for n in xrange(opts.batch_size):
L = sample_caps[n].shape[0]
mask[:L + 1, n] = 1
reward = (sample_scores - greedy_scores).astype(np.float32)
print image_id[0]
print 'greedy: ', greedy_scores[0], greedy_res[0]['caption']
print 'sample: ', sample_scores[0], sample_results[0]['caption']
loss_train = model.train_on_batch(feature, caption[1:,:], mask, reward)
if iter % 300 == 0:
results = []
for nn in range(5000/opts.batch_size):
data = loader.get_batch('val')
tmp = [data['fc_feats'][np.arange(loader.batch_size) * loader.seq_per_img],
data['att_feats'][np.arange(loader.batch_size) * loader.seq_per_img]]
fc_feats, att_feats = tmp
if opts.use_att:
att_feats = att_feats.reshape(att_feats.shape[0], att_feats.shape[1] * att_feats.shape[2],
att_feats.shape[3])
feature_val = att_feats
else:
feature_val = fc_feats
image_id = [data['infos'][i]['id'] for i in range(opts.batch_size)]
greedy_cap, greedy_res = model.inference(vocab,image_id,feature_val,manner='greedy',max_length=16)
# Generate sentences for validation set
results += greedy_res
# Evaluate generated captions
json.dump(results, open(osp.join(save_dir, 'result.json'), 'w'))
gt_file = osp.join('data/features', 'captions_val.json')
score = evaluate(gt_file=gt_file, re_file=osp.join(save_dir, 'result.json'))[-1]
if score > best_score:
best_score = score
model.save(osp.join(save_dir, 'model.best'))
model.save(osp.join(save_dir,'model.ckpt'))
# Output training information
logger.info('[{}], tr_loss={:.5f}, score/best={:.3f}/{:.3f}, finish->{}, time={:.1f}sec'
.format(iter, -1, score, best_score, finish_iter, timer.toc()))
# Reset loss and timer
train_losses = []
timer.tic()
# If early-stop condition triggers
if iter > finish_iter:
break