-
Notifications
You must be signed in to change notification settings - Fork 27.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Cannot batch them ({'num_frames', 'input_features', 'is_last'} != {'input_features', 'is_last'}) #33415
Comments
pipelines: @Rocketknight1 |
Hey @minmie, thanks for opening this issue and for providing a snippet! Note that while providing a snippet is already great, what's even better for us is to provide a self-sufficient snippet, where the maintainer can reproduce the issue without having to change the code. Here, for example, it could have been great to change the model id to something available on the hub (e.g openai/whisper-small) and the data to something available on the web! This seems like the same issue than #33404, let me look how to correct this. |
I can't seem to reproduce with this code, using the main branch : import time
from transformers import pipeline, WhisperForConditionalGeneration, AutoModelForSpeechSeq2Seq, AutoProcessor
import torch
TASK = 'transcribe'
paths = ["https://huggingface.co/datasets/Narsil/asr_dummy/resolve/main/1.flac",
"https://huggingface.co/datasets/Narsil/asr_dummy/resolve/main/2.flac",
"https://huggingface.co/datasets/Narsil/asr_dummy/resolve/main/3.flac",
"https://huggingface.co/datasets/Narsil/asr_dummy/resolve/main/4.flac"]
device = "cuda:0" if torch.cuda.is_available() else "cpu"
torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32
model = AutoModelForSpeechSeq2Seq.from_pretrained(
"openai/whisper-small", torch_dtype=torch_dtype, low_cpu_mem_usage=True
)
model.to(device)
processor = AutoProcessor.from_pretrained( "openai/whisper-small", language="en", task=TASK)
pipe = pipeline(
"automatic-speech-recognition",
model=model,
tokenizer=processor.tokenizer,
feature_extractor=processor.feature_extractor,
torch_dtype=torch_dtype,
device=device,
)
t1 = time.time()
print(pipe(paths, batch_size=4))
print(f'time cost:{time.time()-t1}') Can you verify that your issue still holds using the main branch? And if so, try to make it reproducible ? Many thanks! |
hi, @ylacombe , thanks for your help. And I use your code with my own audio files, error still exist. Below is a self-sufficient snippet, I updated the MODEL_ID and attached the audio files. import time
from transformers import pipeline, WhisperForConditionalGeneration, WhisperProcessor
import os
import torch
DATA_DIR = r'C:\Users\chenjq2\Desktop\wav_data'
LANGUAGE = 'zh'
TASK = 'transcribe'
files = os.listdir(DATA_DIR)
paths = []
for name in files:
paths.append(os.path.join(DATA_DIR, name))
MODEL_ID = "BELLE-2/Belle-whisper-large-v3-zh-punct"
device = "cuda:0" if torch.cuda.is_available() else "cpu"
torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32
model = WhisperForConditionalGeneration.from_pretrained(
MODEL_ID, torch_dtype=torch_dtype, low_cpu_mem_usage=True
)
model.to(device)
processor = WhisperProcessor.from_pretrained(MODEL_ID, language=LANGUAGE, task=TASK)
pipe = pipeline(
"automatic-speech-recognition",
model=model,
tokenizer=processor.tokenizer,
feature_extractor=processor.feature_extractor,
torch_dtype=torch_dtype,
device=device,
)
t1 = time.time()
print(pipe(paths, batch_size=2))
print(f'time cost:{time.time()-t1}') |
Hey @minmie, thanks for providing this snippet and the audio samples! I was able to reproduce the issue. In short, the bug happens when, in the same batch, there are samples that are longer and shorter than 30s, which is the fixed length on which Whisper was trained. You can already avoid this bug by passing Making long and short-form transcription compatible with the pipeline could probably be possible, but it goes against the pipeline philosophy IMO since:
I think we should instead make it very clear that it's better to pass I'd love to have your take on this @amyeroberts and @Rocketknight1 ! |
Hmmn, blending the two would be quite difficult. However, supporting long-form generation would be nice. Maybe we can split inputs in the batch, process long and short inputs separately, and then merge results at the end? The pipeline would need updates to handle it, but I don't think it's intuitive for users either that the pipeline breaks when an input is more than 30 seconds, because this is presumably a common use-case. |
The original sin here seems to be having this dual behaviour within Whisper's generation logic. If a user is using a pipeline, they shouldn't care which model is being used, or the length of their audio, the whole principle is that one can easily swap out any checkpoint and get predictions in a single line i.e. they shouldn't have to pass in different arguments to make things work.
How common a problem would you expect this to be? i.e. what's the general variance in audio clip length you'd expect in a batch and what would the time/memory cost be? In general, I don't think we should try to over-optimize in the case of pipelines, they're not designed for heavy-duty training / inference but rather a simple wrapper that enables predictions in just a few lines of code. If it's going to add a lot of time to get the predictions, then let's be wary, but a second or two might be OK. Would be good to have the opinion of @Rocketknight1 here though :)
At the moment this isn't clear, so we should think about how to improve this. The simplest solution I think would be to: detect the model type, audio length and if
It would be nice to either:
My main questions are:
|
Thanks for your insights @amyeroberts and @Rocketknight1!
No, it's exclusively for Whisper.
Yes it is.
I agree there. Long-form generation was supported to better stick with Whisper's original implementation and to give better performance. I believe it made sense though, since the model is extremely popular!
I believe this would be quite common, but I don't think we can expect a typical audio clip length. That said, other ASR models don't support long-form generation by default and you can only use those with long-form if you're passing I think we agree that the pipeline should be usable whatever the audios' length. Here, it happens because both the pipeline and the
At the moment, the pipeline can't pad |
OK - that's good to know. There's still a question of how long is long-form, but if 30s is pretty standard, we can try and flag when the pipeline is called based on the inputs that I'm happy with 1. + 2. if @Rocketknight1 is :) |
Sure, works for me! |
@ylacombe Thanks, it works! |
I have the same problem.
when i use pipe to inference with batch_size=1, everything is ok. However the error occur when infer with batch_size>1.
transformers: 4.44.0
torch: 2.1.2
model: whisper-large-v3-zh-punct
autio_data: wav data
error msg:
The differnece is due to this:
It will have an additional field num_frames if the code runs to block 2, but not if it runs to block 1.
XXX\transformers\pipelines\automatic_speech_recognition.py
Could anyone tell me how to solve it?
Originally posted by @minmie in #33404 (comment)
The text was updated successfully, but these errors were encountered: