Skip to content

Commit

Permalink
Allow returning attention probs from MultiHeadAttention
Browse files Browse the repository at this point in the history
  • Loading branch information
amancini-N committed Dec 11, 2024
1 parent b14b4ec commit 3147d51
Show file tree
Hide file tree
Showing 10 changed files with 188 additions and 24 deletions.
2 changes: 1 addition & 1 deletion onnxruntime/contrib_ops/cpu/bert/attention.cc
Original file line number Diff line number Diff line change
Expand Up @@ -335,7 +335,7 @@ Status Attention<T>::Compute(OpKernelContext* context) const {

// Compute the attention score and apply the score to V
return ApplyAttention(Q, K, V, mask_index, past, nullptr /* past_key */, nullptr /* past_value */,
output, nullptr /* present_key */, nullptr /* present_value */,
output, nullptr /* present_key */, nullptr /* present_value */, nullptr /* attn_probs */,
batch_size, sequence_length, sequence_length,
parameters.head_size, parameters.v_head_size, parameters.v_hidden_size,
attention_bias, context);
Expand Down
16 changes: 12 additions & 4 deletions onnxruntime/contrib_ops/cpu/bert/attention_cpu_base.h
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ class AttentionCPUBase : public AttentionBase {
Tensor* output, // output tensor
Tensor* present_key, // present K output tensor (if separating present KV)
Tensor* present_value, // present V output tensor (if separating present KV)
Tensor* attn_probs, // attention probabilities output tensor (optional)
int batch_size, // batch size (B)
int sequence_length, // sequence length of Q (S)
int kv_sequence_length, // sequence length of K or V (L)
Expand Down Expand Up @@ -102,10 +103,17 @@ class AttentionCPUBase : public AttentionBase {
}

// Compute the attention score.
size_t bytes = SafeInt<size_t>(batch_size) * num_heads_ * sequence_length * total_sequence_length * sizeof(T);
auto attention_probs = allocator->Alloc(bytes);
void* attention_probs = nullptr;
T* attn_probs_data = nullptr;
if (attn_probs == nullptr) {
size_t bytes = SafeInt<size_t>(batch_size) * num_heads_ * sequence_length * total_sequence_length * sizeof(T);
attention_probs = allocator->Alloc(bytes);
attn_probs_data = static_cast<T*>(attention_probs);
} else {
attn_probs_data = attn_probs->MutableData<T>();
}
BufferUniquePtr scratch_buffer(attention_probs, BufferDeleter(allocator));
ComputeAttentionProbs<T>(static_cast<T*>(attention_probs), Q, K,
ComputeAttentionProbs<T>(attn_probs_data, Q, K,
static_cast<T*>(mask_data),
batch_size, sequence_length, kv_sequence_length, past_sequence_length,
qk_head_size == 0 ? v_head_size : qk_head_size, past_data, past_key_data, present_data,
Expand All @@ -117,7 +125,7 @@ class AttentionCPUBase : public AttentionBase {
allocator->Alloc(SafeInt<size_t>(batch_size) * num_heads_ * sequence_length * v_head_size * sizeof(T));
BufferUniquePtr out_tmp_buffer(out_tmp_data, BufferDeleter(std::move(allocator)));

ComputeVxAttentionScore(output->MutableData<T>(), static_cast<T*>(out_tmp_data), static_cast<T*>(attention_probs),
ComputeVxAttentionScore(output->MutableData<T>(), static_cast<T*>(out_tmp_data), attn_probs_data,
V, batch_size, sequence_length, kv_sequence_length, past_sequence_length, v_head_size,
v_hidden_size, past_data, past_value_data, present_data, present_value_data, tp,
past_present_share_buffer, max_sequence_length);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -189,7 +189,7 @@ Status DecoderMaskedMultiHeadAttention<T>::Compute(OpKernelContext* context) con
key->Data<T>(),
value->Data<T>(),
mask_index, nullptr /* past */, past_key, past_value, output, present_key, present_value,
batch_size, 1 /* sequence_length */, parameters.kv_sequence_length,
nullptr /* attn_probs */, batch_size, 1 /* sequence_length */, parameters.kv_sequence_length,
head_size, v_head_size, v_hidden_size, attention_bias, context, output_qk);
}

Expand All @@ -205,7 +205,7 @@ Status DecoderMaskedMultiHeadAttention<T>::Compute(OpKernelContext* context) con
K.GetMutable<Tensor>()->MutableData<T>(),
V.GetMutable<Tensor>()->MutableData<T>(),
mask_index, nullptr /* past */, past_key, past_value, output, present_key, present_value,
batch_size, 1 /* sequence_length */, parameters.kv_sequence_length,
nullptr /* attn_probs */, batch_size, 1 /* sequence_length */, parameters.kv_sequence_length,
head_size, v_head_size, v_hidden_size, attention_bias, context, output_qk,
parameters.past_sequence_length, true /* past_present_share_buffer */);
}
Expand Down
12 changes: 10 additions & 2 deletions onnxruntime/contrib_ops/cpu/bert/multihead_attention.cc
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,13 @@ Status MultiHeadAttention<T>::Compute(OpKernelContext* context) const {
output_shape[2] = static_cast<int64_t>(parameters.v_hidden_size);
Tensor* output = context->Output(0, output_shape);

std::vector<int64_t> attn_probs_shape(4);
attn_probs_shape[0] = static_cast<int64_t>(batch_size);
attn_probs_shape[1] = static_cast<int64_t>(num_heads_);
attn_probs_shape[2] = static_cast<int64_t>(q_sequence_length);
attn_probs_shape[3] = static_cast<int64_t>(parameters.total_sequence_length);
Tensor* attn_probs = context->Output(3, attn_probs_shape);

constexpr int q_bias_offset = 0;
const int k_bias_offset = qk_hidden_size;
const int v_bias_offset = 2 * qk_hidden_size;
Expand Down Expand Up @@ -134,7 +141,7 @@ Status MultiHeadAttention<T>::Compute(OpKernelContext* context) const {
key->Data<T>(),
value->Data<T>(),
key_padding_mask, nullptr /* past */, past_key, past_value, output, present_k, present_v,
batch_size, q_sequence_length, kv_sequence_length,
attn_probs, batch_size, q_sequence_length, kv_sequence_length,
qk_head_size, v_head_size, v_hidden_size, attn_bias, context);
}

Expand All @@ -154,6 +161,7 @@ Status MultiHeadAttention<T>::Compute(OpKernelContext* context) const {
past_value == nullptr &&
present_k == nullptr &&
present_v == nullptr &&
attn_probs == nullptr && // TODO: can we support it?
l2_cache_size_ > 0) {
MlasFlashAttentionThreadedArgs args;
args.batch_size = batch_size;
Expand Down Expand Up @@ -214,7 +222,7 @@ Status MultiHeadAttention<T>::Compute(OpKernelContext* context) const {
K.GetMutable<Tensor>()->MutableData<T>(),
V.GetMutable<Tensor>()->MutableData<T>(),
key_padding_mask, nullptr /* past */, past_key, past_value, output, present_k, present_v,
batch_size, q_sequence_length, kv_sequence_length,
attn_probs, batch_size, q_sequence_length, kv_sequence_length,
qk_head_size, v_head_size, v_hidden_size, attn_bias, context);
}
} // namespace contrib
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -289,7 +289,7 @@ Status QAttention<T>::Compute(OpKernelContext* context) const {

// Compute the attention score and apply the score to V
return ApplyAttention(Q, K, V, mask_index, past_tensor, nullptr /* past_key */, nullptr /* past_value*/,
output, nullptr /* present_key */, nullptr /* present_value */,
output, nullptr /* present_key */, nullptr /* present_value */, nullptr /* attn_probs */,
batch_size, sequence_length, sequence_length,
head_size, head_size, hidden_size, nullptr /* rel_pos_bias */, context);
}
Expand Down
29 changes: 29 additions & 0 deletions onnxruntime/core/graph/contrib_ops/bert_defs.cc
Original file line number Diff line number Diff line change
Expand Up @@ -226,6 +226,30 @@ void MultiHeadAttentionTypeAndShapeInference(ONNX_NAMESPACE::InferenceContext& c
}
}
}

if (ctx.getNumOutputs() > 3) { // has attention_probs output
// Output 3 has shape (batch_size, num_heads, sequence_length, total_sequence_length)
if (hasInputShape(ctx, 0) && hasInputShape(ctx, past_key_index)) {
auto& query_shape = getInputShape(ctx, 0);
auto& key_shape = getInputShape(ctx, 1);
auto& key_seqlen_dim = key_shape.dim()[1];
auto& past_seqlen_dim = getInputShape(ctx, past_key_index).dim()[2];
if (key_seqlen_dim.has_dim_value() && past_seqlen_dim.has_dim_value()) {
auto kv_sequence_length = key_seqlen_dim.dim_value();
auto past_sequence_length = past_seqlen_dim.dim_value();
int64_t total_sequence_length = kv_sequence_length + past_sequence_length;
auto num_heads = getAttribute(ctx, "num_heads", 0);

ONNX_NAMESPACE::TensorShapeProto attention_probs_shape;
*attention_probs_shape.add_dim() = query_shape.dim()[0];
attention_probs_shape.add_dim()->set_dim_value(num_heads);
*attention_probs_shape.add_dim() = query_shape.dim()[1];
attention_probs_shape.add_dim()->set_dim_value(total_sequence_length);
updateOutputShape(ctx, 3, attention_probs_shape);
propagateElemTypeFromInputToOutput(ctx, 0, 3);
}
}
}
}

// Type and shape inference for group query attention and sparse attention.
Expand Down Expand Up @@ -1034,6 +1058,11 @@ ONNX_MS_OPERATOR_SET_SCHEMA(
"or present state for self attention value with shape (batch_size, num_heads, total_sequence_length, head_size)",
"T",
OpSchema::Optional)
.Output(3,
"attention_probs",
"Attention probabilities with shape (batch_size, num_heads, sequence_length, total_sequence_length)",
"T",
OpSchema::Optional)
.TypeConstraint("T", {"tensor(float)", "tensor(float16)"}, "Constrain input and output to float tensors.")
.TypeConstraint("M", {"tensor(int32)"}, "Constrain mask to integer types")
.TypeAndShapeInferenceFunction([](ONNX_NAMESPACE::InferenceContext& ctx) {
Expand Down
15 changes: 15 additions & 0 deletions onnxruntime/test/contrib_ops/attention_op_test_helper.cc
Original file line number Diff line number Diff line change
Expand Up @@ -76,13 +76,15 @@ void GetCrossAttentionData_HeadSize40(AttentionTestData& data) {
LoadTensor("CrossAttentionData_HeadSize40.bias_data", data.bias_data);
LoadTensor("CrossAttentionData_HeadSize40.fp32_output_data", data.fp32_output_data);
LoadTensor("CrossAttentionData_HeadSize40.fp16_output_data", data.fp16_output_data);
LoadTensor("CrossAttentionData_HeadSize40.attention_probs_data", data.attention_probs_data);
}

void GetCrossAttentionData_HeadSize40_NoBias(AttentionTestData& data) {
GetCrossAttentionData_HeadSize40(data);
data.bias_data.clear();
LoadTensor("CrossAttentionData_HeadSize40_NoBias.fp32_output_data", data.fp32_output_data);
data.fp16_output_data = data.fp32_output_data;
LoadTensor("CrossAttentionData_HeadSize40_NoBias.attention_probs_data", data.attention_probs_data);
}

void GetCrossAttentionData_Batch2_HeadSize32_RightSidePadding(AttentionTestData& data, bool is_mask_1d) {
Expand Down Expand Up @@ -113,6 +115,7 @@ void GetCrossAttentionData_Batch2_HeadSize32_RightSidePadding(AttentionTestData&
LoadTensor("CrossAttentionData_Batch2_HeadSize32_RightSidePadding.bias_data", data.bias_data);
LoadTensor("CrossAttentionData_Batch2_HeadSize32_RightSidePadding.fp32_output_data", data.fp32_output_data);
data.fp16_output_data = data.fp32_output_data;
LoadTensor("CrossAttentionData_Batch2_HeadSize32_RightSidePadding.attention_probs_data", data.attention_probs_data);
}

void GetCrossAttentionData_Batch2_HeadSize32_RightSidePadding_NoBias(AttentionTestData& data, bool is_mask_1d) {
Expand All @@ -121,6 +124,7 @@ void GetCrossAttentionData_Batch2_HeadSize32_RightSidePadding_NoBias(AttentionTe

LoadTensor("CrossAttentionData_Batch2_HeadSize32_RightSidePadding_NoBias.fp32_output_data", data.fp32_output_data);
data.fp16_output_data = data.fp32_output_data;
LoadTensor("CrossAttentionData_Batch2_HeadSize32_RightSidePadding_NoBias.attention_probs_data", data.attention_probs_data);
}

void GetCrossAttentionData_Batch1_HeadSize32_LeftSidePadding(AttentionTestData& data) {
Expand All @@ -145,13 +149,15 @@ void GetCrossAttentionData_Batch1_HeadSize32_LeftSidePadding(AttentionTestData&
LoadTensor("CrossAttentionData_Batch1_HeadSize32_LeftSidePadding.bias_data", data.bias_data);
LoadTensor("CrossAttentionData_Batch1_HeadSize32_LeftSidePadding.fp32_output_data", data.fp32_output_data);
data.fp16_output_data = data.fp32_output_data;
LoadTensor("CrossAttentionData_Batch1_HeadSize32_LeftSidePadding.attention_probs_data", data.attention_probs_data);
}

void GetCrossAttentionData_Batch1_HeadSize32_LeftSidePadding_NoBias(AttentionTestData& data) {
GetCrossAttentionData_Batch1_HeadSize32_LeftSidePadding(data);
data.bias_data.clear();
LoadTensor("CrossAttentionData_Batch1_HeadSize32_LeftSidePadding_NoBias.fp32_output_data", data.fp32_output_data);
data.fp16_output_data = data.fp32_output_data;
LoadTensor("CrossAttentionData_Batch1_HeadSize32_LeftSidePadding_NoBias.attention_probs_data", data.attention_probs_data);
}

void GetCrossAttentionData_Batch2_HeadSize32_NoBias_NoMask_PackedKV(AttentionTestData& data) {
Expand All @@ -174,6 +180,7 @@ void GetCrossAttentionData_Batch2_HeadSize32_NoBias_NoMask_PackedKV(AttentionTes
// Do not test fp32
data.fp32_output_data = {};
LoadTensor("CrossAttentionData_Batch2_HeadSize32_NoBias_NoMask_PackedKV.fp16_output_data", data.fp16_output_data);
LoadTensor("CrossAttentionData_Batch2_HeadSize32_NoBias_NoMask_PackedKV.attention_probs_data", data.attention_probs_data);
}

void GetSelfAttentionData_Batch2_HeadSize32_NoBias_NoMask_PackedQKV(AttentionTestData& data) {
Expand Down Expand Up @@ -217,13 +224,15 @@ void GetCrossAttentionData_HeadSize16_8(AttentionTestData& data) {
LoadTensor("CrossAttentionData_HeadSize16_8.bias_data", data.bias_data);
LoadTensor("CrossAttentionData_HeadSize16_8.fp32_output_data", data.fp32_output_data);
data.fp16_output_data = data.fp32_output_data;
LoadTensor("CrossAttentionData_HeadSize16_8.attention_probs_data", data.attention_probs_data);
}

void GetCrossAttentionData_HeadSize16_8_NoBias(AttentionTestData& data) {
GetCrossAttentionData_HeadSize16_8(data);
data.bias_data.clear();
LoadTensor("CrossAttentionData_HeadSize16_8_NoBias.fp32_output_data", data.fp32_output_data);
data.fp16_output_data = data.fp32_output_data;
LoadTensor("CrossAttentionData_HeadSize16_8_NoBias.attention_probs_data", data.attention_probs_data);
}

void GetCrossAttentionData_HeadSize16(AttentionTestData& data) {
Expand All @@ -241,13 +250,15 @@ void GetCrossAttentionData_HeadSize16(AttentionTestData& data) {
LoadTensor("CrossAttentionData_HeadSize16.bias_data", data.bias_data);
LoadTensor("CrossAttentionData_HeadSize16.fp32_output_data", data.fp32_output_data);
data.fp16_output_data = data.fp32_output_data;
LoadTensor("CrossAttentionData_HeadSize16.attention_probs_data", data.attention_probs_data);
}

void GetCrossAttentionData_HeadSize16_NoBias(AttentionTestData& data) {
GetCrossAttentionData_HeadSize16(data);
data.bias_data.clear();
LoadTensor("CrossAttentionData_HeadSize16_NoBias.fp32_output_data", data.fp32_output_data);
data.fp16_output_data = data.fp32_output_data;
LoadTensor("CrossAttentionData_HeadSize16_NoBias.attention_probs_data", data.attention_probs_data);
}

void GetCrossAttentionData_HeadSize8(AttentionTestData& data) {
Expand All @@ -265,13 +276,15 @@ void GetCrossAttentionData_HeadSize8(AttentionTestData& data) {
LoadTensor("CrossAttention_Batch1_HeadSize8.bias_data", data.bias_data);
LoadTensor("CrossAttention_Batch1_HeadSize8.output", data.fp32_output_data);
data.fp16_output_data = data.fp32_output_data;
LoadTensor("CrossAttention_Batch1_HeadSize8.attention_probs_data", data.attention_probs_data);
}

void GetCrossAttentionData_HeadSize8_NoBias(AttentionTestData& data) {
GetCrossAttentionData_HeadSize8(data);
data.bias_data.clear();
LoadTensor("CrossAttention_Batch1_HeadSize8_NoBias.output", data.fp32_output_data);
data.fp16_output_data = data.fp32_output_data;
LoadTensor("CrossAttention_Batch1_HeadSize8_NoBias.attention_probs_data", data.attention_probs_data);
}

void GetCrossAttentionDataWithPast(AttentionTestData& data) {
Expand Down Expand Up @@ -406,6 +419,7 @@ void GetCrossAttentionData_DiffSequenceLengths_HeadSize8(AttentionTestData& data
LoadTensor("CrossAttentionData_DiffSequenceLengths_HeadSize8.present_key_data", data.present_key_data);
LoadTensor("CrossAttentionData_DiffSequenceLengths_HeadSize8.present_value_data", data.present_value_data);
data.is_static_kv = true;
LoadTensor("CrossAttentionData_DiffSequenceLengths_HeadSize8.attention_probs_data", data.attention_probs_data);
}

void GetCrossAttentionData_DiffSequenceLengths_HeadSize8_NoBias(AttentionTestData& data) {
Expand All @@ -416,6 +430,7 @@ void GetCrossAttentionData_DiffSequenceLengths_HeadSize8_NoBias(AttentionTestDat
LoadTensor("CrossAttentionData_DiffSequenceLengths_HeadSize8_NoBias.present_key_data", data.present_key_data);
LoadTensor("CrossAttentionData_DiffSequenceLengths_HeadSize8_NoBias.present_value_data", data.present_value_data);
data.is_static_kv = true;
LoadTensor("CrossAttentionData_DiffSequenceLengths_HeadSize8_NoBias.attention_probs_data", data.attention_probs_data);
}

void GetSelfAttentionData_WithPastAndPresent_NoMask_NoAttnBias(AttentionTestData& data) {
Expand Down
1 change: 1 addition & 0 deletions onnxruntime/test/contrib_ops/attention_op_test_helper.h
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ struct BaseAttentionTestData {

std::vector<float> present_key_data;
std::vector<float> present_value_data;
std::vector<float> attention_probs_data;

std::vector<AttentionKernelType> skip_kernel_types; // skip some kernels if they do not supported this test case.
};
Expand Down
Loading

0 comments on commit 3147d51

Please sign in to comment.