diff --git a/src/transformers/cache_utils.py b/src/transformers/cache_utils.py index a6258182fb0035..b298a7bdd0f5d6 100644 --- a/src/transformers/cache_utils.py +++ b/src/transformers/cache_utils.py @@ -38,6 +38,21 @@ def get_seq_length(self, layer_idx: Optional[int] = 0) -> int: """Returns the sequence length of the cached states. A layer index can be optionally passed.""" raise NotImplementedError("Make sure to implement `get_seq_length` in a subclass.") + def get_max_length(self) -> Optional[int]: + """Returns the maximum sequence length of the cached states, if there is any.""" + raise NotImplementedError("Make sure to implement `get_max_length` in a subclass.") + + def get_usable_length(self, new_seq_length: int, layer_idx: Optional[int] = 0) -> int: + """Given the sequence length of the new inputs, returns the usable length of the cache.""" + # Cache without size limit -> all cache is usable + # Cache with size limit -> if the length cache plus the length of the new inputs is larger the maximum cache + # length, we will need to evict part of the cache (and thus not all cache is usable) + max_length = self.get_max_length() + previous_seq_length = self.get_seq_length(layer_idx) + if max_length is not None and previous_seq_length + new_seq_length > max_length: + return max_length - new_seq_length + return previous_seq_length + class DynamicCache(Cache): """ @@ -120,6 +135,10 @@ def get_seq_length(self, layer_idx: Optional[int] = 0) -> int: return 0 return self.key_cache[layer_idx].shape[-2] + def get_max_length(self) -> Optional[int]: + """Returns the maximum sequence length of the cached states. DynamicCache does not have a maximum length.""" + return None + def reorder_cache(self, beam_idx: torch.LongTensor): """Reorders the cache for beam search, given the selected beam indices.""" for layer_idx in range(len(self.key_cache)): @@ -209,8 +228,11 @@ def get_seq_length(self, layer_idx: Optional[int] = 0) -> int: # Workaround to make 'key_states.shape[-2] + past_key_value.get_seq_length(self.layer_idx)' <= window_length if len(self.key_cache) <= layer_idx: return 0 - cache_length = self.key_cache[layer_idx].shape[-2] - return min(cache_length, self.window_length - 1) + return self.key_cache[layer_idx].shape[-2] + + def get_max_length(self) -> Optional[int]: + """Returns the maximum sequence length of the cached states.""" + return self.window_length def update( self, @@ -267,7 +289,9 @@ def update( # On RoPE models, we need to recompute the Key rotation as the tokens are shifted if using_rope: - rerotation_cos, rerotation_sin = self._get_rerotation_cos_sin(key_states, cos, sin) + rerotation_cos, rerotation_sin = self._get_rerotation_cos_sin( + key_states, cos[: self.window_length], sin[: self.window_length] + ) if partial_rotation_size is not None: keys_to_keep, keys_pass = ( keys_to_keep[..., :partial_rotation_size], diff --git a/src/transformers/models/llama/modeling_llama.py b/src/transformers/models/llama/modeling_llama.py index 9791c99a1ffdb8..f4234e9a775499 100644 --- a/src/transformers/models/llama/modeling_llama.py +++ b/src/transformers/models/llama/modeling_llama.py @@ -398,7 +398,7 @@ def forward( "for auto-regressive decoding with k/v caching, please make sure to initialize the attention class " "with a layer index." ) - kv_seq_len += past_key_value.get_seq_length(self.layer_idx) + kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx) cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) @@ -503,7 +503,7 @@ def forward( kv_seq_len = key_states.shape[-2] if past_key_value is not None: - kv_seq_len += past_key_value.get_seq_length(self.layer_idx) + kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx) cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) @@ -910,7 +910,7 @@ def forward( use_legacy_cache = not isinstance(past_key_values, Cache) if use_legacy_cache: past_key_values = DynamicCache.from_legacy_cache(past_key_values) - past_key_values_length = past_key_values.get_seq_length() + past_key_values_length = past_key_values.get_usable_length(seq_length) if position_ids is None: device = input_ids.device if input_ids is not None else inputs_embeds.device @@ -1127,8 +1127,10 @@ def prepare_inputs_for_generation( if isinstance(past_key_values, Cache): cache_length = past_key_values.get_seq_length() past_length = past_key_values.seen_tokens + max_cache_length = past_key_values.get_max_length() else: cache_length = past_length = past_key_values[0][0].shape[2] + max_cache_length = None # Keep only the unprocessed tokens: # 1 - If the length of the attention_mask exceeds the length of input_ids, then we are in a setting where @@ -1142,10 +1144,13 @@ def prepare_inputs_for_generation( input_ids = input_ids[:, past_length:] # 3 - Otherwise (past_length >= input_ids.shape[1]), let's assume input_ids only has unprocessed tokens. - # If the cache has seen more tokens than it can hold, then the cache has a size limit. Let's discard the - # older attention values, as their corresponding values are not part of the input. - if cache_length < past_length and attention_mask is not None: - attention_mask = attention_mask[:, -(cache_length + input_ids.shape[1]) :] + # If we are about to go beyond the maximum cache length, we need to crop the input attention mask. + if ( + max_cache_length is not None + and attention_mask is not None + and cache_length + input_ids.shape[1] > max_cache_length + ): + attention_mask = attention_mask[:, -max_cache_length:] position_ids = kwargs.get("position_ids", None) if attention_mask is not None and position_ids is None: diff --git a/src/transformers/models/mistral/modeling_mistral.py b/src/transformers/models/mistral/modeling_mistral.py index 34a8f0bfa8128e..ef65d8a0894b12 100644 --- a/src/transformers/models/mistral/modeling_mistral.py +++ b/src/transformers/models/mistral/modeling_mistral.py @@ -268,7 +268,7 @@ def forward( "for auto-regressive decoding with k/v caching, please make sure to initialize the attention class " "with a layer index." ) - kv_seq_len += past_key_value.get_seq_length(self.layer_idx) + kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx) cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) @@ -363,7 +363,7 @@ def forward( kv_seq_len = key_states.shape[-2] if past_key_value is not None: - kv_seq_len += past_key_value.get_seq_length(self.layer_idx) + kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx) # Because the input can be padded, the absolute sequence length depends on the max position id. rotary_seq_len = max(kv_seq_len, position_ids[:, -1].max().item()) + 1 @@ -850,15 +850,13 @@ def forward( else: raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds") - seq_length_with_past = seq_length past_key_values_length = 0 if use_cache: use_legacy_cache = not isinstance(past_key_values, Cache) if use_legacy_cache: past_key_values = DynamicCache.from_legacy_cache(past_key_values) - past_key_values_length = past_key_values.get_seq_length() - seq_length_with_past = seq_length_with_past + past_key_values_length + past_key_values_length = past_key_values.get_usable_length(seq_length) if position_ids is None: device = input_ids.device if input_ids is not None else inputs_embeds.device @@ -1092,8 +1090,10 @@ def prepare_inputs_for_generation( if isinstance(past_key_values, Cache): cache_length = past_key_values.get_seq_length() past_length = past_key_values.seen_tokens + max_cache_length = past_key_values.get_max_length() else: cache_length = past_length = past_key_values[0][0].shape[2] + max_cache_length = None # Keep only the unprocessed tokens: # 1 - If the length of the attention_mask exceeds the length of input_ids, then we are in a setting where @@ -1107,10 +1107,13 @@ def prepare_inputs_for_generation( input_ids = input_ids[:, past_length:] # 3 - Otherwise (past_length >= input_ids.shape[1]), let's assume input_ids only has unprocessed tokens. - # If the cache has seen more tokens than it can hold, then the cache has a size limit. Let's discard the - # older attention values, as their corresponding values are not part of the input. - if cache_length < past_length and attention_mask is not None: - attention_mask = attention_mask[:, -(cache_length + input_ids.shape[1]) :] + # If we are about to go beyond the maximum cache length, we need to crop the input attention mask. + if ( + max_cache_length is not None + and attention_mask is not None + and cache_length + input_ids.shape[1] > max_cache_length + ): + attention_mask = attention_mask[:, -max_cache_length:] position_ids = kwargs.get("position_ids", None) if attention_mask is not None and position_ids is None: diff --git a/src/transformers/models/persimmon/modeling_persimmon.py b/src/transformers/models/persimmon/modeling_persimmon.py index eacccb1564c8d7..17163dcd8edf9b 100644 --- a/src/transformers/models/persimmon/modeling_persimmon.py +++ b/src/transformers/models/persimmon/modeling_persimmon.py @@ -295,7 +295,7 @@ def forward( "for auto-regressive decoding with k/v caching, please make sure to initialize the attention class " "with a layer index." ) - kv_seq_len += past_key_value.get_seq_length(self.layer_idx) + kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx) cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) # Partial rotary embedding @@ -612,7 +612,7 @@ def forward( use_legacy_cache = not isinstance(past_key_values, Cache) if use_legacy_cache: past_key_values = DynamicCache.from_legacy_cache(past_key_values) - past_key_values_length = past_key_values.get_seq_length() + past_key_values_length = past_key_values.get_usable_length(seq_length) seq_length_with_past = seq_length_with_past + past_key_values_length if position_ids is None: @@ -831,8 +831,10 @@ def prepare_inputs_for_generation( if isinstance(past_key_values, Cache): cache_length = past_key_values.get_seq_length() past_length = past_key_values.seen_tokens + max_cache_length = past_key_values.get_max_length() else: cache_length = past_length = past_key_values[0][0].shape[2] + max_cache_length = None # Keep only the unprocessed tokens: # 1 - If the length of the attention_mask exceeds the length of input_ids, then we are in a setting where @@ -846,10 +848,13 @@ def prepare_inputs_for_generation( input_ids = input_ids[:, past_length:] # 3 - Otherwise (past_length >= input_ids.shape[1]), let's assume input_ids only has unprocessed tokens. - # If the cache has seen more tokens than it can hold, then the cache has a size limit. Let's discard the - # older attention values, as their corresponding values are not part of the input. - if cache_length < past_length and attention_mask is not None: - attention_mask = attention_mask[:, -(cache_length + input_ids.shape[1]) :] + # If we are about to go beyond the maximum cache length, we need to crop the input attention mask. + if ( + max_cache_length is not None + and attention_mask is not None + and cache_length + input_ids.shape[1] > max_cache_length + ): + attention_mask = attention_mask[:, -max_cache_length:] position_ids = kwargs.get("position_ids", None) if attention_mask is not None and position_ids is None: diff --git a/src/transformers/models/phi/modeling_phi.py b/src/transformers/models/phi/modeling_phi.py index 4431c10eebcec4..ca32193d535893 100644 --- a/src/transformers/models/phi/modeling_phi.py +++ b/src/transformers/models/phi/modeling_phi.py @@ -334,7 +334,7 @@ def forward( "for auto-regressive decoding with k/v caching, please make sure to initialize the attention class " "with a layer index." ) - kv_seq_len += past_key_value.get_seq_length(self.layer_idx) + kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx) cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) # Partial rotary embedding @@ -444,7 +444,7 @@ def forward( kv_seq_len = key_states.shape[-2] if past_key_value is not None: - kv_seq_len += past_key_value.get_seq_length(self.layer_idx) + kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx) cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) # Partial rotary embedding @@ -855,15 +855,13 @@ def forward( else: raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds") - seq_length_with_past = seq_length past_key_values_length = 0 if use_cache: use_legacy_cache = not isinstance(past_key_values, Cache) if use_legacy_cache: past_key_values = DynamicCache.from_legacy_cache(past_key_values) - past_key_values_length = past_key_values.get_seq_length() - seq_length_with_past = seq_length_with_past + past_key_values_length + past_key_values_length = past_key_values.get_usable_length(seq_length) if position_ids is None: device = input_ids.device if input_ids is not None else inputs_embeds.device @@ -1085,8 +1083,10 @@ def prepare_inputs_for_generation( if isinstance(past_key_values, Cache): cache_length = past_key_values.get_seq_length() past_length = past_key_values.seen_tokens + max_cache_length = past_key_values.get_max_length() else: cache_length = past_length = past_key_values[0][0].shape[2] + max_cache_length = None # Keep only the unprocessed tokens: # 1 - If the length of the attention_mask exceeds the length of input_ids, then we are in a setting where @@ -1100,10 +1100,13 @@ def prepare_inputs_for_generation( input_ids = input_ids[:, past_length:] # 3 - Otherwise (past_length >= input_ids.shape[1]), let's assume input_ids only has unprocessed tokens. - # If the cache has seen more tokens than it can hold, then the cache has a size limit. Let's discard the - # older attention values, as their corresponding values are not part of the input. - if cache_length < past_length and attention_mask is not None: - attention_mask = attention_mask[:, -(cache_length + input_ids.shape[1]) :] + # If we are about to go beyond the maximum cache length, we need to crop the input attention mask. + if ( + max_cache_length is not None + and attention_mask is not None + and cache_length + input_ids.shape[1] > max_cache_length + ): + attention_mask = attention_mask[:, -max_cache_length:] position_ids = kwargs.get("position_ids", None) if attention_mask is not None and position_ids is None: diff --git a/tests/test_cache_utils.py b/tests/test_cache_utils.py index 1ffc4962787905..72d055c8806afd 100644 --- a/tests/test_cache_utils.py +++ b/tests/test_cache_utils.py @@ -187,3 +187,45 @@ def test_sink_cache_hard(self): gen_out = model.generate(**inputs, do_sample=False, max_new_tokens=3000, past_key_values=cache) decoded = tokenizer.batch_decode(gen_out, skip_special_tokens=True) self.assertTrue(decoded[0].endswith("to perform a variety of tasks. The Transformer is a neural network")) + + def test_sink_cache_iterative_prompts(self): + """Tests that SinkCache supports more than one new token at once, when shifting the cache""" + tokenizer = AutoTokenizer.from_pretrained("HuggingFaceH4/zephyr-7b-beta") + model = AutoModelForCausalLM.from_pretrained( + "HuggingFaceH4/zephyr-7b-beta", device_map="auto", torch_dtype=torch.float16 + ) + prompt = ( + "Compose an engaging travel blog post about a recent trip to Hawaii, highlighting cultural experiences " + "and must-see attractions." + ) + + # Prepare generation settings + cache = SinkCache(window_length=256, num_sink_tokens=4) + input_ids = torch.tensor([], device=model.device, dtype=torch.int) + for _ in range(3): + # Tokenize the prompt with the correct chat template + chat = [{"role": "user", "content": prompt}] + tokenized_chat = tokenizer.apply_chat_template(chat, return_tensors="pt", add_generation_prompt=True).to( + model.device + ) + input_ids = torch.cat((input_ids, tokenized_chat), dim=1) + + # Perform the generation + gen_out = model.generate( + input_ids, do_sample=False, max_new_tokens=100, past_key_values=cache, use_cache=True + ) + input_ids = gen_out + + # We went well beyond the cache length + self.assertTrue(input_ids.shape[1] > cache.get_max_length() * 1.5) + + # And it still produces a coherent english + decoded = tokenizer.batch_decode(input_ids, skip_special_tokens=True) + last_output = ( + "<|assistant|>\nAs the sun began to set over the Pacific Ocean, I found myself standing on the shores of " + "Waikiki Beach, my heart filled with awe and wonder. I had just returned from a two-week journey to the " + "beautiful island of Hawaii, and it had been an unforgettable experience filled with cultural experiences " + "and must-see attractions that left me breathless.\n\nOne of the most memorable experiences of my trip " + "was visiting the historic district of Honolulu. Here," + ) + self.assertTrue(decoded[0].endswith(last_output))