Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Whisper] Add sequential longform decoding #27492

Merged
merged 39 commits into from
Nov 22, 2023
Merged
Show file tree
Hide file tree
Changes from 37 commits
Commits
Show all changes
39 commits
Select commit Hold shift + click to select a range
51d2a53
[Whisper] Add seq gen
patrickvonplaten Nov 14, 2023
b0c387d
[Whisper] Add seq gen
patrickvonplaten Nov 14, 2023
e7cb31b
more debug
patrickvonplaten Nov 14, 2023
0c3e8c5
Fix whisper logit processor
patrickvonplaten Nov 15, 2023
6e5ce42
Improve whisper code further
patrickvonplaten Nov 15, 2023
ba318f7
Fix more
patrickvonplaten Nov 15, 2023
c47c316
Merge branch 'main' of https://github.com/huggingface/transformers in…
patrickvonplaten Nov 16, 2023
a9ce5b4
Merge branch 'main' of https://github.com/huggingface/transformers in…
patrickvonplaten Nov 16, 2023
2b0dcf6
more debug
patrickvonplaten Nov 16, 2023
8b2281f
more debug
patrickvonplaten Nov 16, 2023
0afe526
Improve further
patrickvonplaten Nov 16, 2023
d2d16b4
Add tests
patrickvonplaten Nov 16, 2023
68e8226
Prep for batch size > 1
patrickvonplaten Nov 16, 2023
ee43be4
Get batch_size>1 working
patrickvonplaten Nov 17, 2023
1030f22
Correct more
patrickvonplaten Nov 17, 2023
04477e7
Add extensive tests
patrickvonplaten Nov 17, 2023
87b5d8d
more debug
patrickvonplaten Nov 17, 2023
a9cf2bb
more debug
patrickvonplaten Nov 17, 2023
5f3ff78
more debug
patrickvonplaten Nov 19, 2023
6c942be
Merge branch 'main' of https://github.com/huggingface/transformers in…
patrickvonplaten Nov 20, 2023
1c1d1e6
add more tests
patrickvonplaten Nov 20, 2023
74967ee
more debug
patrickvonplaten Nov 20, 2023
0e95291
Apply suggestions from code review
patrickvonplaten Nov 20, 2023
311995d
more debug
patrickvonplaten Nov 20, 2023
aeee0f2
add comments to explain the code better
patrickvonplaten Nov 20, 2023
c8507c7
add comments to explain the code better
patrickvonplaten Nov 20, 2023
0593495
add comments to explain the code better
patrickvonplaten Nov 20, 2023
4382898
Add more examples
patrickvonplaten Nov 20, 2023
708be99
add comments to explain the code better
patrickvonplaten Nov 20, 2023
0dfead2
fix more
patrickvonplaten Nov 20, 2023
79c39d8
Merge branch 'add_whisper_seq_gen' of https://github.com/huggingface/…
patrickvonplaten Nov 20, 2023
62ccd52
add comments to explain the code better
patrickvonplaten Nov 20, 2023
a5755d9
Merge branch 'add_whisper_seq_gen' of https://github.com/huggingface/…
patrickvonplaten Nov 20, 2023
a75ea30
add comments to explain the code better
patrickvonplaten Nov 20, 2023
cc4c19c
correct
patrickvonplaten Nov 21, 2023
0878ac6
correct
patrickvonplaten Nov 21, 2023
889eebb
finalize
patrickvonplaten Nov 22, 2023
c1c2042
Apply suggestions from code review
patrickvonplaten Nov 22, 2023
cc1f87c
Apply suggestions from code review
patrickvonplaten Nov 22, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -9,13 +9,13 @@
import jax.numpy as jnp
import joblib
import optax
patrickvonplaten marked this conversation as resolved.
Show resolved Hide resolved
import wandb
from flax import jax_utils, struct, traverse_util
patrickvonplaten marked this conversation as resolved.
Show resolved Hide resolved
from flax.serialization import from_bytes, to_bytes
from flax.training import train_state
from flax.training.common_utils import shard
from tqdm.auto import tqdm

import wandb
patrickvonplaten marked this conversation as resolved.
Show resolved Hide resolved
patrickvonplaten marked this conversation as resolved.
Show resolved Hide resolved
from transformers import BigBirdConfig, FlaxBigBirdForQuestionAnswering
from transformers.models.big_bird.modeling_flax_big_bird import FlaxBigBirdForQuestionAnsweringModule

Expand Down
2 changes: 1 addition & 1 deletion examples/research_projects/jax-projects/big_bird/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,11 @@
from dataclasses import replace

import jax
patrickvonplaten marked this conversation as resolved.
Show resolved Hide resolved
import wandb
from bigbird_flax import Args, DataCollator, FlaxBigBirdForNaturalQuestions, Trainer, build_tx, train_step, val_step
patrickvonplaten marked this conversation as resolved.
Show resolved Hide resolved
from datasets import load_dataset
from flax import jax_utils

import wandb
patrickvonplaten marked this conversation as resolved.
Show resolved Hide resolved
patrickvonplaten marked this conversation as resolved.
Show resolved Hide resolved
from transformers import BigBirdTokenizerFast


Expand Down
2 changes: 1 addition & 1 deletion examples/research_projects/vqgan-clip/VQGAN_CLIP.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,12 @@
import imageio
import torch
import torchvision
patrickvonplaten marked this conversation as resolved.
Show resolved Hide resolved
import wandb
from img_processing import custom_to_pil, loop_post_process, preprocess, preprocess_vqgan
patrickvonplaten marked this conversation as resolved.
Show resolved Hide resolved
from loaders import load_vqgan
from PIL import Image
from torch import nn

