Skip to content

Commit

Permalink
Fix warning; rename to output_qk
Browse files Browse the repository at this point in the history
  • Loading branch information
mindest committed Oct 12, 2024
1 parent f9f4ff3 commit 5b5e791
Show file tree
Hide file tree
Showing 2 changed files with 21 additions and 20 deletions.
12 changes: 6 additions & 6 deletions onnxruntime/contrib_ops/cpu/bert/attention_cpu_base.h
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ class AttentionCPUBase : public AttentionBase {
int v_hidden_size, // hidden size of V (D_v)
const Tensor* attn_bias, // additive bias applied on scaled QK.
OpKernelContext* context,
Tensor* scaled_qk = nullptr, // output buffer for QK (if needed)
Tensor* output_qk = nullptr, // output buffer for QK (if needed)
int past_sequence_length = 0, // sequence length of past state
bool past_present_share_buffer = false) const {
AllocatorPtr allocator;
Expand Down Expand Up @@ -89,7 +89,7 @@ class AttentionCPUBase : public AttentionBase {
T* present_key_data = present_key != nullptr ? present_key->MutableData<T>() : nullptr;
const T* past_value_data = past_value != nullptr ? past_value->Data<T>() : nullptr;
T* present_value_data = present_value != nullptr ? present_value->MutableData<T>() : nullptr;
T* scaled_qk_data = scaled_qk != nullptr ? scaled_qk->MutableData<T>() : nullptr;
T* output_qk_data = output_qk != nullptr ? output_qk->MutableData<T>() : nullptr;

const T* attn_bias_data = (attn_bias != nullptr) ? attn_bias->Data<T>() : nullptr;
auto attn_bias_dims = (attn_bias != nullptr) ? attn_bias->Shape().GetDims() : gsl::span<const int64_t>{};
Expand All @@ -109,7 +109,7 @@ class AttentionCPUBase : public AttentionBase {
static_cast<T*>(mask_data),
batch_size, sequence_length, kv_sequence_length, past_sequence_length,
qk_head_size == 0 ? v_head_size : qk_head_size, past_data, past_key_data, present_data,
present_key_data, tp, scale, attn_bias_data, attn_bias_dims, scaled_qk_data,
present_key_data, tp, scale, attn_bias_data, attn_bias_dims, output_qk_data,
past_present_share_buffer, max_sequence_length);

// Compute the attentionScore * Value: out_tmp(B, N, S, H_v) = attention_probs(B, N, S, T) x V(B, N, T, H_v)
Expand Down Expand Up @@ -148,7 +148,7 @@ class AttentionCPUBase : public AttentionBase {
float scale, // scale factor
const T* attn_bias_data, // attention bias
gsl::span<const int64_t> attn_bias_dims, // attention bias shape
T* scaled_qk_data = nullptr, // scaled output QK buffer
T* output_qk_data = nullptr, // scaled output QK buffer
bool past_present_share_buffer = false,
int max_sequence_length = 0) const {
const int total_sequence_length = past_sequence_length + kv_sequence_length; // T = P + L
Expand Down Expand Up @@ -253,9 +253,9 @@ class AttentionCPUBase : public AttentionBase {
});
}

if (scaled_qk_data != nullptr) {
if (output_qk_data != nullptr) {
// Output the scaled Q*K^T if needed.
memcpy(scaled_qk_data, attention_probs,
memcpy(output_qk_data, attention_probs,
SafeInt<size_t>(batch_size) * num_heads_ * sequence_length * total_sequence_length * sizeof(T));
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,7 @@ Status DecoderMaskedMultiHeadAttention<T>::Compute(OpKernelContext* context) con
TensorShape present_shape(present_dims);
Tensor* present_key = context->Output(kPresentOutputIndex, present_shape);
Tensor* present_value = context->Output(kPresentOutputIndex + 1, present_shape);
Tensor* cross_qk = nullptr;
Tensor* output_qk = nullptr;

// Decoder cross-attention
if (past_key == nullptr && present_key == nullptr) {
Expand Down Expand Up @@ -160,8 +160,7 @@ Status DecoderMaskedMultiHeadAttention<T>::Compute(OpKernelContext* context) con
if (output_qk_) {
int64_t qk_dims[] = {parameters.batch_size, parameters.num_heads, 1, parameters.total_sequence_length};
TensorShape qk_shape(&qk_dims[0], sizeof(qk_dims) / sizeof(qk_dims[0]));
cross_qk = context->Output(kQKOutputIndex, qk_shape);
parameters.out_qk = cross_qk->MutableData<T>();
output_qk = context->Output(kQKOutputIndex, qk_shape);
}

// Beam width (in case we are using this op inside BeamSearch)
Expand Down Expand Up @@ -191,7 +190,7 @@ Status DecoderMaskedMultiHeadAttention<T>::Compute(OpKernelContext* context) con
value->Data<T>(),
mask_index, nullptr /* past */, past_key, past_value, output, present_key, present_value,
batch_size, 1 /* sequence_length */, parameters.kv_sequence_length,
head_size, v_head_size, v_hidden_size, attention_bias, context, cross_qk);
head_size, v_head_size, v_hidden_size, attention_bias, context, output_qk);
}

OrtValue K, V;
Expand All @@ -207,7 +206,7 @@ Status DecoderMaskedMultiHeadAttention<T>::Compute(OpKernelContext* context) con
V.GetMutable<Tensor>()->MutableData<T>(),
mask_index, nullptr /* past */, past_key, past_value, output, present_key, present_value,
batch_size, 1 /* sequence_length */, parameters.kv_sequence_length,
head_size, v_head_size, v_hidden_size, attention_bias, context, cross_qk,
head_size, v_head_size, v_hidden_size, attention_bias, context, output_qk,
parameters.past_sequence_length, true /* past_present_share_buffer */);
}

Expand All @@ -219,7 +218,7 @@ Status DecoderMaskedMultiHeadAttention<T>::Compute(OpKernelContext* context) con
batch_size, parameters.past_sequence_length, parameters.max_sequence_length,
head_size, v_head_size, attention_bias, parameters.broadcast_attn_bias_dim_0,
parameters.broadcast_attn_bias_dim_1, cache_indir, context,
beam_width_value, cross_qk);
beam_width_value, output_qk);
}

template <typename T>
Expand All @@ -244,7 +243,7 @@ Status DecoderMaskedMultiHeadAttention<T>::ApplyAttentionWithBeams(
const Tensor* cache_indir,
OpKernelContext* context,
int beam_width,
Tensor* scaled_qk) const {
Tensor* output_qk) const {
AllocatorPtr allocator;
ORT_RETURN_IF_ERROR(context->GetTempSpaceAllocator(&allocator));

Expand All @@ -255,15 +254,15 @@ Status DecoderMaskedMultiHeadAttention<T>::ApplyAttentionWithBeams(
auto attention_probs = allocator->Alloc(bytes);
BufferUniquePtr scratch_buffer(attention_probs, BufferDeleter(allocator));

T* scaled_qk_data = (scaled_qk != nullptr) ? scaled_qk->MutableData<T>() : nullptr;
T* output_qk_data = (output_qk != nullptr) ? output_qk->MutableData<T>() : nullptr;

const int32_t* mask_index_data = mask_index != nullptr ? mask_index->Data<int32_t>() : nullptr;
const T* attn_bias_data = attn_bias != nullptr ? attn_bias->Data<T>() : nullptr;

ComputeAttentionProbsWithBeams(static_cast<T*>(attention_probs), Q, K, mask_index_data, batch_size,
past_sequence_length, max_sequence_length, head_size, past_key->Data<T>(),
present_key->MutableData<T>(), tp, attn_bias_data, broadcast_attn_bias_dim_0,
broadcast_attn_bias_dim_1, cache_indir->Data<int32_t>(), beam_width, scaled_qk_data);
broadcast_attn_bias_dim_1, cache_indir->Data<int32_t>(), beam_width, output_qk_data);

// Compute the attentionScore * Value: out_tmp(B, N, 1, H_v) = attention_probs(B, N, 1, T) x V(B, N, T, H_v)
auto out_tmp_data = allocator->Alloc(SafeInt<size_t>(batch_size) * num_heads_ * v_head_size * sizeof(T));
Expand Down Expand Up @@ -295,7 +294,7 @@ void DecoderMaskedMultiHeadAttention<T>::ComputeAttentionProbsWithBeams(
bool broadcast_attn_bias_dim_1,
const int32_t* cache_indir_data,
int beam_width,
T* scaled_qk_data) const {
T* output_qk_data) const {
float scale = scale_ == 0.0f ? 1.0f / sqrt(static_cast<float>(head_size)) : scale_;

TensorOpCost unit_cost;
Expand Down Expand Up @@ -356,7 +355,8 @@ void DecoderMaskedMultiHeadAttention<T>::ComputeAttentionProbsWithBeams(
// Calculate the rest of the attention_probs
for (std::ptrdiff_t j = 0; j < past_sequence_length; ++j) {
const int* beam_indices = &cache_indir_data[batch_index * max_sequence_length];
const std::ptrdiff_t beam_offset = beam_indices[j] * num_heads_ * max_sequence_length * head_size;
const std::ptrdiff_t beam_offset = static_cast<std::ptrdiff_t>(beam_indices[j]) * num_heads_ *
max_sequence_length * head_size;
const std::ptrdiff_t beam_batch_offset = (beam_batch_index * beam_width * num_heads_ + head_index) *
max_sequence_length * head_size;
const T* past_k_vec = past_key_data + beam_batch_offset + beam_offset + j * head_size;
Expand All @@ -379,9 +379,9 @@ void DecoderMaskedMultiHeadAttention<T>::ComputeAttentionProbsWithBeams(
}
});

if (scaled_qk_data != nullptr) {
if (output_qk_data != nullptr) {
// Output the scaled Q*K^T if needed.
memcpy(scaled_qk_data, attention_probs,
memcpy(output_qk_data, attention_probs,
SafeInt<size_t>(batch_size) * num_heads_ * total_sequence_length * sizeof(T));
}

Expand Down Expand Up @@ -440,7 +440,8 @@ void DecoderMaskedMultiHeadAttention<T>::ComputeVxAttentionScoreWithBeams(
{
for (std::ptrdiff_t j = 0; j < past_sequence_length; ++j) {
const int* beam_indices = &cache_indir_data[batch_index * max_sequence_length];
const std::ptrdiff_t beam_offset = beam_indices[j] * num_heads_ * max_sequence_length * v_head_size;
const std::ptrdiff_t beam_offset = static_cast<std::ptrdiff_t>(beam_indices[j]) * num_heads_ *
max_sequence_length * v_head_size;
const std::ptrdiff_t beam_batch_offset = (beam_batch_index * beam_width * num_heads_ + head_index) *
max_sequence_length * v_head_size;
const T* past_value_vec = past_value_data + beam_offset + beam_batch_offset;
Expand Down

0 comments on commit 5b5e791

Please sign in to comment.