Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Cannot batch them ({'num_frames', 'input_features', 'is_last'} != {'input_features', 'is_last'}) #33415

Closed
minmie opened this issue Sep 11, 2024 · 11 comments
Labels
Audio Core: Pipeline Internals of the library; Pipeline.

Comments

@minmie
Copy link

minmie commented Sep 11, 2024

I have the same problem.
when i use pipe to inference with batch_size=1, everything is ok. However the error occur when infer with batch_size>1.

transformers: 4.44.0
torch: 2.1.2
model: whisper-large-v3-zh-punct
autio_data: wav data

import time

from transformers import pipeline, WhisperForConditionalGeneration, AutoModelForSpeechSeq2Seq, AutoProcessor
import os
import torch

DATA_DIR = r'C:\Users\chenjq2\Desktop\wav格式录音'
# DATA_DIR = r'./test_data'
LANGUAGE = 'zh'
TASK = 'transcribe'
files = os.listdir(DATA_DIR)

paths = []
for name in files:
    paths.append(os.path.join(DATA_DIR, name))
MODEL_ID = r"C:\Users\chenjq2\Desktop\tmp\models--BELLE-2--Belle-whisper-large-v3-zh-punct\snapshots\f81f1ac2f123f118094a7baa69e532eab375600e"
device = "cuda:0" if torch.cuda.is_available() else "cpu"
torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32
model = AutoModelForSpeechSeq2Seq.from_pretrained(
    MODEL_ID, torch_dtype=torch_dtype, low_cpu_mem_usage=True
)
model.to(device)

processor = AutoProcessor.from_pretrained(MODEL_ID, language=LANGUAGE, task=TASK)

pipe = pipeline(
    "automatic-speech-recognition",
    model=model,
    tokenizer=processor.tokenizer,
    feature_extractor=processor.feature_extractor,
    torch_dtype=torch_dtype,
    device=device,
)
t1 = time.time()
print(pipe(paths, batch_size=4))
print(f'time cost:{time.time()-t1}')

error msg:

E:\program\anaconda3\envs\nlp\lib\site-packages\torch\_utils.py:831: UserWarning: TypedStorage is deprecated. It will be removed in the future and UntypedStorage will be the only storage class. This should only matter to you if you are using storages directly.  To access UntypedStorage directly, use tensor.untyped_storage() instead of tensor.storage()
  return self.fget.__get__(instance, owner)()
