diff --git a/docs/source/en/index.mdx b/docs/source/en/index.mdx index 87a2f09403e2f0..18cbce20d73bde 100644 --- a/docs/source/en/index.mdx +++ b/docs/source/en/index.mdx @@ -390,7 +390,7 @@ Flax), PyTorch, and/or TensorFlow. | Wav2Vec2 | ✅ | ❌ | ✅ | ✅ | ✅ | | Wav2Vec2-Conformer | ❌ | ❌ | ✅ | ❌ | ❌ | | WavLM | ❌ | ❌ | ✅ | ❌ | ❌ | -| Whisper | ✅ | ❌ | ✅ | ✅ | ❌ | +| Whisper | ✅ | ❌ | ✅ | ✅ | ✅ | | X-CLIP | ❌ | ❌ | ✅ | ❌ | ❌ | | XGLM | ✅ | ✅ | ✅ | ✅ | ✅ | | XLM | ✅ | ❌ | ✅ | ✅ | ❌ | diff --git a/docs/source/en/model_doc/auto.mdx b/docs/source/en/model_doc/auto.mdx index b39920151db424..9df4fa9c995d78 100644 --- a/docs/source/en/model_doc/auto.mdx +++ b/docs/source/en/model_doc/auto.mdx @@ -286,6 +286,10 @@ The following auto classes are available for the following audio tasks. [[autodoc]] TFAutoModelForSpeechSeq2Seq +### FlaxAutoModelForSpeechSeq2Seq + +[[autodoc]] FlaxAutoModelForSpeechSeq2Seq + ### AutoModelForAudioXVector [[autodoc]] AutoModelForAudioXVector diff --git a/docs/source/en/model_doc/whisper.mdx b/docs/source/en/model_doc/whisper.mdx index 4b7a6028618427..26d4ef8af91306 100644 --- a/docs/source/en/model_doc/whisper.mdx +++ b/docs/source/en/model_doc/whisper.mdx @@ -79,3 +79,14 @@ The original code can be found [here](https://github.com/openai/whisper). [[autodoc]] TFWhisperForConditionalGeneration - call + + +## FlaxWhisperModel + +[[autodoc]] FlaxWhisperModel + - __call__ + +## FlaxWhisperForConditionalGeneration + +[[autodoc]] FlaxWhisperForConditionalGeneration + - __call__ diff --git a/src/transformers/__init__.py b/src/transformers/__init__.py index e8cef7dfe4e2be..7adda898e8bc66 100755 --- a/src/transformers/__init__.py +++ b/src/transformers/__init__.py @@ -3293,6 +3293,7 @@ "FLAX_MODEL_FOR_QUESTION_ANSWERING_MAPPING", "FLAX_MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING", "FLAX_MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING", + "FLAX_MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING", "FLAX_MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING", "FLAX_MODEL_FOR_VISION_2_SEQ_MAPPING", "FLAX_MODEL_MAPPING", @@ -3306,6 +3307,7 @@ "FlaxAutoModelForQuestionAnswering", "FlaxAutoModelForSeq2SeqLM", "FlaxAutoModelForSequenceClassification", + "FlaxAutoModelForSpeechSeq2Seq", "FlaxAutoModelForTokenClassification", "FlaxAutoModelForVision2Seq", ] @@ -3489,6 +3491,13 @@ _import_structure["models.wav2vec2"].extend( ["FlaxWav2Vec2ForCTC", "FlaxWav2Vec2ForPreTraining", "FlaxWav2Vec2Model", "FlaxWav2Vec2PreTrainedModel"] ) + _import_structure["models.whisper"].extend( + [ + "FlaxWhisperForConditionalGeneration", + "FlaxWhisperModel", + "FlaxWhisperPreTrainedModel", + ] + ) _import_structure["models.xglm"].extend( [ "FlaxXGLMForCausalLM", @@ -6208,6 +6217,7 @@ FLAX_MODEL_FOR_QUESTION_ANSWERING_MAPPING, FLAX_MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING, FLAX_MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING, + FLAX_MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING, FLAX_MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING, FLAX_MODEL_FOR_VISION_2_SEQ_MAPPING, FLAX_MODEL_MAPPING, @@ -6221,6 +6231,7 @@ FlaxAutoModelForQuestionAnswering, FlaxAutoModelForSeq2SeqLM, FlaxAutoModelForSequenceClassification, + FlaxAutoModelForSpeechSeq2Seq, FlaxAutoModelForTokenClassification, FlaxAutoModelForVision2Seq, ) @@ -6356,6 +6367,7 @@ FlaxWav2Vec2Model, FlaxWav2Vec2PreTrainedModel, ) + from .models.whisper import FlaxWhisperForConditionalGeneration, FlaxWhisperModel, FlaxWhisperPreTrainedModel from .models.xglm import FlaxXGLMForCausalLM, FlaxXGLMModel, FlaxXGLMPreTrainedModel from .models.xlm_roberta import ( FLAX_XLM_ROBERTA_PRETRAINED_MODEL_ARCHIVE_LIST, diff --git a/src/transformers/generation/flax_logits_process.py b/src/transformers/generation/flax_logits_process.py index 92e3cff82cb5c2..c5be0cfca937ce 100644 --- a/src/transformers/generation/flax_logits_process.py +++ b/src/transformers/generation/flax_logits_process.py @@ -258,10 +258,197 @@ def __init__(self, min_length: int, eos_token_id: int): self.eos_token_id = eos_token_id def __call__(self, input_ids: jnp.ndarray, scores: jnp.ndarray, cur_len: int) -> jnp.ndarray: - # create boolean flag to decide if min length penalty should be applied apply_penalty = 1 - jnp.clip(cur_len - self.min_length, 0, 1) scores = jnp.where(apply_penalty, scores.at[:, self.eos_token_id].set(-float("inf")), scores) return scores + + +class FlaxSuppressTokensAtBeginLogitsProcessor(FlaxLogitsProcessor): + r""" + [`FlaxLogitsProcessor`] supressing a list of tokens as soon as the `generate` function starts generating using + `begin_index` tokens. This should ensure that the tokens defined by `begin_suppress_tokens` are not sampled at the + begining of the generation. + + Args: + begin_suppress_tokens (`List[int]`): + Tokens to not sample. + begin_index (`int`): + Index where the tokens are suppressed. + """ + + def __init__(self, begin_suppress_tokens, begin_index): + self.begin_suppress_tokens = list(begin_suppress_tokens) + self.begin_index = begin_index + + def __call__(self, input_ids, scores, cur_len: int): + apply_penalty = 1 - jnp.bool_(cur_len - self.begin_index) + + scores = jnp.where(apply_penalty, scores.at[:, self.begin_suppress_tokens].set(-float("inf")), scores) + + return scores + + +class FlaxSuppressTokensLogitsProcessor(FlaxLogitsProcessor): + r""" + [`FlaxLogitsProcessor`] suppressing a list of tokens at each decoding step. The processor will set their log probs + to be `-inf` so they are not sampled. + + Args: + suppress_tokens (`list`): + Tokens to not sample. + """ + + def __init__(self, suppress_tokens: list): + self.suppress_tokens = list(suppress_tokens) + + def __call__(self, input_ids: jnp.ndarray, scores: jnp.ndarray, cur_len: int) -> jnp.ndarray: + scores = scores.at[..., self.suppress_tokens].set(-float("inf")) + + return scores + + +class FlaxForceTokensLogitsProcessor(FlaxLogitsProcessor): + r""" + [`FlaxLogitsProcessor`] that takes a list of pairs of integers which indicates a mapping from generation indices to + token indices that will be forced before sampling. The processor will set their log probs to 0 and all other tokens + to `-inf` so that they are sampled at their corresponding index. + + Args: + force_token_map (`list`): + Map giving token ids and indices where they will be forced to be sampled. + """ + + def __init__(self, force_token_map): + force_token_map = dict(force_token_map) + # Converts the dictionary of format {index: token} containing the tokens to be forced to an array, where the + # index of the array corresponds to the index of the token to be forced, for XLA compatibility. + # Indexes without forced tokens will have a negative value. + force_token_array = jnp.ones((max(force_token_map.keys()) + 1), dtype=jnp.int32) * -1 + for index, token in force_token_map.items(): + force_token_array = force_token_array.at[index].set(token) + self.force_token_array = jnp.int32(force_token_array) + + def __call__(self, input_ids: jnp.ndarray, scores: jnp.ndarray, cur_len: int) -> jnp.ndarray: + def _force_token(generation_idx): + batch_size = scores.shape[0] + current_token = self.force_token_array[generation_idx] + + new_scores = jnp.ones_like(scores, dtype=scores.dtype) * -float("inf") + updates = jnp.zeros((batch_size, 1), dtype=scores.dtype) + new_scores = lax.dynamic_update_slice(new_scores, updates, (0, current_token)) + return new_scores + + scores = lax.cond( + cur_len >= self.force_token_array.shape[0], + # If the current length is geq than the length of force_token_array, the processor does nothing. + lambda: scores, + # Otherwise, it may force a certain token. + lambda: lax.cond( + self.force_token_array[cur_len] >= 0, + # Only valid (positive) tokens are forced + lambda: _force_token(cur_len), + # Otherwise, the processor does nothing. + lambda: scores, + ), + ) + return scores + + +class FlaxWhisperTimeStampLogitsProcessor(FlaxLogitsProcessor): + r""" + Whisper specific Processor. This processor can be used to force a list of tokens. The processor will set their log + probs to `inf` so that they are sampled at their corresponding index. + + Args: + generate_config (`GenerateConfig`): + The generate config used to generate the output. The following parameters are required: + eos_token_id (`int`, *optional*, defaults to 50257): + The id of the *end-of-sequence* token. + no_timestamps_token_id (`int`, *optional*, defaults to 50363): + The id of the `"<|notimestamps|>"` token. + max_initial_timestamp_index (`int`, *optional*, defaults to 1): + Used to set the maximum value of the initial timestamp. This is used to prevent the model from + predicting timestamps that are too far in the future. + """ + + def __init__(self, generate_config, model_config, decoder_input_length): + self.eos_token_id = generate_config.eos_token_id + self.no_timestamps_token_id = generate_config.no_timestamps_token_id + self.timestamp_begin = generate_config.no_timestamps_token_id + 1 + + self.begin_index = decoder_input_length + 1 + + if generate_config.is_multilingual: + # room for language token and task token + self.begin_index += 2 + if hasattr(generate_config, "max_initial_timestamp_index"): + self.max_initial_timestamp_index = generate_config.max_initial_timestamp_index + else: + self.max_initial_timestamp_index = model_config.vocab_size + if self.max_initial_timestamp_index is None: + self.max_initial_timestamp_index = model_config.vocab_size + + def __call__(self, input_ids, scores, cur_len): + # suppress <|notimestamps|> which is handled by without_timestamps + scores = scores.at[:, self.no_timestamps_token_id].set(-float("inf")) + + def handle_pairs(input_ids_k, scores_k): + last_was_timestamp = jnp.where((cur_len - self.begin_index) >= 1, True, False) + last_was_timestamp = jnp.where( + input_ids_k[cur_len - 1] >= self.timestamp_begin, + True and last_was_timestamp, + False, + ) + + penultimate_was_timestamp = jnp.where((cur_len - self.begin_index) < 2, True, False) + penultimate_was_timestamp = jnp.where( + input_ids_k[cur_len - 2] >= self.timestamp_begin, + True, + penultimate_was_timestamp, + ) + + return jnp.where( + last_was_timestamp, + jnp.where( + penultimate_was_timestamp > 0, + scores_k.at[self.timestamp_begin :].set(-float("inf")), + scores_k.at[: self.eos_token_id].set(-float("inf")), + ), + scores_k, + ) + + scores = jax.vmap(handle_pairs)(input_ids, scores) + + apply_max_initial_timestamp = jnp.where(cur_len == self.begin_index, True, False) + apply_max_initial_timestamp = jnp.where( + self.max_initial_timestamp_index is not None, + True and apply_max_initial_timestamp, + False, + ) + + last_allowed = self.timestamp_begin + self.max_initial_timestamp_index + + scores = jnp.where( + apply_max_initial_timestamp, + scores.at[:, last_allowed + 1 :].set(-float("inf")), + scores, + ) + + # if sum of probability over timestamps is above any other token, sample timestamp + logprobs = jax.nn.log_softmax(scores, axis=-1) + + def handle_cumulative_probs(logprobs_k, scores_k): + timestamp_logprob = jax.nn.logsumexp(logprobs_k[self.timestamp_begin :], axis=-1) + max_text_token_logprob = jnp.max(logprobs_k[: self.timestamp_begin]) + return jnp.where( + timestamp_logprob > max_text_token_logprob, + scores_k.at[: self.timestamp_begin].set(-float("inf")), + scores_k, + ) + + scores = jax.vmap(handle_cumulative_probs)(logprobs, scores) + + return scores diff --git a/src/transformers/generation/flax_utils.py b/src/transformers/generation/flax_utils.py index 1cfc07b9786c8b..5ced61434d30e0 100644 --- a/src/transformers/generation/flax_utils.py +++ b/src/transformers/generation/flax_utils.py @@ -38,8 +38,11 @@ from .flax_logits_process import ( FlaxForcedBOSTokenLogitsProcessor, FlaxForcedEOSTokenLogitsProcessor, + FlaxForceTokensLogitsProcessor, FlaxLogitsProcessorList, FlaxMinLengthLogitsProcessor, + FlaxSuppressTokensAtBeginLogitsProcessor, + FlaxSuppressTokensLogitsProcessor, FlaxTemperatureLogitsWarper, FlaxTopKLogitsWarper, FlaxTopPLogitsWarper, @@ -165,6 +168,50 @@ def _prepare_encoder_decoder_kwargs_for_generation(self, input_ids, params, mode model_kwargs["encoder_outputs"] = self.encode(input_ids, params=params, return_dict=True, **encoder_kwargs) return model_kwargs + def _prepare_decoder_input_ids_for_generation( + self, + batch_size: int, + decoder_start_token_id: int = None, + bos_token_id: int = None, + model_kwargs: Optional[Dict[str, jnp.ndarray]] = None, + ) -> jnp.ndarray: + if model_kwargs is not None and "decoder_input_ids" in model_kwargs: + # Only use this arg if not None, otherwise just remove from model_kwargs + decoder_input_ids = model_kwargs.pop("decoder_input_ids") + if decoder_input_ids is not None: + return decoder_input_ids + decoder_start_token_id = self._get_decoder_start_token_id(decoder_start_token_id, bos_token_id) + return jnp.array(decoder_start_token_id, dtype="i4").reshape(1, -1).repeat(batch_size, axis=0) + + def _get_decoder_start_token_id(self, decoder_start_token_id: int = None, bos_token_id: int = None) -> int: + # retrieve decoder_start_token_id for encoder-decoder models + # fall back to bos_token_id if necessary + decoder_start_token_id = ( + decoder_start_token_id + if decoder_start_token_id is not None + else self.generation_config.decoder_start_token_id + ) + bos_token_id = bos_token_id if bos_token_id is not None else self.generation_config.bos_token_id + if decoder_start_token_id is not None: + return decoder_start_token_id + elif ( + hasattr(self.config, "decoder") + and hasattr(self.config.decoder, "decoder_start_token_id") + and self.config.decoder.decoder_start_token_id is not None + ): + return self.config.decoder.decoder_start_token_id + elif bos_token_id is not None: + return bos_token_id + elif ( + hasattr(self.config, "decoder") + and hasattr(self.config.decoder, "bos_token_id") + and self.config.decoder.bos_token_id is not None + ): + return self.config.decoder.bos_token_id + raise ValueError( + "`decoder_start_token_id` or `bos_token_id` has to be defined for encoder-decoder generation." + ) + @staticmethod def _expand_to_num_beams(tensor, num_beams): return jnp.broadcast_to(tensor[:, None], (tensor.shape[0], num_beams) + tensor.shape[1:]) @@ -225,6 +272,7 @@ def generate( prng_key: Optional[jnp.ndarray] = None, trace: bool = True, params: Optional[Dict[str, jnp.ndarray]] = None, + logits_processor: Optional[FlaxLogitsProcessorList] = None, **kwargs, ): r""" @@ -245,6 +293,10 @@ def generate( considerably slower runtime. params (`Dict[str, jnp.ndarray]`, *optional*): Optionally the model parameters can be passed. Can be useful for parallelized generation. + logits_processor (`FlaxLogitsProcessorList `, *optional*): + Custom logits processors that complement the default logits processors built from arguments and + generation config. If a logit processor is passed that is already created with the arguments or a + generation config an error is thrown. This feature is intended for advanced users. kwargs: Ad hoc parametrization of `generate_config` and/or additional model-specific kwargs that will be forwarded to the `forward` function of the model. If the model is an encoder-decoder model, encoder @@ -278,6 +330,8 @@ def generate( generation_config.validate() self._validate_model_kwargs(model_kwargs.copy()) + logits_processor = logits_processor if logits_processor is not None else FlaxLogitsProcessorList() + # set init values prng_key = prng_key if prng_key is not None else jax.random.PRNGKey(0) @@ -307,12 +361,19 @@ def generate( "generation results, please set `padding_side='left'` when initializing the tokenizer." ) + batch_size = input_ids.shape[0] + if self.config.is_encoder_decoder: # add encoder_outputs to model_kwargs if model_kwargs.get("encoder_outputs") is None: model_kwargs = self._prepare_encoder_decoder_kwargs_for_generation(input_ids, params, model_kwargs) # prepare decoder_input_ids for generation - input_ids = jnp.ones((input_ids.shape[0], 1), dtype="i4") * generation_config.decoder_start_token_id + input_ids = self._prepare_decoder_input_ids_for_generation( + batch_size, + decoder_start_token_id=generation_config.decoder_start_token_id, + bos_token_id=generation_config.bos_token_id, + model_kwargs=model_kwargs, + ) # Prepare `max_length` depending on other stopping criteria. input_ids_seq_length = input_ids.shape[-1] @@ -348,7 +409,11 @@ def generate( " increasing`max_new_tokens`." ) - logits_processor = self._get_logits_processor(generation_config=generation_config) + logits_processor = self._get_logits_processor( + generation_config=generation_config, + input_ids_seq_length=input_ids_seq_length, + logits_processor=logits_processor, + ) if not generation_config.do_sample and generation_config.num_beams == 1: return self._greedy_search( @@ -420,7 +485,12 @@ def _get_logits_warper(self, generation_config: GenerationConfig) -> FlaxLogitsP return warpers - def _get_logits_processor(self, generation_config: GenerationConfig) -> FlaxLogitsProcessorList: + def _get_logits_processor( + self, + generation_config: GenerationConfig, + input_ids_seq_length: int, + logits_processor: Optional[FlaxLogitsProcessorList], + ) -> FlaxLogitsProcessorList: """ This class returns a [`FlaxLogitsProcessorList`] list object that contains all relevant [`FlaxLogitsProcessor`] instances used to modify the scores of the language model head. @@ -441,9 +511,51 @@ def _get_logits_processor(self, generation_config: GenerationConfig) -> FlaxLogi processors.append( FlaxForcedEOSTokenLogitsProcessor(generation_config.max_length, generation_config.forced_eos_token_id) ) + if generation_config.suppress_tokens is not None: + processors.append(FlaxSuppressTokensLogitsProcessor(generation_config.suppress_tokens)) + if generation_config.begin_suppress_tokens is not None: + begin_index = input_ids_seq_length + begin_index = ( + begin_index + if (input_ids_seq_length > 1 or generation_config.forced_bos_token_id is None) + else begin_index + 1 + ) + if generation_config.forced_decoder_ids is not None and len(generation_config.forced_decoder_ids) > 0: + # generation starts after the last token that is forced + begin_index += generation_config.forced_decoder_ids[-1][0] + processors.append( + FlaxSuppressTokensAtBeginLogitsProcessor(generation_config.begin_suppress_tokens, begin_index) + ) + if generation_config.forced_decoder_ids is not None: + forced_decoder_ids = [ + [input_ids_seq_length + i[0] - 1, i[1]] for i in generation_config.forced_decoder_ids + ] + processors.append(FlaxForceTokensLogitsProcessor(forced_decoder_ids)) + processors = self._merge_criteria_processor_list(processors, logits_processor) return processors + def _merge_criteria_processor_list( + self, + default_list: FlaxLogitsProcessorList, + custom_list: FlaxLogitsProcessorList, + ) -> FlaxLogitsProcessorList: + if len(custom_list) == 0: + return default_list + for default in default_list: + for custom in custom_list: + if type(custom) is type(default): + object_type = "logits processor" + raise ValueError( + f"A custom {object_type} of type {type(custom)} with values {custom} has been passed to" + f" `generate`, but it has already been created with the values {default}. {default} has been" + " created by passing the corresponding arguments to generate or by the model's config default" + f" values. If you just want to change the default values of {object_type} consider passing" + f" them as arguments to `generate` instead of using a custom {object_type}." + ) + default_list.extend(custom_list) + return default_list + def _greedy_search( self, input_ids: None, diff --git a/src/transformers/models/auto/__init__.py b/src/transformers/models/auto/__init__.py index da8ceb8e7e6258..d86f738fa9df68 100644 --- a/src/transformers/models/auto/__init__.py +++ b/src/transformers/models/auto/__init__.py @@ -167,6 +167,7 @@ "FLAX_MODEL_FOR_QUESTION_ANSWERING_MAPPING", "FLAX_MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING", "FLAX_MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING", + "FLAX_MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING", "FLAX_MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING", "FLAX_MODEL_FOR_VISION_2_SEQ_MAPPING", "FLAX_MODEL_MAPPING", @@ -180,6 +181,7 @@ "FlaxAutoModelForQuestionAnswering", "FlaxAutoModelForSeq2SeqLM", "FlaxAutoModelForSequenceClassification", + "FlaxAutoModelForSpeechSeq2Seq", "FlaxAutoModelForTokenClassification", "FlaxAutoModelForVision2Seq", ] @@ -324,6 +326,7 @@ FLAX_MODEL_FOR_QUESTION_ANSWERING_MAPPING, FLAX_MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING, FLAX_MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING, + FLAX_MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING, FLAX_MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING, FLAX_MODEL_FOR_VISION_2_SEQ_MAPPING, FLAX_MODEL_MAPPING, @@ -337,6 +340,7 @@ FlaxAutoModelForQuestionAnswering, FlaxAutoModelForSeq2SeqLM, FlaxAutoModelForSequenceClassification, + FlaxAutoModelForSpeechSeq2Seq, FlaxAutoModelForTokenClassification, FlaxAutoModelForVision2Seq, ) diff --git a/src/transformers/models/auto/modeling_flax_auto.py b/src/transformers/models/auto/modeling_flax_auto.py index 61d34f0f082675..77be9b33f0a701 100644 --- a/src/transformers/models/auto/modeling_flax_auto.py +++ b/src/transformers/models/auto/modeling_flax_auto.py @@ -55,6 +55,7 @@ ("vision-text-dual-encoder", "FlaxVisionTextDualEncoderModel"), ("vit", "FlaxViTModel"), ("wav2vec2", "FlaxWav2Vec2Model"), + ("whisper", "FlaxWhisperModel"), ("xglm", "FlaxXGLMModel"), ("xlm-roberta", "FlaxXLMRobertaModel"), ] @@ -76,6 +77,7 @@ ("roformer", "FlaxRoFormerForMaskedLM"), ("t5", "FlaxT5ForConditionalGeneration"), ("wav2vec2", "FlaxWav2Vec2ForPreTraining"), + ("whisper", "FlaxWhisperForConditionalGeneration"), ("xlm-roberta", "FlaxXLMRobertaForMaskedLM"), ] ) @@ -219,6 +221,7 @@ FLAX_MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING_NAMES = OrderedDict( [ ("speech-encoder-decoder", "FlaxSpeechEncoderDecoderModel"), + ("whisper", "FlaxWhisperForConditionalGeneration"), ] ) diff --git a/src/transformers/models/whisper/__init__.py b/src/transformers/models/whisper/__init__.py index 2528e03a4d2c88..31e167c4ff6c0c 100644 --- a/src/transformers/models/whisper/__init__.py +++ b/src/transformers/models/whisper/__init__.py @@ -17,7 +17,13 @@ # limitations under the License. from typing import TYPE_CHECKING -from ...utils import OptionalDependencyNotAvailable, _LazyModule, is_tf_available, is_torch_available +from ...utils import ( + OptionalDependencyNotAvailable, + _LazyModule, + is_flax_available, + is_tf_available, + is_torch_available, +) _import_structure = { @@ -54,6 +60,19 @@ "TFWhisperPreTrainedModel", ] +try: + if not is_flax_available(): + raise OptionalDependencyNotAvailable() +except OptionalDependencyNotAvailable: + pass +else: + _import_structure["modeling_flax_whisper"] = [ + "FlaxWhisperForConditionalGeneration", + "FlaxWhisperModel", + "FlaxWhisperPreTrainedModel", + ] + + if TYPE_CHECKING: from .configuration_whisper import WHISPER_PRETRAINED_CONFIG_ARCHIVE_MAP, WhisperConfig, WhisperOnnxConfig from .feature_extraction_whisper import WhisperFeatureExtractor @@ -86,6 +105,18 @@ TFWhisperPreTrainedModel, ) + try: + if not is_flax_available(): + raise OptionalDependencyNotAvailable() + except OptionalDependencyNotAvailable: + pass + else: + from .modeling_flax_whisper import ( + FlaxWhisperForConditionalGeneration, + FlaxWhisperModel, + FlaxWhisperPreTrainedModel, + ) + else: import sys diff --git a/src/transformers/models/whisper/modeling_flax_whisper.py b/src/transformers/models/whisper/modeling_flax_whisper.py new file mode 100644 index 00000000000000..f66a02453d7936 --- /dev/null +++ b/src/transformers/models/whisper/modeling_flax_whisper.py @@ -0,0 +1,1470 @@ +# coding=utf-8 +# Copyright 2022 The OpenAI Authors and The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" Flax whisper model.""" + +import random +from functools import partial +from typing import Optional, Tuple + +import flax.linen as nn +import jax +import jax.numpy as jnp +from flax.core.frozen_dict import FrozenDict, freeze, unfreeze +from flax.linen import combine_masks, make_causal_mask +from flax.linen.attention import dot_product_attention_weights +from flax.traverse_util import flatten_dict, unflatten_dict +from jax import lax +from jax.random import PRNGKey + +from ...generation.flax_logits_process import FlaxWhisperTimeStampLogitsProcessor +from ...modeling_flax_outputs import ( + FlaxBaseModelOutput, + FlaxBaseModelOutputWithPastAndCrossAttentions, + FlaxCausalLMOutputWithCrossAttentions, + FlaxSeq2SeqLMOutput, + FlaxSeq2SeqModelOutput, +) +from ...modeling_flax_utils import ( + ACT2FN, + FlaxPreTrainedModel, + append_call_sample_docstring, + append_replace_return_docstrings, + overwrite_call_docstring, +) +from ...utils import add_start_docstrings, add_start_docstrings_to_model_forward, logging, replace_return_docstrings +from .configuration_whisper import WhisperConfig + + +logger = logging.get_logger(__name__) + + +_CHECKPOINT_FOR_DOC = "openai/whisper-tiny" +_CONFIG_FOR_DOC = "WhisperConfig" + + +WHISPER_START_DOCSTRING = r""" + This model inherits from [`FlaxPreTrainedModel`]. Check the superclass documentation for the generic methods the + library implements for all its models (such as downloading or saving, resizing the input embeddings, pruning heads + etc.) This model is also a Flax Linen + [flax.nn.Module](https://flax.readthedocs.io/en/latest/_autosummary/flax.nn.module.html) subclass. Use it as a + regular Flax Module and refer to the Flax documentation for all matter related to general usage and behavior. + Finally, this model supports inherent JAX features such as: + - [Just-In-Time (JIT) compilation](https://jax.readthedocs.io/en/latest/jax.html#just-in-time-compilation-jit) + - [Automatic Differentiation](https://jax.readthedocs.io/en/latest/jax.html#automatic-differentiation) + - [Vectorization](https://jax.readthedocs.io/en/latest/jax.html#vectorization-vmap) + - [Parallelization](https://jax.readthedocs.io/en/latest/jax.html#parallelization-pmap) + + Parameters: + config ([`WhisperConfig`]): Model configuration class with all the parameters of the model. + Initializing with a config file does not load the weights associated with the model, only the + configuration. Check out the [`~FlaxPreTrainedModel.from_pretrained`] method to load the model weights. + dtype (`jax.numpy.dtype`, *optional*, defaults to `jax.numpy.float32`): + The data type of the computation. Can be one of `jax.numpy.float32`, `jax.numpy.float16` (on GPUs) and + `jax.numpy.bfloat16` (on TPUs). This can be used to enable mixed-precision training or half-precision + inference on GPUs or TPUs. If specified all the computation will be performed with the given `dtype`. + **Note that this only specifies the dtype of the computation and does not influence the dtype of model + parameters.** If you wish to change the dtype of the model parameters, see [`~FlaxPreTrainedModel.to_fp16`] + and [`~FlaxPreTrainedModel.to_bf16`]. +""" + +WHISPER_INPUTS_DOCSTRING = r""" + Args: + input_features (`numpy.ndarray` of shape `(batch_size, feature_size, sequence_length)`): + Float values mel features extracted from the raw speech waveform. Raw speech waveform can be obtained by + loading a `.flac` or `.wav` audio file into an array of type `List[float]` or a `numpy.ndarray`, *e.g.* via + the soundfile library (`pip install soundfile`). To prepare the array into `input_features`, the + [`WhisperFeatureExtractor`] should be used for extracting the features, padding and conversion into a + tensor of type `numpy.ndarray`. See [`~WhisperFeatureExtractor.__call__`] + attention_mask (`numpy.ndarray` of shape `(batch_size, sequence_length)`, *optional*): + Whisper does not support masking of the `input_features`, this argument is preserved for compatibility, but + is not used. By default the silence in the input log mel spectrogram are ignored. + decoder_input_ids (`numpy.ndarray` of shape `(batch_size, target_sequence_length)`, *optional*): + Indices of decoder input sequence tokens in the vocabulary. Indices can be obtained using + [`WhisperTokenizer`]. See [`PreTrainedTokenizer.encode`] and [`PreTrainedTokenizer.__call__`] for details. + [What are decoder input IDs?](../glossary#decoder-input-ids) Whisper uses the `decoder_start_token_id` as + the starting token for `decoder_input_ids` generation. + decoder_attention_mask (`numpy.ndarray` of shape `(batch_size, target_sequence_length)`, *optional*): + Default behavior: generate a tensor that ignores pad tokens in `decoder_input_ids`. Causal mask will also + be used by default. If you want to change padding behavior, you should modify to your needs. See diagram 1 + in [the paper](https://arxiv.org/abs/1910.13461) for more information on the default strategy. + position_ids (`numpy.ndarray` of shape `(batch_size, sequence_length)`, *optional*): + Whisper does not use `position_ids` in the encoder as `input_features` is always the same size and doesn't + use masking, but this argument is preserved for compatibility. By default the silence in the input log mel + spectrogram are ignored. + decoder_position_ids (`numpy.ndarray` of shape `(batch_size, sequence_length)`, *optional*): + Indices of positions of each decoder input sequence tokens in the position embeddings. Selected in the + range `[0, config.max_position_embeddings - 1]`. + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned + tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for + more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. +""" + +WHISPER_ENCODE_INPUTS_DOCSTRING = r""" + Args: + input_features (`numpy.ndarray` of shape `(batch_size, feature_size, sequence_length)`): + Float values mel features extracted from the raw speech waveform. Raw speech waveform can be obtained by + loading a `.flac` or `.wav` audio file into an array of type `List[float]` or a `numpy.ndarray`, *e.g.* via + the soundfile library (`pip install soundfile`). To prepare the array into `input_features`, the + [`WhisperFeatureExtractor`] should be used for extracting the mel features, padding and conversion into a + tensor of type `numpy.ndarray`. See [`~WhisperFeatureExtractor.__call__`]. + attention_mask (`numpy.ndarray` of shape `(batch_size, sequence_length)`, *optional*): + Whisper does not support masking of the `input_features`, this argument is preserved for compatibility, but + is not used. By default the silence in the input log mel spectrogram are ignored. + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned + tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for + more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. +""" + +WHISPER_DECODE_INPUTS_DOCSTRING = r""" + Args: + decoder_input_ids (`numpy.ndarray` of shape `(batch_size, target_sequence_length)`): + Indices of decoder input sequence tokens in the vocabulary. Indices can be obtained using + [`WhisperTokenizer`]. See [`PreTrainedTokenizer.encode`] and [`PreTrainedTokenizer.__call__`] for details. + [What are decoder input IDs?](../glossary#decoder-input-ids) + encoder_outputs (`tuple(tuple(numpy.ndarray)`): + Tuple consists of (`last_hidden_state`, *optional*: `hidden_states`, *optional*: `attentions`) + `last_hidden_state` of shape `(batch_size, sequence_length, hidden_size)`, *optional*) is a sequence of + hidden-states at the output of the last layer of the encoder. Used in the cross-attention of the decoder. + encoder_attention_mask (`numpy.ndarray` of shape `(batch_size, sequence_length)`, *optional*): + Whisper does not support masking of the `input_features`, this argument is preserved for compatibility, + but it is not used. By default the silence in the input log mel spectrogram are ignored. + decoder_attention_mask (`numpy.ndarray` of shape `(batch_size, target_sequence_length)`, *optional*): + Default behavior: generate a tensor that ignores pad tokens in `decoder_input_ids`. Causal mask will also + be used by default. If you want to change padding behavior, you should modify to your needs. See diagram 1 + in [the paper](https://arxiv.org/abs/1910.13461) for more information on the default strategy. + decoder_position_ids (`numpy.ndarray` of shape `(batch_size, sequence_length)`, *optional*): + Indices of positions of each decoder input sequence tokens in the position embeddings. Selected in the + range `[0, config.max_position_embeddings - 1]`. + past_key_values (`Dict[str, numpy.ndarray]`, *optional*, returned by `init_cache` or when passing previous `past_key_values`): + Dictionary of pre-computed hidden-states (key and values in the attention blocks) that can be used for fast + auto-regressive decoding. Pre-computed key and value hidden-states are of shape *[batch_size, max_length]*. + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned + tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for + more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. +""" + + +class FlaxWhisperAttention(nn.Module): + config: WhisperConfig + embed_dim: int + num_heads: int + dropout: float = 0.0 + causal: bool = False + bias: bool = True + dtype: jnp.dtype = jnp.float32 + + def setup(self) -> None: + self.head_dim = self.embed_dim // self.num_heads + if self.head_dim * self.num_heads != self.embed_dim: + raise ValueError( + f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim}" + f" and `num_heads`: {self.num_heads})." + ) + + dense = partial( + nn.Dense, + self.embed_dim, + dtype=self.dtype, + kernel_init=jax.nn.initializers.normal(self.config.init_std), + ) + + self.q_proj = dense(use_bias=self.bias) + self.k_proj = dense(use_bias=False) + self.v_proj = dense(use_bias=self.bias) + self.out_proj = dense(use_bias=self.bias) + + if self.causal: + self.causal_mask = make_causal_mask( + jnp.ones((1, self.config.max_target_positions), dtype="bool"), dtype="bool" + ) + + def __call__( + self, + hidden_states: jnp.ndarray, + key_value_states: Optional[jnp.ndarray] = None, + attention_mask: Optional[jnp.ndarray] = None, + init_cache: bool = False, + deterministic: bool = True, + ) -> Tuple[jnp.ndarray]: + is_cross_attention = key_value_states is not None + batch_size = hidden_states.shape[0] + + query_states = self.q_proj(hidden_states) + + if is_cross_attention: + key_states = self.k_proj(key_value_states) + value_states = self.v_proj(key_value_states) + else: + key_states = self.k_proj(hidden_states) + value_states = self.v_proj(hidden_states) + + query_states = self._split_heads(query_states) + key_states = self._split_heads(key_states) + value_states = self._split_heads(value_states) + + if self.causal: + query_length, key_length = query_states.shape[1], key_states.shape[1] + if self.has_variable("cache", "cached_key"): + mask_shift = self.variables["cache"]["cache_index"] + max_decoder_length = self.variables["cache"]["cached_key"].shape[1] + causal_mask = lax.dynamic_slice( + self.causal_mask, + (0, 0, mask_shift, 0), + (1, 1, query_length, max_decoder_length), + ) + else: + causal_mask = self.causal_mask[:, :, :query_length, :key_length] + causal_mask = jnp.broadcast_to(causal_mask, (batch_size,) + causal_mask.shape[1:]) + + # combine masks if needed + if attention_mask is not None and self.causal: + attention_mask = jnp.broadcast_to(jnp.expand_dims(attention_mask, axis=(-3, -2)), causal_mask.shape) + attention_mask = combine_masks(attention_mask, causal_mask) + elif self.causal: + attention_mask = causal_mask + elif attention_mask is not None: + attention_mask = jnp.expand_dims(attention_mask, axis=(-3, -2)) + + # During fast autoregressive decoding, we feed one position at a time, + # and cache the keys and values step by step. + + if self.causal and (self.has_variable("cache", "cached_key") or init_cache): + key_states, value_states, attention_mask = self._concatenate_to_cache( + key_states, value_states, query_states, attention_mask + ) + + # Convert the boolean attention mask to an attention bias. + if attention_mask is not None: + # attention mask in the form of attention bias + attention_bias = lax.select( + attention_mask > 0, + jnp.full(attention_mask.shape, 0.0).astype(self.dtype), + jnp.full(attention_mask.shape, jnp.finfo(self.dtype).min).astype(self.dtype), + ) + else: + attention_bias = None + + dropout_rng = None + if not deterministic and self.dropout > 0.0: + dropout_rng = self.make_rng("dropout") + + attn_weights = dot_product_attention_weights( + query_states, + key_states, + bias=attention_bias, + dropout_rng=dropout_rng, + dropout_rate=self.dropout, + broadcast_dropout=True, + deterministic=deterministic, + dtype=self.dtype, + precision=None, + ) + + attn_output = jnp.einsum("...hqk,...khd->...qhd", attn_weights, value_states) + attn_output = self._merge_heads(attn_output) + attn_output = self.out_proj(attn_output) + + return attn_output, attn_weights + + def _split_heads(self, hidden_state) -> jnp.ndarray: + return hidden_state.reshape(hidden_state.shape[:2] + (self.num_heads, self.head_dim)) + + def _merge_heads(self, hidden_state) -> jnp.ndarray: + return hidden_state.reshape(hidden_state.shape[:2] + (self.embed_dim,)) + + @nn.compact + def _concatenate_to_cache(self, key, value, query, attention_mask) -> Tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray]: + # detect if we're initializing by absence of existing cache data. + is_initialized = self.has_variable("cache", "cached_key") + cached_key = self.variable("cache", "cached_key", jnp.zeros, key.shape, key.dtype) + cached_value = self.variable("cache", "cached_value", jnp.zeros, value.shape, value.dtype) + cache_index = self.variable("cache", "cache_index", lambda: jnp.array(0, dtype=jnp.int32)) + + if is_initialized: + *batch_dims, max_length, num_heads, depth_per_head = cached_key.value.shape + # update key, value caches with our new 1d spatial slices + cur_index = cache_index.value + indices = (0,) * len(batch_dims) + (cur_index, 0, 0) + key = lax.dynamic_update_slice(cached_key.value, key, indices) + value = lax.dynamic_update_slice(cached_value.value, value, indices) + cached_key.value = key + cached_value.value = value + num_updated_cache_vectors = query.shape[1] + cache_index.value = cache_index.value + num_updated_cache_vectors + # causal mask for cached decoder self-attention: our single query position should only + # attend to those key positions that have already been generated and cached, not the + # remaining zero elements. + pad_mask = jnp.broadcast_to( + jnp.arange(max_length) < cur_index + num_updated_cache_vectors, + tuple(batch_dims) + (1, num_updated_cache_vectors, max_length), + ) + attention_mask = combine_masks(pad_mask, attention_mask) + + return key, value, attention_mask + + +# Copied from transformers.models.mbart.modeling_flax_mbart.FlaxMBartEncoderLayer with MBart->Whisper +class FlaxWhisperEncoderLayer(nn.Module): + config: WhisperConfig + dtype: jnp.dtype = jnp.float32 + + def setup(self) -> None: + self.embed_dim = self.config.d_model + self.self_attn = FlaxWhisperAttention( + config=self.config, + embed_dim=self.embed_dim, + num_heads=self.config.encoder_attention_heads, + dropout=self.config.attention_dropout, + dtype=self.dtype, + ) + self.self_attn_layer_norm = nn.LayerNorm(dtype=self.dtype, epsilon=1e-05) + self.dropout_layer = nn.Dropout(rate=self.config.dropout) + self.activation_fn = ACT2FN[self.config.activation_function] + self.activation_dropout_layer = nn.Dropout(rate=self.config.activation_dropout) + self.fc1 = nn.Dense( + self.config.encoder_ffn_dim, + dtype=self.dtype, + kernel_init=jax.nn.initializers.normal(self.config.init_std), + ) + self.fc2 = nn.Dense( + self.embed_dim, dtype=self.dtype, kernel_init=jax.nn.initializers.normal(self.config.init_std) + ) + self.final_layer_norm = nn.LayerNorm(dtype=self.dtype, epsilon=1e-05) + + def __call__( + self, + hidden_states: jnp.ndarray, + attention_mask: jnp.ndarray, + output_attentions: bool = True, + deterministic: bool = True, + ) -> Tuple[jnp.ndarray]: + residual = hidden_states + hidden_states = self.self_attn_layer_norm(hidden_states) + hidden_states, attn_weights = self.self_attn(hidden_states=hidden_states, attention_mask=attention_mask) + hidden_states = self.dropout_layer(hidden_states, deterministic=deterministic) + hidden_states = residual + hidden_states + + residual = hidden_states + hidden_states = self.final_layer_norm(hidden_states) + hidden_states = self.activation_fn(self.fc1(hidden_states)) + hidden_states = self.activation_dropout_layer(hidden_states, deterministic=deterministic) + hidden_states = self.fc2(hidden_states) + hidden_states = self.dropout_layer(hidden_states, deterministic=deterministic) + hidden_states = residual + hidden_states + + outputs = (hidden_states,) + + if output_attentions: + outputs += (attn_weights,) + + return outputs + + +# Copied from transformers.models.mbart.modeling_flax_mbart.FlaxMBartEncoderLayerCollection with MBart->Whisper +class FlaxWhisperEncoderLayerCollection(nn.Module): + config: WhisperConfig + dtype: jnp.dtype = jnp.float32 # the dtype of the computation + + def setup(self): + self.layers = [ + FlaxWhisperEncoderLayer(self.config, name=str(i), dtype=self.dtype) + for i in range(self.config.encoder_layers) + ] + self.layerdrop = self.config.encoder_layerdrop + + def __call__( + self, + hidden_states, + attention_mask, + deterministic: bool = True, + output_attentions: bool = False, + output_hidden_states: bool = False, + return_dict: bool = True, + ): + all_attentions = () if output_attentions else None + all_hidden_states = () if output_hidden_states else None + + for encoder_layer in self.layers: + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + # add LayerDrop (see https://arxiv.org/abs/1909.11556 for description) + dropout_probability = random.uniform(0, 1) + if not deterministic and (dropout_probability < self.layerdrop): # skip the layer + layer_outputs = (None, None) + else: + layer_outputs = encoder_layer( + hidden_states, + attention_mask, + output_attentions, + deterministic, + ) + hidden_states = layer_outputs[0] + if output_attentions: + all_attentions = all_attentions + (layer_outputs[1],) + + if output_hidden_states: + all_hidden_states += (hidden_states,) + + outputs = (hidden_states, all_hidden_states, all_attentions) + + if not return_dict: + return tuple(v for v in outputs if v is not None) + + return FlaxBaseModelOutput( + last_hidden_state=hidden_states, hidden_states=all_hidden_states, attentions=all_attentions + ) + + +# Copied from transformers.models.mbart.modeling_flax_mbart.FlaxMBartDecoderLayer with MBart->Whisper +class FlaxWhisperDecoderLayer(nn.Module): + config: WhisperConfig + dtype: jnp.dtype = jnp.float32 + + def setup(self) -> None: + self.embed_dim = self.config.d_model + self.self_attn = FlaxWhisperAttention( + config=self.config, + embed_dim=self.embed_dim, + num_heads=self.config.decoder_attention_heads, + dropout=self.config.attention_dropout, + causal=True, + dtype=self.dtype, + ) + self.dropout_layer = nn.Dropout(rate=self.config.dropout) + self.activation_fn = ACT2FN[self.config.activation_function] + self.activation_dropout_layer = nn.Dropout(rate=self.config.activation_dropout) + + self.self_attn_layer_norm = nn.LayerNorm(dtype=self.dtype, epsilon=1e-05) + self.encoder_attn = FlaxWhisperAttention( + config=self.config, + embed_dim=self.embed_dim, + num_heads=self.config.decoder_attention_heads, + dropout=self.config.attention_dropout, + dtype=self.dtype, + ) + self.encoder_attn_layer_norm = nn.LayerNorm(dtype=self.dtype, epsilon=1e-05) + self.fc1 = nn.Dense( + self.config.decoder_ffn_dim, + dtype=self.dtype, + kernel_init=jax.nn.initializers.normal(self.config.init_std), + ) + self.fc2 = nn.Dense( + self.embed_dim, dtype=self.dtype, kernel_init=jax.nn.initializers.normal(self.config.init_std) + ) + self.final_layer_norm = nn.LayerNorm(dtype=self.dtype, epsilon=1e-05) + + def __call__( + self, + hidden_states: jnp.ndarray, + attention_mask: jnp.ndarray, + encoder_hidden_states: Optional[jnp.ndarray] = None, + encoder_attention_mask: Optional[jnp.ndarray] = None, + init_cache: bool = False, + output_attentions: bool = True, + deterministic: bool = True, + ) -> Tuple[jnp.ndarray]: + residual = hidden_states + hidden_states = self.self_attn_layer_norm(hidden_states) + + # Self Attention + hidden_states, self_attn_weights = self.self_attn( + hidden_states=hidden_states, attention_mask=attention_mask, init_cache=init_cache + ) + hidden_states = self.dropout_layer(hidden_states, deterministic=deterministic) + hidden_states = residual + hidden_states + + # Cross-Attention Block + cross_attn_weights = None + if encoder_hidden_states is not None: + residual = hidden_states + + hidden_states = self.encoder_attn_layer_norm(hidden_states) + hidden_states, cross_attn_weights = self.encoder_attn( + hidden_states=hidden_states, + key_value_states=encoder_hidden_states, + attention_mask=encoder_attention_mask, + ) + hidden_states = self.dropout_layer(hidden_states, deterministic=deterministic) + hidden_states = residual + hidden_states + + # Fully Connected + residual = hidden_states + hidden_states = self.final_layer_norm(hidden_states) + hidden_states = self.activation_fn(self.fc1(hidden_states)) + hidden_states = self.activation_dropout_layer(hidden_states, deterministic=deterministic) + hidden_states = self.fc2(hidden_states) + hidden_states = self.dropout_layer(hidden_states, deterministic=deterministic) + hidden_states = residual + hidden_states + + outputs = (hidden_states,) + + if output_attentions: + outputs += (self_attn_weights, cross_attn_weights) + + return outputs + + +# Copied from transformers.models.mbart.modeling_flax_mbart.FlaxMBartDecoderLayerCollection with MBart->Whisper +class FlaxWhisperDecoderLayerCollection(nn.Module): + config: WhisperConfig + dtype: jnp.dtype = jnp.float32 # the dtype of the computation + + def setup(self): + self.layers = [ + FlaxWhisperDecoderLayer(self.config, name=str(i), dtype=self.dtype) + for i in range(self.config.decoder_layers) + ] + self.layerdrop = self.config.decoder_layerdrop + + def __call__( + self, + hidden_states, + attention_mask, + encoder_hidden_states: Optional[jnp.ndarray] = None, + encoder_attention_mask: Optional[jnp.ndarray] = None, + deterministic: bool = True, + init_cache: bool = False, + output_attentions: bool = False, + output_hidden_states: bool = False, + return_dict: bool = True, + ): + # decoder layers + all_hidden_states = () if output_hidden_states else None + all_self_attns = () if output_attentions else None + all_cross_attentions = () if (output_attentions and encoder_hidden_states is not None) else None + + for decoder_layer in self.layers: + if output_hidden_states: + all_hidden_states += (hidden_states,) + # add LayerDrop (see https://arxiv.org/abs/1909.11556 for description) + dropout_probability = random.uniform(0, 1) + if not deterministic and (dropout_probability < self.layerdrop): + layer_outputs = (None, None, None) + else: + layer_outputs = decoder_layer( + hidden_states, + attention_mask=attention_mask, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + init_cache=init_cache, + output_attentions=output_attentions, + deterministic=deterministic, + ) + + hidden_states = layer_outputs[0] + if output_attentions: + all_self_attns += (layer_outputs[1],) + + if encoder_hidden_states is not None: + all_cross_attentions += (layer_outputs[2],) + + # add hidden states from the last decoder layer + if output_hidden_states: + all_hidden_states += (hidden_states,) + + outputs = [hidden_states, all_hidden_states, all_self_attns, all_cross_attentions] + + if not return_dict: + return tuple(v for v in outputs if v is not None) + + return FlaxBaseModelOutputWithPastAndCrossAttentions( + last_hidden_state=hidden_states, + hidden_states=all_hidden_states, + attentions=all_self_attns, + cross_attentions=all_cross_attentions, + ) + + +class FlaxWhisperEncoder(nn.Module): + config: WhisperConfig + dtype: jnp.dtype = jnp.float32 + + def setup(self) -> None: + self.conv1 = nn.Conv( + self.config.d_model, + kernel_size=(3,), + padding=1, + kernel_init=jax.nn.initializers.normal(self.config.init_std), + dtype=self.dtype, + ) + self.conv2 = nn.Conv( + self.config.d_model, + kernel_size=(3,), + strides=2, + padding=1, + kernel_init=jax.nn.initializers.normal(self.config.init_std), + dtype=self.dtype, + ) + + self.dropout_layer = nn.Dropout(rate=self.config.dropout) + + self.layers = FlaxWhisperEncoderLayerCollection( + self.config, + dtype=self.dtype, + ) + self.embed_positions = nn.Embed(self.config.max_source_positions, self.config.d_model, dtype=self.dtype) + + self.layer_norm = nn.LayerNorm(dtype=self.dtype, epsilon=1e-05) + + def __call__( + self, + input_features: jnp.ndarray, + output_attentions: bool = False, + output_hidden_states: bool = False, + return_dict: bool = True, + deterministic: bool = True, + ) -> Tuple[jnp.ndarray]: + if input_features.shape[1:] != (self.config.num_mel_bins, self.config.max_source_positions * 2): + raise ValueError( + "input_features.shape[1:], must be equal to (self.config.num_mel_bins," + f" self.config.max_source_positions * 2) (got {input_features.shape[1:]}, but should be" + f" ({self.config.num_mel_bins}, {self.config.max_source_positions * 2}))" + ) + + input_features = input_features.transpose(0, 2, 1) + hidden_states = jax.nn.gelu(self.conv1(input_features), approximate=False) + hidden_states = jax.nn.gelu(self.conv2(hidden_states), approximate=False) + + embed_positions = self.embed_positions(jnp.arange(self.config.max_source_positions)) + hidden_states = hidden_states + embed_positions + + hidden_states = self.dropout_layer(hidden_states, deterministic=deterministic) + + outputs = self.layers( + hidden_states, + attention_mask=None, + deterministic=deterministic, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + last_hidden_states = outputs[0] + last_hidden_states = self.layer_norm(last_hidden_states) + + # update the last element in `hidden_states` after applying `layernorm` above + hidden_states = None + if output_hidden_states: + hidden_states = outputs[1] + hidden_states = hidden_states[:-1] + (last_hidden_states,) + + if not return_dict: + outputs = (last_hidden_states, hidden_states) + (outputs[2:] if output_hidden_states else outputs[1:]) + return tuple(v for v in outputs if v is not None) + + return FlaxBaseModelOutput( + last_hidden_state=last_hidden_states, + hidden_states=hidden_states, + attentions=outputs.attentions, + ) + + +class FlaxWhisperDecoder(nn.Module): + config: WhisperConfig + dtype: jnp.dtype = jnp.float32 + + def setup(self) -> None: + self.embed_tokens = nn.Embed(self.config.vocab_size, self.config.d_model, dtype=self.dtype) + self.embed_positions = nn.Embed(self.config.max_target_positions, self.config.d_model, dtype=self.dtype) + + self.layers = FlaxWhisperDecoderLayerCollection(self.config, dtype=self.dtype) + + self.dropout_layer = nn.Dropout(rate=self.config.dropout) + + self.layer_norm = nn.LayerNorm(dtype=self.dtype, epsilon=1e-5) + + def __call__( + self, + input_ids: jnp.ndarray, + attention_mask: jnp.ndarray, + position_ids: jnp.ndarray, + encoder_hidden_states: Optional[jnp.ndarray] = None, + init_cache: bool = False, + output_attentions: bool = False, + output_hidden_states: bool = False, + return_dict: bool = True, + deterministic: bool = True, + ) -> Tuple[jnp.ndarray]: + input_embeds = self.embed_tokens(input_ids) + position_embeds = self.embed_positions(position_ids) + + hidden_states = input_embeds + position_embeds + hidden_states = self.dropout_layer(hidden_states, deterministic=deterministic) + + outputs = self.layers( + hidden_states, + attention_mask=attention_mask, + encoder_hidden_states=encoder_hidden_states, + deterministic=deterministic, + init_cache=init_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + last_hidden_states = outputs[0] + last_hidden_states = self.layer_norm(last_hidden_states) + + # update the last element in `hidden_states` after applying `layernorm` above + hidden_states = None + if output_hidden_states: + hidden_states = outputs[1] + hidden_states = hidden_states[:-1] + (last_hidden_states,) + + if not return_dict: + outputs = (last_hidden_states, hidden_states) + (outputs[2:] if output_hidden_states else outputs[1:]) + return tuple(v for v in outputs if v is not None) + + return FlaxBaseModelOutputWithPastAndCrossAttentions( + last_hidden_state=last_hidden_states, + hidden_states=hidden_states, + attentions=outputs.attentions, + cross_attentions=outputs.cross_attentions, + ) + + +class FlaxWhisperModule(nn.Module): + config: WhisperConfig + dtype: jnp.dtype = jnp.float32 + + def setup(self) -> None: + self.encoder = FlaxWhisperEncoder(self.config, dtype=self.dtype) + self.decoder = FlaxWhisperDecoder(self.config, dtype=self.dtype) + + def __call__( + self, + input_features: jnp.ndarray, + decoder_input_ids: jnp.ndarray, + decoder_attention_mask: jnp.ndarray, + decoder_position_ids: jnp.ndarray, + output_attentions: bool = False, + output_hidden_states: bool = False, + return_dict: bool = True, + deterministic: bool = True, + ): + encoder_outputs = self.encoder( + input_features, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + deterministic=deterministic, + ) + + decoder_outputs = self.decoder( + input_ids=decoder_input_ids, + attention_mask=decoder_attention_mask, + position_ids=decoder_position_ids, + encoder_hidden_states=encoder_outputs[0], + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + deterministic=deterministic, + ) + + if not return_dict: + return decoder_outputs + encoder_outputs + + return FlaxSeq2SeqModelOutput( + last_hidden_state=decoder_outputs.last_hidden_state, + decoder_hidden_states=decoder_outputs.hidden_states, + decoder_attentions=decoder_outputs.attentions, + cross_attentions=decoder_outputs.cross_attentions, + encoder_last_hidden_state=encoder_outputs.last_hidden_state, + encoder_hidden_states=encoder_outputs.hidden_states, + encoder_attentions=encoder_outputs.attentions, + ) + + def _get_encoder_module(self): + return self.encoder + + def _get_decoder_module(self): + return self.decoder + + +class FlaxWhisperPreTrainedModel(FlaxPreTrainedModel): + config_class = WhisperConfig + base_model_prefix: str = "model" + main_input_name = "input_features" + module_class: nn.Module = None + + def __init__( + self, + config: WhisperConfig, + input_shape: Tuple[int] = (1, 80, 3000), + seed: int = 0, + dtype: jnp.dtype = jnp.float32, + _do_init: bool = True, + **kwargs + ): + module = self.module_class(config=config, dtype=dtype, **kwargs) + super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype, _do_init=_do_init) + + def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple, params: FrozenDict = None) -> FrozenDict: + # init input tensors + input_features = jnp.zeros(input_shape, dtype="f4") + input_features = input_features.at[(..., -1)].set(self.config.eos_token_id) + + decoder_input_ids = jnp.zeros((input_shape[0], 1), dtype="i4") + decoder_attention_mask = jnp.ones_like(decoder_input_ids) + + batch_size, sequence_length = decoder_input_ids.shape + decoder_position_ids = jnp.broadcast_to(jnp.arange(sequence_length)[None, :], (batch_size, sequence_length)) + + params_rng, dropout_rng = jax.random.split(rng) + rngs = {"params": params_rng, "dropout": dropout_rng} + + random_params = self.module.init( + rngs, + input_features=input_features, + decoder_input_ids=decoder_input_ids, + decoder_attention_mask=decoder_attention_mask, + decoder_position_ids=decoder_position_ids, + )["params"] + + if params is not None: + random_params = flatten_dict(unfreeze(random_params)) + params = flatten_dict(unfreeze(params)) + for missing_key in self._missing_keys: + params[missing_key] = random_params[missing_key] + self._missing_keys = set() + return freeze(unflatten_dict(params)) + else: + return random_params + + # Copied from transformers.models.bart.modeling_flax_bart.FlaxBartPreTrainedModel.init_cache with Bart->Whisper + def init_cache(self, batch_size, max_length, encoder_outputs): + r""" + Args: + batch_size (`int`): + batch_size used for fast auto-regressive decoding. Defines the batch size of the initialized cache. + max_length (`int`): + maximum possible length for auto-regressive decoding. Defines the sequence length of the initialized + cache. + encoder_outputs (`Union[FlaxBaseModelOutput, tuple(tuple(jnp.ndarray)]`): + `encoder_outputs` consists of (`last_hidden_state`, *optional*: `hidden_states`, *optional*: + `attentions`). `last_hidden_state` of shape `(batch_size, sequence_length, hidden_size)`, *optional*) + is a sequence of hidden-states at the output of the last layer of the encoder. Used in the + cross-attention of the decoder. + """ + # init input variables to retrieve cache + decoder_input_ids = jnp.ones((batch_size, max_length), dtype="i4") + decoder_attention_mask = jnp.ones_like(decoder_input_ids) + decoder_position_ids = jnp.broadcast_to( + jnp.arange(jnp.atleast_2d(decoder_input_ids).shape[-1]), decoder_input_ids.shape + ) + + def _decoder_forward(module, decoder_input_ids, decoder_attention_mask, decoder_position_ids, **kwargs): + decoder_module = module._get_decoder_module() + return decoder_module( + decoder_input_ids, + decoder_attention_mask, + decoder_position_ids, + **kwargs, + ) + + init_variables = self.module.init( + jax.random.PRNGKey(0), + decoder_input_ids=decoder_input_ids, + decoder_attention_mask=decoder_attention_mask, + decoder_position_ids=decoder_position_ids, + encoder_hidden_states=encoder_outputs[0], + init_cache=True, + method=_decoder_forward, # we only need to call the decoder to init the cache + ) + return unfreeze(init_variables["cache"]) + + @add_start_docstrings(WHISPER_ENCODE_INPUTS_DOCSTRING) + @replace_return_docstrings(output_type=FlaxBaseModelOutput, config_class=WhisperConfig) + def encode( + self, + input_features: jnp.ndarray, + attention_mask: Optional[jnp.ndarray] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + train: bool = False, + params: dict = None, + dropout_rng: PRNGKey = None, + **kwargs, + ): + r""" + Returns: + + Example: + + ```python + >>> from transformers import WhisperProcessor, FlaxWhisperForConditionalGeneration + >>> from datasets import load_dataset + + >>> processor = WhisperProcessor.from_pretrained("openai/whisper-tiny.en") + >>> model = FlaxWhisperForConditionalGeneration.from_pretrained("openai/whisper-tiny.en", from_pt=True) + >>> ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation") + >>> inputs = processor(ds[0]["audio"]["array"], return_tensors="np") + >>> input_features = inputs.input_features + >>> encoder_outputs = model.encode(input_features=input_features) + ```""" + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.return_dict + + # Handle any PRNG if needed + rngs = {} + if dropout_rng is not None: + rngs["dropout"] = dropout_rng + + def _encoder_forward(module, input_features, **kwargs): + encode_module = module._get_encoder_module() + return encode_module(input_features, **kwargs) + + return self.module.apply( + {"params": params or self.params}, + input_features=jnp.array(input_features, dtype="f4"), + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + deterministic=not train, + rngs=rngs, + method=_encoder_forward, + ) + + @add_start_docstrings(WHISPER_DECODE_INPUTS_DOCSTRING) + @replace_return_docstrings(output_type=FlaxBaseModelOutputWithPastAndCrossAttentions, config_class=WhisperConfig) + def decode( + self, + decoder_input_ids, + encoder_outputs, + encoder_attention_mask: Optional[jnp.ndarray] = None, + decoder_attention_mask: Optional[jnp.ndarray] = None, + decoder_position_ids: Optional[jnp.ndarray] = None, + past_key_values: dict = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + train: bool = False, + params: dict = None, + dropout_rng: PRNGKey = None, + ): + r""" + Returns: + + Example: + + ```python + >>> from transformers import WhisperProcessor, FlaxWhisperForConditionalGeneration + >>> from datasets import load_dataset + + >>> processor = WhisperProcessor.from_pretrained("openai/whisper-tiny.en") + >>> model = FlaxWhisperForConditionalGeneration.from_pretrained("openai/whisper-tiny.en", from_pt=True) + >>> ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation") + >>> inputs = processor(ds[0]["audio"]["array"], return_tensors="np") + >>> input_features = inputs.input_features + >>> encoder_outputs = model.encode(input_features=input_features) + >>> decoder_start_token_id = model.config.decoder_start_token_id + + >>> decoder_input_ids = jnp.ones((inputs.input_ids.shape[0], 1), dtype="i4") * decoder_start_token_id + + >>> outputs = model.decode(decoder_input_ids, encoder_outputs) + >>> last_decoder_hidden_states = outputs.last_hidden_state + ```""" + + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.return_dict + + encoder_hidden_states = encoder_outputs[0] + + batch_size, sequence_length = decoder_input_ids.shape + if decoder_position_ids is None: + if past_key_values is not None: + raise ValueError("Make sure to provide `decoder_position_ids` when passing `past_key_values`.") + + if decoder_attention_mask is not None: + decoder_position_ids = (decoder_attention_mask.cumsum(-1) * decoder_attention_mask) - 1 + else: + decoder_position_ids = jnp.broadcast_to( + jnp.arange(sequence_length)[None, :], (batch_size, sequence_length) + ) + + if decoder_attention_mask is None: + decoder_attention_mask = jnp.ones((batch_size, sequence_length)) + + # Handle any PRNG if needed + rngs = {} + if dropout_rng is not None: + rngs["dropout"] = dropout_rng + + inputs = {"params": params or self.params} + + # if past_key_values are passed then cache is already initialized a private flag init_cache has to be + # passed down to ensure cache is used. It has to be made sure that cache is marked as mutable so that + # it can be changed by FlaxWhisperAttention module + if past_key_values: + inputs["cache"] = past_key_values + mutable = ["cache"] + else: + mutable = False + + def _decoder_forward(module, decoder_input_ids, decoder_attention_mask, decoder_position_ids, **kwargs): + decoder_module = module._get_decoder_module() + return decoder_module( + input_ids=decoder_input_ids, + attention_mask=decoder_attention_mask, + position_ids=decoder_position_ids, + **kwargs, + ) + + outputs = self.module.apply( + inputs, + decoder_input_ids=jnp.array(decoder_input_ids, dtype="i4"), + decoder_attention_mask=jnp.array(decoder_attention_mask, dtype="i4"), + decoder_position_ids=jnp.array(decoder_position_ids, dtype="i4"), + encoder_hidden_states=encoder_hidden_states, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + deterministic=not train, + rngs=rngs, + mutable=mutable, + method=_decoder_forward, + ) + + # add updated cache to model output + if past_key_values is not None and return_dict: + outputs, past = outputs + outputs["past_key_values"] = unfreeze(past["cache"]) + return outputs + elif past_key_values is not None and not return_dict: + outputs, past = outputs + outputs = outputs[:1] + (unfreeze(past["cache"]),) + outputs[1:] + + return outputs + + @add_start_docstrings_to_model_forward(WHISPER_INPUTS_DOCSTRING) + def __call__( + self, + input_features: jnp.ndarray, + decoder_input_ids: jnp.ndarray, + attention_mask: Optional[jnp.ndarray] = None, + decoder_attention_mask: Optional[jnp.ndarray] = None, + position_ids: Optional[jnp.ndarray] = None, + decoder_position_ids: Optional[jnp.ndarray] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + train: bool = False, + params: dict = None, + dropout_rng: PRNGKey = None, + ): + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.return_dict + + # prepare decoder inputs + if decoder_position_ids is None: + if decoder_attention_mask is not None: + decoder_position_ids = (decoder_attention_mask.cumsum(-1) * decoder_attention_mask) - 1 + else: + batch_size, sequence_length = decoder_input_ids.shape + decoder_position_ids = jnp.broadcast_to( + jnp.arange(sequence_length)[None, :], (batch_size, sequence_length) + ) + if decoder_attention_mask is None: + decoder_attention_mask = jnp.ones_like(decoder_input_ids) + + # Handle any PRNG if needed + rngs = {"dropout": dropout_rng} if dropout_rng is not None else {} + + return self.module.apply( + {"params": params or self.params}, + input_features=jnp.array(input_features, dtype="f4"), + decoder_input_ids=jnp.array(decoder_input_ids, dtype="i4"), + decoder_attention_mask=jnp.array(decoder_attention_mask, dtype="i4"), + decoder_position_ids=jnp.array(decoder_position_ids, dtype="i4"), + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + deterministic=not train, + rngs=rngs, + ) + + +@add_start_docstrings( + "The bare Whisper Model transformer outputting raw hidden-states without any specific head on top.", + WHISPER_START_DOCSTRING, +) +class FlaxWhisperModel(FlaxWhisperPreTrainedModel): + config: WhisperConfig + dtype: jnp.dtype = jnp.float32 # the dtype of the computation + module_class = FlaxWhisperModule + + +append_call_sample_docstring(FlaxWhisperModel, _CHECKPOINT_FOR_DOC, FlaxSeq2SeqModelOutput, _CONFIG_FOR_DOC) + + +class FlaxWhisperForConditionalGenerationModule(nn.Module): + config: WhisperConfig + dtype: jnp.dtype = jnp.float32 + + def setup(self) -> None: + self.model = FlaxWhisperModule(config=self.config, dtype=self.dtype) + self.lm_head = nn.Dense( + self.config.vocab_size, + use_bias=False, + dtype=self.dtype, + kernel_init=jax.nn.initializers.normal(self.config.init_std), + ) + + def _get_encoder_module(self): + return self.model.encoder + + def _get_decoder_module(self): + return self.model.decoder + + def __call__( + self, + input_features, + decoder_input_ids, + decoder_attention_mask: jnp.ndarray = None, + decoder_position_ids: jnp.ndarray = None, + position_ids: jnp.ndarray = None, + attention_mask: jnp.ndarray = None, + output_attentions: bool = False, + output_hidden_states: bool = False, + return_dict: bool = True, + deterministic: bool = True, + ): + outputs = self.model( + input_features=input_features, + decoder_input_ids=decoder_input_ids, + decoder_attention_mask=decoder_attention_mask, + decoder_position_ids=decoder_position_ids, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + deterministic=deterministic, + ) + + hidden_states = outputs[0] + + if self.config.tie_word_embeddings: + shared_embedding = self.model.decoder.embed_tokens.variables["params"]["embedding"] + lm_logits = self.lm_head.apply({"params": {"kernel": shared_embedding.T}}, hidden_states) + else: + lm_logits = self.lm_head(hidden_states) + + if not return_dict: + output = (lm_logits,) + outputs[1:] + return output + + return FlaxSeq2SeqLMOutput( + logits=lm_logits, + decoder_hidden_states=outputs.decoder_hidden_states, + decoder_attentions=outputs.decoder_attentions, + cross_attentions=outputs.cross_attentions, + encoder_last_hidden_state=outputs.encoder_last_hidden_state, + encoder_hidden_states=outputs.encoder_hidden_states, + encoder_attentions=outputs.encoder_attentions, + ) + + +@add_start_docstrings("The Whisper Model with a language modeling head.", WHISPER_START_DOCSTRING) +class FlaxWhisperForConditionalGeneration(FlaxWhisperPreTrainedModel): + module_class = FlaxWhisperForConditionalGenerationModule + dtype: jnp.dtype = jnp.float32 + + @add_start_docstrings(WHISPER_DECODE_INPUTS_DOCSTRING) + @replace_return_docstrings(output_type=FlaxCausalLMOutputWithCrossAttentions, config_class=WhisperConfig) + def decode( + self, + decoder_input_ids, + encoder_outputs, + encoder_attention_mask: Optional[jnp.ndarray] = None, + decoder_attention_mask: Optional[jnp.ndarray] = None, + decoder_position_ids: Optional[jnp.ndarray] = None, + past_key_values: dict = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + train: bool = False, + params: dict = None, + dropout_rng: PRNGKey = None, + ): + r""" + Returns: + + Example: + + ```python + >>> from transformers import WhisperProcessor, FlaxWhisperForConditionalGeneration + >>> from datasets import load_dataset + + >>> processor = WhisperProcessor.from_pretrained("openai/whisper-tiny.en") + >>> model = FlaxWhisperForConditionalGeneration.from_pretrained("openai/whisper-tiny.en", from_pt=True) + >>> ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation") + >>> inputs = processor(ds[0]["audio"]["array"], return_tensors="np") + >>> input_features = inputs.input_features + >>> encoder_outputs = model.encode(input_features=input_features) + >>> decoder_start_token_id = model.config.decoder_start_token_id + + >>> decoder_input_ids = jnp.ones((inputs.input_ids.shape[0], 1), dtype="i4") * decoder_start_token_id + + >>> outputs = model.decode(decoder_input_ids, encoder_outputs) + >>> last_decoder_hidden_states = outputs.last_hidden_state + ```""" + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.return_dict + + encoder_hidden_states = encoder_outputs[0] + + batch_size, sequence_length = decoder_input_ids.shape + if decoder_position_ids is None: + if past_key_values is not None: + raise ValueError("Make sure to provide `decoder_position_ids` when passing `past_key_values`.") + + if decoder_attention_mask is not None: + decoder_position_ids = (decoder_attention_mask.cumsum(-1) * decoder_attention_mask) - 1 + else: + decoder_position_ids = jnp.broadcast_to( + jnp.arange(sequence_length)[None, :], (batch_size, sequence_length) + ) + if decoder_attention_mask is None: + decoder_attention_mask = jnp.ones((batch_size, sequence_length), dtype="i4") + + # Handle any PRNG if needed + rngs = {} + if dropout_rng is not None: + rngs["dropout"] = dropout_rng + + inputs = {"params": params or self.params} + + # if past_key_values are passed then cache is already initialized a private flag init_cache has to be + # passed down to ensure cache is used. It has to be made sure that cache is marked as mutable so that + # it can be changed by FlaxWhisperAttention module + if past_key_values: + inputs["cache"] = past_key_values + mutable = ["cache"] + else: + mutable = False + + def _decoder_forward(module, decoder_input_ids, decoder_attention_mask, decoder_position_ids, **kwargs): + decoder_module = module._get_decoder_module() + outputs = decoder_module( + input_ids=decoder_input_ids, + attention_mask=decoder_attention_mask, + position_ids=decoder_position_ids, + **kwargs, + ) + hidden_states = outputs[0] + + if self.config.tie_word_embeddings: + shared_embedding = module.model.decoder.embed_tokens.variables["params"]["embedding"] + lm_logits = module.lm_head.apply({"params": {"kernel": shared_embedding.T}}, hidden_states) + else: + lm_logits = module.lm_head(hidden_states) + + return lm_logits, outputs + + outputs = self.module.apply( + inputs, + decoder_input_ids=jnp.array(decoder_input_ids, dtype="i4"), + decoder_attention_mask=jnp.array(decoder_attention_mask, dtype="i4"), + decoder_position_ids=jnp.array(decoder_position_ids, dtype="i4"), + encoder_hidden_states=encoder_hidden_states, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + deterministic=not train, + rngs=rngs, + mutable=mutable, + method=_decoder_forward, + ) + + if past_key_values is None: + lm_logits, decoder_outputs = outputs + else: + (lm_logits, decoder_outputs), past = outputs + + if return_dict: + outputs = FlaxCausalLMOutputWithCrossAttentions( + logits=lm_logits, + hidden_states=decoder_outputs.hidden_states, + attentions=decoder_outputs.attentions, + cross_attentions=decoder_outputs.cross_attentions, + ) + else: + outputs = (lm_logits,) + decoder_outputs[1:] + + # add updated cache to model output + if past_key_values is not None and return_dict: + outputs["past_key_values"] = unfreeze(past["cache"]) + return outputs + elif past_key_values is not None and not return_dict: + outputs = outputs[:1] + (unfreeze(past["cache"]),) + outputs[1:] + + return outputs + + def generate( + self, + input_features, + generation_config=None, + logits_processor=None, + return_timestamps=None, + task=None, + language=None, + is_multilingual=None, + **kwargs + ): + if generation_config is None: + generation_config = self.generation_config + + if return_timestamps is not None: + generation_config.return_timestamps = return_timestamps + + if task is not None: + generation_config.task = task + + if is_multilingual is not None: + generation_config.is_multilingual = is_multilingual + + if language is not None: + generation_config.language = language + + if kwargs is not None and "decoder_input_ids" in kwargs: + decoder_input_length = len(kwargs["decoder_input_ids"]) + else: + decoder_input_length = 1 + + forced_decoder_ids = [] + + if hasattr(generation_config, "is_multilingual") and generation_config.is_multilingual: + if hasattr(generation_config, "language"): + forced_decoder_ids.append((1, generation_config.lang_to_id[generation_config.language])) + else: + forced_decoder_ids.append((1, None)) + + if hasattr(generation_config, "task"): + forced_decoder_ids.append((2, generation_config.task_to_id[generation_config.task])) + else: + forced_decoder_ids.append((2, generation_config.task_to_id["transcribe"])) + + if ( + hasattr(generation_config, "return_timestamps") and generation_config.return_timestamps + ) or return_timestamps: + logits_processor = [ + FlaxWhisperTimeStampLogitsProcessor(generation_config, self.config, decoder_input_length) + ] + else: + if forced_decoder_ids and forced_decoder_ids[-1][0] != generation_config.no_timestamps_token_id: + idx = forced_decoder_ids[-1][0] + 1 if forced_decoder_ids else 1 + forced_decoder_ids.append((idx, generation_config.no_timestamps_token_id)) + + if len(forced_decoder_ids) > 0: + generation_config.forced_decoder_ids = forced_decoder_ids + + return super().generate( + input_features, + generation_config, + logits_processor=logits_processor, + **kwargs, + ) + + def prepare_inputs_for_generation( + self, + decoder_input_ids, + max_length, + attention_mask: Optional[jnp.DeviceArray] = None, + decoder_attention_mask: Optional[jnp.DeviceArray] = None, + encoder_outputs=None, + **kwargs + ): + # initializing the cache + batch_size, seq_length = decoder_input_ids.shape + + past_key_values = self.init_cache(batch_size, max_length, encoder_outputs) + # Note that usually one would have to put 0's in the attention_mask for x > input_ids.shape[-1] and x < cache_length. + # But since the decoder uses a causal mask, those positions are masked anyways. + # Thus we can create a single static attention_mask here, which is more efficient for compilation + extended_attention_mask = jnp.ones((batch_size, max_length), dtype="i4") + if decoder_attention_mask is not None: + position_ids = decoder_attention_mask.cumsum(-1) - 1 + extended_attention_mask = lax.dynamic_update_slice(extended_attention_mask, decoder_attention_mask, (0, 0)) + else: + position_ids = jnp.broadcast_to(jnp.arange(seq_length, dtype="i4")[None, :], (batch_size, seq_length)) + + return { + "past_key_values": past_key_values, + "encoder_outputs": encoder_outputs, + "encoder_attention_mask": attention_mask, + "decoder_attention_mask": extended_attention_mask, + "decoder_position_ids": position_ids, + } + + def update_inputs_for_generation(self, model_outputs, model_kwargs): + model_kwargs["past_key_values"] = model_outputs.past_key_values + model_kwargs["decoder_position_ids"] = model_kwargs["decoder_position_ids"][:, -1:] + 1 + return model_kwargs + + +FLAX_WHISPER_CONDITIONAL_GENERATION_DOCSTRING = r""" + Returns: + + Transcription example: + + ```python + >>> from transformers import WhisperProcessor, FlaxWhisperForConditionalGeneration + >>> from datasets import load_dataset + + >>> processor = WhisperProcessor.from_pretrained("openai/whisper-tiny.en") + >>> model = FlaxWhisperForConditionalGeneration.from_pretrained("openai/whisper-tiny.en", from_pt=True) + >>> ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation") + >>> inputs = processor(ds[0]["audio"]["array"], return_tensors="np") + >>> input_features = inputs.input_features + >>> generated_ids = model.generate(input_ids=input_features) + >>> transcription = processor.batch_decode(generated_ids, skip_special_tokens=True)[0] + >>> transcription + ' Mr. Quilter is the apostle of the middle classes, and we are glad to welcome his gospel.' + ``` +""" + +overwrite_call_docstring( + FlaxWhisperForConditionalGeneration, WHISPER_INPUTS_DOCSTRING + FLAX_WHISPER_CONDITIONAL_GENERATION_DOCSTRING +) +append_replace_return_docstrings( + FlaxWhisperForConditionalGeneration, output_type=FlaxSeq2SeqLMOutput, config_class=_CONFIG_FOR_DOC +) diff --git a/src/transformers/models/whisper/modeling_whisper.py b/src/transformers/models/whisper/modeling_whisper.py index 38bfe133a076b2..0f36edaf1dfa3c 100644 --- a/src/transformers/models/whisper/modeling_whisper.py +++ b/src/transformers/models/whisper/modeling_whisper.py @@ -102,7 +102,6 @@ def __init__(self, num_positions: int, embedding_dim: int, padding_idx: Optional super().__init__(num_positions, embedding_dim) def forward(self, input_ids, past_key_values_length=0): - return self.weight[past_key_values_length : past_key_values_length + input_ids.shape[-1]] @@ -897,7 +896,6 @@ def forward( past_key_value = past_key_values[idx] if past_key_values is not None else None if self.gradient_checkpointing and self.training: - if use_cache: logger.warning( "`use_cache = True` is incompatible with gradient checkpointing. Setting `use_cache =" @@ -923,7 +921,6 @@ def custom_forward(*inputs): None, # past_key_value ) else: - layer_outputs = decoder_layer( hidden_states, attention_mask=attention_mask, diff --git a/src/transformers/utils/dummy_flax_objects.py b/src/transformers/utils/dummy_flax_objects.py index 8c093e3c4510f0..380dc3d2adc7c8 100644 --- a/src/transformers/utils/dummy_flax_objects.py +++ b/src/transformers/utils/dummy_flax_objects.py @@ -163,6 +163,9 @@ def __init__(self, *args, **kwargs): FLAX_MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING = None +FLAX_MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING = None + + FLAX_MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING = None @@ -242,6 +245,13 @@ def __init__(self, *args, **kwargs): requires_backends(self, ["flax"]) +class FlaxAutoModelForSpeechSeq2Seq(metaclass=DummyObject): + _backends = ["flax"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["flax"]) + + class FlaxAutoModelForTokenClassification(metaclass=DummyObject): _backends = ["flax"] @@ -1131,6 +1141,27 @@ def __init__(self, *args, **kwargs): requires_backends(self, ["flax"]) +class FlaxWhisperForConditionalGeneration(metaclass=DummyObject): + _backends = ["flax"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["flax"]) + + +class FlaxWhisperModel(metaclass=DummyObject): + _backends = ["flax"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["flax"]) + + +class FlaxWhisperPreTrainedModel(metaclass=DummyObject): + _backends = ["flax"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["flax"]) + + class FlaxXGLMForCausalLM(metaclass=DummyObject): _backends = ["flax"] diff --git a/tests/models/whisper/test_modeling_flax_whisper.py b/tests/models/whisper/test_modeling_flax_whisper.py new file mode 100644 index 00000000000000..a102f5d48df0e5 --- /dev/null +++ b/tests/models/whisper/test_modeling_flax_whisper.py @@ -0,0 +1,706 @@ +# coding=utf-8 +# Copyright 2022 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import functools +import inspect +import tempfile +import unittest + +import transformers +from transformers import WhisperConfig, is_flax_available +from transformers.testing_utils import is_pt_flax_cross_test, require_flax, slow +from transformers.utils import cached_property +from transformers.utils.import_utils import is_datasets_available + +from ...test_configuration_common import ConfigTester +from ...test_modeling_flax_common import FlaxModelTesterMixin, floats_tensor + + +if is_datasets_available(): + import datasets + from datasets import load_dataset + +if is_flax_available(): + import numpy as np + + import jax + from flax.core.frozen_dict import unfreeze + from flax.traverse_util import flatten_dict + from transformers import ( + FLAX_MODEL_MAPPING, + FlaxWhisperForConditionalGeneration, + FlaxWhisperModel, + WhisperFeatureExtractor, + WhisperProcessor, + ) + from transformers.modeling_flax_pytorch_utils import load_flax_weights_in_pytorch_model + + +@require_flax +class FlaxWhisperModelTester: + config_cls = WhisperConfig + config_updates = {} + hidden_act = "gelu" + + def __init__( + self, + parent, + batch_size=13, + seq_length=60, + is_training=True, + use_labels=False, + vocab_size=99, + d_model=16, + decoder_attention_heads=4, + decoder_ffn_dim=16, + decoder_layers=2, + encoder_attention_heads=4, + encoder_ffn_dim=16, + encoder_layers=2, + input_channels=1, + hidden_act="gelu", + hidden_dropout_prob=0.1, + attention_probs_dropout_prob=0.1, + max_position_embeddings=70, + max_source_positions=30, + max_target_positions=40, + bos_token_id=98, + eos_token_id=98, + pad_token_id=0, + num_mel_bins=80, + decoder_start_token_id=85, + num_conv_layers=1, + suppress_tokens=None, + begin_suppress_tokens=None, + ): + self.parent = parent + self.batch_size = batch_size + self.seq_length = seq_length + self.is_training = is_training + self.use_labels = use_labels + self.vocab_size = vocab_size + self.d_model = d_model + self.hidden_size = d_model + self.num_hidden_layers = encoder_layers + self.num_attention_heads = encoder_attention_heads + self.decoder_attention_heads = decoder_attention_heads + self.decoder_ffn_dim = decoder_ffn_dim + self.decoder_layers = decoder_layers + self.encoder_attention_heads = encoder_attention_heads + self.encoder_ffn_dim = encoder_ffn_dim + self.encoder_layers = encoder_layers + self.encoder_seq_length = seq_length // 2 + self.decoder_seq_length = 1 + self.input_channels = input_channels + self.hidden_act = hidden_act + self.hidden_dropout_prob = hidden_dropout_prob + self.attention_probs_dropout_prob = attention_probs_dropout_prob + self.num_mel_bins = num_mel_bins + self.max_position_embeddings = max_position_embeddings + self.max_source_positions = max_source_positions + self.max_target_positions = max_target_positions + self.eos_token_id = eos_token_id + self.pad_token_id = pad_token_id + self.bos_token_id = bos_token_id + self.decoder_start_token_id = decoder_start_token_id + self.num_conv_layers = num_conv_layers + self.suppress_tokens = suppress_tokens + self.begin_suppress_tokens = begin_suppress_tokens + + def prepare_config_and_inputs_for_common(self): + input_features = floats_tensor([self.batch_size, self.num_mel_bins, self.seq_length], self.vocab_size) + + decoder_input_ids = np.array(self.batch_size * [[self.decoder_start_token_id]]) + + config = WhisperConfig( + vocab_size=self.vocab_size, + num_mel_bins=self.num_mel_bins, + decoder_start_token_id=self.decoder_start_token_id, + is_encoder_decoder=True, + activation_function=self.hidden_act, + dropout=self.hidden_dropout_prob, + attention_dropout=self.attention_probs_dropout_prob, + max_source_positions=self.max_source_positions, + max_target_positions=self.max_target_positions, + pad_token_id=self.pad_token_id, + bos_token_id=self.bos_token_id, + eos_token_id=self.eos_token_id, + tie_word_embeddings=True, + d_model=self.d_model, + decoder_attention_heads=self.decoder_attention_heads, + decoder_ffn_dim=self.decoder_ffn_dim, + decoder_layers=self.decoder_layers, + encoder_attention_heads=self.encoder_attention_heads, + encoder_ffn_dim=self.encoder_ffn_dim, + encoder_layers=self.encoder_layers, + suppress_tokens=self.suppress_tokens, + begin_suppress_tokens=self.begin_suppress_tokens, + ) + inputs_dict = prepare_whisper_inputs_dict(config, input_features, decoder_input_ids) + return config, inputs_dict + + +def prepare_whisper_inputs_dict( + config, + input_ids, + decoder_input_ids, + attention_mask=None, + decoder_attention_mask=None, +): + if decoder_attention_mask is None: + decoder_attention_mask = np.concatenate( + [ + np.ones(decoder_input_ids[:, :1].shape, dtype=np.int8), + np.not_equal(decoder_input_ids[:, 1:], config.pad_token_id).astype(np.int8), + ], + axis=-1, + ) + return { + "input_features": input_ids, + "decoder_input_ids": decoder_input_ids, + "decoder_attention_mask": decoder_attention_mask, + } + + +def partialclass(cls, *args, **kwargs): + class NewCls(cls): + __init__ = functools.partialmethod(cls.__init__, *args, **kwargs) + + return NewCls + + +def make_partial_class(full_class, *args, **kwargs): + partial_class = partialclass(full_class, *args, **kwargs) + partial_class.__name__ = full_class.__name__ + partial_class.__module__ = full_class.__module__ + + return partial_class + + +@require_flax +class FlaxWhisperModelTest(FlaxModelTesterMixin, unittest.TestCase): + all_model_classes = (FlaxWhisperForConditionalGeneration, FlaxWhisperModel) if is_flax_available() else () + all_generative_model_classes = (FlaxWhisperForConditionalGeneration,) if is_flax_available() else () + is_encoder_decoder = True + test_pruning = False + test_head_masking = False + test_onnx = False + + def setUp(self): + self.model_tester = FlaxWhisperModelTester(self) + _, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + self.init_shape = (1,) + inputs_dict["input_features"].shape[1:] + + self.all_model_classes = ( + make_partial_class(model_class, input_shape=self.init_shape) for model_class in self.all_model_classes + ) + self.config_tester = ConfigTester(self, config_class=WhisperConfig) + + def test_config(self): + self.config_tester.run_common_tests() + + # overwrite because of `input_features` + def test_forward_signature(self): + config, _ = self.model_tester.prepare_config_and_inputs_for_common() + + for model_class in self.all_model_classes: + model = model_class(config) + signature = inspect.signature(model.__call__) + # signature.parameters is an OrderedDict => so arg_names order is deterministic + arg_names = [*signature.parameters.keys()] + + expected_arg_names = ["input_features", "decoder_input_ids"] + self.assertListEqual(arg_names[:2], expected_arg_names) + + # overwrite because of `input_features` + def test_jit_compilation(self): + config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + + for model_class in self.all_model_classes: + with self.subTest(model_class.__name__): + prepared_inputs_dict = self._prepare_for_class(inputs_dict, model_class) + model = model_class(config) + + @jax.jit + def model_jitted(input_features, decoder_input_ids, **kwargs): + return model(input_features=input_features, decoder_input_ids=decoder_input_ids, **kwargs) + + with self.subTest("JIT Enabled"): + jitted_outputs = model_jitted(**prepared_inputs_dict).to_tuple() + + with self.subTest("JIT Disabled"): + with jax.disable_jit(): + outputs = model_jitted(**prepared_inputs_dict).to_tuple() + + self.assertEqual(len(outputs), len(jitted_outputs)) + for jitted_output, output in zip(jitted_outputs, outputs): + self.assertEqual(jitted_output.shape, output.shape) + + # overwrite because of `input_features` + @is_pt_flax_cross_test + def test_save_load_bf16_to_base_pt(self): + config, _ = self.model_tester.prepare_config_and_inputs_for_common() + base_class = make_partial_class(FLAX_MODEL_MAPPING[config.__class__], input_shape=self.init_shape) + + for model_class in self.all_model_classes: + if model_class.__name__ == base_class.__name__: + continue + + model = model_class(config) + model.params = model.to_bf16(model.params) + base_params_from_head = flatten_dict(unfreeze(model.params[model.base_model_prefix])) + + # convert Flax model to PyTorch model + pt_model_class = getattr(transformers, model_class.__name__[4:]) # Skip the "Flax" at the beginning + pt_model = pt_model_class(config).eval() + pt_model = load_flax_weights_in_pytorch_model(pt_model, model.params) + + # check that all base model weights are loaded correctly + with tempfile.TemporaryDirectory() as tmpdirname: + pt_model.save_pretrained(tmpdirname) + base_model = base_class.from_pretrained(tmpdirname, from_pt=True) + + base_params = flatten_dict(unfreeze(base_model.params)) + + for key in base_params_from_head.keys(): + max_diff = (base_params[key] - base_params_from_head[key]).sum().item() + self.assertLessEqual(max_diff, 1e-3, msg=f"{key} not identical") + + # overwrite because of `input_features` + @is_pt_flax_cross_test + def test_save_load_from_base_pt(self): + config, _ = self.model_tester.prepare_config_and_inputs_for_common() + base_class = make_partial_class(FLAX_MODEL_MAPPING[config.__class__], input_shape=self.init_shape) + + for model_class in self.all_model_classes: + if model_class.__name__ == base_class.__name__: + continue + + model = base_class(config) + base_params = flatten_dict(unfreeze(model.params)) + + # convert Flax model to PyTorch model + pt_model_class = getattr(transformers, base_class.__name__[4:]) # Skip the "Flax" at the beginning + pt_model = pt_model_class(config).eval() + pt_model = load_flax_weights_in_pytorch_model(pt_model, model.params) + + # check that all base model weights are loaded correctly + with tempfile.TemporaryDirectory() as tmpdirname: + # save pt model + pt_model.save_pretrained(tmpdirname) + head_model = model_class.from_pretrained(tmpdirname, from_pt=True) + + base_param_from_head = flatten_dict(unfreeze(head_model.params[head_model.base_model_prefix])) + + for key in base_param_from_head.keys(): + max_diff = (base_params[key] - base_param_from_head[key]).sum().item() + self.assertLessEqual(max_diff, 1e-3, msg=f"{key} not identical") + + # overwrite because of `input_features` + @is_pt_flax_cross_test + def test_save_load_to_base_pt(self): + config, _ = self.model_tester.prepare_config_and_inputs_for_common() + base_class = make_partial_class(FLAX_MODEL_MAPPING[config.__class__], input_shape=self.init_shape) + + for model_class in self.all_model_classes: + if model_class.__name__ == base_class.__name__: + continue + + model = model_class(config) + base_params_from_head = flatten_dict(unfreeze(model.params[model.base_model_prefix])) + + # convert Flax model to PyTorch model + pt_model_class = getattr(transformers, model_class.__name__[4:]) # Skip the "Flax" at the beginning + pt_model = pt_model_class(config).eval() + pt_model = load_flax_weights_in_pytorch_model(pt_model, model.params) + + # check that all base model weights are loaded correctly + with tempfile.TemporaryDirectory() as tmpdirname: + pt_model.save_pretrained(tmpdirname) + base_model = base_class.from_pretrained(tmpdirname, from_pt=True) + + base_params = flatten_dict(unfreeze(base_model.params)) + + for key in base_params_from_head.keys(): + max_diff = (base_params[key] - base_params_from_head[key]).sum().item() + self.assertLessEqual(max_diff, 1e-3, msg=f"{key} not identical") + + # overwrite because of `input_features` + def test_save_load_from_base(self): + config, _ = self.model_tester.prepare_config_and_inputs_for_common() + base_class = make_partial_class(FLAX_MODEL_MAPPING[config.__class__], input_shape=self.init_shape) + + for model_class in self.all_model_classes: + if model_class.__name__ == base_class.__name__: + continue + + model = base_class(config) + base_params = flatten_dict(unfreeze(model.params)) + + # check that all base model weights are loaded correctly + with tempfile.TemporaryDirectory() as tmpdirname: + model.save_pretrained(tmpdirname) + head_model = model_class.from_pretrained(tmpdirname) + + base_param_from_head = flatten_dict(unfreeze(head_model.params[head_model.base_model_prefix])) + + for key in base_param_from_head.keys(): + max_diff = (base_params[key] - base_param_from_head[key]).sum().item() + self.assertLessEqual(max_diff, 1e-3, msg=f"{key} not identical") + + # overwrite because of `input_features` + def test_save_load_to_base(self): + config, _ = self.model_tester.prepare_config_and_inputs_for_common() + base_class = make_partial_class(FLAX_MODEL_MAPPING[config.__class__], input_shape=self.init_shape) + + for model_class in self.all_model_classes: + if model_class.__name__ == base_class.__name__: + continue + + model = model_class(config) + base_params_from_head = flatten_dict(unfreeze(model.params[model.base_model_prefix])) + + # check that all base model weights are loaded correctly + with tempfile.TemporaryDirectory() as tmpdirname: + model.save_pretrained(tmpdirname) + base_model = base_class.from_pretrained(tmpdirname) + + base_params = flatten_dict(unfreeze(base_model.params)) + + for key in base_params_from_head.keys(): + max_diff = (base_params[key] - base_params_from_head[key]).sum().item() + self.assertLessEqual(max_diff, 1e-3, msg=f"{key} not identical") + + +@slow +@require_flax +class FlaxWhisperModelIntegrationTest(unittest.TestCase): + @cached_property + def default_processor(self): + return WhisperProcessor.from_pretrained("openai/whisper-base") + + def _load_datasamples(self, num_samples): + ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation") + # automatic decoding with librispeech + speech_samples = ds.sort("id").select(range(num_samples))[:num_samples]["audio"] + + return [x["array"] for x in speech_samples] + + def test_tiny_logits_librispeech(self): + model = FlaxWhisperModel.from_pretrained("openai/whisper-tiny", from_pt=True) + input_speech = self._load_datasamples(1) + feature_extractor = WhisperFeatureExtractor() + input_features = feature_extractor(input_speech, return_tensors="np").input_features + + logits = model( + input_features, + decoder_input_ids=np.array([[50258, 50259, 50359]]), + output_hidden_states=False, + output_attentions=False, + return_dict=False, + ) + + # fmt: off + EXPECTED_LOGITS = np.array( + [ + 2.9892, -6.7607, 5.7348, 3.6096, 0.2152, -5.7321, 4.8855, -1.6407, + 0.2823, -1.5718, 10.4269, 3.4427, 0.0219, -8.0612, 3.4784, 8.4246, + 4.0575, -2.2864, 11.1084, 0.9963, 0.9884, -8.5154, -3.5469, -9.3713, + 0.9786, 3.5435, 7.4850, -5.2579, -1.4366, 10.4841 + ] + ) + # fmt: on + self.assertTrue(np.allclose(logits[0][0, 0, :30], EXPECTED_LOGITS, atol=1e-4)) + + def test_small_en_logits_librispeech(self): + model = FlaxWhisperModel.from_pretrained("openai/whisper-small.en", from_pt=True) + input_speech = self._load_datasamples(1) + feature_extractor = WhisperFeatureExtractor() + input_features = feature_extractor(input_speech, return_tensors="np").input_features + + logits = model( + input_features, + decoder_input_ids=np.array([model.config.decoder_start_token_id]), + output_hidden_states=False, + output_attentions=False, + return_dict=False, + ) + + logits = logits[0] @ model.params["model"]["decoder"]["embed_tokens"]["embedding"].T + + # fmt: off + EXPECTED_LOGITS = np.array( + [ + -3.6784, -7.7211, -9.5070, -11.9286, -7.6489, -9.7026, -5.6188, + -8.0104, -4.6238, -5.1833, -9.0485, -3.4079, -5.4874, -2.6935, + -6.3479, -7.3398, -6.9558, -7.6867, -7.4748, -8.3463, -9.9781, + -10.8389, -10.3105, -11.7201, -9.7261, -7.1590, -5.9272, -12.4509, + -11.1146, -8.1918 + ] + ) + # fmt: on + self.assertTrue(np.allclose(logits[0, 0, :30], EXPECTED_LOGITS, atol=1e-4)) + + def test_large_logits_librispeech(self): + model = FlaxWhisperModel.from_pretrained("openai/whisper-large", from_pt=True) + input_speech = self._load_datasamples(1) + processor = WhisperProcessor.from_pretrained("openai/whisper-large") + processed_inputs = processor( + audio=input_speech, text="This part of the speech", add_special_tokens=False, return_tensors="np" + ) + input_features = processed_inputs.input_features + decoder_input_ids = processed_inputs.labels + + logits = model( + input_features, + decoder_input_ids=decoder_input_ids, + output_hidden_states=False, + output_attentions=False, + return_dict=False, + ) + + logits = logits[0] @ model.params["model"]["decoder"]["embed_tokens"]["embedding"].T + + # fmt: off + EXPECTED_LOGITS = np.array( + [ + 2.1382, 0.9381, 4.4671, 3.5589, 2.4022, 3.8576, -0.6521, 2.5472, + 1.8301, 1.9957, 2.3432, 1.4678, 0.5459, 2.2597, 1.5179, 2.5357, + 1.1624, 0.6194, 1.0757, 1.8259, 2.4076, 1.6601, 2.3503, 1.3376, + 1.9891, 1.8635, 3.8931, 5.3699, 4.4772, 3.9184 + ] + ) + # fmt: on + self.assertTrue(np.allclose(logits[0, 0, :30], EXPECTED_LOGITS, atol=1e-4)) + + def test_tiny_en_generation(self): + processor = WhisperProcessor.from_pretrained("openai/whisper-tiny.en") + model = FlaxWhisperForConditionalGeneration.from_pretrained("openai/whisper-tiny.en", from_pt=True) + model.config.decoder_start_token_id = 50257 + + input_speech = self._load_datasamples(1) + input_features = processor.feature_extractor( + raw_speech=input_speech, sampling_rate=processor.feature_extractor.sampling_rate, return_tensors="jax" + ).input_features + + generated_ids = model.generate(input_features, num_beams=5, max_length=20).sequences + transcript = processor.tokenizer.decode(generated_ids[0]) + + EXPECTED_TRANSCRIPT = ( + "<|startoftranscript|><|en|><|transcribe|><|notimestamps|> Mr. Quilter is the apostle of the middle" + " classes and we are glad to" + ) + self.assertEqual(transcript, EXPECTED_TRANSCRIPT) + + def test_tiny_generation(self): + processor = WhisperProcessor.from_pretrained("openai/whisper-tiny") + model = FlaxWhisperForConditionalGeneration.from_pretrained("openai/whisper-tiny", from_pt=True) + + input_speech = self._load_datasamples(1) + input_features = processor.feature_extractor( + raw_speech=input_speech, sampling_rate=processor.feature_extractor.sampling_rate, return_tensors="jax" + ).input_features + + generated_ids = model.generate(input_features, num_beams=5, max_length=20).sequences + transcript = processor.tokenizer.decode(generated_ids[0]) + + EXPECTED_TRANSCRIPT = ( + "<|startoftranscript|><|en|><|transcribe|><|notimestamps|> Mr. Quilter is the apostle of the middle" + " classes and we are glad" + ) + self.assertEqual(transcript, EXPECTED_TRANSCRIPT) + + def test_large_generation(self): + processor = WhisperProcessor.from_pretrained("openai/whisper-large") + model = FlaxWhisperForConditionalGeneration.from_pretrained("openai/whisper-large", from_pt=True) + + input_speech = self._load_datasamples(1) + input_features = processor.feature_extractor( + raw_speech=input_speech, sampling_rate=processor.feature_extractor.sampling_rate, return_tensors="jax" + ).input_features + + model.config.forced_decoder_ids = processor.get_decoder_prompt_ids(language="en", task="transcribe") + + generated_ids = model.generate(input_features, num_beams=5, max_length=20).sequences + transcript = processor.tokenizer.decode(generated_ids[0], skip_special_tokens=True) + + EXPECTED_TRANSCRIPT = " Mr. Quilter is the apostle of the middle classes and we are glad" + self.assertEqual(transcript, EXPECTED_TRANSCRIPT) + + def test_large_generation_multilingual(self): + processor = WhisperProcessor.from_pretrained("openai/whisper-large") + model = FlaxWhisperForConditionalGeneration.from_pretrained("openai/whisper-large", from_pt=True) + + ds = load_dataset("common_voice", "ja", split="test", streaming=True) + ds = ds.cast_column("audio", datasets.Audio(sampling_rate=16_000)) + input_speech = next(iter(ds))["audio"]["array"] + input_features = processor.feature_extractor(raw_speech=input_speech, return_tensors="np") + + model.config.forced_decoder_ids = processor.get_decoder_prompt_ids(language="ja", task="transcribe") + generated_ids = model.generate(input_features, do_sample=False, max_length=20).sequences + transcript = processor.batch_decode(generated_ids, skip_special_tokens=True)[0] + + EXPECTED_TRANSCRIPT = "木村さんに電話を貸してもらいました" + self.assertEqual(transcript, EXPECTED_TRANSCRIPT) + + model.config.forced_decoder_ids = processor.get_decoder_prompt_ids(language="en", task="transcribe") + generated_ids = model.generate( + input_features, + do_sample=False, + max_length=20, + ).sequences + transcript = processor.batch_decode(generated_ids, skip_special_tokens=True)[0] + + EXPECTED_TRANSCRIPT = " Kimura-san called me." + self.assertEqual(transcript, EXPECTED_TRANSCRIPT) + + model.config.forced_decoder_ids = processor.get_decoder_prompt_ids(language="ja", task="translate") + generated_ids = model.generate(input_features, do_sample=False, max_length=20).sequences + transcript = processor.batch_decode(generated_ids, skip_special_tokens=True)[0] + + EXPECTED_TRANSCRIPT = " I borrowed a phone from Kimura san" + self.assertEqual(transcript, EXPECTED_TRANSCRIPT) + + def test_large_batched_generation(self): + processor = WhisperProcessor.from_pretrained("openai/whisper-large") + model = FlaxWhisperForConditionalGeneration.from_pretrained("openai/whisper-large", from_pt=True) + + input_speech = self._load_datasamples(4) + input_features = processor.feature_extractor(raw_speech=input_speech, return_tensors="np").input_features + generated_ids = model.generate(input_features, max_length=20).sequences + + # fmt: off + EXPECTED_LOGITS = np.array( + [ + [50258, 50358, 50363, 2221, 13, 2326, 388, 391, 307, 264, 50244, 295, 264, 2808, 5359, 293, 321, 366, 5404, 281], + [50258, 50358, 50363, 6966, 307, 2221, 13, 2326, 388, 391, 311, 9060, 1570, 1880, 813, 702, 1871, 13, 50257, 50257], + [50258, 50358, 50363, 634, 5112, 505, 300, 412, 341, 42729, 3196, 295, 264, 1064, 11, 365, 5272, 293, 12904, 9256], + [50258, 50358, 50363, 634, 575, 12525, 22618, 1968, 6144, 35617, 20084, 1756, 311, 589, 307, 534, 10281, 934, 439, 11] + ] + ) + # fmt: on + + self.assertTrue(np.allclose(generated_ids, EXPECTED_LOGITS)) + + # fmt: off + EXPECTED_TRANSCRIPT = [ + " Mr. Quilter is the apostle of the middle classes and we are glad to", + " Nor is Mr. Quilter's manner less interesting than his matter.", + " He tells us that at this festive season of the year, with Christmas and roast beef", + " He has grave doubts whether Sir Frederick Layton's work is really Greek after all,", + ] + # fmt: on + + transcript = processor.batch_decode(generated_ids, skip_special_tokens=True) + self.assertListEqual(transcript, EXPECTED_TRANSCRIPT) + + def test_tiny_en_batched_generation(self): + processor = WhisperProcessor.from_pretrained("openai/whisper-tiny.en") + model = FlaxWhisperForConditionalGeneration.from_pretrained("openai/whisper-tiny.en", from_pt=True) + + input_speech = self._load_datasamples(4) + input_features = processor.feature_extractor(raw_speech=input_speech, return_tensors="np").input_features + generated_ids = model.generate(input_features, max_length=20).sequences + + # fmt: off + EXPECTED_LOGITS = np.array( + [ + [50257, 50362, 1770, 13, 2264, 346, 353, 318, 262, 46329, 286, 262, 3504, 6097, 11, 290, 356, 389, 9675, 284], + [50257, 50362, 5414, 318, 1770, 13, 2264, 346, 353, 338, 5642, 1342, 3499, 621, 465, 2300, 13, 50256, 50256, 50256], + [50257, 50362, 679, 4952, 514, 326, 379, 428, 43856, 1622, 286, 262, 614, 11, 351, 6786, 290, 32595, 12023, 28236], + [50257, 50362, 679, 468, 12296, 17188, 1771, 7361, 26113, 18881, 1122, 338, 670, 318, 1107, 8312, 706, 477, 290, 460] + ] + + ) + # fmt: on + + self.assertTrue(np.allclose(generated_ids, EXPECTED_LOGITS)) + + # fmt: off + EXPECTED_TRANSCRIPT = [ + " Mr. Quilter is the apostle of the middle classes, and we are glad to", + " Nor is Mr. Quilter's manner less interesting than his matter.", + " He tells us that at this festive season of the year, with Christmas and roast beef looming", + " He has grave doubts whether Sir Frederick Layton's work is really Greek after all and can", + ] + # fmt: on + + transcript = processor.batch_decode(generated_ids, skip_special_tokens=True) + self.assertListEqual(transcript, EXPECTED_TRANSCRIPT) + + @slow + def test_tiny_timestamp_generation(self): + processor = WhisperProcessor.from_pretrained("openai/whisper-tiny") + model = FlaxWhisperForConditionalGeneration.from_pretrained("openai/whisper-tiny") + + input_speech = np.concatenate(self._load_datasamples(4)) + input_features = processor.feature_extractor(raw_speech=input_speech, return_tensors="jax").input_features + + generate_fn = jax.jit(functools.partial(model.generate, max_length=448, return_timestamps=True)) + + generated_ids = generate_fn(input_features) + + # fmt: off + EXPECTED_OUTPUT = np.array([50258, 50259, 50359, 50364, 2221, 13, 2326, 388, 391, 307, 264, 50244, 295, 264, 2808, 5359, 11, 293, 321, 366, 5404, 281, 2928, 702, 14943, 13, 50692, 50692, 6966, 307, 2221, 13, 2326, 388, 391, 311, 9060, 1570, 1880, 813, 702, 1871, 13, 50926, 50926, 634, 5112, 505, 300, 412, 341, 42729, 3196, 295, 264, 1064, 11, 365, 5272, 293, 12904, 9256, 450, 10539, 51208, 51208, 949, 505, 11, 14138, 10117, 490, 3936, 293, 1080, 3542, 5160, 881, 26336, 281, 264, 1575, 13, 51552, 51552, 634, 575, 12525, 22618, 1968, 6144, 35617, 7354, 1292, 6, 589, 307, 534, 10281, 934, 439, 11, 293, 51836, 51836, 50257]) + # fmt: on + + self.assertTrue(np.allclose(generated_ids, EXPECTED_OUTPUT)) + + EXPECTED_TRANSCRIPT = [ + { + "text": ( + " Mr. Quilter is the apostle of the middle classes, and we are glad to welcome his gospel. Nor is" + " Mr. Quilter's manner less interesting than his matter. He tells us that at this festive season" + " of the year, with Christmas and roast beef looming before us, similarly drawn from eating and" + " its results occur most readily to the mind. He has grave doubts whether Sir Frederick Latins'" + " work is really Greek after all, and" + ), + "offsets": [ + { + "text": ( + " Mr. Quilter is the apostle of the middle classes, and we are glad to welcome his gospel." + ), + "timestamp": (0.0, 6.5600000000000005), + }, + { + "text": " Nor is Mr. Quilter's manner less interesting than his matter.", + "timestamp": (6.5600000000000005, 11.24), + }, + { + "text": ( + " He tells us that at this festive season of the year, with Christmas and roast beef" + " looming" + ), + "timestamp": (11.24, 16.88), + }, + { + "text": ( + " before us, similarly drawn from eating and its results occur most readily to the mind." + ), + "timestamp": (16.88, 23.76), + }, + { + "text": ( + " He has grave doubts whether Sir Frederick Latins' work is really Greek after all, and" + ), + "timestamp": (23.76, 29.44), + }, + ], + } + ] + + transcript = processor.batch_decode(generated_ids, skip_special_tokens=True, output_offsets=True) + self.assertEqual(transcript, EXPECTED_TRANSCRIPT) diff --git a/tests/models/whisper/test_modeling_whisper.py b/tests/models/whisper/test_modeling_whisper.py index 0944a3a64634dd..68f8a5317a0b58 100644 --- a/tests/models/whisper/test_modeling_whisper.py +++ b/tests/models/whisper/test_modeling_whisper.py @@ -22,9 +22,10 @@ import numpy as np +import transformers from transformers import WhisperConfig -from transformers.testing_utils import is_torch_available, require_torch, require_torchaudio, slow, torch_device -from transformers.utils import cached_property +from transformers.testing_utils import is_pt_flax_cross_test, require_torch, require_torchaudio, slow, torch_device +from transformers.utils import cached_property, is_flax_available, is_torch_available from transformers.utils.import_utils import is_datasets_available from ...generation.test_utils import GenerationTesterMixin @@ -48,6 +49,13 @@ ) from transformers.models.whisper.modeling_whisper import WhisperDecoder, WhisperEncoder +if is_flax_available(): + import jax.numpy as jnp + from transformers.modeling_flax_pytorch_utils import ( + convert_pytorch_state_dict_to_flax, + load_flax_weights_in_pytorch_model, + ) + def prepare_whisper_inputs_dict( config, @@ -747,6 +755,159 @@ def _create_and_check_torchscript(self, config, inputs_dict): self.assertTrue(models_equal) + @is_pt_flax_cross_test + def test_equivalence_pt_to_flax(self): + config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + init_shape = (1,) + inputs_dict["input_features"].shape[1:] + + for model_class in self.all_model_classes: + with self.subTest(model_class.__name__): + fx_model_class_name = "Flax" + model_class.__name__ + + if not hasattr(transformers, fx_model_class_name): + # no flax model exists for this class + return + + # Output all for aggressive testing + config.output_hidden_states = True + config.output_attentions = self.has_attentions + + fx_model_class = getattr(transformers, fx_model_class_name) + + # load PyTorch class + pt_model = model_class(config).eval() + # Flax models don't use the `use_cache` option and cache is not returned as a default. + # So we disable `use_cache` here for PyTorch model. + pt_model.config.use_cache = False + + # load Flax class + fx_model = fx_model_class(config, input_shape=init_shape, dtype=jnp.float32) + + # make sure only flax inputs are forward that actually exist in function args + fx_input_keys = inspect.signature(fx_model.__call__).parameters.keys() + + # prepare inputs + pt_inputs = self._prepare_for_class(inputs_dict, model_class) + + # remove function args that don't exist in Flax + pt_inputs = {k: v for k, v in pt_inputs.items() if k in fx_input_keys} + + # send pytorch inputs to the correct device + pt_inputs = { + k: v.to(device=torch_device) if isinstance(v, torch.Tensor) else v for k, v in pt_inputs.items() + } + + # convert inputs to Flax + fx_inputs = {k: np.array(v) for k, v in pt_inputs.items() if torch.is_tensor(v)} + + fx_state = convert_pytorch_state_dict_to_flax(pt_model.state_dict(), fx_model) + fx_model.params = fx_state + + # send pytorch model to the correct device + pt_model.to(torch_device) + + with torch.no_grad(): + pt_outputs = pt_model(**pt_inputs) + fx_outputs = fx_model(**fx_inputs) + + fx_keys = tuple([k for k, v in fx_outputs.items() if v is not None]) + pt_keys = tuple([k for k, v in pt_outputs.items() if v is not None]) + + self.assertEqual(fx_keys, pt_keys) + self.check_pt_flax_outputs(fx_outputs, pt_outputs, model_class) + + with tempfile.TemporaryDirectory() as tmpdirname: + pt_model.save_pretrained(tmpdirname) + fx_model_loaded = fx_model_class.from_pretrained(tmpdirname, input_shape=init_shape, from_pt=True) + + fx_outputs_loaded = fx_model_loaded(**fx_inputs) + + fx_keys = tuple([k for k, v in fx_outputs_loaded.items() if v is not None]) + pt_keys = tuple([k for k, v in pt_outputs.items() if v is not None]) + + self.assertEqual(fx_keys, pt_keys) + self.check_pt_flax_outputs(fx_outputs_loaded, pt_outputs, model_class) + + @is_pt_flax_cross_test + def test_equivalence_flax_to_pt(self): + config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + init_shape = (1,) + inputs_dict["input_features"].shape[1:] + + for model_class in self.all_model_classes: + with self.subTest(model_class.__name__): + fx_model_class_name = "Flax" + model_class.__name__ + + if not hasattr(transformers, fx_model_class_name): + # no flax model exists for this class + return + + # Output all for aggressive testing + config.output_hidden_states = True + config.output_attentions = self.has_attentions + + fx_model_class = getattr(transformers, fx_model_class_name) + + # load PyTorch class + pt_model = model_class(config).eval() + # Flax models don't use the `use_cache` option and cache is not returned as a default. + # So we disable `use_cache` here for PyTorch model. + pt_model.config.use_cache = False + + # load Flax class + fx_model = fx_model_class(config, input_shape=init_shape, dtype=jnp.float32) + + # make sure only flax inputs are forward that actually exist in function args + fx_input_keys = inspect.signature(fx_model.__call__).parameters.keys() + + # prepare inputs + pt_inputs = self._prepare_for_class(inputs_dict, model_class) + + # remove function args that don't exist in Flax + pt_inputs = {k: v for k, v in pt_inputs.items() if k in fx_input_keys} + + # send pytorch inputs to the correct device + pt_inputs = { + k: v.to(device=torch_device) if isinstance(v, torch.Tensor) else v for k, v in pt_inputs.items() + } + + # convert inputs to Flax + fx_inputs = {k: np.array(v) for k, v in pt_inputs.items() if torch.is_tensor(v)} + + pt_model = load_flax_weights_in_pytorch_model(pt_model, fx_model.params) + + # make sure weights are tied in PyTorch + pt_model.tie_weights() + + # send pytorch model to the correct device + pt_model.to(torch_device) + + with torch.no_grad(): + pt_outputs = pt_model(**pt_inputs) + fx_outputs = fx_model(**fx_inputs) + + fx_keys = tuple([k for k, v in fx_outputs.items() if v is not None]) + pt_keys = tuple([k for k, v in pt_outputs.items() if v is not None]) + + self.assertEqual(fx_keys, pt_keys) + self.check_pt_flax_outputs(fx_outputs, pt_outputs, model_class) + + with tempfile.TemporaryDirectory() as tmpdirname: + fx_model.save_pretrained(tmpdirname) + pt_model_loaded = model_class.from_pretrained(tmpdirname, from_flax=True) + + # send pytorch model to the correct device + pt_model_loaded.to(torch_device) + pt_model_loaded.eval() + + with torch.no_grad(): + pt_outputs_loaded = pt_model_loaded(**pt_inputs) + + fx_keys = tuple([k for k, v in fx_outputs.items() if v is not None]) + pt_keys = tuple([k for k, v in pt_outputs_loaded.items() if v is not None]) + + self.assertEqual(fx_keys, pt_keys) + self.check_pt_flax_outputs(fx_outputs, pt_outputs_loaded, model_class) + @require_torch @require_torchaudio @@ -756,7 +917,6 @@ def default_processor(self): return WhisperProcessor.from_pretrained("openai/whisper-base") def _load_datasamples(self, num_samples): - ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation") # automatic decoding with librispeech speech_samples = ds.sort("id").select(range(num_samples))[:num_samples]["audio"] @@ -886,7 +1046,6 @@ def test_large_logits_librispeech(self): @slow def test_tiny_en_generation(self): - torch_device = "cpu" set_seed(0) processor = WhisperProcessor.from_pretrained("openai/whisper-tiny.en") @@ -910,7 +1069,6 @@ def test_tiny_en_generation(self): @slow def test_tiny_generation(self): - torch_device = "cpu" set_seed(0) processor = WhisperProcessor.from_pretrained("openai/whisper-tiny") diff --git a/utils/check_repo.py b/utils/check_repo.py index ad68a43c8a2e5d..6afa9f3f0c7739 100755 --- a/utils/check_repo.py +++ b/utils/check_repo.py @@ -67,6 +67,8 @@ "DeformableDetrEncoder", # Building part of bigger (tested) model. "DeformableDetrDecoder", # Building part of bigger (tested) model. "OPTDecoder", # Building part of bigger (tested) model. + "FlaxWhisperDecoder", # Building part of bigger (tested) model. + "FlaxWhisperEncoder", # Building part of bigger (tested) model. "WhisperDecoder", # Building part of bigger (tested) model. "WhisperEncoder", # Building part of bigger (tested) model. "DecisionTransformerGPT2Model", # Building part of bigger (tested) model.