Skip to content

Commit

Permalink
Enable Fwd and Backward
Browse files Browse the repository at this point in the history
Enable Fwd and Backward

Enable Fwd and Backward

Enable fwd and varlen_fwd on AMD  (#63)

* flash_attn_func works

Compress

This is a combination of 12 commits.

add scripts

save

add our kernel

import our kernel

round trip

use bshd layout

figure out segfault

fix

show backward failure with prints

save backward work

run forward only

test smallest config on everything

add test

fix

remove pre commit

install triton

skip dropout

pin d

32 factor d

just run power of 2

remove timeout

run serially

clean up

clean up 2

* Varlen works

This is a combination of 6 commits.

save

some tests passing

enable more

enable everything

move around

alibi works

* keep interface and kernel seperate

* clean up

enable flash_attn_with_kvcache (#68)

* Compress kvcache work

This is a combination of 11 commits.

kvcache work

This is a combination of 4 commits.

kvcache is not supported

save

save decode

save

clean up merge

save cases

save

save

save

save

key mask on triton side

fix q size issue

test combos

save

* fix causal. use cache_seqlens

* clean and test what works

* some configs work on new_kv but fails on 1,8

* cache overwrite correct

* new_kv works more or less

* test local

* work on paged kv attention

* prefill paged attention

* fix has_batch_idx and skip local and rotatary emb

* save

* save

* save

* save

* handle new_kv when paged kv cache

* all except has_batch_idx works

* major options are green

* test all

* add tests

* save

* clean up

* minor clean up

* simplest config

* save debug true

* save

* refactor slightly

* save work

* need key masking

* force hip

* use is_hip

* save

* fix cache_seq_len issue

* work on new_kv

* pass new_kv data

* save

* benchmark fwd only

* disable debug

* pandas pdf

* save

* set methods

* record number of heads

* use configs

* flexiable dim, n-heads, headofdim

* better benchmarking

* basic inplace update working

* works upto 64

* new_kv supported!

* test case for has_batch_idx

* has_batch_idx works!

* save

* save

* save

* save ref

* fix mqa and gqa by duplicating

* GQA and MQA working by kernel modifications

* fix new_kv with gqa

* cache index

* deal with nans on fwd_splitk

* save

* causal working on basic case

* causal works!

* alibi works!

* clean up

* clean prefill changes

* remove bwd stuff

* limit decode test to test_op_fwd

* add ref

* use bfloat

Fixes after rebase

Fixes after rebase

rebase fixes

deal with kvcache failure

new run for branch

cancel-in-progress

fix varlen_fwd bug

enable packed layouts and all configs (#72)

Clean up for Upstream (#81)

* Clean

Clean

This is a combination of 4 commits.

clean 1

clean 2

clean more

match main

typo fix

* use is_hip()

* clean up more

* skip odd d only

* fix bug

* skip randomly

* use Flag

* update readme

* remove quantization

* remove bwd

* minor

* print

* remove verbose print

* qunatize zero's out the d stride

Enable Vanilla Bwd and Refactor (#86)

* Vanilla BWD

Vanilla BWD

This is a combination of 79 commits.

save test_flash_attn_output

use impl functions

pass layout

add ref

move arround impls

fix stride issue

save oai kernel

add baseline impl

save bwd kernel working

remove old impl

remove block_ptrs from bwd

pass padded dmodel and apply masking. the old test cases work but cases with small d don't work

save

save

more prints

rename to M to L

save

add notes

add old_bwd back

fa failure fails in kernels too

isolate new bwd and keep old bwd in place

clean up

softmax_lse doesnot match refernce

LOG flag

softmax_lse with LN2

move qk_scale to loop

pass ln2 to fwd

just print kernel input

test softmax output from forward

test exp_scores_triton

save all the ref

create ref USE_EXP2 path

return scores

mask scores when returning them. Basic impl test passes

scores and output match

show max_diff

return score needs to be adjusted as we find new maxes

all good outputs. old style RCP2 example

prep bwd_impl test

save

try openai

save

fix softmax_lse bug

test_op_bwd_impl starting to work!

new kernel. exp2 works but exp is faliing

fix bwd exp2

add m and n masks. small cases still don't work

match old and new kernel prints

compare old and new

print inputs

save

old kernel match on dv

dq works

compare to pytorch including softmax in forward

fix bwd impl bug

small sizes in bwd impl work

old bwd test pass. Moving on to kernel tests

dq, dk and dv are filled in place if given. Need to match cast to match fa

fix non bug

fix dv mismatch. use_exp2 was set to true in fwd

fix case up 128

refactor and clean up a bit more

issue is that dq and dk are not zeros

dq must be zeroed out

ignore segfaults

fa ref and my ref match!

all tests run

use tolerance 1e-3

we need to figure out preprocessing

save

clean up

save

test delta diff

move old impl out

new preprocess function

preprocessing_use_o flag

working _bwd_preprocess_use_p

basic cases pass

all green

fwd exp2 usage is done right before exp

* refactor

* refactor 2

* refactor 3

* fix bug

* try ci

* add flag

* rename to utils

* skip test_op_fwd_decode_int4_kv

* reduce head size

* try again

* go back to old head sizes

* Use Strides

Use Strides

This is a combination of 11 commits.

use strides in bwd

add layout test in forward

fix shape layout function

smaller tests

save

fix varlen error

no headsize passed to bwd

deal with varlen layout

save

save

save

save

* use gen scripts

* varlen fwd passing

* core fwd ref impl

* fix minor bugs

* wrap varlen- launcher attention_forward_pytorch_ref_impl

* varlen backward ref added

* add offsets for varlen

* fix delta bug

* varlen bwd working

* save

* runs on Mi200

* just test basics

* save

* fix bug

* fix varlen in64 bug

* add ref

* test_impl working with causal

* fix qkvpacked issue

* qkvpacked run tests

* remove test_backward

* save

* just test output

* dump into tensors

* softmaxlse layout for varlen

* small cases working

* bwd thd green. although maybe some oom

* forward out and lse are good. Something wrong with backward ref

* make varlen ref work

* save work, ref is working mostly

* 91 failed, 6542 passed, 6336 skipped, 1 warning

* ref is all green

* debug flag in utils

* found bad softmax_lse in varlen fwd

* fix bug in softmax lse. strides in varlen werenot right

* add causal tests and 32*32 bwd doesnot have segfault

* save

* fix oom by reducing block size for small heads

* bwd ref with causal working

* test impl

* causal test passes

* causal working

* fix tests

* nicer bench

* fix qvpacked error

* fix varlen qvpacked bug

* fix minor bug

* bench prefill and prefill_old using the same script

* autotune configs for fwd

* autotune flag

* clean up decode impl

* clean up

* clean up more

* bench everything by default and return time

* clean up readmes

REBASE: fix interface changes in rebase

rename test to test_flash_attn_triton_amd

REBASE: fix unpad diffs

minor clean up in setup

FLASH_ATTENTION_TRITON_AMD flags

bench fwd and bwd

fix sequence_parallel
  • Loading branch information
micmelesse committed Oct 28, 2024
1 parent c1d146c commit 730d260
Show file tree
Hide file tree
Showing 18 changed files with 7,045 additions and 107 deletions.
79 changes: 79 additions & 0 deletions .github/workflows/amd_tests.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
name: AMD Perf Kernel Tests

on:
workflow_dispatch:
pull_request:
branches: [main_perf]
merge_group:
branches: [main_perf]
types: [checks_requested]
push:
branches: [main_perf, micmelesse/upstream_pr]

concurrency:
group: ${{ github.ref }}
cancel-in-progress: true

permissions: read-all

jobs:
Runner-Preparation-AMD:
runs-on: ubuntu-latest
timeout-minutes: 30
outputs:
matrix-HIP: ${{ steps.set-matrix.outputs.matrix-HIP }}
steps:
- name: Prepare runner matrix
id: set-matrix
run: |
if [ x"${{ github.repository }}" == x"ROCm/flash-attention" ]; then
echo '::set-output name=matrix-HIP::[["self-hosted", "rocm"]]'
else
echo '::set-output name=matrix-HIP::[["ubuntu-latest"]]'
fi
Integration-Tests-AMD:
needs: Runner-Preparation-AMD
if: needs.Runner-Preparation-AMD.outputs.matrix-HIP != ''
runs-on: ${{ matrix.runner }}
strategy:
matrix:
runner: ${{fromJson(needs.Runner-Preparation-AMD.outputs.matrix-HIP)}}
container:
image: rocm/pytorch:rocm6.0.2_ubuntu22.04_py3.10_pytorch_2.1.2
options: --device=/dev/kfd --device=/dev/dri --security-opt seccomp=unconfined --group-add video --user root
steps:
- name: Checkout
uses: actions/checkout@v4
- name: Install Triton
run: |
pip uninstall -y triton
pip install matplotlib pandas pytest
git clone https://github.com/triton-lang/triton
cd triton
git checkout 2e9f2c2d20601c24b91a4c32a7b97ad1f8a55d88
pip install --verbose -e python
cd ..
- name: Build
run: |
export FLASH_ATTENTION_TRITON_AMD_ENABLE="TRUE"
python setup.py install
- name: Flash Attention Tests Using Reference Impl
run: |
export FLASH_ATTENTION_TRITON_AMD_ENABLE="TRUE"
export FLASH_ATTENTION_TRITON_AMD_REF=1
pytest tests/test_flash_attn_triton_amd.py
- name: Flash Attention Tests
run: |
export FLASH_ATTENTION_TRITON_AMD_ENABLE="TRUE"
export FLASH_ATTENTION_TRITON_AMD_AUTOTUNE=0
pytest tests/test_flash_attn_triton_amd.py
- name: AMD Kernel Tests
run: |
export FLASH_ATTENTION_TRITON_AMD_ENABLE="TRUE"
export FLASH_ATTENTION_TRITON_AMD_AUTOTUNE=0
pytest -v -s flash_attn/flash_attn_triton_amd/test.py
- name: AMD Kernel Bench
run: |
export FLASH_ATTENTION_TRITON_AMD_ENABLE="TRUE"
python flash_attn/flash_attn_triton_amd/bench.py
9 changes: 8 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,16 @@ var/
*.egg-info/
.installed.cfg
*.egg
.eggs

# IDE-related
.idea/

# Dev
venv
venv
scripts
*.log
core.*
*.csv
*.png
*.html
47 changes: 45 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,7 @@ FlashAttention-2 with CUDA currently supports:
3. All head dimensions up to 256. ~~Head dim > 192 backward requires A100/A800 or H100/H800~~. Head dim 256 backward now works on consumer GPUs (if there's no dropout) as of flash-attn 2.5.5.

### AMD ROCm Support
ROCm version uses [composable_kernel](https://github.com/ROCm/composable_kernel) as the backend. It provides the implementation of FlashAttention-2.
ROCm version has two backends. There is [composable_kernel](https://github.com/ROCm/composable_kernel) (ck) which is the default backend and a [Triton](https://github.com/triton-lang/triton) backend. They provide an implementation of FlashAttention-2.

**Requirements:**
- ROCm 6.0 and above.
Expand All @@ -121,11 +121,54 @@ We recommend the
[Pytorch](https://hub.docker.com/r/rocm/pytorch)
container from ROCm, which has all the required tools to install FlashAttention.

FlashAttention-2 with ROCm currently supports:
#### Composable Kernel Backend
FlashAttention-2 ROCm CK backend currently supports:
1. MI200 or MI300 GPUs.
2. Datatype fp16 and bf16
3. Forward's head dimensions up to 256. Backward head dimensions up to 128.

#### Triton Backend
The Triton implementation of the [Flash Attention v2](https://tridao.me/publications/flash2/flash2.pdf) is currently a work in progress.

It supports AMD's CDNA (MI200, MI300) and RDNA GPU's using fp16, bf16 and fp32 datatypes.

These features are supported in Fwd and Bwd
1) Fwd and Bwd with causal masking
2) Variable sequence lengths
3) Arbitrary Q and KV sequence lengths
4) Arbitrary head sizes

These features are supported in Fwd for now. We will add them to backward soon.
1) Multi and grouped query attention
2) ALiBi and matrix bias

These features are in development
1) Paged Attention
2) Sliding Window
3) Rotary embeddings
4) Dropout
5) Performance Improvements

#### Getting Started
To get started with the triton backend for AMD, follow the steps below.

First install the recommended Triton [commit](https://github.com/triton-lang/triton/commit/2e9f2c2d20601c24b91a4c32a7b97ad1f8a55d88).

```
git clone https://github.com/triton-lang/triton
cd triton
git checkout 2e9f2c2d20601c24b91a4c32a7b97ad1f8a55d88
pip install --verbose -e python
```
Then install and test Flash Attention with the flag `FLASH_ATTENTION_TRITON_AMD_ENABLE` set to `"TRUE"`.

```
export FLASH_ATTENTION_TRITON_AMD_ENABLE="TRUE"
cd flash-attention
python setup.py install
pytest tests/test_flash_attn.py
```


## How to use FlashAttention

Expand Down
17 changes: 11 additions & 6 deletions flash_attn/flash_attn_interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,15 @@

import torch
import torch.nn as nn
import os

# isort: off
# We need to import the CUDA kernels after importing torch
import flash_attn_2_cuda as flash_attn_cuda
USE_TRITON_ROCM = os.getenv("FLASH_ATTENTION_TRITON_AMD_ENABLE", "FALSE") == "TRUE"
if USE_TRITON_ROCM:
from .flash_attn_triton_amd import interface_fa as flash_attn_gpu
else:
import flash_attn_2_cuda as flash_attn_gpu

# isort: on

Expand Down Expand Up @@ -88,7 +93,7 @@ def _flash_attn_forward(
return_softmax: bool
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
q, k, v = [maybe_contiguous(x) for x in (q, k, v)]
out, softmax_lse, S_dmask, rng_state = flash_attn_cuda.fwd(
out, softmax_lse, S_dmask, rng_state = flash_attn_gpu.fwd(
q,
k,
v,
Expand Down Expand Up @@ -161,7 +166,7 @@ def _flash_attn_varlen_forward(
seqused_k: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
q, k, v = [maybe_contiguous(x) for x in (q, k, v)]
out, softmax_lse, S_dmask, rng_state = flash_attn_cuda.varlen_fwd(
out, softmax_lse, S_dmask, rng_state = flash_attn_gpu.varlen_fwd(
q,
k,
v,
Expand Down Expand Up @@ -260,7 +265,7 @@ def _flash_attn_backward(
dk,
dv,
softmax_d,
) = flash_attn_cuda.bwd(
) = flash_attn_gpu.bwd(
dout,
q,
k,
Expand Down Expand Up @@ -356,7 +361,7 @@ def _flash_attn_varlen_backward(
dk,
dv,
softmax_d,
) = flash_attn_cuda.varlen_bwd(
) = flash_attn_gpu.varlen_bwd(
dout,
q,
k,
Expand Down Expand Up @@ -1544,7 +1549,7 @@ def flash_attn_with_kvcache(
cache_seqlens = maybe_contiguous(cache_seqlens)
cache_batch_idx = maybe_contiguous(cache_batch_idx)
block_table = maybe_contiguous(block_table)
out, softmax_lse = flash_attn_cuda.fwd_kvcache(
out, softmax_lse = flash_attn_gpu.fwd_kvcache(
q,
k_cache,
v_cache,
Expand Down
49 changes: 49 additions & 0 deletions flash_attn/flash_attn_triton_amd/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
Flash Attention Triton Kernel
===============

#### Introduction
The Triton implementation of the [Flash Attention v2](https://tridao.me/publications/flash2/flash2.pdf) is currently a work in progress.

It supports AMD's CDNA (MI200, MI300) and RDNA GPU's using fp16, bf16 and fp32 datatypes.

These features are supported in Fwd and Bwd
1) Fwd and Bwd with causal masking
2) Variable sequence lengths
3) Arbitrary Q and KV sequence lengths
4) Arbitrary head sizes

These features are supported in Fwd for now. We will add them to backward soon.
1) Multi and grouped query attention
2) ALiBi and matrix bias

These features are in development
1) Paged Attention
2) Sliding Window
3) Rotary embeddings
4) Dropout
5) Performance Improvements

#### Getting Started
To get started with the triton backend for AMD, follow the steps below.

First install the recommended Triton [commit](https://github.com/triton-lang/triton/commit/2e9f2c2d20601c24b91a4c32a7b97ad1f8a55d88).

```
git clone https://github.com/triton-lang/triton
cd triton
git checkout 2e9f2c2d20601c24b91a4c32a7b97ad1f8a55d88
pip install --verbose -e python
```
Then install and test Flash Attention with the flag `FLASH_ATTENTION_TRITON_AMD_ENABLE` set to `"TRUE"`.

```
export FLASH_ATTENTION_TRITON_AMD_ENABLE="TRUE"
cd flash-attention
python setup.py install
pytest tests/test_flash_attn.py
```

#### Credits
AMD Triton kernels team

OpenAI kernel team
Empty file.
Loading

0 comments on commit 730d260

Please sign in to comment.