Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add memory efficient attention from CUTLASS #14343

Merged
merged 10 commits into from
Jan 20, 2023

Conversation

tianleiwu
Copy link
Contributor

@tianleiwu tianleiwu commented Jan 18, 2023

Description

Add memory efficient attention from CUTLASS.

TODO (in next pull request):
(1) Need performance tests on different GPUs, then add a sequence length threshold (only activate it for long sequence length).
(2) Merge changes from NVIDIA/cutlass#773 when it is in cutlass master.

Average latency of bert-base-cased with FP16 in T4 GPU

In this test, no attention mask is used, and latency might change if you apply attention mask.

batch_sequence b1_s8 b1_s16 b1_s32 b1_s64 b1_s128 b1_s256 b1_s384 b1_s512 b8_s8 b8_s16 b8_s32 b8_s64 b8_s128 b8_s256 b8_s384 b8_s512
Unfused MHA 1.40 1.38 1.50 1.73 2.30 3.73 5.93 8.09 1.68 2.41 3.39 5.78 11.09 24.42 42.64 59.31
Cutlass FMHA 1.42 1.37 1.42 1.63 2.13 3.21 4.36 5.94 1.66 2.13 3.08 5.26 10.13 20.99 31.75 45.42
TRT FMHA 1.36 1.31 1.36 1.58 2.10 3.22 4.41 5.83 1.59 2.09 3.03 5.25 10.01 20.34 30.33 46.82
batch_sequence b1_s7 b1_s15 b1_s31 b1_s63 b1_s127 b1_s255 b1_s383 b1_s511 b8_s7 b8_s15 b8_s31 b8_s63 b8_s127 b8_s255 b8_s383 b8_s511
Unfused 1.48 1.39 1.47 1.78 2.62 5.00 9.09 13.69 1.65 2.18 3.35 6.34 13.87 34.93 68.07 103.98
Cutlass FMHA 1.45 1.36 1.41 1.62 2.14 3.23 4.37 5.95 1.62 2.07 3.05 5.33 10.13 21.02 31.75 45.42
TRT FMHA 1.34 1.30 1.35 1.59 2.09 3.22 4.40 5.82 1.55 2.02 2.99 5.32 9.96 20.30 30.27 46.74

Using TRT kernel is slightly faster then using Cutlass kernel in most cases (except a combination of large batch and large sequence length). Note that it is end to end latency, so other parts (like add bias transpose) might also contribute the gap in latency.

Average latency of bert-base-cased with FP32 in T4 GPU

Kernel b1_s8 b1_s16 b1_s32 b1_s64 b1_s128 b1_s256 b1_s384 b1_s512 b8_s8 b8_s16 b8_s32 b8_s64 b8_s128 b8_s256 b8_s384 b8_s512
Unfused MHA 2.22 2.26 2.66 4.13 6.81 13.24 21.12 27.36 4.07 6.88 12.62 23.41 48.53 101.24 166.10 231.90
Cutlass FMHA 2.37 2.48 2.82 4.38 7.15 13.58 20.58 26.73 4.37 7.14 12.94 23.85 48.54 97.00 152.41 210.79
Change 6.8% 9.7% 6.0% 6.1% 5.0% 2.6% -2.6% -2.3% 7.4% 3.8% 2.5% 1.9% 0.0% -4.2% -8.2% -9.1%

When sequence length is not multiply of 16, unfused kernel performance is worse except for very short sequence (<16). For unfused kernel, proper padding for inputs is important, while padding seems not help fused kernel.

Kernel b1_s7 b1_s15 b1_s31 b1_s63 b1_s127 b1_s255 b1_s383 b1_s511 b8_s7 b8_s15 b8_s31 b8_s63 b8_s127 b8_s255 b8_s383 b8_s511
Unfused MHA 2.23 2.41 2.85 4.47 7.35 14.25 22.51 29.33 4.29 7.12 13.05 24.37 49.81 102.57 167.56 235.94
Cutlass FMHA 2.32 2.47 2.82 4.39 7.16 13.58 20.59 26.73 4.42 7.12 12.89 23.86 48.45 97.08 152.75 211.63
Change 4.0% 2.5% -1.1% -1.8% -2.6% -4.7% -8.5% -8.9% 3.0% 0.0% -1.2% -2.1% -2.7% -5.4% -8.8% -10.3%

For fp32, a good threshold seems to be: use unfused kernel when sequence length < 256, otherwise use cutlass fmha for longer sequence. This threshold 256 also considers other GPU like V100 and RT 1070 etc.

