Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix flash attention GQA bug to use the dynamic size of the key/value tensors - used for eval/inference #756

Merged
merged 8 commits into from
Nov 21, 2023

Conversation

sashaDoubov
Copy link
Contributor

This issue shows up w/ flash attention during inference or if running icl tasks, the .view() causes a runtime error due to the wrong size of key_unpad and value_unpad. Ty to @ShashankMosaicML for flagging!
Example:
RuntimeError: shape '[2, 1024, 5, -1]' is invalid for input of size 192000

Tested with 3b models to have same loss with the fix and without:
image
image
Eval only works for the fix:
image

@sashaDoubov sashaDoubov merged commit 1793c36 into mosaicml:main Nov 21, 2023
10 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants