Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

DecoderMaskedMultiHeadAttention CPU kernel. #22292

Merged
merged 14 commits into from
Oct 12, 2024
10 changes: 5 additions & 5 deletions docs/ContribOperators.md
Original file line number Diff line number Diff line change
Expand Up @@ -1175,9 +1175,9 @@ This version of the operator has been available since version 1 of the 'com.micr
<dt><tt>past_sequence_length</tt> (optional) : M</dt>
<dd>When past_present_share_buffer is used, it is required to specify past_sequence_length (could be 0).Cross Attention doesn't need this input.</dd>
<dt><tt>beam_width</tt> (optional) : M</dt>
<dd>The beam width that is being used while decoding.If not provided, the beam width will be assumed to be 1.</dd>
<dd>The beam width that is being used while decoding. If not provided, the beam width will be assumed to be 1.</dd>
<dt><tt>cache_indirection</tt> (optional) : M</dt>
<dd>A buffer of shape [batch_size, beam_width, max_output_length] where an [i, j, k] entry specifieswhich beam the 'k' th token came from for the 'j' th beam for batch 'i' in the current iteration</dd>
<dd>A buffer of shape [batch_size, beam_width, max_output_length] where an `[i, j, k]` entry specifies which beam the `k`-th token came from for the `j`-th beam for batch `i` in the current iteration</dd>
<dt><tt>bias</tt> (optional) : T</dt>
<dd>Bias tensor with shape (hidden_size + hidden_size + v_hidden_size) from input projection</dd>
</dl>
Expand All @@ -1192,7 +1192,7 @@ This version of the operator has been available since version 1 of the 'com.micr
<dt><tt>present_value</tt> (optional) : T</dt>
<dd>present state for value with shape (batch_size, num_heads, total_sequence_length, head_size). If past_present_share_buffer is set, its shape is (batch_size, num_heads, max_sequence_length, head_size), while effective_seq_length = (past_sequence_length + kv_sequence_length).</dd>
<dt><tt>qk</tt> (optional) : V</dt>
<dd>normalized Q * K, of shape (batch_size, num_heads, 1, head_size). </dd>
<dd>normalized Q * K, of shape (batch_size, num_heads, 1, total_sequence_length). </dd>
tianleiwu marked this conversation as resolved.
Show resolved Hide resolved
</dl>

#### Type Constraints
Expand Down Expand Up @@ -1261,9 +1261,9 @@ This version of the operator has been available since version 1 of the 'com.micr
<dt><tt>past_sequence_length</tt> : M</dt>
<dd>When past_present_share_buffer is used, it is required to specify past_sequence_length (could be 0).</dd>
<dt><tt>beam_width</tt> (optional) : M</dt>
<dd>The beam width that is being used while decoding.If not provided, the beam width will be assumed to be 1.</dd>
<dd>The beam width that is being used while decoding. If not provided, the beam width will be assumed to be 1.</dd>
<dt><tt>cache_indirection</tt> (optional) : M</dt>
<dd>A buffer of shape [batch_size, beam_width, max_output_length] where an [i, j, k] entry specifieswhich beam the 'k' th token came from for the 'j' th beam for batch 'i' in the current iteration</dd>
<dd>A buffer of shape [batch_size, beam_width, max_output_length] where an `[i, j, k]` entry specifies which beam the `k`-th token came from for the `j`-th beam for batch `i` in the current iteration</dd>
</dl>

#### Outputs
Expand Down
39 changes: 39 additions & 0 deletions onnxruntime/contrib_ops/cpu/bert/attention_common.h
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,45 @@ struct AttentionParameters {
AttentionQkvFormat qkv_format;
};

struct DecoderMaskedMultiHeadAttentionParams : AttentionParameters {
int beam_width = 1;

// Only NeoX style rotary embedding is supported
int rotary_embedding_dim = 0;
int t_step = 0;

// Whether to use multihead attention(excludes matmul and bias)
bool is_mha = false;
bool is_cross_attention = false;
bool is_packed_qkv = false;

// Useful to better use global memory bandwidth on certain CUDA architectures.
// Turned off by default for now until we fully understand performance implications
// for all types of workloads.
// Can be turned on by appropriate environment variable (see attention_common.h).
bool kv_data_in_flight = false;

void* q = nullptr;
void* q_bias = nullptr;

void* k = nullptr;
void* k_bias = nullptr;

void* v = nullptr;
void* v_bias = nullptr;

void* attention_bias = nullptr;

void* k_cache = nullptr;
void* v_cache = nullptr;

void* out = nullptr;
void* out_qk = nullptr;

const int32_t* cache_indir = nullptr;
const int32_t* mask = nullptr; // [B, total_sequence_length]
};

// Parameters deduced from node attributes and inputs/outputs.
struct PackedAttentionParameters {
int batch_size;
Expand Down
103 changes: 71 additions & 32 deletions onnxruntime/contrib_ops/cpu/bert/attention_cpu_base.h
Original file line number Diff line number Diff line change
Expand Up @@ -36,18 +36,23 @@
int v_head_size, // head size of V (H_v)
int v_hidden_size, // hidden size of V (D_v)
const Tensor* attn_bias, // additive bias applied on scaled QK.
OpKernelContext* context) const {
OpKernelContext* context,
Tensor* scaled_qk = nullptr, // output buffer for QK (if needed)
int past_sequence_length = 0, // sequence length of past state
bool use_dmmha = false // whether used in DecoderMaskedMultiHeadAttention
) const {

Check warning on line 43 in onnxruntime/contrib_ops/cpu/bert/attention_cpu_base.h

View workflow job for this annotation

GitHub Actions / Optional Lint C++

[cpplint] reported by reviewdog 🐶 Closing ) should be moved to the previous line [whitespace/parens] [2] Raw Output: onnxruntime/contrib_ops/cpu/bert/attention_cpu_base.h:43: Closing ) should be moved to the previous line [whitespace/parens] [2]
AllocatorPtr allocator;
ORT_RETURN_IF_ERROR(context->GetTempSpaceAllocator(&allocator));

auto* tp = context->GetOperatorThreadPool();

int past_sequence_length = 0;
Tensor* present = nullptr;
if (present_key == nullptr && present_value == nullptr) {
present = GetPresent(context, past, batch_size, v_head_size, kv_sequence_length, past_sequence_length);
} else if (past_key != nullptr && past_value != nullptr) {
past_sequence_length = static_cast<int>(past_key->Shape().GetDims()[2]);
if (past_sequence_length == 0) {
if (present_key == nullptr && present_value == nullptr) {
present = GetPresent(context, past, batch_size, v_head_size, kv_sequence_length, past_sequence_length);
} else if (past_key != nullptr && past_value != nullptr) {
past_sequence_length = static_cast<int>(past_key->Shape().GetDims()[2]);
}
}

// Total sequence length including that of past state: T = P + L
Expand All @@ -71,7 +76,7 @@

if (mask_data != nullptr) {
// Convert mask from boolean (0/1) to float (mask_filter_value/0.0f).
// Merge padding mask with causual mask, and broadcast to 3D (BxSxT).
// Merge padding mask with causal mask, and broadcast to 3D (BxSxT).
PrepareMask(mask_index_data, mask_index_dims, static_cast<T*>(mask_data),
causal, batch_size, sequence_length, past_sequence_length, mask_filter_value_);
DUMP_CPU_TENSOR("Mask3D", static_cast<T*>(mask_data), batch_size, sequence_length, total_sequence_length);
Expand All @@ -85,19 +90,29 @@
T* present_key_data = present_key != nullptr ? present_key->MutableData<T>() : nullptr;
const T* past_value_data = past_value != nullptr ? past_value->Data<T>() : nullptr;
T* present_value_data = present_value != nullptr ? present_value->MutableData<T>() : nullptr;
T* scaled_qk_data = scaled_qk != nullptr ? scaled_qk->MutableData<T>() : nullptr;

const T* attn_bias_data = (attn_bias != nullptr) ? attn_bias->Data<T>() : nullptr;
auto attn_bias_dims = (attn_bias != nullptr) ? attn_bias->Shape().GetDims() : gsl::span<const int64_t>{};

// Used for DecoderMaskedMultiHeadAttention
int max_sequence_length = 0;
if (use_dmmha) {
ORT_ENFORCE(past_key != nullptr && past_value != nullptr);
max_sequence_length = static_cast<int>(past_key->Shape().GetDims()[2]);
}
std::cout << "==== max_sequence_length: " << max_sequence_length << std::endl;

Check warning on line 104 in onnxruntime/contrib_ops/cpu/bert/attention_cpu_base.h

View workflow job for this annotation

GitHub Actions / Optional Lint C++

[cpplint] reported by reviewdog 🐶 Add #include <iostream> for cout [build/include_what_you_use] [4] Raw Output: onnxruntime/contrib_ops/cpu/bert/attention_cpu_base.h:104: Add #include <iostream> for cout [build/include_what_you_use] [4]
tianleiwu marked this conversation as resolved.
Show resolved Hide resolved

// Compute the attention score.
size_t bytes = SafeInt<size_t>(batch_size) * num_heads_ * sequence_length * total_sequence_length * sizeof(T);
auto attention_probs = allocator->Alloc(bytes);
BufferUniquePtr scratch_buffer(attention_probs, BufferDeleter(allocator));
ComputeAttentionProbs<T>(static_cast<T*>(attention_probs), Q, K,
static_cast<T*>(mask_data),
batch_size, sequence_length, kv_sequence_length, past_sequence_length,
qk_head_size == 0 ? v_head_size : qk_head_size, past_data, past_key_data,
present_data, present_key_data, tp, scale, attn_bias_data, attn_bias_dims);
qk_head_size == 0 ? v_head_size : qk_head_size, past_data, past_key_data, present_data,
present_key_data, tp, scale, attn_bias_data, attn_bias_dims, scaled_qk_data, use_dmmha,
max_sequence_length);

// Compute the attentionScore * Value: out_tmp(B, N, S, H_v) = attention_probs(B, N, S, T) x V(B, N, T, H_v)
auto out_tmp_data =
Expand All @@ -106,7 +121,8 @@

ComputeVxAttentionScore(output->MutableData<T>(), static_cast<T*>(out_tmp_data), static_cast<T*>(attention_probs),
V, batch_size, sequence_length, kv_sequence_length, past_sequence_length, v_head_size,
v_hidden_size, past_data, past_value_data, present_data, present_value_data, tp);
v_hidden_size, past_data, past_value_data, present_data, present_value_data, tp, use_dmmha,
max_sequence_length);

return Status::OK();
}
Expand All @@ -117,29 +133,33 @@
// 1 x mask_data(B, N, S, T)
// attention_probs(B, N, S, T) = Softmax(attention_probs)
template <typename T>
void ComputeAttentionProbs(T* attention_probs, // output buffer with size BxNxSxT
const T* Q, // Q data. Its size is BxNxSxH
const T* K, // k data. Its size is BxNxLxH
T* mask_data, // buffer for mask data.
int batch_size, // batch size of self-attention
int sequence_length, // sequence length of self-attention (S)
int kv_sequence_length, // sequence length of cross-attention (L)
int past_sequence_length, // sequence length of past state
int head_size, // head size of self-attention
const T* past, // past state
const T* past_key, // past key only (if not using past state)
T* present, // present state
T* present_key, // present key only (if not using present state)
ThreadPool* tp, // thread pool
float scale, // scale factor
const T* attn_bias_data, // attention bias
gsl::span<const int64_t> attn_bias_dims // attention bias shape
void ComputeAttentionProbs(T* attention_probs, // output buffer with size BxNxSxT
const T* Q, // Q data. Its size is BxNxSxH
const T* K, // k data. Its size is BxNxLxH
T* mask_data, // buffer for mask data.
int batch_size, // batch size of self-attention
int sequence_length, // sequence length of self-attention (S)
int kv_sequence_length, // sequence length of cross-attention (L)
int past_sequence_length, // sequence length of past state
int head_size, // head size of self-attention
const T* past, // past state
const T* past_key, // past key only (if not using past state)
T* present, // present state
T* present_key, // present key only (if not using present state)
ThreadPool* tp, // thread pool
float scale, // scale factor
const T* attn_bias_data, // attention bias
gsl::span<const int64_t> attn_bias_dims, // attention bias shape
T* scaled_qk_data = nullptr, // scaled output QK buffer
tianleiwu marked this conversation as resolved.
Show resolved Hide resolved
bool use_dmmha = false, // whether used in DecoderMaskedMultiHeadAttention

Check warning on line 154 in onnxruntime/contrib_ops/cpu/bert/attention_cpu_base.h

View workflow job for this annotation

GitHub Actions / Optional Lint C++

[cpplint] reported by reviewdog 🐶 Lines should be <= 120 characters long [whitespace/line_length] [2] Raw Output: onnxruntime/contrib_ops/cpu/bert/attention_cpu_base.h:154: Lines should be <= 120 characters long [whitespace/line_length] [2]
tianleiwu marked this conversation as resolved.
Show resolved Hide resolved
int max_sequence_length = 0 // max sequence length of kv cache
) const {
const int total_sequence_length = past_sequence_length + kv_sequence_length; // T = P + L
const size_t past_chunk_length = static_cast<size_t>(past_sequence_length) * head_size; // P x H
const size_t q_input_chunk_length = static_cast<size_t>(sequence_length) * head_size; // S x H
const size_t kv_input_chunk_length = static_cast<size_t>(kv_sequence_length) * head_size; // L x H
const size_t present_chunk_length = past_chunk_length + kv_input_chunk_length; // T x H
const size_t cache_chunk_length = static_cast<size_t>(max_sequence_length) * head_size; // M x H

DUMP_CPU_TENSOR_INIT();
DUMP_CPU_TENSOR("Q", Q, batch_size, num_heads_, sequence_length, head_size);
Expand All @@ -164,7 +184,7 @@
}

if (present || present_key) {
double bytes_to_copy_key = static_cast<double>(sizeof(T) * present_chunk_length);
double bytes_to_copy_key = static_cast<double>(sizeof(T) * (use_dmmha ? kv_input_chunk_length : present_chunk_length));

Check warning on line 187 in onnxruntime/contrib_ops/cpu/bert/attention_cpu_base.h

View workflow job for this annotation

GitHub Actions / Optional Lint C++

[cpplint] reported by reviewdog 🐶 Lines should be <= 120 characters long [whitespace/line_length] [2] Raw Output: onnxruntime/contrib_ops/cpu/bert/attention_cpu_base.h:187: Lines should be <= 120 characters long [whitespace/line_length] [2]
Fixed Show fixed Hide fixed
unit_cost.bytes_loaded += bytes_to_copy_key;
unit_cost.bytes_stored += bytes_to_copy_key;
}
Expand Down Expand Up @@ -214,7 +234,12 @@
// Concatenate past_K and K : (BxNx)PxH, (BxNx)LxH -> (BxNx)TxH
k = ConcatStateChunk(past, k, present, past_chunk_length, present_chunk_length, i);
} else if (nullptr != present_key) {
k = ConcatStateChunk(past_key, k, present_key, past_chunk_length, present_chunk_length, i);
if (use_dmmha) {
k = present_key + cache_chunk_length * i;
memcpy(const_cast<T*>(k) + past_chunk_length, K + head_size * i, head_size * sizeof(T));
} else {
k = ConcatStateChunk(past_key, k, present_key, past_chunk_length, present_chunk_length, i);
}
}

