Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Whisper] Finalize batched SOTA long-form generation #27658

Merged
merged 87 commits into from
Jan 19, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
87 commits
Select commit Hold shift + click to select a range
8530d3d
finalize
patrickvonplaten Nov 22, 2023
c4826fd
make fix copies whisper
patrickvonplaten Nov 27, 2023
2de5fe0
[Tests] Make sure that we don't run tests mulitple times
patrickvonplaten Nov 27, 2023
3106c51
Update src/transformers/models/whisper/modeling_whisper.py
patrickvonplaten Nov 27, 2023
cfc1998
[Tests] Make sure that we don't run tests mulitple times
patrickvonplaten Nov 28, 2023
0ca149b
Merge branch 'improve_decoding' of https://github.com/huggingface/tra…
patrickvonplaten Nov 28, 2023
00df894
fix more
patrickvonplaten Dec 1, 2023
25bcd69
improve
patrickvonplaten Dec 2, 2023
30a78a3
improve
patrickvonplaten Dec 2, 2023
e3bff24
improve further
patrickvonplaten Dec 2, 2023
25c9345
improve more
patrickvonplaten Dec 2, 2023
cd7734b
improve
patrickvonplaten Dec 4, 2023
3cf1752
fix more
patrickvonplaten Dec 4, 2023
48c2e60
Merge branch 'improve_decoding' of https://github.com/huggingface/tra…
patrickvonplaten Dec 4, 2023
133d17e
git commit and git push
patrickvonplaten Dec 4, 2023
e86745a
fix more
patrickvonplaten Dec 4, 2023
cbae58c
fix more
patrickvonplaten Dec 4, 2023
3933896
fix more
patrickvonplaten Dec 4, 2023
0e413e7
New try
patrickvonplaten Dec 5, 2023
8411a9e
Fix more whisper stuff
patrickvonplaten Dec 5, 2023
45d1233
Merge branch 'main' of https://github.com/huggingface/transformers in…
patrickvonplaten Dec 5, 2023
10cfdc6
Improve
patrickvonplaten Dec 5, 2023
b0897c7
correct more
patrickvonplaten Dec 7, 2023
24fa463
correct more
patrickvonplaten Dec 7, 2023
e0b7af3
correct more
patrickvonplaten Dec 7, 2023
5a67f75
Fix some tests
patrickvonplaten Dec 7, 2023
404d542
Add more tests
patrickvonplaten Dec 7, 2023
46cdb43
correct more
patrickvonplaten Dec 8, 2023
6818ebf
correct more
patrickvonplaten Dec 8, 2023
b46b63d
correct more
patrickvonplaten Dec 8, 2023
184c888
push
patrickvonplaten Dec 9, 2023
e4b7827
correct more
patrickvonplaten Dec 9, 2023
380bd54
Fix more
patrickvonplaten Dec 9, 2023
3e49df8
Better
patrickvonplaten Dec 9, 2023
c2387ed
without dec mask
patrickvonplaten Dec 9, 2023
1cca405
correct more
patrickvonplaten Dec 10, 2023
29a9830
clean
patrickvonplaten Dec 10, 2023
032d45a
save intermediate
patrickvonplaten Dec 10, 2023
0724d47
correct more
patrickvonplaten Dec 10, 2023
06b598a
Fix more
patrickvonplaten Dec 10, 2023
d1021ec
Fix VAD for large-v2
patrickvonplaten Dec 12, 2023
23d2149
Save new
patrickvonplaten Dec 13, 2023
1caa2cb
Correct more
patrickvonplaten Dec 13, 2023
947e542
make cleaner
patrickvonplaten Dec 13, 2023
6f01cdb
merge from main
patrickvonplaten Dec 13, 2023
c0d03af
correct tests
patrickvonplaten Dec 13, 2023
abb3d56
correct src
patrickvonplaten Dec 13, 2023
a8b8446
Finish
patrickvonplaten Dec 19, 2023
875beab
Fix more
patrickvonplaten Dec 19, 2023
9d31fda
Merge branch 'main' of https://github.com/huggingface/transformers in…
patrickvonplaten Dec 19, 2023
39034c7
Fix more
patrickvonplaten Dec 19, 2023
66c08ee
finish
patrickvonplaten Dec 19, 2023
a6cb9ba
rebase
patrickvonplaten Jan 2, 2024
254026c
Fix edge cases
patrickvonplaten Jan 2, 2024
85c68f2
fix return_dict_in_generate
patrickvonplaten Jan 3, 2024
c2fa76a
fix all tests
patrickvonplaten Jan 3, 2024
affdb6d
make style
patrickvonplaten Jan 3, 2024
bf7ee48
add docstrings
patrickvonplaten Jan 3, 2024
8096e4a
add docstrings
patrickvonplaten Jan 3, 2024
bd8c076
Merge branch 'main' of https://github.com/huggingface/transformers in…
patrickvonplaten Jan 6, 2024
f667fc4
Fix logit processor
patrickvonplaten Jan 8, 2024
5d976b0
Merge branch 'main' of https://github.com/huggingface/transformers in…
patrickvonplaten Jan 8, 2024
33b1903
make style
patrickvonplaten Jan 8, 2024
76134ec
fix pipeline test
patrickvonplaten Jan 8, 2024
8f00ba3
fix more style
patrickvonplaten Jan 8, 2024
5634a64
Apply suggestions from code review
patrickvonplaten Jan 8, 2024
667e03e
fix merge conflict
patrickvonplaten Jan 16, 2024
530c246
apply feedback Sanchit
patrickvonplaten Jan 16, 2024
d1662af
correct more
patrickvonplaten Jan 16, 2024
d22a9b3
Apply suggestions from code review
patrickvonplaten Jan 16, 2024
c34574f
Apply suggestions from code review
patrickvonplaten Jan 16, 2024
0e7c86e
correct more
patrickvonplaten Jan 16, 2024
67c2ea4
correct more
patrickvonplaten Jan 16, 2024
c9da44d
correct more
patrickvonplaten Jan 16, 2024
6541bac
Fix staticmethod
patrickvonplaten Jan 16, 2024
aae16f3
correct more
patrickvonplaten Jan 16, 2024
d24a4d8
fix
patrickvonplaten Jan 16, 2024
85ec8fe
fix slow tests
patrickvonplaten Jan 16, 2024
32b745c
make style
patrickvonplaten Jan 16, 2024
113d678
fix tokenizer test
patrickvonplaten Jan 17, 2024
de34f23
fix tokenizer test
patrickvonplaten Jan 17, 2024
55817e2
Merge branch 'main' of https://github.com/huggingface/transformers in…
patrickvonplaten Jan 17, 2024
d26b2e4
Merge branch 'main' of https://github.com/huggingface/transformers in…
patrickvonplaten Jan 19, 2024
db02e95
Apply suggestions from code review
patrickvonplaten Jan 19, 2024
71b4893
finish
patrickvonplaten Jan 19, 2024
e3711a3
finish
patrickvonplaten Jan 19, 2024
e9673c5
revert kwargs change
patrickvonplaten Jan 19, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions src/transformers/generation/logits_process.py
Original file line number Diff line number Diff line change
Expand Up @@ -1537,6 +1537,9 @@ def __init__(
)
self.max_initial_timestamp_index = getattr(generate_config, "max_initial_timestamp_index", None)

