Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Phi-3] Bug on stale kv cache #33129

Merged
merged 13 commits into from
Sep 13, 2024

Conversation

garg-amit
Copy link
Contributor

@garg-amit garg-amit commented Aug 26, 2024

What does this PR do?

When input < original_max_position_embeddings && input tokens + generation tokens > original_max_position_embeddings, the output often becomes garbled. This PR resolves this issue by recomputing the kv cache starting at original_max_position_embeddings+1, as it switches from a short factor to a long factor.

Before submitting

  • This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case).
  • Did you read the contributor guideline,
    Pull Request section?
  • Was this discussed/approved via a Github issue or the forum? Please add a link
    to it if that's the case.
  • Did you make sure to update the documentation with your changes? Here are the
    documentation guidelines, and
    here are tips on formatting docstrings.
  • Did you write any new necessary tests?

Who can review?

Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.

@garg-amit garg-amit closed this Aug 26, 2024
@garg-amit garg-amit changed the title fix long seq bug [Phi-3] Bug on stale kv cache Aug 26, 2024
@garg-amit garg-amit reopened this Aug 26, 2024
@LysandreJik
Copy link
Member

cc @ArthurZucker and @gante

Copy link
Collaborator

@ArthurZucker ArthurZucker left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hey! I am not sure I understand how this is a fix unless the degradation comes from numerical differences, generation with / with out cache should be equivalent.

It's more probable that the Phi embedding does not follow the new standard here:

def _dynamic_frequency_update(self, position_ids, device):
"""
dynamic RoPE layers should recompute `inv_freq` in the following situations:
1 - growing beyond the cached sequence length (allow scaling)
2 - the current sequence length is in the original scale (avoid losing precision with small sequences)
"""
seq_len = torch.max(position_ids) + 1
if seq_len > self.max_seq_len_cached: # growth
inv_freq, self.attention_scaling = self.rope_init_fn(
self.config, device, seq_len=seq_len, **self.rope_kwargs
)
self.register_buffer("inv_freq", inv_freq, persistent=False) # TODO joao: may break with compilation
self.max_seq_len_cached = seq_len
if seq_len < self.original_max_seq_len and self.max_seq_len_cached > self.original_max_seq_len: # reset
self.register_buffer("inv_freq", self.original_inv_freq, persistent=False)
self.max_seq_len_cached = self.original_max_seq_len

@jiazhan-msft
Copy link

jiazhan-msft commented Aug 28, 2024

Hey! I am not sure I understand how this is a fix unless the degradation comes from numerical differences, generation with / with out cache should be equivalent.

@ArthurZucker : Let me help answer your questions:

  • It is not because numerical diff. It was because Phi3 128k models are using two set of rope factors, one for short and one for longer (e.g., >original_max_position_embeddingin this case, I'll use 4096 to refer it for simplicity):

    • If a sequence length <= original_max_position_embedding (e.g., 4096), it uses the short factors
    • If a sequence length > original_max_position_embedding (4096), use long factors, note: not only starting from 4097-th token but assumes all previous 0-4096 tokens are all computed using the long factors. This is the key of this problem.

    The logic is reflected here.

  • However, in inference stage, the generation is token by token and the sequence length is growing. The above logic causes problem on one scenario: when the initial input sequence < 4096 but the input + output > 4096:

    • The first 0-4096 tokens' kv cache are computed based on the short factors
    • From the 4097-th token, the model switches to use long factors, and most importantly, it requires and assumes all previous 0-4096 tokens kv cache computed using the long factors as well --- unfortunately this is not true. In other words, all previous 0-4096 tokens kv cache need to be recomputed based on the new long factors. Note: it is the all 0-4096 tokens kv cache need recomputed not the rotary embedding vars like inv_freq etc.

So, it is away diff from numerical.

It's more probable that the Phi embedding does not follow the new standard here:

def _dynamic_frequency_update(self, position_ids, device):
"""
dynamic RoPE layers should recompute `inv_freq` in the following situations:
1 - growing beyond the cached sequence length (allow scaling)
2 - the current sequence length is in the original scale (avoid losing precision with small sequences)
"""
seq_len = torch.max(position_ids) + 1
if seq_len > self.max_seq_len_cached: # growth
inv_freq, self.attention_scaling = self.rope_init_fn(
self.config, device, seq_len=seq_len, **self.rope_kwargs
)
self.register_buffer("inv_freq", inv_freq, persistent=False) # TODO joao: may break with compilation
self.max_seq_len_cached = seq_len
if seq_len < self.original_max_seq_len and self.max_seq_len_cached > self.original_max_seq_len: # reset
self.register_buffer("inv_freq", self.original_inv_freq, persistent=False)
self.max_seq_len_cached = self.original_max_seq_len

@ArthurZucker : From above explanation of the detail reason of the problem, we can understand the code you pointed doesn't solve our need to "recompute the kv cache for all previous 0-4096 tokens using the updated long factors". The only way we aware of on current mechanism is to set kvcache null to simulate the clean case of initial input is 4096 already.

Hope this addressed your concerns.

@jiazhan-msft
Copy link

@ArthurZucker , @garg-amit : This bug was captured by community months ago: https://huggingface.co/microsoft/Phi-3-mini-128k-instruct/discussions/85 . @garg-amit : Could you link that post to the initial introduction part of this PR and indicate it is a targeted bug fix PR. Thanks!

@tianleiwu
Copy link
Contributor

I believe that in training, a sequence uses either long or short factor depending on whether the sequence length exceed the threshold 4096.

Ideally, inference shall follow same rule. That is to first "estimate" the final sequence length. For example, if we know that the final length will go beyond 4096, we shall use long factor at the beginning. It is because, use short factor for this case might cause generation targets short sequence at the begining. Reset kv cache just hide the inconsistent issue that is still there (part of sequence before 4096 was generated by short factor, and the remaining is generated by long factor).

I know that it is hard to know the final sequence length since there is early exit or different exit criteria might be used. An ideal fix for inference might not be straight-forward.

@jiazhan-msft
Copy link

I believe that in training, a sequence uses either long or short factor depending on whether the sequence length exceed the threshold 4096.

@tianleiwu : Yes.

Ideally, inference shall follow same rule. That is to first "estimate" the final sequence length. For example, if we know that the final length will go beyond 4096, we shall use long factor at the beginning. It is because, use short factor for this case might cause generation targets short sequence at the begining. Reset kv cache just hide the inconsistent issue that is still there (part of sequence before 4096 was generated by short factor, and the remaining is generated by long factor).

I know that it is hard to know the final sequence length since there is early exit or different exit criteria might be used. An ideal fix for inference might not be straight-forward.

@tianleiwu : There are two parts of this approach:

  • (1) As you said, no way to really accurately estimate
  • (2) In fact, the reason we need two sets of factors are: using long factors to generate 0-4096 tokens is worse than using short-factors. So we actually require/prefer the model to use short factors to generate the first 0-4096 tokens (even if we know user input + output will > 4096).

"part of sequence before 4096 was generated by short factor, and the remaining is generated by long factor"

@tianleiwu : This is no problem. As explained above, we actually hope by this way because generally speaking, for the first part of sequence (e.g., 0-4096), using short factor will have better quality. As long as the generate first sub-sequence 0-4096 tokens is better quality, we prefer it. The problem is, when jump over the 4096 switch point, when switched to long factors, although the first 0-4096 tokens already generated and correct, the model training assumed the kv cache of the first 0-4096 tokens are generated by the same long factors. So what we need do is just re-compute the kv-cache of the 0-4096 tokens based on the long factors while the tokens themselves no need to change which are valid and in fact better than using long factors to re-generate.

@tianleiwu
Copy link
Contributor

@jiazhan-msft, thanks for detail explanation. Current fix makes sense to me.

@ArthurZucker
Copy link
Collaborator

ArthurZucker commented Aug 29, 2024

Wow, super high quality issue here! Thanks a lot both and especially @jiazhan-msft for the details.

Okay, then we really need to consider this PR.

  1. Can we not just re apply rope to the cache in the forward function?
  2. Otherwise yeah clearing the cache might be the only solution but will pretty weird.

Solution 1. Is kind of intermediate solution: we inverse apply rope then re-apply it with correct inv_freq.

def _get_rerotation_cos_sin(
self, key_states: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor]:
if key_states.shape[-2] not in self.cos_sin_rerotation_cache:
# Upcast to float32 temporarily for better accuracy
cos = cos.to(torch.float32)
sin = sin.to(torch.float32)
# Compute the cos and sin required for back- and forward-rotating to one position earlier in the sequence
original_cos = cos[self.num_sink_tokens + key_states.shape[-2] :]
shifted_cos = cos[self.num_sink_tokens : -key_states.shape[-2]]
original_sin = sin[self.num_sink_tokens + key_states.shape[-2] :]
shifted_sin = sin[self.num_sink_tokens : -key_states.shape[-2]]
rerotation_cos = original_cos * shifted_cos + original_sin * shifted_sin
rerotation_sin = -original_sin * shifted_cos + original_cos * shifted_sin
self.cos_sin_rerotation_cache[key_states.shape[-2]] = (
rerotation_cos.to(key_states.dtype).unsqueeze(0),
rerotation_sin.to(key_states.dtype).unsqueeze(0),
)
return self.cos_sin_rerotation_cache[key_states.shape[-2]]
has somewhat of a draft as rope was applied on the sink tokens, and needs to be removed / re-applied

Really not sure how it will go, given that each new key and value depended on the old key and value, but if we assume we are just switching the direction of the vector a little bit, would make sense (as a kind of scaling)

@jiazhan-msft
Copy link

Wow, super high quality issue here! Thanks a lot both and especially @jiazhan-msft for the details.

Okay, then we really need to consider this PR.

  1. Can we not just re apply rope to the cache in the forward function?
  2. Otherwise yeah clearing the cache might be the only solution but will pretty weird.

Solution 1. Is kind of intermediate solution: we inverse apply rope then re-apply it with correct inv_freq.

def _get_rerotation_cos_sin(
self, key_states: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor]:
if key_states.shape[-2] not in self.cos_sin_rerotation_cache:
# Upcast to float32 temporarily for better accuracy
cos = cos.to(torch.float32)
sin = sin.to(torch.float32)
# Compute the cos and sin required for back- and forward-rotating to one position earlier in the sequence
original_cos = cos[self.num_sink_tokens + key_states.shape[-2] :]
shifted_cos = cos[self.num_sink_tokens : -key_states.shape[-2]]
original_sin = sin[self.num_sink_tokens + key_states.shape[-2] :]
shifted_sin = sin[self.num_sink_tokens : -key_states.shape[-2]]
rerotation_cos = original_cos * shifted_cos + original_sin * shifted_sin
rerotation_sin = -original_sin * shifted_cos + original_cos * shifted_sin
self.cos_sin_rerotation_cache[key_states.shape[-2]] = (
rerotation_cos.to(key_states.dtype).unsqueeze(0),
rerotation_sin.to(key_states.dtype).unsqueeze(0),
)
return self.cos_sin_rerotation_cache[key_states.shape[-2]]

has somewhat of a draft as rope was applied on the sink tokens, and needs to be removed / re-applied
Really not sure how it will go, given that each new key and value depended on the old key and value, but if we assume we are just switching the direction of the vector a little bit, would make sense (as a kind of scaling)

@ArthurZucker : Sorry I may not fully capture your idea, could you elaborate a bit more. But generally, whatever approach, when jump over the switch point 4096, all previous tokens' kvcache are not valid and need re-compute --- this is decided by consistency with how model was trained. If we want to change the model setup, e.g., how to ensure previous tokens kv-cache still valid, that means change the model setup and re-do training, that's a different story.

Copy link
Collaborator

@ArthurZucker ArthurZucker left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM otherwise. We are deprecating the api of get_seq_len as it is incompatible with compile

src/transformers/models/phi3/modeling_phi3.py Outdated Show resolved Hide resolved
src/transformers/models/phi3/modeling_phi3.py Show resolved Hide resolved
Copy link
Member

@gante gante left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you for opening this PR 🔥

The concept of the PR looks good to me 🤗 Regarding the implementation, I have two requests:

  1. The issue was addressed in prepare_inputs_for_generation. If a user uses a package that doesn't rely on generate or decides to build a custom generation loop, they will still face the original issue. As such, I believe the fix should belong to the attention layers! We can use kv_seq_len (the cache length) and q_len (length of new tokens) to invalidate the cache.
  2. Let's add a test to ensure we don't regress! Suggestion: using a dummy model, confirm that the cache contents in the short sequence region change after crossing the threshold (take this test as a reference)

@jiazhan-msft
Copy link

jiazhan-msft commented Sep 4, 2024

Thank you for opening this PR 🔥

The concept of the PR looks good to me 🤗 Regarding the implementation, I have two requests:

  1. The issue was addressed in prepare_inputs_for_generation. If a user uses a package that doesn't rely on generate or decides to build a custom generation loop, they will still face the original issue. As such, I believe the fix should belong to the attention layers! We can use kv_seq_len (the cache length) and q_len (length of new tokens) to invalidate the cache.
  2. Let's add a test to ensure we don't regress! Suggestion: using a dummy model, confirm that the cache contents in the short sequence region change after crossing the threshold (take this test as a reference)

Hi @gante : Thanks for your comments.

About 1., I have no idea how that being fixed in attention layer. Did you mean the model's attention (Phi3Attention class)? The current signature of the forward is:

image

There is no original tokens length info. There is no clear pattern to invalidate cache there either. And, it suppose need change to flashattention impl as well. Please point out code if you have.

To me, the current PR change is clean and better.

For

user uses package doesn't rely on generate or decides to build a custom generation loop

If a user wants to customize anything, that user should be aware of this specific modeling code and follow to it, that's by-design.

For

  1. Let's add a test to ensure we don't regress!

Yes will add as suggested.

@garg-amit
Copy link
Contributor Author

[2] @gante thanks for the comment. I've added a test case for it.

Copy link
Member

@gante gante left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you for iterating 🤗

@gante
Copy link
Member

gante commented Sep 5, 2024

@garg-amit @jiazhan-msft The suggestion above (1.) was mostly to avoid bad model usage when generating text without generate 🤗 I agree that users should know about this detail but, in my experience, it is rarely the case -- models are often treated like replaceable black boxes that return the logits for the next token. As such, given that we have all the information in the forward pass of the model, we can apply the additional logic there and prevent misusage everywhere.

There are multiple ways that can be done, with varying levels of optimization. The easiest approach would be to move the lines you added to prepare_inputs_for_generation, which set past_key_values to None, to the forward pass of Phi3ForCausalLM. In Phi3ForCausalLM we don't have the full input_ids sequence, but we have all needed information in cache_positions: cache_positions[0] is the cache length and cache_positions[-1] is the full sequence length. EDIT: not true, see below

(The most compute optimal way would be to rerotate the existing cache in the attention layer, as it would prevent having to do a forward pass on the full input. However, that would require substantial additional logic.)

At the end of the day, it's your call -- I've approved the PR :)

@garg-amit
Copy link
Contributor Author

@gante Regarding your suggestion to move the logic to the forward function, I’m still unclear on how we will recompute the kv cache without the previous tokens. The input_ids tensor will only contain the tokens from the current iteration when using the use_cache=True parameter.

@gante
Copy link
Member

gante commented Sep 6, 2024

@garg-amit you're absolutely right. In that case, input_ids has been cropped in advance. Therefore, we cannot use it to recompute the cache. Minor oversight in my suggestion above 😉 The only viable option would be the rerotation one, because it uses data that is present in the cache, but it requires additional significant logic.

In that case, may I suggest we throw a warning in the forward pass if we detect we a) have a cache b) we're about to switch to the long factor? That way, users wouldn't see their custom generation code failing silently 🤗

@garg-amit
Copy link
Contributor Author

@gante that makes sense! Added it

Copy link
Collaborator

@ArthurZucker ArthurZucker left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

let's merge this! Thanks all for the highquality issue!

@ArthurZucker ArthurZucker merged commit dfd3115 into huggingface:main Sep 13, 2024
13 checks passed
itazap pushed a commit to NielsRogge/transformers that referenced this pull request Sep 20, 2024
* fix long seq bug

* fixed format

* fixed fn copy inconsistency

* fix long seq bug

* fixed format

* fixed fn copy inconsistency

* Addressed comments

* added a unit test

* fixed cache position

* Added a warning msg to the forward fn

* fixed test case
amyeroberts pushed a commit to amyeroberts/transformers that referenced this pull request Oct 2, 2024
* fix long seq bug

* fixed format

* fixed fn copy inconsistency

* fix long seq bug

* fixed format

* fixed fn copy inconsistency

* Addressed comments

* added a unit test

* fixed cache position

* Added a warning msg to the forward fn

* fixed test case
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

6 participants