Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Whisper] Add sequential longform decoding #27492

Merged
merged 39 commits into from
Nov 22, 2023
Merged

Conversation

patrickvonplaten
Copy link
Contributor

@patrickvonplaten patrickvonplaten commented Nov 14, 2023

What does this PR do?

This PR adds the long-form transcription as originally proposed in the Whisper codebase: https://github.com/openai/whisper and in the paper: https://cdn.openai.com/papers/whisper.pdf

To better understand long-form transcription, please have a look at Section 4.5: Strategies for Reliable Long-form Transcription of the paper.

Before this PR transformers only had "chunked" long-form transcription which trades speed against accuracy (see Table below). In this PR we add the best-performing long-form transcription to Transformers.

Usage:

One can use long-form transcription now easily with the pipeline object simply passing long-form audio. Previously, long-form audio was truncated to just 30 seconds. This PR makes sure that long audio is not cut when passed to the pipeline:

from datasets import load_dataset
from datasets import Audio
import numpy as np
from transformers import WhisperForConditionalGeneration, AutoProcessor, pipeline
import torch

processor = AutoProcessor.from_pretrained("openai/whisper-tiny.en")
model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-tiny.en", torch_dtype=torch.float16)
model = model.to("cuda")

pipe = pipeline(
    "automatic-speech-recognition",
    model=model,
    tokenizer=processor.tokenizer,
    feature_extractor=processor.feature_extractor,
    max_new_tokens=128,
    torch_dtype=torch.float16,
    device="cuda:0",
)

ds = load_dataset("distil-whisper/meanwhile", "default")["test"]
ds = ds.cast_column("audio", Audio(sampling_rate=16000))
audio = ds[:8]["audio"]

result = pipe(audio)

The pipeline is great for "easy-to-set-up" code but lacks customization and readability. For example the pipeline currently does not allow running the model with batch sizes > 1 and instead runs each audio 1-by-1, thus being very suboptimal regarding speed. To use long-form transcription for batch size > 1, you can use the following snippet:

from datasets import load_dataset
from datasets import Audio
import numpy as np
from transformers import WhisperForConditionalGeneration, AutoProcessor, pipeline
import torch

processor = AutoProcessor.from_pretrained("openai/whisper-tiny.en")
model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-tiny.en", torch_dtype=torch.float16)
model = model.to("cuda")

ds = load_dataset("distil-whisper/meanwhile", "default")["test"]
ds = ds.cast_column("audio", Audio(sampling_rate=16000))
audio = [x["array"] for x in ds[:8]["audio"]]

inputs = processor(audio, return_tensors="pt", truncation=False, padding="longest", return_attention_mask=True, sampling_rate=16_000)
inputs = inputs.to("cuda", torch.float16)

result = model.generate(**inputs, return_timestamps=True)
decoded = processor.batch_decode(result, skip_special_tokens=True)

print(decoded)

Docs have been
Screenshot from 2023-11-20 16-09-23
added to give examples for both short- and longform transcription:

But I don't think that this is enough for people to notice this method. We should in my opinion create much better guides for Whisper (will be done in a follow-up PR).

Credits

IMPORTANT: Most of the code added from this PR was copied & tweaked from the original whisper code: https://github.com/openai/whisper/blob/main/whisper/transcribe.py . Therefore 90% of the credit of this PR goes to @jongwook as the original author of the transcribe.py code.

Why copy the code?!: We originally weren't planning on integrating the full long-form transcription algorithm to transformers but a couple of reasons forced us now to add it:

Next steps:

When looking at all long-form generation strategies:
Screenshot from 2023-11-20 16-14-00

Transformers has now support for the following:

  • beam search
  • initial timestamp constraint

In a follow-up PR we will add: "temperature fallback", "voice activity detection", and "previous text conditioning".

Results:

Note: that for "chunked transformers" the numbers are crossed-through because the original results from the whisper paper seem to have been slightly incorrect. Re-running the eval gives better results.

Here the results for openai/whisper-tiny.en

x chunked transformers this pr openai/whisper repo
Earnings 21 17.5 17.28 16.13 15.28
Earnings 22 24.1 23.69 22.24 20.95
Meanwhile 16.4 14.55 13.27 12.90
Rev 17.4 17.46 16.10 14.77

Here the results for openai/whisper-small.en

x chunked transformers this pr openai/whisper repo
Earnings 21 15.1 12.61 11.51 11.0
Earnings 22 20.6 16.31 15.08 14.98
Meanwhile 8.7 7.06 6.52 6.49
Rev 14.5 14.02 12.28 11.93

Here the results for openai/whisper-large-v2

