-
Notifications
You must be signed in to change notification settings - Fork 27.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Whisper] Add sequential longform decoding #27492
Conversation
The documentation is not available anymore as the PR was closed or merged. |
…to add_whisper_seq_gen
…to add_whisper_seq_gen
QQ: @patrickvonplaten - wouldn't concatenating and passing the whole audio as input result in exploding GPU VRAM usage? |
Hey @patrickvonplaten would you mind adding the performance for bigger models? The worst the model is at predicting timestamps, the worse the performances of the chuncked algorithm. I remember observing very little loss for large models! (Just as a FMI!) |
The audio is chunked on the fly (there is a while loop now) |
Sure I can run it for larger models as well. I'm not 100% sure though why this matters - if we see such strong gains for smaller models we should add it nevertheless. |
…to add_whisper_seq_gen
examples/research_projects/jax-projects/big_bird/bigbird_flax.py
Outdated
Show resolved
Hide resolved
examples/research_projects/jax-projects/big_bird/bigbird_flax.py
Outdated
Show resolved
Hide resolved
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Awesome feature! Looking forward to the other upgrades mentioned in section 4.5 of the paper 🔥 The "previous text conditioning" probably can benefit from the newly added ability to return and reuse past_key_values
Added a few minor nits
>>> # transcribe audio to ids | ||
>>> generated_ids = model.generate(**inputs) | ||
|
||
>>> transcription = processor.batch_decode(generated_ids, skip_special_tokens=True) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I agree that printing the entire generated output here would be a bad idea, but we could add something like
>>> len(transcription[0])
That way, our doctests would fail if we start having numerical problems with the longform decoding. WDYT?
elif not is_shortform: | ||
if return_timestamps is False: | ||
raise ValueError( | ||
"You have passed more than 3000 mel input features (> 30 seconds) which automatically enables long-form generation which " |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
A few lines here have more than 120 characters :)
This particular sentence also has 2x "which", rewriting might improve readability
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There could be something wrong with the way I'm initialising the pipeline. But on my single file benchmark - it just truncates the output to 30 sec
Repro:https://github.com/Vaibhavs10/scratchpad/blob/main/conditional_long_form_generation_whisper.ipynb
Note: I'm not defining chunk size as it isn't defined in the example snippet up top.
It works as intended with model + generate tho! 🚀
More of an overall usage remark from a developer's PoV:
How do we clarify whether the transcription strategy used is chunked or conditional? Can we allow developers to choose? Supporting this via pipeline is important IMO.
Edit: To clarify, one of the biggest usecase for people to use pipeline is to throw an audio file in whichever format and then get the transcriptions for it.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the hard work!
Would just split the 6.2 in a seperate function, not falling in the trap of havving a huge generate. Same for the small functions used for decoding
@@ -1517,29 +1518,33 @@ class WhisperTimeStampLogitsProcessor(LogitsProcessor): | |||
``` | |||
""" | |||
|
|||
def __init__(self, generate_config): # support for the kwargs | |||
def __init__( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sorry if I am not understanding the statement 😅 No it was used for chunked
processing because the merging algorithm heavily relies on timestamps, and also produces timestamps. See this line which always add the processor.
Placing a breakpoint in the call of this class with this:
from transformers import pipeline
from datasets import load_dataset
import numpy as np
transcriptor = pipeline("automatic-speech-recognition", model = "openai/whisper-tiny")
ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")
audio_file = ds[0]["audio"]["array"]
long_audio = np.concatenate([ds[0]["audio"]["array"]]*40)
out = transcriptor(long_audio, return_timestamps=True, chunk_length_s=30, stride_length_s=5)
(this is for me "timestamp" chunk transcription.
return_segments (`bool`, *optional*, defaults to `False`): | ||
Whether to additionally return a list of all segments. Note that this option can only be enabled | ||
when doing long-form transcription. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
By segment do you mean segments of audio? (not sure I understand) This is only valid for sequential long form no?
time_precision (`int`, *optional*, defaults to 0.02): | ||
The duration of output token in seconds. *E.g.* 0.02 means that a generated token on average accounts | ||
for 20 ms. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We have this in the pipeline:
time_precision = self.feature_extractor.chunk_length / self.model.config.max_source_positions
which should be replaced I guess?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We don't have access to the feature extractor here sadly
**kwargs, | ||
) | ||
if generation_config.return_timestamps is True: | ||
last_forced_decoder_ids = ( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Not sure I understand, if the last token is no_timestamps
we should still have the first 2-3 forced tokens (specifically one set for the task) unless it's done after
cur_bsz = prev_bsz = batch_size | ||
|
||
# 6.2 Transcribe audio until we reach the end of all input audios | ||
while (seek < max_frames).any(): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
let's seperate this in another function?
(Kind of like the merging function)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Factored out one part into a function
Nice catch, there was a typo. Added a test for the pipeline now as well. |
examples/research_projects/jax-projects/big_bird/bigbird_flax.py
Outdated
Show resolved
Hide resolved
examples/research_projects/jax-projects/big_bird/bigbird_flax.py
Outdated
Show resolved
Hide resolved
I've adapted the tests to match the new time stamp logit processor. I've double-checked that the new time stamp logit processor gives the same WER results on long-form and applied suggestions. Failing test is a time-out which is not related to this PR - merging! |
Thanks for implementing. It seems that longform decoding doesn't work with return_token_timestamps=True for model.generate() (nor return_timestamps="word" for pipeline() ) in V4.37.2 Failling at line 822 of whisper/generation_whisper.py in the private method _postprocess_outputs with error "AttribeError: 'tutple' object has no attribute 'cpu' " |
Hi @antoinethl, could you open a new issue, detailing the error encountered (including full traceback) and a minimal reproducer? |
Hi, just opened 2 issues with traceback and short example. |
What does this PR do?
This PR adds the long-form transcription as originally proposed in the Whisper codebase: https://github.com/openai/whisper and in the paper: https://cdn.openai.com/papers/whisper.pdf
To better understand long-form transcription, please have a look at Section 4.5: Strategies for Reliable Long-form Transcription of the paper.
Before this PR transformers only had "chunked" long-form transcription which trades speed against accuracy (see Table below). In this PR we add the best-performing long-form transcription to Transformers.
Usage:
One can use long-form transcription now easily with the pipeline object simply passing long-form audio. Previously, long-form audio was truncated to just 30 seconds. This PR makes sure that long audio is not cut when passed to the pipeline:
The pipeline is great for "easy-to-set-up" code but lacks customization and readability. For example the pipeline currently does not allow running the model with batch sizes > 1 and instead runs each audio 1-by-1, thus being very suboptimal regarding speed. To use long-form transcription for batch size > 1, you can use the following snippet:
Docs have been
added to give examples for both short- and longform transcription:
But I don't think that this is enough for people to notice this method. We should in my opinion create much better guides for Whisper (will be done in a follow-up PR).
Credits
IMPORTANT: Most of the code added from this PR was copied & tweaked from the original whisper code: https://github.com/openai/whisper/blob/main/whisper/transcribe.py . Therefore 90% of the credit of this PR goes to @jongwook as the original author of the
transcribe.py
code.Why copy the code?!: We originally weren't planning on integrating the full long-form transcription algorithm to
transformers
but a couple of reasons forced us now to add it:Next steps:
When looking at all long-form generation strategies:
Transformers has now support for the following:
In a follow-up PR we will add: "temperature fallback", "voice activity detection", and "previous text conditioning".
Results:
Note: that for "chunked transformers" the numbers are
crossed-throughbecause the original results from the whisper paper seem to have been slightly incorrect. Re-running the eval gives better results.Here the results for
openai/whisper-tiny.en
17.517.2824.123.6916.414.5517.417.46Here the results for
openai/whisper-small.en
15.112.6120.616.318.77.0614.514.02Here the results for
openai/whisper-large-v2
11.811.815.115.06.35.213.613.5Update:
It seems like the number we measured in the distil-whisper paper for chunked long-form are a bit off. Re-running them gives the following:
Here the results for
openai/whisper-tiny.en
=> New algo is on avg. 0.01 WER abs points worse which means it's identical
Here the results for
openai/whisper-small.en
=> New algo is on avg. 0.005 WER abs points worse which means it's identical
Here the results for
openai/whisper-large-v2
=> New algo is on avg. 0225 WER abs points better which means it's identical or (tiny tiny bit better)
Batch size > 1
The code now fully functions for batch size > 1 (made sure that results on the four datasets is within +/- 0.1 % WER). When using batch size = 8, there is a 4x speed-up for large-v2, 2x speed-up for small (and 1.5x speed-up for tiny). The bigger the model, the larger the speed-up!
One should definitely use larger batch sizes when doing long-form timestamp prediction!