From d5800ee9450abbb5911112815c64bc0491a0ad61 Mon Sep 17 00:00:00 2001 From: aciddelgado <139922440+aciddelgado@users.noreply.github.com> Date: Tue, 8 Oct 2024 09:22:05 -0700 Subject: [PATCH] Fix Memory Issue GQA CPU Rotary (#22290) ### Description In GQA there was a memory issue which was best described by @edgchen1 [here](https://github.com/microsoft/onnxruntime/issues/22252#issuecomment-2384559255) > here's the problematic code: > > https://github.com/microsoft/onnxruntime/blob/d9de054eb53034e3dc18c298e47c6cc08d5aa884/onnxruntime/contrib_ops/cpu/bert/group_query_attention.cc#L149-L157 > > annotated: > > ```c++ > if (packed_qkv) { > // Q is an OrtValue declared in the enclosing scope. > OrtValue RotaryQKV; > Tensor::InitOrtValue(element_type, TensorShape({batch_size, num_heads_ + 2 * kv_num_heads_, sequence_length, head_size}), allocator, RotaryQKV); > // Save pointer to Q's data in q_input. > q_input = Q.Get().Data(); > k_input = q_input + num_heads_ * sequence_length * head_size; > q_rotary = RotaryQKV.GetMutable()->MutableData(); > k_rotary = q_rotary + num_heads_ * sequence_length * head_size; > // Overwrite Q with RotaryQKV (OrtValues contain shared_ptr to contained value). > // Now, q_input is pointing to freed memory. > Q = RotaryQKV; > } > ``` > > later on, when we use `q_input`, there is a read access violation. > > https://github.com/microsoft/onnxruntime/blob/d9de054eb53034e3dc18c298e47c6cc08d5aa884/onnxruntime/contrib_ops/cpu/bert/group_query_attention.cc#L170-L172 > > this problem showed up when CPU allocator sharing between sessions was enabled. in that case, the CPU allocator's arena was disabled. I suspect that the default usage of the arena hid this issue. > > though I debugged into the first branch, this appears to be a problem in both branches: > > https://github.com/microsoft/onnxruntime/blob/d9de054eb53034e3dc18c298e47c6cc08d5aa884/onnxruntime/contrib_ops/cpu/bert/group_query_attention.cc#L149-L168 ### Motivation and Context Fixes a crucial bug. The issue was found here https://github.com/microsoft/onnxruntime/issues/22252 --- .../cpu/bert/group_query_attention.cc | 19 ++++++++----------- 1 file changed, 8 insertions(+), 11 deletions(-) diff --git a/onnxruntime/contrib_ops/cpu/bert/group_query_attention.cc b/onnxruntime/contrib_ops/cpu/bert/group_query_attention.cc index a1ed35e54b008..8f662cd388c6d 100644 --- a/onnxruntime/contrib_ops/cpu/bert/group_query_attention.cc +++ b/onnxruntime/contrib_ops/cpu/bert/group_query_attention.cc @@ -106,6 +106,11 @@ Status GroupQueryAttention::Compute(OpKernelContext* context) const { allocator, batch_size, kv_num_heads_, sequence_length, head_size, value, V)); } + OrtValue RotaryQKV; + OrtValue RotaryQ; + OrtValue RotaryK; + T* q_rotary = Q.GetMutable()->MutableData(); + T* k_rotary = packed_qkv ? nullptr : K.GetMutable()->MutableData(); if (do_rotary_) { // Initialize rotary parameters rotary_embedding_helper::RotaryParameters rotary_params = {}; @@ -128,7 +133,7 @@ Status GroupQueryAttention::Compute(OpKernelContext* context) const { if (parameters.is_first_prompt) { pos_ids[0] = static_cast(0); } else { - // Note: As of now, interactive decoding supports only batch size 1 and token generation supports only sequence length 1. + // Note: As of now, continuous decoding supports only batch size 1 and token generation supports only sequence length 1. for (int b = 0; b < batch_size; b++) { const int total_seqlen = seqlens_k->Data()[b] + 1; const int past_seqlen = total_seqlen - sequence_length; @@ -144,27 +149,19 @@ Status GroupQueryAttention::Compute(OpKernelContext* context) const { // Initialize separate buffers for rotary embeddings const T* q_input; const T* k_input; - T* q_rotary; - T* k_rotary; if (packed_qkv) { - OrtValue RotaryQKV; Tensor::InitOrtValue(element_type, TensorShape({batch_size, num_heads_ + 2 * kv_num_heads_, sequence_length, head_size}), allocator, RotaryQKV); q_input = Q.Get().Data(); k_input = q_input + num_heads_ * sequence_length * head_size; q_rotary = RotaryQKV.GetMutable()->MutableData(); k_rotary = q_rotary + num_heads_ * sequence_length * head_size; - Q = RotaryQKV; } else { - OrtValue RotaryQ; Tensor::InitOrtValue(element_type, TensorShape({batch_size, num_heads_, sequence_length, head_size}), allocator, RotaryQ); - OrtValue RotaryK; Tensor::InitOrtValue(element_type, TensorShape({batch_size, kv_num_heads_, sequence_length, head_size}), allocator, RotaryK); q_input = Q.Get().Data(); k_input = K.Get().Data(); q_rotary = RotaryQ.GetMutable()->MutableData(); k_rotary = RotaryK.GetMutable()->MutableData(); - Q = RotaryQ; - K = RotaryK; } // Run rotary embedding for Q and K ORT_RETURN_IF_ERROR(RunRotaryEmbedding(tp, rotary_params, q_input, @@ -196,8 +193,8 @@ Status GroupQueryAttention::Compute(OpKernelContext* context) const { ORT_RETURN_IF_ERROR(context->GetTempSpaceAllocator(&allocator)); // Compute the attention score and apply the score to V - return ApplyAttention(Q.Get().Data(), packed_qkv ? nullptr : K.Get().Data(), - packed_qkv ? nullptr : V.Get().Data(), past_key, past_value, output, present_k, present_v, + return ApplyAttention(q_rotary, packed_qkv ? nullptr : k_rotary, packed_qkv ? nullptr : V.Get().Data(), + past_key, past_value, output, present_k, present_v, seqlens_k, parameters, allocator, context); } } // namespace contrib