Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add static cache support for Whisper #30707

Closed
mobicham opened this issue May 8, 2024 · 9 comments
Closed

Add static cache support for Whisper #30707

mobicham opened this issue May 8, 2024 · 9 comments
Labels
Audio Feature request Request for a new feature

Comments

@mobicham
Copy link
Contributor

mobicham commented May 8, 2024

Feature request

Would be great to have static cache support for Whisper to make it faster with torch.compile. Currently, the generate() function doesn't support cache_implementation="static" for Whisper.

Motivation

Static cache with torch.compile can make generation much faster.

Your contribution

Static cache is already supported for LLMs and we see great speed-up.

@amyeroberts amyeroberts added Feature request Request for a new feature Audio labels May 8, 2024
@amyeroberts
Copy link
Collaborator

cc @sanchit-gandhi

@mobicham mobicham changed the title Add support for static cache with Whisper Add support for static cache for Whisper May 9, 2024
@mobicham mobicham changed the title Add support for static cache for Whisper Add static cache support for Whisper May 9, 2024
@huseinzol05
Copy link
Contributor

Let me try, I think I can make it, just need to patch https://github.com/huggingface/transformers/blob/main/src/transformers/models/whisper/modeling_whisper.py#L313 like llama model and pass cache_position should be ok

@mobicham
Copy link
Contributor Author

@huseinzol05 great, thanks ! I think you also need to make sure the model supports initializing the static cache via _setup_cache:

from transformers import StaticCache
model._setup_cache(StaticCache, batch_size, max_cache_len=max_cache_length)

@huseinzol05
Copy link
Contributor

huseinzol05 commented May 11, 2024

I got hit by pytorch/pytorch#123592 at https://github.com/huggingface/transformers/blob/main/src/transformers/models/whisper/modeling_whisper.py#L230, but the static cache is already working without torch compile from my local, arange should solved the problem

@mobicham
Copy link
Contributor Author

@huseinzol05
Copy link
Contributor

huseinzol05 commented May 11, 2024

Anything dynamic not possible, feed position_ids solved the problem, just like cache_position, i will push the initial later, so you can verify, the speedz is good

@mobicham
Copy link
Contributor Author

mobicham commented May 11, 2024

Great 👍 ! But that arange works well in Llama with fullgraph torch compile.

@huseinzol05
Copy link
Contributor

huseinzol05 commented May 11, 2024

#30760

Compiled static cache able to achieve 186.26it/s while non-compiled got 150.20it/s

@ArthurZucker
Copy link
Collaborator

Closing as this is fixed: #31166 and #31772

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Audio Feature request Request for a new feature
Projects
None yet
Development

No branches or pull requests

4 participants