Average latency of bert-base-cased with FP32 in V100 GPU

Kernel  b1_s8 b1_s16 b1_s32 b1_s64 b1_s128 b1_s256 b1_s384 b1_s512 b8_s8 b8_s16 b8_s32 b8_s64 b8_s128 b8_s256 b8_s384 b8_s512
Unfused MHA 1.48 1.51 1.74 2.13 3.21 6.33 7.79 11.24 2.06 3.14 5.86 9.64 17.73 34.69 53.84 72.55
Cutlass FMHA 1.48 1.52 1.59 2.1 3.27 6.2 7.33 10.73 2.12 3.14 5.79 9.4 17.44 33.66 50.93 68.87
Diff 0.0% 0.7% -8.6% -1.4% 1.9% -2.1% -5.9% -4.5% 2.9% 0.0% -1.2% -2.5% -1.6% -3.0% -5.4% -5.1%
Kernel  b1_s7 b1_s15 b1_s31 b1_s63 b1_s127 b1_s255 b1_s383 b1_s511 b8_s7 b8_s15 b8_s31 b8_s63 b8_s127 b8_s255 b8_s383 b8_s511
Unfused MHA 1.47 1.5 1.64 2.11 3.24 6.36 7.92 11.48 2.02 3.14 5.88 9.77 17.87 35.05 54.53 73.28
Cutlass FMHA 1.49 1.52 1.57 2.11 3.29 6.22 7.36 10.74 2.11 3.18 5.82 9.41 17.46 33.68 50.92 68.42
Diff 1.4% 1.3% -4.3% 0.0% 1.5% -2.2% -7.1% -6.4% 4.5% 1.3% -1.0% -3.7% -2.3% -3.9% -6.6% -6.6%

Average latency of bert-base-cased with FP16 in V100 GPU

  b1_s8 b1_s16 b1_s32 b1_s64 b1_s128 b1_s256 b1_s384 b1_s512 b8_s8 b8_s16 b8_s32 b8_s64 b8_s128 b8_s256 b8_s384 b8_s512
Unfused MHA 1.43 1.43 1.50 1.60 1.63 2.04 2.93 3.71 1.47 1.56 1.93 3.05 4.45 8.79 14.33 19.54
Cutlass FMHA 1.27 1.28 1.27 1.27 1.50 1.98 2.67 3.43 1.24 1.53 1.83 2.90 4.15 7.92 11.31 15.71
TRT FMHA 1.26 1.24 1.22 1.27 1.53 2.19 3.39 3.19 1.27 1.48 1.84 2.93 4.12 7.87 11.93 15.25
  b1_s7 b1_s15 b1_s31 b1_s63 b1_s127 b1_s255 b1_s383 b1_s511 b8_s7 b8_s15 b8_s31 b8_s63 b8_s127 b8_s255 b8_s383 b8_s511
Unfused MHA 1.45 1.44 1.94 1.98 2.07 2.36 3.46 4.65 1.47 1.56 2.23 3.74 4.71 9.66 15.52 21.76
Cutlass FMHA 1.24 1.24 1.21 1.27 1.51 2.04 2.70 3.54 1.24 1.46 1.78 2.95 4.14 7.96 11.35 15.75
TRT FMHA 1.26 1.25 1.20 1.29 1.53 2.17 3.37 3.18 1.29 1.45 1.83 2.93 4.12 7.85 11.92 15.23

Average latency of bert-base-cased with FP32 on RTX 1070 GPU

  b1_s8 b1_s16 b1_s32 b1_s64 b1_s128 b1_s256 b1_s384 b1_s512 b8_s8 b8_s16 b8_s32 b8_s64 b8_s128 b8_s256 b8_s384 b8_s512
Unfused MHA 2.98 2.99 3.42 4.9 7.22 12.88 20.06 26.87 4.92 7.22 12.15 22.34 44.22 91.64 152 208.57
Cutlass FMHA 3.27 3.25 3.36 5.06 7.4 13.16 18.81 26.68 5.48 7.89 12.87 23.32 45.7 90.85 136.63 189.41
Change 9.7% 8.7% -1.8% 3.3% 2.5% 2.2% -6.2% -0.7% 11.4% 9.3% 5.9% 4.4% 3.3% -0.9% -10.1% -9.2%
  b1_s7 b1_s15 b1_s31 b1_s63 b1_s127 b1_s255 b1_s383 b1_s511 b8_s7 b8_s15 b8_s31 b8_s63 b8_s127 b8_s255 b8_s383 b8_s511
