Skip to content

Commit

Permalink
Add memory efficient attention from CUTLASS (#14343)
Browse files Browse the repository at this point in the history
### Description
Add memory efficient attention from CUTLASS.

TODO (in next pull request): 
(1) Need performance tests on different GPUs, then add a sequence length
threshold (only activate it for long sequence length).
(2) Merge changes from NVIDIA/cutlass#773 when
it is in cutlass master.
  • Loading branch information
tianleiwu authored Jan 20, 2023
1 parent e64f357 commit 414b012
Show file tree
Hide file tree
Showing 29 changed files with 2,140 additions and 75 deletions.
31 changes: 31 additions & 0 deletions ThirdPartyNotices.txt
Original file line number Diff line number Diff line change
Expand Up @@ -2753,6 +2753,37 @@ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.

_____
nvidia/cutlass

Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
SPDX-License-Identifier: BSD-3-Clause

Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions are met:

1. Redistributions of source code must retain the above copyright notice, this
list of conditions and the following disclaimer.

2. Redistributions in binary form must reproduce the above copyright notice,
this list of conditions and the following disclaimer in the documentation
and/or other materials provided with the distribution.

3. Neither the name of the copyright holder nor the names of its
contributors may be used to endorse or promote products derived from
this software without specific prior written permission.

THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.

_____
Boost

Expand Down
25 changes: 23 additions & 2 deletions cmake/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,8 @@ option(onnxruntime_USE_PREINSTALLED_EIGEN "Use pre-installed EIGEN. Need to prov
option(onnxruntime_BUILD_BENCHMARKS "Build ONNXRuntime micro-benchmarks" OFF)
option(onnxruntime_USE_LLVM "Build TVM with LLVM" OFF)

option(onnxruntime_USE_FLASH_ATTENTION "Build memory efficient attention kernel for scaled dot product attention" ON)

option(onnxruntime_BUILD_FOR_NATIVE_MACHINE "Enable this option for turning on optimization specific to this machine" OFF)
option(onnxruntime_USE_AVX "Use AVX instructions" OFF)
option(onnxruntime_USE_AVX2 "Use AVX2 instructions" OFF)
Expand Down Expand Up @@ -595,10 +597,31 @@ set(ONNXRUNTIME_PROVIDER_NAMES cpu)
set(ORT_PROVIDER_FLAGS)
set(ORT_PROVIDER_CMAKE_FLAGS)

if (onnxruntime_USE_CUDA)
enable_language(CUDA)
message( STATUS "CMAKE_CUDA_COMPILER_VERSION: ${CMAKE_CUDA_COMPILER_VERSION}")

if (onnxruntime_DISABLE_CONTRIB_OPS)
set(onnxruntime_USE_FLASH_ATTENTION OFF)
endif()
if (CMAKE_CUDA_COMPILER_VERSION VERSION_LESS 11.6)
message( STATUS "Turn off flash attention since CUDA compiler version < 11.6")
set(onnxruntime_USE_FLASH_ATTENTION OFF)
endif()
else()
set(onnxruntime_USE_FLASH_ATTENTION OFF)
endif()

if (onnxruntime_USE_CUDA)
list(APPEND ORT_PROVIDER_FLAGS -DUSE_CUDA=1)
list(APPEND ORT_PROVIDER_CMAKE_FLAGS -Donnxruntime_USE_CUDA=1)
list(APPEND ONNXRUNTIME_PROVIDER_NAMES cuda)

if (onnxruntime_USE_FLASH_ATTENTION)
message( STATUS "Enable flash attention for CUDA EP")
list(APPEND ORT_PROVIDER_FLAGS -DUSE_FLASH_ATTENTION=1)
list(APPEND ORT_PROVIDER_CMAKE_FLAGS -Donnxruntime_USE_FLASH_ATTENTION=1)
endif()
endif()
if (onnxruntime_USE_VITISAI)
list(APPEND ORT_PROVIDER_FLAGS -DUSE_VITISAI=1)
Expand Down Expand Up @@ -1234,8 +1257,6 @@ endif()

if (onnxruntime_USE_CUDA)
set(CMAKE_CUDA_RUNTIME_LIBRARY Shared)
enable_language(CUDA)
message( STATUS "CMAKE_CUDA_COMPILER_VERSION: ${CMAKE_CUDA_COMPILER_VERSION}")
set(CMAKE_CUDA_STANDARD 17)
if(onnxruntime_CUDNN_HOME)
file(TO_CMAKE_PATH ${onnxruntime_CUDNN_HOME} onnxruntime_CUDNN_HOME)
Expand Down
12 changes: 12 additions & 0 deletions cmake/external/cutlass.cmake
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
if (onnxruntime_USE_FLASH_ATTENTION)
include(FetchContent)
FetchContent_Declare(cutlass
GIT_REPOSITORY https://github.com/nvidia/cutlass.git
GIT_TAG 8b42e751c63ba219755c8ed91af5f6ec1ecc1ee6
)

FetchContent_GetProperties(cutlass)
if(NOT cutlass_POPULATED)
FetchContent_Populate(cutlass)
endif()
endif()
6 changes: 6 additions & 0 deletions cmake/onnxruntime_providers.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -484,6 +484,12 @@ if (onnxruntime_USE_CUDA)
if(onnxruntime_CUDNN_HOME)
target_include_directories(onnxruntime_providers_cuda PRIVATE ${onnxruntime_CUDNN_HOME}/include)
endif()

if (onnxruntime_USE_FLASH_ATTENTION)
include(cutlass)
target_include_directories(onnxruntime_providers_cuda PRIVATE ${cutlass_SOURCE_DIR}/include ${cutlass_SOURCE_DIR}/examples)
endif()

target_include_directories(onnxruntime_providers_cuda PRIVATE ${ONNXRUNTIME_ROOT} ${CMAKE_CURRENT_BINARY_DIR} ${eigen_INCLUDE_DIRS} ${TVM_INCLUDES} PUBLIC ${CMAKE_CUDA_TOOLKIT_INCLUDE_DIRECTORIES})
# ${CMAKE_CURRENT_BINARY_DIR} is so that #include "onnxruntime_config.h" inside tensor_shape.h is found
set_target_properties(onnxruntime_providers_cuda PROPERTIES LINKER_LANGUAGE CUDA)
Expand Down
1 change: 1 addition & 0 deletions cmake/onnxruntime_rocm_hipify.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ set(contrib_ops_excluded_files
"bert/skip_layer_norm.h"
"bert/skip_layer_norm_impl.cu"
"bert/skip_layer_norm_impl.h"
"bert/cutlass_fmha/*"
"bert/tensorrt_fused_multihead_attention/*"
"bert/transformer_common.h"
"bert/transformer_common.cc"
Expand Down
10 changes: 7 additions & 3 deletions onnxruntime/contrib_ops/cpu/bert/attention_common.h
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ enum AttentionMaskType {

enum AttentionQkvFormat {
Q_K_V_BNSH, // for unfused attention
Q_K_V_BSNH, // input format of query, key and value for MultiHeadAttention
Q_K_V_BSNH, // for memory efficient attention, or format of query, key and value for MultiHeadAttention
QKV_BSN3H, // for TRT fused attention, qkv are packed
Q_K_V_BNSH_QKV_BS3NH, // for TRT fused causal attention, data has two formats (qkv is 3BNSH, gemm_buffer is BS3NH)
Q_KV_BSNH_BSN2H, // for TRT fused cross attention, kv are packed
Expand All @@ -30,6 +30,7 @@ enum AttentionKernelType{
AttentionKernel_TrtFusedAttention,
AttentionKernel_TrtFlashAttention,
AttentionKernel_TrtFusedCrossAttention,
AttentionKernel_CutlassMemoryEfficientAttention,
AttentionKernel_Default
};

Expand Down Expand Up @@ -61,8 +62,11 @@ constexpr const char* kDisableFusedAttention = "ORT_DISABLE_FUSED_ATTENTION";
// Environment variable to enable or disable fused cross attention kernel. Default is 0 (enabled).
constexpr const char* kDisableFusedCrossAttention = "ORT_DISABLE_FUSED_CROSS_ATTENTION";

// Environment variable to enable or disable flash attention. Default is 0 (enabled).
constexpr const char* kDisableFlashAttention = "ORT_DISABLE_FLASH_ATTENTION";
// Environment variable to enable or disable TRT flash attention. Default is 0 (enabled).
constexpr const char* kDisableTrtFlashAttention = "ORT_DISABLE_TRT_FLASH_ATTENTION";

// Environment variable to enable or disable cutlass memory efficient attention. Default is 0 (enabled).
constexpr const char* kDisableMemoryEfficientAttention = "ORT_DISABLE_MEMORY_EFFICIENT_ATTENTION";

} // namespace attention

Expand Down
118 changes: 118 additions & 0 deletions onnxruntime/contrib_ops/cuda/bert/add_bias_transpose.cu
Original file line number Diff line number Diff line change
Expand Up @@ -320,6 +320,110 @@ __global__ void AddBiasTransposeQKVLarge(const int head_size, const T* input, co
}
}

template <typename T>
__global__ void AddBiasTransposeCutlass(const T* input, const T* biases, T* output, int v_head_size) {
// Format 3 for cutlass memory efficient attention
// Input: BxSx(NxH + NxH + NxH_v) (Packed QKV where K and V has different hidden sizes)
// Output: BxNxSxH + BxNxSxH + BxNxSxH_v
// B is batch_size, S is sequence_length, N is num_heads, H is qk_head_size, H_v is v_head_size
int n = threadIdx.y; // head_num_id
int s = blockIdx.x; // sequence_id
int b = blockIdx.y; // batch_id
int m = blockIdx.z; // matrix id (Q=0, K=1, V=2)
const int h = threadIdx.x; // head_element_id

const int qk_head_size = blockDim.x;
const int num_heads = blockDim.y;

const int sequence_length = gridDim.x;
const int batch_size = gridDim.y;

const int head_size = (m == 2 ? v_head_size : qk_head_size);

const int total_head_size = num_heads * (qk_head_size + qk_head_size + v_head_size);

int in_offset;
int out_offset;
int bias_offset;
in_offset = b * (total_head_size * sequence_length) + // B
s * (total_head_size) + // S
m * (qk_head_size * num_heads) + // M
n * head_size + // N
h; // H

out_offset = m * (num_heads * qk_head_size * sequence_length * batch_size) + // M
b * (num_heads * head_size * sequence_length) + // B
s * (num_heads * head_size) + // S
n * (head_size) + // N
h; // H

bias_offset = m * (num_heads * qk_head_size) + // M
n * (head_size) + // N
h; // H

if (h < head_size) {
output[out_offset] = input[in_offset] + biases[bias_offset];
}
}

template <typename T>
__global__ void AddBiasTransposeCutlass(int M, const T* input, const T* biases, T* output) {
// Format 3 for cutlass memory efficient attention
// Input: BxSxMxNxH
// Output: MxBxSxNxH
// B is batch_size, S is sequence_length, M is number of matrices, N is num_heads, H is head_size
int n = threadIdx.y;
int s = blockIdx.x;
int b = blockIdx.y;
int m = blockIdx.z; // matrix id

const int head_size = blockDim.x;
const int num_heads = blockDim.y;

const int sequence_length = gridDim.x;
const int batch_size = gridDim.y;
const int H = head_size;
const int NH = num_heads * head_size;
const int NHS = NH * sequence_length;

int in_offset = n * head_size + (m + s * M) * NH + b * NHS * M;
const int out_offset = n * head_size + s * NH + b * NHS + m * NHS * batch_size;

const int h = threadIdx.x;
if (h < head_size) {
output[out_offset + h] = input[in_offset + h] + biases[m * NH + n * H + h];
}
}

template <typename T>
__global__ void AddBiasTransposeCutlassLarge(const int head_size, const T* input, const T* biases, T* output,
const int M) {
// Format 3 for cutlass memory efficient attention
// Input: BxSxMxNxH (Packed QKV)
// Output: MxBxSxNxH
int n = threadIdx.y;
int s = blockIdx.x;
int b = blockIdx.y;
int m = blockIdx.z; // matrix id

const int stride = blockDim.x;
const int num_heads = blockDim.y;

const int sequence_length = gridDim.x;
const int batch_size = gridDim.y;
const int H = head_size;
const int NH = num_heads * H;
const int NHS = NH * sequence_length;
int in_offset = n * H + (m + s * M) * NH + b * NHS * M;
const int out_offset = n * H + s * NH + b * NHS + m * NHS * batch_size;

int h = threadIdx.x;
while (h < H) {
output[out_offset + h] = input[in_offset + h] + biases[m * NH + n * H + h];
h += stride;
}
}

template <typename T>
__global__ void AddBiasTranspose(const T* input, const T* biases, T* output) {
// Format 0 for Separated Q, K, V (N*H <= 1024)
Expand Down Expand Up @@ -395,6 +499,13 @@ void InvokeAddBiasTranspose(
ORT_ENFORCE(total_matrix_count == 3);
AddBiasTransposeQKV<T><<<grid, block, 0, stream>>>(input, biases, output, v_head_size);
}
} else if (format == 3) {
if (v_head_size == -1 || qk_head_size == v_head_size) {
AddBiasTransposeCutlass<T><<<grid, block, 0, stream>>>(total_matrix_count, input, biases, output);
} else {
ORT_ENFORCE(total_matrix_count == 3);
AddBiasTransposeCutlass<T><<<grid, block, 0, stream>>>(input, biases, output, v_head_size);
}
} else { // format == 0
AddBiasTranspose<T><<<grid, block, 0, stream>>>(input, biases, output);
}
Expand All @@ -410,6 +521,13 @@ void InvokeAddBiasTranspose(
// It is rare for hidden size > 4096 (for half precision) and qk_head_size != v_head_size.
ORT_THROW("AddBiasTranspose (format 1) not implemented for hidden_size > max_threads_per_block when qk_head_size != v_head_size");
}
} else if (format == 3) {
if (v_head_size == -1 || qk_head_size == v_head_size) {
AddBiasTransposeCutlassLarge<T><<<grid, block, 0, stream>>>(qk_head_size, input, biases, output,
total_matrix_count);
} else {
ORT_THROW("AddBiasTranspose (format 3) not implemented for hidden_size > max_threads_per_block when qk_head_size != v_head_size");
}
} else { // format 0
AddBiasTransposeLarge<T><<<grid, block, 0, stream>>>(qk_head_size, input, biases, output);
}
Expand Down
3 changes: 3 additions & 0 deletions onnxruntime/contrib_ops/cuda/bert/add_bias_transpose.h
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,9 @@ namespace cuda {
// format 2:
// input : (batch_size, sequence_length, num_matrices, num_heads, head_size)
// output: (batch_size, sequence_length, num_heads, num_matrices, head_size)
// format 3: (requires sequence_length = kv_sequence_length and qk_head_size = v_head_size when num_matrices == 3)
// input: (batch_size, sequence_length, num_matrices, num_heads, head_size)
// output: (num_matrices, batch_size, sequence_length, num_heads, head_size)
template <typename T>
void LaunchAddBiasTranspose(
cudaStream_t stream, const int num_matrices, const int format, const int max_threads_per_block,
Expand Down
35 changes: 28 additions & 7 deletions onnxruntime/contrib_ops/cuda/bert/attention.cc
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
#include "contrib_ops/cuda/bert/attention_impl.h"
#include "contrib_ops/cuda/bert/attention.h"
#include "contrib_ops/cuda/bert/bert_padding.h"
#include "contrib_ops/cuda/bert/cutlass_fmha/memory_efficient_attention.h"

using namespace onnxruntime::cuda;
using namespace ::onnxruntime::common;
Expand Down Expand Up @@ -41,8 +42,14 @@ Attention<T>::Attention(const OpKernelInfo& info) : CudaKernel(info), AttentionB
disable_fused_runner_ = sizeof(T) != 2 ||
ParseEnvironmentVariableWithDefault<bool>(attention::kDisableFusedAttention, false);

enable_flash_attention_ = sizeof(T) == 2 &&
!ParseEnvironmentVariableWithDefault<bool>(attention::kDisableFlashAttention, false);
enable_trt_flash_attention_ = sizeof(T) == 2 &&
!ParseEnvironmentVariableWithDefault<bool>(attention::kDisableTrtFlashAttention, false);

#if USE_FLASH_ATTENTION
disable_memory_efficient_attention_ = ParseEnvironmentVariableWithDefault<bool>(attention::kDisableMemoryEfficientAttention, false);
#else
disable_memory_efficient_attention_ = true;
#endif
}

template <typename T>
Expand Down Expand Up @@ -102,12 +109,12 @@ Status Attention<T>::ComputeInternal(OpKernelContext* context) const {
parameters.hidden_size == parameters.v_hidden_size &&
parameters.sequence_length == parameters.kv_sequence_length &&
FusedMHARunnerFP16v2::is_supported(sm, parameters.head_size, sequence_length,
enable_flash_attention_, true);
enable_trt_flash_attention_, true);
if (use_causal_fused_runner) {
// Here we assume that num_heads, head_size and is_unidirectional does not change for an Attention node.
if (nullptr == fused_fp16_runner_.get()) {
fused_fp16_runner_.reset(new FusedMHARunnerFP16v2(num_heads_, parameters.head_size, sm, is_unidirectional_,
enable_flash_attention_, parameters.scale));
enable_trt_flash_attention_, parameters.scale));
}

// Here we assume all causal kernels can be loaded into shared memory. TODO: add a function to check.
Expand All @@ -122,13 +129,13 @@ Status Attention<T>::ComputeInternal(OpKernelContext* context) const {
parameters.hidden_size == parameters.v_hidden_size &&
parameters.sequence_length == parameters.kv_sequence_length &&
FusedMHARunnerFP16v2::is_supported(sm, parameters.head_size, sequence_length,
enable_flash_attention_, false);
enable_trt_flash_attention_, false);

if (use_fused_runner) {
// Here we assume that num_heads, head_size and is_unidirectional does not change for an Attention node.
if (nullptr == fused_fp16_runner_.get()) {
fused_fp16_runner_.reset(new FusedMHARunnerFP16v2(num_heads_, parameters.head_size, sm, is_unidirectional_,
enable_flash_attention_, parameters.scale));
enable_trt_flash_attention_, parameters.scale));
}

// In case some kernel not loaded due to shared memory limit, we need to double check here.
Expand All @@ -139,6 +146,18 @@ Status Attention<T>::ComputeInternal(OpKernelContext* context) const {
}
}

