-
Notifications
You must be signed in to change notification settings - Fork 547
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Shashank/seq id flash attn #738
Shashank/seq id flash attn #738
Conversation
Pulling the latest commits from main fork
Pulling from the main repo
Pulling from mosaicml/llm-foundry main
Merging from mosaic main
Pulling from mosaic main
Pulling from mosaic main.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
high level looks ok
in the pr description can you include figure showing mfu diff with and without masking and also figure showing convergence diff with and without masking
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
would be good to train some models to show equivalence of seq id with flash and other attention implementations
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM, lets look into that slow test a bit before merging.
This PR does three flash attention-related things:
WandB link to the experiments: https://wandb.ai/mosaic-ml/seq_id_FA_final_tests
Loss and throughput curves for 125M model trained to chinchilla steps: