Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Whisper Sequential long-form decoding doesn't work with timestamps per token #28977

Open
4 tasks
antoinethl opened this issue Feb 12, 2024 · 3 comments
Open
4 tasks
Labels
Feature request Request for a new feature

Comments

@antoinethl
Copy link

antoinethl commented Feb 12, 2024

System Info

  • transformers version: 4.37.2
  • Platform: Linux-4.15.0-142-generic-x86_64-with-glibc2.23
  • Python version: 3.10.11
  • Huggingface_hub version: 0.20.3
  • Safetensors version: 0.4.2
  • Accelerate version: not installed
  • Accelerate config: not found
  • PyTorch version (GPU?): 2.0.1 (True)
  • Tensorflow version (GPU?): 2.12.0 (True)

Who can help?

No response

Information

  • The official example scripts
  • My own modified scripts

Tasks

  • An officially supported task in the examples folder (such as GLUE/SQuAD, ...)
  • My own task or dataset (give details below)

Reproduction

Following [Whisper] Add sequential longform decoding, it seems that there is an issue when asking for token timestamps when dealing with the new way of handling long-form transcriptions.
If using model.generate() method, passing return_token_timestamps=True causes the issue. Occurs also with the pipeline object if setting return_timestamps="word".

Here is a simple example to reproduce the issue:

from transformers import WhisperForConditionalGeneration, WhisperProcessor, pipeline
import librosa

SR = 16000
model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-medium")
processor = WhisperProcessor.from_pretrained("openai/whisper-medium")

file_path = "path_to_more_than_30_sec_audio"
audio, _ = librosa.load(file_path, sr=SR)

# Long-form transcription with model.generate()
input_features = processor(audio, 
                           sampling_rate=SR, 
                           return_tensors="pt", 
                           truncation=False, # False so the audio isn't truncated and whole audio is sent to the model
                           return_attention_mask=True, 
                           padding="longest")

predicted_ids = model.generate(**input_features,
                               return_token_timestamps=True)

# With pipeline
pipe = pipeline("automatic-speech-recognition", 
                model=model, 
                tokenizer=processor.tokenizer, 
                feature_extractor=processor.feature_extractor, 
                return_timestamps="word",
                return_language=True
                )

pipe(audio)

Traceback:

AttributeError                            Traceback (most recent call last)
Cell In[26], line 19
     11 # Long-form generation
     12 input_features = processor(audio, 
     13                            sampling_rate=16000, 
     14                            return_tensors="pt", 
     15                            truncation=False, 
     16                            return_attention_mask=True, 
     17                            padding="longest")
---> 19 predicted_ids = model.generate(**input_features,
     20                                return_token_timestamps=True)

File ~/miniconda3/envs/py310-fast/lib/python3.10/site-packages/transformers/models/whisper/generation_whisper.py:641, in WhisperGenerationMixin.generate(self, input_features, generation_config, logits_processor, stopping_criteria, prefix_allowed_tokens_fn, synced_gpus, return_timestamps, task, language, is_multilingual, prompt_ids, condition_on_prev_tokens, temperature, compression_ratio_threshold, logprob_threshold, no_speech_threshold, num_segment_frames, attention_mask, time_precision, return_token_timestamps, return_segments, return_dict_in_generate, **kwargs)
    638         proc.set_begin_index(decoder_input_ids.shape[-1])
    640 # 6.8 Run generate with fallback
--> 641 seek_sequences, seek_outputs, should_skip, do_condition_on_prev_tokens = self.generate_with_fallback(
    642     segment_input=segment_input,
    643     decoder_input_ids=decoder_input_ids,
    644     cur_bsz=cur_bsz,
    645     batch_idx_map=batch_idx_map,
    646     seek=seek,
    647     num_segment_frames=num_segment_frames,
    648     max_frames=max_frames,
    649     temperatures=temperatures,
    650     generation_config=generation_config,
    651     logits_processor=logits_processor,
    652     stopping_criteria=stopping_criteria,
    653     prefix_allowed_tokens_fn=prefix_allowed_tokens_fn,
    654     synced_gpus=synced_gpus,
    655     return_token_timestamps=return_token_timestamps,
    656     do_condition_on_prev_tokens=do_condition_on_prev_tokens,
    657     kwargs=kwargs,
    658 )
    660 # 6.9 In every generated sequence, split by timestamp tokens and extract segments
    661 for i, seek_sequence in enumerate(seek_sequences):

