diff --git a/src/transformers/models/whisper/tokenization_whisper.py b/src/transformers/models/whisper/tokenization_whisper.py index fd65041595f6a4..26537fe68f912c 100644 --- a/src/transformers/models/whisper/tokenization_whisper.py +++ b/src/transformers/models/whisper/tokenization_whisper.py @@ -1174,7 +1174,22 @@ def _find_longest_common_sequence(sequences, token_timestamp_sequences=None): "There is a bug within whisper `decode_asr` function, please report it. Dropping to prevent bad inference." ) - matches = np.sum(left == right) + if token_timestamp_sequences: + # Get length of longest subsequence of tokens that match + # and have timestamps that are in order + matches = sum( + 1 + for idx, elem in enumerate(left) + if ( + elem == right[idx] + and left_token_timestamp_sequence[left_start + idx] + <= token_timestamp_sequences[seq_idx + 1][right_start + idx] + ) + ) + + else: + matches = np.sum(left == right) + matching = matches / i + eps if matches > 1 and matching > max_: max_ = matching diff --git a/tests/models/whisper/test_tokenization_whisper.py b/tests/models/whisper/test_tokenization_whisper.py index 530e23351cc09b..5c653f1984f632 100644 --- a/tests/models/whisper/test_tokenization_whisper.py +++ b/tests/models/whisper/test_tokenization_whisper.py @@ -338,6 +338,42 @@ def test_basic_normalizer(self): ) self.assertEqual(decoded_output_diacritics, expected_output_diacritics) + def test_decode_asr_with_word_level_timestamps(self): + # fmt: off + model_outputs = [ + { + 'stride': [10, 0, 5], + 'tokens': np.array([[ 50257, 50362, 3363, 11, 345, 460, 0, 2329, 466, 340, 0, 50256 ]]), + 'token_timestamps': np.array([[ 0, 0, 5.18, 5.56, 5.56, 5.84, 6.36, 7.12, 7.54, 7.82, 8.16, 9.48 ]]) + }, + { + 'stride': [10, 5, 0], + 'tokens': np.array([[ 50257, 50362, 2329, 466, 340, 0, 3363, 345, 460, 0, 2329, 466, 340, 50256 ]]), + 'token_timestamps': np.array([[ 0, 0, 0, 2.44, 4.3, 5.04, 5.06, 5.56, 5.8, 6.32, 7.12, 7.56, 7.8, 8.72 ]]) + } + ] + # fmt: on + + tokenizer = WhisperTokenizer.from_pretrained("onnx-community/whisper-tiny.en_timestamped") + result = tokenizer._decode_asr( + model_outputs, return_timestamps="word", return_language=False, time_precision=0.02 + ) + + EXPECTED_OUTPUT = ( + " Yes, you can! Just do it", + { + "chunks": [ + {"text": " Yes,", "timestamp": (5.18, 5.56)}, + {"text": " you", "timestamp": (5.56, 5.84)}, + {"text": " can!", "timestamp": (5.84, 7.12)}, + {"text": " Just", "timestamp": (7.12, 7.56)}, + {"text": " do", "timestamp": (7.56, 7.8)}, + {"text": " it", "timestamp": (7.8, 8.72)}, + ] + }, + ) + self.assertEqual(result, EXPECTED_OUTPUT) + class SpeechToTextTokenizerMultilinguialTest(unittest.TestCase): checkpoint_name = "openai/whisper-small.en"