-
Notifications
You must be signed in to change notification settings - Fork 27.2k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Phi-3] Bug on stale kv cache #33129
[Phi-3] Bug on stale kv cache #33129
Conversation
cc @ArthurZucker and @gante |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hey! I am not sure I understand how this is a fix unless the degradation comes from numerical differences, generation with / with out cache should be equivalent.
It's more probable that the Phi embedding does not follow the new standard here:
transformers/src/transformers/models/llama/modeling_llama.py
Lines 180 to 196 in 6cffc90
def _dynamic_frequency_update(self, position_ids, device): | |
""" | |
dynamic RoPE layers should recompute `inv_freq` in the following situations: | |
1 - growing beyond the cached sequence length (allow scaling) | |
2 - the current sequence length is in the original scale (avoid losing precision with small sequences) | |
""" | |
seq_len = torch.max(position_ids) + 1 | |
if seq_len > self.max_seq_len_cached: # growth | |
inv_freq, self.attention_scaling = self.rope_init_fn( | |
self.config, device, seq_len=seq_len, **self.rope_kwargs | |
) | |
self.register_buffer("inv_freq", inv_freq, persistent=False) # TODO joao: may break with compilation | |
self.max_seq_len_cached = seq_len | |
if seq_len < self.original_max_seq_len and self.max_seq_len_cached > self.original_max_seq_len: # reset | |
self.register_buffer("inv_freq", self.original_inv_freq, persistent=False) | |
self.max_seq_len_cached = self.original_max_seq_len |
@ArthurZucker : Let me help answer your questions:
So, it is away diff from numerical.
@ArthurZucker : From above explanation of the detail reason of the problem, we can understand the code you pointed doesn't solve our need to "recompute the kv cache for all previous 0-4096 tokens using the updated long factors". The only way we aware of on current mechanism is to set kvcache null to simulate the clean case of initial input is 4096 already. Hope this addressed your concerns. |
@ArthurZucker , @garg-amit : This bug was captured by community months ago: https://huggingface.co/microsoft/Phi-3-mini-128k-instruct/discussions/85 . @garg-amit : Could you link that post to the initial introduction part of this PR and indicate it is a targeted bug fix PR. Thanks! |
I believe that in training, a sequence uses either long or short factor depending on whether the sequence length exceed the threshold 4096. Ideally, inference shall follow same rule. That is to first "estimate" the final sequence length. For example, if we know that the final length will go beyond 4096, we shall use long factor at the beginning. It is because, use short factor for this case might cause generation targets short sequence at the begining. Reset kv cache just hide the inconsistent issue that is still there (part of sequence before 4096 was generated by short factor, and the remaining is generated by long factor). I know that it is hard to know the final sequence length since there is early exit or different exit criteria might be used. An ideal fix for inference might not be straight-forward. |
@tianleiwu : Yes.
@tianleiwu : There are two parts of this approach:
@tianleiwu : This is no problem. As explained above, we actually hope by this way because generally speaking, for the first part of sequence (e.g., 0-4096), using short factor will have better quality. As long as the generate first sub-sequence 0-4096 tokens is better quality, we prefer it. The problem is, when jump over the 4096 switch point, when switched to long factors, although the first 0-4096 tokens already generated and correct, the model training assumed the kv cache of the first 0-4096 tokens are generated by the same long factors. So what we need do is just re-compute the kv-cache of the 0-4096 tokens based on the long factors while the tokens themselves no need to change which are valid and in fact better than using long factors to re-generate. |
@jiazhan-msft, thanks for detail explanation. Current fix makes sense to me. |
Wow, super high quality issue here! Thanks a lot both and especially @jiazhan-msft for the details. Okay, then we really need to consider this PR.
Solution 1. Is kind of intermediate solution: we inverse apply rope then re-apply it with correct inv_freq. transformers/src/transformers/cache_utils.py Lines 844 to 864 in 18527bd
Really not sure how it will go, given that each new key and value depended on the old key and value, but if we assume we are just switching the direction of the vector a little bit, would make sense (as a kind of scaling) |
@ArthurZucker : Sorry I may not fully capture your idea, could you elaborate a bit more. But generally, whatever approach, when jump over the switch point 4096, all previous tokens' kvcache are not valid and need re-compute --- this is decided by consistency with how model was trained. If we want to change the model setup, e.g., how to ensure previous tokens kv-cache still valid, that means change the model setup and re-do training, that's a different story. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM otherwise. We are deprecating the api of get_seq_len as it is incompatible with compile
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you for opening this PR 🔥
The concept of the PR looks good to me 🤗 Regarding the implementation, I have two requests:
- The issue was addressed in
prepare_inputs_for_generation
. If a user uses a package that doesn't rely ongenerate
or decides to build a custom generation loop, they will still face the original issue. As such, I believe the fix should belong to the attention layers! We can usekv_seq_len
(the cache length) andq_len
(length of new tokens) to invalidate the cache. - Let's add a test to ensure we don't regress! Suggestion: using a dummy model, confirm that the cache contents in the short sequence region change after crossing the threshold (take this test as a reference)
Hi @gante : Thanks for your comments. About 1., I have no idea how that being fixed in attention layer. Did you mean the model's attention (Phi3Attention class)? The current signature of the forward is: There is no original tokens length info. There is no clear pattern to invalidate cache there either. And, it suppose need change to flashattention impl as well. Please point out code if you have. To me, the current PR change is clean and better. For
If a user wants to customize anything, that user should be aware of this specific modeling code and follow to it, that's by-design. For
Yes will add as suggested. |
…amit/transformers into gargamit/fix_mini_long_seq
[2] @gante thanks for the comment. I've added a test case for it. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you for iterating 🤗
@garg-amit @jiazhan-msft The suggestion above (1.) was mostly to avoid bad model usage when generating text without
(The most compute optimal way would be to rerotate the existing cache in the attention layer, as it would prevent having to do a forward pass on the full input. However, that would require substantial additional logic.) At the end of the day, it's your call -- I've approved the PR :) |
@gante Regarding your suggestion to move the logic to the |
@garg-amit you're absolutely right. In that case, In that case, may I suggest we throw a warning in the forward pass if we detect we a) have a cache b) we're about to switch to the long factor? That way, users wouldn't see their custom generation code failing silently 🤗 |
@gante that makes sense! Added it |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
let's merge this! Thanks all for the highquality issue!
* fix long seq bug * fixed format * fixed fn copy inconsistency * fix long seq bug * fixed format * fixed fn copy inconsistency * Addressed comments * added a unit test * fixed cache position * Added a warning msg to the forward fn * fixed test case
* fix long seq bug * fixed format * fixed fn copy inconsistency * fix long seq bug * fixed format * fixed fn copy inconsistency * Addressed comments * added a unit test * fixed cache position * Added a warning msg to the forward fn * fixed test case
What does this PR do?
When
input < original_max_position_embeddings && input tokens + generation tokens > original_max_position_embeddings
, the output often becomes garbled. This PR resolves this issue by recomputing the kv cache starting at original_max_position_embeddings+1, as it switches from a short factor to a long factor.Before submitting
Pull Request section?
to it if that's the case.
documentation guidelines, and
here are tips on formatting docstrings.
Who can review?
Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.