Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Whisper] Add sequential longform decoding #27492

Merged
merged 39 commits into from
Nov 22, 2023
Merged
Changes from 1 commit
Commits
Show all changes
39 commits
Select commit Hold shift + click to select a range
51d2a53
[Whisper] Add seq gen
patrickvonplaten Nov 14, 2023
b0c387d
[Whisper] Add seq gen
patrickvonplaten Nov 14, 2023
e7cb31b
more debug
patrickvonplaten Nov 14, 2023
0c3e8c5
Fix whisper logit processor
patrickvonplaten Nov 15, 2023
6e5ce42
Improve whisper code further
patrickvonplaten Nov 15, 2023
ba318f7
Fix more
patrickvonplaten Nov 15, 2023
c47c316
Merge branch 'main' of https://github.com/huggingface/transformers in…
patrickvonplaten Nov 16, 2023
a9ce5b4
Merge branch 'main' of https://github.com/huggingface/transformers in…
patrickvonplaten Nov 16, 2023
2b0dcf6
more debug
patrickvonplaten Nov 16, 2023
8b2281f
more debug
patrickvonplaten Nov 16, 2023
0afe526
Improve further
patrickvonplaten Nov 16, 2023
d2d16b4
Add tests
patrickvonplaten Nov 16, 2023
68e8226
Prep for batch size > 1
patrickvonplaten Nov 16, 2023
ee43be4
Get batch_size>1 working
patrickvonplaten Nov 17, 2023
1030f22
Correct more
patrickvonplaten Nov 17, 2023
04477e7
Add extensive tests
patrickvonplaten Nov 17, 2023
87b5d8d
more debug
patrickvonplaten Nov 17, 2023
a9cf2bb
more debug
patrickvonplaten Nov 17, 2023
5f3ff78
more debug
patrickvonplaten Nov 19, 2023
6c942be
Merge branch 'main' of https://github.com/huggingface/transformers in…
patrickvonplaten Nov 20, 2023
1c1d1e6
add more tests
patrickvonplaten Nov 20, 2023
74967ee
more debug
patrickvonplaten Nov 20, 2023
0e95291
Apply suggestions from code review
patrickvonplaten Nov 20, 2023
311995d
more debug
patrickvonplaten Nov 20, 2023
aeee0f2
add comments to explain the code better
patrickvonplaten Nov 20, 2023
c8507c7
add comments to explain the code better
patrickvonplaten Nov 20, 2023
0593495
add comments to explain the code better
patrickvonplaten Nov 20, 2023
4382898
Add more examples
patrickvonplaten Nov 20, 2023
708be99
add comments to explain the code better
patrickvonplaten Nov 20, 2023
0dfead2
fix more
patrickvonplaten Nov 20, 2023
79c39d8
Merge branch 'add_whisper_seq_gen' of https://github.com/huggingface/…
patrickvonplaten Nov 20, 2023
62ccd52
add comments to explain the code better
patrickvonplaten Nov 20, 2023
a5755d9
Merge branch 'add_whisper_seq_gen' of https://github.com/huggingface/…
patrickvonplaten Nov 20, 2023
a75ea30
add comments to explain the code better
patrickvonplaten Nov 20, 2023
cc4c19c
correct
patrickvonplaten Nov 21, 2023
0878ac6
correct
patrickvonplaten Nov 21, 2023
889eebb
finalize
patrickvonplaten Nov 22, 2023
c1c2042
Apply suggestions from code review
patrickvonplaten Nov 22, 2023
cc1f87c
Apply suggestions from code review
patrickvonplaten Nov 22, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
more debug
  • Loading branch information