#if USE_FLASH_ATTENTION
bool use_memory_efficient_attention = fused_runner == nullptr &&
!disable_memory_efficient_attention_ &&
nullptr == mask_index && // TODO: support 1D mask
nullptr == past &&
nullptr == present &&
nullptr == extra_add_qk &&
has_memory_efficient_attention(sm, sizeof(T) == 2);
#else
constexpr bool use_memory_efficient_attention = false;
#endif

cublasHandle_t cublas = GetCublasHandle(context);

typedef typename ToCudaType<T>::MappedType CudaT;
Expand Down Expand Up @@ -169,7 +188,8 @@ Status Attention<T>::ComputeInternal(OpKernelContext* context) const {
parameters.sequence_length,
parameters.kv_sequence_length,
parameters.total_sequence_length,
fused_runner);
fused_runner,
use_memory_efficient_attention);
auto work_space = GetScratchBuffer<void>(workSpaceSize, context->GetComputeStream());

typedef typename ToCudaType<T>::MappedType CudaT;
Expand All @@ -188,6 +208,7 @@ Status Attention<T>::ComputeInternal(OpKernelContext* context) const {
data.present = (nullptr == present) ? nullptr : reinterpret_cast<CudaT*>(present->MutableData<T>());
data.fused_runner = reinterpret_cast<void*>(fused_runner);
data.fused_cross_attention_kernel = nullptr;
data.use_memory_efficient_attention = use_memory_efficient_attention;

return QkvToContext<CudaT>(device_prop, cublas, Stream(context), parameters, data);
}
Expand Down
Loading

0 comments on commit 414b012

Please sign in to comment.