Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Whisper] Finalize batched SOTA long-form generation #27658

Merged
merged 87 commits into from
Jan 19, 2024
Merged
Show file tree
Hide file tree
Changes from 83 commits
Commits
Show all changes
87 commits
Select commit Hold shift + click to select a range
8530d3d
finalize
patrickvonplaten Nov 22, 2023
c4826fd
make fix copies whisper
patrickvonplaten Nov 27, 2023
2de5fe0
[Tests] Make sure that we don't run tests mulitple times
patrickvonplaten Nov 27, 2023
3106c51
Update src/transformers/models/whisper/modeling_whisper.py
patrickvonplaten Nov 27, 2023
cfc1998
[Tests] Make sure that we don't run tests mulitple times
patrickvonplaten Nov 28, 2023
0ca149b
Merge branch 'improve_decoding' of https://github.com/huggingface/tra…
patrickvonplaten Nov 28, 2023
00df894
fix more
patrickvonplaten Dec 1, 2023
25bcd69
improve
patrickvonplaten Dec 2, 2023
30a78a3
improve
patrickvonplaten Dec 2, 2023
e3bff24
improve further
patrickvonplaten Dec 2, 2023
25c9345
improve more
patrickvonplaten Dec 2, 2023
cd7734b
improve
patrickvonplaten Dec 4, 2023
3cf1752
fix more
patrickvonplaten Dec 4, 2023
48c2e60
Merge branch 'improve_decoding' of https://github.com/huggingface/tra…
patrickvonplaten Dec 4, 2023
133d17e
git commit and git push
patrickvonplaten Dec 4, 2023
e86745a
fix more
patrickvonplaten Dec 4, 2023
cbae58c
fix more
patrickvonplaten Dec 4, 2023
3933896
fix more
patrickvonplaten Dec 4, 2023
0e413e7
New try
patrickvonplaten Dec 5, 2023
8411a9e
Fix more whisper stuff
patrickvonplaten Dec 5, 2023
45d1233
Merge branch 'main' of https://github.com/huggingface/transformers in…
patrickvonplaten Dec 5, 2023
10cfdc6
Improve
patrickvonplaten Dec 5, 2023
b0897c7
correct more
patrickvonplaten Dec 7, 2023
24fa463
correct more
patrickvonplaten Dec 7, 2023
e0b7af3
correct more
patrickvonplaten Dec 7, 2023
5a67f75
Fix some tests
patrickvonplaten Dec 7, 2023
404d542
Add more tests
patrickvonplaten Dec 7, 2023
46cdb43
correct more
patrickvonplaten Dec 8, 2023
6818ebf
correct more
patrickvonplaten Dec 8, 2023
b46b63d
correct more
patrickvonplaten Dec 8, 2023
184c888
push
patrickvonplaten Dec 9, 2023
e4b7827
correct more
patrickvonplaten Dec 9, 2023
380bd54
Fix more
patrickvonplaten Dec 9, 2023
3e49df8
Better
patrickvonplaten Dec 9, 2023
c2387ed
without dec mask
patrickvonplaten Dec 9, 2023
1cca405
correct more
patrickvonplaten Dec 10, 2023
29a9830
clean
patrickvonplaten Dec 10, 2023
032d45a
save intermediate
patrickvonplaten Dec 10, 2023
0724d47
correct more
patrickvonplaten Dec 10, 2023
06b598a
Fix more
patrickvonplaten Dec 10, 2023
d1021ec
Fix VAD for large-v2
patrickvonplaten Dec 12, 2023
23d2149
Save new
patrickvonplaten Dec 13, 2023
1caa2cb
Correct more
patrickvonplaten Dec 13, 2023
947e542
make cleaner
patrickvonplaten Dec 13, 2023
6f01cdb
merge from main
patrickvonplaten Dec 13, 2023
c0d03af
correct tests
patrickvonplaten Dec 13, 2023
abb3d56
correct src
patrickvonplaten Dec 13, 2023
a8b8446
Finish
patrickvonplaten Dec 19, 2023
875beab
Fix more
patrickvonplaten Dec 19, 2023
9d31fda
Merge branch 'main' of https://github.com/huggingface/transformers in…
patrickvonplaten Dec 19, 2023
39034c7
Fix more
patrickvonplaten Dec 19, 2023
66c08ee
finish
patrickvonplaten Dec 19, 2023
a6cb9ba
rebase
patrickvonplaten Jan 2, 2024
254026c
Fix edge cases
patrickvonplaten Jan 2, 2024
85c68f2
fix return_dict_in_generate
patrickvonplaten Jan 3, 2024
c2fa76a
fix all tests
patrickvonplaten Jan 3, 2024
affdb6d
make style
patrickvonplaten Jan 3, 2024
bf7ee48
add docstrings
patrickvonplaten Jan 3, 2024
8096e4a
add docstrings
patrickvonplaten Jan 3, 2024
bd8c076
Merge branch 'main' of https://github.com/huggingface/transformers in…
patrickvonplaten Jan 6, 2024
f667fc4
Fix logit processor
patrickvonplaten Jan 8, 2024
5d976b0
Merge branch 'main' of https://github.com/huggingface/transformers in…
patrickvonplaten Jan 8, 2024
33b1903
make style
patrickvonplaten Jan 8, 2024
76134ec
fix pipeline test
patrickvonplaten Jan 8, 2024
8f00ba3
fix more style
patrickvonplaten Jan 8, 2024
5634a64
Apply suggestions from code review
patrickvonplaten Jan 8, 2024
667e03e
fix merge conflict
patrickvonplaten Jan 16, 2024
530c246
apply feedback Sanchit
patrickvonplaten Jan 16, 2024
d1662af
correct more
patrickvonplaten Jan 16, 2024
d22a9b3
Apply suggestions from code review
patrickvonplaten Jan 16, 2024
c34574f
Apply suggestions from code review
patrickvonplaten Jan 16, 2024
0e7c86e
correct more
patrickvonplaten Jan 16, 2024
67c2ea4
correct more
patrickvonplaten Jan 16, 2024
c9da44d
correct more
patrickvonplaten Jan 16, 2024
6541bac
Fix staticmethod
patrickvonplaten Jan 16, 2024
aae16f3
correct more
patrickvonplaten Jan 16, 2024
d24a4d8
fix
patrickvonplaten Jan 16, 2024
85ec8fe
fix slow tests
patrickvonplaten Jan 16, 2024
32b745c
make style
patrickvonplaten Jan 16, 2024
113d678
fix tokenizer test
patrickvonplaten Jan 17, 2024
de34f23
fix tokenizer test
patrickvonplaten Jan 17, 2024
55817e2
Merge branch 'main' of https://github.com/huggingface/transformers in…
patrickvonplaten Jan 17, 2024
d26b2e4
Merge branch 'main' of https://github.com/huggingface/transformers in…
patrickvonplaten Jan 19, 2024
db02e95
Apply suggestions from code review
patrickvonplaten Jan 19, 2024
71b4893
finish
patrickvonplaten Jan 19, 2024
e3711a3
finish
patrickvonplaten Jan 19, 2024
e9673c5
revert kwargs change
patrickvonplaten Jan 19, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
77 changes: 73 additions & 4 deletions src/transformers/generation/logits_process.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,7 @@ def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwa
scores = processor(input_ids, scores, **kwargs)
else:
scores = processor(input_ids, scores)

