Skip to content

Commit

Permalink
Fix: Raise informative exception when prefix_allowed_tokens_fn retu…
Browse files Browse the repository at this point in the history
…rn empty set of tokens (#27797)

Co-authored-by: Arthur <[email protected]>
  • Loading branch information
Saibo-creator and ArthurZucker authored Dec 8, 2023
1 parent 307a7d0 commit 56be5e8
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 1 deletion.
9 changes: 8 additions & 1 deletion src/transformers/generation/logits_process.py
Original file line number Diff line number Diff line change
Expand Up @@ -1229,7 +1229,14 @@ def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> to
mask = torch.full_like(scores, -math.inf)
for batch_id, beam_sent in enumerate(input_ids.view(-1, self._num_beams, input_ids.shape[-1])):
for beam_id, sent in enumerate(beam_sent):
mask[batch_id * self._num_beams + beam_id, self._prefix_allowed_tokens_fn(batch_id, sent)] = 0
prefix_allowed_tokens = self._prefix_allowed_tokens_fn(batch_id, sent)
if len(prefix_allowed_tokens) == 0:
raise ValueError(
f"`prefix_allowed_tokens_fn` returned an empty list for batch ID {batch_id}."
f"This means that the constraint is unsatisfiable. Please check your implementation"
f"of `prefix_allowed_tokens_fn` "
)
mask[batch_id * self._num_beams + beam_id, prefix_allowed_tokens] = 0

return scores + mask

Expand Down
7 changes: 7 additions & 0 deletions tests/generation/test_logits_process.py
Original file line number Diff line number Diff line change
Expand Up @@ -610,6 +610,13 @@ def prefix_allowed_tokens_fn(batch_id, inputs_ids):
torch.isinf(filtered_scores).tolist(), [[False, False, True, True, True], [True, True, False, False, True]]
)

def empty_prefix_allowed_tokens_fn(batch_id, inputs_ids):
return []

prefix_constrained_logits_proc = PrefixConstrainedLogitsProcessor(empty_prefix_allowed_tokens_fn, 1)

self.assertRaises(ValueError, prefix_constrained_logits_proc, input_ids, scores.clone())

def test_hamming_diversity(self):
vocab_size = 4
num_beams = 2
Expand Down

0 comments on commit 56be5e8

Please sign in to comment.