From 414b012f42e3f9a343f827c363b30b94011dc3f7 Mon Sep 17 00:00:00 2001 From: Tianlei Wu Date: Fri, 20 Jan 2023 12:33:01 -0800 Subject: [PATCH] Add memory efficient attention from CUTLASS (#14343) ### Description Add memory efficient attention from CUTLASS. TODO (in next pull request): (1) Need performance tests on different GPUs, then add a sequence length threshold (only activate it for long sequence length). (2) Merge changes from https://github.com/NVIDIA/cutlass/pull/773 when it is in cutlass master. --- ThirdPartyNotices.txt | 31 + cmake/CMakeLists.txt | 25 +- cmake/external/cutlass.cmake | 12 + cmake/onnxruntime_providers.cmake | 6 + cmake/onnxruntime_rocm_hipify.cmake | 1 + .../contrib_ops/cpu/bert/attention_common.h | 10 +- .../cuda/bert/add_bias_transpose.cu | 118 +++ .../cuda/bert/add_bias_transpose.h | 3 + .../contrib_ops/cuda/bert/attention.cc | 35 +- onnxruntime/contrib_ops/cuda/bert/attention.h | 3 +- .../contrib_ops/cuda/bert/attention_impl.cu | 83 +- .../contrib_ops/cuda/bert/attention_impl.h | 5 +- .../bert/cutlass_fmha/fmha_launch_template.h | 116 +++ .../cuda/bert/cutlass_fmha/fmha_sm50.cu | 24 + .../cuda/bert/cutlass_fmha/fmha_sm70.cu | 24 + .../cuda/bert/cutlass_fmha/fmha_sm75.cu | 24 + .../cuda/bert/cutlass_fmha/fmha_sm80.cu | 24 + .../cuda/bert/cutlass_fmha/kernel_forward.h | 947 ++++++++++++++++++ .../memory_efficient_attention.cu | 30 + .../cutlass_fmha/memory_efficient_attention.h | 55 + .../cuda/bert/multihead_attention.cc | 31 +- .../cuda/bert/multihead_attention.h | 3 +- .../quantization/attention_quantization.cc | 5 +- .../tools/transformers/benchmark_helper.py | 3 +- .../test/contrib_ops/attention_op_test.cc | 6 +- .../contrib_ops/attention_op_test_helper.cc | 391 +++++++- .../contrib_ops/attention_op_test_helper.h | 4 + .../multihead_attention_op_test.cc | 95 +- .../multihead_attention_op_test_data_gen.py | 101 +- 29 files changed, 2140 insertions(+), 75 deletions(-) create mode 100644 cmake/external/cutlass.cmake create mode 100644 onnxruntime/contrib_ops/cuda/bert/cutlass_fmha/fmha_launch_template.h create mode 100644 onnxruntime/contrib_ops/cuda/bert/cutlass_fmha/fmha_sm50.cu create mode 100644 onnxruntime/contrib_ops/cuda/bert/cutlass_fmha/fmha_sm70.cu create mode 100644 onnxruntime/contrib_ops/cuda/bert/cutlass_fmha/fmha_sm75.cu create mode 100644 onnxruntime/contrib_ops/cuda/bert/cutlass_fmha/fmha_sm80.cu create mode 100644 onnxruntime/contrib_ops/cuda/bert/cutlass_fmha/kernel_forward.h create mode 100644 onnxruntime/contrib_ops/cuda/bert/cutlass_fmha/memory_efficient_attention.cu create mode 100644 onnxruntime/contrib_ops/cuda/bert/cutlass_fmha/memory_efficient_attention.h diff --git a/ThirdPartyNotices.txt b/ThirdPartyNotices.txt index abfbd588cd853..c41c3cbefcb73 100644 --- a/ThirdPartyNotices.txt +++ b/ThirdPartyNotices.txt @@ -2753,6 +2753,37 @@ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. +_____ +nvidia/cutlass + +Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +SPDX-License-Identifier: BSD-3-Clause + +Redistribution and use in source and binary forms, with or without +modification, are permitted provided that the following conditions are met: + +1. Redistributions of source code must retain the above copyright notice, this +list of conditions and the following disclaimer. + +2. Redistributions in binary form must reproduce the above copyright notice, +this list of conditions and the following disclaimer in the documentation +and/or other materials provided with the distribution. + +3. Neither the name of the copyright holder nor the names of its +contributors may be used to endorse or promote products derived from +this software without specific prior written permission. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + _____ Boost diff --git a/cmake/CMakeLists.txt b/cmake/CMakeLists.txt index 0a00b0835f815..b315b346f7b05 100644 --- a/cmake/CMakeLists.txt +++ b/cmake/CMakeLists.txt @@ -64,6 +64,8 @@ option(onnxruntime_USE_PREINSTALLED_EIGEN "Use pre-installed EIGEN. Need to prov option(onnxruntime_BUILD_BENCHMARKS "Build ONNXRuntime micro-benchmarks" OFF) option(onnxruntime_USE_LLVM "Build TVM with LLVM" OFF) +option(onnxruntime_USE_FLASH_ATTENTION "Build memory efficient attention kernel for scaled dot product attention" ON) + option(onnxruntime_BUILD_FOR_NATIVE_MACHINE "Enable this option for turning on optimization specific to this machine" OFF) option(onnxruntime_USE_AVX "Use AVX instructions" OFF) option(onnxruntime_USE_AVX2 "Use AVX2 instructions" OFF) @@ -595,10 +597,31 @@ set(ONNXRUNTIME_PROVIDER_NAMES cpu) set(ORT_PROVIDER_FLAGS) set(ORT_PROVIDER_CMAKE_FLAGS) +if (onnxruntime_USE_CUDA) + enable_language(CUDA) + message( STATUS "CMAKE_CUDA_COMPILER_VERSION: ${CMAKE_CUDA_COMPILER_VERSION}") + + if (onnxruntime_DISABLE_CONTRIB_OPS) + set(onnxruntime_USE_FLASH_ATTENTION OFF) + endif() + if (CMAKE_CUDA_COMPILER_VERSION VERSION_LESS 11.6) + message( STATUS "Turn off flash attention since CUDA compiler version < 11.6") + set(onnxruntime_USE_FLASH_ATTENTION OFF) + endif() +else() + set(onnxruntime_USE_FLASH_ATTENTION OFF) +endif() + if (onnxruntime_USE_CUDA) list(APPEND ORT_PROVIDER_FLAGS -DUSE_CUDA=1) list(APPEND ORT_PROVIDER_CMAKE_FLAGS -Donnxruntime_USE_CUDA=1) list(APPEND ONNXRUNTIME_PROVIDER_NAMES cuda) + + if (onnxruntime_USE_FLASH_ATTENTION) + message( STATUS "Enable flash attention for CUDA EP") + list(APPEND ORT_PROVIDER_FLAGS -DUSE_FLASH_ATTENTION=1) + list(APPEND ORT_PROVIDER_CMAKE_FLAGS -Donnxruntime_USE_FLASH_ATTENTION=1) + endif() endif() if (onnxruntime_USE_VITISAI) list(APPEND ORT_PROVIDER_FLAGS -DUSE_VITISAI=1) @@ -1234,8 +1257,6 @@ endif() if (onnxruntime_USE_CUDA) set(CMAKE_CUDA_RUNTIME_LIBRARY Shared) - enable_language(CUDA) - message( STATUS "CMAKE_CUDA_COMPILER_VERSION: ${CMAKE_CUDA_COMPILER_VERSION}") set(CMAKE_CUDA_STANDARD 17) if(onnxruntime_CUDNN_HOME) file(TO_CMAKE_PATH ${onnxruntime_CUDNN_HOME} onnxruntime_CUDNN_HOME) diff --git a/cmake/external/cutlass.cmake b/cmake/external/cutlass.cmake new file mode 100644 index 0000000000000..c4d52cfd2c1a0 --- /dev/null +++ b/cmake/external/cutlass.cmake @@ -0,0 +1,12 @@ +if (onnxruntime_USE_FLASH_ATTENTION) + include(FetchContent) + FetchContent_Declare(cutlass + GIT_REPOSITORY https://github.com/nvidia/cutlass.git + GIT_TAG 8b42e751c63ba219755c8ed91af5f6ec1ecc1ee6 + ) + + FetchContent_GetProperties(cutlass) + if(NOT cutlass_POPULATED) + FetchContent_Populate(cutlass) + endif() +endif() diff --git a/cmake/onnxruntime_providers.cmake b/cmake/onnxruntime_providers.cmake index 78b738cd29800..c07c60b486e4f 100644 --- a/cmake/onnxruntime_providers.cmake +++ b/cmake/onnxruntime_providers.cmake @@ -484,6 +484,12 @@ if (onnxruntime_USE_CUDA) if(onnxruntime_CUDNN_HOME) target_include_directories(onnxruntime_providers_cuda PRIVATE ${onnxruntime_CUDNN_HOME}/include) endif() + + if (onnxruntime_USE_FLASH_ATTENTION) + include(cutlass) + target_include_directories(onnxruntime_providers_cuda PRIVATE ${cutlass_SOURCE_DIR}/include ${cutlass_SOURCE_DIR}/examples) + endif() + target_include_directories(onnxruntime_providers_cuda PRIVATE ${ONNXRUNTIME_ROOT} ${CMAKE_CURRENT_BINARY_DIR} ${eigen_INCLUDE_DIRS} ${TVM_INCLUDES} PUBLIC ${CMAKE_CUDA_TOOLKIT_INCLUDE_DIRECTORIES}) # ${CMAKE_CURRENT_BINARY_DIR} is so that #include "onnxruntime_config.h" inside tensor_shape.h is found set_target_properties(onnxruntime_providers_cuda PROPERTIES LINKER_LANGUAGE CUDA) diff --git a/cmake/onnxruntime_rocm_hipify.cmake b/cmake/onnxruntime_rocm_hipify.cmake index 525d112b24040..d3b8f5ebfcc26 100644 --- a/cmake/onnxruntime_rocm_hipify.cmake +++ b/cmake/onnxruntime_rocm_hipify.cmake @@ -23,6 +23,7 @@ set(contrib_ops_excluded_files "bert/skip_layer_norm.h" "bert/skip_layer_norm_impl.cu" "bert/skip_layer_norm_impl.h" + "bert/cutlass_fmha/*" "bert/tensorrt_fused_multihead_attention/*" "bert/transformer_common.h" "bert/transformer_common.cc" diff --git a/onnxruntime/contrib_ops/cpu/bert/attention_common.h b/onnxruntime/contrib_ops/cpu/bert/attention_common.h index 849937d8cff8e..f45bbecfc71e0 100644 --- a/onnxruntime/contrib_ops/cpu/bert/attention_common.h +++ b/onnxruntime/contrib_ops/cpu/bert/attention_common.h @@ -19,7 +19,7 @@ enum AttentionMaskType { enum AttentionQkvFormat { Q_K_V_BNSH, // for unfused attention - Q_K_V_BSNH, // input format of query, key and value for MultiHeadAttention + Q_K_V_BSNH, // for memory efficient attention, or format of query, key and value for MultiHeadAttention QKV_BSN3H, // for TRT fused attention, qkv are packed Q_K_V_BNSH_QKV_BS3NH, // for TRT fused causal attention, data has two formats (qkv is 3BNSH, gemm_buffer is BS3NH) Q_KV_BSNH_BSN2H, // for TRT fused cross attention, kv are packed @@ -30,6 +30,7 @@ enum AttentionKernelType{ AttentionKernel_TrtFusedAttention, AttentionKernel_TrtFlashAttention, AttentionKernel_TrtFusedCrossAttention, + AttentionKernel_CutlassMemoryEfficientAttention, AttentionKernel_Default }; @@ -61,8 +62,11 @@ constexpr const char* kDisableFusedAttention = "ORT_DISABLE_FUSED_ATTENTION"; // Environment variable to enable or disable fused cross attention kernel. Default is 0 (enabled). constexpr const char* kDisableFusedCrossAttention = "ORT_DISABLE_FUSED_CROSS_ATTENTION"; -// Environment variable to enable or disable flash attention. Default is 0 (enabled). -constexpr const char* kDisableFlashAttention = "ORT_DISABLE_FLASH_ATTENTION"; +// Environment variable to enable or disable TRT flash attention. Default is 0 (enabled). +constexpr const char* kDisableTrtFlashAttention = "ORT_DISABLE_TRT_FLASH_ATTENTION"; + +// Environment variable to enable or disable cutlass memory efficient attention. Default is 0 (enabled). +constexpr const char* kDisableMemoryEfficientAttention = "ORT_DISABLE_MEMORY_EFFICIENT_ATTENTION"; } // namespace attention diff --git a/onnxruntime/contrib_ops/cuda/bert/add_bias_transpose.cu b/onnxruntime/contrib_ops/cuda/bert/add_bias_transpose.cu index c4f96df610625..b7eebb9d48785 100644 --- a/onnxruntime/contrib_ops/cuda/bert/add_bias_transpose.cu +++ b/onnxruntime/contrib_ops/cuda/bert/add_bias_transpose.cu @@ -320,6 +320,110 @@ __global__ void AddBiasTransposeQKVLarge(const int head_size, const T* input, co } } +template +__global__ void AddBiasTransposeCutlass(const T* input, const T* biases, T* output, int v_head_size) { + // Format 3 for cutlass memory efficient attention + // Input: BxSx(NxH + NxH + NxH_v) (Packed QKV where K and V has different hidden sizes) + // Output: BxNxSxH + BxNxSxH + BxNxSxH_v + // B is batch_size, S is sequence_length, N is num_heads, H is qk_head_size, H_v is v_head_size + int n = threadIdx.y; // head_num_id + int s = blockIdx.x; // sequence_id + int b = blockIdx.y; // batch_id + int m = blockIdx.z; // matrix id (Q=0, K=1, V=2) + const int h = threadIdx.x; // head_element_id + + const int qk_head_size = blockDim.x; + const int num_heads = blockDim.y; + + const int sequence_length = gridDim.x; + const int batch_size = gridDim.y; + + const int head_size = (m == 2 ? v_head_size : qk_head_size); + + const int total_head_size = num_heads * (qk_head_size + qk_head_size + v_head_size); + + int in_offset; + int out_offset; + int bias_offset; + in_offset = b * (total_head_size * sequence_length) + // B + s * (total_head_size) + // S + m * (qk_head_size * num_heads) + // M + n * head_size + // N + h; // H + + out_offset = m * (num_heads * qk_head_size * sequence_length * batch_size) + // M + b * (num_heads * head_size * sequence_length) + // B + s * (num_heads * head_size) + // S + n * (head_size) + // N + h; // H + + bias_offset = m * (num_heads * qk_head_size) + // M + n * (head_size) + // N + h; // H + + if (h < head_size) { + output[out_offset] = input[in_offset] + biases[bias_offset]; + } +} + +template +__global__ void AddBiasTransposeCutlass(int M, const T* input, const T* biases, T* output) { + // Format 3 for cutlass memory efficient attention + // Input: BxSxMxNxH + // Output: MxBxSxNxH + // B is batch_size, S is sequence_length, M is number of matrices, N is num_heads, H is head_size + int n = threadIdx.y; + int s = blockIdx.x; + int b = blockIdx.y; + int m = blockIdx.z; // matrix id + + const int head_size = blockDim.x; + const int num_heads = blockDim.y; + + const int sequence_length = gridDim.x; + const int batch_size = gridDim.y; + const int H = head_size; + const int NH = num_heads * head_size; + const int NHS = NH * sequence_length; + + int in_offset = n * head_size + (m + s * M) * NH + b * NHS * M; + const int out_offset = n * head_size + s * NH + b * NHS + m * NHS * batch_size; + + const int h = threadIdx.x; + if (h < head_size) { + output[out_offset + h] = input[in_offset + h] + biases[m * NH + n * H + h]; + } +} + +template +__global__ void AddBiasTransposeCutlassLarge(const int head_size, const T* input, const T* biases, T* output, + const int M) { + // Format 3 for cutlass memory efficient attention + // Input: BxSxMxNxH (Packed QKV) + // Output: MxBxSxNxH + int n = threadIdx.y; + int s = blockIdx.x; + int b = blockIdx.y; + int m = blockIdx.z; // matrix id + + const int stride = blockDim.x; + const int num_heads = blockDim.y; + + const int sequence_length = gridDim.x; + const int batch_size = gridDim.y; + const int H = head_size; + const int NH = num_heads * H; + const int NHS = NH * sequence_length; + int in_offset = n * H + (m + s * M) * NH + b * NHS * M; + const int out_offset = n * H + s * NH + b * NHS + m * NHS * batch_size; + + int h = threadIdx.x; + while (h < H) { + output[out_offset + h] = input[in_offset + h] + biases[m * NH + n * H + h]; + h += stride; + } +} + template __global__ void AddBiasTranspose(const T* input, const T* biases, T* output) { // Format 0 for Separated Q, K, V (N*H <= 1024) @@ -395,6 +499,13 @@ void InvokeAddBiasTranspose( ORT_ENFORCE(total_matrix_count == 3); AddBiasTransposeQKV<<>>(input, biases, output, v_head_size); } + } else if (format == 3) { + if (v_head_size == -1 || qk_head_size == v_head_size) { + AddBiasTransposeCutlass<<>>(total_matrix_count, input, biases, output); + } else { + ORT_ENFORCE(total_matrix_count == 3); + AddBiasTransposeCutlass<<>>(input, biases, output, v_head_size); + } } else { // format == 0 AddBiasTranspose<<>>(input, biases, output); } @@ -410,6 +521,13 @@ void InvokeAddBiasTranspose( // It is rare for hidden size > 4096 (for half precision) and qk_head_size != v_head_size. ORT_THROW("AddBiasTranspose (format 1) not implemented for hidden_size > max_threads_per_block when qk_head_size != v_head_size"); } + } else if (format == 3) { + if (v_head_size == -1 || qk_head_size == v_head_size) { + AddBiasTransposeCutlassLarge<<>>(qk_head_size, input, biases, output, + total_matrix_count); + } else { + ORT_THROW("AddBiasTranspose (format 3) not implemented for hidden_size > max_threads_per_block when qk_head_size != v_head_size"); + } } else { // format 0 AddBiasTransposeLarge<<>>(qk_head_size, input, biases, output); } diff --git a/onnxruntime/contrib_ops/cuda/bert/add_bias_transpose.h b/onnxruntime/contrib_ops/cuda/bert/add_bias_transpose.h index a0b5ac146ba7c..8cc36637054e7 100644 --- a/onnxruntime/contrib_ops/cuda/bert/add_bias_transpose.h +++ b/onnxruntime/contrib_ops/cuda/bert/add_bias_transpose.h @@ -21,6 +21,9 @@ namespace cuda { // format 2: // input : (batch_size, sequence_length, num_matrices, num_heads, head_size) // output: (batch_size, sequence_length, num_heads, num_matrices, head_size) +// format 3: (requires sequence_length = kv_sequence_length and qk_head_size = v_head_size when num_matrices == 3) +// input: (batch_size, sequence_length, num_matrices, num_heads, head_size) +// output: (num_matrices, batch_size, sequence_length, num_heads, head_size) template void LaunchAddBiasTranspose( cudaStream_t stream, const int num_matrices, const int format, const int max_threads_per_block, diff --git a/onnxruntime/contrib_ops/cuda/bert/attention.cc b/onnxruntime/contrib_ops/cuda/bert/attention.cc index 3201ad1bcff27..3abacc935d8f0 100644 --- a/onnxruntime/contrib_ops/cuda/bert/attention.cc +++ b/onnxruntime/contrib_ops/cuda/bert/attention.cc @@ -7,6 +7,7 @@ #include "contrib_ops/cuda/bert/attention_impl.h" #include "contrib_ops/cuda/bert/attention.h" #include "contrib_ops/cuda/bert/bert_padding.h" +#include "contrib_ops/cuda/bert/cutlass_fmha/memory_efficient_attention.h" using namespace onnxruntime::cuda; using namespace ::onnxruntime::common; @@ -41,8 +42,14 @@ Attention::Attention(const OpKernelInfo& info) : CudaKernel(info), AttentionB disable_fused_runner_ = sizeof(T) != 2 || ParseEnvironmentVariableWithDefault(attention::kDisableFusedAttention, false); - enable_flash_attention_ = sizeof(T) == 2 && - !ParseEnvironmentVariableWithDefault(attention::kDisableFlashAttention, false); + enable_trt_flash_attention_ = sizeof(T) == 2 && + !ParseEnvironmentVariableWithDefault(attention::kDisableTrtFlashAttention, false); + +#if USE_FLASH_ATTENTION + disable_memory_efficient_attention_ = ParseEnvironmentVariableWithDefault(attention::kDisableMemoryEfficientAttention, false); +#else + disable_memory_efficient_attention_ = true; +#endif } template @@ -102,12 +109,12 @@ Status Attention::ComputeInternal(OpKernelContext* context) const { parameters.hidden_size == parameters.v_hidden_size && parameters.sequence_length == parameters.kv_sequence_length && FusedMHARunnerFP16v2::is_supported(sm, parameters.head_size, sequence_length, - enable_flash_attention_, true); + enable_trt_flash_attention_, true); if (use_causal_fused_runner) { // Here we assume that num_heads, head_size and is_unidirectional does not change for an Attention node. if (nullptr == fused_fp16_runner_.get()) { fused_fp16_runner_.reset(new FusedMHARunnerFP16v2(num_heads_, parameters.head_size, sm, is_unidirectional_, - enable_flash_attention_, parameters.scale)); + enable_trt_flash_attention_, parameters.scale)); } // Here we assume all causal kernels can be loaded into shared memory. TODO: add a function to check. @@ -122,13 +129,13 @@ Status Attention::ComputeInternal(OpKernelContext* context) const { parameters.hidden_size == parameters.v_hidden_size && parameters.sequence_length == parameters.kv_sequence_length && FusedMHARunnerFP16v2::is_supported(sm, parameters.head_size, sequence_length, - enable_flash_attention_, false); + enable_trt_flash_attention_, false); if (use_fused_runner) { // Here we assume that num_heads, head_size and is_unidirectional does not change for an Attention node. if (nullptr == fused_fp16_runner_.get()) { fused_fp16_runner_.reset(new FusedMHARunnerFP16v2(num_heads_, parameters.head_size, sm, is_unidirectional_, - enable_flash_attention_, parameters.scale)); + enable_trt_flash_attention_, parameters.scale)); } // In case some kernel not loaded due to shared memory limit, we need to double check here. @@ -139,6 +146,18 @@ Status Attention::ComputeInternal(OpKernelContext* context) const { } } +#if USE_FLASH_ATTENTION + bool use_memory_efficient_attention = fused_runner == nullptr && + !disable_memory_efficient_attention_ && + nullptr == mask_index && // TODO: support 1D mask + nullptr == past && + nullptr == present && + nullptr == extra_add_qk && + has_memory_efficient_attention(sm, sizeof(T) == 2); +#else + constexpr bool use_memory_efficient_attention = false; +#endif + cublasHandle_t cublas = GetCublasHandle(context); typedef typename ToCudaType::MappedType CudaT; @@ -169,7 +188,8 @@ Status Attention::ComputeInternal(OpKernelContext* context) const { parameters.sequence_length, parameters.kv_sequence_length, parameters.total_sequence_length, - fused_runner); + fused_runner, + use_memory_efficient_attention); auto work_space = GetScratchBuffer(workSpaceSize, context->GetComputeStream()); typedef typename ToCudaType::MappedType CudaT; @@ -188,6 +208,7 @@ Status Attention::ComputeInternal(OpKernelContext* context) const { data.present = (nullptr == present) ? nullptr : reinterpret_cast(present->MutableData()); data.fused_runner = reinterpret_cast(fused_runner); data.fused_cross_attention_kernel = nullptr; + data.use_memory_efficient_attention = use_memory_efficient_attention; return QkvToContext(device_prop, cublas, Stream(context), parameters, data); } diff --git a/onnxruntime/contrib_ops/cuda/bert/attention.h b/onnxruntime/contrib_ops/cuda/bert/attention.h index 7c95234337766..13b2019b21d0d 100644 --- a/onnxruntime/contrib_ops/cuda/bert/attention.h +++ b/onnxruntime/contrib_ops/cuda/bert/attention.h @@ -22,7 +22,8 @@ class Attention final : public CudaKernel, public AttentionBase { protected: bool disable_fused_runner_; - bool enable_flash_attention_; + bool enable_trt_flash_attention_; + bool disable_memory_efficient_attention_; mutable std::unique_ptr fused_fp16_runner_; }; diff --git a/onnxruntime/contrib_ops/cuda/bert/attention_impl.cu b/onnxruntime/contrib_ops/cuda/bert/attention_impl.cu index 091e8de13aa32..695cc2200fe7a 100644 --- a/onnxruntime/contrib_ops/cuda/bert/attention_impl.cu +++ b/onnxruntime/contrib_ops/cuda/bert/attention_impl.cu @@ -40,6 +40,7 @@ limitations under the License. #include "contrib_ops/cpu/bert/attention_base.h" #include "contrib_ops/cuda/bert/bert_padding.h" #include "contrib_ops/cuda/transformers/dump_cuda_tensor.h" +#include "contrib_ops/cuda/bert/cutlass_fmha/memory_efficient_attention.h" using namespace onnxruntime::cuda; using namespace cub; @@ -101,11 +102,25 @@ size_t GetAttentionWorkspaceSize( size_t sequence_length, size_t kv_sequence_length, size_t total_sequence_length, - void* fused_runner) { + void* fused_runner, + bool use_memory_efficient_attention) { // Note that q, k and v might need alignment for fused attention kernels. const size_t qkv_bytes = element_size * batch_size * num_heads * ((sequence_length + kv_sequence_length) * qk_head_size + kv_sequence_length * v_head_size); +#if USE_FLASH_ATTENTION + if (use_memory_efficient_attention) { + size_t fmha_buffer_bytes = 0; + if (MemoryEfficientAttentionParams::need_workspace(v_head_size, element_size == sizeof(float))) { + fmha_buffer_bytes = batch_size * sequence_length * num_heads * v_head_size * sizeof(float); + } + + return qkv_bytes + fmha_buffer_bytes; + } +#else + ORT_UNUSED_PARAMETER(use_memory_efficient_attention); +#endif + if (fused_runner != nullptr) { size_t sequence_offset_bytes = GetSequenceOffsetSize(static_cast(batch_size), true); return qkv_bytes + sequence_offset_bytes; @@ -259,12 +274,15 @@ Status PrepareQkv(contrib::AttentionParameters& parameters, const int v_head_size = parameters.v_head_size; const bool past_present_share_buffer = parameters.past_present_share_buffer; void* fused_runner = data.fused_runner; + bool use_memory_efficient_attention = data.use_memory_efficient_attention; T* qkv = data.workspace; bool use_fused_kernel = (nullptr != fused_runner && data.bias != nullptr && !parameters.is_unidirectional); bool use_fused_causal = (nullptr != fused_runner && parameters.is_unidirectional); + // Default format for memory efficient attention. + // When there is past state, the format shal be BxNxSxH, so we disable memory efficient attention when there is past. DUMP_ATTENTION_INIT(); if (nullptr != data.gemm_buffer) { if (data.bias == nullptr) { @@ -277,13 +295,16 @@ Status PrepareQkv(contrib::AttentionParameters& parameters, qkv_format = AttentionQkvFormat::Q_K_V_BNSH; } else { // For fused TRT attention, transpose qkv to BxSxNx3xH (format 2) + // For memory efficient attention, transpose to 3xBxSxNxH (format 3) // For unfused kernel, transpose to 3xBxNxSxH (format 1) - // For fused causal kernel, use format 1 since we need have K and V in BNSH format to update present state, - // we also update gemm_buffer BxSx3xNxH with bias which is used as input for fused causal kernel. - const int format = (use_fused_kernel ? 2 : 1); + // For fused causal kernel, use format 1 since we need have K and V to update present state, + // at the same time, we update gemm_buffer BxSx3xNxH with bias which is used as input for fused causal kernel. + const int format = (use_fused_kernel ? 2 : (use_memory_efficient_attention ? 3 : 1)); qkv_format = use_fused_kernel ? AttentionQkvFormat::QKV_BSN3H - : (use_fused_causal ? AttentionQkvFormat::Q_K_V_BNSH_QKV_BS3NH : AttentionQkvFormat::Q_K_V_BNSH); + : (use_memory_efficient_attention + ? AttentionQkvFormat::Q_K_V_BSNH + : (use_fused_causal ? AttentionQkvFormat::Q_K_V_BNSH_QKV_BS3NH : AttentionQkvFormat::Q_K_V_BNSH)); // For fused causal, we will update gemm_buffer with bias directly. T* qkv_add_bias = use_fused_causal ? data.gemm_buffer : nullptr; @@ -318,7 +339,21 @@ Status PrepareQkv(contrib::AttentionParameters& parameters, data.bias, data.query, data.key, data.value, qkv, true, kv_sequence_length); qkv_format = AttentionQkvFormat::Q_KV_BSNH_BSN2H; - } else if (use_fused_kernel) { + } +#if USE_FLASH_ATTENTION + else if (use_memory_efficient_attention) { + LaunchAddBias(stream, max_threads_per_block, + batch_size, sequence_length, kv_sequence_length, + num_heads, qk_head_size, v_head_size, + data.bias, data.query, data.key, data.value, q, k, v); + + DUMP_ATTENTION_D("q(BSNH)", q, batch_size * sequence_length, num_heads, qk_head_size); + DUMP_ATTENTION_D("k(BSNH)", k, batch_size * kv_sequence_length, num_heads, qk_head_size); + DUMP_ATTENTION_D("v(BSNH)", v, batch_size * kv_sequence_length, num_heads, v_head_size); + qkv_format = AttentionQkvFormat::Q_K_V_BSNH; + } +#endif + else if (use_fused_kernel) { assert(qk_head_size == v_head_size); // Q (BxSxNxH), K (BxSxNxH), V (BxSxNxH) => BxSxNx(H + H + H) @@ -382,9 +417,10 @@ Status QkvToContext( const bool past_present_share_buffer = parameters.past_present_share_buffer; const float mask_filter_value = parameters.mask_filter_value; void* fused_runner = data.fused_runner; + bool use_memory_efficient_attention = data.use_memory_efficient_attention; // At most one fused kernel is enabled. - assert(int(fused_runner != nullptr) + int(data.fused_cross_attention_kernel != nullptr) <= 1); + assert(int(use_memory_efficient_attention) + int(fused_runner != nullptr) + int(data.fused_cross_attention_kernel != nullptr) <= 1); const int batches = batch_size * num_heads; const int size_per_batch_q = sequence_length * qk_head_size; @@ -433,6 +469,7 @@ Status QkvToContext( assert(data.fused_cross_attention_kernel == nullptr); assert(!use_fused_kernel); assert(data.gemm_buffer != nullptr); + assert(!use_memory_efficient_attention); if (data.present != data.past) { // For easy testing. Production should better avoid this path. @@ -526,6 +563,38 @@ Status QkvToContext( return Status::OK(); } +#if USE_FLASH_ATTENTION + if (use_memory_efficient_attention) { + // We only enable fused cross attention when there is no key padding mask. + // Otherwise, key have effective batch size 2 * batch_size, which is different from batch_size of query. + assert(data.mask_index == nullptr); + assert(qkv_format == AttentionQkvFormat::Q_K_V_BSNH); + + MemoryEfficientAttentionParams p; + p.sm = device_prop.major * 10 + device_prop.minor; + p.is_half = sizeof(T) == 2; + p.batch_size = data.mask_index == nullptr ? parameters.batch_size : 2 * parameters.batch_size; + p.num_heads = parameters.num_heads; + p.sequence_length = parameters.sequence_length; + p.kv_sequence_length = parameters.total_sequence_length; + p.qk_head_size = parameters.head_size; + p.v_head_size = parameters.v_head_size; + p.causal = parameters.is_unidirectional; + p.cu_seqlens_q = nullptr; + p.cu_seqlens_k = nullptr; + p.query = q; + p.key = k; + p.value = v; + p.output = data.output; + p.workspace = MemoryEfficientAttentionParams::need_workspace(v_head_size, sizeof(T) == sizeof(float)) ? scratch1 : nullptr; + p.stream = stream; + run_memory_efficient_attention(p); + + DUMP_ATTENTION("cutlass output", data.output, batch_size * sequence_length, num_heads, v_head_size); + return Status::OK(); + } +#endif + // The following are unfused attention. assert(qkv_format == AttentionQkvFormat::Q_K_V_BNSH); const int* mask_index = data.mask_index; diff --git a/onnxruntime/contrib_ops/cuda/bert/attention_impl.h b/onnxruntime/contrib_ops/cuda/bert/attention_impl.h index e6f8a3c518ed9..d98a0380c479b 100644 --- a/onnxruntime/contrib_ops/cuda/bert/attention_impl.h +++ b/onnxruntime/contrib_ops/cuda/bert/attention_impl.h @@ -27,7 +27,8 @@ size_t GetAttentionWorkspaceSize( size_t sequence_length, size_t kv_sequence_length, size_t total_sequence_length, - void* fused_runner); + void* fused_runner, + bool use_memory_efficient_attention = false); template struct AttentionData { @@ -48,6 +49,8 @@ struct AttentionData { void* fused_runner; const void* fused_cross_attention_kernel; + + bool use_memory_efficient_attention; }; template diff --git a/onnxruntime/contrib_ops/cuda/bert/cutlass_fmha/fmha_launch_template.h b/onnxruntime/contrib_ops/cuda/bert/cutlass_fmha/fmha_launch_template.h new file mode 100644 index 0000000000000..17f4665a80f77 --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/bert/cutlass_fmha/fmha_launch_template.h @@ -0,0 +1,116 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#if USE_FLASH_ATTENTION + +#if defined(__GNUC__) +#pragma GCC diagnostic push +#pragma GCC diagnostic ignored "-Wunused-parameter" +#pragma GCC diagnostic ignored "-Wstrict-aliasing" +#endif + +#include "contrib_ops/cuda/bert/cutlass_fmha/memory_efficient_attention.h" +#include "contrib_ops/cuda/bert/cutlass_fmha/kernel_forward.h" + +namespace onnxruntime { +namespace contrib { +namespace cuda { + +template +void LaunchCutlassFmha(const MemoryEfficientAttentionParams& params) { + using Attention = AttentionKernel; + typename Attention::Params p; + { // set parameters + p.query_ptr = const_cast(reinterpret_cast(params.query)); + p.key_ptr = const_cast(reinterpret_cast(params.key)); + p.value_ptr = const_cast(reinterpret_cast(params.value)); + p.cu_seqlens_q_ptr = params.cu_seqlens_q; + p.cu_seqlens_k_ptr = params.cu_seqlens_k; + + p.logsumexp_ptr = nullptr; // [num_heads, num_queries] for backward or nullptr for forward + p.output_ptr = reinterpret_cast(params.output); + if (Attention::kNeedsOutputAccumulatorBuffer) { + using Acc = typename Attention::accum_t; + // workspace size: batch_size * sequence_length * num_heads * v_head_size * sizeof(float) + ORT_ENFORCE(params.workspace != nullptr, "Need output accumulator buffer but no workspace provided"); + p.output_accum_ptr = reinterpret_cast(params.workspace); + } else { + p.output_accum_ptr = nullptr; + } + p.num_heads = params.num_heads; + p.num_batches = params.batch_size; + p.head_dim = params.qk_head_size; + p.head_dim_value = params.v_head_size; + + // When params.cu_seqlens_q is provided, num_queries is max_seq_q and num_keys will be set inside the kernel + p.num_queries = params.sequence_length; + p.num_keys = params.kv_sequence_length; + + p.causal = params.causal; + + // Input format is BxSxNxH, output is BxSxNxH + p.q_strideH = params.qk_head_size; + p.k_strideH = params.qk_head_size; + p.v_strideH = params.v_head_size; + p.o_strideH = params.v_head_size; + + p.q_strideM = params.num_heads * params.qk_head_size; + p.k_strideM = params.num_heads * params.qk_head_size; + p.v_strideM = params.num_heads * params.v_head_size; + + p.q_strideB = static_cast(p.q_strideM) * params.sequence_length; + p.k_strideB = static_cast(p.k_strideM) * params.kv_sequence_length; + p.v_strideB = static_cast(p.v_strideM) * params.kv_sequence_length; + p.o_strideB = static_cast(params.num_heads) * params.v_head_size * params.sequence_length; + + p.causal = params.causal; + } + + constexpr auto kernel_fn = attention_kernel_batched_impl; + int smem_bytes = sizeof(typename Attention::SharedStorage); + if (smem_bytes > 0xc000) { + ORT_ENFORCE(params.sm >= 70, "This kernel requires too much shared memory on this machine!"); + static bool once = [&]() { + cudaFuncSetAttribute(kernel_fn, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_bytes); + return true; + }(); + } + + ORT_ENFORCE(Attention::check_supported(p)); + kernel_fn<<>>(p); +} + +template +void DispatchIsAligned(const MemoryEfficientAttentionParams& params) { + using AlignedAK = AttentionKernel; + + // Run a more efficient kernel with `isAligned=True` when memory is correctly aligned. + bool is_aligned = params.qk_head_size % AlignedAK::kAlignmentQ == 0 && + params.qk_head_size % AlignedAK::kAlignmentK == 0 && + params.v_head_size % AlignedAK::kAlignmentV == 0; + + DISPATCH_BOOL(is_aligned, kIsAligned, ([&]() { + LaunchCutlassFmha(params); + })); +} + +template +void DispatchBlockSize(const MemoryEfficientAttentionParams& params) { + if (params.v_head_size <= 64) { + DispatchIsAligned(params); + } else if (params.v_head_size <= 128) { + DispatchIsAligned(params); + } else { + DispatchIsAligned(params); + } +} + +} // namespace cuda +} // namespace contrib +} // namespace onnxruntime + +#if defined(__GNUC__) +#pragma GCC diagnostic pop +#endif + +#endif // USE_FLASH_ATTENTION diff --git a/onnxruntime/contrib_ops/cuda/bert/cutlass_fmha/fmha_sm50.cu b/onnxruntime/contrib_ops/cuda/bert/cutlass_fmha/fmha_sm50.cu new file mode 100644 index 0000000000000..237f7ea8c9c42 --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/bert/cutlass_fmha/fmha_sm50.cu @@ -0,0 +1,24 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#if USE_FLASH_ATTENTION + +#include "contrib_ops/cuda/bert/cutlass_fmha/fmha_launch_template.h" + +namespace onnxruntime { +namespace contrib { +namespace cuda { + +void run_memory_efficient_attention_sm50(const MemoryEfficientAttentionParams& params) { + if (params.is_half) { + DispatchBlockSize(params); + } else { + DispatchBlockSize(params); + } +} + +} // namespace cuda +} // namespace contrib +} // namespace onnxruntime + +#endif // USE_FLASH_ATTENTION diff --git a/onnxruntime/contrib_ops/cuda/bert/cutlass_fmha/fmha_sm70.cu b/onnxruntime/contrib_ops/cuda/bert/cutlass_fmha/fmha_sm70.cu new file mode 100644 index 0000000000000..941ea87baa398 --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/bert/cutlass_fmha/fmha_sm70.cu @@ -0,0 +1,24 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#if USE_FLASH_ATTENTION + +#include "contrib_ops/cuda/bert/cutlass_fmha/fmha_launch_template.h" + +namespace onnxruntime { +namespace contrib { +namespace cuda { + +void run_memory_efficient_attention_sm70(const MemoryEfficientAttentionParams& params) { + if (params.is_half) { + DispatchBlockSize(params); + } else { + DispatchBlockSize(params); + } +} + +} // namespace cuda +} // namespace contrib +} // namespace onnxruntime + +#endif // USE_FLASH_ATTENTION diff --git a/onnxruntime/contrib_ops/cuda/bert/cutlass_fmha/fmha_sm75.cu b/onnxruntime/contrib_ops/cuda/bert/cutlass_fmha/fmha_sm75.cu new file mode 100644 index 0000000000000..5a0e7c9ed5b7a --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/bert/cutlass_fmha/fmha_sm75.cu @@ -0,0 +1,24 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#if USE_FLASH_ATTENTION + +#include "contrib_ops/cuda/bert/cutlass_fmha/fmha_launch_template.h" + +namespace onnxruntime { +namespace contrib { +namespace cuda { + +void run_memory_efficient_attention_sm75(const MemoryEfficientAttentionParams& params) { + if (params.is_half) { + DispatchBlockSize(params); + } else { + DispatchBlockSize(params); + } +} + +} // namespace cuda +} // namespace contrib +} // namespace onnxruntime + +#endif // USE_FLASH_ATTENTION diff --git a/onnxruntime/contrib_ops/cuda/bert/cutlass_fmha/fmha_sm80.cu b/onnxruntime/contrib_ops/cuda/bert/cutlass_fmha/fmha_sm80.cu new file mode 100644 index 0000000000000..d0775a29c4cf1 --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/bert/cutlass_fmha/fmha_sm80.cu @@ -0,0 +1,24 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#if USE_FLASH_ATTENTION + +#include "contrib_ops/cuda/bert/cutlass_fmha/fmha_launch_template.h" + +namespace onnxruntime { +namespace contrib { +namespace cuda { + +void run_memory_efficient_attention_sm80(const MemoryEfficientAttentionParams& params) { + if (params.is_half) { + DispatchBlockSize(params); + } else { + DispatchBlockSize(params); + } +} + +} // namespace cuda +} // namespace contrib +} // namespace onnxruntime + +#endif // USE_FLASH_ATTENTION diff --git a/onnxruntime/contrib_ops/cuda/bert/cutlass_fmha/kernel_forward.h b/onnxruntime/contrib_ops/cuda/bert/cutlass_fmha/kernel_forward.h new file mode 100644 index 0000000000000..7885983f99ea6 --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/bert/cutlass_fmha/kernel_forward.h @@ -0,0 +1,947 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holdvr nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +#pragma once + +#if USE_FLASH_ATTENTION + +#include +#include + +#include "cutlass/bfloat16.h" +#include "cutlass/gemm/gemm.h" +#include "cutlass/layout/matrix.h" +#include "cutlass/layout/vector.h" +#include "cutlass/numeric_types.h" + +#include "41_fused_multi_head_attention/attention_scaling_coefs_updater.h" +#include "cutlass/epilogue/threadblock/default_epilogue_simt.h" +#include "cutlass/epilogue/threadblock/default_epilogue_tensor_op.h" +#include "cutlass/epilogue/threadblock/default_epilogue_volta_tensor_op.h" +#include "cutlass/gemm/device/default_gemm_configuration.h" +#include "cutlass/gemm/kernel/default_gemm.h" +#include "cutlass/gemm/threadblock/default_mma.h" +#include "cutlass/gemm/threadblock/default_mma_core_simt.h" +#include "cutlass/gemm/threadblock/default_mma_core_sm70.h" +#include "cutlass/gemm/threadblock/default_mma_core_sm75.h" +#include "cutlass/gemm/threadblock/default_mma_core_sm80.h" +#include "cutlass/gemm/threadblock/threadblock_swizzle.h" +#include "cutlass/matrix_shape.h" +#include "cutlass/platform/platform.h" +#include "cutlass/transform/threadblock/predicated_tile_iterator.h" +#include "41_fused_multi_head_attention/debug_utils.h" +#include "41_fused_multi_head_attention/epilogue_pipelined.h" +#include "41_fused_multi_head_attention/epilogue_rescale_output.h" +#include "41_fused_multi_head_attention/find_default_mma.h" +#include "41_fused_multi_head_attention/gemm_kernel_utils.h" +#include "41_fused_multi_head_attention/mma_from_smem.h" + +#include + +using namespace gemm_kernel_utils; + +namespace { +template +constexpr int getWarpsPerSm() { + return ( + Arch::kMinComputeCapability >= 80 && + !cutlass::platform::is_same::value + ? 16 + : 12); +} +} // namespace + +template < + // The datatype of Q/K/V + typename scalar_t_, + // Architecture we are targeting (eg `cutlass::arch::Sm80`) + typename ArchTag, + // If Q/K/V are correctly aligned in memory and we can run a fast kernel + bool isAligned_, + int kQueriesPerBlock, + int kKeysPerBlock, + bool kSingleValueIteration // = `value.shape[-1] <= kKeysPerBlock` + > +struct AttentionKernel { + using scalar_t = scalar_t_; + using accum_t = float; + using lse_scalar_t = float; + using output_t = scalar_t; + // Accumulator between 2 iterations + // Using `accum_t` improves perf on f16 at the cost of + // numerical errors + using output_accum_t = accum_t; + static constexpr bool kIsAligned = isAligned_; + static constexpr int32_t kAlignLSE = 32; // block size of backward + static constexpr bool kPreloadV = ArchTag::kMinComputeCapability >= 80 && + cutlass::sizeof_bits::value == 16; + static constexpr bool kKeepOutputInRF = kSingleValueIteration; + static constexpr bool kNeedsOutputAccumulatorBuffer = !kKeepOutputInRF && + !cutlass::platform::is_same::value; + + static_assert(kQueriesPerBlock % 32 == 0, ""); + static_assert(kKeysPerBlock % 32 == 0, ""); + static constexpr int kNumWarpsPerBlock = + kQueriesPerBlock * kKeysPerBlock / (32 * 32); + static constexpr int kWarpSize = 32; + + // Launch bounds + static constexpr int kNumThreads = kWarpSize * kNumWarpsPerBlock; + static constexpr int kMinBlocksPerSm = + getWarpsPerSm() / kNumWarpsPerBlock; + + struct Params { + // Input tensors + scalar_t* query_ptr; // [num_queries, num_heads, head_dim] + scalar_t* key_ptr; // [num_keys, num_heads, head_dim] + scalar_t* value_ptr; // [num_keys, num_heads, head_dim_value] + int32_t* cu_seqlens_q_ptr = nullptr; + int32_t* cu_seqlens_k_ptr = nullptr; + + // Output tensors + output_t* output_ptr; // [num_queries, num_heads, head_dim_value] + output_accum_t* + output_accum_ptr; // [num_queries, num_heads, head_dim_value] + lse_scalar_t* logsumexp_ptr; // [num_heads, num_queries] - can be null + + // Dimensions/strides + int32_t head_dim; + int32_t head_dim_value; + int32_t num_queries; + int32_t num_keys; + + bool causal; + + int32_t q_strideM; + int32_t k_strideM; + int32_t v_strideM; + + // Everything below is only used in `advance_to_block` + // and shouldn't use registers + int32_t q_strideH; + int32_t k_strideH; + int32_t v_strideH; + int32_t o_strideH; + int64_t q_strideB; + int64_t k_strideB; + int64_t v_strideB; + int64_t o_strideB; + int32_t num_batches; + int32_t num_heads; + + // https://github.com/NVIDIA/cutlass/issues/771 + CUTLASS_HOST_DEVICE int32_t o_strideM() const { + return head_dim_value * num_heads; + } + + // Moves pointers to what we should process + // Returns "false" if there is no work to do + CUTLASS_DEVICE bool advance_to_block() { + auto batch_id = blockIdx.z; + auto head_id = blockIdx.y; + auto query_start = blockIdx.x * kQueriesPerBlock; + + auto lse_dim = ceil_div((int32_t)num_queries, kAlignLSE) * kAlignLSE; + + int64_t q_start, k_start; + // Advance to current batch - in case of different sequence lengths + if (cu_seqlens_q_ptr != nullptr) { + assert(cu_seqlens_k_ptr != nullptr); + cu_seqlens_q_ptr += batch_id; + cu_seqlens_k_ptr += batch_id; + q_start = cu_seqlens_q_ptr[0]; + k_start = cu_seqlens_k_ptr[0]; + int64_t q_next_start = cu_seqlens_q_ptr[1]; + int64_t k_next_start = cu_seqlens_k_ptr[1]; + num_queries = q_next_start - q_start; + num_keys = k_next_start - k_start; + + if (query_start >= num_queries) { + return false; + } + } else { + query_ptr += batch_id * q_strideB; + key_ptr += batch_id * k_strideB; + value_ptr += batch_id * v_strideB; + output_ptr += batch_id * o_strideB; + if (output_accum_ptr != nullptr) { + output_accum_ptr += batch_id * o_strideB; + } + q_start = 0; + k_start = 0; + } + + // Advance to the current batch / head / query_start + query_ptr += (q_start + query_start) * q_strideM + head_id * q_strideH; + key_ptr += k_start * k_strideM + head_id * k_strideH; + value_ptr += k_start * v_strideM + head_id * v_strideH; + output_ptr += int64_t(q_start + query_start) * o_strideM() + + head_id * o_strideH; + + if (output_accum_ptr != nullptr) { + output_accum_ptr += int64_t(q_start + query_start) * o_strideM() + + head_id * o_strideH; + } else { + // Accumulate directly in the destination buffer (eg for f32) + output_accum_ptr = (accum_t*)output_ptr; + } + if (logsumexp_ptr != nullptr) { + // lse[batch_id, head_id, query_start] + logsumexp_ptr += + batch_id * lse_dim * num_heads + head_id * lse_dim + query_start; + } + + num_queries -= query_start; + if (causal) { + num_keys = cutlass::fast_min( + int32_t(query_start + kQueriesPerBlock), num_keys); + } + num_batches = 0; // no longer used after + + // Make sure the compiler knows these variables are the same on all + // the threads of the warp. + query_ptr = warp_uniform(query_ptr); + key_ptr = warp_uniform(key_ptr); + value_ptr = warp_uniform(value_ptr); + output_ptr = warp_uniform(output_ptr); + output_accum_ptr = warp_uniform(output_accum_ptr); + logsumexp_ptr = warp_uniform(logsumexp_ptr); + num_queries = warp_uniform(num_queries); + num_keys = warp_uniform(num_keys); + head_dim = warp_uniform(head_dim); + head_dim_value = warp_uniform(head_dim_value); + return true; + } + + __host__ dim3 getBlocksGrid() const { + return dim3( + ceil_div(num_queries, (int32_t)kQueriesPerBlock), + num_heads, + num_batches); + } + __host__ dim3 getThreadsGrid() const { + return dim3(kWarpSize, kNumWarpsPerBlock, 1); + } + }; + + struct MM0 { + /* + In this first matmul, we compute a block of `Q @ K.T`. + While the calculation result is still hot in registers, we update + `mi`, `m_prime`, `s_prime` in shared-memory, and then store this value + into a shared-memory ("AccumulatorSharedStorage") that is used later as + operand A for the second matmul (see MM1) + */ + using GemmType = DefaultGemmType; + + using OpClass = typename GemmType::OpClass; + using DefaultConfig = + typename cutlass::gemm::device::DefaultGemmConfiguration< + OpClass, + ArchTag, + scalar_t, + scalar_t, + scalar_t, // ElementC + accum_t // ElementAccumulator + >; + static constexpr int kAlignmentA = + kIsAligned ? DefaultConfig::kAlignmentA : GemmType::kMinimumAlignment; + static constexpr int kAlignmentB = + kIsAligned ? DefaultConfig::kAlignmentB : GemmType::kMinimumAlignment; + using ThreadblockShape = cutlass::gemm:: + GemmShape; + using WarpShape = cutlass::gemm::GemmShape<32, 32, GemmType::WarpK>; + using DefaultMma = typename cutlass::gemm::threadblock::FindDefaultMma< + scalar_t, // ElementA, + cutlass::layout::RowMajor, // LayoutA, + kAlignmentA, + scalar_t, // ElementB, + cutlass::layout::ColumnMajor, // LayoutB, + kAlignmentB, + accum_t, + cutlass::layout::RowMajor, // LayoutC, + OpClass, + ArchTag, // ArchTag + ThreadblockShape, // ThreadblockShape + WarpShape, // WarpShape + typename GemmType::InstructionShape, // InstructionShape + DefaultConfig::kStages, // Should use `DefaultConfig::kStages`, but that + // uses too much smem + typename GemmType::Operator // Operator + >::DefaultMma; + using MmaCore = typename DefaultMma::MmaCore; + using IteratorA = typename DefaultMma::IteratorA; + using IteratorB = typename DefaultMma::IteratorB; + using Mma = typename DefaultMma::ThreadblockMma; + using ScalingCoefsUpdater = typename DefaultAttentionScalingCoefsUpdater< + typename Mma::Operator::IteratorC, + accum_t, + kWarpSize>::Updater; + static_assert( + MmaCore::WarpCount::kM * MmaCore::WarpCount::kN * + MmaCore::WarpCount::kK == + kNumWarpsPerBlock, + ""); + + // Epilogue to store to shared-memory in a format that we can use later for + // the second matmul + using B2bGemm = typename cutlass::gemm::threadblock::B2bGemm< + typename Mma::Operator::IteratorC, + typename Mma::Operator, + scalar_t, + WarpShape, + ThreadblockShape>; + using AccumulatorSharedStorage = typename B2bGemm::AccumulatorSharedStorage; + }; + + struct MM1 { + /** + Second matmul: perform `attn @ V` where `attn` is the attention (not + normalized) and stored in shared memory + */ + using GemmType = DefaultGemmType; + + using OpClass = typename GemmType::OpClass; + using DefaultConfig = + typename cutlass::gemm::device::DefaultGemmConfiguration< + OpClass, + ArchTag, + scalar_t, + scalar_t, + output_accum_t, // ElementC + accum_t // ElementAccumulator + >; + static constexpr int kAlignmentA = DefaultConfig::kAlignmentA; // from smem + static constexpr int kAlignmentB = + kIsAligned ? DefaultConfig::kAlignmentB : GemmType::kMinimumAlignment; + using ThreadblockShape = cutlass::gemm:: + GemmShape; + using WarpShape = cutlass::gemm::GemmShape<32, 32, GemmType::WarpK>; + using InstructionShape = typename GemmType::InstructionShape; + + using LayoutB = cutlass::layout::RowMajor; + using DefaultGemm = cutlass::gemm::kernel::DefaultGemm< + scalar_t, // ElementA, + cutlass::layout::RowMajor, // LayoutA, + kAlignmentA, + scalar_t, // ElementB, + LayoutB, // LayoutB, + kAlignmentB, + output_accum_t, + cutlass::layout::RowMajor, // LayoutC, + accum_t, + OpClass, + ArchTag, + ThreadblockShape, + WarpShape, + typename GemmType::InstructionShape, + typename DefaultConfig::EpilogueOutputOp, + void, // ThreadblockSwizzle - not used + DefaultConfig::kStages, + false, // SplitKSerial + typename GemmType::Operator>; + + using DefaultMmaFromSmem = + typename cutlass::gemm::threadblock::DefaultMmaFromSharedMemory< + typename DefaultGemm::Mma, + typename MM0::AccumulatorSharedStorage>; + using Mma = typename DefaultMmaFromSmem::Mma; + using IteratorB = typename Mma::IteratorB; + using WarpCount = typename Mma::WarpCount; + static_assert( + WarpCount::kM * WarpCount::kN * WarpCount::kK == kNumWarpsPerBlock, + ""); + + using DefaultEpilogue = typename DefaultGemm::Epilogue; + using OutputTileIterator = + typename cutlass::epilogue::threadblock::PredicatedTileIterator< + typename DefaultEpilogue::OutputTileIterator::ThreadMap, + output_t>; + using OutputTileIteratorAccum = + typename cutlass::epilogue::threadblock::PredicatedTileIterator< + typename DefaultEpilogue::OutputTileIterator::ThreadMap, + output_accum_t>; + + struct SharedStorageMM1 { + typename Mma::SharedStorage mm; + }; + }; + + static constexpr int64_t kAlignmentQ = MM0::kAlignmentA; + static constexpr int64_t kAlignmentK = MM0::kAlignmentB; + static constexpr int64_t kAlignmentV = 1; + + // Shared storage - depends on kernel params + struct ScalingCoefs { + cutlass::Array m_prime; + cutlass::Array s_prime; + cutlass::Array mi; + }; + + struct SharedStorageEpilogueAtEnd : ScalingCoefs { + struct SharedStorageAfterMM0 { + // Everything here might be overwritten during MM0 + typename MM0::AccumulatorSharedStorage si; + typename MM1::SharedStorageMM1 mm1; + }; + + union { + typename MM0::Mma::SharedStorage mm0; + SharedStorageAfterMM0 after_mm0; + typename MM1::DefaultEpilogue::SharedStorage epilogue; + }; + + CUTLASS_DEVICE typename MM1::DefaultEpilogue::SharedStorage& + epilogue_shared_storage() { + return epilogue; + } + }; + + struct SharedStorageEpilogueInLoop : ScalingCoefs { + struct SharedStorageAfterMM0 { + // Everything here might be overwritten during MM0 + typename MM0::AccumulatorSharedStorage si; + typename MM1::SharedStorageMM1 mm1; + typename MM1::DefaultEpilogue::SharedStorage epilogue; + }; + + union { + typename MM0::Mma::SharedStorage mm0; + SharedStorageAfterMM0 after_mm0; + }; + + CUTLASS_DEVICE typename MM1::DefaultEpilogue::SharedStorage& + epilogue_shared_storage() { + return after_mm0.epilogue; + } + }; + + using SharedStorage = typename cutlass::platform::conditional< + kSingleValueIteration || kKeepOutputInRF, + SharedStorageEpilogueAtEnd, + SharedStorageEpilogueInLoop>::type; + + static bool __host__ check_supported(Params const& p) { + CHECK_ALIGNED_PTR(p.query_ptr, kAlignmentQ); + CHECK_ALIGNED_PTR(p.key_ptr, kAlignmentK); + CHECK_ALIGNED_PTR(p.value_ptr, kAlignmentV); + XFORMERS_CHECK( + p.q_strideM % kAlignmentQ == 0, "query is not correctly aligned"); + XFORMERS_CHECK( + p.k_strideM % kAlignmentK == 0, "key is not correctly aligned"); + XFORMERS_CHECK( + p.v_strideM % kAlignmentV == 0, "value is not correctly aligned"); + XFORMERS_CHECK( + p.q_strideH % kAlignmentQ == 0, "query is not correctly aligned"); + XFORMERS_CHECK( + p.k_strideH % kAlignmentK == 0, "key is not correctly aligned"); + XFORMERS_CHECK( + p.v_strideH % kAlignmentV == 0, "value is not correctly aligned"); + return true; + } + + static void CUTLASS_DEVICE attention_kernel(Params& p) { + // In this block, we will only ever: + // - read query[query_start:query_end, :] + // - write to output[query_start:query_end, :] + + extern __shared__ char smem_buffer[]; + SharedStorage& shared_storage = *((SharedStorage*)smem_buffer); + auto& m_prime = shared_storage.m_prime; + auto& s_prime = shared_storage.s_prime; + auto& mi = shared_storage.mi; + + static_assert(kQueriesPerBlock < kNumWarpsPerBlock * kWarpSize, ""); + if (thread_id() < kQueriesPerBlock) { + s_prime[thread_id()] = accum_t(0); + m_prime[thread_id()] = + -cutlass::platform::numeric_limits::infinity(); + mi[thread_id()] = -cutlass::platform::numeric_limits::infinity(); + } + typename MM1::Mma::FragmentC accum_o; + accum_o.clear(); + + auto createOutputIter = [&](int col) -> typename MM1::OutputTileIterator { + using OutputTileIterator = typename MM1::OutputTileIterator; + return OutputTileIterator( + typename OutputTileIterator::Params{(int32_t)p.o_strideM()}, + p.output_ptr, + typename OutputTileIterator::TensorCoord{ + p.num_queries, p.head_dim_value}, + thread_id(), + {0, col}); + }; + + auto createOutputAccumIter = [&](int col) -> + typename MM1::OutputTileIteratorAccum { + using OutputTileIteratorAccum = typename MM1::OutputTileIteratorAccum; + return OutputTileIteratorAccum( + typename OutputTileIteratorAccum::Params{(int32_t)p.o_strideM()}, + p.output_accum_ptr, + typename OutputTileIteratorAccum::TensorCoord{ + p.num_queries, p.head_dim_value}, + thread_id(), + {0, col}); + }; + + // Iterate through keys + for (int32_t iter_key_start = 0; iter_key_start < p.num_keys; + iter_key_start += kKeysPerBlock) { + int32_t problem_size_0_m = + cutlass::fast_min((int32_t)kQueriesPerBlock, p.num_queries); + int32_t problem_size_0_n = cutlass::fast_min( + int32_t(kKeysPerBlock), p.num_keys - iter_key_start); + int32_t const& problem_size_0_k = p.head_dim; + int32_t const& problem_size_1_n = p.head_dim_value; + int32_t const& problem_size_1_k = problem_size_0_n; + + auto prologueV = [&](int blockN) { + typename MM1::Mma::IteratorB iterator_V( + typename MM1::IteratorB::Params{MM1::LayoutB(p.v_strideM)}, + p.value_ptr + iter_key_start * p.v_strideM, + {problem_size_1_k, problem_size_1_n}, + thread_id(), + cutlass::MatrixCoord{0, blockN * MM1::Mma::Shape::kN}); + MM1::Mma::prologue( + shared_storage.after_mm0.mm1.mm, + iterator_V, + thread_id(), + problem_size_1_k); + }; + + __syncthreads(); // Need to have shared memory initialized, and `m_prime` + // updated from end of prev iter + // + // MATMUL: Q.K_t + // + // Computes the block-matrix product of: + // (a) query[query_start:query_end, :] + // with + // (b) key[iter_key_start:iter_key_start + kKeysPerBlock] + // and stores that into `shared_storage.si` + // + + // Compute threadblock location + cutlass::gemm::GemmCoord tb_tile_offset = {0, 0, 0}; + + cutlass::MatrixCoord tb_offset_A{ + tb_tile_offset.m() * MM0::Mma::Shape::kM, tb_tile_offset.k()}; + + cutlass::MatrixCoord tb_offset_B{ + tb_tile_offset.k(), tb_tile_offset.n() * MM0::Mma::Shape::kN}; + + // Construct iterators to A and B operands + typename MM0::IteratorA iterator_A( + typename MM0::IteratorA::Params( + typename MM0::MmaCore::LayoutA(p.q_strideM)), + p.query_ptr, + {problem_size_0_m, problem_size_0_k}, + thread_id(), + tb_offset_A); + + typename MM0::IteratorB iterator_B( + typename MM0::IteratorB::Params( + typename MM0::MmaCore::LayoutB(p.k_strideM)), + p.key_ptr + iter_key_start * p.k_strideM, + {problem_size_0_k, problem_size_0_n}, + thread_id(), + tb_offset_B); + + auto my_warp_id = warp_id(); + auto my_lane_id = lane_id(); + + // Construct thread-scoped matrix multiply + typename MM0::Mma mma( + shared_storage.mm0, thread_id(), my_warp_id, my_lane_id); + + typename MM0::Mma::FragmentC accum; + + accum.clear(); + + auto gemm_k_iterations = + (problem_size_0_k + MM0::Mma::Shape::kK - 1) / MM0::Mma::Shape::kK; + + // Compute threadblock-scoped matrix multiply-add + mma(gemm_k_iterations, accum, iterator_A, iterator_B, accum); + __syncthreads(); + + if (kPreloadV) { + prologueV(0); + } + + typename MM0::Mma::Operator::IteratorC::TensorCoord + iteratorC_tile_offset = { + (tb_tile_offset.m() * MM0::Mma::WarpCount::kM) + + (my_warp_id % MM0::Mma::WarpCount::kM), + (tb_tile_offset.n() * MM0::Mma::WarpCount::kN) + + (my_warp_id / MM0::Mma::WarpCount::kM)}; + + // Mask out last if causal + if (p.causal && p.num_keys - iter_key_start <= kKeysPerBlock) { + auto query_start = blockIdx.x * kQueriesPerBlock; + auto lane_offset = MM0::ScalingCoefsUpdater::get_lane_offset( + lane_id(), warp_id(), iteratorC_tile_offset); + int32_t last_col; + MM0::ScalingCoefsUpdater::iterateRows( + lane_offset, + [&](int accum_m) { + last_col = query_start + accum_m - iter_key_start; + }, + [&](int accum_m, int accum_n, int idx) { + if (accum_n > last_col) { + accum[idx] = + -cutlass::platform::numeric_limits::infinity(); + } + }, + [&](int accum_m) {}); + } + DISPATCH_BOOL(iter_key_start == 0, kIsFirst, ([&] { + DISPATCH_BOOL( + p.num_keys - iter_key_start >= kKeysPerBlock, + kFullColumns, + ([&] { + // Update `mi` from accum stored in registers + // Also updates `accum` with accum[i] <- + // exp(accum[i] * scale + // - mi) + MM0::ScalingCoefsUpdater::update< + kQueriesPerBlock, + kFullColumns, + kIsFirst, + kKeepOutputInRF>( + accum_o, + accum, + mi, + m_prime, + s_prime, + lane_id(), + thread_id(), + warp_id(), + p.num_keys - iter_key_start, + iteratorC_tile_offset, + 1.0f / cutlass::fast_sqrt(float(p.head_dim))); + })); + })); + + // Output results to shared-memory + int warp_idx_mn_0 = my_warp_id % + (MM0::Mma::Base::WarpCount::kM * MM0::Mma::Base::WarpCount::kN); + auto output_tile_coords = cutlass::MatrixCoord{ + warp_idx_mn_0 % MM0::Mma::Base::WarpCount::kM, + warp_idx_mn_0 / MM0::Mma::Base::WarpCount::kM}; + + MM0::B2bGemm::accumToSmem( + shared_storage.after_mm0.si, accum, my_lane_id, output_tile_coords); + + __syncthreads(); + + // + // MATMUL: Attn . V + // Run the matmul `attn @ V` for a block of attn and V. + // `attn` is read from shared memory (in `shared_storage_si`) + // `V` is read from global memory (with iterator_B) + // + + const int64_t nBlockN = kSingleValueIteration + ? 1 + : ceil_div( + (int64_t)problem_size_1_n, int64_t(MM1::ThreadblockShape::kN)); + for (int blockN = 0; blockN < nBlockN; ++blockN) { + int gemm_k_iterations = + (problem_size_1_k + MM1::Mma::Shape::kK - 1) / MM1::Mma::Shape::kK; + + // Compute threadblock-scoped matrix multiply-add and store it in accum + // (in registers) + if (!kPreloadV) { + __syncthreads(); // we share shmem between mma and epilogue + } + + typename MM1::Mma::IteratorB iterator_V( + typename MM1::IteratorB::Params{MM1::LayoutB(p.v_strideM)}, + p.value_ptr + iter_key_start * p.v_strideM, + {problem_size_1_k, problem_size_1_n}, + thread_id(), + cutlass::MatrixCoord{0, blockN * MM1::Mma::Shape::kN}); + typename MM1::Mma mma_pv( + shared_storage.after_mm0.mm1.mm, + shared_storage.after_mm0.si, + (int)thread_id(), + (int)warp_id(), + (int)lane_id(), + (int)problem_size_1_k); + mma_pv.set_prologue_done(kPreloadV); + if (!kKeepOutputInRF) { + accum_o.clear(); + } + mma_pv(gemm_k_iterations, accum_o, iterator_V, accum_o); + __syncthreads(); + + if (kPreloadV && !kSingleValueIteration && blockN + 1 < nBlockN) { + prologueV(blockN + 1); + } + + if (!kKeepOutputInRF) { + DISPATCH_BOOL( + iter_key_start == 0, kIsFirst, ([&] { + DISPATCH_BOOL( + (iter_key_start + kKeysPerBlock) >= p.num_keys, + kIsLast, + ([&] { + using DefaultEpilogue = typename MM1::DefaultEpilogue; + using DefaultOp = + typename MM1::DefaultConfig::EpilogueOutputOp; + using ElementCompute = typename DefaultOp::ElementCompute; + using EpilogueOutputOp = typename cutlass::epilogue:: + thread::MemoryEfficientAttentionNormalize< + typename cutlass::platform::conditional< + kIsLast, + output_t, + output_accum_t>::type, + output_accum_t, + DefaultOp::kCount, + typename DefaultOp::ElementAccumulator, + ElementCompute, + kIsFirst, + kIsLast, + cutlass::Array>; + using Epilogue = typename cutlass::epilogue::threadblock:: + EpiloguePipelined< + typename DefaultEpilogue::Shape, + typename MM1::Mma::Operator, + DefaultEpilogue::kPartitionsK, + typename cutlass::platform::conditional< + kIsLast, + typename MM1::OutputTileIterator, + typename MM1::OutputTileIteratorAccum>::type, + typename DefaultEpilogue:: + AccumulatorFragmentIterator, + typename DefaultEpilogue::WarpTileIterator, + typename DefaultEpilogue::SharedLoadIterator, + EpilogueOutputOp, + typename DefaultEpilogue::Padding, + DefaultEpilogue::kFragmentsPerIteration, + true, // IterationsUnroll + typename MM1::OutputTileIteratorAccum // Read + // iterator + >; + + int col = blockN * MM1::Mma::Shape::kN; + auto source_iter = createOutputAccumIter(col); + auto dest_iter = call_conditional< + kIsLast, + decltype(createOutputIter), + decltype(createOutputAccumIter)>:: + apply(createOutputIter, createOutputAccumIter, col); + EpilogueOutputOp rescale(s_prime, m_prime); + Epilogue epilogue( + shared_storage.epilogue_shared_storage(), + thread_id(), + warp_id(), + lane_id()); + epilogue(rescale, dest_iter, accum_o, source_iter); + })); + })); + if (!kSingleValueIteration) { + __syncthreads(); + } + } + } + __syncthreads(); // we modify `m_prime` after + } + + if (kKeepOutputInRF) { + constexpr bool kIsFirst = true; + constexpr bool kIsLast = true; + using DefaultEpilogue = typename MM1::DefaultEpilogue; + using DefaultOp = typename MM1::DefaultConfig::EpilogueOutputOp; + using ElementCompute = typename DefaultOp::ElementCompute; + using EpilogueOutputOp = + typename cutlass::epilogue::thread::MemoryEfficientAttentionNormalize< + output_t, // output + output_accum_t, // source + DefaultOp::kCount, + typename DefaultOp::ElementAccumulator, // accum + output_accum_t, // compute + kIsFirst, + kIsLast, + cutlass::Array>; + using Epilogue = + typename cutlass::epilogue::threadblock::EpiloguePipelined< + typename DefaultEpilogue::Shape, + typename MM1::Mma::Operator, + DefaultEpilogue::kPartitionsK, + typename MM1::OutputTileIterator, // destination + typename DefaultEpilogue::AccumulatorFragmentIterator, + typename DefaultEpilogue::WarpTileIterator, + typename DefaultEpilogue::SharedLoadIterator, + EpilogueOutputOp, + typename DefaultEpilogue::Padding, + DefaultEpilogue::kFragmentsPerIteration, + true, // IterationsUnroll + typename MM1::OutputTileIteratorAccum // source tile + >; + auto dest_iter = createOutputIter(0); + EpilogueOutputOp rescale(s_prime, m_prime); + Epilogue epilogue( + shared_storage.epilogue_shared_storage(), + thread_id(), + warp_id(), + lane_id()); + epilogue(rescale, dest_iter, accum_o); + } + + // 7. Calculate logsumexp + // To make the backward easier, we pad logsumexp with `inf` + // this avoids a few bound checks, and is not more expensive during fwd + static_assert(kQueriesPerBlock < kNumWarpsPerBlock * kWarpSize, ""); + if (p.logsumexp_ptr && thread_id() < kQueriesPerBlock) { + auto lse_dim = ceil_div((int32_t)p.num_queries, kAlignLSE) * kAlignLSE; + if (thread_id() < p.num_queries) { + p.logsumexp_ptr[thread_id()] = accum_t(mi[thread_id()]) + + cutlass::fast_log(accum_t(s_prime[thread_id()])); + } else if (thread_id() < lse_dim) { + p.logsumexp_ptr[thread_id()] = + cutlass::platform::numeric_limits::infinity(); + } + } + } + + static CUTLASS_DEVICE int8_t lane_id() { + return threadIdx.x; + } + static CUTLASS_DEVICE int8_t warp_id() { + return threadIdx.y; + } + static CUTLASS_DEVICE int16_t thread_id() { + return threadIdx.x + threadIdx.y * blockDim.x; + } +}; + +template +__global__ void __launch_bounds__(AK::kNumThreads, AK::kMinBlocksPerSm) + attention_kernel_batched_impl(typename AK::Params p) { + if (!p.advance_to_block()) { + return; + } + AK::attention_kernel(p); +} + +template +__global__ void __launch_bounds__(AK::kNumThreads, AK::kMinBlocksPerSm) + attention_kernel_batched(typename AK::Params params); + +#define _ATTENTION_KERNEL_FORWARD_BEGIN(...) \ + template <> \ + __global__ void __launch_bounds__( \ + __VA_ARGS__::kNumThreads, __VA_ARGS__::kMinBlocksPerSm) \ + attention_kernel_batched<__VA_ARGS__>(typename __VA_ARGS__::Params p) { \ + using Kernel = __VA_ARGS__; +#define _ATTENTION_KERNEL_FORWARD_END() } + +#ifdef __CUDA_ARCH__ +#define __CUDA_ARCH_OR_ZERO__ __CUDA_ARCH__ +#else +#define __CUDA_ARCH_OR_ZERO__ 0 +#endif + +#define INSTANTIATE_ATTENTION_KERNEL_FORWARD( \ + ARCH, \ + SCALAR_T, \ + IS_ALIGNED, \ + QUERIES_PER_BLOCK, \ + KEYS_PER_BLOCK, \ + SINGLE_VALUE_ITER) \ + _ATTENTION_KERNEL_FORWARD_BEGIN(AttentionKernel< \ + SCALAR_T, \ + cutlass::arch::Sm##ARCH, \ + IS_ALIGNED, \ + QUERIES_PER_BLOCK, \ + KEYS_PER_BLOCK, \ + SINGLE_VALUE_ITER>) \ + if (!p.advance_to_block()) { \ + return; \ + } \ + Kernel::attention_kernel(p); \ + _ATTENTION_KERNEL_FORWARD_END(); + +#define INSTANTIATE_ATTENTION_KERNEL_FORWARD_DISABLED( \ + ARCH, \ + SCALAR_T, \ + IS_ALIGNED, \ + QUERIES_PER_BLOCK, \ + KEYS_PER_BLOCK, \ + SINGLE_VALUE_ITER) \ + _ATTENTION_KERNEL_FORWARD_BEGIN(AttentionKernel< \ + SCALAR_T, \ + cutlass::arch::Sm##ARCH, \ + IS_ALIGNED, \ + QUERIES_PER_BLOCK, \ + KEYS_PER_BLOCK, \ + SINGLE_VALUE_ITER>) \ + printf( \ + "FATAL: this function is for sm%d, but was built for sm%d\n", \ + int(ARCH), \ + int(__CUDA_ARCH_OR_ZERO__)); \ + _ATTENTION_KERNEL_FORWARD_END(); + +// All kernels are disabled by default +#define INSTANTIATE_ATTENTION_KERNEL_FORWARD_SM50(...) \ + INSTANTIATE_ATTENTION_KERNEL_FORWARD_DISABLED(50, __VA_ARGS__) +#define INSTANTIATE_ATTENTION_KERNEL_FORWARD_SM70(...) \ + INSTANTIATE_ATTENTION_KERNEL_FORWARD_DISABLED(70, __VA_ARGS__) +#define INSTANTIATE_ATTENTION_KERNEL_FORWARD_SM75(...) \ + INSTANTIATE_ATTENTION_KERNEL_FORWARD_DISABLED(75, __VA_ARGS__) +#define INSTANTIATE_ATTENTION_KERNEL_FORWARD_SM80(...) \ + INSTANTIATE_ATTENTION_KERNEL_FORWARD_DISABLED(80, __VA_ARGS__) + +// Enable the right one based on __CUDA_ARCH__ +#ifndef __CUDA_ARCH__ +#elif __CUDA_ARCH__ < 500 +//#error "Need cuda arch at least 5.0" +#elif __CUDA_ARCH__ < 700 +#undef INSTANTIATE_ATTENTION_KERNEL_FORWARD_SM50 +#define INSTANTIATE_ATTENTION_KERNEL_FORWARD_SM50(...) \ + INSTANTIATE_ATTENTION_KERNEL_FORWARD(50, __VA_ARGS__) +#elif __CUDA_ARCH__ < 750 +#undef INSTANTIATE_ATTENTION_KERNEL_FORWARD_SM70 +#define INSTANTIATE_ATTENTION_KERNEL_FORWARD_SM70(...) \ + INSTANTIATE_ATTENTION_KERNEL_FORWARD(70, __VA_ARGS__) +#elif __CUDA_ARCH__ < 800 +#undef INSTANTIATE_ATTENTION_KERNEL_FORWARD_SM75 +#define INSTANTIATE_ATTENTION_KERNEL_FORWARD_SM75(...) \ + INSTANTIATE_ATTENTION_KERNEL_FORWARD(75, __VA_ARGS__) +#elif __CUDA_ARCH__ >= 800 +#undef INSTANTIATE_ATTENTION_KERNEL_FORWARD_SM80 +#define INSTANTIATE_ATTENTION_KERNEL_FORWARD_SM80(...) \ + INSTANTIATE_ATTENTION_KERNEL_FORWARD(80, __VA_ARGS__) +#endif + +#endif diff --git a/onnxruntime/contrib_ops/cuda/bert/cutlass_fmha/memory_efficient_attention.cu b/onnxruntime/contrib_ops/cuda/bert/cutlass_fmha/memory_efficient_attention.cu new file mode 100644 index 0000000000000..284211f96514d --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/bert/cutlass_fmha/memory_efficient_attention.cu @@ -0,0 +1,30 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. +#if USE_FLASH_ATTENTION + +#include "contrib_ops/cuda/bert/cutlass_fmha/memory_efficient_attention.h" + +namespace onnxruntime { +namespace contrib { +namespace cuda { + +void run_memory_efficient_attention(const MemoryEfficientAttentionParams& params) { + const int32_t& sm = params.sm; + if (sm >= 80) { + run_memory_efficient_attention_sm80(params); + } else if (sm >= 75) { + run_memory_efficient_attention_sm75(params); + } else if (sm >= 70) { + run_memory_efficient_attention_sm70(params); + } else if (sm >= 50) { + run_memory_efficient_attention_sm50(params); + } else { + assert(false); // shall not reach here. + } +} + +} // namespace cuda +} // namespace contrib +} // namespace onnxruntime + +#endif // USE_FLASH_ATTENTION diff --git a/onnxruntime/contrib_ops/cuda/bert/cutlass_fmha/memory_efficient_attention.h b/onnxruntime/contrib_ops/cuda/bert/cutlass_fmha/memory_efficient_attention.h new file mode 100644 index 0000000000000..d4484628b6f32 --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/bert/cutlass_fmha/memory_efficient_attention.h @@ -0,0 +1,55 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. +#pragma once + +#if USE_FLASH_ATTENTION + +#include "core/providers/cuda/cuda_common.h" +#include "contrib_ops/cpu/bert/attention_common.h" + +namespace onnxruntime { +namespace contrib { +namespace cuda { + +struct MemoryEfficientAttentionParams { + int32_t sm; + bool is_half; + int32_t batch_size; + int32_t num_heads; + int32_t sequence_length; + int32_t kv_sequence_length; + int32_t qk_head_size; + int32_t v_head_size; + bool causal; + + int32_t* cu_seqlens_q; + int32_t* cu_seqlens_k; + + const void* query; // [B, S, N, H] + const void* key; // [B, L, N, H], where L is kv_sequence_length + const void* value; // [B, L, N, H_v] + void* output; // [B, S, N, H_v] + void* workspace; // [B, S, N, H_v] when kNeedsOutputAccumulatorBuffer, nullptr otherwise + cudaStream_t stream; + + static bool need_workspace(size_t v_head_size, bool is_float) { + return (v_head_size > 128 && !is_float); + } +}; + +void run_memory_efficient_attention(const MemoryEfficientAttentionParams& params); + +inline bool has_memory_efficient_attention(int32_t sm, bool is_half) { + return sm >= (is_half ? 53 : 50); +} + +void run_memory_efficient_attention_sm80(const MemoryEfficientAttentionParams& params); +void run_memory_efficient_attention_sm75(const MemoryEfficientAttentionParams& params); +void run_memory_efficient_attention_sm70(const MemoryEfficientAttentionParams& params); +void run_memory_efficient_attention_sm50(const MemoryEfficientAttentionParams& params); + +} // namespace cuda +} // namespace contrib +} // namespace onnxruntime + +#endif // USE_FLASH_ATTENTION diff --git a/onnxruntime/contrib_ops/cuda/bert/multihead_attention.cc b/onnxruntime/contrib_ops/cuda/bert/multihead_attention.cc index c9baf66d7089c..363a901858692 100644 --- a/onnxruntime/contrib_ops/cuda/bert/multihead_attention.cc +++ b/onnxruntime/contrib_ops/cuda/bert/multihead_attention.cc @@ -6,6 +6,7 @@ #include "contrib_ops/cuda/bert/attention_impl.h" #include "contrib_ops/cuda/bert/multihead_attention.h" #include "contrib_ops/cpu/bert/multihead_attention_helper.h" +#include "contrib_ops/cuda/bert/cutlass_fmha/memory_efficient_attention.h" using namespace onnxruntime::cuda; using namespace ::onnxruntime::common; @@ -42,8 +43,14 @@ MultiHeadAttention::MultiHeadAttention(const OpKernelInfo& info) disable_fused_runner_ = sizeof(T) != 2 || ParseEnvironmentVariableWithDefault(attention::kDisableFusedAttention, false); - enable_flash_attention_ = sizeof(T) == 2 && - !ParseEnvironmentVariableWithDefault(attention::kDisableFlashAttention, false); + enable_trt_flash_attention_ = sizeof(T) == 2 && + !ParseEnvironmentVariableWithDefault(attention::kDisableTrtFlashAttention, false); + +#if USE_FLASH_ATTENTION + disable_memory_efficient_attention_ = ParseEnvironmentVariableWithDefault(attention::kDisableMemoryEfficientAttention, false); +#else + disable_memory_efficient_attention_ = true; +#endif disable_fused_cross_attention_ = sizeof(T) != 2 || ParseEnvironmentVariableWithDefault(attention::kDisableFusedCrossAttention, false); } @@ -85,7 +92,6 @@ Status MultiHeadAttention::ComputeInternal(OpKernelContext* context) const { bool is_mask_1d_seq_len = parameters.mask_type == AttentionMaskType::MASK_1D_KEY_SEQ_LEN; - bool use_fused_cross_attention = !disable_fused_cross_attention_ && nullptr == key_padding_mask && parameters.hidden_size == parameters.v_hidden_size && @@ -104,17 +110,18 @@ Status MultiHeadAttention::ComputeInternal(OpKernelContext* context) const { } bool use_fused_runner = !disable_fused_runner_ && + fused_cross_attention_kernel == nullptr && (nullptr == key_padding_mask || is_mask_1d_seq_len) && parameters.hidden_size == parameters.v_hidden_size && parameters.sequence_length == parameters.kv_sequence_length && FusedMHARunnerFP16v2::is_supported(sm, parameters.head_size, sequence_length, - enable_flash_attention_, false); + enable_trt_flash_attention_, false); if (use_fused_runner) { // Here we assume that num_heads and head_size does not change for a MultiHeadAttention node. if (nullptr == fused_fp16_runner_.get()) { constexpr bool is_unidirectional = false; fused_fp16_runner_.reset(new FusedMHARunnerFP16v2( - num_heads_, parameters.head_size, sm, is_unidirectional, enable_flash_attention_, parameters.scale)); + num_heads_, parameters.head_size, sm, is_unidirectional, enable_trt_flash_attention_, parameters.scale)); } // In case some kernel not loaded due to shared memory limit, we need to double check here. @@ -124,6 +131,16 @@ Status MultiHeadAttention::ComputeInternal(OpKernelContext* context) const { } } +#if USE_FLASH_ATTENTION + bool use_memory_efficient_attention = fused_runner == nullptr && + fused_cross_attention_kernel == nullptr && + !disable_memory_efficient_attention_ && + nullptr == key_padding_mask && // TODO: support 1D mask + has_memory_efficient_attention(sm, sizeof(T) == 2); +#else + constexpr bool use_memory_efficient_attention = false; +#endif + constexpr size_t element_size = sizeof(T); size_t workSpaceSize = GetAttentionWorkspaceSize(element_size, parameters.batch_size, @@ -133,7 +150,8 @@ Status MultiHeadAttention::ComputeInternal(OpKernelContext* context) const { parameters.sequence_length, parameters.kv_sequence_length, parameters.total_sequence_length, - fused_runner); + fused_runner, + use_memory_efficient_attention); auto work_space = GetScratchBuffer(workSpaceSize, context->GetComputeStream()); typedef typename ToCudaType::MappedType CudaT; @@ -152,6 +170,7 @@ Status MultiHeadAttention::ComputeInternal(OpKernelContext* context) const { data.present = nullptr; data.fused_runner = reinterpret_cast(fused_runner); data.fused_cross_attention_kernel = fused_cross_attention_kernel; + data.use_memory_efficient_attention = use_memory_efficient_attention; cublasHandle_t cublas = GetCublasHandle(context); return QkvToContext( diff --git a/onnxruntime/contrib_ops/cuda/bert/multihead_attention.h b/onnxruntime/contrib_ops/cuda/bert/multihead_attention.h index 56a1b7885802f..b4ac7f19597ea 100644 --- a/onnxruntime/contrib_ops/cuda/bert/multihead_attention.h +++ b/onnxruntime/contrib_ops/cuda/bert/multihead_attention.h @@ -24,8 +24,9 @@ class MultiHeadAttention final : public CudaKernel { int num_heads_; // number of attention heads float mask_filter_value_; bool disable_fused_runner_; - bool enable_flash_attention_; + bool enable_trt_flash_attention_; bool disable_fused_cross_attention_; + bool disable_memory_efficient_attention_; mutable std::unique_ptr fused_fp16_runner_; mutable const FusedMultiHeadCrossAttentionKernel* fused_fp16_cross_attention_kernel_; }; diff --git a/onnxruntime/contrib_ops/cuda/quantization/attention_quantization.cc b/onnxruntime/contrib_ops/cuda/quantization/attention_quantization.cc index db77ac575e320..e5ea47a6a2a5b 100644 --- a/onnxruntime/contrib_ops/cuda/quantization/attention_quantization.cc +++ b/onnxruntime/contrib_ops/cuda/quantization/attention_quantization.cc @@ -174,6 +174,7 @@ Status QAttention::ComputeInternal(OpKernelContext* context) const { Tensor* present = context->Output(1, present_shape); void* fused_runner = nullptr; // TODO(tianleiwu): use fused kernel to speed up + bool use_memory_efficient_attention = false; size_t workSpaceSize = GetAttentionWorkspaceSize(element_size, batch_size, parameters.num_heads, @@ -182,7 +183,8 @@ Status QAttention::ComputeInternal(OpKernelContext* context) const { sequence_length, parameters.kv_sequence_length, parameters.total_sequence_length, - fused_runner); + fused_runner, + use_memory_efficient_attention); auto work_space = GetScratchBuffer(workSpaceSize, context->GetComputeStream()); @@ -202,6 +204,7 @@ Status QAttention::ComputeInternal(OpKernelContext* context) const { data.present = (nullptr == present) ? nullptr : reinterpret_cast(present->MutableData()); data.fused_runner = fused_runner; data.fused_cross_attention_kernel = nullptr; + data.use_memory_efficient_attention = use_memory_efficient_attention; return QkvToContext(GetDeviceProp(), cublas, Stream(context), parameters, data); } diff --git a/onnxruntime/python/tools/transformers/benchmark_helper.py b/onnxruntime/python/tools/transformers/benchmark_helper.py index dae65df8fe14e..2b5c53b867257 100644 --- a/onnxruntime/python/tools/transformers/benchmark_helper.py +++ b/onnxruntime/python/tools/transformers/benchmark_helper.py @@ -544,7 +544,8 @@ def get_ort_environment_variables(): env_names = [ "ORT_DISABLE_FUSED_ATTENTION", "ORT_DISABLE_FUSED_CROSS_ATTENTION", - "ORT_DISABLE_FLASH_ATTENTION", + "ORT_DISABLE_TRT_FLASH_ATTENTION", + "ORT_DISABLE_MEMORY_EFFICIENT_ATTENTION", "ORT_TRANSFORMER_OPTIONS", "ORT_CUDA_GEMM_OPTIONS", ] diff --git a/onnxruntime/test/contrib_ops/attention_op_test.cc b/onnxruntime/test/contrib_ops/attention_op_test.cc index 743c05050c73b..fb1d8fcfe451a 100644 --- a/onnxruntime/test/contrib_ops/attention_op_test.cc +++ b/onnxruntime/test/contrib_ops/attention_op_test.cc @@ -929,7 +929,7 @@ TEST(AttentionTest, Causal_EmptyPastState) { { ScopedEnvironmentVariables scoped_env_vars{ EnvVarMap{ - {onnxruntime::contrib::attention::kDisableFlashAttention, "1"}, + {onnxruntime::contrib::attention::kDisableTrtFlashAttention, "1"}, {onnxruntime::contrib::attention::kDisableFusedAttention, "1"}}}; RunAttentionTest(input_data, weight_data, bias_data, mask_index_data, output_data, batch_size, sequence_length, hidden_size, number_of_heads, use_float16, is_unidirectional, @@ -940,7 +940,7 @@ TEST(AttentionTest, Causal_EmptyPastState) { { ScopedEnvironmentVariables scoped_env_vars{ EnvVarMap{ - {onnxruntime::contrib::attention::kDisableFlashAttention, "1"}, + {onnxruntime::contrib::attention::kDisableTrtFlashAttention, "1"}, {onnxruntime::contrib::attention::kDisableFusedAttention, "0"}}}; RunAttentionTest(input_data, weight_data, bias_data, mask_index_data, output_data, batch_size, sequence_length, hidden_size, number_of_heads, use_float16, is_unidirectional, @@ -951,7 +951,7 @@ TEST(AttentionTest, Causal_EmptyPastState) { { ScopedEnvironmentVariables scoped_env_vars{ EnvVarMap{ - {onnxruntime::contrib::attention::kDisableFlashAttention, "0"}, + {onnxruntime::contrib::attention::kDisableTrtFlashAttention, "0"}, {onnxruntime::contrib::attention::kDisableFusedAttention, "0"}}}; RunAttentionTest(input_data, weight_data, bias_data, mask_index_data, output_data, batch_size, sequence_length, hidden_size, number_of_heads, use_float16, is_unidirectional, diff --git a/onnxruntime/test/contrib_ops/attention_op_test_helper.cc b/onnxruntime/test/contrib_ops/attention_op_test_helper.cc index bb859ff5e0d8b..5c4c14ce137ed 100644 --- a/onnxruntime/test/contrib_ops/attention_op_test_helper.cc +++ b/onnxruntime/test/contrib_ops/attention_op_test_helper.cc @@ -1902,6 +1902,392 @@ void GetCrossAttentionData_HeadSize40(AttentionTestData& data) { 1.2402344f, 2.2792969f, 0.33398438f, 2.2519531f, 0.67041016f, -0.55957031f, 0.20666504f, 1.3583984f, -1.9716797f, 2.6074219f, 2.2832031f, -2.0546875f, -2.4335938f, 0.53515625f, -0.15100098f, 1.9599609f, -0.51513672f, 0.31030273f, -0.49169922f, 1.4677734f, 2.234375f, 0.87451172f, 0.54736328f, -1.8681641f, -4.2265625f, -0.97509766f, -7.296875f, -1.3486328f, 1.3769531f, -1.8427734f, 3.1601562f, -2.4238281f, -0.82421875f, -2.7324219f, -0.52734375f, 2.2089844f, 0.66796875f, -0.42236328f, -3.03125f, -0.047302246f}; } } + +void GetCrossAttentionData_Batch2_HeadSize32_RightSidePadding(AttentionTestData& data, bool is_mask_1d) { + data.hidden_size = 64; + data.v_hidden_size = 64; + data.num_heads = 2; + data.batch_size = 2; + data.sequence_length = 2; + data.kv_sequence_length = 3; + + if (is_mask_1d) { + data.mask_type = AttentionMaskType::MASK_1D_KEY_SEQ_LEN; + data.key_padding_mask_data = {1, 2}; + } else { + data.mask_type = AttentionMaskType::MASK_2D_KEY_PADDING; + data.key_padding_mask_data = {1, 0, 0, + 1, 1, 0}; + } + + data.skip_kernel_types = {AttentionKernelType::AttentionKernel_TrtFusedCrossAttention, + AttentionKernelType::AttentionKernel_TrtFusedAttention, + AttentionKernelType::AttentionKernel_CutlassMemoryEfficientAttention}; + + { + data.query_data = { + 0.66417468f, -2.82039404f, 1.66603971f, 4.84341049f, -1.63285708f, 3.61133432f, + -1.07151258f, -0.41698062f, -1.38491797f, -3.79137778f, 1.34514475f, -2.97253704f, + 2.12579250f, -0.02954102f, 2.30081463f, 0.21410012f, 1.84038579f, 0.46486610f, + -4.49463224f, 0.69027799f, 1.01090157f, 0.04715919f, -1.60957003f, 0.10730582f, + -5.77672052f, 0.37593889f, 2.04825425f, -1.00890708f, -3.88195300f, -2.69047785f, + 1.15699422f, -1.13536406f, -0.42816854f, 3.12039518f, 3.21898699f, -0.51998949f, + -4.72336435f, -0.78055519f, -0.72722042f, 3.17147565f, -1.31066322f, -3.09425855f, + -3.54743338f, -0.07284085f, 1.10525322f, 1.82087338f, -2.03681397f, -4.27978802f, + 0.26408362f, 0.58637118f, -2.07128787f, -3.48036027f, -0.03049034f, -1.99293542f, + -0.67289937f, 1.17342246f, -4.84998703f, -2.43558168f, 1.16422236f, 0.26511097f, + -1.98199308f, -1.86423326f, 1.61366916f, -0.35201707f, + -1.43554640f, -1.37493825f, 2.32563400f, -1.31762123f, -1.46716797f, 0.18536982f, + 0.85819042f, -3.11506653f, -1.25773919f, 1.30177450f, 0.58314162f, -1.72039497f, + -4.55264997f, 0.02031951f, -2.83490133f, 2.69835496f, -0.07102034f, -2.05412841f, + -1.26518285f, 3.30601740f, -4.54173231f, 0.80148667f, -1.36685658f, -2.26921320f, + -0.94192690f, -2.77439642f, 0.43918809f, 1.44727242f, 1.53386545f, 2.67014980f, + 3.30231142f, -1.60745978f, -1.26032567f, 1.27801156f, 0.31288767f, 3.04471421f, + -1.09798527f, -2.76303077f, -1.68329728f, -4.78179169f, -0.86371553f, -1.57159030f, + -1.06435764f, 3.61700702f, 0.71459293f, -0.25048330f, 1.31865597f, -1.83117080f, + -1.10344386f, 2.94894052f, -1.33930528f, 1.94855583f, -1.94283628f, -0.64020038f, + 2.24100995f, 1.06447530f, -0.03809617f, 3.47241497f, -2.55227089f, 0.12048072f, + 2.88777542f, -1.73300576f, 3.10077643f, -0.37158102f, + + -0.76705527f, -1.27237630f, 3.55744553f, 0.84103155f, -2.37726879f, 0.20218298f, + -3.41723180f, 1.26160014f, 1.45791709f, -1.47226799f, -2.36974764f, 1.49916458f, + 1.68845606f, -1.33727181f, -2.18113089f, -0.64312577f, -1.06002951f, -0.98938328f, + 1.95285964f, 3.08321524f, 1.28492856f, 2.28907299f, 1.14324796f, -0.11273877f, + -5.96574259f, -1.80337310f, 3.86340094f, -2.42390299f, -1.29642844f, 0.14276078f, + -1.23373103f, -0.51519167f, -1.04046988f, 0.60624832f, -0.93274558f, 2.46919179f, + -0.58201206f, -3.43382907f, 1.63227773f, 1.92112875f, -0.17216301f, 2.79771209f, + 2.67759442f, 1.73900354f, -0.00557053f, -0.63086307f, -0.37115061f, 0.82691956f, + 1.81370568f, -0.48766607f, -1.05545425f, -2.79009533f, -7.64374399f, -2.65407372f, + -0.84429693f, 1.35677493f, -1.25277543f, 2.26928639f, -1.77852845f, 2.31752825f, + -1.28869593f, -2.97340727f, -2.87103486f, 2.17401385f, + 0.20970306f, -1.19119942f, 1.11263359f, 0.21227169f, -5.30872822f, -2.15851903f, + 0.63067430f, -0.49583313f, 3.05784941f, 0.09588236f, 0.76925617f, 1.18900692f, + 0.35771871f, -0.97235727f, 1.14949071f, -1.25595427f, 2.37192512f, -0.32522821f, + 1.42988098f, -0.38017935f, 2.49831486f, -0.30629224f, 1.08675146f, -1.02598715f, + -0.17971759f, -0.55683851f, 1.04535389f, 1.54741859f, -0.05179391f, 0.73957652f, + 0.54304504f, 1.95280874f, -1.19504929f, -1.19528544f, 1.33258319f, 0.13532166f, + -1.87509251f, 0.99605685f, 2.69439840f, 1.03421521f, 1.79539657f, 0.15001571f, + 0.55184591f, -0.84038037f, -2.08177447f, -1.43082356f, -1.52199960f, 1.69448102f, + 2.12475252f, -2.64191580f, 0.10776700f, -4.01538181f, 1.15558016f, -0.09849232f, + 0.33533198f, 3.34633803f, -2.89805937f, -2.51580763f, 0.94939411f, 1.36254668f, + 0.47172806f, 4.40817642f, -0.11368597f, -2.70789719f}; + } + { + data.key_data = { + 1.18319833f, -0.20700163f, -0.64873743f, 3.88316822f, -2.82827115f, 4.12166834f, 0.84225285f, -1.11044288f, + -1.75086212f, -1.66724730f, 2.22730064f, -3.22617316f, -0.14071584f, 0.58066225f, 3.04375815f, -1.43881261f, + -2.39294887f, 1.03637624f, -0.98744214f, 1.13576865f, -0.23876363f, 0.27395499f, -0.51450062f, -2.23614597f, + -2.12345290f, -0.68864471f, 2.56223369f, -1.14069867f, -2.14457107f, -1.32647824f, -1.20575166f, -0.98427975f, + 0.43083039f, -1.72496212f, 0.89925444f, -0.33879194f, -1.01836991f, 0.06260723f, -4.40405083f, 1.51136112f, + -1.57057071f, -2.49242449f, -0.37187487f, -3.55319405f, 1.50083232f, 0.37271553f, 1.00157571f, -0.50416815f, + 1.28753221f, -0.82453167f, -1.13294256f, -1.49514699f, 0.11243388f, 1.89696264f, -1.46173263f, 3.32755566f, + -0.54521537f, -2.61305809f, -0.43132567f, -0.33066380f, -0.47485363f, 3.62707257f, -0.61352783f, 2.21147466f, + -2.39673638f, 0.89925957f, -2.58643913f, -0.81968069f, 3.34945726f, 0.73745269f, -1.62732553f, -4.55126476f, + 2.78017616f, 0.33757699f, 2.50468874f, -4.14928627f, 0.20017165f, 3.62233806f, -4.17984772f, 2.60447359f, + 2.16826940f, 1.70457518f, 1.03199887f, 2.66712570f, 0.50808340f, -3.47132921f, -2.60008478f, 1.03852415f, + -0.53876096f, 3.36212158f, -5.49142551f, 1.69825470f, -2.98179603f, -3.39561105f, -2.33971524f, 1.23642313f, + 2.13283253f, -0.56307364f, -2.49120903f, 2.97641850f, -1.28758216f, 3.43342829f, 2.49575281f, 0.09292871f, + -0.46469527f, -3.95696974f, 2.16474032f, -2.15254521f, -2.24547267f, 2.34235692f, -1.02470589f, 3.97816467f, + 3.60425544f, 1.87994969f, -2.46964216f, 1.47802746f, -1.81441534f, -1.56946301f, 0.56189334f, -1.69905055f, + -1.83049631f, 4.64296293f, 3.36173010f, 1.17065477f, 0.62365234f, 1.23748016f, 0.63865232f, -2.90434527f, + 1.80253839f, 3.11227179f, -3.96782875f, -2.78780794f, 3.76587057f, -1.66908360f, 1.83301187f, -1.74414611f, + -2.83874130f, -2.00238085f, -6.45539570f, 0.56152177f, 2.52830791f, -4.32480669f, 1.40038610f, 0.83278954f, + 0.16065764f, -0.13457650f, 2.17216778f, -4.28218699f, 0.75475001f, -0.67497885f, -0.95346600f, 3.29623652f, + 1.84325528f, 1.18348145f, -0.23741919f, 2.49520302f, 0.88820332f, 1.15528166f, 0.75733638f, 2.09371948f, + -1.16427231f, 1.36415648f, -1.17721760f, 0.19180456f, -3.83617687f, -0.22694540f, 5.14728260f, -0.43242604f, + -2.59039426f, -1.40904129f, 0.58194822f, -2.59625196f, -3.60205126f, 1.45633197f, 3.66319609f, -4.45727873f, + 3.95457315f, -0.17875004f, 2.43404126f, 2.83592010f, 0.87342203f, 1.24538708f, 3.10003138f, 2.63025975f, + 4.57258415f, -5.20645714f, -2.55821514f, 0.60136455f, -4.13579988f, -2.04082966f, 2.21142578f, -1.05740535f, + + 1.78609943f, -3.10438013f, -0.13040465f, -3.02957106f, 0.91924584f, 0.45405358f, -1.90627027f, -1.05065346f, + -1.21743047f, -1.65989709f, -0.51138550f, 2.04327297f, 0.65217698f, 0.77914226f, 1.86315429f, 0.75791669f, + -0.55304748f, -1.23857486f, 2.63207936f, -0.51371288f, 5.48993397f, -2.35509205f, -2.30255723f, 3.88706803f, + -1.93575382f, 0.03364474f, -1.61156952f, -2.74172544f, 1.64667726f, 0.04652762f, 2.88130736f, -2.00066185f, + 0.74907655f, -3.35894132f, -1.85703170f, 1.78695405f, 0.16497552f, 0.94382036f, 3.04452896f, -4.42404556f, + -1.67239439f, 0.93356639f, 0.08288032f, -0.11422639f, -3.94759631f, 0.35302341f, -1.20778334f, -1.92491865f, + -1.86599326f, -1.29324412f, -1.12795746f, 0.24268979f, -0.50242394f, 2.26449108f, 0.91289425f, -2.48235416f, + -1.12685704f, -0.32806787f, 3.28139257f, 3.19231367f, 0.99441254f, -1.86975384f, -3.57600951f, 0.07424650f, + -0.45312887f, 5.02197504f, -3.93365264f, -3.30742884f, -1.48101401f, 1.03335130f, 2.79531693f, -3.71739435f, + 1.58574414f, -4.52857542f, 1.99908066f, 1.53755212f, 1.60631371f, -2.46801257f, -1.85840714f, 5.07508087f, + 1.69143867f, -1.04688716f, -3.17096090f, -4.08357859f, -0.02436948f, -1.26299214f, 1.55509603f, 3.11954260f, + 3.55844116f, 0.10080734f, -0.57031679f, 2.01342750f, -0.66671205f, -1.89724469f, 2.52388906f, 3.71421099f, + 0.77953398f, -1.63364959f, -1.90900147f, -3.60591793f, 1.17604601f, -1.69456589f, -1.62096381f, -1.44886708f, + -1.09821022f, -1.27646899f, 2.73696446f, -2.21802664f, -0.22022307f, 1.76918471f, -1.55524099f, 0.27310926f, + -0.56175643f, -0.59620953f, 2.34752941f, -0.74946308f, -2.33520174f, 1.37984359f, -1.82466078f, -0.04973821f, + -4.77387571f, -0.85034770f, 3.39579129f, -2.82413197f, -2.37980723f, 0.10482252f, 0.10614476f, 0.38176090f, + -0.03948998f, -3.33898020f, 0.33013302f, -0.24926627f, 1.82249093f, 0.57584983f, -0.68790460f, -0.62760007f, + 0.17052543f, -0.54540014f, 1.66043472f, -0.29917845f, 3.31803465f, 0.86704284f, -0.26854402f, 2.23795938f, + -0.65058500f, -2.01540327f, -2.32472515f, -2.85143948f, -3.76564598f, -0.25596800f, -2.08064461f, -0.60812098f, + 3.64154029f, -2.58636141f, -0.25312662f, -2.22530699f, -1.24763203f, -3.08458424f, 0.69228125f, -1.84211481f, + 1.09744453f, -1.35679579f, 1.68044925f, 0.89537722f, 3.56465936f, -0.64790231f, -1.42140329f, -2.85126376f, + 0.88302374f, -0.77923191f, -0.61865216f, -3.08081675f, 0.87791818f, -0.27943787f, 0.46918952f, 1.50163293f, + 3.43236423f, 1.99953759f, -2.42805409f, 4.97383118f, -2.13942194f, 1.45409000f, -1.14207470f, 0.63804722f, + -4.23801470f, 1.23076391f, 2.71176004f, 1.13607812f, 2.27742863f, 1.64165723f, 1.20048785f, -0.66269439f}; + } + + { + data.value_data = { + 2.52855659f, 1.00436294f, 0.83871710f, 0.97005701f, 1.33615291f, -2.07353282f, 0.14190522f, -1.42923164f, + -0.05781263f, -3.81081843f, 1.15263164f, 0.62601233f, -0.93824124f, 1.21525323f, -0.17992918f, 2.08717370f, + 3.61659431f, -0.16836943f, 2.17779160f, -0.63968349f, 0.32170480f, 1.74428463f, -0.46570981f, -0.07432288f, + -0.21569058f, 0.65559602f, 3.58669281f, 0.40837619f, 2.40912223f, 1.31780922f, -4.45945454f, 0.64903581f, + -1.10752177f, -1.79390311f, 0.89312351f, -1.84512544f, -1.13948750f, 3.87221098f, -2.74163318f, 2.90849519f, + -0.31782085f, 3.12108278f, 0.80056298f, 1.02164125f, -0.07995117f, -0.96148860f, 3.49803638f, -4.48321056f, + -1.50024915f, -2.58987570f, 0.61711067f, 4.13532829f, -4.38111591f, -2.48988461f, -0.43977243f, -3.93134618f, + -2.67314148f, 2.64455128f, 0.11041284f, 1.26786041f, -0.24446392f, -0.86178148f, 2.35680771f, -1.69236851f, + -1.22143269f, 1.99185669f, 2.99625540f, -2.32311869f, -2.26162481f, 3.13980794f, 0.37014920f, 3.22335911f, + 2.55935216f, 2.19479871f, 4.89236355f, 1.76135564f, -2.74285603f, 1.39842391f, -0.25135490f, -4.76257038f, + -0.80362052f, -1.75548995f, -4.70487833f, 1.72763062f, 3.14491320f, 3.97562551f, -0.64091396f, -0.49683607f, + 1.09094775f, -0.04886785f, -0.20181555f, 2.22182846f, 3.00734067f, -0.52149582f, -1.55592132f, 4.41542721f, + 4.68795204f, -1.03364658f, 1.12266266f, -1.50595415f, -4.82583904f, -0.65535200f, -1.44525290f, -0.24540535f, + -0.44778955f, 2.32284093f, 1.60033488f, 0.12583408f, -4.42107201f, -1.32412672f, -1.84733653f, -1.53440499f, + 3.21279287f, -0.37051341f, 0.26685789f, 2.25037003f, 0.01608747f, 1.66141725f, -0.53394145f, 1.35017800f, + 1.35997009f, -2.73341703f, 5.47488451f, 5.49519920f, -1.90401053f, 3.37626982f, -1.97467375f, 1.91208827f, + -0.39609963f, -3.46037388f, -1.47946858f, 3.59935665f, 2.36377144f, -2.32310963f, 1.95714176f, -3.10615826f, + -1.72878003f, 0.37169266f, -5.95610952f, -1.32819366f, -1.24326205f, 0.17746472f, 2.59834385f, 1.83808351f, + 2.94952321f, 3.01939392f, 1.37281823f, 2.67180538f, -0.32547897f, 1.11373281f, -0.26456773f, 0.30103314f, + -1.05465972f, -1.74858260f, 4.66243505f, -0.58474910f, 1.26216507f, 1.28856802f, 0.30135399f, -3.24127388f, + 1.57217860f, -3.84659171f, 1.52000761f, -0.57999939f, 7.80852032f, 2.83661318f, -1.72516418f, 0.70036685f, + 5.33224869f, 3.27205563f, 0.22613347f, 1.27628899f, 0.63828707f, 0.60137266f, 2.23047280f, -3.12771320f, + -0.03023779f, 0.80765182f, -2.25078392f, -2.55701947f, -1.01789987f, -4.81986141f, 5.08153057f, -1.74439597f, + -2.12658811f, -0.01458025f, -2.19556737f, 0.66254830f, -0.97602153f, -0.09858370f, -2.05090475f, -3.57909155f, + + 4.57896709f, -1.96923888f, -3.86827421f, 3.18770289f, -5.16361237f, 1.42594528f, -1.43490076f, 1.62748218f, + 0.91413617f, -0.27147734f, 0.89311242f, 0.39315015f, 1.18184900f, 4.30172014f, -2.32771754f, 1.61144018f, + 1.31702828f, 1.47999883f, -0.20565452f, 0.75846130f, -0.13237280f, -2.10059071f, 0.12025893f, -0.58277643f, + 1.93927395f, -3.11170292f, 0.84666562f, 0.08490577f, -0.36315954f, -3.13071823f, 0.12070303f, -0.10385191f, + -2.37523723f, 2.28944397f, 0.12518460f, -1.10043252f, -1.94665289f, 3.44240570f, 1.14374518f, 3.27769613f, + 1.40222466f, 0.68902296f, 2.48193359f, 1.85469973f, 0.53099388f, -2.16307211f, 0.67865700f, -0.05084896f, + 0.09825261f, 1.40057099f, -0.74452353f, 0.81515837f, 1.51540780f, -1.30754757f, -1.50317430f, -2.04524612f, + -0.49154273f, 0.75809133f, -0.25134420f, 0.36961895f, -0.01882899f, -1.72547066f, 1.12012851f, -6.72828960f, + 1.76177442f, 1.19128907f, -0.77717477f, -1.97159290f, -2.30860472f, 2.01583147f, 5.43375349f, 2.58655977f, + 0.71099019f, 0.71843386f, 3.10709906f, 1.48128355f, 0.22561067f, -4.27442265f, -2.49249840f, 4.71605539f, + 2.19818974f, -1.96133125f, 0.41619009f, 0.66834581f, -3.74457240f, -0.48215276f, -1.28305256f, -1.83142948f, + -0.72452945f, -1.97440028f, -0.14068973f, 0.11765432f, 0.49793118f, 0.40227121f, -1.34390569f, 0.92099732f, + -1.21718168f, -1.95382285f, 1.37468243f, -0.72062874f, 2.66714525f, 1.06695974f, -2.86761045f, 1.34743905f, + 3.30500460f, -0.91894615f, -0.09608981f, -4.09408808f, -2.57941151f, -0.36501098f, 1.93333972f, 1.54577386f, + -2.96415496f, -2.09494066f, 1.63500857f, -1.51829720f, -0.98314112f, -1.89401948f, -0.54314089f, -3.68928242f, + 1.07439506f, 1.70869648f, 0.86973846f, 1.71959770f, 1.78241849f, -4.29455566f, -1.55857742f, -3.32966399f, + 0.20903873f, 1.40176547f, -6.08825064f, 2.12755013f, 3.84799123f, -0.83979988f, -1.64312506f, -0.69876713f, + 4.00779629f, -2.85212469f, 0.09145057f, 1.72984874f, -0.77233994f, 1.21815240f, -1.75377214f, 4.08561277f, + -1.20909250f, -1.24881196f, 4.37579060f, 4.27434301f, -2.01065826f, 2.96602201f, 3.07406378f, 1.22374272f, + 0.06376281f, -1.60328245f, -1.32239270f, 1.00765312f, 1.27593243f, -2.14843464f, -3.47884607f, -0.32401958f, + -2.52805567f, -1.01782882f, 0.74270618f, 1.47170806f, -2.56010485f, -1.49985540f, 0.92767721f, 3.42378139f, + 5.23711205f, 0.47062784f, -0.26747131f, -2.06014609f, -0.20237172f, -1.60944867f, -2.51956654f, 0.59529293f, + 2.63805699f, 0.43868792f, -5.84081888f, 3.25271368f, -4.44406748f, -3.80642724f, -1.59846020f, -2.59634686f, + 0.11074528f, 2.04441738f, -1.51878321f, -2.59639883f, 2.23697233f, 0.07920718f, 1.31056094f, -8.10540771f}; + } + + { + data.bias_data = { + -0.38124341f, 0.02696526f, -0.11914945f, -0.43795273f, -0.34948170f, -0.19608477f, 0.19725692f, 0.39987487f, + 0.04772711f, -0.03419551f, -0.30606642f, 0.42656231f, -0.23178342f, -0.13692456f, -0.04889601f, 0.48739988f, + -0.25891554f, 0.13431972f, 0.22861153f, 0.06360734f, 0.48096961f, -0.47906545f, 0.43613154f, -0.23511401f, + -0.10595283f, -0.42839217f, 0.28931111f, -0.13180739f, -0.45826656f, 0.23286396f, -0.43407962f, 0.40754890f, + 0.23778325f, 0.34850210f, -0.01385659f, 0.32141626f, -0.27738628f, 0.27683002f, 0.31886810f, -0.24781504f, + -0.25476855f, -0.46742713f, -0.12478521f, 0.39731556f, -0.12087554f, 0.40822440f, 0.13202906f, -0.23747686f, + 0.30502868f, 0.27182943f, -0.03640261f, -0.39626551f, -0.22411832f, 0.17324352f, -0.49959660f, -0.49318257f, + 0.31363028f, 0.05469471f, -0.00390345f, -0.46100286f, -0.27253938f, 0.17251462f, 0.46564627f, 0.21038425f, + 0.27079183f, 0.42074734f, -0.40314156f, -0.43726659f, 0.27376485f, -0.38174152f, -0.43700469f, 0.38040614f, + -0.40546918f, 0.06927037f, 0.16979086f, 0.41458064f, 0.07120579f, -0.08055863f, 0.12095112f, -0.27988660f, + 0.06004709f, -0.05600315f, -0.25510073f, 0.41887105f, -0.19016314f, 0.47241372f, 0.12890404f, -0.24272856f, + 0.21106839f, -0.40523255f, 0.10336459f, -0.11084765f, 0.42408967f, -0.15285304f, -0.28945464f, -0.25714916f, + 0.40978593f, -0.09138483f, -0.02013114f, -0.39042589f, -0.19557095f, 0.07540411f, 0.33955890f, 0.41873980f, + -0.27744853f, -0.33097768f, -0.44587523f, -0.01648277f, 0.34952271f, -0.48838940f, -0.17273578f, 0.37286615f, + -0.10157353f, -0.08097187f, 0.23243034f, 0.25516337f, -0.45793599f, 0.08089012f, 0.17673731f, 0.03000754f, + 0.48834521f, 0.35069120f, -0.32989410f, 0.20729345f, 0.24406803f, 0.35393929f, -0.16146761f, 0.04258209f, + -0.10567203f, 0.26791072f, -0.08976898f, 0.31341976f, 0.06027532f, 0.14307594f, 0.31587386f, 0.16180152f, + 0.34785229f, 0.00531715f, -0.35168743f, -0.11641458f, 0.39196932f, 0.44535065f, 0.43545735f, 0.15593112f, + 0.06171834f, -0.42181283f, -0.41170910f, 0.40969193f, -0.01510030f, 0.07973170f, -0.18156880f, 0.21522856f, + 0.03915739f, -0.20913908f, -0.47068381f, 0.35633272f, -0.35124153f, 0.36624825f, -0.05567622f, -0.35343069f, + 0.12821168f, 0.35526341f, -0.23420528f, -0.46328634f, -0.21994811f, -0.27556795f, 0.01653767f, 0.42626363f, + 0.23239774f, 0.39632857f, 0.32416028f, -0.48494491f, -0.05365932f, -0.10860911f, 0.06893444f, 0.46116674f, + 0.34345043f, -0.02719739f, -0.39574289f, -0.39339882f, 0.23044002f, -0.06155324f, 0.23292047f, 0.39775699f, + 0.12789404f, -0.44719657f, 0.12020230f, 0.26871282f, -0.10917315f, -0.29244915f, 0.09059817f, -0.19613290f}; + } + + { + data.fp32_output_data = { + 2.42288446f, 1.27227366f, 0.74894810f, 1.28347683f, 1.39642823f, -1.93045688f, 0.45777908f, -1.26743007f, + 0.29003966f, -3.80550122f, 0.80094421f, 0.50959778f, -0.54627192f, 1.66060388f, 0.25552815f, 2.24310493f, + 3.67831278f, -0.59018224f, 1.76608253f, -0.22999156f, 0.30660450f, 1.82401633f, -0.64727861f, 0.14090568f, + -0.17653319f, 0.44645694f, 3.11600900f, 0.76470888f, 2.05788064f, 1.68405747f, -4.51513100f, 0.29560512f, + -0.97931010f, -1.43863964f, 0.65891826f, -2.30841184f, -1.35943556f, 3.59664297f, -2.72509551f, 3.33475876f, + -0.08542311f, 3.51741123f, 1.12472320f, 0.53669631f, -0.13361049f, -1.07009768f, 3.56697083f, -4.02204370f, + -1.15679872f, -2.61707306f, 0.22136778f, 3.74192953f, -4.15067577f, -2.55143785f, -0.20685196f, -3.53358912f, + -2.54524755f, 2.19735479f, 0.23061514f, 1.53657317f, -0.35363707f, -1.15423059f, 2.44740582f, -1.88850141f, + + 2.42288446f, 1.27227366f, 0.74894810f, 1.28347683f, 1.39642823f, -1.93045688f, 0.45777908f, -1.26743007f, + 0.29003966f, -3.80550122f, 0.80094421f, 0.50959778f, -0.54627192f, 1.66060388f, 0.25552815f, 2.24310493f, + 3.67831278f, -0.59018224f, 1.76608253f, -0.22999156f, 0.30660450f, 1.82401633f, -0.64727861f, 0.14090568f, + -0.17653319f, 0.44645694f, 3.11600900f, 0.76470888f, 2.05788064f, 1.68405747f, -4.51513100f, 0.29560512f, + -0.97931010f, -1.43863964f, 0.65891826f, -2.30841184f, -1.35943556f, 3.59664297f, -2.72509551f, 3.33475876f, + -0.08542311f, 3.51741123f, 1.12472320f, 0.53669631f, -0.13361049f, -1.07009768f, 3.56697083f, -4.02204370f, + -1.15679872f, -2.61707306f, 0.22136778f, 3.74192953f, -4.15067577f, -2.55143785f, -0.20685196f, -3.53358912f, + -2.54524755f, 2.19735479f, 0.23061514f, 1.53657317f, -0.35363707f, -1.15423059f, 2.44740582f, -1.88850141f, + + 4.47329473f, -1.70132744f, -3.95804238f, 3.50112128f, -5.10333633f, 1.56902146f, -1.11902511f, 1.78928399f, + 1.26198828f, -0.26615992f, 0.54142559f, 0.27673587f, 1.57381809f, 4.74706888f, -1.89226031f, 1.76737213f, + 1.37874687f, 1.05818522f, -0.61736351f, 1.16815329f, -0.14747408f, -2.02085853f, -0.06131025f, -0.36754823f, + 1.97843063f, -3.32084179f, 0.37598154f, 0.44123849f, -0.71440083f, -2.76446915f, 0.06502641f, -0.45728233f, + -1.93884647f, 1.51549935f, 0.22349268f, -1.46264625f, -0.93878794f, 2.53468966f, 0.09279048f, 3.19028425f, + 2.14098549f, 0.65744257f, 2.12003636f, -0.21332240f, -0.35039914f, -1.79318547f, 1.08148456f, 0.83520722f, + -0.37325758f, 0.44315636f, -0.50703102f, -0.19921407f, 1.08093989f, -1.52517128f, -1.01477206f, -2.08499599f, + 0.05307493f, 0.56386751f, 0.16719794f, 0.99758488f, 0.35134155f, -2.70159864f, 0.49787593f, -6.01998806f, + + 1.88393891f, 1.20359635f, -1.11693203f, -1.24092197f, -2.47922421f, 2.11120105f, 5.19413376f, 2.67079711f, + 1.07527149f, 0.64369327f, 2.57635832f, 1.27686763f, 0.69491446f, -3.13548803f, -2.04371452f, 4.62090492f, + 2.18864536f, -2.10483122f, -0.04580984f, 1.08532572f, -3.46754074f, -0.53330994f, -1.35113037f, -1.51521778f, + -0.46994060f, -2.27551699f, -0.53152251f, 0.47133854f, 0.07705012f, 0.48279381f, -1.28113365f, 0.48468336f, + -1.54411674f, 0.06915778f, 0.64939111f, -1.33318806f, 0.63385141f, 1.72500539f, -1.27450287f, 2.53234506f, + 2.78955889f, 0.10935718f, 1.24130249f, -2.24100065f, -1.41059852f, -1.18030620f, 1.50915027f, 1.37942517f, + -1.41709673f, -0.74830860f, 0.30404601f, -0.99458563f, 0.22929534f, -1.72507358f, -0.68753922f, -2.64537501f, + 0.58683372f, 0.88788664f, 0.54932535f, 1.45773280f, 0.96530700f, -3.57728553f, -0.41517627f, -4.86154747f}; + } + + { + data.fp16_output_data = data.fp32_output_data; + } +} + +void GetCrossAttentionData_Batch1_HeadSize32_LeftSidePadding(AttentionTestData& data) { + data.hidden_size = 32; + data.v_hidden_size = 32; + data.num_heads = 1; + data.batch_size = 2; + data.sequence_length = 2; + data.kv_sequence_length = 3; + data.mask_type = AttentionMaskType::MASK_2D_KEY_PADDING; + data.key_padding_mask_data = {0, 1, 1, // first key sequence has one padding on the left + 0, 0, 1}; // second key sequence has two paddings on the left + + data.skip_kernel_types = { + AttentionKernelType::AttentionKernel_TrtFusedCrossAttention, + AttentionKernelType::AttentionKernel_TrtFusedAttention, + AttentionKernelType::AttentionKernel_CutlassMemoryEfficientAttention + }; + + { + data.query_data = { + 2.88765883f, 1.27536213f, -0.57580215f, 2.73696542f, 2.19016314f, 0.42629790f, 1.55081677f, -2.01307678f, + -0.80203497f, -1.23206115f, 1.78565156f, -2.09875321f, -2.22730732f, -0.98120236f, -0.25774139f, 0.75868356f, + -2.87585187f, -0.41810805f, -2.11528730f, 0.50642025f, -0.29446256f, -3.69675803f, -2.73721838f, -1.51089072f, + 0.74300194f, 0.27352047f, -0.88251829f, 2.82622814f, 0.73837662f, -2.14588642f, 0.37608737f, -0.06190044f, + -1.97659302f, -2.22348428f, 2.25573063f, -2.24459195f, -2.28073978f, -0.52412349f, -0.57297325f, 3.29259396f, + 1.35617173f, -0.83082151f, 0.03767079f, 1.82568312f, 0.88193995f, 1.15579486f, 1.87845564f, -0.15923920f, + 2.37435389f, 1.49093378f, 1.95134592f, -1.67609048f, -0.45959851f, 1.63960719f, 3.44909906f, -0.23531833f, + -0.57074630f, 1.38279045f, 0.58870834f, 0.85297751f, -1.44973445f, 1.56243801f, -0.67229253f, -0.16198707f, + + -0.23966503f, -0.15329531f, -3.22765136f, 0.60538405f, -0.33244422f, -1.34865439f, -0.24373266f, -1.78808010f, + -1.53090763f, 1.75037694f, -0.71890754f, 0.12527336f, 1.26654553f, -0.86477917f, -1.49822962f, 1.67973542f, + 0.99763191f, -0.07183220f, 1.55289185f, 1.62626481f, -0.04283767f, -2.55072594f, -1.95238030f, 0.60994428f, + -2.53714681f, 1.54605150f, 0.05900350f, 1.42194426f, 0.33801061f, 1.25557244f, 0.67291188f, -1.36867523f, + 1.86936152f, -1.19588101f, 0.75778806f, 1.85271311f, 0.02081686f, 2.65807819f, 0.78890860f, -1.07388866f, + 4.18109226f, 0.06373940f, 2.86840463f, 0.90427721f, -0.09531648f, -0.40835506f, 1.60812938f, -1.61683714f, + -0.45421624f, -2.25537109f, -1.35910070f, -0.25111723f, -0.71782172f, 0.62597942f, -0.42838976f, 0.23198499f, + 1.29250073f, -2.01550317f, 0.14619158f, -0.03868395f, -0.74211842f, -3.17291188f, -1.90475547f, 2.02544284f}; + } + { + data.key_data = { + 1.14242256f, 1.08148384f, -0.00962424f, -1.62719429f, 0.86478198f, 0.16862091f, 1.01692820f, -1.15278327f, + -1.13622630f, 1.78038371f, 0.58222097f, 0.39166588f, 1.75063372f, -1.20408881f, 0.75154918f, 0.58156419f, + -0.98975772f, -0.82555556f, -0.72656512f, -2.42399549f, 2.19217968f, 2.18518472f, -1.72216129f, 1.35098433f, + -0.34989786f, -0.69064844f, -0.98365444f, 3.10148478f, 0.64813483f, 1.78129303f, -0.47006512f, 2.53122735f, + 0.09757380f, 0.04077591f, -0.81791472f, -0.19737752f, 1.13775492f, -1.51351953f, 0.59109330f, 2.86624002f, + -0.09282493f, -1.69204521f, 1.27087700f, 3.53944731f, 0.59776509f, -0.90838081f, -0.15813766f, -1.86199224f, + 0.18734205f, -0.76110429f, -0.02243887f, -0.94068182f, 1.32443166f, 0.03512055f, -0.13194422f, -1.50401211f, + 0.92001319f, 0.20918207f, -1.34839189f, 1.56431675f, -0.61030018f, 2.39562368f, -1.56722510f, -0.96874726f, + -0.48726845f, -1.41476154f, -1.45116997f, 0.53907454f, -2.14415288f, 1.14340270f, -0.21846619f, -2.72349358f, + 2.99664998f, -2.38684058f, 0.95269018f, 0.04208702f, -1.75080788f, 1.24652982f, -1.76879966f, 3.10814905f, + 2.48754454f, -0.62601894f, 1.41356945f, 0.10340121f, 1.09059846f, -0.78241473f, -0.61477584f, -0.19339988f, + -0.48253334f, -2.41782594f, 1.04690075f, 0.14725411f, -0.20820639f, -1.95920563f, 0.96303236f, -1.20068836f, + + -1.71051037f, -1.90946770f, -2.07985783f, 2.35042953f, 0.35059446f, -0.44228595f, 4.08558750f, -0.60121447f, + 0.78836018f, 0.35280651f, 0.23129070f, -0.21523762f, 0.12277550f, 0.12348226f, -1.62759030f, -2.78246498f, + 4.04853964f, 0.29263157f, -0.38621908f, -1.07599223f, -1.99170423f, 1.41409016f, 2.19121861f, -3.53451037f, + 3.63692737f, 0.68270516f, 2.51469731f, 2.57543731f, -2.39040112f, -3.97164130f, 1.28371549f, 1.64144099f, + -0.70385075f, 2.55361128f, 1.60707259f, 0.84735453f, -2.07756495f, -1.99240303f, -3.60991144f, 2.87136865f, + 2.31296396f, 2.30251813f, -1.05624914f, -2.43777156f, -0.27048296f, 2.39037871f, -2.04504776f, 1.65183067f, + -0.38970214f, 0.16808379f, -1.30286717f, 1.90201700f, -2.71696734f, -0.66445369f, 1.27085483f, -0.60816145f, + 1.81054437f, -1.55584621f, -2.19360781f, -4.52794456f, -0.90534067f, 0.94724411f, 2.40401077f, -2.94815230f, + -3.19650269f, 2.50638890f, 1.02038431f, 1.50519919f, 0.47196171f, -1.89026380f, -1.86559379f, 0.82210326f, + 0.10818237f, 1.45290673f, 1.62321615f, -0.61283481f, -1.42501950f, 2.10349464f, -1.65715265f, 0.30090189f, + -3.81919909f, -2.44903922f, -1.20557833f, -0.69951278f, -1.31475580f, -3.73842764f, 1.49299407f, -0.70933276f, + -1.49021530f, 0.71776378f, -1.23052382f, -2.13119912f, -1.20718014f, 2.30572701f, 1.78386402f, -1.57122159f}; + } + + { + data.value_data = { + 1.79297853f, 0.96909231f, 1.23087275f, -0.61933923f, -0.56477690f, 1.47813499f, 0.51474279f, -3.44743419f, + 0.95816678f, -0.20553169f, -0.76906109f, -4.60927439f, 0.40629998f, 0.91934747f, -1.09594405f, -1.45653892f, + -0.59282207f, 0.05621797f, -2.26634383f, -1.30799258f, 1.22072279f, -3.60811162f, 1.70111597f, 0.47336632f, + -1.43857694f, -0.13917151f, -1.34617388f, 1.07960105f, -1.77342618f, 0.31946269f, 1.19137061f, 2.59346104f, + -1.82395399f, 0.73557752f, 2.32600021f, -0.22650969f, -0.48526058f, 1.40349376f, -0.33553454f, 0.45531431f, + 0.73859257f, 0.37798560f, 0.85344458f, -1.30447221f, 1.23349071f, -0.26439479f, 1.18636096f, -0.33328748f, + -0.50939041f, 0.53500950f, 1.33486223f, -1.54447496f, -2.88690519f, -0.06809106f, -0.00597921f, -1.07510388f, + 0.62182164f, 0.50033569f, -0.88293070f, 2.56142712f, 0.37708595f, 1.59349704f, -1.17139614f, 0.89580274f, + 0.69456708f, 2.91441655f, -0.25431669f, -1.20305562f, 2.06701255f, -0.86700624f, -2.23615170f, 0.13303493f, + -2.97540593f, 0.08654684f, 1.40381706f, 3.54294443f, -2.07661867f, -1.33181918f, 2.24228764f, 1.79975545f, + 2.14695477f, 1.40222490f, -0.29813689f, 1.94485068f, 1.99623775f, 1.53450203f, 0.28755581f, -0.67934704f, + -0.92102510f, -1.52764773f, 1.11267352f, -3.90122724f, 0.22128634f, 0.14945325f, -4.38529491f, -1.58423281f, + + -2.45574522f, -1.91599977f, 5.05240345f, 2.24617362f, 3.99182248f, 0.92924285f, -0.39660916f, -0.08696688f, + 0.24855530f, 0.71378094f, 0.92413902f, 1.73599064f, 1.03852975f, 2.44676781f, 0.35013664f, 0.98107171f, + 1.62946916f, 0.41239718f, -1.41385484f, 2.49293518f, 2.32976985f, 2.89612579f, 2.66875219f, 1.47379971f, + 1.31164551f, -1.82183075f, -5.15272474f, 0.28575048f, 0.16861364f, -0.47264135f, 0.22565089f, -0.37727535f, + -1.13935280f, 0.38051969f, -2.38735437f, -2.80645251f, 0.18637873f, 2.13938355f, 2.92260599f, -0.38653925f, + 0.58366799f, -1.67636371f, -2.29396892f, -1.31527638f, 2.39795637f, 0.39815575f, -0.98530269f, -1.29227996f, + 0.14452982f, -0.38186538f, -1.71267688f, 0.18121701f, -2.26441002f, -0.94511753f, 0.27371156f, -2.44858527f, + -0.21510160f, -2.65228534f, -2.16755104f, 0.86151361f, 0.77589297f, -1.06628847f, 0.73745233f, 1.15778029f, + -0.73659700f, 0.74325305f, -1.97666430f, -1.07301974f, 0.17534591f, -1.66584718f, 1.21820331f, 0.67675018f, + -1.08938253f, 1.78010321f, 0.39817584f, -0.02914053f, 1.13571596f, -0.44081455f, 1.70561552f, -2.12085509f, + -0.69322622f, -1.87331009f, -2.15000772f, 2.08436966f, 1.70494926f, -3.69169927f, -1.22119129f, -1.60190558f, + -2.09093666f, -1.02816033f, -1.78743768f, 2.34501553f, 2.79939008f, 1.82245076f, 1.47408092f, 1.10063124f}; + } + + { + data.bias_data = { + -0.38124341f, 0.02696526f, -0.11914945f, -0.43795273f, -0.34948170f, -0.19608477f, 0.19725692f, 0.39987487f, + 0.04772711f, -0.03419551f, -0.30606642f, 0.42656231f, -0.23178342f, -0.13692456f, -0.04889601f, 0.48739988f, + -0.25891554f, 0.13431972f, 0.22861153f, 0.06360734f, 0.48096961f, -0.47906545f, 0.43613154f, -0.23511401f, + -0.10595283f, -0.42839217f, 0.28931111f, -0.13180739f, -0.45826656f, 0.23286396f, -0.43407962f, 0.40754890f, + 0.27079183f, 0.42074734f, -0.40314156f, -0.43726659f, 0.27376485f, -0.38174152f, -0.43700469f, 0.38040614f, + -0.40546918f, 0.06927037f, 0.16979086f, 0.41458064f, 0.07120579f, -0.08055863f, 0.12095112f, -0.27988660f, + 0.06004709f, -0.05600315f, -0.25510073f, 0.41887105f, -0.19016314f, 0.47241372f, 0.12890404f, -0.24272856f, + 0.21106839f, -0.40523255f, 0.10336459f, -0.11084765f, 0.42408967f, -0.15285304f, -0.28945464f, -0.25714916f, + -0.10567203f, 0.26791072f, -0.08976898f, 0.31341976f, 0.06027532f, 0.14307594f, 0.31587386f, 0.16180152f, + 0.34785229f, 0.00531715f, -0.35168743f, -0.11641458f, 0.39196932f, 0.44535065f, 0.43545735f, 0.15593112f, + 0.06171834f, -0.42181283f, -0.41170910f, 0.40969193f, -0.01510030f, 0.07973170f, -0.18156880f, 0.21522856f, + 0.03915739f, -0.20913908f, -0.47068381f, 0.35633272f, -0.35124153f, 0.36624825f, -0.05567622f, -0.35343069f}; + } + + { + data.fp32_output_data = { + 0.23503941f, 2.87619758f, 0.01845241f, -0.75242990f, 1.76869011f, -0.40492195f, -1.65323853f, 0.34011719f, + -2.10573196f, 0.13281155f, 0.97480160f, 2.74546146f, -1.21957457f, -0.73649400f, 2.52938581f, 1.65599120f, + 1.83545303f, 0.85856718f, -0.48040742f, 1.86428785f, 1.29504943f, 1.38906729f, 0.06474495f, -0.51972288f, + -0.66509569f, -1.45185244f, 0.36160457f, -2.63688278f, -0.10806514f, 0.71859169f, -3.98941422f, -1.58921516f, + + -1.89806330f, 1.03079379f, 2.20389438f, 0.07467184f, -0.39299977f, 1.51811528f, -0.04347950f, 0.61307698f, + 1.03990030f, 0.37965038f, 0.50865448f, -1.36013806f, 1.58397710f, 0.16757873f, 1.63505113f, -0.15062472f, + -0.41438234f, 0.12406474f, 0.90268815f, -1.09105420f, -2.84080887f, 0.03172458f, -0.18386938f, -0.85491556f, + 0.64164376f, 0.26578158f, -1.32860518f, 2.83676863f, 0.02389192f, 1.94164813f, -1.26734924f, 0.51129180f, + + -0.84226906f, 1.01116371f, -2.06643319f, -0.75959998f, 0.23562123f, -1.52277124f, 1.53407717f, 0.83855170f, + -0.74153024f, 1.78542042f, 0.04648840f, -0.14555511f, 1.52768528f, 0.00453609f, 2.14107275f, -1.96492398f, + -0.63150787f, -2.29512286f, -2.56171679f, 2.49406147f, 1.68984890f, -3.61196756f, -1.40276003f, -1.38667703f, + -2.05177927f, -1.23729944f, -2.25812149f, 2.70134830f, 2.44814849f, 2.18869901f, 1.41840470f, 0.74720055f, + + -0.84226906f, 1.01116371f, -2.06643319f, -0.75959998f, 0.23562123f, -1.52277124f, 1.53407717f, 0.83855170f, + -0.74153024f, 1.78542042f, 0.04648840f, -0.14555511f, 1.52768528f, 0.00453609f, 2.14107275f, -1.96492398f, + -0.63150787f, -2.29512286f, -2.56171679f, 2.49406147f, 1.68984890f, -3.61196756f, -1.40276003f, -1.38667703f, + -2.05177927f, -1.23729944f, -2.25812149f, 2.70134830f, 2.44814849f, 2.18869901f, 1.41840470f, 0.74720055f}; + } + + { + data.fp16_output_data = data.fp32_output_data; + } +} #endif void GetCrossAttentionData_HeadSize16_8(AttentionTestData& data) { @@ -1913,9 +2299,8 @@ void GetCrossAttentionData_HeadSize16_8(AttentionTestData& data) { data.kv_sequence_length = 3; data.mask_type = AttentionMaskType::MASK_NONE; data.skip_kernel_types = { - AttentionKernelType::AttentionKernel_TrtFusedCrossAttention, - AttentionKernelType::AttentionKernel_TrtFusedAttention - }; + AttentionKernelType::AttentionKernel_TrtFusedCrossAttention, + AttentionKernelType::AttentionKernel_TrtFusedAttention}; { data.query_data = { diff --git a/onnxruntime/test/contrib_ops/attention_op_test_helper.h b/onnxruntime/test/contrib_ops/attention_op_test_helper.h index d10eed07d8f9d..664fbb50aa6d7 100644 --- a/onnxruntime/test/contrib_ops/attention_op_test_helper.h +++ b/onnxruntime/test/contrib_ops/attention_op_test_helper.h @@ -28,11 +28,15 @@ struct AttentionTestData{ std::vector skip_kernel_types; // skip some kernels if they do not supported this test case. }; +// Disable some tests in Windows since prefast build might crash with large test data. #ifndef _MSC_VER // Return packed weights and bias for input projection. void GetAttentionWeight(std::vector& weight_data, int elements = 64 * 3 * 64, int offset = 0, int step=1); void GetAttentionBias(std::vector& bias_data, int elements = 3 * 64, int offset = 0, int step=1); + void GetCrossAttentionData_HeadSize40(AttentionTestData& data); +void GetCrossAttentionData_Batch2_HeadSize32_RightSidePadding(AttentionTestData& data, bool is_mask_1d); +void GetCrossAttentionData_Batch1_HeadSize32_LeftSidePadding(AttentionTestData& data); #endif void GetCrossAttentionData_HeadSize16_8(AttentionTestData& data); diff --git a/onnxruntime/test/contrib_ops/multihead_attention_op_test.cc b/onnxruntime/test/contrib_ops/multihead_attention_op_test.cc index 745f1e41991c6..26d832c64a6a6 100644 --- a/onnxruntime/test/contrib_ops/multihead_attention_op_test.cc +++ b/onnxruntime/test/contrib_ops/multihead_attention_op_test.cc @@ -74,7 +74,7 @@ static void RunMultiHeadAttentionTest( constexpr float rel_error = 0.0f; constexpr float abs_error = 0.05f; - tester.AddOutput("output", output_dims, ToFloat16(output_data), /*sort*/false, rel_error, abs_error); + tester.AddOutput("output", output_dims, ToFloat16(output_data), /*sort*/ false, rel_error, abs_error); } else { tester.AddInput("query", query_dims, query_data); tester.AddInput("key", key_dims, key_data); @@ -94,7 +94,7 @@ static void RunMultiHeadAttentionTest( constexpr float rel_error = 0.0f; constexpr float abs_error = 0.02f; - tester.AddOutput("output", output_dims, output_data, /*sort*/false, rel_error, abs_error); + tester.AddOutput("output", output_dims, output_data, /*sort*/ false, rel_error, abs_error); } if (enable_cuda) { @@ -124,7 +124,7 @@ static void RunMultiHeadAttentionKernel( const std::vector& bias_data, // bias: [hidden_size + hidden_size + v_hidden_size] const std::vector& key_padding_mask_data, // key_padding_mask: see below AttentionMaskType mask_type, // 1 for [batch_size], 2 for [batch_size, kv_sequence_length] - const std::vector& output_data, // output: [batch_size, sequence_length, v_hidden_size] + const std::vector& output_data, // output: [batch_size, sequence_length, v_hidden_size] int num_heads, int batch_size, int sequence_length, @@ -136,13 +136,13 @@ static void RunMultiHeadAttentionKernel( bool disable_cpu = true, // not supported in cpu right now. bool disable_cuda = false, bool disable_rocm = true) { - if (kernel_type == AttentionKernelType::AttentionKernel_Default) { ScopedEnvironmentVariables scoped_env_vars{ EnvVarMap{ - {onnxruntime::contrib::attention::kDisableFlashAttention, "0"}, + {onnxruntime::contrib::attention::kDisableTrtFlashAttention, "0"}, {onnxruntime::contrib::attention::kDisableFusedAttention, "0"}, - {onnxruntime::contrib::attention::kDisableFusedCrossAttention, "0"}}}; + {onnxruntime::contrib::attention::kDisableFusedCrossAttention, "0"}, + {onnxruntime::contrib::attention::kDisableMemoryEfficientAttention, "0"}}}; RunMultiHeadAttentionTest( query_data, key_data, value_data, bias_data, key_padding_mask_data, mask_type, output_data, num_heads, batch_size, sequence_length, kv_sequence_length, hidden_size, v_hidden_size, @@ -150,13 +150,13 @@ static void RunMultiHeadAttentionKernel( return; } - if (kernel_type == AttentionKernelType::AttentionKernel_Unfused) - { + if (kernel_type == AttentionKernelType::AttentionKernel_Unfused) { ScopedEnvironmentVariables scoped_env_vars{ EnvVarMap{ - {onnxruntime::contrib::attention::kDisableFlashAttention, "1"}, + {onnxruntime::contrib::attention::kDisableTrtFlashAttention, "1"}, {onnxruntime::contrib::attention::kDisableFusedAttention, "1"}, - {onnxruntime::contrib::attention::kDisableFusedCrossAttention, "1"}}}; + {onnxruntime::contrib::attention::kDisableFusedCrossAttention, "1"}, + {onnxruntime::contrib::attention::kDisableMemoryEfficientAttention, "1"}}}; RunMultiHeadAttentionTest( query_data, key_data, value_data, bias_data, key_padding_mask_data, mask_type, output_data, num_heads, batch_size, sequence_length, kv_sequence_length, hidden_size, v_hidden_size, @@ -164,13 +164,13 @@ static void RunMultiHeadAttentionKernel( return; } - if (kernel_type == AttentionKernelType::AttentionKernel_TrtFusedCrossAttention) - { + if (kernel_type == AttentionKernelType::AttentionKernel_TrtFusedCrossAttention) { ScopedEnvironmentVariables scoped_env_vars{ EnvVarMap{ - {onnxruntime::contrib::attention::kDisableFlashAttention, "1"}, + {onnxruntime::contrib::attention::kDisableTrtFlashAttention, "1"}, {onnxruntime::contrib::attention::kDisableFusedAttention, "1"}, - {onnxruntime::contrib::attention::kDisableFusedCrossAttention, "0"}}}; + {onnxruntime::contrib::attention::kDisableFusedCrossAttention, "0"}, + {onnxruntime::contrib::attention::kDisableMemoryEfficientAttention, "1"}}}; RunMultiHeadAttentionTest( query_data, key_data, value_data, bias_data, key_padding_mask_data, mask_type, output_data, num_heads, batch_size, sequence_length, kv_sequence_length, hidden_size, v_hidden_size, @@ -178,13 +178,29 @@ static void RunMultiHeadAttentionKernel( return; } - if (kernel_type == AttentionKernelType::AttentionKernel_TrtFusedAttention) - { +#if USE_FLASH_ATTENTION + if (kernel_type == AttentionKernelType::AttentionKernel_CutlassMemoryEfficientAttention) { ScopedEnvironmentVariables scoped_env_vars{ EnvVarMap{ - {onnxruntime::contrib::attention::kDisableFlashAttention, "0"}, + {onnxruntime::contrib::attention::kDisableTrtFlashAttention, "1"}, + {onnxruntime::contrib::attention::kDisableFusedAttention, "1"}, + {onnxruntime::contrib::attention::kDisableFusedCrossAttention, "1"}, + {onnxruntime::contrib::attention::kDisableMemoryEfficientAttention, "0"}}}; + RunMultiHeadAttentionTest( + query_data, key_data, value_data, bias_data, key_padding_mask_data, mask_type, output_data, + num_heads, batch_size, sequence_length, kv_sequence_length, hidden_size, v_hidden_size, + use_float16, disable_cpu, disable_cuda, disable_rocm); + return; + } +#endif + + if (kernel_type == AttentionKernelType::AttentionKernel_TrtFusedAttention) { + ScopedEnvironmentVariables scoped_env_vars{ + EnvVarMap{ + {onnxruntime::contrib::attention::kDisableTrtFlashAttention, "0"}, {onnxruntime::contrib::attention::kDisableFusedAttention, "0"}, - {onnxruntime::contrib::attention::kDisableFusedCrossAttention, "1"}}}; + {onnxruntime::contrib::attention::kDisableFusedCrossAttention, "1"}, + {onnxruntime::contrib::attention::kDisableMemoryEfficientAttention, "1"}}}; RunMultiHeadAttentionTest( query_data, key_data, value_data, bias_data, key_padding_mask_data, mask_type, output_data, num_heads, batch_size, sequence_length, kv_sequence_length, hidden_size, v_hidden_size, @@ -204,11 +220,21 @@ static void RunMultiHeadAttentionTests(AttentionTestData& data) { data.hidden_size, data.v_hidden_size, kernel_type, use_float16); } - kernel_type = AttentionKernelType::AttentionKernel_Default; - RunMultiHeadAttentionKernel( +#if USE_FLASH_ATTENTION + kernel_type = AttentionKernelType::AttentionKernel_CutlassMemoryEfficientAttention; + if (!SkipAttentionKernel(data, kernel_type)) { + RunMultiHeadAttentionKernel( data.query_data, data.key_data, data.value_data, data.bias_data, data.key_padding_mask_data, data.mask_type, data.fp32_output_data, data.num_heads, data.batch_size, data.sequence_length, data.kv_sequence_length, data.hidden_size, data.v_hidden_size, kernel_type, use_float16); + } +#endif + + kernel_type = AttentionKernelType::AttentionKernel_Default; + RunMultiHeadAttentionKernel( + data.query_data, data.key_data, data.value_data, data.bias_data, data.key_padding_mask_data, data.mask_type, + data.fp32_output_data, data.num_heads, data.batch_size, data.sequence_length, data.kv_sequence_length, + data.hidden_size, data.v_hidden_size, kernel_type, use_float16); } if (data.fp16_output_data.size() > 0) { @@ -229,6 +255,16 @@ static void RunMultiHeadAttentionTests(AttentionTestData& data) { data.hidden_size, data.v_hidden_size, kernel_type, use_float16); } +#if USE_FLASH_ATTENTION + kernel_type = AttentionKernelType::AttentionKernel_CutlassMemoryEfficientAttention; + if (!SkipAttentionKernel(data, kernel_type)) { + RunMultiHeadAttentionKernel( + data.query_data, data.key_data, data.value_data, data.bias_data, data.key_padding_mask_data, data.mask_type, + data.fp16_output_data, data.num_heads, data.batch_size, data.sequence_length, data.kv_sequence_length, + data.hidden_size, data.v_hidden_size, kernel_type, use_float16); + } +#endif + kernel_type = AttentionKernelType::AttentionKernel_Default; RunMultiHeadAttentionKernel( data.query_data, data.key_data, data.value_data, data.bias_data, data.key_padding_mask_data, data.mask_type, @@ -245,6 +281,24 @@ TEST(MultiHeadAttentionTest, CrossAttention_Batch2_HeadSize40) { GetCrossAttentionData_HeadSize40(data); RunMultiHeadAttentionTests(data); } + +TEST(MultiHeadAttentionTest, CrossAttention_Batch2_HeadSize32_RightSidePadding_Mask1D) { + AttentionTestData data; + GetCrossAttentionData_Batch2_HeadSize32_RightSidePadding(data, true); + RunMultiHeadAttentionTests(data); +} + +TEST(MultiHeadAttentionTest, CrossAttention_Batch2_HeadSize32_RightSidePadding_Mask2D) { + AttentionTestData data; + GetCrossAttentionData_Batch2_HeadSize32_RightSidePadding(data, false); + RunMultiHeadAttentionTests(data); +} + +TEST(MultiHeadAttentionTest, CrossAttention_Batch1_HeadSize32_LeftSidePadding_Mask2D) { + AttentionTestData data; + GetCrossAttentionData_Batch1_HeadSize32_LeftSidePadding(data); + RunMultiHeadAttentionTests(data); +} #endif // This tests qk_head_size != k_head_size @@ -260,6 +314,5 @@ TEST(MultiHeadAttentionTest, CrossAttention_Batch1_HeadSize16) { RunMultiHeadAttentionTests(data); } - } // namespace test } // namespace onnxruntime diff --git a/onnxruntime/test/contrib_ops/multihead_attention_op_test_data_gen.py b/onnxruntime/test/contrib_ops/multihead_attention_op_test_data_gen.py index 761bba59772f3..2d7bbc9b6d6c0 100644 --- a/onnxruntime/test/contrib_ops/multihead_attention_op_test_data_gen.py +++ b/onnxruntime/test/contrib_ops/multihead_attention_op_test_data_gen.py @@ -36,6 +36,8 @@ def __init__( self.value = nn.Linear(hidden_dim, self.v_hidden_size) self.is_decoder = is_decoder + # Do not reshape output for pretty print. + self.reshape_output = False self.verbose = False def transpose_for_scores(self, x: torch.Tensor, head_size) -> torch.Tensor: @@ -43,7 +45,7 @@ def transpose_for_scores(self, x: torch.Tensor, head_size) -> torch.Tensor: x = x.view(new_x_shape) return x.permute(0, 2, 1, 3) - def get_extended_attention_mask(self, attention_mask: Tensor) -> Tensor: + def get_extended_attention_mask(self, attention_mask: Tensor, dtype: torch.dtype) -> Tensor: assert attention_mask.dim() == 2 or attention_mask.dim() == 3 extended_attention_mask = ( attention_mask[:, None, :, :] if attention_mask.dim() == 3 else attention_mask[:, None, None, :] @@ -120,7 +122,7 @@ def forward( if attention_mask is not None: # Apply the attention mask is (precomputed for all layers in BertModel forward() function) - attention_scores = attention_scores + self.get_extended_attention_mask(attention_mask) + attention_scores = attention_scores + self.get_extended_attention_mask(attention_mask, hidden_states.dtype) # Normalize the attention scores to probabilities. attention_probs = nn.functional.softmax(attention_scores, dim=-1) @@ -131,8 +133,9 @@ def forward( context_layer = context_layer.permute(0, 2, 1, 3).contiguous() - # new_context_layer_shape = context_layer.size()[:-2] + (self.v_hidden_size,) - # context_layer = context_layer.view(new_context_layer_shape) + if self.reshape_output: + new_context_layer_shape = context_layer.size()[:-2] + (self.v_hidden_size,) + context_layer = context_layer.view(new_context_layer_shape) print("output", context_layer) @@ -144,7 +147,7 @@ def forward( return outputs -def generate_test_data( +def run_cross_attention( hidden_dim, q_head_size, v_head_size, @@ -161,7 +164,8 @@ def generate_test_data( device = torch.device("cuda:0") mha = Attention(num_heads, hidden_dim, q_head_size, v_head_size, is_decoder=False).to(device).eval() - + if key_padding_mask is not None: + key_padding_mask = key_padding_mask.to(device) torch.nn.init.uniform_(mha.query.weight, -0.5, 0.5) torch.nn.init.uniform_(mha.key.weight, -0.5, 0.5) torch.nn.init.uniform_(mha.value.weight, -0.5, 0.5) @@ -205,10 +209,9 @@ def generate_test_data( past_key_value=None, output_attentions=False, ) - print("output", output) -def CrossAttention_Batch2_HeadSize40(): +def run_cross_batch2_headsize_40(): hidden_dim = 80 q_head_size = 40 v_head_size = 40 @@ -216,10 +219,12 @@ def CrossAttention_Batch2_HeadSize40(): batch_size = 2 sequence_length = 3 kv_sequence_length = 5 - generate_test_data(hidden_dim, q_head_size, v_head_size, num_heads, batch_size, sequence_length, kv_sequence_length) + run_cross_attention( + hidden_dim, q_head_size, v_head_size, num_heads, batch_size, sequence_length, kv_sequence_length + ) -def CrossAttention_Batch1_HeadSize16(): +def run_cross_batch1_headsize_16(): hidden_dim = 32 q_head_size = 16 v_head_size = 16 @@ -227,10 +232,12 @@ def CrossAttention_Batch1_HeadSize16(): batch_size = 1 sequence_length = 2 kv_sequence_length = 3 - generate_test_data(hidden_dim, q_head_size, v_head_size, num_heads, batch_size, sequence_length, kv_sequence_length) + run_cross_attention( + hidden_dim, q_head_size, v_head_size, num_heads, batch_size, sequence_length, kv_sequence_length + ) -def CrossAttention_Batch2_HeadSize16_8(): +def run_cross_batch2_headsize_16_8(): hidden_dim = 32 q_head_size = 16 v_head_size = 8 @@ -238,15 +245,73 @@ def CrossAttention_Batch2_HeadSize16_8(): batch_size = 2 sequence_length = 1 kv_sequence_length = 3 - generate_test_data(hidden_dim, q_head_size, v_head_size, num_heads, batch_size, sequence_length, kv_sequence_length) + run_cross_attention( + hidden_dim, q_head_size, v_head_size, num_heads, batch_size, sequence_length, kv_sequence_length + ) -with torch.no_grad(): +def run_cross_batch2_headsize_32_right_side_padding(): + hidden_dim = 64 + q_head_size = 32 + v_head_size = 32 + num_heads = 2 + batch_size = 2 + sequence_length = 2 + kv_sequence_length = 3 + key_padding_mask = torch.tensor([[1, 0, 0], [1, 1, 0]], dtype=torch.int32).cuda() + + run_cross_attention( + hidden_dim, + q_head_size, + v_head_size, + num_heads, + batch_size, + sequence_length, + kv_sequence_length, + key_padding_mask, + ) + + +def run_cross_batch1_headsize_32_left_side_padding(): + hidden_dim = 32 + q_head_size = 32 + v_head_size = 32 + num_heads = 1 + batch_size = 2 + sequence_length = 2 + kv_sequence_length = 3 + key_padding_mask = torch.tensor([[0, 1, 1], [0, 0, 1]], dtype=torch.int32).cuda() + run_cross_attention( + hidden_dim, + q_head_size, + v_head_size, + num_heads, + batch_size, + sequence_length, + kv_sequence_length, + key_padding_mask, + ) + + +def create_cross_attention_test_data(): + """ + Create test data used in attention_op_test_helper.cc and multihead_attention_op_test.cc + """ print("CrossAttention_Batch2_HeadSize40") - CrossAttention_Batch2_HeadSize40() + run_cross_batch2_headsize_40() - rint("CrossAttention_Batch1_HeadSize16") - CrossAttention_Batch1_HeadSize16() + print("CrossAttention_Batch1_HeadSize16") + run_cross_batch1_headsize_16() print("CrossAttention_Batch2_HeadSize16_8") - CrossAttention_Batch2_HeadSize16_8() + run_cross_batch2_headsize_16_8() + + print("CrossAttention_Batch2_HeadSize32_RightSidePadding") + run_cross_batch2_headsize_32_right_side_padding() + + print("CrossAttention_Batch1_HeadSize32_LeftSidePadding") + run_cross_batch1_headsize_32_left_side_padding() + + +with torch.no_grad(): + create_cross_attention_test_data()