return scores


Expand Down Expand Up @@ -1657,6 +1658,9 @@ def __init__(self, begin_suppress_tokens, begin_index):
self.begin_suppress_tokens = list(begin_suppress_tokens)
self.begin_index = begin_index

def set_begin_index(self, begin_index):
self.begin_index = begin_index

@add_start_docstrings(LOGITS_PROCESSOR_INPUTS_DOCSTRING)
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
if input_ids.shape[1] == self.begin_index:
Expand Down Expand Up @@ -1778,6 +1782,7 @@ class WhisperTimeStampLogitsProcessor(LogitsProcessor):
max_initial_timestamp_index (`int`, *optional*, defaults to 1):
Used to set the maximum value of the initial timestamp. This is used to prevent the model from
predicting timestamps that are too far in the future.
begin_index (`Optional`, *optional*): Token index of the first token that is generated by the model.
_detect_timestamp_from_logprob (`bool`, *optional*): Whether timestamps can be predicted from logprobs over all timestamps.

Examples:
Expand Down Expand Up @@ -1810,11 +1815,11 @@ class WhisperTimeStampLogitsProcessor(LogitsProcessor):
"""

def __init__(
self, generate_config, _detect_timestamp_from_logprob: Optional[bool] = None
self, generate_config, begin_index: Optional[int] = None, _detect_timestamp_from_logprob: Optional[bool] = None
): # support for the kwargs
self.eos_token_id = generate_config.eos_token_id
self.no_timestamps_token_id = generate_config.no_timestamps_token_id
self.timestamp_begin = generate_config.no_timestamps_token_id + 1
self.eos_token_id = generate_config.eos_token_id or generate_config.bos_token_id

# this variable is mostly just used for testing
self._detect_timestamp_from_logprob = (
Expand All @@ -1823,10 +1828,17 @@ def __init__(
else getattr(generate_config, "_detect_timestamp_from_logprob", True)
)

self.begin_index = (
len(generate_config.forced_decoder_ids) + 1 if generate_config.forced_decoder_ids is not None else 1
num_forced_ids = (
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This was incorrect before

len(generate_config.forced_decoder_ids) if generate_config.forced_decoder_ids is not None else 0
)
self.begin_index = begin_index or (num_forced_ids + 1)

self.max_initial_timestamp_index = getattr(generate_config, "max_initial_timestamp_index", None)
# TODO(Patrick): Make sure that official models have max_initial_timestamp_index set to 50
# self.max_initial_timestamp_index = 50
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Once the PR is in a mergable state, we should update the model configurations of all whisper models.


def set_begin_index(self, begin_index):
self.begin_index = begin_index

@add_start_docstrings(LOGITS_PROCESSOR_INPUTS_DOCSTRING)
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
Expand Down Expand Up @@ -1878,6 +1890,63 @@ def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> to
return scores


class WhisperNoSpeechDetection(LogitsProcessor):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This processor class is quite awkward: it doesn't modify the scores (as opposed to all other LogitsProcessor), it requires a model, it requires inputs, and it may do a forward pass. Its sole purpose is to set self._no_speech_prob -- it would be nice if we could get it some other way 🤔

Looking at the generation code, this processor obtains the probability of voice activity, which may set should_skip in generate_with_fallback. generate_with_fallback calls generate, which then triggers this class' __call__. set_inputs is called with the same segment_input and decoder_input_ids, so the forward pass in this class receives the same inputs as the generate that triggers it. Doesn't this mean that we should be able to call generate with output_scores=True, and compute no_speech_scores from the returned scores? We even already set output_scores=True when not is_shortform and logprob_threshold is not None.

If what I wrote above is correct, the changes would have the following pros and cons:
(+) no need for this class, and a few other functions in generation_whisper.py
(+) more readable, because there isn't as much code nesting
(-) this PR would need some code changes 😛

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You are a 100% correct that this logit processor is quite akward! I'm also not super about the way it's implemented, but for me it was a hard requirement that Whisper doesn't change the generate method at all.

The problem with retrieving the "no speech" prob from the scores is that these scores will be the ones that went through all the logit processors (see here) which will mean that the prob retrieved will necessarily be incorrect. In the original codebase the "no speech" prob is computed from the "untouched" logits (see here)

Other solutions that I have considered here:

  • 1.) Just not add the "NoSpeechDetection" logit processor since it doesn't help that much for performance in my tests.
  • 2.) Make the logit processor cleaner by just checking the "scores" that are forwarded to the logit processor. However this would require this lineto be changed so that all logits (not just the last one are passed into the logit processors). This would though mean I'd have to also add some hacky code to the "LogitsProcessor" class that makes sure that the logits are cut to the last in case WhisperNoSpeechDetection is not part of the processors.

=> Overall I think neither 1) nor 2) is better than what we have now. I'm willing to go for 1.) though if people prefer (also cc @sanchit-gandhi)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it was a hard requirement that Whisper doesn't change the generate method at all

100% agreed 👍

The problem with retrieving the "no speech" prob from the scores is that these scores will be the ones that went through all the logit processors (see here) which will mean that the prob retrieved will necessarily be incorrect.

I see 😞 Our design makes this feature quite ugly to implement, but I think it's more important to match the reference paper/implementation!

r"""This processor can be used to detect silence when using Whisper."""
patrickvonplaten marked this conversation as resolved.
Show resolved Hide resolved

def __init__(self, no_speech_token: int, begin_index: int, scores_is_logprobs: bool = False):
self.no_speech_token = no_speech_token
# offset between <start-of-transcription> token, <SOT>, in paper and first generated token
# is equal to the position of the first generated token index
self.start_of_trans_offset = begin_index

# `self.begin_index` is a running value that is changed on the fly
self.begin_index = begin_index
self._no_speech_prob = [0.0]
self.is_scores_logprobs = scores_is_logprobs

# make sure we pass all logits
self._pass_all_logits = True
patrickvonplaten marked this conversation as resolved.
Show resolved Hide resolved
patrickvonplaten marked this conversation as resolved.
Show resolved Hide resolved

# overwritten dynamically
self.model = None
self.inputs = None

def set_model(self, model):
self.model = model

def set_inputs(self, inputs):
self.inputs = {**self.model.prepare_inputs_for_generation(**inputs), **inputs}
self.inputs["input_features"] = self.inputs.pop("inputs")

@property
def no_speech_prob(self):
return self._no_speech_prob

def set_begin_index(self, begin_index):
self.begin_index = begin_index

@add_start_docstrings(LOGITS_PROCESSOR_INPUTS_DOCSTRING)
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
if input_ids.shape[1] == self.begin_index:
if self.start_of_trans_offset > 1:
with torch.no_grad():
logits = self.model(**self.inputs).logits
Comment on lines +1929 to +1930
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this should be used in the generate function so no need to specify this no ?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.


no_speech_index = self.begin_index - self.start_of_trans_offset
no_speech_scores = logits[:, no_speech_index]
else:
no_speech_scores = scores

if self.is_scores_logprobs:
probs = no_speech_scores.exp()
else:
probs = no_speech_scores.float().softmax(dim=-1)

self._no_speech_prob = probs[:, self.no_speech_token]

return scores


class ClassifierFreeGuidanceLogitsProcessor(LogitsProcessor):
r"""
[`LogitsProcessor`] for classifier free guidance (CFG). The scores are split over the batch dimension,
Expand Down
2 changes: 2 additions & 0 deletions src/transformers/generation/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -518,6 +518,8 @@ def _prepare_decoder_input_ids_for_generation(
# exception: Donut checkpoints have task-specific decoder starts and don't expect a BOS token
elif self.config.model_type == "vision-encoder-decoder" and "donut" in self.name_or_path.lower():
pass
elif self.config.model_type in ["whisper"]:
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For Whisper we often don't start with the decoder_input_ids (e.g. when conditioning on the previous prompt OR when using "prompt_ids"). We should correct this here.

pass
# user input but doesn't start with decoder_start_token_id -> prepend decoder_start_token_id (and adjust
# decoder_attention_mask if provided)
elif (decoder_input_ids[:, 0] != decoder_start_token_id).all().item():
Expand Down
Loading
Loading