Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Whisper Sequential long-form decoding doesn't work when forcing task #28978

Closed
4 tasks
antoinethl opened this issue Feb 12, 2024 · 3 comments
Closed
4 tasks

Comments

@antoinethl
Copy link

System Info

  • transformers version: 4.37.2
  • Platform: Linux-4.15.0-142-generic-x86_64-with-glibc2.23
  • Python version: 3.10.11
  • Huggingface_hub version: 0.20.3
  • Safetensors version: 0.4.2
  • Accelerate version: not installed
  • Accelerate config: not found
  • PyTorch version (GPU?): 2.0.1 (True)
  • Tensorflow version (GPU?): 2.12.0 (True)

Who can help?

No response

Information

  • The official example scripts
  • My own modified scripts

Tasks

  • An officially supported task in the examples folder (such as GLUE/SQuAD, ...)
  • My own task or dataset (give details below)

Reproduction

Similar as #28977 , the long-form decoding recently added in [Whisper] Add sequential longform decoding seems to have issues in some parameters. There it's the task specification that seems problematic.

It is also linked to another issue : transformers's Whisper implementation seems to force the output language to be English. Tested with French, German, Dutch audios, result is always the same : Whisper translate the audio into English when the task isn't set (and language aswell obviously).

Here is the discussion about the issue : https://huggingface.co/openai/whisper-large-v3/discussions/71

So while trying to bypass this issue of English-only output, I tried, as mentionned in the discussion, to set the task="transcribe" to force the model to transcribe the audio. But when working with long audio and the new implementation of long-form decoding, the issue occured.

Here is a minimal example to reproduce the issue:

from transformers import WhisperForConditionalGeneration, WhisperProcessor, pipeline
import librosa

SR = 16000
model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-medium")
processor = WhisperProcessor.from_pretrained("openai/whisper-medium")

file_path = "path_to_more_than_30_sec_audio"
audio, _ = librosa.load(file_path, sr=SR)

# Long-form transcription with model.generate()
input_features = processor(audio, 
                           sampling_rate=SR, 
                           return_tensors="pt", 
                           truncation=False, # False so the audio isn't truncated and whole audio is sent to the model
                           return_attention_mask=True, 
                           padding="longest")

predicted_ids = model.generate(**input_features,
                               task="transcribe") # If you remove this parameter, it works as expected

Traceback

TypeError                                 Traceback (most recent call last)
Cell In[39], line 19
     11 # Long-form generation
     12 input_features = processor(audio, 
     13                            sampling_rate=16000, 
     14                            return_tensors="pt", 
     15                            truncation=False, 
     16                            return_attention_mask=True, 
     17                            padding="longest")
---> 19 predicted_ids = model.generate(**input_features, task="transcribe")

File ~/miniconda3/envs/py310-fast/lib/python3.10/site-packages/transformers/models/whisper/generation_whisper.py:614, in WhisperGenerationMixin.generate(self, input_features, generation_config, logits_processor, stopping_criteria, prefix_allowed_tokens_fn, synced_gpus, return_timestamps, task, language, is_multilingual, prompt_ids, condition_on_prev_tokens, temperature, compression_ratio_threshold, logprob_threshold, no_speech_threshold, num_segment_frames, attention_mask, time_precision, return_token_timestamps, return_segments, return_dict_in_generate, **kwargs)
    610 # 6.5 prepare decoder input ids
    611 suppress_tokens = _get_attr_from_logit_processors(
    612     logits_processor, SuppressTokensLogitsProcessor, "suppress_tokens"
    613 )
--> 614 decoder_input_ids, kwargs = self._prepare_decoder_input_ids(
    615     cur_bsz=cur_bsz,
    616     init_tokens=init_tokens,
    617     current_segments=current_segments,
    618     batch_idx_map=batch_idx_map,
    619     do_condition_on_prev_tokens=do_condition_on_prev_tokens,
    620     generation_config=generation_config,
    621     config=self.config,
    622     device=segment_input.device,
    623     suppress_tokens=suppress_tokens,
    624     kwargs=kwargs,
    625 )
    627 # 6.6 set max new tokens or max length
    628 kwargs = self._set_max_new_tokens_and_length(
    629     config=self.config,
    630     decoder_input_ids=decoder_input_ids,
    631     generation_config=generation_config,
    632     kwargs=kwargs,
    633 )

File ~/miniconda3/envs/py310-fast/lib/python3.10/site-packages/transformers/models/whisper/generation_whisper.py:1322, in WhisperGenerationMixin._prepare_decoder_input_ids(cur_bsz, init_tokens, current_segments, batch_idx_map, do_condition_on_prev_tokens, generation_config, config, device, suppress_tokens, kwargs)
   1319 cut_off_length = config.max_target_positions // 2 - 1
   1321 one_tensor = torch.ones((cur_bsz, 1), device=device, dtype=torch.long)
-> 1322 decoder_input_ids = torch.cat([t * one_tensor for t in init_tokens], dim=-1)
   1324 prev_start_of_text = getattr(generation_config, "prev_sot_token_id", None)
   1325 if prev_start_of_text is None:

File ~/miniconda3/envs/py310-fast/lib/python3.10/site-packages/transformers/models/whisper/generation_whisper.py:1322, in <listcomp>(.0)
   1319 cut_off_length = config.max_target_positions // 2 - 1
   1321 one_tensor = torch.ones((cur_bsz, 1), device=device, dtype=torch.long)
-> 1322 decoder_input_ids = torch.cat([t * one_tensor for t in init_tokens], dim=-1)
   1324 prev_start_of_text = getattr(generation_config, "prev_sot_token_id", None)
   1325 if prev_start_of_text is None:

TypeError: unsupported operand type(s) for *: 'NoneType' and 'Tensor'

Expected behavior

Model should be able to work with the task parameter when processing long audio after #27492

@amyeroberts
Copy link
Collaborator

cc @sanchit-gandhi @ylacombe

@patrickvonplaten
Copy link
Contributor

Hey @antoinethl,

Thanks for reporting the bug! Note that the bug is already solved on "main" with #28687. Could you try to install transformers as follows:

!pip install git+https://github.com/huggingface/transformers

and run your code snippet again?

@antoinethl
Copy link
Author

Hey @antoinethl,

Thanks for reporting the bug! Note that the bug is already solved on "main" with #28687. Could you try to install transformers as follows:

!pip install git+https://github.com/huggingface/transformers

and run your code snippet again?

Hi, thanks for the quick reply, seems indeed fixed with PR #28687 . Working when updating to the 4.38 dev version

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants