Skip to content

Commit

Permalink
Fix Memory Issue GQA CPU Rotary (microsoft#22290)
Browse files Browse the repository at this point in the history
### Description
In GQA there was a memory issue which was best described by @edgchen1
[here](microsoft#22252 (comment))

> here's the problematic code:
> 
>
https://github.com/microsoft/onnxruntime/blob/d9de054eb53034e3dc18c298e47c6cc08d5aa884/onnxruntime/contrib_ops/cpu/bert/group_query_attention.cc#L149-L157
> 
> annotated:
> 
> ```c++
>     if (packed_qkv) {
>       // Q is an OrtValue declared in the enclosing scope.
>       OrtValue RotaryQKV;
> Tensor::InitOrtValue(element_type, TensorShape({batch_size, num_heads_
+ 2 * kv_num_heads_, sequence_length, head_size}), allocator,
RotaryQKV);
>       // Save pointer to Q's data in q_input.
>       q_input = Q.Get<Tensor>().Data<T>();
>       k_input = q_input + num_heads_ * sequence_length * head_size;
>       q_rotary = RotaryQKV.GetMutable<Tensor>()->MutableData<T>();
>       k_rotary = q_rotary + num_heads_ * sequence_length * head_size;
> // Overwrite Q with RotaryQKV (OrtValues contain shared_ptr to
contained value).
>       // Now, q_input is pointing to freed memory.
>       Q = RotaryQKV;
>     }
> ```
> 
> later on, when we use `q_input`, there is a read access violation.
> 
>
https://github.com/microsoft/onnxruntime/blob/d9de054eb53034e3dc18c298e47c6cc08d5aa884/onnxruntime/contrib_ops/cpu/bert/group_query_attention.cc#L170-L172
> 
> this problem showed up when CPU allocator sharing between sessions was
enabled. in that case, the CPU allocator's arena was disabled. I suspect
that the default usage of the arena hid this issue.
> 
> though I debugged into the first branch, this appears to be a problem
in both branches:
> 
>
https://github.com/microsoft/onnxruntime/blob/d9de054eb53034e3dc18c298e47c6cc08d5aa884/onnxruntime/contrib_ops/cpu/bert/group_query_attention.cc#L149-L168





### Motivation and Context
Fixes a crucial bug. The issue was found here
microsoft#22252
  • Loading branch information
aciddelgado authored and Ishwar Raut committed Nov 19, 2024
1 parent c0a3b0f commit d5800ee
Showing 1 changed file with 8 additions and 11 deletions.
19 changes: 8 additions & 11 deletions onnxruntime/contrib_ops/cpu/bert/group_query_attention.cc
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,11 @@ Status GroupQueryAttention<T>::Compute(OpKernelContext* context) const {
allocator, batch_size, kv_num_heads_, sequence_length, head_size, value, V));
}

OrtValue RotaryQKV;
OrtValue RotaryQ;
OrtValue RotaryK;
T* q_rotary = Q.GetMutable<Tensor>()->MutableData<T>();
T* k_rotary = packed_qkv ? nullptr : K.GetMutable<Tensor>()->MutableData<T>();
if (do_rotary_) {
// Initialize rotary parameters
rotary_embedding_helper::RotaryParameters rotary_params = {};
Expand All @@ -128,7 +133,7 @@ Status GroupQueryAttention<T>::Compute(OpKernelContext* context) const {
if (parameters.is_first_prompt) {
pos_ids[0] = static_cast<int64_t>(0);
} else {
// Note: As of now, interactive decoding supports only batch size 1 and token generation supports only sequence length 1.
// Note: As of now, continuous decoding supports only batch size 1 and token generation supports only sequence length 1.
for (int b = 0; b < batch_size; b++) {
const int total_seqlen = seqlens_k->Data<int32_t>()[b] + 1;
const int past_seqlen = total_seqlen - sequence_length;
Expand All @@ -144,27 +149,19 @@ Status GroupQueryAttention<T>::Compute(OpKernelContext* context) const {
// Initialize separate buffers for rotary embeddings
const T* q_input;
const T* k_input;
T* q_rotary;
T* k_rotary;
if (packed_qkv) {
OrtValue RotaryQKV;
Tensor::InitOrtValue(element_type, TensorShape({batch_size, num_heads_ + 2 * kv_num_heads_, sequence_length, head_size}), allocator, RotaryQKV);
q_input = Q.Get<Tensor>().Data<T>();
k_input = q_input + num_heads_ * sequence_length * head_size;
q_rotary = RotaryQKV.GetMutable<Tensor>()->MutableData<T>();
k_rotary = q_rotary + num_heads_ * sequence_length * head_size;
Q = RotaryQKV;
} else {
OrtValue RotaryQ;
Tensor::InitOrtValue(element_type, TensorShape({batch_size, num_heads_, sequence_length, head_size}), allocator, RotaryQ);
OrtValue RotaryK;
Tensor::InitOrtValue(element_type, TensorShape({batch_size, kv_num_heads_, sequence_length, head_size}), allocator, RotaryK);
q_input = Q.Get<Tensor>().Data<T>();
k_input = K.Get<Tensor>().Data<T>();
q_rotary = RotaryQ.GetMutable<Tensor>()->MutableData<T>();
k_rotary = RotaryK.GetMutable<Tensor>()->MutableData<T>();
Q = RotaryQ;
K = RotaryK;
}
// Run rotary embedding for Q and K
ORT_RETURN_IF_ERROR(RunRotaryEmbedding<T>(tp, rotary_params, q_input,
Expand Down Expand Up @@ -196,8 +193,8 @@ Status GroupQueryAttention<T>::Compute(OpKernelContext* context) const {

ORT_RETURN_IF_ERROR(context->GetTempSpaceAllocator(&allocator));
// Compute the attention score and apply the score to V
return ApplyAttention(Q.Get<Tensor>().Data<T>(), packed_qkv ? nullptr : K.Get<Tensor>().Data<T>(),
packed_qkv ? nullptr : V.Get<Tensor>().Data<T>(), past_key, past_value, output, present_k, present_v,
return ApplyAttention(q_rotary, packed_qkv ? nullptr : k_rotary, packed_qkv ? nullptr : V.Get<Tensor>().Data<T>(),
past_key, past_value, output, present_k, present_v,
seqlens_k, parameters, allocator, context);
}
} // namespace contrib
Expand Down

0 comments on commit d5800ee

Please sign in to comment.