Skip to content

Commit

Permalink
Allow dbs with temperature.
Browse files Browse the repository at this point in the history
  • Loading branch information
ruotianluo committed May 12, 2019
1 parent d2eef7f commit 8118670
Showing 1 changed file with 2 additions and 0 deletions.
2 changes: 2 additions & 0 deletions models/CaptionModel.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,7 @@ def beam_step(logprobsf, unaug_logprobsf, beam_size, t, beam_seq, beam_seq_logpr

# Start diverse_beam_search
opt = kwargs['opt']
temperature = opt.get('temperature', 1) # This should not affect beam search, but will affect dbs
beam_size = opt.get('beam_size', 10)
group_size = opt.get('group_size', 1)
diversity_lambda = opt.get('diversity_lambda', 0.5)
Expand Down Expand Up @@ -178,6 +179,7 @@ def beam_step(logprobsf, unaug_logprobsf, beam_size, t, beam_seq, beam_seq_logpr

it = beam_seq_table[divm][t-divm]
logprobs_table[divm], state_table[divm] = self.get_logprobs_state(it.cuda(), *(args[divm] + [state_table[divm]]))
logprobs_table[divm] = F.log_softmax(logprobs_table[divm] / temperature, dim=-1)

# all beams are sorted by their log-probabilities
done_beams_table = [sorted(done_beams_table[i], key=lambda x: -x['p'])[:bdash] for i in range(group_size)]
Expand Down

0 comments on commit 8118670

Please sign in to comment.