Skip to content

Commit

Permalink
optimize gqa cpu (microsoft#20598)
Browse files Browse the repository at this point in the history
### Description
<!-- Describe your changes. -->
optimize the GQA implementation on CPU. Mainly optimization are:
1. compute attention on real total sequence length instead of maximum
sequence length in case past/present share same buffer
2. remove the mask
3. remove the transpose after attention x value

It improve the phi3 model
https://github.com/microsoft/onnxruntime-genai/blob/main/examples/python/phi3-qa.py
with max sequence length 2k/4k from 10 tps to 20 tps.

### Motivation and Context
<!-- - Why is this change required? What problem does it solve?
- If it fixes an open issue, please link to the issue here. -->
  • Loading branch information
yufenglee authored and poweiw committed Jun 25, 2024
1 parent 1d75078 commit 3a7a6f6
Showing 1 changed file with 12 additions and 23 deletions.
35 changes: 12 additions & 23 deletions onnxruntime/contrib_ops/cpu/bert/gqa_attention_base.h
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,6 @@ class GQAAttentionBase : public AttentionBase {
const size_t past_buff_chunk_length = static_cast<size_t>(past_buffer_sequence_length) * head_size; // L x H
const size_t present_buff_chunk_length = static_cast<size_t>(present_buffer_sequence_length) * head_size; // T x H

PrepareMaskGQA(mask_data, batch_size, sequence_length, present_buffer_sequence_length, local_window_size_, seqlens_k);
if (!past_present_share_buffer) {
memset(present_key, 0, batch_size * kv_num_heads_ * present_buffer_sequence_length * head_size * sizeof(T));
}
Expand Down Expand Up @@ -139,11 +138,6 @@ class GQAAttentionBase : public AttentionBase {
const int output_offset = static_cast<int>(i) * sequence_length * present_buffer_sequence_length;
T* output = attention_probs + output_offset;

// Broadcast mask data: (Bx)SxT -> (BxNx)SxT
memcpy(output,
mask_data + mask_offset,
probs_matrix_bytes);

const T* k;
if (packed_qkv) {
k = K + packed_batch_stride * batch_index + kv_input_chunk_length * (head_index / kv_num_heads_factor);
Expand Down Expand Up @@ -216,8 +210,7 @@ class GQAAttentionBase : public AttentionBase {
const bool is_prompt = sequence_length != 1;
const int packed_batch_stride = packed_qkv ? (num_heads_ + 2 * kv_num_heads_) * sequence_length * head_size : 0;
const int kv_num_heads_factor = num_heads_ / kv_num_heads_;
const size_t q_input_chunk_length = static_cast<size_t>(sequence_length) * head_size; // S x H
const size_t kv_input_chunk_length = static_cast<size_t>(sequence_length) * head_size; // L x H
const int kv_input_chunk_length = sequence_length * head_size; // L x H
const size_t past_buff_chunk_length = static_cast<size_t>(past_buffer_sequence_length) * head_size; // L x H
const size_t present_buff_chunk_length = static_cast<size_t>(present_buffer_sequence_length) * head_size; // T x H

Expand Down Expand Up @@ -262,21 +255,17 @@ class GQAAttentionBase : public AttentionBase {
i / kv_num_heads_factor);
}

T* current_tmp_data = reinterpret_cast<T*>(tmp_buffer) + q_input_chunk_length * static_cast<int>(i);
const int attention_probs_offset = sequence_length * present_buffer_sequence_length * static_cast<int>(i);
math::MatMul<T>(sequence_length, head_size, present_buffer_sequence_length,
attention_probs + attention_probs_offset,
v, current_tmp_data, nullptr);

// Transpose: out(B, S, N, H_v) -> out_tmp(B, N, S, H_v)
T* src = current_tmp_data;
const int dest_offset = (batch_index * sequence_length * num_heads_ + head_index) * head_size;
T* dest = output + dest_offset;
for (int j = 0; j < sequence_length; j++) {
memcpy(dest, src, bytes_to_copy_trans);
src += head_size;
dest += hidden_size;
}
T* output_current = output + (batch_index * sequence_length * num_heads_ + head_index) * head_size;
ptrdiff_t attention_probs_offset = SafeInt<ptrdiff_t>(sequence_length) * present_buffer_sequence_length * i;

math::GemmEx<T, ThreadPool>(CblasNoTrans,
CblasNoTrans,
sequence_length, head_size, total_seqlen,
1.f, /*alpha*/
attention_probs + attention_probs_offset, present_buffer_sequence_length,
v, head_size,
0.0f /*beta*/,
output_current, hidden_size, nullptr);
}
});
}
Expand Down

0 comments on commit 3a7a6f6

Please sign in to comment.