Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fp8 backward #119

Open
wants to merge 31 commits into
base: main_perf
Choose a base branch
from
Open

fp8 backward #119

wants to merge 31 commits into from

Conversation

micmelesse
Copy link
Collaborator

@micmelesse micmelesse commented Jan 24, 2025

add fp8 backward

@micmelesse micmelesse changed the title add backward test case fp8 backward Jan 24, 2025
@micmelesse micmelesse force-pushed the micmelesse/fp8_bwd branch 4 times, most recently from 6b691eb to 297742b Compare February 3, 2025 09:24
@micmelesse micmelesse marked this pull request as ready for review February 4, 2025 13:37
Copy link

@brunomazzottiamd brunomazzottiamd left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm approving the PR because I can't see anything wrong with it. I just left some questions and cleanup suggestions.

@micmelesse micmelesse force-pushed the micmelesse/fp8_bwd branch 2 times, most recently from b725cdc to e6a67b3 Compare February 6, 2025 14:33
@micmelesse micmelesse marked this pull request as draft February 7, 2025 19:27
Copy link

@brunomazzottiamd brunomazzottiamd left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great job Michael! Kudos for introducing compute_fp8_scaling_factors Triton function, it's really useful to avoid code repetition.

@micmelesse micmelesse force-pushed the micmelesse/fp8_bwd branch 2 times, most recently from ac33ced to 567c282 Compare February 20, 2025 16:47
@micmelesse micmelesse force-pushed the micmelesse/fp8_bwd branch 3 times, most recently from 84c3259 to 5df3f94 Compare February 25, 2025 23:13
This is a combination of 21 commits.

fp8 BWD

Enable BWD fp8 with split kernel

Enable BWD fp8 with per block scale factors for p
and ds

This is a combination of 9 commits.

Enable BWD fp8

This is a combination of 12 commits.

add backward test case

save clean up

disable ci

lse is good

dv matches

reduce diff

use do fp8 for dv

kinda working

group size is a constexpr

clean up a bit

everything except mqa/gqa works

skip mqa cases

20 cases have nan on dropout

save what you have

disable tests

failing

enable tests

per block descale_p and descale_ds

use max(abs(())

clean up tests a bit more

fix bug

disable ci for now

pass variables

add flags

add alternate path. Still need to load descale factors

dv working

dk works

save

add type info for backward

fix  DEBUG flag bug

fix bug with backward. Normal forward works with dropout. Segfault with causal. Varlen has some issues. Might be related to strides.

pass descale strides

test causal

fix causal compiler assert. min head should be 32

remove descale_p

save

explict name as causal

isolate bad case

just run fp8 tests

bench with autotune

min changes

cast_fp8 helper

cast_varlen_to_fp8

save

minor

highlight failing configs

increase test cases

mark failing

recategorize misc tests

group failing gqa configs

add more tests

add vis code

min ci changes

dump folder

single image per tensors

add tensor comparison

gen varlen tensor

vis varlen tensors

varlen diff

nice varlen vis

vis function

show seqlen in varlen

add vis_tensors function

simplify

add color bars

rm vis from test

set canvas size.

descale values are optional

add ck tests

add flag to build ck

rm ck test

assert requires grad

ensure q, k, and v require gradients

split vis

rm interp, 8k and 300 dpi

slice per page

disable ci for now

add more vis code

tensor per image is better

for vis_close, don't vis if no error. also vis all failing varlen tests

varlen failures due to different seqlens

rm vis code
fix minor things

match readme

decast fp8 for ref input, use fp16 as input
@micmelesse micmelesse marked this pull request as ready for review February 26, 2025 22:21
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants