From a9701953ffefcf8f10ec5a94dc2c55633648ba28 Mon Sep 17 00:00:00 2001 From: Sanchit Gandhi <93869735+sanchit-gandhi@users.noreply.github.com> Date: Tue, 2 Jul 2024 13:24:15 +0100 Subject: [PATCH] [whisper] static kv cache (#31166) * make work with cache abstraction * correct for static cache * hacks for compile * make fast * fix * fix pos ids * generate * fix sdpa * fix sdpa cache pos * fix fa2 * clean fa2 * integrate cache into generate * make style * copies * more copies * update eager * update sdpa * update fa2 * simplify * use cache pos * always compute cross-cache for debug * avoid recompiles Co-authored-by: Arthur Zucker * fix fix * fix fix fix * more fix * try encoder-decoder cache (too messy) * revert encoder-decoder cache * check cross-attn cache * use enc-dec dataclass * use richer enc-dec dataclass * clean-up * revert static cache changes * small fixes * revert to cpu flag * fix copies * add static slow test * past k/v docstring * more docstrings * cache_position docstrings * add to docs * add enc-dec cache to docs * make style * fix after rebase * fix beam * style * fix generation strategies * fix most decoder-only tests * style * skip test * more clean up * small docstrings * Apply suggestions from code review Co-authored-by: Joao Gante * add todo * only crop self-attn * check cache in mixin * style * fix re-compile after rebase * move `is_updated` logic to enc-dec wrapper * revert back * revert cache back * finalise design * fix * fix fix * style * Update src/transformers/cache_utils.py Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com> * deprecate * updates * final updates * style * style --------- Co-authored-by: Joao Gante Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com> --- docs/source/en/internal/generation_utils.md | 6 + docs/source/en/model_doc/whisper.md | 47 +- src/transformers/__init__.py | 2 + src/transformers/cache_utils.py | 160 ++++- src/transformers/generation/utils.py | 96 ++- .../models/whisper/configuration_whisper.py | 6 +- .../models/whisper/modeling_whisper.py | 552 +++++++++++------- src/transformers/utils/dummy_pt_objects.py | 7 + tests/generation/test_utils.py | 15 +- tests/models/whisper/test_modeling_whisper.py | 72 +++ 10 files changed, 705 insertions(+), 258 deletions(-) diff --git a/docs/source/en/internal/generation_utils.md b/docs/source/en/internal/generation_utils.md index 5bf8b5c4a0b36f..da7ea25e54b6b0 100644 --- a/docs/source/en/internal/generation_utils.md +++ b/docs/source/en/internal/generation_utils.md @@ -391,6 +391,12 @@ A [`Constraint`] can be used to force the generation to include specific tokens - get_seq_length - reset +[[autodoc]] EncoderDecoderCache + - get_seq_length + - to_legacy_cache + - from_legacy_cache + - reset + - reorder_cache ## Watermark Utils diff --git a/docs/source/en/model_doc/whisper.md b/docs/source/en/model_doc/whisper.md index 992ff71735db34..0565bd5aae111b 100644 --- a/docs/source/en/model_doc/whisper.md +++ b/docs/source/en/model_doc/whisper.md @@ -52,8 +52,6 @@ Here is a step-by-step guide to transcribing an audio sample using a pre-trained >>> # Select an audio file and read it: >>> ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation") >>> audio_sample = ds[0]["audio"] ->>> waveform = audio_sample["array"] ->>> sampling_rate = audio_sample["sampling_rate"] >>> # Load the Whisper model in Hugging Face format: >>> processor = WhisperProcessor.from_pretrained("openai/whisper-tiny.en") @@ -61,7 +59,7 @@ Here is a step-by-step guide to transcribing an audio sample using a pre-trained >>> # Use the model and processor to transcribe the audio: >>> input_features = processor( -... waveform, sampling_rate=sampling_rate, return_tensors="pt" +... audio_sample["array"], sampling_rate=audio_sample["sampling_rate"], return_tensors="pt" ... ).input_features >>> # Generate token ids @@ -74,6 +72,49 @@ Here is a step-by-step guide to transcribing an audio sample using a pre-trained ' Mr. Quilter is the apostle of the middle classes, and we are glad to welcome his gospel.' ``` +Whisper is compatible with the following optimisations: +- [PyTorch Scaled Dot Product Attention (SDPA)](../perf_infer_gpu_one#pytorch-scaled-dot-product-attention): flash attention and memory-efficient attention kernels. Enabled by default for `torch>=2.1.1`. +- [Flash Attention 2](../perf_infer_gpu_one#flashattention-2): improved implementation of flash attention through better parallelism and work partitioning. +- [torch.compile](../llm_optims#static-kv-cache-and-torchcompile): JIT-compile the forward pass to dispatch to efficient fused kernels. + +As an example, the following codesnippet enables SDPA and `torch.compile` for up to 5x faster inference: + +```python +>>> from datasets import load_dataset +>>> from transformers import WhisperProcessor, WhisperForConditionalGeneration + +>>> # Select an audio file and read it: +>>> ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation") +>>> audio_sample = ds[0]["audio"] + +>>> # Load the Whisper model with SDPA attention +>>> processor = WhisperProcessor.from_pretrained("openai/whisper-tiny.en") +>>> model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-tiny.en", attn_implementation="sdpa") + +>>> # Enable static cache and compile the forward pass +>>> model.generation_config.cache_implementation = "static" +>>> model.forward = torch.compile(model.forward, mode="reduce-overhead", fullgraph=True) + +>>> # Use the model and processor to transcribe the audio: +>>> input_features = processor( +... audio_sample["array"], sampling_rate=audio_sample["sampling_rate"], return_tensors="pt" +... ).input_features + +>>> # Compile the forward pass +>>> _ = model.generate(input_features) + +>>> # Generate token ids using compiled graph (fast!) +>>> predicted_ids = model.generate(input_features) + +>>> # Decode token ids to text +>>> transcription = processor.batch_decode(predicted_ids, skip_special_tokens=True) + +>>> transcription[0] +' Mr. Quilter is the apostle of the middle classes, and we are glad to welcome his gospel.' +``` + +For more details on each optimisation, refer to the documentation linked above. + ## Resources A list of official Hugging Face and community (indicated by 🌎) resources to help you get started with Whisper. If you're interested in submitting a resource to be included here, please feel free to open a Pull Request and we'll review it! The resource should ideally demonstrate something new instead of duplicating an existing resource. diff --git a/src/transformers/__init__.py b/src/transformers/__init__.py index c559ed61acad03..42c5b713c55aef 100755 --- a/src/transformers/__init__.py +++ b/src/transformers/__init__.py @@ -1212,6 +1212,7 @@ "Cache", "CacheConfig", "DynamicCache", + "EncoderDecoderCache", "HQQQuantizedCache", "QuantizedCache", "QuantizedCacheConfig", @@ -5895,6 +5896,7 @@ Cache, CacheConfig, DynamicCache, + EncoderDecoderCache, HQQQuantizedCache, QuantizedCache, QuantizedCacheConfig, diff --git a/src/transformers/cache_utils.py b/src/transformers/cache_utils.py index d572b8c8c71636..1f5a164815aaed 100644 --- a/src/transformers/cache_utils.py +++ b/src/transformers/cache_utils.py @@ -858,8 +858,12 @@ def update( k_out = self.key_cache[layer_idx] v_out = self.value_cache[layer_idx] - k_out[:, :, cache_position] = key_states - v_out[:, :, cache_position] = value_states + if cache_position is None: + k_out.copy_(key_states) + v_out.copy_(value_states) + else: + k_out[:, :, cache_position] = key_states + v_out[:, :, cache_position] = value_states return k_out, v_out @@ -971,6 +975,158 @@ def get_max_length(self) -> Optional[int]: # no matter how long the sentence is return None + def reset(self): + self.key_cache.zero_() + self.value_cache.zero_() + + +class EncoderDecoderCache(Cache): + """ + Base, abstract class for all encoder-decoder caches. Can be used to hold combinations of self-attention and + cross-attention caches. + """ + + def __init__(self, self_attention_cache: Cache, cross_attention_cache: Cache): + self.self_attention_cache = self_attention_cache + self.cross_attention_cache = cross_attention_cache + + self.is_updated = {} + for layer_idx in range(len(cross_attention_cache.key_cache)): + self.is_updated[layer_idx] = bool(cross_attention_cache.get_seq_length(layer_idx) > 0) + + def __getitem__(self, layer_idx: int) -> List[Tuple[torch.Tensor]]: + """ + Support for backwards-compatible `past_key_value` indexing, e.g. `past_key_value[0][0].shape[2]` to get the + sequence length. + """ + if layer_idx < len(self): + return ( + self.self_attention_cache.key_cache[layer_idx], + self.self_attention_cache.value_cache[layer_idx], + self.cross_attention_cache.key_cache[layer_idx], + self.cross_attention_cache.key_cache[layer_idx], + ) + else: + raise KeyError(f"Cache only has {len(self)} layers, attempted to access layer with index {layer_idx}") + + def __len__(self): + """ + Support for backwards-compatible `past_key_value` length, e.g. `len(past_key_value)`. This value corresponds + to the number of layers in the model. + """ + return len(self.self_attention_cache) + + def to_legacy_cache(self) -> Tuple[Tuple[torch.Tensor], Tuple[torch.Tensor]]: + """Converts the `EncoderDecoderCache` instance into its equivalent in the legacy cache format.""" + legacy_cache = () + if len(self.cross_attention_cache) > 0: + for self_attn, cross_attn in zip( + self.self_attention_cache.to_legacy_cache(), self.cross_attention_cache.to_legacy_cache() + ): + legacy_cache += (self_attn + cross_attn,) + else: + legacy_cache = self.self_attention_cache.to_legacy_cache() + return legacy_cache + + @classmethod + def from_legacy_cache( + cls, past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None + ) -> "EncoderDecoderCache": + """Converts a cache in the legacy cache format into an equivalent `EncoderDecoderCache`.""" + cache = cls(self_attention_cache=DynamicCache(), cross_attention_cache=DynamicCache()) + if past_key_values is not None: + for layer_idx in range(len(past_key_values)): + key_states, value_states = past_key_values[layer_idx][:2] + cache.self_attention_cache.update(key_states, value_states, layer_idx) + if len(past_key_values[layer_idx]) > 2: + key_states, value_states = past_key_values[layer_idx][2:] + cache.cross_attention_cache.update(key_states, value_states, layer_idx) + cache.is_updated[layer_idx] = True + return cache + + def get_seq_length(self, layer_idx: Optional[int] = 0) -> int: + """Returns the sequence length of the cached states. A layer index can be optionally passed.""" + if len(self.self_attention_cache.key_cache) <= layer_idx: + return 0 + return (self.self_attention_cache.key_cache[layer_idx][0, 0].any(dim=-1)).sum() + + def reset(self): + if hasattr(self.self_attention_cache, "reset"): + self.self_attention_cache.reset() + if hasattr(self.cross_attention_cache, "reset"): + self.cross_attention_cache.reset() + elif not hasattr(self.self_attention_cache, "reset") and not hasattr(self.cross_attention_cache, "reset"): + raise ValueError( + "Neither self nor cross-attention cache have valid `.reset()` methods. `.reset()` should " + "only be called on compatible cache classes, such as `StaticCache` or `SlidingWindowCache`. " + f"Got {self.self_attention_cache.__str__()} for the self attention cache and " + f"{self.cross_attention_cache.__str__()} for the cross attention cache." + ) + for layer_idx in self.is_updated: + self.is_updated[layer_idx] = False + + def reorder_cache(self, beam_idx: torch.LongTensor): + """Reorders the cache for beam search, given the selected beam indices.""" + self.self_attention_cache.reorder_cache(beam_idx) + self.cross_attention_cache.reorder_cache(beam_idx) + + def check_dynamic_cache(self, method: str): + if not ( + isinstance(self.self_attention_cache, DynamicCache) + and isinstance(self.cross_attention_cache, DynamicCache) + ): + raise ValueError( + f"`{method}` is only defined for dynamic cache, got {self.self_attention_cache.__str__()} for the self " + f"attention cache and {self.cross_attention_cache.__str__()} for the cross attention cache." + ) + + # TODO(gante, sanchit-gandhi): move following functionality into `.generate` + def crop(self, maximum_length: int): + """Crop the past key values up to a new `maximum_length` in terms of tokens. `maximum_length` can also be + negative to remove `maximum_length` tokens. This is used in assisted decoding and contrastive search.""" + self.check_dynamic_cache(self.crop.__name__) + self.self_attention_cache.crop(maximum_length) + + def batch_split(self, full_batch_size: int, split_size: int) -> "List[EncoderDecoderCache]": + """Split the current instance into a list of `DynamicCache` by the batch size. This will be used by + `_split_model_inputs()` in `generation.utils`""" + self.check_dynamic_cache(self.batch_split.__name__) + self_attention_cache = self.self_attention_cache.batch_split(full_batch_size, split_size) + cross_attention_cache = self.cross_attention_cache.batch_split(full_batch_size, split_size) + + out = [] + for self_attn, cross_attn in zip(self_attention_cache, cross_attention_cache): + out.append(EncoderDecoderCache(self_attn, cross_attn)) + return out + + @classmethod + def from_batch_splits(cls, splits: List["EncoderDecoderCache"]) -> "EncoderDecoderCache": + """This is the opposite of the above `batch_split()` method. This will be used by `stack_model_outputs` in + `generation.utils`""" + self_attention_cache = DynamicCache() + cross_attention_cache = DynamicCache() + for idx in range(len(splits[0])): + layer_keys = torch.cat([current.self_attention_cache.key_cache[idx] for current in splits], dim=0) + layer_values = torch.cat([current.self_attention_cache.value_cache[idx] for current in splits], dim=0) + self_attention_cache.update(layer_keys, layer_values, idx) + + layer_keys = torch.cat([current.cross_attention_cache.key_cache[idx] for current in splits], dim=0) + layer_values = torch.cat([current.cross_attention_cache.value_cache[idx] for current in splits], dim=0) + cross_attention_cache.update(layer_keys, layer_values, idx) + return cls(self_attention_cache, cross_attention_cache) + + def batch_repeat_interleave(self, repeats: int): + """Repeat the cache `repeats` times in the batch dimension. Used in contrastive search.""" + self.check_dynamic_cache(self.batch_repeat_interleave.__name__) + self.self_attention_cache.batch_repeat_interleave(repeats) + self.cross_attention_cache.batch_repeat_interleave(repeats) + + def batch_select_indices(self, indices: torch.Tensor): + """Only keep the `indices` in the batch dimension of the cache. Used in contrastive search.""" + self.check_dynamic_cache(self.batch_select_indices.__name__) + self.self_attention_cache.batch_select_indices(indices) + self.cross_attention_cache.batch_select_indices(indices) + class HybridCache(Cache): def __init__(self, config: PretrainedConfig, max_batch_size, max_cache_len, device="cpu", dtype=None) -> None: diff --git a/src/transformers/generation/utils.py b/src/transformers/generation/utils.py index 9c69bb35d264fe..25ec7be1b57e9a 100644 --- a/src/transformers/generation/utils.py +++ b/src/transformers/generation/utils.py @@ -27,6 +27,7 @@ from ..cache_utils import ( Cache, DynamicCache, + EncoderDecoderCache, HQQQuantizedCache, HybridCache, QuantizedCacheConfig, @@ -1409,7 +1410,7 @@ def _get_initial_cache_position(self, input_ids, model_kwargs): model_kwargs["cache_position"] = torch.arange(past_length, cur_len, device=input_ids.device) return model_kwargs - def _get_cache(self, cache_implementation: str, max_batch_size: int, max_cache_len: int) -> Cache: + def _get_cache(self, cache_implementation: str, max_batch_size: int, max_cache_len: int, model_kwargs) -> Cache: """ Sets a cache for `generate`, that will persist across calls. A new cache will only be initialized a new `generate` call requires a larger cache. @@ -1417,28 +1418,46 @@ def _get_cache(self, cache_implementation: str, max_batch_size: int, max_cache_l Returns the resulting cache object. """ cache_cls: Cache = NEED_SETUP_CACHE_CLASSES_MAPPING[cache_implementation] + requires_cross_attention_cache = ( + self.config.is_encoder_decoder or model_kwargs.get("encoder_outputs") is not None + ) + + if hasattr(self, "_cache"): + cache_to_check = self._cache.self_attention_cache if requires_cross_attention_cache else self._cache + if cache_implementation == "sliding_window": max_cache_len = min(self.config.sliding_window, max_cache_len) need_new_cache = ( not hasattr(self, "_cache") - or (not isinstance(self._cache, cache_cls)) - or self._cache.max_batch_size != max_batch_size - or self._cache.max_cache_len < max_cache_len + or (not isinstance(cache_to_check, cache_cls)) + or cache_to_check.max_batch_size != max_batch_size + or cache_to_check.max_cache_len < max_cache_len ) + if requires_cross_attention_cache and hasattr(self, "_cache"): + need_new_cache = ( + need_new_cache + or self._cache.cross_attention_cache.max_cache_len != model_kwargs["encoder_outputs"][0].shape[1] + ) + if need_new_cache: if hasattr(self.config, "_pre_quantization_dtype"): cache_dtype = self.config._pre_quantization_dtype else: cache_dtype = self.dtype - self._cache = cache_cls( - config=self.config, - max_batch_size=max_batch_size, - max_cache_len=max_cache_len, - device=self.device, - dtype=cache_dtype, - ) + cache_kwargs = { + "config": self.config, + "max_batch_size": max_batch_size, + "max_cache_len": max_cache_len, + "device": self.device, + "dtype": cache_dtype, + } + self._cache = cache_cls(**cache_kwargs) + if requires_cross_attention_cache: + encoder_kwargs = cache_kwargs.copy() + encoder_kwargs["max_cache_len"] = model_kwargs["encoder_outputs"][0].shape[1] + self._cache = EncoderDecoderCache(self._cache, cache_cls(**encoder_kwargs)) else: self._cache.reset() return self._cache @@ -1745,6 +1764,7 @@ def generate( generation_config.cache_implementation, getattr(generation_config, "num_beams", 1) * batch_size, generation_config.max_length, + model_kwargs, ) elif generation_config.cache_implementation == "quantized": if not self._supports_quantized_cache: @@ -1776,11 +1796,22 @@ def generate( # keeps copying the cache thus using much more memory elif generation_config.cache_implementation is None and self._supports_default_dynamic_cache(): past = model_kwargs.get("past_key_values", None) + requires_cross_attention_cache = ( + self.config.is_encoder_decoder or model_kwargs.get("encoder_outputs") is not None + ) if past is None: - model_kwargs["past_key_values"] = DynamicCache() + model_kwargs["past_key_values"] = ( + DynamicCache() + if not requires_cross_attention_cache + else EncoderDecoderCache(DynamicCache(), DynamicCache()) + ) use_dynamic_cache_by_default = True elif isinstance(past, tuple): - model_kwargs["past_key_values"] = DynamicCache.from_legacy_cache(past) + model_kwargs["past_key_values"] = ( + DynamicCache.from_legacy_cache(past) + if not requires_cross_attention_cache + else EncoderDecoderCache.from_legacy_cache(past) + ) use_dynamic_cache_by_default = True self._validate_generated_length(generation_config, input_ids_length, has_default_max_length) @@ -2064,7 +2095,7 @@ def typeerror(): # Convert to legacy cache if needed if use_dynamic_cache_by_default and generation_config.return_legacy_cache: if isinstance(result, ModelOutput) and hasattr(result, "past_key_values"): - if isinstance(result.past_key_values, DynamicCache): + if isinstance(result.past_key_values, (DynamicCache, EncoderDecoderCache)): result.past_key_values = result.past_key_values.to_legacy_cache() return result @@ -2234,7 +2265,7 @@ def _contrastive_search( # if the first step in the loop, encode all the prefix and obtain: (1) past_key_values; # (2) last_hidden_states; (3) logit_for_next_step; (4) update model kwargs for the next step if model_kwargs.get("past_key_values") is None or ( - isinstance(model_kwargs["past_key_values"], Cache) + isinstance(model_kwargs["past_key_values"], (Cache, EncoderDecoderCache)) and model_kwargs["past_key_values"].get_seq_length() == 0 ): # prepare inputs @@ -2323,7 +2354,9 @@ def _contrastive_search( # Replicates the new past_key_values to match the `top_k` candidates past = model_kwargs["past_key_values"] # If it is a static cache, modify it in-place layer after layer to save memory - if isinstance(past, DynamicCache): + if isinstance(past, DynamicCache) or ( + isinstance(past, EncoderDecoderCache) and isinstance(past.self_attention_cache, DynamicCache) + ): past.batch_repeat_interleave(top_k) else: new_key_values = [] @@ -2350,7 +2383,10 @@ def _contrastive_search( output_hidden_states=True, output_attentions=output_attentions, ) - if isinstance(outputs["past_key_values"], DynamicCache): + if isinstance(outputs["past_key_values"], DynamicCache) or ( + isinstance(outputs["past_key_values"], EncoderDecoderCache) + and isinstance(outputs["past_key_values"].self_attention_cache, DynamicCache) + ): # Remove past K-V from output since we don't need to stack later outputs["past_key_values"] = None # Remove last token from past K-V since we don't want to append it at this point @@ -2425,7 +2461,10 @@ def _contrastive_search( else: _, next_past_key_values = self._extract_past_from_model_output(outputs, standardize_cache_format=True) # Do it in-place layer per layer to save memory - if isinstance(next_past_key_values, DynamicCache): + if isinstance(next_past_key_values, DynamicCache) or ( + isinstance(next_past_key_values, EncoderDecoderCache) + and isinstance(next_past_key_values.self_attention_cache, DynamicCache) + ): next_past_key_values.batch_select_indices(augmented_idx) else: new_key_values = [] @@ -2498,7 +2537,10 @@ def _contrastive_search( # Contrastive search works by forward looking at the next token, so we need to exclude it from # `past_key_values` to be consistent with the other decoding methods if model_kwargs.get("past_key_values") is not None: - if isinstance(model_kwargs["past_key_values"], DynamicCache): + if isinstance(model_kwargs["past_key_values"], DynamicCache) or ( + isinstance(model_kwargs["past_key_values"], EncoderDecoderCache) + and isinstance(model_kwargs["past_key_values"].self_attention_cache, DynamicCache) + ): model_kwargs["past_key_values"].crop(-1) else: past_key_values = [] @@ -2757,7 +2799,7 @@ def _temporary_reorder_cache(self, past_key_values, beam_idx): # Exception 2: models with different cache formats. These are limited to `DynamicCache` until their # cache format is standardized, to avoid adding complexity to the codebase. elif "bloom" in model_class or "gptbigcode" in model_class: - if not isinstance(past_key_values, DynamicCache): + if not isinstance(past_key_values, (DynamicCache, EncoderDecoderCache)): raise ValueError( f"Using an unsupported cache format with {model_class}. Currently, it only supports the " "legacy tuple format or `DynamicCache`" @@ -3703,8 +3745,12 @@ def _assisted_decoding( # This is needed if return_dict_in_generate is True start_from_empty_dynamic_cache = False - if isinstance(model_kwargs.get("past_key_values", None), DynamicCache): - if len(model_kwargs["past_key_values"]) == 0: + past_key_values = model_kwargs.get("past_key_values", None) + if isinstance(past_key_values, DynamicCache) or ( + isinstance(past_key_values, EncoderDecoderCache) + and isinstance(past_key_values.self_attention_cache, DynamicCache) + ): + if len(past_key_values) == 0: start_from_empty_dynamic_cache = True this_peer_finished = False @@ -4022,7 +4068,9 @@ def _split(data, full_batch_size: int, split_size: int = None): if isinstance(data, torch.Tensor): return [data[i : i + split_size] for i in range(0, full_batch_size, split_size)] # New cache format - elif isinstance(data, DynamicCache): + elif isinstance(data, DynamicCache) or ( + isinstance(data, EncoderDecoderCache) and isinstance(data.self_attention_cache, DynamicCache) + ): return data.batch_split(full_batch_size, split_size) elif isinstance(data, tuple): # If the elements of the tuple are also tuples (e.g., past_key_values in our earlier example) @@ -4130,6 +4178,8 @@ def _concat(data): # New cache format elif isinstance(data[0], DynamicCache): return DynamicCache.from_batch_splits(data) + elif isinstance(data[0], EncoderDecoderCache): + return EncoderDecoderCache.from_batch_splits(data) elif isinstance(data[0], tuple): # If the elements of the tuple are also tuples (e.g., past_key_values in our earlier example) if isinstance(data[0][0], tuple): diff --git a/src/transformers/models/whisper/configuration_whisper.py b/src/transformers/models/whisper/configuration_whisper.py index e7c0f47b58e587..d65811cbc8efe6 100644 --- a/src/transformers/models/whisper/configuration_whisper.py +++ b/src/transformers/models/whisper/configuration_whisper.py @@ -189,7 +189,11 @@ class WhisperConfig(PretrainedConfig): model_type = "whisper" keys_to_ignore_at_inference = ["past_key_values"] - attribute_map = {"num_attention_heads": "encoder_attention_heads", "hidden_size": "d_model"} + attribute_map = { + "num_key_value_heads": "encoder_attention_heads", + "num_attention_heads": "encoder_attention_heads", + "hidden_size": "d_model", + } def __init__( self, diff --git a/src/transformers/models/whisper/modeling_whisper.py b/src/transformers/models/whisper/modeling_whisper.py index aedc0c43aca752..f1467a55e03b9b 100644 --- a/src/transformers/models/whisper/modeling_whisper.py +++ b/src/transformers/models/whisper/modeling_whisper.py @@ -25,7 +25,8 @@ from torch.nn import CrossEntropyLoss from ...activations import ACT2FN -from ...modeling_attn_mask_utils import _prepare_4d_causal_attention_mask, _prepare_4d_causal_attention_mask_for_sdpa +from ...cache_utils import Cache, DynamicCache, EncoderDecoderCache, StaticCache +from ...modeling_attn_mask_utils import AttentionMaskConverter from ...modeling_outputs import ( BaseModelOutput, BaseModelOutputWithPastAndCrossAttentions, @@ -244,6 +245,7 @@ def __init__( is_decoder: bool = False, bias: bool = True, is_causal: bool = False, + layer_idx: Optional[int] = None, config: Optional[WhisperConfig] = None, ): super().__init__() @@ -262,6 +264,14 @@ def __init__( self.is_decoder = is_decoder self.is_causal = is_causal + if layer_idx is None and is_decoder: + logger.warning_once( + f"Instantiating a decoder {self.__class__.__name__} without passing `layer_idx` is not recommended and " + "will to errors during the forward call, if caching is used. Please make sure to provide a `layer_idx` " + "when creating this class." + ) + self.layer_idx = layer_idx + self.k_proj = nn.Linear(embed_dim, embed_dim, bias=False) self.v_proj = nn.Linear(embed_dim, embed_dim, bias=bias) self.q_proj = nn.Linear(embed_dim, embed_dim, bias=bias) @@ -271,84 +281,56 @@ def __init__( def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int): return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous() - # Copied from transformers.models.bart.modeling_bart.BartAttention.forward with BART->whisper def forward( self, hidden_states: torch.Tensor, key_value_states: Optional[torch.Tensor] = None, - past_key_value: Optional[Tuple[torch.Tensor]] = None, + past_key_value: Optional[EncoderDecoderCache] = None, attention_mask: Optional[torch.Tensor] = None, layer_head_mask: Optional[torch.Tensor] = None, output_attentions: bool = False, + cache_position: Optional[torch.LongTensor] = None, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: """Input shape: Batch x Time x Channel""" # if key_value_states are provided this layer is used as a cross-attention layer # for the decoder is_cross_attention = key_value_states is not None - bsz, tgt_len, _ = hidden_states.size() # get query proj - query_states = self.q_proj(hidden_states) * self.scaling - # get key, value proj - # `past_key_value[0].shape[2] == key_value_states.shape[1]` - # is checking that the `sequence_length` of the `past_key_value` is the same as - # the provided `key_value_states` to support prefix tuning - if ( - is_cross_attention - and past_key_value is not None - and past_key_value[0].shape[2] == key_value_states.shape[1] - ): + query_states = self._shape(self.q_proj(hidden_states) * self.scaling, tgt_len, bsz) + + if past_key_value is not None: + is_updated = past_key_value.is_updated.get(self.layer_idx) + if is_cross_attention: + # after the first generated id, we can subsequently re-use all key/value_states from cache + past_key_value.is_updated[self.layer_idx] = True + past_key_value = past_key_value.cross_attention_cache + else: + past_key_value = past_key_value.self_attention_cache + + # use key_value_states if cross attention + current_states = key_value_states if key_value_states is not None else hidden_states + if is_cross_attention and past_key_value and is_updated: # reuse k,v, cross_attentions - key_states = past_key_value[0] - value_states = past_key_value[1] - elif is_cross_attention: - # cross_attentions - key_states = self._shape(self.k_proj(key_value_states), -1, bsz) - value_states = self._shape(self.v_proj(key_value_states), -1, bsz) - elif past_key_value is not None: - # reuse k, v, self_attention - key_states = self._shape(self.k_proj(hidden_states), -1, bsz) - value_states = self._shape(self.v_proj(hidden_states), -1, bsz) - key_states = torch.cat([past_key_value[0], key_states], dim=2) - value_states = torch.cat([past_key_value[1], value_states], dim=2) + key_states = past_key_value.key_cache[self.layer_idx] + value_states = past_key_value.value_cache[self.layer_idx] else: - # self_attention - key_states = self._shape(self.k_proj(hidden_states), -1, bsz) - value_states = self._shape(self.v_proj(hidden_states), -1, bsz) - - if self.is_decoder: - # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states. - # Further calls to cross_attention layer can then reuse all cross-attention - # key/value_states (first "if" case) - # if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of - # all previous decoder key/value_states. Further calls to uni-directional self-attention - # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case) - # if encoder bi-directional self-attention `past_key_value` is always `None` - past_key_value = (key_states, value_states) - - proj_shape = (bsz * self.num_heads, -1, self.head_dim) - query_states = self._shape(query_states, tgt_len, bsz).view(*proj_shape) - key_states = key_states.reshape(*proj_shape) - value_states = value_states.reshape(*proj_shape) - - src_len = key_states.size(1) - attn_weights = torch.bmm(query_states, key_states.transpose(1, 2)) - - if attn_weights.size() != (bsz * self.num_heads, tgt_len, src_len): - raise ValueError( - f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is" - f" {attn_weights.size()}" - ) - - if attention_mask is not None: - if attention_mask.size() != (bsz, 1, tgt_len, src_len): - raise ValueError( - f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is {attention_mask.size()}" + key_states = self._shape(self.k_proj(current_states), -1, bsz) + value_states = self._shape(self.v_proj(current_states), -1, bsz) + if past_key_value is not None: + # save all key/value_states to cache to be re-used for fast auto-regressive generation + cache_position = cache_position if not is_cross_attention else None + key_states, value_states = past_key_value.update( + key_states, value_states, self.layer_idx, {"cache_position": cache_position} ) - attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attention_mask - attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len) + + attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) + + if attention_mask is not None: # no matter the length, we just slice it + causal_mask = attention_mask[:, :, :, : key_states.shape[-2]] + attn_weights = attn_weights + causal_mask attn_weights = nn.functional.softmax(attn_weights, dim=-1) @@ -358,42 +340,27 @@ def forward( f"Head mask for a single layer should be of size {(self.num_heads,)}, but is" f" {layer_head_mask.size()}" ) - attn_weights = layer_head_mask.view(1, -1, 1, 1) * attn_weights.view(bsz, self.num_heads, tgt_len, src_len) - attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len) - - if output_attentions: - # this operation is a bit awkward, but it's required to - # make sure that attn_weights keeps its gradient. - # In order to do so, attn_weights have to be reshaped - # twice and have to be reused in the following - attn_weights_reshaped = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) - attn_weights = attn_weights_reshaped.view(bsz * self.num_heads, tgt_len, src_len) - else: - attn_weights_reshaped = None + attn_weights = layer_head_mask.view(1, -1, 1, 1) * attn_weights attn_probs = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training) + attn_output = torch.matmul(attn_probs, value_states) - attn_output = torch.bmm(attn_probs, value_states) - - if attn_output.size() != (bsz * self.num_heads, tgt_len, self.head_dim): + if attn_output.size() != (bsz, self.num_heads, tgt_len, self.head_dim): raise ValueError( - f"`attn_output` should be of size {(bsz * self.num_heads, tgt_len, self.head_dim)}, but is" + f"`attn_output` should be of size {(bsz, self.num_heads, tgt_len, self.head_dim)}, but is" f" {attn_output.size()}" ) - attn_output = attn_output.view(bsz, self.num_heads, tgt_len, self.head_dim) attn_output = attn_output.transpose(1, 2) - # Use the `embed_dim` from the config (stored in the class) rather than `hidden_state` because `attn_output` can be # partitioned across GPUs when using tensor-parallelism. attn_output = attn_output.reshape(bsz, tgt_len, self.embed_dim) attn_output = self.out_proj(attn_output) - return attn_output, attn_weights_reshaped, past_key_value + return attn_output, attn_weights, past_key_value -# Copied from transformers.models.bart.modeling_bart.BartFlashAttention2 with Bart->Whisper class WhisperFlashAttention2(WhisperAttention): """ Whisper flash attention module. This module inherits from `WhisperAttention` as the weights of the module stays @@ -410,18 +377,21 @@ def __init__(self, *args, **kwargs): # Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left). self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10() - def _reshape(self, tensor: torch.Tensor, seq_len: int, bsz: int): - return tensor.view(bsz, seq_len, self.num_heads, self.head_dim) - def forward( self, hidden_states: torch.Tensor, key_value_states: Optional[torch.Tensor] = None, - past_key_value: Optional[Tuple[torch.Tensor]] = None, + past_key_value: Optional[EncoderDecoderCache] = None, attention_mask: Optional[torch.Tensor] = None, layer_head_mask: Optional[torch.Tensor] = None, output_attentions: bool = False, + cache_position: Optional[torch.LongTensor] = None, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + if isinstance(past_key_value, StaticCache): + raise ValueError( + "The `static` cache implementation is not compatible with `attn_implementation='flash_attention_2'`. " + "Use `attn_implementation='sdpa'` in the meantime, and open an issue at https://github.com/huggingface/transformers" + ) # WhisperFlashAttention2 attention does not support output_attentions if output_attentions: raise ValueError("WhisperFlashAttention2 attention does not support output_attentions") @@ -429,51 +399,45 @@ def forward( # if key_value_states are provided this layer is used as a cross-attention layer # for the decoder is_cross_attention = key_value_states is not None - - bsz, q_len, _ = hidden_states.size() + bsz, tgt_len, _ = hidden_states.size() # get query proj - query_states = self._reshape(self.q_proj(hidden_states), -1, bsz) - # get key, value proj - # `past_key_value[0].shape[2] == key_value_states.shape[1]` - # is checking that the `sequence_length` of the `past_key_value` is the same as - # the provided `key_value_states` to support prefix tuning - if ( - is_cross_attention - and past_key_value is not None - and past_key_value[0].shape[2] == key_value_states.shape[1] - ): + query_states = self._shape(self.q_proj(hidden_states), tgt_len, bsz) + + if past_key_value is not None: + is_updated = past_key_value.is_updated.get(self.layer_idx) + if is_cross_attention: + # after the first generated id, we can subsequently re-use all key/value_states from cache + past_key_value.is_updated[self.layer_idx] = True + past_key_value = past_key_value.cross_attention_cache + else: + past_key_value = past_key_value.self_attention_cache + + # use key_value_states if cross attention + current_states = key_value_states if key_value_states is not None else hidden_states + if is_cross_attention and past_key_value and is_updated: # reuse k,v, cross_attentions - key_states = past_key_value[0].transpose(1, 2) - value_states = past_key_value[1].transpose(1, 2) - elif is_cross_attention: - # cross_attentions - key_states = self._reshape(self.k_proj(key_value_states), -1, bsz) - value_states = self._reshape(self.v_proj(key_value_states), -1, bsz) - elif past_key_value is not None: - # reuse k, v, self_attention - key_states = self._reshape(self.k_proj(hidden_states), -1, bsz) - value_states = self._reshape(self.v_proj(hidden_states), -1, bsz) - key_states = torch.cat([past_key_value[0].transpose(1, 2), key_states], dim=1) - value_states = torch.cat([past_key_value[1].transpose(1, 2), value_states], dim=1) + key_states = past_key_value.key_cache[self.layer_idx] + value_states = past_key_value.value_cache[self.layer_idx] else: - # self_attention - key_states = self._reshape(self.k_proj(hidden_states), -1, bsz) - value_states = self._reshape(self.v_proj(hidden_states), -1, bsz) - - if self.is_decoder: - # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states. - # Further calls to cross_attention layer can then reuse all cross-attention - # key/value_states (first "if" case) - # if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of - # all previous decoder key/value_states. Further calls to uni-directional self-attention - # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case) - # if encoder bi-directional self-attention `past_key_value` is always `None` - past_key_value = (key_states.transpose(1, 2), value_states.transpose(1, 2)) - - kv_seq_len = key_states.shape[-2] - if past_key_value is not None: - kv_seq_len += past_key_value[0].shape[-2] + key_states = self._shape(self.k_proj(current_states), -1, bsz) + value_states = self._shape(self.v_proj(current_states), -1, bsz) + if past_key_value is not None: + # save all key/value_states to cache to be re-used for fast auto-regressive generation + cache_position = cache_position if not is_cross_attention else None + key_states, value_states = past_key_value.update( + key_states, value_states, self.layer_idx, {"cache_position": cache_position} + ) + + # TODO: These transpose are quite inefficient but Flash Attention requires the layout [batch_size, sequence_length, num_heads, head_dim] + # We would need to refactor the KV cache to be able to avoid many of these transpose/reshape/view. + query_states = query_states.transpose(1, 2) + key_states = key_states.transpose(1, 2) + value_states = value_states.transpose(1, 2) + + causal_mask = attention_mask + if attention_mask is not None: # no matter the length, we just slice it + causal_mask = attention_mask[:, :, :, : key_states.shape[-2]] # In PEFT, usually we cast the layer norms in float32 for training stability reasons # therefore the input hidden states gets silently casted in float32. Hence, we need @@ -502,10 +466,10 @@ def forward( value_states = value_states.to(target_dtype) attn_output = self._flash_attention_forward( - query_states, key_states, value_states, attention_mask, q_len, dropout=self.dropout + query_states, key_states, value_states, causal_mask, tgt_len, dropout=self.dropout ) - attn_output = attn_output.reshape(bsz, q_len, -1) + attn_output = attn_output.reshape(bsz, tgt_len, -1) attn_output = self.out_proj(attn_output) if not output_attentions: @@ -614,15 +578,15 @@ def _upad_input(self, query_layer, key_layer, value_layer, attention_mask, query class WhisperSdpaAttention(WhisperAttention): - # Copied from transformers.models.bart.modeling_bart.BartSdpaAttention.forward with BART->whisper, Bart->Whisper def forward( self, hidden_states: torch.Tensor, key_value_states: Optional[torch.Tensor] = None, - past_key_value: Optional[Tuple[torch.Tensor]] = None, + past_key_value: Optional[EncoderDecoderCache] = None, attention_mask: Optional[torch.Tensor] = None, layer_head_mask: Optional[torch.Tensor] = None, output_attentions: bool = False, + cache_position: Optional[torch.LongTensor] = None, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: """Input shape: Batch x Time x Channel""" if output_attentions or layer_head_mask is not None: @@ -638,59 +602,50 @@ def forward( attention_mask=attention_mask, layer_head_mask=layer_head_mask, output_attentions=output_attentions, + cache_position=cache_position, ) # if key_value_states are provided this layer is used as a cross-attention layer # for the decoder is_cross_attention = key_value_states is not None - bsz, tgt_len, _ = hidden_states.size() # get query proj - query_states = self.q_proj(hidden_states) - # get key, value proj - # `past_key_value[0].shape[2] == key_value_states.shape[1]` - # is checking that the `sequence_length` of the `past_key_value` is the same as - # the provided `key_value_states` to support prefix tuning - if ( - is_cross_attention - and past_key_value is not None - and past_key_value[0].shape[2] == key_value_states.shape[1] - ): + query_states = self._shape(self.q_proj(hidden_states), tgt_len, bsz) + + if past_key_value is not None: + is_updated = past_key_value.is_updated.get(self.layer_idx) + if is_cross_attention: + # after the first generated id, we can subsequently re-use all key/value_states from cache + past_key_value.is_updated[self.layer_idx] = True + past_key_value = past_key_value.cross_attention_cache + else: + past_key_value = past_key_value.self_attention_cache + + # use key_value_states if cross attention + current_states = key_value_states if key_value_states is not None else hidden_states + if is_cross_attention and past_key_value and is_updated: # reuse k,v, cross_attentions - key_states = past_key_value[0] - value_states = past_key_value[1] - elif is_cross_attention: - # cross_attentions - key_states = self._shape(self.k_proj(key_value_states), -1, bsz) - value_states = self._shape(self.v_proj(key_value_states), -1, bsz) - elif past_key_value is not None: - # reuse k, v, self_attention - key_states = self._shape(self.k_proj(hidden_states), -1, bsz) - value_states = self._shape(self.v_proj(hidden_states), -1, bsz) - key_states = torch.cat([past_key_value[0], key_states], dim=2) - value_states = torch.cat([past_key_value[1], value_states], dim=2) + key_states = past_key_value.key_cache[self.layer_idx] + value_states = past_key_value.value_cache[self.layer_idx] else: - # self_attention - key_states = self._shape(self.k_proj(hidden_states), -1, bsz) - value_states = self._shape(self.v_proj(hidden_states), -1, bsz) - - if self.is_decoder: - # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states. - # Further calls to cross_attention layer can then reuse all cross-attention - # key/value_states (first "if" case) - # if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of - # all previous decoder key/value_states. Further calls to uni-directional self-attention - # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case) - # if encoder bi-directional self-attention `past_key_value` is always `None` - past_key_value = (key_states, value_states) - - query_states = self._shape(query_states, tgt_len, bsz) + key_states = self._shape(self.k_proj(current_states), -1, bsz) + value_states = self._shape(self.v_proj(current_states), -1, bsz) + if past_key_value is not None: + # save all key/value_states to cache to be re-used for fast auto-regressive generation + cache_position = cache_position if not is_cross_attention else None + key_states, value_states = past_key_value.update( + key_states, value_states, self.layer_idx, {"cache_position": cache_position} + ) + + causal_mask = attention_mask + if attention_mask is not None: # no matter the length, we just slice it + causal_mask = attention_mask[:, :, :, : key_states.shape[-2]] # We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment # in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling. # The tgt_len > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not create a causal mask in case tgt_len == 1. - is_causal = True if self.is_causal and attention_mask is None and tgt_len > 1 else False + is_causal = True if self.is_causal and causal_mask is None and tgt_len > 1 else False # NOTE: SDPA with memory-efficient backend is currently (torch==2.1.2) bugged when using non-contiguous inputs and a custom attn_mask, # but we are fine here as `_shape` do call `.contiguous()`. Reference: https://github.com/pytorch/pytorch/issues/112577 @@ -698,7 +653,7 @@ def forward( query_states, key_states, value_states, - attn_mask=attention_mask, + attn_mask=causal_mask, dropout_p=self.dropout if self.training else 0.0, is_causal=is_causal, ) @@ -798,9 +753,8 @@ def forward( return outputs -# Copied from transformers.models.mbart.modeling_mbart.MBartDecoderLayer with MBart->Whisper, MBART->WHISPER class WhisperDecoderLayer(nn.Module): - def __init__(self, config: WhisperConfig): + def __init__(self, config: WhisperConfig, layer_idx: int = None): super().__init__() self.embed_dim = config.d_model @@ -810,6 +764,7 @@ def __init__(self, config: WhisperConfig): dropout=config.attention_dropout, is_decoder=True, is_causal=True, + layer_idx=layer_idx, config=config, ) self.dropout = config.dropout @@ -822,6 +777,7 @@ def __init__(self, config: WhisperConfig): config.decoder_attention_heads, dropout=config.attention_dropout, is_decoder=True, + layer_idx=layer_idx, config=config, ) self.encoder_attn_layer_norm = nn.LayerNorm(self.embed_dim) @@ -837,9 +793,10 @@ def forward( encoder_attention_mask: Optional[torch.Tensor] = None, layer_head_mask: Optional[torch.Tensor] = None, cross_attn_layer_head_mask: Optional[torch.Tensor] = None, - past_key_value: Optional[Tuple[torch.Tensor]] = None, + past_key_value: Optional[EncoderDecoderCache] = None, output_attentions: Optional[bool] = False, use_cache: Optional[bool] = True, + cache_position: Optional[torch.LongTensor] = None, ) -> torch.Tensor: """ Args: @@ -863,41 +820,35 @@ def forward( hidden_states = self.self_attn_layer_norm(hidden_states) # Self Attention - # decoder uni-directional self-attention cached key/values tuple is at positions 1,2 - self_attn_past_key_value = past_key_value[:2] if past_key_value is not None else None - # add present self-attn cache to positions 1,2 of present_key_value tuple hidden_states, self_attn_weights, present_key_value = self.self_attn( hidden_states=hidden_states, - past_key_value=self_attn_past_key_value, + past_key_value=past_key_value, attention_mask=attention_mask, layer_head_mask=layer_head_mask, output_attentions=output_attentions, + cache_position=cache_position, ) hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) hidden_states = residual + hidden_states # Cross-Attention Block - cross_attn_present_key_value = None cross_attn_weights = None if encoder_hidden_states is not None: residual = hidden_states hidden_states = self.encoder_attn_layer_norm(hidden_states) - - # cross_attn cached key/values tuple is at positions 3,4 of present_key_value tuple - cross_attn_past_key_value = past_key_value[-2:] if past_key_value is not None else None hidden_states, cross_attn_weights, cross_attn_present_key_value = self.encoder_attn( hidden_states=hidden_states, key_value_states=encoder_hidden_states, attention_mask=encoder_attention_mask, layer_head_mask=cross_attn_layer_head_mask, - past_key_value=cross_attn_past_key_value, + past_key_value=past_key_value, output_attentions=output_attentions, ) hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) hidden_states = residual + hidden_states - # add cross-attn to positions 3,4 of present_key_value tuple - present_key_value = present_key_value + cross_attn_present_key_value + # add cross-attn to positions 1 of present_key_value tuple + present_key_value = (present_key_value, cross_attn_present_key_value) # Fully Connected residual = hidden_states @@ -927,6 +878,8 @@ class WhisperPreTrainedModel(PreTrainedModel): _no_split_modules = ["WhisperEncoderLayer", "WhisperDecoderLayer"] _supports_flash_attn_2 = True _supports_sdpa = True + _supports_cache_class = True + _supports_static_cache = True def _init_weights(self, module): std = self.config.init_std @@ -1024,14 +977,18 @@ def _get_feat_extract_output_lengths(self, input_lengths: torch.LongTensor): Tuple consists of (`last_hidden_state`, *optional*: `hidden_states`, *optional*: `attentions`) `last_hidden_state` of shape `(batch_size, sequence_length, hidden_size)`, *optional*) is a sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention of the decoder. - past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): - Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape + past_key_values (`EncoderDecoderCache` or `tuple(tuple(torch.FloatTensor))`, *optional*): + Pre-computed hidden-states that can be used to speed up auto-regressive (sequential) decoding. There are + four sets of pre-computed hidden-states: key and values states in the self-attention blocks (2) and + in the cross-attention blocks (2). The `past_key_values` are returned when `use_cache=True` is passed or + when `config.use_cache=True` + + Two formats are allowed: + - An [`~cache_utils.EncoderDecoderCache`] instance; + - Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of shape `(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`. - Contains pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention - blocks) that can be used (see `past_key_values` input) to speed up sequential decoding. - If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all `decoder_input_ids` of shape `(batch_size, sequence_length)`. @@ -1051,6 +1008,9 @@ def _get_feat_extract_output_lengths(self, input_lengths: torch.LongTensor): more detail. return_dict (`bool`, *optional*): Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. + cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*): + Indices depicting the position of the input sequence tokens in the sequence. It is used to update the cache + in the correct position and to infer the complete sequence length. """ WHISPER_ENCODER_INPUTS_DOCSTRING = r""" @@ -1256,7 +1216,9 @@ def __init__(self, config: WhisperConfig): self.embed_tokens = nn.Embedding(config.vocab_size, config.d_model, self.padding_idx) self.embed_positions = WhisperPositionalEmbedding(self.max_target_positions, config.d_model) - self.layers = nn.ModuleList([WhisperDecoderLayer(config) for _ in range(config.decoder_layers)]) + self.layers = nn.ModuleList( + [WhisperDecoderLayer(config, layer_idx) for layer_idx in range(config.decoder_layers)] + ) self._use_flash_attention_2 = config._attn_implementation == "flash_attention_2" self._use_sdpa = config._attn_implementation == "sdpa" @@ -1286,6 +1248,7 @@ def forward( output_attentions=None, output_hidden_states=None, return_dict=None, + cache_position=None, ): r""" Args: @@ -1320,13 +1283,17 @@ def forward( - 1 indicates the head is **not masked**, - 0 indicates the head is **masked**. - past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): - Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of - shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of - shape `(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`. + past_key_values (`EncoderDecoderCache` or `tuple(tuple(torch.FloatTensor))`, *optional*): + Pre-computed hidden-states that can be used to speed up auto-regressive (sequential) decoding. There are + four sets of pre-computed hidden-states: key and values states in the self-attention blocks (2) and + in the cross-attention blocks (2). The `past_key_values` are returned when `use_cache=True` is passed or + when `config.use_cache=True` - Contains pre-computed hidden-states (key and values in the self-attention blocks and in the - cross-attention blocks) that can be used (see `past_key_values` input) to speed up sequential decoding. + Two formats are allowed: + - An [`~cache_utils.EncoderDecoderCache`] instance; + - Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of + shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of shape + `(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`. If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of @@ -1344,6 +1311,9 @@ def forward( for more detail. return_dict (`bool`, *optional*): Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. + cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*): + Indices depicting the position of the input sequence tokens in the sequence. It is used to update the + cache in the correct position and to infer the complete sequence length. """ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( @@ -1363,26 +1333,38 @@ def forward( else: raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds") - # past_key_values_length - past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0 - if inputs_embeds is None: inputs_embeds = self.embed_tokens(input_ids) - if self._use_flash_attention_2: - # 2d mask is passed through the layers - attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None - elif self._use_sdpa and head_mask is None and not output_attentions: - # output_attentions=True & head_mask can not be supported when using SDPA. - attention_mask = _prepare_4d_causal_attention_mask_for_sdpa( - attention_mask, input_shape, inputs_embeds, past_key_values_length - ) - else: - # 4d mask is passed through the layers - attention_mask = _prepare_4d_causal_attention_mask( - attention_mask, input_shape, inputs_embeds, past_key_values_length + return_legacy_cache = False + return_self_attention_cache = False + if use_cache or past_key_values is not None: + if isinstance(past_key_values, Cache) and not isinstance(past_key_values, EncoderDecoderCache): + return_self_attention_cache = True + past_key_values = EncoderDecoderCache(past_key_values, DynamicCache()) + elif not isinstance(past_key_values, EncoderDecoderCache): + return_legacy_cache = True + logger.warning_once( + "Passing a tuple of `past_key_values` is deprecated and will be removed in Transformers v4.43.0. " + "You should pass an instance of `EncoderDecoderCache` instead, e.g. " + "`past_key_values=EncoderDecoderCache.from_legacy_cache(past_key_values)`." + ) + past_key_values = EncoderDecoderCache.from_legacy_cache(past_key_values) + + past_key_values_length = 0 + if cache_position is not None: + past_key_values_length = cache_position[0] + elif past_key_values is not None: + past_key_values_length = past_key_values.get_seq_length() + + if cache_position is None: + cache_position = torch.arange( + past_key_values_length, past_key_values_length + input_shape[1], device=inputs_embeds.device ) + if position_ids is None: + position_ids = cache_position.unsqueeze(0) + # embed positions if input_ids is not None: positions = self.embed_positions( @@ -1396,6 +1378,14 @@ def forward( hidden_states = inputs_embeds + positions.to(inputs_embeds.device) hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) + causal_mask = self._update_causal_mask( + attention_mask, + inputs_embeds, + cache_position, + past_key_values.self_attention_cache if past_key_values is not None else None, + output_attentions, + ) + if self.gradient_checkpointing and self.training: if use_cache: logger.warning_once( @@ -1406,7 +1396,6 @@ def forward( all_hidden_states = () if output_hidden_states else None all_self_attns = () if output_attentions else None all_cross_attentions = () if (output_attentions and encoder_hidden_states is not None) else None - next_decoder_cache = () if use_cache else None # check if head_mask/cross_attn_head_mask has a correct number of layers specified if desired for attn_mask, mask_name in zip([head_mask, cross_attn_head_mask], ["head_mask", "cross_attn_head_mask"]): @@ -1424,13 +1413,11 @@ def forward( if dropout_probability < self.layerdrop: continue - past_key_value = past_key_values[idx] if past_key_values is not None else None - if self.gradient_checkpointing and self.training: layer_outputs = self._gradient_checkpointing_func( decoder_layer.__call__, hidden_states, - attention_mask, + causal_mask, encoder_hidden_states, None, # encoder attention mask head_mask[idx] if head_mask is not None else None, @@ -1438,25 +1425,24 @@ def forward( None, # past_key_value output_attentions, use_cache, + cache_position, ) else: layer_outputs = decoder_layer( hidden_states, - attention_mask=attention_mask, + attention_mask=causal_mask, encoder_hidden_states=encoder_hidden_states, layer_head_mask=(head_mask[idx] if head_mask is not None else None), cross_attn_layer_head_mask=( cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None ), - past_key_value=past_key_value, + past_key_value=past_key_values if use_cache else None, output_attentions=output_attentions, use_cache=use_cache, + cache_position=cache_position, ) hidden_states = layer_outputs[0] - if use_cache: - next_decoder_cache += (layer_outputs[3 if output_attentions else 1],) - if output_attentions: all_self_attns += (layer_outputs[1],) @@ -1468,7 +1454,11 @@ def forward( if output_hidden_states: all_hidden_states += (hidden_states,) - next_cache = next_decoder_cache if use_cache else None + next_cache = past_key_values if use_cache else None + if return_self_attention_cache: + next_cache = past_key_values.self_attention_cache + if return_legacy_cache: + next_cache = past_key_values.to_legacy_cache() if not return_dict: return tuple( v @@ -1483,6 +1473,87 @@ def forward( cross_attentions=all_cross_attentions, ) + # Copied from transformers.models.llama.modeling_llama.LlamaModel._update_causal_mask + def _update_causal_mask( + self, + attention_mask: torch.Tensor, + input_tensor: torch.Tensor, + cache_position: torch.Tensor, + past_key_values: Cache, + output_attentions: bool, + ): + # TODO: As of torch==2.2.0, the `attention_mask` passed to the model in `generate` is 2D and of dynamic length even when the static + # KV cache is used. This is an issue for torch.compile which then recaptures cudagraphs at each decode steps due to the dynamic shapes. + # (`recording cudagraph tree for symint key 13`, etc.), which is VERY slow. A workaround is `@torch.compiler.disable`, but this prevents using + # `fullgraph=True`. See more context in https://github.com/huggingface/transformers/pull/29114 + + if self.config._attn_implementation == "flash_attention_2": + if attention_mask is not None and 0.0 in attention_mask: + return attention_mask + return None + + # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in + # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail + # to infer the attention mask. + past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + using_static_cache = isinstance(past_key_values, StaticCache) + + # When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward + if self.config._attn_implementation == "sdpa" and not using_static_cache and not output_attentions: + if AttentionMaskConverter._ignore_causal_mask_sdpa( + attention_mask, + inputs_embeds=input_tensor, + past_key_values_length=past_seen_tokens, + is_training=self.training, + ): + return None + + dtype, device = input_tensor.dtype, input_tensor.device + min_dtype = torch.finfo(dtype).min + sequence_length = input_tensor.shape[1] + if using_static_cache: + target_length = past_key_values.get_max_length() + else: + target_length = ( + attention_mask.shape[-1] + if isinstance(attention_mask, torch.Tensor) + else past_seen_tokens + sequence_length + 1 + ) + + if attention_mask is not None and attention_mask.dim() == 4: + # in this case we assume that the mask comes already in inverted form and requires no inversion or slicing + if attention_mask.max() != 0: + raise ValueError("Custom 4D attention mask should be passed in inverted form with max==0`") + causal_mask = attention_mask + else: + causal_mask = torch.full( + (sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device + ) + if sequence_length != 1: + causal_mask = torch.triu(causal_mask, diagonal=1) + causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1) + causal_mask = causal_mask[None, None, :, :].expand(input_tensor.shape[0], 1, -1, -1) + if attention_mask is not None: + causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit + mask_length = attention_mask.shape[-1] + padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :] + padding_mask = padding_mask == 0 + causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill( + padding_mask, min_dtype + ) + if ( + self.config._attn_implementation == "sdpa" + and attention_mask is not None + and attention_mask.device.type == "cuda" + and not output_attentions + ): + # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when + # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path. + # Details: https://github.com/pytorch/pytorch/issues/110213 + causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype) + + return causal_mask + @add_start_docstrings( "The bare Whisper Model outputting raw hidden-states without any specific head on top.", @@ -1571,13 +1642,14 @@ def forward( decoder_head_mask: Optional[torch.Tensor] = None, cross_attn_head_mask: Optional[torch.Tensor] = None, encoder_outputs: Optional[Tuple[Tuple[torch.FloatTensor]]] = None, - past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None, + past_key_values: Optional[Union[EncoderDecoderCache, Tuple[torch.FloatTensor]]] = None, decoder_inputs_embeds: Optional[Tuple[torch.FloatTensor]] = None, decoder_position_ids: Optional[Tuple[torch.LongTensor]] = None, use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, ) -> Union[Tuple[torch.Tensor], Seq2SeqModelOutput]: r""" Returns: @@ -1637,6 +1709,7 @@ def forward( output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, + cache_position=cache_position, ) if not return_dict: @@ -1704,7 +1777,7 @@ def forward( decoder_head_mask: Optional[torch.Tensor] = None, cross_attn_head_mask: Optional[torch.Tensor] = None, encoder_outputs: Optional[Tuple[Tuple[torch.FloatTensor]]] = None, - past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None, + past_key_values: Optional[Union[EncoderDecoderCache, Tuple[torch.FloatTensor]]] = None, decoder_inputs_embeds: Optional[Tuple[torch.FloatTensor]] = None, decoder_position_ids: Optional[Tuple[torch.LongTensor]] = None, labels: Optional[torch.LongTensor] = None, @@ -1712,6 +1785,7 @@ def forward( output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, ) -> Union[Tuple[torch.Tensor], Seq2SeqLMOutput]: r""" labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): @@ -1766,6 +1840,7 @@ def forward( output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, + cache_position=cache_position, ) lm_logits = self.proj_out(outputs[0]) @@ -1800,14 +1875,19 @@ def prepare_inputs_for_generation( encoder_outputs=None, attention_mask=None, decoder_attention_mask=None, + cache_position=None, **kwargs, ): decoder_position_ids = None if decoder_attention_mask is not None: decoder_position_ids = (decoder_attention_mask.cumsum(-1) - 1).clamp(min=0) + past_length = 0 if past_key_values is not None: - past_length = past_key_values[0][0].shape[2] + if isinstance(past_key_values, EncoderDecoderCache): + past_length = cache_position[0] if cache_position is not None else past_key_values.get_seq_length() + else: + past_length = past_key_values[0][0].shape[2] # Some generation methods already pass only the last input ID if decoder_input_ids.shape[1] > past_length: @@ -1821,6 +1901,13 @@ def prepare_inputs_for_generation( if decoder_position_ids is not None and decoder_position_ids.shape[1] > decoder_input_ids.shape[1]: decoder_position_ids = decoder_position_ids[:, remove_prefix_length:] + if cache_position is None: + cache_position = torch.arange( + past_length, past_length + decoder_input_ids.shape[1], device=decoder_input_ids.device + ) + elif use_cache: + cache_position = cache_position[-decoder_input_ids.shape[1] :] + return { "encoder_outputs": encoder_outputs, "past_key_values": past_key_values, @@ -1828,6 +1915,7 @@ def prepare_inputs_for_generation( "use_cache": use_cache, "decoder_attention_mask": decoder_attention_mask, "decoder_position_ids": decoder_position_ids, + "cache_position": cache_position, } @staticmethod @@ -1914,6 +2002,7 @@ def forward( output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, ) -> Union[Tuple, CausalLMOutputWithCrossAttentions]: r""" Args: @@ -1968,6 +2057,9 @@ def forward( for more detail. return_dict (`bool`, *optional*): Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. + cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*): + Indices depicting the position of the input sequence tokens in the sequence. It is used to update the cache + in the correct position and to infer the complete sequence length. Returns: @@ -2019,6 +2111,7 @@ def forward( output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, + cache_position=cache_position, ) logits = self.proj_out(outputs[0]) @@ -2049,10 +2142,15 @@ def prepare_inputs_for_generation( use_cache=None, encoder_outputs=None, attention_mask=None, + cache_position=None, **kwargs, ): + past_length = 0 if past_key_values is not None: - past_length = past_key_values[0][0].shape[2] + if isinstance(past_key_values, (Cache, EncoderDecoderCache)): + past_length = cache_position[0] if cache_position is not None else past_key_values.get_seq_length() + else: + past_length = past_key_values[0][0].shape[2] # Some generation methods already pass only the last input ID if input_ids.shape[1] > past_length: @@ -2063,12 +2161,18 @@ def prepare_inputs_for_generation( input_ids = input_ids[:, remove_prefix_length:] + if cache_position is None: + cache_position = torch.arange(past_length, past_length + input_ids.shape[1], device=input_ids.device) + elif use_cache: + cache_position = cache_position[-input_ids.shape[1] :] + return { "encoder_outputs": encoder_outputs, "past_key_values": past_key_values, "input_ids": input_ids, "use_cache": use_cache, "attention_mask": attention_mask, + "cache_position": cache_position, } @staticmethod diff --git a/src/transformers/utils/dummy_pt_objects.py b/src/transformers/utils/dummy_pt_objects.py index c9267debc5de81..925d8bbb2f6547 100644 --- a/src/transformers/utils/dummy_pt_objects.py +++ b/src/transformers/utils/dummy_pt_objects.py @@ -37,6 +37,13 @@ def __init__(self, *args, **kwargs): requires_backends(self, ["torch"]) +class EncoderDecoderCache(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + class HQQQuantizedCache(metaclass=DummyObject): _backends = ["torch"] diff --git a/tests/generation/test_utils.py b/tests/generation/test_utils.py index 3293cc279d019a..469bfa9206d2d9 100644 --- a/tests/generation/test_utils.py +++ b/tests/generation/test_utils.py @@ -57,7 +57,7 @@ ImageGPTForCausalImageModeling, SpeechEncoderDecoderModel, ) - from transformers.cache_utils import DynamicCache, QuantoQuantizedCache + from transformers.cache_utils import DynamicCache, EncoderDecoderCache, QuantoQuantizedCache from transformers.generation import ( BeamSampleDecoderOnlyOutput, BeamSampleEncoderDecoderOutput, @@ -1636,7 +1636,6 @@ def test_new_cache_format(self, num_beams, do_sample): config, input_ids, attention_mask = self._get_input_ids_and_config() config.use_cache = True - config.is_decoder = True model = model_class(config).to(torch_device).eval() generation_kwargs = { @@ -1652,15 +1651,21 @@ def test_new_cache_format(self, num_beams, do_sample): set_seed(seed) legacy_results = model.generate(input_ids, attention_mask=attention_mask, **generation_kwargs) set_seed(seed) + if config.is_encoder_decoder: + cache_cls = EncoderDecoderCache + past_key_values = cache_cls(DynamicCache(), DynamicCache()) + else: + cache_cls = DynamicCache + past_key_values = cache_cls() new_results = model.generate( - input_ids, attention_mask=attention_mask, past_key_values=DynamicCache(), **generation_kwargs + input_ids, attention_mask=attention_mask, past_key_values=past_key_values, **generation_kwargs ) # The two sets of generated sequences must match, despite the cache format between forward passes being # different self.assertListEqual(legacy_results.sequences.tolist(), new_results.sequences.tolist()) self.assertTrue(isinstance(legacy_results.past_key_values, tuple)) - self.assertTrue(isinstance(new_results.past_key_values, DynamicCache)) + self.assertTrue(isinstance(new_results.past_key_values, cache_cls)) # The contents of the two caches, when converted to the same format (in both directions!), must match legacy_cache = legacy_results.past_key_values @@ -1675,7 +1680,7 @@ def test_new_cache_format(self, num_beams, do_sample): ) new_cache = new_results.past_key_values - legacy_cache_converted = DynamicCache.from_legacy_cache(legacy_results.past_key_values) + legacy_cache_converted = cache_cls.from_legacy_cache(legacy_results.past_key_values) for layer_idx in range(len(new_cache)): for kv_idx in range(len(new_cache[layer_idx])): self.assertTrue( diff --git a/tests/models/whisper/test_modeling_whisper.py b/tests/models/whisper/test_modeling_whisper.py index 86a89af8c13359..dcb495d95a6e4d 100644 --- a/tests/models/whisper/test_modeling_whisper.py +++ b/tests/models/whisper/test_modeling_whisper.py @@ -1539,6 +1539,46 @@ def test_longform_generate_multi_batch(self): def test_longform_generate_multi_batch_cond_prev(self): self._check_longform_generate_multi_batch(condition_on_prev_tokens=True) + def test_custom_4d_attention_mask(self): + config, input_dict = self.model_tester.prepare_config_and_inputs_for_common() + model = WhisperForConditionalGeneration(config).to(device=torch_device, dtype=torch.float32) + model.eval() + + ( + input_ids, + position_ids, + input_ids_shared_prefix, + mask_shared_prefix, + position_ids_shared_prefix, + ) = self._get_custom_4d_mask_test_data() + + with torch.no_grad(): + logits = model.forward( + decoder_input_ids=input_ids, + input_features=input_dict["input_features"], + decoder_position_ids=position_ids, + ).logits + # logits.shape == torch.Size([3, 4, ...]) + + logits_shared_prefix = model( + decoder_input_ids=input_ids_shared_prefix, + input_features=input_dict["input_features"], + decoder_attention_mask=mask_shared_prefix, + decoder_position_ids=position_ids_shared_prefix, + )[0] + # logits_shared_prefix.shape == torch.Size([1, 6, ...]) + + out_last_tokens = logits[:, -1, :] # last tokens in each batch line + out_shared_prefix_last_tokens = logits_shared_prefix[0, -3:, :] # last three tokens + + # comparing greedily-chosen tokens: + assert torch.equal(out_last_tokens.max(axis=1).indices, out_shared_prefix_last_tokens.max(axis=1).indices) + + # comparing softmax-normalized logits: + normalized_0 = torch.nn.functional.softmax(out_last_tokens) + normalized_1 = torch.nn.functional.softmax(out_shared_prefix_last_tokens) + torch.testing.assert_close(normalized_0, normalized_1, rtol=1e-3, atol=1e-4) + @require_torch @require_torchaudio @@ -2961,6 +3001,34 @@ def test_whisper_empty_longform_multi_gpu(self): torch.manual_seed(0) model.generate(**inputs, **gen_kwargs) + @slow + def test_tiny_static_generation(self): + processor = WhisperProcessor.from_pretrained("openai/whisper-tiny.en") + model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-tiny.en") + model.to(torch_device) + + input_speech = self._load_datasamples(4) + input_features = processor(input_speech, return_tensors="pt", sampling_rate=16_000).input_features + input_features = input_features.to(torch_device) + eager_generated_ids = model.generate(input_features, max_new_tokens=64) + + model.generation_config.cache_implementation = "static" + model.forward = torch.compile(model.forward, mode="reduce-overhead", fullgraph=True) + + # compile the forward pass and assert equivalence + static_generated_ids = model.generate(input_features, max_new_tokens=64) + assert (eager_generated_ids == static_generated_ids).all() + + # check the compiled graph can be re-used and that the cache is correctly reset + # reverse the ordering of the input features + permutation_idx = ( + torch.arange(input_features.shape[0], 0, step=-1, dtype=torch.long, device=input_features.device) - 1 + ) + input_features = input_features[permutation_idx, ...] + static_generated_ids = model.generate(input_features, max_new_tokens=64) + # assert re-ordered generations match those from eager + assert (eager_generated_ids[permutation_idx, :] == static_generated_ids).all() + def prepare_whisper_encoder_inputs_dict(config, input_features, head_mask=None): if head_mask is None: @@ -3564,6 +3632,10 @@ def test_decoder_model_attn_mask_past(self): config=config, input_ids=inputs_dict["input_ids"] ) + @unittest.skip(reason="Tested implicitly through the encoder-decoder tests") + def test_custom_4d_attention_mask(self): + pass + @unittest.skip(reason="Generate needs input ids") def test_generate_without_input_ids(self): # generate only works with input ids for whisper