Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix +/-inf in LSE returned by forward #978

Open
wants to merge 1 commit into
base: main
Choose a base branch
from

Conversation

sgrigory
Copy link
Contributor

@sgrigory sgrigory commented Jun 3, 2024

Forward op was returning +inf in LSE for queries which have no keys to attend to, e.g. when K/V length happens to be 0. This diverges from the definition of LSE = log(exp(L1) + ... exp(L2)) which would give log(0) = -inf.
This PR fixes it, which allows feeding the output LSE directly into ops like merge_attentions without postprocessing.

pytest tests/test_flash_attn.py
...
======================================================================================== 268004 passed, 152064 skipped in 4404.00s (1:13:23) =========================================================================================

@tridao
Copy link
Contributor

tridao commented Jun 27, 2024

One issue I can see is that in the backward pass, if lse = +inf then exp(qk - lse) returns 0, which is what we want. If lse = -inf then exp would blow up.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants