Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make MultiHeadAttention op return attention probabilities #23125

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion onnxruntime/contrib_ops/cpu/bert/attention.cc
Original file line number Diff line number Diff line change
Expand Up @@ -335,7 +335,7 @@ Status Attention<T>::Compute(OpKernelContext* context) const {

// Compute the attention score and apply the score to V
return ApplyAttention(Q, K, V, mask_index, past, nullptr /* past_key */, nullptr /* past_value */,
output, nullptr /* present_key */, nullptr /* present_value */,
output, nullptr /* present_key */, nullptr /* present_value */, nullptr /* attn_probs */,
batch_size, sequence_length, sequence_length,
parameters.head_size, parameters.v_head_size, parameters.v_hidden_size,
attention_bias, context);
Expand Down
16 changes: 12 additions & 4 deletions onnxruntime/contrib_ops/cpu/bert/attention_cpu_base.h
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ class AttentionCPUBase : public AttentionBase {
Tensor* output, // output tensor
Tensor* present_key, // present K output tensor (if separating present KV)
Tensor* present_value, // present V output tensor (if separating present KV)
Tensor* attn_probs, // attention probabilities output tensor (optional)
int batch_size, // batch size (B)
int sequence_length, // sequence length of Q (S)
int kv_sequence_length, // sequence length of K or V (L)
Expand Down Expand Up @@ -102,10 +103,17 @@ class AttentionCPUBase : public AttentionBase {
}

// Compute the attention score.
size_t bytes = SafeInt<size_t>(batch_size) * num_heads_ * sequence_length * total_sequence_length * sizeof(T);
auto attention_probs = allocator->Alloc(bytes);
void* attention_probs = nullptr;
T* attn_probs_data = nullptr;
if (attn_probs == nullptr) {
size_t bytes = SafeInt<size_t>(batch_size) * num_heads_ * sequence_length * total_sequence_length * sizeof(T);
attention_probs = allocator->Alloc(bytes);
Copy link
Contributor

@tianleiwu tianleiwu Dec 17, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There is no need to allocate extra space if we do not output it. You can follow the handling of output_qk (temp result of q*k before softmax) in this function.

If we do not output both q*k and softmax(q*k), we can consolidate them together by using a boolean flag to indicate whether we need output the one before softmax or after softmax.

attn_probs_data = static_cast<T*>(attention_probs);
} else {
attn_probs_data = attn_probs->MutableData<T>();
}
BufferUniquePtr scratch_buffer(attention_probs, BufferDeleter(allocator));
ComputeAttentionProbs<T>(static_cast<T*>(attention_probs), Q, K,
ComputeAttentionProbs<T>(attn_probs_data, Q, K,
static_cast<T*>(mask_data),
batch_size, sequence_length, kv_sequence_length, past_sequence_length,
qk_head_size == 0 ? v_head_size : qk_head_size, past_data, past_key_data, present_data,
Expand All @@ -117,7 +125,7 @@ class AttentionCPUBase : public AttentionBase {
allocator->Alloc(SafeInt<size_t>(batch_size) * num_heads_ * sequence_length * v_head_size * sizeof(T));
BufferUniquePtr out_tmp_buffer(out_tmp_data, BufferDeleter(std::move(allocator)));

ComputeVxAttentionScore(output->MutableData<T>(), static_cast<T*>(out_tmp_data), static_cast<T*>(attention_probs),
ComputeVxAttentionScore(output->MutableData<T>(), static_cast<T*>(out_tmp_data), attn_probs_data,
V, batch_size, sequence_length, kv_sequence_length, past_sequence_length, v_head_size,
v_hidden_size, past_data, past_value_data, present_data, present_value_data, tp,
past_present_share_buffer, max_sequence_length);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -189,7 +189,7 @@ Status DecoderMaskedMultiHeadAttention<T>::Compute(OpKernelContext* context) con
key->Data<T>(),
value->Data<T>(),
mask_index, nullptr /* past */, past_key, past_value, output, present_key, present_value,
batch_size, 1 /* sequence_length */, parameters.kv_sequence_length,
nullptr /* attn_probs */, batch_size, 1 /* sequence_length */, parameters.kv_sequence_length,
head_size, v_head_size, v_hidden_size, attention_bias, context, output_qk);
}

Expand All @@ -205,7 +205,7 @@ Status DecoderMaskedMultiHeadAttention<T>::Compute(OpKernelContext* context) con
K.GetMutable<Tensor>()->MutableData<T>(),
V.GetMutable<Tensor>()->MutableData<T>(),
mask_index, nullptr /* past */, past_key, past_value, output, present_key, present_value,
batch_size, 1 /* sequence_length */, parameters.kv_sequence_length,
nullptr /* attn_probs */, batch_size, 1 /* sequence_length */, parameters.kv_sequence_length,
head_size, v_head_size, v_hidden_size, attention_bias, context, output_qk,
parameters.past_sequence_length, true /* past_present_share_buffer */);
}
Expand Down
12 changes: 10 additions & 2 deletions onnxruntime/contrib_ops/cpu/bert/multihead_attention.cc
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,13 @@
output_shape[2] = static_cast<int64_t>(parameters.v_hidden_size);
Tensor* output = context->Output(0, output_shape);

std::vector<int64_t> attn_probs_shape(4);
attn_probs_shape[0] = static_cast<int64_t>(batch_size);
attn_probs_shape[1] = static_cast<int64_t>(num_heads_);
attn_probs_shape[2] = static_cast<int64_t>(q_sequence_length);
attn_probs_shape[3] = static_cast<int64_t>(parameters.total_sequence_length);
Tensor* attn_probs = context->Output(3, attn_probs_shape);

constexpr int q_bias_offset = 0;
const int k_bias_offset = qk_hidden_size;
const int v_bias_offset = 2 * qk_hidden_size;
Expand Down Expand Up @@ -134,7 +141,7 @@
key->Data<T>(),
value->Data<T>(),
key_padding_mask, nullptr /* past */, past_key, past_value, output, present_k, present_v,
batch_size, q_sequence_length, kv_sequence_length,
attn_probs, batch_size, q_sequence_length, kv_sequence_length,
qk_head_size, v_head_size, v_hidden_size, attn_bias, context);
}

Expand All @@ -154,6 +161,7 @@
past_value == nullptr &&
present_k == nullptr &&
present_v == nullptr &&
attn_probs == nullptr && // TODO: can we support it?

Check warning on line 164 in onnxruntime/contrib_ops/cpu/bert/multihead_attention.cc

View workflow job for this annotation

GitHub Actions / Optional Lint C++

[cpplint] reported by reviewdog 🐶 Missing username in TODO; it should look like "// TODO(my_username): Stuff." [readability/todo] [2] Raw Output: onnxruntime/contrib_ops/cpu/bert/multihead_attention.cc:164: Missing username in TODO; it should look like "// TODO(my_username): Stuff." [readability/todo] [2]
l2_cache_size_ > 0) {
MlasFlashAttentionThreadedArgs args;
args.batch_size = batch_size;
Expand Down Expand Up @@ -214,7 +222,7 @@
K.GetMutable<Tensor>()->MutableData<T>(),
V.GetMutable<Tensor>()->MutableData<T>(),
key_padding_mask, nullptr /* past */, past_key, past_value, output, present_k, present_v,
batch_size, q_sequence_length, kv_sequence_length,
attn_probs, batch_size, q_sequence_length, kv_sequence_length,
qk_head_size, v_head_size, v_hidden_size, attn_bias, context);
}
} // namespace contrib
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -289,7 +289,7 @@ Status QAttention<T>::Compute(OpKernelContext* context) const {

// Compute the attention score and apply the score to V
return ApplyAttention(Q, K, V, mask_index, past_tensor, nullptr /* past_key */, nullptr /* past_value*/,
output, nullptr /* present_key */, nullptr /* present_value */,
output, nullptr /* present_key */, nullptr /* present_value */, nullptr /* attn_probs */,
batch_size, sequence_length, sequence_length,
head_size, head_size, hidden_size, nullptr /* rel_pos_bias */, context);
}
Expand Down
27 changes: 17 additions & 10 deletions onnxruntime/contrib_ops/cuda/bert/attention_impl.cu
Original file line number Diff line number Diff line change
Expand Up @@ -590,15 +590,22 @@ Status UnfusedAttention(

DUMP_TENSOR_D("QK", data.scratch, batch_size, num_heads, sequence_length, total_sequence_length);

constexpr size_t element_size = sizeof(T);
const size_t bytes = GetAttentionScratchSize(element_size, batch_size, num_heads,
sequence_length, total_sequence_length);
T* scratch2 = data.scratch + (bytes / element_size);
T* softmax_storage;
if (data.attn_probs == nullptr) {
constexpr size_t element_size = sizeof(T);
const size_t bytes = GetAttentionScratchSize(element_size, batch_size, num_heads,
sequence_length, total_sequence_length);
T* scratch2 = data.scratch + (bytes / element_size);
softmax_storage = scratch2;
}
else {
softmax_storage = data.attn_probs;
}

const bool broadcast_attn_bias_dim_0 = parameters.broadcast_attn_bias_dim_0;
const bool broadcast_attn_bias_dim_1 = parameters.broadcast_attn_bias_dim_1;

// Apply softmax and store result R to scratch2: BxNxSxT
// Apply softmax and store result R to softmax_storage: BxNxSxT
if (use_raw_attention_mask) { // 2d, 3d or 4d attention mask
const int mask_dimension = static_cast<int>(mask_index_dims.size());

Expand All @@ -612,7 +619,7 @@ Status UnfusedAttention(
ComputeSoftmaxWithRawMask<T>(
ort_stream, total_sequence_length, sequence_length, batch_size, num_heads,
mask_index, nullptr, data.attention_bias, broadcast_attn_bias_dim_0, broadcast_attn_bias_dim_1,
data.scratch, scratch2, parameters.is_unidirectional, scale, mask_dimension,
data.scratch, softmax_storage, parameters.is_unidirectional, scale, mask_dimension,
parameters.max_sequence_length, use_persistent_softmax, persistent_softmax_workspace,
parameters.mask_filter_value));
} else if (nullptr != mask_index) { // 1d mask index
Expand All @@ -622,24 +629,24 @@ Status UnfusedAttention(
ORT_RETURN_IF_ERROR(ComputeSoftmaxWithMask1D<T>(
stream, total_sequence_length, sequence_length, batch_size, num_heads,
mask_index, mask_start, data.attention_bias, broadcast_attn_bias_dim_0, broadcast_attn_bias_dim_1,
data.scratch, scratch2, parameters.is_unidirectional));
data.scratch, softmax_storage, parameters.is_unidirectional));
} else { // no mask
ORT_RETURN_IF_ERROR(
ComputeSoftmax<T>(
stream, total_sequence_length, sequence_length, batch_size, num_heads,
data.attention_bias, broadcast_attn_bias_dim_0, broadcast_attn_bias_dim_1,
data.scratch, scratch2, parameters.is_unidirectional));
data.scratch, softmax_storage, parameters.is_unidirectional));
}

DUMP_TENSOR_D("Softmax", scratch2, batch_size, num_heads, sequence_length, total_sequence_length);
DUMP_TENSOR_D("Softmax", softmax_storage, batch_size, num_heads, sequence_length, total_sequence_length);

// compute R*V (as V*R), and store in temp_output (space used by Q): BxNxSxH_v
T* temp_output = data.q;
CUBLAS_RETURN_IF_ERROR(cublasGemmStridedBatchedHelper(
cublas, CUBLAS_OP_N, CUBLAS_OP_N,
v_head_size, sequence_length, total_sequence_length,
&one, data.v, v_head_size, present_size_per_batch_v,
scratch2, total_sequence_length, sequence_length * total_sequence_length,
softmax_storage, total_sequence_length, sequence_length * total_sequence_length,
&zero, temp_output, v_head_size, sequence_length * v_head_size, batches, device_prop, parameters.use_tf32));

// Temp_output is BxNxSxH_v, transpose to output BxSxNxH_v
Expand Down
1 change: 1 addition & 0 deletions onnxruntime/contrib_ops/cuda/bert/attention_impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,7 @@ struct AttentionData {
T* present = nullptr;
T* present_key = nullptr;
T* present_value = nullptr;
T* attn_probs = nullptr;

void* fused_runner = nullptr;
const void* fused_cross_attention_kernel = nullptr;
Expand Down
15 changes: 13 additions & 2 deletions onnxruntime/contrib_ops/cuda/bert/multihead_attention.cc
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,13 @@
output_shape[2] = static_cast<int64_t>(parameters.v_hidden_size);
Tensor* output = context->Output(0, output_shape);

TensorShapeVector attn_probs_shape(4);
attn_probs_shape[0] = static_cast<int64_t>(parameters.batch_size);
attn_probs_shape[1] = static_cast<int64_t>(parameters.num_heads);
attn_probs_shape[2] = static_cast<int64_t>(sequence_length);
attn_probs_shape[3] = static_cast<int64_t>(parameters.total_sequence_length);
Tensor* attn_probs = context->Output(3, attn_probs_shape);

std::vector<int64_t> present_dims{
parameters.batch_size, parameters.num_heads, parameters.total_sequence_length, parameters.head_size};
TensorShape present_shape(present_dims);
Expand Down Expand Up @@ -172,6 +179,7 @@
parameters.past_sequence_length > 0 &&
nullptr == attention_bias &&
nullptr == key_padding_mask &&
nullptr == attn_probs && // TODO: support attn_probs

Check warning on line 182 in onnxruntime/contrib_ops/cuda/bert/multihead_attention.cc

View workflow job for this annotation

GitHub Actions / Optional Lint C++

[cpplint] reported by reviewdog 🐶 Missing username in TODO; it should look like "// TODO(my_username): Stuff." [readability/todo] [2] Raw Output: onnxruntime/contrib_ops/cuda/bert/multihead_attention.cc:182: Missing username in TODO; it should look like "// TODO(my_username): Stuff." [readability/todo] [2]
parameters.head_size == parameters.v_head_size &&
onnxruntime::lean::is_supported(device_prop,
parameters.head_size,
Expand Down Expand Up @@ -216,6 +224,7 @@
!disable_flash_attention_ &&
nullptr == attention_bias &&
nullptr == key_padding_mask &&
nullptr == attn_probs && // TODO: support attn_probs

Check warning on line 227 in onnxruntime/contrib_ops/cuda/bert/multihead_attention.cc

View workflow job for this annotation

GitHub Actions / Optional Lint C++

[cpplint] reported by reviewdog 🐶 Missing username in TODO; it should look like "// TODO(my_username): Stuff." [readability/todo] [2] Raw Output: onnxruntime/contrib_ops/cuda/bert/multihead_attention.cc:227: Missing username in TODO; it should look like "// TODO(my_username): Stuff." [readability/todo] [2]
parameters.head_size == parameters.v_head_size &&
onnxruntime::flash::is_supported(device_prop,
parameters.head_size,
Expand Down Expand Up @@ -280,7 +289,7 @@
!is_unidirectional_ &&
nullptr == key_padding_mask &&
nullptr == attention_bias &&
nullptr == past_key && nullptr == present_key &&
nullptr == past_key && nullptr == present_key && nullptr == attn_probs &&
(parameters.qkv_format == Q_K_V_BSNH || (parameters.qkv_format == Q_KV_BSNH_BSN2H && bias == nullptr)) &&
parameters.hidden_size == parameters.v_hidden_size &&
has_fused_cross_attention_kernel(sm, parameters.head_size, parameters.kv_sequence_length);
Expand All @@ -305,7 +314,7 @@
!is_unidirectional_ &&
nullptr == attention_bias &&
(parameters.qkv_format == Q_K_V_BSNH || parameters.qkv_format == QKV_BSN3H) &&
nullptr == past_key && nullptr == present_key &&
nullptr == past_key && nullptr == present_key && nullptr == attn_probs &&
is_mask_none_or_1d_k_len &&
parameters.hidden_size == parameters.v_hidden_size &&
parameters.sequence_length == parameters.kv_sequence_length && // self attention only for fused runner
Expand Down Expand Up @@ -339,6 +348,7 @@
kernel_type == AttentionKernelType::AttentionKernel_Default &&
!disable_memory_efficient_attention_ &&
is_long_sequence &&
nullptr == attn_probs && // TODO: support attn_probs

Check warning on line 351 in onnxruntime/contrib_ops/cuda/bert/multihead_attention.cc

View workflow job for this annotation

GitHub Actions / Optional Lint C++

[cpplint] reported by reviewdog 🐶 Missing username in TODO; it should look like "// TODO(my_username): Stuff." [readability/todo] [2] Raw Output: onnxruntime/contrib_ops/cuda/bert/multihead_attention.cc:351: Missing username in TODO; it should look like "// TODO(my_username): Stuff." [readability/todo] [2]
// Check whether the attention bias alignment is good for memory efficient attention.
(attention_bias == nullptr || parameters.sequence_length % (4 * sizeof(T)) == 0) &&
(nullptr == key_padding_mask || parameters.mask_type == AttentionMaskType::MASK_1D_KEY_SEQ_LEN_START) &&
Expand Down Expand Up @@ -369,6 +379,7 @@
data.output = reinterpret_cast<CudaT*>(output->MutableData<T>());
data.present_key = (nullptr == present_key) ? nullptr : reinterpret_cast<CudaT*>(present_key->MutableData<T>());
data.present_value = (nullptr == present_value) ? nullptr : reinterpret_cast<CudaT*>(present_value->MutableData<T>());
data.attn_probs = (nullptr == attn_probs) ? nullptr : reinterpret_cast<CudaT*>(attn_probs->MutableData<T>());
data.fused_runner = reinterpret_cast<void*>(fused_runner);
data.fused_cross_attention_kernel = fused_cross_attention_kernel;
data.use_flash_attention = use_flash_attention;
Expand Down
29 changes: 29 additions & 0 deletions onnxruntime/core/graph/contrib_ops/bert_defs.cc
Original file line number Diff line number Diff line change
Expand Up @@ -226,6 +226,30 @@ void MultiHeadAttentionTypeAndShapeInference(ONNX_NAMESPACE::InferenceContext& c
}
}
}

if (ctx.getNumOutputs() > 3) { // has attention_probs output
// Output 3 has shape (batch_size, num_heads, sequence_length, total_sequence_length)
if (hasInputShape(ctx, 0) && hasInputShape(ctx, past_key_index)) {
auto& query_shape = getInputShape(ctx, 0);
auto& key_shape = getInputShape(ctx, 1);
auto& key_seqlen_dim = key_shape.dim()[1];
auto& past_seqlen_dim = getInputShape(ctx, past_key_index).dim()[2];
if (key_seqlen_dim.has_dim_value() && past_seqlen_dim.has_dim_value()) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add a condition of !past_present_share_buffer here.

auto kv_sequence_length = key_seqlen_dim.dim_value();
auto past_sequence_length = past_seqlen_dim.dim_value();
int64_t total_sequence_length = kv_sequence_length + past_sequence_length;
auto num_heads = getAttribute(ctx, "num_heads", 0);

ONNX_NAMESPACE::TensorShapeProto attention_probs_shape;
*attention_probs_shape.add_dim() = query_shape.dim()[0];
attention_probs_shape.add_dim()->set_dim_value(num_heads);
*attention_probs_shape.add_dim() = query_shape.dim()[1];
attention_probs_shape.add_dim()->set_dim_value(total_sequence_length);
updateOutputShape(ctx, 3, attention_probs_shape);
propagateElemTypeFromInputToOutput(ctx, 0, 3);
}
}
}
}

// Type and shape inference for group query attention and sparse attention.
Expand Down Expand Up @@ -1034,6 +1058,11 @@ ONNX_MS_OPERATOR_SET_SCHEMA(
"or present state for self attention value with shape (batch_size, num_heads, total_sequence_length, head_size)",
"T",
OpSchema::Optional)
.Output(3,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You will need update documents (You can find the updated documents in artifacts of Windows GPU Doc Gen CI Pipeline for this PR).

"attention_probs",
"Attention probabilities with shape (batch_size, num_heads, sequence_length, total_sequence_length)",
"T",
OpSchema::Optional)
.TypeConstraint("T", {"tensor(float)", "tensor(float16)"}, "Constrain input and output to float tensors.")
.TypeConstraint("M", {"tensor(int32)"}, "Constrain mask to integer types")
.TypeAndShapeInferenceFunction([](ONNX_NAMESPACE::InferenceContext& ctx) {
Expand Down
Loading
Loading