// Compute Q*K' + AttentionMask
Expand All @@ -230,6 +255,12 @@
});
}

if (scaled_qk_data != nullptr) {
// Output the scaled Q*K^T if needed.
memcpy(scaled_qk_data, attention_probs,
SafeInt<size_t>(batch_size) * num_heads_ * sequence_length * total_sequence_length * sizeof(T));
}

DUMP_CPU_TENSOR("QK (scaled)", attention_probs, batch_size, num_heads_, sequence_length, total_sequence_length);

// attention_probs(B, N, S, T) = Softmax(attention_probs)
Expand Down Expand Up @@ -257,12 +288,15 @@
const T* past_value, // past value only (if not using past state)
T* present, // present state
T* present_value, // present value only (if not using present state)
ThreadPool* tp) const {
ThreadPool* tp,
bool use_dmmha = false,
tianleiwu marked this conversation as resolved.
Show resolved Hide resolved
int max_sequence_length = 0) const {
const int total_sequence_length = past_sequence_length + kv_sequence_length; // T = P + L
const ptrdiff_t past_chunk_length = SafeInt<ptrdiff_t>(past_sequence_length) * v_head_size; // P x H_v
const ptrdiff_t q_input_chunk_length = SafeInt<ptrdiff_t>(sequence_length) * v_head_size; // S x H_v
const ptrdiff_t kv_input_chunk_length = SafeInt<ptrdiff_t>(kv_sequence_length) * v_head_size; // L x H_v
const ptrdiff_t present_chunk_length = past_chunk_length + kv_input_chunk_length; // T x H_v
const ptrdiff_t cache_chunk_length = SafeInt<ptrdiff_t>(max_sequence_length) * v_head_size; // M x H_v

// Move the pointer of past and present to start of v values.
if (nullptr != past) {
Expand All @@ -281,7 +315,7 @@
unit_cost.bytes_stored = static_cast<double>(sequence_length * v_head_size * sizeof(T));

if (present || present_value) {
double bytes_to_copy_value = static_cast<double>(present_chunk_length * sizeof(T));
double bytes_to_copy_value = static_cast<double>((use_dmmha ? kv_input_chunk_length : present_chunk_length) * sizeof(T));

Check warning on line 318 in onnxruntime/contrib_ops/cpu/bert/attention_cpu_base.h

View workflow job for this annotation

GitHub Actions / Optional Lint C++

[cpplint] reported by reviewdog 🐶 Lines should be <= 120 characters long [whitespace/line_length] [2] Raw Output: onnxruntime/contrib_ops/cpu/bert/attention_cpu_base.h:318: Lines should be <= 120 characters long [whitespace/line_length] [2]
Fixed Show fixed Hide fixed
unit_cost.bytes_loaded += bytes_to_copy_value;
unit_cost.bytes_stored += bytes_to_copy_value;
}
Expand All @@ -299,7 +333,12 @@
// Concatenate past_V and V: (BxNx)PxH_v, (BxNx)LxH_v -> (BxNx)TxH_v
v = ConcatStateChunk(past, v, present, past_chunk_length, present_chunk_length, i);
} else if (nullptr != present_value) {
v = ConcatStateChunk(past_value, v, present_value, past_chunk_length, present_chunk_length, i);
if (use_dmmha) {
v = present_value + cache_chunk_length * i;
memcpy(const_cast<T*>(v) + past_chunk_length, V + v_head_size * i, v_head_size * sizeof(T));
} else {
v = ConcatStateChunk(past_value, v, present_value, past_chunk_length, present_chunk_length, i);
}
}

T* current_tmp_data = reinterpret_cast<T*>(tmp_buffer) + q_input_chunk_length * i;
Expand Down
Loading
Loading