diff --git a/src/transformers/generation_utils.py b/src/transformers/generation_utils.py index a5a2c4abffd..cc07bfbcb6f 100644 --- a/src/transformers/generation_utils.py +++ b/src/transformers/generation_utils.py @@ -1970,8 +1970,19 @@ def sample( else (outputs.hidden_states,) ) + # To avoid all `-inf` along the vocab dimension (dim -1), which gives `nan` after `softmax` and error + # in `torch.multinomial`. + _next_token_scores = torch.max( + next_token_scores, + torch.tensor( + torch.finfo(next_token_scores.dtype).min, + dtype=next_token_scores.dtype, + device=next_token_scores.device, + ), + ) + # sample - probs = nn.functional.softmax(next_token_scores, dim=-1) + probs = nn.functional.softmax(_next_token_scores, dim=-1) next_tokens = torch.multinomial(probs, num_samples=1).squeeze(1) # finished sentences should have their next token be a padding token