diff --git a/docs/ContribOperators.md b/docs/ContribOperators.md
index f0543f2649205..c1d12a1d5cba6 100644
--- a/docs/ContribOperators.md
+++ b/docs/ContribOperators.md
@@ -2310,12 +2310,12 @@ This version of the operator has been available since version 1 of the 'com.micr
Number of attention heads
-#### Inputs (2 - 6)
+#### Inputs (1 - 6)
- query : T
-- Query with shape (batch_size, sequence_length, hidden_size)
-- key : T
+- Query with shape (batch_size, sequence_length, hidden_size), or packed QKV with shape (batch_size, kv_sequence_length, num_heads, 3, head_size)
+- key (optional) : T
- Key with shape (batch_size, kv_sequence_length, hidden_size), or packed KV with shape (batch_size, kv_sequence_length, num_heads, 2, head_size)
- value (optional) : T
- Value with shape (batch_size, kv_sequence_length, v_hidden_size)
diff --git a/onnxruntime/contrib_ops/cpu/bert/multihead_attention_helper.h b/onnxruntime/contrib_ops/cpu/bert/multihead_attention_helper.h
index ee1720b9f43bb..34a615a880594 100644
--- a/onnxruntime/contrib_ops/cpu/bert/multihead_attention_helper.h
+++ b/onnxruntime/contrib_ops/cpu/bert/multihead_attention_helper.h
@@ -22,53 +22,79 @@ Status CheckInputs(const T* query,
int num_heads,
float mask_filter_value,
int max_threads_per_block) {
+ // key_padding_mask (K/V) : (B) or (B, L) or None
+ // relative_position_bias : (B, 1, S, L)
+ // When no packing for q/k/v:
// query (Q) : (B, S, D)
// key (K) : (B, L, D)
// value (V) : (B, L, D_v)
// bias (Q/K/V) : (D + D + D_v)
- // key_padding_mask (K/V) : (B) or (B, L) or None
- // relative_position_bias : (B, 1, S, L)
// When packed kv is used:
+ // query (Q) : (B, S, D)
// key (K) : (B, L, N, 2, H)
// value (V) : None
// bias (Q/K/V) : None
+ // When packed qkv is used:
+ // query (Q) : (B, L, N, 3, H)
+ // key (K) : None
+ // value (V) : None
+ // bias (Q/K/V) : None
const auto& query_dims = query->Shape().GetDims();
- if (query_dims.size() != 3) {
- return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input 'query' is expected to have 3 dimensions, got ",
+ if (query_dims.size() != 3 && query_dims.size() != 5) {
+ return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input 'query' is expected to have 3 or 5 dimensions, got ",
query_dims.size());
}
- const auto& key_dims = key->Shape().GetDims();
- if (key_dims.size() != 3 && key_dims.size() != 5) {
- return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input 'key' is expected to have 3 or 5 dimensions, got ",
- key_dims.size());
- }
- if (query_dims[0] != key_dims[0]) {
- return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
- "Input 'query' and 'key' shall have same dim 0 (batch size)");
- }
-
int batch_size = static_cast(query_dims[0]);
int sequence_length = static_cast(query_dims[1]);
- int hidden_size = static_cast(query_dims[2]);
+ int hidden_size = query_dims.size() == 3 ? static_cast(query_dims[2]) : (num_heads * static_cast(query_dims[4]));
int head_size = static_cast(hidden_size) / num_heads;
- int kv_sequence_length = static_cast(key_dims[1]);
+ int kv_sequence_length = sequence_length;
+
+ if (key != nullptr) {
+ if (query_dims.size() != 3) {
+ return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input 'query' is expected to have 3 dimensions when key is given, got ",
+ query_dims.size());
+ }
- if (key_dims.size() == 3) {
- if (key_dims[2] != query_dims[2]) {
+ const auto& key_dims = key->Shape().GetDims();
+ if (key_dims.size() != 3 && key_dims.size() != 5) {
+ return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input 'key' is expected to have 3 or 5 dimensions, got ",
+ key_dims.size());
+ }
+ if (query_dims[0] != key_dims[0]) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
- "Input 'query' and 'key' shall have same dim 2 (hidden_size)");
+ "Input 'query' and 'key' shall have same dim 0 (batch size)");
}
- } else // if (key_dims.size() == 5)
- {
- if (static_cast(key_dims[2]) != num_heads || static_cast(key_dims[3]) != 2 || static_cast(key_dims[4]) != head_size) {
+
+ if (key_dims.size() == 3) {
+ if (key_dims[2] != query_dims[2]) {
+ return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
+ "Input 'query' and 'key' shall have same dim 2 (hidden_size)");
+ }
+ } else // if (key_dims.size() == 5)
+ {
+ if (static_cast(key_dims[2]) != num_heads || static_cast(key_dims[3]) != 2 || static_cast(key_dims[4]) != head_size) {
+ return ORT_MAKE_STATUS(
+ ONNXRUNTIME, INVALID_ARGUMENT,
+ "Expect 'key' shape (batch_size, kv_sequence_length, num_heads, 2, head_size) for packed kv");
+ }
+ if (value != nullptr) {
+ return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Expect 'value' be none when 'key' has packed kv format.");
+ }
+ }
+
+ kv_sequence_length = static_cast(key_dims[1]);
+ } else { // packed QKV
+ if (query_dims.size() != 5) {
+ return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input 'query' is expected to have 5 dimensions when key is empty, got ",
+ query_dims.size());
+ }
+ if (static_cast(query_dims[2]) != num_heads || static_cast(query_dims[3]) != 3) {
return ORT_MAKE_STATUS(
ONNXRUNTIME, INVALID_ARGUMENT,
- "Expect 'key' shape (batch_size, kv_sequence_length, num_heads, 2, head_size) for packed kv");
- }
- if (value != nullptr) {
- return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Expect 'value' be none when 'key' has packed kv format.");
+ "Expect 'query' shape (batch_size, kv_sequence_length, num_heads, 3, head_size) for packed kv");
}
}
@@ -82,7 +108,7 @@ Status CheckInputs(const T* query,
// Currently, bias is not allowed for packed KV. This constraint can be removed later.
// Here we assume that fusion tool will not include bias for packed KV.
if (value == nullptr) {
- return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "'bias' is not allowed for packed kv. ");
+ return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "'bias' is not allowed for packed qkv or kv. ");
}
}
@@ -90,9 +116,9 @@ Status CheckInputs(const T* query,
if (key_padding_mask != nullptr) {
mask_type = AttentionMaskType::MASK_UNKNOWN;
const auto& mask_dims = key_padding_mask->Shape().GetDims();
- if (mask_dims.size() == 1 && mask_dims[0] == key_dims[0]) {
+ if (mask_dims.size() == 1 && mask_dims[0] == static_cast(batch_size)) {
mask_type = AttentionMaskType::MASK_1D_KEY_SEQ_LEN;
- } else if (mask_dims.size() == 2 && mask_dims[0] == key_dims[0] && mask_dims[1] == key_dims[1]) {
+ } else if (mask_dims.size() == 2 && mask_dims[0] == static_cast(batch_size) && mask_dims[1] == static_cast(kv_sequence_length)) {
mask_type = AttentionMaskType::MASK_2D_KEY_PADDING;
}
@@ -115,7 +141,7 @@ Status CheckInputs(const T* query,
"Input 'query' and 'value' shall have same dim 0 (batch_size)");
}
- if (key_dims[1] != value_dims[1]) {
+ if (static_cast(kv_sequence_length) != value_dims[1]) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
"Input 'key' and 'value' shall have same same dim 1 (kv_sequence_length)");
}
diff --git a/onnxruntime/contrib_ops/cuda/bert/attention.cc b/onnxruntime/contrib_ops/cuda/bert/attention.cc
index 1ab89b525eae5..04cac1962f37b 100644
--- a/onnxruntime/contrib_ops/cuda/bert/attention.cc
+++ b/onnxruntime/contrib_ops/cuda/bert/attention.cc
@@ -181,6 +181,7 @@ Status Attention::ComputeInternal(OpKernelContext* context) const {
&zero, reinterpret_cast(gemm_buffer.get()), n, device_prop));
constexpr size_t element_size = sizeof(T);
+ constexpr bool use_fused_cross_attention = false;
size_t workSpaceSize = GetAttentionWorkspaceSize(element_size,
parameters.batch_size,
parameters.num_heads,
@@ -190,6 +191,7 @@ Status Attention::ComputeInternal(OpKernelContext* context) const {
parameters.kv_sequence_length,
parameters.total_sequence_length,
fused_runner,
+ use_fused_cross_attention,
use_memory_efficient_attention);
auto work_space = GetScratchBuffer(workSpaceSize, context->GetComputeStream());
@@ -204,12 +206,15 @@ Status Attention::ComputeInternal(OpKernelContext* context) const {
data.mask_index_dims = (nullptr == mask_index) ? gsl::span() : mask_index->Shape().GetDims();
data.past = (nullptr == past) ? nullptr : reinterpret_cast(past->Data());
data.relative_position_bias = (nullptr == relative_position_bias) ? nullptr : reinterpret_cast(relative_position_bias->Data());
+ data.has_qkv_workspace = true;
data.workspace = reinterpret_cast(work_space.get());
data.output = reinterpret_cast(output->MutableData());
data.present = (nullptr == present) ? nullptr : reinterpret_cast(present->MutableData());
data.fused_runner = reinterpret_cast(fused_runner);
data.fused_cross_attention_kernel = nullptr;
data.use_memory_efficient_attention = use_memory_efficient_attention;
+ data.cumulated_sequence_length_q_cache = nullptr;
+ data.cumulated_sequence_length_kv_cache = nullptr;
return QkvToContext(device_prop, cublas, Stream(context), parameters, data);
}
diff --git a/onnxruntime/contrib_ops/cuda/bert/attention_impl.cu b/onnxruntime/contrib_ops/cuda/bert/attention_impl.cu
index 972b2d4ade7d6..41f19f460e80c 100644
--- a/onnxruntime/contrib_ops/cuda/bert/attention_impl.cu
+++ b/onnxruntime/contrib_ops/cuda/bert/attention_impl.cu
@@ -46,22 +46,47 @@ limitations under the License.
using namespace onnxruntime::cuda;
using namespace cub;
-#define CHECK_CUDA(expr) CUDA_RETURN_IF_ERROR(expr)
-#define CUDA_MEMORY_ALIGNMENT 256
-
namespace onnxruntime {
namespace contrib {
namespace cuda {
+constexpr size_t kMemoryAlignment = 256;
+
static size_t AlignTo(size_t a, size_t b) {
return CeilDiv(a, b) * b;
}
size_t AlignSize(size_t bytes) {
- const size_t bytesAligned = AlignTo(bytes, CUDA_MEMORY_ALIGNMENT);
+ const size_t bytesAligned = AlignTo(bytes, kMemoryAlignment);
return bytesAligned;
}
+void CumulatedSequenceLengthCache::Initialize(int32_t sequence_length, cudaStream_t stream) {
+ if (this->sequence_length != sequence_length) {
+ ORT_ENFORCE(buffer.get() != nullptr && this->max_batch_size > 0);
+ LaunchTrtSequenceOffset(reinterpret_cast(buffer.get()), nullptr, this->max_batch_size, sequence_length, stream);
+ this->sequence_length = sequence_length;
+ }
+}
+
+int* GetCumulatedSequenceLength(CumulatedSequenceLengthCache* cache,
+ const int* mask_index,
+ int batch_size,
+ int sequence_length,
+ cudaStream_t stream,
+ void* scratch_buffer) {
+ if (mask_index == nullptr && cache != nullptr) {
+ if (batch_size <= cache->max_batch_size) {
+ cache->Initialize(sequence_length, stream);
+ return reinterpret_cast(cache->buffer.get());
+ }
+ }
+
+ int* sequence_offset = reinterpret_cast(scratch_buffer);
+ LaunchTrtSequenceOffset(sequence_offset, mask_index, batch_size, sequence_length, stream);
+ return sequence_offset;
+}
+
size_t GetAttentionScratchSize(
size_t element_size,
size_t batch_size,
@@ -89,6 +114,7 @@ size_t GetAttentionWorkspaceSize(
size_t kv_sequence_length,
size_t total_sequence_length,
void* fused_runner,
+ bool use_fused_cross_attention,
bool use_memory_efficient_attention) {
// Note that q, k and v might need alignment for fused attention kernels.
const size_t qkv_bytes = element_size * batch_size * num_heads *
@@ -108,8 +134,11 @@ size_t GetAttentionWorkspaceSize(
#endif
if (fused_runner != nullptr) {
- size_t sequence_offset_bytes = GetSequenceOffsetSize(static_cast(batch_size), true);
- return qkv_bytes + sequence_offset_bytes;
+ return qkv_bytes + GetSequenceOffsetSize(static_cast(batch_size), true);
+ }
+
+ if (use_fused_cross_attention) {
+ return qkv_bytes + 2 * GetSequenceOffsetSize(static_cast(batch_size), true);
}
return qkv_bytes + 2 * GetAttentionScratchSize(element_size, batch_size, num_heads, sequence_length,
@@ -264,7 +293,7 @@ Status PrepareQkv(contrib::AttentionParameters& parameters,
T* qkv = data.workspace;
- bool use_fused_kernel = (nullptr != fused_runner && data.bias != nullptr && !parameters.is_unidirectional);
+ bool use_fused_kernel = (nullptr != fused_runner && !parameters.is_unidirectional);
bool use_fused_causal = (nullptr != fused_runner && parameters.is_unidirectional);
// Default format for memory efficient attention.
@@ -272,6 +301,7 @@ Status PrepareQkv(contrib::AttentionParameters& parameters,
DUMP_TENSOR_INIT();
if (nullptr != data.gemm_buffer) {
if (data.bias == nullptr) {
+ assert(nullptr == fused_runner);
// For quantized attention, bias has been added so only need transpose here.
// gemm_buffer should be BxSx3xNxH => qkv: 3xBxNxSxH
assert(qk_head_size == v_head_size);
@@ -303,6 +333,31 @@ Status PrepareQkv(contrib::AttentionParameters& parameters,
data.gemm_buffer, data.bias, qkv,
true, v_head_size, qkv_add_bias, 3);
}
+ } else if (data.key == nullptr) { // gemm_buffer == nullptr and packed qkv
+ assert(data.bias == nullptr);
+ assert(qk_head_size == v_head_size);
+
+ DUMP_TENSOR_D("packed_qkv", data.query, batch_size * sequence_length, num_heads, 3, qk_head_size);
+
+ if (use_memory_efficient_attention) {
+ // unpack qkv to BSNH. Note that there is no bias so we need not output query to q.
+ constexpr int format = 4;
+ T* qkv_add_bias = nullptr;
+ LaunchAddBiasTranspose(stream, 3, format, max_threads_per_block,
+ batch_size, sequence_length, num_heads, qk_head_size,
+ data.query, data.bias, qkv,
+ true, v_head_size, qkv_add_bias, 3);
+ DUMP_TENSOR_D("k(BSNH)", q, batch_size * sequence_length, num_heads, qk_head_size);
+ DUMP_TENSOR_D("k(BSNH)", k, batch_size * kv_sequence_length, num_heads, qk_head_size);
+ DUMP_TENSOR_D("v(BSNH)", v, batch_size * kv_sequence_length, num_heads, v_head_size);
+ qkv_format = AttentionQkvFormat::Q_K_V_BSNH;
+ } else {
+ if (!use_fused_kernel) {
+ return ORT_MAKE_STATUS(ONNXRUNTIME, NOT_IMPLEMENTED, "packed QKV format is not implemented for current GPU. Please disable it in fusion options.");
+ }
+
+ qkv_format = AttentionQkvFormat::QKV_BSN3H;
+ }
} else if (data.value == nullptr) { // gemm_buffer == nullptr and packed kv
// TODO: unpack kv to BNSH for unfused kernel so that we can remove the following constraint.
// CheckInputs verified this constraint.
@@ -330,7 +385,7 @@ Status PrepareQkv(contrib::AttentionParameters& parameters,
qkv_format = AttentionQkvFormat::Q_KV_BSNH_BSN2H;
}
- } else { // gemm_buffer == nullptr and not packed kv
+ } else { // gemm_buffer == nullptr and not packed
assert(data.query != nullptr && data.key != nullptr && data.value != nullptr && data.bias != nullptr);
DUMP_TENSOR_D("query", data.query, batch_size * sequence_length, num_heads, qk_head_size);
@@ -435,18 +490,25 @@ Status QkvToContext(
assert(int(data.use_memory_efficient_attention) + int(fused_runner != nullptr) + int(data.fused_cross_attention_kernel != nullptr) <= 1);
const int batches = batch_size * num_heads;
- const int size_per_batch_q = sequence_length * qk_head_size;
- const int size_per_batch_k = kv_sequence_length * qk_head_size;
- const int size_per_batch_v = kv_sequence_length * v_head_size;
- const size_t elements_q = static_cast(batches) * static_cast(size_per_batch_q);
- const size_t elements_k = static_cast(batches) * static_cast(size_per_batch_k);
- const size_t elements_v = static_cast(batches) * static_cast(size_per_batch_v);
-
- // Q, K and V pointers when fused attention is not used
- T* qkv = data.workspace;
- T* q = qkv;
- T* k = q + elements_q;
- T* v = k + elements_k;
+
+ T* qkv = nullptr;
+ T* q = nullptr;
+ T* k = nullptr;
+ T* v = nullptr;
+ T* scratch1 = data.workspace;
+ if (data.has_qkv_workspace) {
+ const int size_per_batch_q = sequence_length * qk_head_size;
+ const int size_per_batch_k = kv_sequence_length * qk_head_size;
+ const int size_per_batch_v = kv_sequence_length * v_head_size;
+ const size_t elements_q = static_cast(batches) * static_cast(size_per_batch_q);
+ const size_t elements_k = static_cast(batches) * static_cast(size_per_batch_k);
+ const size_t elements_v = static_cast(batches) * static_cast(size_per_batch_v);
+ qkv = data.workspace;
+ q = qkv;
+ k = q + elements_q;
+ v = k + elements_k;
+ scratch1 = v + elements_v;
+ }
bool use_fused_kernel = (nullptr != fused_runner && !parameters.is_unidirectional);
bool use_fused_causal = (nullptr != fused_runner && parameters.is_unidirectional);
@@ -454,8 +516,6 @@ Status QkvToContext(
AttentionQkvFormat qkv_format = AttentionQkvFormat::Q_K_V_BSNH;
ORT_RETURN_IF_ERROR(PrepareQkv(parameters, data, stream, max_threads_per_block, q, k, v, qkv_format));
- T* scratch1 = qkv + elements_q + elements_k + elements_v;
-
int present_size_per_batch_k = 0;
int present_size_per_batch_v = 0;
if (!past_present_share_buffer) {
@@ -482,6 +542,7 @@ Status QkvToContext(
assert(!use_fused_kernel);
assert(data.gemm_buffer != nullptr);
assert(!data.use_memory_efficient_attention);
+ assert(data.has_qkv_workspace);
if (data.present != data.past) {
// For easy testing. Production should better avoid this path.
@@ -507,18 +568,21 @@ Status QkvToContext(
if (data.fused_cross_attention_kernel != nullptr) {
assert(qkv_format == AttentionQkvFormat::Q_KV_BSNH_BSN2H);
- int* q_sequence_offset = reinterpret_cast(scratch1);
- LaunchTrtSequenceOffset(q_sequence_offset, nullptr, batch_size, sequence_length, stream);
- CUDA_RETURN_IF_ERROR(cudaGetLastError());
-
- DUMP_TENSOR_D("q_sequence_offset", q_sequence_offset, 1, batch_size + 1);
// We only enable fused cross attention when there is no key padding mask.
// Otherwise, key have effective batch size 2 * batch_size, which is different from batch_size of query.
assert(data.mask_index == nullptr);
+ int* q_sequence_offset = GetCumulatedSequenceLength(data.cumulated_sequence_length_q_cache,
+ data.mask_index, batch_size, sequence_length, stream,
+ scratch1);
+
+ DUMP_TENSOR_D("q_sequence_offset", q_sequence_offset, 1, batch_size + 1);
+
int* kv_sequence_offset = q_sequence_offset + (GetSequenceOffsetSize(batch_size, false) / sizeof(int));
- LaunchTrtSequenceOffset(kv_sequence_offset, data.mask_index, batch_size, kv_sequence_length, stream);
+ kv_sequence_offset = GetCumulatedSequenceLength(data.cumulated_sequence_length_kv_cache,
+ data.mask_index, batch_size, kv_sequence_length, stream,
+ kv_sequence_offset);
CUDA_RETURN_IF_ERROR(cudaGetLastError());
DUMP_TENSOR_D("kv_sequence_offset", kv_sequence_offset, 1, batch_size + 1);
@@ -526,7 +590,7 @@ Status QkvToContext(
FusedMultiHeadCrossAttentionKernel const* cross_attention_kernel =
reinterpret_cast(data.fused_cross_attention_kernel);
- // When there is no bias, we can directly use q and packed kv from inputs. TODO: not need qkv in workspace.
+ // When there is no bias, we can directly use q and packed kv from inputs.
void const* query = q;
void const* packed_kv = k;
if (data.value == nullptr && data.bias == nullptr) {
@@ -558,7 +622,9 @@ Status QkvToContext(
if (parameters.mask_type == AttentionMaskType::MASK_2D_KEY_PADDING) {
LaunchTrtSequenceOffset2d(sequence_offset, data.mask_index, batch_size, sequence_length, stream);
} else {
- LaunchTrtSequenceOffset(sequence_offset, data.mask_index, batch_size, sequence_length, stream);
+ sequence_offset = GetCumulatedSequenceLength(data.cumulated_sequence_length_q_cache,
+ data.mask_index, batch_size, sequence_length, stream,
+ sequence_offset);
}
CUDA_RETURN_IF_ERROR(cudaGetLastError());
@@ -573,7 +639,14 @@ Status QkvToContext(
if (use_fused_kernel) {
assert(qkv_format == AttentionQkvFormat::QKV_BSN3H);
- fused_fp16_runner->run(qkv, sequence_offset, data.output, stream);
+
+ // When there is no bias, we can directly use packed qkv from inputs.
+ void const* packed_qkv = qkv;
+ if (data.query != nullptr && data.key == nullptr && data.bias == nullptr) {
+ packed_qkv = data.query;
+ }
+
+ fused_fp16_runner->run(packed_qkv, sequence_offset, data.output, stream);
DUMP_TENSOR("fused output", data.output, batch_size * sequence_length, num_heads, v_head_size);
} else {
assert(qkv_format == AttentionQkvFormat::Q_K_V_BNSH_QKV_BS3NH);
@@ -593,7 +666,9 @@ Status QkvToContext(
const void* query = q;
const void* key = k;
const void* value = v;
- if (data.gemm_buffer == nullptr && data.value == nullptr) { // packed KV
+ // For packed KV, we can use query input directly.
+ if (data.gemm_buffer == nullptr && data.key != nullptr && data.value == nullptr) {
+ assert(data.bias == nullptr);
query = data.query;
}
@@ -781,15 +856,15 @@ Status DecoderQkvToContext(
if (has_layer_state) {
if (use_past && static_kv) {
- CHECK_CUDA(cudaMemcpyAsync(new_key_cache, key_cache, kv_sequence_length * BHN * sizeof(T),
- cudaMemcpyDeviceToDevice, stream));
- CHECK_CUDA(cudaMemcpyAsync(new_value_cache, value_cache, kv_sequence_length * BHN * sizeof(T),
- cudaMemcpyDeviceToDevice, stream));
+ CUDA_RETURN_IF_ERROR(cudaMemcpyAsync(new_key_cache, key_cache, kv_sequence_length * BHN * sizeof(T),
+ cudaMemcpyDeviceToDevice, stream));
+ CUDA_RETURN_IF_ERROR(cudaMemcpyAsync(new_value_cache, value_cache, kv_sequence_length * BHN * sizeof(T),
+ cudaMemcpyDeviceToDevice, stream));
} else {
- CHECK_CUDA(cudaMemcpyAsync(new_key_cache, k, kv_sequence_length * BHN * sizeof(T),
- cudaMemcpyDeviceToDevice, stream));
- CHECK_CUDA(cudaMemcpyAsync(new_value_cache, v, kv_sequence_length * BHN * sizeof(T),
- cudaMemcpyDeviceToDevice, stream));
+ CUDA_RETURN_IF_ERROR(cudaMemcpyAsync(new_key_cache, k, kv_sequence_length * BHN * sizeof(T),
+ cudaMemcpyDeviceToDevice, stream));
+ CUDA_RETURN_IF_ERROR(cudaMemcpyAsync(new_value_cache, v, kv_sequence_length * BHN * sizeof(T),
+ cudaMemcpyDeviceToDevice, stream));
}
}
diff --git a/onnxruntime/contrib_ops/cuda/bert/attention_impl.h b/onnxruntime/contrib_ops/cuda/bert/attention_impl.h
index 2ecda71479c52..ec7371db4c14d 100644
--- a/onnxruntime/contrib_ops/cuda/bert/attention_impl.h
+++ b/onnxruntime/contrib_ops/cuda/bert/attention_impl.h
@@ -6,18 +6,33 @@
#include
#include
#include "contrib_ops/cpu/bert/attention_common.h"
+#include "core/framework/allocator.h"
namespace onnxruntime {
namespace contrib {
namespace cuda {
-size_t GetAttentionScratchSize(
+constexpr int kCumulatedSequenceLengthCacheMaxBatchSize = 128;
+
+struct CumulatedSequenceLengthCache {
+ onnxruntime::IAllocatorUniquePtr buffer;
+ int32_t max_batch_size;
+ int32_t sequence_length;
+
+ CumulatedSequenceLengthCache() : max_batch_size(0), sequence_length(0) {}
+ void Initialize(int32_t sequence_length, cudaStream_t stream);
+};
+
+size_t
+GetAttentionScratchSize(
size_t element_size,
size_t batch_size,
size_t num_heads,
size_t sequence_length,
size_t all_sequence_length);
+size_t GetSequenceOffsetSize(int batch_size, bool has_padding);
+
size_t GetAttentionWorkspaceSize(
size_t element_size,
size_t batchsize,
@@ -28,7 +43,8 @@ size_t GetAttentionWorkspaceSize(
size_t kv_sequence_length,
size_t total_sequence_length,
void* fused_runner,
- bool use_memory_efficient_attention = false);
+ bool use_fused_cross_attention,
+ bool use_memory_efficient_attention);
template
struct AttentionData {
@@ -43,7 +59,9 @@ struct AttentionData {
const T* past;
const T* relative_position_bias;
+ bool has_qkv_workspace;
T* workspace;
+
T* output;
T* present;
@@ -51,6 +69,9 @@ struct AttentionData {
const void* fused_cross_attention_kernel;
bool use_memory_efficient_attention;
+
+ mutable CumulatedSequenceLengthCache* cumulated_sequence_length_q_cache;
+ mutable CumulatedSequenceLengthCache* cumulated_sequence_length_kv_cache;
};
template
diff --git a/onnxruntime/contrib_ops/cuda/bert/multihead_attention.cc b/onnxruntime/contrib_ops/cuda/bert/multihead_attention.cc
index 57a3a310a0dd6..321c2a1df0df2 100644
--- a/onnxruntime/contrib_ops/cuda/bert/multihead_attention.cc
+++ b/onnxruntime/contrib_ops/cuda/bert/multihead_attention.cc
@@ -31,9 +31,11 @@ REGISTER_KERNEL_TYPED(float)
REGISTER_KERNEL_TYPED(MLFloat16)
template
-
MultiHeadAttention::MultiHeadAttention(const OpKernelInfo& info)
- : CudaKernel(info), fused_fp16_cross_attention_kernel_(nullptr) {
+ : CudaKernel(info),
+ fused_fp16_cross_attention_kernel_(nullptr),
+ cumulated_sequence_length_q_cache_(),
+ cumulated_sequence_length_kv_cache_() {
int64_t num_heads = 0;
ORT_ENFORCE(info.GetAttr("num_heads", &num_heads).IsOK() && num_heads > 0);
num_heads_ = static_cast(num_heads);
@@ -52,7 +54,15 @@ MultiHeadAttention::MultiHeadAttention(const OpKernelInfo& info)
disable_memory_efficient_attention_ = true;
#endif
- disable_fused_cross_attention_ = sizeof(T) != 2 || ParseEnvironmentVariableWithDefault(attention::kDisableFusedCrossAttention, false);
+ disable_fused_cross_attention_ = sizeof(T) != 2 ||
+ ParseEnvironmentVariableWithDefault(attention::kDisableFusedCrossAttention, false);
+
+ // Allocate cache buffers
+ constexpr size_t cache_bytes = sizeof(int32_t) * (static_cast(kCumulatedSequenceLengthCacheMaxBatchSize) + 1);
+ cumulated_sequence_length_q_cache_.buffer = GetTransientScratchBuffer(cache_bytes);
+ cumulated_sequence_length_q_cache_.max_batch_size = kCumulatedSequenceLengthCacheMaxBatchSize;
+ cumulated_sequence_length_kv_cache_.buffer = GetTransientScratchBuffer(cache_bytes);
+ cumulated_sequence_length_kv_cache_.max_batch_size = kCumulatedSequenceLengthCacheMaxBatchSize;
}
template
@@ -97,6 +107,7 @@ Status MultiHeadAttention::ComputeInternal(OpKernelContext* context) const {
bool use_fused_cross_attention = !disable_fused_cross_attention_ &&
nullptr == key_padding_mask &&
nullptr == relative_position_bias &&
+ key != nullptr &&
(value != nullptr || bias == nullptr) && // TODO: new kernel for adding bias to packed KV
parameters.hidden_size == parameters.v_hidden_size &&
has_fused_cross_attention_kernel(sm, parameters.head_size,
@@ -116,7 +127,7 @@ Status MultiHeadAttention::ComputeInternal(OpKernelContext* context) const {
bool use_fused_runner = !disable_fused_runner_ &&
fused_cross_attention_kernel == nullptr &&
nullptr == relative_position_bias &&
- value != nullptr && // fused runner requires packed qkv instead of packed kv
+ (value != nullptr || key == nullptr) &&
(nullptr == key_padding_mask || is_mask_1d_seq_len) &&
parameters.hidden_size == parameters.v_hidden_size &&
parameters.sequence_length == parameters.kv_sequence_length &&
@@ -153,36 +164,52 @@ Status MultiHeadAttention::ComputeInternal(OpKernelContext* context) const {
constexpr bool use_memory_efficient_attention = false;
#endif
- constexpr size_t element_size = sizeof(T);
- size_t workSpaceSize = GetAttentionWorkspaceSize(element_size,
- parameters.batch_size,
- parameters.num_heads,
- parameters.head_size,
- parameters.v_head_size,
- parameters.sequence_length,
- parameters.kv_sequence_length,
- parameters.total_sequence_length,
- fused_runner,
- use_memory_efficient_attention);
- auto work_space = GetScratchBuffer(workSpaceSize, context->GetComputeStream());
+ // When packed kv or packed qkv is used, there is no needed for add bias transpose thus no qkv workspace.
+ bool no_qkv_workspace = nullptr == value &&
+ (use_fused_cross_attention || (nullptr != fused_runner && nullptr == key)) &&
+ nullptr == key_padding_mask &&
+ nullptr == bias;
+
+ size_t workspace_bytes;
+ if (no_qkv_workspace) {
+ workspace_bytes = (parameters.batch_size > kCumulatedSequenceLengthCacheMaxBatchSize) ? 2 * GetSequenceOffsetSize(parameters.batch_size, true) : 0;
+ } else {
+ constexpr size_t element_size = sizeof(T);
+ workspace_bytes = GetAttentionWorkspaceSize(element_size,
+ parameters.batch_size,
+ parameters.num_heads,
+ parameters.head_size,
+ parameters.v_head_size,
+ parameters.sequence_length,
+ parameters.kv_sequence_length,
+ parameters.total_sequence_length,
+ fused_runner,
+ use_fused_cross_attention,
+ use_memory_efficient_attention);
+ }
+
+ auto work_space = GetScratchBuffer(workspace_bytes, context->GetComputeStream());
typedef typename ToCudaType::MappedType CudaT;
AttentionData data;
data.gemm_buffer = nullptr;
data.bias = (nullptr == bias) ? nullptr : reinterpret_cast(bias->Data());
data.query = reinterpret_cast(query->Data());
- data.key = reinterpret_cast(key->Data());
+ data.key = (nullptr == key) ? nullptr : reinterpret_cast(key->Data());
data.value = (nullptr == value) ? nullptr : reinterpret_cast(value->Data());
data.mask_index = (nullptr == key_padding_mask) ? nullptr : key_padding_mask->Data();
data.mask_index_dims = (nullptr == key_padding_mask) ? gsl::span() : key_padding_mask->Shape().GetDims();
data.past = nullptr;
data.relative_position_bias = (nullptr == relative_position_bias) ? nullptr : reinterpret_cast(relative_position_bias->Data());
+ data.has_qkv_workspace = !no_qkv_workspace;
data.workspace = reinterpret_cast(work_space.get());
data.output = reinterpret_cast(output->MutableData());
data.present = nullptr;
data.fused_runner = reinterpret_cast(fused_runner);
data.fused_cross_attention_kernel = fused_cross_attention_kernel;
data.use_memory_efficient_attention = use_memory_efficient_attention;
+ data.cumulated_sequence_length_q_cache = &(this->cumulated_sequence_length_q_cache_);
+ data.cumulated_sequence_length_kv_cache = &(this->cumulated_sequence_length_kv_cache_);
cublasHandle_t cublas = GetCublasHandle(context);
return QkvToContext(
diff --git a/onnxruntime/contrib_ops/cuda/bert/multihead_attention.h b/onnxruntime/contrib_ops/cuda/bert/multihead_attention.h
index b4ac7f19597ea..928dbd1c4a0f4 100644
--- a/onnxruntime/contrib_ops/cuda/bert/multihead_attention.h
+++ b/onnxruntime/contrib_ops/cuda/bert/multihead_attention.h
@@ -7,6 +7,7 @@
#include "core/providers/cuda/cuda_kernel.h"
#include "contrib_ops/cuda/bert/tensorrt_fused_multihead_attention/mha_runner.h"
#include "contrib_ops/cuda/bert/tensorrt_fused_multihead_attention/cross_attention/fmha_cross_attention.h"
+#include "contrib_ops/cuda/bert/attention_impl.h"
namespace onnxruntime {
namespace contrib {
@@ -29,6 +30,8 @@ class MultiHeadAttention final : public CudaKernel {
bool disable_memory_efficient_attention_;
mutable std::unique_ptr fused_fp16_runner_;
mutable const FusedMultiHeadCrossAttentionKernel* fused_fp16_cross_attention_kernel_;
+ mutable CumulatedSequenceLengthCache cumulated_sequence_length_q_cache_;
+ mutable CumulatedSequenceLengthCache cumulated_sequence_length_kv_cache_;
};
} // namespace cuda
diff --git a/onnxruntime/contrib_ops/cuda/bert/tensorrt_fused_multihead_attention/cross_attention/fmha_cross_attention.h b/onnxruntime/contrib_ops/cuda/bert/tensorrt_fused_multihead_attention/cross_attention/fmha_cross_attention.h
index 45a17a03d82f2..23bab06fe46ca 100644
--- a/onnxruntime/contrib_ops/cuda/bert/tensorrt_fused_multihead_attention/cross_attention/fmha_cross_attention.h
+++ b/onnxruntime/contrib_ops/cuda/bert/tensorrt_fused_multihead_attention/cross_attention/fmha_cross_attention.h
@@ -186,7 +186,7 @@ static Fused_multihead_attention_params_mhca getMHCAParams(
void const* q_packed_d, void const* kv_packed_d, void* cu_seqlens_q_d, void* cu_seqlens_kv_d, void* o_packed_d) {
Fused_multihead_attention_params_mhca params{};
- int32_t const d_padded = static_cast(std::pow(2, std::ceil(std::log(d) / std::log(2))));
+ int32_t const d_padded = d <= 64 ? 64 : static_cast(std::pow(2, std::ceil(std::log(d) / std::log(2))));
// Set the pointers.
params.o_ptr = o_packed_d;
@@ -269,11 +269,11 @@ using FusedMHACrossKernelFactory = TSharedCubinKernelFactory min_head_size) && (head_size <= max_head_size) &&
- (kv_sequence_length <= 128); // TODO: shall we remove this constraint on kv_sequence_length?
+ (kv_sequence_length <= 128); // TODO: shall we remove this constraint on kv_sequence_length?
}
inline FusedMultiHeadCrossAttentionKernel const* get_fused_cross_attention_kernels(int32_t sm) {
diff --git a/onnxruntime/contrib_ops/cuda/quantization/attention_quantization.cc b/onnxruntime/contrib_ops/cuda/quantization/attention_quantization.cc
index 7cd717efc9fba..90ec1a35ac63a 100644
--- a/onnxruntime/contrib_ops/cuda/quantization/attention_quantization.cc
+++ b/onnxruntime/contrib_ops/cuda/quantization/attention_quantization.cc
@@ -174,6 +174,7 @@ Status QAttention::ComputeInternal(OpKernelContext* context) const {
Tensor* present = context->Output(1, present_shape);
void* fused_runner = nullptr; // TODO(tianleiwu): use fused kernel to speed up
+ bool use_fused_cross_attention = false;
bool use_memory_efficient_attention = false;
size_t workSpaceSize = GetAttentionWorkspaceSize(element_size,
batch_size,
@@ -184,6 +185,7 @@ Status QAttention::ComputeInternal(OpKernelContext* context) const {
parameters.kv_sequence_length,
parameters.total_sequence_length,
fused_runner,
+ use_fused_cross_attention,
use_memory_efficient_attention);
auto work_space = GetScratchBuffer(workSpaceSize, context->GetComputeStream());
@@ -199,12 +201,15 @@ Status QAttention::ComputeInternal(OpKernelContext* context) const {
data.mask_index_dims = (nullptr == mask_index) ? gsl::span() : mask_index->Shape().GetDims();
data.past = (nullptr == past_tensor) ? nullptr : reinterpret_cast(past_tensor->Data());
data.relative_position_bias = nullptr; // add_qk is not supported in quantized attention
+ data.has_qkv_workspace = true;
data.workspace = reinterpret_cast(work_space.get());
data.output = reinterpret_cast(output->MutableData());
data.present = (nullptr == present) ? nullptr : reinterpret_cast(present->MutableData());
data.fused_runner = fused_runner;
data.fused_cross_attention_kernel = nullptr;
data.use_memory_efficient_attention = use_memory_efficient_attention;
+ data.cumulated_sequence_length_q_cache = nullptr;
+ data.cumulated_sequence_length_kv_cache = nullptr;
return QkvToContext(GetDeviceProp(), cublas, Stream(context), parameters, data);
}
diff --git a/onnxruntime/core/graph/contrib_ops/bert_defs.cc b/onnxruntime/core/graph/contrib_ops/bert_defs.cc
index 6b00ac94bc10f..580a0a993454a 100644
--- a/onnxruntime/core/graph/contrib_ops/bert_defs.cc
+++ b/onnxruntime/core/graph/contrib_ops/bert_defs.cc
@@ -126,11 +126,23 @@ void RestorePaddingTypeAndShapeInference(ONNX_NAMESPACE::InferenceContext& ctx)
}
void MultiHeadAttentionTypeAndShapeInference(ONNX_NAMESPACE::InferenceContext& ctx) {
- // Input 0 (query) has shape (batch_size, sequence_length, hidden_size)
- // Input 1 (key) has shape (batch_size, kv_sequence_length, hidden_size) or (batch_size, kv_sequence_length, num_heads, 2, head_size)
- // Input 2 (value) has shape (batch_size, kv_sequence_length, v_hidden_size) or nullptr
// Output 0 has shape (batch_size, sequence_length, v_hidden_size)
+ // Q, K and V without packing:
+ // Input 0 (query) has shape (batch_size, sequence_length, hidden_size)
+ // Input 1 (key) has shape (batch_size, kv_sequence_length, hidden_size)
+ // Input 2 (value) has shape (batch_size, kv_sequence_length, v_hidden_size)
+
+ // Packed KV:
+ // Input 0 (query) has shape (batch_size, sequence_length, hidden_size)
+ // Input 1 (batch_size, kv_sequence_length, num_heads, 2, head_size)
+ // Input 2 nullptr
+
+ // Packed QKV:
+ // Input 0 (batch_size, sequence_length, num_heads, 3, head_size)
+ // Input 1 nullptr
+ // Input 2 nullptr
+
// Type inference
ONNX_NAMESPACE::propagateElemTypeFromInputToOutput(ctx, 0, 0);
@@ -138,8 +150,18 @@ void MultiHeadAttentionTypeAndShapeInference(ONNX_NAMESPACE::InferenceContext& c
if (hasInputShape(ctx, 0)) {
auto& query_shape = getInputShape(ctx, 0);
auto& query_dims = query_shape.dim();
- if (query_dims.size() != 3) {
- fail_shape_inference("Inputs 0 (query) shall be 3 dimensions");
+
+ if (query_dims.size() != 3 && query_dims.size() != 5) {
+ fail_shape_inference("Inputs 0 (query) shall be 3 or 5 dimensions");
+ }
+
+ if (query_dims.size() == 5) { // packed QKV
+ ONNX_NAMESPACE::TensorShapeProto output_shape;
+ *output_shape.add_dim() = query_dims[0];
+ *output_shape.add_dim() = query_dims[1];
+ *output_shape.add_dim() = query_dims[2] * query_dims[4];
+ updateOutputShape(ctx, 0, output_shape);
+ return;
}
if (hasInputShape(ctx, 2)) {
@@ -154,11 +176,12 @@ void MultiHeadAttentionTypeAndShapeInference(ONNX_NAMESPACE::InferenceContext& c
*output_shape.add_dim() = query_dims[1];
*output_shape.add_dim() = value_dims[2];
updateOutputShape(ctx, 0, output_shape);
+ return;
}
if (hasInputShape(ctx, 1)) {
auto& key_shape = getInputShape(ctx, 1);
- if (key_shape.dim().size() == 5) {
+ if (key_shape.dim().size() == 5) { // packed KV
ONNX_NAMESPACE::propagateShapeAndTypeFromFirstInput(ctx);
}
}
@@ -292,12 +315,13 @@ ONNX_MS_OPERATOR_SET_SCHEMA(
AttributeProto::FLOAT, OPTIONAL_VALUE)
.Input(0,
"query",
- "Query with shape (batch_size, sequence_length, hidden_size)",
+ "Query with shape (batch_size, sequence_length, hidden_size), or packed QKV with shape (batch_size, kv_sequence_length, num_heads, 3, head_size)",
"T")
.Input(1,
"key",
"Key with shape (batch_size, kv_sequence_length, hidden_size), or packed KV with shape (batch_size, kv_sequence_length, num_heads, 2, head_size)",
- "T")
+ "T",
+ OpSchema::Optional)
.Input(2,
"value",
"Value with shape (batch_size, kv_sequence_length, v_hidden_size)",
diff --git a/onnxruntime/python/tools/symbolic_shape_infer.py b/onnxruntime/python/tools/symbolic_shape_infer.py
index 53106715cd36c..bf7f7d86690ac 100755
--- a/onnxruntime/python/tools/symbolic_shape_infer.py
+++ b/onnxruntime/python/tools/symbolic_shape_infer.py
@@ -2025,28 +2025,43 @@ def _infer_BiasGelu(self, node):
self._propagate_shape_and_type(node)
def _infer_MultiHeadAttention(self, node):
- # Input 0 (query) has shape (batch_size, sequence_length, hidden_size)
- # Without packed KV:
+ # Output 0 has shape (batch_size, sequence_length, v_hidden_size)
+ # Q, K and V without packing:
+ # Input 0 (query) has shape (batch_size, sequence_length, hidden_size)
# Input 1 (key) has shape (batch_size, kv_sequence_length, hidden_size)
# Input 2 (value) has shape (batch_size, kv_sequence_length, v_hidden_size)
- # With packed KV:
- # Input 1 (key) has shape (batch_size, kv_sequence_length, num_heads, 2, head_size)
- # Input 2 (value) is nullptr
- # Output 0 has shape (batch_size, sequence_length, v_hidden_size)
- query_shape = self._get_shape(node, 0)
- key_shape = self._get_shape(node, 1)
- if query_shape is not None and len(query_shape) == 3:
+ # Packed KV:
+ # Input 0 (query) has shape (batch_size, sequence_length, hidden_size)
+ # Input 1 (batch_size, kv_sequence_length, num_heads, 2, head_size)
+ # Input 2 nullptr
+ # Packed QKV:
+ # Input 0 (batch_size, sequence_length, num_heads, 3, head_size)
+ # Input 1 nullptr
+ # Input 2 nullptr
- # By default, hidden size is same for Q/K/V. Only need check v_hidden_size when value is provided.
- output_shape = query_shape
- if key_shape and len(key_shape) == 3:
- value_shape = self._get_shape(node, 2)
- if value_shape and len(value_shape) == 3:
- output_shape[2] = value_shape[2]
+ query_shape = self._get_shape(node, 0)
+ if query_shape is not None:
+ if len(query_shape) == 3:
+ key_shape = self._get_shape(node, 1)
+ # By default, hidden size is same for Q/K/V. Only need check v_hidden_size when value is provided.
+ output_shape = query_shape
+ if key_shape and len(key_shape) == 3:
+ value_shape = self._get_shape(node, 2)
+ if value_shape and len(value_shape) == 3:
+ output_shape[2] = value_shape[2]
+
+ output_dtype = self.known_vi_[node.input[0]].type.tensor_type.elem_type
+ vi = self.known_vi_[node.output[0]]
+ vi.CopyFrom(helper.make_tensor_value_info(node.output[0], output_dtype, output_shape))
+ elif len(query_shape) == 5:
+ if isinstance(query_shape[2], int) and isinstance(query_shape[4], int):
+ output_shape = [query_shape[0], query_shape[1], query_shape[2] * query_shape[4]]
+ else:
+ output_shape = [query_shape[0], query_shape[1], f"{query_shape[2]}*{query_shape[4]}"]
- output_dtype = self.known_vi_[node.input[0]].type.tensor_type.elem_type
- vi = self.known_vi_[node.output[0]]
- vi.CopyFrom(helper.make_tensor_value_info(node.output[0], output_dtype, output_shape))
+ output_dtype = self.known_vi_[node.input[0]].type.tensor_type.elem_type
+ vi = self.known_vi_[node.output[0]]
+ vi.CopyFrom(helper.make_tensor_value_info(node.output[0], output_dtype, output_shape))
def _infer_FastGelu(self, node):
self._propagate_shape_and_type(node)
diff --git a/onnxruntime/python/tools/transformers/fusion_attention_unet.py b/onnxruntime/python/tools/transformers/fusion_attention_unet.py
index 361baaedc4a95..dcbc640923e35 100644
--- a/onnxruntime/python/tools/transformers/fusion_attention_unet.py
+++ b/onnxruntime/python/tools/transformers/fusion_attention_unet.py
@@ -20,12 +20,19 @@ class FusionAttentionUnet(Fusion):
"""
def __init__(
- self, model: OnnxModel, hidden_size: int, num_heads: int, is_cross_attention: bool, enable_packed_kv: bool
+ self,
+ model: OnnxModel,
+ hidden_size: int,
+ num_heads: int,
+ is_cross_attention: bool,
+ enable_packed_qkv: bool,
+ enable_packed_kv: bool,
):
super().__init__(model, "MultiHeadAttention" if is_cross_attention else "Attention", ["LayerNormalization"])
self.hidden_size = hidden_size
self.num_heads = num_heads
self.is_cross_attention = is_cross_attention
+ self.enable_packed_qkv = enable_packed_qkv
self.enable_packed_kv = enable_packed_kv
# Flags to show warning only once
@@ -147,10 +154,6 @@ def create_attention_node(
return None
qw_in_size = qw.shape[0]
- kw_in_size = kw.shape[0]
- vw_in_size = vw.shape[0]
-
- assert qw_in_size == kw_in_size and kw_in_size == vw_in_size
if hidden_size > 0 and hidden_size != qw_in_size:
raise ValueError(
@@ -161,21 +164,70 @@ def create_attention_node(
# All the matrices can have the same shape or q, k matrics can have the same shape with v being different
# For 2d weights, the shapes would be [in_size, out_size].
# For 3d weights, shape would be [in_size, a, b] where a*b = out_size
- qw_out_size = np.prod(qw.shape[1:])
+ qw_out_size = int(np.prod(qw.shape[1:]))
- qkv_weight = np.stack((qw, kw, vw), axis=1)
- qkv_weight_dim = 3 * qw_out_size
+ if self.enable_packed_qkv:
+ attention_node_name = self.model.create_node_name("MultiHeadAttention")
- attention_node_name = self.model.create_node_name("Attention")
+ c = qw_in_size
+ n = num_heads
+ h = qw_out_size // num_heads
- weight = helper.make_tensor(
- name=attention_node_name + "_qkv_weight",
- data_type=TensorProto.FLOAT,
- dims=[qw_in_size, qkv_weight_dim],
- vals=qkv_weight.flatten().tolist(),
- )
+ # Concat and interleave weights so that the output of fused KV GEMM has [B, S_kv, N, 3, H] shape
+ qkv_weight = np.dstack([qw.reshape(c, n, h), kw.reshape(c, n, h), vw.reshape(c, n, h)]).reshape(
+ c, n * 3 * h
+ )
- self.model.add_initializer(weight, self.this_graph_name)
+ matmul_node_name = self.model.create_node_name("MatMul", name_prefix="MatMul_QKV")
+ weight = helper.make_tensor(
+ name=matmul_node_name + "_weight",
+ data_type=TensorProto.FLOAT,
+ dims=[qkv_weight.shape[0], qkv_weight.shape[1]],
+ vals=qkv_weight.flatten().tolist(),
+ )
+
+ self.model.add_initializer(weight, self.this_graph_name)
+
+ matmul_node = helper.make_node(
+ "MatMul",
+ inputs=[k_matmul.input[0], matmul_node_name + "_weight"],
+ outputs=[matmul_node_name + "_out"],
+ name=matmul_node_name,
+ )
+ self.node_name_to_graph_name[matmul_node.name] = self.this_graph_name
+
+ shape_tensor = helper.make_tensor(
+ name=matmul_node_name + "_reshape_shape",
+ data_type=TensorProto.INT64,
+ dims=[5],
+ vals=[0, 0, n, 3, h],
+ )
+ self.model.add_initializer(shape_tensor, self.this_graph_name)
+
+ reshape_node = helper.make_node(
+ "Reshape",
+ inputs=[matmul_node_name + "_out", matmul_node_name + "_reshape_shape"],
+ outputs=[attention_node_name + "_input"],
+ name=matmul_node_name + "_reshape",
+ )
+ self.node_name_to_graph_name[reshape_node.name] = self.this_graph_name
+ self.nodes_to_add.extend([matmul_node, reshape_node])
+ self.nodes_to_remove.extend([q_matmul, k_matmul, v_matmul])
+
+ else:
+ qkv_weight = np.stack((qw, kw, vw), axis=1)
+ qkv_weight_dim = 3 * qw_out_size
+
+ attention_node_name = self.model.create_node_name("Attention")
+
+ weight = helper.make_tensor(
+ name=attention_node_name + "_qkv_weight",
+ data_type=TensorProto.FLOAT,
+ dims=[qw_in_size, qkv_weight_dim],
+ vals=qkv_weight.flatten().tolist(),
+ )
+
+ self.model.add_initializer(weight, self.this_graph_name)
else: # cross attention
attention_node_name = self.model.create_node_name("MultiHeadAttention")
if self.enable_packed_kv:
@@ -247,11 +299,14 @@ def create_attention_node(
self.model.add_initializer(bias, self.this_graph_name)
if is_self_attention:
- attention_inputs = [
- input,
- attention_node_name + "_qkv_weight",
- attention_node_name + "_qkv_bias",
- ]
+ if not self.enable_packed_qkv:
+ attention_inputs = [
+ input,
+ attention_node_name + "_qkv_weight",
+ attention_node_name + "_qkv_bias",
+ ]
+ else:
+ attention_inputs = [attention_node_name + "_input"]
else:
if not self.enable_packed_kv:
attention_inputs = [
@@ -267,7 +322,7 @@ def create_attention_node(
]
attention_node = helper.make_node(
- "Attention" if is_self_attention else "MultiHeadAttention",
+ "Attention" if (is_self_attention and not self.enable_packed_qkv) else "MultiHeadAttention",
inputs=attention_inputs,
outputs=[output],
name=attention_node_name,
@@ -277,9 +332,13 @@ def create_attention_node(
counter_name = (
"Attention (self attention)"
- if is_self_attention
+ if is_self_attention and not self.enable_packed_qkv
else "MultiHeadAttention ({})".format(
- "cross attention with packed kv" if self.enable_packed_kv else "cross attention"
+ "self attention with packed qkv"
+ if self.enable_packed_qkv
+ else "cross attention with packed kv"
+ if self.enable_packed_kv
+ else "cross attention"
)
)
self.increase_counter(counter_name)
diff --git a/onnxruntime/python/tools/transformers/fusion_options.py b/onnxruntime/python/tools/transformers/fusion_options.py
index 7629a6886ab94..c0847a6a0c79a 100644
--- a/onnxruntime/python/tools/transformers/fusion_options.py
+++ b/onnxruntime/python/tools/transformers/fusion_options.py
@@ -51,10 +51,12 @@ def __init__(self, model_type):
)
# options for stable diffusion
- self.enable_group_norm = model_type in ["unet", "vae"]
- self.enable_bias_splitgelu = model_type in ["unet"]
- self.enable_packed_kv = model_type in ["unet"]
- self.enable_bias_add = model_type in ["unet"]
+ if model_type in ["unet", "vae", "clip"]:
+ self.enable_group_norm = True
+ self.enable_bias_splitgelu = True
+ self.enable_packed_qkv = True
+ self.enable_packed_kv = True
+ self.enable_bias_add = True
def use_raw_attention_mask(self, use_raw_mask=True):
if use_raw_mask:
@@ -96,10 +98,15 @@ def parse(args):
options.use_raw_attention_mask(True)
if args.no_attention_mask:
options.disable_attention_mask()
- if args.disable_group_norm:
- options.enable_group_norm = False
- if args.disable_packed_kv:
- options.enable_packed_kv = False
+
+ if args.model_type in ["unet", "vae", "clip"]:
+ if args.disable_group_norm:
+ options.enable_group_norm = False
+ if args.disable_packed_kv:
+ options.enable_packed_kv = False
+ if args.disable_packed_qkv:
+ options.enable_packed_qkv = False
+
return options
@staticmethod
@@ -222,7 +229,7 @@ def add_arguments(parser: ArgumentParser):
"--disable_group_norm",
required=False,
action="store_true",
- help="not fuse GroupNorm. Only works for model_type=unet",
+ help="not fuse GroupNorm. Only works for model_type=unet or vae",
)
parser.set_defaults(disable_group_norm=False)
@@ -230,6 +237,30 @@ def add_arguments(parser: ArgumentParser):
"--disable_packed_kv",
required=False,
action="store_true",
- help="not use packed kv in cross attention. Only works for model_type=unet",
+ help="not use packed kv for cross attention in MultiHeadAttention. Only works for model_type=unet",
)
parser.set_defaults(disable_packed_kv=False)
+
+ parser.add_argument(
+ "--disable_packed_qkv",
+ required=False,
+ action="store_true",
+ help="not use packed qkv for self attention in MultiHeadAttention. Only works for model_type=unet",
+ )
+ parser.set_defaults(disable_packed_qkv=False)
+
+ parser.add_argument(
+ "--disable_bias_add",
+ required=False,
+ action="store_true",
+ help="not fuse BiasAdd. Only works for model_type=unet",
+ )
+ parser.set_defaults(disable_bias_add=False)
+
+ parser.add_argument(
+ "--disable_bias_splitgelu",
+ required=False,
+ action="store_true",
+ help="not fuse BiasSplitGelu. Only works for model_type=unet",
+ )
+ parser.set_defaults(disable_bias_splitgelu=False)
diff --git a/onnxruntime/python/tools/transformers/models/stable_diffusion/optimize_pipeline.py b/onnxruntime/python/tools/transformers/models/stable_diffusion/optimize_pipeline.py
index 31b2a22c2f615..40b0d5d858cb2 100644
--- a/onnxruntime/python/tools/transformers/models/stable_diffusion/optimize_pipeline.py
+++ b/onnxruntime/python/tools/transformers/models/stable_diffusion/optimize_pipeline.py
@@ -130,12 +130,13 @@ def optimize_sd_pipeline(
# Right now, onnxruntime does not save >2GB model so we use script to optimize unet instead.
logger.info(f"Optimize {onnx_model_path}...")
- # There are some optimizations that are not avaiable in v1.14 or older version
- has_all_optimizations = version.parse(onnxruntime.__version__) > version.parse("1.14.0")
-
fusion_options = FusionOptions(model_type)
- fusion_options.enable_packed_kv = float16
- fusion_options.enable_bias_add = has_all_optimizations
+ if model_type in ["unet"]:
+ # There are some optimizations that are not available in v1.14 or older version
+ has_all_optimizations = version.parse(onnxruntime.__version__) > version.parse("1.14.0")
+ fusion_options.enable_packed_kv = float16
+ fusion_options.enable_packed_qkv = float16 and has_all_optimizations
+ fusion_options.enable_bias_add = has_all_optimizations
m = optimize_model(
str(onnx_model_path),
diff --git a/onnxruntime/python/tools/transformers/onnx_model_unet.py b/onnxruntime/python/tools/transformers/onnx_model_unet.py
index b460d7d8a7f8f..4873489770757 100644
--- a/onnxruntime/python/tools/transformers/onnx_model_unet.py
+++ b/onnxruntime/python/tools/transformers/onnx_model_unet.py
@@ -91,12 +91,17 @@ def merge_adjacent_transpose(self):
def fuse_attention(self, options: Optional[FusionOptions] = None):
# Self Attention
- self_attention_fusion = FusionAttentionUnet(self, self.hidden_size, self.num_heads, False, False)
+ enable_packed_qkv = (options is None) or options.enable_packed_qkv
+ self_attention_fusion = FusionAttentionUnet(
+ self, self.hidden_size, self.num_heads, False, enable_packed_qkv, False
+ )
self_attention_fusion.apply()
# Cross Attention
enable_packed_kv = (options is None) or options.enable_packed_kv
- cross_attention_fusion = FusionAttentionUnet(self, self.hidden_size, self.num_heads, True, enable_packed_kv)
+ cross_attention_fusion = FusionAttentionUnet(
+ self, self.hidden_size, self.num_heads, True, False, enable_packed_kv
+ )
cross_attention_fusion.apply()
def fuse_bias_add(self):
@@ -152,7 +157,8 @@ def optimize(self, options: Optional[FusionOptions] = None):
self.merge_adjacent_transpose()
- self.fuse_bias_add()
+ if options is not None and options.enable_bias_add:
+ self.fuse_bias_add()
self.postprocess()
diff --git a/onnxruntime/test/contrib_ops/attention_op_test_helper.cc b/onnxruntime/test/contrib_ops/attention_op_test_helper.cc
index 5c4c14ce137ed..802df2b1d4aa4 100644
--- a/onnxruntime/test/contrib_ops/attention_op_test_helper.cc
+++ b/onnxruntime/test/contrib_ops/attention_op_test_helper.cc
@@ -2164,8 +2164,7 @@ void GetCrossAttentionData_Batch1_HeadSize32_LeftSidePadding(AttentionTestData&
data.skip_kernel_types = {
AttentionKernelType::AttentionKernel_TrtFusedCrossAttention,
AttentionKernelType::AttentionKernel_TrtFusedAttention,
- AttentionKernelType::AttentionKernel_CutlassMemoryEfficientAttention
- };
+ AttentionKernelType::AttentionKernel_CutlassMemoryEfficientAttention};
{
data.query_data = {
@@ -2288,6 +2287,413 @@ void GetCrossAttentionData_Batch1_HeadSize32_LeftSidePadding(AttentionTestData&
data.fp16_output_data = data.fp32_output_data;
}
}
+
+void GetCrossAttentionData_Batch2_HeadSize32_NoBias_NoMask_PackedKV(AttentionTestData& data) {
+ data.hidden_size = 32;
+ data.v_hidden_size = 32;
+ data.num_heads = 1;
+ data.batch_size = 2;
+ data.sequence_length = 2;
+ data.kv_sequence_length = 3;
+ data.mask_type = AttentionMaskType::MASK_NONE;
+ // Packed KV format is only supported by TRT fused cross attention or memory efficient attention right now.
+ data.skip_kernel_types = {
+ AttentionKernelType::AttentionKernel_Unfused,
+ AttentionKernelType::AttentionKernel_TrtFusedAttention};
+ {
+ data.query_data = {
+ -0.35420692f, 1.31206024f, -2.80201197f, 2.42258096f, -0.86031514f, -1.44535458f, -0.10832444f, -2.00132895f,
+ 1.62475216f, 0.10978927f, 1.84596729f, 0.48908550f, 1.44369888f, 0.87542874f, -1.16434252f, 0.52133209f,
+ 1.54848897f, -2.21174526f, -0.28574878f, 0.70815033f, 1.18327498f, 3.14097571f, -0.25795099f, 1.89341247f,
+ -0.11603792f, 0.38110194f, 0.40873206f, -1.14149106f, 0.79770875f, -0.98069525f, -1.53588808f, 0.50821728f,
+ -2.21641898f, 0.55090773f, 0.80901796f, -0.56089771f, 0.03574468f, -1.27940118f, -0.02213959f, -0.80698186f,
+ -0.82701880f, 1.72937381f, 1.56083691f, -0.30311784f, -0.25183848f, 0.24280515f, 0.29569417f, -0.31162494f,
+ 0.48996922f, 0.22795241f, 2.07125854f, 1.45823467f, 3.03750706f, 1.53734803f, 0.48668906f, -1.63703632f,
+ -0.14114749f, 1.85963213f, 1.20729232f, -0.28972962f, -0.80783498f, -1.16619551f, -0.60004634f, 0.02498829f,
+
+ 3.50846076f, -2.50027657f, -2.59866142f, 1.58495271f, 2.21110034f, -2.74877763f, -1.00267041f, 0.62646407f,
+ 2.50227380f, -0.27291518f, -0.33037442f, 0.75840306f, 0.45437157f, -0.79876304f, 0.83509272f, 2.53716302f,
+ 0.01348384f, -2.16307616f, 2.01661849f, 2.10746121f, -1.70485222f, 1.35548759f, 1.39401650f, -0.99451691f,
+ -4.13484812f, 0.56262714f, -0.92725742f, -0.16389316f, -1.31260049f, 2.32357836f, -3.05251694f, -1.12570131f,
+ 1.87849474f, -1.80381167f, 0.52235699f, 2.38887334f, -1.58878529f, 0.69571090f, 1.65044296f, -0.27024290f,
+ 3.59580970f, -1.97888982f, 1.17034674f, 0.26716161f, -1.16770899f, 0.74609619f, 0.78886843f, 0.15717520f,
+ -0.93303132f, -0.84753871f, -4.32799959f, -1.94716609f, -1.16980326f, 1.62631667f, 2.41053247f, 3.78186774f,
+ 0.26432252f, -0.40396988f, 2.04414082f, 0.65150046f, 0.47777444f, -2.57569051f, 0.99004912f, 2.47947693f};
+ }
+ {
+ data.key_data = {
+ 2.66554713f, 0.04116637f, -1.14599442f, -1.99071956f, 0.42523879f, 0.94212061f, 1.15597987f, -1.76809072f,
+ -1.89803648f, -0.74707657f, -0.71960962f, 0.67453432f, 3.31946969f, 1.06201041f, 2.29824829f, 0.23788756f,
+ 1.69329333f, 0.06745748f, -1.34720469f, 1.81031406f, -0.33143526f, -2.46566057f, -0.32179555f, 1.69001770f,
+ -0.39678648f, -0.91400242f, 1.56746745f, 0.36029303f, -1.01637018f, -1.84069777f, 0.15860040f, 1.35965717f,
+ 0.16654867f, 3.63810396f, 2.03763342f, 0.64186901f, -1.02682137f, 2.18480039f, -2.17365599f, -0.56225222f,
+ -2.48764873f, 1.94031644f, -1.13630998f, -2.51891637f, -1.29985571f, 0.23808026f, 2.95523596f, 1.06378591f,
+ -0.20339361f, -0.56349581f, 1.46587682f, 4.12142849f, 0.78908098f, -0.24000889f, -1.15510166f, 0.42653239f,
+ -1.98345447f, 1.06918168f, 2.98073006f, -2.94872737f, -0.67443597f, -0.96227646f, -1.94805872f, -0.96003568f,
+ 1.06492281f, 0.32333452f, -0.52869648f, -1.25258100f, 0.75479198f, -1.04409528f, -1.81722605f, 0.99018478f,
+ 1.83352923f, 1.02711058f, 0.31064227f, 2.44383168f, -1.80332434f, 1.57207584f, -0.41058558f, 0.20494992f,
+ -0.78399467f, -0.35703743f, -0.67568171f, -1.30091023f, -0.17390330f, 0.22340816f, -0.44613233f, 1.23870432f,
+ -0.16092014f, -1.22258115f, 0.60575533f, -0.17969827f, 1.87851882f, 1.13991237f, -0.81591004f, -1.68899822f,
+
+ -1.72543812f, 0.63848293f, 0.87042624f, 0.39726460f, 0.62647510f, 1.73326159f, -0.55110240f, -1.26900804f,
+ 1.94843686f, -1.73077893f, 2.53475809f, 2.79892564f, -1.91852188f, 0.99826050f, -3.04680610f, 1.38900220f,
+ -1.17920876f, -2.07508063f, -0.34274688f, -0.24780962f, 1.75715542f, 1.27657294f, -1.15560341f, -2.69310951f,
+ 0.93523502f, 0.58213681f, -2.57009196f, 2.56376076f, 0.06911665f, 1.73962176f, 0.43307841f, -1.18240118f,
+ 1.52338290f, 1.02856898f, 0.40946901f, 1.57649779f, 1.22447217f, 0.85961932f, 0.30765539f, -2.66427660f,
+ -1.55998194f, -0.31161505f, -1.63090813f, -1.62476087f, 1.28381526f, -0.77024549f, 1.46711981f, -0.71657622f,
+ -0.51606011f, 0.87953311f, 0.26169056f, 1.03068113f, 0.41064253f, -1.56344402f, -1.53443003f, -0.03009570f,
+ -0.02123317f, -1.74375248f, 1.60988081f, 1.74488568f, 0.59155780f, -0.62032932f, 0.03105794f, 4.54175377f,
+ -2.08403850f, 0.22352570f, -0.17924348f, 0.65815634f, -0.59089363f, -1.66189861f, 0.75618476f, 0.03879535f,
+ 1.50222909f, -0.29873836f, -1.76075482f, -2.97067928f, 0.28112072f, 0.72105575f, 0.06761266f, -1.61681306f,
+ -0.80693424f, 2.40102959f, -2.91352296f, -1.21352315f, -1.62430143f, -1.60909438f, 0.53140688f, -0.28235722f,
+ 0.63271880f, 1.33791542f, -1.37593675f, -1.60502291f, 1.27470064f, -0.96280038f, 0.79614848f, 0.31894624f};
+ }
+
+ {
+ data.value_data = {
+ 0.67641568f, -1.44437671f, 0.57255918f, 0.11087912f, 0.73787844f, -1.36586773f,
+ 1.45507979f, -3.70331645f, -0.85970032f, -2.14545083f, 2.11107802f, -0.16663373f,
+ 1.47033095f, 0.49124131f, 1.99316287f, -2.68613410f, 0.23831765f, 0.90826637f,
+ 0.72628385f, 1.29567933f, -0.07918698f, 0.13999116f, 1.22531521f, 0.06399018f,
+ -2.24613571f, -1.08365369f, -0.68457615f, -0.25960952f, -0.88386559f, 0.46391147f,
+ 1.24469304f, 1.13121903f,
+ -0.21484625f, 3.49263334f, -1.35283577f, 0.38428289f, -4.29686069f, -4.34778786f,
+ -0.49574745f, -0.08637778f, -0.50855160f, -1.12334609f, -1.44851387f, 3.36797357f,
+ -0.91776383f, -0.98647243f, 1.45408130f, 0.29062888f, 0.24470398f, -1.28129590f,
+ 0.47530234f, 2.19562674f, 0.62674099f, -2.56222868f, -1.42671025f, 1.51795268f,
+ -1.92045701f, 1.20271325f, 2.53190184f, -0.37211552f, 0.92569226f, -1.11019444f,
+ 1.15402830f, -1.98479640f,
+ -0.49658760f, 1.62168694f, -1.71412969f, -1.26646388f, -1.37257946f, 1.53828073f,
+ -0.35583261f, 0.03810386f, 0.43514529f, 0.97525519f, -2.22109556f, 1.17547810f,
+ -0.28825673f, 0.91509271f, -1.19243717f, 1.09280133f, -0.51078367f, 0.63577116f,
+ -0.62186599f, -2.80234575f, -1.58007598f, 1.06965756f, -0.89327252f, -0.84735525f,
+ -0.46283475f, 0.77867299f, -0.07434830f, 1.44711912f, 1.07089376f, 0.78913736f,
+ 0.59053934f, -0.32160193f,
+
+ 0.51273453f, 1.12628150f, 1.96404183f, 0.26380035f, 3.41526699f, 1.08249199f,
+ -1.70347631f, 0.42854923f, -1.98269284f, 1.97382474f, -0.12164606f, -1.41219604f,
+ 0.01819625f, 0.73082930f, -2.60845804f, 1.47046185f, 0.26324001f, 1.54259276f,
+ -1.18744254f, -1.77539694f, 1.76547086f, -1.57072937f, -1.83995926f, -0.05529352f,
+ 1.83544660f, 0.69575423f, -0.03345531f, -1.69629955f, 0.04713173f, 1.39800107f,
+ 0.24362923f, 0.12432972f,
+ -2.92895460f, -0.46070760f, 0.20383459f, 1.93618548f, -1.08026588f, 1.08253515f,
+ -0.48318014f, -2.34334373f, -2.69622159f, 0.00661799f, -1.10738027f, 0.03181311f,
+ 0.32897863f, 1.89451993f, -0.01152946f, 0.17766151f, 2.46450090f, -0.64409554f,
+ 2.56058550f, 1.29339278f, 2.72114944f, 0.87801707f, -1.58970404f, 2.88365316f,
+ 0.46464550f, -1.71912467f, -1.90960062f, -3.13572145f, 0.19871379f, -0.28741950f,
+ -0.38167781f, -2.30705547f,
+ 0.64399612f, 0.32866889f, -3.49091625f, -0.02294427f, 1.60225844f, 1.83659923f,
+ 1.55193460f, -0.06712314f, 0.76592684f, 0.83479869f, 0.49627584f, 0.75736403f,
+ 0.75179487f, -0.32156041f, 1.36537170f, 0.57024354f, 0.36152276f, 0.93625057f,
+ -1.69728792f, -0.28833422f, 0.43304375f, 1.62640548f, -0.00187188f, 0.80429250f,
+ -0.77993584f, 1.37333393f, -1.16019452f, -0.91983509f, 0.20466281f, 1.09339333f,
+ -0.99191529f, 3.42685890f};
+ }
+
+ {
+ data.bias_data = {
+ // 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f,
+ // 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f,
+ // 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f,
+ // 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f,
+ // 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f
+ };
+ }
+
+ {
+ data.kv_data = {
+ 2.66554713f, 0.04116637f, -1.14599442f, -1.99071956f, 0.42523879f, 0.94212061f,
+ 1.15597987f, -1.76809072f, -1.89803648f, -0.74707657f, -0.71960962f, 0.67453432f,
+ 3.31946969f, 1.06201041f, 2.29824829f, 0.23788756f, 1.69329333f, 0.06745748f,
+ -1.34720469f, 1.81031406f, -0.33143526f, -2.46566057f, -0.32179555f, 1.69001770f,
+ -0.39678648f, -0.91400242f, 1.56746745f, 0.36029303f, -1.01637018f, -1.84069777f,
+ 0.15860040f, 1.35965717f,
+ 0.67641568f, -1.44437671f, 0.57255918f, 0.11087912f, 0.73787844f, -1.36586773f,
+ 1.45507979f, -3.70331645f, -0.85970032f, -2.14545083f, 2.11107802f, -0.16663373f,
+ 1.47033095f, 0.49124131f, 1.99316287f, -2.68613410f, 0.23831765f, 0.90826637f,
+ 0.72628385f, 1.29567933f, -0.07918698f, 0.13999116f, 1.22531521f, 0.06399018f,
+ -2.24613571f, -1.08365369f, -0.68457615f, -0.25960952f, -0.88386559f, 0.46391147f,
+ 1.24469304f, 1.13121903f,
+
+ 0.16654867f, 3.63810396f, 2.03763342f, 0.64186901f, -1.02682137f, 2.18480039f,
+ -2.17365599f, -0.56225222f, -2.48764873f, 1.94031644f, -1.13630998f, -2.51891637f,
+ -1.29985571f, 0.23808026f, 2.95523596f, 1.06378591f, -0.20339361f, -0.56349581f,
+ 1.46587682f, 4.12142849f, 0.78908098f, -0.24000889f, -1.15510166f, 0.42653239f,
+ -1.98345447f, 1.06918168f, 2.98073006f, -2.94872737f, -0.67443597f, -0.96227646f,
+ -1.94805872f, -0.96003568f,
+ -0.21484625f, 3.49263334f, -1.35283577f, 0.38428289f, -4.29686069f, -4.34778786f,
+ -0.49574745f, -0.08637778f, -0.50855160f, -1.12334609f, -1.44851387f, 3.36797357f,
+ -0.91776383f, -0.98647243f, 1.45408130f, 0.29062888f, 0.24470398f, -1.28129590f,
+ 0.47530234f, 2.19562674f, 0.62674099f, -2.56222868f, -1.42671025f, 1.51795268f,
+ -1.92045701f, 1.20271325f, 2.53190184f, -0.37211552f, 0.92569226f, -1.11019444f,
+ 1.15402830f, -1.98479640f,
+
+ 1.06492281f, 0.32333452f, -0.52869648f, -1.25258100f, 0.75479198f, -1.04409528f,
+ -1.81722605f, 0.99018478f, 1.83352923f, 1.02711058f, 0.31064227f, 2.44383168f,
+ -1.80332434f, 1.57207584f, -0.41058558f, 0.20494992f, -0.78399467f, -0.35703743f,
+ -0.67568171f, -1.30091023f, -0.17390330f, 0.22340816f, -0.44613233f, 1.23870432f,
+ -0.16092014f, -1.22258115f, 0.60575533f, -0.17969827f, 1.87851882f, 1.13991237f,
+ -0.81591004f, -1.68899822f,
+ -0.49658760f, 1.62168694f, -1.71412969f, -1.26646388f, -1.37257946f, 1.53828073f,
+ -0.35583261f, 0.03810386f, 0.43514529f, 0.97525519f, -2.22109556f, 1.17547810f,
+ -0.28825673f, 0.91509271f, -1.19243717f, 1.09280133f, -0.51078367f, 0.63577116f,
+ -0.62186599f, -2.80234575f, -1.58007598f, 1.06965756f, -0.89327252f, -0.84735525f,
+ -0.46283475f, 0.77867299f, -0.07434830f, 1.44711912f, 1.07089376f, 0.78913736f,
+ 0.59053934f, -0.32160193f,
+
+ -1.72543812f, 0.63848293f, 0.87042624f, 0.39726460f, 0.62647510f, 1.73326159f,
+ -0.55110240f, -1.26900804f, 1.94843686f, -1.73077893f, 2.53475809f, 2.79892564f,
+ -1.91852188f, 0.99826050f, -3.04680610f, 1.38900220f, -1.17920876f, -2.07508063f,
+ -0.34274688f, -0.24780962f, 1.75715542f, 1.27657294f, -1.15560341f, -2.69310951f,
+ 0.93523502f, 0.58213681f, -2.57009196f, 2.56376076f, 0.06911665f, 1.73962176f,
+ 0.43307841f, -1.18240118f,
+ 0.51273453f, 1.12628150f, 1.96404183f, 0.26380035f, 3.41526699f, 1.08249199f,
+ -1.70347631f, 0.42854923f, -1.98269284f, 1.97382474f, -0.12164606f, -1.41219604f,
+ 0.01819625f, 0.73082930f, -2.60845804f, 1.47046185f, 0.26324001f, 1.54259276f,
+ -1.18744254f, -1.77539694f, 1.76547086f, -1.57072937f, -1.83995926f, -0.05529352f,
+ 1.83544660f, 0.69575423f, -0.03345531f, -1.69629955f, 0.04713173f, 1.39800107f,
+ 0.24362923f, 0.12432972f,
+
+ 1.52338290f, 1.02856898f, 0.40946901f, 1.57649779f, 1.22447217f, 0.85961932f,
+ 0.30765539f, -2.66427660f, -1.55998194f, -0.31161505f, -1.63090813f, -1.62476087f,
+ 1.28381526f, -0.77024549f, 1.46711981f, -0.71657622f, -0.51606011f, 0.87953311f,
+ 0.26169056f, 1.03068113f, 0.41064253f, -1.56344402f, -1.53443003f, -0.03009570f,
+ -0.02123317f, -1.74375248f, 1.60988081f, 1.74488568f, 0.59155780f, -0.62032932f,
+ 0.03105794f, 4.54175377f,
+ -2.92895460f, -0.46070760f, 0.20383459f, 1.93618548f, -1.08026588f, 1.08253515f,
+ -0.48318014f, -2.34334373f, -2.69622159f, 0.00661799f, -1.10738027f, 0.03181311f,
+ 0.32897863f, 1.89451993f, -0.01152946f, 0.17766151f, 2.46450090f, -0.64409554f,
+ 2.56058550f, 1.29339278f, 2.72114944f, 0.87801707f, -1.58970404f, 2.88365316f,
+ 0.46464550f, -1.71912467f, -1.90960062f, -3.13572145f, 0.19871379f, -0.28741950f,
+ -0.38167781f, -2.30705547f,
+
+ -2.08403850f, 0.22352570f, -0.17924348f, 0.65815634f, -0.59089363f, -1.66189861f,
+ 0.75618476f, 0.03879535f, 1.50222909f, -0.29873836f, -1.76075482f, -2.97067928f,
+ 0.28112072f, 0.72105575f, 0.06761266f, -1.61681306f, -0.80693424f, 2.40102959f,
+ -2.91352296f, -1.21352315f, -1.62430143f, -1.60909438f, 0.53140688f, -0.28235722f,
+ 0.63271880f, 1.33791542f, -1.37593675f, -1.60502291f, 1.27470064f, -0.96280038f,
+ 0.79614848f, 0.31894624f,
+ 0.64399612f, 0.32866889f, -3.49091625f, -0.02294427f, 1.60225844f, 1.83659923f,
+ 1.55193460f, -0.06712314f, 0.76592684f, 0.83479869f, 0.49627584f, 0.75736403f,
+ 0.75179487f, -0.32156041f, 1.36537170f, 0.57024354f, 0.36152276f, 0.93625057f,
+ -1.69728792f, -0.28833422f, 0.43304375f, 1.62640548f, -0.00187188f, 0.80429250f,
+ -0.77993584f, 1.37333393f, -1.16019452f, -0.91983509f, 0.20466281f, 1.09339333f,
+ -0.99191529f, 3.42685890f};
+ }
+
+ {
+ // Do not test fp32
+ data.fp32_output_data = {};
+ }
+
+ {
+ data.fp16_output_data = {
+ -0.18665725f, 1.53655565f, -1.16219902f, -0.53553712f, -1.76899862f, -0.67172408f,
+ -0.03719823f, -0.73519617f, -0.08289805f, -0.22439885f, -1.15095568f, 1.52012229f,
+ -0.11608444f, 0.30267856f, 0.17237782f, 0.12366229f, -0.15282108f, 0.15652999f,
+ -0.05062571f, -0.60356319f, -0.67014134f, -0.12373877f, -0.62331146f, -0.00974876f,
+ -1.22021353f, 0.52888882f, 0.52984023f, 0.60431194f, 0.64458221f, 0.19681633f,
+ 0.87637067f, -0.49721599f,
+
+ -0.21368809f, 3.48006248f, -1.34989023f, 0.38081154f, -4.28221798f, -4.33166409f,
+ -0.49185523f, -0.09290559f, -0.50751436f, -1.12148952f, -1.44325578f, 3.35744357f,
+ -0.91217732f, -0.98030341f, 1.45034039f, 0.28651160f, 0.24333692f, -1.27377033f,
+ 0.47380278f, 2.18498206f, 0.62146491f, -2.55067039f, -1.42080343f, 1.51099622f,
+ -1.91845036f, 1.19768524f, 2.52122355f, -0.36864227f, 0.92257524f, -1.10384953f,
+ 1.15318692f, -1.97599709f,
+
+ 0.25325853f, 0.99478984f, 1.75511229f, 0.38680723f, 3.04894686f, 1.09290445f,
+ -1.56589723f, 0.21126932f, -1.99892235f, 1.80875492f, -0.18795207f, -1.27262163f,
+ 0.05191655f, 0.80464834f, -2.35645056f, 1.35988820f, 0.43171301f, 1.36821294f,
+ -0.90993541f, -1.52189243f, 1.81963241f, -1.34069264f, -1.79558825f, 0.17969209f,
+ 1.69527614f, 0.52177316f, -0.19144230f, -1.79486036f, 0.06081408f, 1.26584184f,
+ 0.17910211f, -0.01467115f,
+
+ 0.15915775f, 0.23589413f, -2.89726520f, 0.24662971f, 1.27141249f, 1.72167253f,
+ 1.22059381f, -0.36594886f, 0.25064397f, 0.74270636f, 0.26896626f, 0.62173104f,
+ 0.68196213f, -0.00399938f, 1.11046481f, 0.53283978f, 0.64384484f, 0.73332942f,
+ -1.11337614f, -0.10050645f, 0.76519096f, 1.46986043f, -0.24821334f, 1.07021677f,
+ -0.56646812f, 0.94391233f, -1.24186087f, -1.23258281f, 0.20112629f, 0.91218621f,
+ -0.88806105f, 2.59514260f};
+ }
+}
+
+void GetSelfAttentionData_Batch2_HeadSize32_NoBias_NoMask_PackedQKV(AttentionTestData& data) {
+ data.hidden_size = 32;
+ data.v_hidden_size = 32;
+ data.num_heads = 1;
+ data.batch_size = 2;
+ data.sequence_length = 2;
+ data.kv_sequence_length = 2;
+ data.mask_type = AttentionMaskType::MASK_NONE;
+ // Packed QKV format is only supported by TRT fused attention or memory efficient attention right now.
+ data.skip_kernel_types = {
+ AttentionKernelType::AttentionKernel_Unfused,
+ AttentionKernelType::AttentionKernel_TrtFusedCrossAttention};
+
+ {
+ data.query_data = {
+ -0.35420692f, 1.31206024f, -2.80201197f, 2.42258096f, -0.86031514f, -1.44535458f, -0.10832444f, -2.00132895f,
+ 1.62475216f, 0.10978927f, 1.84596729f, 0.48908550f, 1.44369888f, 0.87542874f, -1.16434252f, 0.52133209f,
+ 1.54848897f, -2.21174526f, -0.28574878f, 0.70815033f, 1.18327498f, 3.14097571f, -0.25795099f, 1.89341247f,
+ -0.11603792f, 0.38110194f, 0.40873206f, -1.14149106f, 0.79770875f, -0.98069525f, -1.53588808f, 0.50821728f,
+ -2.21641898f, 0.55090773f, 0.80901796f, -0.56089771f, 0.03574468f, -1.27940118f, -0.02213959f, -0.80698186f,
+ -0.82701880f, 1.72937381f, 1.56083691f, -0.30311784f, -0.25183848f, 0.24280515f, 0.29569417f, -0.31162494f,
+ 0.48996922f, 0.22795241f, 2.07125854f, 1.45823467f, 3.03750706f, 1.53734803f, 0.48668906f, -1.63703632f,
+ -0.14114749f, 1.85963213f, 1.20729232f, -0.28972962f, -0.80783498f, -1.16619551f, -0.60004634f, 0.02498829f,
+
+ 3.50846076f, -2.50027657f, -2.59866142f, 1.58495271f, 2.21110034f, -2.74877763f, -1.00267041f, 0.62646407f,
+ 2.50227380f, -0.27291518f, -0.33037442f, 0.75840306f, 0.45437157f, -0.79876304f, 0.83509272f, 2.53716302f,
+ 0.01348384f, -2.16307616f, 2.01661849f, 2.10746121f, -1.70485222f, 1.35548759f, 1.39401650f, -0.99451691f,
+ -4.13484812f, 0.56262714f, -0.92725742f, -0.16389316f, -1.31260049f, 2.32357836f, -3.05251694f, -1.12570131f,
+ 1.87849474f, -1.80381167f, 0.52235699f, 2.38887334f, -1.58878529f, 0.69571090f, 1.65044296f, -0.27024290f,
+ 3.59580970f, -1.97888982f, 1.17034674f, 0.26716161f, -1.16770899f, 0.74609619f, 0.78886843f, 0.15717520f,
+ -0.93303132f, -0.84753871f, -4.32799959f, -1.94716609f, -1.16980326f, 1.62631667f, 2.41053247f, 3.78186774f,
+ 0.26432252f, -0.40396988f, 2.04414082f, 0.65150046f, 0.47777444f, -2.57569051f, 0.99004912f, 2.47947693f};
+ }
+ {
+ data.key_data = {
+ -0.04407793f, 1.29459429f, 1.05810797f, 1.92067695f, -0.65047157f, 0.99029726f, -1.69796586f, 1.15320420f,
+ -1.66444266f, 1.78305888f, 1.20582056f, 1.69975281f, 0.34572244f, -0.60833001f, 2.59864879f, -1.05330181f,
+ -1.16554165f, -0.03781542f, -1.13475525f, 0.71595150f, -0.91169560f, 1.26686060f, 1.60492957f, -0.53510487f,
+ -1.40180850f, 1.83253956f, 2.70238972f, -1.48750985f, 0.47105616f, -0.79477602f, -1.93152475f, 1.04042351f,
+ 1.21863425f, 1.20610654f, 0.69031805f, 2.60092020f, 1.43040228f, 0.60616529f, 0.47948456f, -1.15139377f,
+ 0.15641990f, -0.46933329f, 0.64774191f, 0.35970241f, -1.00424135f, 0.01247875f, 1.00281739f, -1.10514688f,
+ 0.30922988f, -0.82255656f, -1.23242986f, 0.90557313f, -0.38946581f, -0.21124774f, -2.37903309f, -1.42169905f,
+ -0.05935127f, 0.47488672f, -0.37083727f, 1.31585515f, -0.21577421f, -0.97746384f, -0.13380399f, 1.77390409f,
+
+ -2.65206385f, 1.26134932f, -1.01682174f, 0.64366758f, 0.95474619f, 2.06720352f, 0.51750720f, -0.07041813f,
+ 0.53124994f, -3.26612782f, 1.37013340f, 0.13939659f, -0.57418114f, 0.80680281f, -3.40751696f, -0.15847699f,
+ 0.97837782f, -0.09121911f, 1.18452120f, 0.52711177f, -1.86135840f, -0.11258313f, 0.85863215f, -2.60261130f,
+ 0.72695309f, 1.44092011f, 0.43785980f, -1.63415265f, -1.05772328f, 0.12997569f, 0.07356137f, -0.62493324f,
+ -0.43267637f, -1.80009198f, 0.92961007f, 2.05127883f, -2.85521173f, -0.21652693f, -0.89153922f, 0.15524670f,
+ -2.16850328f, 1.46751809f, 2.51663852f, -0.49499366f, 0.19886012f, 0.77093124f, -1.14819765f, 1.47111738f,
+ 2.42824388f, 1.56369960f, 1.69934130f, -0.42460468f, -2.25951004f, -1.18074155f, 3.51091242f, -0.30183151f,
+ -1.83517075f, -0.56233191f, 2.35561657f, -3.63751698f, -3.20001125f, -1.66120780f, 3.23455381f, -1.86251283f};
+ }
+
+ {
+ data.value_data = {
+ -0.89167893f, 0.02633595f, -0.84866279f, 1.43489110f, -2.91941142f, -0.20650116f, 1.85965109f, 0.45669034f,
+ 0.07678832f, 0.04492294f, 0.67326981f, 0.97103029f, 1.53470886f, -1.10242307f, 0.86584085f, -0.34770033f,
+ -1.24311507f, -1.80293822f, -1.01317739f, -0.71518499f, 0.77814674f, -0.59236068f, -2.00310278f, 3.13277125f,
+ -1.20754123f, 2.01506066f, 0.82650810f, 2.06084490f, -0.46267471f, 1.56365979f, 4.31514502f, -1.03099275f,
+ -1.85462761f, 2.10100341f, 1.79686451f, 0.23871201f, 1.23598254f, -0.31959364f, 0.50101948f, -0.09527110f,
+ -1.02331078f, 0.16319990f, -0.54766160f, 0.41597658f, -0.52141404f, 1.71663237f, -0.00776333f, -0.68160462f,
+ 1.76272714f, -0.04465733f, 0.28247434f, 1.69360149f, 0.14144623f, 0.75038731f, -1.33337545f, 2.23457718f,
+ -0.07649468f, 1.97064841f, -1.85374629f, -1.59334683f, 0.32698441f, -0.16024286f, 2.02828407f, -0.96440399f,
+
+ -2.11639142f, -1.50897706f, 1.63863683f, 2.32786226f, 1.32746494f, 0.75751448f, 0.57184196f, 0.86446053f,
+ -0.62406683f, 0.78861046f, 0.01044065f, 3.51772785f, -1.33701336f, 0.27977663f, -0.35464612f, 0.74973166f,
+ 0.03352100f, 1.55007398f, 0.69849420f, -2.47725606f, -1.89363778f, -1.79874682f, -0.56210291f, -1.75556040f,
+ 1.07565808f, -0.18023658f, 1.63777173f, 1.28198206f, 2.19431949f, 0.67998970f, -0.52531999f, -1.89906740f,
+ 1.35158050f, -2.21481490f, -0.11812399f, -1.74263430f, -0.57895988f, -0.04181165f, 0.78120053f, -2.22377038f,
+ -0.53264999f, -2.03721714f, 0.21023634f, 2.55751204f, -1.04522800f, 0.85386503f, 0.41594937f, -2.98181081f,
+ 1.14034331f, -1.41539204f, 0.13379651f, 3.47018123f, 1.53924727f, 1.50004411f, 2.87318921f, 1.62624204f,
+ 0.64942807f, -4.54302311f, -1.50294220f, -1.75212634f, 0.27900690f, -3.05124855f, 3.30960631f, -0.07991691f};
+ }
+
+ {
+ data.bias_data = {
+ // 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f,
+ // 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f,
+ // 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f,
+ // 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f,
+ // 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f
+ };
+ }
+
+ {
+ data.qkv_data = {
+ -0.35420692f, 1.31206024f, -2.80201197f, 2.42258096f, -0.86031514f, -1.44535458f, -0.10832444f, -2.00132895f,
+ 1.62475216f, 0.10978927f, 1.84596729f, 0.48908550f, 1.44369888f, 0.87542874f, -1.16434252f, 0.52133209f,
+ 1.54848897f, -2.21174526f, -0.28574878f, 0.70815033f, 1.18327498f, 3.14097571f, -0.25795099f, 1.89341247f,
+ -0.11603792f, 0.38110194f, 0.40873206f, -1.14149106f, 0.79770875f, -0.98069525f, -1.53588808f, 0.50821728f,
+ -0.04407793f, 1.29459429f, 1.05810797f, 1.92067695f, -0.65047157f, 0.99029726f, -1.69796586f, 1.15320420f,
+ -1.66444266f, 1.78305888f, 1.20582056f, 1.69975281f, 0.34572244f, -0.60833001f, 2.59864879f, -1.05330181f,
+ -1.16554165f, -0.03781542f, -1.13475525f, 0.71595150f, -0.91169560f, 1.26686060f, 1.60492957f, -0.53510487f,
+ -1.40180850f, 1.83253956f, 2.70238972f, -1.48750985f, 0.47105616f, -0.79477602f, -1.93152475f, 1.04042351f,
+ -0.89167893f, 0.02633595f, -0.84866279f, 1.43489110f, -2.91941142f, -0.20650116f, 1.85965109f, 0.45669034f,
+ 0.07678832f, 0.04492294f, 0.67326981f, 0.97103029f, 1.53470886f, -1.10242307f, 0.86584085f, -0.34770033f,
+ -1.24311507f, -1.80293822f, -1.01317739f, -0.71518499f, 0.77814674f, -0.59236068f, -2.00310278f, 3.13277125f,
+ -1.20754123f, 2.01506066f, 0.82650810f, 2.06084490f, -0.46267471f, 1.56365979f, 4.31514502f, -1.03099275f,
+
+ -2.21641898f, 0.55090773f, 0.80901796f, -0.56089771f, 0.03574468f, -1.27940118f, -0.02213959f, -0.80698186f,
+ -0.82701880f, 1.72937381f, 1.56083691f, -0.30311784f, -0.25183848f, 0.24280515f, 0.29569417f, -0.31162494f,
+ 0.48996922f, 0.22795241f, 2.07125854f, 1.45823467f, 3.03750706f, 1.53734803f, 0.48668906f, -1.63703632f,
+ -0.14114749f, 1.85963213f, 1.20729232f, -0.28972962f, -0.80783498f, -1.16619551f, -0.60004634f, 0.02498829f,
+ 1.21863425f, 1.20610654f, 0.69031805f, 2.60092020f, 1.43040228f, 0.60616529f, 0.47948456f, -1.15139377f,
+ 0.15641990f, -0.46933329f, 0.64774191f, 0.35970241f, -1.00424135f, 0.01247875f, 1.00281739f, -1.10514688f,
+ 0.30922988f, -0.82255656f, -1.23242986f, 0.90557313f, -0.38946581f, -0.21124774f, -2.37903309f, -1.42169905f,
+ -0.05935127f, 0.47488672f, -0.37083727f, 1.31585515f, -0.21577421f, -0.97746384f, -0.13380399f, 1.77390409f,
+ -1.85462761f, 2.10100341f, 1.79686451f, 0.23871201f, 1.23598254f, -0.31959364f, 0.50101948f, -0.09527110f,
+ -1.02331078f, 0.16319990f, -0.54766160f, 0.41597658f, -0.52141404f, 1.71663237f, -0.00776333f, -0.68160462f,
+ 1.76272714f, -0.04465733f, 0.28247434f, 1.69360149f, 0.14144623f, 0.75038731f, -1.33337545f, 2.23457718f,
+ -0.07649468f, 1.97064841f, -1.85374629f, -1.59334683f, 0.32698441f, -0.16024286f, 2.02828407f, -0.96440399f,
+
+ 3.50846076f, -2.50027657f, -2.59866142f, 1.58495271f, 2.21110034f, -2.74877763f, -1.00267041f, 0.62646407f,
+ 2.50227380f, -0.27291518f, -0.33037442f, 0.75840306f, 0.45437157f, -0.79876304f, 0.83509272f, 2.53716302f,
+ 0.01348384f, -2.16307616f, 2.01661849f, 2.10746121f, -1.70485222f, 1.35548759f, 1.39401650f, -0.99451691f,
+ -4.13484812f, 0.56262714f, -0.92725742f, -0.16389316f, -1.31260049f, 2.32357836f, -3.05251694f, -1.12570131f,
+ -2.65206385f, 1.26134932f, -1.01682174f, 0.64366758f, 0.95474619f, 2.06720352f, 0.51750720f, -0.07041813f,
+ 0.53124994f, -3.26612782f, 1.37013340f, 0.13939659f, -0.57418114f, 0.80680281f, -3.40751696f, -0.15847699f,
+ 0.97837782f, -0.09121911f, 1.18452120f, 0.52711177f, -1.86135840f, -0.11258313f, 0.85863215f, -2.60261130f,
+ 0.72695309f, 1.44092011f, 0.43785980f, -1.63415265f, -1.05772328f, 0.12997569f, 0.07356137f, -0.62493324f,
+ -2.11639142f, -1.50897706f, 1.63863683f, 2.32786226f, 1.32746494f, 0.75751448f, 0.57184196f, 0.86446053f,
+ -0.62406683f, 0.78861046f, 0.01044065f, 3.51772785f, -1.33701336f, 0.27977663f, -0.35464612f, 0.74973166f,
+ 0.03352100f, 1.55007398f, 0.69849420f, -2.47725606f, -1.89363778f, -1.79874682f, -0.56210291f, -1.75556040f,
+ 1.07565808f, -0.18023658f, 1.63777173f, 1.28198206f, 2.19431949f, 0.67998970f, -0.52531999f, -1.89906740f,
+
+ 1.87849474f, -1.80381167f, 0.52235699f, 2.38887334f, -1.58878529f, 0.69571090f, 1.65044296f, -0.27024290f,
+ 3.59580970f, -1.97888982f, 1.17034674f, 0.26716161f, -1.16770899f, 0.74609619f, 0.78886843f, 0.15717520f,
+ -0.93303132f, -0.84753871f, -4.32799959f, -1.94716609f, -1.16980326f, 1.62631667f, 2.41053247f, 3.78186774f,
+ 0.26432252f, -0.40396988f, 2.04414082f, 0.65150046f, 0.47777444f, -2.57569051f, 0.99004912f, 2.47947693f,
+ -0.43267637f, -1.80009198f, 0.92961007f, 2.05127883f, -2.85521173f, -0.21652693f, -0.89153922f, 0.15524670f,
+ -2.16850328f, 1.46751809f, 2.51663852f, -0.49499366f, 0.19886012f, 0.77093124f, -1.14819765f, 1.47111738f,
+ 2.42824388f, 1.56369960f, 1.69934130f, -0.42460468f, -2.25951004f, -1.18074155f, 3.51091242f, -0.30183151f,
+ -1.83517075f, -0.56233191f, 2.35561657f, -3.63751698f, -3.20001125f, -1.66120780f, 3.23455381f, -1.86251283f,
+ 1.35158050f, -2.21481490f, -0.11812399f, -1.74263430f, -0.57895988f, -0.04181165f, 0.78120053f, -2.22377038f,
+ -0.53264999f, -2.03721714f, 0.21023634f, 2.55751204f, -1.04522800f, 0.85386503f, 0.41594937f, -2.98181081f,
+ 1.14034331f, -1.41539204f, 0.13379651f, 3.47018123f, 1.53924727f, 1.50004411f, 2.87318921f, 1.62624204f,
+ 0.64942807f, -4.54302311f, -1.50294220f, -1.75212634f, 0.27900690f, -3.05124855f, 3.30960631f, -0.07991691f};
+ }
+
+ {
+ // Do not test fp32
+ data.fp32_output_data = {};
+ }
+
+ {
+ data.fp16_output_data = {
+ -1.30247164f, 0.91138631f, 0.27991560f, 0.92460269f, -1.14672589f, -0.25474626f,
+ 1.28006065f, 0.22122431f, -0.39251250f, 0.09537974f, 0.15242209f, 0.73424512f,
+ 0.65756959f, 0.10018224f, 0.49316248f, -0.49014348f, 0.03917319f, -1.05285788f,
+ -0.46045411f, 0.31240013f, 0.50653118f, -0.01954618f, -1.71739793f, 2.74960279f,
+ -0.72503829f, 1.99611449f, -0.31688485f, 0.50197154f, -0.12580720f, 0.82824522f,
+ 3.33957314f, -1.00258613f,
+
+ -0.95444643f, 0.16156809f, -0.67622054f, 1.35692120f, -2.64855242f, -0.21387282f,
+ 1.77109206f, 0.42071211f, 0.00508105f, 0.05263254f, 0.59368640f, 0.93485051f,
+ 1.40068567f, -0.91866994f, 0.80889714f, -0.36946508f, -1.04718661f, -1.68832910f,
+ -0.92872351f, -0.55817413f, 0.73664504f, -0.50483698f, -1.95944834f, 3.07422471f,
+ -1.13381684f, 2.01216578f, 0.65180230f, 1.82265544f, -0.41120273f, 1.45129156f,
+ 4.16608191f, -1.02665234f,
+
+ 0.21158576f, -1.98279130f, 0.45935997f, -0.40457720f, 0.04772174f, 0.22094353f,
+ 0.71238005f, -1.20860445f, -0.56270063f, -1.10830855f, 0.14455934f, 2.87315512f,
+ -1.14114404f, 0.66515017f, 0.16263856f, -1.75517499f, 0.77650774f, -0.44058144f,
+ 0.31942442f, 1.51513433f, 0.41078627f, 0.41566271f, 1.74393702f, 0.51457298f,
+ 0.78953874f, -3.10888410f, -0.47052401f, -0.75475156f, 0.90861011f, -1.82471263f,
+ 2.04898596f, -0.67790967f,
+
+ 1.17194295f, -2.17825341f, -0.02712549f, -1.53178656f, -0.48020893f, -0.00040733f,
+ 0.77035600f, -2.06380320f, -0.53738528f, -1.89084208f, 0.19988713f, 2.60725045f,
+ -1.06034219f, 0.82412785f, 0.37603328f, -2.78852081f, 1.08301103f, -1.26178384f,
+ 0.16304730f, 3.16210985f, 1.36142719f, 1.32916999f, 2.69524455f, 1.45106804f,
+ 0.67150640f, -4.31703520f, -1.34025633f, -1.59496248f, 0.37821823f, -2.85797405f,
+ 3.11096096f, -0.17414713f};
+ }
+}
#endif
void GetCrossAttentionData_HeadSize16_8(AttentionTestData& data) {
diff --git a/onnxruntime/test/contrib_ops/attention_op_test_helper.h b/onnxruntime/test/contrib_ops/attention_op_test_helper.h
index 664fbb50aa6d7..3a00b08b45318 100644
--- a/onnxruntime/test/contrib_ops/attention_op_test_helper.h
+++ b/onnxruntime/test/contrib_ops/attention_op_test_helper.h
@@ -5,38 +5,45 @@
#include "contrib_ops/cpu/bert/attention_common.h"
namespace onnxruntime {
-using contrib::AttentionMaskType;
using contrib::AttentionKernelType;
+using contrib::AttentionMaskType;
namespace test {
-struct AttentionTestData{
- int hidden_size;
- int v_hidden_size;
- int num_heads;
- int batch_size;
- int sequence_length;
- int kv_sequence_length;
- AttentionMaskType mask_type;
- std::vector key_padding_mask_data;
- std::vector query_data;
- std::vector key_data;
- std::vector value_data;
- std::vector bias_data;
- std::vector fp32_output_data;
- std::vector fp16_output_data;
- std::vector skip_kernel_types; // skip some kernels if they do not supported this test case.
+struct AttentionTestData {
+ int hidden_size;
+ int v_hidden_size;
+ int num_heads;
+ int batch_size;
+ int sequence_length;
+ int kv_sequence_length;
+ AttentionMaskType mask_type;
+ std::vector key_padding_mask_data;
+ std::vector query_data;
+ std::vector key_data;
+ std::vector value_data;
+
+ std::vector kv_data;
+ std::vector qkv_data;
+
+ std::vector bias_data;
+ std::vector fp32_output_data;
+ std::vector fp16_output_data;
+ std::vector skip_kernel_types; // skip some kernels if they do not supported this test case.
};
// Disable some tests in Windows since prefast build might crash with large test data.
#ifndef _MSC_VER
// Return packed weights and bias for input projection.
-void GetAttentionWeight(std::vector& weight_data, int elements = 64 * 3 * 64, int offset = 0, int step=1);
-void GetAttentionBias(std::vector& bias_data, int elements = 3 * 64, int offset = 0, int step=1);
+void GetAttentionWeight(std::vector& weight_data, int elements = 64 * 3 * 64, int offset = 0, int step = 1);
+void GetAttentionBias(std::vector& bias_data, int elements = 3 * 64, int offset = 0, int step = 1);
void GetCrossAttentionData_HeadSize40(AttentionTestData& data);
void GetCrossAttentionData_Batch2_HeadSize32_RightSidePadding(AttentionTestData& data, bool is_mask_1d);
void GetCrossAttentionData_Batch1_HeadSize32_LeftSidePadding(AttentionTestData& data);
+
+void GetCrossAttentionData_Batch2_HeadSize32_NoBias_NoMask_PackedKV(AttentionTestData& data);
+void GetSelfAttentionData_Batch2_HeadSize32_NoBias_NoMask_PackedQKV(AttentionTestData& data);
#endif
void GetCrossAttentionData_HeadSize16_8(AttentionTestData& data);
diff --git a/onnxruntime/test/contrib_ops/multihead_attention_op_test.cc b/onnxruntime/test/contrib_ops/multihead_attention_op_test.cc
index eddf9ea4c7c4e..83b9f628655ab 100644
--- a/onnxruntime/test/contrib_ops/multihead_attention_op_test.cc
+++ b/onnxruntime/test/contrib_ops/multihead_attention_op_test.cc
@@ -16,7 +16,9 @@ static void RunMultiHeadAttentionTest(
const std::vector& query_data, // query: [batch_size, sequence_length, hidden_size]
const std::vector& key_data, // key: [batch_size, kv_sequence_length, hidden_size]
const std::vector& value_data, // value: [batch_size, kv_sequence_length, v_hidden_size]
- const std::vector& bias_data, // bias: [hidden_size + hidden_size + v_hidden_size]
+ const std::vector& kv_data, // packed_kv: [batch_size, kv_sequence_length, num_heads, 2, head_size]
+ const std::vector& qkv_data, // packed_qkv: [batch_size, sequence_length, num_heads, 3, head_size]
+ const std::vector& bias_data, // bias: [hidden_size + hidden_size + v_hidden_size] or empty
const std::vector& key_padding_mask_data, // key_padding_mask: see below
AttentionMaskType mask_type, // 1 for [batch_size], 2 for [batch_size, kv_sequence_length]
const std::vector& output_data, // output: [batch_size, sequence_length, v_hidden_size]
@@ -33,7 +35,7 @@ static void RunMultiHeadAttentionTest(
{
kv_sequence_length = (kv_sequence_length == 0 ? sequence_length : kv_sequence_length);
- int min_cuda_architecture = use_float16 ? 530 : 0;
+ int min_cuda_architecture = use_float16 ? 750 : 0;
bool enable_cuda = HasCudaEnvironment(min_cuda_architecture) && !disable_cuda;
bool enable_rocm = (nullptr != DefaultRocmExecutionProvider().get()) && !disable_rocm;
bool enable_cpu = (nullptr != DefaultCpuExecutionProvider().get()) && !use_float16 && !disable_cpu;
@@ -49,6 +51,23 @@ static void RunMultiHeadAttentionTest(
std::vector bias_dims = {hidden_size + hidden_size + v_hidden_size};
std::vector output_dims = {batch_size, sequence_length, v_hidden_size};
+ std::vector query = (qkv_data.size() > 0 ? qkv_data : query_data);
+ std::vector key;
+ std::vector value;
+ if (qkv_data.size() == 0) {
+ if (kv_data.size() > 0) {
+ ORT_ENFORCE(hidden_size == v_hidden_size);
+ key = kv_data;
+ key_dims = {batch_size, kv_sequence_length, num_heads, 2, hidden_size / num_heads};
+ } else {
+ key = key_data;
+ value = value_data;
+ }
+ } else {
+ ORT_ENFORCE(sequence_length == kv_sequence_length && hidden_size == v_hidden_size);
+ query_dims = {batch_size, sequence_length, num_heads, 3, hidden_size / num_heads};
+ }
+
std::vector mask_dims_1 = {batch_size};
std::vector mask_dims_2 = {batch_size, kv_sequence_length};
std::vector& key_padding_mask_dims = (mask_type == AttentionMaskType::MASK_1D_KEY_SEQ_LEN)
@@ -56,9 +75,19 @@ static void RunMultiHeadAttentionTest(
: mask_dims_2;
if (use_float16) {
- tester.AddInput("query", query_dims, ToFloat16(query_data));
- tester.AddInput("key", key_dims, ToFloat16(key_data));
- tester.AddInput("value", value_dims, ToFloat16(value_data));
+ tester.AddInput("query", query_dims, ToFloat16(query));
+
+ if (key.size()) {
+ tester.AddInput("key", key_dims, ToFloat16(key));
+ } else {
+ tester.AddOptionalInputEdge();
+ }
+
+ if (value.size()) {
+ tester.AddInput("value", value_dims, ToFloat16(value));
+ } else {
+ tester.AddOptionalInputEdge();
+ }
if (bias_data.size()) {
tester.AddInput("bias", bias_dims, ToFloat16(bias_data));
@@ -76,9 +105,19 @@ static void RunMultiHeadAttentionTest(
constexpr float abs_error = 0.05f;
tester.AddOutput("output", output_dims, ToFloat16(output_data), /*sort*/ false, rel_error, abs_error);
} else {
- tester.AddInput("query", query_dims, query_data);
- tester.AddInput("key", key_dims, key_data);
- tester.AddInput("value", value_dims, value_data);
+ tester.AddInput("query", query_dims, query);
+
+ if (key.size()) {
+ tester.AddInput("key", key_dims, key);
+ } else {
+ tester.AddOptionalInputEdge();
+ }
+
+ if (value.size()) {
+ tester.AddInput("value", value_dims, value);
+ } else {
+ tester.AddOptionalInputEdge();
+ }
if (bias_data.size()) {
tester.AddInput("bias", bias_dims, bias_data);
@@ -121,6 +160,8 @@ static void RunMultiHeadAttentionKernel(
const std::vector& query_data, // query: [batch_size, sequence_length, hidden_size]
const std::vector& key_data, // key: [batch_size, kv_sequence_length, hidden_size]
const std::vector& value_data, // value: [batch_size, kv_sequence_length, v_hidden_size]
+ const std::vector& kv_data, // packed_kv: [batch_size, kv_sequence_length, num_heads, 2, head_size]
+ const std::vector& qkv_data, // packed_qkv: [batch_size, sequence_length, num_heads, 3, head_size]
const std::vector& bias_data, // bias: [hidden_size + hidden_size + v_hidden_size]
const std::vector& key_padding_mask_data, // key_padding_mask: see below
AttentionMaskType mask_type, // 1 for [batch_size], 2 for [batch_size, kv_sequence_length]
@@ -144,7 +185,7 @@ static void RunMultiHeadAttentionKernel(
{onnxruntime::contrib::attention::kDisableFusedCrossAttention, "0"},
{onnxruntime::contrib::attention::kDisableMemoryEfficientAttention, "0"}}};
RunMultiHeadAttentionTest(
- query_data, key_data, value_data, bias_data, key_padding_mask_data, mask_type, output_data,
+ query_data, key_data, value_data, kv_data, qkv_data, bias_data, key_padding_mask_data, mask_type, output_data,
num_heads, batch_size, sequence_length, kv_sequence_length, hidden_size, v_hidden_size,
use_float16, disable_cpu, disable_cuda, disable_rocm);
return;
@@ -158,7 +199,7 @@ static void RunMultiHeadAttentionKernel(
{onnxruntime::contrib::attention::kDisableFusedCrossAttention, "1"},
{onnxruntime::contrib::attention::kDisableMemoryEfficientAttention, "1"}}};
RunMultiHeadAttentionTest(
- query_data, key_data, value_data, bias_data, key_padding_mask_data, mask_type, output_data,
+ query_data, key_data, value_data, kv_data, qkv_data, bias_data, key_padding_mask_data, mask_type, output_data,
num_heads, batch_size, sequence_length, kv_sequence_length, hidden_size, v_hidden_size,
use_float16, disable_cpu, disable_cuda, disable_rocm);
return;
@@ -172,7 +213,7 @@ static void RunMultiHeadAttentionKernel(
{onnxruntime::contrib::attention::kDisableFusedCrossAttention, "0"},
{onnxruntime::contrib::attention::kDisableMemoryEfficientAttention, "1"}}};
RunMultiHeadAttentionTest(
- query_data, key_data, value_data, bias_data, key_padding_mask_data, mask_type, output_data,
+ query_data, key_data, value_data, kv_data, qkv_data, bias_data, key_padding_mask_data, mask_type, output_data,
num_heads, batch_size, sequence_length, kv_sequence_length, hidden_size, v_hidden_size,
use_float16, disable_cpu, disable_cuda, disable_rocm);
return;
@@ -187,7 +228,7 @@ static void RunMultiHeadAttentionKernel(
{onnxruntime::contrib::attention::kDisableFusedCrossAttention, "1"},
{onnxruntime::contrib::attention::kDisableMemoryEfficientAttention, "0"}}};
RunMultiHeadAttentionTest(
- query_data, key_data, value_data, bias_data, key_padding_mask_data, mask_type, output_data,
+ query_data, key_data, value_data, kv_data, qkv_data, bias_data, key_padding_mask_data, mask_type, output_data,
num_heads, batch_size, sequence_length, kv_sequence_length, hidden_size, v_hidden_size,
use_float16, disable_cpu, disable_cuda, disable_rocm);
return;
@@ -202,7 +243,7 @@ static void RunMultiHeadAttentionKernel(
{onnxruntime::contrib::attention::kDisableFusedCrossAttention, "1"},
{onnxruntime::contrib::attention::kDisableMemoryEfficientAttention, "1"}}};
RunMultiHeadAttentionTest(
- query_data, key_data, value_data, bias_data, key_padding_mask_data, mask_type, output_data,
+ query_data, key_data, value_data, kv_data, qkv_data, bias_data, key_padding_mask_data, mask_type, output_data,
num_heads, batch_size, sequence_length, kv_sequence_length, hidden_size, v_hidden_size,
use_float16, disable_cpu, disable_cuda, disable_rocm);
}
@@ -215,7 +256,7 @@ static void RunMultiHeadAttentionTests(AttentionTestData& data) {
AttentionKernelType kernel_type = AttentionKernelType::AttentionKernel_Unfused;
if (!SkipAttentionKernel(data, kernel_type)) {
RunMultiHeadAttentionKernel(
- data.query_data, data.key_data, data.value_data, data.bias_data, data.key_padding_mask_data, data.mask_type,
+ data.query_data, data.key_data, data.value_data, data.kv_data, data.qkv_data, data.bias_data, data.key_padding_mask_data, data.mask_type,
data.fp32_output_data, data.num_heads, data.batch_size, data.sequence_length, data.kv_sequence_length,
data.hidden_size, data.v_hidden_size, kernel_type, use_float16);
}
@@ -226,7 +267,7 @@ static void RunMultiHeadAttentionTests(AttentionTestData& data) {
kernel_type = AttentionKernelType::AttentionKernel_CutlassMemoryEfficientAttention;
if (!SkipAttentionKernel(data, kernel_type)) {
RunMultiHeadAttentionKernel(
- data.query_data, data.key_data, data.value_data, data.bias_data, data.key_padding_mask_data, data.mask_type,
+ data.query_data, data.key_data, data.value_data, data.kv_data, data.qkv_data, data.bias_data, data.key_padding_mask_data, data.mask_type,
data.fp32_output_data, data.num_heads, data.batch_size, data.sequence_length, data.kv_sequence_length,
data.hidden_size, data.v_hidden_size, kernel_type, use_float16);
}
@@ -235,7 +276,7 @@ static void RunMultiHeadAttentionTests(AttentionTestData& data) {
kernel_type = AttentionKernelType::AttentionKernel_Default;
RunMultiHeadAttentionKernel(
- data.query_data, data.key_data, data.value_data, data.bias_data, data.key_padding_mask_data, data.mask_type,
+ data.query_data, data.key_data, data.value_data, data.kv_data, data.qkv_data, data.bias_data, data.key_padding_mask_data, data.mask_type,
data.fp32_output_data, data.num_heads, data.batch_size, data.sequence_length, data.kv_sequence_length,
data.hidden_size, data.v_hidden_size, kernel_type, use_float16);
}
@@ -245,7 +286,7 @@ static void RunMultiHeadAttentionTests(AttentionTestData& data) {
AttentionKernelType kernel_type = AttentionKernelType::AttentionKernel_TrtFusedCrossAttention;
if (!SkipAttentionKernel(data, kernel_type)) {
RunMultiHeadAttentionKernel(
- data.query_data, data.key_data, data.value_data, data.bias_data, data.key_padding_mask_data, data.mask_type,
+ data.query_data, data.key_data, data.value_data, data.kv_data, data.qkv_data, data.bias_data, data.key_padding_mask_data, data.mask_type,
data.fp16_output_data, data.num_heads, data.batch_size, data.sequence_length, data.kv_sequence_length,
data.hidden_size, data.v_hidden_size, kernel_type, use_float16);
}
@@ -253,7 +294,7 @@ static void RunMultiHeadAttentionTests(AttentionTestData& data) {
kernel_type = AttentionKernelType::AttentionKernel_TrtFusedAttention;
if (!SkipAttentionKernel(data, kernel_type)) {
RunMultiHeadAttentionKernel(
- data.query_data, data.key_data, data.value_data, data.bias_data, data.key_padding_mask_data, data.mask_type,
+ data.query_data, data.key_data, data.value_data, data.kv_data, data.qkv_data, data.bias_data, data.key_padding_mask_data, data.mask_type,
data.fp16_output_data, data.num_heads, data.batch_size, data.sequence_length, data.kv_sequence_length,
data.hidden_size, data.v_hidden_size, kernel_type, use_float16);
}
@@ -262,7 +303,7 @@ static void RunMultiHeadAttentionTests(AttentionTestData& data) {
kernel_type = AttentionKernelType::AttentionKernel_CutlassMemoryEfficientAttention;
if (!SkipAttentionKernel(data, kernel_type)) {
RunMultiHeadAttentionKernel(
- data.query_data, data.key_data, data.value_data, data.bias_data, data.key_padding_mask_data, data.mask_type,
+ data.query_data, data.key_data, data.value_data, data.kv_data, data.qkv_data, data.bias_data, data.key_padding_mask_data, data.mask_type,
data.fp16_output_data, data.num_heads, data.batch_size, data.sequence_length, data.kv_sequence_length,
data.hidden_size, data.v_hidden_size, kernel_type, use_float16);
}
@@ -270,7 +311,7 @@ static void RunMultiHeadAttentionTests(AttentionTestData& data) {
kernel_type = AttentionKernelType::AttentionKernel_Default;
RunMultiHeadAttentionKernel(
- data.query_data, data.key_data, data.value_data, data.bias_data, data.key_padding_mask_data, data.mask_type,
+ data.query_data, data.key_data, data.value_data, data.kv_data, data.qkv_data, data.bias_data, data.key_padding_mask_data, data.mask_type,
data.fp16_output_data, data.num_heads, data.batch_size, data.sequence_length, data.kv_sequence_length,
data.hidden_size, data.v_hidden_size, kernel_type, use_float16);
}
@@ -302,6 +343,18 @@ TEST(MultiHeadAttentionTest, CrossAttention_Batch1_HeadSize32_LeftSidePadding_Ma
GetCrossAttentionData_Batch1_HeadSize32_LeftSidePadding(data);
RunMultiHeadAttentionTests(data);
}
+
+TEST(MultiHeadAttentionTest, CrossAttention_Batch2_HeadSize32_NoBias_NoMask_PackedKV) {
+ AttentionTestData data;
+ GetCrossAttentionData_Batch2_HeadSize32_NoBias_NoMask_PackedKV(data);
+ RunMultiHeadAttentionTests(data);
+}
+
+TEST(MultiHeadAttentionTest, SelfAttention_Batch2_HeadSize32_NoBias_NoMask_PackedQKV) {
+ AttentionTestData data;
+ GetSelfAttentionData_Batch2_HeadSize32_NoBias_NoMask_PackedQKV(data);
+ RunMultiHeadAttentionTests(data);
+}
#endif
// This tests qk_head_size != k_head_size
diff --git a/onnxruntime/test/contrib_ops/multihead_attention_op_test_data_gen.py b/onnxruntime/test/contrib_ops/multihead_attention_op_test_data_gen.py
index 2d7bbc9b6d6c0..fcac5b3ab28eb 100644
--- a/onnxruntime/test/contrib_ops/multihead_attention_op_test_data_gen.py
+++ b/onnxruntime/test/contrib_ops/multihead_attention_op_test_data_gen.py
@@ -156,6 +156,7 @@ def run_cross_attention(
sequence_length,
kv_sequence_length,
key_padding_mask=None,
+ has_bias=True,
):
seed = 123
torch.manual_seed(seed)
@@ -170,10 +171,16 @@ def run_cross_attention(
torch.nn.init.uniform_(mha.key.weight, -0.5, 0.5)
torch.nn.init.uniform_(mha.value.weight, -0.5, 0.5)
- torch.nn.init.uniform_(mha.query.bias, -0.5, 0.5)
- torch.nn.init.uniform_(mha.key.bias, -0.5, 0.5)
- torch.nn.init.uniform_(mha.value.bias, -0.5, 0.5)
+ if has_bias:
+ torch.nn.init.uniform_(mha.query.bias, -0.5, 0.5)
+ torch.nn.init.uniform_(mha.key.bias, -0.5, 0.5)
+ torch.nn.init.uniform_(mha.value.bias, -0.5, 0.5)
+ else:
+ torch.nn.init.zeros_(mha.query.bias)
+ torch.nn.init.zeros_(mha.key.bias)
+ torch.nn.init.zeros_(mha.value.bias)
+ # Here we simulate input projection with MatMul but no bias:
w_q = nn.Linear(hidden_dim, num_heads * q_head_size).to(device).eval()
w_k = nn.Linear(hidden_dim, num_heads * q_head_size).to(device).eval()
w_v = nn.Linear(hidden_dim, num_heads * v_head_size).to(device).eval()
@@ -195,13 +202,27 @@ def run_cross_attention(
input_q = w_q(hidden_states.clone())
input_k = w_k(encoder_hidden_states.clone())
input_v = w_v(encoder_hidden_states.clone())
- input_bias = torch.concat([mha.query.bias, mha.key.bias, mha.value.bias])
print("input_q", input_q)
print("input_k", input_k)
print("input_v", input_v)
+
+ input_bias = torch.concat([mha.query.bias, mha.key.bias, mha.value.bias])
print("input_bias", input_bias)
+ if not has_bias:
+ print("no bias!")
+
+ # packed KV
+ if q_head_size == v_head_size:
+ packed_kv = torch.dstack(
+ (
+ input_k.reshape(batch_size * kv_sequence_length, num_heads, q_head_size),
+ input_v.reshape(batch_size * kv_sequence_length, num_heads, v_head_size),
+ )
+ )
+ packed_kv = packed_kv.reshape(batch_size, kv_sequence_length, num_heads, 2, q_head_size)
+ print("packed_kv_5d", packed_kv)
- output = mha.forward(
+ mha.forward(
hidden_states,
attention_mask=None,
encoder_hidden_states=encoder_hidden_states,
@@ -211,6 +232,88 @@ def run_cross_attention(
)
+def run_self_attention(
+ hidden_dim,
+ q_head_size,
+ v_head_size,
+ num_heads,
+ batch_size,
+ sequence_length,
+ key_padding_mask=None,
+ has_bias=True,
+):
+ seed = 123
+ torch.manual_seed(seed)
+ np.random.seed(seed)
+ torch.use_deterministic_algorithms(True)
+
+ device = torch.device("cuda:0")
+ mha = Attention(num_heads, hidden_dim, q_head_size, v_head_size, is_decoder=False).to(device).eval()
+ if key_padding_mask is not None:
+ key_padding_mask = key_padding_mask.to(device)
+ torch.nn.init.uniform_(mha.query.weight, -0.5, 0.5)
+ torch.nn.init.uniform_(mha.key.weight, -0.5, 0.5)
+ torch.nn.init.uniform_(mha.value.weight, -0.5, 0.5)
+
+ if has_bias:
+ torch.nn.init.uniform_(mha.query.bias, -0.5, 0.5)
+ torch.nn.init.uniform_(mha.key.bias, -0.5, 0.5)
+ torch.nn.init.uniform_(mha.value.bias, -0.5, 0.5)
+ else:
+ torch.nn.init.zeros_(mha.query.bias)
+ torch.nn.init.zeros_(mha.key.bias)
+ torch.nn.init.zeros_(mha.value.bias)
+
+ # Here we simulate input projection with MatMul but no bias:
+ w_q = nn.Linear(hidden_dim, num_heads * q_head_size).to(device).eval()
+ w_k = nn.Linear(hidden_dim, num_heads * q_head_size).to(device).eval()
+ w_v = nn.Linear(hidden_dim, num_heads * v_head_size).to(device).eval()
+ w_q.weight.copy_(mha.query.weight)
+ w_k.weight.copy_(mha.key.weight)
+ w_v.weight.copy_(mha.value.weight)
+ torch.nn.init.zeros_(w_q.bias)
+ torch.nn.init.zeros_(w_k.bias)
+ torch.nn.init.zeros_(w_v.bias)
+
+ torch.set_printoptions(profile="full", precision=8, linewidth=120, sci_mode=False)
+
+ hidden_states = torch.empty(batch_size, sequence_length, hidden_dim, device="cuda")
+ torch.nn.init.normal_(hidden_states)
+
+ input_q = w_q(hidden_states.clone())
+ input_k = w_k(hidden_states.clone())
+ input_v = w_v(hidden_states.clone())
+ print("input_q", input_q)
+ print("input_k", input_k)
+ print("input_v", input_v)
+
+ input_bias = torch.concat([mha.query.bias, mha.key.bias, mha.value.bias])
+ print("input_bias", input_bias)
+ if not has_bias:
+ print("no bias!")
+
+ # packed QKV
+ if q_head_size == v_head_size:
+ packed_qkv = torch.dstack(
+ (
+ input_q.reshape(batch_size * sequence_length, num_heads, q_head_size),
+ input_k.reshape(batch_size * sequence_length, num_heads, q_head_size),
+ input_v.reshape(batch_size * sequence_length, num_heads, v_head_size),
+ )
+ )
+ packed_qkv = packed_qkv.reshape(batch_size, sequence_length, num_heads, 3, q_head_size)
+ print("packed_qkv_5d", packed_qkv)
+
+ mha.forward(
+ hidden_states,
+ attention_mask=None,
+ encoder_hidden_states=None,
+ encoder_attention_mask=key_padding_mask,
+ past_key_value=None,
+ output_attentions=False,
+ )
+
+
def run_cross_batch2_headsize_40():
hidden_dim = 80
q_head_size = 40
@@ -293,7 +396,44 @@ def run_cross_batch1_headsize_32_left_side_padding():
)
-def create_cross_attention_test_data():
+def run_cross_batch2_headsize_32_packed_kv():
+ hidden_dim = 32
+ q_head_size = 32
+ v_head_size = 32
+ num_heads = 1
+ batch_size = 2
+ sequence_length = 2
+ kv_sequence_length = 3
+ key_padding_mask = None
+ has_bias = False
+ run_cross_attention(
+ hidden_dim,
+ q_head_size,
+ v_head_size,
+ num_heads,
+ batch_size,
+ sequence_length,
+ kv_sequence_length,
+ key_padding_mask,
+ has_bias,
+ )
+
+
+def run_self_batch2_headsize_32_packed_qkv():
+ hidden_dim = 32
+ q_head_size = 32
+ v_head_size = 32
+ num_heads = 1
+ batch_size = 2
+ sequence_length = 2
+ key_padding_mask = None
+ has_bias = False
+ run_self_attention(
+ hidden_dim, q_head_size, v_head_size, num_heads, batch_size, sequence_length, key_padding_mask, has_bias
+ )
+
+
+def create_test_data():
"""
Create test data used in attention_op_test_helper.cc and multihead_attention_op_test.cc
"""
@@ -312,6 +452,12 @@ def create_cross_attention_test_data():
print("CrossAttention_Batch1_HeadSize32_LeftSidePadding")
run_cross_batch1_headsize_32_left_side_padding()
+ print("CrossAttention_Batch2_HeadSize32_PackedKV")
+ run_cross_batch2_headsize_32_packed_kv()
+
+ print("SelfAttention_Batch2_HeadSize32_PackedQKV")
+ run_self_batch2_headsize_32_packed_qkv()
+
with torch.no_grad():
- create_cross_attention_test_data()
+ create_test_data()