Skip to content

Commit

Permalink
fix bug: input q/k/v should not be modified by operator (microsoft#20555
Browse files Browse the repository at this point in the history
)

### Description
<!-- Describe your changes. -->
Operator should not modify input tensors because they are managed by
framework and may be reused by other nodes.


### Motivation and Context
<!-- - Why is this change required? What problem does it solve?
- If it fixes an open issue, please link to the issue here. -->
  • Loading branch information
yufenglee authored and Ted Themistokleous committed May 7, 2024
1 parent 76db926 commit 62c8dad
Showing 1 changed file with 1 addition and 5 deletions.
6 changes: 1 addition & 5 deletions onnxruntime/contrib_ops/cpu/bert/group_query_attention.cc
Original file line number Diff line number Diff line change
Expand Up @@ -103,17 +103,13 @@ Status GroupQueryAttention<T>::Compute(OpKernelContext* context) const {
if (packed_qkv) {
ORT_RETURN_IF_ERROR(MaybeTransposeToBNSH<T>(
allocator, batch_size, num_heads_ + 2 * kv_num_heads_, sequence_length, head_size, query, Q));
} else if (sequence_length > 1) {
} else {
ORT_RETURN_IF_ERROR(MaybeTransposeToBNSH<T>(
allocator, batch_size, num_heads_, sequence_length, head_size, query, Q));
ORT_RETURN_IF_ERROR(MaybeTransposeToBNSH<T>(
allocator, batch_size, kv_num_heads_, sequence_length, head_size, key, K));
ORT_RETURN_IF_ERROR(MaybeTransposeToBNSH<T>(
allocator, batch_size, kv_num_heads_, sequence_length, head_size, value, V));
} else {
Tensor::InitOrtValue(std::move(const_cast<Tensor&>(*query)), Q);
Tensor::InitOrtValue(std::move(const_cast<Tensor&>(*key)), K);
Tensor::InitOrtValue(std::move(const_cast<Tensor&>(*value)), V);
}

if (do_rotary_) {
Expand Down

0 comments on commit 62c8dad

Please sign in to comment.