-
Notifications
You must be signed in to change notification settings - Fork 27.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Whisper] Finalize batched SOTA long-form generation #27658
Changes from all commits
8530d3d
c4826fd
2de5fe0
3106c51
cfc1998
0ca149b
00df894
25bcd69
30a78a3
e3bff24
25c9345
cd7734b
3cf1752
48c2e60
133d17e
e86745a
cbae58c
3933896
0e413e7
8411a9e
45d1233
10cfdc6
b0897c7
24fa463
e0b7af3
5a67f75
404d542
46cdb43
6818ebf
b46b63d
184c888
e4b7827
380bd54
3e49df8
c2387ed
1cca405
29a9830
032d45a
0724d47
06b598a
d1021ec
23d2149
1caa2cb
947e542
6f01cdb
c0d03af
abb3d56
a8b8446
875beab
9d31fda
39034c7
66c08ee
a6cb9ba
254026c
85c68f2
c2fa76a
affdb6d
bf7ee48
8096e4a
bd8c076
f667fc4
5d976b0
33b1903
76134ec
8f00ba3
5634a64
667e03e
530c246
d1662af
d22a9b3
c34574f
0e7c86e
67c2ea4
c9da44d
6541bac
aae16f3
d24a4d8
85ec8fe
32b745c
113d678
de34f23
55817e2
d26b2e4
db02e95
71b4893
e3711a3
e9673c5
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -95,6 +95,7 @@ def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwa | |
scores = processor(input_ids, scores, **kwargs) | ||
else: | ||
scores = processor(input_ids, scores) | ||
|
||
return scores | ||
|
||
|
||
|
@@ -1657,6 +1658,9 @@ def __init__(self, begin_suppress_tokens, begin_index): | |
self.begin_suppress_tokens = list(begin_suppress_tokens) | ||
self.begin_index = begin_index | ||
|
||
def set_begin_index(self, begin_index): | ||
self.begin_index = begin_index | ||
|
||
@add_start_docstrings(LOGITS_PROCESSOR_INPUTS_DOCSTRING) | ||
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor: | ||
if input_ids.shape[1] == self.begin_index: | ||
|
@@ -1778,6 +1782,7 @@ class WhisperTimeStampLogitsProcessor(LogitsProcessor): | |
max_initial_timestamp_index (`int`, *optional*, defaults to 1): | ||
Used to set the maximum value of the initial timestamp. This is used to prevent the model from | ||
predicting timestamps that are too far in the future. | ||
begin_index (`Optional`, *optional*): Token index of the first token that is generated by the model. | ||
_detect_timestamp_from_logprob (`bool`, *optional*): Whether timestamps can be predicted from logprobs over all timestamps. | ||
Examples: | ||
|
@@ -1810,11 +1815,11 @@ class WhisperTimeStampLogitsProcessor(LogitsProcessor): | |
""" | ||
|
||
def __init__( | ||
self, generate_config, _detect_timestamp_from_logprob: Optional[bool] = None | ||
self, generate_config, begin_index: Optional[int] = None, _detect_timestamp_from_logprob: Optional[bool] = None | ||
): # support for the kwargs | ||
self.eos_token_id = generate_config.eos_token_id | ||
self.no_timestamps_token_id = generate_config.no_timestamps_token_id | ||
self.timestamp_begin = generate_config.no_timestamps_token_id + 1 | ||
self.eos_token_id = generate_config.eos_token_id or generate_config.bos_token_id | ||
|
||
# this variable is mostly just used for testing | ||
self._detect_timestamp_from_logprob = ( | ||
|
@@ -1823,10 +1828,17 @@ def __init__( | |
else getattr(generate_config, "_detect_timestamp_from_logprob", True) | ||
) | ||
|
||
self.begin_index = ( | ||
len(generate_config.forced_decoder_ids) + 1 if generate_config.forced_decoder_ids is not None else 1 | ||
num_forced_ids = ( | ||
len(generate_config.forced_decoder_ids) if generate_config.forced_decoder_ids is not None else 0 | ||
) | ||
self.begin_index = begin_index or (num_forced_ids + 1) | ||
|
||
self.max_initial_timestamp_index = getattr(generate_config, "max_initial_timestamp_index", None) | ||
# TODO(Patrick): Make sure that official models have max_initial_timestamp_index set to 50 | ||
# self.max_initial_timestamp_index = 50 | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Once the PR is in a mergable state, we should update the model configurations of all whisper models. |
||
|
||
def set_begin_index(self, begin_index): | ||
self.begin_index = begin_index | ||
|
||
@add_start_docstrings(LOGITS_PROCESSOR_INPUTS_DOCSTRING) | ||
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor: | ||
|
@@ -1878,6 +1890,60 @@ def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> to | |
return scores | ||
|
||
|
||
class WhisperNoSpeechDetection(LogitsProcessor): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This processor class is quite awkward: it doesn't modify the Looking at the generation code, this processor obtains the probability of voice activity, which may set If what I wrote above is correct, the changes would have the following pros and cons: There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. You are a 100% correct that this logit processor is quite akward! I'm also not super about the way it's implemented, but for me it was a hard requirement that Whisper doesn't change the generate method at all. The problem with retrieving the "no speech" prob from the Other solutions that I have considered here:
=> Overall I think neither 1) nor 2) is better than what we have now. I'm willing to go for 1.) though if people prefer (also cc @sanchit-gandhi) There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
100% agreed 👍
I see 😞 Our design makes this feature quite ugly to implement, but I think it's more important to match the reference paper/implementation! |
||
r"""This processor can be used to detect silence when using Whisper. It should take as input unprocessed logits to follow the original implementation""" | ||
|
||
def __init__(self, no_speech_token: int, begin_index: int, scores_is_logprobs: bool = False): | ||
self.no_speech_token = no_speech_token | ||
# offset between <start-of-transcription> token, <SOT>, in paper and first generated token | ||
# is equal to the position of the first generated token index | ||
self.start_of_trans_offset = begin_index | ||
|
||
# `self.begin_index` is a running value that is changed on the fly | ||
self.begin_index = begin_index | ||
self._no_speech_prob = [0.0] | ||
self.is_scores_logprobs = scores_is_logprobs | ||
|
||
# overwritten dynamically | ||
self.model = None | ||
self.inputs = None | ||
|
||
def set_model(self, model): | ||
self.model = model | ||
|
||
def set_inputs(self, inputs): | ||
self.inputs = {**self.model.prepare_inputs_for_generation(**inputs), **inputs} | ||
self.inputs["input_features"] = self.inputs.pop("inputs") | ||
|
||
@property | ||
def no_speech_prob(self): | ||
return self._no_speech_prob | ||
|
||
def set_begin_index(self, begin_index): | ||
self.begin_index = begin_index | ||
|
||
@add_start_docstrings(LOGITS_PROCESSOR_INPUTS_DOCSTRING) | ||
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor: | ||
if input_ids.shape[1] == self.begin_index: | ||
if self.start_of_trans_offset > 1: | ||
with torch.no_grad(): | ||
logits = self.model(**self.inputs).logits | ||
Comment on lines
+1929
to
+1930
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. this should be used in the generate function so no need to specify this no ? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. |
||
|
||
no_speech_index = self.begin_index - self.start_of_trans_offset | ||
no_speech_scores = logits[:, no_speech_index] | ||
else: | ||
no_speech_scores = scores | ||
|
||
if self.is_scores_logprobs: | ||
probs = no_speech_scores.exp() | ||
else: | ||
probs = no_speech_scores.float().softmax(dim=-1) | ||
|
||
self._no_speech_prob = probs[:, self.no_speech_token] | ||
|
||
return scores | ||
|
||
|
||
class ClassifierFreeGuidanceLogitsProcessor(LogitsProcessor): | ||
r""" | ||
[`LogitsProcessor`] for classifier free guidance (CFG). The scores are split over the batch dimension, | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -518,6 +518,8 @@ def _prepare_decoder_input_ids_for_generation( | |
# exception: Donut checkpoints have task-specific decoder starts and don't expect a BOS token | ||
elif self.config.model_type == "vision-encoder-decoder" and "donut" in self.name_or_path.lower(): | ||
pass | ||
elif self.config.model_type in ["whisper"]: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. For Whisper we often don't start with the |
||
pass | ||
# user input but doesn't start with decoder_start_token_id -> prepend decoder_start_token_id (and adjust | ||
# decoder_attention_mask if provided) | ||
elif (decoder_input_ids[:, 0] != decoder_start_token_id).all().item(): | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This was incorrect before