Unfused MHA 3.03 3.05 3.48 5.09 7.4 13.2 20.78 28.39 4.84 7.45 12.84 23.66 46.84 95.98 153.87 211.01
Cutlass FMHA 3.26 3.26 3.38 4.97 7.43 13.28 18.94 26.75 5.33 7.8 12.78 23.42 45.72 91.34 137.29 190.44
Change 7.6% 6.9% -2.9% -2.4% 0.4% 0.6% -8.9% -5.8% 10.1% 4.7% -0.5% -1.0% -2.4% -4.8% -10.8% -9.7%

Average latency of bert-base-cased with FP16 on RTX 1070 GPU

  b1_s8 b1_s16 b1_s32 b1_s64 b1_s128 b1_s256 b1_s384 b1_s512 b8_s8 b8_s16 b8_s32 b8_s64 b8_s128 b8_s256 b8_s384 b8_s512
Unfused MHA 3.18 3.26 3.6 5.5 8.21 14.97 23.26 31.88 5.25 8.17 14.68 27.5 53.17 109.36 183.23 248.98
Cutlass FMHA 3.42 3.55 3.5 5.36 8.43 15.23 22.13 30.97 5.88 8.75 14.93 27.52 52.59 106.39 169.52 232.67
Change 7.5% 8.9% -2.8% -2.5% 2.7% 1.7% -4.9% -2.9% 12.0% 7.1% 1.7% 0.1% -1.1% -2.7% -7.5% -6.6%
  b1_s7 b1_s15 b1_s31 b1_s63 b1_s127 b1_s255 b1_s383 b1_s511 b8_s7 b8_s15 b8_s31 b8_s63 b8_s127 b8_s255 b8_s383 b8_s511
Unfused MHA 3.15 3.25 3.59 5.51 8.39 15.19 23.53 32.37 5.06 8.13 14.8 27.55 53.28 110.15 183.8 250.95
Cutlass FMHA 3.46 3.52 3.41 5.22 8.06 14.74 21.11 29.24 5.55 8.26 14 25.63 49.73 100.4 164.56 229.96
Change 9.8% 8.3% -5.0% -5.3% -3.9% -3.0% -10.3% -9.7% 9.7% 1.6% -5.4% -7.0% -6.7% -8.9% -10.5% -8.4%

Average latency of stable diffusion v1.5 pipeline (50 steps, 512x512 images) with FP16 in T4 GPU

Kernel Latency (s)
Attention Not fused 11.5
Cutlass FMHA 7.6
TRT FMHA 7.2

Motivation and Context

@tianleiwu tianleiwu marked this pull request as draft January 18, 2023 19:25
@tianleiwu tianleiwu marked this pull request as ready for review January 19, 2023 18:08
@tianleiwu tianleiwu merged commit 414b012 into main Jan 20, 2023
@tianleiwu tianleiwu deleted the tlwu/cutlass_memory_efficient_attention branch January 20, 2023 20:33
adrianlizarraga added a commit that referenced this pull request Jan 24, 2023
…14404)

### Description
Fixes unused `use_memory_efficient_attention` variable in
contrib_ops/cuda/bert/attention_impl.cu.



### Motivation and Context
ORT with CUDA version < 11.6 fails to build for release configurations
due to an unused variable.

```shell
c:\...\onnxruntime\onnxruntime\contrib_ops\cuda\bert\attention_impl.cu(420): error : variable "use_memory_efficient_attention" was declared but never referenced [C:\...\onnxruntime\build\Windows\RelWithDebInfo\onnx
runtime_providers_cuda.vcxproj]
            detected during instantiation of "onnxruntime::common::Status onnxruntime::contrib::cuda::QkvToContext(const cudaDeviceProp &, cublasHandle_t &, cudaStream_t, onnxruntime::contrib::AttentionParameters &, onnxruntime::contrib::cuda::AttentionData<T> &) [wit
  h T=float]"
  (923): here
```

This happens for CUDA < 11.6. Our cmake script turns off
onnxruntime_USE_FLASH_ATTENTION for CUDA < 11.6, which leaves the
aforementioned variable unused outside of asserts (which are removed in
release builds).