Traceback (most recent call last):
  File "E:\program\anaconda3\envs\nlp\lib\site-packages\torch\utils\data\dataloader.py", line 630, in __next__
    data = self._next_data()
  File "E:\program\anaconda3\envs\nlp\lib\site-packages\torch\utils\data\dataloader.py", line 674, in _next_data
    data = self._dataset_fetcher.fetch(index)  # may raise StopIteration
  File "E:\program\anaconda3\envs\nlp\lib\site-packages\torch\utils\data\_utils\fetch.py", line 42, in fetch
    return self.collate_fn(data)
  File "E:\program\anaconda3\envs\nlp\lib\site-packages\transformers\pipelines\base.py", line 175, in inner
    raise ValueError(
ValueError: The elements of the batch contain different keys. Cannot batch them ({'num_frames', 'input_features', 'is_last'} != {'input_features', 'is_last'})

The differnece is due to this:

It will have an additional field num_frames if the code runs to block 2, but not if it runs to block 1.

XXX\transformers\pipelines\automatic_speech_recognition.py
Snipaste_2024-09-11_10-40-47

Could anyone tell me how to solve it?

Originally posted by @minmie in #33404 (comment)

@minmie minmie closed this as not planned Won't fix, can't repro, duplicate, stale Sep 11, 2024
@minmie minmie reopened this Sep 11, 2024
@minmie
Copy link
Author

minmie commented Sep 11, 2024

pipelines: @Rocketknight1
speech models: @ylacombe, @eustlb

@ylacombe
Copy link
Contributor

Hey @minmie, thanks for opening this issue and for providing a snippet!

Note that while providing a snippet is already great, what's even better for us is to provide a self-sufficient snippet, where the maintainer can reproduce the issue without having to change the code. Here, for example, it could have been great to change the model id to something available on the hub (e.g openai/whisper-small) and the data to something available on the web!

This seems like the same issue than #33404, let me look how to correct this.

@ylacombe
Copy link
Contributor

I can't seem to reproduce with this code, using the main branch :

import time

from transformers import pipeline, WhisperForConditionalGeneration, AutoModelForSpeechSeq2Seq, AutoProcessor
import torch

TASK = 'transcribe'

paths = ["https://huggingface.co/datasets/Narsil/asr_dummy/resolve/main/1.flac",
         "https://huggingface.co/datasets/Narsil/asr_dummy/resolve/main/2.flac",
         "https://huggingface.co/datasets/Narsil/asr_dummy/resolve/main/3.flac",
         "https://huggingface.co/datasets/Narsil/asr_dummy/resolve/main/4.flac"]

device = "cuda:0" if torch.cuda.is_available() else "cpu"
torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32

model = AutoModelForSpeechSeq2Seq.from_pretrained(
     "openai/whisper-small", torch_dtype=torch_dtype, low_cpu_mem_usage=True
)
model.to(device)

processor = AutoProcessor.from_pretrained( "openai/whisper-small", language="en", task=TASK)

pipe = pipeline(
    "automatic-speech-recognition",
    model=model,
    tokenizer=processor.tokenizer,
    feature_extractor=processor.feature_extractor,
    torch_dtype=torch_dtype,
    device=device,
)
t1 = time.time()
print(pipe(paths, batch_size=4))
print(f'time cost:{time.time()-t1}')

Can you verify that your issue still holds using the main branch? And if so, try to make it reproducible ? Many thanks!

@minmie
Copy link
Author

minmie commented Sep 13, 2024

hi, @ylacombe , thanks for your help.
I have tested the code you provided using the main branch and it can indeed run normally because all audio files (short file) have executed the code of block 2. In my example, the audio file I used contained a long audio file (32 seconds), which caused the code to run to block 1, resulting in a missing num_frames field.

And I use your code with my own audio files, error still exist.

image

Below is a self-sufficient snippet, I updated the MODEL_ID and attached the audio files.
wav_data.zip

import time

from transformers import pipeline, WhisperForConditionalGeneration, WhisperProcessor
import os
import torch

DATA_DIR = r'C:\Users\chenjq2\Desktop\wav_data'
LANGUAGE = 'zh'
TASK = 'transcribe'
files = os.listdir(DATA_DIR)

paths = []
for name in files:
    paths.append(os.path.join(DATA_DIR, name))
MODEL_ID = "BELLE-2/Belle-whisper-large-v3-zh-punct"
device = "cuda:0" if torch.cuda.is_available() else "cpu"
torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32
model = WhisperForConditionalGeneration.from_pretrained(
    MODEL_ID, torch_dtype=torch_dtype, low_cpu_mem_usage=True
)
model.to(device)

processor = WhisperProcessor.from_pretrained(MODEL_ID, language=LANGUAGE, task=TASK)

pipe = pipeline(
    "automatic-speech-recognition",
    model=model,
    tokenizer=processor.tokenizer,
    feature_extractor=processor.feature_extractor,
    torch_dtype=torch_dtype,
    device=device,
)
t1 = time.time()
print(pipe(paths, batch_size=2))
print(f'time cost:{time.time()-t1}')

@ylacombe
Copy link
Contributor

Hey @minmie, thanks for providing this snippet and the audio samples! I was able to reproduce the issue.

In short, the bug happens when, in the same batch, there are samples that are longer and shorter than 30s, which is the fixed length on which Whisper was trained. You can already avoid this bug by passing chunk_length_s=30.0 to the pipeline.

Making long and short-form transcription compatible with the pipeline could probably be possible, but it goes against the pipeline philosophy IMO since:

  • The long-form generation of Whisper already chunks the inputs and processes them either in batch or sequentially
  • The ASR pipeline is supposed to do the same with the batch_size and the chunk_length_s arguments. How can we reconcile everything, and how can we deal with contradictory arguments?
  • To use long-form generation, you also have to pass return_token_timestamps to both the feature extractor and the model.generate. It's not done by default in the pipeline, which adds onto the steps the user must do (chose which model, which parameters, understand how to do long-form, batching etc.)
  • In long-form, the input is not truncated. It is with short-form. If we decide to reconcile long and short form in the pipeline, we'd have to pad the short input to the long input size. In that case, we would do unnecessary operation (one generation pass on at least one empty input)

I think we should instead make it very clear that it's better to pass chunk_length_s=30.0 when using the ASR pipeline and Whisper (short-audio are padded by default to 30.0). Any other more advance use-cases should probably require using Whisper in the usual way.

I'd love to have your take on this @amyeroberts and @Rocketknight1 !

@Rocketknight1
Copy link
Member

Hmmn, blending the two would be quite difficult. However, supporting long-form generation would be nice.

Maybe we can split inputs in the batch, process long and short inputs separately, and then merge results at the end? The pipeline would need updates to handle it, but I don't think it's intuitive for users either that the pipeline breaks when an input is more than 30 seconds, because this is presumably a common use-case.

@amyeroberts
Copy link
Collaborator

The long-form generation of Whisper already chunks the inputs and processes them either in batch or sequentially

The original sin here seems to be having this dual behaviour within Whisper's generation logic. If a user is using a pipeline, they shouldn't care which model is being used, or the length of their audio, the whole principle is that one can easily swap out any checkpoint and get predictions in a single line i.e. they shouldn't have to pass in different arguments to make things work.

In long-form, the input is not truncated. It is with short-form. If we decide to reconcile long and short form in the pipeline, we'd have to pad the short input to the long input size. In that case, we would do unnecessary operation (one generation pass on at least one empty input)

How common a problem would you expect this to be? i.e. what's the general variance in audio clip length you'd expect in a batch and what would the time/memory cost be? In general, I don't think we should try to over-optimize in the case of pipelines, they're not designed for heavy-duty training / inference but rather a simple wrapper that enables predictions in just a few lines of code. If it's going to add a lot of time to get the predictions, then let's be wary, but a second or two might be OK. Would be good to have the opinion of @Rocketknight1 here though :)

