Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enable Vanilla Bwd and Refactor #86

Merged
merged 67 commits into from
Oct 28, 2024
Merged

Conversation

micmelesse
Copy link
Collaborator

@micmelesse micmelesse commented Oct 14, 2024

This PR

  • enables vanilla bwd
  • refactors our code into several files.
  • fixes a bug with softmax lse that was returned with forward
  • adds an alternate mode that uses exp instead of exp2. This was useful to debug issues with both forward and backward
  • creates interfaces for fa and pytorch that use implementation functions with explicit paramaters.
  • adds a pytorch ref implementations that mimics our triton kernels for testing.

Vanilla BWD

This is a combination of 79 commits.

save test_flash_attn_output

use impl functions

pass layout

add ref

move arround impls

fix stride issue

save oai kernel

add baseline impl

save bwd kernel working

remove old impl

remove block_ptrs from bwd

pass padded dmodel and apply masking. the old test cases work but cases with small d don't work

save

save

more prints

rename to M to L

save

add notes

add old_bwd back

fa failure fails in kernels too

isolate new bwd and keep old bwd in place

clean up

softmax_lse doesnot match refernce

LOG flag

softmax_lse with LN2

move qk_scale to loop

pass ln2 to fwd

just print kernel input

test softmax output from forward

test exp_scores_triton

save all the ref

create ref USE_EXP2 path

return scores

mask scores when returning them. Basic impl test passes

scores and output match

show max_diff

return score needs to be adjusted as we find new maxes

all good outputs. old style RCP2 example

prep bwd_impl test

save

try openai

save

fix softmax_lse bug

test_op_bwd_impl starting to work!

new kernel. exp2 works but exp is faliing

fix bwd exp2

add m and n masks. small cases still don't work

match old and new kernel prints

compare old and new

print inputs

save

old kernel match on dv

dq works

compare to pytorch including softmax in forward

fix bwd impl bug

small sizes in bwd impl work

old bwd test pass. Moving on to kernel tests

dq, dk and dv are filled in place if given. Need to match cast to match fa

fix non bug

fix dv mismatch. use_exp2 was set to true in fwd

fix case up 128

refactor and clean up a bit more

issue is that dq and dk are not zeros

dq must be zeroed out

ignore segfaults

fa ref and my ref match!

all tests run

use tolerance 1e-3

we need to figure out preprocessing

save

clean up

save

test delta diff

move old impl out

new preprocess function

preprocessing_use_o flag

working _bwd_preprocess_use_p

basic cases pass

all green

fwd exp2 usage is done right before exp
@micmelesse
Copy link
Collaborator Author

micmelesse commented Oct 15, 2024

The kernel tests pass on MI300 but seems the ci MI200 have issues.
image

@micmelesse micmelesse force-pushed the micmelesse/enable_bwd branch from c6c9559 to 0bd8120 Compare October 16, 2024 16:18
Use Strides

This is a combination of 11 commits.

use strides in bwd

add layout test in forward

fix shape layout function

smaller tests

save

fix varlen error

no headsize passed to bwd

deal with varlen layout

save

save

save

