-
Notifications
You must be signed in to change notification settings - Fork 27.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Generate: SinkCache can handle iterative prompts #27907
Changes from 5 commits
e9c5b17
f1f12c1
d50b85e
49e4d40
e319719
029fac9
fa8bf62
dd3d069
e629c21
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change | ||||
---|---|---|---|---|---|---|
|
@@ -38,6 +38,10 @@ def get_seq_length(self, layer_idx: Optional[int] = 0) -> int: | |||||
"""Returns the sequence length of the cached states. A layer index can be optionally passed.""" | ||||||
raise NotImplementedError("Make sure to implement `get_seq_length` in a subclass.") | ||||||
|
||||||
def get_max_length(self) -> Optional[int]: | ||||||
"""Returns the maximum sequence length of the cached states, if there is any.""" | ||||||
raise NotImplementedError("Make sure to implement `get_max_length` in a subclass.") | ||||||
|
||||||
|
||||||
class DynamicCache(Cache): | ||||||
""" | ||||||
|
@@ -120,6 +124,10 @@ def get_seq_length(self, layer_idx: Optional[int] = 0) -> int: | |||||
return 0 | ||||||
return self.key_cache[layer_idx].shape[-2] | ||||||
|
||||||
def get_max_length(self) -> Optional[int]: | ||||||
"""Returns the maximum sequence length of the cached states. DynamicCache does not have a maximum length.""" | ||||||
return None | ||||||
|
||||||
def reorder_cache(self, beam_idx: torch.LongTensor): | ||||||
"""Reorders the cache for beam search, given the selected beam indices.""" | ||||||
for layer_idx in range(len(self.key_cache)): | ||||||
|
@@ -209,8 +217,11 @@ def get_seq_length(self, layer_idx: Optional[int] = 0) -> int: | |||||
# Workaround to make 'key_states.shape[-2] + past_key_value.get_seq_length(self.layer_idx)' <= window_length | ||||||
if len(self.key_cache) <= layer_idx: | ||||||
return 0 | ||||||
cache_length = self.key_cache[layer_idx].shape[-2] | ||||||
return min(cache_length, self.window_length - 1) | ||||||
return self.key_cache[layer_idx].shape[-2] | ||||||
|
||||||
def get_max_length(self) -> Optional[int]: | ||||||
"""Returns the maximum sequence length of the cached states.""" | ||||||
return self.window_length | ||||||
|
||||||
def update( | ||||||
self, | ||||||
|
@@ -239,8 +250,8 @@ def update( | |||||
""" | ||||||
# Optional kwargs for `SinkCache` -- needed on models using RoPE. `partial_rotation_size` is used on models | ||||||
# with partially rotated position embeddings, like Phi or Persimmon. | ||||||
sin = cache_kwargs.get("sin") | ||||||
cos = cache_kwargs.get("cos") | ||||||
sin = cache_kwargs.get("sin")[: self.window_length] | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Slicing here is needed if more than one token is fed at once, after the cache is full. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
would that not make more sense? since that's the side we split the cache from no? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I'll run some tests for this There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Updated my previous message with my findings There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @ArthurZucker @tomaarsen that's the beauty of sin and cos: the rerotation applied on the sink caches is based on the relative angles, and the relative angles are the same regardless of the side we slice 🙌 If you place a debugger, you can see that the sliced tensors are different, but the rerotation coefficients are exactly the same! |
||||||
cos = cache_kwargs.get("cos")[: self.window_length] | ||||||
partial_rotation_size = cache_kwargs.get("partial_rotation_size") | ||||||
using_rope = cos is not None and sin is not None | ||||||
|
||||||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -405,6 +405,12 @@ def forward( | |
if past_key_value is not None: | ||
cache_kwargs = {"sin": sin, "cos": cos} # Specific to RoPE models | ||
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) | ||
# If the cache has a fixed length and we were about to go beyond it, update the key/value length and the | ||
# attention mask accordingly. `.update()` handles the cache cropping internally if needed. | ||
kv_max_length = past_key_value.get_max_length() | ||
if kv_max_length is not None and kv_seq_len > kv_max_length: | ||
kv_seq_len = kv_max_length | ||
attention_mask = attention_mask[:, :, :, -kv_seq_len:] | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The attention mask must be sliced to match the length of There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Ouch not a fan of that. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The alternative would be to make |
||
|
||
key_states = repeat_kv(key_states, self.num_key_value_groups) | ||
value_states = repeat_kv(value_states, self.num_key_value_groups) | ||
|
@@ -511,6 +517,12 @@ def forward( | |
if past_key_value is not None: | ||
cache_kwargs = {"sin": sin, "cos": cos} # Specific to RoPE models | ||
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) | ||
# If the cache has a fixed length and we were about to go beyond it, update the key/value length and the | ||
# attention mask accordingly. `.update()` handles the cache cropping internally if needed. | ||
kv_max_length = past_key_value.get_max_length() | ||
if kv_max_length is not None and kv_seq_len > kv_max_length: | ||
kv_seq_len = kv_max_length | ||
attention_mask = attention_mask[:, :, :, -kv_seq_len:] | ||
|
||
# TODO: These transpose are quite inefficient but Flash Attention requires the layout [batch_size, sequence_length, num_heads, head_dim]. We would need to refactor the KV cache | ||
# to be able to avoid many of these transpose/reshape/view. | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -187,3 +187,44 @@ def test_sink_cache_hard(self): | |
gen_out = model.generate(**inputs, do_sample=False, max_new_tokens=3000, past_key_values=cache) | ||
decoded = tokenizer.batch_decode(gen_out, skip_special_tokens=True) | ||
self.assertTrue(decoded[0].endswith("to perform a variety of tasks. The Transformer is a neural network")) | ||
|
||
@require_auto_gptq | ||
def test_sink_cache_iterative_prompts(self): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This test would fail on |
||
"""Tests that SinkCache supports more than one new token at once, when shifting the cache""" | ||
tokenizer = AutoTokenizer.from_pretrained("TheBloke/zephyr-7B-beta-GPTQ") | ||
model = AutoModelForCausalLM.from_pretrained("TheBloke/zephyr-7B-beta-GPTQ", device_map="auto") | ||
prompt = ( | ||
"Compose an engaging travel blog post about a recent trip to Hawaii, highlighting cultural experiences " | ||
"and must-see attractions." | ||
) | ||
|
||
# Prepare generation settings | ||
cache = SinkCache(window_length=256, num_sink_tokens=4) | ||
input_ids = torch.tensor([], device=model.device, dtype=torch.int) | ||
for _ in range(3): | ||
# Tokenize the prompt with the correct chat template | ||
chat = [{"role": "user", "content": prompt}] | ||
tokenized_chat = tokenizer.apply_chat_template(chat, return_tensors="pt", add_generation_prompt=True).to( | ||
model.device | ||
) | ||
input_ids = torch.cat((input_ids, tokenized_chat), dim=1) | ||
|
||
# Perform the generation | ||
gen_out = model.generate( | ||
input_ids, do_sample=False, max_new_tokens=100, past_key_values=cache, use_cache=True | ||
) | ||
input_ids = gen_out | ||
|
||
# We went well beyond the cache length | ||
self.assertTrue(input_ids.shape[1] > cache.get_max_length() * 1.5) | ||
|
||
# And it still produces a coherent english (the repetition is due to the prompt being repeated 3 times) | ||
decoded = tokenizer.batch_decode(input_ids, skip_special_tokens=True) | ||
last_output = ( | ||
"<|assistant|>\nHawaii, the Aloha State post for a travel destination you've taken. Your post's and " | ||
"must-see landmarks. Use a descriptive and engaging writing style, incorporating personal anecdotes and " | ||
"recommendations for fellow travelers. Your post should be at least 800 words and include high-quality " | ||
"images to enhance the reader's experience. Be sure to cover a variety of experiences, from cultural " | ||
"immersion to outdoor adventures, and provide practical" | ||
) | ||
self.assertTrue(decoded[0].endswith(last_output)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This line was a bit of a hack,
get_max_length
makes us no longer need the hack :)get_seq_length
now always does what the fn name and the docstring say it does.