Skip to content

Commit

Permalink
Whisper tokenizer word level timestamps (huggingface#32197)
Browse files Browse the repository at this point in the history
* fix _fix_key in PreTrainedModel

* fix _find_longest_common_sequence

* add test

* remove result.json

* nit

* update test
  • Loading branch information
kamilakesbi authored Jul 29, 2024
1 parent 7ffe25f commit 3fbaaaa
Show file tree
Hide file tree
Showing 2 changed files with 52 additions and 1 deletion.
17 changes: 16 additions & 1 deletion src/transformers/models/whisper/tokenization_whisper.py
Original file line number Diff line number Diff line change
Expand Up @@ -1174,7 +1174,22 @@ def _find_longest_common_sequence(sequences, token_timestamp_sequences=None):
"There is a bug within whisper `decode_asr` function, please report it. Dropping to prevent bad inference."
)

matches = np.sum(left == right)
if token_timestamp_sequences:
# Get length of longest subsequence of tokens that match
# and have timestamps that are in order
matches = sum(
1
for idx, elem in enumerate(left)
if (
elem == right[idx]
and left_token_timestamp_sequence[left_start + idx]
<= token_timestamp_sequences[seq_idx + 1][right_start + idx]
)
)

else:
matches = np.sum(left == right)

matching = matches / i + eps
if matches > 1 and matching > max_:
max_ = matching
Expand Down
36 changes: 36 additions & 0 deletions tests/models/whisper/test_tokenization_whisper.py
Original file line number Diff line number Diff line change
Expand Up @@ -338,6 +338,42 @@ def test_basic_normalizer(self):
)
self.assertEqual(decoded_output_diacritics, expected_output_diacritics)

def test_decode_asr_with_word_level_timestamps(self):
# fmt: off
model_outputs = [
{
'stride': [10, 0, 5],
'tokens': np.array([[ 50257, 50362, 3363, 11, 345, 460, 0, 2329, 466, 340, 0, 50256 ]]),
'token_timestamps': np.array([[ 0, 0, 5.18, 5.56, 5.56, 5.84, 6.36, 7.12, 7.54, 7.82, 8.16, 9.48 ]])
},
{
'stride': [10, 5, 0],
'tokens': np.array([[ 50257, 50362, 2329, 466, 340, 0, 3363, 345, 460, 0, 2329, 466, 340, 50256 ]]),
'token_timestamps': np.array([[ 0, 0, 0, 2.44, 4.3, 5.04, 5.06, 5.56, 5.8, 6.32, 7.12, 7.56, 7.8, 8.72 ]])
}
]
# fmt: on

tokenizer = WhisperTokenizer.from_pretrained("onnx-community/whisper-tiny.en_timestamped")
result = tokenizer._decode_asr(
model_outputs, return_timestamps="word", return_language=False, time_precision=0.02
)

EXPECTED_OUTPUT = (
" Yes, you can! Just do it",
{
"chunks": [
{"text": " Yes,", "timestamp": (5.18, 5.56)},
{"text": " you", "timestamp": (5.56, 5.84)},
{"text": " can!", "timestamp": (5.84, 7.12)},
{"text": " Just", "timestamp": (7.12, 7.56)},
{"text": " do", "timestamp": (7.56, 7.8)},
{"text": " it", "timestamp": (7.8, 8.72)},
]
},
)
self.assertEqual(result, EXPECTED_OUTPUT)


class SpeechToTextTokenizerMultilinguialTest(unittest.TestCase):
checkpoint_name = "openai/whisper-small.en"
Expand Down

0 comments on commit 3fbaaaa

Please sign in to comment.