-
Notifications
You must be signed in to change notification settings - Fork 50
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
fp8 backward #119
base: main_perf
Are you sure you want to change the base?
fp8 backward #119
Conversation
6b691eb
to
297742b
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm approving the PR because I can't see anything wrong with it. I just left some questions and cleanup suggestions.
b725cdc
to
e6a67b3
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Great job Michael! Kudos for introducing compute_fp8_scaling_factors
Triton function, it's really useful to avoid code repetition.
ac33ced
to
567c282
Compare
37dd11e
to
6d5691f
Compare
84c3259
to
5df3f94
Compare
This is a combination of 21 commits. fp8 BWD Enable BWD fp8 with split kernel Enable BWD fp8 with per block scale factors for p and ds This is a combination of 9 commits. Enable BWD fp8 This is a combination of 12 commits. add backward test case save clean up disable ci lse is good dv matches reduce diff use do fp8 for dv kinda working group size is a constexpr clean up a bit everything except mqa/gqa works skip mqa cases 20 cases have nan on dropout save what you have disable tests failing enable tests per block descale_p and descale_ds use max(abs(()) clean up tests a bit more fix bug disable ci for now pass variables add flags add alternate path. Still need to load descale factors dv working dk works save add type info for backward fix DEBUG flag bug fix bug with backward. Normal forward works with dropout. Segfault with causal. Varlen has some issues. Might be related to strides. pass descale strides test causal fix causal compiler assert. min head should be 32 remove descale_p save explict name as causal isolate bad case just run fp8 tests bench with autotune min changes cast_fp8 helper cast_varlen_to_fp8 save minor highlight failing configs increase test cases mark failing recategorize misc tests group failing gqa configs add more tests add vis code min ci changes dump folder single image per tensors add tensor comparison gen varlen tensor vis varlen tensors varlen diff nice varlen vis vis function show seqlen in varlen add vis_tensors function simplify add color bars rm vis from test set canvas size. descale values are optional add ck tests add flag to build ck rm ck test assert requires grad ensure q, k, and v require gradients split vis rm interp, 8k and 300 dpi slice per page disable ci for now add more vis code tensor per image is better for vis_close, don't vis if no error. also vis all failing varlen tests varlen failures due to different seqlens rm vis code
fix minor things match readme decast fp8 for ref input, use fp16 as input
5df3f94
to
592e69b
Compare
accumlating fp32 This is a combination of 5 commits. extend ci time clean more minimize difference add types ZERO_TENSORS and ACCUMLATE_FP32 flags
add fp8 backward