-
Notifications
You must be signed in to change notification settings - Fork 27.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Generate: SinkCache can handle iterative prompts #27907
Changes from all commits
e9c5b17
f1f12c1
d50b85e
49e4d40
e319719
029fac9
fa8bf62
dd3d069
e629c21
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -398,7 +398,7 @@ def forward( | |
"for auto-regressive decoding with k/v caching, please make sure to initialize the attention class " | ||
"with a layer index." | ||
) | ||
kv_seq_len += past_key_value.get_seq_length(self.layer_idx) | ||
kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx) | ||
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) | ||
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) | ||
|
||
|
@@ -503,7 +503,7 @@ def forward( | |
|
||
kv_seq_len = key_states.shape[-2] | ||
if past_key_value is not None: | ||
kv_seq_len += past_key_value.get_seq_length(self.layer_idx) | ||
kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx) | ||
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) | ||
|
||
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) | ||
|
@@ -910,7 +910,7 @@ def forward( | |
use_legacy_cache = not isinstance(past_key_values, Cache) | ||
if use_legacy_cache: | ||
past_key_values = DynamicCache.from_legacy_cache(past_key_values) | ||
past_key_values_length = past_key_values.get_seq_length() | ||
past_key_values_length = past_key_values.get_usable_length(seq_length) | ||
|
||
if position_ids is None: | ||
device = input_ids.device if input_ids is not None else inputs_embeds.device | ||
|
@@ -1127,8 +1127,10 @@ def prepare_inputs_for_generation( | |
if isinstance(past_key_values, Cache): | ||
cache_length = past_key_values.get_seq_length() | ||
past_length = past_key_values.seen_tokens | ||
max_cache_length = past_key_values.get_max_length() | ||
else: | ||
cache_length = past_length = past_key_values[0][0].shape[2] | ||
max_cache_length = None | ||
|
||
# Keep only the unprocessed tokens: | ||
# 1 - If the length of the attention_mask exceeds the length of input_ids, then we are in a setting where | ||
|
@@ -1142,10 +1144,13 @@ def prepare_inputs_for_generation( | |
input_ids = input_ids[:, past_length:] | ||
# 3 - Otherwise (past_length >= input_ids.shape[1]), let's assume input_ids only has unprocessed tokens. | ||
|
||
# If the cache has seen more tokens than it can hold, then the cache has a size limit. Let's discard the | ||
# older attention values, as their corresponding values are not part of the input. | ||
if cache_length < past_length and attention_mask is not None: | ||
attention_mask = attention_mask[:, -(cache_length + input_ids.shape[1]) :] | ||
# If we are about to go beyond the maximum cache length, we need to crop the input attention mask. | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. okay better! thanks |
||
if ( | ||
max_cache_length is not None | ||
and attention_mask is not None | ||
and cache_length + input_ids.shape[1] > max_cache_length | ||
): | ||
attention_mask = attention_mask[:, -max_cache_length:] | ||
|
||
position_ids = kwargs.get("position_ids", None) | ||
if attention_mask is not None and position_ids is None: | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -268,7 +268,7 @@ def forward( | |
"for auto-regressive decoding with k/v caching, please make sure to initialize the attention class " | ||
"with a layer index." | ||
) | ||
kv_seq_len += past_key_value.get_seq_length(self.layer_idx) | ||
kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. slice_length or window_length might be better? but a nit feel free to ignore |
||
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) | ||
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) | ||
|
||
|
@@ -363,7 +363,7 @@ def forward( | |
|
||
kv_seq_len = key_states.shape[-2] | ||
if past_key_value is not None: | ||
kv_seq_len += past_key_value.get_seq_length(self.layer_idx) | ||
kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx) | ||
|
||
# Because the input can be padded, the absolute sequence length depends on the max position id. | ||
rotary_seq_len = max(kv_seq_len, position_ids[:, -1].max().item()) + 1 | ||
|
@@ -850,15 +850,13 @@ def forward( | |
else: | ||
raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds") | ||
|
||
seq_length_with_past = seq_length | ||
past_key_values_length = 0 | ||
|
||
if use_cache: | ||
use_legacy_cache = not isinstance(past_key_values, Cache) | ||
if use_legacy_cache: | ||
past_key_values = DynamicCache.from_legacy_cache(past_key_values) | ||
past_key_values_length = past_key_values.get_seq_length() | ||
seq_length_with_past = seq_length_with_past + past_key_values_length | ||
past_key_values_length = past_key_values.get_usable_length(seq_length) | ||
|
||
if position_ids is None: | ||
device = input_ids.device if input_ids is not None else inputs_embeds.device | ||
|
@@ -1092,8 +1090,10 @@ def prepare_inputs_for_generation( | |
if isinstance(past_key_values, Cache): | ||
cache_length = past_key_values.get_seq_length() | ||
past_length = past_key_values.seen_tokens | ||
max_cache_length = past_key_values.get_max_length() | ||
else: | ||
cache_length = past_length = past_key_values[0][0].shape[2] | ||
max_cache_length = None | ||
|
||
# Keep only the unprocessed tokens: | ||
# 1 - If the length of the attention_mask exceeds the length of input_ids, then we are in a setting where | ||
|
@@ -1107,10 +1107,13 @@ def prepare_inputs_for_generation( | |
input_ids = input_ids[:, past_length:] | ||
# 3 - Otherwise (past_length >= input_ids.shape[1]), let's assume input_ids only has unprocessed tokens. | ||
|
||
# If the cache has seen more tokens than it can hold, then the cache has a size limit. Let's discard the | ||
# older attention values, as their corresponding values are not part of the input. | ||
if cache_length < past_length and attention_mask is not None: | ||
attention_mask = attention_mask[:, -(cache_length + input_ids.shape[1]) :] | ||
# If we are about to go beyond the maximum cache length, we need to crop the input attention mask. | ||
if ( | ||
max_cache_length is not None | ||
and attention_mask is not None | ||
and cache_length + input_ids.shape[1] > max_cache_length | ||
): | ||
attention_mask = attention_mask[:, -max_cache_length:] | ||
|
||
position_ids = kwargs.get("position_ids", None) | ||
if attention_mask is not None and position_ids is None: | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -187,3 +187,45 @@ def test_sink_cache_hard(self): | |
gen_out = model.generate(**inputs, do_sample=False, max_new_tokens=3000, past_key_values=cache) | ||
decoded = tokenizer.batch_decode(gen_out, skip_special_tokens=True) | ||
self.assertTrue(decoded[0].endswith("to perform a variety of tasks. The Transformer is a neural network")) | ||
|
||
def test_sink_cache_iterative_prompts(self): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This test would fail on |
||
"""Tests that SinkCache supports more than one new token at once, when shifting the cache""" | ||
tokenizer = AutoTokenizer.from_pretrained("HuggingFaceH4/zephyr-7b-beta") | ||
model = AutoModelForCausalLM.from_pretrained( | ||
"HuggingFaceH4/zephyr-7b-beta", device_map="auto", torch_dtype=torch.float16 | ||
) | ||
prompt = ( | ||
"Compose an engaging travel blog post about a recent trip to Hawaii, highlighting cultural experiences " | ||
"and must-see attractions." | ||
) | ||
|
||
# Prepare generation settings | ||
cache = SinkCache(window_length=256, num_sink_tokens=4) | ||
input_ids = torch.tensor([], device=model.device, dtype=torch.int) | ||
for _ in range(3): | ||
# Tokenize the prompt with the correct chat template | ||
chat = [{"role": "user", "content": prompt}] | ||
tokenized_chat = tokenizer.apply_chat_template(chat, return_tensors="pt", add_generation_prompt=True).to( | ||
model.device | ||
) | ||
input_ids = torch.cat((input_ids, tokenized_chat), dim=1) | ||
|
||
# Perform the generation | ||
gen_out = model.generate( | ||
input_ids, do_sample=False, max_new_tokens=100, past_key_values=cache, use_cache=True | ||
) | ||
input_ids = gen_out | ||
|
||
# We went well beyond the cache length | ||
self.assertTrue(input_ids.shape[1] > cache.get_max_length() * 1.5) | ||
|
||
# And it still produces a coherent english | ||
decoded = tokenizer.batch_decode(input_ids, skip_special_tokens=True) | ||
last_output = ( | ||
"<|assistant|>\nAs the sun began to set over the Pacific Ocean, I found myself standing on the shores of " | ||
"Waikiki Beach, my heart filled with awe and wonder. I had just returned from a two-week journey to the " | ||
"beautiful island of Hawaii, and it had been an unforgettable experience filled with cultural experiences " | ||
"and must-see attractions that left me breathless.\n\nOne of the most memorable experiences of my trip " | ||
"was visiting the historic district of Honolulu. Here," | ||
) | ||
self.assertTrue(decoded[0].endswith(last_output)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This line was a bit of a hack,
get_max_length
makes us no longer need the hack :)get_seq_length
now always does what the fn name and the docstring say it does.