patrickvonplaten committed Nov 17, 2023
commit 87b5d8dc9dbaaf10dea7d803d699995458879d6a
78 changes: 61 additions & 17 deletions tests/models/whisper/test_modeling_whisper.py
Original file line number Diff line number Diff line change
@@ -106,7 +106,9 @@ def prepare_whisper_inputs_dict(
class DummyTimestampLogitProcessor(LogitsProcessor):
"""This processor fakes the correct timestamps tokens pattern [TOK_1] [TOK_2] ... [TOK_N] [TIME_STAMP_TOK_1] [TIME_STAMP_TOK_2] [TOK_N+1] ..."""

def __init__(self, timestamp_begin, vocab_size, batch_size, max_length, min_space=3, seed=0, is_length_ascending=True):
def __init__(
self, timestamp_begin, vocab_size, batch_size, max_length, min_space=3, seed=0, is_length_ascending=True
):
self.timestamp_begin = timestamp_begin
self.vocab_size = vocab_size

@@ -130,7 +132,7 @@ def __init__(self, timestamp_begin, vocab_size, batch_size, max_length, min_spac
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
# we don't want to randomely sample timestamp tokens
if input_ids.shape[-1] > 1:
scores[:, self.timestamp_begin:] = -float("inf")
scores[:, self.timestamp_begin :] = -float("inf")

self.no_time_stamp_counter = [x + 1 for x in self.no_time_stamp_counter]
for k in range(input_ids.shape[0]):
@@ -144,12 +146,14 @@ def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> to
self.no_time_stamp_counter[prev_k] = 0

can_produce = self.no_time_stamp_counter[prev_k] > self.min_space_between_timestamps
must_produce = input_ids[k][2:].le(self.timestamp_begin).all() and input_ids.shape[-1] == self.max_length - 1
must_produce = (
input_ids[k][2:].le(self.timestamp_begin).all() and input_ids.shape[-1] == self.max_length - 1
)
# produce timestamp with 30%
if (can_produce and self.let_pass[prev_k][self.count]) or must_produce:
if (can_produce and self.let_pass[prev_k][self.count]) or must_produce:
self.no_time_stamp_counter[prev_k] = 0
self.prev_highest_timestamp[prev_k] = max(input_ids[k].max() + 1, self.timestamp_tokens[0].item())

# force a timestamp
scores[k, :] = -float("inf")
scores[k, self.prev_highest_timestamp[prev_k]] = 10.0
@@ -1323,15 +1327,22 @@ def test_longform_generate_single_batch(self):
batch_size = 1
num_timestamp_tokens = 20
max_length = 16
logits_processor = [DummyTimestampLogitProcessor(vocab_size - num_timestamp_tokens, vocab_size, batch_size=batch_size, max_length=max_length, min_space=4)]
logits_processor = [
DummyTimestampLogitProcessor(
vocab_size - num_timestamp_tokens,
vocab_size,
batch_size=batch_size,
max_length=max_length,
min_space=4,
)
]

# each chunk should not be longer than 10
model.generation_config.max_length = max_length

# if input features are long can't set return_timestamps to False
with self.assertRaises(ValueError):
_ = model.generate(long_input_features, logits_processor=logits_processor, return_timestamps=False)


# if input features are long need to set generation config
with self.assertRaises(ValueError):
@@ -1350,9 +1361,16 @@ def test_longform_generate_single_batch(self):

for i, segment in enumerate(segments):
assert segment["start"] <= segment["end"], "start has to be smaller equal end"
assert segment["tokens"][0] == model.generation_config.decoder_start_token_id or segment["tokens"][0] >= timestamp_begin, "First segment token should be a timestamp token"
assert any(s > timestamp_begin for s in segment["tokens"][1:]), f"At least one segment token should be a timestamp token, but not first., {segment['tokens']}"
assert segment["tokens"].shape[-1] <= max_length, "make sure that no segment is larger than max generation length"
assert (
segment["tokens"][0] == model.generation_config.decoder_start_token_id
or segment["tokens"][0] >= timestamp_begin
), "First segment token should be a timestamp token"
assert any(
s > timestamp_begin for s in segment["tokens"][1:]
), f"At least one segment token should be a timestamp token, but not first., {segment['tokens']}"
assert (
segment["tokens"].shape[-1] <= max_length
), "make sure that no segment is larger than max generation length"

def test_longform_generate_multi_batch(self):
config, input_dict = self.model_tester.prepare_config_and_inputs_for_common()
@@ -1362,9 +1380,11 @@ def test_longform_generate_multi_batch(self):

# len = 250 with num_input_frames = 60
long_input_features = torch.cat([input_features.repeat(1, 1, 4), input_features[:, :, :10]], dim=-1)
input_features_1 = long_input_features[:1, :, :200]
long_input_features[:1, :, :200]
input_features_2 = long_input_features[1:]
attention_mask = torch.ones((2, long_input_features.shape[-1]), dtype=input_features.dtype, device=input_features.device)
attention_mask = torch.ones(
(2, long_input_features.shape[-1]), dtype=input_features.dtype, device=input_features.device
)
attention_mask[0, 200:] = 0

# force bsz=1
@@ -1380,14 +1400,34 @@ def test_longform_generate_multi_batch(self):
# make sure that we only have the same begin token
model.generation_config.max_initial_timestamp_index = 0

logits_processor = [DummyTimestampLogitProcessor(vocab_size - num_timestamp_tokens, vocab_size, batch_size=batch_size, max_length=max_length, min_space=4, seed=1)]
logits_processor = [
DummyTimestampLogitProcessor(
vocab_size - num_timestamp_tokens,
vocab_size,
batch_size=batch_size,
max_length=max_length,
min_space=4,
seed=1,
)
]
outputs_2 = model.generate(input_features_2, logits_processor=logits_processor, return_segments=True)
tokens_2 = outputs_2["sequences"][0]
segments_2 = outputs_2["segments"][0]

batch_size = 2
logits_processor = [DummyTimestampLogitProcessor(vocab_size - num_timestamp_tokens, vocab_size, batch_size=batch_size, max_length=max_length, min_space=4, seed=0)]
outputs = model.generate(long_input_features, attention_mask=attention_mask, logits_processor=logits_processor, return_segments=True)
logits_processor = [
DummyTimestampLogitProcessor(
vocab_size - num_timestamp_tokens,
vocab_size,
batch_size=batch_size,
max_length=max_length,
min_space=4,
seed=0,
)
]
outputs = model.generate(
long_input_features, attention_mask=attention_mask, logits_processor=logits_processor, return_segments=True
)
tokens = outputs["sequences"][1]
segments = outputs["segments"][1]

@@ -2007,7 +2047,9 @@ def test_whisper_longform_single_batch(self):
ds = load_dataset("patrickvonplaten/librispeech_asr_dummy", "clean")
one_audio = np.concatenate([x["array"] for x in ds["validation"]["audio"]], dtype=np.float32)

input_features = processor(one_audio, return_tensors="pt", truncation=False, padding="longest")["input_features"]
input_features = processor(one_audio, return_tensors="pt", truncation=False, padding="longest")[
"input_features"
]
input_features = input_features.to(device="cuda")

result = model.generate(input_features, return_timestamps=True)
@@ -2044,7 +2086,9 @@ def test_whisper_longform_multi_batch(self):
result = model.generate(**inputs, return_timestamps=True)
decoded_single.append(processor.batch_decode(result, skip_special_tokens=True))

inputs = processor(audios, return_tensors="pt", truncation=False, padding="longest", return_attention_mask=True)
inputs = processor(
audios, return_tensors="pt", truncation=False, padding="longest", return_attention_mask=True
)
inputs = inputs.to(device="cuda")

result = model.generate(**inputs, return_timestamps=True)