Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[whisper] static kv cache #31166

Merged
merged 76 commits into from
Jul 2, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
76 commits
Select commit Hold shift + click to select a range
738ed90
make work with cache abstraction
sanchit-gandhi May 31, 2024
624fa74
correct for static cache
sanchit-gandhi May 31, 2024
f2124f8
hacks for compile
sanchit-gandhi May 31, 2024
9f02f7d
make fast
sanchit-gandhi May 31, 2024
2d7102e
fix
sanchit-gandhi May 31, 2024
cd9ce9b
fix pos ids
sanchit-gandhi May 31, 2024
abad0b9
generate
sanchit-gandhi May 31, 2024
248be4d
fix sdpa
sanchit-gandhi May 31, 2024
9ba0da9
fix sdpa cache pos
sanchit-gandhi May 31, 2024
4ea437a
fix fa2
sanchit-gandhi May 31, 2024
92f94f8
clean fa2
sanchit-gandhi May 31, 2024
7ea0d16
integrate cache into generate
sanchit-gandhi May 31, 2024
b4478c1
make style
sanchit-gandhi May 31, 2024
b6cb739
copies
sanchit-gandhi May 31, 2024
57a219b
more copies
sanchit-gandhi May 31, 2024
2d91708
update eager
sanchit-gandhi Jun 5, 2024
11e79a9
update sdpa
sanchit-gandhi Jun 5, 2024
27d520b
update fa2
sanchit-gandhi Jun 5, 2024
f72224d
simplify
sanchit-gandhi Jun 5, 2024
fcf024a
use cache pos
sanchit-gandhi Jun 5, 2024
3f48947
always compute cross-cache for debug
sanchit-gandhi Jun 6, 2024
7a5a5eb
avoid recompiles
sanchit-gandhi Jun 7, 2024
2eba447
fix fix
sanchit-gandhi Jun 7, 2024
0bb8cb6
fix fix fix
sanchit-gandhi Jun 7, 2024
bfac769
more fix
sanchit-gandhi Jun 7, 2024
93c97c1
try encoder-decoder cache (too messy)
sanchit-gandhi Jun 10, 2024
05f12a3
revert encoder-decoder cache
sanchit-gandhi Jun 11, 2024
c1060df
check cross-attn cache
sanchit-gandhi Jun 13, 2024
6ee17cc
use enc-dec dataclass
sanchit-gandhi Jun 13, 2024
606417b
use richer enc-dec dataclass
sanchit-gandhi Jun 18, 2024
e13b38e
clean-up
sanchit-gandhi Jun 18, 2024
5a54a01
revert static cache changes
sanchit-gandhi Jun 18, 2024
3daa6ad
small fixes
sanchit-gandhi Jun 19, 2024
c244bcb
revert to cpu flag
sanchit-gandhi Jun 19, 2024
e0588df
fix copies
sanchit-gandhi Jun 19, 2024
5813aa3
Merge branch 'main' into whisper-static-kv
sanchit-gandhi Jun 19, 2024
b879c57
add static slow test
sanchit-gandhi Jun 19, 2024
86a46ed
past k/v docstring
sanchit-gandhi Jun 19, 2024
d209421
more docstrings
sanchit-gandhi Jun 19, 2024
0cba828
cache_position docstrings
sanchit-gandhi Jun 19, 2024
05e95dc
add to docs
sanchit-gandhi Jun 19, 2024
e5c8393
add enc-dec cache to docs
sanchit-gandhi Jun 19, 2024
959bae3
make style
sanchit-gandhi Jun 19, 2024
832e0b9
fix after rebase
sanchit-gandhi Jun 19, 2024
34d7873
fix beam
sanchit-gandhi Jun 19, 2024
a321cd6
style
sanchit-gandhi Jun 19, 2024
f825daf
fix generation strategies
sanchit-gandhi Jun 20, 2024
e5c33dc
fix most decoder-only tests
sanchit-gandhi Jun 20, 2024
216665a
style
sanchit-gandhi Jun 20, 2024
11a2791
skip test
sanchit-gandhi Jun 20, 2024
004e94d
more clean up
sanchit-gandhi Jun 20, 2024
23b7c22
small docstrings
sanchit-gandhi Jun 20, 2024
1a87b2b
Apply suggestions from code review
sanchit-gandhi Jun 20, 2024
d629233
add todo
sanchit-gandhi Jun 20, 2024
8c0ce1a
only crop self-attn
sanchit-gandhi Jun 20, 2024
0f8b34f
check cache in mixin
sanchit-gandhi Jun 20, 2024
2d09a41
Merge remote-tracking branch 'origin/whisper-static-kv' into whisper-…
sanchit-gandhi Jun 20, 2024
dba80a0
style
sanchit-gandhi Jun 20, 2024
df31a15
fix re-compile after rebase
sanchit-gandhi Jun 20, 2024
cadd3db
move `is_updated` logic to enc-dec wrapper
sanchit-gandhi Jun 21, 2024
7842215
revert back
sanchit-gandhi Jun 24, 2024
5fcfdea
Merge remote-tracking branch 'origin/whisper-static-kv' into whisper-…
sanchit-gandhi Jun 24, 2024
79db195
revert cache back
sanchit-gandhi Jun 24, 2024
6d3997f
finalise design
sanchit-gandhi Jun 24, 2024
6a377d1
fix
sanchit-gandhi Jun 24, 2024
0093919
fix fix
sanchit-gandhi Jun 24, 2024
ff57b4c
style
sanchit-gandhi Jun 25, 2024
2d4a2a8
Update src/transformers/cache_utils.py
sanchit-gandhi Jun 26, 2024
1860c31
deprecate
sanchit-gandhi Jun 26, 2024
d8e738f
Merge remote-tracking branch 'origin/whisper-static-kv' into whisper-…
sanchit-gandhi Jun 26, 2024
24183cb
updates
sanchit-gandhi Jun 26, 2024
2bad47c
final updates
sanchit-gandhi Jun 26, 2024
89823f3
Merge branch 'main' into whisper-static-kv
sanchit-gandhi Jun 27, 2024
f0f8130
style
sanchit-gandhi Jun 27, 2024
d8e8d64
Merge branch 'main' into whisper-static-kv
sanchit-gandhi Jul 2, 2024
e25c8e1
style
sanchit-gandhi Jul 2, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions docs/source/en/internal/generation_utils.md
Original file line number Diff line number Diff line change
Expand Up @@ -391,6 +391,12 @@ A [`Constraint`] can be used to force the generation to include specific tokens
- get_seq_length
- reset