x chunked transformers this pr openai/whisper repo
Earnings 21 11.8 11.8 10.66 9.7
Earnings 22 15.1 15.0 13.93 12.6
Meanwhile 6.3 5.2 5.14 5.1
Rev 13.6 13.5 11.47 11.3

Update:

It seems like the number we measured in the distil-whisper paper for chunked long-form are a bit off. Re-running them gives the following:

Here the results for openai/whisper-tiny.en

x chunked transformers chunked transformers (this PR)
Earnings 21 17.28 17.29 (+0.01)
Earnings 22 23.69 23.7 (+0.02)
Meanwhile 14.55 14.58 (+0.03)
Rev 17.46 17.44 (-0.02)

=> New algo is on avg. 0.01 WER abs points worse which means it's identical

Here the results for openai/whisper-small.en

x chunked transformers chunked transformers (this PR)
Earnings 21 12.61 12.63 (+0.02)
Earnings 22 16.31 16.31 (+0.0)
Meanwhile 7.06 7.04 (-0.02)
Rev 14.02 14.04 (+0.02)

=> New algo is on avg. 0.005 WER abs points worse which means it's identical

Here the results for openai/whisper-large-v2

x chunked transformers chunked transformers (this PR)
Earnings 21 11.75 11.72 (-0.03)
Earnings 22 15.0 14.97 (-0.03)
Meanwhile 5.19 5.16 (-0.03)
Rev 13.53 13.53 (+0.0)

=> New algo is on avg. 0225 WER abs points better which means it's identical or (tiny tiny bit better)

Batch size > 1

The code now fully functions for batch size > 1 (made sure that results on the four datasets is within +/- 0.1 % WER). When using batch size = 8, there is a 4x speed-up for large-v2, 2x speed-up for small (and 1.5x speed-up for tiny). The bigger the model, the larger the speed-up!

One should definitely use larger batch sizes when doing long-form timestamp prediction!

@patrickvonplaten patrickvonplaten marked this pull request as draft November 14, 2023 14:31
@HuggingFaceDocBuilderDev
Copy link

HuggingFaceDocBuilderDev commented Nov 14, 2023

The documentation is not available anymore as the PR was closed or merged.

@Vaibhavs10
Copy link
Member

QQ: @patrickvonplaten - wouldn't concatenating and passing the whole audio as input result in exploding GPU VRAM usage?

@ArthurZucker
Copy link
Collaborator

ArthurZucker commented Nov 20, 2023

Hey @patrickvonplaten would you mind adding the performance for bigger models? The worst the model is at predicting timestamps, the worse the performances of the chuncked algorithm. I remember observing very little loss for large models! (Just as a FMI!)

@patrickvonplaten
Copy link
Contributor Author

QQ: @patrickvonplaten - wouldn't concatenating and passing the whole audio as input result in exploding GPU VRAM usage?

The audio is chunked on the fly (there is a while loop now)

@patrickvonplaten
Copy link
Contributor Author

Hey @patrickvonplaten would you mind adding the performance for bigger models? The worst the model is at predicting timestamps, the worse the performances of the chuncked algorithm. I remember observing very little loss for large models!

Sure I can run it for larger models as well. I'm not 100% sure though why this matters - if we see such strong gains for smaller models we should add it nevertheless.

Copy link
Member

@gante gante left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Awesome feature! Looking forward to the other upgrades mentioned in section 4.5 of the paper 🔥 The "previous text conditioning" probably can benefit from the newly added ability to return and reuse past_key_values

Added a few minor nits

src/transformers/generation/logits_process.py Outdated Show resolved Hide resolved
>>> # transcribe audio to ids
>>> generated_ids = model.generate(**inputs)

>>> transcription = processor.batch_decode(generated_ids, skip_special_tokens=True)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree that printing the entire generated output here would be a bad idea, but we could add something like

>>> len(transcription[0])

That way, our doctests would fail if we start having numerical problems with the longform decoding. WDYT?

elif not is_shortform:
if return_timestamps is False:
raise ValueError(
"You have passed more than 3000 mel input features (> 30 seconds) which automatically enables long-form generation which "
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A few lines here have more than 120 characters :)

This particular sentence also has 2x "which", rewriting might improve readability

Copy link
Member

@Vaibhavs10 Vaibhavs10 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There could be something wrong with the way I'm initialising the pipeline. But on my single file benchmark - it just truncates the output to 30 sec

Repro:https://github.com/Vaibhavs10/scratchpad/blob/main/conditional_long_form_generation_whisper.ipynb
Note: I'm not defining chunk size as it isn't defined in the example snippet up top.

It works as intended with model + generate tho! 🚀

