Skip to content

Commit

Permalink
[Whisper] Finalize batched SOTA long-form generation (#27658)
Browse files Browse the repository at this point in the history
* finalize

* make fix copies whisper

* [Tests] Make sure that we don't run tests mulitple times

* Update src/transformers/models/whisper/modeling_whisper.py

* [Tests] Make sure that we don't run tests mulitple times

* fix more

* improve

* improve

* improve further

* improve more

* improve

* fix more

* git commit and git push

* fix more

* fix more

* fix more

* New try

* Fix more whisper stuff

* Improve

* correct more

* correct more

* correct more

* Fix some tests

* Add more tests

* correct more

* correct more

* correct more

* push

* correct more

* Fix more

* Better

* without dec mask

* correct more

* clean

* save intermediate

* Fix more

* Fix VAD for large-v2

* Save new

* Correct more

* make cleaner

* correct tests

* correct src

* Finish

* Fix more

* Fix more

* finish

* Fix edge cases

* fix return_dict_in_generate

* fix all tests

* make style

* add docstrings

* add docstrings

* Fix logit processor

* make style

* fix pipeline test

* fix more style

* Apply suggestions from code review

* apply feedback Sanchit

* correct more

* Apply suggestions from code review

Co-authored-by: Sanchit Gandhi <[email protected]>

* Apply suggestions from code review

Co-authored-by: Joao Gante <[email protected]>
Co-authored-by: Sanchit Gandhi <[email protected]>

* correct more

* correct more

* correct more

* Fix staticmethod

* correct more

* fix

* fix slow tests

* make style

* fix tokenizer test

* fix tokenizer test

* Apply suggestions from code review

Co-authored-by: Arthur <[email protected]>

* finish

* finish

* revert kwargs change

---------

Co-authored-by: Sanchit Gandhi <[email protected]>
Co-authored-by: Joao Gante <[email protected]>
Co-authored-by: Arthur <[email protected]>
  • Loading branch information
4 people authored Jan 19, 2024
1 parent d4fc1eb commit 690fe73
Show file tree
Hide file tree
Showing 8 changed files with 1,825 additions and 852 deletions.
74 changes: 70 additions & 4 deletions src/transformers/generation/logits_process.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,7 @@ def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwa
scores = processor(input_ids, scores, **kwargs)
else:
scores = processor(input_ids, scores)

return scores


Expand Down Expand Up @@ -1657,6 +1658,9 @@ def __init__(self, begin_suppress_tokens, begin_index):
self.begin_suppress_tokens = list(begin_suppress_tokens)
self.begin_index = begin_index

def set_begin_index(self, begin_index):
self.begin_index = begin_index

@add_start_docstrings(LOGITS_PROCESSOR_INPUTS_DOCSTRING)
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
if input_ids.shape[1] == self.begin_index:
Expand Down Expand Up @@ -1778,6 +1782,7 @@ class WhisperTimeStampLogitsProcessor(LogitsProcessor):
max_initial_timestamp_index (`int`, *optional*, defaults to 1):
Used to set the maximum value of the initial timestamp. This is used to prevent the model from
predicting timestamps that are too far in the future.
begin_index (`Optional`, *optional*): Token index of the first token that is generated by the model.
_detect_timestamp_from_logprob (`bool`, *optional*): Whether timestamps can be predicted from logprobs over all timestamps.
Examples:
Expand Down Expand Up @@ -1810,11 +1815,11 @@ class WhisperTimeStampLogitsProcessor(LogitsProcessor):
"""

def __init__(
self, generate_config, _detect_timestamp_from_logprob: Optional[bool] = None
self, generate_config, begin_index: Optional[int] = None, _detect_timestamp_from_logprob: Optional[bool] = None
): # support for the kwargs
self.eos_token_id = generate_config.eos_token_id
self.no_timestamps_token_id = generate_config.no_timestamps_token_id
self.timestamp_begin = generate_config.no_timestamps_token_id + 1
self.eos_token_id = generate_config.eos_token_id or generate_config.bos_token_id

# this variable is mostly just used for testing
self._detect_timestamp_from_logprob = (
Expand All @@ -1823,10 +1828,17 @@ def __init__(
else getattr(generate_config, "_detect_timestamp_from_logprob", True)
)

self.begin_index = (
len(generate_config.forced_decoder_ids) + 1 if generate_config.forced_decoder_ids is not None else 1
num_forced_ids = (
len(generate_config.forced_decoder_ids) if generate_config.forced_decoder_ids is not None else 0
)
self.begin_index = begin_index or (num_forced_ids + 1)

self.max_initial_timestamp_index = getattr(generate_config, "max_initial_timestamp_index", None)
# TODO(Patrick): Make sure that official models have max_initial_timestamp_index set to 50
# self.max_initial_timestamp_index = 50

def set_begin_index(self, begin_index):
self.begin_index = begin_index

@add_start_docstrings(LOGITS_PROCESSOR_INPUTS_DOCSTRING)
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
Expand Down Expand Up @@ -1878,6 +1890,60 @@ def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> to
return scores


class WhisperNoSpeechDetection(LogitsProcessor):
r"""This processor can be used to detect silence when using Whisper. It should take as input unprocessed logits to follow the original implementation"""

def __init__(self, no_speech_token: int, begin_index: int, scores_is_logprobs: bool = False):
self.no_speech_token = no_speech_token
# offset between <start-of-transcription> token, <SOT>, in paper and first generated token
# is equal to the position of the first generated token index
self.start_of_trans_offset = begin_index

# `self.begin_index` is a running value that is changed on the fly
self.begin_index = begin_index
self._no_speech_prob = [0.0]
self.is_scores_logprobs = scores_is_logprobs

# overwritten dynamically
self.model = None
self.inputs = None

def set_model(self, model):
self.model = model

def set_inputs(self, inputs):
self.inputs = {**self.model.prepare_inputs_for_generation(**inputs), **inputs}
self.inputs["input_features"] = self.inputs.pop("inputs")

@property
def no_speech_prob(self):
return self._no_speech_prob

def set_begin_index(self, begin_index):
self.begin_index = begin_index

@add_start_docstrings(LOGITS_PROCESSOR_INPUTS_DOCSTRING)
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
if input_ids.shape[1] == self.begin_index:
if self.start_of_trans_offset > 1:
with torch.no_grad():
logits = self.model(**self.inputs).logits

no_speech_index = self.begin_index - self.start_of_trans_offset
no_speech_scores = logits[:, no_speech_index]
else:
no_speech_scores = scores

if self.is_scores_logprobs:
probs = no_speech_scores.exp()
else:
probs = no_speech_scores.float().softmax(dim=-1)

self._no_speech_prob = probs[:, self.no_speech_token]

return scores


class ClassifierFreeGuidanceLogitsProcessor(LogitsProcessor):
r"""
[`LogitsProcessor`] for classifier free guidance (CFG). The scores are split over the batch dimension,
Expand Down
2 changes: 2 additions & 0 deletions src/transformers/generation/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -518,6 +518,8 @@ def _prepare_decoder_input_ids_for_generation(
# exception: Donut checkpoints have task-specific decoder starts and don't expect a BOS token
elif self.config.model_type == "vision-encoder-decoder" and "donut" in self.name_or_path.lower():
pass
elif self.config.model_type in ["whisper"]:
pass
# user input but doesn't start with decoder_start_token_id -> prepend decoder_start_token_id (and adjust
# decoder_attention_mask if provided)
elif (decoder_input_ids[:, 0] != decoder_start_token_id).all().item():
Expand Down
Loading

0 comments on commit 690fe73

Please sign in to comment.