def set_begin_index(self, begin_index):
self.begin_index = begin_index

@add_start_docstrings(LOGITS_PROCESSOR_INPUTS_DOCSTRING)
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
# suppress <|notimestamps|> which is handled by without_timestamps
Expand Down
2 changes: 2 additions & 0 deletions src/transformers/generation/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -773,6 +773,8 @@ def _prepare_decoder_input_ids_for_generation(
# exception: Donut checkpoints have task-specific decoder starts and don't expect a BOS token
elif self.config.model_type == "vision-encoder-decoder" and "donut" in self.name_or_path.lower():
pass
elif self.config.model_type in ['whisper']:
pass
# user input but doesn't start with decoder_start_token_id -> prepend decoder_start_token_id (and adjust
# decoder_attention_mask if provided)
elif (decoder_input_ids[:, 0] != decoder_start_token_id).all().item():
Expand Down
49 changes: 39 additions & 10 deletions src/transformers/models/whisper/modeling_whisper.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
""" PyTorch Whisper model."""

import math
from pickle import decode_long
patrickvonplaten marked this conversation as resolved.
Show resolved Hide resolved
import warnings
from typing import Optional, Tuple, Union

Expand Down Expand Up @@ -1741,6 +1742,7 @@ def generate(
task=None,
language=None,
is_multilingual=None,
condition_on_previous_tokens: Optional[bool] = None,
prompt_ids: Optional[torch.Tensor] = None,
num_segment_frames: Optional[int] = None,
return_token_timestamps: Optional[bool] = None,
Expand Down Expand Up @@ -2143,6 +2145,9 @@ def generate(

return outputs

condition_on_previous_tokens = condition_on_previous_tokens or getattr(self.generation_config, "condition_on_previous_tokens", False)
self.generation_config.condition_on_previous_tokens = condition_on_previous_tokens

# 6. Else we're in longform mode which is more complex. We need to chunk the audio input depending on when the model generated
# timestamp tokens
# 6.1 Set running parameters for while loop
Expand Down Expand Up @@ -2214,6 +2219,17 @@ def generate(

segment_input = torch.cat(segment_input, dim=0)

decoder_input_ids = None
if condition_on_previous_tokens and len(current_segments[0]) > 0:
# according to https://github.com/openai/whisper/blob/e58f28804528831904c3b6f2c0e473f346223433/whisper/decoding.py#L609
cut_off_length = self.config.max_target_positions // 2 - 1
active_segments = [current_segments[i] for i in new_cur_to_prev_index_map]
decoder_input_ids = self._pad_to_max_length(active_segments, self.generation_config.pad_token_id, padding="left")
decoder_input_ids = torch.cat([decoder_input_ids[:, -cut_off_length:], torch.ones((cur_bsz, 1), device=decoder_input_ids.device, dtype=torch.long) * self.config.decoder_start_token_id], dim=-1)

timestamp_processor = [proc for proc in logits_processor if isinstance(proc, WhisperTimeStampLogitsProcessor)][0]
timestamp_processor.set_begin_index(decoder_input_ids.shape[-1])

# 6.6 Batch generate current chunk
seek_outputs = super().generate(
segment_input,
Expand All @@ -2223,6 +2239,7 @@ def generate(
prefix_allowed_tokens_fn,
synced_gpus,
return_dict_in_generate=return_dict_in_generate,
decoder_input_ids=decoder_input_ids,
**kwargs,
)

Expand All @@ -2241,6 +2258,9 @@ def generate(
else:
seek_sequences = seek_outputs

if decoder_input_ids is not None:
seek_sequences = seek_sequences[:, decoder_input_ids.shape[-1]:]

# 6.7 Loop over each decoded audio individually as each decoding can be of a different length
for i, seek_sequence in enumerate(seek_sequences):
prev_i = cur_to_prev_index_map[i]
Expand Down Expand Up @@ -2273,25 +2293,34 @@ def generate(

# 7. Once all segments are added to the list of all segments, called `current_segments`, we extract the predicted
# output tokens from the list of dicts. If we use batch size > 1, we make sure to pad the output
sequences = []
sequences = self._pad_to_max_length(current_segments, self.generation_config.pad_token_id, padding="right")

# 8. If we return all segments, the predicted output sequences are put under `"sequences"`.
if return_segments:
return {"sequences": sequences, "segments": current_segments}

return sequences

@staticmethod
def _pad_to_max_length(current_segments, pad_token_id, padding="right"):
max_total_length = 0
sequences = []
if padding not in ["right", "left"]:
raise ValueError(f"`padding` must be either 'right' or 'left', not {padding}")

for current_segment_list in current_segments:
sequences.append(torch.cat([d["tokens"] for d in current_segment_list], dim=-1))
max_total_length = max(max_total_length, len(sequences[-1]))

for i in range(batch_size):
sequences[i] = F.pad(
sequences[i], pad=(0, max_total_length - len(sequences[i])), value=self.generation_config.pad_token_id
)
for i in range(len(current_segments)):
pad_length = max_total_length - len(sequences[i])
pad = (0, pad_length) if padding == "right" else (pad_length, 0)
sequences[i] = F.pad(sequences[i], pad=pad, value=pad_token_id)

sequences = torch.stack(sequences, dim=0)

# 8. If we return all segments, the predicted output sequences are put under `"sequences"`.
if return_segments:
return {"sequences": sequences, "segments": current_segments}

return sequences


@staticmethod
def _retrieve_segment(
seek_sequence,
Expand Down
Loading