[[autodoc]] EncoderDecoderCache
- get_seq_length
- to_legacy_cache
- from_legacy_cache
- reset
- reorder_cache

## Watermark Utils

Expand Down
47 changes: 44 additions & 3 deletions docs/source/en/model_doc/whisper.md
Original file line number Diff line number Diff line change
Expand Up @@ -52,16 +52,14 @@ Here is a step-by-step guide to transcribing an audio sample using a pre-trained
>>> # Select an audio file and read it:
>>> ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")
>>> audio_sample = ds[0]["audio"]
>>> waveform = audio_sample["array"]
>>> sampling_rate = audio_sample["sampling_rate"]

>>> # Load the Whisper model in Hugging Face format:
>>> processor = WhisperProcessor.from_pretrained("openai/whisper-tiny.en")
>>> model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-tiny.en")

>>> # Use the model and processor to transcribe the audio:
>>> input_features = processor(
... waveform, sampling_rate=sampling_rate, return_tensors="pt"
... audio_sample["array"], sampling_rate=audio_sample["sampling_rate"], return_tensors="pt"
... ).input_features

>>> # Generate token ids
Expand All @@ -74,6 +72,49 @@ Here is a step-by-step guide to transcribing an audio sample using a pre-trained
' Mr. Quilter is the apostle of the middle classes, and we are glad to welcome his gospel.'
```

Whisper is compatible with the following optimisations:
- [PyTorch Scaled Dot Product Attention (SDPA)](../perf_infer_gpu_one#pytorch-scaled-dot-product-attention): flash attention and memory-efficient attention kernels. Enabled by default for `torch>=2.1.1`.
- [Flash Attention 2](../perf_infer_gpu_one#flashattention-2): improved implementation of flash attention through better parallelism and work partitioning.
- [torch.compile](../llm_optims#static-kv-cache-and-torchcompile): JIT-compile the forward pass to dispatch to efficient fused kernels.

As an example, the following codesnippet enables SDPA and `torch.compile` for up to 5x faster inference:

```python
>>> from datasets import load_dataset
>>> from transformers import WhisperProcessor, WhisperForConditionalGeneration

>>> # Select an audio file and read it:
>>> ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")
>>> audio_sample = ds[0]["audio"]

>>> # Load the Whisper model with SDPA attention
>>> processor = WhisperProcessor.from_pretrained("openai/whisper-tiny.en")
>>> model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-tiny.en", attn_implementation="sdpa")

>>> # Enable static cache and compile the forward pass
>>> model.generation_config.cache_implementation = "static"
>>> model.forward = torch.compile(model.forward, mode="reduce-overhead", fullgraph=True)

>>> # Use the model and processor to transcribe the audio:
>>> input_features = processor(
... audio_sample["array"], sampling_rate=audio_sample["sampling_rate"], return_tensors="pt"
... ).input_features

>>> # Compile the forward pass
>>> _ = model.generate(input_features)

>>> # Generate token ids using compiled graph (fast!)
>>> predicted_ids = model.generate(input_features)

>>> # Decode token ids to text
>>> transcription = processor.batch_decode(predicted_ids, skip_special_tokens=True)

>>> transcription[0]
' Mr. Quilter is the apostle of the middle classes, and we are glad to welcome his gospel.'
```

For more details on each optimisation, refer to the documentation linked above.

## Resources

A list of official Hugging Face and community (indicated by 🌎) resources to help you get started with Whisper. If you're interested in submitting a resource to be included here, please feel free to open a Pull Request and we'll review it! The resource should ideally demonstrate something new instead of duplicating an existing resource.
Expand Down
2 changes: 2 additions & 0 deletions src/transformers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -1212,6 +1212,7 @@
"Cache",
"CacheConfig",
"DynamicCache",
"EncoderDecoderCache",
"HQQQuantizedCache",
"QuantizedCache",
"QuantizedCacheConfig",
Expand Down Expand Up @@ -5895,6 +5896,7 @@
Cache,
CacheConfig,
DynamicCache,
EncoderDecoderCache,
HQQQuantizedCache,
QuantizedCache,
QuantizedCacheConfig,
Expand Down
160 changes: 158 additions & 2 deletions src/transformers/cache_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -858,8 +858,12 @@ def update(
k_out = self.key_cache[layer_idx]
v_out = self.value_cache[layer_idx]

k_out[:, :, cache_position] = key_states
v_out[:, :, cache_position] = value_states
if cache_position is None:
k_out.copy_(key_states)
v_out.copy_(value_states)
sanchit-gandhi marked this conversation as resolved.
Show resolved Hide resolved
else:
k_out[:, :, cache_position] = key_states
v_out[:, :, cache_position] = value_states

return k_out, v_out

Expand Down Expand Up @@ -971,6 +975,158 @@ def get_max_length(self) -> Optional[int]:
# no matter how long the sentence is
return None

def reset(self):
self.key_cache.zero_()
self.value_cache.zero_()


class EncoderDecoderCache(Cache):
"""
Base, abstract class for all encoder-decoder caches. Can be used to hold combinations of self-attention and
cross-attention caches.
"""

def __init__(self, self_attention_cache: Cache, cross_attention_cache: Cache):
self.self_attention_cache = self_attention_cache
self.cross_attention_cache = cross_attention_cache

self.is_updated = {}
for layer_idx in range(len(cross_attention_cache.key_cache)):
self.is_updated[layer_idx] = bool(cross_attention_cache.get_seq_length(layer_idx) > 0)
sanchit-gandhi marked this conversation as resolved.
Show resolved Hide resolved

def __getitem__(self, layer_idx: int) -> List[Tuple[torch.Tensor]]:
"""
Support for backwards-compatible `past_key_value` indexing, e.g. `past_key_value[0][0].shape[2]` to get the
sequence length.
"""
if layer_idx < len(self):
return (
self.self_attention_cache.key_cache[layer_idx],
self.self_attention_cache.value_cache[layer_idx],
self.cross_attention_cache.key_cache[layer_idx],
self.cross_attention_cache.key_cache[layer_idx],
)
else:
raise KeyError(f"Cache only has {len(self)} layers, attempted to access layer with index {layer_idx}")

def __len__(self):
"""
Support for backwards-compatible `past_key_value` length, e.g. `len(past_key_value)`. This value corresponds
to the number of layers in the model.
"""
return len(self.self_attention_cache)

def to_legacy_cache(self) -> Tuple[Tuple[torch.Tensor], Tuple[torch.Tensor]]:
"""Converts the `EncoderDecoderCache` instance into its equivalent in the legacy cache format."""
legacy_cache = ()
if len(self.cross_attention_cache) > 0:
for self_attn, cross_attn in zip(
self.self_attention_cache.to_legacy_cache(), self.cross_attention_cache.to_legacy_cache()
):
legacy_cache += (self_attn + cross_attn,)
else:
legacy_cache = self.self_attention_cache.to_legacy_cache()
return legacy_cache

@classmethod
def from_legacy_cache(
cls, past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
) -> "EncoderDecoderCache":
"""Converts a cache in the legacy cache format into an equivalent `EncoderDecoderCache`."""
cache = cls(self_attention_cache=DynamicCache(), cross_attention_cache=DynamicCache())
if past_key_values is not None:
for layer_idx in range(len(past_key_values)):
key_states, value_states = past_key_values[layer_idx][:2]
cache.self_attention_cache.update(key_states, value_states, layer_idx)
if len(past_key_values[layer_idx]) > 2:
key_states, value_states = past_key_values[layer_idx][2:]
cache.cross_attention_cache.update(key_states, value_states, layer_idx)
cache.is_updated[layer_idx] = True
return cache

def get_seq_length(self, layer_idx: Optional[int] = 0) -> int:
"""Returns the sequence length of the cached states. A layer index can be optionally passed."""
if len(self.self_attention_cache.key_cache) <= layer_idx:
return 0
return (self.self_attention_cache.key_cache[layer_idx][0, 0].any(dim=-1)).sum()

def reset(self):
if hasattr(self.self_attention_cache, "reset"):
self.self_attention_cache.reset()
if hasattr(self.cross_attention_cache, "reset"):
self.cross_attention_cache.reset()
elif not hasattr(self.self_attention_cache, "reset") and not hasattr(self.cross_attention_cache, "reset"):
raise ValueError(
"Neither self nor cross-attention cache have valid `.reset()` methods. `.reset()` should "
"only be called on compatible cache classes, such as `StaticCache` or `SlidingWindowCache`. "
f"Got {self.self_attention_cache.__str__()} for the self attention cache and "
f"{self.cross_attention_cache.__str__()} for the cross attention cache."
)
for layer_idx in self.is_updated:
self.is_updated[layer_idx] = False

def reorder_cache(self, beam_idx: torch.LongTensor):
"""Reorders the cache for beam search, given the selected beam indices."""
self.self_attention_cache.reorder_cache(beam_idx)
self.cross_attention_cache.reorder_cache(beam_idx)

def check_dynamic_cache(self, method: str):
if not (
isinstance(self.self_attention_cache, DynamicCache)
and isinstance(self.cross_attention_cache, DynamicCache)
):
raise ValueError(
f"`{method}` is only defined for dynamic cache, got {self.self_attention_cache.__str__()} for the self "
f"attention cache and {self.cross_attention_cache.__str__()} for the cross attention cache."
)

# TODO(gante, sanchit-gandhi): move following functionality into `.generate`
def crop(self, maximum_length: int):
sanchit-gandhi marked this conversation as resolved.
Show resolved Hide resolved
"""Crop the past key values up to a new `maximum_length` in terms of tokens. `maximum_length` can also be
negative to remove `maximum_length` tokens. This is used in assisted decoding and contrastive search."""
self.check_dynamic_cache(self.crop.__name__)
self.self_attention_cache.crop(maximum_length)

def batch_split(self, full_batch_size: int, split_size: int) -> "List[EncoderDecoderCache]":
"""Split the current instance into a list of `DynamicCache` by the batch size. This will be used by
`_split_model_inputs()` in `generation.utils`"""
self.check_dynamic_cache(self.batch_split.__name__)
self_attention_cache = self.self_attention_cache.batch_split(full_batch_size, split_size)
cross_attention_cache = self.cross_attention_cache.batch_split(full_batch_size, split_size)

out = []
for self_attn, cross_attn in zip(self_attention_cache, cross_attention_cache):
out.append(EncoderDecoderCache(self_attn, cross_attn))
return out

@classmethod
def from_batch_splits(cls, splits: List["EncoderDecoderCache"]) -> "EncoderDecoderCache":
"""This is the opposite of the above `batch_split()` method. This will be used by `stack_model_outputs` in
`generation.utils`"""
self_attention_cache = DynamicCache()
cross_attention_cache = DynamicCache()
for idx in range(len(splits[0])):
layer_keys = torch.cat([current.self_attention_cache.key_cache[idx] for current in splits], dim=0)
layer_values = torch.cat([current.self_attention_cache.value_cache[idx] for current in splits], dim=0)
self_attention_cache.update(layer_keys, layer_values, idx)

layer_keys = torch.cat([current.cross_attention_cache.key_cache[idx] for current in splits], dim=0)
layer_values = torch.cat([current.cross_attention_cache.value_cache[idx] for current in splits], dim=0)
cross_attention_cache.update(layer_keys, layer_values, idx)
return cls(self_attention_cache, cross_attention_cache)

def batch_repeat_interleave(self, repeats: int):
"""Repeat the cache `repeats` times in the batch dimension. Used in contrastive search."""
self.check_dynamic_cache(self.batch_repeat_interleave.__name__)
self.self_attention_cache.batch_repeat_interleave(repeats)
self.cross_attention_cache.batch_repeat_interleave(repeats)

def batch_select_indices(self, indices: torch.Tensor):
"""Only keep the `indices` in the batch dimension of the cache. Used in contrastive search."""
self.check_dynamic_cache(self.batch_select_indices.__name__)
self.self_attention_cache.batch_select_indices(indices)
self.cross_attention_cache.batch_select_indices(indices)


class HybridCache(Cache):
def __init__(self, config: PretrainedConfig, max_batch_size, max_cache_len, device="cpu", dtype=None) -> None:
Expand Down
Loading
Loading