Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[attention] Extend attention to fuse transpose #669

Closed
antiagainst opened this issue May 10, 2024 · 8 comments
Closed

[attention] Extend attention to fuse transpose #669

antiagainst opened this issue May 10, 2024 · 8 comments
Assignees
Labels
sdxl-int8 Issues replated to SDXL quantized model support

Comments

@antiagainst
Copy link

No description provided.

@antiagainst antiagainst converted this from a draft issue May 10, 2024
@antiagainst antiagainst changed the title KERNEL: Extend attention to fuse transpose [attention] Extend attention to fuse transpose May 11, 2024
@antiagainst
Copy link
Author

antiagainst commented May 22, 2024

Update 5/22: patch iree-org/iree#17408 out; needing review.
Update 5/23: working on decomposition and tiling. Patch out today or so.

@Groverkss
Copy link
Contributor

Plan to finish it this week (Before Jun 7):

4 Jun: Land online attention (iree-org/iree#17536)
5 Jun: Create transform script using online_attention for MFMA
6 Jun: Add indexing_maps to attention op
7 Jun: Fusions

@antiagainst antiagainst added the sdxl-int8 Issues replated to SDXL quantized model support label Jul 12, 2024
@raikonenfnu
Copy link
Member

Hey guys, quick update

  1. Indexing attention ([LinalgExt] Adding IndexingMaps to linalg_ext.attentionOp iree-org/iree#17864) has landed
  2. CastTypeOFitMMA support for TD pipeline ([LLVMGPU] Support CastTypeToFitMMA on TransformDialect script.  iree-org/iree#17884) is up
  3. transfer_write distribution for non contiguous indexing map ([LLVMGPU][VectorDist] Enable support to distribute vector.transfer_write with non-contiguous dims iree-org/iree#17895) is up

Once 2. and 3. and iree-org/iree@d2ca774 is landed on main, we should be able to handle/compile fused attn-transpose.

@antiagainst
Copy link
Author

Awesome. All 3 pull requests are in. Can you send out the last piece?

@raikonenfnu
Copy link
Member

Awesome. All 3 pull requests are in. Can you send out the last piece?

Hey Lei, I think @MaheshRavishankar is en route to pushing that one in! :)

@MaheshRavishankar
Copy link

I can send it in early next week.

@raikonenfnu
Copy link
Member

I also pushed up/updated the spec mlir to find k2 correctly (link). I tested compiling on the fusion-preprocessing test MLIR (here) and was able to get a vmfb out.

The gist above is slightly different from the test in where we make the scale constant here. It fails on vector distribution if scale is not constant.

compile command:

~/nod/iree-build-notrace/tools/iree-compile constant_transpose_fusion.mlir --iree-hal-target-backends=rocm --iree-rocm-target-chip=gfx942 --iree-global-opt-propagate-transposes=true --iree-opt-outer-dim-concat=true --iree-opt-const-eval=false --iree-opt-data-tiling=false --iree-rocm-waves-per-eu=2 --iree-vm-target-truncate-unsupported-floats --iree-codegen-llvmgpu-use-vector-distribution --iree-codegen-gpu-native-math-precision=true --iree-flow-enable-aggressive-fusion -o attention.vmfb --iree-codegen-transform-dialect-library=attention_and_matmul_spec.mlir

@raikonenfnu
Copy link
Member

FYI I also tested the attention-transpose-fusion vmfb numerics on normal random numbers (0.0, 1.0) against torch, seems like we have good numerics there :)
numerics_test

Starting IR, compile command, data generator can all be found in https://gist.github.com/raikonenfnu/973b4d91e4378702ce4b4496d732cb57

Needed to update the shape from the original fusion-preprocessing test a little bit since the fastest dim for Q,K,V needs to be the same to run on pytorch.

@github-project-automation github-project-automation bot moved this from In progress to Done in Turbine: SDXL on CDNA Jul 27, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
sdxl-int8 Issues replated to SDXL quantized model support
Projects
Status: Done
Development

No branches or pull requests

4 participants