From e156abd05a9910106351edaa4fa856c5ba93a0b5 Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Fri, 19 Jan 2024 14:04:17 +0200 Subject: [PATCH] [Whisper] Finalize batched SOTA long-form generation (#27658) * finalize * make fix copies whisper * [Tests] Make sure that we don't run tests mulitple times * Update src/transformers/models/whisper/modeling_whisper.py * [Tests] Make sure that we don't run tests mulitple times * fix more * improve * improve * improve further * improve more * improve * fix more * git commit and git push * fix more * fix more * fix more * New try * Fix more whisper stuff * Improve * correct more * correct more * correct more * Fix some tests * Add more tests * correct more * correct more * correct more * push * correct more * Fix more * Better * without dec mask * correct more * clean * save intermediate * Fix more * Fix VAD for large-v2 * Save new * Correct more * make cleaner * correct tests * correct src * Finish * Fix more * Fix more * finish * Fix edge cases * fix return_dict_in_generate * fix all tests * make style * add docstrings * add docstrings * Fix logit processor * make style * fix pipeline test * fix more style * Apply suggestions from code review * apply feedback Sanchit * correct more * Apply suggestions from code review Co-authored-by: Sanchit Gandhi <93869735+sanchit-gandhi@users.noreply.github.com> * Apply suggestions from code review Co-authored-by: Joao Gante Co-authored-by: Sanchit Gandhi <93869735+sanchit-gandhi@users.noreply.github.com> * correct more * correct more * correct more * Fix staticmethod * correct more * fix * fix slow tests * make style * fix tokenizer test * fix tokenizer test * Apply suggestions from code review Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com> * finish * finish * revert kwargs change --------- Co-authored-by: Sanchit Gandhi <93869735+sanchit-gandhi@users.noreply.github.com> Co-authored-by: Joao Gante Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com> --- src/transformers/generation/logits_process.py | 74 +- src/transformers/generation/utils.py | 2 + .../models/whisper/generation_whisper.py | 1493 +++++++++++++++++ .../models/whisper/modeling_whisper.py | 848 +--------- .../models/whisper/tokenization_whisper.py | 17 +- .../whisper/tokenization_whisper_fast.py | 17 +- tests/models/whisper/test_modeling_whisper.py | 224 ++- ..._pipelines_automatic_speech_recognition.py | 2 +- 8 files changed, 1825 insertions(+), 852 deletions(-) create mode 100644 src/transformers/models/whisper/generation_whisper.py diff --git a/src/transformers/generation/logits_process.py b/src/transformers/generation/logits_process.py index 2b1b9f5a50b6ef..04120e39fbd27c 100644 --- a/src/transformers/generation/logits_process.py +++ b/src/transformers/generation/logits_process.py @@ -95,6 +95,7 @@ def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwa scores = processor(input_ids, scores, **kwargs) else: scores = processor(input_ids, scores) + return scores @@ -1657,6 +1658,9 @@ def __init__(self, begin_suppress_tokens, begin_index): self.begin_suppress_tokens = list(begin_suppress_tokens) self.begin_index = begin_index + def set_begin_index(self, begin_index): + self.begin_index = begin_index + @add_start_docstrings(LOGITS_PROCESSOR_INPUTS_DOCSTRING) def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor: if input_ids.shape[1] == self.begin_index: @@ -1778,6 +1782,7 @@ class WhisperTimeStampLogitsProcessor(LogitsProcessor): max_initial_timestamp_index (`int`, *optional*, defaults to 1): Used to set the maximum value of the initial timestamp. This is used to prevent the model from predicting timestamps that are too far in the future. + begin_index (`Optional`, *optional*): Token index of the first token that is generated by the model. _detect_timestamp_from_logprob (`bool`, *optional*): Whether timestamps can be predicted from logprobs over all timestamps. Examples: @@ -1810,11 +1815,11 @@ class WhisperTimeStampLogitsProcessor(LogitsProcessor): """ def __init__( - self, generate_config, _detect_timestamp_from_logprob: Optional[bool] = None + self, generate_config, begin_index: Optional[int] = None, _detect_timestamp_from_logprob: Optional[bool] = None ): # support for the kwargs - self.eos_token_id = generate_config.eos_token_id self.no_timestamps_token_id = generate_config.no_timestamps_token_id self.timestamp_begin = generate_config.no_timestamps_token_id + 1 + self.eos_token_id = generate_config.eos_token_id or generate_config.bos_token_id # this variable is mostly just used for testing self._detect_timestamp_from_logprob = ( @@ -1823,10 +1828,17 @@ def __init__( else getattr(generate_config, "_detect_timestamp_from_logprob", True) ) - self.begin_index = ( - len(generate_config.forced_decoder_ids) + 1 if generate_config.forced_decoder_ids is not None else 1 + num_forced_ids = ( + len(generate_config.forced_decoder_ids) if generate_config.forced_decoder_ids is not None else 0 ) + self.begin_index = begin_index or (num_forced_ids + 1) + self.max_initial_timestamp_index = getattr(generate_config, "max_initial_timestamp_index", None) + # TODO(Patrick): Make sure that official models have max_initial_timestamp_index set to 50 + # self.max_initial_timestamp_index = 50 + + def set_begin_index(self, begin_index): + self.begin_index = begin_index @add_start_docstrings(LOGITS_PROCESSOR_INPUTS_DOCSTRING) def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor: @@ -1878,6 +1890,60 @@ def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> to return scores +class WhisperNoSpeechDetection(LogitsProcessor): + r"""This processor can be used to detect silence when using Whisper. It should take as input unprocessed logits to follow the original implementation""" + + def __init__(self, no_speech_token: int, begin_index: int, scores_is_logprobs: bool = False): + self.no_speech_token = no_speech_token + # offset between token, , in paper and first generated token + # is equal to the position of the first generated token index + self.start_of_trans_offset = begin_index + + # `self.begin_index` is a running value that is changed on the fly + self.begin_index = begin_index + self._no_speech_prob = [0.0] + self.is_scores_logprobs = scores_is_logprobs + + # overwritten dynamically + self.model = None + self.inputs = None + + def set_model(self, model): + self.model = model + + def set_inputs(self, inputs): + self.inputs = {**self.model.prepare_inputs_for_generation(**inputs), **inputs} + self.inputs["input_features"] = self.inputs.pop("inputs") + + @property + def no_speech_prob(self): + return self._no_speech_prob + + def set_begin_index(self, begin_index): + self.begin_index = begin_index + + @add_start_docstrings(LOGITS_PROCESSOR_INPUTS_DOCSTRING) + def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor: + if input_ids.shape[1] == self.begin_index: + if self.start_of_trans_offset > 1: + with torch.no_grad(): + logits = self.model(**self.inputs).logits + + no_speech_index = self.begin_index - self.start_of_trans_offset + no_speech_scores = logits[:, no_speech_index] + else: + no_speech_scores = scores + + if self.is_scores_logprobs: + probs = no_speech_scores.exp() + else: + probs = no_speech_scores.float().softmax(dim=-1) + + self._no_speech_prob = probs[:, self.no_speech_token] + + return scores + + class ClassifierFreeGuidanceLogitsProcessor(LogitsProcessor): r""" [`LogitsProcessor`] for classifier free guidance (CFG). The scores are split over the batch dimension, diff --git a/src/transformers/generation/utils.py b/src/transformers/generation/utils.py index f8d1e3c3ef423d..a16c6e4b302846 100644 --- a/src/transformers/generation/utils.py +++ b/src/transformers/generation/utils.py @@ -518,6 +518,8 @@ def _prepare_decoder_input_ids_for_generation( # exception: Donut checkpoints have task-specific decoder starts and don't expect a BOS token elif self.config.model_type == "vision-encoder-decoder" and "donut" in self.name_or_path.lower(): pass + elif self.config.model_type in ["whisper"]: + pass # user input but doesn't start with decoder_start_token_id -> prepend decoder_start_token_id (and adjust # decoder_attention_mask if provided) elif (decoder_input_ids[:, 0] != decoder_start_token_id).all().item(): diff --git a/src/transformers/models/whisper/generation_whisper.py b/src/transformers/models/whisper/generation_whisper.py new file mode 100644 index 00000000000000..c45fffb984b113 --- /dev/null +++ b/src/transformers/models/whisper/generation_whisper.py @@ -0,0 +1,1493 @@ +# coding=utf-8 +# Copyright 2024 The HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import copy +import math +import warnings +import zlib +from typing import Callable, List, Optional, Tuple, Union + +import numpy as np +import torch +import torch.nn.functional as F +from torch import nn + +from ...generation.configuration_utils import GenerationConfig +from ...generation.logits_process import ( + ForceTokensLogitsProcessor, + LogitsProcessorList, + SuppressTokensAtBeginLogitsProcessor, + SuppressTokensLogitsProcessor, + WhisperNoSpeechDetection, + WhisperTimeStampLogitsProcessor, +) +from ...generation.stopping_criteria import StoppingCriteriaList +from ...modeling_outputs import BaseModelOutput +from ...utils import logging +from .tokenization_whisper import TASK_IDS, TO_LANGUAGE_CODE + + +logger = logging.get_logger(__name__) + + +def _median_filter(inputs: torch.Tensor, filter_width: int) -> torch.Tensor: + """ + Applies a median filter of width `filter_width` along the last dimension of the input. + + The `inputs` tensor is assumed to be 3- or 4-dimensional. + """ + if filter_width <= 0 or filter_width % 2 != 1: + raise ValueError("`filter_width` should be an odd number") + + pad_width = filter_width // 2 + if inputs.shape[-1] <= pad_width: + return inputs + + # Pad the left and right edges. + inputs = nn.functional.pad(inputs, (pad_width, pad_width, 0, 0), mode="reflect") + + # sort() is faster than torch.median (https://github.com/pytorch/pytorch/issues/51450) + result = inputs.unfold(-1, filter_width, 1).sort()[0][..., pad_width] + return result + + +def _dynamic_time_warping(matrix: np.ndarray): + """ + Measures similarity between two temporal sequences: the input audio and the output tokens. Used to generate + token-level timestamps. + """ + output_length, input_length = matrix.shape + cost = np.ones((output_length + 1, input_length + 1), dtype=np.float32) * np.inf + trace = -np.ones((output_length + 1, input_length + 1), dtype=np.float32) + + cost[0, 0] = 0 + for j in range(1, input_length + 1): + for i in range(1, output_length + 1): + c0 = cost[i - 1, j - 1] + c1 = cost[i - 1, j] + c2 = cost[i, j - 1] + + if c0 < c1 and c0 < c2: + c, t = c0, 0 + elif c1 < c0 and c1 < c2: + c, t = c1, 1 + else: + c, t = c2, 2 + + cost[i, j] = matrix[i - 1, j - 1] + c + trace[i, j] = t + + # backtrace + i = trace.shape[0] - 1 + j = trace.shape[1] - 1 + trace[0, :] = 2 + trace[:, 0] = 1 + + text_indices = [] + time_indices = [] + while i > 0 or j > 0: + text_indices.append(i - 1) + time_indices.append(j - 1) + if trace[i, j] == 0: + i -= 1 + j -= 1 + elif trace[i, j] == 1: + i -= 1 + elif trace[i, j] == 2: + j -= 1 + else: + raise RuntimeError( + f"Internal error in dynamic time warping. Unexpected trace[{i}, {j}]. Please file a bug report." + ) + + text_indices = np.array(text_indices)[::-1] + time_indices = np.array(time_indices)[::-1] + return text_indices, time_indices + + +def _get_attr_from_logit_processors(logits_processor, logit_processor_class, attribute_name): + logit_processor = next((cls for cls in logits_processor if isinstance(cls, logit_processor_class)), None) + if logit_processor: + return getattr(logit_processor, attribute_name, None) + return None + + +def _pad_to_max_length(current_segments, pad_token_id, padding="right", bos_token_tensor=None, cut_off_length=None): + max_total_length = 0 + sequences = [] + if padding not in ["right", "left"]: + raise ValueError(f"`padding` must be either 'right' or 'left', not {padding}") + + for current_segment_list in current_segments: + if current_segment_list is not None and len([d["tokens"] for d in current_segment_list]) > 0: + sequence = torch.cat([d["tokens"] for d in current_segment_list], dim=-1) + + if cut_off_length is not None: + sequence = sequence[-cut_off_length:] + + if bos_token_tensor is not None: + sequence = torch.cat([bos_token_tensor, sequence]) + + sequences.append(sequence) + max_total_length = max(max_total_length, len(sequences[-1])) + else: + sequences.append(bos_token_tensor) + + for i in range(len(current_segments)): + pad_length = max_total_length - len(sequences[i]) + pad = (0, pad_length) if padding == "right" else (pad_length, 0) + sequences[i] = F.pad(sequences[i], pad=pad, value=pad_token_id) + + sequences = torch.stack(sequences, dim=0) + return sequences + + +class WhisperGenerationMixin: + def _extract_token_timestamps(self, generate_outputs, alignment_heads, time_precision=0.02, num_frames=None): + """ + Calculates token-level timestamps using the encoder-decoder cross-attentions and dynamic time-warping (DTW) to + map each output token to a position in the input audio. If `num_frames` is specified, the encoder-decoder + cross-attentions will be cropped before applying DTW. + + Returns: + tensor containing the timestamps in seconds for each predicted token + """ + # Create a list with `decoder_layers` elements, each a tensor of shape + # (batch size, attention_heads, output length, input length). + cross_attentions = [] + for i in range(self.config.decoder_layers): + cross_attentions.append(torch.cat([x[i] for x in generate_outputs.cross_attentions], dim=2)) + + # Select specific cross-attention layers and heads. This is a tensor + # of shape (batch size, num selected, output length, input length). + weights = torch.stack([cross_attentions[l][:, h] for l, h in alignment_heads]) + weights = weights.permute([1, 0, 2, 3]) + + if "beam_indices" in generate_outputs: + # If beam search has been used, the output sequences may have been generated for more timesteps than their sequence_lengths + # since the beam search strategy chooses the most probable sequences at the end of the search. + # In that case, the cross_attentions weights are too long and we have to make sure that they have the right output_length + weight_length = (generate_outputs.beam_indices != -1).sum(-1).max() + weights = weights[:, :, :weight_length] + + # If beam index is still -1, it means that the associated token id is EOS + # We need to replace the index with 0 since index_select gives an error if any of the indexes is -1. + beam_indices = generate_outputs.beam_indices[:, :weight_length] + beam_indices = beam_indices.masked_fill(beam_indices == -1, 0) + + # Select the cross attention from the right beam for each output sequences + weights = torch.stack( + [ + torch.index_select(weights[:, :, i, :], dim=0, index=beam_indices[:, i]) + for i in range(beam_indices.shape[1]) + ], + dim=2, + ) + + timestamps = torch.zeros_like(generate_outputs.sequences, dtype=torch.float32) + batch_size = timestamps.shape[0] + + if num_frames is not None: + # two cases: + # 1. num_frames is the same for each sample -> compute the DTW matrix for each sample in parallel + # 2. num_frames is different, compute the DTW matrix for each sample sequentially + + # we're using np.unique because num_frames can be int/list/tuple + if len(np.unique(num_frames)) == 1: + # if num_frames is the same, no need to recompute matrix, std and mean for each element of the batch + num_frames = num_frames if isinstance(num_frames, int) else num_frames[0] + + weights = weights[..., : num_frames // 2] + else: + # num_frames is of shape (batch_size,) whereas batch_size is truely batch_size*num_return_sequences + repeat_time = batch_size if isinstance(num_frames, int) else batch_size // len(num_frames) + num_frames = np.repeat(num_frames, repeat_time) + + if num_frames is None or isinstance(num_frames, int): + # Normalize and smoothen the weights. + std = torch.std(weights, dim=-2, keepdim=True, unbiased=False) + mean = torch.mean(weights, dim=-2, keepdim=True) + weights = (weights - mean) / std + weights = _median_filter(weights, self.config.median_filter_width) + + # Average the different cross-attention heads. + weights = weights.mean(dim=1) + + # Perform dynamic time warping on each element of the batch. + for batch_idx in range(batch_size): + if num_frames is not None and isinstance(num_frames, (tuple, list, np.ndarray)): + matrix = weights[batch_idx, ..., : num_frames[batch_idx] // 2] + + # Normalize and smoothen the weights. + std = torch.std(matrix, dim=-2, keepdim=True, unbiased=False) + mean = torch.mean(matrix, dim=-2, keepdim=True) + matrix = (matrix - mean) / std + matrix = _median_filter(matrix, self.config.median_filter_width) + + # Average the different cross-attention heads. + matrix = matrix.mean(dim=0) + else: + matrix = weights[batch_idx] + + text_indices, time_indices = _dynamic_time_warping(-matrix.cpu().double().numpy()) + jumps = np.pad(np.diff(text_indices), (1, 0), constant_values=1).astype(bool) + jump_times = time_indices[jumps] * time_precision + timestamps[batch_idx, 1:] = torch.tensor(jump_times) + + return timestamps + + def generate( + self, + input_features: Optional[torch.Tensor] = None, + generation_config: Optional[GenerationConfig] = None, + logits_processor: Optional[LogitsProcessorList] = None, + stopping_criteria: Optional[StoppingCriteriaList] = None, + prefix_allowed_tokens_fn: Optional[Callable[[int, torch.Tensor], List[int]]] = None, + synced_gpus: bool = False, + return_timestamps: Optional[bool] = None, + task: Optional[str] = None, + language: Optional[str] = None, + is_multilingual: Optional[bool] = None, + prompt_ids: Optional[torch.Tensor] = None, + condition_on_prev_tokens: Optional[bool] = None, + temperature: Optional[Union[float, Tuple[float, ...]]] = None, + compression_ratio_threshold: Optional[float] = None, + logprob_threshold: Optional[float] = None, + no_speech_threshold: Optional[float] = None, + num_segment_frames: Optional[int] = None, + attention_mask: Optional[torch.Tensor] = None, + time_precision: float = 0.02, + return_token_timestamps: Optional[bool] = None, + return_segments: bool = False, + return_dict_in_generate: Optional[bool] = None, + **kwargs, + ): + """ + Transcribes or translates log-mel input features to a sequence of auto-regressively generated token ids. + + + + Most generation-controlling parameters are set in `generation_config` which, if not passed, will be set to the + model's default generation configuration. You can override any `generation_config` by passing the corresponding + parameters to generate(), e.g. `.generate(inputs, num_beams=4, do_sample=True)`. + + For an overview of generation strategies and code examples, check out the [following + guide](./generation_strategies). + + + + Parameters: + input_features (`torch.Tensor` of shape `(batch_size, feature_size, sequence_length)`, *optional*): + Float values of log-mel features extracted from the raw speech waveform. The raw speech waveform can be obtained by + loading a `.flac` or `.wav` audio file into an array of type `List[float]` or a `numpy.ndarray`, *e.g.* via + the soundfile library (`pip install soundfile`). To prepare the array into `input_features`, the + [`AutoFeatureExtractor`] should be used for extracting the mel features, padding and conversion into a + tensor of type `torch.FloatTensor`. See [`~WhisperFeatureExtractor.__call__`] for details. + generation_config (`~generation.GenerationConfig`, *optional*): + The generation configuration to be used as base parametrization for the generation call. `**kwargs` + passed to generate matching the attributes of `generation_config` will override them. If + `generation_config` is not provided, the default will be used, which had the following loading + priority: 1) from the `generation_config.json` model file, if it exists; 2) from the model + configuration. Please note that unspecified parameters will inherit [`~generation.GenerationConfig`]'s + default values, whose documentation should be checked to parameterize generation. + logits_processor (`LogitsProcessorList`, *optional*): + Custom logits processors that complement the default logits processors built from arguments and + generation config. If a logit processor is passed that is already created with the arguments or a + generation config an error is thrown. This feature is intended for advanced users. + stopping_criteria (`StoppingCriteriaList`, *optional*): + Custom stopping criteria that complement the default stopping criteria built from arguments and a + generation config. If a stopping criteria is passed that is already created with the arguments or a + generation config an error is thrown. This feature is intended for advanced users. + prefix_allowed_tokens_fn (`Callable[[int, torch.Tensor], List[int]]`, *optional*): + If provided, this function constraints the beam search to allowed tokens only at each step. If not + provided no constraint is applied. This function takes 2 arguments: the batch ID `batch_id` and + `input_ids`. It has to return a list with the allowed tokens for the next generation step conditioned + on the batch ID `batch_id` and the previously generated tokens `inputs_ids`. This argument is useful + for constrained generation conditioned on the prefix, as described in [Autoregressive Entity + Retrieval](https://arxiv.org/abs/2010.00904). + synced_gpus (`bool`, *optional*, defaults to `False`): + Whether to continue running the while loop until max_length (needed for ZeRO stage 3) + return_timestamps (`bool`, *optional*): + Whether to return the timestamps with the text. This enables the `WhisperTimestampsLogitsProcessor`. + task (`str`, *optional*): + Task to use for generation, either "translate" or "transcribe". The `model.config.forced_decoder_ids` + will be updated accordingly. + language (`str`, *optional*): + Language token to use for generation, can be either in the form of `<|en|>`, `en` or `english`. You can + find all the possible language tokens in the `model.generation_config.lang_to_id` dictionary. + is_multilingual (`bool`, *optional*): + Whether or not the model is multilingual. + prompt_ids (`torch.Tensor`, *optional*): + Rank-1 tensor of token IDs created by passing text to [`~WhisperProcessor.get_prompt_ids`] that is + provided as a prompt to each chunk. This can be used to provide or "prompt-engineer" a context for + transcription, e.g. custom vocabularies or proper nouns to make it more likely to predict those words + correctly. It cannot be used in conjunction with `decoder_start_token_id` as it overwrites this value. + condition_on_prev_tokens (`bool`, *optional*): + Only relevant for long-form transcription. Whether to condition each segment on the previous segment. + As shown in the [the Whisper paper](https://cdn.openai.com/papers/whisper.pdf), this can help to improve + performance. + temperature (`float` or list of `float`, *optional*): + The temperature to be used for generation. Passing a single `float` value and `do_sample=True` activates + generation using sampling. For long-form transcription, temperature fallback can be activated by passing + a list of float values such as (0.0, 0.2, 0.4, 0.6, 0.8, 1.0). As shown in the [the Whisper paper](https://cdn.openai.com/papers/whisper.pdf), this can help to improve + performance. + compression_ratio_threshold (`float`, *optional*): + Only relevant for long-form transcription. If defined, the zlib compression rate of each segment will be computed. If the compression rate of + a segment is higher than `compression_ratio_threshold`, temperature fallback is activated: the generated segment is discarded and the generation is + repeated using a higher temperature. The intuition behind this feature is that segments with very high compression rates + suffer from a lot of repetition. The unwanted repetition can be reduced by injecting more randomness by increasing the temperature. If `compression_ratio_threshold` is defined + make sure that `temperature` is a list of values. A common value for `compression_ratio_threshold` is 1.35. + As shown in the [the Whisper paper](https://cdn.openai.com/papers/whisper.pdf), this can help to improve + performance. + logprob_threshold (`float`, *optional*): + Only relevant for long-form transcription. If defined, the average log-probability of each segment will be computed. If the log-probability of + a given segment is lower than `logprob_threshold`, temperature fallback is activated: the generated segment is discarded and the generation is + repeated using a higher temperature. The intuition behind this feature is that segments of low log-probability + can be improved by injecting more randomness by increasing the temperature. If `logprob_threshold` is defined + make sure that `temperature` is a list of values. A common value for `logprob_threshold` is -1.0. + As shown in the [the Whisper paper](https://cdn.openai.com/papers/whisper.pdf), this can help to improve + performance. + no_speech_threshold (`float`, *optional*): + Only relevant for long-form transcription. If defined, the "no-speech" token combined with the `logprob_threshold` + is used to determine whether a segment contains only silence. In this case, the transcription for this segment + is skipped. + As shown in the [the Whisper paper](https://cdn.openai.com/papers/whisper.pdf), this can help to improve + performance. + num_segment_frames (`int`, *optional*): + The number of frames a single segment is made of. If not defined, `num_segment_frames` defaults to the model's stride + times the maximum input length. + attention_mask (`torch.Tensor`, *optional*): + `attention_mask` needs to be passed when doing long-form transcription using a batch size > 1. + time_precision (`int`, *optional*, defaults to 0.02): + The duration of output token in seconds. *E.g.* 0.02 means that a generated token on average accounts + for 20 ms. + return_token_timestamps (`bool`, *optional*): + Whether to return token-level timestamps with the text. This can be used with or without the + `return_timestamps` option. To get word-level timestamps, use the tokenizer to group the tokens into + words. + return_segments (`bool`, *optional*, defaults to `False`): + Whether to additionally return a list of all segments. Note that this option can only be enabled + when doing long-form transcription. + return_dict_in_generate (`bool`, *optional*, defaults to `False`): + Whether or not to return a [`~utils.ModelOutput`] instead of just returning the generated tokens. + Note that when doing long-form transcription, `return_dict_in_generate` can only be enabled when + `return_segments` is set True. In this case the generation outputs of each segment is added to each + segment. + kwargs (`Dict[str, Any]`, *optional*): + Ad hoc parametrization of `generate_config` and/or additional model-specific kwargs that will be + forwarded to the `forward` function of the model. If the model is an encoder-decoder model, encoder + specific kwargs should not be prefixed and decoder specific kwargs should be prefixed with *decoder_*. + + Return: + [`~utils.ModelOutput`] or `torch.LongTensor` or `Dict[str, Any]`: A [`~utils.ModelOutput`] (if `return_dict_in_generate=True` + or when `config.return_dict_in_generate=True`) or a `torch.FloatTensor` or a dict of segments when `return_segments=True`. + + If the passed input is > 30 seconds / > 3000 mel input features and `return_segments=True` then a dictionary of generated sequence ids, called `sequences` and a list of each generated segment is returned. + + else if the passed input is <= 30 seconds / >= 3000 mel input features, the possible [`~utils.ModelOutput`] types are: + + - [`~generation.GenerateEncoderDecoderOutput`], + - [`~generation.GenerateBeamEncoderDecoderOutput`] + + else only the generated output sequence ids are returned. + + Example: + + - *Longform transcription*: To transcribe or translate audios longer than 30 seconds, process the audio files without truncation and pass all mel features at once to generate. + + ```python + >>> import torch + >>> from transformers import AutoProcessor, WhisperForConditionalGeneration + >>> from datasets import load_dataset, Audio + + >>> processor = AutoProcessor.from_pretrained("openai/whisper-tiny.en") + >>> model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-tiny.en") + >>> model.cuda() + + >>> # load audios > 30 seconds + >>> ds = load_dataset("distil-whisper/meanwhile", "default")["test"] + >>> # resample to 16kHz + >>> ds = ds.cast_column("audio", Audio(sampling_rate=16000)) + >>> # take first 8 audios and retrieve array + >>> audio = ds[:8]["audio"] + >>> audio = [x["array"] for x in audio] + + >>> # make sure to NOT truncate the input audio, to return the `attention_mask` and to pad to the longest audio + >>> inputs = processor(audio, return_tensors="pt", truncation=False, padding="longest", return_attention_mask=True, sampling_rate=16_000) + >>> inputs = inputs.to("cuda", torch.float32) + + >>> # transcribe audio to ids + >>> generated_ids = model.generate(**inputs) + + >>> transcription = processor.batch_decode(generated_ids, skip_special_tokens=True) + >>> transcription[0] + ' Folks, if you watch the show, you know, I spent a lot of time right over there. Patiently and astutely scrutinizing the boxwood and mahogany chest set of the day's biggest stories developing the central headline pawns, definitely maneuvering an oso topical night to F6, fainting a classic Sicilian, nade door variation on the news, all the while seeing eight moves deep and patiently marshalling the latest press releases into a fisher's shows in Lip Nitsky attack that culminates in the elegant lethal slow-played, all-passant checkmate that is my nightly monologue. But sometimes, sometimes, folks, I. CHEERING AND APPLAUSE Sometimes I startle away, cubside down in the monkey bars of a condemned playground on a super fun site. Get all hept up on goofballs. Rummage that were discarded tag bag of defective toys. Yank out a fist bowl of disembodied doll limbs, toss them on a stained kid's place mat from a defunct dennies. set up a table inside a rusty cargo container down by the Wharf and challenged toothless drifters to the godless bughouse blitz of tournament that is my segment. Meanwhile!' + ``` + + - *Shortform transcription*: If passed mel input features are < 30 seconds, the whole audio will be transcribed with a single call to generate. + + ```python + >>> import torch + >>> from transformers import AutoProcessor, WhisperForConditionalGeneration + >>> from datasets import load_dataset + + >>> processor = AutoProcessor.from_pretrained("openai/whisper-tiny.en") + >>> model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-tiny.en") + + >>> ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation") + + >>> inputs = processor(ds[0]["audio"]["array"], return_tensors="pt") + >>> input_features = inputs.input_features + + >>> generated_ids = model.generate(inputs=input_features) + + >>> transcription = processor.batch_decode(generated_ids, skip_special_tokens=True)[0] + >>> transcription + ' Mr. Quilter is the apostle of the middle classes, and we are glad to welcome his gospel.' + ``` + + """ + # 0. deprecate old inputs + if "inputs" in kwargs: + input_features = kwargs.pop("inputs") + warnings.warn( + "The input name `inputs` is deprecated. Please make sure to use `input_features` instead.", + FutureWarning, + ) + # 1. copy generation config + if generation_config is None: + generation_config = copy.deepcopy(self.generation_config) + else: + generation_config = copy.deepcopy(generation_config) + + # 2. set global generate variables + input_stride = self.model.encoder.conv1.stride[0] * self.model.encoder.conv2.stride[0] + num_segment_frames = input_stride * self.config.max_source_positions + total_input_frames = self._retrieve_total_input_frames( + input_features=input_features, input_stride=input_stride, kwargs=kwargs + ) + is_shortform = total_input_frames <= num_segment_frames + + if is_shortform: + # warn user of ignored inputs + self._maybe_warn_unused_inputs( + condition_on_prev_tokens=condition_on_prev_tokens, + temperature=temperature, + compression_ratio_threshold=compression_ratio_threshold, + logprob_threshold=logprob_threshold, + no_speech_threshold=no_speech_threshold, + total_input_frames=total_input_frames, + ) + + # 3. Make sure generation config is correctly set + # Make sure the generation config is correctly set depending on whether timestamps are to be returned or not + self._set_return_outputs( + return_dict_in_generate=return_dict_in_generate, + return_token_timestamps=return_token_timestamps, + is_shortform=is_shortform, + logprob_threshold=logprob_threshold, + generation_config=generation_config, + ) + self._set_return_timestamps( + return_timestamps=return_timestamps, is_shortform=is_shortform, generation_config=generation_config + ) + self._set_language_and_task( + language=language, task=task, is_multilingual=is_multilingual, generation_config=generation_config + ) + # pass self.config for backward compatibility + self._set_forced_decoder_ids( + task=task, + language=language, + prompt_ids=prompt_ids, + generation_config=generation_config, + config=self.config, + kwargs=kwargs, + ) + self._set_token_ids(generation_config=generation_config, config=self.config, kwargs=kwargs) + self._set_num_frames( + return_token_timestamps=return_token_timestamps, generation_config=generation_config, kwargs=kwargs + ) + self._set_thresholds_and_condition( + generation_config=generation_config, + logprob_threshold=logprob_threshold, + compression_ratio_threshold=compression_ratio_threshold, + no_speech_threshold=no_speech_threshold, + condition_on_prev_tokens=condition_on_prev_tokens, + ) + + # 4. Retrieve logits processors + logits_processor = self._retrieve_logit_processors( + generation_config=generation_config, + logits_processor=logits_processor, + no_speech_threshold=no_speech_threshold, + is_shortform=is_shortform, + num_beams=kwargs.get("num_beams", 1), + ) + + # 5. If we're in shortform mode, simple generate the whole input at once and return the output + if is_shortform: + if temperature is not None: + kwargs["temperature"] = temperature + + outputs = super().generate( + input_features, + generation_config=generation_config, + logits_processor=logits_processor, + stopping_criteria=stopping_criteria, + prefix_allowed_tokens_fn=prefix_allowed_tokens_fn, + synced_gpus=synced_gpus, + **kwargs, + ) + + if generation_config.return_token_timestamps and hasattr(generation_config, "alignment_heads"): + outputs["token_timestamps"] = self._extract_token_timestamps( + outputs, generation_config.alignment_heads, num_frames=generation_config.num_frames + ) + + return outputs + + # 6. Else we're in longform mode which is more complex. + # We need to chunk the audio input depending on when the model generates timestamp tokens + + # 6.1 Set and retrieve global longform generation variables + self._set_condition_on_prev_tokens( + condition_on_prev_tokens=condition_on_prev_tokens, generation_config=generation_config + ) + + timestamp_begin = generation_config.no_timestamps_token_id + 1 + temperatures = [temperature] if not isinstance(temperature, (list, tuple)) else temperature + temperature = temperatures[0] + batch_size = input_features.shape[0] + + max_frames, seek = self._retrieve_max_frames_and_seek( + batch_size=batch_size, attention_mask=attention_mask, total_input_frames=total_input_frames + ) + init_tokens = self._retrieve_init_tokens_from_forced_decoder_ids(generation_config=generation_config) + + # 6.2 Preppare running variables, list for generation + cur_bsz = batch_size + current_segments = [[] for _ in range(batch_size)] + batch_idx_map = list(range(batch_size)) + do_condition_on_prev_tokens = [condition_on_prev_tokens for _ in range(batch_size)] + + # 6.2 Transcribe audio until we reach the end of all input audios + while (seek < max_frames).any(): + # 6.3 NOTE: When in longform transcription mode and batch size > 1 we need to dynamically reduce the batch size during the loop + # in case one audio finished earlier than another one. Thus, we need to keep a table of "previous-index-2-current-index" in order + # to know which original audio is being decoded + # Set updated index map, duration of previously decoded chunks and number of max frames of current decoding chunk + input_features, cur_bsz, batch_idx_map = self._maybe_reduce_batch( + input_features=input_features, + seek=seek, + max_frames=max_frames, + cur_bsz=cur_bsz, + batch_idx_map=batch_idx_map, + ) + time_offset = seek * time_precision / input_stride + seek_num_frames = (max_frames - seek).clamp(max=num_segment_frames) + + # 6.4 cut out next 30s segment from input features + segment_input = self._get_input_segment( + input_features=input_features, + seek=seek, + seek_num_frames=seek_num_frames, + num_segment_frames=num_segment_frames, + cur_bsz=cur_bsz, + batch_idx_map=batch_idx_map, + ) + + # 6.5 prepare decoder input ids + suppress_tokens = _get_attr_from_logit_processors( + logits_processor, SuppressTokensLogitsProcessor, "suppress_tokens" + ) + decoder_input_ids, kwargs = self._prepare_decoder_input_ids( + cur_bsz=cur_bsz, + init_tokens=init_tokens, + current_segments=current_segments, + batch_idx_map=batch_idx_map, + do_condition_on_prev_tokens=do_condition_on_prev_tokens, + generation_config=generation_config, + config=self.config, + device=segment_input.device, + suppress_tokens=suppress_tokens, + kwargs=kwargs, + ) + + # 6.6 set max new tokens or max length + kwargs = self._set_max_new_tokens_and_length( + config=self.config, + decoder_input_ids=decoder_input_ids, + generation_config=generation_config, + kwargs=kwargs, + ) + + # 6.7 Set current `begin_index` for all logit processors + for proc in logits_processor: + if hasattr(proc, "set_begin_index"): + proc.set_begin_index(decoder_input_ids.shape[-1]) + + # 6.8 Run generate with fallback + seek_sequences, seek_outputs, should_skip, do_condition_on_prev_tokens = self.generate_with_fallback( + segment_input=segment_input, + decoder_input_ids=decoder_input_ids, + cur_bsz=cur_bsz, + batch_idx_map=batch_idx_map, + seek=seek, + num_segment_frames=num_segment_frames, + max_frames=max_frames, + temperatures=temperatures, + generation_config=generation_config, + logits_processor=logits_processor, + stopping_criteria=stopping_criteria, + prefix_allowed_tokens_fn=prefix_allowed_tokens_fn, + synced_gpus=synced_gpus, + return_token_timestamps=return_token_timestamps, + do_condition_on_prev_tokens=do_condition_on_prev_tokens, + kwargs=kwargs, + ) + + # 6.9 In every generated sequence, split by timestamp tokens and extract segments + for i, seek_sequence in enumerate(seek_sequences): + prev_i = batch_idx_map[i] + + if should_skip[i]: + seek[prev_i] += seek_num_frames[prev_i] + continue + + segments, segment_offset = self._retrieve_segment( + seek_sequence=seek_sequence, + seek_outputs=seek_outputs, + time_offset=time_offset, + timestamp_begin=timestamp_begin, + seek_num_frames=seek_num_frames, + time_precision=time_precision, + input_stride=input_stride, + prev_idx=prev_i, + idx=i, + ) + + current_segments[prev_i] += segments + seek[prev_i] += segment_offset + + # 7. Once all segments are added to the list of all segments, called `current_segments`, we extract the predicted + # output tokens from the list of dicts. If we use batch size > 1, we make sure to pad the output + sequences = _pad_to_max_length(current_segments, generation_config.pad_token_id, padding="right") + + # 8. If we return all segments, the predicted output sequences are put under `"sequences"`. + if return_segments: + return {"sequences": sequences, "segments": current_segments} + + return sequences + + def generate_with_fallback( + self, + segment_input, + decoder_input_ids, + cur_bsz, + batch_idx_map, + seek, + num_segment_frames, + max_frames, + temperatures, + generation_config, + logits_processor, + stopping_criteria, + prefix_allowed_tokens_fn, + synced_gpus, + return_token_timestamps, + do_condition_on_prev_tokens, + kwargs, + ): + # 6.6 Batch generate current chunk + seek_sequence_list = [None for _ in range(cur_bsz)] + seek_outputs_list = [None for _ in range(cur_bsz)] + needs_fallback = [False for _ in range(cur_bsz)] + should_skip = [False for _ in range(cur_bsz)] + fallback_index_map = list(range(cur_bsz)) + + if generation_config.no_speech_threshold is not None: + self._setup_no_speech_detection(logits_processor, segment_input, decoder_input_ids, kwargs) + + for fallback_idx, temperature in enumerate(temperatures): + generation_config.do_sample = temperature is not None and temperature > 0.0 + generation_config.temperature = temperature + generation_config.num_beams = kwargs.pop("num_beams", 1) if not generation_config.do_sample else 1 + + seek_outputs = super().generate( + segment_input, + generation_config, + logits_processor, + stopping_criteria, + prefix_allowed_tokens_fn, + synced_gpus, + decoder_input_ids=decoder_input_ids, + **kwargs, + ) + + # post-process sequence tokens and outputs to be in list form + sequence_tokens, seek_outputs = self._postprocess_outputs( + seek_outputs, return_token_timestamps, generation_config + ) + + # remove all previously passed decoder input ids + seek_sequences = sequence_tokens[:, decoder_input_ids.shape[-1] :] + + # 6.7 Extract cut sequences from every sequence and check if fallback should be applied + # Loop over each decoded audio individually as each decoding can be of a different length + new_fallback_index_map = [] + new_segment_input = [] + new_decoder_input_ids = [] + new_decoder_attention_mask = [] + + for i, seek_sequence in enumerate(seek_sequences): + # make sure we cut a predicted EOS token if we are not finished with the generation yet + prev_i = batch_idx_map[fallback_index_map[i]] + is_not_final = (seek[prev_i] + num_segment_frames) < max_frames[prev_i] + + # remove eos token id + if is_not_final and seek_sequence[-1] == generation_config.eos_token_id: + seek_sequence = seek_sequence[:-1] + + # remove all padding tokens + if seek_sequence[-1] == generation_config.pad_token_id: + num_paddings = (seek_sequence == generation_config.pad_token_id).sum() + seek_sequence = seek_sequence[:-num_paddings] + + # check which sequences in batch need fallback & which should be skipped + needs_fallback[i], should_skip[i] = self._need_fallback( + seek_sequence, + seek_outputs, + i, + logits_processor, + generation_config, + self.config.vocab_size, + temperature, + ) + + seek_sequence_list[fallback_index_map[i]] = seek_sequence + seek_outputs_list[fallback_index_map[i]] = seek_outputs[i] + do_condition_on_prev_tokens[fallback_index_map[i]] = ( + generation_config.condition_on_prev_tokens and temperature is not None and temperature < 0.5 + ) + + if needs_fallback[i]: + new_fallback_index_map.append(fallback_index_map[i]) + new_segment_input.append(segment_input[i]) + new_decoder_input_ids.append(decoder_input_ids[i]) + if "decoder_attention_mask" in kwargs: + new_decoder_attention_mask.append(kwargs["decoder_attention_mask"][i]) + + fallback_index_map = new_fallback_index_map + + # if no sequence needs to be run with temperature fallback, we're finished + if len(fallback_index_map) == 0 or fallback_idx == len(temperatures) - 1: + seek_sequences = seek_sequence_list + seek_outputs = seek_outputs_list + break + + # if we're still in the loop, make sure that decoder_input_ids and segment inputs are tensors + decoder_input_ids = torch.stack(new_decoder_input_ids) + segment_input = torch.stack(new_segment_input) + if "decoder_attention_mask" in kwargs: + kwargs["decoder_attention_mask"] = torch.stack(new_decoder_attention_mask) + + return seek_sequences, seek_outputs, should_skip, do_condition_on_prev_tokens + + def _postprocess_outputs(self, seek_outputs, return_token_timestamps, generation_config): + if return_token_timestamps and hasattr(generation_config, "alignment_heads"): + num_frames = getattr(generation_config, "num_frames", None) + seek_outputs["token_timestamps"] = self._extract_token_timestamps( + seek_outputs, generation_config.alignment_heads, num_frames=num_frames + ) + + if generation_config.return_dict_in_generate: + + def split_by_batch_index(values, key, batch_idx): + if key == "scores": + return [v[batch_idx].cpu() for v in values] + if key == "past_key_values": + # we don't save `past_key_values` as this is too costly + return None + return values[batch_idx].cpu() + + sequence_tokens = seek_outputs["sequences"] + seek_outputs = [ + {k: split_by_batch_index(v, k, i) for k, v in seek_outputs.items()} + for i in range(sequence_tokens.shape[0]) + ] + else: + sequence_tokens = seek_outputs + + return sequence_tokens, seek_outputs + + def _need_fallback( + self, + seek_sequence, + seek_outputs, + index, + logits_processor, + generation_config, + vocab_size, + temperature, + ): + needs_fallback = False + should_skip = False + if generation_config.compression_ratio_threshold is not None: + compression_ratio = self._retrieve_compression_ratio(seek_sequence, vocab_size) + + if compression_ratio > generation_config.compression_ratio_threshold: + needs_fallback = True + + if generation_config.logprob_threshold is not None: + if "sequences_scores" in seek_outputs[0]: + logprobs = [s["sequences_scores"] for s in seek_outputs][index] + else: + scores = seek_outputs[index]["scores"] + logprobs = self._retrieve_avg_logprobs( + scores, seek_sequence, generation_config.eos_token_id, temperature + ) + + if logprobs < generation_config.logprob_threshold: + needs_fallback = True + + if generation_config.no_speech_threshold is not None: + no_speech_prob = _get_attr_from_logit_processors( + logits_processor, WhisperNoSpeechDetection, "no_speech_prob" + ) + + if ( + logprobs < generation_config.logprob_threshold + and no_speech_prob[index] > generation_config.no_speech_threshold + ): + needs_fallback = False + should_skip = True + + return needs_fallback, should_skip + + @staticmethod + def _setup_no_speech_detection(logits_processor, segment_input, decoder_input_ids, kwargs): + set_inputs = _get_attr_from_logit_processors(logits_processor, WhisperNoSpeechDetection, "set_inputs") + extra_kwargs = {k: v for k, v in kwargs.items() if torch.is_tensor(v)} + set_inputs({"inputs": segment_input, "decoder_input_ids": decoder_input_ids, **extra_kwargs}) + + @staticmethod + def _retrieve_total_input_frames(input_features, input_stride, kwargs): + if input_features is not None: + return input_features.shape[-1] + + if "encoder_outputs" in kwargs: + encoder_outputs_shape = ( + kwargs["encoder_outputs"][0].shape + if isinstance(kwargs["encoder_outputs"], BaseModelOutput) + else kwargs["encoder_outputs"].shape + ) + return encoder_outputs_shape[1] * input_stride + + raise ValueError("Make sure to provide either `input_features` or `encoder_outputs` to `generate`.") + + @staticmethod + def _maybe_warn_unused_inputs( + condition_on_prev_tokens, + temperature, + compression_ratio_threshold, + logprob_threshold, + no_speech_threshold, + total_input_frames, + ): + warning_prefix = ( + f"Audio input consists of only {total_input_frames}. " + "Short-form transcription is activated." + "{}, but will be ignored." + ) + if condition_on_prev_tokens is not None: + logger.warn(warning_prefix.format(f"condition_on_prev_tokens is set to {condition_on_prev_tokens}")) + + if compression_ratio_threshold is not None: + logger.warn(warning_prefix.format(f"compression_ratio_threshold is set to {compression_ratio_threshold}")) + + if logprob_threshold is not None: + logger.warn(warning_prefix.format(f"logprob_threshold is set to {logprob_threshold}")) + + if no_speech_threshold is not None: + logger.warn(warning_prefix.format(f"no_speech_threshold is set to {no_speech_threshold}")) + + # when passing temperature as a list it cannot just be ignored => throw error in this case + if isinstance(temperature, (list, tuple)): + raise ValueError( + f"Audio input consists of only {total_input_frames}. Short-form transcription is activated." + f"temperature cannot be set to {temperature} which can only be used for temperature fallback for long-form generation. Make sure to set `temperature` to a float value or `None` for short-form generation." + ) + + @staticmethod + def _set_return_outputs( + return_dict_in_generate, return_token_timestamps, is_shortform, logprob_threshold, generation_config + ): + if return_dict_in_generate is None: + return_dict_in_generate = generation_config.return_dict_in_generate + + generation_config.return_token_timestamps = return_token_timestamps + if return_token_timestamps: + return_dict_in_generate = True + generation_config.output_attentions = True + generation_config.output_scores = True + + if not is_shortform and logprob_threshold is not None: + return_dict_in_generate = True + generation_config.output_scores = True + + generation_config.return_dict_in_generate = return_dict_in_generate + + @staticmethod + def _set_return_timestamps(return_timestamps, is_shortform, generation_config): + if return_timestamps is True: + if not hasattr(generation_config, "no_timestamps_token_id"): + raise ValueError( + "You are trying to return timestamps, but the generation config is not properly set. " + "Make sure to initialize the generation config with the correct attributes that are needed such as `no_timestamps_token_id`. " + "For more details on how to generate the approtiate config, refer to https://github.com/huggingface/transformers/issues/21878#issuecomment-1451902363" + ) + generation_config.return_timestamps = True + elif not is_shortform: + if return_timestamps is False: + raise ValueError( + "You have passed more than 3000 mel input features (> 30 seconds) which automatically enables long-form generation which " + "requires the model to predict timestamp tokens. Please either pass `return_timestamps=True` or make sure to pass no more than 3000 mel input features." + ) + + if not hasattr(generation_config, "no_timestamps_token_id"): + raise ValueError( + "You have passed more than 3000 mel input features (> 30 seconds) which automatically enables long-form generation which " + "requires the generation config to have `no_timestamps_token_id` correctly. " + "Make sure to initialize the generation config with the correct attributes that are needed such as `no_timestamps_token_id`. " + "For more details on how to generate the approtiate config, refer to https://github.com/huggingface/transformers/issues/21878#issuecomment-1451902363" + "or make sure to pass no more than 3000 mel input features." + ) + + logger.info("Setting `return_timestamps=True` for long-form generation.") + generation_config.return_timestamps = True + else: + generation_config.return_timestamps = False + + @staticmethod + def _set_language_and_task(language, task, is_multilingual, generation_config): + if is_multilingual is not None: + if not hasattr(generation_config, "is_multilingual"): + raise ValueError( + "The generation config is outdated and is thus not compatible with the `is_multilingual` argument " + "to `generate`. Please update the generation config as per the instructions " + "https://github.com/huggingface/transformers/issues/25084#issuecomment-1664398224" + ) + generation_config.is_multilingual = is_multilingual + + if hasattr(generation_config, "is_multilingual") and not generation_config.is_multilingual: + if task is not None or language is not None: + raise ValueError( + "Cannot specify `task` or `language` for an English-only model. If the model is intended to be " + "multilingual, pass `is_multilingual=True` to generate, or update the generation config." + ) + + if language is not None: + if not hasattr(generation_config, "lang_to_id"): + raise ValueError( + "The generation config is outdated and is thus not compatible with the `language` argument " + "to `generate`. Either set the language using the `forced_decoder_ids` in the model config, " + "or update the generation config as per the instructions https://github.com/huggingface/transformers/issues/25084#issuecomment-1664398224" + ) + language = language.lower() + generation_config.language = language + + if task is not None: + if not hasattr(generation_config, "task_to_id"): + raise ValueError( + "The generation config is outdated and is thus not compatible with the `task` argument " + "to `generate`. Either set the task using the `forced_decoder_ids` in the model config, " + "or update the generation config as per the instructions https://github.com/huggingface/transformers/issues/25084#issuecomment-1664398224" + ) + generation_config.task = task + + @staticmethod + def _set_forced_decoder_ids(task, language, prompt_ids, generation_config, config, kwargs): + forced_decoder_ids = None + # Legacy code for backward compatibility + if hasattr(config, "forced_decoder_ids") and config.forced_decoder_ids is not None: + forced_decoder_ids = config.forced_decoder_ids + elif hasattr(generation_config, "forced_decoder_ids") and generation_config.forced_decoder_ids is not None: + forced_decoder_ids = generation_config.forced_decoder_ids + else: + forced_decoder_ids = kwargs.pop("forced_decoder_ids", None) + + if task is not None or language is not None or (forced_decoder_ids is None and prompt_ids is not None): + forced_decoder_ids = [] + if hasattr(generation_config, "language"): + if generation_config.language in generation_config.lang_to_id.keys(): + language_token = generation_config.language + elif generation_config.language in TO_LANGUAGE_CODE.keys(): + language_token = f"<|{TO_LANGUAGE_CODE[generation_config.language]}|>" + elif generation_config.language in TO_LANGUAGE_CODE.values(): + language_token = f"<|{generation_config.language}|>" + else: + is_language_code = len(generation_config.language) == 2 + raise ValueError( + f"Unsupported language: {generation_config.language}. Language should be one of:" + f" {list(TO_LANGUAGE_CODE.values()) if is_language_code else list(TO_LANGUAGE_CODE.keys())}." + ) + if language_token not in generation_config.lang_to_id: + raise ValueError( + f"{language_token} is not supported by this specific model as it is not in the `generation_config.lang_to_id`." + "(You should just add it to the generation config)" + ) + forced_decoder_ids.append((1, generation_config.lang_to_id[language_token])) + else: + forced_decoder_ids.append((1, None)) # automatically detect the language + + if hasattr(generation_config, "task"): + if generation_config.task in TASK_IDS: + forced_decoder_ids.append((2, generation_config.task_to_id[generation_config.task])) + else: + raise ValueError( + f"The `{generation_config.task}`task is not supported. The task should be one of `{TASK_IDS}`" + ) + elif hasattr(generation_config, "task_to_id"): + forced_decoder_ids.append((2, generation_config.task_to_id["transcribe"])) # defaults to transcribe + if hasattr(generation_config, "no_timestamps_token_id") and not generation_config.return_timestamps: + idx = forced_decoder_ids[-1][0] + 1 if forced_decoder_ids else 1 + forced_decoder_ids.append((idx, generation_config.no_timestamps_token_id)) + + if forced_decoder_ids is not None: + generation_config.forced_decoder_ids = forced_decoder_ids + + if prompt_ids is not None: + if kwargs.get("decoder_start_token_id") is not None: + raise ValueError( + "When specifying `prompt_ids`, you cannot also specify `decoder_start_token_id` as it gets overwritten." + ) + prompt_ids = prompt_ids.tolist() + decoder_start_token_id, *text_prompt_ids = prompt_ids + # Slicing the text prompt ids in a manner consistent with the OpenAI implementation + # to accomodate context space for the prefix (see https://github.com/openai/whisper/blob/c09a7ae299c4c34c5839a76380ae407e7d785914/whisper/decoding.py#L599) + text_prompt_ids = text_prompt_ids[-config.max_target_positions // 2 - 1 :] + # Set the decoder_start_token_id to <|startofprev|> + kwargs.update({"decoder_start_token_id": decoder_start_token_id}) + + # If the user passes `max_new_tokens`, increase its number to account for the prompt + if kwargs.get("max_new_tokens", None) is not None: + kwargs["max_new_tokens"] += len(text_prompt_ids) + if kwargs["max_new_tokens"] >= config.max_target_positions: + raise ValueError( + f"The length of the sliced `prompt_ids` is {len(text_prompt_ids)}, and the `max_new_tokens` " + f"{kwargs['max_new_tokens'] - len(text_prompt_ids)}. Thus, the combined length of the sliced " + f"`prompt_ids` and `max_new_tokens` is: {kwargs['max_new_tokens']}. This exceeds the " + f"`max_target_positions` of the Whisper model: {config.max_target_positions}. " + "You should either reduce the length of your prompt, or reduce the value of `max_new_tokens`, " + f"so that their combined length is less that {config.max_target_positions}." + ) + + # Reformat the forced_decoder_ids to incorporate the prompt + non_prompt_forced_decoder_ids = ( + kwargs.pop("forced_decoder_ids", None) or generation_config.forced_decoder_ids + ) + forced_decoder_ids = [ + *text_prompt_ids, + generation_config.decoder_start_token_id, + *[token for _, token in non_prompt_forced_decoder_ids], + ] + forced_decoder_ids = [(rank + 1, token) for rank, token in enumerate(forced_decoder_ids)] + generation_config.forced_decoder_ids = forced_decoder_ids + + @staticmethod + def _set_token_ids(generation_config, config, kwargs): + eos_token_id = kwargs.pop("eos_token_id", None) + decoder_start_token_id = kwargs.pop("decoder_start_token_id", None) + + eos_token_id = eos_token_id if eos_token_id is not None else generation_config.eos_token_id + decoder_start_token_id = ( + decoder_start_token_id if decoder_start_token_id is not None else generation_config.decoder_start_token_id + ) + + generation_config.eos_token_id = eos_token_id if eos_token_id is not None else config.eos_token_id + generation_config.decoder_start_token_id = ( + decoder_start_token_id if decoder_start_token_id is not None else config.decoder_start_token_id + ) + + @staticmethod + def _set_num_frames(return_token_timestamps, generation_config, kwargs): + if return_token_timestamps: + if getattr(generation_config, "task", None) == "translate": + logger.warning("Token-level timestamps may not be reliable for task 'translate'.") + if not hasattr(generation_config, "alignment_heads"): + raise ValueError( + "Model generation config has no `alignment_heads`, token-level timestamps not available. " + "See https://gist.github.com/hollance/42e32852f24243b748ae6bc1f985b13a on how to add this property to the generation config." + ) + + generation_config.num_frames = kwargs.pop("num_frames", None) + + @staticmethod + def _set_thresholds_and_condition( + generation_config, + logprob_threshold, + compression_ratio_threshold, + no_speech_threshold, + condition_on_prev_tokens, + ): + generation_config.logprob_threshold = ( + logprob_threshold + if logprob_threshold is not None + else getattr(generation_config, "logprob_threshold", None) + ) + generation_config.compression_ratio_threshold = ( + compression_ratio_threshold + if compression_ratio_threshold is not None + else getattr(generation_config, "compression_ratio_threshold", None) + ) + generation_config.no_speech_threshold = ( + no_speech_threshold + if no_speech_threshold is not None + else getattr(generation_config, "no_speech_threshold", None) + ) + generation_config.condition_on_prev_tokens = ( + condition_on_prev_tokens + if condition_on_prev_tokens is not None + else getattr(generation_config, "condition_on_prev_tokens", None) + ) + + @staticmethod + def _set_condition_on_prev_tokens(condition_on_prev_tokens, generation_config): + condition_on_prev_tokens = ( + condition_on_prev_tokens + if condition_on_prev_tokens is not None + else getattr(generation_config, "condition_on_prev_tokens", False) + ) + generation_config.condition_on_prev_tokens = condition_on_prev_tokens + + @staticmethod + def _retrieve_max_frames_and_seek(batch_size, attention_mask, total_input_frames): + if batch_size > 1 and attention_mask is None: + raise ValueError( + "When doing long-form audio transcription, make sure to pass an `attention_mask`. You can retrieve the `attention_mask` by doing `processor(audio, ..., return_attention_mask=True)` " + ) + elif batch_size > 1: + max_frames = attention_mask.sum(-1).cpu().to(torch.long) + seek = torch.zeros((batch_size,), dtype=torch.long) + else: + max_frames = torch.ones((1,), dtype=torch.long) * total_input_frames + seek = torch.zeros((1,), dtype=torch.long) + + return max_frames, seek + + @staticmethod + def _retrieve_init_tokens_from_forced_decoder_ids(generation_config): + init_tokens = [generation_config.decoder_start_token_id] + forced_decoder_ids = generation_config.forced_decoder_ids + if forced_decoder_ids is not None and forced_decoder_ids[0][0] == 1: + i = 1 + while len(forced_decoder_ids) > 0 and forced_decoder_ids[0][0] == i: + init_tokens += [forced_decoder_ids[0][1]] + forced_decoder_ids = forced_decoder_ids[1:] + i += 1 + + forced_decoder_ids = forced_decoder_ids if len(forced_decoder_ids) > 0 else None + generation_config.forced_decoder_ids = forced_decoder_ids + + return init_tokens + + def _retrieve_logit_processors( + self, generation_config, logits_processor, no_speech_threshold, is_shortform, num_beams + ): + forced_decoder_ids = generation_config.forced_decoder_ids + if generation_config.return_timestamps is True: + last_forced_decoder_ids = forced_decoder_ids[-1][-1] if forced_decoder_ids is not None else None + if last_forced_decoder_ids == generation_config.no_timestamps_token_id: + # remove no_timestamp to be forcefully generated if we want to return timestamps + # this is also important to make sure `WhisperTimeStampLogitsProcessor` functions correctly + forced_decoder_ids = forced_decoder_ids[:-1] if len(forced_decoder_ids) > 1 else None + # Make sure that if list is empty we set it to None + generation_config.forced_decoder_ids = forced_decoder_ids + + begin_index = len(forced_decoder_ids) + 1 if forced_decoder_ids is not None else 1 + + if generation_config.return_timestamps is True: + timestamp_processor = WhisperTimeStampLogitsProcessor(generation_config, begin_index=begin_index) + logits_processor = ( + [timestamp_processor] if logits_processor is None else [timestamp_processor] + logits_processor + ) + + if generation_config.suppress_tokens is not None: + suppress_tokens_processor = SuppressTokensLogitsProcessor(generation_config.suppress_tokens) + logits_processor = ( + [suppress_tokens_processor] + if logits_processor is None + else [suppress_tokens_processor] + logits_processor + ) + generation_config.suppress_tokens = None + + if generation_config.begin_suppress_tokens is not None: + begin_suppress_processor = SuppressTokensAtBeginLogitsProcessor( + generation_config.begin_suppress_tokens, begin_index=begin_index + ) + logits_processor = ( + [begin_suppress_processor] + if logits_processor is None + else [begin_suppress_processor] + logits_processor + ) + generation_config.begin_suppress_tokens = None + + if no_speech_threshold is not None and not is_shortform: + no_speech_detector = WhisperNoSpeechDetection( + no_speech_token=generation_config.no_timestamps_token_id - 1, + begin_index=begin_index, + scores_is_logprobs=num_beams > 1, + ) + logits_processor = ( + [no_speech_detector] if logits_processor is None else [no_speech_detector] + logits_processor + ) + no_speech_detector.set_model(self) + + if is_shortform and generation_config.forced_decoder_ids is not None: + forced_tokens_proc = ForceTokensLogitsProcessor(generation_config.forced_decoder_ids) + # TODO(Patrick): It's important that the `forced_tokens_proc` processor is appended after + # the suppress_tokens processor or else it might happen that all token logits are suppressed to -inf + # which would lead to unexpected behavior + # The better approach here is to NOT make use of the `forced_tokens_proc` for Whisper and instead + # initialize all of them as `decoder_input_ids`. + logits_processor = ( + [forced_tokens_proc] if logits_processor is None else logits_processor + [forced_tokens_proc] + ) + generation_config.forced_decoder_ids = None + + return logits_processor + + @staticmethod + def _maybe_reduce_batch(input_features, seek, max_frames, cur_bsz, batch_idx_map): + prev_bsz = cur_bsz + new_batch_idx_map = [] + for i in range(prev_bsz): + prev_i = batch_idx_map[i] + if seek[prev_i] >= max_frames[prev_i]: + cut_index = i + (cur_bsz - prev_bsz) + cur_bsz -= 1 + input_features = torch.cat([input_features[:cut_index], input_features[cut_index + 1 :]], dim=0) + else: + # cut out index that goes away + new_batch_idx_map.append(prev_i) + + return input_features, cur_bsz, new_batch_idx_map + + @staticmethod + def _get_input_segment(input_features, seek, seek_num_frames, num_segment_frames, cur_bsz, batch_idx_map): + segment_input = [] + for i in range(cur_bsz): + prev_i = batch_idx_map[i] + segment_input_slice = input_features[i : i + 1, :, seek[prev_i] : seek[prev_i] + seek_num_frames[prev_i]] + + if segment_input_slice.shape[-1] < num_segment_frames: + # pad to 3000 if necessary + segment_input_slice = F.pad( + segment_input_slice, pad=(0, num_segment_frames - segment_input_slice.shape[-1]) + ) + + segment_input.append(segment_input_slice) + + segment_input = torch.cat(segment_input, dim=0) + + return segment_input + + @staticmethod + def _prepare_decoder_input_ids( + cur_bsz, + init_tokens, + current_segments, + batch_idx_map, + do_condition_on_prev_tokens, + generation_config, + config, + device, + suppress_tokens, + kwargs, + ): + cut_off_length = config.max_target_positions // 2 - 1 + + one_tensor = torch.ones((cur_bsz, 1), device=device, dtype=torch.long) + decoder_input_ids = torch.cat([t * one_tensor for t in init_tokens], dim=-1) + + prev_start_of_text = getattr(generation_config, "prev_sot_token_id", None) + if prev_start_of_text is None: + prev_start_of_text = suppress_tokens[-2] if suppress_tokens is not None else None + + if any(do_condition_on_prev_tokens) and len(current_segments[0]) > 0: + # according to https://github.com/openai/whisper/blob/e58f28804528831904c3b6f2c0e473f346223433/whisper/decoding.py#L609 + active_segments = [current_segments[i] if do_condition_on_prev_tokens[i] else None for i in batch_idx_map] + prev_start_of_text = getattr(generation_config, "prev_bos_token_id", None) or prev_start_of_text + + bos_token_tensor = prev_start_of_text * one_tensor[0] + prev_tokens = _pad_to_max_length( + active_segments, + generation_config.pad_token_id, + padding="left", + bos_token_tensor=bos_token_tensor, + cut_off_length=cut_off_length, + ) + decoder_input_ids = torch.cat([prev_tokens, decoder_input_ids], dim=-1) + + kwargs["decoder_attention_mask"] = decoder_input_ids != generation_config.pad_token_id + else: + # make sure `"decoder_attention_mask"` is not passed to forward + kwargs.pop("decoder_attention_mask", None) + + return decoder_input_ids, kwargs + + @staticmethod + def _set_max_new_tokens_and_length(config, decoder_input_ids, generation_config, kwargs): + num_initial_tokens = min(config.max_target_positions // 2 - 1, decoder_input_ids.shape[-1] - 1) + + passed_max_length = kwargs.pop("max_length", None) + passed_max_new_tokens = kwargs.pop("max_new_tokens", None) + max_length_config = getattr(generation_config, "max_length", None) + max_new_tokens_config = getattr(generation_config, "max_new_tokens", None) + + max_new_tokens = None + max_length = None + + # Make sure we don't get larger than `max_length` + if passed_max_length is not None and passed_max_new_tokens is None: + max_length = min(passed_max_length + num_initial_tokens, config.max_target_positions) + logger.info( + f"Increase max_length from {passed_max_length} to {max_length} since input is conditioned on previous segment." + ) + elif max_length_config is not None and passed_max_new_tokens is None and max_new_tokens_config is None: + max_length = min(generation_config.max_length + num_initial_tokens, config.max_target_positions) + logger.info( + f"Increase max_length from {max_length_config} to {max_length} since input is conditioned on previous segment." + ) + elif ( + passed_max_new_tokens is not None + and passed_max_new_tokens + decoder_input_ids.shape[-1] > config.max_target_positions + ): + max_new_tokens = config.max_target_positions - decoder_input_ids.shape[-1] + elif ( + passed_max_new_tokens is None + and max_new_tokens_config is not None + and max_new_tokens_config + decoder_input_ids.shape[-1] > config.max_target_positions + ): + max_new_tokens = config.max_target_positions - decoder_input_ids.shape[-1] + + if max_new_tokens is not None: + kwargs["max_new_tokens"] = max_new_tokens + + if max_length is not None: + kwargs["max_length"] = max_length + + return kwargs + + @staticmethod + def _retrieve_compression_ratio(tokens, vocab_size): + """Compute byte length of zlib compressed token bytes vs. byte length of raw token bytes""" + length = int(math.log2(vocab_size) / 8) + 1 + token_bytes = b"".join([t.to_bytes(length, "little") for t in tokens.tolist()]) + compression_ratio = len(token_bytes) / len(zlib.compress(token_bytes)) + + return compression_ratio + + @staticmethod + def _retrieve_avg_logprobs(scores, tokens, eos_token_id, temperature): + rescale_temperature = temperature if temperature > 0.0 else 1 + scores = torch.stack(scores).to(tokens.device) + + if scores.shape[0] > tokens.shape[0]: + scores = scores[: tokens.shape[0]] + else: + tokens = tokens[-scores.shape[0] :] + + logprobs = F.log_softmax((scores * rescale_temperature).float(), dim=-1).to(scores.dtype) + + # retrieve logprob of selected tokens and sum + sum_logprobs = sum((logprobs[i][tokens[i]] * (tokens[i] != eos_token_id)) for i in range(logprobs.shape[0])) + length = (tokens != eos_token_id).sum(-1) if eos_token_id is not None else tokens.shape[0] + + avg_logprobs = sum_logprobs / (length + 1) + return avg_logprobs + + @staticmethod + def _retrieve_segment( + seek_sequence, + seek_outputs, + time_offset, + timestamp_begin, + seek_num_frames, + time_precision, + input_stride, + prev_idx, + idx, + ): + # find the predicted "end of segment" predictions of Whisper + # "end of segment" predictions occur whenever Whisper predicts a timestamp token + timestamp_tokens: torch.Tensor = seek_sequence.ge(timestamp_begin) + single_timestamp_ending = timestamp_tokens[-2:].tolist() == [False, True] + timestamp_segment_indices = torch.where(timestamp_tokens[:-1] & timestamp_tokens[1:])[0] + timestamp_segment_indices.add_(1) + + # If whisper predicted a "end of segment" via a timestep token, let's go ever each + # "end of segment" prediction and slice the decoding into segments accordingly + if len(timestamp_segment_indices) > 0: + # if the output contains two consecutive timestamp tokens + slices = timestamp_segment_indices.tolist() + segments = [] + if single_timestamp_ending: + slices.append(len(seek_sequence)) + + last_slice = 0 + # Add each segment to list of all segments + for current_slice in slices: + sliced_tokens = seek_sequence[last_slice:current_slice] + start_timestamp_pos = sliced_tokens[0].item() - timestamp_begin + end_timestamp_pos = sliced_tokens[-1].item() - timestamp_begin + segments.append( + { + "start": time_offset[prev_idx] + start_timestamp_pos * time_precision, + "end": time_offset[prev_idx] + end_timestamp_pos * time_precision, + "tokens": sliced_tokens, + "result": seek_outputs[idx], + } + ) + last_slice = current_slice + + if single_timestamp_ending: + # single timestamp at the end means no speech after the last timestamp. + segment_offset = seek_num_frames[prev_idx] + else: + # otherwise, ignore the unfinished segment and seek to the last timestamp + # here we throw away all predictions after the last predicted "end of segment" + # since we are cutting right in the middle of an audio + last_timestamp_pos = seek_sequence[last_slice - 1].item() - timestamp_begin + segment_offset = last_timestamp_pos * input_stride + else: + # If whisper does not predict any "end of segment" token, then + # the whole decoding is considered a segment and we add it to the list of segments + timestamps = seek_sequence[timestamp_tokens.nonzero().flatten()] + last_timestamp_pos = seek_num_frames[prev_idx] + if timestamps.numel() > 0 and timestamps[-1].item() != timestamp_begin: + # no consecutive timestamps but it has a timestamp; use the last one. + last_timestamp_pos = timestamps[-1].item() - timestamp_begin + + segments = [ + { + "start": time_offset[prev_idx], + "end": time_offset[prev_idx] + last_timestamp_pos * time_precision, + "tokens": seek_sequence, + "result": seek_outputs[idx], + } + ] + segment_offset = seek_num_frames[prev_idx] + + return segments, segment_offset diff --git a/src/transformers/models/whisper/modeling_whisper.py b/src/transformers/models/whisper/modeling_whisper.py index 1e68f4f63e9aff..76ea27a954a84a 100644 --- a/src/transformers/models/whisper/modeling_whisper.py +++ b/src/transformers/models/whisper/modeling_whisper.py @@ -13,10 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. """ PyTorch Whisper model.""" - -import copy import math -import warnings from typing import Optional, Tuple, Union import numpy as np @@ -27,7 +24,6 @@ from torch.nn import CrossEntropyLoss from ...activations import ACT2FN -from ...generation.logits_process import WhisperTimeStampLogitsProcessor from ...modeling_attn_mask_utils import _prepare_4d_causal_attention_mask, _prepare_4d_causal_attention_mask_for_sdpa from ...modeling_outputs import ( BaseModelOutput, @@ -47,7 +43,7 @@ replace_return_docstrings, ) from .configuration_whisper import WhisperConfig -from .tokenization_whisper import TASK_IDS, TO_LANGUAGE_CODE +from .generation_whisper import WhisperGenerationMixin if is_flash_attn_2_available(): @@ -231,87 +227,15 @@ def compute_num_masked_span(input_length): return spec_aug_mask -def _median_filter(inputs: torch.Tensor, filter_width: int) -> torch.Tensor: - """ - Applies a median filter of width `filter_width` along the last dimension of the input. - - The `inputs` tensor is assumed to be 3- or 4-dimensional. - """ - if filter_width <= 0 or filter_width % 2 != 1: - raise ValueError("`filter_width` should be an odd number") - - pad_width = filter_width // 2 - if inputs.shape[-1] <= pad_width: - return inputs - - # Pad the left and right edges. - inputs = nn.functional.pad(inputs, (pad_width, pad_width, 0, 0), mode="reflect") - - # sort() is faster than torch.median (https://github.com/pytorch/pytorch/issues/51450) - result = inputs.unfold(-1, filter_width, 1).sort()[0][..., pad_width] - return result - - -def _dynamic_time_warping(matrix: np.ndarray): - """ - Measures similarity between two temporal sequences: the input audio and the output tokens. Used to generate - token-level timestamps. - """ - output_length, input_length = matrix.shape - cost = np.ones((output_length + 1, input_length + 1), dtype=np.float32) * np.inf - trace = -np.ones((output_length + 1, input_length + 1), dtype=np.float32) - - cost[0, 0] = 0 - for j in range(1, input_length + 1): - for i in range(1, output_length + 1): - c0 = cost[i - 1, j - 1] - c1 = cost[i - 1, j] - c2 = cost[i, j - 1] - - if c0 < c1 and c0 < c2: - c, t = c0, 0 - elif c1 < c0 and c1 < c2: - c, t = c1, 1 - else: - c, t = c2, 2 - - cost[i, j] = matrix[i - 1, j - 1] + c - trace[i, j] = t - - # backtrace - i = trace.shape[0] - 1 - j = trace.shape[1] - 1 - trace[0, :] = 2 - trace[:, 0] = 1 - - text_indices = [] - time_indices = [] - while i > 0 or j > 0: - text_indices.append(i - 1) - time_indices.append(j - 1) - if trace[i, j] == 0: - i -= 1 - j -= 1 - elif trace[i, j] == 1: - i -= 1 - elif trace[i, j] == 2: - j -= 1 - else: - raise RuntimeError( - f"Internal error in dynamic time warping. Unexpected trace[{i}, {j}]. Please file a bug report." - ) - - text_indices = np.array(text_indices)[::-1] - time_indices = np.array(time_indices)[::-1] - return text_indices, time_indices - - class WhisperPositionalEmbedding(nn.Embedding): def __init__(self, num_positions: int, embedding_dim: int, padding_idx: Optional[int] = None): super().__init__(num_positions, embedding_dim) - def forward(self, input_ids, past_key_values_length=0): - return self.weight[past_key_values_length : past_key_values_length + input_ids.shape[1]] + def forward(self, input_ids, past_key_values_length=0, position_ids=None): + if position_ids is None: + return self.weight[past_key_values_length : past_key_values_length + input_ids.shape[1]] + else: + return self.weight[position_ids] class WhisperAttention(nn.Module): @@ -1358,6 +1282,7 @@ def forward( cross_attn_head_mask=None, past_key_values=None, inputs_embeds=None, + position_ids=None, use_cache=None, output_attentions=None, output_hidden_states=None, @@ -1461,9 +1386,13 @@ def forward( # embed positions if input_ids is not None: - positions = self.embed_positions(input_ids, past_key_values_length=past_key_values_length) + positions = self.embed_positions( + input_ids, past_key_values_length=past_key_values_length, position_ids=position_ids + ) else: - positions = self.embed_positions(inputs_embeds, past_key_values_length=past_key_values_length) + positions = self.embed_positions( + inputs_embeds, past_key_values_length=past_key_values_length, position_ids=position_ids + ) hidden_states = inputs_embeds + positions hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) @@ -1645,6 +1574,7 @@ def forward( encoder_outputs: Optional[Tuple[Tuple[torch.FloatTensor]]] = None, past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None, decoder_inputs_embeds: Optional[Tuple[torch.FloatTensor]] = None, + decoder_position_ids: Optional[Tuple[torch.LongTensor]] = None, use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, @@ -1703,6 +1633,7 @@ def forward( cross_attn_head_mask=cross_attn_head_mask, past_key_values=past_key_values, inputs_embeds=decoder_inputs_embeds, + position_ids=decoder_position_ids, use_cache=use_cache, output_attentions=output_attentions, output_hidden_states=output_hidden_states, @@ -1728,7 +1659,7 @@ def forward( "The Whisper Model with a language modeling head. Can be used for automatic speech recognition.", WHISPER_START_DOCSTRING, ) -class WhisperForConditionalGeneration(WhisperPreTrainedModel): +class WhisperForConditionalGeneration(WhisperGenerationMixin, WhisperPreTrainedModel): base_model_prefix = "model" _tied_weights_keys = ["proj_out.weight"] @@ -1776,6 +1707,7 @@ def forward( encoder_outputs: Optional[Tuple[Tuple[torch.FloatTensor]]] = None, past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None, decoder_inputs_embeds: Optional[Tuple[torch.FloatTensor]] = None, + decoder_position_ids: Optional[Tuple[torch.LongTensor]] = None, labels: Optional[torch.LongTensor] = None, use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, @@ -1830,6 +1762,7 @@ def forward( cross_attn_head_mask=cross_attn_head_mask, past_key_values=past_key_values, decoder_inputs_embeds=decoder_inputs_embeds, + decoder_position_ids=decoder_position_ids, use_cache=use_cache, output_attentions=output_attentions, output_hidden_states=output_hidden_states, @@ -1860,647 +1793,6 @@ def forward( encoder_attentions=outputs.encoder_attentions, ) - def generate( - self, - input_features: Optional[torch.Tensor] = None, - generation_config=None, - logits_processor=None, - stopping_criteria=None, - prefix_allowed_tokens_fn=None, - synced_gpus=False, - return_timestamps=None, - task=None, - language=None, - is_multilingual=None, - prompt_ids: Optional[torch.Tensor] = None, - num_segment_frames: Optional[int] = None, - return_token_timestamps: Optional[bool] = None, - return_segments: bool = False, - attention_mask: Optional[torch.Tensor] = None, - time_precision: int = 0.02, - return_dict_in_generate: Optional[bool] = None, - **kwargs, - ): - """ - Transcribes or translates passed mel input features to a sequence of token ids. - - - - Most generation-controlling parameters are set in `generation_config` which, if not passed, will be set to the - model's default generation configuration. You can override any `generation_config` by passing the corresponding - parameters to generate(), e.g. `.generate(inputs, num_beams=4, do_sample=True)`. - - For an overview of generation strategies and code examples, check out the [following - guide](./generation_strategies). - - - - Parameters: - inputs (`torch.Tensor` of varying shape depending on the modality, *optional*): - The sequence used as a prompt for the generation or as model inputs to the encoder. If `None` the - method initializes it with `bos_token_id` and a batch size of 1. For decoder-only models `inputs` - should of in the format of `input_ids`. For encoder-decoder models *inputs* can represent any of - `input_ids`, `input_values`, `input_features`, or `pixel_values`. - generation_config (`~generation.GenerationConfig`, *optional*): - The generation configuration to be used as base parametrization for the generation call. `**kwargs` - passed to generate matching the attributes of `generation_config` will override them. If - `generation_config` is not provided, the default will be used, which had the following loading - priority: 1) from the `generation_config.json` model file, if it exists; 2) from the model - configuration. Please note that unspecified parameters will inherit [`~generation.GenerationConfig`]'s - default values, whose documentation should be checked to parameterize generation. - logits_processor (`LogitsProcessorList`, *optional*): - Custom logits processors that complement the default logits processors built from arguments and - generation config. If a logit processor is passed that is already created with the arguments or a - generation config an error is thrown. This feature is intended for advanced users. - stopping_criteria (`StoppingCriteriaList`, *optional*): - Custom stopping criteria that complement the default stopping criteria built from arguments and a - generation config. If a stopping criteria is passed that is already created with the arguments or a - generation config an error is thrown. This feature is intended for advanced users. - prefix_allowed_tokens_fn (`Callable[[int, torch.Tensor], List[int]]`, *optional*): - If provided, this function constraints the beam search to allowed tokens only at each step. If not - provided no constraint is applied. This function takes 2 arguments: the batch ID `batch_id` and - `input_ids`. It has to return a list with the allowed tokens for the next generation step conditioned - on the batch ID `batch_id` and the previously generated tokens `inputs_ids`. This argument is useful - for constrained generation conditioned on the prefix, as described in [Autoregressive Entity - Retrieval](https://arxiv.org/abs/2010.00904). - synced_gpus (`bool`, *optional*, defaults to `False`): - Whether to continue running the while loop until max_length (needed for ZeRO stage 3) - return_timestamps (`bool`, *optional*): - Whether to return the timestamps with the text. This enables the `WhisperTimestampsLogitsProcessor`. - task (`str`, *optional*): - Task to use for generation, either "translate" or "transcribe". The `model.config.forced_decoder_ids` - will be updated accordingly. - language (`str`, *optional*): - Language token to use for generation, can be either in the form of `<|en|>`, `en` or `english`. You can - find all the possible language tokens in the `model.generation_config.lang_to_id` dictionary. - is_multilingual (`bool`, *optional*): - Whether or not the model is multilingual. - prompt_ids (`torch.Tensor`, *optional*): - Rank-1 tensor of token IDs created by passing text to [`~WhisperProcessor.get_prompt_ids`] that is - provided as a prompt to each chunk. This can be used to provide or "prompt-engineer" a context for - transcription, e.g. custom vocabularies or proper nouns to make it more likely to predict those words - correctly. It cannot be used in conjunction with `decoder_start_token_id` as it overwrites this value. - return_token_timestamps (`bool`, *optional*): - Whether to return token-level timestamps with the text. This can be used with or without the - `return_timestamps` option. To get word-level timestamps, use the tokenizer to group the tokens into - words. - return_segments (`bool`, *optional*, defaults to `False`): - Whether to additionally return a list of all segments. Note that this option can only be enabled - when doing long-form transcription. - attention_mask (`torch.Tensor`, *optional*): - `attention_mask` needs to be passed when doing long-form transcription using a batch size > 1. - time_precision (`int`, *optional*, defaults to 0.02): - The duration of output token in seconds. *E.g.* 0.02 means that a generated token on average accounts - for 20 ms. - return_dict_in_generate (`bool`, *optional*, defaults to `False`): - Whether or not to return a [`~utils.ModelOutput`] instead of just returning the generated tokens. - Note that when doing long-form transcription, `return_dict_in_generate` can only be enabled when - `return_segments` is set True. In this case the generation outputs of each segment is added to each - segment. - kwargs (`Dict[str, Any]`, *optional*): - Ad hoc parametrization of `generate_config` and/or additional model-specific kwargs that will be - forwarded to the `forward` function of the model. If the model is an encoder-decoder model, encoder - specific kwargs should not be prefixed and decoder specific kwargs should be prefixed with *decoder_*. - - Return: - [`~utils.ModelOutput`] or `torch.LongTensor` or `Dict[str, Any]`: A [`~utils.ModelOutput`] (if `return_dict_in_generate=True` - or when `config.return_dict_in_generate=True`) or a `torch.FloatTensor` or a dict of segments when `return_segments=True`. - - If the passed input is > 30 seconds / > 3000 mel input features and `return_segments=True` then a dictionary of generated sequence ids, called `sequences` and a list of each generated segment is returned. - - else if the passed input is <= 30 seconds / >= 3000 mel input features, the possible [`~utils.ModelOutput`] types are: - - - [`~generation.GenerateEncoderDecoderOutput`], - - [`~generation.GenerateBeamEncoderDecoderOutput`] - - else only the generated output sequence ids are returned. - - Example: - - - *Longform transcription*: To transcribe or translate audios longer than 30 seconds, process the audio files without truncation and pass all mel features at once to generate. - - ```python - >>> import torch - >>> from transformers import AutoProcessor, WhisperForConditionalGeneration - >>> from datasets import load_dataset, Audio - - >>> processor = AutoProcessor.from_pretrained("openai/whisper-tiny.en") - >>> model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-tiny.en") - >>> model.cuda() - - >>> # load audios > 30 seconds - >>> ds = load_dataset("distil-whisper/meanwhile", "default")["test"] - >>> # resample to 16kHz - >>> ds = ds.cast_column("audio", Audio(sampling_rate=16000)) - >>> # take first 8 audios and retrieve array - >>> audio = ds[:8]["audio"] - >>> audio = [x["array"] for x in audio] - - >>> # make sure to NOT truncate the input audio, to return the `attention_mask` and to pad to the longest audio - >>> inputs = processor(audio, return_tensors="pt", truncation=False, padding="longest", return_attention_mask=True, sampling_rate=16_000) - >>> inputs = inputs.to("cuda", torch.float32) - - >>> # transcribe audio to ids - >>> generated_ids = model.generate(**inputs) - - >>> transcription = processor.batch_decode(generated_ids, skip_special_tokens=True) - >>> transcription[0] - ' Folks, if you watch the show, you know, I spent a lot of time right over there. Patiently and astutely scrutinizing the boxwood and mahogany chest set of the day's biggest stories developing the central headline pawns, definitely maneuvering an oso topical night to F6, fainting a classic Sicilian, nade door variation on the news, all the while seeing eight moves deep and patiently marshalling the latest press releases into a fisher's shows in Lip Nitsky attack that culminates in the elegant lethal slow-played, all-passant checkmate that is my nightly monologue. But sometimes, sometimes, folks, I. CHEERING AND APPLAUSE Sometimes I startle away, cubside down in the monkey bars of a condemned playground on a super fun site. Get all hept up on goofballs. Rummage that were discarded tag bag of defective toys. Yank out a fist bowl of disembodied doll limbs, toss them on a stained kid's place mat from a defunct dennies. set up a table inside a rusty cargo container down by the Wharf and challenged toothless drifters to the godless bughouse blitz of tournament that is my segment. Meanwhile!' - ``` - - - *Shortform transcription*: If passed mel input features are < 30 seconds, the whole audio will be transcribed with a single call to generate. - - ```python - >>> import torch - >>> from transformers import AutoProcessor, WhisperForConditionalGeneration - >>> from datasets import load_dataset - - >>> processor = AutoProcessor.from_pretrained("openai/whisper-tiny.en") - >>> model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-tiny.en") - - >>> ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation") - - >>> inputs = processor(ds[0]["audio"]["array"], return_tensors="pt") - >>> input_features = inputs.input_features - - >>> generated_ids = model.generate(inputs=input_features) - - >>> transcription = processor.batch_decode(generated_ids, skip_special_tokens=True)[0] - >>> transcription - ' Mr. Quilter is the apostle of the middle classes, and we are glad to welcome his gospel.' - ``` - - """ - - if "inputs" in kwargs: - input_features = kwargs.pop("inputs") - warnings.warn( - "The input name `inputs` is deprecated. Please make sure to use `input_features` instead.", - FutureWarning, - ) - - if generation_config is None: - generation_config = copy.deepcopy(self.generation_config) - - return_dict_in_generate = ( - return_dict_in_generate - if return_dict_in_generate is not None - else generation_config.return_dict_in_generate - ) - - input_stride = self.model.encoder.conv1.stride[0] * self.model.encoder.conv2.stride[0] - if num_segment_frames is None: - num_segment_frames = input_stride * self.config.max_source_positions - - # 1. Check whether we're in shortform or longform mode - if input_features is not None: - total_input_frames = input_features.shape[-1] - elif "encoder_outputs" in kwargs: - encoder_outputs_shape = ( - kwargs["encoder_outputs"][0].shape - if isinstance(kwargs["encoder_outputs"], BaseModelOutput) - else kwargs["encoder_outputs"].shape - ) - total_input_frames = encoder_outputs_shape[1] * input_stride - else: - raise ValueError("Make sure to provide either `input_features` or `encoder_outputs` to `generate`.") - - is_shortform = total_input_frames <= num_segment_frames - - # 2. Make sure the generation config is correctly set depending on whether timestamps are to be returned or not - if return_timestamps is True: - if not hasattr(generation_config, "no_timestamps_token_id"): - raise ValueError( - "You are trying to return timestamps, but the generation config is not properly set. " - "Make sure to initialize the generation config with the correct attributes that are needed such as `no_timestamps_token_id`. " - "For more details on how to generate the approtiate config, refer to https://github.com/huggingface/transformers/issues/21878#issuecomment-1451902363" - ) - generation_config.return_timestamps = return_timestamps - elif not is_shortform: - if return_timestamps is False: - raise ValueError( - "You have passed more than 3000 mel input features (> 30 seconds) which automatically enables long-form generation which " - "requires the model to predict timestamp tokens. Please either pass `return_timestamps=True` or make sure to pass no more than 3000 mel input features." - ) - - if not hasattr(generation_config, "no_timestamps_token_id"): - raise ValueError( - "You have passed more than 3000 mel input features (> 30 seconds) which automatically enables long-form generation which " - "requires the generation config to have `no_timestamps_token_id` correctly. " - "Make sure to initialize the generation config with the correct attributes that are needed such as `no_timestamps_token_id`. " - "For more details on how to generate the approtiate config, refer to https://github.com/huggingface/transformers/issues/21878#issuecomment-1451902363" - "or make sure to pass no more than 3000 mel input features." - ) - - logger.info("Setting `return_timestamps=True` for long-form generation.") - generation_config.return_timestamps = True - else: - generation_config.return_timestamps = False - - # 3. Make sure to correctly set language-related parameters - if is_multilingual is not None: - if not hasattr(generation_config, "is_multilingual"): - raise ValueError( - "The generation config is outdated and is thus not compatible with the `is_multilingual` argument " - "to `generate`. Please update the generation config as per the instructions " - "https://github.com/huggingface/transformers/issues/25084#issuecomment-1664398224" - ) - generation_config.is_multilingual = is_multilingual - - if hasattr(generation_config, "is_multilingual") and not generation_config.is_multilingual: - if task is not None or language is not None: - raise ValueError( - "Cannot specify `task` or `language` for an English-only model. If the model is intended to be " - "multilingual, pass `is_multilingual=True` to generate, or update the generation config." - ) - - if language is not None: - if not hasattr(generation_config, "lang_to_id"): - raise ValueError( - "The generation config is outdated and is thus not compatible with the `language` argument " - "to `generate`. Either set the language using the `forced_decoder_ids` in the model config, " - "or update the generation config as per the instructions https://github.com/huggingface/transformers/issues/25084#issuecomment-1664398224" - ) - language = language.lower() - generation_config.language = language - if task is not None: - if not hasattr(generation_config, "task_to_id"): - raise ValueError( - "The generation config is outdated and is thus not compatible with the `task` argument " - "to `generate`. Either set the task using the `forced_decoder_ids` in the model config, " - "or update the generation config as per the instructions https://github.com/huggingface/transformers/issues/25084#issuecomment-1664398224" - ) - generation_config.task = task - - # 4. Add forced decoder ids depending on passed `language`, `task`,`prompt_ids`, `return_token_timestamps` and `return_timestamps` - forced_decoder_ids = None - # Legacy code for backward compatibility - if hasattr(self.config, "forced_decoder_ids") and self.config.forced_decoder_ids is not None: - forced_decoder_ids = self.config.forced_decoder_ids - elif ( - hasattr(self.generation_config, "forced_decoder_ids") - and self.generation_config.forced_decoder_ids is not None - ): - forced_decoder_ids = self.generation_config.forced_decoder_ids - else: - forced_decoder_ids = kwargs.get("forced_decoder_ids", None) - - if task is not None or language is not None or (forced_decoder_ids is None and prompt_ids is not None): - forced_decoder_ids = [] - if hasattr(generation_config, "language"): - if generation_config.language in generation_config.lang_to_id.keys(): - language_token = generation_config.language - elif generation_config.language in TO_LANGUAGE_CODE.keys(): - language_token = f"<|{TO_LANGUAGE_CODE[generation_config.language]}|>" - elif generation_config.language in TO_LANGUAGE_CODE.values(): - language_token = f"<|{generation_config.language}|>" - else: - is_language_code = len(generation_config.language) == 2 - raise ValueError( - f"Unsupported language: {generation_config.language}. Language should be one of:" - f" {list(TO_LANGUAGE_CODE.values()) if is_language_code else list(TO_LANGUAGE_CODE.keys())}." - ) - if language_token not in generation_config.lang_to_id: - raise ValueError( - f"{language_token} is not supported by this specific model as it is not in the `generation_config.lang_to_id`." - "(You should just add it to the generation config)" - ) - forced_decoder_ids.append((1, generation_config.lang_to_id[language_token])) - else: - forced_decoder_ids.append((1, None)) # automatically detect the language - - if hasattr(generation_config, "task"): - if generation_config.task in TASK_IDS: - forced_decoder_ids.append((2, generation_config.task_to_id[generation_config.task])) - else: - raise ValueError( - f"The `{generation_config.task}`task is not supported. The task should be one of `{TASK_IDS}`" - ) - elif hasattr(generation_config, "task_to_id"): - forced_decoder_ids.append((2, generation_config.task_to_id["transcribe"])) # defaults to transcribe - if hasattr(generation_config, "no_timestamps_token_id") and not generation_config.return_timestamps: - idx = forced_decoder_ids[-1][0] + 1 if forced_decoder_ids else 1 - forced_decoder_ids.append((idx, generation_config.no_timestamps_token_id)) - - if forced_decoder_ids is not None: - generation_config.forced_decoder_ids = forced_decoder_ids - - if prompt_ids is not None: - if kwargs.get("decoder_start_token_id") is not None: - raise ValueError( - "When specifying `prompt_ids`, you cannot also specify `decoder_start_token_id` as it gets overwritten." - ) - prompt_ids = prompt_ids.tolist() - decoder_start_token_id, *text_prompt_ids = prompt_ids - # Slicing the text prompt ids in a manner consistent with the OpenAI implementation - # to accomodate context space for the prefix (see https://github.com/openai/whisper/blob/c09a7ae299c4c34c5839a76380ae407e7d785914/whisper/decoding.py#L599) - text_prompt_ids = text_prompt_ids[-self.config.max_target_positions // 2 - 1 :] - # Set the decoder_start_token_id to <|startofprev|> - kwargs.update({"decoder_start_token_id": decoder_start_token_id}) - - # If the user passes `max_new_tokens`, increase its number to account for the prompt - if kwargs.get("max_new_tokens", None) is not None: - kwargs["max_new_tokens"] += len(text_prompt_ids) - if kwargs["max_new_tokens"] >= self.config.max_target_positions: - raise ValueError( - f"The length of the sliced `prompt_ids` is {len(text_prompt_ids)}, and the `max_new_tokens` " - f"{kwargs['max_new_tokens'] - len(text_prompt_ids)}. Thus, the combined length of the sliced " - f"`prompt_ids` and `max_new_tokens` is: {kwargs['max_new_tokens']}. This exceeds the " - f"`max_target_positions` of the Whisper model: {self.config.max_target_positions}. " - "You should either reduce the length of your prompt, or reduce the value of `max_new_tokens`, " - f"so that their combined length is less that {self.config.max_target_positions}." - ) - - # Reformat the forced_decoder_ids to incorporate the prompt - non_prompt_forced_decoder_ids = ( - kwargs.pop("forced_decoder_ids", None) or generation_config.forced_decoder_ids - ) - forced_decoder_ids = [ - *text_prompt_ids, - generation_config.decoder_start_token_id, - *[token for _rank, token in non_prompt_forced_decoder_ids], - ] - forced_decoder_ids = [(rank + 1, token) for rank, token in enumerate(forced_decoder_ids)] - generation_config.forced_decoder_ids = forced_decoder_ids - - if return_token_timestamps: - kwargs["output_attentions"] = True - return_dict_in_generate = True - kwargs["output_scores"] = True - - if getattr(generation_config, "task", None) == "translate": - logger.warning("Token-level timestamps may not be reliable for task 'translate'.") - if not hasattr(generation_config, "alignment_heads"): - raise ValueError( - "Model generation config has no `alignment_heads`, token-level timestamps not available. " - "See https://gist.github.com/hollance/42e32852f24243b748ae6bc1f985b13a on how to add this property to the generation config." - ) - - if kwargs.get("num_frames") is not None: - generation_config.num_frames = kwargs.pop("num_frames") - - if generation_config.return_timestamps is True: - last_forced_decoder_ids = ( - generation_config.forced_decoder_ids[-1][-1] - if hasattr(self.config, "forced_decoder_ids") and self.config.forced_decoder_ids - else None - ) - if last_forced_decoder_ids == self.generation_config.no_timestamps_token_id: - # remove no_timestamp to be forcefully generated if we want to return timestamps - # this is also important to make sure `WhisperTimeStampLogitsProcessor` functions correctly - forced_decoder_ids = generation_config.forced_decoder_ids[:-1] - # Make sure that if list is empty we set it to None - generation_config.forced_decoder_ids = None if len(forced_decoder_ids) == 0 else forced_decoder_ids - - timestamp_processor = [WhisperTimeStampLogitsProcessor(generation_config)] - logits_processor = ( - timestamp_processor if logits_processor is None else timestamp_processor + logits_processor - ) - - # 5. If we're in shortform mode, simple generate the whole input at once and return the output - if is_shortform: - outputs = super().generate( - input_features, - generation_config, - logits_processor, - stopping_criteria, - prefix_allowed_tokens_fn, - synced_gpus, - return_dict_in_generate=return_dict_in_generate, - **kwargs, - ) - - if return_token_timestamps and hasattr(generation_config, "alignment_heads"): - num_frames = getattr(generation_config, "num_frames", None) - outputs["token_timestamps"] = self._extract_token_timestamps( - outputs, generation_config.alignment_heads, num_frames=num_frames - ) - - return outputs - - # 6. Else we're in longform mode which is more complex. We need to chunk the audio input depending on when the model generated - # timestamp tokens - # 6.1 Set running parameters for while loop - if not return_segments and return_dict_in_generate: - raise ValueError( - "Make sure to set `return_segments=True` to return generation outputs as part of the `'segments' key.`" - ) - - # if input is longer than 30 seconds we default to long-form generation - timestamp_begin = self.generation_config.no_timestamps_token_id + 1 - # input stride is mel frames per encoder output vector which is the product of all conv strides - batch_size = input_features.shape[0] - - if batch_size > 1 and attention_mask is None: - raise ValueError( - "When doing long-form audio transcription, make sure to pass an `attention_mask`. You can retrieve the `attention_mask` by doing `processor(audio, ..., return_attention_mask=True)` " - ) - elif batch_size > 1: - max_frames = attention_mask.sum(-1).cpu().to(torch.long) - seek = torch.zeros((batch_size,), dtype=torch.long) - else: - max_frames = torch.ones((1,), dtype=torch.long) * total_input_frames - seek = torch.zeros((1,), dtype=torch.long) - - current_segments = [[] for _ in range(batch_size)] - cur_to_prev_index_map = list(range(batch_size)) - - # batch size can decrease during the run - cur_bsz = prev_bsz = batch_size - - # 6.2 Transcribe audio until we reach the end of all input audios - while (seek < max_frames).any(): - prev_bsz = cur_bsz - - # 6.3 NOTE: When in longform transcription mode and batch size > 1 we need to dynamically reduce the batch size during the loop - # in case one audio finished earlier than another one. Thus, we need to keep a table of "previous-index-2-current-index" in order - # to know which original audio is being decoded - new_cur_to_prev_index_map = [] - for i in range(prev_bsz): - prev_i = cur_to_prev_index_map[i] - if seek[prev_i] >= max_frames[prev_i]: - cut_index = i + (cur_bsz - prev_bsz) - cur_bsz -= 1 - input_features = torch.cat([input_features[:cut_index], input_features[cut_index + 1 :]], dim=0) - else: - # cut out index that goes away - new_cur_to_prev_index_map.append(prev_i) - - # 6.4 Set updated index map, duration of previously decoded chunks and number of max frames of current decoding chunk - cur_to_prev_index_map = new_cur_to_prev_index_map - time_offset = seek * time_precision / input_stride - seek_num_frames = (max_frames - seek).clamp(max=num_segment_frames) - - # 6.5 Make sure that all inputs are padded to the same input length - segment_input = [] - for i in range(cur_bsz): - prev_i = cur_to_prev_index_map[i] - segment_input_slice = input_features[ - i : i + 1, :, seek[prev_i] : seek[prev_i] + seek_num_frames[prev_i] - ] - - if segment_input_slice.shape[-1] < num_segment_frames: - # pad to 3000 if necessary - segment_input_slice = F.pad( - segment_input_slice, pad=(0, num_segment_frames - segment_input_slice.shape[-1]) - ) - - segment_input.append(segment_input_slice) - - segment_input = torch.cat(segment_input, dim=0) - - # 6.6 Batch generate current chunk - seek_outputs = super().generate( - segment_input, - generation_config, - logits_processor, - stopping_criteria, - prefix_allowed_tokens_fn, - synced_gpus, - return_dict_in_generate=return_dict_in_generate, - **kwargs, - ) - - if return_token_timestamps and hasattr(generation_config, "alignment_heads"): - num_frames = getattr(generation_config, "num_frames", None) - seek_outputs["token_timestamps"] = self._extract_token_timestamps( - seek_outputs, generation_config.alignment_heads, num_frames=num_frames - ) - - if return_dict_in_generate: - seek_sequences = seek_outputs["sequences"] - seek_outputs = [ - {k: v[i] for k, v in seek_outputs.items()} - for i in range(next(iter(seek_outputs.values())).size(0)) - ] - else: - seek_sequences = seek_outputs - - # 6.7 Loop over each decoded audio individually as each decoding can be of a different length - for i, seek_sequence in enumerate(seek_sequences): - prev_i = cur_to_prev_index_map[i] - - # make sure we cut a predicted EOS token if we are not finished with the generation yet - is_not_final = (seek[prev_i] + num_segment_frames) < max_frames[prev_i] - if is_not_final and seek_sequence[-1] == self.generation_config.eos_token_id: - seek_sequence = seek_sequence[:-1] - - # remove all padding tokens - if seek_sequence[-1] == self.generation_config.pad_token_id: - num_paddings = (seek_sequence == self.generation_config.pad_token_id).sum() - seek_sequence = seek_sequence[:-num_paddings] - - segments, segment_offset = self._retrieve_segment( - seek_sequence=seek_sequence, - seek_outputs=seek_outputs, - time_offset=time_offset, - timestamp_begin=timestamp_begin, - seek_num_frames=seek_num_frames, - cur_bsz=cur_bsz, - time_precision=time_precision, - input_stride=input_stride, - prev_idx=prev_i, - idx=i, - ) - - current_segments[prev_i] += segments - seek[prev_i] += segment_offset - - # 7. Once all segments are added to the list of all segments, called `current_segments`, we extract the predicted - # output tokens from the list of dicts. If we use batch size > 1, we make sure to pad the output - sequences = [] - max_total_length = 0 - for current_segment_list in current_segments: - sequences.append(torch.cat([d["tokens"] for d in current_segment_list], dim=-1)) - max_total_length = max(max_total_length, len(sequences[-1])) - - for i in range(batch_size): - sequences[i] = F.pad( - sequences[i], pad=(0, max_total_length - len(sequences[i])), value=self.generation_config.pad_token_id - ) - - sequences = torch.stack(sequences, dim=0) - - # 8. If we return all segments, the predicted output sequences are put under `"sequences"`. - if return_segments: - return {"sequences": sequences, "segments": current_segments} - - return sequences - - @staticmethod - def _retrieve_segment( - seek_sequence, - seek_outputs, - time_offset, - timestamp_begin, - seek_num_frames, - cur_bsz, - time_precision, - input_stride, - prev_idx, - idx, - ): - # find the predicted "end of segment" predictions of Whisper - # "end of segment" predictions occur whenever Whisper predicts a timestamp token - timestamp_tokens: torch.Tensor = seek_sequence.ge(timestamp_begin) - single_timestamp_ending = timestamp_tokens[-2:].tolist() == cur_bsz * [[False, True]] - timestamp_segment_indices = torch.where(timestamp_tokens[:-1] & timestamp_tokens[1:])[0] - - # If whisper predicted a "end of segment" via a timestep token, let's go ever each - # "end of segment" prediction and slice the decoding into segments accordingly - if len(timestamp_segment_indices) > 0: - # if the output contains two consecutive timestamp tokens - slices = timestamp_segment_indices.tolist() - segments = [] - if single_timestamp_ending: - slices.append(len(seek_sequence)) - - last_slice = 0 - # Add each segment to list of all segments - for current_slice in slices: - sliced_tokens = seek_sequence[last_slice + 1 : current_slice + 1] - start_timestamp_pos = sliced_tokens[0].item() - timestamp_begin - end_timestamp_pos = sliced_tokens[-1].item() - timestamp_begin - segments.append( - { - "start": time_offset[prev_idx] + start_timestamp_pos * time_precision, - "end": time_offset[prev_idx] + end_timestamp_pos * time_precision, - "tokens": sliced_tokens, - "result": seek_outputs[idx], - } - ) - last_slice = current_slice - - if single_timestamp_ending: - # single timestamp at the end means no speech after the last timestamp. - segment_offset = seek_num_frames[prev_idx] - else: - # otherwise, ignore the unfinished segment and seek to the last timestamp - # here we throw away all predictions after the last predicted "end of segment" - # since we are cutting right in the middle of an audio - last_timestamp_pos = seek_sequence[last_slice].item() - timestamp_begin - segment_offset = last_timestamp_pos * input_stride - else: - # If whisper does not predict any "end of segment" token, then - # the whole decoding is considered a segment and we add it to the list of segments - timestamps = seek_sequence[timestamp_tokens.nonzero().flatten()] - last_timestamp_pos = seek_num_frames[prev_idx] - if timestamps.numel() > 0 and timestamps[-1].item() != timestamp_begin: - # no consecutive timestamps but it has a timestamp; use the last one. - last_timestamp_pos = timestamps[-1].item() - timestamp_begin - - segments = [ - { - "start": time_offset[prev_idx], - "end": time_offset[prev_idx] + last_timestamp_pos * time_precision, - "tokens": seek_sequence, - "result": seek_outputs[idx], - } - ] - segment_offset = seek_num_frames[prev_idx] - - return segments, segment_offset - def prepare_inputs_for_generation( self, decoder_input_ids, @@ -2508,8 +1800,13 @@ def prepare_inputs_for_generation( use_cache=None, encoder_outputs=None, attention_mask=None, + decoder_attention_mask=None, **kwargs, ): + decoder_position_ids = None + if decoder_attention_mask is not None: + decoder_position_ids = (decoder_attention_mask.cumsum(-1) - 1).clamp(min=0) + if past_key_values is not None: past_length = past_key_values[0][0].shape[2] @@ -2522,12 +1819,16 @@ def prepare_inputs_for_generation( decoder_input_ids = decoder_input_ids[:, remove_prefix_length:] + if decoder_position_ids is not None and decoder_position_ids.shape[1] > decoder_input_ids.shape[1]: + decoder_position_ids = decoder_position_ids[:, remove_prefix_length:] + return { "encoder_outputs": encoder_outputs, "past_key_values": past_key_values, "decoder_input_ids": decoder_input_ids, "use_cache": use_cache, - "decoder_attention_mask": None, + "decoder_attention_mask": decoder_attention_mask, + "decoder_position_ids": decoder_position_ids, } @staticmethod @@ -2539,99 +1840,6 @@ def _reorder_cache(past_key_values, beam_idx): ) return reordered_past - def _extract_token_timestamps(self, generate_outputs, alignment_heads, time_precision=0.02, num_frames=None): - """ - Calculates token-level timestamps using the encoder-decoder cross-attentions and dynamic time-warping (DTW) to - map each output token to a position in the input audio. If `num_frames` is specified, the encoder-decoder - cross-attentions will be cropped before applying DTW. - - Returns: - tensor containing the timestamps in seconds for each predicted token - """ - # Create a list with `decoder_layers` elements, each a tensor of shape - # (batch size, attention_heads, output length, input length). - cross_attentions = [] - for i in range(self.config.decoder_layers): - cross_attentions.append(torch.cat([x[i] for x in generate_outputs.cross_attentions], dim=2)) - - # Select specific cross-attention layers and heads. This is a tensor - # of shape (batch size, num selected, output length, input length). - weights = torch.stack([cross_attentions[l][:, h] for l, h in alignment_heads]) - weights = weights.permute([1, 0, 2, 3]) - - if "beam_indices" in generate_outputs: - # If beam search has been used, the output sequences may have been generated for more timesteps than their sequence_lengths - # since the beam search strategy chooses the most probable sequences at the end of the search. - # In that case, the cross_attentions weights are too long and we have to make sure that they have the right output_length - weight_length = (generate_outputs.beam_indices != -1).sum(-1).max() - weights = weights[:, :, :weight_length] - - # If beam index is still -1, it means that the associated token id is EOS - # We need to replace the index with 0 since index_select gives an error if any of the indexes is -1. - beam_indices = generate_outputs.beam_indices[:, :weight_length] - beam_indices = beam_indices.masked_fill(beam_indices == -1, 0) - - # Select the cross attention from the right beam for each output sequences - weights = torch.stack( - [ - torch.index_select(weights[:, :, i, :], dim=0, index=beam_indices[:, i]) - for i in range(beam_indices.shape[1]) - ], - dim=2, - ) - - timestamps = torch.zeros_like(generate_outputs.sequences, dtype=torch.float32) - batch_size = timestamps.shape[0] - - if num_frames is not None: - # two cases: - # 1. num_frames is the same for each sample -> compute the DTW matrix for each sample in parallel - # 2. num_frames is different, compute the DTW matrix for each sample sequentially - - # we're using np.unique because num_frames can be int/list/tuple - if len(np.unique(num_frames)) == 1: - # if num_frames is the same, no need to recompute matrix, std and mean for each element of the batch - num_frames = num_frames if isinstance(num_frames, int) else num_frames[0] - - weights = weights[..., : num_frames // 2] - else: - # num_frames is of shape (batch_size,) whereas batch_size is truely batch_size*num_return_sequences - repeat_time = batch_size if isinstance(num_frames, int) else batch_size // len(num_frames) - num_frames = np.repeat(num_frames, repeat_time) - - if num_frames is None or isinstance(num_frames, int): - # Normalize and smoothen the weights. - std = torch.std(weights, dim=-2, keepdim=True, unbiased=False) - mean = torch.mean(weights, dim=-2, keepdim=True) - weights = (weights - mean) / std - weights = _median_filter(weights, self.config.median_filter_width) - - # Average the different cross-attention heads. - weights = weights.mean(dim=1) - - # Perform dynamic time warping on each element of the batch. - for batch_idx in range(batch_size): - if num_frames is not None and isinstance(num_frames, (tuple, list, np.ndarray)): - matrix = weights[batch_idx, ..., : num_frames[batch_idx] // 2] - - # Normalize and smoothen the weights. - std = torch.std(matrix, dim=-2, keepdim=True, unbiased=False) - mean = torch.mean(matrix, dim=-2, keepdim=True) - matrix = (matrix - mean) / std - matrix = _median_filter(matrix, self.config.median_filter_width) - - # Average the different cross-attention heads. - matrix = matrix.mean(dim=0) - else: - matrix = weights[batch_idx] - - text_indices, time_indices = _dynamic_time_warping(-matrix.cpu().double().numpy()) - jumps = np.pad(np.diff(text_indices), (1, 0), constant_values=1).astype(bool) - jump_times = time_indices[jumps] * time_precision - timestamps[batch_idx, 1:] = torch.tensor(jump_times) - - return timestamps - class WhisperDecoderWrapper(WhisperPreTrainedModel): """ diff --git a/src/transformers/models/whisper/tokenization_whisper.py b/src/transformers/models/whisper/tokenization_whisper.py index 1316fc9b83fc20..127f5be6193d72 100644 --- a/src/transformers/models/whisper/tokenization_whisper.py +++ b/src/transformers/models/whisper/tokenization_whisper.py @@ -530,10 +530,21 @@ def _decode_with_timestamps(self, token_ids, skip_special_tokens=False, time_pre """ timestamp_begin = self.all_special_ids[-1] + 1 outputs = [[]] + + cur_max_timestamp = 0.0 + prev_segments_len = 0.0 + for token in token_ids: if token >= timestamp_begin: - timestamp = f"<|{(token - timestamp_begin) * time_precision:.2f}|>" - outputs.append(timestamp) + timestamp = float((token - timestamp_begin) * time_precision) + + if timestamp < cur_max_timestamp: + # next segment has started + prev_segments_len += cur_max_timestamp + + cur_max_timestamp = timestamp + + outputs.append(f"<|{(timestamp + prev_segments_len):.2f}|>") outputs.append([]) else: outputs[-1].append(token) @@ -631,7 +642,7 @@ def decode( skip_special_tokens: bool = False, clean_up_tokenization_spaces: bool = None, output_offsets: bool = False, - time_precision=0.02, + time_precision: float = 0.02, decode_with_timestamps: bool = False, normalize: bool = False, basic_normalize: bool = False, diff --git a/src/transformers/models/whisper/tokenization_whisper_fast.py b/src/transformers/models/whisper/tokenization_whisper_fast.py index e987bc35683ec8..509175be994f75 100644 --- a/src/transformers/models/whisper/tokenization_whisper_fast.py +++ b/src/transformers/models/whisper/tokenization_whisper_fast.py @@ -224,10 +224,21 @@ def _decode_with_timestamps(self, token_ids, skip_special_tokens=False, time_pre """ timestamp_begin = self.all_special_ids[-1] + 1 outputs = [[]] + + cur_max_timestamp = 0.0 + prev_segments_len = 0.0 + for token in token_ids: if token >= timestamp_begin: - timestamp = f"<|{(token - timestamp_begin) * time_precision:.2f}|>" - outputs.append(timestamp) + timestamp = float((token - timestamp_begin) * time_precision) + + if timestamp < cur_max_timestamp: + # next segment has started + prev_segments_len += cur_max_timestamp + + cur_max_timestamp = timestamp + + outputs.append(f"<|{(timestamp + prev_segments_len):.2f}|>") outputs.append([]) else: outputs[-1].append(token) @@ -330,7 +341,7 @@ def decode( skip_special_tokens: bool = False, clean_up_tokenization_spaces: bool = None, output_offsets: bool = False, - time_precision=0.02, + time_precision: float = 0.02, decode_with_timestamps: bool = False, normalize: bool = False, basic_normalize: bool = False, diff --git a/tests/models/whisper/test_modeling_whisper.py b/tests/models/whisper/test_modeling_whisper.py index e0f369c7a678cb..505d2e991033d8 100644 --- a/tests/models/whisper/test_modeling_whisper.py +++ b/tests/models/whisper/test_modeling_whisper.py @@ -18,6 +18,7 @@ import inspect import os import random +import re import tempfile import time import unittest @@ -84,6 +85,7 @@ def __init__( self.batch_size = batch_size self.max_length = max_length self.count = 0 + self.begin_index = 0 self.let_pass = [[] for _ in range(batch_size)] for k in range(batch_size): @@ -91,9 +93,12 @@ def __init__( for _ in range(10000): self.let_pass[k].append(random.randint(1, 10) <= 3) + def set_begin_index(self, begin_index: int): + self.begin_index = begin_index + def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor: # we don't want to randomely sample timestamp tokens - if input_ids.shape[-1] > 1: + if input_ids.shape[-1] != self.begin_index: scores[:, self.timestamp_begin :] = -float("inf") self.no_time_stamp_counter = [x + 1 for x in self.no_time_stamp_counter] @@ -1314,7 +1319,7 @@ def test_generate_with_prompt_ids_max_length(self): model.generate(input_features, max_new_tokens=1, prompt_ids=prompt_ids) - def test_longform_generate_single_batch(self): + def _check_longform_generate_single_batch(self, condition_on_prev_tokens): config, input_dict = self.model_tester.prepare_config_and_inputs_for_common() model = WhisperForConditionalGeneration(config).eval().to(torch_device) @@ -1354,20 +1359,30 @@ def test_longform_generate_single_batch(self): timestamp_begin = vocab_size - num_timestamp_tokens model.generation_config.no_timestamps_token_id = timestamp_begin - 1 model.generation_config.eos_token_id = None + model.config.eos_token_id = None model.generation_config._detect_timestamp_from_logprob = False # make sure that we only have the same begin token model.generation_config.max_initial_timestamp_index = 0 + model.generation_config.prev_bos_token_id = timestamp_begin - 3 + + gen_kwargs = { + "logits_processor": logits_processor, + "return_segments": True, + "condition_on_prev_tokens": condition_on_prev_tokens, + } - outputs = model.generate(long_input_features, logits_processor=logits_processor, return_segments=True) + if condition_on_prev_tokens: + gen_kwargs["no_speech_threshold"] = 0.6 + gen_kwargs["temperature"] = (0.0, 0.2, 0.4, 0.6, 0.8, 1.0) + gen_kwargs["compression_ratio_threshold"] = 2.4 + gen_kwargs["logprob_threshold"] = -1.0 + + outputs = model.generate(long_input_features, **gen_kwargs) segments = outputs["segments"][0] - for i, segment in enumerate(segments): + for _, segment in enumerate(segments): assert segment["start"] <= segment["end"], "start has to be smaller equal end" - assert ( - segment["tokens"][0] == model.generation_config.decoder_start_token_id - or segment["tokens"][0] >= timestamp_begin - ), "First segment token should be a timestamp token" assert any( s > timestamp_begin for s in segment["tokens"][1:] ), f"At least one segment token should be a timestamp token, but not first., {segment['tokens']}" @@ -1375,7 +1390,13 @@ def test_longform_generate_single_batch(self): segment["tokens"].shape[-1] <= max_length ), "make sure that no segment is larger than max generation length" - def test_longform_generate_multi_batch(self): + def test_longform_generate_single_batch(self): + self._check_longform_generate_single_batch(condition_on_prev_tokens=False) + + def test_longform_generate_single_batch_cond_prev(self): + self._check_longform_generate_single_batch(condition_on_prev_tokens=True) + + def _check_longform_generate_multi_batch(self, condition_on_prev_tokens): config, input_dict = self.model_tester.prepare_config_and_inputs_for_common() model = WhisperForConditionalGeneration(config).eval().to(torch_device) @@ -1383,7 +1404,6 @@ def test_longform_generate_multi_batch(self): # len = 250 with num_input_frames = 60 long_input_features = torch.cat([input_features.repeat(1, 1, 4), input_features[:, :, :10]], dim=-1) - long_input_features[:1, :, :200] input_features_2 = long_input_features[1:] attention_mask = torch.ones( (2, long_input_features.shape[-1]), dtype=input_features.dtype, device=input_features.device @@ -1395,25 +1415,34 @@ def test_longform_generate_multi_batch(self): batch_size = 1 num_timestamp_tokens = 20 - max_length = 16 + max_new_tokens = 16 timestamp_begin = vocab_size - num_timestamp_tokens model.generation_config.no_timestamps_token_id = timestamp_begin - 1 model.generation_config.eos_token_id = None + model.config.eos_token_id = None model.generation_config._detect_timestamp_from_logprob = False # make sure that we only have the same begin token model.generation_config.max_initial_timestamp_index = 0 + model.generation_config.max_new_tokens = max_new_tokens + model.generation_config.prev_bos_token_id = timestamp_begin - 3 logits_processor = [ DummyTimestampLogitProcessor( vocab_size - num_timestamp_tokens, vocab_size, batch_size=batch_size, - max_length=max_length, + max_length=max_new_tokens, min_space=4, seed=1, ) ] - outputs_2 = model.generate(input_features_2, logits_processor=logits_processor, return_segments=True) + outputs_2 = model.generate( + input_features_2, + max_new_tokens=max_new_tokens, + logits_processor=logits_processor, + condition_on_prev_tokens=condition_on_prev_tokens, + return_segments=True, + ) tokens_2 = outputs_2["sequences"][0] segments_2 = outputs_2["segments"][0] @@ -1423,24 +1452,37 @@ def test_longform_generate_multi_batch(self): vocab_size - num_timestamp_tokens, vocab_size, batch_size=batch_size, - max_length=max_length, + max_length=max_new_tokens, min_space=4, seed=0, ) ] - outputs = model.generate( - long_input_features, attention_mask=attention_mask, logits_processor=logits_processor, return_segments=True - ) + gen_kwargs = { + "logits_processor": logits_processor, + "return_segments": True, + "condition_on_prev_tokens": condition_on_prev_tokens, + "attention_mask": attention_mask, + "max_new_tokens": max_new_tokens, + } + + outputs = model.generate(long_input_features, **gen_kwargs) tokens = outputs["sequences"][1] segments = outputs["segments"][1] - assert tokens_2.tolist() == tokens.tolist() + # make sure batched and non-batched is the same + assert tokens_2.tolist() == tokens[: tokens_2.shape[-1]].tolist() for seg1, seg2 in zip(segments_2, segments): assert seg1["start"] == seg2["start"] assert seg1["end"] == seg2["end"] assert seg1["tokens"].tolist() == seg2["tokens"].tolist() + def test_longform_generate_multi_batch(self): + self._check_longform_generate_multi_batch(condition_on_prev_tokens=False) + + def test_longform_generate_multi_batch_cond_prev(self): + self._check_longform_generate_multi_batch(condition_on_prev_tokens=True) + @require_torch @require_torchaudio @@ -2089,12 +2131,59 @@ def test_whisper_longform_single_batch(self): assert decoded == EXPECTED_TEXT + decoded_with_timestamps = processor.batch_decode(result, skip_special_tokens=True, decode_with_timestamps=True) + + no_timestamp_matches = re.split(r"<\|[\d\.]+\|>", decoded_with_timestamps[0]) + + assert ["".join(no_timestamp_matches)] == EXPECTED_TEXT + + timestamp_matches = re.findall(r"<\|[\d\.]+\|>", decoded_with_timestamps[0]) + + timestamp_floats = [float(t[2:-2]) for t in timestamp_matches] + + is_increasing = all(timestamp_floats[i] <= timestamp_floats[i + 1] for i in range(len(timestamp_floats) - 1)) + + assert is_increasing + + @slow + def test_whisper_longform_single_batch_prev_cond(self): + # fmt: off + EXPECTED_TEXT = [""" Mr. Quilter is the apostle of the middle classes, and we are glad to welcome his gospel. Nor is Mr. Quilter's manner less interesting than his matter. He tells us that at this festive season of the year, with Christmas and roast beef looming before us, similes drawn from eating and its results occur most readily to the mind. He has grieved doubts whether Sir Frederick Layton's work is really Greek after all, and can discover in it but little of rocky Ithaca. Linnell's pictures are a sort of up-gards and atom paintings, and Mason's exquisite itals are as national as a jingo poem. Mr. Birk at Foster's landscapes smile at one much in the same way that Mr. Carker used to flash his teeth. When Mr. John Collier gives his sitter a cheerful slap in the back, before he says like a shampooer and a Turkish bath, next man it is obviously unnecessary for us to point out how luminous these criticisms are, how delicate an expression. On the general principles of art, Mr. Quilter writes with equal lucidity. He tells us is of a different quality to mathematics, and finish in art is adding more effect. As for etchings, there are two kinds, British and foreign. He laments most bitterly the divorce that has been made between decorative art and what we usually call pictures. Makes a customary appeal to the last judgment and reminds us that in the great days of art Michelangelo was the furnishing upholsterer. Near the fire, any ornaments Fred brought home from India on the mental board. In fact, he is quite severe on Mr. Ruskin for not recognizing that a picture should denote the frailty of man, and remarks was pleasing courtesy in felicitous grace that many faces are feeling. Unfortunately his own work never does get good. Mr. Quilter has missed his chance, for he has failed even to make himself the tupper of painting. By Harry Quilter M. A. A man said to the universe, Sir, I exist. Sweat covered Breon's body trickling into the tight-lowing cloth that was the only german he wore. The cut on his chest still dripping blood. The ache of his overstrained eyes, even the soaring arena around him with thousands of spectators, retroveilities not worth thinking about. His instant panic was followed by a small sharp blow high on his chest. One minute, a voice said, and a time buzzer sounded. A minute is not a very large measure of time, and his body needed every fraction of it. The buzzers were triggered as muscles into complete relaxation. Only his heart and lungs worked on at a strong measured rate. He was in reverie, sliding along the borders of consciousness. The contestants in the twenties needed undisturbed rest. Therefore, nights in the dormitories were as quiet as death. Particularly so, on this last night, when only two of the little cubicles were occupied, the thousands of others standing with dark empty doors. The other voice snapped with a harsh urgency, clearly used to command. I'm here because the matter is of utmost importance, and brand is the one I must see. Now stand aside. The twenties, he must have drawn his gun because the intruder said quickly, but that away you're being a fool. But there was silence then, and still wondering, Breon was once more asleep. Ten seconds, he asked the handler who was needing his aching muscles. A red-haired mountain of a man with an apparently inexhaustible store of energy. There could be little art in this last and final round of fencing. Just thrust and parry and victory to the stronger. Your man who entered the twenties had his own training tricks. They were appeared to be an immediate association with the death trauma, as if the two were inextricably linked into one. The strength that enables someone in a trance to hold his body stiff and unsupported except at two points, the head and heels. This is physically impossible when conscious. Breon's death was in some ways easier than defeat. Breon's softly spoke the auto-hypnotic phrases that triggered the process. When the buzzer sounded, he pulled his foil from his second startled grasp and ran forward. Our role looked amazed at the sudden fury of the attack, then smiled. He thought it was the last burst of energy. He knew how close they both were to exhaustion. Breon saw something close to panic on his opponent's face when the man finally recognized his error. A wave of despair rolled out from our rogue. Breon sensed it and knew the fifth point was his. Then the powerful twist that's rested aside, in and under the guard, because he was sleeping instead of conquering, the lovely rose princess has become a fiddle without a bow, while poor Shaggy sits there, accooing dove. He has gone and gone for good, answered Polychrome, who had managed to squeeze into the room beside the dragon, and had witnessed the occurrences with much interest. I have remained a prisoner only because I wished to be one. And with this, he stepped forward and burst the stout chains as easily as if they had been threads. The little girl had been asleep, but she heard the wraps and opened the door. The king has flooded disgrace, and your friends are asking for you. I begged Ruggido long ago to send him away, but he would not do so. I also offered to help your brother to escape, but he would not go. He eats and sleeps very steadily, replied the new king. I hope he doesn't work too hard, since Shaggy. He doesn't work at all. In fact, there's nothing he can do in these dominions, as well as our gnomes, whose numbers are so great that it worries us to keep them all busy. Not exactly, we've turned Calico. Where is my brother now, inquired Shaggy. In the metal forest. Where is that? The metal forest is in the great domed cavern, the largest and all-ard dominions, replied Calico. Calico hesitated. However, if we look sharp, we may be able to discover one of these secret ways. Oh no, I'm quite sure he didn't. That's funny, remarked Betsy thoughtfully. I don't believe Anne knew any magic, or she'd have worked it before. I do not know, confessed Shaggy. True, agreed Calico. Calico went to the big gong, and pounded on it, just as we're good to be used to do. But no one answered the summons. Having returned to the royal cavern, Calico first pounded the gong, and then sat in the throne, wearing Regidos discarded Ruby Crown, and holding in his hand to scepter, which Regidos had so often thrown at his head."""] + # fmt: on + + processor = WhisperProcessor.from_pretrained("openai/whisper-tiny.en") + model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-tiny.en") + model = model.to("cuda") + + ds = load_dataset("patrickvonplaten/librispeech_asr_dummy", "clean") + one_audio = np.concatenate([x["array"] for x in ds["validation"]["audio"]], dtype=np.float32) + + input_features = processor(one_audio, return_tensors="pt", truncation=False, padding="longest")[ + "input_features" + ] + input_features = input_features.to(device="cuda") + + gen_kwargs = { + "return_timestamps": True, + "no_speech_threshold": 0.6, + "temperature": (0.0, 0.2, 0.4, 0.6, 0.8, 1.0), + "compression_ratio_threshold": 1.35, + "condition_on_prev_tokens": True, + "logprob_threshold": -1.0, + } + + torch.manual_seed(0) + result = model.generate(input_features, **gen_kwargs) + decoded = processor.batch_decode(result, skip_special_tokens=True) + + assert decoded == EXPECTED_TEXT + @slow def test_whisper_longform_multi_batch(self): # fmt: off EXPECTED_TEXT_1 = [" Mr. Quilter's manner less interesting than his matter. He tells us that at this festive season of the year, with Christmas and roast beef looming before us, similes drawn from eating and its results occur most readily to the mind. He has grave doubts whether Sir Frederick Layton's work is really Greek after all, and can discover in it but little of rocky Ithaca. Linnell's pictures are a sort of up-gards and atom paintings, and Mason's exquisite idles are as national as a jingo poem. Mr. Birkett Foster's landscapes smile at one much in the same way that Mr. Carker used to flash his teeth. And Mr. John Collier gives his sitter a cheerful slap in the back, before he says, like a shampooer and a Turkish bath. Next man, it is obviously unnecessary for us to point out how luminous these criticisms are, how delicate an expression. On the general principles of art, Mr. Quilter writes with equal lucidity. Painting he tells us is of a different quality to mathematics, and finish in art is adding more effect. As for etchings, there are two kinds, British and foreign. He laments most bitterly the divorce that has been made between decorative art and what we usually call pictures. Mix a customary appeal to the last judgment and reminds us that in the great days of art Michelangelo was the furnishing a poster or near the fire, and the ornaments Fred brought home from India on the mental board. In fact, he is quite severe on Mr. Ruskin for not recognizing that a picture should denote the frailty of man. And remarks was pleasing courtesy in Felicitis Grace that many faces are feeling. Only unfortunately his own work never does get good. Mr. Quilter has missed his chance, for he has failed even to make himself the Tupper of painting. a Harry Quilter M.A. A man said to the universe, Sir, I exist. Sweat-covered Breon's body trickling into the tight-wing cloth that was the only germany war. The cut on his chest still dripping blood. The ache of his overstrained eyes, even the soaring arena around him with thousands of spectators, retrovealities not worth thinking about. His instant panic was followed by a small sharp blow high on his chest. One minute, a voice said, and a time buzzer sounded. A minute is not a very large measure of time, and his body needed every fraction of it. The buzzers were, triggered his muscles into complete relaxation. Oily his heart and lungs worked on at a strong, measured rate. He was in reverie, sliding along the borders of consciousness. The contestants in the 20s needed undisturbed rest. Therefore, knights in the dormitories were as quiet as death. Particularly so, on this last night, when only two of the little cubicles were occupied, the thousands of others standing with dark empty doors. The other voice snapped with a harsh urgency, clearly used to command. I'm here because the matter is of utmost importance, and brand is the one I must see. Now stand aside. The twenty's he must have drawn his gun, because the intruder said quickly, but that away you're being a fool. Out there was silence then, and still wondering, Breon was once more asleep. Ten seconds he asked the handler who was needing his aching muscles. a red-haired mountain of a man with an apparently inexhaustible store of energy. There could be little art in this last and final round of fencing, just thrust and parry and victory to the stronger. Every man who entered the twenties had his own training tricks. There appeared to be an immediate association with the death trauma as if the two were andextricably linked into one. The strength that enables someone in a trance to hold his body stiff and unsupported except at two points, the head and heels. This is physically impossible when conscious. Others had died before during the twenties and death during the last round was, in some ways, easier than defeat. Breeding deeply, Breon's softly spoke the auto-hypnotic phrases that triggered the process. When the buzzer sounded, he pulled his foil from his second startled grasp and ran forward. I rolled the mazed at the sudden fury of the attack, then smiled. He thought it was the last burst of energy. He knew how close they both were to exhaustion. Breon saw something close to panic on his opponent's face when the man finally recognized his error. A wave of despair rolled out from our rogue, pre-inscented and new to fifth point was his. Then the powerful twist that's rest of the side, in and under the guard, because you were sleeping instead of conquering, the lovely rose princess has become a fiddle without a bow, while poor Shaggy sits there, a cooing dove. He has gone and gone for good, answered Polychrome, who had managed to squeeze into the room beside the dragon, and had witnessed the occurrences with much interest. I have remained a prisoner only because I wished to be one. And with this, he stepped forward and burst the stout chains as easily as if they had been threads. The little girl had been asleep, but she heard the wraps and opened the door. The king has flooded disgrace, and your friends are asking for you. I begged Ruggadot long ago to send him away, but he would not do so. I also offered to help your brother to escape, but he would not go. He eats and sleeps very steadily, replied the new king. I hope he doesn't work too hard, since Shaggy. He doesn't work at all. In fact, there's nothing he can do in these dominions, as well as our gnomes, whose numbers are so great that it worries us to keep them all busy. Not exactly, return Calico. Where is my brother now? choir-dshaggy, in the metal forest. Where is that? The metal forest is in the great domed cavern, the largest and all-ard dominions, replied Calico. Calico hesitated. However, if we look sharp, we may be able to discover one of these secret ways. Oh, no, I'm quite sure he didn't. That's funny, remarked Betsy thoughtfully. I don't believe and knew any magic, or she'd have worked it before. I do not know, confess shaggy. True, a great calico. Calico went to the big gong and pounded on it, just as Virgado used to do, but no one answered the summons. Having returned to the Royal Cavern, Calico first pounded the gong and then sat in the throne, wearing Virgados discarded Ruby Crown, and holding in his hand to scepter, which Virgado had so often thrown at his head. head."] EXPECTED_TEXT_2 = [" Mr. Quilter is the apostle of the middle classes, and we are glad to welcome his gospel. Nor is Mr. Quilter's manner less interesting than his matter. He tells us that at this festive season of the year, with Christmas and roast beef looming before us, similes drawn from eating and its results occur most readily to the mind. He has grave doubts whether Sir Frederick Layton's work is really Greek after all, and can discover in it but little of rocky Ithaca. Linnell's pictures are a sort of up-gards and atom paintings, and Mason's exquisite idles are as national as a jingo poem. Mr. Burkett Foster's landscapes smile at one much in the same way that Mr. Carker."] - EXPECTED_TEXT_3 = [" possible. Nor is Mr. Quilter's manner less interesting than his matter. He tells us that at this festive season of the year, with Christmas and roast beef looming before us, similes drawn from eating and its results occur most readily to the mind. He has grieved doubts whether Sir Frederick Layton's work is really greek after all, and can discover in it but little of rocky Ithaca. Linnell's pictures are a sort of up-guards and atom paintings, and Mason's exquisite idles are as national as a jingo poem. Mr. Birk at Foster's landscapes smile at one much in the same way that Mr. Carker used to flash his teeth. And Mr. John Collier gives his sitter a cheerful slap in the back, before he says, like a shampooer and a Turkish bath, next man, it is obviously unnecessary for us to point out how luminous these criticisms are, how delicate an expression. Under general principles of art, Mr. Quilter writes with equal lucidity. Painting, he tells us, is of a different quality to mathematics and finish in art is adding more effect. As for etchings, there are two kinds, British and foreign. He laments most bitterly the divorce that has been made between decorative art and what we usually call pictures. Mix a customary appeal to the last judgment and reminds us that in the great days of art Michelangelo was the furnishing upholsterer. Near the fire. any ornaments Fred brought home from India on the mental board. In fact, he is quite severe on Mr. Ruskin for not recognizing that a picture should denote the frailty of man, and remarks was pleasing courtesy in Felicitis Grace that many faces are feeling. Only, unfortunately, his own work never does get good. Mr. Quilter has missed his chance, for he has failed even to make himself the tupper of painting. By Harry Quilter M.A. A man said to the universe, Sir, I exist. Sweat-covered Breon's body trickling into the titling cloth that was the only german he wore. The cut on his chest still dripping blood. The ache of his overstrained eyes. Even to soaring arena around him with thousands of spectators, retrovealities not worth thinking about. His instant panic was followed by a small sharp blow high on his chest. One minute, a voice said, and a time buzzer sounded. A minute is not a very large measure of time, and his body needed every fraction of it. The buzzers were triggered as muscles into complete relaxation. Oily his heart and lungs worked on at a strong measured rate. He was in In reverie, sliding along the borders of consciousness. The contestants in the 20s needed undisturbed rest. Therefore, nights in the dormitories were as quiet as death. Particularly so, on this last night, when only two of the little cubicles were occupied, the thousands of others standing with dark empty doors. The other voice snapped with a harsh urgency clearly used to command. I'm here because the matter is of utmost importance, and brand is the one I must see. Now stand aside. The twenty's he must have drawn his gun, because the intruder said quickly, but that away you're being a fool. Out there was silence then, and still wondering, Breon was once more asleep. Ten seconds he asked the handler who was needing his aching muscles. a red-haired mountain of a man with an apparently inexhaustible store of energy. There could be little art in this last and final round of fencing, just thrust and parry and victory to the stronger. Every man who entered the twenties had his own training tricks. There appeared to be an immediate association with the death trauma as if the two were andextricably linked into one. The strength that enables someone in a trance to hold his body stiff and unsupported except at two points, the head and heels. This is physically impossible when conscious. Others had died before during the twenties and death during the last round was, in some ways, easier than defeat. Breeding deeply, Breon's softly spoke the auto-hypnotic phrases that triggered the process. When the buzzer sounded, he pulled his foil from his second startled grasp and ran forward. Our role looked amazed at the sudden fury of the attack, then smiled. He thought it was the last burst of energy. He knew how close they both were to exhaustion. Breon saw something close to panic on his opponent's face when the man finally recognized his error. A wave of despair rolled out from our rogue, re-insunced it and knew the fifth point was his. Then the powerful twist that's rest of the side, in and under the guard, because you were sleeping instead of conquering, the lovely rose princess has become a fiddle without a bow, while poor Shaggy sits there, a cooing dove. He has gone and gone for good, answered Polychrome, who had managed to squeeze into the room beside the dragon, and had witnessed the occurrences with much interest. I have remained a prisoner only because I wished to be one. And with this, he stepped forward and burst the stout chains as easily as if they had been threads. The little girl had been asleep, but she heard the wraps and opened the door. The king has fled and disgraced, and your friends are asking for you. I begged Ruggadot long ago to send him away, but he would not do so. I also offered to help your brother to escape, but he would not go. He eats and sleeps very steadily, replied the new king. I hope he doesn't work too hard, since Shaggy. He doesn't work at all. In fact, there's nothing he can do in these dominions as well as our gnomes, whose numbers are so great that it worries us to keep them all busy. Not exactly, we've turned Calico. Where is my brother now? quared shaggy. In the metal forest. Where is that? The metal forest is in the great domed cavern, the largest and all-ard dominions, replied Calico. Calico hesitated. However, if we look sharp, we may be able to discover one of these secret ways. Oh no, I'm quite sure he didn't. And that's funny, remarked Betsy thoughtfully. I don't believe Anne knew any magic, or she'd have worked it before. I do not know, confess Shaggy. True, a great calico. Calico went to the big gong and pounded on it, just as we're good to have used to do, but no one answered the summons. Having returned to the Royal Cavern, Calico first pounded the gong and then sat in the thrown wearing ruggedos discarded ruby crown and holding in his hand to septor which Ruggato had so often thrown at his head."] + EXPECTED_TEXT_3 = [" possible. Nor is Mr. Quilter's manner less interesting than his matter. He tells us that at this festive season of the year, with Christmas and roast beef looming before us, similes drawn from eating and its results occur most readily to the mind. He has grieved doubts whether Sir Frederick Layton's work is really greek after all, and can discover in it but little of rocky Ithaca. Linnell's pictures are a sort of up-guards and atom paintings, and Mason's exquisite idles are as national as a jingo poem. Mr. Birk at Foster's landscapes smile at one much in the same way that Mr. Carker used to flash his teeth. And Mr. John Collier gives his sitter a cheerful slap in the back, before he says, like a shampooer and a Turkish bath, next man, it is obviously unnecessary for us to point out how luminous these criticisms are, how delicate an expression. Under general principles of art, Mr. Quilter writes with equal lucidity. Painting, he tells us, is of a different quality to mathematics and finish in art is adding more effect. As for etchings, there are two kinds, British and foreign. He laments most bitterly the divorce that has been made between decorative art and what we usually call pictures. Mix a customary appeal to the last judgment and reminds us that in the great days of art Michelangelo was the furnishing upholsterer. Near the fire. any ornaments Fred brought home from India on the mental board. In fact, he is quite severe on Mr. Ruskin for not recognizing that a picture should denote the frailty of man, and remarks was pleasing courtesy in Felicitis Grace that many faces are feeling. Only, unfortunately, his own work never does get good. Mr. Quilter has missed his chance, for he has failed even to make himself the tupper of painting. By Harry Quilter M.A. A man said to the universe, Sir, I exist. Sweat-covered Breon's body trickling into the titling cloth that was the only german he wore. The cut on his chest still dripping blood. The ache of his overstrained eyes. Even to soaring arena around him with thousands of spectators, retrovealities not worth thinking about. His instant panic was followed by a small sharp blow high on his chest. One minute, a voice said, and a time buzzer sounded. A minute is not a very large measure of time, and his body needed every fraction of it. The buzzers were triggered as muscles into complete relaxation. Oily his heart and lungs worked on at a strong measured rate. He was in In reverie, sliding along the borders of consciousness. The contestants in the 20s needed undisturbed rest. Therefore, nights in the dormitories were as quiet as death. Particularly so, on this last night, when only two of the little cubicles were occupied, the thousands of others standing with dark empty doors. The other voice snapped with a harsh urgency clearly used to command. I'm here because the matter is of utmost importance, and brand is the one I must see. Now stand aside. The twenty's he must have drawn his gun, because the intruder said quickly, but that away you're being a fool. Out there was silence then, and still wondering, Breon was once more asleep. Ten seconds he asked the handler who was needing his aching muscles. a red-haired mountain of a man with an apparently inexhaustible store of energy. There could be little art in this last and final round of fencing, just thrust and parry and victory to the stronger. Every man who entered the twenties had his own training tricks. There appeared to be an immediate association with the death trauma as if the two were andextricably linked into one. The strength that enables someone in a trance to hold his body stiff and unsupported except at two points, the head and heels. This is physically impossible when conscious. Others had died before during the twenties and death during the last round was, in some ways, easier than defeat. Breeding deeply, Breon's softly spoke the auto-hypnotic phrases that triggered the process. When the buzzer sounded, he pulled his foil from his second startled grasp and ran forward. Our role looked amazed at the sudden fury of the attack, then smiled. He thought it was the last burst of energy. He knew how close they both were to exhaustion. Breon saw something close to panic on his opponent's face when the man finally recognized his error. A wave of despair rolled out from our rogue, re-insunced it and knew the fifth point was his. Then the powerful twist that's rest of the side, in and under the guard, because you were sleeping instead of conquering, the lovely rose princess has become a fiddle without a bow, while poor Shaggy sits there, a cooing dove. He has gone and gone for good, answered Polychrome, who had managed to squeeze into the room beside the dragon, and had witnessed the occurrences with much interest. I have remained a prisoner only because I wished to be one. And with this, he stepped forward and burst the stout chains as easily as if they had been threads. The little girl had been asleep, but she heard the wraps and opened the door. The king has fled and disgraced, and your friends are asking for you. I begged Ruggadot long ago to send him away, but he would not do so. I also offered to help your brother to escape, but he would not go. He eats and sleeps very steadily, replied the new king. I hope he doesn't work too hard, since Shaggy. He doesn't work at all. In fact, there's nothing he can do in these dominions as well as our gnomes, whose numbers are so great that it worries us to keep them all busy. Not exactly, we've turned Calico. Where is my brother now? quared shaggy. In the metal forest. Where is that? The metal forest is in the great domed cavern, the largest and all-ard dominions, replied Calico. Calico hesitated. However, if we look sharp, we may be able to discover one of these secret ways. Oh no, I'm quite sure he didn't. And that's funny, remarked Betsy thoughtfully. I don't believe Anne knew any magic, or she'd have worked it before. I do not know, confess Shaggy. True, a great calico. Calico went to the big gong and pounded on it, just as we're good to have used to do, but no one answered the summons. Having returned to the Royal Cavern, Calico first pounded the gong and then sat in the thrown wearing ruggedos discarded ruby crown and holding in his hand to septor which ruggedo had so often thrown at his head."] EXPECTED_TEXT_4 = [' Mr. Quilter is the apostle of the middle classes, and we are glad to welcome his gospel. Nor is Mr. Quilter\'s manner less interesting than his matter. He tells us that at this festive season of the year, with Christmas and roast beef looming before us, similes drawn from eating and its results occur most readily to the mind. He has grave doubts whether Sir Frederick Layton\'s work is really Greek after all, and can discover in it but little of rocky Ithaca. Linnell\'s pictures are a sort of up-gards and atom paintings, and Mason\'s exquisite idles are as national as a jingo poem. Mr. Birk at Foster\'s landscapes smile at one much in the same way that Mr. Carker used to flash his teeth. Mr. John Collier gives his sitter a cheerful slap in the back, before he says, like a shampoo or a Turkish bath. Next man, it is obviously unnecessary for us to point out how luminous these criticisms are, how delicate an expression. On the general principles of art, Mr. Quilter writes with equal lucidity. he tells us is of a different quality to mathematics, and finish in art is adding more effect. As for etchings, there are two kinds, British and foreign. He laments most bitterly the divorce that has been made between decorative art and what we usually call pictures. Makes the customary appeal to the last judgment and reminds us that in the great days of art Michelangelo was the furnishing upholsterer. Near the fire, any ornaments Fred brought home from India on the mantelboard. In fact, he is quite severe on Mr. Ruskin for not recognizing that a picture should denote the frailty of man. And remarks was pleasing courtesy in Felicitis Grace that many faces are feeling. Only, unfortunately, his own work never does get good. Mr. Quilter has missed his chance, for he has failed even to make himself the Tupper of painting. By Harry Quilter M.A. A man said to the universe, Sir, I exist. Sweat-covered Breon\'s body trickling into the tight-lowing cloth that was the only german he wore. The cut on his chest still dripping blood. The ache of his overstrained eyes, even the soaring arena around him with thousands of spectators, retrovealities not worth thinking about. His instant panic was followed by a small sharp blow high on his chest. One minute, a voice said, and a time buzzer sounded. A minute is not a very large measure of time, and his body needed every fraction of it. The buzzers were triggered his muscles into complete relaxation. Oli\'s heart and lungs worked on at a strong, measured rate. He was in reverie, sliding along the borders of consciousness. The contestants in the twenties needed undisturbed rest. Therefore, nights in the dormitories were as quiet as death. Particularly so, on this last night, when only two of the little cubicles were occupied, The thousands of others standing with dark empty doors. The other voice snapped with a harsh urgency, clearly used to command. I\'m here because the matter is of utmost importance, and brand is the one I must see. Now stand aside. The twenties, he must have drawn his gun because the intruder said quickly, but that away you\'re being a fool. out, through his silence then, and still wondering, Breon was once more asleep. Ten seconds, he asked the handler who was needing his aching muscles. A red-haired mountain of a man, with an apparently inexhaustible store of energy. There could be little art in this last and final round of fencing. Just thrust and parry, and victory to the stronger. man who entered the twenties had his own training tricks. They were appeared to be an immediate association with the death trauma, as if the two were inextricably linked into one. The strength that enables someone in a trance to hold his body stiff and unsupported except at two points, the head and heels. This is physically impossible when conscious. had died before during the 20s and death during the last round was in some ways easier than defeat. Breathing deeply, Breon\'s softly spoke the auto-hypnotic phrases that triggered the process. When the buzzer sounded, he pulled his foil from his second startled grasp and ran forward. Our role looked amazed at the sudden fury of the attack, then smiled. He thought it was the last burst of energy. He knew how close they both were to exhaustion. Breon saw something close to panic on his opponent\'s face when the man finally recognized his error. A wave of despair rolled out from our rogue. Breon sensed it and knew the fifth point was his. Then the powerful twist that\'s rested aside, in and under the guard, because he was sleeping instead of conquering, the lovely rose princess has become a fiddle without a bow, while poor Shaggy sits there, accooing dove. He has gone, and gone for good," answered Polychrom, who had managed to squeeze into the room beside the dragon, and had witnessed the occurrences with much interest. I have remained a prisoner only because I wished to be one. And with says he stepped forward and burst the stout chains as easily as if they had been threads. The little girl had been asleep, but she heard the wraps and opened the door. The king has flooded disgrace, and your friends are asking for you. I begged Ruggadot long ago to send him away, but he would not do so. I also offered to help your brother to escape, but he would not go. He eats and sleeps very steadily, replied the new king. I hope he doesn\'t work too hard, said Shaggy. He doesn\'t work at all. In fact, there\'s nothing he can do in these dominions as well as our gnomes, whose numbers are so great that it worries us to keep them all busy. Not exactly, we\'ve turned Calico. Where is my brother now, inquired Shaggy. In the metal forest. Where is that? The middle forest is in the great domed cavern, the largest and all-ard dominions, replied Calico. Calico hesitated. However, if we look sharp, we may be able to discover one of these secret ways. Oh no, I\'m quite sure he didn\'t. That\'s funny, remarked Betsy thoughtfully. I don\'t believe Anne knew any magic, or she\'d have worked it before. I do not know, confess Shaggy. True, agreed Calico. Calico went to the big gong and pounded on it just as Virgato used to do, but no one answered the summons. Having returned to the Royal Cavern, Calico first pounded the gong and then sat in the throne, wearing Virgato\'s discarded ruby crown and holding in his hand to scepter which reggative head so often thrown at his head.'] # fmt: on @@ -2138,18 +2227,62 @@ def test_whisper_longform_multi_batch(self): assert decoded_all[2:3] == EXPECTED_TEXT_3 assert decoded_all[3:4] == EXPECTED_TEXT_4 + @slow + def test_whisper_longform_multi_batch_prev_cond(self): + # fmt: off + EXPECTED_TEXT_1 = [" Mr. Quilters manner less interesting than his matter. He tells us that at this festive season of the year, with Christmas and roast beef looming before us, similarly drawn from eating and its results occur most readily to the mind. He has grave doubts whether Sir Frederick Layton's work is really Greek after all and can discover in it but little of Rocky Ithaca. The Nils, pictures are sort of upguards and atom paintings and Mason's exquisite itals are as national as a jingo poem. Mr. Berkett Foster's landscapes smile at one much in the same way that Mr. Carker used to flash his teeth. And Mr. John Collier gives his sitter a cheerful slap on the back before he says like a shampooer and a Turkish bath. Next man, it is obviously unnecessary for us to point out how luminous these criticisms are, how delicate and expression. On the general principles of art, Mr. Quilters writes with equal lucidity. Painting he tells us is of a different quality to mathematics and finish in art is adding more effect. As for etchings, there are of two kinds, British and foreign. He laments most bitterly the divorce that has been made between decorative art and what we usually call pictures makes a customary appeal to the last judgment and reminds us that in the great days of art Michelangelo was the furnishing apostorer. Near the fire, any ornaments Fred brought home from India on the mental board. In fact, he is quite severe on Mr. Ruskin, for not recognizing that a picture should denote the frailty of man. And remarks with pleasing courtesy and solicitous grace that many phases of feeling only, unfortunately, his own work never does get good. Mr. Quilters has missed his chance, for he has failed even to make himself the tougher of painting. My hair equal to M.A. A man said to the universe, Sir, I exist. Sweat covered Breon's body, trickling into the tight-wing cloth that was the only garment he wore. The cut on his chest still dripping blood. The ache of his overstrain dyes. Even the soaring arena around him with thousands of spectators, retrievalidies not worth thinking about. His instant panic was followed by a small sharp blow, high on his chest. One minute, a voice said, and a time buzzer sounded. A minute is not a very large measure of time, and his body needed every fraction of it. The buzzer's were triggered as muscles into complete relaxation. Only his heart and lungs worked on at a strong, measured rate. He was in reverie, sliding along the borders of consciousness. The contestants in the 20s needed undisturbed rest. Therefore, knights and the dormitories were as quiet as death. Particularly so, on this last night, when only two of the little cubicles were occupied, the thousands of others standing with dark empty doors. The other voice snapped with a harsh urgency, clearly used to command. I'm here because the matter is of utmost importance, and brand is the one I must see. Now stand aside. To 20s, he must have drawn his gun because the intruder said quickly, but that away, you're being a fool. Out, the resoundance then, and still wondering, Brienne was once more asleep. Ten seconds, he asked the handler who was needing his aching muscles. Our red-haired mountain of a man, with an apparently inexhaustible story of energy. There could be little art in this last and final round of fencing, just thrust and parry and victory to the stronger. Every man who entered the 20s had his own training tricks. There appeared to be an immediate association with the death trauma as if the two were inexplicably linked into one. This strengthened enables someone in a trance to hold his body stiff and unsupported, except at two points, the head and heels. This is physically impossible when conscious. Others had died before during the 20s, and death during the last round was, in some ways, easier than defeat. Breathing deeply, Brienne softly spoke the other hypnotic phrases that triggered the process. When the buzzer sounded, he pulled his foil from his second startled grasp and ran forward. I rolled the maze at the sudden fury of the attack, then smiled. He thought it was the last burst of energy. He knew how close they both were to exhaustion. Brienne saw something close to panic on his opponent's face when the man finally recognized his error. A wave of despair rolled out from our role. Brienne sensed it and knew the fifth point was his. Then the powerful twist that's right to the side, in and under the guard, because he was sleeping instead of conquering, the lovely rose princess has become a fiddle with a bow, while poor shaggy sits there, a cooling dove. He has gone and gone for good, answered polychrome, who had managed to squeeze into the room beside the dragon, and had witnessed the occurrences with much interest. I have remained a prisoner only because I wished to be one. And with this, he stepped forward and burst the stoutchanges as easily as if they had been threads. The little girl had been asleep, but she heard the wraps and opened the door. The king has fled in disgrace in your friends, they're asking for you. I begged Ruggano a long ago to send him away, but he would not do so. I also offered to help you run into escape, but he would not go. He eats and sleeps very steadily, replied the new king. I hope he doesn't work too hard since shaggy. He doesn't work at all. In fact, there's nothing he can do in these dominions, as well as our nooms, whose numbers are so great that it worries us to keep them all busy. Not exactly, we've turned Calico, whereas my brother now inquired shaggy in the metal forest. Where is that? The metal forest is in the great domed cavern, the largest and all our dominions replied Calico. Calico hesitated. However, if we look sharp, we may be able to discover one of these secret ways. Oh no, I'm quite sure he didn't. That's funny, remarked to Bedsey thoughtfully. I don't believe Anne knew any magic or she'd have worked before. I do not know, confessed shaggy. True, agreed Calico. Calico went to the big gong and pounded on it just as Ruggano used to do, but no one answered the summons. Having returned to the royal cavern, Calico first pounded the gong and then sat in the throne, wearing Ruggano's discarded ruby crown. And holding in his hand the scepter which Ruggano had so often thrown at his head."] + EXPECTED_TEXT_2 = [" Mr. Quilter is the apostle of the middle classes, and we are glad to welcome his gospel. Nor is Mr. Quilter's manner less interesting than his matter. He tells us that at this festive season of the year, with Christmas and roast beef looming before us, similarly drawn from eating and its results occur most readily to the mind. He has grave doubts whether Sir Frederick Latins' work is really Greek after all, and can discover in it but little of rocky Ithaca. Lennials, pictures are a sort of upguards and atom paintings, and Mason's exquisite idles are as national as a jingo poem. Mr. Berkett Foster's landscapes smile at one much in the same way that Mr. Carker"] + EXPECTED_TEXT_3 = [" gospel. Nor is Mr. Quilter's manner less interesting than his matter. He tells us that at this festive season of the year, with Christmas and roast beef looming before us, similarly drawn from eating in its results occur most readily to the mind. He has grave doubts whether Sir Frederick Latins work is really Greek after all and can discover in it but little of rocky ithaka. Lennils, pictures, are a sort of upguards and atom paintings and Mason's exquisite itals are as national as a jingo poem. Mr. Birkut Foster's landscapes smile at one much in the same way that Mr. Carker used to flash his teeth. And Mr. John Collier gives his sitter a cheerful slap on the back before he says like a shampooer and a Turkish bath. Next man, it is obviously unnecessary for us to point out how luminous these criticisms are, how delicate and expression. Under general principles of art, Mr. Quilter writes with equal lucidity. Painting he tells us is of a different quality to mathematics and finish in art is adding more effect. As for etchings, thereof two kinds, British and foreign. He laments most bitterly the divorce that has been made between decorative art and what we usually call pictures makes a customary appeal to the last judgment and reminds us that in the great days of art Michelangelo was the furnishing apostoror. Near the fire, any ornaments Fred brought home from India on the mental board. In fact, he is quite severe on Mr. Ruskin, for not recognizing that a picture should denote the frailty of man. And remarks with pleasing courtesy and falseness graced that many phases of feeling, only unfortunately his own work never does get good. Mr. Quilter has missed his chance, for he has failed even to make himself the tougher of painting. By Harry Quilter M.A. A man said to the universe, Sir, I exist. Sweat-covered Breon's body trickling into the tight-wing cloth that was the only garment you wore. The cut on his chest still dripping blood. The ache of his overstrained eyes. Even the soaring arena around him with thousands of spectators were trivealed, not worth thinking about. His instant panic was followed by a small sharp, blow high on his chest. One minute, a voice said, and a time buzzer sounded. A minute is not a very large measure of time, and his body needed every fraction of it. The buzzer's were triggered as muscles into complete relaxation. Only his heart and lungs worked on at a strong, measured rate. He was in reverie sliding out on the borders of consciousness. The contestants in the 20s needed undisturbed rest. Therefore, knights in the dormitories were as quiet as death. Particularly so, on this last night, when only two of the little cubicles were occupied, the thousands of others standing with dark empty doors. The other voice snapped with a harsh urgency, clearly used to command. I'm here because the matter is of utmost importance, and brand is the one I must see. Now stand aside. The 20s, he must have drawn his gun because the intruder said quickly, but that away, he'll be in the fool. Out, there is silence then, and still wondering, Brienne was once more asleep. Ten seconds, he asked the handler who was needing his aching muscles. A red-haired mountain of a man, with an apparently inexhaustible story of energy. There could be little art in this last and final round of fencing, just thrust and parry and victory to the stronger. Every man who entered the 20s had his own training tricks. There appeared to be an immediate association with the death trauma, as if the two were inextricably linked into one. The strength that enables someone in a trance to hold his body stiff and unsupported, except at two points, the head and heels. This is physically impossible when conscious. Others had died before during the 20s, and death during the last round was, in some ways, easier than defeat. Breathing deeply, Brienne softly spoke the autohydrotic phrases that triggered the process. When the buzzer sounded, he pulled his foil from his second startled grasp and ran forward. I rolled up the maze at the sudden fury of the attack, then smiled. He thought it was the last burst of energy. He knew how close they both were to exhaustion. Brienne saw something close to panic on his opponent's face when the man finally recognized his error. A wave of despair rolled out from our ol' Brienne sensed it and knew the fifth point was his. Then the powerful twist that's right to decide, in and under the guard, because he was sleeping instead of conquering, the lovely rose princess has become a fiddle with a bow, while poor shaggy sits there, a cooling dove. He has gone and gone for good, answered polychrome, who had managed to squeeze into the room beside the dragon, and had witnessed the occurrences with much interest. I have remained a prisoner only because I wished to be one. And with this, he stepped forward and burst the stoutchains as easily as if they had been threads. The little girl had been asleep, but she heard the wraps and opened the door. The king has fled in disgrace in your friends, they're asking for you. I begged Ruggano a long ago to send him away, but he would not do so. I also offered to help your brother to escape, but he would not go. He eats and sleeps very steadily, replied the new king. I hope he doesn't work too hard since shaggy. He doesn't work at all. In fact, there's nothing he can do in these dominions as well as our nooms, whose numbers are so great that it worries us to keep them all busy. Not exactly, we've turned Calico. Whereas my brother now, in Quaragejjegi, in the metal forest. Where is that? The metal forest is in the great Dome to Cavern, the largest and all our dominions, replied Calico. Calico hesitated. However, if we look sharp, we may be able to discover one of these secret ways. Oh no, I'm quite sure he didn't. That's funny remarked by the bad sea thoughtfully. I don't believe Anne knew any magic or she'd have worked it before. I do not know, confessed shaggy. True, a great Calico. Calico went to the big gong and pounded on it, just as we're good or used to do, but no one answered the summons. Having returned to the royal cavern, Calico first pounded the gong and then sat in the throne, wearing reggos, discarded ruby crown, and holding in his hand to scepter which reggos had so often thrown at his head."] + EXPECTED_TEXT_4 = [" Mr. Quilter is the apostle of the middle classes, and we are glad to welcome his gospel. Nor is Mr. Quilter's manner less interesting than his matter. He tells us that at this festive season of the year, with Christmas and roast beef looming before us, similarly drawn from eating and its results occur most readily to the mind. He has grave doubts whether Sir Frederick Latins' work is really Greek after all, and can discover in it but little of rocky Ithaca. Lennils, pictures, are a sort of upguards and atom paintings, and Mason's exquisite idles are as national as a jingo poem. Mr. Berkett Foster's landscapes smile at one much in the same way that Mr. Carker used to flash his teeth. And Mr. John Collier gives his sitter a cheerful slap on the back before he says, like a shampooer in a Turkish bath. Next man, it is obviously unnecessary for us to point out how luminous these criticisms are, how delicate and expression. On the general principles of art, Mr. Quilter writes with equal lucidity. Painting he tells us is of a different quality to mathematics, and finish in art is adding more effect. As for etchings, thereof two kinds, British and foreign. He laments most bitterly the divorce that has been made between decorative art and what we usually call pictures makes a customary appeal to the last judgment and reminds us that in the great days of art Michelangelo was the furnishing apostorer. Near the fire, any ornaments Fred brought home from India on the mental board. In fact, he is quite severe on Mr. Ruskin, for not recognizing that a picture should denote the frailty of man. And remarks with pleasing courtesy and solicitous grace that many phases of feeling only, unfortunately, his own work never does, get good. Mr. Quilter has missed his chance, for he has failed even to make himself the tougher of painting. By Harry Quilter, M.A. A man said to the universe, Sir, I exist. Sweat covered Breon's body, trickling into the tight-wing cloth that was the only garment you wore. The cut on his chest still dripping blood. The ache of his overstrained eyes, even the soaring arena around him with thousands of spectators were trivialities not worth thinking about. His instant panic was followed by a small sharp blow, high on his chest. One minute, a voice said, and a time buzzer sounded. A minute is not a very large measure of time, and his body needed every fraction of it. The buzzer's were triggered as muscles into complete relaxation. Only his heart and lungs worked on at a strong, measured rate. He was in reverie, sliding along the borders of consciousness. The contestants in the 20s needed undisturbed rest. Therefore, knights and the dormitories were as quiet as death. Particularly so, on this last night, when only two of the little cubicles were occupied, the thousands of others standing with dark empty doors. The other voice snapped with a harsh urgency, clearly used to command. I'm here because the matter is of utmost importance, and brand is the one I must see. Now stand aside. To 20s, he must have drawn his gun because the intruder said quickly, but that away, you're being a fool. Out, there is silence then, and still wondering, Brienne was once more asleep. Ten seconds, he asked the handler who was needing his aching muscles. I've read here at Mountain of a Man, with an apparently inexhaustible story of energy. There could be little art in this last and final round of fencing. Just thrust and parry and victory to the stronger. Every man who entered the 20s had his own training tricks. There appeared to be an immediate association with the death trauma, as if the two were inexplicably linked into one. Just strengthed and enabled someone in a trance to hold his body stiff and unsupported, except at two points, the head and heels. This is physically impossible when conscious. Others had died before during the 20s, and death during the last round was, in some ways, easier than defeat. Breathing deeply, Brienne softly spoke the autohydrotic phrases that triggered the process. When the buzzer sounded, he pulled his foil from his second startled grasp and ran forward. I rolled up the maze at the sudden fury of the attack, then smiled. He thought it was the last burst of energy. He knew how close they both were to exhaustion. Brienne saw something close to panic on his opponent's face when the man finally recognized his error. A wave of despair rolled out from our ol' Brienne sensed it and knew the fifth point was his. Then the powerful twist that's right to the side, in and under the guard, because he was sleeping instead of conquering, the lovely rose princess has become a fiddle with a bow, while poor shaggy sits there, a cooling dove. She has gone and gone for good, answered polychrome, who had managed to squeeze into the room beside the dragon, and had witnessed the occurrences with much interest. I have remained a prisoner only because I wished to be one. And with this, he stepped forward and burst the stoutchanges as easily as if they had been threads. The little girl had been asleep, but she heard the wraps and opened the door. The king has fled in disgrace and your friends are asking for you. I begged Ruggano a long ago to send him away, but he would not do so. I also offered to help your brother to escape, but he would not go. He eats and sleeps very steadily, replied the new king. I hope he doesn't work too hard since shaggy. He doesn't work at all. In fact, there's nothing he can do in these dominions as well as our nooms, whose numbers are so great that it worries us to keep them all busy. Not exactly, we've turned Calico. Where is my brother now, in Quaragejji, in the metal forest? Where is that? The metal forest is in the great Dome to Cavern, the largest and all our dominions replied Calico. Calico hesitated. However, if we look sharp, we may be able to discover one of these secret ways. Oh no, I'm quite sure he didn't. That's funny, remarked a bit, see you thoughtfully. I don't believe Anne knew any magic or she'd have worked it before. I do not know, confessed shaggy. True, agreed Calico. Calico went to the big gong and pounded on it just as we're good we used to do, but no one answered the summons. Having returned to the royal cavern, Calico first pounded the gong and then sat in the throne, wearing reggos, discarded ruby crown and holding it his hand to scepter which reggo had so often thrown at his head."] + # fmt: on + + processor = WhisperProcessor.from_pretrained("openai/whisper-tiny") + model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-tiny") + model = model.to("cuda") + + ds = load_dataset("patrickvonplaten/librispeech_asr_dummy", "clean") + one_audio = np.concatenate([x["array"] for x in ds["validation"]["audio"]], dtype=np.float32) + audios = [] + audios.append(one_audio[110000:]) + audios.append(one_audio[:800000]) + audios.append(one_audio[80000:]) + audios.append(one_audio[:]) + + gen_kwargs = { + "return_timestamps": True, + "no_speech_threshold": 0.6, + "temperature": 0.0, + "compression_ratio_threshold": 1.35, + "condition_on_prev_tokens": True, + "logprob_threshold": -1.0, + } + + decoded_single = [] + for audio in audios: + inputs = processor(audio, return_tensors="pt", truncation=False) + inputs = inputs.to(device="cuda") + + result = model.generate(**inputs, **gen_kwargs) + decoded_single.append(processor.batch_decode(result, skip_special_tokens=True)) + + # exact match + assert decoded_single[0] == EXPECTED_TEXT_1 + assert decoded_single[1] == EXPECTED_TEXT_2 + assert decoded_single[2] == EXPECTED_TEXT_3 + assert decoded_single[3] == EXPECTED_TEXT_4 + @slow def test_whisper_longform_multi_batch_hard(self): # fmt: off EXPECTED_TEXT = [ - " Folks, if you watch the show, you know, I spent a lot of time right over there. Patiently and astutely scrutinizing the boxwood and mahogany chest set of the day's biggest stories developing the central headline pawns, definitely maneuvering an oso topical night to F6, fainting a classic Sicilian, nade door variation on the news, all the while seeing eight moves deep and patiently marshalling the latest press releases into a fisher's shows in Lip Nitsky attack that culminates in the elegant lethal slow-played, all-passant checkmate that is my nightly monologue. But sometimes, sometimes, folks, I. CHEERING AND APPLAUSE Sometimes I startle away, cubside down in the monkey bars of a condemned playground on a super fun site. Get all hept up on goofballs. Rummage that were discarded tag bag of defective toys. Yank out a fist bowl of disembodied doll limbs, toss them on a stained kid's place mat from a defunct dennies. set up a table inside a rusty cargo container down by the Wharf and challenged toothless drifters to the godless bughouse blitz of tournament that is my segment. Meanwhile!", + " Folks, if you watch the show, you know, I spent a lot of time right over there. Patiently and astutely scrutinizing the boxwood and mahogany chest set of the day's biggest stories developing the central headline pawns, definitely maneuvering an oso topical night to F6, fainting a classic Sicilian, nade door variation on the news, all the while seeing eight moves deep and patiently marshalling the latest press releases into a fisher's shows in Lip Nitsky attack that culminates in the elegant lethal slow-played, all-passant checkmate that is my nightly monologue. But sometimes, sometimes, folks, I. CHEERING AND APPLAUSE Sometimes I startle away, cubside down in the monkey bars of a condemned playground on a super fun site. Get all hept up on goofballs. Rummage that were discarded tag bag of defective toys. Yank out a fist bowl of disembodied doll limbs, toss them on a stained kid's place mat from a defunct dennies. set up a table inside a rusty cargo container down by the Wharf and challenged toothless drifters to the godless bughouse blitz of tournament that is my segment. Meanwhile.", " Folks, I spend a lot of time right over there, night after night after night, actually. Carefully selecting for you the day's noosiest, most aerodynamic headlines, stress testing, and those topical anti-lock breaks and power steering, painstakingly stitching, leather seating so soft, it would make JD power and her associates blush to create the luxury sedan that is my nightly monologue. But sometimes, you sometimes, folks. I lurched a consciousness in the back of an abandoned school and slap myself awake with a crusty floor mat. Before using a mouse-bitten timing belt to strap some old plywood to a couple of discarded oil drums, then by the light of a heathen moon, render a gas tank out of an empty big gulp, fill with white claw and denatured alcohol, then light a match and let her rip and the demented one man soapbox derby of news that is my segment. Me, Guadalupe! No!", " Ladies and gentlemen, you know, I spent a lot of time right over there Raising the finest Holstein news cattle firmly yet tenderly milking the latest headlines from their jokes swollen teats Churning the daily stories into the decadent proven-style style triple cream breed that is my nightly monologue But sometimes sometimes folks I stagger home hungry after being released by the police and Root around in the neighbor's trash can for an old milk carton scrape out the blooming dairy residue into the remains of a wet cheese rod I won from a rat in a pre-donned street fight. Put it in a discarded paint can to leave it to ferment next to a trash fire then hunker down and hallucinate while eating the listeria laden demon custard of news that is my segment. You mean one of them.", " Folks, if you watch this show, you know I spend most of my time right over there carefully sorting through the day's biggest stories and selecting only the most subtle and unblemished ostrich and crocodile news leather, which I then entrust to artisan graduates of the Ichol Gregoire Ferrandi, who carefully dye them in a palette of bright zesty shades and adorn them in the finest and most topical inlay work using hand tools and double magnifying glasses, then assemble them according to now classic and elegant geometry using our signature saddles stitching. In line it with bees, wax, coated linen, finely attached a mallet, hammered strap, pearled hardware, and close-shit to create for you the one-of-a-kind hoke couture, Erme's Birkin bag that is my monologue. But sometimes, sometimes folks, sometimes. Sometimes I wake up in the last car of an abandoned roller coaster at Coney Island where I'm I'm hiding from the triads. I have some engine lubricants out of a safe way bag and stagger down the shore to tear the sail off a beach schooner. Then I rip the coaxial cable out of an RV and elderly couple from Utah, Hank, and Mabel lovely folks. And use it to stitch the sail into a loose pouch like a rock sack. And I stow away in the back of a garbage truck to the junkyard where I pick through to the debris for only the broken toys that make me the saddest until I have loaded for you. The Hobo Fugitives bug out, bindle of news that is my segment. Me one!", " You know, folks, I spent a lot of time crafting for you a bespoke playlist of the day's biggest stories right over there. Meticulously selecting the most topical chakra affirming scented candles, and using Feng Shui to perfectly align the joke energy in the exclusive boutique yoga retreat that is my monologue. But sometimes just sometimes I go to the dumpster behind the waffle house at three in the morning, take off my shirt, cover myself, and used fry oil, wrap my hands with some double-duct tape by stole from the broken car window. Pound a six-pack of blueberry hard-seltzer and a sack of pills I stole from a parked ambulance. Then arm wrestle a raccoon in the back alley vision quest of news that is my segment. Meanwhile!", " You know, folks, I spend most of my time right over there. Mining the day's biggest, most important stories, collecting the finest, most topical iron or hand hammering it into joke panels. Then I craft sheets of bronze and blazing with patterns that tell an epic tale of conquest and glory. Then, using the Germanic tradition press-black process, I place thin sheets of foil against the scenes and by hammering or otherwise applying pressure from the back, I project these scenes into a pair of cheat cards in a faceplate and, finally, using fluted strips of white alloyed molding, I divide the designs into framed panels and hold it all together using bronze rivets to create the beautiful and intimidating, Anglo-Saxon battle helm that is my nightly monologue. Sometimes, sometimes folks. Sometimes, just sometimes, I come into my sense as fully naked on the deck of a pirate besieged melee container ship that picked me up floating on the detached door of a portapotty in the Indian Ocean. Then after a sunstroke-induced realization of the crew of this ship plans to sell me an exchange for a bag of oranges to fight off scurvy, I lead a mutiny using only a PVC pipe at a pool chain that accepting my new role as Captain and declaring myself king of the windarc seas. I grab a dirty mop bucket covered in barnacles and adorn it with the teeth of the vanquished to create the sopping wet pirate crown of news that is my segment. Meanwhile!", " Folks, if you watch this show, you know I spend most of my time right over there carefully blending for you the day's Newsiest most topical flower eggs milk and butter and Stranding into a fine batter to make delicate and informative comedy pancakes Then I glaze them in the juice and zest of the most relevant midnight Valencia oranges and douse it all and a fine Dela main de voyage cognac Before prom baying and basting them tables. I deserve for you the James Beard award worthy crepe suzzette That is my nightly monologue, but sometimes just sometimes folks. I wake up in the baggage hold of Greyhound bus. It's being hoisted by the scrap yard claw toward the burn pit. Escape to a nearby abandoned price chopper where I scrounge for old bread scraps and busted open bags of starfruit candies and expired eggs. Chuck it all on a dirty hubcap and slap it over a tire fire before using the legs of a strain, pair of sweatpants and as oven mitts to extract and serve the demented transience poundcake of news that is my segment. Me, Guadalupe!", - " Folks, if you watched the show and I hope you do, I spent a lot of time right over there. Tiredlessly studying the lineage of the days most important thoroughbred stories and whole-stiner headlines, working with the best trainers, money can buy to rear their comedy offspring with a hand that is stern yet gentle into the triple crown winning equine specimen. That is my nightly monologue, but sometimes, sometimes, folks, I break into an unincorporated veterinary genetics lab and grab whatever test tubes I can find and then under a grow light I got from a discarded chia pet. I mixed the pilfered DNA of a horse and whatever was in a tube labeled Keith Colan extra. Slurrying the concoction with caffeine pills and a microwave red bull, I screamed, sang a prayer to Janice, initiator of human life and God of transformation as a half horse, half man, freak. Seizes to life before me and the hideous collection of loose animal parts and corrupted man tissue that is my segment. Meanwhile!", + " Folks, if you watched the show and I hope you do, I spent a lot of time right over there. Tiredlessly studying the lineage of the days most important thoroughbred stories and whole-stiner headlines, working with the best trainers, money can buy to rear their comedy offspring with a hand that is stern yet gentle into the triple crown winning equine specimen. That is my nightly monologue, but sometimes, sometimes, folks, I break into an unincorporated veterinary genetics lab and grab whatever test tubes I can find and then under a grow light I got from a discarded chia pet. I mixed the pilfered DNA of a horse and whatever was in a tube labeled Keith Colan extra. Slurrying the concoction with caffeine pills and a microwave red bull, I screamed, sang a prayer to Janice, initiator of human life and God of transformation as a half horse, half man, freak. Seizes to life before me and the hideous collection of loose animal parts and corrupted man tissue that is my segment. Meanwhile!" ] # fmt: on @@ -2185,6 +2318,55 @@ def test_whisper_longform_multi_batch_hard(self): assert decoded_all[i] == decoded_single[i] assert decoded_all[i] == EXPECTED_TEXT[i] + @slow + def test_whisper_longform_multi_batch_hard_prev_cond(self): + # fmt: off + EXPECTED_TEXT = [ + " Folks, if you watch the show, you know I spent a lot of time right over there. Patiently and astutely scrutinizing the boxwood and mahogany chest set of the day's biggest stories, developing the central headline pawns, definitely maneuvering an oh-so-topical night to F6, faming of classic Sicilian, named or variation on the news, all the while seeing eight moves deep and patiently marshalling the latest press releases into a Fisher show's in lip-nitsky attack that culminates in the elegant lethal slow played all pass on checkmate that is my nightly monologue, but sometimes sometimes folks I sometimes I start a little wake-up side down in the monkey bars of a condemned playground on a super fun site, get all hepped up on goofballs, rummage that would discard a tag bag of defective toys, yank out a fistball of disembodied doll limbs, toss them on a stain kid's place mad from a defunked denies, set up a table inside a rusty cargo container down by the warf and challenge toothless drifters to the godless bughouse blitz of tournament that is my segment.", + " Folks, I spent a lot of time right over there night after night, actually. Carefully selecting for you the day's newsiest, most aerodynamic headlines, stress testing on those topical anti-lock breaks and power steering, painstakingly stitching, leather seating, so soft, it would make JD power and her associates blush. To create the luxury sedan that is my nightly monologue, but sometimes I just sometimes focus. I lurched to consciousness in the back of an abandoned school bus and slapped myself awake with a crusty floor mat. Before using a mouse-bitten timing belt to strap some old plywood to a couple of discarded oil drums, then by the light of a heathen-moon render a gas tank out of an empty big gulp, filled with white claw and de-natured alcohol, then light a match, letter-ripping the dis-mented one-man soapbox derby of news that is my segment.", + " Ladies and gentlemen, you know, I spent a lot of time right over there, raising the finest hosting news cattle firmly, yet tenderly milking the latest headlines from their jokes, swollen teats, churning the daily stories into the decadent Provincil style triple cream-breed. It is my nightly monologue, but sometimes sometimes I stagger home hungry after being released by the police and root around in the neighbors trash can for an old milk carton scrape out the blooming dairy residue into the remains of a wet cheese rind I won from a rat and a pre-drawn street fight. Put it into discarded paint can to leave it to ferment next to a trash fire than a hunker down in hallucinate while eating the lusteria latent demon custard of news that is my segment.", + " Folks, you watched this show, you know I spend most of my time right over there, carefully sorting through the days, big stories, and selecting only the most subtle, and unblemished ostrich and crocodile news leather, which I then entrust to artisan graduates of the Ickel Greg Waferandi, who carefully died them in a pallet of bright, zesty shades, and adorn them in the finest most topical inlay work, using hand tools and double magnifying glasses, then assemble them according to now classic and elegant geometry using our signature saddle stitching, and line it with bees, wax, coated linen, and finally attach a mallet hammered strap, perled hardware, and close-shet to create for you the one of a kind hope, kutur, earn-may is burkin bag that is my monologue, but sometimes, sometimes, sometimes. Sometimes, sometimes I wake up in the last car of an abandoned roller coaster at Kony Island, where I'm hiding from the triads, I have some engine lubricants out of a safe way bag and staggered down the shore to tear the sail off a beach sooner than I ripped the coaxial cable out of an RV and elderly couple from Utah, Hank, and Mabel Lovelyfokes, and use it to stitch the sail into a loose pouch like rock sack, and I stole a bag of a garbage truck to the junkyard, where I picked through to the debris for only the broken toys that make me the saddest, until I have loaded for you. The hobo fugitives bug out Bindle of news that is my segment.", + " You know, folks, I spent a lot of time crafting for you a bespoke playlist of the day's big stories right over there. meticulously selecting the most topical chakra affirming scented candles, using Feng Shui, to perfectly align the joke energy in the exclusive boutique yoga retreat that is my monologue, but sometimes just sometimes, I go to the dumpster behind the waffle house at three in the morning, take off my shirt, cover myself and use fry oil, wrap my hands and some old duct tape I stole from a broken car window, pound a six pack of blueberry hard-seller and a second pill, as I stole from a park damsel, and it's then arm wrestle a raccoon in the back alley vision quest of news that is my segment.", + " You know, folks, I spend most of my time right over there. Mining the days, biggest, most important stories, collecting the finest, most topical iron or hand hammering it into joke panels, then I craft sheets of bronze and blazing with patterns that tell an epic tale of conquest and glory. Then, using the Germanic tradition press, black process, I place thin sheets of foil against the scenes and by hammering or otherwise applying pressure from the back, I project these scenes into a pair of cheat cards and a face plate, and finally using fluted strips of white alloyed molding I divide the designs into framed panels and hold it all together using bronze rivets to create the beautiful and intimidating Anglo-Saxon battle helm that is my nightly monologue. Sometimes, sometimes, folks. Sometimes, just sometimes, I come to my senses fully naked on the deck of a pirate, beceived, melee, container ship that picked me up floating on the detainees. Then after I sunstroke in juice, realization of the crew of this ship plans to sell me and exchange for a bag of oranges to fight off scurvy, I lead a mutiny using only a PVC pipe in a pool chain that accepting my new role as captain and declaring myself king of the wind arc seas. I grab a dirty muck bucket covered in barnacles and a dornet with the teeth of the vanquished to create the softening wet pirate crown of news that is my segment. I'm going to use the white paper to create the softened white paper to create the softened white paper to create the softened white pirate crown of news that is my segment. Meanwhile.", + " Folks, if you watch this show, you know I spend most of my time right over there carefully blending for you the day's newsiest, most topical flower eggs, milk and butter. And straining into a fine batter to make delicate and informative comedy pancakes, then I glaze them in the juice and zest of the most relevant midnight valencio oranges. And doubts at all, and I find delimane de voyage cognac, before from bang and basting them tables, I deserve you the James Beard Award worthy creeps to ZET. That is my nightly monologue, but sometimes sometimes folks I wake up in the baggage hole of Greyhound bus, it's being hoisted by the scrapyard claw toward the burn pit. Escape to a nearby abandoned price chopper where I scrounge for old bread scraps, busted open bags of starfruit candies and expired eggs. Chuck it all on a dirty hubcap and slap it over a tire fire before using the legs of a strained pair of sweatpants and as ovenmets to extract and serve the demented transients pound cake of news that is my segment. Me wild!", + " Folks, if you watch the show and I hope you do, I spend a lot of time right over there. Tirelessly studying the lineage of the day's most important thoroughbred stories and whole-stiner headlines, working with the best trainers money can buy to rear their comedy offspring with a hand that is stern yet gentle into the triple crown winning equine specimen that is my nightly monologue. But sometimes sometimes folks I break into an unincorporated veterinary genetics lab. And grab whatever test tubes I can find and then under a grow light I got from it a discarded chia pet. I mixed the pill for DNA of a horse and whatever was in a tube labeled Keith Cole and extra. Sloering the concoction with caffeine pills and a microwave bread bowl, I screamed sing a prayer to Janice initiator of human life and God of transformation as a half horse, half man freak, seasons to life before me. And the hideous collection of loose animal parts and corrupted men tissue that is my segment. Meanwhile.", + ] + # fmt: on + + processor = WhisperProcessor.from_pretrained("openai/whisper-tiny") + model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-tiny") + model = model.to("cuda") + + ds = load_dataset("distil-whisper/meanwhile", "default")["test"] + ds = ds.cast_column("audio", Audio(sampling_rate=16000)) + + num_samples = 8 + + audio = ds[:num_samples]["audio"] + audios = [x["array"] for x in audio] + + inputs = processor( + audios, return_tensors="pt", truncation=False, padding="longest", return_attention_mask=True + ) + inputs = inputs.to(device="cuda") + + gen_kwargs = { + "return_timestamps": True, + "no_speech_threshold": 0.6, + "temperature": (0.0, 0.2, 0.4, 0.6, 0.8, 1.0), + "compression_ratio_threshold": 1.35, + "condition_on_prev_tokens": True, + "logprob_threshold": -1.0, + "num_beams": 5, + } + + torch.manual_seed(0) + result = model.generate(**inputs, **gen_kwargs) + decoded_all = processor.batch_decode(result, skip_special_tokens=True) + + for i in range(num_samples): + assert decoded_all[i] == EXPECTED_TEXT[i] + def prepare_whisper_encoder_inputs_dict(config, input_features, head_mask=None): if head_mask is None: diff --git a/tests/pipelines/test_pipelines_automatic_speech_recognition.py b/tests/pipelines/test_pipelines_automatic_speech_recognition.py index 1aae29e5d45bec..7b6a9f30c55ac9 100644 --- a/tests/pipelines/test_pipelines_automatic_speech_recognition.py +++ b/tests/pipelines/test_pipelines_automatic_speech_recognition.py @@ -1152,7 +1152,7 @@ def test_with_local_lm_fast(self): @slow def test_whisper_longform(self): # fmt: off - EXPECTED_RESULT = """ Folks, if you watch the show, you know, I spent a lot of time right over there. Patiently and astutely scrutinizing the boxwood and mahogany chest set of the day's biggest stories developing the central headline pawns, definitely maneuvering an oso topical night to F6, fainting a classic Sicilian, nade door variation on the news, all the while seeing eight moves deep and patiently marshalling the latest press releases into a fisher's shows in Lip Nitsky attack that culminates in the elegant lethal slow-played, all-passant checkmate that is my nightly monologue. But sometimes, sometimes, folks, I. CHEERING AND APPLAUSE Sometimes I startle away, cubside down in the monkey bars of a condemned playground on a super fun site. Get all hept up on goofballs. Rummage that were discarded tag bag of defective toys. Yank out of fist bowl of disembodied doll limbs, toss them on a stained kid's place mat from a defunct denny's, set up a table inside a rusty cargo container down by the Wharf and challenged toothless drifters to the godless bughouse blitz of tournament that is my segment. Meanwhile!""" + EXPECTED_RESULT = """ Folks, if you watch the show, you know, I spent a lot of time right over there. Patiently and astutely scrutinizing the boxwood and mahogany chest set of the day's biggest stories developing the central headline pawns, definitely maneuvering an oso topical night to F6, fainting a classic Sicilian, nade door variation on the news, all the while seeing eight moves deep and patiently marshalling the latest press releases into a fisher's shows in Lip Nitsky attack that culminates in the elegant lethal slow-played, all-passant checkmate that is my nightly monologue. But sometimes, sometimes, folks, I. CHEERING AND APPLAUSE Sometimes I startle away, cubside down in the monkey bars of a condemned playground on a super fun site. Get all hept up on goofballs. Rummage that were discarded tag bag of defective toys. Yank out a fist bowl of disembodied doll limbs, toss them on a stained kid's place mat from a defunct dennies. set up a table inside a rusty cargo container down by the Wharf and challenged toothless drifters to the godless bughouse blitz of tournament that is my segment. Meanwhile.""" # fmt: on processor = AutoProcessor.from_pretrained("openai/whisper-tiny.en")