-
Notifications
You must be signed in to change notification settings - Fork 27.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Batched Whisper] ValueError on input mel features #30740
Comments
Hey @kerem0comert! Thanks for the detailed reproducer. There are three paradigms for Whisper generation:
I explain these decoding strategies in full in this YouTube video: https://www.youtube.com/live/92xX-E2y4GQ?si=GBxyimNo9-4z1tx_&t=1919 The corresponding code for each of these generation strategies are detailed on the Distil-Whisper model card: https://huggingface.co/distil-whisper/distil-large-v3#transformers-usage It's not apparent from the documentation how to use each of these strategies -> this is something we should definitely highlight better on the docs. The problem you're facing is that you have a short-form audio (<30-seconds), but are not padding/truncating it to 30-seconds before passing it to the model. This throws an error, since Whisper expects fixed inputs of 30-seconds. You can remedy this quickly by changing your args to the feature extractor: def predict_batch(
self, df_batch: pd.DataFrame, column_to_transcribe_into: str
) -> pd.DataFrame:
inputs: list[np.ndarray] = df_batch[COLUMN_AUDIO_DATA].tolist()
processed_inputs: BatchFeature = self.processor(
inputs,
return_tensors="pt",
- truncation=False,
- padding="longest",
return_attention_mask=True,
sampling_rate=self.asr_params.sampling_rate,
)
results = self.model.generate(
**(processed_inputs.to(self.device, torch.float16)),
)
df_batch[column_to_transcribe_into] = [str(r["text"]).strip() for r in results]
return df_batch Your code will now work for short-form generation, but not sequential long-form! To handle both automatically, I suggest you use the following code: def predict_batch(
self, df_batch: pd.DataFrame, column_to_transcribe_into: str
) -> pd.DataFrame:
inputs: list[np.ndarray] = df_batch[COLUMN_AUDIO_DATA].tolist()
# assume we have long-form audios
processed_inputs: BatchFeature = self.processor(
inputs,
return_tensors="pt",
truncation=False,
padding="longest",
return_attention_mask=True,
sampling_rate=self.asr_params.sampling_rate,
)
if processed_inputs.input_features.shape[-1] < 3000:
# we in-fact have short-form -> pre-process accordingly
processed_inputs: BatchFeature = self.processor(
inputs,
return_tensors="pt",
sampling_rate=self.asr_params.sampling_rate,
)
results = self.model.generate(
**(processed_inputs.to(self.device, torch.float16)),
)
df_batch[column_to_transcribe_into] = [str(r["text"]).strip() for r in results]
return df_batch What we're doing is first assuming we have a long-audio segment. If we compute the log-mel features and in-fact find we have a short-form audio, then we re-compute the log-mel with padding and truncation to 30-seconds. You can use a similar logic to pass the long-form kwargs |
Thanks for your very detailed response, this was indeed it! Just for completeness, I had to make a small change - since your version returned the tokens but I need the detokenized text version: def predict_batch(
self, df_batch: pd.DataFrame, column_to_transcribe_into: str
) -> pd.DataFrame:
inputs: list[np.ndarray] = df_batch[COLUMN_AUDIO_DATA].tolist()
# assume we have long-form audios
processed_inputs: BatchFeature = self.processor(
inputs,
return_tensors="pt",
truncation=False,
padding="longest",
return_attention_mask=True,
sampling_rate=self.asr_params.sampling_rate,
)
if processed_inputs.input_features.shape[-1] < 3000:
# we in-fact have short-form -> pre-process accordingly
processed_inputs = self.processor(
inputs,
return_tensors="pt",
sampling_rate=self.asr_params.sampling_rate,
)
result_tokens = self.model.generate(
**(processed_inputs.to(self.device, torch.float16)),
)
results_texts = self.processor.batch_decode(
result_tokens, skip_special_tokens=True
)
df_batch[column_to_transcribe_into] = [
str(result_text).strip() for result_text in results_texts
]
return df_batch |
System Info
transformers
version: 4.36.2Who can help?
Hello,
I am using a finetuned
Whisper
model for transcription, and it works well. However I get the warning:and as such I would like to take advantage of batching given that I run this on a GPU.
As such I implemented the code that I shared in the
Reproduction
section.I wanted to do it via this fork, but I see that in its README, it is recommended that I follow this instead. In my code snippet,
self.model
is an instance of:<class 'transformers.models.whisper.modeling_whisper.WhisperForConditionalGeneration'>
and I have two problems:
generate()
call like so:I get:
If I exclude the flags (which I do not mind) like so:
This time I get:
for which I could not find a satisfactory explanation yet, so any help would be much appreciated. Thanks!
Information
Tasks
examples
folder (such as GLUE/SQuAD, ...)Reproduction
Expected behavior
Transcribed results of Whisper, ideally with timestamps
The text was updated successfully, but these errors were encountered: