-
Notifications
You must be signed in to change notification settings - Fork 27.5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Avoid nan during sampling in generate() #17937
Conversation
I have some doubts here, as this will make all tokens having equal probability to be sampled. But with all |
The documentation is not available anymore as the PR was closed or merged. |
@@ -1970,8 +1970,19 @@ def sample( | |||
else (outputs.hidden_states,) | |||
) | |||
|
|||
# To avoid all `-inf` along the vocab dimension (dim -1), which gives `nan` after `softmax` and error | |||
# in `torch.multinomial`. | |||
_next_token_scores = torch.max( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@ydshieh, softmax should be able to handle -inf
correctly actually.
You can try:
torch.nn.functional.softmax(torch.tensor([0, float("-inf")]))
which works as mathematically expected.
It's only when all values are -inf
that it doesn't work in which case this fix won't help because the generation is broken.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This will fix the nan
issue actually. The concern is that it doesn't really make sense, as it changes the probability to uniform distribution along vocab dim, while in the broken cases, it is nothing can't be sampled (all probability 0 , mathematically)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
As explained here: https://github.com/huggingface/transformers/pull/17937/files#r910424357
this won't fix the problem. Also note that generation is used a lot so it's every additional operation (torch.max(...)
) leads to a tiny slow down.
Usually if you get nan's
after the softmax it means that the generation is broken anyways which can happen and I think there is little we can do against it
Yes, that happens only when all |
What does this PR do?
Fix CI test error
in
https://github.com/huggingface/transformers/runs/6959698965?check_suite_focus=true
The test
test_sample_generate
may still fail attransformers/tests/generation/test_generation_utils.py
Line 711 in 8f40077
for some unknown reason. I think it is better to investigate this in another PR.