save
@micmelesse micmelesse force-pushed the micmelesse/enable_bwd branch from 8d1c2f7 to a168999 Compare October 16, 2024 16:40
@micmelesse micmelesse marked this pull request as ready for review October 26, 2024 04:11
@micmelesse micmelesse merged commit b2a2dff into main_perf Oct 28, 2024
2 checks passed
micmelesse added a commit that referenced this pull request Oct 28, 2024
Enable fwd and varlen_fwd on AMD  (#63)

* flash_attn_func works

Compress

This is a combination of 12 commits.

add scripts

save

add our kernel

import our kernel

round trip

use bshd layout

figure out segfault

fix

show backward failure with prints

save backward work

run forward only

test smallest config on everything

add test

fix

remove pre commit

install triton

skip dropout

pin d

32 factor d

just run power of 2

remove timeout

run serially

clean up

clean up 2

* Varlen works

This is a combination of 6 commits.

save

some tests passing

enable more

enable everything

move around

alibi works

* keep interface and kernel seperate

* clean up

enable flash_attn_with_kvcache (#68)

* Compress kvcache work

This is a combination of 11 commits.

kvcache work

This is a combination of 4 commits.

kvcache is not supported

save

save decode

save

clean up merge

save cases

save

save

save

save

key mask on triton side

fix q size issue

test combos

save

* fix causal. use cache_seqlens

* clean and test what works

* some configs work on new_kv but fails on 1,8

* cache overwrite correct

* new_kv works more or less

* test local

* work on paged kv attention

* prefill paged attention

* fix has_batch_idx and skip local and rotatary emb

* save

* save

* save

* save

* handle new_kv when paged kv cache

* all except has_batch_idx works

* major options are green

* test all

* add tests

* save

* clean up

* minor clean up

* simplest config

* save debug true

* save

* refactor slightly

* save work

* need key masking

* force hip

* use is_hip

* save

* fix cache_seq_len issue

* work on new_kv

* pass new_kv data

* save

* benchmark fwd only

* disable debug

* pandas pdf

* save

* set methods

* record number of heads

* use configs

* flexiable dim, n-heads, headofdim

* better benchmarking

* basic inplace update working

* works upto 64

* new_kv supported!

* test case for has_batch_idx

* has_batch_idx works!

* save

* save

* save

* save ref

* fix mqa and gqa by duplicating

* GQA and MQA working by kernel modifications

* fix new_kv with gqa

* cache index

* deal with nans on fwd_splitk

* save

* causal working on basic case

* causal works!

* alibi works!

* clean up

* clean prefill changes

* remove bwd stuff

* limit decode test to test_op_fwd

* add ref

* use bfloat

Fixes after rebase

Fixes after rebase

rebase fixes

deal with kvcache failure

new run for branch

cancel-in-progress

fix varlen_fwd bug

enable packed layouts and all configs (#72)

Clean up for Upstream (#81)

* Clean

Clean

This is a combination of 4 commits.

clean 1

clean 2

clean more

match main

typo fix

* use is_hip()

* clean up more

* skip odd d only

* fix bug

* skip randomly

* use Flag

* update readme

* remove quantization

* remove bwd

* minor

* print

* remove verbose print

* qunatize zero's out the d stride

Enable Vanilla Bwd and Refactor (#86)

* Vanilla BWD

Vanilla BWD

This is a combination of 79 commits.

save test_flash_attn_output

use impl functions

pass layout

add ref

move arround impls

fix stride issue

save oai kernel

add baseline impl

save bwd kernel working

remove old impl

remove block_ptrs from bwd

pass padded dmodel and apply masking. the old test cases work but cases with small d don't work

save

save

more prints

rename to M to L

save

add notes

add old_bwd back

fa failure fails in kernels too

isolate new bwd and keep old bwd in place

clean up

softmax_lse doesnot match refernce

LOG flag

softmax_lse with LN2

move qk_scale to loop

pass ln2 to fwd

just print kernel input

test softmax output from forward

test exp_scores_triton

save all the ref

create ref USE_EXP2 path

return scores

mask scores when returning them. Basic impl test passes

scores and output match

show max_diff

return score needs to be adjusted as we find new maxes

all good outputs. old style RCP2 example

prep bwd_impl test

save

try openai

save

fix softmax_lse bug

test_op_bwd_impl starting to work!

new kernel. exp2 works but exp is faliing

fix bwd exp2

add m and n masks. small cases still don't work

match old and new kernel prints

compare old and new

print inputs

save

old kernel match on dv

dq works

compare to pytorch including softmax in forward

fix bwd impl bug

small sizes in bwd impl work

old bwd test pass. Moving on to kernel tests

dq, dk and dv are filled in place if given. Need to match cast to match fa

fix non bug

fix dv mismatch. use_exp2 was set to true in fwd

fix case up 128

refactor and clean up a bit more

issue is that dq and dk are not zeros

dq must be zeroed out

ignore segfaults

fa ref and my ref match!

all tests run

use tolerance 1e-3

we need to figure out preprocessing

save

clean up

save

test delta diff

move old impl out

new preprocess function

preprocessing_use_o flag

working _bwd_preprocess_use_p

basic cases pass

all green

fwd exp2 usage is done right before exp

* refactor

* refactor 2

* refactor 3

* fix bug

* try ci

* add flag

* rename to utils

* skip test_op_fwd_decode_int4_kv

* reduce head size

* try again

* go back to old head sizes

* Use Strides

Use Strides

This is a combination of 11 commits.

use strides in bwd

add layout test in forward

fix shape layout function

smaller tests

save

fix varlen error

no headsize passed to bwd

deal with varlen layout

save

save

save

save

* use gen scripts

* varlen fwd passing

* core fwd ref impl

* fix minor bugs

* wrap varlen- launcher attention_forward_pytorch_ref_impl

* varlen backward ref added

* add offsets for varlen

* fix delta bug

* varlen bwd working

* save

* runs on Mi200

* just test basics

* save

* fix bug

* fix varlen in64 bug

* add ref

* test_impl working with causal

* fix qkvpacked issue

* qkvpacked run tests

* remove test_backward

* save

* just test output

* dump into tensors

* softmaxlse layout for varlen

* small cases working

* bwd thd green. although maybe some oom

* forward out and lse are good. Something wrong with backward ref

* make varlen ref work

* save work, ref is working mostly

* 91 failed, 6542 passed, 6336 skipped, 1 warning

* ref is all green

* debug flag in utils

* found bad softmax_lse in varlen fwd

* fix bug in softmax lse. strides in varlen werenot right

* add causal tests and 32*32 bwd doesnot have segfault

* save

* fix oom by reducing block size for small heads

* bwd ref with causal working

* test impl

* causal test passes

* causal working

* fix tests

* nicer bench

* fix qvpacked error

* fix varlen qvpacked bug

* fix minor bug

* bench prefill and prefill_old using the same script

* autotune configs for fwd

* autotune flag

* clean up decode impl

* clean up

* clean up more

* bench everything by default and return time

* clean up readmes
micmelesse added a commit that referenced this pull request Oct 28, 2024
Enable Fwd and Backward

Enable fwd and varlen_fwd on AMD  (#63)

* flash_attn_func works

Compress

This is a combination of 12 commits.

add scripts

save

add our kernel

import our kernel

round trip

use bshd layout

figure out segfault

fix

show backward failure with prints

save backward work

run forward only

test smallest config on everything

add test

fix

remove pre commit

install triton

skip dropout

pin d

32 factor d

just run power of 2

remove timeout

run serially

clean up

clean up 2

* Varlen works

This is a combination of 6 commits.

save

some tests passing

enable more

enable everything

move around

alibi works

* keep interface and kernel seperate

* clean up

enable flash_attn_with_kvcache (#68)

* Compress kvcache work

This is a combination of 11 commits.

kvcache work

This is a combination of 4 commits.

kvcache is not supported

save

save decode

save

clean up merge

save cases

save

save

save

save

key mask on triton side

fix q size issue

test combos

save

* fix causal. use cache_seqlens

* clean and test what works

* some configs work on new_kv but fails on 1,8

* cache overwrite correct

* new_kv works more or less

* test local

* work on paged kv attention

* prefill paged attention

* fix has_batch_idx and skip local and rotatary emb

* save

* save

* save

* save

* handle new_kv when paged kv cache

* all except has_batch_idx works

* major options are green

* test all

* add tests

* save

* clean up

* minor clean up

* simplest config

* save debug true

* save

* refactor slightly

* save work

* need key masking

* force hip

* use is_hip

* save

* fix cache_seq_len issue

* work on new_kv

* pass new_kv data

* save

* benchmark fwd only

* disable debug

* pandas pdf

* save

* set methods

* record number of heads

* use configs

* flexiable dim, n-heads, headofdim

* better benchmarking

* basic inplace update working

* works upto 64

* new_kv supported!

* test case for has_batch_idx

* has_batch_idx works!

* save

* save

* save

* save ref

* fix mqa and gqa by duplicating

* GQA and MQA working by kernel modifications

* fix new_kv with gqa

* cache index

* deal with nans on fwd_splitk

* save

* causal working on basic case

* causal works!

* alibi works!

* clean up

* clean prefill changes

* remove bwd stuff

* limit decode test to test_op_fwd

* add ref

* use bfloat

Fixes after rebase

Fixes after rebase

rebase fixes

deal with kvcache failure

new run for branch

cancel-in-progress

fix varlen_fwd bug

enable packed layouts and all configs (#72)

Clean up for Upstream (#81)

* Clean

Clean

This is a combination of 4 commits.

clean 1

clean 2

clean more

match main

typo fix

* use is_hip()

* clean up more

* skip odd d only

* fix bug

* skip randomly

* use Flag

* update readme

* remove quantization

* remove bwd

* minor

* print

* remove verbose print

* qunatize zero's out the d stride

Enable Vanilla Bwd and Refactor (#86)

* Vanilla BWD

Vanilla BWD

This is a combination of 79 commits.

save test_flash_attn_output

use impl functions

pass layout

add ref

move arround impls

fix stride issue

save oai kernel

add baseline impl

save bwd kernel working

remove old impl

remove block_ptrs from bwd

pass padded dmodel and apply masking. the old test cases work but cases with small d don't work

save

save

more prints

rename to M to L

save

add notes

add old_bwd back

fa failure fails in kernels too

isolate new bwd and keep old bwd in place

clean up

softmax_lse doesnot match refernce

LOG flag

softmax_lse with LN2

move qk_scale to loop

pass ln2 to fwd

just print kernel input

test softmax output from forward

test exp_scores_triton

save all the ref

create ref USE_EXP2 path

return scores

mask scores when returning them. Basic impl test passes

scores and output match

show max_diff

return score needs to be adjusted as we find new maxes

all good outputs. old style RCP2 example

prep bwd_impl test

save

try openai

save

fix softmax_lse bug

test_op_bwd_impl starting to work!

new kernel. exp2 works but exp is faliing

fix bwd exp2

add m and n masks. small cases still don't work

match old and new kernel prints

compare old and new

print inputs

save

old kernel match on dv

dq works

compare to pytorch including softmax in forward

fix bwd impl bug

small sizes in bwd impl work

old bwd test pass. Moving on to kernel tests

dq, dk and dv are filled in place if given. Need to match cast to match fa

fix non bug

fix dv mismatch. use_exp2 was set to true in fwd

fix case up 128

refactor and clean up a bit more

issue is that dq and dk are not zeros

dq must be zeroed out

ignore segfaults

fa ref and my ref match!

all tests run

use tolerance 1e-3

we need to figure out preprocessing

save

clean up

save

test delta diff

move old impl out

new preprocess function

preprocessing_use_o flag

working _bwd_preprocess_use_p

basic cases pass

all green

fwd exp2 usage is done right before exp

* refactor

* refactor 2

* refactor 3

* fix bug

* try ci

* add flag

* rename to utils

* skip test_op_fwd_decode_int4_kv

* reduce head size

* try again

* go back to old head sizes

* Use Strides

Use Strides

This is a combination of 11 commits.

use strides in bwd

add layout test in forward

fix shape layout function

smaller tests

save

fix varlen error

no headsize passed to bwd

deal with varlen layout

save

save

save

save

* use gen scripts

* varlen fwd passing

* core fwd ref impl

* fix minor bugs

* wrap varlen- launcher attention_forward_pytorch_ref_impl

* varlen backward ref added

* add offsets for varlen

* fix delta bug

* varlen bwd working

* save

* runs on Mi200

* just test basics

* save

* fix bug

* fix varlen in64 bug

* add ref

* test_impl working with causal

* fix qkvpacked issue

* qkvpacked run tests

* remove test_backward

* save

* just test output

* dump into tensors

* softmaxlse layout for varlen

* small cases working

* bwd thd green. although maybe some oom

* forward out and lse are good. Something wrong with backward ref

* make varlen ref work

* save work, ref is working mostly

* 91 failed, 6542 passed, 6336 skipped, 1 warning

* ref is all green

* debug flag in utils

* found bad softmax_lse in varlen fwd

* fix bug in softmax lse. strides in varlen werenot right

* add causal tests and 32*32 bwd doesnot have segfault

* save

* fix oom by reducing block size for small heads

* bwd ref with causal working

* test impl

* causal test passes

* causal working

* fix tests

* nicer bench

* fix qvpacked error

* fix varlen qvpacked bug

* fix minor bug

* bench prefill and prefill_old using the same script

* autotune configs for fwd

* autotune flag

* clean up decode impl

* clean up

* clean up more

* bench everything by default and return time

* clean up readmes

REBASE: fix interface changes in rebase

rename test to test_flash_attn_triton_amd

REBASE: fix unpad diffs

minor clean up in setup
micmelesse added a commit that referenced this pull request Oct 28, 2024
Enable Fwd and Backward

Enable Fwd and Backward

Enable fwd and varlen_fwd on AMD  (#63)

* flash_attn_func works

Compress

This is a combination of 12 commits.

add scripts

save

add our kernel

import our kernel

round trip

use bshd layout

figure out segfault

fix

show backward failure with prints

save backward work

run forward only

test smallest config on everything

add test

fix

remove pre commit

install triton

skip dropout

pin d

32 factor d

just run power of 2

remove timeout

run serially

clean up

clean up 2

* Varlen works

This is a combination of 6 commits.

save

some tests passing

enable more

enable everything

move around

alibi works

* keep interface and kernel seperate

* clean up

enable flash_attn_with_kvcache (#68)

* Compress kvcache work

This is a combination of 11 commits.

kvcache work

This is a combination of 4 commits.

kvcache is not supported

save

save decode

save

clean up merge

save cases

save

save

save

save

key mask on triton side

fix q size issue

test combos

save

* fix causal. use cache_seqlens

* clean and test what works

* some configs work on new_kv but fails on 1,8

* cache overwrite correct

* new_kv works more or less

* test local

* work on paged kv attention

* prefill paged attention

* fix has_batch_idx and skip local and rotatary emb

* save

* save

* save

* save

* handle new_kv when paged kv cache

* all except has_batch_idx works

* major options are green

* test all

* add tests

* save

* clean up

* minor clean up

* simplest config

* save debug true

* save

* refactor slightly

* save work

* need key masking

* force hip

* use is_hip

* save

* fix cache_seq_len issue

* work on new_kv

* pass new_kv data

* save

* benchmark fwd only

* disable debug

* pandas pdf

* save

* set methods

* record number of heads

* use configs

* flexiable dim, n-heads, headofdim

* better benchmarking

* basic inplace update working

* works upto 64

* new_kv supported!

* test case for has_batch_idx

* has_batch_idx works!

* save

* save

* save

* save ref

* fix mqa and gqa by duplicating

* GQA and MQA working by kernel modifications

* fix new_kv with gqa

* cache index

* deal with nans on fwd_splitk

* save

* causal working on basic case

* causal works!

* alibi works!

* clean up

* clean prefill changes

* remove bwd stuff

* limit decode test to test_op_fwd

* add ref

* use bfloat

Fixes after rebase

Fixes after rebase

rebase fixes

deal with kvcache failure

new run for branch

cancel-in-progress

fix varlen_fwd bug

enable packed layouts and all configs (#72)

Clean up for Upstream (#81)

* Clean

Clean

This is a combination of 4 commits.

clean 1

clean 2

clean more

match main

typo fix

* use is_hip()

* clean up more

* skip odd d only

* fix bug

* skip randomly

* use Flag

* update readme

* remove quantization

* remove bwd

* minor

* print

* remove verbose print

* qunatize zero's out the d stride

Enable Vanilla Bwd and Refactor (#86)

* Vanilla BWD

Vanilla BWD

This is a combination of 79 commits.

save test_flash_attn_output

use impl functions

pass layout

add ref

move arround impls

fix stride issue

save oai kernel

add baseline impl

save bwd kernel working

remove old impl

remove block_ptrs from bwd

pass padded dmodel and apply masking. the old test cases work but cases with small d don't work

save

save

more prints

rename to M to L

save

add notes

add old_bwd back

fa failure fails in kernels too

isolate new bwd and keep old bwd in place

clean up

softmax_lse doesnot match refernce

LOG flag

softmax_lse with LN2

move qk_scale to loop

pass ln2 to fwd

just print kernel input

test softmax output from forward

test exp_scores_triton

save all the ref

create ref USE_EXP2 path

return scores

mask scores when returning them. Basic impl test passes

scores and output match

show max_diff

return score needs to be adjusted as we find new maxes

all good outputs. old style RCP2 example

prep bwd_impl test

save

try openai

save

fix softmax_lse bug

test_op_bwd_impl starting to work!

new kernel. exp2 works but exp is faliing

fix bwd exp2

add m and n masks. small cases still don't work

match old and new kernel prints

compare old and new

print inputs

save

old kernel match on dv

dq works

compare to pytorch including softmax in forward

fix bwd impl bug

small sizes in bwd impl work

old bwd test pass. Moving on to kernel tests

dq, dk and dv are filled in place if given. Need to match cast to match fa

fix non bug

fix dv mismatch. use_exp2 was set to true in fwd

fix case up 128

refactor and clean up a bit more

issue is that dq and dk are not zeros

dq must be zeroed out

ignore segfaults

fa ref and my ref match!

all tests run

use tolerance 1e-3

we need to figure out preprocessing

save

clean up

save

test delta diff

move old impl out

new preprocess function

preprocessing_use_o flag

working _bwd_preprocess_use_p

basic cases pass

all green

fwd exp2 usage is done right before exp

* refactor

* refactor 2

* refactor 3

* fix bug

* try ci

* add flag

* rename to utils

* skip test_op_fwd_decode_int4_kv

* reduce head size

* try again

* go back to old head sizes

* Use Strides

Use Strides

This is a combination of 11 commits.

use strides in bwd

add layout test in forward

fix shape layout function

smaller tests

save

fix varlen error

no headsize passed to bwd

deal with varlen layout

save

save

save

save

* use gen scripts

* varlen fwd passing

* core fwd ref impl

* fix minor bugs

* wrap varlen- launcher attention_forward_pytorch_ref_impl

* varlen backward ref added

* add offsets for varlen

* fix delta bug

* varlen bwd working

* save

* runs on Mi200

* just test basics

* save

* fix bug

* fix varlen in64 bug

* add ref

* test_impl working with causal

* fix qkvpacked issue

* qkvpacked run tests

* remove test_backward

* save

* just test output

* dump into tensors

* softmaxlse layout for varlen

* small cases working

* bwd thd green. although maybe some oom

* forward out and lse are good. Something wrong with backward ref

* make varlen ref work

* save work, ref is working mostly

* 91 failed, 6542 passed, 6336 skipped, 1 warning

* ref is all green

* debug flag in utils

* found bad softmax_lse in varlen fwd

* fix bug in softmax lse. strides in varlen werenot right

* add causal tests and 32*32 bwd doesnot have segfault

* save

* fix oom by reducing block size for small heads

* bwd ref with causal working

* test impl

* causal test passes

* causal working

* fix tests

* nicer bench

* fix qvpacked error

* fix varlen qvpacked bug

* fix minor bug

* bench prefill and prefill_old using the same script

* autotune configs for fwd

* autotune flag

* clean up decode impl

* clean up

* clean up more

* bench everything by default and return time

* clean up readmes

REBASE: fix interface changes in rebase

rename test to test_flash_attn_triton_amd

REBASE: fix unpad diffs

minor clean up in setup

FLASH_ATTENTION_TRITON_AMD flags

bench fwd and bwd

fix sequence_parallel
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

1 participant