Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Fix Memory Issue GQA CPU Rotary (microsoft#22290)
### Description In GQA there was a memory issue which was best described by @edgchen1 [here](microsoft#22252 (comment)) > here's the problematic code: > > https://github.com/microsoft/onnxruntime/blob/d9de054eb53034e3dc18c298e47c6cc08d5aa884/onnxruntime/contrib_ops/cpu/bert/group_query_attention.cc#L149-L157 > > annotated: > > ```c++ > if (packed_qkv) { > // Q is an OrtValue declared in the enclosing scope. > OrtValue RotaryQKV; > Tensor::InitOrtValue(element_type, TensorShape({batch_size, num_heads_ + 2 * kv_num_heads_, sequence_length, head_size}), allocator, RotaryQKV); > // Save pointer to Q's data in q_input. > q_input = Q.Get<Tensor>().Data<T>(); > k_input = q_input + num_heads_ * sequence_length * head_size; > q_rotary = RotaryQKV.GetMutable<Tensor>()->MutableData<T>(); > k_rotary = q_rotary + num_heads_ * sequence_length * head_size; > // Overwrite Q with RotaryQKV (OrtValues contain shared_ptr to contained value). > // Now, q_input is pointing to freed memory. > Q = RotaryQKV; > } > ``` > > later on, when we use `q_input`, there is a read access violation. > > https://github.com/microsoft/onnxruntime/blob/d9de054eb53034e3dc18c298e47c6cc08d5aa884/onnxruntime/contrib_ops/cpu/bert/group_query_attention.cc#L170-L172 > > this problem showed up when CPU allocator sharing between sessions was enabled. in that case, the CPU allocator's arena was disabled. I suspect that the default usage of the arena hid this issue. > > though I debugged into the first branch, this appears to be a problem in both branches: > > https://github.com/microsoft/onnxruntime/blob/d9de054eb53034e3dc18c298e47c6cc08d5aa884/onnxruntime/contrib_ops/cpu/bert/group_query_attention.cc#L149-L168 ### Motivation and Context Fixes a crucial bug. The issue was found here microsoft#22252
- Loading branch information