More of an overall usage remark from a developer's PoV:
How do we clarify whether the transcription strategy used is chunked or conditional? Can we allow developers to choose? Supporting this via pipeline is important IMO.

Edit: To clarify, one of the biggest usecase for people to use pipeline is to throw an audio file in whichever format and then get the transcriptions for it.

Copy link
Collaborator

@ArthurZucker ArthurZucker left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the hard work!
Would just split the 6.2 in a seperate function, not falling in the trap of havving a huge generate. Same for the small functions used for decoding

@@ -1517,29 +1518,33 @@ class WhisperTimeStampLogitsProcessor(LogitsProcessor):
```
"""

def __init__(self, generate_config): # support for the kwargs
def __init__(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry if I am not understanding the statement 😅 No it was used for chunked processing because the merging algorithm heavily relies on timestamps, and also produces timestamps. See this line which always add the processor.

Placing a breakpoint in the call of this class with this:

from transformers import pipeline
from datasets import load_dataset
import numpy as np
transcriptor = pipeline("automatic-speech-recognition", model = "openai/whisper-tiny")
ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")
audio_file = ds[0]["audio"]["array"]
long_audio = np.concatenate([ds[0]["audio"]["array"]]*40)
out = transcriptor(long_audio, return_timestamps=True, chunk_length_s=30, stride_length_s=5)

(this is for me "timestamp" chunk transcription.

src/transformers/generation/logits_process.py Show resolved Hide resolved
src/transformers/generation/logits_process.py Show resolved Hide resolved
Comment on lines +1816 to +1818
return_segments (`bool`, *optional*, defaults to `False`):
Whether to additionally return a list of all segments. Note that this option can only be enabled
when doing long-form transcription.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

By segment do you mean segments of audio? (not sure I understand) This is only valid for sequential long form no?

Comment on lines +1821 to +1823
time_precision (`int`, *optional*, defaults to 0.02):
The duration of output token in seconds. *E.g.* 0.02 means that a generated token on average accounts
for 20 ms.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We have this in the pipeline:

            time_precision = self.feature_extractor.chunk_length / self.model.config.max_source_positions

which should be replaced I guess?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We don't have access to the feature extractor here sadly

**kwargs,
)
if generation_config.return_timestamps is True:
last_forced_decoder_ids = (
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure I understand, if the last token is no_timestamps we should still have the first 2-3 forced tokens (specifically one set for the task) unless it's done after

cur_bsz = prev_bsz = batch_size

# 6.2 Transcribe audio until we reach the end of all input audios
while (seek < max_frames).any():
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

let's seperate this in another function?
(Kind of like the merging function)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Factored out one part into a function

@patrickvonplaten
Copy link
Contributor Author

There could be something wrong with the way I'm initialising the pipeline. But on my single file benchmark - it just truncates the output to 30 sec

Repro:https://github.com/Vaibhavs10/scratchpad/blob/main/conditional_long_form_generation_whisper.ipynb Note: I'm not defining chunk size as it isn't defined in the example snippet up top.

It works as intended with model + generate tho! 🚀

More of an overall usage remark from a developer's PoV: How do we clarify whether the transcription strategy used is chunked or conditional? Can we allow developers to choose? Supporting this via pipeline is important IMO.

Edit: To clarify, one of the biggest usecase for people to use pipeline is to throw an audio file in whichever format and then get the transcriptions for it.

Nice catch, there was a typo. Added a test for the pipeline now as well.

@patrickvonplaten
Copy link
Contributor Author

patrickvonplaten commented Nov 22, 2023

I've adapted the tests to match the new time stamp logit processor. I've double-checked that the new time stamp logit processor gives the same WER results on long-form and applied suggestions.

Failing test is a time-out which is not related to this PR - merging!

@patrickvonplaten patrickvonplaten merged commit 4151fbb into main Nov 22, 2023
2 of 3 checks passed
@patrickvonplaten patrickvonplaten deleted the add_whisper_seq_gen branch November 22, 2023 12:27
@antoinethl
Copy link

Thanks for implementing. It seems that longform decoding doesn't work with return_token_timestamps=True for model.generate() (nor return_timestamps="word" for pipeline() ) in V4.37.2

Failling at line 822 of whisper/generation_whisper.py in the private method _postprocess_outputs with error "AttribeError: 'tutple' object has no attribute 'cpu' "

@amyeroberts
Copy link
Collaborator

Hi @antoinethl, could you open a new issue, detailing the error encountered (including full traceback) and a minimal reproducer?

@antoinethl
Copy link

Hi @antoinethl, could you open a new issue, detailing the error encountered (including full traceback) and a minimal reproducer?

Hi, just opened 2 issues with traceback and short example.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

7 participants