diff --git a/src/transformers/cache_utils.py b/src/transformers/cache_utils.py index 22d0e44b2d90cb..abdc3c7c0707bc 100644 --- a/src/transformers/cache_utils.py +++ b/src/transformers/cache_utils.py @@ -344,17 +344,15 @@ class StaticCache(Cache): The default `dtype` to use when initializing the layer. """ - def __init__( - self, config: PretrainedConfig, max_batch_size: int, max_cache_len: int, device, dtype=torch.float32 - ) -> None: + def __init__(self, config: PretrainedConfig, max_batch_size: int, max_cache_len: int, device, dtype=None) -> None: super().__init__() self.max_batch_size = max_batch_size self.max_cache_len = config.max_position_embeddings if max_cache_len is None else max_cache_len self.head_dim = config.hidden_size // config.num_attention_heads + self.dtype = dtype if dtype is not None else torch.float32 self.num_key_value_heads = ( config.num_attention_heads if config.num_key_value_heads is None else config.num_key_value_heads ) - self.dtype = config.torch_dtype if config.torch_dtype is not None else dtype cache_shape = (max_batch_size, self.num_key_value_heads, self.max_cache_len, self.head_dim) self.key_cache: torch.Tensor = torch.zeros(cache_shape, dtype=self.dtype, device=device) @@ -386,20 +384,23 @@ def update( Return: A tuple containing the updated key and value states. """ - new_cache_positions = cache_kwargs.get("position_ids") + new_cache_positions = cache_kwargs.get("cache_position") k_out = self.key_cache v_out = self.value_cache k_out[:, :, new_cache_positions] = key_states v_out[:, :, new_cache_positions] = value_states - self.seen_tokens += key_states.shape[-2] + self.seen_tokens += key_states.shape[2] return k_out, v_out def get_seq_length(self, layer_idx: Optional[int] = 0) -> int: """Returns the sequence length of the cached states that were seen by the model. `layer_idx` kept for BC""" return self.seen_tokens + def get_usable_length(self, new_sequence_length=None, layer_idx: Optional[int] = 0) -> int: + return self.seen_tokens + def get_max_length(self) -> Optional[int]: """Returns the maximum sequence length of the cached states. DynamicCache does not have a maximum length.""" return self.max_cache_len diff --git a/src/transformers/generation/utils.py b/src/transformers/generation/utils.py index 0bbdd643421996..dd8fa604d63e94 100644 --- a/src/transformers/generation/utils.py +++ b/src/transformers/generation/utils.py @@ -4776,8 +4776,9 @@ def _split_model_inputs( # Here we can have four types of values: tensors, tuples of tensors and booleans, and encoder_outputs which is a # ModelOutput object. # bool should not be split but replicated for each split - bool_keys = [k for k in keys if isinstance(model_input[k], bool)] - non_bool_keys = [k for k in keys if not isinstance(model_input[k], bool) and not k == "encoder_outputs"] + bool_keys = [k for k in keys if isinstance(model_input[k], bool) or k == "cache_position"] + keys_to_ignore = ["cache_position", "encoder_outputs"] + non_bool_keys = [k for k in keys if not isinstance(model_input[k], bool) and k not in keys_to_ignore] # we split the tensors and tuples of tensors data_split_list = [ diff --git a/src/transformers/models/llama/modeling_llama.py b/src/transformers/models/llama/modeling_llama.py index 426db7a8c09208..c30be2a2da4f63 100644 --- a/src/transformers/models/llama/modeling_llama.py +++ b/src/transformers/models/llama/modeling_llama.py @@ -29,7 +29,7 @@ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss from ...activations import ACT2FN -from ...cache_utils import Cache, DynamicCache +from ...cache_utils import Cache, DynamicCache, StaticCache from ...modeling_outputs import ( BaseModelOutputWithPast, CausalLMOutputWithPast, @@ -303,6 +303,7 @@ def forward( past_key_value: Optional[Cache] = None, output_attentions: bool = False, use_cache: bool = False, + cache_position: Optional[torch.LongTensor] = None, **kwargs, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: bsz, q_len, _ = hidden_states.size() @@ -333,21 +334,13 @@ def forward( key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) - kv_seq_len = key_states.shape[-2] - past_seen_tokens = 0 past_key_value = getattr(self, "past_key_value", past_key_value) - if past_key_value is not None: - past_seen_tokens = past_key_value.get_usable_length(kv_seq_len, self.layer_idx) # add what was seen - kv_seq_len += past_seen_tokens - - new_cache_positions = torch.arange(past_seen_tokens, past_seen_tokens + q_len, device=key_states.device) - position_ids = new_cache_positions.unsqueeze(0) if position_ids is None else position_ids - cos, sin = self.rotary_emb(value_states, position_ids, seq_len=kv_seq_len) - query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) + cos, sin = self.rotary_emb(value_states, position_ids, seq_len=None) + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, None) if past_key_value is not None: # sin and cos are specific to RoPE models; position_ids needed for the static cache - cache_kwargs = {"sin": sin, "cos": cos, "position_ids": new_cache_positions} + cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) key_states = repeat_kv(key_states, self.num_key_value_groups) @@ -356,7 +349,8 @@ def forward( attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim) if attention_mask is not None: # no matter the length, we just slice it - causal_mask = attention_mask[..., past_seen_tokens : past_seen_tokens + q_len, : key_states.shape[-2]] + if cache_position is not None: + causal_mask = attention_mask[:, :, cache_position, : key_states.shape[-2]] attn_weights = attn_weights + causal_mask # upcast attention to fp32 @@ -410,6 +404,7 @@ def forward( past_key_value: Optional[Cache] = None, output_attentions: bool = False, use_cache: bool = False, + cache_position: Optional[torch.LongTensor] = None, **kwargs, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: output_attentions = False @@ -427,20 +422,14 @@ def forward( key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) - kv_seq_len = key_states.shape[-2] - past_seen_tokens = 0 - past_key_value = getattr(self, "past_key_value", past_key_value) - if past_key_value is not None: - past_seen_tokens = past_key_value.get_usable_length(kv_seq_len, self.layer_idx) # add what was seen - kv_seq_len += past_seen_tokens + cos, sin = self.rotary_emb(value_states, position_ids, seq_len=None) + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, None) - new_cache_positions = torch.arange(past_seen_tokens, past_seen_tokens + q_len, device=key_states.device) - position_ids = new_cache_positions.unsqueeze(0) if position_ids is None else position_ids - cos, sin = self.rotary_emb(value_states, position_ids, seq_len=kv_seq_len) - query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) + past_key_value = getattr(self, "past_key_value", past_key_value) if past_key_value is not None: - cache_kwargs = {"sin": sin, "cos": cos, "position_ids": new_cache_positions} # Specific to RoPE models + # sin and cos are specific to RoPE models; position_ids needed for the static cache + cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) # TODO: These transpose are quite inefficient but Flash Attention requires the layout [batch_size, sequence_length, num_heads, head_dim]. We would need to refactor the KV cache @@ -603,6 +592,7 @@ def forward( past_key_value: Optional[Cache] = None, output_attentions: bool = False, use_cache: bool = False, + cache_position: Optional[torch.LongTensor] = None, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: if output_attentions: # TODO: Improve this warning with e.g. `model.config.attn_implementation = "manual"` once this is implemented. @@ -617,6 +607,7 @@ def forward( past_key_value=past_key_value, output_attentions=output_attentions, use_cache=use_cache, + cache_position=cache_position, ) bsz, q_len, _ = hidden_states.size() @@ -629,29 +620,22 @@ def forward( key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) - kv_seq_len = key_states.shape[-2] - past_seen_tokens = 0 - past_key_value = getattr(self, "past_key_value", past_key_value) - if past_key_value is not None: - past_seen_tokens = past_key_value.get_usable_length(kv_seq_len, self.layer_idx) # add what was seen - kv_seq_len += past_seen_tokens + cos, sin = self.rotary_emb(value_states, position_ids, seq_len=None) + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, None) - new_cache_positions = torch.arange(past_seen_tokens, past_seen_tokens + q_len, device=key_states.device) - position_ids = new_cache_positions.unsqueeze(0) if position_ids is None else position_ids - cos, sin = self.rotary_emb(value_states, position_ids, seq_len=kv_seq_len) - query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) + past_key_value = getattr(self, "past_key_value", past_key_value) if past_key_value is not None: # sin and cos are specific to RoPE models; position_ids needed for the static cache - cache_kwargs = {"sin": sin, "cos": cos, "position_ids": new_cache_positions} + cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) key_states = repeat_kv(key_states, self.num_key_value_groups) value_states = repeat_kv(value_states, self.num_key_value_groups) - causal_mask = None - if attention_mask is not None: - causal_mask = attention_mask[:, :, past_seen_tokens : past_seen_tokens + q_len, : key_states.shape[-2]] + causal_mask = attention_mask + if attention_mask is not None and cache_position is not None: + causal_mask = causal_mask[:, :, cache_position, : key_states.shape[-2]] # SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask, # Reference: https://github.com/pytorch/pytorch/issues/112577. @@ -666,7 +650,6 @@ def forward( value_states, attn_mask=causal_mask, dropout_p=self.attention_dropout if self.training else 0.0, - is_causal=causal_mask is None, ) attn_output = attn_output.transpose(1, 2).contiguous() @@ -703,6 +686,7 @@ def forward( past_key_value: Optional[Tuple[torch.Tensor]] = None, output_attentions: Optional[bool] = False, use_cache: Optional[bool] = False, + cache_position: Optional[torch.LongTensor] = None, **kwargs, ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: """ @@ -736,6 +720,7 @@ def forward( past_key_value=past_key_value, output_attentions=output_attentions, use_cache=use_cache, + cache_position=cache_position, **kwargs, ) hidden_states = residual + hidden_states @@ -800,13 +785,20 @@ def _init_weights(self, module): module.weight.data[module.padding_idx].zero_() def _setup_cache(self, cache_cls, max_batch_size, max_cache_len: Optional[int] = None): + if self.config._attn_implementation == "flash_attention_2" and cache_cls == StaticCache: + raise ValueError( + "`static` cache implementation is not compatible with `attn_implementation==flash_attention_2` " + "make sure to use `sdpa` in the mean time, and open an issue at https://github.com/huggingface/transformers" + ) + if max_cache_len > self.model.causal_mask.shape[-1] or self.device != self.model.causal_mask.device: causal_mask = torch.full((max_cache_len, max_cache_len), fill_value=1, device=self.device) self.register_buffer("causal_mask", torch.triu(causal_mask, diagonal=1), persistent=False) for layer in self.model.layers: + weights = layer.self_attn.o_proj.weight layer.self_attn.past_key_value = cache_cls( - self.config, max_batch_size, max_cache_len, device=layer.self_attn.o_proj.weight.device + self.config, max_batch_size, max_cache_len, device=weights.device, dtype=weights.dtype ) def _reset_cache(self): @@ -932,6 +924,7 @@ def forward( output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, ) -> Union[Tuple, BaseModelOutputWithPast]: output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( @@ -951,12 +944,23 @@ def forward( ) use_cache = False - if use_cache and not isinstance(past_key_values, Cache): - past_key_values = DynamicCache.from_legacy_cache(past_key_values) - if inputs_embeds is None: inputs_embeds = self.embed_tokens(input_ids) + past_seen_tokens = 0 + if use_cache: # kept for BC (cache positions) + if not isinstance(past_key_values, StaticCache): + past_key_values = DynamicCache.from_legacy_cache(past_key_values) + past_seen_tokens = past_key_values.get_seq_length() + + if cache_position is None: + cache_position = torch.arange( + past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device + ) + + if position_ids is None: + position_ids = cache_position.unsqueeze(0) + causal_mask = self._update_causal_mask(attention_mask, inputs_embeds) # embed positions @@ -980,6 +984,7 @@ def forward( past_key_values, output_attentions, use_cache, + cache_position, ) else: layer_outputs = decoder_layer( @@ -989,6 +994,7 @@ def forward( past_key_value=past_key_values, output_attentions=output_attentions, use_cache=use_cache, + cache_position=cache_position, ) hidden_states = layer_outputs[0] @@ -1021,8 +1027,9 @@ def forward( def _update_causal_mask(self, attention_mask, input_tensor): if self.config._attn_implementation == "flash_attention_2": - causal_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None - return causal_mask + if attention_mask is not None and 0.0 in attention_mask: + return attention_mask + return None batch_size, seq_length = input_tensor.shape[:2] dtype = input_tensor.dtype @@ -1051,14 +1058,11 @@ def _update_causal_mask(self, attention_mask, input_tensor): ) if self.config._attn_implementation == "sdpa": - if attention_mask is None: - return None is_tracing = torch.jit.is_tracing() or isinstance(input_tensor, torch.fx.Proxy) - if not is_tracing and (torch.all(attention_mask == 1)): - return None - if is_tracing and seq_length == 1: - return None - causal_mask = causal_mask.mul(~torch.all(causal_mask == causal_mask.min(), dim=-1)[..., None]).to(dtype) + if not is_tracing and attention_mask is not None and torch.any(attention_mask != 1): + causal_mask = causal_mask.mul(~torch.all(causal_mask == causal_mask.min(), dim=-1)[..., None]).to( + dtype + ) return causal_mask @@ -1107,6 +1111,7 @@ def forward( output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, ) -> Union[Tuple, CausalLMOutputWithPast]: r""" Args: @@ -1150,6 +1155,7 @@ def forward( output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, + cache_position=cache_position, ) hidden_states = outputs[0] @@ -1189,6 +1195,7 @@ def forward( def prepare_inputs_for_generation( self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs ): + past_length = 0 if past_key_values is not None: if isinstance(past_key_values, Cache): cache_length = past_key_values.get_seq_length() @@ -1228,9 +1235,17 @@ def prepare_inputs_for_generation( if past_key_value := getattr(self.model.layers[0].self_attn, "past_key_value", None): # generation with static cache - seen_tokens = past_key_value.get_seq_length() - input_ids = input_ids[:, seen_tokens:] - position_ids = position_ids[:, seen_tokens:] + past_length = past_key_value.get_seq_length() + input_ids = input_ids[:, past_length:] + position_ids = position_ids[:, past_length:] + + # TODO @gante we should only keep a `cache_position` in generate, and do +=1. + # same goes for position ids. Could also help with continued generation. + cache_position = kwargs.get("cache_position", None) + if cache_position is None: + cache_position = torch.arange( + past_length, past_length + position_ids.shape[-1], device=position_ids.device + ) # if `inputs_embeds` are passed, we only want to use them in the 1st generation step if inputs_embeds is not None and past_key_values is None: @@ -1241,6 +1256,7 @@ def prepare_inputs_for_generation( model_inputs.update( { "position_ids": position_ids, + "cache_position": cache_position, "past_key_values": past_key_values, "use_cache": kwargs.get("use_cache"), "attention_mask": attention_mask, diff --git a/src/transformers/models/persimmon/modeling_persimmon.py b/src/transformers/models/persimmon/modeling_persimmon.py index 592d3e914106d0..f0de7ef29346ea 100644 --- a/src/transformers/models/persimmon/modeling_persimmon.py +++ b/src/transformers/models/persimmon/modeling_persimmon.py @@ -823,7 +823,6 @@ def forward( attentions=outputs.attentions, ) - # Copied from transformers.models.llama.modeling_llama.LlamaForCausalLM.prepare_inputs_for_generation def prepare_inputs_for_generation( self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs ): @@ -864,12 +863,6 @@ def prepare_inputs_for_generation( if past_key_values: position_ids = position_ids[:, -input_ids.shape[1] :] - if past_key_value := getattr(self.model.layers[0].self_attn, "past_key_value", None): - # generation with static cache - seen_tokens = past_key_value.get_seq_length() - input_ids = input_ids[:, seen_tokens:] - position_ids = position_ids[:, seen_tokens:] - # if `inputs_embeds` are passed, we only want to use them in the 1st generation step if inputs_embeds is not None and past_key_values is None: model_inputs = {"inputs_embeds": inputs_embeds} diff --git a/src/transformers/models/phi/modeling_phi.py b/src/transformers/models/phi/modeling_phi.py index 2f4bfbad89a475..799fe02c8f48d6 100644 --- a/src/transformers/models/phi/modeling_phi.py +++ b/src/transformers/models/phi/modeling_phi.py @@ -1084,7 +1084,7 @@ def forward( attentions=outputs.attentions, ) - # Copied from transformers.models.llama.modeling_llama.LlamaForCausalLM.prepare_inputs_for_generation + # Copied from transformers.models.persimmon.modeling_persimmon.PersimmonForCausalLM.prepare_inputs_for_generation def prepare_inputs_for_generation( self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs ): @@ -1125,12 +1125,6 @@ def prepare_inputs_for_generation( if past_key_values: position_ids = position_ids[:, -input_ids.shape[1] :] - if past_key_value := getattr(self.model.layers[0].self_attn, "past_key_value", None): - # generation with static cache - seen_tokens = past_key_value.get_seq_length() - input_ids = input_ids[:, seen_tokens:] - position_ids = position_ids[:, seen_tokens:] - # if `inputs_embeds` are passed, we only want to use them in the 1st generation step if inputs_embeds is not None and past_key_values is None: model_inputs = {"inputs_embeds": inputs_embeds} diff --git a/src/transformers/models/stablelm/modeling_stablelm.py b/src/transformers/models/stablelm/modeling_stablelm.py index 06d34bcc92d4ab..9baaac1f513505 100755 --- a/src/transformers/models/stablelm/modeling_stablelm.py +++ b/src/transformers/models/stablelm/modeling_stablelm.py @@ -1048,7 +1048,6 @@ def forward( attentions=outputs.attentions, ) - # Copied from transformers.models.llama.modeling_llama.LlamaForCausalLM.prepare_inputs_for_generation def prepare_inputs_for_generation( self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs ): @@ -1089,12 +1088,6 @@ def prepare_inputs_for_generation( if past_key_values: position_ids = position_ids[:, -input_ids.shape[1] :] - if past_key_value := getattr(self.model.layers[0].self_attn, "past_key_value", None): - # generation with static cache - seen_tokens = past_key_value.get_seq_length() - input_ids = input_ids[:, seen_tokens:] - position_ids = position_ids[:, seen_tokens:] - # if `inputs_embeds` are passed, we only want to use them in the 1st generation step if inputs_embeds is not None and past_key_values is None: model_inputs = {"inputs_embeds": inputs_embeds} diff --git a/tests/test_cache_utils.py b/tests/test_cache_utils.py index c6a07bb268b753..5f3af2acf5723c 100644 --- a/tests/test_cache_utils.py +++ b/tests/test_cache_utils.py @@ -143,7 +143,7 @@ def _random_kvs(config): mha_config = LlamaConfig(num_attention_heads=32) mha_static_cache = StaticCache(config=mha_config, max_batch_size=1, max_cache_len=10, device=torch_device) cached_keys, cached_values = mha_static_cache.update( - *_random_kvs(mha_config), 0, cache_kwargs={"position_ids": torch.arange(1)} + *_random_kvs(mha_config), 0, cache_kwargs={"cache_position": torch.arange(1)} ) self.assertTrue(cached_keys.shape == (1, 32, 10, 128)) self.assertTrue(cached_values.shape == (1, 32, 10, 128)) @@ -151,7 +151,7 @@ def _random_kvs(config): gqa_config = LlamaConfig(num_attention_heads=32, num_key_value_heads=4) gqa_static_cache = StaticCache(config=gqa_config, max_batch_size=1, max_cache_len=10, device=torch_device) cached_keys, cached_values = gqa_static_cache.update( - *_random_kvs(gqa_config), 0, cache_kwargs={"position_ids": torch.arange(1)} + *_random_kvs(gqa_config), 0, cache_kwargs={"cache_position": torch.arange(1)} ) self.assertTrue(cached_keys.shape == (1, 4, 10, 128)) self.assertTrue(cached_values.shape == (1, 4, 10, 128)) @@ -159,7 +159,7 @@ def _random_kvs(config): mqa_config = LlamaConfig(num_attention_heads=32, num_key_value_heads=1) mqa_static_cache = StaticCache(config=mqa_config, max_batch_size=1, max_cache_len=10, device=torch_device) cached_keys, cached_values = mqa_static_cache.update( - *_random_kvs(mqa_config), 0, cache_kwargs={"position_ids": torch.arange(1)} + *_random_kvs(mqa_config), 0, cache_kwargs={"cache_position": torch.arange(1)} ) self.assertTrue(cached_keys.shape == (1, 1, 10, 128)) self.assertTrue(cached_values.shape == (1, 1, 10, 128))