import wandb
patrickvonplaten marked this conversation as resolved.
Show resolved Hide resolved
patrickvonplaten marked this conversation as resolved.
Show resolved Hide resolved
from transformers import CLIPModel, CLIPTokenizerFast
from utils import get_device, get_timestamp, show_pil

Expand Down
50 changes: 36 additions & 14 deletions src/transformers/generation/logits_process.py
Original file line number Diff line number Diff line change
Expand Up @@ -1487,6 +1487,7 @@ class WhisperTimeStampLogitsProcessor(LogitsProcessor):
max_initial_timestamp_index (`int`, *optional*, defaults to 1):
Used to set the maximum value of the initial timestamp. This is used to prevent the model from
predicting timestamps that are too far in the future.
_detect_timestamp_from_logprob (`bool`, *optional*): Whether timestamps can be predicted from logprobs over all timestamps.

Examples:
``` python
Expand Down Expand Up @@ -1517,29 +1518,35 @@ class WhisperTimeStampLogitsProcessor(LogitsProcessor):
```
"""

def __init__(self, generate_config): # support for the kwargs
def __init__(
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This logit processor was incorrectly added before. Also note that since we never had support for "sequential" or "timestamp"-chunked transcription before, we didn't really use the class before. It was only used when single 30s chunks were supposed to predict timestamps.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry if I am not understanding the statement 😅 No it was used for chunked processing because the merging algorithm heavily relies on timestamps, and also produces timestamps. See this line which always add the processor.

Placing a breakpoint in the call of this class with this:

from transformers import pipeline
from datasets import load_dataset
import numpy as np
transcriptor = pipeline("automatic-speech-recognition", model = "openai/whisper-tiny")
ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")
audio_file = ds[0]["audio"]["array"]
long_audio = np.concatenate([ds[0]["audio"]["array"]]*40)
out = transcriptor(long_audio, return_timestamps=True, chunk_length_s=30, stride_length_s=5)

(this is for me "timestamp" chunk transcription.

self, generate_config, _detect_timestamp_from_logprob: Optional[bool] = None
): # support for the kwargs
self.eos_token_id = generate_config.eos_token_id
self.no_timestamps_token_id = generate_config.no_timestamps_token_id
self.timestamp_begin = generate_config.no_timestamps_token_id + 1

self.begin_index = len(generate_config.forced_decoder_ids) + 2
if generate_config.forced_decoder_ids[-1][1] == self.no_timestamps_token_id:
self.begin_index -= 1
self.max_initial_timestamp_index = generate_config.max_initial_timestamp_index
# this variable is mostly just used for testing
self._detect_timestamp_from_logprob = (
_detect_timestamp_from_logprob
if _detect_timestamp_from_logprob is not None
else getattr(generate_config, "_detect_timestamp_from_logprob", True)
)

self.begin_index = (
len(generate_config.forced_decoder_ids) + 1 if generate_config.forced_decoder_ids is not None else 1
)
self.max_initial_timestamp_index = getattr(generate_config, "max_initial_timestamp_index", None)

@add_start_docstrings(LOGITS_PROCESSOR_INPUTS_DOCSTRING)
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
# suppress <|notimestamps|> which is handled by without_timestamps
scores[:, self.no_timestamps_token_id] = -float("inf")

if input_ids.shape[1] == self.begin_index - 1:
scores[:, :] = -float("inf")
scores[:, self.timestamp_begin] = 0
return scores
patrickvonplaten marked this conversation as resolved.
Show resolved Hide resolved

# timestamps have to appear in pairs, except directly before eos_token; mask logits accordingly
for k in range(input_ids.shape[0]):
seq = list(input_ids[k, self.begin_index :].tolist())
sampled_tokens = input_ids[k, self.begin_index :]
seq = list(sampled_tokens.tolist())

last_was_timestamp = len(seq) >= 1 and seq[-1] >= self.timestamp_begin
penultimate_was_timestamp = len(seq) < 2 or seq[-2] >= self.timestamp_begin

Expand All @@ -1549,8 +1556,23 @@ def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> to
else: # cannot be normal text tokens
scores[k, : self.eos_token_id] = -float("inf")

# apply the `max_initial_timestamp` option
if input_ids.shape[1] == self.begin_index and self.max_initial_timestamp_index is not None:
timestamps = sampled_tokens[sampled_tokens.ge(self.timestamp_begin)]
if timestamps.numel() > 0:
# `timestamps` shouldn't decrease; forbid timestamp tokens smaller than the last
# The following lines of code are copied from: https://github.com/openai/whisper/pull/914/files#r1137085090
if last_was_timestamp and not penultimate_was_timestamp:
timestamp_last = timestamps[-1]
else:
# Avoid to emit <|0.00|> again
timestamp_last = timestamps[-1] + 1
patrickvonplaten marked this conversation as resolved.
Show resolved Hide resolved

scores[k, self.timestamp_begin : timestamp_last] = -float("inf")

# apply the `max_initial_timestamp` option
if input_ids.shape[1] == self.begin_index:
scores[:, : self.timestamp_begin] = -float("inf")

if self.max_initial_timestamp_index is not None:
last_allowed = self.timestamp_begin + self.max_initial_timestamp_index
scores[:, last_allowed + 1 :] = -float("inf")

Expand All @@ -1559,7 +1581,7 @@ def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> to
for k in range(input_ids.shape[0]):
timestamp_logprob = logprobs[k, self.timestamp_begin :].logsumexp(dim=-1)
max_text_token_logprob = logprobs[k, : self.timestamp_begin].max()
if timestamp_logprob > max_text_token_logprob:
if timestamp_logprob > max_text_token_logprob and self._detect_timestamp_from_logprob:
scores[k, : self.timestamp_begin] = -float("inf")

return scores
Expand Down
Loading
Loading