-
Notifications
You must be signed in to change notification settings - Fork 2
/
caption_utils.py
90 lines (68 loc) · 3.54 KB
/
caption_utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
################################################################################
# CSE 253: Programming Assignment 4
# Code snippet by Ajit Kumar, Savyasachi
# Fall 2020
################################################################################
from nltk.translate.bleu_score import sentence_bleu, SmoothingFunction
import torch
# See this for input references - https://urldefense.com/v3/__https://www.nltk.org/api/nltk.translate.html*nltk.translate.bleu_score.sentence_bleu__;Iw!!Mih3wA!Wq83jaNrHwIpeQ6Nhqht_dgBzF3jc5LYS3MZ-AYh6xIYveu-JINbTzAkxsclLYU2$
# A Caption should be a list of strings.
# Reference Captions are list of actual captions - list(list(str))
# Predicted Caption is the string caption based on your model's output - list(str)
# Make sure to process your captions before evaluating bleu scores -
# Converting to lower case, Removing tokens like <start>, <end>, padding etc.
def bleu1(reference_captions, predicted_caption):
return 100 * sentence_bleu(reference_captions, predicted_caption,
weights=(1, 0, 0, 0), smoothing_function=SmoothingFunction().method1)
def bleu4(reference_captions, predicted_caption):
return 100 * sentence_bleu(reference_captions, predicted_caption,
weights=(0.25, 0.25, 0.25, 0.25), smoothing_function=SmoothingFunction().method1)
def clean_caption(indx, tokenizer):
# Transforms an array of vocab indexes into an array of words with start, end, and pad tags removed.
clean_text = []
clean_text_joined = []
for i in range(len(indx)):
new_sentence = []
new_sentence_joined = ''
c = 0
for j, p in enumerate(indx[i]):
word = tokenizer._convert_id_to_token(int(p))
if not (('[' in word) or (']' in word) or ('##' in word)):
new_sentence.append(word)
if c == 0:
c += 1
new_sentence_joined = new_sentence_joined + word
else:
new_sentence_joined = new_sentence_joined + " " + word
elif '[SEP]' in word:
break # End of sentence is reached x
clean_text.append(new_sentence)
clean_text_joined.append(new_sentence_joined)
return clean_text, clean_text_joined
def get_captions(img_ids, cocoTest):
# Gets the captions from the coco object given image ids from the dataloader
all_caps = []
for j in range(len(img_ids)):
cap_for_1image = []
one_image_info = cocoTest.imgToAnns[img_ids[j]] # A list of dictionary for one image, with keys 'image_id', id', 'caption'.
all_caps.append([])
for k in range(len(one_image_info)):
cap_for_1image.append(one_image_info[k]['caption'].lower().split()) # (batch_size, 5, max_length)
all_caps[-1] = cap_for_1image
return all_caps
def calculate_bleu(bleu_func, clean_preds_text, clean_targets_text):
# Calculates the aggregate bleu value in the bleu and clean targets text
b = 0
for pred_text, targets_text in zip(clean_preds_text, clean_targets_text):
# clean_pred_text: (batch_size, max_length)
# clean_targets_text: (batch_size, 5, max_length)
b += bleu_func(targets_text, pred_text)
return b
def stochastic_generation(outputs_raw, temperature):
# Calculate weighted softmax
s = torch.nn.Softmax(dim=2)
weighted_softmax = s(outputs_raw / temperature)
# Sample from probability distribution
prob_dist = torch.distributions.Categorical(weighted_softmax)
preds = prob_dist.sample()
return preds