File ~/miniconda3/envs/py310-fast/lib/python3.10/site-packages/transformers/models/whisper/generation_whisper.py:739, in WhisperGenerationMixin.generate_with_fallback(self, segment_input, decoder_input_ids, cur_bsz, batch_idx_map, seek, num_segment_frames, max_frames, temperatures, generation_config, logits_processor, stopping_criteria, prefix_allowed_tokens_fn, synced_gpus, return_token_timestamps, do_condition_on_prev_tokens, kwargs)
    727 seek_outputs = super().generate(
    728     segment_input,
    729     generation_config,
   (...)
    735     **kwargs,
    736 )
    738 # post-process sequence tokens and outputs to be in list form
--> 739 sequence_tokens, seek_outputs = self._postprocess_outputs(
    740     seek_outputs, return_token_timestamps, generation_config
    741 )
    743 # remove all previously passed decoder input ids
    744 seek_sequences = sequence_tokens[:, decoder_input_ids.shape[-1] :]

File ~/miniconda3/envs/py310-fast/lib/python3.10/site-packages/transformers/models/whisper/generation_whisper.py:825, in WhisperGenerationMixin._postprocess_outputs(self, seek_outputs, return_token_timestamps, generation_config)
    822         return values[batch_idx].cpu()
    824     sequence_tokens = seek_outputs["sequences"]
--> 825     seek_outputs = [
    826         {k: split_by_batch_index(v, k, i) for k, v in seek_outputs.items()}
    827         for i in range(sequence_tokens.shape[0])
    828     ]
    829 else:
    830     sequence_tokens = seek_outputs

File ~/miniconda3/envs/py310-fast/lib/python3.10/site-packages/transformers/models/whisper/generation_whisper.py:826, in <listcomp>(.0)
    822         return values[batch_idx].cpu()
    824     sequence_tokens = seek_outputs["sequences"]
    825     seek_outputs = [
--> 826         {k: split_by_batch_index(v, k, i) for k, v in seek_outputs.items()}
    827         for i in range(sequence_tokens.shape[0])
    828     ]
    829 else:
    830     sequence_tokens = seek_outputs

File ~/miniconda3/envs/py310-fast/lib/python3.10/site-packages/transformers/models/whisper/generation_whisper.py:826, in <dictcomp>(.0)
    822         return values[batch_idx].cpu()
    824     sequence_tokens = seek_outputs["sequences"]
    825     seek_outputs = [
--> 826         {k: split_by_batch_index(v, k, i) for k, v in seek_outputs.items()}
    827         for i in range(sequence_tokens.shape[0])
    828     ]
    829 else:
    830     sequence_tokens = seek_outputs

File ~/miniconda3/envs/py310-fast/lib/python3.10/site-packages/transformers/models/whisper/generation_whisper.py:822, in WhisperGenerationMixin._postprocess_outputs.<locals>.split_by_batch_index(values, key, batch_idx)
    819 if key == "past_key_values":
    820     # we don't save `past_key_values` as this is too costly
    821     return None
--> 822 return values[batch_idx].cpu()

AttributeError: 'tuple' object has no attribute 'cpu'

Works fine if you don't ask the timestamps per token.

Expected behavior

Model should be able to return the timestamps per token when working with long audio after #27492

@amyeroberts
Copy link
Collaborator

cc @sanchit-gandhi @ylacombe

@patrickvonplaten
Copy link
Contributor

This is more of a feature request than a bug I'd say. Happy to have a look with #28984

@amyeroberts amyeroberts added the Feature request Request for a new feature label Feb 12, 2024
@Gheovgos
Copy link

Same issue here, using transformer with whisper-tiny

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Feature request Request for a new feature
Projects
None yet
Development

Successfully merging a pull request may close this issue.

4 participants