I think we should instead make it very clear that it's better to pass chunk_length_s=30.0 when using the ASR pipeline and Whisper (short-audio are padded by default to 30.0).

At the moment this isn't clear, so we should think about how to improve this. The simplest solution I think would be to: detect the model type, audio length and if chunk_length_s is set and raise a warning if the values would result in this bug. However, as the pipelines should be model agnostic, this is a patch at best.


Maybe we can split inputs in the batch, process long and short inputs separately, and then merge results at the end?

It would be nice to either:

  • Have clearly separated long form and short form audio pipelines OR
  • Have a ASR pipeline which handles both long and short

My main questions are:

  • is "long" vs. "short" a standardized definition? i.e. does everyone agree >30s is long form?
  • is this only a whisper issue?

@ylacombe
Copy link
Contributor

ylacombe commented Sep 17, 2024

Thanks for your insights @amyeroberts and @Rocketknight1!

is "long" vs. "short" a standardized definition? i.e. does everyone agree >30s is long form?

No, it's exclusively for Whisper.
Whisper generation code automatically chunks audio every 30 seconds. Not that it's because its encoder was trained on 30s-only segments, and it's not something that we can change about the model.
Other ASR models don't chunk, and the pipeline's chunk_length_s is especially great for those.

is this only a whisper issue?

Yes it is.

The original sin here seems to be having this dual behaviour within Whisper's generation logic.

I agree there. Long-form generation was supported to better stick with Whisper's original implementation and to give better performance. I believe it made sense though, since the model is extremely popular!

How common a problem would you expect this to be? i.e. what's the general variance in audio clip length you'd expect in a batch and what would the time/memory cost be?

I believe this would be quite common, but I don't think we can expect a typical audio clip length. That said, other ASR models don't support long-form generation by default and you can only use those with long-form if you're passing chunk_length_s, so users have to be careful about it whatever the model used.


I think we agree that the pipeline should be usable whatever the audios' length. Here, it happens because both the pipeline and the Whisper.generate are able to deal with long-form audios, by chunking them. What I propose is:

  1. make the ASR pipeline compatible with Whisper usage when bs>1 and the input audios are both short and long (at the cost of padding and having non-agnostic code)
  2. log a warning when it's the case advising to either use bs=1 or to use chunk_length_s or to use the model directly.

At the moment, the pipeline can't pad input_features, so I'll probably have to modify those lines a bit, do you think it's okay @Rocketknight1 ?

@amyeroberts
Copy link
Collaborator

That said, other ASR models don't support long-form generation by default and you can only use those with long-form if you're passing chunk_length_s, so users have to be careful about it whatever the model used.

OK - that's good to know. There's still a question of how long is long-form, but if 30s is pretty standard, we can try and flag when the pipeline is called based on the inputs that chunk_length_s needs to be specified and make sure this is documented somewhere, ideally in a code demo.

I'm happy with 1. + 2. if @Rocketknight1 is :)

@Rocketknight1
Copy link
Member

Sure, works for me!

@minmie
Copy link
Author

minmie commented Sep 26, 2024

@ylacombe Thanks, it works!

@minmie minmie closed this as completed Sep 26, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Audio Core: Pipeline Internals of the library; Pipeline.
Projects
None yet
Development

No branches or pull requests

5 participants