diff --git a/src/transformers/generation/logits_process.py b/src/transformers/generation/logits_process.py index 7f4415dd0dbe84..4b9b91cd8068d9 100644 --- a/src/transformers/generation/logits_process.py +++ b/src/transformers/generation/logits_process.py @@ -1229,7 +1229,14 @@ def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> to mask = torch.full_like(scores, -math.inf) for batch_id, beam_sent in enumerate(input_ids.view(-1, self._num_beams, input_ids.shape[-1])): for beam_id, sent in enumerate(beam_sent): - mask[batch_id * self._num_beams + beam_id, self._prefix_allowed_tokens_fn(batch_id, sent)] = 0 + prefix_allowed_tokens = self._prefix_allowed_tokens_fn(batch_id, sent) + if len(prefix_allowed_tokens) == 0: + raise ValueError( + f"`prefix_allowed_tokens_fn` returned an empty list for batch ID {batch_id}." + f"This means that the constraint is unsatisfiable. Please check your implementation" + f"of `prefix_allowed_tokens_fn` " + ) + mask[batch_id * self._num_beams + beam_id, prefix_allowed_tokens] = 0 return scores + mask diff --git a/tests/generation/test_logits_process.py b/tests/generation/test_logits_process.py index 9e5ccd16eb7d12..b1b3602c927dba 100644 --- a/tests/generation/test_logits_process.py +++ b/tests/generation/test_logits_process.py @@ -610,6 +610,13 @@ def prefix_allowed_tokens_fn(batch_id, inputs_ids): torch.isinf(filtered_scores).tolist(), [[False, False, True, True, True], [True, True, False, False, True]] ) + def empty_prefix_allowed_tokens_fn(batch_id, inputs_ids): + return [] + + prefix_constrained_logits_proc = PrefixConstrainedLogitsProcessor(empty_prefix_allowed_tokens_fn, 1) + + self.assertRaises(ValueError, prefix_constrained_logits_proc, input_ids, scores.clone()) + def test_hamming_diversity(self): vocab_size = 4 num_beams = 2