diff --git a/docs/ContribOperators.md b/docs/ContribOperators.md index f0543f2649205..c1d12a1d5cba6 100644 --- a/docs/ContribOperators.md +++ b/docs/ContribOperators.md @@ -2310,12 +2310,12 @@ This version of the operator has been available since version 1 of the 'com.micr
Number of attention heads
-#### Inputs (2 - 6) +#### Inputs (1 - 6)
query : T
-
Query with shape (batch_size, sequence_length, hidden_size)
-
key : T
+
Query with shape (batch_size, sequence_length, hidden_size), or packed QKV with shape (batch_size, kv_sequence_length, num_heads, 3, head_size)
+
key (optional) : T
Key with shape (batch_size, kv_sequence_length, hidden_size), or packed KV with shape (batch_size, kv_sequence_length, num_heads, 2, head_size)
value (optional) : T
Value with shape (batch_size, kv_sequence_length, v_hidden_size)
diff --git a/onnxruntime/contrib_ops/cpu/bert/multihead_attention_helper.h b/onnxruntime/contrib_ops/cpu/bert/multihead_attention_helper.h index ee1720b9f43bb..34a615a880594 100644 --- a/onnxruntime/contrib_ops/cpu/bert/multihead_attention_helper.h +++ b/onnxruntime/contrib_ops/cpu/bert/multihead_attention_helper.h @@ -22,53 +22,79 @@ Status CheckInputs(const T* query, int num_heads, float mask_filter_value, int max_threads_per_block) { + // key_padding_mask (K/V) : (B) or (B, L) or None + // relative_position_bias : (B, 1, S, L) + // When no packing for q/k/v: // query (Q) : (B, S, D) // key (K) : (B, L, D) // value (V) : (B, L, D_v) // bias (Q/K/V) : (D + D + D_v) - // key_padding_mask (K/V) : (B) or (B, L) or None - // relative_position_bias : (B, 1, S, L) // When packed kv is used: + // query (Q) : (B, S, D) // key (K) : (B, L, N, 2, H) // value (V) : None // bias (Q/K/V) : None + // When packed qkv is used: + // query (Q) : (B, L, N, 3, H) + // key (K) : None + // value (V) : None + // bias (Q/K/V) : None const auto& query_dims = query->Shape().GetDims(); - if (query_dims.size() != 3) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input 'query' is expected to have 3 dimensions, got ", + if (query_dims.size() != 3 && query_dims.size() != 5) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input 'query' is expected to have 3 or 5 dimensions, got ", query_dims.size()); } - const auto& key_dims = key->Shape().GetDims(); - if (key_dims.size() != 3 && key_dims.size() != 5) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input 'key' is expected to have 3 or 5 dimensions, got ", - key_dims.size()); - } - if (query_dims[0] != key_dims[0]) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, - "Input 'query' and 'key' shall have same dim 0 (batch size)"); - } - int batch_size = static_cast(query_dims[0]); int sequence_length = static_cast(query_dims[1]); - int hidden_size = static_cast(query_dims[2]); + int hidden_size = query_dims.size() == 3 ? static_cast(query_dims[2]) : (num_heads * static_cast(query_dims[4])); int head_size = static_cast(hidden_size) / num_heads; - int kv_sequence_length = static_cast(key_dims[1]); + int kv_sequence_length = sequence_length; + + if (key != nullptr) { + if (query_dims.size() != 3) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input 'query' is expected to have 3 dimensions when key is given, got ", + query_dims.size()); + } - if (key_dims.size() == 3) { - if (key_dims[2] != query_dims[2]) { + const auto& key_dims = key->Shape().GetDims(); + if (key_dims.size() != 3 && key_dims.size() != 5) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input 'key' is expected to have 3 or 5 dimensions, got ", + key_dims.size()); + } + if (query_dims[0] != key_dims[0]) { return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, - "Input 'query' and 'key' shall have same dim 2 (hidden_size)"); + "Input 'query' and 'key' shall have same dim 0 (batch size)"); } - } else // if (key_dims.size() == 5) - { - if (static_cast(key_dims[2]) != num_heads || static_cast(key_dims[3]) != 2 || static_cast(key_dims[4]) != head_size) { + + if (key_dims.size() == 3) { + if (key_dims[2] != query_dims[2]) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "Input 'query' and 'key' shall have same dim 2 (hidden_size)"); + } + } else // if (key_dims.size() == 5) + { + if (static_cast(key_dims[2]) != num_heads || static_cast(key_dims[3]) != 2 || static_cast(key_dims[4]) != head_size) { + return ORT_MAKE_STATUS( + ONNXRUNTIME, INVALID_ARGUMENT, + "Expect 'key' shape (batch_size, kv_sequence_length, num_heads, 2, head_size) for packed kv"); + } + if (value != nullptr) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Expect 'value' be none when 'key' has packed kv format."); + } + } + + kv_sequence_length = static_cast(key_dims[1]); + } else { // packed QKV + if (query_dims.size() != 5) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input 'query' is expected to have 5 dimensions when key is empty, got ", + query_dims.size()); + } + if (static_cast(query_dims[2]) != num_heads || static_cast(query_dims[3]) != 3) { return ORT_MAKE_STATUS( ONNXRUNTIME, INVALID_ARGUMENT, - "Expect 'key' shape (batch_size, kv_sequence_length, num_heads, 2, head_size) for packed kv"); - } - if (value != nullptr) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Expect 'value' be none when 'key' has packed kv format."); + "Expect 'query' shape (batch_size, kv_sequence_length, num_heads, 3, head_size) for packed kv"); } } @@ -82,7 +108,7 @@ Status CheckInputs(const T* query, // Currently, bias is not allowed for packed KV. This constraint can be removed later. // Here we assume that fusion tool will not include bias for packed KV. if (value == nullptr) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "'bias' is not allowed for packed kv. "); + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "'bias' is not allowed for packed qkv or kv. "); } } @@ -90,9 +116,9 @@ Status CheckInputs(const T* query, if (key_padding_mask != nullptr) { mask_type = AttentionMaskType::MASK_UNKNOWN; const auto& mask_dims = key_padding_mask->Shape().GetDims(); - if (mask_dims.size() == 1 && mask_dims[0] == key_dims[0]) { + if (mask_dims.size() == 1 && mask_dims[0] == static_cast(batch_size)) { mask_type = AttentionMaskType::MASK_1D_KEY_SEQ_LEN; - } else if (mask_dims.size() == 2 && mask_dims[0] == key_dims[0] && mask_dims[1] == key_dims[1]) { + } else if (mask_dims.size() == 2 && mask_dims[0] == static_cast(batch_size) && mask_dims[1] == static_cast(kv_sequence_length)) { mask_type = AttentionMaskType::MASK_2D_KEY_PADDING; } @@ -115,7 +141,7 @@ Status CheckInputs(const T* query, "Input 'query' and 'value' shall have same dim 0 (batch_size)"); } - if (key_dims[1] != value_dims[1]) { + if (static_cast(kv_sequence_length) != value_dims[1]) { return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input 'key' and 'value' shall have same same dim 1 (kv_sequence_length)"); } diff --git a/onnxruntime/contrib_ops/cuda/bert/attention.cc b/onnxruntime/contrib_ops/cuda/bert/attention.cc index 1ab89b525eae5..04cac1962f37b 100644 --- a/onnxruntime/contrib_ops/cuda/bert/attention.cc +++ b/onnxruntime/contrib_ops/cuda/bert/attention.cc @@ -181,6 +181,7 @@ Status Attention::ComputeInternal(OpKernelContext* context) const { &zero, reinterpret_cast(gemm_buffer.get()), n, device_prop)); constexpr size_t element_size = sizeof(T); + constexpr bool use_fused_cross_attention = false; size_t workSpaceSize = GetAttentionWorkspaceSize(element_size, parameters.batch_size, parameters.num_heads, @@ -190,6 +191,7 @@ Status Attention::ComputeInternal(OpKernelContext* context) const { parameters.kv_sequence_length, parameters.total_sequence_length, fused_runner, + use_fused_cross_attention, use_memory_efficient_attention); auto work_space = GetScratchBuffer(workSpaceSize, context->GetComputeStream()); @@ -204,12 +206,15 @@ Status Attention::ComputeInternal(OpKernelContext* context) const { data.mask_index_dims = (nullptr == mask_index) ? gsl::span() : mask_index->Shape().GetDims(); data.past = (nullptr == past) ? nullptr : reinterpret_cast(past->Data()); data.relative_position_bias = (nullptr == relative_position_bias) ? nullptr : reinterpret_cast(relative_position_bias->Data()); + data.has_qkv_workspace = true; data.workspace = reinterpret_cast(work_space.get()); data.output = reinterpret_cast(output->MutableData()); data.present = (nullptr == present) ? nullptr : reinterpret_cast(present->MutableData()); data.fused_runner = reinterpret_cast(fused_runner); data.fused_cross_attention_kernel = nullptr; data.use_memory_efficient_attention = use_memory_efficient_attention; + data.cumulated_sequence_length_q_cache = nullptr; + data.cumulated_sequence_length_kv_cache = nullptr; return QkvToContext(device_prop, cublas, Stream(context), parameters, data); } diff --git a/onnxruntime/contrib_ops/cuda/bert/attention_impl.cu b/onnxruntime/contrib_ops/cuda/bert/attention_impl.cu index 972b2d4ade7d6..41f19f460e80c 100644 --- a/onnxruntime/contrib_ops/cuda/bert/attention_impl.cu +++ b/onnxruntime/contrib_ops/cuda/bert/attention_impl.cu @@ -46,22 +46,47 @@ limitations under the License. using namespace onnxruntime::cuda; using namespace cub; -#define CHECK_CUDA(expr) CUDA_RETURN_IF_ERROR(expr) -#define CUDA_MEMORY_ALIGNMENT 256 - namespace onnxruntime { namespace contrib { namespace cuda { +constexpr size_t kMemoryAlignment = 256; + static size_t AlignTo(size_t a, size_t b) { return CeilDiv(a, b) * b; } size_t AlignSize(size_t bytes) { - const size_t bytesAligned = AlignTo(bytes, CUDA_MEMORY_ALIGNMENT); + const size_t bytesAligned = AlignTo(bytes, kMemoryAlignment); return bytesAligned; } +void CumulatedSequenceLengthCache::Initialize(int32_t sequence_length, cudaStream_t stream) { + if (this->sequence_length != sequence_length) { + ORT_ENFORCE(buffer.get() != nullptr && this->max_batch_size > 0); + LaunchTrtSequenceOffset(reinterpret_cast(buffer.get()), nullptr, this->max_batch_size, sequence_length, stream); + this->sequence_length = sequence_length; + } +} + +int* GetCumulatedSequenceLength(CumulatedSequenceLengthCache* cache, + const int* mask_index, + int batch_size, + int sequence_length, + cudaStream_t stream, + void* scratch_buffer) { + if (mask_index == nullptr && cache != nullptr) { + if (batch_size <= cache->max_batch_size) { + cache->Initialize(sequence_length, stream); + return reinterpret_cast(cache->buffer.get()); + } + } + + int* sequence_offset = reinterpret_cast(scratch_buffer); + LaunchTrtSequenceOffset(sequence_offset, mask_index, batch_size, sequence_length, stream); + return sequence_offset; +} + size_t GetAttentionScratchSize( size_t element_size, size_t batch_size, @@ -89,6 +114,7 @@ size_t GetAttentionWorkspaceSize( size_t kv_sequence_length, size_t total_sequence_length, void* fused_runner, + bool use_fused_cross_attention, bool use_memory_efficient_attention) { // Note that q, k and v might need alignment for fused attention kernels. const size_t qkv_bytes = element_size * batch_size * num_heads * @@ -108,8 +134,11 @@ size_t GetAttentionWorkspaceSize( #endif if (fused_runner != nullptr) { - size_t sequence_offset_bytes = GetSequenceOffsetSize(static_cast(batch_size), true); - return qkv_bytes + sequence_offset_bytes; + return qkv_bytes + GetSequenceOffsetSize(static_cast(batch_size), true); + } + + if (use_fused_cross_attention) { + return qkv_bytes + 2 * GetSequenceOffsetSize(static_cast(batch_size), true); } return qkv_bytes + 2 * GetAttentionScratchSize(element_size, batch_size, num_heads, sequence_length, @@ -264,7 +293,7 @@ Status PrepareQkv(contrib::AttentionParameters& parameters, T* qkv = data.workspace; - bool use_fused_kernel = (nullptr != fused_runner && data.bias != nullptr && !parameters.is_unidirectional); + bool use_fused_kernel = (nullptr != fused_runner && !parameters.is_unidirectional); bool use_fused_causal = (nullptr != fused_runner && parameters.is_unidirectional); // Default format for memory efficient attention. @@ -272,6 +301,7 @@ Status PrepareQkv(contrib::AttentionParameters& parameters, DUMP_TENSOR_INIT(); if (nullptr != data.gemm_buffer) { if (data.bias == nullptr) { + assert(nullptr == fused_runner); // For quantized attention, bias has been added so only need transpose here. // gemm_buffer should be BxSx3xNxH => qkv: 3xBxNxSxH assert(qk_head_size == v_head_size); @@ -303,6 +333,31 @@ Status PrepareQkv(contrib::AttentionParameters& parameters, data.gemm_buffer, data.bias, qkv, true, v_head_size, qkv_add_bias, 3); } + } else if (data.key == nullptr) { // gemm_buffer == nullptr and packed qkv + assert(data.bias == nullptr); + assert(qk_head_size == v_head_size); + + DUMP_TENSOR_D("packed_qkv", data.query, batch_size * sequence_length, num_heads, 3, qk_head_size); + + if (use_memory_efficient_attention) { + // unpack qkv to BSNH. Note that there is no bias so we need not output query to q. + constexpr int format = 4; + T* qkv_add_bias = nullptr; + LaunchAddBiasTranspose(stream, 3, format, max_threads_per_block, + batch_size, sequence_length, num_heads, qk_head_size, + data.query, data.bias, qkv, + true, v_head_size, qkv_add_bias, 3); + DUMP_TENSOR_D("k(BSNH)", q, batch_size * sequence_length, num_heads, qk_head_size); + DUMP_TENSOR_D("k(BSNH)", k, batch_size * kv_sequence_length, num_heads, qk_head_size); + DUMP_TENSOR_D("v(BSNH)", v, batch_size * kv_sequence_length, num_heads, v_head_size); + qkv_format = AttentionQkvFormat::Q_K_V_BSNH; + } else { + if (!use_fused_kernel) { + return ORT_MAKE_STATUS(ONNXRUNTIME, NOT_IMPLEMENTED, "packed QKV format is not implemented for current GPU. Please disable it in fusion options."); + } + + qkv_format = AttentionQkvFormat::QKV_BSN3H; + } } else if (data.value == nullptr) { // gemm_buffer == nullptr and packed kv // TODO: unpack kv to BNSH for unfused kernel so that we can remove the following constraint. // CheckInputs verified this constraint. @@ -330,7 +385,7 @@ Status PrepareQkv(contrib::AttentionParameters& parameters, qkv_format = AttentionQkvFormat::Q_KV_BSNH_BSN2H; } - } else { // gemm_buffer == nullptr and not packed kv + } else { // gemm_buffer == nullptr and not packed assert(data.query != nullptr && data.key != nullptr && data.value != nullptr && data.bias != nullptr); DUMP_TENSOR_D("query", data.query, batch_size * sequence_length, num_heads, qk_head_size); @@ -435,18 +490,25 @@ Status QkvToContext( assert(int(data.use_memory_efficient_attention) + int(fused_runner != nullptr) + int(data.fused_cross_attention_kernel != nullptr) <= 1); const int batches = batch_size * num_heads; - const int size_per_batch_q = sequence_length * qk_head_size; - const int size_per_batch_k = kv_sequence_length * qk_head_size; - const int size_per_batch_v = kv_sequence_length * v_head_size; - const size_t elements_q = static_cast(batches) * static_cast(size_per_batch_q); - const size_t elements_k = static_cast(batches) * static_cast(size_per_batch_k); - const size_t elements_v = static_cast(batches) * static_cast(size_per_batch_v); - - // Q, K and V pointers when fused attention is not used - T* qkv = data.workspace; - T* q = qkv; - T* k = q + elements_q; - T* v = k + elements_k; + + T* qkv = nullptr; + T* q = nullptr; + T* k = nullptr; + T* v = nullptr; + T* scratch1 = data.workspace; + if (data.has_qkv_workspace) { + const int size_per_batch_q = sequence_length * qk_head_size; + const int size_per_batch_k = kv_sequence_length * qk_head_size; + const int size_per_batch_v = kv_sequence_length * v_head_size; + const size_t elements_q = static_cast(batches) * static_cast(size_per_batch_q); + const size_t elements_k = static_cast(batches) * static_cast(size_per_batch_k); + const size_t elements_v = static_cast(batches) * static_cast(size_per_batch_v); + qkv = data.workspace; + q = qkv; + k = q + elements_q; + v = k + elements_k; + scratch1 = v + elements_v; + } bool use_fused_kernel = (nullptr != fused_runner && !parameters.is_unidirectional); bool use_fused_causal = (nullptr != fused_runner && parameters.is_unidirectional); @@ -454,8 +516,6 @@ Status QkvToContext( AttentionQkvFormat qkv_format = AttentionQkvFormat::Q_K_V_BSNH; ORT_RETURN_IF_ERROR(PrepareQkv(parameters, data, stream, max_threads_per_block, q, k, v, qkv_format)); - T* scratch1 = qkv + elements_q + elements_k + elements_v; - int present_size_per_batch_k = 0; int present_size_per_batch_v = 0; if (!past_present_share_buffer) { @@ -482,6 +542,7 @@ Status QkvToContext( assert(!use_fused_kernel); assert(data.gemm_buffer != nullptr); assert(!data.use_memory_efficient_attention); + assert(data.has_qkv_workspace); if (data.present != data.past) { // For easy testing. Production should better avoid this path. @@ -507,18 +568,21 @@ Status QkvToContext( if (data.fused_cross_attention_kernel != nullptr) { assert(qkv_format == AttentionQkvFormat::Q_KV_BSNH_BSN2H); - int* q_sequence_offset = reinterpret_cast(scratch1); - LaunchTrtSequenceOffset(q_sequence_offset, nullptr, batch_size, sequence_length, stream); - CUDA_RETURN_IF_ERROR(cudaGetLastError()); - - DUMP_TENSOR_D("q_sequence_offset", q_sequence_offset, 1, batch_size + 1); // We only enable fused cross attention when there is no key padding mask. // Otherwise, key have effective batch size 2 * batch_size, which is different from batch_size of query. assert(data.mask_index == nullptr); + int* q_sequence_offset = GetCumulatedSequenceLength(data.cumulated_sequence_length_q_cache, + data.mask_index, batch_size, sequence_length, stream, + scratch1); + + DUMP_TENSOR_D("q_sequence_offset", q_sequence_offset, 1, batch_size + 1); + int* kv_sequence_offset = q_sequence_offset + (GetSequenceOffsetSize(batch_size, false) / sizeof(int)); - LaunchTrtSequenceOffset(kv_sequence_offset, data.mask_index, batch_size, kv_sequence_length, stream); + kv_sequence_offset = GetCumulatedSequenceLength(data.cumulated_sequence_length_kv_cache, + data.mask_index, batch_size, kv_sequence_length, stream, + kv_sequence_offset); CUDA_RETURN_IF_ERROR(cudaGetLastError()); DUMP_TENSOR_D("kv_sequence_offset", kv_sequence_offset, 1, batch_size + 1); @@ -526,7 +590,7 @@ Status QkvToContext( FusedMultiHeadCrossAttentionKernel const* cross_attention_kernel = reinterpret_cast(data.fused_cross_attention_kernel); - // When there is no bias, we can directly use q and packed kv from inputs. TODO: not need qkv in workspace. + // When there is no bias, we can directly use q and packed kv from inputs. void const* query = q; void const* packed_kv = k; if (data.value == nullptr && data.bias == nullptr) { @@ -558,7 +622,9 @@ Status QkvToContext( if (parameters.mask_type == AttentionMaskType::MASK_2D_KEY_PADDING) { LaunchTrtSequenceOffset2d(sequence_offset, data.mask_index, batch_size, sequence_length, stream); } else { - LaunchTrtSequenceOffset(sequence_offset, data.mask_index, batch_size, sequence_length, stream); + sequence_offset = GetCumulatedSequenceLength(data.cumulated_sequence_length_q_cache, + data.mask_index, batch_size, sequence_length, stream, + sequence_offset); } CUDA_RETURN_IF_ERROR(cudaGetLastError()); @@ -573,7 +639,14 @@ Status QkvToContext( if (use_fused_kernel) { assert(qkv_format == AttentionQkvFormat::QKV_BSN3H); - fused_fp16_runner->run(qkv, sequence_offset, data.output, stream); + + // When there is no bias, we can directly use packed qkv from inputs. + void const* packed_qkv = qkv; + if (data.query != nullptr && data.key == nullptr && data.bias == nullptr) { + packed_qkv = data.query; + } + + fused_fp16_runner->run(packed_qkv, sequence_offset, data.output, stream); DUMP_TENSOR("fused output", data.output, batch_size * sequence_length, num_heads, v_head_size); } else { assert(qkv_format == AttentionQkvFormat::Q_K_V_BNSH_QKV_BS3NH); @@ -593,7 +666,9 @@ Status QkvToContext( const void* query = q; const void* key = k; const void* value = v; - if (data.gemm_buffer == nullptr && data.value == nullptr) { // packed KV + // For packed KV, we can use query input directly. + if (data.gemm_buffer == nullptr && data.key != nullptr && data.value == nullptr) { + assert(data.bias == nullptr); query = data.query; } @@ -781,15 +856,15 @@ Status DecoderQkvToContext( if (has_layer_state) { if (use_past && static_kv) { - CHECK_CUDA(cudaMemcpyAsync(new_key_cache, key_cache, kv_sequence_length * BHN * sizeof(T), - cudaMemcpyDeviceToDevice, stream)); - CHECK_CUDA(cudaMemcpyAsync(new_value_cache, value_cache, kv_sequence_length * BHN * sizeof(T), - cudaMemcpyDeviceToDevice, stream)); + CUDA_RETURN_IF_ERROR(cudaMemcpyAsync(new_key_cache, key_cache, kv_sequence_length * BHN * sizeof(T), + cudaMemcpyDeviceToDevice, stream)); + CUDA_RETURN_IF_ERROR(cudaMemcpyAsync(new_value_cache, value_cache, kv_sequence_length * BHN * sizeof(T), + cudaMemcpyDeviceToDevice, stream)); } else { - CHECK_CUDA(cudaMemcpyAsync(new_key_cache, k, kv_sequence_length * BHN * sizeof(T), - cudaMemcpyDeviceToDevice, stream)); - CHECK_CUDA(cudaMemcpyAsync(new_value_cache, v, kv_sequence_length * BHN * sizeof(T), - cudaMemcpyDeviceToDevice, stream)); + CUDA_RETURN_IF_ERROR(cudaMemcpyAsync(new_key_cache, k, kv_sequence_length * BHN * sizeof(T), + cudaMemcpyDeviceToDevice, stream)); + CUDA_RETURN_IF_ERROR(cudaMemcpyAsync(new_value_cache, v, kv_sequence_length * BHN * sizeof(T), + cudaMemcpyDeviceToDevice, stream)); } } diff --git a/onnxruntime/contrib_ops/cuda/bert/attention_impl.h b/onnxruntime/contrib_ops/cuda/bert/attention_impl.h index 2ecda71479c52..ec7371db4c14d 100644 --- a/onnxruntime/contrib_ops/cuda/bert/attention_impl.h +++ b/onnxruntime/contrib_ops/cuda/bert/attention_impl.h @@ -6,18 +6,33 @@ #include #include #include "contrib_ops/cpu/bert/attention_common.h" +#include "core/framework/allocator.h" namespace onnxruntime { namespace contrib { namespace cuda { -size_t GetAttentionScratchSize( +constexpr int kCumulatedSequenceLengthCacheMaxBatchSize = 128; + +struct CumulatedSequenceLengthCache { + onnxruntime::IAllocatorUniquePtr buffer; + int32_t max_batch_size; + int32_t sequence_length; + + CumulatedSequenceLengthCache() : max_batch_size(0), sequence_length(0) {} + void Initialize(int32_t sequence_length, cudaStream_t stream); +}; + +size_t +GetAttentionScratchSize( size_t element_size, size_t batch_size, size_t num_heads, size_t sequence_length, size_t all_sequence_length); +size_t GetSequenceOffsetSize(int batch_size, bool has_padding); + size_t GetAttentionWorkspaceSize( size_t element_size, size_t batchsize, @@ -28,7 +43,8 @@ size_t GetAttentionWorkspaceSize( size_t kv_sequence_length, size_t total_sequence_length, void* fused_runner, - bool use_memory_efficient_attention = false); + bool use_fused_cross_attention, + bool use_memory_efficient_attention); template struct AttentionData { @@ -43,7 +59,9 @@ struct AttentionData { const T* past; const T* relative_position_bias; + bool has_qkv_workspace; T* workspace; + T* output; T* present; @@ -51,6 +69,9 @@ struct AttentionData { const void* fused_cross_attention_kernel; bool use_memory_efficient_attention; + + mutable CumulatedSequenceLengthCache* cumulated_sequence_length_q_cache; + mutable CumulatedSequenceLengthCache* cumulated_sequence_length_kv_cache; }; template diff --git a/onnxruntime/contrib_ops/cuda/bert/multihead_attention.cc b/onnxruntime/contrib_ops/cuda/bert/multihead_attention.cc index 57a3a310a0dd6..321c2a1df0df2 100644 --- a/onnxruntime/contrib_ops/cuda/bert/multihead_attention.cc +++ b/onnxruntime/contrib_ops/cuda/bert/multihead_attention.cc @@ -31,9 +31,11 @@ REGISTER_KERNEL_TYPED(float) REGISTER_KERNEL_TYPED(MLFloat16) template - MultiHeadAttention::MultiHeadAttention(const OpKernelInfo& info) - : CudaKernel(info), fused_fp16_cross_attention_kernel_(nullptr) { + : CudaKernel(info), + fused_fp16_cross_attention_kernel_(nullptr), + cumulated_sequence_length_q_cache_(), + cumulated_sequence_length_kv_cache_() { int64_t num_heads = 0; ORT_ENFORCE(info.GetAttr("num_heads", &num_heads).IsOK() && num_heads > 0); num_heads_ = static_cast(num_heads); @@ -52,7 +54,15 @@ MultiHeadAttention::MultiHeadAttention(const OpKernelInfo& info) disable_memory_efficient_attention_ = true; #endif - disable_fused_cross_attention_ = sizeof(T) != 2 || ParseEnvironmentVariableWithDefault(attention::kDisableFusedCrossAttention, false); + disable_fused_cross_attention_ = sizeof(T) != 2 || + ParseEnvironmentVariableWithDefault(attention::kDisableFusedCrossAttention, false); + + // Allocate cache buffers + constexpr size_t cache_bytes = sizeof(int32_t) * (static_cast(kCumulatedSequenceLengthCacheMaxBatchSize) + 1); + cumulated_sequence_length_q_cache_.buffer = GetTransientScratchBuffer(cache_bytes); + cumulated_sequence_length_q_cache_.max_batch_size = kCumulatedSequenceLengthCacheMaxBatchSize; + cumulated_sequence_length_kv_cache_.buffer = GetTransientScratchBuffer(cache_bytes); + cumulated_sequence_length_kv_cache_.max_batch_size = kCumulatedSequenceLengthCacheMaxBatchSize; } template @@ -97,6 +107,7 @@ Status MultiHeadAttention::ComputeInternal(OpKernelContext* context) const { bool use_fused_cross_attention = !disable_fused_cross_attention_ && nullptr == key_padding_mask && nullptr == relative_position_bias && + key != nullptr && (value != nullptr || bias == nullptr) && // TODO: new kernel for adding bias to packed KV parameters.hidden_size == parameters.v_hidden_size && has_fused_cross_attention_kernel(sm, parameters.head_size, @@ -116,7 +127,7 @@ Status MultiHeadAttention::ComputeInternal(OpKernelContext* context) const { bool use_fused_runner = !disable_fused_runner_ && fused_cross_attention_kernel == nullptr && nullptr == relative_position_bias && - value != nullptr && // fused runner requires packed qkv instead of packed kv + (value != nullptr || key == nullptr) && (nullptr == key_padding_mask || is_mask_1d_seq_len) && parameters.hidden_size == parameters.v_hidden_size && parameters.sequence_length == parameters.kv_sequence_length && @@ -153,36 +164,52 @@ Status MultiHeadAttention::ComputeInternal(OpKernelContext* context) const { constexpr bool use_memory_efficient_attention = false; #endif - constexpr size_t element_size = sizeof(T); - size_t workSpaceSize = GetAttentionWorkspaceSize(element_size, - parameters.batch_size, - parameters.num_heads, - parameters.head_size, - parameters.v_head_size, - parameters.sequence_length, - parameters.kv_sequence_length, - parameters.total_sequence_length, - fused_runner, - use_memory_efficient_attention); - auto work_space = GetScratchBuffer(workSpaceSize, context->GetComputeStream()); + // When packed kv or packed qkv is used, there is no needed for add bias transpose thus no qkv workspace. + bool no_qkv_workspace = nullptr == value && + (use_fused_cross_attention || (nullptr != fused_runner && nullptr == key)) && + nullptr == key_padding_mask && + nullptr == bias; + + size_t workspace_bytes; + if (no_qkv_workspace) { + workspace_bytes = (parameters.batch_size > kCumulatedSequenceLengthCacheMaxBatchSize) ? 2 * GetSequenceOffsetSize(parameters.batch_size, true) : 0; + } else { + constexpr size_t element_size = sizeof(T); + workspace_bytes = GetAttentionWorkspaceSize(element_size, + parameters.batch_size, + parameters.num_heads, + parameters.head_size, + parameters.v_head_size, + parameters.sequence_length, + parameters.kv_sequence_length, + parameters.total_sequence_length, + fused_runner, + use_fused_cross_attention, + use_memory_efficient_attention); + } + + auto work_space = GetScratchBuffer(workspace_bytes, context->GetComputeStream()); typedef typename ToCudaType::MappedType CudaT; AttentionData data; data.gemm_buffer = nullptr; data.bias = (nullptr == bias) ? nullptr : reinterpret_cast(bias->Data()); data.query = reinterpret_cast(query->Data()); - data.key = reinterpret_cast(key->Data()); + data.key = (nullptr == key) ? nullptr : reinterpret_cast(key->Data()); data.value = (nullptr == value) ? nullptr : reinterpret_cast(value->Data()); data.mask_index = (nullptr == key_padding_mask) ? nullptr : key_padding_mask->Data(); data.mask_index_dims = (nullptr == key_padding_mask) ? gsl::span() : key_padding_mask->Shape().GetDims(); data.past = nullptr; data.relative_position_bias = (nullptr == relative_position_bias) ? nullptr : reinterpret_cast(relative_position_bias->Data()); + data.has_qkv_workspace = !no_qkv_workspace; data.workspace = reinterpret_cast(work_space.get()); data.output = reinterpret_cast(output->MutableData()); data.present = nullptr; data.fused_runner = reinterpret_cast(fused_runner); data.fused_cross_attention_kernel = fused_cross_attention_kernel; data.use_memory_efficient_attention = use_memory_efficient_attention; + data.cumulated_sequence_length_q_cache = &(this->cumulated_sequence_length_q_cache_); + data.cumulated_sequence_length_kv_cache = &(this->cumulated_sequence_length_kv_cache_); cublasHandle_t cublas = GetCublasHandle(context); return QkvToContext( diff --git a/onnxruntime/contrib_ops/cuda/bert/multihead_attention.h b/onnxruntime/contrib_ops/cuda/bert/multihead_attention.h index b4ac7f19597ea..928dbd1c4a0f4 100644 --- a/onnxruntime/contrib_ops/cuda/bert/multihead_attention.h +++ b/onnxruntime/contrib_ops/cuda/bert/multihead_attention.h @@ -7,6 +7,7 @@ #include "core/providers/cuda/cuda_kernel.h" #include "contrib_ops/cuda/bert/tensorrt_fused_multihead_attention/mha_runner.h" #include "contrib_ops/cuda/bert/tensorrt_fused_multihead_attention/cross_attention/fmha_cross_attention.h" +#include "contrib_ops/cuda/bert/attention_impl.h" namespace onnxruntime { namespace contrib { @@ -29,6 +30,8 @@ class MultiHeadAttention final : public CudaKernel { bool disable_memory_efficient_attention_; mutable std::unique_ptr fused_fp16_runner_; mutable const FusedMultiHeadCrossAttentionKernel* fused_fp16_cross_attention_kernel_; + mutable CumulatedSequenceLengthCache cumulated_sequence_length_q_cache_; + mutable CumulatedSequenceLengthCache cumulated_sequence_length_kv_cache_; }; } // namespace cuda diff --git a/onnxruntime/contrib_ops/cuda/bert/tensorrt_fused_multihead_attention/cross_attention/fmha_cross_attention.h b/onnxruntime/contrib_ops/cuda/bert/tensorrt_fused_multihead_attention/cross_attention/fmha_cross_attention.h index 45a17a03d82f2..23bab06fe46ca 100644 --- a/onnxruntime/contrib_ops/cuda/bert/tensorrt_fused_multihead_attention/cross_attention/fmha_cross_attention.h +++ b/onnxruntime/contrib_ops/cuda/bert/tensorrt_fused_multihead_attention/cross_attention/fmha_cross_attention.h @@ -186,7 +186,7 @@ static Fused_multihead_attention_params_mhca getMHCAParams( void const* q_packed_d, void const* kv_packed_d, void* cu_seqlens_q_d, void* cu_seqlens_kv_d, void* o_packed_d) { Fused_multihead_attention_params_mhca params{}; - int32_t const d_padded = static_cast(std::pow(2, std::ceil(std::log(d) / std::log(2)))); + int32_t const d_padded = d <= 64 ? 64 : static_cast(std::pow(2, std::ceil(std::log(d) / std::log(2)))); // Set the pointers. params.o_ptr = o_packed_d; @@ -269,11 +269,11 @@ using FusedMHACrossKernelFactory = TSharedCubinKernelFactory min_head_size) && (head_size <= max_head_size) && - (kv_sequence_length <= 128); // TODO: shall we remove this constraint on kv_sequence_length? + (kv_sequence_length <= 128); // TODO: shall we remove this constraint on kv_sequence_length? } inline FusedMultiHeadCrossAttentionKernel const* get_fused_cross_attention_kernels(int32_t sm) { diff --git a/onnxruntime/contrib_ops/cuda/quantization/attention_quantization.cc b/onnxruntime/contrib_ops/cuda/quantization/attention_quantization.cc index 7cd717efc9fba..90ec1a35ac63a 100644 --- a/onnxruntime/contrib_ops/cuda/quantization/attention_quantization.cc +++ b/onnxruntime/contrib_ops/cuda/quantization/attention_quantization.cc @@ -174,6 +174,7 @@ Status QAttention::ComputeInternal(OpKernelContext* context) const { Tensor* present = context->Output(1, present_shape); void* fused_runner = nullptr; // TODO(tianleiwu): use fused kernel to speed up + bool use_fused_cross_attention = false; bool use_memory_efficient_attention = false; size_t workSpaceSize = GetAttentionWorkspaceSize(element_size, batch_size, @@ -184,6 +185,7 @@ Status QAttention::ComputeInternal(OpKernelContext* context) const { parameters.kv_sequence_length, parameters.total_sequence_length, fused_runner, + use_fused_cross_attention, use_memory_efficient_attention); auto work_space = GetScratchBuffer(workSpaceSize, context->GetComputeStream()); @@ -199,12 +201,15 @@ Status QAttention::ComputeInternal(OpKernelContext* context) const { data.mask_index_dims = (nullptr == mask_index) ? gsl::span() : mask_index->Shape().GetDims(); data.past = (nullptr == past_tensor) ? nullptr : reinterpret_cast(past_tensor->Data()); data.relative_position_bias = nullptr; // add_qk is not supported in quantized attention + data.has_qkv_workspace = true; data.workspace = reinterpret_cast(work_space.get()); data.output = reinterpret_cast(output->MutableData()); data.present = (nullptr == present) ? nullptr : reinterpret_cast(present->MutableData()); data.fused_runner = fused_runner; data.fused_cross_attention_kernel = nullptr; data.use_memory_efficient_attention = use_memory_efficient_attention; + data.cumulated_sequence_length_q_cache = nullptr; + data.cumulated_sequence_length_kv_cache = nullptr; return QkvToContext(GetDeviceProp(), cublas, Stream(context), parameters, data); } diff --git a/onnxruntime/core/graph/contrib_ops/bert_defs.cc b/onnxruntime/core/graph/contrib_ops/bert_defs.cc index 6b00ac94bc10f..580a0a993454a 100644 --- a/onnxruntime/core/graph/contrib_ops/bert_defs.cc +++ b/onnxruntime/core/graph/contrib_ops/bert_defs.cc @@ -126,11 +126,23 @@ void RestorePaddingTypeAndShapeInference(ONNX_NAMESPACE::InferenceContext& ctx) } void MultiHeadAttentionTypeAndShapeInference(ONNX_NAMESPACE::InferenceContext& ctx) { - // Input 0 (query) has shape (batch_size, sequence_length, hidden_size) - // Input 1 (key) has shape (batch_size, kv_sequence_length, hidden_size) or (batch_size, kv_sequence_length, num_heads, 2, head_size) - // Input 2 (value) has shape (batch_size, kv_sequence_length, v_hidden_size) or nullptr // Output 0 has shape (batch_size, sequence_length, v_hidden_size) + // Q, K and V without packing: + // Input 0 (query) has shape (batch_size, sequence_length, hidden_size) + // Input 1 (key) has shape (batch_size, kv_sequence_length, hidden_size) + // Input 2 (value) has shape (batch_size, kv_sequence_length, v_hidden_size) + + // Packed KV: + // Input 0 (query) has shape (batch_size, sequence_length, hidden_size) + // Input 1 (batch_size, kv_sequence_length, num_heads, 2, head_size) + // Input 2 nullptr + + // Packed QKV: + // Input 0 (batch_size, sequence_length, num_heads, 3, head_size) + // Input 1 nullptr + // Input 2 nullptr + // Type inference ONNX_NAMESPACE::propagateElemTypeFromInputToOutput(ctx, 0, 0); @@ -138,8 +150,18 @@ void MultiHeadAttentionTypeAndShapeInference(ONNX_NAMESPACE::InferenceContext& c if (hasInputShape(ctx, 0)) { auto& query_shape = getInputShape(ctx, 0); auto& query_dims = query_shape.dim(); - if (query_dims.size() != 3) { - fail_shape_inference("Inputs 0 (query) shall be 3 dimensions"); + + if (query_dims.size() != 3 && query_dims.size() != 5) { + fail_shape_inference("Inputs 0 (query) shall be 3 or 5 dimensions"); + } + + if (query_dims.size() == 5) { // packed QKV + ONNX_NAMESPACE::TensorShapeProto output_shape; + *output_shape.add_dim() = query_dims[0]; + *output_shape.add_dim() = query_dims[1]; + *output_shape.add_dim() = query_dims[2] * query_dims[4]; + updateOutputShape(ctx, 0, output_shape); + return; } if (hasInputShape(ctx, 2)) { @@ -154,11 +176,12 @@ void MultiHeadAttentionTypeAndShapeInference(ONNX_NAMESPACE::InferenceContext& c *output_shape.add_dim() = query_dims[1]; *output_shape.add_dim() = value_dims[2]; updateOutputShape(ctx, 0, output_shape); + return; } if (hasInputShape(ctx, 1)) { auto& key_shape = getInputShape(ctx, 1); - if (key_shape.dim().size() == 5) { + if (key_shape.dim().size() == 5) { // packed KV ONNX_NAMESPACE::propagateShapeAndTypeFromFirstInput(ctx); } } @@ -292,12 +315,13 @@ ONNX_MS_OPERATOR_SET_SCHEMA( AttributeProto::FLOAT, OPTIONAL_VALUE) .Input(0, "query", - "Query with shape (batch_size, sequence_length, hidden_size)", + "Query with shape (batch_size, sequence_length, hidden_size), or packed QKV with shape (batch_size, kv_sequence_length, num_heads, 3, head_size)", "T") .Input(1, "key", "Key with shape (batch_size, kv_sequence_length, hidden_size), or packed KV with shape (batch_size, kv_sequence_length, num_heads, 2, head_size)", - "T") + "T", + OpSchema::Optional) .Input(2, "value", "Value with shape (batch_size, kv_sequence_length, v_hidden_size)", diff --git a/onnxruntime/python/tools/symbolic_shape_infer.py b/onnxruntime/python/tools/symbolic_shape_infer.py index 53106715cd36c..bf7f7d86690ac 100755 --- a/onnxruntime/python/tools/symbolic_shape_infer.py +++ b/onnxruntime/python/tools/symbolic_shape_infer.py @@ -2025,28 +2025,43 @@ def _infer_BiasGelu(self, node): self._propagate_shape_and_type(node) def _infer_MultiHeadAttention(self, node): - # Input 0 (query) has shape (batch_size, sequence_length, hidden_size) - # Without packed KV: + # Output 0 has shape (batch_size, sequence_length, v_hidden_size) + # Q, K and V without packing: + # Input 0 (query) has shape (batch_size, sequence_length, hidden_size) # Input 1 (key) has shape (batch_size, kv_sequence_length, hidden_size) # Input 2 (value) has shape (batch_size, kv_sequence_length, v_hidden_size) - # With packed KV: - # Input 1 (key) has shape (batch_size, kv_sequence_length, num_heads, 2, head_size) - # Input 2 (value) is nullptr - # Output 0 has shape (batch_size, sequence_length, v_hidden_size) - query_shape = self._get_shape(node, 0) - key_shape = self._get_shape(node, 1) - if query_shape is not None and len(query_shape) == 3: + # Packed KV: + # Input 0 (query) has shape (batch_size, sequence_length, hidden_size) + # Input 1 (batch_size, kv_sequence_length, num_heads, 2, head_size) + # Input 2 nullptr + # Packed QKV: + # Input 0 (batch_size, sequence_length, num_heads, 3, head_size) + # Input 1 nullptr + # Input 2 nullptr - # By default, hidden size is same for Q/K/V. Only need check v_hidden_size when value is provided. - output_shape = query_shape - if key_shape and len(key_shape) == 3: - value_shape = self._get_shape(node, 2) - if value_shape and len(value_shape) == 3: - output_shape[2] = value_shape[2] + query_shape = self._get_shape(node, 0) + if query_shape is not None: + if len(query_shape) == 3: + key_shape = self._get_shape(node, 1) + # By default, hidden size is same for Q/K/V. Only need check v_hidden_size when value is provided. + output_shape = query_shape + if key_shape and len(key_shape) == 3: + value_shape = self._get_shape(node, 2) + if value_shape and len(value_shape) == 3: + output_shape[2] = value_shape[2] + + output_dtype = self.known_vi_[node.input[0]].type.tensor_type.elem_type + vi = self.known_vi_[node.output[0]] + vi.CopyFrom(helper.make_tensor_value_info(node.output[0], output_dtype, output_shape)) + elif len(query_shape) == 5: + if isinstance(query_shape[2], int) and isinstance(query_shape[4], int): + output_shape = [query_shape[0], query_shape[1], query_shape[2] * query_shape[4]] + else: + output_shape = [query_shape[0], query_shape[1], f"{query_shape[2]}*{query_shape[4]}"] - output_dtype = self.known_vi_[node.input[0]].type.tensor_type.elem_type - vi = self.known_vi_[node.output[0]] - vi.CopyFrom(helper.make_tensor_value_info(node.output[0], output_dtype, output_shape)) + output_dtype = self.known_vi_[node.input[0]].type.tensor_type.elem_type + vi = self.known_vi_[node.output[0]] + vi.CopyFrom(helper.make_tensor_value_info(node.output[0], output_dtype, output_shape)) def _infer_FastGelu(self, node): self._propagate_shape_and_type(node) diff --git a/onnxruntime/python/tools/transformers/fusion_attention_unet.py b/onnxruntime/python/tools/transformers/fusion_attention_unet.py index 361baaedc4a95..dcbc640923e35 100644 --- a/onnxruntime/python/tools/transformers/fusion_attention_unet.py +++ b/onnxruntime/python/tools/transformers/fusion_attention_unet.py @@ -20,12 +20,19 @@ class FusionAttentionUnet(Fusion): """ def __init__( - self, model: OnnxModel, hidden_size: int, num_heads: int, is_cross_attention: bool, enable_packed_kv: bool + self, + model: OnnxModel, + hidden_size: int, + num_heads: int, + is_cross_attention: bool, + enable_packed_qkv: bool, + enable_packed_kv: bool, ): super().__init__(model, "MultiHeadAttention" if is_cross_attention else "Attention", ["LayerNormalization"]) self.hidden_size = hidden_size self.num_heads = num_heads self.is_cross_attention = is_cross_attention + self.enable_packed_qkv = enable_packed_qkv self.enable_packed_kv = enable_packed_kv # Flags to show warning only once @@ -147,10 +154,6 @@ def create_attention_node( return None qw_in_size = qw.shape[0] - kw_in_size = kw.shape[0] - vw_in_size = vw.shape[0] - - assert qw_in_size == kw_in_size and kw_in_size == vw_in_size if hidden_size > 0 and hidden_size != qw_in_size: raise ValueError( @@ -161,21 +164,70 @@ def create_attention_node( # All the matrices can have the same shape or q, k matrics can have the same shape with v being different # For 2d weights, the shapes would be [in_size, out_size]. # For 3d weights, shape would be [in_size, a, b] where a*b = out_size - qw_out_size = np.prod(qw.shape[1:]) + qw_out_size = int(np.prod(qw.shape[1:])) - qkv_weight = np.stack((qw, kw, vw), axis=1) - qkv_weight_dim = 3 * qw_out_size + if self.enable_packed_qkv: + attention_node_name = self.model.create_node_name("MultiHeadAttention") - attention_node_name = self.model.create_node_name("Attention") + c = qw_in_size + n = num_heads + h = qw_out_size // num_heads - weight = helper.make_tensor( - name=attention_node_name + "_qkv_weight", - data_type=TensorProto.FLOAT, - dims=[qw_in_size, qkv_weight_dim], - vals=qkv_weight.flatten().tolist(), - ) + # Concat and interleave weights so that the output of fused KV GEMM has [B, S_kv, N, 3, H] shape + qkv_weight = np.dstack([qw.reshape(c, n, h), kw.reshape(c, n, h), vw.reshape(c, n, h)]).reshape( + c, n * 3 * h + ) - self.model.add_initializer(weight, self.this_graph_name) + matmul_node_name = self.model.create_node_name("MatMul", name_prefix="MatMul_QKV") + weight = helper.make_tensor( + name=matmul_node_name + "_weight", + data_type=TensorProto.FLOAT, + dims=[qkv_weight.shape[0], qkv_weight.shape[1]], + vals=qkv_weight.flatten().tolist(), + ) + + self.model.add_initializer(weight, self.this_graph_name) + + matmul_node = helper.make_node( + "MatMul", + inputs=[k_matmul.input[0], matmul_node_name + "_weight"], + outputs=[matmul_node_name + "_out"], + name=matmul_node_name, + ) + self.node_name_to_graph_name[matmul_node.name] = self.this_graph_name + + shape_tensor = helper.make_tensor( + name=matmul_node_name + "_reshape_shape", + data_type=TensorProto.INT64, + dims=[5], + vals=[0, 0, n, 3, h], + ) + self.model.add_initializer(shape_tensor, self.this_graph_name) + + reshape_node = helper.make_node( + "Reshape", + inputs=[matmul_node_name + "_out", matmul_node_name + "_reshape_shape"], + outputs=[attention_node_name + "_input"], + name=matmul_node_name + "_reshape", + ) + self.node_name_to_graph_name[reshape_node.name] = self.this_graph_name + self.nodes_to_add.extend([matmul_node, reshape_node]) + self.nodes_to_remove.extend([q_matmul, k_matmul, v_matmul]) + + else: + qkv_weight = np.stack((qw, kw, vw), axis=1) + qkv_weight_dim = 3 * qw_out_size + + attention_node_name = self.model.create_node_name("Attention") + + weight = helper.make_tensor( + name=attention_node_name + "_qkv_weight", + data_type=TensorProto.FLOAT, + dims=[qw_in_size, qkv_weight_dim], + vals=qkv_weight.flatten().tolist(), + ) + + self.model.add_initializer(weight, self.this_graph_name) else: # cross attention attention_node_name = self.model.create_node_name("MultiHeadAttention") if self.enable_packed_kv: @@ -247,11 +299,14 @@ def create_attention_node( self.model.add_initializer(bias, self.this_graph_name) if is_self_attention: - attention_inputs = [ - input, - attention_node_name + "_qkv_weight", - attention_node_name + "_qkv_bias", - ] + if not self.enable_packed_qkv: + attention_inputs = [ + input, + attention_node_name + "_qkv_weight", + attention_node_name + "_qkv_bias", + ] + else: + attention_inputs = [attention_node_name + "_input"] else: if not self.enable_packed_kv: attention_inputs = [ @@ -267,7 +322,7 @@ def create_attention_node( ] attention_node = helper.make_node( - "Attention" if is_self_attention else "MultiHeadAttention", + "Attention" if (is_self_attention and not self.enable_packed_qkv) else "MultiHeadAttention", inputs=attention_inputs, outputs=[output], name=attention_node_name, @@ -277,9 +332,13 @@ def create_attention_node( counter_name = ( "Attention (self attention)" - if is_self_attention + if is_self_attention and not self.enable_packed_qkv else "MultiHeadAttention ({})".format( - "cross attention with packed kv" if self.enable_packed_kv else "cross attention" + "self attention with packed qkv" + if self.enable_packed_qkv + else "cross attention with packed kv" + if self.enable_packed_kv + else "cross attention" ) ) self.increase_counter(counter_name) diff --git a/onnxruntime/python/tools/transformers/fusion_options.py b/onnxruntime/python/tools/transformers/fusion_options.py index 7629a6886ab94..c0847a6a0c79a 100644 --- a/onnxruntime/python/tools/transformers/fusion_options.py +++ b/onnxruntime/python/tools/transformers/fusion_options.py @@ -51,10 +51,12 @@ def __init__(self, model_type): ) # options for stable diffusion - self.enable_group_norm = model_type in ["unet", "vae"] - self.enable_bias_splitgelu = model_type in ["unet"] - self.enable_packed_kv = model_type in ["unet"] - self.enable_bias_add = model_type in ["unet"] + if model_type in ["unet", "vae", "clip"]: + self.enable_group_norm = True + self.enable_bias_splitgelu = True + self.enable_packed_qkv = True + self.enable_packed_kv = True + self.enable_bias_add = True def use_raw_attention_mask(self, use_raw_mask=True): if use_raw_mask: @@ -96,10 +98,15 @@ def parse(args): options.use_raw_attention_mask(True) if args.no_attention_mask: options.disable_attention_mask() - if args.disable_group_norm: - options.enable_group_norm = False - if args.disable_packed_kv: - options.enable_packed_kv = False + + if args.model_type in ["unet", "vae", "clip"]: + if args.disable_group_norm: + options.enable_group_norm = False + if args.disable_packed_kv: + options.enable_packed_kv = False + if args.disable_packed_qkv: + options.enable_packed_qkv = False + return options @staticmethod @@ -222,7 +229,7 @@ def add_arguments(parser: ArgumentParser): "--disable_group_norm", required=False, action="store_true", - help="not fuse GroupNorm. Only works for model_type=unet", + help="not fuse GroupNorm. Only works for model_type=unet or vae", ) parser.set_defaults(disable_group_norm=False) @@ -230,6 +237,30 @@ def add_arguments(parser: ArgumentParser): "--disable_packed_kv", required=False, action="store_true", - help="not use packed kv in cross attention. Only works for model_type=unet", + help="not use packed kv for cross attention in MultiHeadAttention. Only works for model_type=unet", ) parser.set_defaults(disable_packed_kv=False) + + parser.add_argument( + "--disable_packed_qkv", + required=False, + action="store_true", + help="not use packed qkv for self attention in MultiHeadAttention. Only works for model_type=unet", + ) + parser.set_defaults(disable_packed_qkv=False) + + parser.add_argument( + "--disable_bias_add", + required=False, + action="store_true", + help="not fuse BiasAdd. Only works for model_type=unet", + ) + parser.set_defaults(disable_bias_add=False) + + parser.add_argument( + "--disable_bias_splitgelu", + required=False, + action="store_true", + help="not fuse BiasSplitGelu. Only works for model_type=unet", + ) + parser.set_defaults(disable_bias_splitgelu=False) diff --git a/onnxruntime/python/tools/transformers/models/stable_diffusion/optimize_pipeline.py b/onnxruntime/python/tools/transformers/models/stable_diffusion/optimize_pipeline.py index 31b2a22c2f615..40b0d5d858cb2 100644 --- a/onnxruntime/python/tools/transformers/models/stable_diffusion/optimize_pipeline.py +++ b/onnxruntime/python/tools/transformers/models/stable_diffusion/optimize_pipeline.py @@ -130,12 +130,13 @@ def optimize_sd_pipeline( # Right now, onnxruntime does not save >2GB model so we use script to optimize unet instead. logger.info(f"Optimize {onnx_model_path}...") - # There are some optimizations that are not avaiable in v1.14 or older version - has_all_optimizations = version.parse(onnxruntime.__version__) > version.parse("1.14.0") - fusion_options = FusionOptions(model_type) - fusion_options.enable_packed_kv = float16 - fusion_options.enable_bias_add = has_all_optimizations + if model_type in ["unet"]: + # There are some optimizations that are not available in v1.14 or older version + has_all_optimizations = version.parse(onnxruntime.__version__) > version.parse("1.14.0") + fusion_options.enable_packed_kv = float16 + fusion_options.enable_packed_qkv = float16 and has_all_optimizations + fusion_options.enable_bias_add = has_all_optimizations m = optimize_model( str(onnx_model_path), diff --git a/onnxruntime/python/tools/transformers/onnx_model_unet.py b/onnxruntime/python/tools/transformers/onnx_model_unet.py index b460d7d8a7f8f..4873489770757 100644 --- a/onnxruntime/python/tools/transformers/onnx_model_unet.py +++ b/onnxruntime/python/tools/transformers/onnx_model_unet.py @@ -91,12 +91,17 @@ def merge_adjacent_transpose(self): def fuse_attention(self, options: Optional[FusionOptions] = None): # Self Attention - self_attention_fusion = FusionAttentionUnet(self, self.hidden_size, self.num_heads, False, False) + enable_packed_qkv = (options is None) or options.enable_packed_qkv + self_attention_fusion = FusionAttentionUnet( + self, self.hidden_size, self.num_heads, False, enable_packed_qkv, False + ) self_attention_fusion.apply() # Cross Attention enable_packed_kv = (options is None) or options.enable_packed_kv - cross_attention_fusion = FusionAttentionUnet(self, self.hidden_size, self.num_heads, True, enable_packed_kv) + cross_attention_fusion = FusionAttentionUnet( + self, self.hidden_size, self.num_heads, True, False, enable_packed_kv + ) cross_attention_fusion.apply() def fuse_bias_add(self): @@ -152,7 +157,8 @@ def optimize(self, options: Optional[FusionOptions] = None): self.merge_adjacent_transpose() - self.fuse_bias_add() + if options is not None and options.enable_bias_add: + self.fuse_bias_add() self.postprocess() diff --git a/onnxruntime/test/contrib_ops/attention_op_test_helper.cc b/onnxruntime/test/contrib_ops/attention_op_test_helper.cc index 5c4c14ce137ed..802df2b1d4aa4 100644 --- a/onnxruntime/test/contrib_ops/attention_op_test_helper.cc +++ b/onnxruntime/test/contrib_ops/attention_op_test_helper.cc @@ -2164,8 +2164,7 @@ void GetCrossAttentionData_Batch1_HeadSize32_LeftSidePadding(AttentionTestData& data.skip_kernel_types = { AttentionKernelType::AttentionKernel_TrtFusedCrossAttention, AttentionKernelType::AttentionKernel_TrtFusedAttention, - AttentionKernelType::AttentionKernel_CutlassMemoryEfficientAttention - }; + AttentionKernelType::AttentionKernel_CutlassMemoryEfficientAttention}; { data.query_data = { @@ -2288,6 +2287,413 @@ void GetCrossAttentionData_Batch1_HeadSize32_LeftSidePadding(AttentionTestData& data.fp16_output_data = data.fp32_output_data; } } + +void GetCrossAttentionData_Batch2_HeadSize32_NoBias_NoMask_PackedKV(AttentionTestData& data) { + data.hidden_size = 32; + data.v_hidden_size = 32; + data.num_heads = 1; + data.batch_size = 2; + data.sequence_length = 2; + data.kv_sequence_length = 3; + data.mask_type = AttentionMaskType::MASK_NONE; + // Packed KV format is only supported by TRT fused cross attention or memory efficient attention right now. + data.skip_kernel_types = { + AttentionKernelType::AttentionKernel_Unfused, + AttentionKernelType::AttentionKernel_TrtFusedAttention}; + { + data.query_data = { + -0.35420692f, 1.31206024f, -2.80201197f, 2.42258096f, -0.86031514f, -1.44535458f, -0.10832444f, -2.00132895f, + 1.62475216f, 0.10978927f, 1.84596729f, 0.48908550f, 1.44369888f, 0.87542874f, -1.16434252f, 0.52133209f, + 1.54848897f, -2.21174526f, -0.28574878f, 0.70815033f, 1.18327498f, 3.14097571f, -0.25795099f, 1.89341247f, + -0.11603792f, 0.38110194f, 0.40873206f, -1.14149106f, 0.79770875f, -0.98069525f, -1.53588808f, 0.50821728f, + -2.21641898f, 0.55090773f, 0.80901796f, -0.56089771f, 0.03574468f, -1.27940118f, -0.02213959f, -0.80698186f, + -0.82701880f, 1.72937381f, 1.56083691f, -0.30311784f, -0.25183848f, 0.24280515f, 0.29569417f, -0.31162494f, + 0.48996922f, 0.22795241f, 2.07125854f, 1.45823467f, 3.03750706f, 1.53734803f, 0.48668906f, -1.63703632f, + -0.14114749f, 1.85963213f, 1.20729232f, -0.28972962f, -0.80783498f, -1.16619551f, -0.60004634f, 0.02498829f, + + 3.50846076f, -2.50027657f, -2.59866142f, 1.58495271f, 2.21110034f, -2.74877763f, -1.00267041f, 0.62646407f, + 2.50227380f, -0.27291518f, -0.33037442f, 0.75840306f, 0.45437157f, -0.79876304f, 0.83509272f, 2.53716302f, + 0.01348384f, -2.16307616f, 2.01661849f, 2.10746121f, -1.70485222f, 1.35548759f, 1.39401650f, -0.99451691f, + -4.13484812f, 0.56262714f, -0.92725742f, -0.16389316f, -1.31260049f, 2.32357836f, -3.05251694f, -1.12570131f, + 1.87849474f, -1.80381167f, 0.52235699f, 2.38887334f, -1.58878529f, 0.69571090f, 1.65044296f, -0.27024290f, + 3.59580970f, -1.97888982f, 1.17034674f, 0.26716161f, -1.16770899f, 0.74609619f, 0.78886843f, 0.15717520f, + -0.93303132f, -0.84753871f, -4.32799959f, -1.94716609f, -1.16980326f, 1.62631667f, 2.41053247f, 3.78186774f, + 0.26432252f, -0.40396988f, 2.04414082f, 0.65150046f, 0.47777444f, -2.57569051f, 0.99004912f, 2.47947693f}; + } + { + data.key_data = { + 2.66554713f, 0.04116637f, -1.14599442f, -1.99071956f, 0.42523879f, 0.94212061f, 1.15597987f, -1.76809072f, + -1.89803648f, -0.74707657f, -0.71960962f, 0.67453432f, 3.31946969f, 1.06201041f, 2.29824829f, 0.23788756f, + 1.69329333f, 0.06745748f, -1.34720469f, 1.81031406f, -0.33143526f, -2.46566057f, -0.32179555f, 1.69001770f, + -0.39678648f, -0.91400242f, 1.56746745f, 0.36029303f, -1.01637018f, -1.84069777f, 0.15860040f, 1.35965717f, + 0.16654867f, 3.63810396f, 2.03763342f, 0.64186901f, -1.02682137f, 2.18480039f, -2.17365599f, -0.56225222f, + -2.48764873f, 1.94031644f, -1.13630998f, -2.51891637f, -1.29985571f, 0.23808026f, 2.95523596f, 1.06378591f, + -0.20339361f, -0.56349581f, 1.46587682f, 4.12142849f, 0.78908098f, -0.24000889f, -1.15510166f, 0.42653239f, + -1.98345447f, 1.06918168f, 2.98073006f, -2.94872737f, -0.67443597f, -0.96227646f, -1.94805872f, -0.96003568f, + 1.06492281f, 0.32333452f, -0.52869648f, -1.25258100f, 0.75479198f, -1.04409528f, -1.81722605f, 0.99018478f, + 1.83352923f, 1.02711058f, 0.31064227f, 2.44383168f, -1.80332434f, 1.57207584f, -0.41058558f, 0.20494992f, + -0.78399467f, -0.35703743f, -0.67568171f, -1.30091023f, -0.17390330f, 0.22340816f, -0.44613233f, 1.23870432f, + -0.16092014f, -1.22258115f, 0.60575533f, -0.17969827f, 1.87851882f, 1.13991237f, -0.81591004f, -1.68899822f, + + -1.72543812f, 0.63848293f, 0.87042624f, 0.39726460f, 0.62647510f, 1.73326159f, -0.55110240f, -1.26900804f, + 1.94843686f, -1.73077893f, 2.53475809f, 2.79892564f, -1.91852188f, 0.99826050f, -3.04680610f, 1.38900220f, + -1.17920876f, -2.07508063f, -0.34274688f, -0.24780962f, 1.75715542f, 1.27657294f, -1.15560341f, -2.69310951f, + 0.93523502f, 0.58213681f, -2.57009196f, 2.56376076f, 0.06911665f, 1.73962176f, 0.43307841f, -1.18240118f, + 1.52338290f, 1.02856898f, 0.40946901f, 1.57649779f, 1.22447217f, 0.85961932f, 0.30765539f, -2.66427660f, + -1.55998194f, -0.31161505f, -1.63090813f, -1.62476087f, 1.28381526f, -0.77024549f, 1.46711981f, -0.71657622f, + -0.51606011f, 0.87953311f, 0.26169056f, 1.03068113f, 0.41064253f, -1.56344402f, -1.53443003f, -0.03009570f, + -0.02123317f, -1.74375248f, 1.60988081f, 1.74488568f, 0.59155780f, -0.62032932f, 0.03105794f, 4.54175377f, + -2.08403850f, 0.22352570f, -0.17924348f, 0.65815634f, -0.59089363f, -1.66189861f, 0.75618476f, 0.03879535f, + 1.50222909f, -0.29873836f, -1.76075482f, -2.97067928f, 0.28112072f, 0.72105575f, 0.06761266f, -1.61681306f, + -0.80693424f, 2.40102959f, -2.91352296f, -1.21352315f, -1.62430143f, -1.60909438f, 0.53140688f, -0.28235722f, + 0.63271880f, 1.33791542f, -1.37593675f, -1.60502291f, 1.27470064f, -0.96280038f, 0.79614848f, 0.31894624f}; + } + + { + data.value_data = { + 0.67641568f, -1.44437671f, 0.57255918f, 0.11087912f, 0.73787844f, -1.36586773f, + 1.45507979f, -3.70331645f, -0.85970032f, -2.14545083f, 2.11107802f, -0.16663373f, + 1.47033095f, 0.49124131f, 1.99316287f, -2.68613410f, 0.23831765f, 0.90826637f, + 0.72628385f, 1.29567933f, -0.07918698f, 0.13999116f, 1.22531521f, 0.06399018f, + -2.24613571f, -1.08365369f, -0.68457615f, -0.25960952f, -0.88386559f, 0.46391147f, + 1.24469304f, 1.13121903f, + -0.21484625f, 3.49263334f, -1.35283577f, 0.38428289f, -4.29686069f, -4.34778786f, + -0.49574745f, -0.08637778f, -0.50855160f, -1.12334609f, -1.44851387f, 3.36797357f, + -0.91776383f, -0.98647243f, 1.45408130f, 0.29062888f, 0.24470398f, -1.28129590f, + 0.47530234f, 2.19562674f, 0.62674099f, -2.56222868f, -1.42671025f, 1.51795268f, + -1.92045701f, 1.20271325f, 2.53190184f, -0.37211552f, 0.92569226f, -1.11019444f, + 1.15402830f, -1.98479640f, + -0.49658760f, 1.62168694f, -1.71412969f, -1.26646388f, -1.37257946f, 1.53828073f, + -0.35583261f, 0.03810386f, 0.43514529f, 0.97525519f, -2.22109556f, 1.17547810f, + -0.28825673f, 0.91509271f, -1.19243717f, 1.09280133f, -0.51078367f, 0.63577116f, + -0.62186599f, -2.80234575f, -1.58007598f, 1.06965756f, -0.89327252f, -0.84735525f, + -0.46283475f, 0.77867299f, -0.07434830f, 1.44711912f, 1.07089376f, 0.78913736f, + 0.59053934f, -0.32160193f, + + 0.51273453f, 1.12628150f, 1.96404183f, 0.26380035f, 3.41526699f, 1.08249199f, + -1.70347631f, 0.42854923f, -1.98269284f, 1.97382474f, -0.12164606f, -1.41219604f, + 0.01819625f, 0.73082930f, -2.60845804f, 1.47046185f, 0.26324001f, 1.54259276f, + -1.18744254f, -1.77539694f, 1.76547086f, -1.57072937f, -1.83995926f, -0.05529352f, + 1.83544660f, 0.69575423f, -0.03345531f, -1.69629955f, 0.04713173f, 1.39800107f, + 0.24362923f, 0.12432972f, + -2.92895460f, -0.46070760f, 0.20383459f, 1.93618548f, -1.08026588f, 1.08253515f, + -0.48318014f, -2.34334373f, -2.69622159f, 0.00661799f, -1.10738027f, 0.03181311f, + 0.32897863f, 1.89451993f, -0.01152946f, 0.17766151f, 2.46450090f, -0.64409554f, + 2.56058550f, 1.29339278f, 2.72114944f, 0.87801707f, -1.58970404f, 2.88365316f, + 0.46464550f, -1.71912467f, -1.90960062f, -3.13572145f, 0.19871379f, -0.28741950f, + -0.38167781f, -2.30705547f, + 0.64399612f, 0.32866889f, -3.49091625f, -0.02294427f, 1.60225844f, 1.83659923f, + 1.55193460f, -0.06712314f, 0.76592684f, 0.83479869f, 0.49627584f, 0.75736403f, + 0.75179487f, -0.32156041f, 1.36537170f, 0.57024354f, 0.36152276f, 0.93625057f, + -1.69728792f, -0.28833422f, 0.43304375f, 1.62640548f, -0.00187188f, 0.80429250f, + -0.77993584f, 1.37333393f, -1.16019452f, -0.91983509f, 0.20466281f, 1.09339333f, + -0.99191529f, 3.42685890f}; + } + + { + data.bias_data = { + // 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, + // 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, + // 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, + // 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, + // 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f + }; + } + + { + data.kv_data = { + 2.66554713f, 0.04116637f, -1.14599442f, -1.99071956f, 0.42523879f, 0.94212061f, + 1.15597987f, -1.76809072f, -1.89803648f, -0.74707657f, -0.71960962f, 0.67453432f, + 3.31946969f, 1.06201041f, 2.29824829f, 0.23788756f, 1.69329333f, 0.06745748f, + -1.34720469f, 1.81031406f, -0.33143526f, -2.46566057f, -0.32179555f, 1.69001770f, + -0.39678648f, -0.91400242f, 1.56746745f, 0.36029303f, -1.01637018f, -1.84069777f, + 0.15860040f, 1.35965717f, + 0.67641568f, -1.44437671f, 0.57255918f, 0.11087912f, 0.73787844f, -1.36586773f, + 1.45507979f, -3.70331645f, -0.85970032f, -2.14545083f, 2.11107802f, -0.16663373f, + 1.47033095f, 0.49124131f, 1.99316287f, -2.68613410f, 0.23831765f, 0.90826637f, + 0.72628385f, 1.29567933f, -0.07918698f, 0.13999116f, 1.22531521f, 0.06399018f, + -2.24613571f, -1.08365369f, -0.68457615f, -0.25960952f, -0.88386559f, 0.46391147f, + 1.24469304f, 1.13121903f, + + 0.16654867f, 3.63810396f, 2.03763342f, 0.64186901f, -1.02682137f, 2.18480039f, + -2.17365599f, -0.56225222f, -2.48764873f, 1.94031644f, -1.13630998f, -2.51891637f, + -1.29985571f, 0.23808026f, 2.95523596f, 1.06378591f, -0.20339361f, -0.56349581f, + 1.46587682f, 4.12142849f, 0.78908098f, -0.24000889f, -1.15510166f, 0.42653239f, + -1.98345447f, 1.06918168f, 2.98073006f, -2.94872737f, -0.67443597f, -0.96227646f, + -1.94805872f, -0.96003568f, + -0.21484625f, 3.49263334f, -1.35283577f, 0.38428289f, -4.29686069f, -4.34778786f, + -0.49574745f, -0.08637778f, -0.50855160f, -1.12334609f, -1.44851387f, 3.36797357f, + -0.91776383f, -0.98647243f, 1.45408130f, 0.29062888f, 0.24470398f, -1.28129590f, + 0.47530234f, 2.19562674f, 0.62674099f, -2.56222868f, -1.42671025f, 1.51795268f, + -1.92045701f, 1.20271325f, 2.53190184f, -0.37211552f, 0.92569226f, -1.11019444f, + 1.15402830f, -1.98479640f, + + 1.06492281f, 0.32333452f, -0.52869648f, -1.25258100f, 0.75479198f, -1.04409528f, + -1.81722605f, 0.99018478f, 1.83352923f, 1.02711058f, 0.31064227f, 2.44383168f, + -1.80332434f, 1.57207584f, -0.41058558f, 0.20494992f, -0.78399467f, -0.35703743f, + -0.67568171f, -1.30091023f, -0.17390330f, 0.22340816f, -0.44613233f, 1.23870432f, + -0.16092014f, -1.22258115f, 0.60575533f, -0.17969827f, 1.87851882f, 1.13991237f, + -0.81591004f, -1.68899822f, + -0.49658760f, 1.62168694f, -1.71412969f, -1.26646388f, -1.37257946f, 1.53828073f, + -0.35583261f, 0.03810386f, 0.43514529f, 0.97525519f, -2.22109556f, 1.17547810f, + -0.28825673f, 0.91509271f, -1.19243717f, 1.09280133f, -0.51078367f, 0.63577116f, + -0.62186599f, -2.80234575f, -1.58007598f, 1.06965756f, -0.89327252f, -0.84735525f, + -0.46283475f, 0.77867299f, -0.07434830f, 1.44711912f, 1.07089376f, 0.78913736f, + 0.59053934f, -0.32160193f, + + -1.72543812f, 0.63848293f, 0.87042624f, 0.39726460f, 0.62647510f, 1.73326159f, + -0.55110240f, -1.26900804f, 1.94843686f, -1.73077893f, 2.53475809f, 2.79892564f, + -1.91852188f, 0.99826050f, -3.04680610f, 1.38900220f, -1.17920876f, -2.07508063f, + -0.34274688f, -0.24780962f, 1.75715542f, 1.27657294f, -1.15560341f, -2.69310951f, + 0.93523502f, 0.58213681f, -2.57009196f, 2.56376076f, 0.06911665f, 1.73962176f, + 0.43307841f, -1.18240118f, + 0.51273453f, 1.12628150f, 1.96404183f, 0.26380035f, 3.41526699f, 1.08249199f, + -1.70347631f, 0.42854923f, -1.98269284f, 1.97382474f, -0.12164606f, -1.41219604f, + 0.01819625f, 0.73082930f, -2.60845804f, 1.47046185f, 0.26324001f, 1.54259276f, + -1.18744254f, -1.77539694f, 1.76547086f, -1.57072937f, -1.83995926f, -0.05529352f, + 1.83544660f, 0.69575423f, -0.03345531f, -1.69629955f, 0.04713173f, 1.39800107f, + 0.24362923f, 0.12432972f, + + 1.52338290f, 1.02856898f, 0.40946901f, 1.57649779f, 1.22447217f, 0.85961932f, + 0.30765539f, -2.66427660f, -1.55998194f, -0.31161505f, -1.63090813f, -1.62476087f, + 1.28381526f, -0.77024549f, 1.46711981f, -0.71657622f, -0.51606011f, 0.87953311f, + 0.26169056f, 1.03068113f, 0.41064253f, -1.56344402f, -1.53443003f, -0.03009570f, + -0.02123317f, -1.74375248f, 1.60988081f, 1.74488568f, 0.59155780f, -0.62032932f, + 0.03105794f, 4.54175377f, + -2.92895460f, -0.46070760f, 0.20383459f, 1.93618548f, -1.08026588f, 1.08253515f, + -0.48318014f, -2.34334373f, -2.69622159f, 0.00661799f, -1.10738027f, 0.03181311f, + 0.32897863f, 1.89451993f, -0.01152946f, 0.17766151f, 2.46450090f, -0.64409554f, + 2.56058550f, 1.29339278f, 2.72114944f, 0.87801707f, -1.58970404f, 2.88365316f, + 0.46464550f, -1.71912467f, -1.90960062f, -3.13572145f, 0.19871379f, -0.28741950f, + -0.38167781f, -2.30705547f, + + -2.08403850f, 0.22352570f, -0.17924348f, 0.65815634f, -0.59089363f, -1.66189861f, + 0.75618476f, 0.03879535f, 1.50222909f, -0.29873836f, -1.76075482f, -2.97067928f, + 0.28112072f, 0.72105575f, 0.06761266f, -1.61681306f, -0.80693424f, 2.40102959f, + -2.91352296f, -1.21352315f, -1.62430143f, -1.60909438f, 0.53140688f, -0.28235722f, + 0.63271880f, 1.33791542f, -1.37593675f, -1.60502291f, 1.27470064f, -0.96280038f, + 0.79614848f, 0.31894624f, + 0.64399612f, 0.32866889f, -3.49091625f, -0.02294427f, 1.60225844f, 1.83659923f, + 1.55193460f, -0.06712314f, 0.76592684f, 0.83479869f, 0.49627584f, 0.75736403f, + 0.75179487f, -0.32156041f, 1.36537170f, 0.57024354f, 0.36152276f, 0.93625057f, + -1.69728792f, -0.28833422f, 0.43304375f, 1.62640548f, -0.00187188f, 0.80429250f, + -0.77993584f, 1.37333393f, -1.16019452f, -0.91983509f, 0.20466281f, 1.09339333f, + -0.99191529f, 3.42685890f}; + } + + { + // Do not test fp32 + data.fp32_output_data = {}; + } + + { + data.fp16_output_data = { + -0.18665725f, 1.53655565f, -1.16219902f, -0.53553712f, -1.76899862f, -0.67172408f, + -0.03719823f, -0.73519617f, -0.08289805f, -0.22439885f, -1.15095568f, 1.52012229f, + -0.11608444f, 0.30267856f, 0.17237782f, 0.12366229f, -0.15282108f, 0.15652999f, + -0.05062571f, -0.60356319f, -0.67014134f, -0.12373877f, -0.62331146f, -0.00974876f, + -1.22021353f, 0.52888882f, 0.52984023f, 0.60431194f, 0.64458221f, 0.19681633f, + 0.87637067f, -0.49721599f, + + -0.21368809f, 3.48006248f, -1.34989023f, 0.38081154f, -4.28221798f, -4.33166409f, + -0.49185523f, -0.09290559f, -0.50751436f, -1.12148952f, -1.44325578f, 3.35744357f, + -0.91217732f, -0.98030341f, 1.45034039f, 0.28651160f, 0.24333692f, -1.27377033f, + 0.47380278f, 2.18498206f, 0.62146491f, -2.55067039f, -1.42080343f, 1.51099622f, + -1.91845036f, 1.19768524f, 2.52122355f, -0.36864227f, 0.92257524f, -1.10384953f, + 1.15318692f, -1.97599709f, + + 0.25325853f, 0.99478984f, 1.75511229f, 0.38680723f, 3.04894686f, 1.09290445f, + -1.56589723f, 0.21126932f, -1.99892235f, 1.80875492f, -0.18795207f, -1.27262163f, + 0.05191655f, 0.80464834f, -2.35645056f, 1.35988820f, 0.43171301f, 1.36821294f, + -0.90993541f, -1.52189243f, 1.81963241f, -1.34069264f, -1.79558825f, 0.17969209f, + 1.69527614f, 0.52177316f, -0.19144230f, -1.79486036f, 0.06081408f, 1.26584184f, + 0.17910211f, -0.01467115f, + + 0.15915775f, 0.23589413f, -2.89726520f, 0.24662971f, 1.27141249f, 1.72167253f, + 1.22059381f, -0.36594886f, 0.25064397f, 0.74270636f, 0.26896626f, 0.62173104f, + 0.68196213f, -0.00399938f, 1.11046481f, 0.53283978f, 0.64384484f, 0.73332942f, + -1.11337614f, -0.10050645f, 0.76519096f, 1.46986043f, -0.24821334f, 1.07021677f, + -0.56646812f, 0.94391233f, -1.24186087f, -1.23258281f, 0.20112629f, 0.91218621f, + -0.88806105f, 2.59514260f}; + } +} + +void GetSelfAttentionData_Batch2_HeadSize32_NoBias_NoMask_PackedQKV(AttentionTestData& data) { + data.hidden_size = 32; + data.v_hidden_size = 32; + data.num_heads = 1; + data.batch_size = 2; + data.sequence_length = 2; + data.kv_sequence_length = 2; + data.mask_type = AttentionMaskType::MASK_NONE; + // Packed QKV format is only supported by TRT fused attention or memory efficient attention right now. + data.skip_kernel_types = { + AttentionKernelType::AttentionKernel_Unfused, + AttentionKernelType::AttentionKernel_TrtFusedCrossAttention}; + + { + data.query_data = { + -0.35420692f, 1.31206024f, -2.80201197f, 2.42258096f, -0.86031514f, -1.44535458f, -0.10832444f, -2.00132895f, + 1.62475216f, 0.10978927f, 1.84596729f, 0.48908550f, 1.44369888f, 0.87542874f, -1.16434252f, 0.52133209f, + 1.54848897f, -2.21174526f, -0.28574878f, 0.70815033f, 1.18327498f, 3.14097571f, -0.25795099f, 1.89341247f, + -0.11603792f, 0.38110194f, 0.40873206f, -1.14149106f, 0.79770875f, -0.98069525f, -1.53588808f, 0.50821728f, + -2.21641898f, 0.55090773f, 0.80901796f, -0.56089771f, 0.03574468f, -1.27940118f, -0.02213959f, -0.80698186f, + -0.82701880f, 1.72937381f, 1.56083691f, -0.30311784f, -0.25183848f, 0.24280515f, 0.29569417f, -0.31162494f, + 0.48996922f, 0.22795241f, 2.07125854f, 1.45823467f, 3.03750706f, 1.53734803f, 0.48668906f, -1.63703632f, + -0.14114749f, 1.85963213f, 1.20729232f, -0.28972962f, -0.80783498f, -1.16619551f, -0.60004634f, 0.02498829f, + + 3.50846076f, -2.50027657f, -2.59866142f, 1.58495271f, 2.21110034f, -2.74877763f, -1.00267041f, 0.62646407f, + 2.50227380f, -0.27291518f, -0.33037442f, 0.75840306f, 0.45437157f, -0.79876304f, 0.83509272f, 2.53716302f, + 0.01348384f, -2.16307616f, 2.01661849f, 2.10746121f, -1.70485222f, 1.35548759f, 1.39401650f, -0.99451691f, + -4.13484812f, 0.56262714f, -0.92725742f, -0.16389316f, -1.31260049f, 2.32357836f, -3.05251694f, -1.12570131f, + 1.87849474f, -1.80381167f, 0.52235699f, 2.38887334f, -1.58878529f, 0.69571090f, 1.65044296f, -0.27024290f, + 3.59580970f, -1.97888982f, 1.17034674f, 0.26716161f, -1.16770899f, 0.74609619f, 0.78886843f, 0.15717520f, + -0.93303132f, -0.84753871f, -4.32799959f, -1.94716609f, -1.16980326f, 1.62631667f, 2.41053247f, 3.78186774f, + 0.26432252f, -0.40396988f, 2.04414082f, 0.65150046f, 0.47777444f, -2.57569051f, 0.99004912f, 2.47947693f}; + } + { + data.key_data = { + -0.04407793f, 1.29459429f, 1.05810797f, 1.92067695f, -0.65047157f, 0.99029726f, -1.69796586f, 1.15320420f, + -1.66444266f, 1.78305888f, 1.20582056f, 1.69975281f, 0.34572244f, -0.60833001f, 2.59864879f, -1.05330181f, + -1.16554165f, -0.03781542f, -1.13475525f, 0.71595150f, -0.91169560f, 1.26686060f, 1.60492957f, -0.53510487f, + -1.40180850f, 1.83253956f, 2.70238972f, -1.48750985f, 0.47105616f, -0.79477602f, -1.93152475f, 1.04042351f, + 1.21863425f, 1.20610654f, 0.69031805f, 2.60092020f, 1.43040228f, 0.60616529f, 0.47948456f, -1.15139377f, + 0.15641990f, -0.46933329f, 0.64774191f, 0.35970241f, -1.00424135f, 0.01247875f, 1.00281739f, -1.10514688f, + 0.30922988f, -0.82255656f, -1.23242986f, 0.90557313f, -0.38946581f, -0.21124774f, -2.37903309f, -1.42169905f, + -0.05935127f, 0.47488672f, -0.37083727f, 1.31585515f, -0.21577421f, -0.97746384f, -0.13380399f, 1.77390409f, + + -2.65206385f, 1.26134932f, -1.01682174f, 0.64366758f, 0.95474619f, 2.06720352f, 0.51750720f, -0.07041813f, + 0.53124994f, -3.26612782f, 1.37013340f, 0.13939659f, -0.57418114f, 0.80680281f, -3.40751696f, -0.15847699f, + 0.97837782f, -0.09121911f, 1.18452120f, 0.52711177f, -1.86135840f, -0.11258313f, 0.85863215f, -2.60261130f, + 0.72695309f, 1.44092011f, 0.43785980f, -1.63415265f, -1.05772328f, 0.12997569f, 0.07356137f, -0.62493324f, + -0.43267637f, -1.80009198f, 0.92961007f, 2.05127883f, -2.85521173f, -0.21652693f, -0.89153922f, 0.15524670f, + -2.16850328f, 1.46751809f, 2.51663852f, -0.49499366f, 0.19886012f, 0.77093124f, -1.14819765f, 1.47111738f, + 2.42824388f, 1.56369960f, 1.69934130f, -0.42460468f, -2.25951004f, -1.18074155f, 3.51091242f, -0.30183151f, + -1.83517075f, -0.56233191f, 2.35561657f, -3.63751698f, -3.20001125f, -1.66120780f, 3.23455381f, -1.86251283f}; + } + + { + data.value_data = { + -0.89167893f, 0.02633595f, -0.84866279f, 1.43489110f, -2.91941142f, -0.20650116f, 1.85965109f, 0.45669034f, + 0.07678832f, 0.04492294f, 0.67326981f, 0.97103029f, 1.53470886f, -1.10242307f, 0.86584085f, -0.34770033f, + -1.24311507f, -1.80293822f, -1.01317739f, -0.71518499f, 0.77814674f, -0.59236068f, -2.00310278f, 3.13277125f, + -1.20754123f, 2.01506066f, 0.82650810f, 2.06084490f, -0.46267471f, 1.56365979f, 4.31514502f, -1.03099275f, + -1.85462761f, 2.10100341f, 1.79686451f, 0.23871201f, 1.23598254f, -0.31959364f, 0.50101948f, -0.09527110f, + -1.02331078f, 0.16319990f, -0.54766160f, 0.41597658f, -0.52141404f, 1.71663237f, -0.00776333f, -0.68160462f, + 1.76272714f, -0.04465733f, 0.28247434f, 1.69360149f, 0.14144623f, 0.75038731f, -1.33337545f, 2.23457718f, + -0.07649468f, 1.97064841f, -1.85374629f, -1.59334683f, 0.32698441f, -0.16024286f, 2.02828407f, -0.96440399f, + + -2.11639142f, -1.50897706f, 1.63863683f, 2.32786226f, 1.32746494f, 0.75751448f, 0.57184196f, 0.86446053f, + -0.62406683f, 0.78861046f, 0.01044065f, 3.51772785f, -1.33701336f, 0.27977663f, -0.35464612f, 0.74973166f, + 0.03352100f, 1.55007398f, 0.69849420f, -2.47725606f, -1.89363778f, -1.79874682f, -0.56210291f, -1.75556040f, + 1.07565808f, -0.18023658f, 1.63777173f, 1.28198206f, 2.19431949f, 0.67998970f, -0.52531999f, -1.89906740f, + 1.35158050f, -2.21481490f, -0.11812399f, -1.74263430f, -0.57895988f, -0.04181165f, 0.78120053f, -2.22377038f, + -0.53264999f, -2.03721714f, 0.21023634f, 2.55751204f, -1.04522800f, 0.85386503f, 0.41594937f, -2.98181081f, + 1.14034331f, -1.41539204f, 0.13379651f, 3.47018123f, 1.53924727f, 1.50004411f, 2.87318921f, 1.62624204f, + 0.64942807f, -4.54302311f, -1.50294220f, -1.75212634f, 0.27900690f, -3.05124855f, 3.30960631f, -0.07991691f}; + } + + { + data.bias_data = { + // 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, + // 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, + // 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, + // 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, + // 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f + }; + } + + { + data.qkv_data = { + -0.35420692f, 1.31206024f, -2.80201197f, 2.42258096f, -0.86031514f, -1.44535458f, -0.10832444f, -2.00132895f, + 1.62475216f, 0.10978927f, 1.84596729f, 0.48908550f, 1.44369888f, 0.87542874f, -1.16434252f, 0.52133209f, + 1.54848897f, -2.21174526f, -0.28574878f, 0.70815033f, 1.18327498f, 3.14097571f, -0.25795099f, 1.89341247f, + -0.11603792f, 0.38110194f, 0.40873206f, -1.14149106f, 0.79770875f, -0.98069525f, -1.53588808f, 0.50821728f, + -0.04407793f, 1.29459429f, 1.05810797f, 1.92067695f, -0.65047157f, 0.99029726f, -1.69796586f, 1.15320420f, + -1.66444266f, 1.78305888f, 1.20582056f, 1.69975281f, 0.34572244f, -0.60833001f, 2.59864879f, -1.05330181f, + -1.16554165f, -0.03781542f, -1.13475525f, 0.71595150f, -0.91169560f, 1.26686060f, 1.60492957f, -0.53510487f, + -1.40180850f, 1.83253956f, 2.70238972f, -1.48750985f, 0.47105616f, -0.79477602f, -1.93152475f, 1.04042351f, + -0.89167893f, 0.02633595f, -0.84866279f, 1.43489110f, -2.91941142f, -0.20650116f, 1.85965109f, 0.45669034f, + 0.07678832f, 0.04492294f, 0.67326981f, 0.97103029f, 1.53470886f, -1.10242307f, 0.86584085f, -0.34770033f, + -1.24311507f, -1.80293822f, -1.01317739f, -0.71518499f, 0.77814674f, -0.59236068f, -2.00310278f, 3.13277125f, + -1.20754123f, 2.01506066f, 0.82650810f, 2.06084490f, -0.46267471f, 1.56365979f, 4.31514502f, -1.03099275f, + + -2.21641898f, 0.55090773f, 0.80901796f, -0.56089771f, 0.03574468f, -1.27940118f, -0.02213959f, -0.80698186f, + -0.82701880f, 1.72937381f, 1.56083691f, -0.30311784f, -0.25183848f, 0.24280515f, 0.29569417f, -0.31162494f, + 0.48996922f, 0.22795241f, 2.07125854f, 1.45823467f, 3.03750706f, 1.53734803f, 0.48668906f, -1.63703632f, + -0.14114749f, 1.85963213f, 1.20729232f, -0.28972962f, -0.80783498f, -1.16619551f, -0.60004634f, 0.02498829f, + 1.21863425f, 1.20610654f, 0.69031805f, 2.60092020f, 1.43040228f, 0.60616529f, 0.47948456f, -1.15139377f, + 0.15641990f, -0.46933329f, 0.64774191f, 0.35970241f, -1.00424135f, 0.01247875f, 1.00281739f, -1.10514688f, + 0.30922988f, -0.82255656f, -1.23242986f, 0.90557313f, -0.38946581f, -0.21124774f, -2.37903309f, -1.42169905f, + -0.05935127f, 0.47488672f, -0.37083727f, 1.31585515f, -0.21577421f, -0.97746384f, -0.13380399f, 1.77390409f, + -1.85462761f, 2.10100341f, 1.79686451f, 0.23871201f, 1.23598254f, -0.31959364f, 0.50101948f, -0.09527110f, + -1.02331078f, 0.16319990f, -0.54766160f, 0.41597658f, -0.52141404f, 1.71663237f, -0.00776333f, -0.68160462f, + 1.76272714f, -0.04465733f, 0.28247434f, 1.69360149f, 0.14144623f, 0.75038731f, -1.33337545f, 2.23457718f, + -0.07649468f, 1.97064841f, -1.85374629f, -1.59334683f, 0.32698441f, -0.16024286f, 2.02828407f, -0.96440399f, + + 3.50846076f, -2.50027657f, -2.59866142f, 1.58495271f, 2.21110034f, -2.74877763f, -1.00267041f, 0.62646407f, + 2.50227380f, -0.27291518f, -0.33037442f, 0.75840306f, 0.45437157f, -0.79876304f, 0.83509272f, 2.53716302f, + 0.01348384f, -2.16307616f, 2.01661849f, 2.10746121f, -1.70485222f, 1.35548759f, 1.39401650f, -0.99451691f, + -4.13484812f, 0.56262714f, -0.92725742f, -0.16389316f, -1.31260049f, 2.32357836f, -3.05251694f, -1.12570131f, + -2.65206385f, 1.26134932f, -1.01682174f, 0.64366758f, 0.95474619f, 2.06720352f, 0.51750720f, -0.07041813f, + 0.53124994f, -3.26612782f, 1.37013340f, 0.13939659f, -0.57418114f, 0.80680281f, -3.40751696f, -0.15847699f, + 0.97837782f, -0.09121911f, 1.18452120f, 0.52711177f, -1.86135840f, -0.11258313f, 0.85863215f, -2.60261130f, + 0.72695309f, 1.44092011f, 0.43785980f, -1.63415265f, -1.05772328f, 0.12997569f, 0.07356137f, -0.62493324f, + -2.11639142f, -1.50897706f, 1.63863683f, 2.32786226f, 1.32746494f, 0.75751448f, 0.57184196f, 0.86446053f, + -0.62406683f, 0.78861046f, 0.01044065f, 3.51772785f, -1.33701336f, 0.27977663f, -0.35464612f, 0.74973166f, + 0.03352100f, 1.55007398f, 0.69849420f, -2.47725606f, -1.89363778f, -1.79874682f, -0.56210291f, -1.75556040f, + 1.07565808f, -0.18023658f, 1.63777173f, 1.28198206f, 2.19431949f, 0.67998970f, -0.52531999f, -1.89906740f, + + 1.87849474f, -1.80381167f, 0.52235699f, 2.38887334f, -1.58878529f, 0.69571090f, 1.65044296f, -0.27024290f, + 3.59580970f, -1.97888982f, 1.17034674f, 0.26716161f, -1.16770899f, 0.74609619f, 0.78886843f, 0.15717520f, + -0.93303132f, -0.84753871f, -4.32799959f, -1.94716609f, -1.16980326f, 1.62631667f, 2.41053247f, 3.78186774f, + 0.26432252f, -0.40396988f, 2.04414082f, 0.65150046f, 0.47777444f, -2.57569051f, 0.99004912f, 2.47947693f, + -0.43267637f, -1.80009198f, 0.92961007f, 2.05127883f, -2.85521173f, -0.21652693f, -0.89153922f, 0.15524670f, + -2.16850328f, 1.46751809f, 2.51663852f, -0.49499366f, 0.19886012f, 0.77093124f, -1.14819765f, 1.47111738f, + 2.42824388f, 1.56369960f, 1.69934130f, -0.42460468f, -2.25951004f, -1.18074155f, 3.51091242f, -0.30183151f, + -1.83517075f, -0.56233191f, 2.35561657f, -3.63751698f, -3.20001125f, -1.66120780f, 3.23455381f, -1.86251283f, + 1.35158050f, -2.21481490f, -0.11812399f, -1.74263430f, -0.57895988f, -0.04181165f, 0.78120053f, -2.22377038f, + -0.53264999f, -2.03721714f, 0.21023634f, 2.55751204f, -1.04522800f, 0.85386503f, 0.41594937f, -2.98181081f, + 1.14034331f, -1.41539204f, 0.13379651f, 3.47018123f, 1.53924727f, 1.50004411f, 2.87318921f, 1.62624204f, + 0.64942807f, -4.54302311f, -1.50294220f, -1.75212634f, 0.27900690f, -3.05124855f, 3.30960631f, -0.07991691f}; + } + + { + // Do not test fp32 + data.fp32_output_data = {}; + } + + { + data.fp16_output_data = { + -1.30247164f, 0.91138631f, 0.27991560f, 0.92460269f, -1.14672589f, -0.25474626f, + 1.28006065f, 0.22122431f, -0.39251250f, 0.09537974f, 0.15242209f, 0.73424512f, + 0.65756959f, 0.10018224f, 0.49316248f, -0.49014348f, 0.03917319f, -1.05285788f, + -0.46045411f, 0.31240013f, 0.50653118f, -0.01954618f, -1.71739793f, 2.74960279f, + -0.72503829f, 1.99611449f, -0.31688485f, 0.50197154f, -0.12580720f, 0.82824522f, + 3.33957314f, -1.00258613f, + + -0.95444643f, 0.16156809f, -0.67622054f, 1.35692120f, -2.64855242f, -0.21387282f, + 1.77109206f, 0.42071211f, 0.00508105f, 0.05263254f, 0.59368640f, 0.93485051f, + 1.40068567f, -0.91866994f, 0.80889714f, -0.36946508f, -1.04718661f, -1.68832910f, + -0.92872351f, -0.55817413f, 0.73664504f, -0.50483698f, -1.95944834f, 3.07422471f, + -1.13381684f, 2.01216578f, 0.65180230f, 1.82265544f, -0.41120273f, 1.45129156f, + 4.16608191f, -1.02665234f, + + 0.21158576f, -1.98279130f, 0.45935997f, -0.40457720f, 0.04772174f, 0.22094353f, + 0.71238005f, -1.20860445f, -0.56270063f, -1.10830855f, 0.14455934f, 2.87315512f, + -1.14114404f, 0.66515017f, 0.16263856f, -1.75517499f, 0.77650774f, -0.44058144f, + 0.31942442f, 1.51513433f, 0.41078627f, 0.41566271f, 1.74393702f, 0.51457298f, + 0.78953874f, -3.10888410f, -0.47052401f, -0.75475156f, 0.90861011f, -1.82471263f, + 2.04898596f, -0.67790967f, + + 1.17194295f, -2.17825341f, -0.02712549f, -1.53178656f, -0.48020893f, -0.00040733f, + 0.77035600f, -2.06380320f, -0.53738528f, -1.89084208f, 0.19988713f, 2.60725045f, + -1.06034219f, 0.82412785f, 0.37603328f, -2.78852081f, 1.08301103f, -1.26178384f, + 0.16304730f, 3.16210985f, 1.36142719f, 1.32916999f, 2.69524455f, 1.45106804f, + 0.67150640f, -4.31703520f, -1.34025633f, -1.59496248f, 0.37821823f, -2.85797405f, + 3.11096096f, -0.17414713f}; + } +} #endif void GetCrossAttentionData_HeadSize16_8(AttentionTestData& data) { diff --git a/onnxruntime/test/contrib_ops/attention_op_test_helper.h b/onnxruntime/test/contrib_ops/attention_op_test_helper.h index 664fbb50aa6d7..3a00b08b45318 100644 --- a/onnxruntime/test/contrib_ops/attention_op_test_helper.h +++ b/onnxruntime/test/contrib_ops/attention_op_test_helper.h @@ -5,38 +5,45 @@ #include "contrib_ops/cpu/bert/attention_common.h" namespace onnxruntime { -using contrib::AttentionMaskType; using contrib::AttentionKernelType; +using contrib::AttentionMaskType; namespace test { -struct AttentionTestData{ - int hidden_size; - int v_hidden_size; - int num_heads; - int batch_size; - int sequence_length; - int kv_sequence_length; - AttentionMaskType mask_type; - std::vector key_padding_mask_data; - std::vector query_data; - std::vector key_data; - std::vector value_data; - std::vector bias_data; - std::vector fp32_output_data; - std::vector fp16_output_data; - std::vector skip_kernel_types; // skip some kernels if they do not supported this test case. +struct AttentionTestData { + int hidden_size; + int v_hidden_size; + int num_heads; + int batch_size; + int sequence_length; + int kv_sequence_length; + AttentionMaskType mask_type; + std::vector key_padding_mask_data; + std::vector query_data; + std::vector key_data; + std::vector value_data; + + std::vector kv_data; + std::vector qkv_data; + + std::vector bias_data; + std::vector fp32_output_data; + std::vector fp16_output_data; + std::vector skip_kernel_types; // skip some kernels if they do not supported this test case. }; // Disable some tests in Windows since prefast build might crash with large test data. #ifndef _MSC_VER // Return packed weights and bias for input projection. -void GetAttentionWeight(std::vector& weight_data, int elements = 64 * 3 * 64, int offset = 0, int step=1); -void GetAttentionBias(std::vector& bias_data, int elements = 3 * 64, int offset = 0, int step=1); +void GetAttentionWeight(std::vector& weight_data, int elements = 64 * 3 * 64, int offset = 0, int step = 1); +void GetAttentionBias(std::vector& bias_data, int elements = 3 * 64, int offset = 0, int step = 1); void GetCrossAttentionData_HeadSize40(AttentionTestData& data); void GetCrossAttentionData_Batch2_HeadSize32_RightSidePadding(AttentionTestData& data, bool is_mask_1d); void GetCrossAttentionData_Batch1_HeadSize32_LeftSidePadding(AttentionTestData& data); + +void GetCrossAttentionData_Batch2_HeadSize32_NoBias_NoMask_PackedKV(AttentionTestData& data); +void GetSelfAttentionData_Batch2_HeadSize32_NoBias_NoMask_PackedQKV(AttentionTestData& data); #endif void GetCrossAttentionData_HeadSize16_8(AttentionTestData& data); diff --git a/onnxruntime/test/contrib_ops/multihead_attention_op_test.cc b/onnxruntime/test/contrib_ops/multihead_attention_op_test.cc index eddf9ea4c7c4e..83b9f628655ab 100644 --- a/onnxruntime/test/contrib_ops/multihead_attention_op_test.cc +++ b/onnxruntime/test/contrib_ops/multihead_attention_op_test.cc @@ -16,7 +16,9 @@ static void RunMultiHeadAttentionTest( const std::vector& query_data, // query: [batch_size, sequence_length, hidden_size] const std::vector& key_data, // key: [batch_size, kv_sequence_length, hidden_size] const std::vector& value_data, // value: [batch_size, kv_sequence_length, v_hidden_size] - const std::vector& bias_data, // bias: [hidden_size + hidden_size + v_hidden_size] + const std::vector& kv_data, // packed_kv: [batch_size, kv_sequence_length, num_heads, 2, head_size] + const std::vector& qkv_data, // packed_qkv: [batch_size, sequence_length, num_heads, 3, head_size] + const std::vector& bias_data, // bias: [hidden_size + hidden_size + v_hidden_size] or empty const std::vector& key_padding_mask_data, // key_padding_mask: see below AttentionMaskType mask_type, // 1 for [batch_size], 2 for [batch_size, kv_sequence_length] const std::vector& output_data, // output: [batch_size, sequence_length, v_hidden_size] @@ -33,7 +35,7 @@ static void RunMultiHeadAttentionTest( { kv_sequence_length = (kv_sequence_length == 0 ? sequence_length : kv_sequence_length); - int min_cuda_architecture = use_float16 ? 530 : 0; + int min_cuda_architecture = use_float16 ? 750 : 0; bool enable_cuda = HasCudaEnvironment(min_cuda_architecture) && !disable_cuda; bool enable_rocm = (nullptr != DefaultRocmExecutionProvider().get()) && !disable_rocm; bool enable_cpu = (nullptr != DefaultCpuExecutionProvider().get()) && !use_float16 && !disable_cpu; @@ -49,6 +51,23 @@ static void RunMultiHeadAttentionTest( std::vector bias_dims = {hidden_size + hidden_size + v_hidden_size}; std::vector output_dims = {batch_size, sequence_length, v_hidden_size}; + std::vector query = (qkv_data.size() > 0 ? qkv_data : query_data); + std::vector key; + std::vector value; + if (qkv_data.size() == 0) { + if (kv_data.size() > 0) { + ORT_ENFORCE(hidden_size == v_hidden_size); + key = kv_data; + key_dims = {batch_size, kv_sequence_length, num_heads, 2, hidden_size / num_heads}; + } else { + key = key_data; + value = value_data; + } + } else { + ORT_ENFORCE(sequence_length == kv_sequence_length && hidden_size == v_hidden_size); + query_dims = {batch_size, sequence_length, num_heads, 3, hidden_size / num_heads}; + } + std::vector mask_dims_1 = {batch_size}; std::vector mask_dims_2 = {batch_size, kv_sequence_length}; std::vector& key_padding_mask_dims = (mask_type == AttentionMaskType::MASK_1D_KEY_SEQ_LEN) @@ -56,9 +75,19 @@ static void RunMultiHeadAttentionTest( : mask_dims_2; if (use_float16) { - tester.AddInput("query", query_dims, ToFloat16(query_data)); - tester.AddInput("key", key_dims, ToFloat16(key_data)); - tester.AddInput("value", value_dims, ToFloat16(value_data)); + tester.AddInput("query", query_dims, ToFloat16(query)); + + if (key.size()) { + tester.AddInput("key", key_dims, ToFloat16(key)); + } else { + tester.AddOptionalInputEdge(); + } + + if (value.size()) { + tester.AddInput("value", value_dims, ToFloat16(value)); + } else { + tester.AddOptionalInputEdge(); + } if (bias_data.size()) { tester.AddInput("bias", bias_dims, ToFloat16(bias_data)); @@ -76,9 +105,19 @@ static void RunMultiHeadAttentionTest( constexpr float abs_error = 0.05f; tester.AddOutput("output", output_dims, ToFloat16(output_data), /*sort*/ false, rel_error, abs_error); } else { - tester.AddInput("query", query_dims, query_data); - tester.AddInput("key", key_dims, key_data); - tester.AddInput("value", value_dims, value_data); + tester.AddInput("query", query_dims, query); + + if (key.size()) { + tester.AddInput("key", key_dims, key); + } else { + tester.AddOptionalInputEdge(); + } + + if (value.size()) { + tester.AddInput("value", value_dims, value); + } else { + tester.AddOptionalInputEdge(); + } if (bias_data.size()) { tester.AddInput("bias", bias_dims, bias_data); @@ -121,6 +160,8 @@ static void RunMultiHeadAttentionKernel( const std::vector& query_data, // query: [batch_size, sequence_length, hidden_size] const std::vector& key_data, // key: [batch_size, kv_sequence_length, hidden_size] const std::vector& value_data, // value: [batch_size, kv_sequence_length, v_hidden_size] + const std::vector& kv_data, // packed_kv: [batch_size, kv_sequence_length, num_heads, 2, head_size] + const std::vector& qkv_data, // packed_qkv: [batch_size, sequence_length, num_heads, 3, head_size] const std::vector& bias_data, // bias: [hidden_size + hidden_size + v_hidden_size] const std::vector& key_padding_mask_data, // key_padding_mask: see below AttentionMaskType mask_type, // 1 for [batch_size], 2 for [batch_size, kv_sequence_length] @@ -144,7 +185,7 @@ static void RunMultiHeadAttentionKernel( {onnxruntime::contrib::attention::kDisableFusedCrossAttention, "0"}, {onnxruntime::contrib::attention::kDisableMemoryEfficientAttention, "0"}}}; RunMultiHeadAttentionTest( - query_data, key_data, value_data, bias_data, key_padding_mask_data, mask_type, output_data, + query_data, key_data, value_data, kv_data, qkv_data, bias_data, key_padding_mask_data, mask_type, output_data, num_heads, batch_size, sequence_length, kv_sequence_length, hidden_size, v_hidden_size, use_float16, disable_cpu, disable_cuda, disable_rocm); return; @@ -158,7 +199,7 @@ static void RunMultiHeadAttentionKernel( {onnxruntime::contrib::attention::kDisableFusedCrossAttention, "1"}, {onnxruntime::contrib::attention::kDisableMemoryEfficientAttention, "1"}}}; RunMultiHeadAttentionTest( - query_data, key_data, value_data, bias_data, key_padding_mask_data, mask_type, output_data, + query_data, key_data, value_data, kv_data, qkv_data, bias_data, key_padding_mask_data, mask_type, output_data, num_heads, batch_size, sequence_length, kv_sequence_length, hidden_size, v_hidden_size, use_float16, disable_cpu, disable_cuda, disable_rocm); return; @@ -172,7 +213,7 @@ static void RunMultiHeadAttentionKernel( {onnxruntime::contrib::attention::kDisableFusedCrossAttention, "0"}, {onnxruntime::contrib::attention::kDisableMemoryEfficientAttention, "1"}}}; RunMultiHeadAttentionTest( - query_data, key_data, value_data, bias_data, key_padding_mask_data, mask_type, output_data, + query_data, key_data, value_data, kv_data, qkv_data, bias_data, key_padding_mask_data, mask_type, output_data, num_heads, batch_size, sequence_length, kv_sequence_length, hidden_size, v_hidden_size, use_float16, disable_cpu, disable_cuda, disable_rocm); return; @@ -187,7 +228,7 @@ static void RunMultiHeadAttentionKernel( {onnxruntime::contrib::attention::kDisableFusedCrossAttention, "1"}, {onnxruntime::contrib::attention::kDisableMemoryEfficientAttention, "0"}}}; RunMultiHeadAttentionTest( - query_data, key_data, value_data, bias_data, key_padding_mask_data, mask_type, output_data, + query_data, key_data, value_data, kv_data, qkv_data, bias_data, key_padding_mask_data, mask_type, output_data, num_heads, batch_size, sequence_length, kv_sequence_length, hidden_size, v_hidden_size, use_float16, disable_cpu, disable_cuda, disable_rocm); return; @@ -202,7 +243,7 @@ static void RunMultiHeadAttentionKernel( {onnxruntime::contrib::attention::kDisableFusedCrossAttention, "1"}, {onnxruntime::contrib::attention::kDisableMemoryEfficientAttention, "1"}}}; RunMultiHeadAttentionTest( - query_data, key_data, value_data, bias_data, key_padding_mask_data, mask_type, output_data, + query_data, key_data, value_data, kv_data, qkv_data, bias_data, key_padding_mask_data, mask_type, output_data, num_heads, batch_size, sequence_length, kv_sequence_length, hidden_size, v_hidden_size, use_float16, disable_cpu, disable_cuda, disable_rocm); } @@ -215,7 +256,7 @@ static void RunMultiHeadAttentionTests(AttentionTestData& data) { AttentionKernelType kernel_type = AttentionKernelType::AttentionKernel_Unfused; if (!SkipAttentionKernel(data, kernel_type)) { RunMultiHeadAttentionKernel( - data.query_data, data.key_data, data.value_data, data.bias_data, data.key_padding_mask_data, data.mask_type, + data.query_data, data.key_data, data.value_data, data.kv_data, data.qkv_data, data.bias_data, data.key_padding_mask_data, data.mask_type, data.fp32_output_data, data.num_heads, data.batch_size, data.sequence_length, data.kv_sequence_length, data.hidden_size, data.v_hidden_size, kernel_type, use_float16); } @@ -226,7 +267,7 @@ static void RunMultiHeadAttentionTests(AttentionTestData& data) { kernel_type = AttentionKernelType::AttentionKernel_CutlassMemoryEfficientAttention; if (!SkipAttentionKernel(data, kernel_type)) { RunMultiHeadAttentionKernel( - data.query_data, data.key_data, data.value_data, data.bias_data, data.key_padding_mask_data, data.mask_type, + data.query_data, data.key_data, data.value_data, data.kv_data, data.qkv_data, data.bias_data, data.key_padding_mask_data, data.mask_type, data.fp32_output_data, data.num_heads, data.batch_size, data.sequence_length, data.kv_sequence_length, data.hidden_size, data.v_hidden_size, kernel_type, use_float16); } @@ -235,7 +276,7 @@ static void RunMultiHeadAttentionTests(AttentionTestData& data) { kernel_type = AttentionKernelType::AttentionKernel_Default; RunMultiHeadAttentionKernel( - data.query_data, data.key_data, data.value_data, data.bias_data, data.key_padding_mask_data, data.mask_type, + data.query_data, data.key_data, data.value_data, data.kv_data, data.qkv_data, data.bias_data, data.key_padding_mask_data, data.mask_type, data.fp32_output_data, data.num_heads, data.batch_size, data.sequence_length, data.kv_sequence_length, data.hidden_size, data.v_hidden_size, kernel_type, use_float16); } @@ -245,7 +286,7 @@ static void RunMultiHeadAttentionTests(AttentionTestData& data) { AttentionKernelType kernel_type = AttentionKernelType::AttentionKernel_TrtFusedCrossAttention; if (!SkipAttentionKernel(data, kernel_type)) { RunMultiHeadAttentionKernel( - data.query_data, data.key_data, data.value_data, data.bias_data, data.key_padding_mask_data, data.mask_type, + data.query_data, data.key_data, data.value_data, data.kv_data, data.qkv_data, data.bias_data, data.key_padding_mask_data, data.mask_type, data.fp16_output_data, data.num_heads, data.batch_size, data.sequence_length, data.kv_sequence_length, data.hidden_size, data.v_hidden_size, kernel_type, use_float16); } @@ -253,7 +294,7 @@ static void RunMultiHeadAttentionTests(AttentionTestData& data) { kernel_type = AttentionKernelType::AttentionKernel_TrtFusedAttention; if (!SkipAttentionKernel(data, kernel_type)) { RunMultiHeadAttentionKernel( - data.query_data, data.key_data, data.value_data, data.bias_data, data.key_padding_mask_data, data.mask_type, + data.query_data, data.key_data, data.value_data, data.kv_data, data.qkv_data, data.bias_data, data.key_padding_mask_data, data.mask_type, data.fp16_output_data, data.num_heads, data.batch_size, data.sequence_length, data.kv_sequence_length, data.hidden_size, data.v_hidden_size, kernel_type, use_float16); } @@ -262,7 +303,7 @@ static void RunMultiHeadAttentionTests(AttentionTestData& data) { kernel_type = AttentionKernelType::AttentionKernel_CutlassMemoryEfficientAttention; if (!SkipAttentionKernel(data, kernel_type)) { RunMultiHeadAttentionKernel( - data.query_data, data.key_data, data.value_data, data.bias_data, data.key_padding_mask_data, data.mask_type, + data.query_data, data.key_data, data.value_data, data.kv_data, data.qkv_data, data.bias_data, data.key_padding_mask_data, data.mask_type, data.fp16_output_data, data.num_heads, data.batch_size, data.sequence_length, data.kv_sequence_length, data.hidden_size, data.v_hidden_size, kernel_type, use_float16); } @@ -270,7 +311,7 @@ static void RunMultiHeadAttentionTests(AttentionTestData& data) { kernel_type = AttentionKernelType::AttentionKernel_Default; RunMultiHeadAttentionKernel( - data.query_data, data.key_data, data.value_data, data.bias_data, data.key_padding_mask_data, data.mask_type, + data.query_data, data.key_data, data.value_data, data.kv_data, data.qkv_data, data.bias_data, data.key_padding_mask_data, data.mask_type, data.fp16_output_data, data.num_heads, data.batch_size, data.sequence_length, data.kv_sequence_length, data.hidden_size, data.v_hidden_size, kernel_type, use_float16); } @@ -302,6 +343,18 @@ TEST(MultiHeadAttentionTest, CrossAttention_Batch1_HeadSize32_LeftSidePadding_Ma GetCrossAttentionData_Batch1_HeadSize32_LeftSidePadding(data); RunMultiHeadAttentionTests(data); } + +TEST(MultiHeadAttentionTest, CrossAttention_Batch2_HeadSize32_NoBias_NoMask_PackedKV) { + AttentionTestData data; + GetCrossAttentionData_Batch2_HeadSize32_NoBias_NoMask_PackedKV(data); + RunMultiHeadAttentionTests(data); +} + +TEST(MultiHeadAttentionTest, SelfAttention_Batch2_HeadSize32_NoBias_NoMask_PackedQKV) { + AttentionTestData data; + GetSelfAttentionData_Batch2_HeadSize32_NoBias_NoMask_PackedQKV(data); + RunMultiHeadAttentionTests(data); +} #endif // This tests qk_head_size != k_head_size diff --git a/onnxruntime/test/contrib_ops/multihead_attention_op_test_data_gen.py b/onnxruntime/test/contrib_ops/multihead_attention_op_test_data_gen.py index 2d7bbc9b6d6c0..fcac5b3ab28eb 100644 --- a/onnxruntime/test/contrib_ops/multihead_attention_op_test_data_gen.py +++ b/onnxruntime/test/contrib_ops/multihead_attention_op_test_data_gen.py @@ -156,6 +156,7 @@ def run_cross_attention( sequence_length, kv_sequence_length, key_padding_mask=None, + has_bias=True, ): seed = 123 torch.manual_seed(seed) @@ -170,10 +171,16 @@ def run_cross_attention( torch.nn.init.uniform_(mha.key.weight, -0.5, 0.5) torch.nn.init.uniform_(mha.value.weight, -0.5, 0.5) - torch.nn.init.uniform_(mha.query.bias, -0.5, 0.5) - torch.nn.init.uniform_(mha.key.bias, -0.5, 0.5) - torch.nn.init.uniform_(mha.value.bias, -0.5, 0.5) + if has_bias: + torch.nn.init.uniform_(mha.query.bias, -0.5, 0.5) + torch.nn.init.uniform_(mha.key.bias, -0.5, 0.5) + torch.nn.init.uniform_(mha.value.bias, -0.5, 0.5) + else: + torch.nn.init.zeros_(mha.query.bias) + torch.nn.init.zeros_(mha.key.bias) + torch.nn.init.zeros_(mha.value.bias) + # Here we simulate input projection with MatMul but no bias: w_q = nn.Linear(hidden_dim, num_heads * q_head_size).to(device).eval() w_k = nn.Linear(hidden_dim, num_heads * q_head_size).to(device).eval() w_v = nn.Linear(hidden_dim, num_heads * v_head_size).to(device).eval() @@ -195,13 +202,27 @@ def run_cross_attention( input_q = w_q(hidden_states.clone()) input_k = w_k(encoder_hidden_states.clone()) input_v = w_v(encoder_hidden_states.clone()) - input_bias = torch.concat([mha.query.bias, mha.key.bias, mha.value.bias]) print("input_q", input_q) print("input_k", input_k) print("input_v", input_v) + + input_bias = torch.concat([mha.query.bias, mha.key.bias, mha.value.bias]) print("input_bias", input_bias) + if not has_bias: + print("no bias!") + + # packed KV + if q_head_size == v_head_size: + packed_kv = torch.dstack( + ( + input_k.reshape(batch_size * kv_sequence_length, num_heads, q_head_size), + input_v.reshape(batch_size * kv_sequence_length, num_heads, v_head_size), + ) + ) + packed_kv = packed_kv.reshape(batch_size, kv_sequence_length, num_heads, 2, q_head_size) + print("packed_kv_5d", packed_kv) - output = mha.forward( + mha.forward( hidden_states, attention_mask=None, encoder_hidden_states=encoder_hidden_states, @@ -211,6 +232,88 @@ def run_cross_attention( ) +def run_self_attention( + hidden_dim, + q_head_size, + v_head_size, + num_heads, + batch_size, + sequence_length, + key_padding_mask=None, + has_bias=True, +): + seed = 123 + torch.manual_seed(seed) + np.random.seed(seed) + torch.use_deterministic_algorithms(True) + + device = torch.device("cuda:0") + mha = Attention(num_heads, hidden_dim, q_head_size, v_head_size, is_decoder=False).to(device).eval() + if key_padding_mask is not None: + key_padding_mask = key_padding_mask.to(device) + torch.nn.init.uniform_(mha.query.weight, -0.5, 0.5) + torch.nn.init.uniform_(mha.key.weight, -0.5, 0.5) + torch.nn.init.uniform_(mha.value.weight, -0.5, 0.5) + + if has_bias: + torch.nn.init.uniform_(mha.query.bias, -0.5, 0.5) + torch.nn.init.uniform_(mha.key.bias, -0.5, 0.5) + torch.nn.init.uniform_(mha.value.bias, -0.5, 0.5) + else: + torch.nn.init.zeros_(mha.query.bias) + torch.nn.init.zeros_(mha.key.bias) + torch.nn.init.zeros_(mha.value.bias) + + # Here we simulate input projection with MatMul but no bias: + w_q = nn.Linear(hidden_dim, num_heads * q_head_size).to(device).eval() + w_k = nn.Linear(hidden_dim, num_heads * q_head_size).to(device).eval() + w_v = nn.Linear(hidden_dim, num_heads * v_head_size).to(device).eval() + w_q.weight.copy_(mha.query.weight) + w_k.weight.copy_(mha.key.weight) + w_v.weight.copy_(mha.value.weight) + torch.nn.init.zeros_(w_q.bias) + torch.nn.init.zeros_(w_k.bias) + torch.nn.init.zeros_(w_v.bias) + + torch.set_printoptions(profile="full", precision=8, linewidth=120, sci_mode=False) + + hidden_states = torch.empty(batch_size, sequence_length, hidden_dim, device="cuda") + torch.nn.init.normal_(hidden_states) + + input_q = w_q(hidden_states.clone()) + input_k = w_k(hidden_states.clone()) + input_v = w_v(hidden_states.clone()) + print("input_q", input_q) + print("input_k", input_k) + print("input_v", input_v) + + input_bias = torch.concat([mha.query.bias, mha.key.bias, mha.value.bias]) + print("input_bias", input_bias) + if not has_bias: + print("no bias!") + + # packed QKV + if q_head_size == v_head_size: + packed_qkv = torch.dstack( + ( + input_q.reshape(batch_size * sequence_length, num_heads, q_head_size), + input_k.reshape(batch_size * sequence_length, num_heads, q_head_size), + input_v.reshape(batch_size * sequence_length, num_heads, v_head_size), + ) + ) + packed_qkv = packed_qkv.reshape(batch_size, sequence_length, num_heads, 3, q_head_size) + print("packed_qkv_5d", packed_qkv) + + mha.forward( + hidden_states, + attention_mask=None, + encoder_hidden_states=None, + encoder_attention_mask=key_padding_mask, + past_key_value=None, + output_attentions=False, + ) + + def run_cross_batch2_headsize_40(): hidden_dim = 80 q_head_size = 40 @@ -293,7 +396,44 @@ def run_cross_batch1_headsize_32_left_side_padding(): ) -def create_cross_attention_test_data(): +def run_cross_batch2_headsize_32_packed_kv(): + hidden_dim = 32 + q_head_size = 32 + v_head_size = 32 + num_heads = 1 + batch_size = 2 + sequence_length = 2 + kv_sequence_length = 3 + key_padding_mask = None + has_bias = False + run_cross_attention( + hidden_dim, + q_head_size, + v_head_size, + num_heads, + batch_size, + sequence_length, + kv_sequence_length, + key_padding_mask, + has_bias, + ) + + +def run_self_batch2_headsize_32_packed_qkv(): + hidden_dim = 32 + q_head_size = 32 + v_head_size = 32 + num_heads = 1 + batch_size = 2 + sequence_length = 2 + key_padding_mask = None + has_bias = False + run_self_attention( + hidden_dim, q_head_size, v_head_size, num_heads, batch_size, sequence_length, key_padding_mask, has_bias + ) + + +def create_test_data(): """ Create test data used in attention_op_test_helper.cc and multihead_attention_op_test.cc """ @@ -312,6 +452,12 @@ def create_cross_attention_test_data(): print("CrossAttention_Batch1_HeadSize32_LeftSidePadding") run_cross_batch1_headsize_32_left_side_padding() + print("CrossAttention_Batch2_HeadSize32_PackedKV") + run_cross_batch2_headsize_32_packed_kv() + + print("SelfAttention_Batch2_HeadSize32_PackedQKV") + run_self_batch2_headsize_32_packed_qkv() + with torch.no_grad(): - create_cross_attention_test_data() + create_test_data()