diff --git a/onnxruntime/contrib_ops/cpu/bert/group_query_attention.cc b/onnxruntime/contrib_ops/cpu/bert/group_query_attention.cc index a1ed35e54b008..8f662cd388c6d 100644 --- a/onnxruntime/contrib_ops/cpu/bert/group_query_attention.cc +++ b/onnxruntime/contrib_ops/cpu/bert/group_query_attention.cc @@ -106,6 +106,11 @@ Status GroupQueryAttention::Compute(OpKernelContext* context) const { allocator, batch_size, kv_num_heads_, sequence_length, head_size, value, V)); } + OrtValue RotaryQKV; + OrtValue RotaryQ; + OrtValue RotaryK; + T* q_rotary = Q.GetMutable()->MutableData(); + T* k_rotary = packed_qkv ? nullptr : K.GetMutable()->MutableData(); if (do_rotary_) { // Initialize rotary parameters rotary_embedding_helper::RotaryParameters rotary_params = {}; @@ -128,7 +133,7 @@ Status GroupQueryAttention::Compute(OpKernelContext* context) const { if (parameters.is_first_prompt) { pos_ids[0] = static_cast(0); } else { - // Note: As of now, interactive decoding supports only batch size 1 and token generation supports only sequence length 1. + // Note: As of now, continuous decoding supports only batch size 1 and token generation supports only sequence length 1. for (int b = 0; b < batch_size; b++) { const int total_seqlen = seqlens_k->Data()[b] + 1; const int past_seqlen = total_seqlen - sequence_length; @@ -144,27 +149,19 @@ Status GroupQueryAttention::Compute(OpKernelContext* context) const { // Initialize separate buffers for rotary embeddings const T* q_input; const T* k_input; - T* q_rotary; - T* k_rotary; if (packed_qkv) { - OrtValue RotaryQKV; Tensor::InitOrtValue(element_type, TensorShape({batch_size, num_heads_ + 2 * kv_num_heads_, sequence_length, head_size}), allocator, RotaryQKV); q_input = Q.Get().Data(); k_input = q_input + num_heads_ * sequence_length * head_size; q_rotary = RotaryQKV.GetMutable()->MutableData(); k_rotary = q_rotary + num_heads_ * sequence_length * head_size; - Q = RotaryQKV; } else { - OrtValue RotaryQ; Tensor::InitOrtValue(element_type, TensorShape({batch_size, num_heads_, sequence_length, head_size}), allocator, RotaryQ); - OrtValue RotaryK; Tensor::InitOrtValue(element_type, TensorShape({batch_size, kv_num_heads_, sequence_length, head_size}), allocator, RotaryK); q_input = Q.Get().Data(); k_input = K.Get().Data(); q_rotary = RotaryQ.GetMutable()->MutableData(); k_rotary = RotaryK.GetMutable()->MutableData(); - Q = RotaryQ; - K = RotaryK; } // Run rotary embedding for Q and K ORT_RETURN_IF_ERROR(RunRotaryEmbedding(tp, rotary_params, q_input, @@ -196,8 +193,8 @@ Status GroupQueryAttention::Compute(OpKernelContext* context) const { ORT_RETURN_IF_ERROR(context->GetTempSpaceAllocator(&allocator)); // Compute the attention score and apply the score to V - return ApplyAttention(Q.Get().Data(), packed_qkv ? nullptr : K.Get().Data(), - packed_qkv ? nullptr : V.Get().Data(), past_key, past_value, output, present_k, present_v, + return ApplyAttention(q_rotary, packed_qkv ? nullptr : k_rotary, packed_qkv ? nullptr : V.Get().Data(), + past_key, past_value, output, present_k, present_v, seqlens_k, parameters, allocator, context); } } // namespace contrib