Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make static KV cache work. #23061

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open

Conversation

satyajandhyala
Copy link
Contributor

Description

Fix for GQA static KV cache

Motivation and Context

@@ -436,7 +436,7 @@ Status ApplyAttention(const Tensor* Q, const Tensor* K, const Tensor* V, const T
WebgpuAttentionParameters& parameters, onnxruntime::webgpu::ComputeContext& context, const Tensor* seqlen_k) {
const int output_count = std::min({context.OutputCount(), 1 + (past_key != nullptr ? 1 : 0) + (past_value != nullptr ? 1 : 0)});
const int past_sequence_length = output_count > 1 ? parameters.past_sequence_length_ : 0;
const int total_sequence_length = past_sequence_length + parameters.kv_sequence_length_;
const int total_sequence_length = parameters.is_gqa_ && parameters.past_present_share_buffer_ ? parameters.seqlen_present_kv_cache_ : (past_sequence_length + parameters.kv_sequence_length_);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For GQA, total_sequence_length is read from node input like

int total_sequence_length = *((*total_seqlen).template Data<int32_t>());

seqlen_present_kv_cache is the max buffer length, when past and present share buffer.

Copy link
Contributor Author

@satyajandhyala satyajandhyala Dec 10, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That is correct. This name is a bit confusing because we are using this variable total_sequence_length in different ways to work for both MHA and GQA. I am trying to avoid code duplication. My intention is to use the same implementation of Attention for other variations we want to support so that we get the benefit of any optimizations for all Attention related operators. This way we can limit the binary size.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here we are assigning present_sequence_length, which is seqlen_present_kv_cache in GQA parameters. The WebGPU implementation uses CheckInputs implementation in onnxruntime\contrib_ops\cpu\bert\group_query_attention_helper.h

Copy link
Contributor

@guschmue guschmue Dec 10, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

parameters.total_sequence_length_ should also work - it gets set to *((*total_seqlen).template Data<int32_t>()) in CheckInputs

Copy link
Contributor Author

@satyajandhyala satyajandhyala Dec 10, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The use of the name total_sequence_length is generalized in that sense that it corresponds to the present key and present value buffer sequence_lengths.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@guschmue guschmue added the ep:WebGPU ort-web webgpu provider label Dec 10, 2024
@tianleiwu
Copy link
Contributor

Please add a test case.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
ep:WebGPU ort-web webgpu provider
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants