Skip to content

Commit

Permalink
[CUDA] Lean Attention (microsoft#22352)
Browse files Browse the repository at this point in the history
### Description
Add [Lean Attention](https://arxiv.org/abs/2405.10480) and the
integration with MultiHeadAttention operator for LLM in GPU.

LeanAttention speeds up self-attention for the token-generation phase
(decode-phase) of decoder-only transformer models, especially on long
context lengths.

- [x] Initial implementation of Lean Attention (by Srikant Bharadwaj)
- [x] Integration with MultiHeadAttention operator
- [x] Add parity tests
- [x] Add benchmark

#### Implementation Details

(1) Lean Attention is enabled in build for Linux, and disabled for
Windows
(2) Lean Attention is disabled by default. Need enable it through cuda
provider option sdpa_kernel, or use environment variable
`ORT_ENABLE_LEAN_ATTENTION=1`
(3) It only works for token-generation (sequence_length==1,
past_sequence_length > 0).
(4) Like flash attention, it only works in Ampere or newer GPU.

We can revisit #1 and #2 after comparing with
DecoderMaskedMultiHeadAttention and XQA kernels.

#### Benchmark

```
cd onnxruntime/test/python/transformers 
/bin/bash benchmark_mha.sh lean
```

Example outputs in H100:

Note that past and present does not share buffer for MHA for now, so we
can see low tflops. The relative ratio will change after buffer sharing
is enabled. But we expect that the order (kernel A is faster than B)
will remain the same after buffer sharing is enabled.

Note that common settings `sequence_length=1;
causal=True;attn_bias=None;cuda_graph=False` are not shown in the below
table.

batch_size | past_sequence_length | num_heads | head_size |
average_latency | tflops | kernel
-- | -- | -- | -- | -- | -- | --
1 | 512 | 16 | 64 | 0.000059 | 0.0178 | ort:flash
1 | 512 | 16 | 64 | 0.000068 | 0.0155 | ort:efficient
1 | 512 | 16 | 64 | 0.000065 | 0.0161 | ort:math
1 | 512 | 16 | 64 | 0.000060 | 0.0176 | ort:lean
1 | 512 | 32 | 128 | 0.000062 | 0.0674 | ort:flash
1 | 512 | 32 | 128 | 0.000064 | 0.0661 | ort:efficient
1 | 512 | 32 | 128 | 0.000067 | 0.0625 | ort:math
1 | 512 | 32 | 128 | 0.000062 | 0.0678 | ort:lean
1 | 1024 | 16 | 64 | 0.000061 | 0.0345 | ort:flash
1 | 1024 | 16 | 64 | 0.000086 | 0.0244 | ort:efficient
1 | 1024 | 16 | 64 | 0.000065 | 0.0322 | ort:math
1 | 1024 | 16 | 64 | 0.000063 | 0.0332 | ort:lean
1 | 1024 | 32 | 128 | 0.000075 | 0.1125 | ort:flash
1 | 1024 | 32 | 128 | 0.000088 | 0.0951 | ort:efficient
1 | 1024 | 32 | 128 | 0.000079 | 0.1068 | ort:math
1 | 1024 | 32 | 128 | 0.000072 | 0.1171 | ort:lean
1 | 2048 | 16 | 64 | 0.000069 | 0.0606 | ort:flash
1 | 2048 | 16 | 64 | 0.000125 | 0.0336 | ort:efficient
1 | 2048 | 16 | 64 | 0.000064 | 0.0655 | ort:lean
1 | 2048 | 32 | 128 | 0.000098 | 0.1720 | ort:flash
1 | 2048 | 32 | 128 | 0.000132 | 0.1270 | ort:efficient
1 | 2048 | 32 | 128 | 0.000092 | 0.1828 | ort:lean
1 | 4096 | 16 | 64 | 0.000076 | 0.1097 | ort:flash
1 | 4096 | 16 | 64 | 0.000207 | 0.0406 | ort:efficient
1 | 4096 | 16 | 64 | 0.000069 | 0.1209 | ort:lean
1 | 4096 | 32 | 128 | 0.000140 | 0.2394 | ort:flash
1 | 4096 | 32 | 128 | 0.000213 | 0.1575 | ort:efficient
1 | 4096 | 32 | 128 | 0.000139 | 0.2419 | ort:lean
1 | 8192 | 16 | 64 | 0.000104 | 0.1609 | ort:flash
1 | 8192 | 16 | 64 | 0.000392 | 0.0428 | ort:efficient
1 | 8192 | 16 | 64 | 0.000093 | 0.1809 | ort:lean
1 | 8192 | 32 | 128 | 0.000212 | 0.3160 | ort:flash
1 | 8192 | 32 | 128 | 0.000360 | 0.1866 | ort:efficient
1 | 8192 | 32 | 128 | 0.000212 | 0.3162 | ort:lean
1 | 16384 | 16 | 64 | 0.000139 | 0.2410 | ort:flash
1 | 16384 | 16 | 64 | 0.000731 | 0.0459 | ort:efficient
1 | 16384 | 16 | 64 | 0.000136 | 0.2465 | ort:lean
1 | 16384 | 32 | 128 | 0.000361 | 0.3722 | ort:flash
1 | 16384 | 32 | 128 | 0.000667 | 0.2014 | ort:efficient
1 | 16384 | 32 | 128 | 0.000357 | 0.3765 | ort:lean
1 | 32768 | 16 | 64 | 0.000210 | 0.3194 | ort:flash
1 | 32768 | 16 | 64 | 0.001428 | 0.0470 | ort:efficient
1 | 32768 | 16 | 64 | 0.000209 | 0.3211 | ort:lean
1 | 32768 | 32 | 128 | 0.000659 | 0.4074 | ort:flash
1 | 32768 | 32 | 128 | 0.001270 | 0.2114 | ort:efficient
1 | 32768 | 32 | 128 | 0.000651 | 0.4123 | ort:lean
1 | 65536 | 16 | 64 | 0.000355 | 0.3785 | ort:flash
1 | 65536 | 16 | 64 | 0.002736 | 0.0491 | ort:efficient
1 | 65536 | 16 | 64 | 0.000349 | 0.3845 | ort:lean
1 | 65536 | 32 | 128 | 0.001251 | 0.4290 | ort:flash
1 | 65536 | 32 | 128 | 0.002480 | 0.2165 | ort:efficient
1 | 65536 | 32 | 128 | 0.001239 | 0.4333 | ort:lean
4 | 512 | 16 | 64 | 0.000063 | 0.0665 | ort:flash
4 | 512 | 16 | 64 | 0.000069 | 0.0607 | ort:efficient
4 | 512 | 16 | 64 | 0.000066 | 0.0634 | ort:math
4 | 512 | 16 | 64 | 0.000062 | 0.0674 | ort:lean
4 | 512 | 32 | 128 | 0.000100 | 0.1677 | ort:flash
4 | 512 | 32 | 128 | 0.000099 | 0.1703 | ort:efficient
4 | 512 | 32 | 128 | 0.000108 | 0.1557 | ort:math
4 | 512 | 32 | 128 | 0.000092 | 0.1818 | ort:lean
4 | 1024 | 16 | 64 | 0.000077 | 0.1094 | ort:flash
4 | 1024 | 16 | 64 | 0.000099 | 0.0850 | ort:efficient
4 | 1024 | 16 | 64 | 0.000081 | 0.1038 | ort:math
4 | 1024 | 16 | 64 | 0.000072 | 0.1161 | ort:lean
4 | 1024 | 32 | 128 | 0.000143 | 0.2343 | ort:flash
4 | 1024 | 32 | 128 | 0.000137 | 0.2447 | ort:efficient
4 | 1024 | 32 | 128 | 0.000150 | 0.2245 | ort:math
4 | 1024 | 32 | 128 | 0.000135 | 0.2496 | ort:lean
4 | 2048 | 16 | 64 | 0.000096 | 0.1757 | ort:flash
4 | 2048 | 16 | 64 | 0.000156 | 0.1078 | ort:efficient
4 | 2048 | 16 | 64 | 0.000089 | 0.1892 | ort:lean
4 | 2048 | 32 | 128 | 0.000223 | 0.3010 | ort:flash
4 | 2048 | 32 | 128 | 0.000217 | 0.3101 | ort:efficient
4 | 2048 | 32 | 128 | 0.000209 | 0.3209 | ort:lean
4 | 4096 | 16 | 64 | 0.000137 | 0.2448 | ort:flash
4 | 4096 | 16 | 64 | 0.000256 | 0.1312 | ort:efficient
4 | 4096 | 16 | 64 | 0.000133 | 0.2530 | ort:lean
4 | 4096 | 32 | 128 | 0.000389 | 0.3450 | ort:flash
4 | 4096 | 32 | 128 | 0.000376 | 0.3574 | ort:efficient
4 | 4096 | 32 | 128 | 0.000354 | 0.3794 | ort:lean
4 | 8192 | 16 | 64 | 0.000210 | 0.3198 | ort:flash
4 | 8192 | 16 | 64 | 0.000453 | 0.1480 | ort:efficient
4 | 8192 | 16 | 64 | 0.000206 | 0.3260 | ort:lean
4 | 8192 | 32 | 128 | 0.000725 | 0.3705 | ort:flash
4 | 8192 | 32 | 128 | 0.000693 | 0.3874 | ort:efficient
4 | 8192 | 32 | 128 | 0.000653 | 0.4114 | ort:lean
4 | 16384 | 16 | 64 | 0.000355 | 0.3782 | ort:flash
4 | 16384 | 16 | 64 | 0.000849 | 0.1581 | ort:efficient
4 | 16384 | 16 | 64 | 0.000346 | 0.3874 | ort:lean
4 | 16384 | 32 | 128 | 0.001395 | 0.3848 | ort:flash
4 | 16384 | 32 | 128 | 0.001337 | 0.4017 | ort:efficient
4 | 16384 | 32 | 128 | 0.001252 | 0.4288 | ort:lean
4 | 32768 | 16 | 64 | 0.000647 | 0.4146 | ort:flash
4 | 32768 | 16 | 64 | 0.001649 | 0.1628 | ort:efficient
4 | 32768 | 16 | 64 | 0.000639 | 0.4204 | ort:lean
4 | 32768 | 32 | 128 | 0.002721 | 0.3947 | ort:flash
4 | 32768 | 32 | 128 | 0.002601 | 0.4128 | ort:efficient
4 | 32768 | 32 | 128 | 0.002434 | 0.4411 | ort:lean
4 | 65536 | 16 | 64 | 0.001231 | 0.4361 | ort:flash
4 | 65536 | 16 | 64 | 0.003238 | 0.1658 | ort:efficient
4 | 65536 | 16 | 64 | 0.001217 | 0.4412 | ort:lean
4 | 65536 | 32 | 128 | 0.005357 | 0.4009 | ort:flash
4 | 65536 | 32 | 128 | 0.005118 | 0.4196 | ort:efficient
4 | 65536 | 32 | 128 | 0.004781 | 0.4492 | ort:lean
16 | 512 | 16 | 64 | 0.000098 | 0.1724 | ort:flash
16 | 512 | 16 | 64 | 0.000104 | 0.1616 | ort:efficient
16 | 512 | 16 | 64 | 0.000118 | 0.1420 | ort:math
16 | 512 | 16 | 64 | 0.000087 | 0.1926 | ort:lean
16 | 512 | 32 | 128 | 0.000220 | 0.3062 | ort:flash
16 | 512 | 32 | 128 | 0.000208 | 0.3237 | ort:efficient
16 | 512 | 32 | 128 | 0.000237 | 0.2838 | ort:math
16 | 512 | 32 | 128 | 0.000209 | 0.3216 | ort:lean
16 | 1024 | 16 | 64 | 0.000136 | 0.2465 | ort:flash
16 | 1024 | 16 | 64 | 0.000150 | 0.2235 | ort:efficient
16 | 1024 | 16 | 64 | 0.000148 | 0.2266 | ort:math
16 | 1024 | 16 | 64 | 0.000129 | 0.2611 | ort:lean
16 | 1024 | 32 | 128 | 0.000367 | 0.3663 | ort:flash
16 | 1024 | 32 | 128 | 0.000351 | 0.3829 | ort:efficient
16 | 1024 | 32 | 128 | 0.000400 | 0.3357 | ort:math
16 | 1024 | 32 | 128 | 0.000349 | 0.3853 | ort:lean
16 | 2048 | 16 | 64 | 0.000209 | 0.3206 | ort:flash
16 | 2048 | 16 | 64 | 0.000243 | 0.2762 | ort:efficient
16 | 2048 | 16 | 64 | 0.000201 | 0.3338 | ort:lean
16 | 2048 | 32 | 128 | 0.000671 | 0.4002 | ort:flash
16 | 2048 | 32 | 128 | 0.000645 | 0.4163 | ort:efficient
16 | 2048 | 32 | 128 | 0.000642 | 0.4185 | ort:lean
16 | 4096 | 16 | 64 | 0.000360 | 0.3732 | ort:flash
16 | 4096 | 16 | 64 | 0.000425 | 0.3162 | ort:efficient
16 | 4096 | 16 | 64 | 0.000341 | 0.3933 | ort:lean
16 | 4096 | 32 | 128 | 0.001292 | 0.4156 | ort:flash
16 | 4096 | 32 | 128 | 0.001251 | 0.4291 | ort:efficient
16 | 4096 | 32 | 128 | 0.001241 | 0.4327 | ort:lean
16 | 8192 | 16 | 64 | 0.000666 | 0.4030 | ort:flash
16 | 8192 | 16 | 64 | 0.000804 | 0.3339 | ort:efficient
16 | 8192 | 16 | 64 | 0.000627 | 0.4283 | ort:lean
16 | 8192 | 32 | 128 | 0.002541 | 0.4226 | ort:flash
16 | 8192 | 32 | 128 | 0.002454 | 0.4376 | ort:efficient
16 | 8192 | 32 | 128 | 0.002438 | 0.4405 | ort:lean
16 | 16384 | 16 | 64 | 0.001292 | 0.4156 | ort:flash
16 | 16384 | 16 | 64 | 0.001571 | 0.3417 | ort:efficient
16 | 16384 | 16 | 64 | 0.001217 | 0.4411 | ort:lean
16 | 16384 | 32 | 128 | 0.005042 | 0.4260 | ort:flash
16 | 16384 | 32 | 128 | 0.004859 | 0.4420 | ort:efficient
16 | 16384 | 32 | 128 | 0.004827 | 0.4449 | ort:lean
16 | 32768 | 16 | 64 | 0.002537 | 0.4233 | ort:flash
16 | 32768 | 16 | 64 | 0.003103 | 0.3461 | ort:efficient
16 | 32768 | 16 | 64 | 0.002385 | 0.4501 | ort:lean
16 | 32768 | 32 | 128 | 0.009961 | 0.4312 | ort:flash
16 | 32768 | 32 | 128 | 0.009605 | 0.4472 | ort:efficient
16 | 32768 | 32 | 128 | 0.009524 | 0.4510 | ort:lean
16 | 65536 | 16 | 64 | 0.005019 | 0.4279 | ort:flash
16 | 65536 | 16 | 64 | 0.006133 | 0.3502 | ort:efficient
16 | 65536 | 16 | 64 | 0.004703 | 0.4566 | ort:lean
16 | 65536 | 32 | 128 | 0.019746 | 0.4350 | ort:flash
16 | 65536 | 32 | 128 | 0.019027 | 0.4515 | ort:efficient
16 | 65536 | 32 | 128 | 0.018864 | 0.4554 | ort:lean

### Motivation and Context
<!-- - Why is this change required? What problem does it solve?
- If it fixes an open issue, please link to the issue here. -->
  • Loading branch information
tianleiwu authored Oct 14, 2024
1 parent 87e8a5d commit de93f40
Show file tree
Hide file tree
Showing 27 changed files with 3,578 additions and 68 deletions.
17 changes: 17 additions & 0 deletions cmake/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,7 @@ option(onnxruntime_USE_LLVM "Build TVM with LLVM" OFF)
option(onnxruntime_USE_VSINPU "Build with VSINPU support" OFF)

cmake_dependent_option(onnxruntime_USE_FLASH_ATTENTION "Build flash attention kernel for scaled dot product attention" ON "onnxruntime_USE_CUDA" OFF)
cmake_dependent_option(onnxruntime_USE_LEAN_ATTENTION "Build lean attention kernel for scaled dot product attention" ON "onnxruntime_USE_CUDA; NOT WIN32" OFF)
option(onnxruntime_USE_MEMORY_EFFICIENT_ATTENTION "Build memory efficient attention kernel for scaled dot product attention" ON)

option(onnxruntime_BUILD_FOR_NATIVE_MACHINE "Enable this option for turning on optimization specific to this machine" OFF)
Expand Down Expand Up @@ -751,21 +752,30 @@ if (onnxruntime_USE_CUDA)

if (onnxruntime_DISABLE_CONTRIB_OPS)
set(onnxruntime_USE_FLASH_ATTENTION OFF)
set(onnxruntime_USE_LEAN_ATTENTION OFF)
set(onnxruntime_USE_MEMORY_EFFICIENT_ATTENTION OFF)
endif()

if (CMAKE_CUDA_COMPILER_VERSION VERSION_LESS 11.6)
message( STATUS "Turn off flash attention since CUDA compiler version < 11.6")
set(onnxruntime_USE_FLASH_ATTENTION OFF)
set(onnxruntime_USE_LEAN_ATTENTION OFF)
set(onnxruntime_USE_MEMORY_EFFICIENT_ATTENTION OFF)
elseif(WIN32 AND CMAKE_CUDA_COMPILER_VERSION VERSION_LESS 12)
message( STATUS "Flash-Attention unsupported in Windows with CUDA compiler version < 12.0")
set(onnxruntime_USE_FLASH_ATTENTION OFF)
endif()

if (CMAKE_CUDA_COMPILER_VERSION VERSION_LESS 11.4)
message( FATAL_ERROR "Failed build due to CUDA compiler version < 11.4")
endif()
if (WIN32)
message( STATUS "Lean Attention unsupported in Windows")
set(onnxruntime_USE_LEAN_ATTENTION OFF)
endif()
else()
set(onnxruntime_USE_FLASH_ATTENTION OFF)
set(onnxruntime_USE_LEAN_ATTENTION OFF)
set(onnxruntime_USE_MEMORY_EFFICIENT_ATTENTION OFF)
endif()

Expand All @@ -779,6 +789,13 @@ if (onnxruntime_USE_CUDA)
list(APPEND ORT_PROVIDER_FLAGS -DUSE_FLASH_ATTENTION=1)
list(APPEND ORT_PROVIDER_CMAKE_FLAGS -Donnxruntime_USE_FLASH_ATTENTION=1)
endif()

if (onnxruntime_USE_LEAN_ATTENTION)
message( STATUS "Enable lean attention for CUDA EP")
list(APPEND ORT_PROVIDER_FLAGS -DUSE_LEAN_ATTENTION=1)
list(APPEND ORT_PROVIDER_CMAKE_FLAGS -Donnxruntime_USE_LEAN_ATTENTION=1)
endif()

if (onnxruntime_USE_MEMORY_EFFICIENT_ATTENTION)
message( STATUS "Enable memory efficient attention for CUDA EP")
list(APPEND ORT_PROVIDER_FLAGS -DUSE_MEMORY_EFFICIENT_ATTENTION=1)
Expand Down
10 changes: 8 additions & 2 deletions onnxruntime/contrib_ops/cpu/bert/attention_common.h
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ enum AttentionKernelType {
AttentionKernel_CutlassMemoryEfficientAttention,
AttentionKernel_FlashAttention,
AttentionKernel_CudnnFlashAttention,
AttentionKernel_LeanAttention,
AttentionKernel_Default
};

Expand All @@ -65,7 +66,6 @@ struct AttentionParameters {
int v_hidden_size; // hidden size of V
int v_head_size; // hidden size per head of V
int num_heads;
int num_splits;
int rotary_embedding;
bool is_unidirectional;
bool past_present_share_buffer;
Expand Down Expand Up @@ -208,10 +208,13 @@ enum class AttentionBackend : int {
CUDNN_FLASH_ATTENTION = 8, // reserved for cuDNN flash attention.
MATH = 16, // unfused kernel cannot be disabled right now.

// The following kernels might be deprecated in the future.
// The following TRT kernels might be deprecated in the future.
TRT_FLASH_ATTENTION = 32,
TRT_CROSS_ATTENTION = 64,
TRT_CAUSAL_ATTENTION = 128,

// Experimental kernels
LEAN_ATTENTION = 256,
};

// Environment variable to enable debug information of attention kernel to be printed. Default is 0 (disabled).
Expand Down Expand Up @@ -239,6 +242,9 @@ constexpr const char* kDisableMemoryEfficientAttention = "ORT_DISABLE_MEMORY_EFF
// Environment variable to enable or disable flash attention. Default is 0 (enabled).
constexpr const char* kDisableFlashAttention = "ORT_DISABLE_FLASH_ATTENTION";

// Environment variable to enable or disable lean attention. Default is 0 (disabled).
constexpr const char* kEnableLeanAttention = "ORT_ENABLE_LEAN_ATTENTION";

// Minimum sequence length to perfer memory efficient attention when data type is float32
constexpr const char* kMinSeqLenForEfficientAttentionFp32 = "ORT_MIN_SEQ_LEN_EFFICIENT_ATTENTION_FP32";

Expand Down
18 changes: 15 additions & 3 deletions onnxruntime/contrib_ops/cuda/bert/attention.cc
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,9 @@ Status Attention<T>::ComputeInternal(OpKernelContext* context) const {
const int sm = device_prop.major * 10 + device_prop.minor;
const bool is_mask_1d_seq_len = parameters.mask_type == AttentionMaskType::MASK_1D_KEY_SEQ_LEN;

typedef typename ToCudaType<T>::MappedType CudaT;
AttentionData<CudaT> data;

#if USE_FLASH_ATTENTION
bool use_flash_attention = !disable_flash_attention_ &&
(nullptr == attention_bias) &&
Expand All @@ -118,21 +121,26 @@ Status Attention<T>::ComputeInternal(OpKernelContext* context) const {
use_flash_attention = false;
}
// Allocate buffers
size_t softmax_lse_bytes = 0;
size_t softmax_lse_accum_bytes = 0;
size_t out_accum_bytes = 0;
if (use_flash_attention) {
softmax_lse_bytes = onnxruntime::flash::get_softmax_lse_size(sequence_length, batch_size, parameters.num_heads);

using namespace std;
auto [num_splits, slse_accum_bytes, o_accum_bytes] = onnxruntime::flash::get_num_splits_and_buffer_sizes(
parameters.batch_size, parameters.sequence_length, parameters.total_sequence_length, parameters.num_heads,
parameters.head_size, device_prop.multiProcessorCount);
parameters.num_splits = static_cast<int>(num_splits);
data.num_splits = static_cast<int>(num_splits);
softmax_lse_accum_bytes = slse_accum_bytes;
out_accum_bytes = o_accum_bytes;
}
auto softmax_lse_buffer = GetScratchBuffer<void>(softmax_lse_bytes, context->GetComputeStream());
auto softmax_lse_accum_buffer = GetScratchBuffer<void>(softmax_lse_accum_bytes, context->GetComputeStream());
auto out_accum_buffer = GetScratchBuffer<void>(out_accum_bytes, context->GetComputeStream());
#else
constexpr bool use_flash_attention = false;
auto softmax_lse_buffer = GetScratchBuffer<void>(0, context->GetComputeStream());
auto softmax_lse_accum_buffer = GetScratchBuffer<void>(0, context->GetComputeStream()); // nullptr
auto out_accum_buffer = GetScratchBuffer<void>(0, context->GetComputeStream()); // nullptr
#endif
Expand Down Expand Up @@ -247,6 +255,7 @@ Status Attention<T>::ComputeInternal(OpKernelContext* context) const {
constexpr size_t element_size = sizeof(T);
constexpr bool use_fused_cross_attention = false;
constexpr bool use_cudnn_flash_attention = false;
constexpr bool use_lean_attention = false;
size_t workSpaceSize = GetAttentionWorkspaceSize(element_size,
parameters.batch_size,
parameters.num_heads,
Expand All @@ -257,14 +266,13 @@ Status Attention<T>::ComputeInternal(OpKernelContext* context) const {
parameters.total_sequence_length,
fused_runner,
use_flash_attention,
use_lean_attention,
use_fused_cross_attention,
use_memory_efficient_attention,
use_cudnn_flash_attention,
false);
IAllocatorUniquePtr<void> work_space = IAllocator::MakeUniquePtr<void>(allocator, workSpaceSize, false, context->GetComputeStream());

typedef typename ToCudaType<T>::MappedType CudaT;
AttentionData<CudaT> data;
data.gemm_buffer = reinterpret_cast<CudaT*>(gemm_buffer.get());
if (nullptr != bias) {
data.bias = reinterpret_cast<const CudaT*>(bias->Data<T>());
Expand All @@ -289,6 +297,10 @@ Status Attention<T>::ComputeInternal(OpKernelContext* context) const {
data.fused_runner = reinterpret_cast<void*>(fused_runner);
data.use_flash_attention = use_flash_attention;
data.use_memory_efficient_attention = use_memory_efficient_attention;
if (softmax_lse_buffer != nullptr) {
data.softmax_lse = reinterpret_cast<CudaT*>(softmax_lse_buffer.get());
}

if (softmax_lse_accum_buffer != nullptr) {
data.softmax_lse_accum = reinterpret_cast<CudaT*>(softmax_lse_accum_buffer.get());
}
Expand Down
96 changes: 93 additions & 3 deletions onnxruntime/contrib_ops/cuda/bert/attention_impl.cu
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ limitations under the License.
#include "contrib_ops/cuda/bert/cutlass_fmha/memory_efficient_attention.h"
#include "contrib_ops/cuda/bert/cudnn_fmha/cudnn_flash_attention.h"
#include "contrib_ops/cuda/bert/flash_attention/flash_api.h"
#include "contrib_ops/cuda/bert/lean_attention/lean_api.h"
#include "contrib_ops/cuda/bert/attention_impl.h"

using namespace onnxruntime::cuda;
Expand Down Expand Up @@ -108,6 +109,7 @@ size_t GetAttentionWorkspaceSize(
size_t total_sequence_length,
void* fused_runner,
bool use_flash_attention,
bool use_lean_attention,
bool use_fused_cross_attention,
bool use_memory_efficient_attention,
bool use_cudnn_flash_attention,
Expand All @@ -119,12 +121,20 @@ size_t GetAttentionWorkspaceSize(

#if USE_FLASH_ATTENTION
if (use_flash_attention) {
return qkv_bytes + onnxruntime::flash::get_softmax_lse_size(sequence_length, batch_size, num_heads);
return qkv_bytes;
}
#else
ORT_UNUSED_PARAMETER(use_flash_attention);
#endif

#if USE_LEAN_ATTENTION
if (use_lean_attention) {
return qkv_bytes;
}
#else
ORT_UNUSED_PARAMETER(use_lean_attention);
#endif

#if USE_MEMORY_EFFICIENT_ATTENTION
if (use_memory_efficient_attention) {
size_t fmha_buffer_bytes = 0;
Expand Down Expand Up @@ -301,10 +311,10 @@ Status FlashAttention(

constexpr bool is_bf16 = false;
ORT_RETURN_IF_ERROR(onnxruntime::flash::mha_fwd(
device_prop, stream, data.q, data.k, data.v, data.output, reinterpret_cast<void*>(data.scratch),
device_prop, stream, data.q, data.k, data.v, data.output, reinterpret_cast<void*>(data.softmax_lse),
parameters.batch_size, parameters.num_heads, parameters.num_heads, parameters.head_size,
parameters.sequence_length, parameters.total_sequence_length, scale, 0.0, parameters.is_unidirectional, is_bf16,
false, parameters.num_splits, reinterpret_cast<void*>(data.softmax_lse_accum),
false, data.num_splits, reinterpret_cast<void*>(data.softmax_lse_accum),
reinterpret_cast<void*>(data.out_accum), data.qkv_format == AttentionQkvFormat::Q_K_V_BSNH));

return Status::OK();
Expand All @@ -326,6 +336,81 @@ Status FlashAttention(
}
#endif

#if USE_LEAN_ATTENTION
template <typename T>
Status LeanAttention(
const cudaDeviceProp& device_prop,
cudaStream_t stream,
contrib::AttentionParameters& parameters,
AttentionData<T>& data,
float scale) {
assert(data.qkv_format == AttentionQkvFormat::Q_K_V_BSNH ||
data.qkv_format == AttentionQkvFormat::Q_K_V_BSNH_BNSH_BNSH);
assert(nullptr == data.mask_index);
assert(nullptr == data.attention_bias);
assert(parameters.head_size == parameters.v_head_size);

constexpr bool is_bf16 = false;

ORT_RETURN_IF_ERROR(onnxruntime::lean::mha_fwd_kvcache(
device_prop, stream,
data.q,
data.k, // k_cache
data.v, // v_cache
nullptr, // new_k (we have appended new_k to k_cache)
nullptr, // new_v (we have appended new_v to k_cache)
data.output,
reinterpret_cast<void*>(data.softmax_lse),
nullptr, // seqlens_k
nullptr, // cos_cache
nullptr, // sin_cache
nullptr, // block_table
parameters.batch_size,
parameters.num_heads,
parameters.num_heads, // num_heads_k
parameters.head_size,
parameters.sequence_length, // seqlen_q
parameters.total_sequence_length, // seqlen_k
0, // seqlen_k_new
0, // rotary_dim
scale, // softmax_scale
parameters.is_unidirectional,
is_bf16,
false, // past_bsnh
data.num_splits,
data.grid_dim_z,
data.max_tiles_per_tb,
data.high_load_tbs,
data.tiles_per_head,
reinterpret_cast<void*>(data.softmax_lse_accum),
reinterpret_cast<void*>(data.out_accum),
data.lean_sync_flag,
-1, // local_window_size
false, // is_rotary_interleaved
false // is_packed_qkv
));

return Status::OK();
}

template <>
Status LeanAttention(
const cudaDeviceProp& device_prop,
cudaStream_t stream,
contrib::AttentionParameters& parameters,
AttentionData<float>& data,
float scale) {
ORT_UNUSED_PARAMETER(device_prop);
ORT_UNUSED_PARAMETER(stream);
ORT_UNUSED_PARAMETER(parameters);
ORT_UNUSED_PARAMETER(data);
ORT_UNUSED_PARAMETER(scale);
return ORT_MAKE_STATUS(ONNXRUNTIME, StatusCode::NOT_IMPLEMENTED, "lean attention does not support float tensor");
}
#endif



template <typename T>
Status CudnnFlashAttention(
cudnnHandle_t cudnn_handle,
Expand Down Expand Up @@ -641,6 +726,11 @@ Status QkvToContext(
// For raw attention mask, the scalar 1/sqrt(H) is moved to combine with softmax computation.
const float scale = parameters.scale == 0.0f ? 1.f / sqrt(static_cast<float>(qk_head_size))
: parameters.scale;
#if USE_LEAN_ATTENTION
if (data.use_lean_attention) {
return LeanAttention(device_prop, stream, parameters, data, scale);
}
#endif

#if USE_FLASH_ATTENTION
if (data.use_flash_attention) {
Expand Down
15 changes: 15 additions & 0 deletions onnxruntime/contrib_ops/cuda/bert/attention_impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ size_t GetAttentionWorkspaceSize(
size_t total_sequence_length,
void* fused_runner,
bool use_flash_attention,
bool use_lean_attention,
bool use_fused_cross_attention,
bool use_memory_efficient_attention,
bool use_cudnn_flash_attention,
Expand Down Expand Up @@ -102,6 +103,19 @@ struct AttentionData {
T* softmax_lse_accum = nullptr;
T* out_accum = nullptr;

// Flash Atttention and Lean Attention
int num_splits;

// Lean Attention
bool use_lean_attention = false;
#if USE_LEAN_ATTENTION
int grid_dim_z = 0;
int max_tiles_per_tb = 0;
int high_load_tbs = 0;
int tiles_per_head = 0;
int* lean_sync_flag = nullptr;
#endif

// For Debugging
size_t workspace_bytes = 0;
bool allow_debug_info = false;
Expand All @@ -115,6 +129,7 @@ struct AttentionData {

void PrintDebugInfo() const {
std::cout << "flash=" << use_flash_attention
<< ", lean=" << use_lean_attention
<< ", efficient=" << use_memory_efficient_attention
<< ", fused_runner=" << (fused_runner != nullptr)
<< ", fused_cross=" << (fused_cross_attention_kernel != nullptr)
Expand Down
17 changes: 17 additions & 0 deletions onnxruntime/contrib_ops/cuda/bert/attention_kernel_options.cc
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,9 @@ namespace onnxruntime {
void AttentionKernelOptions::Initialize(int value, bool use_build_flag, bool check_cudnn_version) {
if (value > 0) {
use_flash_attention_ = (value & static_cast<int>(AttentionBackend::FLASH_ATTENTION)) > 0;
#if USE_LEAN_ATTENTION
use_lean_attention_ = (value & static_cast<int>(AttentionBackend::LEAN_ATTENTION)) > 0;
#endif
use_efficient_attention_ = (value & static_cast<int>(AttentionBackend::EFFICIENT_ATTENTION)) > 0;
use_trt_fused_attention_ = (value & static_cast<int>(AttentionBackend::TRT_FUSED_ATTENTION)) > 0;
use_cudnn_flash_attention_ = (value & static_cast<int>(AttentionBackend::CUDNN_FLASH_ATTENTION)) > 0;
Expand All @@ -26,6 +29,9 @@ void AttentionKernelOptions::Initialize(int value, bool use_build_flag, bool che
use_trt_causal_attention_ = (value & static_cast<int>(AttentionBackend::TRT_CAUSAL_ATTENTION)) > 0;
} else {
use_flash_attention_ = !ParseEnvironmentVariableWithDefault<bool>(kDisableFlashAttention, false);
#if USE_LEAN_ATTENTION
use_lean_attention_ = ParseEnvironmentVariableWithDefault<bool>(kEnableLeanAttention, false);
#endif
use_efficient_attention_ = !ParseEnvironmentVariableWithDefault<bool>(kDisableMemoryEfficientAttention, false);
use_trt_fused_attention_ = !ParseEnvironmentVariableWithDefault<bool>(kDisableFusedSelfAttention, false);
use_cudnn_flash_attention_ = ParseEnvironmentVariableWithDefault<bool>(kEnableCudnnFlashAttention, false);
Expand Down Expand Up @@ -61,6 +67,10 @@ void AttentionKernelOptions::Initialize(int value, bool use_build_flag, bool che
use_flash_attention_ = false;
#endif

#ifndef USE_LEAN_ATTENTION
use_lean_attention_ = false;
#endif

#ifndef USE_MEMORY_EFFICIENT_ATTENTION
use_efficient_attention_ = false;
#endif
Expand All @@ -81,6 +91,9 @@ void AttentionKernelOptions::Print() const {
std::stringstream sstream;
sstream << "AttentionKernelOptions:";
sstream << " FLASH_ATTENTION=" << int(use_flash_attention_);
#if USE_LEAN_ATTENTION
sstream << " LEAN_ATTENTION=" << int(use_lean_attention_);
#endif
sstream << " EFFICIENT_ATTENTION=" << int(use_efficient_attention_);
sstream << " TRT_FUSED_ATTENTION=" << int(use_trt_fused_attention_);
sstream << " CUDNN_FLASH_ATTENTION=" << int(use_cudnn_flash_attention_);
Expand Down Expand Up @@ -131,6 +144,10 @@ void AttentionKernelDebugInfo::Print(const char* operator_name,
sstream << " SdpaKernel=";
if (use_flash_attention.has_value() && use_flash_attention.value()) {
sstream << "FLASH_ATTENTION";
#if USE_LEAN_ATTENTION
} else if (use_lean_attention.has_value() && use_lean_attention.value()) {
sstream << "LEAN_ATTENTION";
#endif
} else if (use_efficient_attention.has_value() && use_efficient_attention.value()) {
sstream << "EFFICIENT_ATTENTION";
} else if (use_trt_fused_attention.has_value() && use_trt_fused_attention.value()) {
Expand Down
Loading

0 comments on commit de93f40

Please sign in to comment.