The USE_FLASH_ATTENTION option was added by
#14343
@faxu faxu added the triage:approved Approved for cherrypicks for release label Jan 25, 2023
tianleiwu added a commit that referenced this pull request Jan 25, 2023
…s FMHA (#14401)

### Description
Add sequence length threshold for triggering cutlass FMHA in FP32. See
performance test results in
#14343 to see how this
threshold is selected.

Upgrade cutlass to v2.11 and update deps.txt and cgmanifest for nuget
pipeline build (test build:
https://aiinfra.visualstudio.com/Lotus/_build/results?buildId=268574&view=results)
rui-ren pushed a commit that referenced this pull request Jan 27, 2023
### Description
Add memory efficient attention from CUTLASS.

TODO (in next pull request): 
(1) Need performance tests on different GPUs, then add a sequence length
threshold (only activate it for long sequence length).
(2) Merge changes from NVIDIA/cutlass#773 when
it is in cutlass master.
rui-ren pushed a commit that referenced this pull request Jan 27, 2023
…s FMHA (#14401)

### Description
Add sequence length threshold for triggering cutlass FMHA in FP32. See
performance test results in
#14343 to see how this
threshold is selected.

Upgrade cutlass to v2.11 and update deps.txt and cgmanifest for nuget
pipeline build (test build:
https://aiinfra.visualstudio.com/Lotus/_build/results?buildId=268574&view=results)
rui-ren pushed a commit that referenced this pull request Jan 27, 2023
### Description
Add memory efficient attention from CUTLASS.

TODO (in next pull request): 
(1) Need performance tests on different GPUs, then add a sequence length
threshold (only activate it for long sequence length).
(2) Merge changes from NVIDIA/cutlass#773 when
it is in cutlass master.
@faxu faxu removed the release:1.14 label Feb 1, 2023
rui-ren pushed a commit that referenced this pull request Feb 3, 2023
…14404)

### Description
Fixes unused `use_memory_efficient_attention` variable in
contrib_ops/cuda/bert/attention_impl.cu.



### Motivation and Context
ORT with CUDA version < 11.6 fails to build for release configurations
due to an unused variable.

```shell
c:\...\onnxruntime\onnxruntime\contrib_ops\cuda\bert\attention_impl.cu(420): error : variable "use_memory_efficient_attention" was declared but never referenced [C:\...\onnxruntime\build\Windows\RelWithDebInfo\onnx
runtime_providers_cuda.vcxproj]
            detected during instantiation of "onnxruntime::common::Status onnxruntime::contrib::cuda::QkvToContext(const cudaDeviceProp &, cublasHandle_t &, cudaStream_t, onnxruntime::contrib::AttentionParameters &, onnxruntime::contrib::cuda::AttentionData<T> &) [wit
  h T=float]"
  (923): here
```

This happens for CUDA < 11.6. Our cmake script turns off
onnxruntime_USE_FLASH_ATTENTION for CUDA < 11.6, which leaves the
aforementioned variable unused outside of asserts (which are removed in
release builds).

The USE_FLASH_ATTENTION option was added by
#14343
rui-ren pushed a commit that referenced this pull request Feb 3, 2023
…14404)

### Description
Fixes unused `use_memory_efficient_attention` variable in
contrib_ops/cuda/bert/attention_impl.cu.



### Motivation and Context
ORT with CUDA version < 11.6 fails to build for release configurations
due to an unused variable.

```shell
c:\...\onnxruntime\onnxruntime\contrib_ops\cuda\bert\attention_impl.cu(420): error : variable "use_memory_efficient_attention" was declared but never referenced [C:\...\onnxruntime\build\Windows\RelWithDebInfo\onnx
runtime_providers_cuda.vcxproj]
            detected during instantiation of "onnxruntime::common::Status onnxruntime::contrib::cuda::QkvToContext(const cudaDeviceProp &, cublasHandle_t &, cudaStream_t, onnxruntime::contrib::AttentionParameters &, onnxruntime::contrib::cuda::AttentionData<T> &) [wit
  h T=float]"
  (923): here
```

This happens for CUDA < 11.6. Our cmake script turns off
onnxruntime_USE_FLASH_ATTENTION for CUDA < 11.6, which leaves the
aforementioned variable unused outside of asserts (which are removed in
release builds).

The USE_FLASH_ATTENTION option was added by
#14343
snnn added a commit that referenced this pull request Apr 13, 2023
### Description

The following three lines are needed before including some cutlass
header files, because cutlass uses "and"/"or" keywords. Generally it
should not be a problem without this header, but nvcc is not strictly
compliant to C++ standard.

```c++
#ifdef __cplusplus
#include <ciso646>
#endif
```

We didn't hit this problem because the above code exists in absl. We
always include absl headers first. However, ABSL recently deleted them!
abseil/abseil-cpp#1246

The cutlass dependency was introduced in #14343 , after we had abseil.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
triage:approved Approved for cherrypicks for release
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants