diff --git a/src/transformers/cache_utils.py b/src/transformers/cache_utils.py index ad91edfcbb50b2..42a5ee0a66dd1e 100644 --- a/src/transformers/cache_utils.py +++ b/src/transformers/cache_utils.py @@ -935,3 +935,39 @@ def get_max_length(self) -> Optional[int]: def reset(self): self.key_cache.zero_() self.value_cache.zero_() + + +class OneShotStaticCache(StaticCache): + """ + OneShotStaticCache is for cases where we update the cache only once and the cache remains constant after, it's useful in + encoder decoder models where we need to cache the key and value states of cross attention layer, in which case we only need + to compute and update the cache in the first generation step. + + Parameters: + config (`PretrainedConfig): + The configuration file defining the shape-related attributes required to initialize the static cache. + max_batch_size (`int`): + The maximum batch size with which the model will be used. + max_cache_len (`int`): + The maximum sequence length with which the model will be used. + device (`torch.device`): + The device on which the cache should be initialized. Should be the same as the layer. + dtype (*optional*, defaults to `torch.float32`): + The default `dtype` to use when initializing the layer. + """ + def __init__(self, config: PretrainedConfig, max_batch_size: int, max_cache_len: int, device, dtype=None) -> None: + super().__init__(config, max_batch_size, max_cache_len, device, dtype) + self.cache_filled = [False for _ in range(config.num_hidden_layers)] + + def update(self, key_states: torch.Tensor, value_states: torch.Tensor, layer_idx: int, cache_kwargs: Dict[str, Any] | None = None) -> Tuple[torch.Tensor]: + if self.cache_filled[layer_idx]: + return self.key_cache[layer_idx], self.value_cache[layer_idx] + self.cache_filled[layer_idx] = True + return super().update(key_states, value_states, layer_idx, cache_kwargs) + + def query_cache_filled_status(self, layer_idx: int) -> bool: + return self.cache_filled[layer_idx] + + def reset(self): + super().reset() + self.cache_filled = [False for _ in range(len(self.cache_filled))] \ No newline at end of file diff --git a/src/transformers/generation/utils.py b/src/transformers/generation/utils.py index 84c9dd995eb4f1..6c9643fa871919 100644 --- a/src/transformers/generation/utils.py +++ b/src/transformers/generation/utils.py @@ -32,6 +32,7 @@ QuantoQuantizedCache, SlidingWindowCache, StaticCache, + OneShotStaticCache ) from ..integrations.deepspeed import is_deepspeed_zero3_enabled from ..modeling_outputs import CausalLMOutputWithPast, Seq2SeqLMOutput @@ -111,7 +112,7 @@ if is_accelerate_available(): from accelerate.hooks import AlignDevicesHook, add_hook_to_module -NEED_SETUP_CACHE_CLASSES_MAPPING = {"static": StaticCache, "sliding_window": SlidingWindowCache} +NEED_SETUP_CACHE_CLASSES_MAPPING = {"static": StaticCache, "sliding_window": SlidingWindowCache, "one_shot": OneShotStaticCache} QUANT_BACKEND_CLASSES_MAPPING = {"quanto": QuantoQuantizedCache, "HQQ": HQQQuantizedCache} @@ -1348,10 +1349,15 @@ def _get_initial_cache_position(self, input_ids, model_kwargs): past_length = 0 if "past_key_values" in model_kwargs: - if isinstance(model_kwargs["past_key_values"], Cache): - past_length = model_kwargs["past_key_values"].get_seq_length() + past_key_values = model_kwargs["past_key_values"] + # double cache case in encoder decoder arch + if isinstance(past_key_values, tuple) and isinstance(past_key_values[0], Cache): + past_key_values = past_key_values[0] + + if isinstance(past_key_values, Cache): + past_length = past_key_values.get_seq_length() else: - past_length = model_kwargs["past_key_values"][0][0].shape[2] + past_length = past_key_values[0][0].shape[2] if "inputs_embeds" in model_kwargs: cur_len = model_kwargs["inputs_embeds"].shape[1] else: @@ -1359,7 +1365,7 @@ def _get_initial_cache_position(self, input_ids, model_kwargs): model_kwargs["cache_position"] = torch.arange(past_length, cur_len, device=input_ids.device) return model_kwargs - def _get_cache(self, cache_implementation: str, max_batch_size: int, max_cache_len: int) -> Cache: + def _get_cache(self, cache_implementation: str, max_batch_size: int, max_cache_len: int, cache_name = '_cache') -> Cache: """ Sets a cache for `generate`, that will persist across calls. A new cache will only be initialized a new `generate` call requires a larger cache. @@ -1367,34 +1373,36 @@ def _get_cache(self, cache_implementation: str, max_batch_size: int, max_cache_l Returns the resulting cache object. """ cache_cls: Cache = NEED_SETUP_CACHE_CLASSES_MAPPING[cache_implementation] - need_new_cache = ( - not hasattr(self, "_cache") - or (not isinstance(self._cache, cache_cls)) - or self._cache.max_batch_size < max_batch_size - ) - if cache_implementation == "sliding_window": - need_new_cache = need_new_cache or ( - self._cache.sliding_window_size < self._cache.model_sliding_window_size - and max_cache_len > self._cache.max_cache_len - ) - elif cache_implementation == "static": - need_new_cache = need_new_cache or self._cache.max_cache_len < max_cache_len + need_new_cache = not hasattr(self, cache_name) + + if not need_new_cache: + current_cache: Cache = getattr(self, cache_name) + need_new_cache = not isinstance(current_cache, cache_cls) or current_cache.max_batch_size != max_batch_size + if cache_implementation == "sliding_window": + need_new_cache = need_new_cache or ( + current_cache.sliding_window_size < current_cache.model_sliding_window_size + and max_cache_len > current_cache.max_cache_len + ) + elif cache_implementation == "static": + need_new_cache = need_new_cache or current_cache.max_cache_len < max_cache_len + elif cache_implementation == "one_shot": + need_new_cache = need_new_cache or current_cache.max_cache_len != max_cache_len if need_new_cache: if hasattr(self.config, "_pre_quantization_dtype"): cache_dtype = self.config._pre_quantization_dtype else: cache_dtype = self.dtype - self._cache = cache_cls( + setattr(self, cache_name, cache_cls( config=self.config, max_batch_size=max_batch_size, max_cache_len=max_cache_len, device=self.device, dtype=cache_dtype, - ) + )) else: - self._cache.reset() - return self._cache + current_cache.reset() + return getattr(self, cache_name) def _get_decoder_start_token_id( self, decoder_start_token_id: Union[int, List[int]] = None, bos_token_id: int = None @@ -1681,9 +1689,14 @@ def generate( "This model does not support `cache_implementation='static'`. Please check the following " "issue: https://github.com/huggingface/transformers/issues/28981" ) - model_kwargs["past_key_values"] = self._get_cache( - generation_config.cache_implementation, batch_size, generation_config.max_length - ) + model_kwargs["past_key_values"] = self._get_cache(generation_config.cache_implementation, batch_size, generation_config.max_length) + if self.config.is_encoder_decoder: + # manually set another cache for cross attention + encoder_outputs = model_kwargs["encoder_outputs"][0] + model_kwargs["past_key_values"] = ( + model_kwargs["past_key_values"], + self._get_cache("one_shot", encoder_outputs.shape[0], encoder_outputs.shape[1], '_cross_attn_cache') + ) elif generation_config.cache_implementation == "quantized": if not self._supports_quantized_cache: raise ValueError( @@ -2454,7 +2467,6 @@ def _sample( this_peer_finished = False unfinished_sequences = torch.ones(batch_size, dtype=torch.long, device=input_ids.device) model_kwargs = self._get_initial_cache_position(input_ids, model_kwargs) - while self._has_unfinished_sequences(this_peer_finished, synced_gpus, device=input_ids.device): # prepare model inputs model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs) diff --git a/src/transformers/models/whisper/configuration_whisper.py b/src/transformers/models/whisper/configuration_whisper.py index e7c0f47b58e587..b2e85d09b8dee3 100644 --- a/src/transformers/models/whisper/configuration_whisper.py +++ b/src/transformers/models/whisper/configuration_whisper.py @@ -189,7 +189,7 @@ class WhisperConfig(PretrainedConfig): model_type = "whisper" keys_to_ignore_at_inference = ["past_key_values"] - attribute_map = {"num_attention_heads": "encoder_attention_heads", "hidden_size": "d_model"} + attribute_map = {"num_key_value_heads": "encoder_attention_heads", "num_attention_heads": "encoder_attention_heads", "hidden_size": "d_model"} def __init__( self, diff --git a/src/transformers/models/whisper/generation_whisper.py b/src/transformers/models/whisper/generation_whisper.py index f30cfe19476504..8074d983b3b204 100644 --- a/src/transformers/models/whisper/generation_whisper.py +++ b/src/transformers/models/whisper/generation_whisper.py @@ -221,7 +221,7 @@ def _extract_token_timestamps(self, generate_outputs, alignment_heads, time_prec else: # num_frames is of shape (batch_size,) whereas batch_size is truely batch_size*num_return_sequences repeat_time = batch_size if isinstance(num_frames, int) else batch_size // len(num_frames) - num_frames = np.repeat(num_frames, repeat_time) + num_frames = np.repeat(num_frames.cpu() if torch.is_tensor(num_frames) else num_frames, repeat_time) if num_frames is None or isinstance(num_frames, int): # Normalize and smoothen the weights. diff --git a/src/transformers/models/whisper/modeling_whisper.py b/src/transformers/models/whisper/modeling_whisper.py index d2a7107c1eeb98..244a7c210a3b03 100644 --- a/src/transformers/models/whisper/modeling_whisper.py +++ b/src/transformers/models/whisper/modeling_whisper.py @@ -25,7 +25,9 @@ from torch.nn import CrossEntropyLoss from ...activations import ACT2FN -from ...modeling_attn_mask_utils import _prepare_4d_causal_attention_mask, _prepare_4d_causal_attention_mask_for_sdpa +from ...cache_utils import Cache, DynamicCache, StaticCache, OneShotStaticCache +from ...modeling_attn_mask_utils import AttentionMaskConverter + from ...modeling_outputs import ( BaseModelOutput, BaseModelOutputWithPastAndCrossAttentions, @@ -245,6 +247,7 @@ def __init__( bias: bool = True, is_causal: bool = False, config: Optional[WhisperConfig] = None, + layer_idx: Optional[int] = None, ): super().__init__() self.embed_dim = embed_dim @@ -252,6 +255,13 @@ def __init__( self.dropout = dropout self.head_dim = embed_dim // num_heads self.config = config + self.layer_idx = layer_idx + if is_decoder and layer_idx is None: + logger.warning_once( + f"Instantiating {self.__class__.__name__} without passing a `layer_idx` is not recommended and will " + "lead to errors during the forward call if caching is used. Please make sure to provide a `layer_idx` " + "when creating this class." + ) if (self.head_dim * num_heads) != self.embed_dim: raise ValueError( @@ -267,19 +277,16 @@ def __init__( self.q_proj = nn.Linear(embed_dim, embed_dim, bias=bias) self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias) - # Copied from transformers.models.bart.modeling_bart.BartAttention._shape with BART->whisper - def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int): - return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous() - # Copied from transformers.models.bart.modeling_bart.BartAttention.forward with BART->whisper def forward( self, hidden_states: torch.Tensor, key_value_states: Optional[torch.Tensor] = None, - past_key_value: Optional[Tuple[torch.Tensor]] = None, + past_key_value: Optional[Cache] = None, attention_mask: Optional[torch.Tensor] = None, layer_head_mask: Optional[torch.Tensor] = None, output_attentions: bool = False, + cache_position: Optional[torch.LongTensor] = None, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: """Input shape: Batch x Time x Channel""" @@ -290,66 +297,42 @@ def forward( bsz, tgt_len, _ = hidden_states.size() # get query proj - query_states = self.q_proj(hidden_states) * self.scaling - # get key, value proj - # `past_key_value[0].shape[2] == key_value_states.shape[1]` - # is checking that the `sequence_length` of the `past_key_value` is the same as - # the provided `key_value_states` to support prefix tuning - if ( - is_cross_attention - and past_key_value is not None - and past_key_value[0].shape[2] == key_value_states.shape[1] - ): - # reuse k,v, cross_attentions - key_states = past_key_value[0] - value_states = past_key_value[1] - elif is_cross_attention: - # cross_attentions - key_states = self._shape(self.k_proj(key_value_states), -1, bsz) - value_states = self._shape(self.v_proj(key_value_states), -1, bsz) - elif past_key_value is not None: - # reuse k, v, self_attention - key_states = self._shape(self.k_proj(hidden_states), -1, bsz) - value_states = self._shape(self.v_proj(hidden_states), -1, bsz) - key_states = torch.cat([past_key_value[0], key_states], dim=2) - value_states = torch.cat([past_key_value[1], value_states], dim=2) + query_states = self.q_proj(hidden_states) + need_update_cache = self.is_decoder and past_key_value is not None + + # reuse cross attn kv cache except for the first generation step + # case 1: DynamicCache, check if we have computed the current layer + # case 2: OneShotStaticCache, check if the flag has been set, for `torch.compile` + if is_cross_attention and ((isinstance(past_key_value, DynamicCache) and self.layer_idx < len(past_key_value.key_cache)) + or isinstance(past_key_value, OneShotStaticCache) and past_key_value.query_cache_filled_status(self.layer_idx)): + key_states = past_key_value.key_cache[self.layer_idx] + value_states = past_key_value.value_cache[self.layer_idx] + need_update_cache = False else: - # self_attention - key_states = self._shape(self.k_proj(hidden_states), -1, bsz) - value_states = self._shape(self.v_proj(hidden_states), -1, bsz) - - if self.is_decoder: - # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states. - # Further calls to cross_attention layer can then reuse all cross-attention - # key/value_states (first "if" case) - # if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of - # all previous decoder key/value_states. Further calls to uni-directional self-attention - # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case) - # if encoder bi-directional self-attention `past_key_value` is always `None` - past_key_value = (key_states, value_states) - - proj_shape = (bsz * self.num_heads, -1, self.head_dim) - query_states = self._shape(query_states, tgt_len, bsz).view(*proj_shape) - key_states = key_states.reshape(*proj_shape) - value_states = value_states.reshape(*proj_shape) - - src_len = key_states.size(1) - attn_weights = torch.bmm(query_states, key_states.transpose(1, 2)) - - if attn_weights.size() != (bsz * self.num_heads, tgt_len, src_len): - raise ValueError( - f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is" - f" {attn_weights.size()}" - ) - + if is_cross_attention: + key_states = self.k_proj(key_value_states) + value_states = self.v_proj(key_value_states) + else: + key_states = self.k_proj(hidden_states) + value_states = self.v_proj(hidden_states) + key_states = key_states.view(bsz, -1, self.num_heads, self.head_dim).transpose(1,2) + value_states = value_states.view(bsz, -1, self.num_heads, self.head_dim).transpose(1,2) + + if need_update_cache: + # when we do cross attention we need to set cache_position for source sequence + if is_cross_attention: + cache_position = torch.arange(0, key_value_states.shape[1], dtype=cache_position.dtype, device=cache_position.device) + cache_kwargs = {"cache_position" : cache_position} + key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) + + query_states = query_states.view(bsz, tgt_len, self.num_heads, self.head_dim).transpose(1, 2) + attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) * self.scaling + + # this will be skipped in cross attention where attention_mask is None if attention_mask is not None: - if attention_mask.size() != (bsz, 1, tgt_len, src_len): - raise ValueError( - f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is {attention_mask.size()}" - ) - attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attention_mask - attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len) - + causal_mask = attention_mask[:, :, :, : key_states.shape[-2]] + attn_weights = attn_weights + causal_mask + attn_weights = nn.functional.softmax(attn_weights, dim=-1) if layer_head_mask is not None: @@ -358,31 +341,19 @@ def forward( f"Head mask for a single layer should be of size {(self.num_heads,)}, but is" f" {layer_head_mask.size()}" ) - attn_weights = layer_head_mask.view(1, -1, 1, 1) * attn_weights.view(bsz, self.num_heads, tgt_len, src_len) - attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len) - - if output_attentions: - # this operation is a bit awkward, but it's required to - # make sure that attn_weights keeps its gradient. - # In order to do so, attn_weights have to be reshaped - # twice and have to be reused in the following - attn_weights_reshaped = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) - attn_weights = attn_weights_reshaped.view(bsz * self.num_heads, tgt_len, src_len) - else: - attn_weights_reshaped = None + attn_weights = layer_head_mask.view(1, -1, 1, 1) * attn_weights attn_probs = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training) - attn_output = torch.bmm(attn_probs, value_states) + attn_output = torch.matmul(attn_probs, value_states) - if attn_output.size() != (bsz * self.num_heads, tgt_len, self.head_dim): + if attn_output.size() != (bsz, self.num_heads, tgt_len, self.head_dim): raise ValueError( - f"`attn_output` should be of size {(bsz * self.num_heads, tgt_len, self.head_dim)}, but is" + f"`attn_output` should be of size {(bsz, self.num_heads, tgt_len, self.head_dim)}, but is" f" {attn_output.size()}" ) - attn_output = attn_output.view(bsz, self.num_heads, tgt_len, self.head_dim) - attn_output = attn_output.transpose(1, 2) + attn_output = attn_output.transpose(1, 2).contiguous() # Use the `embed_dim` from the config (stored in the class) rather than `hidden_state` because `attn_output` can be # partitioned across GPUs when using tensor-parallelism. @@ -390,7 +361,10 @@ def forward( attn_output = self.out_proj(attn_output) - return attn_output, attn_weights_reshaped, past_key_value + if not output_attentions: + attn_weights = None + + return attn_output, attn_weights, past_key_value # Copied from transformers.models.bart.modeling_bart.BartFlashAttention2 with Bart->Whisper @@ -410,22 +384,27 @@ def __init__(self, *args, **kwargs): # Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left). self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10() - def _reshape(self, tensor: torch.Tensor, seq_len: int, bsz: int): - return tensor.view(bsz, seq_len, self.num_heads, self.head_dim) - def forward( self, hidden_states: torch.Tensor, key_value_states: Optional[torch.Tensor] = None, - past_key_value: Optional[Tuple[torch.Tensor]] = None, + past_key_value: Optional[Cache] = None, attention_mask: Optional[torch.Tensor] = None, layer_head_mask: Optional[torch.Tensor] = None, output_attentions: bool = False, + cache_position: Optional[torch.LongTensor] = None, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: # WhisperFlashAttention2 attention does not support output_attentions if output_attentions: raise ValueError("WhisperFlashAttention2 attention does not support output_attentions") + if layer_head_mask is not None: + raise ValueError("WhisperFlashAttention2 attention does not support layer_head_mask") + if isinstance(past_key_value, StaticCache): + raise ValueError( + "`static` cache implementation is not compatible with `attn_implementation==flash_attention_2` " + "make sure to use `sdpa` in the mean time, and open an issue at https://github.com/huggingface/transformers" + ) # if key_value_states are provided this layer is used as a cross-attention layer # for the decoder is_cross_attention = key_value_states is not None @@ -433,47 +412,40 @@ def forward( bsz, q_len, _ = hidden_states.size() # get query proj - query_states = self._reshape(self.q_proj(hidden_states), -1, bsz) - # get key, value proj - # `past_key_value[0].shape[2] == key_value_states.shape[1]` - # is checking that the `sequence_length` of the `past_key_value` is the same as - # the provided `key_value_states` to support prefix tuning - if ( - is_cross_attention - and past_key_value is not None - and past_key_value[0].shape[2] == key_value_states.shape[1] - ): - # reuse k,v, cross_attentions - key_states = past_key_value[0].transpose(1, 2) - value_states = past_key_value[1].transpose(1, 2) - elif is_cross_attention: - # cross_attentions - key_states = self._reshape(self.k_proj(key_value_states), -1, bsz) - value_states = self._reshape(self.v_proj(key_value_states), -1, bsz) - elif past_key_value is not None: - # reuse k, v, self_attention - key_states = self._reshape(self.k_proj(hidden_states), -1, bsz) - value_states = self._reshape(self.v_proj(hidden_states), -1, bsz) - key_states = torch.cat([past_key_value[0].transpose(1, 2), key_states], dim=1) - value_states = torch.cat([past_key_value[1].transpose(1, 2), value_states], dim=1) + query_states = self.q_proj(hidden_states) + need_update_cache = self.is_decoder and past_key_value is not None + + # reuse cross attn kv cache except for the first generation step + # can only be DynamicCache becuase StaticCache is banned for flash attention + if is_cross_attention and isinstance(past_key_value, DynamicCache) and self.layer_idx < len(past_key_value.key_cache): + key_states = past_key_value.key_cache[self.layer_idx] + value_states = past_key_value.value_cache[self.layer_idx] + need_update_cache = False else: - # self_attention - key_states = self._reshape(self.k_proj(hidden_states), -1, bsz) - value_states = self._reshape(self.v_proj(hidden_states), -1, bsz) - - if self.is_decoder: - # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states. - # Further calls to cross_attention layer can then reuse all cross-attention - # key/value_states (first "if" case) - # if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of - # all previous decoder key/value_states. Further calls to uni-directional self-attention - # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case) - # if encoder bi-directional self-attention `past_key_value` is always `None` - past_key_value = (key_states.transpose(1, 2), value_states.transpose(1, 2)) - - kv_seq_len = key_states.shape[-2] - if past_key_value is not None: - kv_seq_len += past_key_value[0].shape[-2] + if is_cross_attention: + key_states = self.k_proj(key_value_states) + value_states = self.v_proj(key_value_states) + else: + key_states = self.k_proj(hidden_states) + value_states = self.v_proj(hidden_states) + key_states = key_states.view(bsz, -1, self.num_heads, self.head_dim).transpose(1,2) + value_states = value_states.view(bsz, -1, self.num_heads, self.head_dim).transpose(1,2) + + if need_update_cache: + # when we do cross attention we need to set cache_position for source sequence + if is_cross_attention: + cache_position = torch.arange(0, key_value_states.shape[1], dtype=cache_position.dtype, device=cache_position.device) + cache_kwargs = {"cache_position" : cache_position} + key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) + + query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim) + + + # TODO: These transpose are quite inefficient but Flash Attention requires the layout [batch_size, sequence_length, num_heads, head_dim]. We would need to refactor the KV cache + # to be able to avoid many of these transpose/reshape/view. + key_states = key_states.transpose(1, 2) + value_states = value_states.transpose(1, 2) + # In PEFT, usually we cast the layer norms in float32 for training stability reasons # therefore the input hidden states gets silently casted in float32. Hence, we need @@ -619,10 +591,11 @@ def forward( self, hidden_states: torch.Tensor, key_value_states: Optional[torch.Tensor] = None, - past_key_value: Optional[Tuple[torch.Tensor]] = None, + past_key_value: Optional[Cache] = None, attention_mask: Optional[torch.Tensor] = None, layer_head_mask: Optional[torch.Tensor] = None, output_attentions: bool = False, + cache_position: Optional[torch.LongTensor] = None, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: """Input shape: Batch x Time x Channel""" if output_attentions or layer_head_mask is not None: @@ -638,83 +611,71 @@ def forward( attention_mask=attention_mask, layer_head_mask=layer_head_mask, output_attentions=output_attentions, + cache_position=cache_position, ) # if key_value_states are provided this layer is used as a cross-attention layer # for the decoder - is_cross_attention = key_value_states is not None - + is_cross_attention = True if key_value_states is not None else False bsz, tgt_len, _ = hidden_states.size() - # get query proj query_states = self.q_proj(hidden_states) - # get key, value proj - # `past_key_value[0].shape[2] == key_value_states.shape[1]` - # is checking that the `sequence_length` of the `past_key_value` is the same as - # the provided `key_value_states` to support prefix tuning - if ( - is_cross_attention - and past_key_value is not None - and past_key_value[0].shape[2] == key_value_states.shape[1] - ): - # reuse k,v, cross_attentions - key_states = past_key_value[0] - value_states = past_key_value[1] - elif is_cross_attention: - # cross_attentions - key_states = self._shape(self.k_proj(key_value_states), -1, bsz) - value_states = self._shape(self.v_proj(key_value_states), -1, bsz) - elif past_key_value is not None: - # reuse k, v, self_attention - key_states = self._shape(self.k_proj(hidden_states), -1, bsz) - value_states = self._shape(self.v_proj(hidden_states), -1, bsz) - key_states = torch.cat([past_key_value[0], key_states], dim=2) - value_states = torch.cat([past_key_value[1], value_states], dim=2) + need_update_cache = self.is_decoder and past_key_value is not None + + # reuse cross attn kv cache except for the first generation step + # case 1: DynamicCache, check if we have computed the current layer + # case 2: OneShotStaticCache, check if the flag has been set, for `torch.compile` + if is_cross_attention and ((isinstance(past_key_value, DynamicCache) and self.layer_idx < len(past_key_value.key_cache)) + or isinstance(past_key_value, OneShotStaticCache) and past_key_value.query_cache_filled_status(self.layer_idx)): + key_states = past_key_value.key_cache[self.layer_idx] + value_states = past_key_value.value_cache[self.layer_idx] + need_update_cache = False else: - # self_attention - key_states = self._shape(self.k_proj(hidden_states), -1, bsz) - value_states = self._shape(self.v_proj(hidden_states), -1, bsz) - - if self.is_decoder: - # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states. - # Further calls to cross_attention layer can then reuse all cross-attention - # key/value_states (first "if" case) - # if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of - # all previous decoder key/value_states. Further calls to uni-directional self-attention - # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case) - # if encoder bi-directional self-attention `past_key_value` is always `None` - past_key_value = (key_states, value_states) - - query_states = self._shape(query_states, tgt_len, bsz) + if is_cross_attention: + key_states = self.k_proj(key_value_states) + value_states = self.v_proj(key_value_states) + else: + key_states = self.k_proj(hidden_states) + value_states = self.v_proj(hidden_states) + key_states = key_states.view(bsz, -1, self.num_heads, self.head_dim).transpose(1,2) + value_states = value_states.view(bsz, -1, self.num_heads, self.head_dim).transpose(1,2) + + if need_update_cache: + # when we do cross attention we need to set cache_position for source sequence + if is_cross_attention: + cache_position = torch.arange(0, key_value_states.shape[1], dtype=cache_position.dtype, device=cache_position.device) + cache_kwargs = {"cache_position" : cache_position} + key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) + + query_states = query_states.view(bsz, tgt_len, self.num_heads, self.head_dim).transpose(1, 2) + + causal_mask = attention_mask + if attention_mask is not None: + causal_mask = causal_mask[:, :, :, : key_states.shape[-2]] + + # SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask, + # Reference: https://github.com/pytorch/pytorch/issues/112577. + if query_states.device.type == "cuda" and causal_mask is not None: + query_states = query_states.contiguous() + key_states = key_states.contiguous() + value_states = value_states.contiguous() # We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment # in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling. # The tgt_len > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not create a causal mask in case tgt_len == 1. - is_causal = True if self.is_causal and attention_mask is None and tgt_len > 1 else False - - # NOTE: SDPA with memory-efficient backend is currently (torch==2.1.2) bugged when using non-contiguous inputs and a custom attn_mask, - # but we are fine here as `_shape` do call `.contiguous()`. Reference: https://github.com/pytorch/pytorch/issues/112577 + is_causal = True if self.is_causal and causal_mask is None and tgt_len > 1 else False + attn_output = torch.nn.functional.scaled_dot_product_attention( query_states, key_states, value_states, - attn_mask=attention_mask, + attn_mask=causal_mask, dropout_p=self.dropout if self.training else 0.0, is_causal=is_causal, ) - if attn_output.size() != (bsz, self.num_heads, tgt_len, self.head_dim): - raise ValueError( - f"`attn_output` should be of size {(bsz, self.num_heads, tgt_len, self.head_dim)}, but is" - f" {attn_output.size()}" - ) - - attn_output = attn_output.transpose(1, 2) - - # Use the `embed_dim` from the config (stored in the class) rather than `hidden_state` because `attn_output` can be - # partitioned across GPUs when using tensor-parallelism. - attn_output = attn_output.reshape(bsz, tgt_len, self.embed_dim) - + attn_output = attn_output.transpose(1, 2).contiguous() + attn_output = attn_output.view(bsz, tgt_len, self.embed_dim) attn_output = self.out_proj(attn_output) return attn_output, None, past_key_value @@ -800,7 +761,7 @@ def forward( # Copied from transformers.models.mbart.modeling_mbart.MBartDecoderLayer with MBart->Whisper, MBART->WHISPER class WhisperDecoderLayer(nn.Module): - def __init__(self, config: WhisperConfig): + def __init__(self, config: WhisperConfig, layer_idx: int): super().__init__() self.embed_dim = config.d_model @@ -811,6 +772,7 @@ def __init__(self, config: WhisperConfig): is_decoder=True, is_causal=True, config=config, + layer_idx=layer_idx, ) self.dropout = config.dropout self.activation_fn = ACT2FN[config.activation_function] @@ -823,6 +785,7 @@ def __init__(self, config: WhisperConfig): dropout=config.attention_dropout, is_decoder=True, config=config, + layer_idx=layer_idx, ) self.encoder_attn_layer_norm = nn.LayerNorm(self.embed_dim) self.fc1 = nn.Linear(self.embed_dim, config.decoder_ffn_dim) @@ -837,10 +800,11 @@ def forward( encoder_attention_mask: Optional[torch.Tensor] = None, layer_head_mask: Optional[torch.Tensor] = None, cross_attn_layer_head_mask: Optional[torch.Tensor] = None, - past_key_value: Optional[Tuple[torch.Tensor]] = None, + past_key_value: Optional[Tuple[Cache]] = None, output_attentions: Optional[bool] = False, use_cache: Optional[bool] = True, - ) -> torch.Tensor: + cache_position: Optional[torch.LongTensor] = None, + ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: """ Args: hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)` @@ -854,7 +818,7 @@ def forward( `(encoder_attention_heads,)`. cross_attn_layer_head_mask (`torch.FloatTensor`): mask for cross-attention heads in a given layer of size `(decoder_attention_heads,)`. - past_key_value (`Tuple(torch.FloatTensor)`): cached past key and value projection states + past_key_value (`Tuple(Cache)`): cached past key and value projection states output_attentions (`bool`, *optional*): Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned tensors for more detail. @@ -862,43 +826,41 @@ def forward( residual = hidden_states hidden_states = self.self_attn_layer_norm(hidden_states) + self_atten_key_value, cross_attn_key_value = None, None + if past_key_value is not None: + self_atten_key_value, cross_attn_key_value = past_key_value + # Self Attention - # decoder uni-directional self-attention cached key/values tuple is at positions 1,2 - self_attn_past_key_value = past_key_value[:2] if past_key_value is not None else None - # add present self-attn cache to positions 1,2 of present_key_value tuple - hidden_states, self_attn_weights, present_key_value = self.self_attn( + hidden_states, self_attn_weights, self_atten_key_value = self.self_attn( hidden_states=hidden_states, - past_key_value=self_attn_past_key_value, + past_key_value=self_atten_key_value, attention_mask=attention_mask, layer_head_mask=layer_head_mask, output_attentions=output_attentions, + cache_position=cache_position, ) hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) hidden_states = residual + hidden_states # Cross-Attention Block - cross_attn_present_key_value = None cross_attn_weights = None if encoder_hidden_states is not None: residual = hidden_states hidden_states = self.encoder_attn_layer_norm(hidden_states) # cross_attn cached key/values tuple is at positions 3,4 of present_key_value tuple - cross_attn_past_key_value = past_key_value[-2:] if past_key_value is not None else None - hidden_states, cross_attn_weights, cross_attn_present_key_value = self.encoder_attn( + hidden_states, cross_attn_weights, cross_attn_key_value = self.encoder_attn( hidden_states=hidden_states, key_value_states=encoder_hidden_states, attention_mask=encoder_attention_mask, layer_head_mask=cross_attn_layer_head_mask, - past_key_value=cross_attn_past_key_value, + past_key_value=cross_attn_key_value, output_attentions=output_attentions, + cache_position=cache_position, ) hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) hidden_states = residual + hidden_states - # add cross-attn to positions 3,4 of present_key_value tuple - present_key_value = present_key_value + cross_attn_present_key_value - # Fully Connected residual = hidden_states hidden_states = self.final_layer_norm(hidden_states) @@ -914,7 +876,7 @@ def forward( outputs += (self_attn_weights, cross_attn_weights) if use_cache: - outputs += (present_key_value,) + outputs += ((self_atten_key_value, cross_attn_key_value),) return outputs @@ -925,8 +887,12 @@ class WhisperPreTrainedModel(PreTrainedModel): main_input_name = "input_features" supports_gradient_checkpointing = True _no_split_modules = ["WhisperEncoderLayer", "WhisperDecoderLayer"] + _skip_keys_device_placement = "past_key_values" _supports_flash_attn_2 = True _supports_sdpa = True + _supports_cache_class = True + _supports_static_cache = True + def _init_weights(self, module): std = self.config.init_std @@ -1024,13 +990,17 @@ def _get_feat_extract_output_lengths(self, input_lengths: torch.LongTensor): Tuple consists of (`last_hidden_state`, *optional*: `hidden_states`, *optional*: `attentions`) `last_hidden_state` of shape `(batch_size, sequence_length, hidden_size)`, *optional*) is a sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention of the decoder. - past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): - Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape - `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of shape - `(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`. - - Contains pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention - blocks) that can be used (see `past_key_values` input) to speed up sequential decoding. + past_key_values (`Cache` or `tuple(Cache)` or `tuple(tuple(torch.FloatTensor))`, *optional*): + Pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention + blocks) that can be used to speed up sequential decoding. This typically consists in the `past_key_values` + returned by the model at a previous stage of decoding, when `use_cache=True` or `config.use_cache=True`. + + Three formats are allowed: + - `Cache` instance, when we do pure decoding without outputs from encoders; + - Tuple of [`~cache_utils.Cache`] instance, when we do generation based on outputs from encoders; + - Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of + shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of shape + `(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`. If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all @@ -1256,7 +1226,7 @@ def __init__(self, config: WhisperConfig): self.embed_tokens = nn.Embedding(config.vocab_size, config.d_model, self.padding_idx) self.embed_positions = WhisperPositionalEmbedding(self.max_target_positions, config.d_model) - self.layers = nn.ModuleList([WhisperDecoderLayer(config) for _ in range(config.decoder_layers)]) + self.layers = nn.ModuleList([WhisperDecoderLayer(config, idx) for idx in range(config.decoder_layers)]) self._use_flash_attention_2 = config._attn_implementation == "flash_attention_2" self._use_sdpa = config._attn_implementation == "sdpa" @@ -1286,6 +1256,7 @@ def forward( output_attentions=None, output_hidden_states=None, return_dict=None, + cache_position=None, ): r""" Args: @@ -1320,17 +1291,21 @@ def forward( - 1 indicates the head is **not masked**, - 0 indicates the head is **masked**. - past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): - Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of - shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of - shape `(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`. + past_key_values (`Cache` or `tuple(Cache)` or `tuple(tuple(torch.FloatTensor))`, *optional*): + Pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention + blocks) that can be used to speed up sequential decoding. This typically consists in the `past_key_values` + returned by the model at a previous stage of decoding, when `use_cache=True` or `config.use_cache=True`. - Contains pre-computed hidden-states (key and values in the self-attention blocks and in the - cross-attention blocks) that can be used (see `past_key_values` input) to speed up sequential decoding. + Three formats are allowed: + - `Cache` instance, when we do pure decoding without outputs from encoders; + - Tuple of [`~cache_utils.Cache`] instance, when we do generation based on outputs from encoders; + - Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of + shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of shape + `(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`. - If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those - that don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of - all `decoder_input_ids` of shape `(batch_size, sequence_length)`. + If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that + don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all + `decoder_input_ids` of shape `(batch_size, sequence_length)`. inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This is useful if you want more @@ -1344,6 +1319,10 @@ def forward( for more detail. return_dict (`bool`, *optional*): Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. + cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*): + Indices depicting the position of the input sequence tokens in the sequence. Contrarily to `position_ids`, + this tensor is not affected by padding. It is used to update the cache in the correct position and to infer + the complete sequence length. """ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( @@ -1352,56 +1331,73 @@ def forward( use_cache = use_cache if use_cache is not None else self.config.use_cache return_dict = return_dict if return_dict is not None else self.config.use_return_dict - # retrieve input_ids and inputs_embeds - if input_ids is not None and inputs_embeds is not None: - raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time") - elif input_ids is not None: - input_shape = input_ids.size() - input_ids = input_ids.view(-1, input_shape[-1]) - elif inputs_embeds is not None: - input_shape = inputs_embeds.size()[:-1] - else: - raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds") + if (input_ids is None) ^ (inputs_embeds is not None): + raise ValueError( + "You cannot specify both input_ids and inputs_embeds at the same time, and must specify either one" + ) - # past_key_values_length - past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0 + if self.gradient_checkpointing and self.training and use_cache: + logger.warning_once( + "`use_cache = True` is incompatible with gradient checkpointing. Setting `use_cache = False`..." + ) + use_cache = False if inputs_embeds is None: inputs_embeds = self.embed_tokens(input_ids) - if self._use_flash_attention_2: - # 2d mask is passed through the layers - attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None - elif self._use_sdpa and head_mask is None and not output_attentions: - # output_attentions=True & head_mask can not be supported when using SDPA. - attention_mask = _prepare_4d_causal_attention_mask_for_sdpa( - attention_mask, input_shape, inputs_embeds, past_key_values_length - ) - else: - # 4d mask is passed through the layers - attention_mask = _prepare_4d_causal_attention_mask( - attention_mask, input_shape, inputs_embeds, past_key_values_length - ) - - # embed positions - if input_ids is not None: - positions = self.embed_positions( - input_ids, past_key_values_length=past_key_values_length, position_ids=position_ids - ) - else: - positions = self.embed_positions( - inputs_embeds, past_key_values_length=past_key_values_length, position_ids=position_ids + return_legacy_cache = False + if use_cache: + if past_key_values is None: + return_legacy_cache = True + past_key_values = (DynamicCache(), DynamicCache() if encoder_hidden_states is not None else None) + elif isinstance(past_key_values, Cache): + if encoder_hidden_states is not None: + raise ValueError( + "Passing a single cache instance with `encoder_hidden_states` passed are not allowed, " + "you should considering passing in a tuple of two cache instances (`DynamicCache`, `DynamicCache`) " + "or (`StaticCache`, `OneShotStaticCache`) if you are using `torch.compile`" + ) + past_key_values = (past_key_values, None) + else: + assert isinstance(past_key_values, tuple) + if not isinstance(past_key_values[0], Cache): + # tuple of tuple of tensors + return_legacy_cache = True + cross_attn_kv_cache = None + if encoder_hidden_states is not None and len(past_key_values[0]) == 4: + cross_attn_kv_cache = DynamicCache.from_legacy_cache(tuple(past_key_value[2:] for past_key_value in past_key_values)) + + past_key_values = ( + DynamicCache.from_legacy_cache(tuple(past_key_value[:2] for past_key_value in past_key_values)), + cross_attn_kv_cache + ) + else: + # tuple of caches, do sanity check + if encoder_hidden_states is None or len(past_key_values) != 2: + raise ValueError( + "Please pass a single cache instance instead of a tuple if `encoder_hidden_states` is not passed ," + "and make sure to pass a tuple of two cache instances otherwise" + ) + + if cache_position is None: + past_seen_tokens = past_key_values[0].get_seq_length() if past_key_values is not None else 0 + cache_position = torch.arange( + past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device ) + + if position_ids is None: + position_ids = cache_position.unsqueeze(0) + positions = self.embed_positions(inputs_embeds, position_ids=position_ids) hidden_states = inputs_embeds + positions.to(inputs_embeds.device) hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) - - if self.gradient_checkpointing and self.training: - if use_cache: - logger.warning_once( - "`use_cache = True` is incompatible with gradient checkpointing. Setting `use_cache = False`..." - ) - use_cache = False + + use_head_mask = head_mask is not None + causal_mask = self._update_causal_mask( + attention_mask, inputs_embeds, cache_position, past_key_values[0] if past_key_values is not None else None, + use_head_mask, output_attentions + ) + # decoder layers all_hidden_states = () if output_hidden_states else None all_self_attns = () if output_attentions else None @@ -1424,13 +1420,11 @@ def forward( if dropout_probability < self.layerdrop: continue - past_key_value = past_key_values[idx] if past_key_values is not None else None - if self.gradient_checkpointing and self.training: layer_outputs = self._gradient_checkpointing_func( decoder_layer.__call__, hidden_states, - attention_mask, + causal_mask, encoder_hidden_states, None, # encoder attention mask head_mask[idx] if head_mask is not None else None, @@ -1442,21 +1436,21 @@ def forward( else: layer_outputs = decoder_layer( hidden_states, - attention_mask=attention_mask, + attention_mask=causal_mask, encoder_hidden_states=encoder_hidden_states, layer_head_mask=(head_mask[idx] if head_mask is not None else None), cross_attn_layer_head_mask=( cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None ), - past_key_value=past_key_value, + past_key_value=past_key_values, output_attentions=output_attentions, use_cache=use_cache, + cache_position=cache_position, ) hidden_states = layer_outputs[0] if use_cache: - next_decoder_cache += (layer_outputs[3 if output_attentions else 1],) - + next_decoder_cache = layer_outputs[3 if output_attentions else 1] if output_attentions: all_self_attns += (layer_outputs[1],) @@ -1469,6 +1463,20 @@ def forward( all_hidden_states += (hidden_states,) next_cache = next_decoder_cache if use_cache else None + if isinstance(next_cache, tuple) and next_cache[1] is None: + next_cache = next_cache[0] + + if return_legacy_cache: + if isinstance(next_cache, Cache): + # only one cache for decoding + next_cache = next_cache.to_legacy_cache() + else: + # two caches scenario + next_cache = tuple( + self_attn_kv + cross_attn_kv for self_attn_kv, cross_attn_kv in + zip(next_cache[0].to_legacy_cache(), next_cache[1].to_legacy_cache()) + ) + if not return_dict: return tuple( v @@ -1483,6 +1491,87 @@ def forward( cross_attentions=all_cross_attentions, ) + def _update_causal_mask( + self, + attention_mask: torch.Tensor, + input_tensor: torch.Tensor, + cache_position: torch.Tensor, + past_key_values: Cache, + use_head_mask: bool, + output_attentions: bool, + ): + # TODO: As of torch==2.2.0, the `attention_mask` passed to the model in `generate` is 2D and of dynamic length even when the static + # KV cache is used. This is an issue for torch.compile which then recaptures cudagraphs at each decode steps due to the dynamic shapes. + # (`recording cudagraph tree for symint key 13`, etc.), which is VERY slow. A workaround is `@torch.compiler.disable`, but this prevents using + # `fullgraph=True`. See more context in https://github.com/huggingface/transformers/pull/29114 + + if self.config._attn_implementation == "flash_attention_2": + if attention_mask is not None and 0.0 in attention_mask: + return attention_mask + return None + + # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in + # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail + # to infer the attention mask. + past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + using_static_cache = isinstance(past_key_values, StaticCache) + + # When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward + if self.config._attn_implementation == "sdpa" and not using_static_cache and not output_attentions and not use_head_mask: + if AttentionMaskConverter._ignore_causal_mask_sdpa( + attention_mask, + inputs_embeds=input_tensor, + past_key_values_length=past_seen_tokens, + is_training=self.training, + ): + return None + + dtype, device = input_tensor.dtype, input_tensor.device + min_dtype = torch.finfo(dtype).min + sequence_length = input_tensor.shape[1] + if using_static_cache: + target_length = past_key_values.get_max_length() + else: + target_length = ( + attention_mask.shape[-1] + if isinstance(attention_mask, torch.Tensor) + else past_seen_tokens + sequence_length + 1 + ) + + if attention_mask is not None and attention_mask.dim() == 4: + # in this case we assume that the mask comes already in inverted form and requires no inversion or slicing + if attention_mask.max() != 0: + raise ValueError("Custom 4D attention mask should be passed in inverted form with max==0`") + causal_mask = attention_mask + else: + causal_mask = torch.full( + (sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device + ) + if sequence_length != 1: + causal_mask = torch.triu(causal_mask, diagonal=1) + causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1) + causal_mask = causal_mask[None, None, :, :].expand(input_tensor.shape[0], 1, -1, -1) + if attention_mask is not None: + causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit + mask_length = attention_mask.shape[-1] + padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :] + padding_mask = padding_mask == 0 + causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill( + padding_mask, min_dtype + ) + if ( + self.config._attn_implementation == "sdpa" + and attention_mask is not None + and attention_mask.device.type == "cuda" + and not output_attentions + ): + # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when + # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path. + # Details: https://github.com/pytorch/pytorch/issues/110213 + causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype) + + return causal_mask + @add_start_docstrings( "The bare Whisper Model outputting raw hidden-states without any specific head on top.", @@ -1571,13 +1660,14 @@ def forward( decoder_head_mask: Optional[torch.Tensor] = None, cross_attn_head_mask: Optional[torch.Tensor] = None, encoder_outputs: Optional[Tuple[Tuple[torch.FloatTensor]]] = None, - past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None, + past_key_values: Optional[Union[Cache, Tuple[Union[Cache, Tuple[torch.FloatTensor]]]]] = None, decoder_inputs_embeds: Optional[Tuple[torch.FloatTensor]] = None, decoder_position_ids: Optional[Tuple[torch.LongTensor]] = None, use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, ) -> Union[Tuple[torch.Tensor], Seq2SeqModelOutput]: r""" Returns: @@ -1604,7 +1694,7 @@ def forward( ) use_cache = use_cache if use_cache is not None else self.config.use_cache return_dict = return_dict if return_dict is not None else self.config.use_return_dict - + # encoder_outputs will never be None in generate if encoder_outputs is None: input_features = self._mask_input_features(input_features, attention_mask=attention_mask) @@ -1637,6 +1727,7 @@ def forward( output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, + cache_position=cache_position, ) if not return_dict: @@ -1704,7 +1795,7 @@ def forward( decoder_head_mask: Optional[torch.Tensor] = None, cross_attn_head_mask: Optional[torch.Tensor] = None, encoder_outputs: Optional[Tuple[Tuple[torch.FloatTensor]]] = None, - past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None, + past_key_values: Optional[Union[Cache, Tuple[Union[Cache, Tuple[torch.FloatTensor]]]]] = None, decoder_inputs_embeds: Optional[Tuple[torch.FloatTensor]] = None, decoder_position_ids: Optional[Tuple[torch.LongTensor]] = None, labels: Optional[torch.LongTensor] = None, @@ -1712,6 +1803,7 @@ def forward( output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, ) -> Union[Tuple[torch.Tensor], Seq2SeqLMOutput]: r""" labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): @@ -1766,6 +1858,7 @@ def forward( output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, + cache_position=cache_position, ) lm_logits = self.proj_out(outputs[0]) @@ -1800,6 +1893,7 @@ def prepare_inputs_for_generation( encoder_outputs=None, attention_mask=None, decoder_attention_mask=None, + cache_position=None, **kwargs, ): decoder_position_ids = None @@ -1807,8 +1901,15 @@ def prepare_inputs_for_generation( decoder_position_ids = (decoder_attention_mask.cumsum(-1) - 1).clamp(min=0) if past_key_values is not None: - past_length = past_key_values[0][0].shape[2] + if isinstance(past_key_values, tuple) and isinstance(past_key_values[0], Cache): + past_self_attn_key_value = past_key_values[0] + else: + past_self_attn_key_value = past_key_values + if isinstance(past_self_attn_key_value, Cache): + past_length = cache_position[0] if cache_position is not None else past_self_attn_key_value.get_seq_length() + else: + past_length = past_self_attn_key_value[0][0].shape[2] # Some generation methods already pass only the last input ID if decoder_input_ids.shape[1] > past_length: remove_prefix_length = past_length @@ -1822,12 +1923,13 @@ def prepare_inputs_for_generation( decoder_position_ids = decoder_position_ids[:, remove_prefix_length:] return { + "decoder_input_ids": decoder_input_ids.contiguous(), "encoder_outputs": encoder_outputs, "past_key_values": past_key_values, - "decoder_input_ids": decoder_input_ids, "use_cache": use_cache, "decoder_attention_mask": decoder_attention_mask, "decoder_position_ids": decoder_position_ids, + "cache_position": cache_position, } @staticmethod @@ -1904,16 +2006,18 @@ def forward( self, input_ids: torch.LongTensor = None, attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, encoder_outputs: Optional[Tuple[torch.FloatTensor]] = None, head_mask: Optional[torch.Tensor] = None, cross_attn_head_mask: Optional[torch.Tensor] = None, - past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None, + past_key_values: Optional[Union[Cache, Tuple[Union[Cache, Tuple[torch.FloatTensor]]]]] = None, inputs_embeds: Optional[torch.FloatTensor] = None, labels: Optional[torch.LongTensor] = None, use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, ) -> Union[Tuple, CausalLMOutputWithCrossAttentions]: r""" Args: @@ -1926,6 +2030,9 @@ def forward( - 1 for tokens that are **not masked**, - 0 for tokens that are **masked**. [What are attention masks?](../glossary#attention-mask) + position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0, + config.n_positions - 1]`. encoder_outputs (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if the model is configured as a decoder. @@ -1937,14 +2044,19 @@ def forward( Mask to nullify selected heads of the cross-attention modules. Mask values selected in `[0, 1]`: - 1 indicates the head is **not masked**, - 0 indicates the head is **masked**. - past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): - Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of - shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of - shape `(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`. The two additional - tensors are only required when the model is used as a decoder in a Sequence to Sequence model. Contains - pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention - blocks) that can be used (see `past_key_values` input) to speed up sequential decoding. If - `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that + past_key_values (`Cache` or `tuple(Cache)` or `tuple(tuple(torch.FloatTensor))`, *optional*): + Pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention + blocks) that can be used to speed up sequential decoding. This typically consists in the `past_key_values` + returned by the model at a previous stage of decoding, when `use_cache=True` or `config.use_cache=True`. + + Three formats are allowed: + - `Cache` instance, when we do pure decoding without outputs from encoders; + - Tuple of [`~cache_utils.Cache`] instance, when we do generation based on outputs from encoders; + - Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of + shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of shape + `(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`. + + If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all `decoder_input_ids` of shape `(batch_size, sequence_length)`. inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): @@ -1968,7 +2080,10 @@ def forward( for more detail. return_dict (`bool`, *optional*): Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. - + cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*): + Indices depicting the position of the input sequence tokens in the sequence. Contrarily to `position_ids`, + this tensor is not affected by padding. It is used to update the cache in the correct position and to infer + the complete sequence length. Returns: Example: @@ -2015,10 +2130,12 @@ def forward( cross_attn_head_mask=cross_attn_head_mask, past_key_values=past_key_values, inputs_embeds=inputs_embeds, + position_ids=position_ids, use_cache=use_cache, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, + cache_position=cache_position, ) logits = self.proj_out(outputs[0]) @@ -2049,10 +2166,18 @@ def prepare_inputs_for_generation( use_cache=None, encoder_outputs=None, attention_mask=None, + cache_position=None, **kwargs, ): if past_key_values is not None: - past_length = past_key_values[0][0].shape[2] + if isinstance(past_key_values, tuple) and isinstance(past_key_values[0], Cache): + past_self_attn_key_value = past_key_values[0] + else: + past_self_attn_key_value = past_key_values + if isinstance(past_self_attn_key_value, Cache): + past_length = cache_position[0] if cache_position is not None else past_self_attn_key_value.get_seq_length() + else: + past_length = past_self_attn_key_value[0][0].shape[2] # Some generation methods already pass only the last input ID if input_ids.shape[1] > past_length: @@ -2062,13 +2187,13 @@ def prepare_inputs_for_generation( remove_prefix_length = input_ids.shape[1] - 1 input_ids = input_ids[:, remove_prefix_length:] - return { + "input_ids": input_ids.contiguous(), "encoder_outputs": encoder_outputs, "past_key_values": past_key_values, - "input_ids": input_ids, "use_cache": use_cache, "attention_mask": attention_mask, + "cache_position": cache_position, } @staticmethod diff --git a/tests/models/whisper/test_modeling_whisper.py b/tests/models/whisper/test_modeling_whisper.py index 70e37d76492db1..dedecd50ad422d 100644 --- a/tests/models/whisper/test_modeling_whisper.py +++ b/tests/models/whisper/test_modeling_whisper.py @@ -25,11 +25,13 @@ import numpy as np import pytest +from parameterized import parameterized from huggingface_hub import hf_hub_download import transformers from transformers import WhisperConfig from transformers.testing_utils import ( + is_flaky, is_pt_flax_cross_test, require_flash_attn, require_torch, @@ -897,7 +899,6 @@ def _check_outputs(self, output, input_ids, config, use_cache=False, num_return_ @slow def test_flash_attn_2_inference_equivalence(self): import torch - for model_class in self.all_model_classes: if not model_class._supports_flash_attn_2: return @@ -957,9 +958,11 @@ def test_flash_attn_2_inference_equivalence_right_padding(self): tmpdirname, torch_dtype=torch.float16, attn_implementation="flash_attention_2" ) model_fa.to(torch_device) + model_fa.eval() model = model_class.from_pretrained(tmpdirname, torch_dtype=torch.float16) model.to(torch_device) + model.eval() dummy_input = inputs_dict[model.main_input_name][:1] dummy_input = dummy_input.to(torch.float16) @@ -1535,6 +1538,46 @@ def test_longform_generate_multi_batch(self): def test_longform_generate_multi_batch_cond_prev(self): self._check_longform_generate_multi_batch(condition_on_prev_tokens=True) + + @unittest.skip("Skip this for now because conditional generation needs two caches") + @parameterized.expand([(1, False), (1, True), (4, False)]) + def test_new_cache_format(self, num_beams, do_sample): + pass + + @is_flaky(description="Flaky on conditional generation") + def test_custom_4d_attention_mask(self): + set_seed(0) + config, input_dict = self.model_tester.prepare_config_and_inputs_for_common() + model = WhisperForConditionalGeneration(config).eval().to(device=torch_device, dtype=torch.float32) + ( + input_ids, + position_ids, + input_ids_shared_prefix, + mask_shared_prefix, + position_ids_shared_prefix, + ) = self._get_custom_4d_mask_test_data() + + logits = model.forward(decoder_input_ids=input_ids, input_features=input_dict["input_features"], decoder_position_ids=position_ids).logits + # logits.shape == torch.Size([3, 4, ...]) + + logits_shared_prefix = model( + decoder_input_ids=input_ids_shared_prefix, + input_features=input_dict["input_features"], + attention_mask=mask_shared_prefix, + decoder_position_ids=position_ids_shared_prefix, + )[0] + # logits_shared_prefix.shape == torch.Size([1, 6, ...]) + + out_last_tokens = logits[:, -1, :] # last tokens in each batch line + out_shared_prefix_last_tokens = logits_shared_prefix[0, -3:, :] # last three tokens + + # comparing greedily-chosen tokens: + assert torch.equal(out_last_tokens.max(axis=1).indices, out_shared_prefix_last_tokens.max(axis=1).indices) + + # comparing softmax-normalized logits: + normalized_0 = torch.nn.functional.softmax(out_last_tokens) + normalized_1 = torch.nn.functional.softmax(out_shared_prefix_last_tokens) + torch.testing.assert_close(normalized_0, normalized_1, rtol=1e-3, atol=1e-4) @require_torch @@ -2866,6 +2909,53 @@ def test_whisper_longform_no_speech_detection(self): for i in range(num_samples): assert decoded_all[i] == EXPECTED_TEXT[i] + @slow + def test_compile_static_cache(self): + set_seed(0) + processor = WhisperProcessor.from_pretrained("openai/whisper-tiny.en") + model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-tiny.en") + model.to(torch_device) + + input_speech = self._load_datasamples(4) + input_features = processor(input_speech, return_tensors="pt", sampling_rate=16_000).input_features + input_features = input_features.to(torch_device) + + # fmt: off + EXPECTED_LOGITS = torch.tensor( + [ + [50257, 50362, 1770, 13, 2264, 346, 353, 318, 262, 46329, 286, 262, 3504, 6097, 11, 290, 356, 389, 9675, 284], + [50257, 50362, 5414, 318, 1770, 13, 2264, 346, 353, 338, 5642, 1342, 3499, 621, 465, 2300, 13, 50256, 50256, 50256], + [50257, 50362, 679, 4952, 514, 326, 379, 428, 43856, 1622, 286, 262, 614, 11, 351, 6786, 290, 32595, 12023, 28236], + [50257, 50362, 679, 468, 12296, 17188, 1771, 7361, 26113, 18881, 1122, 338, 670, 318, 1107, 8312, 706, 477, 290, 460] + ] + ) + EXPECTED_TRANSCRIPT = [ + " Mr. Quilter is the apostle of the middle classes, and we are glad to", + " Nor is Mr. Quilter's manner less interesting than his matter.", + " He tells us that at this festive season of the year, with Christmas and roast beef looming", + " He has grave doubts whether Sir Frederick Layton's work is really Greek after all and can", + ] + # fmt: on + + # dynamic cache + generated_ids = model.generate(input_features, max_length=20).to("cpu") + self.assertTrue(torch.allclose(generated_ids, EXPECTED_LOGITS)) + transcript = processor.batch_decode(generated_ids, skip_special_tokens=True) + self.assertListEqual(transcript, EXPECTED_TRANSCRIPT) + + # static cache + generated_ids = model.generate(input_features, max_length=20, cache_implementation="static").to("cpu") + self.assertTrue(torch.allclose(generated_ids, EXPECTED_LOGITS)) + transcript = processor.batch_decode(generated_ids, skip_special_tokens=True) + self.assertListEqual(transcript, EXPECTED_TRANSCRIPT) + + # torch.compile + static + model.forward = torch.compile(model.forward, fullgraph=True, mode="reduce-overhead") + generated_ids = model.generate(input_features, max_length=20, cache_implementation="static").to("cpu") + self.assertTrue(torch.allclose(generated_ids, EXPECTED_LOGITS)) + transcript = processor.batch_decode(generated_ids, skip_special_tokens=True) + self.assertListEqual(transcript, EXPECTED_TRANSCRIPT) + def prepare_whisper_encoder_inputs_dict(config, input_features, head_mask=None): if head_mask is None: diff --git a/tests/test_modeling_common.py b/tests/test_modeling_common.py index 30010cde9116dc..b9aaed78c54329 100755 --- a/tests/test_modeling_common.py +++ b/tests/test_modeling_common.py @@ -4413,6 +4413,8 @@ def test_custom_4d_attention_mask(self): if getattr(config, "sliding_window", 0) > 0: self.skipTest(f"{model_class.__name__} with sliding window attention is not supported by this test") model = model_class(config).to(device=torch_device, dtype=torch.float32) + # for models with un-certain ops like dropout in training mode + model = model.eval() ( input_ids,