From 7d3b6ef3ac10feecb29a7a4a4f26325856f2d782 Mon Sep 17 00:00:00 2001 From: andyehrenberg Date: Mon, 28 Nov 2022 14:13:22 -0800 Subject: [PATCH 001/111] add flax whisper implementation --- docs/source/en/index.mdx | 2 +- docs/source/en/model_doc/whisper.mdx | 11 + setup.py | 6 +- src/transformers/__init__.py | 12 + .../generation/flax_logits_process.py | 45 +- src/transformers/generation/flax_utils.py | 99 +- .../models/auto/modeling_flax_auto.py | 3 + src/transformers/models/whisper/__init__.py | 33 +- .../models/whisper/modeling_flax_whisper.py | 1411 +++++++++++++++++ src/transformers/utils/dummy_flax_objects.py | 28 +- .../whisper/test_modeling_flax_whisper.py | 289 ++++ utils/check_repo.py | 2 + 12 files changed, 1923 insertions(+), 18 deletions(-) create mode 100644 src/transformers/models/whisper/modeling_flax_whisper.py create mode 100644 tests/models/whisper/test_modeling_flax_whisper.py diff --git a/docs/source/en/index.mdx b/docs/source/en/index.mdx index 4cba7ee9692811..e80c46f6643917 100644 --- a/docs/source/en/index.mdx +++ b/docs/source/en/index.mdx @@ -354,7 +354,7 @@ Flax), PyTorch, and/or TensorFlow. | Wav2Vec2 | ✅ | ❌ | ✅ | ✅ | ✅ | | Wav2Vec2-Conformer | ❌ | ❌ | ✅ | ❌ | ❌ | | WavLM | ❌ | ❌ | ✅ | ❌ | ❌ | -| Whisper | ✅ | ❌ | ✅ | ✅ | ❌ | +| Whisper | ✅ | ❌ | ✅ | ✅ | ✅ | | X-CLIP | ❌ | ❌ | ✅ | ❌ | ❌ | | XGLM | ✅ | ✅ | ✅ | ✅ | ✅ | | XLM | ✅ | ❌ | ✅ | ✅ | ❌ | diff --git a/docs/source/en/model_doc/whisper.mdx b/docs/source/en/model_doc/whisper.mdx index 4b7a6028618427..26d4ef8af91306 100644 --- a/docs/source/en/model_doc/whisper.mdx +++ b/docs/source/en/model_doc/whisper.mdx @@ -79,3 +79,14 @@ The original code can be found [here](https://github.com/openai/whisper). [[autodoc]] TFWhisperForConditionalGeneration - call + + +## FlaxWhisperModel + +[[autodoc]] FlaxWhisperModel + - __call__ + +## FlaxWhisperForConditionalGeneration + +[[autodoc]] FlaxWhisperForConditionalGeneration + - __call__ diff --git a/setup.py b/setup.py index ef533058b7bf10..2dc58263cf383f 100644 --- a/setup.py +++ b/setup.py @@ -159,7 +159,7 @@ "starlette", "tensorflow-cpu>=2.4,<2.11", "tensorflow>=2.4,<2.11", - "tensorflow-text", + #"tensorflow-text", "tf2onnx", "timeout-decorator", "timm", @@ -247,8 +247,8 @@ def run(self): extras["ja"] = deps_list("fugashi", "ipadic", "unidic_lite", "unidic", "sudachipy", "sudachidict_core", "pyknp") extras["sklearn"] = deps_list("scikit-learn") -extras["tf"] = deps_list("tensorflow", "onnxconverter-common", "tf2onnx", "tensorflow-text") -extras["tf-cpu"] = deps_list("tensorflow-cpu", "onnxconverter-common", "tf2onnx", "tensorflow-text") +extras["tf"] = deps_list("tensorflow", "onnxconverter-common", "tf2onnx") +extras["tf-cpu"] = deps_list("tensorflow-cpu", "onnxconverter-common", "tf2onnx") extras["torch"] = deps_list("torch") extras["accelerate"] = deps_list("accelerate") diff --git a/src/transformers/__init__.py b/src/transformers/__init__.py index 714ce483ea71bb..f06f239e94f37e 100644 --- a/src/transformers/__init__.py +++ b/src/transformers/__init__.py @@ -3223,6 +3223,13 @@ _import_structure["models.wav2vec2"].extend( ["FlaxWav2Vec2ForCTC", "FlaxWav2Vec2ForPreTraining", "FlaxWav2Vec2Model", "FlaxWav2Vec2PreTrainedModel"] ) + _import_structure["models.whisper"].extend( + [ + "FlaxWhisperForConditionalGeneration", + "FlaxWhisperModel", + "FlaxWhisperPreTrainedModel", + ] + ) _import_structure["models.xglm"].extend( [ "FlaxXGLMForCausalLM", @@ -5872,6 +5879,11 @@ FlaxWav2Vec2Model, FlaxWav2Vec2PreTrainedModel, ) + from .models.whisper import ( + FlaxWhisperForConditionalGeneration, + FlaxWhisperModel, + FlaxWhisperPreTrainedModel, + ) from .models.xglm import FlaxXGLMForCausalLM, FlaxXGLMModel, FlaxXGLMPreTrainedModel from .models.xlm_roberta import ( FlaxXLMRobertaForMaskedLM, diff --git a/src/transformers/generation/flax_logits_process.py b/src/transformers/generation/flax_logits_process.py index 12fc9c39e5999f..641c52c21a1568 100644 --- a/src/transformers/generation/flax_logits_process.py +++ b/src/transformers/generation/flax_logits_process.py @@ -259,10 +259,53 @@ def __init__(self, min_length: int, eos_token_id: int): self.eos_token_id = eos_token_id def __call__(self, input_ids: jnp.ndarray, scores: jnp.ndarray, cur_len: int) -> jnp.ndarray: - # create boolean flag to decide if min length penalty should be applied apply_penalty = 1 - jnp.clip(cur_len - self.min_length, 0, 1) scores = jnp.where(apply_penalty, scores.at[:, self.eos_token_id].set(-float("inf")), scores) return scores + + +class FlaxSuppressTokensAtBeginLogitsProcessor(FlaxLogitsProcessor): + r""" + [`FlaxSuppressTokensAtBeginLogitsProcessor`] supresses a list of tokens as soon as the `generate` function starts + generating using `begin_index` tokens. This should ensure that the tokens defined by `begin_suppress_tokens` are not + sampled at the begining of the generation. + """ + + def __init__(self, begin_suppress_tokens, begin_index): + self.begin_suppress_tokens = list(begin_suppress_tokens) + self.begin_index = begin_index + + def __call__(self, input_ids, scores, cur_len: int): + if input_ids.shape[1] == self.begin_index: + scores = scores.at[:, self.begin_suppress_tokens].set(-float("inf")) + + return scores + + +class FlaxSuppressTokensLogitsProcessor(FlaxLogitsProcessor): + def __init__(self, suppress_tokens: list): + self.suppress_tokens = list(suppress_tokens) + + def __call__(self, input_ids: jnp.ndarray, scores: jnp.ndarray, cur_len: int) -> jnp.ndarray: + scores = scores.at[..., self.suppress_tokens].set(-float("inf")) + + return scores + + +class FlaxForceTokensLogitsProcessor(FlaxLogitsProcessor): + r"""This processor can be used to force a list of tokens. The processor will set their log probs to `inf` so that they + are sampled at their corresponding index.""" + + def __init__(self, force_token_map): + self.force_token_map = dict(force_token_map) + + def __call__(self, input_ids: jnp.ndarray, scores: jnp.ndarray, cur_len: int): + generation_idx = input_ids.shape[-1] + current_token = self.force_token_map.get(generation_idx, None) + if current_token is not None: + scores = scores.at[:, :].set(-float("inf")) + scores = scores.at[:, current_token].set(0) + return scores diff --git a/src/transformers/generation/flax_utils.py b/src/transformers/generation/flax_utils.py index 5d936ce5b1dccd..33abc101351a54 100644 --- a/src/transformers/generation/flax_utils.py +++ b/src/transformers/generation/flax_utils.py @@ -18,7 +18,7 @@ import inspect import warnings from functools import partial -from typing import Any, Dict, Optional +from typing import Any, Dict, List, Optional import numpy as np @@ -36,8 +36,11 @@ from .flax_logits_process import ( FlaxForcedBOSTokenLogitsProcessor, FlaxForcedEOSTokenLogitsProcessor, + FlaxForceTokensLogitsProcessor, FlaxLogitsProcessorList, FlaxMinLengthLogitsProcessor, + FlaxSuppressTokensAtBeginLogitsProcessor, + FlaxSuppressTokensLogitsProcessor, FlaxTemperatureLogitsWarper, FlaxTopKLogitsWarper, FlaxTopPLogitsWarper, @@ -155,6 +158,35 @@ def _prepare_encoder_decoder_kwargs_for_generation(self, input_ids, params, mode model_kwargs["encoder_outputs"] = self.encode(input_ids, params=params, return_dict=True, **encoder_kwargs) return model_kwargs + def _prepare_decoder_input_ids_for_generation( + self, + batch_size: int, + decoder_start_token_id: int = None, + bos_token_id: int = None, + model_kwargs: Optional[Dict[str, jnp.ndarray]] = None, + ) -> jnp.ndarray: + if model_kwargs is not None and "decoder_input_ids" in model_kwargs: + # Only use this arg if not None, otherwise just remove from model_kwargs + decoder_input_ids = model_kwargs.pop("decoder_input_ids") + if decoder_input_ids is not None: + return decoder_input_ids + decoder_start_token_id = self._get_decoder_start_token_id(decoder_start_token_id, bos_token_id) + return jnp.array(decoder_start_token_id).reshape(1, -1).repeat(batch_size, axis=0) + + def _get_decoder_start_token_id(self, decoder_start_token_id: int = None, bos_token_id: int = None) -> int: + decoder_start_token_id = ( + decoder_start_token_id if decoder_start_token_id is not None else self.config.decoder_start_token_id + ) + bos_token_id = bos_token_id if bos_token_id is not None else self.config.bos_token_id + + if decoder_start_token_id is not None: + return decoder_start_token_id + elif bos_token_id is not None: + return bos_token_id + raise ValueError( + "`decoder_start_token_id` or `bos_token_id` has to be defined for encoder-decoder generation." + ) + @staticmethod def _expand_to_num_beams(tensor, num_beams): return jnp.broadcast_to(tensor[:, None], (tensor.shape[0], num_beams) + tensor.shape[1:]) @@ -227,6 +259,9 @@ def generate( min_length: Optional[int] = None, forced_bos_token_id: Optional[int] = None, forced_eos_token_id: Optional[int] = None, + suppress_tokens: Optional[List[int]] = None, + begin_suppress_tokens: Optional[List[int]] = None, + forced_decoder_ids: Optional[List[int]] = None, length_penalty: Optional[float] = None, early_stopping: Optional[bool] = None, trace: bool = True, @@ -334,12 +369,19 @@ def generate( "generation results, please set `padding_side='left'` when initializing the tokenizer." ) + batch_size = input_ids.shape[0] + if self.config.is_encoder_decoder: # add encoder_outputs to model_kwargs if model_kwargs.get("encoder_outputs") is None: model_kwargs = self._prepare_encoder_decoder_kwargs_for_generation(input_ids, params, model_kwargs) # prepare decoder_input_ids for generation - input_ids = jnp.ones((input_ids.shape[0], 1), dtype="i4") * decoder_start_token_id + input_ids = self._prepare_decoder_input_ids_for_generation( + batch_size, + decoder_start_token_id=decoder_start_token_id, + bos_token_id=bos_token_id, + model_kwargs=model_kwargs, + ) # Prepare `max_length` depending on other stopping criteria. input_ids_seq_length = input_ids.shape[-1] @@ -382,7 +424,16 @@ def generate( if not do_sample and num_beams == 1: logits_processor = self._get_logits_processor( - no_repeat_ngram_size, min_length, max_length, eos_token_id, forced_bos_token_id, forced_eos_token_id + no_repeat_ngram_size, + min_length, + max_length, + eos_token_id, + forced_bos_token_id, + forced_eos_token_id, + input_ids_seq_length, + suppress_tokens=suppress_tokens, + begin_suppress_tokens=begin_suppress_tokens, + forced_decoder_ids=forced_decoder_ids, ) return self._greedy_search( input_ids, @@ -397,7 +448,16 @@ def generate( elif do_sample and num_beams == 1: logits_warper = self._get_logits_warper(top_k=top_k, top_p=top_p, temperature=temperature) logits_processor = self._get_logits_processor( - no_repeat_ngram_size, min_length, max_length, eos_token_id, forced_bos_token_id, forced_eos_token_id + no_repeat_ngram_size, + min_length, + max_length, + eos_token_id, + forced_bos_token_id, + forced_eos_token_id, + input_ids_seq_length, + suppress_tokens=suppress_tokens, + begin_suppress_tokens=begin_suppress_tokens, + forced_decoder_ids=forced_decoder_ids, ) return self._sample( input_ids, @@ -426,7 +486,16 @@ def generate( ) logits_processor = self._get_logits_processor( - no_repeat_ngram_size, min_length, max_length, eos_token_id, forced_bos_token_id, forced_eos_token_id + no_repeat_ngram_size, + min_length, + max_length, + eos_token_id, + forced_bos_token_id, + forced_eos_token_id, + input_ids_seq_length, + suppress_tokens=suppress_tokens, + begin_suppress_tokens=begin_suppress_tokens, + forced_decoder_ids=forced_decoder_ids, ) return self._beam_search( @@ -478,6 +547,10 @@ def _get_logits_processor( eos_token_id: int, forced_bos_token_id: int, forced_eos_token_id: int, + input_ids_seq_length: int, + suppress_tokens: Optional[List[int]] = None, + begin_suppress_tokens: Optional[List[int]] = None, + forced_decoder_ids: Optional[List[int]] = None, ) -> FlaxLogitsProcessorList: """ This class returns a [`FlaxLogitsProcessorList`] list object that contains all relevant [`FlaxLogitsProcessor`] @@ -496,6 +569,12 @@ def _get_logits_processor( forced_eos_token_id = ( forced_eos_token_id if forced_eos_token_id is not None else self.config.forced_eos_token_id ) + suppress_tokens = suppress_tokens if suppress_tokens is not None else self.config.suppress_tokens + begin_suppress_tokens = ( + begin_suppress_tokens if begin_suppress_tokens is not None else self.config.begin_suppress_tokens + ) + if forced_decoder_ids is None and hasattr(self.config, "forced_decoder_ids"): + forced_decoder_ids = self.config.forced_decoder_ids # the following idea is largely copied from this PR: https://github.com/huggingface/transformers/pull/5420/files # all samplers can be found in `generation_utils_samplers.py` @@ -505,6 +584,16 @@ def _get_logits_processor( processors.append(FlaxForcedBOSTokenLogitsProcessor(forced_bos_token_id)) if forced_eos_token_id is not None: processors.append(FlaxForcedEOSTokenLogitsProcessor(max_length, forced_eos_token_id)) + if suppress_tokens is not None: + processors.append(FlaxSuppressTokensLogitsProcessor(suppress_tokens)) + if begin_suppress_tokens is not None: + begin_index = input_ids_seq_length + begin_index = begin_index if (input_ids_seq_length > 1 or forced_bos_token_id is None) else begin_index + 1 + if forced_decoder_ids is not None: + begin_index += forced_decoder_ids[-1][0] # generation starts after the last token that is forced + processors.append(FlaxSuppressTokensAtBeginLogitsProcessor(begin_suppress_tokens, begin_index)) + if forced_decoder_ids is not None: + processors.append(FlaxForceTokensLogitsProcessor(forced_decoder_ids)) return processors def _greedy_search( diff --git a/src/transformers/models/auto/modeling_flax_auto.py b/src/transformers/models/auto/modeling_flax_auto.py index 98c5d6fb5a1045..6b65f1de7396a0 100644 --- a/src/transformers/models/auto/modeling_flax_auto.py +++ b/src/transformers/models/auto/modeling_flax_auto.py @@ -53,6 +53,7 @@ ("vision-text-dual-encoder", "FlaxVisionTextDualEncoderModel"), ("vit", "FlaxViTModel"), ("wav2vec2", "FlaxWav2Vec2Model"), + ("whisper", "FlaxWhisperModel"), ("xglm", "FlaxXGLMModel"), ("xlm-roberta", "FlaxXLMRobertaModel"), ] @@ -73,6 +74,7 @@ ("roformer", "FlaxRoFormerForMaskedLM"), ("t5", "FlaxT5ForConditionalGeneration"), ("wav2vec2", "FlaxWav2Vec2ForPreTraining"), + ("whisper", "FlaxWhisperForConditionalGeneration"), ("xlm-roberta", "FlaxXLMRobertaForMaskedLM"), ] ) @@ -208,6 +210,7 @@ FLAX_MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING_NAMES = OrderedDict( [ ("speech-encoder-decoder", "FlaxSpeechEncoderDecoderModel"), + ("whisper", "FlaxWhisperForConditionalGeneration"), ] ) diff --git a/src/transformers/models/whisper/__init__.py b/src/transformers/models/whisper/__init__.py index 2528e03a4d2c88..31e167c4ff6c0c 100644 --- a/src/transformers/models/whisper/__init__.py +++ b/src/transformers/models/whisper/__init__.py @@ -17,7 +17,13 @@ # limitations under the License. from typing import TYPE_CHECKING -from ...utils import OptionalDependencyNotAvailable, _LazyModule, is_tf_available, is_torch_available +from ...utils import ( + OptionalDependencyNotAvailable, + _LazyModule, + is_flax_available, + is_tf_available, + is_torch_available, +) _import_structure = { @@ -54,6 +60,19 @@ "TFWhisperPreTrainedModel", ] +try: + if not is_flax_available(): + raise OptionalDependencyNotAvailable() +except OptionalDependencyNotAvailable: + pass +else: + _import_structure["modeling_flax_whisper"] = [ + "FlaxWhisperForConditionalGeneration", + "FlaxWhisperModel", + "FlaxWhisperPreTrainedModel", + ] + + if TYPE_CHECKING: from .configuration_whisper import WHISPER_PRETRAINED_CONFIG_ARCHIVE_MAP, WhisperConfig, WhisperOnnxConfig from .feature_extraction_whisper import WhisperFeatureExtractor @@ -86,6 +105,18 @@ TFWhisperPreTrainedModel, ) + try: + if not is_flax_available(): + raise OptionalDependencyNotAvailable() + except OptionalDependencyNotAvailable: + pass + else: + from .modeling_flax_whisper import ( + FlaxWhisperForConditionalGeneration, + FlaxWhisperModel, + FlaxWhisperPreTrainedModel, + ) + else: import sys diff --git a/src/transformers/models/whisper/modeling_flax_whisper.py b/src/transformers/models/whisper/modeling_flax_whisper.py new file mode 100644 index 00000000000000..3029410d505a69 --- /dev/null +++ b/src/transformers/models/whisper/modeling_flax_whisper.py @@ -0,0 +1,1411 @@ +from typing import Union, Dict, Optional, Tuple +from functools import partial + +import flax.linen as nn +import jax +import jax.numpy as jnp +from flax.core.frozen_dict import FrozenDict, freeze, unfreeze +from flax.linen import combine_masks, make_causal_mask +from flax.linen.attention import dot_product_attention_weights +from flax.linen import partitioning as nn_partitioning +from flax.traverse_util import flatten_dict, unflatten_dict +from jax import lax +from jax.random import PRNGKey + +from ...modeling_flax_outputs import ( + FlaxBaseModelOutput, + FlaxBaseModelOutputWithPastAndCrossAttentions, + FlaxCausalLMOutputWithCrossAttentions, + FlaxSeq2SeqLMOutput, + FlaxSeq2SeqModelOutput, +) + +from ...modeling_flax_utils import ( + ACT2FN, + FlaxPreTrainedModel, + append_call_sample_docstring, + append_replace_return_docstrings, + overwrite_call_docstring, +) +from ...utils import add_start_docstrings, add_start_docstrings_to_model_forward, logging, replace_return_docstrings +from .configuration_whisper import WhisperConfig + +scan_with_axes = nn_partitioning.scan_with_axes + + +logger = logging.get_logger(__name__) + + +_CHECKPOINT_FOR_DOC = "openai/whisper-tiny" +_CONFIG_FOR_DOC = "WhisperConfig" +_TOKENIZER_FOR_DOC = "WhisperTokenizer" + + +WHISPER_START_DOCSTRING = r""" + This model inherits from [`FlaxPreTrainedModel`]. Check the superclass documentation for the generic methods the + library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads + etc.) + This model is also a Flax Linen + [flax.nn.Module](https://flax.readthedocs.io/en/latest/_autosummary/flax.nn.module.html) subclass. Use it as a + regular Flax Module and refer to the Flax documentation for all matter related to general usage and behavior. + Finally, this model supports inherent JAX features such as: + - [Just-In-Time (JIT) compilation](https://jax.readthedocs.io/en/latest/jax.html#just-in-time-compilation-jit) + - [Automatic Differentiation](https://jax.readthedocs.io/en/latest/jax.html#automatic-differentiation) + - [Vectorization](https://jax.readthedocs.io/en/latest/jax.html#vectorization-vmap) + - [Parallelization](https://jax.readthedocs.io/en/latest/jax.html#parallelization-pmap) + Parameters: + config ([`WhisperConfig`]): Model configuration class with all the parameters of the model. + Initializing with a config file does not load the weights associated with the model, only the + configuration. Check out the [`~FlaxPreTrainedModel.from_pretrained`] method to load the model weights. + dtype (`jax.numpy.dtype`, *optional*, defaults to `jax.numpy.float32`): + The data type of the computation. Can be one of `jax.numpy.float32`, `jax.numpy.float16` (on GPUs) and + `jax.numpy.bfloat16` (on TPUs). + This can be used to enable mixed-precision training or half-precision inference on GPUs or TPUs. If + specified all the computation will be performed with the given `dtype`. + **Note that this only specifies the dtype of the computation and does not influence the dtype of model + parameters.** + If you wish to change the dtype of the model parameters, see [`~FlaxPreTrainedModel.to_fp16`] and + [`~FlaxPreTrainedModel.to_bf16`]. +""" + +WHISPER_INPUTS_DOCSTRING = r""" + Args: + input_features (`numpy.ndarray` of shape `(batch_size, feature_size, sequence_length)`): + Float values mel features extracted from the raw speech waveform. Raw speech waveform can be obtained by + loading a `.flac` or `.wav` audio file into an array of type `List[float]` or a `numpy.ndarray`, *e.g.* via + the soundfile library (`pip install soundfile`). To prepare the array into `input_features`, the + [`WhisperFeatureExtractor`] should be used for extracting the mel features, padding and conversion into a + tensor of type `numpy.ndarray`. See [`~WhisperFeatureExtractor.__call__`] + attention_mask (`numpy.ndarray` of shape `(batch_size, sequence_length)`, *optional*): + Whisper does not support masking of the `input_features`, this argument is preserved for compatibility, but it is not used. By default the silence in the input log mel spectrogram are ignored. + decoder_input_ids (`numpy.ndarray` of shape `(batch_size, target_sequence_length)`, *optional*): + Indices of decoder input sequence tokens in the vocabulary. + Indices can be obtained using [`WhisperTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + [What are decoder input IDs?](../glossary#decoder-input-ids) + Whisper uses the `decoder_start_token_id` as the starting token for `decoder_input_ids` generation. + decoder_attention_mask (`numpy.ndarray` of shape `(batch_size, target_sequence_length)`, *optional*): + Default behavior: generate a tensor that ignores pad tokens in `decoder_input_ids`. Causal mask will also + be used by default. + If you want to change padding behavior, you should modify to your needs. See diagram 1 in [the + paper](https://arxiv.org/abs/1910.13461) for more information on the default strategy. + position_ids (`numpy.ndarray` of shape `(batch_size, sequence_length)`, *optional*): + Whisper does not use position_ids in the encoder as input_features is always the same size and doesn't use masking, but this argument is preserved for compatibility. By default the silence in the input log mel spectrogram are ignored. + decoder_position_ids (`numpy.ndarray` of shape `(batch_size, sequence_length)`, *optional*): + Indices of positions of each decoder input sequence tokens in the position embeddings. Selected in the + range `[0, config.max_position_embeddings - 1]`. + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned + tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for + more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. +""" + +WHISPER_ENCODE_INPUTS_DOCSTRING = r""" + Args: + input_features (`numpy.ndarray` of shape `(batch_size, feature_size, sequence_length)`): + Float values mel features extracted from the raw speech waveform. Raw speech waveform can be obtained by + loading a `.flac` or `.wav` audio file into an array of type `List[float]` or a `numpy.ndarray`, *e.g.* via + the soundfile library (`pip install soundfile`). To prepare the array into `input_features`, the + [`WhisperFeatureExtractor`] should be used for extracting the mel features, padding and conversion into a + tensor of type `numpy.ndarray`. See [`~WhisperFeatureExtractor.__call__`] + attention_mask (`numpy.ndarray` of shape `(batch_size, sequence_length)`, *optional*): + Whisper does not support masking of the `input_features`, this argument is preserved for compatibility, but it is not used. By default the silence in the input log mel spectrogram are ignored. + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned + tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for + more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. +""" + +WHISPER_DECODE_INPUTS_DOCSTRING = r""" + Args: + decoder_input_ids (`numpy.ndarray` of shape `(batch_size, target_sequence_length)`): + Indices of decoder input sequence tokens in the vocabulary. + Indices can be obtained using [`WhisperTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + [What are decoder input IDs?](../glossary#decoder-input-ids) + encoder_outputs (`tuple(tuple(numpy.ndarray)`): + Tuple consists of (`last_hidden_state`, *optional*: `hidden_states`, *optional*: `attentions`) + `last_hidden_state` of shape `(batch_size, sequence_length, hidden_size)`, *optional*) is a sequence of + hidden-states at the output of the last layer of the encoder. Used in the cross-attention of the decoder. + encoder_attention_mask (`numpy.ndarray` of shape `(batch_size, sequence_length)`, *optional*): + Whisper does not support masking of the `input_features`, this argument is preserved for compatibility, but it is not used. By default the silence in the input log mel spectrogram are ignored. + decoder_attention_mask (`numpy.ndarray` of shape `(batch_size, target_sequence_length)`, *optional*): + Default behavior: generate a tensor that ignores pad tokens in `decoder_input_ids`. Causal mask will also + be used by default. + If you want to change padding behavior, you should modify to your needs. See diagram 1 in [the + paper](https://arxiv.org/abs/1910.13461) for more information on the default strategy. + decoder_position_ids (`numpy.ndarray` of shape `(batch_size, sequence_length)`, *optional*): + Indices of positions of each decoder input sequence tokens in the position embeddings. Selected in the + range `[0, config.max_position_embeddings - 1]`. + past_key_values (`Dict[str, numpy.ndarray]`, *optional*, returned by `init_cache` or when passing previous `past_key_values`): + Dictionary of pre-computed hidden-states (key and values in the attention blocks) that can be used for fast + auto-regressive decoding. Pre-computed key and value hidden-states are of shape *[batch_size, max_length]*. + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned + tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for + more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. +""" + + +class FlaxWhisperAttention(nn.Module): + config: WhisperConfig + embed_dim: int + num_heads: int + dropout: float = 0.0 + causal: bool = False + bias: bool = True + dtype: jnp.dtype = jnp.float32 + + def setup(self) -> None: + self.head_dim = self.embed_dim // self.num_heads + if self.head_dim * self.num_heads != self.embed_dim: + raise ValueError( + f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim}" + f" and `num_heads`: {self.num_heads})." + ) + + dense = partial( + nn.Dense, + self.embed_dim, + use_bias=self.bias, + dtype=self.dtype, + kernel_init=jax.nn.initializers.normal(self.config.init_std), + ) + + self.q_proj = dense(use_bias=self.bias) + self.k_proj = dense(use_bias=False) + self.v_proj = dense(use_bias=self.bias) + self.out_proj = dense(use_bias=self.bias) + + if self.causal: + self.causal_mask = make_causal_mask( + jnp.ones((1, self.config.max_target_positions), dtype="bool"), dtype="bool" + ) + + def __call__( + self, + hidden_states: jnp.ndarray, + key_value_states: Optional[jnp.ndarray] = None, + attention_mask: Optional[jnp.ndarray] = None, + init_cache: bool = False, + deterministic: bool = True, + ) -> jnp.ndarray: + is_cross_attention = key_value_states is not None + batch_size = hidden_states.shape[0] + + q = self.q_proj(hidden_states) + + if is_cross_attention: + k = self.k_proj(key_value_states) + v = self.v_proj(key_value_states) + else: + k = self.k_proj(hidden_states) + v = self.v_proj(hidden_states) + + q, k, v = jax.tree_util.tree_map(self._split_heads, (q, k, v)) + + if self.causal: + query_length, key_length = q.shape[1], k.shape[1] + if self.has_variable("cache", "cached_key"): + mask_shift = self.variables["cache"]["cache_index"] + max_decoder_length = self.variables["cache"]["cached_key"].shape[1] + causal_mask = lax.dynamic_slice( + self.causal_mask, + (0, 0, mask_shift, 0), + (1, 1, query_length, max_decoder_length), + ) + else: + causal_mask = self.causal_mask[:, :, :query_length, :key_length] + causal_mask = jnp.broadcast_to(causal_mask, (batch_size,) + causal_mask.shape[1:]) + + # combine masks if needed + if attention_mask is not None and self.causal: + attention_mask = jnp.broadcast_to(jnp.expand_dims(attention_mask, axis=(-3, -2)), causal_mask.shape) + attention_mask = combine_masks(attention_mask, causal_mask) + elif self.causal: + attention_mask = causal_mask + elif attention_mask is not None: + attention_mask = jnp.expand_dims(attention_mask, axis=(-3, -2)) + + # During fast autoregressive decoding, we feed one position at a time, + # and cache the keys and values step by step. + + if self.causal and (self.has_variable("cache", "cached_key") or init_cache): + k, v, attention_mask = self._concatenate_to_cache(k, v, q, attention_mask) + + # Convert the boolean attention mask to an attention bias. + if attention_mask is not None: + # attention mask in the form of attention bias + attention_bias = lax.select( + attention_mask > 0, + jnp.full(attention_mask.shape, 0.0), + jnp.full(attention_mask.shape, float("-inf")), + ) + else: + attention_bias = None + + dropout_rng = None + if not deterministic and self.dropout > 0.0: + dropout_rng = self.make_rng("dropout") + + attn_weights = dot_product_attention_weights( + q, + k, + bias=attention_bias, + dropout_rng=dropout_rng, + dropout_rate=self.dropout, + broadcast_dropout=True, + deterministic=deterministic, + dtype=self.dtype, + precision=None, + ) + + attn_output = jnp.einsum("...hqk,...khd->...qhd", attn_weights, v) + attn_output = self._merge_heads(attn_output) + attn_output = self.out_proj(attn_output) + + return attn_output, attn_weights + + def _split_heads(self, hidden_state) -> jnp.ndarray: + return hidden_state.reshape(hidden_state.shape[:2] + (self.num_heads, self.head_dim)) + + def _merge_heads(self, hidden_state) -> jnp.ndarray: + return hidden_state.reshape(hidden_state.shape[:2] + (self.embed_dim,)) + + @nn.compact + def _concatenate_to_cache(self, key, value, query, attention_mask) -> Tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray]: + # detect if we're initializing by absence of existing cache data. + is_initialized = self.has_variable("cache", "cached_key") + cached_key = self.variable("cache", "cached_key", jnp.zeros, key.shape, key.dtype) + cached_value = self.variable("cache", "cached_value", jnp.zeros, value.shape, value.dtype) + cache_index = self.variable("cache", "cache_index", lambda: jnp.array(0, dtype=jnp.int32)) + + if is_initialized: + *batch_dims, max_length, num_heads, depth_per_head = cached_key.value.shape + # update key, value caches with our new 1d spatial slices + cur_index = cache_index.value + indices = (0,) * len(batch_dims) + (cur_index, 0, 0) + key = lax.dynamic_update_slice(cached_key.value, key, indices) + value = lax.dynamic_update_slice(cached_value.value, value, indices) + cached_key.value = key + cached_value.value = value + num_updated_cache_vectors = query.shape[1] + cache_index.value = cache_index.value + num_updated_cache_vectors + # causal mask for cached decoder self-attention: our single query position should only + # attend to those key positions that have already been generated and cached, not the + # remaining zero elements. + pad_mask = jnp.broadcast_to( + jnp.arange(max_length) < cur_index + num_updated_cache_vectors, + tuple(batch_dims) + (1, num_updated_cache_vectors, max_length), + ) + attention_mask = combine_masks(pad_mask, attention_mask) + + return key, value, attention_mask + + +class FlaxWhisperEncoderLayer(nn.Module): + config: WhisperConfig + dtype: jnp.dtype = jnp.float32 + use_scan: bool = False + + def setup(self) -> None: + self.embed_dim = self.config.d_model + self.self_attn = FlaxWhisperAttention( + config=self.config, + num_heads=self.config.encoder_attention_heads, + embed_dim=self.embed_dim, + causal=False, + dtype=self.dtype, + ) + self.dropout_layer = nn.Dropout(rate=self.config.dropout) + self.activation_dropout_layer = nn.Dropout(rate=self.config.activation_dropout) + self.activation_fn = ACT2FN[self.config.activation_function] + self.self_attn_layer_norm = nn.LayerNorm(epsilon=1e-05) + self.fc1 = nn.Dense( + self.config.encoder_ffn_dim, + dtype=self.dtype, + kernel_init=jax.nn.initializers.normal(self.config.init_std), + ) + self.fc2 = nn.Dense( + self.embed_dim, + kernel_init=jax.nn.initializers.normal(self.config.init_std), + dtype=self.dtype, + ) + self.final_layer_norm = nn.LayerNorm(dtype=self.dtype, epsilon=1e-05) + + def __call__( + self, + hidden_states: jnp.ndarray, + output_attentions: bool = False, + deterministic: bool = True, + ) -> jnp.ndarray: + if self.use_scan: + hidden_states = hidden_states[0] + + residual = hidden_states + hidden_states = self.self_attn_layer_norm(hidden_states) + hidden_states, attn_weights = self.self_attn(hidden_states) + hidden_states = self.dropout_layer(hidden_states, deterministic=deterministic) + hidden_states = residual + hidden_states + + residual = hidden_states + hidden_states = self.final_layer_norm(hidden_states) + hidden_states = self.activation_fn(self.fc1(hidden_states)) + hidden_states = self.activation_dropout_layer(hidden_states, deterministic=deterministic) + hidden_states = self.fc2(hidden_states) + hidden_states = residual + hidden_states + + outputs = (hidden_states,) + + if output_attentions: + outputs += (attn_weights,) + + if self.use_scan: + outputs = (outputs, None) + + return outputs + + +class EncoderLayerCollection(nn.Module): + config: WhisperConfig + dtype: jnp.dtype = jnp.float32 + use_scan: bool = False + + @nn.compact + def __call__( + self, + hidden_states: jnp.ndarray, + deterministic: bool = True, + output_attentions: bool = False, + output_hidden_states: bool = False, + return_dict: bool = True, + ) -> jnp.ndarray: + all_attentions = () if output_attentions else None + all_hidden_states = () if output_hidden_states else None + + if self.use_scan: + assert not output_attentions, "cannot use `scan` with `output_attentions` set to `True`" + assert not output_hidden_states, "cannot use `scan` with `output_hidden_states` set to `True`" + hidden_states = (hidden_states,) + + hidden_states, _ = scan_with_axes( + FlaxWhisperEncoderLayer, + variable_axes={"params": 0, "cache": 0}, + split_rngs={"params": True, "dropout": True}, + in_axes=(nn.broadcast, nn.broadcast), + length=self.config.encoder_layers, + )(self.config, dtype=self.dtype, use_scan=self.use_scan, name="Layers")( + hidden_states, + output_attentions, + deterministic, + ) + hidden_states = hidden_states[0] + else: + for layer in range(self.config.encoder_layers): + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + layer_outputs = FlaxWhisperEncoderLayer( + self.config, dtype=self.dtype, use_scan=self.use_scan, name=str(layer) + )( + hidden_states, + output_attentions=output_attentions, + deterministic=deterministic, + ) + hidden_states = layer_outputs[0] + if output_attentions: + all_attentions = all_attentions + (layer_outputs[1],) + + if output_hidden_states: + all_hidden_states += (hidden_states,) + + outputs = (hidden_states, all_hidden_states, all_attentions) + + if not return_dict: + return tuple(v for v in outputs if v is not None) + + return FlaxBaseModelOutput( + last_hidden_state=hidden_states, hidden_states=all_hidden_states, attentions=all_attentions + ) + + +class FlaxWhisperDecoderLayer(nn.Module): + config: WhisperConfig + dtype: jnp.dtype = jnp.float32 + use_scan: bool = False + + def setup(self) -> None: + self.embed_dim = self.config.d_model + self.self_attn = FlaxWhisperAttention( + config=self.config, + num_heads=self.config.decoder_attention_heads, + embed_dim=self.embed_dim, + causal=True, + ) + self.self_attn_layer_norm = nn.LayerNorm(dtype=self.dtype, epsilon=1e-05) + + self.encoder_attn = FlaxWhisperAttention( + config=self.config, + num_heads=self.config.decoder_attention_heads, + embed_dim=self.embed_dim, + causal=False, + dtype=self.dtype, + ) + self.encoder_attn_layer_norm = nn.LayerNorm(dtype=self.dtype, epsilon=1e-05) + self.dropout_layer = nn.Dropout(rate=self.config.dropout) + self.activation_dropout_layer = nn.Dropout(rate=self.config.activation_dropout) + self.activation_fn = ACT2FN[self.config.activation_function] + self.fc1 = nn.Dense( + self.config.decoder_ffn_dim, + kernel_init=jax.nn.initializers.normal(self.config.init_std), + dtype=self.dtype, + ) + self.fc2 = nn.Dense( + self.embed_dim, + kernel_init=jax.nn.initializers.normal(self.config.init_std), + dtype=self.dtype, + ) + self.final_layer_norm = nn.LayerNorm(dtype=self.dtype, epsilon=1e-05) + + def __call__( + self, + hidden_states: jnp.ndarray, + attention_mask: Optional[jnp.ndarray] = None, + encoder_hidden_states: Optional[jnp.ndarray] = None, + init_cache: bool = False, + output_attentions: bool = True, + deterministic: bool = True, + ) -> Tuple[jnp.ndarray]: + if self.use_scan: + hidden_states = hidden_states[0] + residual = hidden_states + + hidden_states = self.self_attn_layer_norm(hidden_states) + hidden_states, self_attn_weights = self.self_attn( + hidden_states, + attention_mask=attention_mask, + init_cache=init_cache, + ) + hidden_states = self.dropout_layer(hidden_states, deterministic=deterministic) + hidden_states = residual + hidden_states + + cross_attn_weights = None + if encoder_hidden_states is not None: + residual = hidden_states + hidden_states = self.encoder_attn_layer_norm(hidden_states) + hidden_states, cross_attn_weights = self.encoder_attn( + hidden_states, + encoder_hidden_states, + ) + hidden_states = self.dropout_layer(hidden_states, deterministic=deterministic) + hidden_states = residual + hidden_states + + residual = hidden_states + hidden_states = self.final_layer_norm(hidden_states) + hidden_states = self.activation_fn(self.fc1(hidden_states)) + hidden_states = self.activation_dropout_layer(hidden_states, deterministic=deterministic) + hidden_states = self.fc2(hidden_states) + hidden_states = residual + hidden_states + + outputs = (hidden_states,) + if output_attentions: + outputs += (self_attn_weights, cross_attn_weights) + + if self.use_scan: + outputs = (outputs, None) + + return outputs + + +class FlaxWhisperDecoderLayerCollection(nn.Module): + config: WhisperConfig + dtype: jnp.dtype = jnp.float32 + use_scan: bool = False + + @nn.compact + def __call__( + self, + hidden_states: jnp.ndarray, + attention_mask: Optional[jnp.ndarray] = None, + encoder_hidden_states: Optional[jnp.ndarray] = None, + init_cache: bool = False, + output_hidden_states: bool = False, + output_attentions: bool = False, + deterministic: bool = True, + return_dict: bool = True, + ) -> jnp.ndarray: + # decoder layers + all_hidden_states = () if output_hidden_states else None + all_self_attns = () if output_attentions else None + all_cross_attentions = () if (output_attentions and encoder_hidden_states is not None) else None + + if self.use_scan: + hidden_states = (hidden_states,) + hidden_states, _ = scan_with_axes( + FlaxWhisperDecoderLayer, + variable_axes={"params": 0, "cache": 0}, + split_rngs={"params": True, "dropout": True}, + in_axes=(nn.broadcast, nn.broadcast, nn.broadcast, nn.broadcast, nn.broadcast), + length=self.config.decoder_layers, + )(self.config, dtype=self.dtype, use_scan=self.use_scan, name="Layers",)( + hidden_states, + attention_mask, + encoder_hidden_states, + init_cache, + output_attentions, + deterministic, + ) + hidden_states = hidden_states[0] + else: + for layer in range(self.config.decoder_layers): + if output_hidden_states: + all_hidden_states += (hidden_states,) + layer_outputs = FlaxWhisperDecoderLayer( + self.config, dtype=self.dtype, use_scan=self.use_scan, name=str(layer) + )( + hidden_states, + attention_mask, + encoder_hidden_states, + init_cache=init_cache, + output_attentions=output_attentions, + deterministic=deterministic, + ) + hidden_states = layer_outputs[0] + if output_attentions: + all_self_attns += (layer_outputs[1],) + + if encoder_hidden_states is not None: + all_cross_attentions += (layer_outputs[2],) + + # add hidden states from the last decoder layer + if output_hidden_states: + all_hidden_states += (hidden_states,) + + outputs = [hidden_states, all_hidden_states, all_self_attns, all_cross_attentions] + + if not return_dict: + return tuple(v for v in outputs if v is not None) + + return FlaxBaseModelOutputWithPastAndCrossAttentions( + last_hidden_state=hidden_states, + hidden_states=all_hidden_states, + attentions=all_self_attns, + cross_attentions=all_cross_attentions, + ) + + +class FlaxWhisperEncoder(nn.Module): + config: WhisperConfig + dtype: jnp.dtype = jnp.float32 + use_scan: bool = False + + def setup(self) -> None: + self.conv1 = nn.Conv( + self.config.d_model, + kernel_size=(3,), + padding=1, + kernel_init=jax.nn.initializers.normal(self.config.init_std), + dtype=self.dtype, + ) + self.conv2 = nn.Conv( + self.config.d_model, + kernel_size=(3,), + strides=2, + padding=1, + kernel_init=jax.nn.initializers.normal(self.config.init_std), + dtype=self.dtype, + ) + + self.dropout_layer = nn.Dropout(rate=self.config.dropout) + + self.layers = EncoderLayerCollection( + self.config, + dtype=self.dtype, + use_scan=self.use_scan, + ) + self.embed_positions = nn.Embed(self.config.max_source_positions, self.config.d_model, dtype=self.dtype) + + self.layer_norm = nn.LayerNorm(dtype=self.dtype, epsilon=1e-05) + + def __call__( + self, + input_features: jnp.ndarray, + output_attentions: bool = False, + output_hidden_states: bool = False, + return_dict: bool = True, + deterministic: bool = True, + ) -> jnp.ndarray: + input_features = input_features.transpose(0, 2, 1) + hidden_states = jax.nn.gelu(self.conv1(input_features), approximate=False) + hidden_states = jax.nn.gelu(self.conv2(hidden_states), approximate=False) + + assert hidden_states.shape[1:] == ( + self.config.max_source_positions, + self.config.d_model, + ), "incorrect audio shape" + hidden_states = hidden_states + self.embed_positions(jnp.arange(self.config.max_source_positions)) + + hidden_states = self.dropout_layer(hidden_states, deterministic=deterministic) + + outputs = self.layers( + hidden_states, + deterministic=deterministic, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + hidden_states = self.layer_norm(outputs[0]) + + if not return_dict: + return (hidden_states,) + outputs[1:] + + return FlaxBaseModelOutput( + last_hidden_state=hidden_states, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + +class FlaxWhisperDecoder(nn.Module): + config: WhisperConfig + dtype: jnp.dtype = jnp.float32 + use_scan: bool = False + + def setup(self) -> None: + self.embed_tokens = nn.Embed(self.config.vocab_size, self.config.d_model, dtype=self.dtype) + self.embed_positions = nn.Embed(self.config.max_target_positions, self.config.d_model, dtype=self.dtype) + + self.layers = FlaxWhisperDecoderLayerCollection(self.config, dtype=self.dtype, use_scan=self.use_scan) + + self.dropout_layer = nn.Dropout(rate=self.config.dropout) + + self.layer_norm = nn.LayerNorm(dtype=self.dtype, epsilon=1e-5) + + def __call__( + self, + input_ids: jnp.ndarray, + encoder_hidden_states: jnp.ndarray, + attention_mask: Optional[jnp.ndarray] = None, + position_ids: Optional[jnp.ndarray] = None, + init_cache: bool = False, + output_attentions: bool = False, + output_hidden_states: bool = False, + return_dict: bool = True, + deterministic: bool = True, + ) -> jnp.ndarray: + if attention_mask is not None: + if position_ids is None: + position_ids = attention_mask.cumsum(-1) - 1 + if position_ids is None: + batch_size, sequence_length = input_ids.shape + position_ids = jnp.broadcast_to(jnp.arange(sequence_length)[None, :], (batch_size, sequence_length)) + + pos_emb = self.embed_positions(position_ids) + + hidden_states = self.embed_tokens(input_ids) + pos_emb + hidden_states = self.dropout_layer(hidden_states, deterministic=deterministic) + + outputs = self.layers( + hidden_states, + attention_mask=attention_mask, + encoder_hidden_states=encoder_hidden_states, + deterministic=deterministic, + init_cache=init_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + hidden_states = self.layer_norm(outputs[0]) + + if not return_dict: + return (hidden_states,) + outputs[1:] + + return FlaxBaseModelOutputWithPastAndCrossAttentions( + last_hidden_state=hidden_states, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + cross_attentions=outputs.cross_attentions, + ) + + +class FlaxWhisperModule(nn.Module): + config: WhisperConfig + dtype: jnp.dtype = jnp.float32 + use_scan: bool = False + + def setup(self) -> None: + self.encoder = FlaxWhisperEncoder(self.config, dtype=self.dtype, use_scan=self.use_scan) + self.decoder = FlaxWhisperDecoder(self.config, dtype=self.dtype, use_scan=self.use_scan) + + def __call__( + self, + input_features, + decoder_input_ids, + decoder_attention_mask: Optional[jnp.ndarray] = None, + decoder_position_ids: Optional[jnp.ndarray] = None, + output_attentions: bool = False, + output_hidden_states: bool = False, + return_dict: bool = True, + deterministic: bool = True, + ): + encoder_outputs = self.encoder( + input_features, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + deterministic=deterministic, + ) + + decoder_outputs = self.decoder( + input_ids=decoder_input_ids, + attention_mask=decoder_attention_mask, + position_ids=decoder_position_ids, + encoder_hidden_states=encoder_outputs[0], + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + deterministic=deterministic, + ) + + if not return_dict: + return decoder_outputs + encoder_outputs + + return FlaxSeq2SeqModelOutput( + last_hidden_state=decoder_outputs.last_hidden_state, + decoder_hidden_states=decoder_outputs.hidden_states, + decoder_attentions=decoder_outputs.attentions, + cross_attentions=decoder_outputs.cross_attentions, + encoder_last_hidden_state=encoder_outputs.last_hidden_state, + encoder_hidden_states=encoder_outputs.hidden_states, + encoder_attentions=encoder_outputs.attentions, + ) + + def _get_encoder_module(self): + return self.encoder + + def _get_decoder_module(self): + return self.decoder + + +class FlaxWhisperPreTrainedModel(FlaxPreTrainedModel): + config_class = WhisperConfig + base_model_prefix: str = "model" + module_class: nn.Module = None + + def __init__( + self, + config: WhisperConfig, + input_shape: Tuple[int] = (1, 80, 3000), + seed: int = 0, + dtype: jnp.dtype = jnp.float32, + _do_init: bool = True, + **kwargs + ): + module = self.module_class(config=config, dtype=dtype, **kwargs) + super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype, _do_init=_do_init) + + def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple, params: FrozenDict = None) -> FrozenDict: + # init input tensors + input_features = jnp.zeros(input_shape) + input_features = input_features.at[(..., -1)].set(self.config.eos_token_id) + + decoder_input_ids = jnp.zeros((input_shape[0], 1), dtype="i4") + decoder_attention_mask = jnp.ones_like(decoder_input_ids) + + batch_size, sequence_length = decoder_input_ids.shape + decoder_position_ids = jnp.broadcast_to(jnp.arange(sequence_length)[None, :], (batch_size, sequence_length)) + + params_rng, dropout_rng = jax.random.split(rng) + rngs = {"params": params_rng, "dropout": dropout_rng} + + random_params = self.module.init( + rngs, + input_features=input_features, + decoder_input_ids=decoder_input_ids, + decoder_attention_mask=decoder_attention_mask, + decoder_position_ids=decoder_position_ids, + )["params"] + + if params is not None: + random_params = flatten_dict(unfreeze(random_params)) + params = flatten_dict(unfreeze(params)) + for missing_key in self._missing_keys: + params[missing_key] = random_params[missing_key] + self._missing_keys = set() + return freeze(unflatten_dict(params)) + else: + return random_params + + def init_cache(self, batch_size, max_length, encoder_outputs): + r""" + Args: + batch_size (`int`): + batch_size used for fast auto-regressive decoding. Defines the batch size of the initialized cache. + max_length (`int`): + maximum possible length for auto-regressive decoding. Defines the sequence length of the initialized + cache. + encoder_outputs (`Union[FlaxBaseModelOutput, tuple(tuple(jnp.ndarray)]`): + `encoder_outputs` consists of (`last_hidden_state`, *optional*: `hidden_states`, *optional*: + `attentions`). `last_hidden_state` of shape `(batch_size, sequence_length, hidden_size)`, *optional*) + is a sequence of hidden-states at the output of the last layer of the encoder. Used in the + cross-attention of the decoder. + """ + # init input variables to retrieve cache + decoder_input_ids = jnp.ones((batch_size, max_length), dtype="i4") + decoder_attention_mask = jnp.ones_like(decoder_input_ids) + decoder_position_ids = jnp.broadcast_to( + jnp.arange(jnp.atleast_2d(decoder_input_ids).shape[-1]), decoder_input_ids.shape + ) + + def _decoder_forward(module, decoder_input_ids, decoder_attention_mask, decoder_position_ids, **kwargs): + decoder_module = module._get_decoder_module() + return decoder_module( + input_ids=decoder_input_ids, + attention_mask=decoder_attention_mask, + position_ids=decoder_position_ids, + **kwargs, + ) + + init_variables = self.module.init( + jax.random.PRNGKey(0), + decoder_input_ids=decoder_input_ids, + decoder_attention_mask=decoder_attention_mask, + decoder_position_ids=decoder_position_ids, + encoder_hidden_states=encoder_outputs[0], + init_cache=True, + method=_decoder_forward, # we only need to call the decoder to init the cache + ) + + return unfreeze(init_variables["cache"]) + + @add_start_docstrings(WHISPER_ENCODE_INPUTS_DOCSTRING) + @replace_return_docstrings(output_type=FlaxBaseModelOutput, config_class=WhisperConfig) + def encode( + self, + input_features: jnp.ndarray, + attention_mask: Optional[jnp.ndarray] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + train: bool = False, + params: dict = None, + dropout_rng: PRNGKey = None, + **kwargs, + ): + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.return_dict + + # Handle any PRNG if needed + rngs = {} + if dropout_rng is not None: + rngs["dropout"] = dropout_rng + + def _encoder_forward(module, input_features, **kwargs): + encode_module = module._get_encoder_module() + return encode_module(input_features, **kwargs) + + return self.module.apply( + {"params": params or self.params}, + input_features=input_features, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + deterministic=not train, + rngs=rngs, + method=_encoder_forward, + ) + + @add_start_docstrings(WHISPER_DECODE_INPUTS_DOCSTRING) + @replace_return_docstrings(output_type=FlaxBaseModelOutputWithPastAndCrossAttentions, config_class=WhisperConfig) + def decode( + self, + decoder_input_ids, + encoder_outputs, + encoder_attention_mask: Optional[jnp.ndarray] = None, + decoder_attention_mask: Optional[jnp.ndarray] = None, + decoder_position_ids: Optional[jnp.ndarray] = None, + past_key_values: dict = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + train: bool = False, + params: dict = None, + dropout_rng: PRNGKey = None, + ): + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.return_dict + + encoder_hidden_states = encoder_outputs[0] + + batch_size, sequence_length = decoder_input_ids.shape + if decoder_attention_mask is None: + decoder_attention_mask = jnp.ones((batch_size, sequence_length)) + + if decoder_position_ids is None: + if past_key_values is not None: + raise ValueError("Make sure to provide `decoder_position_ids` when passing `past_key_values`.") + + decoder_position_ids = jnp.broadcast_to( + jnp.arange(sequence_length)[None, :], (batch_size, sequence_length) + ) + + # Handle any PRNG if needed + rngs = {} + if dropout_rng is not None: + rngs["dropout"] = dropout_rng + + inputs = {"params": params or self.params} + + # if past_key_values are passed then cache is already initialized a private flag init_cache has to be + # passed down to ensure cache is used. It has to be made sure that cache is marked as mutable so that + # it can be changed by FlaxWhisperAttention module + if past_key_values: + inputs["cache"] = past_key_values + mutable = ["cache"] + else: + mutable = False + + def _decoder_forward(module, decoder_input_ids, decoder_attention_mask, decoder_position_ids, **kwargs): + decoder_module = module._get_decoder_module() + return decoder_module( + input_ids=decoder_input_ids, + attention_mask=decoder_attention_mask, + position_ids=decoder_position_ids, + **kwargs, + ) + + outputs = self.module.apply( + inputs, + decoder_input_ids=jnp.array(decoder_input_ids, dtype="i4"), + decoder_attention_mask=jnp.array(decoder_attention_mask, dtype="i4"), + decoder_position_ids=jnp.array(decoder_position_ids, dtype="i4"), + encoder_hidden_states=encoder_hidden_states, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + deterministic=not train, + rngs=rngs, + mutable=mutable, + method=_decoder_forward, + ) + + # add updated cache to model output + if past_key_values is not None and return_dict: + outputs, past = outputs + outputs["past_key_values"] = unfreeze(past["cache"]) + return outputs + elif past_key_values is not None and not return_dict: + outputs, past = outputs + outputs = outputs[:1] + (unfreeze(past["cache"]),) + outputs[1:] + + return outputs + + @add_start_docstrings_to_model_forward(WHISPER_INPUTS_DOCSTRING) + def __call__( + self, + input_features: jnp.ndarray, + decoder_input_ids: jnp.ndarray, + attention_mask: Optional[jnp.ndarray] = None, + decoder_attention_mask: Optional[jnp.ndarray] = None, + position_ids: Optional[jnp.ndarray] = None, + decoder_position_ids: Optional[jnp.ndarray] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + train: bool = False, + params: dict = None, + dropout_rng: PRNGKey = None, + ): + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.return_dict + + # prepare decoder inputs + if decoder_attention_mask is None: + decoder_attention_mask = jnp.ones_like(decoder_input_ids) + if decoder_position_ids is None: + batch_size, sequence_length = decoder_input_ids.shape + decoder_position_ids = jnp.broadcast_to( + jnp.arange(sequence_length)[None, :], (batch_size, sequence_length) + ) + + # Handle any PRNG if needed + rngs = {"dropout": dropout_rng} if dropout_rng is not None else {} + + return self.module.apply( + {"params": params or self.params}, + input_features=input_features, + decoder_input_ids=jnp.array(decoder_input_ids, dtype="i4"), + decoder_attention_mask=jnp.array(decoder_attention_mask, dtype="i4"), + decoder_position_ids=jnp.array(decoder_position_ids, dtype="i4"), + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + deterministic=not train, + rngs=rngs, + ) + + def scan(self, params: Optional[FrozenDict] = None): + self._module = self.module_class(config=self.config, dtype=self.dtype, use_scan=True) + init_fn = partial(self.init_weights, input_shape=self.input_shape) + params_shape_tree = jax.eval_shape(init_fn, self.key) + + # get the shape of the parameters + self._params_shape_tree = params_shape_tree + + # save required_params as set + self._required_params = set(flatten_dict(unfreeze(params_shape_tree)).keys()) + + # initialize the parameters + + if params is not None: + params = convert_unroll_to_scan(self, params) + return params + elif self._is_initialized: + self.params = convert_unroll_to_scan(self, self.params) + + def scan_disable(self, params: Optional[FrozenDict] = None): + self._module = self.module_class( + config=self.config, + dtype=self.dtype, + use_scan=False, + ) + init_fn = partial(self.init_weights, input_shape=self.input_shape) + params_shape_tree = jax.eval_shape(init_fn, self.key) + + # get the shape of the parameters + self._params_shape_tree = params_shape_tree + + # save required_params as set + self._required_params = set(flatten_dict(unfreeze(params_shape_tree)).keys()) + + # initialize the parameters + if params is not None: + params = convert_scan_to_unroll(self, params) + return params + elif self._is_initialized: + self.params = convert_scan_to_unroll(self, self.params) + + +@add_start_docstrings( + "The bare Whisper Model transformer outputting raw hidden-states without any specific head on top.", + WHISPER_START_DOCSTRING, +) +class FlaxWhisperModel(FlaxWhisperPreTrainedModel): + config: WhisperConfig + dtype: jnp.dtype = jnp.float32 # the dtype of the computation + module_class = FlaxWhisperModule + + +append_call_sample_docstring( + FlaxWhisperModel, _TOKENIZER_FOR_DOC, _CHECKPOINT_FOR_DOC, FlaxSeq2SeqModelOutput, _CONFIG_FOR_DOC +) + + +class FlaxWhisperForConditionalGenerationModule(nn.Module): + config: WhisperConfig + dtype: jnp.dtype = jnp.float32 + use_scan: bool = False + + def setup(self): + self.model = FlaxWhisperModule(config=self.config, dtype=self.dtype, use_scan=self.use_scan) + self.lm_head = nn.Dense( + self.config.vocab_size, + use_bias=False, + dtype=self.dtype, + kernel_init=jax.nn.initializers.normal(self.config.init_std), + ) + + def _get_encoder_module(self): + return self.model.encoder + + def _get_decoder_module(self): + return self.model.decoder + + def __call__( + self, + input_features, + decoder_input_ids, + decoder_attention_mask: jnp.ndarray = None, + decoder_position_ids: jnp.ndarray = None, + position_ids: jnp.ndarray = None, + attention_mask: jnp.ndarray = None, + output_attentions: bool = False, + output_hidden_states: bool = False, + return_dict: bool = True, + deterministic: bool = True, + ): + outputs = self.model( + input_features=input_features, + decoder_input_ids=decoder_input_ids, + decoder_attention_mask=decoder_attention_mask, + decoder_position_ids=decoder_position_ids, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + deterministic=deterministic, + ) + + hidden_states = outputs[0] + + if self.config.tie_word_embeddings: + shared_embedding = self.model.decoder.embed_tokens.variables["params"]["embedding"] + lm_logits = self.lm_head.apply({"params": {"kernel": shared_embedding.T}}, hidden_states) + else: + lm_logits = self.lm_head(hidden_states) + + if not return_dict: + output = (lm_logits,) + outputs[1:] + return output + + return FlaxSeq2SeqLMOutput( + logits=lm_logits, + decoder_hidden_states=outputs.decoder_hidden_states, + decoder_attentions=outputs.decoder_attentions, + cross_attentions=outputs.cross_attentions, + encoder_last_hidden_state=outputs.encoder_last_hidden_state, + encoder_hidden_states=outputs.encoder_hidden_states, + encoder_attentions=outputs.encoder_attentions, + ) + + +@add_start_docstrings("The Whisper Model with a language modeling head.", WHISPER_START_DOCSTRING) +class FlaxWhisperForConditionalGeneration(FlaxWhisperPreTrainedModel): + module_class = FlaxWhisperForConditionalGenerationModule + dtype: jnp.dtype = jnp.float32 + + @add_start_docstrings(WHISPER_DECODE_INPUTS_DOCSTRING) + @replace_return_docstrings(output_type=FlaxCausalLMOutputWithCrossAttentions, config_class=WhisperConfig) + def decode( + self, + decoder_input_ids, + encoder_outputs, + encoder_attention_mask: Optional[jnp.ndarray] = None, + decoder_attention_mask: Optional[jnp.ndarray] = None, + decoder_position_ids: Optional[jnp.ndarray] = None, + past_key_values: dict = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + train: bool = False, + params: dict = None, + dropout_rng: PRNGKey = None, + ): + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.return_dict + + encoder_hidden_states = encoder_outputs[0] + + batch_size, sequence_length = decoder_input_ids.shape + if decoder_attention_mask is None: + decoder_attention_mask = jnp.ones((batch_size, sequence_length)) + + if decoder_position_ids is None: + if past_key_values is not None: + raise ValueError("Make sure to provide `decoder_position_ids` when passing `past_key_values`.") + + decoder_position_ids = jnp.broadcast_to( + jnp.arange(sequence_length)[None, :], (batch_size, sequence_length) + ) + + # Handle any PRNG if needed + rngs = {} + if dropout_rng is not None: + rngs["dropout"] = dropout_rng + + inputs = {"params": params or self.params} + + # if past_key_values are passed then cache is already initialized a private flag init_cache has to be + # passed down to ensure cache is used. It has to be made sure that cache is marked as mutable so that + # it can be changed by FlaxWhisperAttention module + if past_key_values: + inputs["cache"] = past_key_values + mutable = ["cache"] + else: + mutable = False + + def _decoder_forward(module, decoder_input_ids, decoder_attention_mask, decoder_position_ids, **kwargs): + decoder_module = module._get_decoder_module() + outputs = decoder_module( + input_ids=decoder_input_ids, + attention_mask=decoder_attention_mask, + position_ids=decoder_position_ids, + **kwargs, + ) + hidden_states = outputs[0] + + if self.config.tie_word_embeddings: + shared_embedding = module.model.decoder.embed_tokens.variables["params"]["embedding"] + lm_logits = module.lm_head.apply({"params": {"kernel": shared_embedding.T}}, hidden_states) + else: + lm_logits = module.lm_head(hidden_states) + + return lm_logits, outputs + + outputs = self.module.apply( + inputs, + decoder_input_ids=jnp.array(decoder_input_ids, dtype="i4"), + decoder_attention_mask=jnp.array(decoder_attention_mask, dtype="i4"), + decoder_position_ids=jnp.array(decoder_position_ids, dtype="i4"), + encoder_hidden_states=encoder_hidden_states, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + deterministic=not train, + rngs=rngs, + mutable=mutable, + method=_decoder_forward, + ) + + if past_key_values is None: + lm_logits, decoder_outputs = outputs + else: + (lm_logits, decoder_outputs), past = outputs + + if return_dict: + outputs = FlaxCausalLMOutputWithCrossAttentions( + logits=lm_logits, + hidden_states=decoder_outputs.hidden_states, + attentions=decoder_outputs.attentions, + cross_attentions=decoder_outputs.cross_attentions, + ) + else: + outputs = (lm_logits,) + decoder_outputs[1:] + + # add updated cache to model output + if past_key_values is not None and return_dict: + outputs["past_key_values"] = unfreeze(past["cache"]) + return outputs + elif past_key_values is not None and not return_dict: + outputs = outputs[:1] + (unfreeze(past["cache"]),) + outputs[1:] + + return outputs + + def prepare_inputs_for_generation( + self, + decoder_input_ids, + max_length, + attention_mask: Optional[jnp.DeviceArray] = None, + decoder_attention_mask: Optional[jnp.DeviceArray] = None, + encoder_outputs=None, + **kwargs + ): + # initializing the cache + batch_size, seq_length = decoder_input_ids.shape + + past_key_values = self.init_cache(batch_size, max_length, encoder_outputs) + # Note that usually one would have to put 0's in the attention_mask for x > input_ids.shape[-1] and x < cache_length. + # But since the decoder uses a causal mask, those positions are masked anyways. + # Thus we can create a single static attention_mask here, which is more efficient for compilation + extended_attention_mask = jnp.ones((batch_size, max_length), dtype="i4") + if decoder_attention_mask is not None: + position_ids = decoder_attention_mask.cumsum(axis=-1) - 1 + extended_attention_mask = lax.dynamic_update_slice(extended_attention_mask, decoder_attention_mask, (0, 0)) + else: + position_ids = jnp.broadcast_to(jnp.arange(seq_length, dtype="i4")[None, :], (batch_size, seq_length)) + + return { + "past_key_values": past_key_values, + "encoder_outputs": encoder_outputs, + "encoder_attention_mask": attention_mask, + "decoder_attention_mask": extended_attention_mask, + "decoder_position_ids": position_ids, + } + + def update_inputs_for_generation(self, model_outputs, model_kwargs): + model_kwargs["past_key_values"] = model_outputs.past_key_values + model_kwargs["decoder_position_ids"] = model_kwargs["decoder_position_ids"][:, -1:] + 1 + return model_kwargs + + +def convert_unroll_to_scan(model, params: Union[Dict, FrozenDict]): + r""" + Convert a `PyTree` of unrolled model parameters to a scanned block of model parameters. This method can be used + to explicitly convert the model parameters to scanned format. This returns a new `params` tree and does not + convert the `params` in place. + ```""" + if isinstance(params, FrozenDict): + params = unfreeze(params) + + params = flatten_dict(params, sep="/") + keys = list(params.keys()) + + for k in keys: + # Identify all "unrolled" layers formed in LayerCollections + # These params contain the identifier `layers` in their key + if "layers/0" in k: + # Squash the keys for the N unrolled layers into one single key: + # (layers/0, ..., layers/N) -> layers/Layers + scan_key = k.replace("0", "Layers", 1) + stacked_params = [] + + # Iterate over the unrolled layers (1,...,N) + for i in range(model.config.encoder_layers): + # Stack the params for the N layers into one super block + # and remove the unrolled layer params on the fly + # -> no memory overhead for conversion! + unrolled_layer = params.pop(k.replace("0", str(i), 1)) + stacked_params.append(unrolled_layer) + + params[scan_key] = jnp.stack(stacked_params) + + # Finally, unflatten the dict to restore the nested pytree structure + params = unflatten_dict(params, sep="/") + return params + + +def convert_scan_to_unroll(model, params: Union[Dict, FrozenDict]): + r""" + Convert a `PyTree` of scanned model parameters to an unrolled stack of model parameters. This method can be + used to explicitly convert the model parameters to unrolled format. This returns a new `params` tree and does + not convert the `params` in place. + ```""" + + if isinstance(params, FrozenDict): + params = unfreeze(params) + + params = flatten_dict(params, sep="/") + keys = list(params.keys()) + + for k in keys: + # Identify all "scan" layers in LayerCollections + # These params contain the identifier `EncoderLayers` or `DecoderLayers` in their key + if "Layers" in k: + # Remove the scan layer from the PyTree of params + scan_layer = params.pop(k) + + # Unroll the key for the stacked scan matrix into N separate keys, indexed by layer number + # layer/Layers -> (layer/0, ..., layer/N) + for i in range(model.dims.n_audio_layer): + # Unstack the params for the i-th scan layer to unrolled + # and remove corresponding scan params on the fly + # -> no memory overhead for conversion! + unrolled_key = k.replace("Layers", str(i)) + params[unrolled_key], scan_layer = scan_layer[0], scan_layer[1:] + + params = unflatten_dict(params, sep="/") + return params diff --git a/src/transformers/utils/dummy_flax_objects.py b/src/transformers/utils/dummy_flax_objects.py index 20339c94b7cf02..26e52b4cec59f7 100644 --- a/src/transformers/utils/dummy_flax_objects.py +++ b/src/transformers/utils/dummy_flax_objects.py @@ -17,13 +17,6 @@ def __init__(self, *args, **kwargs): requires_backends(self, ["flax"]) -class FlaxGenerationMixin(metaclass=DummyObject): - _backends = ["flax"] - - def __init__(self, *args, **kwargs): - requires_backends(self, ["flax"]) - - class FlaxLogitsProcessor(metaclass=DummyObject): _backends = ["flax"] @@ -1075,6 +1068,27 @@ def __init__(self, *args, **kwargs): requires_backends(self, ["flax"]) +class FlaxWhisperForConditionalGeneration(metaclass=DummyObject): + _backends = ["flax"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["flax"]) + + +class FlaxWhisperModel(metaclass=DummyObject): + _backends = ["flax"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["flax"]) + + +class FlaxWhisperPreTrainedModel(metaclass=DummyObject): + _backends = ["flax"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["flax"]) + + class FlaxXGLMForCausalLM(metaclass=DummyObject): _backends = ["flax"] diff --git a/tests/models/whisper/test_modeling_flax_whisper.py b/tests/models/whisper/test_modeling_flax_whisper.py new file mode 100644 index 00000000000000..d0fde5a694738e --- /dev/null +++ b/tests/models/whisper/test_modeling_flax_whisper.py @@ -0,0 +1,289 @@ +# coding=utf-8 +# Copyright 2022 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import inspect +import unittest + +from datasets import load_dataset + +from transformers import WhisperConfig, is_flax_available +from transformers.testing_utils import require_flax, slow +from transformers.utils import cached_property + +from ...test_configuration_common import ConfigTester +from ...test_modeling_flax_common import FlaxModelTesterMixin, floats_tensor + + +if is_flax_available(): + import numpy as np + + import jax + from transformers import ( + FlaxWhisperForConditionalGeneration, + FlaxWhisperModel, + WhisperFeatureExtractor, + WhisperProcessor, + ) + + +@require_flax +class FlaxWhisperModelTester: + config_cls = WhisperConfig + config_updates = {} + hidden_act = "gelu" + + def __init__( + self, + parent, + batch_size=1, + seq_length=3000, + is_training=True, + use_labels=False, + vocab_size=99, + d_model=384, + decoder_attention_heads=6, + decoder_ffn_dim=1536, + decoder_layers=4, + encoder_attention_heads=6, + encoder_ffn_dim=1536, + encoder_layers=4, + input_channels=1, + hidden_act="gelu", + hidden_dropout_prob=0.1, + attention_probs_dropout_prob=0.1, + max_position_embeddings=20, + max_source_positions=1500, + max_target_positions=448, + bos_token_id=98, + eos_token_id=98, + pad_token_id=0, + num_mel_bins=80, + decoder_start_token_id=85, + num_conv_layers=1, + suppress_tokens=None, + begin_suppress_tokens=None, + ): + self.parent = parent + self.batch_size = batch_size + self.seq_length = seq_length + self.is_training = is_training + self.use_labels = use_labels + self.vocab_size = vocab_size + self.d_model = d_model + self.hidden_size = d_model + self.num_hidden_layers = encoder_layers + self.num_attention_heads = encoder_attention_heads + self.decoder_attention_heads = decoder_attention_heads + self.decoder_ffn_dim = decoder_ffn_dim + self.decoder_layers = decoder_layers + self.encoder_attention_heads = encoder_attention_heads + self.encoder_ffn_dim = encoder_ffn_dim + self.encoder_layers = encoder_layers + self.encoder_seq_length = seq_length // 2 + self.decoder_seq_length = 1 + self.input_channels = input_channels + self.hidden_act = hidden_act + self.hidden_dropout_prob = hidden_dropout_prob + self.attention_probs_dropout_prob = attention_probs_dropout_prob + self.num_mel_bins = num_mel_bins + self.max_position_embeddings = max_position_embeddings + self.max_source_positions = max_source_positions + self.max_target_positions = max_target_positions + self.eos_token_id = eos_token_id + self.pad_token_id = pad_token_id + self.bos_token_id = bos_token_id + self.decoder_start_token_id = decoder_start_token_id + self.num_conv_layers = num_conv_layers + self.suppress_tokens = suppress_tokens + self.begin_suppress_tokens = begin_suppress_tokens + + def prepare_config_and_inputs_for_common(self): + input_features = floats_tensor([self.batch_size, self.num_mel_bins, self.seq_length], self.vocab_size) + + decoder_input_ids = np.array(self.batch_size * [[self.decoder_start_token_id]]) + + config = WhisperConfig( + vocab_size=self.vocab_size, + num_mel_bins=self.num_mel_bins, + decoder_start_token_id=self.decoder_start_token_id, + is_encoder_decoder=True, + activation_function=self.hidden_act, + dropout=self.hidden_dropout_prob, + attention_dropout=self.attention_probs_dropout_prob, + max_source_positions=self.max_source_positions, + max_target_positions=self.max_target_positions, + pad_token_id=self.pad_token_id, + bos_token_id=self.bos_token_id, + eos_token_id=self.eos_token_id, + tie_word_embeddings=True, + d_model=self.d_model, + decoder_attention_heads=self.decoder_attention_heads, + decoder_ffn_dim=self.decoder_ffn_dim, + decoder_layers=self.decoder_layers, + encoder_attention_heads=self.encoder_attention_heads, + encoder_ffn_dim=self.encoder_ffn_dim, + encoder_layers=self.encoder_layers, + suppress_tokens=self.suppress_tokens, + begin_suppress_tokens=self.begin_suppress_tokens, + ) + inputs_dict = prepare_whisper_inputs_dict(config, input_features, decoder_input_ids) + return config, inputs_dict + + +def prepare_whisper_inputs_dict( + config, + input_ids, + decoder_input_ids, + attention_mask=None, + decoder_attention_mask=None, +): + if attention_mask is None: + attention_mask = np.not_equal(input_ids, config.pad_token_id).astype(np.int8) + if decoder_attention_mask is None: + decoder_attention_mask = np.concatenate( + [ + np.ones(decoder_input_ids[:, :1].shape, dtype=np.int8), + np.not_equal(decoder_input_ids[:, 1:], config.pad_token_id).astype(np.int8), + ], + axis=-1, + ) + return { + "input_features": input_ids, + "decoder_input_ids": decoder_input_ids, + } + + +@require_flax +class FlaxWhisperModelTest(FlaxModelTesterMixin, unittest.TestCase): + all_model_classes = ( + ( + FlaxWhisperForConditionalGeneration, + FlaxWhisperModel, + ) + if is_flax_available() + else () + ) + all_generative_model_classes = (FlaxWhisperForConditionalGeneration,) if is_flax_available() else () + is_encoder_decoder = True + test_pruning = False + test_head_masking = False + test_onnx = False + + def setUp(self): + self.model_tester = FlaxWhisperModelTester(self) + self.config_tester = ConfigTester(self, config_class=WhisperConfig) + + def test_config(self): + self.config_tester.run_common_tests() + + # overwrite because of `input_features` + def test_forward_signature(self): + config, _ = self.model_tester.prepare_config_and_inputs_for_common() + + for model_class in self.all_model_classes: + model = model_class(config) + signature = inspect.signature(model.__call__) + # signature.parameters is an OrderedDict => so arg_names order is deterministic + arg_names = [*signature.parameters.keys()] + + expected_arg_names = ["input_features", "decoder_input_ids"] + self.assertListEqual(arg_names[:2], expected_arg_names) + + # overwrite because of `input_features` + def test_jit_compilation(self): + config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + + for model_class in self.all_model_classes: + with self.subTest(model_class.__name__): + prepared_inputs_dict = self._prepare_for_class(inputs_dict, model_class) + model = model_class(config) + + @jax.jit + def model_jitted(input_features, decoder_input_ids, **kwargs): + return model(input_features=input_features, decoder_input_ids=decoder_input_ids, **kwargs) + + with self.subTest("JIT Enabled"): + jitted_outputs = model_jitted(**prepared_inputs_dict).to_tuple() + + with self.subTest("JIT Disabled"): + with jax.disable_jit(): + outputs = model_jitted(**prepared_inputs_dict).to_tuple() + + self.assertEqual(len(outputs), len(jitted_outputs)) + for jitted_output, output in zip(jitted_outputs, outputs): + self.assertEqual(jitted_output.shape, output.shape) + + +def _assert_tensors_equal(a, b, atol=1e-12, prefix=""): + """If tensors not close, or a and b arent both tensors, raise a nice Assertion error.""" + if a is None and b is None: + return True + try: + if _assert_tensors_equal(a, b, atol=atol): + return True + raise + except Exception: + if len(prefix) > 0: + prefix = f"{prefix}: " + raise AssertionError(f"{prefix}{a} != {b}") + + +def _long_tensor(tok_lst): + return np.array(tok_lst, dtype=np.int32) + + +TOLERANCE = 1e-4 + + +@slow +@require_flax +class FlaxWhisperModelIntegrationTest(unittest.TestCase): + @cached_property + def default_processor(self): + return WhisperProcessor.from_pretrained("openai/whisper-base") + + def _load_datasamples(self, num_samples): + ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation") + # automatic decoding with librispeech + speech_samples = ds.sort("id").select(range(num_samples))[:num_samples]["audio"] + + return [x["array"] for x in speech_samples] + + def test_tiny_logits_librispeech(self): + model = FlaxWhisperModel.from_pretrained("openai/whisper-tiny", from_pt=True) + input_speech = self._load_datasamples(1) + feature_extractor = WhisperFeatureExtractor() + input_features = feature_extractor(input_speech, return_tensors="np").input_features + + logits = model( + input_features, + decoder_input_ids=np.array([[50258, 50259, 50359]]), + output_hidden_states=False, + output_attentions=False, + return_dict=False, + ) + + # fmt: off + EXPECTED_LOGITS = np.array( + [ + 2.9892, -6.7607, 5.7348, 3.6096, 0.2152, -5.7321, 4.8855, -1.6407, + 0.2823, -1.5718, 10.4269, 3.4427, 0.0219, -8.0612, 3.4784, 8.4246, + 4.0575, -2.2864, 11.1084, 0.9963, 0.9884, -8.5154, -3.5469, -9.3713, + 0.9786, 3.5435, 7.4850, -5.2579, -1.4366, 10.4841 + ] + ) + # fmt: on + self.assertTrue(np.allclose(logits[0][0, 0, :30], EXPECTED_LOGITS, atol=1e-4)) \ No newline at end of file diff --git a/utils/check_repo.py b/utils/check_repo.py index 494a703f7e26bb..e42d54a2ab9abf 100644 --- a/utils/check_repo.py +++ b/utils/check_repo.py @@ -59,6 +59,8 @@ "DeformableDetrEncoder", # Building part of bigger (tested) model. "DeformableDetrDecoder", # Building part of bigger (tested) model. "OPTDecoder", # Building part of bigger (tested) model. + "FlaxWhisperDecoder", # Building part of bigger (tested) model. + "FlaxWhisperEncoder", # Building part of bigger (tested) model. "WhisperDecoder", # Building part of bigger (tested) model. "WhisperEncoder", # Building part of bigger (tested) model. "DecisionTransformerGPT2Model", # Building part of bigger (tested) model. From a9bed4c7e0fe356289715a97f83268e131c1bef3 Mon Sep 17 00:00:00 2001 From: andyehrenberg Date: Mon, 28 Nov 2022 14:17:28 -0800 Subject: [PATCH 002/111] rever change to setup --- setup.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/setup.py b/setup.py index 2dc58263cf383f..ef533058b7bf10 100644 --- a/setup.py +++ b/setup.py @@ -159,7 +159,7 @@ "starlette", "tensorflow-cpu>=2.4,<2.11", "tensorflow>=2.4,<2.11", - #"tensorflow-text", + "tensorflow-text", "tf2onnx", "timeout-decorator", "timm", @@ -247,8 +247,8 @@ def run(self): extras["ja"] = deps_list("fugashi", "ipadic", "unidic_lite", "unidic", "sudachipy", "sudachidict_core", "pyknp") extras["sklearn"] = deps_list("scikit-learn") -extras["tf"] = deps_list("tensorflow", "onnxconverter-common", "tf2onnx") -extras["tf-cpu"] = deps_list("tensorflow-cpu", "onnxconverter-common", "tf2onnx") +extras["tf"] = deps_list("tensorflow", "onnxconverter-common", "tf2onnx", "tensorflow-text") +extras["tf-cpu"] = deps_list("tensorflow-cpu", "onnxconverter-common", "tf2onnx", "tensorflow-text") extras["torch"] = deps_list("torch") extras["accelerate"] = deps_list("accelerate") From 03129935b7b73da97cd64f9bed8ff327352c4136 Mon Sep 17 00:00:00 2001 From: andyehrenberg Date: Mon, 28 Nov 2022 14:25:21 -0800 Subject: [PATCH 003/111] remove unused imports --- src/transformers/models/whisper/modeling_flax_whisper.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/src/transformers/models/whisper/modeling_flax_whisper.py b/src/transformers/models/whisper/modeling_flax_whisper.py index 3029410d505a69..f82ed70c9ec588 100644 --- a/src/transformers/models/whisper/modeling_flax_whisper.py +++ b/src/transformers/models/whisper/modeling_flax_whisper.py @@ -24,8 +24,6 @@ ACT2FN, FlaxPreTrainedModel, append_call_sample_docstring, - append_replace_return_docstrings, - overwrite_call_docstring, ) from ...utils import add_start_docstrings, add_start_docstrings_to_model_forward, logging, replace_return_docstrings from .configuration_whisper import WhisperConfig From c71fe4f2fc492a4c2b2d3915ecf17cee27790bd9 Mon Sep 17 00:00:00 2001 From: andyehrenberg Date: Tue, 29 Nov 2022 08:34:36 -0800 Subject: [PATCH 004/111] revert generation changes --- .../generation/flax_logits_process.py | 44 --------- src/transformers/generation/flax_utils.py | 99 +------------------ .../whisper/test_modeling_flax_whisper.py | 2 +- 3 files changed, 6 insertions(+), 139 deletions(-) diff --git a/src/transformers/generation/flax_logits_process.py b/src/transformers/generation/flax_logits_process.py index 641c52c21a1568..c8abeea0982374 100644 --- a/src/transformers/generation/flax_logits_process.py +++ b/src/transformers/generation/flax_logits_process.py @@ -265,47 +265,3 @@ def __call__(self, input_ids: jnp.ndarray, scores: jnp.ndarray, cur_len: int) -> scores = jnp.where(apply_penalty, scores.at[:, self.eos_token_id].set(-float("inf")), scores) return scores - - -class FlaxSuppressTokensAtBeginLogitsProcessor(FlaxLogitsProcessor): - r""" - [`FlaxSuppressTokensAtBeginLogitsProcessor`] supresses a list of tokens as soon as the `generate` function starts - generating using `begin_index` tokens. This should ensure that the tokens defined by `begin_suppress_tokens` are not - sampled at the begining of the generation. - """ - - def __init__(self, begin_suppress_tokens, begin_index): - self.begin_suppress_tokens = list(begin_suppress_tokens) - self.begin_index = begin_index - - def __call__(self, input_ids, scores, cur_len: int): - if input_ids.shape[1] == self.begin_index: - scores = scores.at[:, self.begin_suppress_tokens].set(-float("inf")) - - return scores - - -class FlaxSuppressTokensLogitsProcessor(FlaxLogitsProcessor): - def __init__(self, suppress_tokens: list): - self.suppress_tokens = list(suppress_tokens) - - def __call__(self, input_ids: jnp.ndarray, scores: jnp.ndarray, cur_len: int) -> jnp.ndarray: - scores = scores.at[..., self.suppress_tokens].set(-float("inf")) - - return scores - - -class FlaxForceTokensLogitsProcessor(FlaxLogitsProcessor): - r"""This processor can be used to force a list of tokens. The processor will set their log probs to `inf` so that they - are sampled at their corresponding index.""" - - def __init__(self, force_token_map): - self.force_token_map = dict(force_token_map) - - def __call__(self, input_ids: jnp.ndarray, scores: jnp.ndarray, cur_len: int): - generation_idx = input_ids.shape[-1] - current_token = self.force_token_map.get(generation_idx, None) - if current_token is not None: - scores = scores.at[:, :].set(-float("inf")) - scores = scores.at[:, current_token].set(0) - return scores diff --git a/src/transformers/generation/flax_utils.py b/src/transformers/generation/flax_utils.py index 33abc101351a54..5d936ce5b1dccd 100644 --- a/src/transformers/generation/flax_utils.py +++ b/src/transformers/generation/flax_utils.py @@ -18,7 +18,7 @@ import inspect import warnings from functools import partial -from typing import Any, Dict, List, Optional +from typing import Any, Dict, Optional import numpy as np @@ -36,11 +36,8 @@ from .flax_logits_process import ( FlaxForcedBOSTokenLogitsProcessor, FlaxForcedEOSTokenLogitsProcessor, - FlaxForceTokensLogitsProcessor, FlaxLogitsProcessorList, FlaxMinLengthLogitsProcessor, - FlaxSuppressTokensAtBeginLogitsProcessor, - FlaxSuppressTokensLogitsProcessor, FlaxTemperatureLogitsWarper, FlaxTopKLogitsWarper, FlaxTopPLogitsWarper, @@ -158,35 +155,6 @@ def _prepare_encoder_decoder_kwargs_for_generation(self, input_ids, params, mode model_kwargs["encoder_outputs"] = self.encode(input_ids, params=params, return_dict=True, **encoder_kwargs) return model_kwargs - def _prepare_decoder_input_ids_for_generation( - self, - batch_size: int, - decoder_start_token_id: int = None, - bos_token_id: int = None, - model_kwargs: Optional[Dict[str, jnp.ndarray]] = None, - ) -> jnp.ndarray: - if model_kwargs is not None and "decoder_input_ids" in model_kwargs: - # Only use this arg if not None, otherwise just remove from model_kwargs - decoder_input_ids = model_kwargs.pop("decoder_input_ids") - if decoder_input_ids is not None: - return decoder_input_ids - decoder_start_token_id = self._get_decoder_start_token_id(decoder_start_token_id, bos_token_id) - return jnp.array(decoder_start_token_id).reshape(1, -1).repeat(batch_size, axis=0) - - def _get_decoder_start_token_id(self, decoder_start_token_id: int = None, bos_token_id: int = None) -> int: - decoder_start_token_id = ( - decoder_start_token_id if decoder_start_token_id is not None else self.config.decoder_start_token_id - ) - bos_token_id = bos_token_id if bos_token_id is not None else self.config.bos_token_id - - if decoder_start_token_id is not None: - return decoder_start_token_id - elif bos_token_id is not None: - return bos_token_id - raise ValueError( - "`decoder_start_token_id` or `bos_token_id` has to be defined for encoder-decoder generation." - ) - @staticmethod def _expand_to_num_beams(tensor, num_beams): return jnp.broadcast_to(tensor[:, None], (tensor.shape[0], num_beams) + tensor.shape[1:]) @@ -259,9 +227,6 @@ def generate( min_length: Optional[int] = None, forced_bos_token_id: Optional[int] = None, forced_eos_token_id: Optional[int] = None, - suppress_tokens: Optional[List[int]] = None, - begin_suppress_tokens: Optional[List[int]] = None, - forced_decoder_ids: Optional[List[int]] = None, length_penalty: Optional[float] = None, early_stopping: Optional[bool] = None, trace: bool = True, @@ -369,19 +334,12 @@ def generate( "generation results, please set `padding_side='left'` when initializing the tokenizer." ) - batch_size = input_ids.shape[0] - if self.config.is_encoder_decoder: # add encoder_outputs to model_kwargs if model_kwargs.get("encoder_outputs") is None: model_kwargs = self._prepare_encoder_decoder_kwargs_for_generation(input_ids, params, model_kwargs) # prepare decoder_input_ids for generation - input_ids = self._prepare_decoder_input_ids_for_generation( - batch_size, - decoder_start_token_id=decoder_start_token_id, - bos_token_id=bos_token_id, - model_kwargs=model_kwargs, - ) + input_ids = jnp.ones((input_ids.shape[0], 1), dtype="i4") * decoder_start_token_id # Prepare `max_length` depending on other stopping criteria. input_ids_seq_length = input_ids.shape[-1] @@ -424,16 +382,7 @@ def generate( if not do_sample and num_beams == 1: logits_processor = self._get_logits_processor( - no_repeat_ngram_size, - min_length, - max_length, - eos_token_id, - forced_bos_token_id, - forced_eos_token_id, - input_ids_seq_length, - suppress_tokens=suppress_tokens, - begin_suppress_tokens=begin_suppress_tokens, - forced_decoder_ids=forced_decoder_ids, + no_repeat_ngram_size, min_length, max_length, eos_token_id, forced_bos_token_id, forced_eos_token_id ) return self._greedy_search( input_ids, @@ -448,16 +397,7 @@ def generate( elif do_sample and num_beams == 1: logits_warper = self._get_logits_warper(top_k=top_k, top_p=top_p, temperature=temperature) logits_processor = self._get_logits_processor( - no_repeat_ngram_size, - min_length, - max_length, - eos_token_id, - forced_bos_token_id, - forced_eos_token_id, - input_ids_seq_length, - suppress_tokens=suppress_tokens, - begin_suppress_tokens=begin_suppress_tokens, - forced_decoder_ids=forced_decoder_ids, + no_repeat_ngram_size, min_length, max_length, eos_token_id, forced_bos_token_id, forced_eos_token_id ) return self._sample( input_ids, @@ -486,16 +426,7 @@ def generate( ) logits_processor = self._get_logits_processor( - no_repeat_ngram_size, - min_length, - max_length, - eos_token_id, - forced_bos_token_id, - forced_eos_token_id, - input_ids_seq_length, - suppress_tokens=suppress_tokens, - begin_suppress_tokens=begin_suppress_tokens, - forced_decoder_ids=forced_decoder_ids, + no_repeat_ngram_size, min_length, max_length, eos_token_id, forced_bos_token_id, forced_eos_token_id ) return self._beam_search( @@ -547,10 +478,6 @@ def _get_logits_processor( eos_token_id: int, forced_bos_token_id: int, forced_eos_token_id: int, - input_ids_seq_length: int, - suppress_tokens: Optional[List[int]] = None, - begin_suppress_tokens: Optional[List[int]] = None, - forced_decoder_ids: Optional[List[int]] = None, ) -> FlaxLogitsProcessorList: """ This class returns a [`FlaxLogitsProcessorList`] list object that contains all relevant [`FlaxLogitsProcessor`] @@ -569,12 +496,6 @@ def _get_logits_processor( forced_eos_token_id = ( forced_eos_token_id if forced_eos_token_id is not None else self.config.forced_eos_token_id ) - suppress_tokens = suppress_tokens if suppress_tokens is not None else self.config.suppress_tokens - begin_suppress_tokens = ( - begin_suppress_tokens if begin_suppress_tokens is not None else self.config.begin_suppress_tokens - ) - if forced_decoder_ids is None and hasattr(self.config, "forced_decoder_ids"): - forced_decoder_ids = self.config.forced_decoder_ids # the following idea is largely copied from this PR: https://github.com/huggingface/transformers/pull/5420/files # all samplers can be found in `generation_utils_samplers.py` @@ -584,16 +505,6 @@ def _get_logits_processor( processors.append(FlaxForcedBOSTokenLogitsProcessor(forced_bos_token_id)) if forced_eos_token_id is not None: processors.append(FlaxForcedEOSTokenLogitsProcessor(max_length, forced_eos_token_id)) - if suppress_tokens is not None: - processors.append(FlaxSuppressTokensLogitsProcessor(suppress_tokens)) - if begin_suppress_tokens is not None: - begin_index = input_ids_seq_length - begin_index = begin_index if (input_ids_seq_length > 1 or forced_bos_token_id is None) else begin_index + 1 - if forced_decoder_ids is not None: - begin_index += forced_decoder_ids[-1][0] # generation starts after the last token that is forced - processors.append(FlaxSuppressTokensAtBeginLogitsProcessor(begin_suppress_tokens, begin_index)) - if forced_decoder_ids is not None: - processors.append(FlaxForceTokensLogitsProcessor(forced_decoder_ids)) return processors def _greedy_search( diff --git a/tests/models/whisper/test_modeling_flax_whisper.py b/tests/models/whisper/test_modeling_flax_whisper.py index d0fde5a694738e..220826cb84d809 100644 --- a/tests/models/whisper/test_modeling_flax_whisper.py +++ b/tests/models/whisper/test_modeling_flax_whisper.py @@ -286,4 +286,4 @@ def test_tiny_logits_librispeech(self): ] ) # fmt: on - self.assertTrue(np.allclose(logits[0][0, 0, :30], EXPECTED_LOGITS, atol=1e-4)) \ No newline at end of file + self.assertTrue(np.allclose(logits[0][0, 0, :30], EXPECTED_LOGITS, atol=1e-4)) From 828d8001e118c785ccd2d9314bae662fd040e1cd Mon Sep 17 00:00:00 2001 From: andyehrenberg Date: Tue, 29 Nov 2022 09:37:00 -0800 Subject: [PATCH 005/111] flax whisper docs --- .../models/whisper/modeling_flax_whisper.py | 92 +++++++++++++++++++ 1 file changed, 92 insertions(+) diff --git a/src/transformers/models/whisper/modeling_flax_whisper.py b/src/transformers/models/whisper/modeling_flax_whisper.py index f82ed70c9ec588..5aa2cbdf8d85e9 100644 --- a/src/transformers/models/whisper/modeling_flax_whisper.py +++ b/src/transformers/models/whisper/modeling_flax_whisper.py @@ -24,6 +24,8 @@ ACT2FN, FlaxPreTrainedModel, append_call_sample_docstring, + append_replace_return_docstrings, + overwrite_call_docstring, ) from ...utils import add_start_docstrings, add_start_docstrings_to_model_forward, logging, replace_return_docstrings from .configuration_whisper import WhisperConfig @@ -903,6 +905,22 @@ def encode( dropout_rng: PRNGKey = None, **kwargs, ): + r""" + Returns: + + Example: + + ```python + >>> from transformers import WhisperProcessor, FlaxWhisperForConditionalGeneration + >>> from datasets import load_dataset + + >>> processor = WhisperProcessor.from_pretrained("openai/whisper-tiny.en") + >>> model = FlaxWhisperForConditionalGeneration.from_pretrained("openai/whisper-tiny.en", from_pt=True) + >>> ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation") + >>> inputs = processor(ds[0]["audio"]["array"], return_tensors="np") + >>> input_features = inputs.input_features + >>> encoder_outputs = model.encode(input_features=input_features) + ```""" output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states @@ -946,6 +964,29 @@ def decode( params: dict = None, dropout_rng: PRNGKey = None, ): + r""" + Returns: + + Example: + + ```python + >>> from transformers import WhisperProcessor, FlaxWhisperForConditionalGeneration + >>> from datasets import load_dataset + + >>> processor = WhisperProcessor.from_pretrained("openai/whisper-tiny.en") + >>> model = FlaxWhisperForConditionalGeneration.from_pretrained("openai/whisper-tiny.en", from_pt=True) + >>> ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation") + >>> inputs = processor(ds[0]["audio"]["array"], return_tensors="np") + >>> input_features = inputs.input_features + >>> encoder_outputs = model.encode(input_features=input_features) + + >>> decoder_start_token_id = model.config.decoder_start_token_id + >>> decoder_input_ids = jnp.ones((inputs.input_ids.shape[0], 1), dtype="i4") * decoder_start_token_id + + >>> outputs = model.decode(decoder_input_ids, encoder_outputs) + >>> last_decoder_hidden_states = outputs.last_hidden_state + ```""" + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states @@ -1210,6 +1251,28 @@ def decode( params: dict = None, dropout_rng: PRNGKey = None, ): + r""" + Returns: + + Example: + + ```python + >>> from transformers import WhisperProcessor, FlaxWhisperForConditionalGeneration + >>> from datasets import load_dataset + + >>> processor = WhisperProcessor.from_pretrained("openai/whisper-tiny.en") + >>> model = FlaxWhisperForConditionalGeneration.from_pretrained("openai/whisper-tiny.en", from_pt=True) + >>> ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation") + >>> inputs = processor(ds[0]["audio"]["array"], return_tensors="np") + >>> input_features = inputs.input_features + >>> encoder_outputs = model.encode(input_features=input_features) + + >>> decoder_start_token_id = model.config.decoder_start_token_id + >>> decoder_input_ids = jnp.ones((inputs.input_ids.shape[0], 1), dtype="i4") * decoder_start_token_id + + >>> outputs = model.decode(decoder_input_ids, encoder_outputs) + >>> last_decoder_hidden_states = outputs.last_hidden_state + ```""" output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states @@ -1340,6 +1403,35 @@ def update_inputs_for_generation(self, model_outputs, model_kwargs): return model_kwargs +FLAX_WHISPER_CONDITIONAL_GENERATION_DOCSTRING = r""" + Returns: + + Transcription example: + + ```python + >>> from transformers import WhisperProcessor, FlaxWhisperForConditionalGeneration + >>> from datasets import load_dataset + + >>> processor = WhisperProcessor.from_pretrained("openai/whisper-tiny.en") + >>> model = FlaxWhisperForConditionalGeneration.from_pretrained("openai/whisper-tiny.en", from_pt=True) + >>> ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation") + >>> inputs = processor(ds[0]["audio"]["array"], return_tensors="np") + >>> input_features = inputs.input_features + >>> generated_ids = model.generate(input_ids=input_features) + >>> transcription = processor.batch_decode(generated_ids, skip_special_tokens=True)[0] + >>> transcription + ' Mr. Quilter is the apostle of the middle classes, and we are glad to welcome his gospel.' + ``` +""" + +overwrite_call_docstring( + FlaxWhisperForConditionalGeneration, WHISPER_INPUTS_DOCSTRING + FLAX_WHISPER_CONDITIONAL_GENERATION_DOCSTRING +) +append_replace_return_docstrings( + FlaxWhisperForConditionalGeneration, output_type=FlaxSeq2SeqLMOutput, config_class=_CONFIG_FOR_DOC +) + + def convert_unroll_to_scan(model, params: Union[Dict, FrozenDict]): r""" Convert a `PyTree` of unrolled model parameters to a scanned block of model parameters. This method can be used From baafb1c3c141441143f6a92496c88c96f2bbac59 Mon Sep 17 00:00:00 2001 From: andyehrenberg Date: Thu, 1 Dec 2022 08:05:41 -0800 Subject: [PATCH 006/111] docs --- src/transformers/models/whisper/modeling_flax_whisper.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/transformers/models/whisper/modeling_flax_whisper.py b/src/transformers/models/whisper/modeling_flax_whisper.py index 5aa2cbdf8d85e9..4b60e200bff932 100644 --- a/src/transformers/models/whisper/modeling_flax_whisper.py +++ b/src/transformers/models/whisper/modeling_flax_whisper.py @@ -43,7 +43,7 @@ WHISPER_START_DOCSTRING = r""" This model inherits from [`FlaxPreTrainedModel`]. Check the superclass documentation for the generic methods the - library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads + library implements for all its models (such as downloading or saving, resizing the input embeddings, pruning heads etc.) This model is also a Flax Linen [flax.nn.Module](https://flax.readthedocs.io/en/latest/_autosummary/flax.nn.module.html) subclass. Use it as a From 2da5a58a44d896e88eb4d6b9ce773d3b9cbd134b Mon Sep 17 00:00:00 2001 From: andyehrenberg Date: Thu, 1 Dec 2022 08:29:14 -0800 Subject: [PATCH 007/111] import order --- src/transformers/models/whisper/modeling_flax_whisper.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/transformers/models/whisper/modeling_flax_whisper.py b/src/transformers/models/whisper/modeling_flax_whisper.py index 4b60e200bff932..fc90bc7176b084 100644 --- a/src/transformers/models/whisper/modeling_flax_whisper.py +++ b/src/transformers/models/whisper/modeling_flax_whisper.py @@ -6,8 +6,8 @@ import jax.numpy as jnp from flax.core.frozen_dict import FrozenDict, freeze, unfreeze from flax.linen import combine_masks, make_causal_mask -from flax.linen.attention import dot_product_attention_weights from flax.linen import partitioning as nn_partitioning +from flax.linen.attention import dot_product_attention_weights from flax.traverse_util import flatten_dict, unflatten_dict from jax import lax from jax.random import PRNGKey From 00f695fec1dbab56b572f4257aa6b3c9570e386f Mon Sep 17 00:00:00 2001 From: andyehrenberg Date: Thu, 1 Dec 2022 08:49:17 -0800 Subject: [PATCH 008/111] import sorting --- src/transformers/models/whisper/modeling_flax_whisper.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/transformers/models/whisper/modeling_flax_whisper.py b/src/transformers/models/whisper/modeling_flax_whisper.py index fc90bc7176b084..63311ad5424b03 100644 --- a/src/transformers/models/whisper/modeling_flax_whisper.py +++ b/src/transformers/models/whisper/modeling_flax_whisper.py @@ -1,5 +1,5 @@ -from typing import Union, Dict, Optional, Tuple from functools import partial +from typing import Dict, Optional, Tuple, Union import flax.linen as nn import jax @@ -19,7 +19,6 @@ FlaxSeq2SeqLMOutput, FlaxSeq2SeqModelOutput, ) - from ...modeling_flax_utils import ( ACT2FN, FlaxPreTrainedModel, @@ -30,6 +29,7 @@ from ...utils import add_start_docstrings, add_start_docstrings_to_model_forward, logging, replace_return_docstrings from .configuration_whisper import WhisperConfig + scan_with_axes = nn_partitioning.scan_with_axes From 0ecc03b3513ec6c56c323d286d16d6f195f50199 Mon Sep 17 00:00:00 2001 From: andyehrenberg Date: Thu, 1 Dec 2022 09:35:32 -0800 Subject: [PATCH 009/111] isort --- src/transformers/__init__.py | 6 +----- 1 file changed, 1 insertion(+), 5 deletions(-) diff --git a/src/transformers/__init__.py b/src/transformers/__init__.py index 45af853af7729c..797788fc81b8c3 100644 --- a/src/transformers/__init__.py +++ b/src/transformers/__init__.py @@ -5938,11 +5938,7 @@ FlaxWav2Vec2Model, FlaxWav2Vec2PreTrainedModel, ) - from .models.whisper import ( - FlaxWhisperForConditionalGeneration, - FlaxWhisperModel, - FlaxWhisperPreTrainedModel, - ) + from .models.whisper import FlaxWhisperForConditionalGeneration, FlaxWhisperModel, FlaxWhisperPreTrainedModel from .models.xglm import FlaxXGLMForCausalLM, FlaxXGLMModel, FlaxXGLMPreTrainedModel from .models.xlm_roberta import ( FlaxXLMRobertaForMaskedLM, From f66a005df2d17df7c84c4ed4a09424e617fbeadb Mon Sep 17 00:00:00 2001 From: andyehrenberg Date: Thu, 1 Dec 2022 09:48:52 -0800 Subject: [PATCH 010/111] add dummy objects --- src/transformers/utils/dummy_flax_objects.py | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/src/transformers/utils/dummy_flax_objects.py b/src/transformers/utils/dummy_flax_objects.py index 26e52b4cec59f7..36f1932ba6b75e 100644 --- a/src/transformers/utils/dummy_flax_objects.py +++ b/src/transformers/utils/dummy_flax_objects.py @@ -17,6 +17,13 @@ def __init__(self, *args, **kwargs): requires_backends(self, ["flax"]) +class FlaxGenerationMixin(metaclass=DummyObject): + _backends = ["flax"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["flax"]) + + class FlaxLogitsProcessor(metaclass=DummyObject): _backends = ["flax"] From 175f3442940893828455c93499a8a6f934c3d21e Mon Sep 17 00:00:00 2001 From: andyehrenberg Date: Thu, 1 Dec 2022 10:03:27 -0800 Subject: [PATCH 011/111] doc formatting --- .../models/whisper/modeling_flax_whisper.py | 113 ++++++++++-------- 1 file changed, 62 insertions(+), 51 deletions(-) diff --git a/src/transformers/models/whisper/modeling_flax_whisper.py b/src/transformers/models/whisper/modeling_flax_whisper.py index 63311ad5424b03..baa5d3c1b63365 100644 --- a/src/transformers/models/whisper/modeling_flax_whisper.py +++ b/src/transformers/models/whisper/modeling_flax_whisper.py @@ -42,9 +42,9 @@ WHISPER_START_DOCSTRING = r""" - This model inherits from [`FlaxPreTrainedModel`]. Check the superclass documentation for the generic methods the - library implements for all its models (such as downloading or saving, resizing the input embeddings, pruning heads - etc.) + This model inherits from [`FlaxPreTrainedModel`]. Check the superclass documentation for the generic methods + the library implements for all its models (such as downloading or saving, resizing the input embeddings, + pruning heads etc.) This model is also a Flax Linen [flax.nn.Module](https://flax.readthedocs.io/en/latest/_autosummary/flax.nn.module.html) subclass. Use it as a regular Flax Module and refer to the Flax documentation for all matter related to general usage and behavior. @@ -56,7 +56,8 @@ Parameters: config ([`WhisperConfig`]): Model configuration class with all the parameters of the model. Initializing with a config file does not load the weights associated with the model, only the - configuration. Check out the [`~FlaxPreTrainedModel.from_pretrained`] method to load the model weights. + configuration. Check out the [`~FlaxPreTrainedModel.from_pretrained`] method to load the model + weights. dtype (`jax.numpy.dtype`, *optional*, defaults to `jax.numpy.float32`): The data type of the computation. Can be one of `jax.numpy.float32`, `jax.numpy.float16` (on GPUs) and `jax.numpy.bfloat16` (on TPUs). @@ -71,13 +72,14 @@ WHISPER_INPUTS_DOCSTRING = r""" Args: input_features (`numpy.ndarray` of shape `(batch_size, feature_size, sequence_length)`): - Float values mel features extracted from the raw speech waveform. Raw speech waveform can be obtained by - loading a `.flac` or `.wav` audio file into an array of type `List[float]` or a `numpy.ndarray`, *e.g.* via - the soundfile library (`pip install soundfile`). To prepare the array into `input_features`, the - [`WhisperFeatureExtractor`] should be used for extracting the mel features, padding and conversion into a - tensor of type `numpy.ndarray`. See [`~WhisperFeatureExtractor.__call__`] + Float values mel features extracted from the raw speech waveform. Raw speech waveform can be obtained + by loading a `.flac` or `.wav` audio file into an array of type `List[float]` or a `numpy.ndarray`, + *e.g.* via the soundfile library (`pip install soundfile`). To prepare the array into + `input_features`, the [`WhisperFeatureExtractor`] should be used for extracting the features, padding + and conversion into a tensor of type `numpy.ndarray`. See [`~WhisperFeatureExtractor.__call__`] attention_mask (`numpy.ndarray` of shape `(batch_size, sequence_length)`, *optional*): - Whisper does not support masking of the `input_features`, this argument is preserved for compatibility, but it is not used. By default the silence in the input log mel spectrogram are ignored. + Whisper does not support masking of the `input_features`, this argument is preserved for + compatibility, but is not used. By default the silence in the input log mel spectrogram are ignored. decoder_input_ids (`numpy.ndarray` of shape `(batch_size, target_sequence_length)`, *optional*): Indices of decoder input sequence tokens in the vocabulary. Indices can be obtained using [`WhisperTokenizer`]. See [`PreTrainedTokenizer.encode`] and @@ -85,21 +87,23 @@ [What are decoder input IDs?](../glossary#decoder-input-ids) Whisper uses the `decoder_start_token_id` as the starting token for `decoder_input_ids` generation. decoder_attention_mask (`numpy.ndarray` of shape `(batch_size, target_sequence_length)`, *optional*): - Default behavior: generate a tensor that ignores pad tokens in `decoder_input_ids`. Causal mask will also - be used by default. + Default behavior: generate a tensor that ignores pad tokens in `decoder_input_ids`. Causal mask will + also be used by default. If you want to change padding behavior, you should modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more information on the default strategy. position_ids (`numpy.ndarray` of shape `(batch_size, sequence_length)`, *optional*): - Whisper does not use position_ids in the encoder as input_features is always the same size and doesn't use masking, but this argument is preserved for compatibility. By default the silence in the input log mel spectrogram are ignored. + Whisper does not use position_ids in the encoder as input_features is always the same size and doesn't + use masking, but this argument is preserved for compatibility. By default the silence in the input log + mel spectrogram are ignored. decoder_position_ids (`numpy.ndarray` of shape `(batch_size, sequence_length)`, *optional*): Indices of positions of each decoder input sequence tokens in the position embeddings. Selected in the range `[0, config.max_position_embeddings - 1]`. output_attentions (`bool`, *optional*): - Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned - tensors for more detail. + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more detail. output_hidden_states (`bool`, *optional*): - Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for - more detail. + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors + for more detail. return_dict (`bool`, *optional*): Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. """ @@ -107,19 +111,21 @@ WHISPER_ENCODE_INPUTS_DOCSTRING = r""" Args: input_features (`numpy.ndarray` of shape `(batch_size, feature_size, sequence_length)`): - Float values mel features extracted from the raw speech waveform. Raw speech waveform can be obtained by - loading a `.flac` or `.wav` audio file into an array of type `List[float]` or a `numpy.ndarray`, *e.g.* via - the soundfile library (`pip install soundfile`). To prepare the array into `input_features`, the - [`WhisperFeatureExtractor`] should be used for extracting the mel features, padding and conversion into a - tensor of type `numpy.ndarray`. See [`~WhisperFeatureExtractor.__call__`] + Float values mel features extracted from the raw speech waveform. Raw speech waveform can be obtained + by loading a `.flac` or `.wav` audio file into an array of type `List[float]` or a `numpy.ndarray`, + *e.g.* via the soundfile library (`pip install soundfile`). To prepare the array into + `input_features`, the [`WhisperFeatureExtractor`] should be used for extracting the mel features, + padding and conversion into a tensor of type `numpy.ndarray`. See + [`~WhisperFeatureExtractor.__call__`]. attention_mask (`numpy.ndarray` of shape `(batch_size, sequence_length)`, *optional*): - Whisper does not support masking of the `input_features`, this argument is preserved for compatibility, but it is not used. By default the silence in the input log mel spectrogram are ignored. + Whisper does not support masking of the `input_features`, this argument is preserved for + compatibility, but is not used. By default the silence in the input log mel spectrogram are ignored. output_attentions (`bool`, *optional*): - Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned - tensors for more detail. + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more detail. output_hidden_states (`bool`, *optional*): - Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for - more detail. + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors + for more detail. return_dict (`bool`, *optional*): Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. """ @@ -132,28 +138,32 @@ [`PreTrainedTokenizer.__call__`] for details. [What are decoder input IDs?](../glossary#decoder-input-ids) encoder_outputs (`tuple(tuple(numpy.ndarray)`): - Tuple consists of (`last_hidden_state`, *optional*: `hidden_states`, *optional*: `attentions`) - `last_hidden_state` of shape `(batch_size, sequence_length, hidden_size)`, *optional*) is a sequence of - hidden-states at the output of the last layer of the encoder. Used in the cross-attention of the decoder. + Tuple consists of (`last_hidden_state`, *optional*: `hidden_states`, *optional*: `attentions`) + `last_hidden_state` of shape `(batch_size, sequence_length, hidden_size)`, *optional*) is a sequence + of hidden-states at the output of the last layer of the encoder. Used in the cross-attention of the + decoder. encoder_attention_mask (`numpy.ndarray` of shape `(batch_size, sequence_length)`, *optional*): - Whisper does not support masking of the `input_features`, this argument is preserved for compatibility, but it is not used. By default the silence in the input log mel spectrogram are ignored. + Whisper does not support masking of the `input_features`, this argument is preserved for compatibility, + but it is not used. By default the silence in the input log mel spectrogram are ignored. decoder_attention_mask (`numpy.ndarray` of shape `(batch_size, target_sequence_length)`, *optional*): - Default behavior: generate a tensor that ignores pad tokens in `decoder_input_ids`. Causal mask will also - be used by default. - If you want to change padding behavior, you should modify to your needs. See diagram 1 in [the - paper](https://arxiv.org/abs/1910.13461) for more information on the default strategy. + Default behavior: generate a tensor that ignores pad tokens in `decoder_input_ids`. Causal mask will + also be used by default. If you want to change padding behavior, you should modify to your needs. See + diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more information on the default + strategy. decoder_position_ids (`numpy.ndarray` of shape `(batch_size, sequence_length)`, *optional*): Indices of positions of each decoder input sequence tokens in the position embeddings. Selected in the range `[0, config.max_position_embeddings - 1]`. - past_key_values (`Dict[str, numpy.ndarray]`, *optional*, returned by `init_cache` or when passing previous `past_key_values`): - Dictionary of pre-computed hidden-states (key and values in the attention blocks) that can be used for fast - auto-regressive decoding. Pre-computed key and value hidden-states are of shape *[batch_size, max_length]*. + past_key_values (`Dict[str, numpy.ndarray]`, *optional*, returned by `init_cache` or when passing previous + `past_key_values`): + Dictionary of pre-computed hidden-states (key and values in the attention blocks) that can be used for + fast auto-regressive decoding. Pre-computed key and value hidden-states are of shape + *[batch_size, max_length]*. output_attentions (`bool`, *optional*): - Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned - tensors for more detail. + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more detail. output_hidden_states (`bool`, *optional*): - Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for - more detail. + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors + for more detail. return_dict (`bool`, *optional*): Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. """ @@ -853,15 +863,16 @@ def init_cache(self, batch_size, max_length, encoder_outputs): r""" Args: batch_size (`int`): - batch_size used for fast auto-regressive decoding. Defines the batch size of the initialized cache. - max_length (`int`): - maximum possible length for auto-regressive decoding. Defines the sequence length of the initialized + batch_size used for fast auto-regressive decoding. Defines the batch size of the initialized cache. + max_length (`int`): + maximum possible length for auto-regressive decoding. Defines the sequence length of the + initialized cache. encoder_outputs (`Union[FlaxBaseModelOutput, tuple(tuple(jnp.ndarray)]`): `encoder_outputs` consists of (`last_hidden_state`, *optional*: `hidden_states`, *optional*: - `attentions`). `last_hidden_state` of shape `(batch_size, sequence_length, hidden_size)`, *optional*) - is a sequence of hidden-states at the output of the last layer of the encoder. Used in the - cross-attention of the decoder. + `attentions`). `last_hidden_state` of shape `(batch_size, sequence_length, hidden_size)`, + *optional*) is a sequence of hidden-states at the output of the last layer of the encoder. + Used in the cross-attention of the decoder. """ # init input variables to retrieve cache decoder_input_ids = jnp.ones((batch_size, max_length), dtype="i4") @@ -1434,9 +1445,9 @@ def update_inputs_for_generation(self, model_outputs, model_kwargs): def convert_unroll_to_scan(model, params: Union[Dict, FrozenDict]): r""" - Convert a `PyTree` of unrolled model parameters to a scanned block of model parameters. This method can be used - to explicitly convert the model parameters to scanned format. This returns a new `params` tree and does not - convert the `params` in place. + Convert a `PyTree` of unrolled model parameters to a scanned block of model parameters. This method can be + used to explicitly convert the model parameters to scanned format. This returns a new `params` tree and does + not convert the `params` in place. ```""" if isinstance(params, FrozenDict): params = unfreeze(params) From 3329e6ceffdbe361d778b6c534931d78fccacc90 Mon Sep 17 00:00:00 2001 From: andyehrenberg Date: Thu, 1 Dec 2022 10:10:58 -0800 Subject: [PATCH 012/111] formatting --- src/transformers/models/whisper/modeling_flax_whisper.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/transformers/models/whisper/modeling_flax_whisper.py b/src/transformers/models/whisper/modeling_flax_whisper.py index baa5d3c1b63365..70451d6942dbe9 100644 --- a/src/transformers/models/whisper/modeling_flax_whisper.py +++ b/src/transformers/models/whisper/modeling_flax_whisper.py @@ -863,14 +863,14 @@ def init_cache(self, batch_size, max_length, encoder_outputs): r""" Args: batch_size (`int`): - batch_size used for fast auto-regressive decoding. Defines the batch size of the initialized + batch_size used for fast auto-regressive decoding. Defines the batch size of the initialized cache. max_length (`int`): - maximum possible length for auto-regressive decoding. Defines the sequence length of the + maximum possible length for auto-regressive decoding. Defines the sequence length of the initialized cache. encoder_outputs (`Union[FlaxBaseModelOutput, tuple(tuple(jnp.ndarray)]`): `encoder_outputs` consists of (`last_hidden_state`, *optional*: `hidden_states`, *optional*: - `attentions`). `last_hidden_state` of shape `(batch_size, sequence_length, hidden_size)`, + `attentions`). `last_hidden_state` of shape `(batch_size, sequence_length, hidden_size)`, *optional*) is a sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention of the decoder. """ @@ -1445,7 +1445,7 @@ def update_inputs_for_generation(self, model_outputs, model_kwargs): def convert_unroll_to_scan(model, params: Union[Dict, FrozenDict]): r""" - Convert a `PyTree` of unrolled model parameters to a scanned block of model parameters. This method can be + Convert a `PyTree` of unrolled model parameters to a scanned block of model parameters. This method can be used to explicitly convert the model parameters to scanned format. This returns a new `params` tree and does not convert the `params` in place. ```""" From c05089b8b0c738250aa46d9f4fbbc0c0e7c9e710 Mon Sep 17 00:00:00 2001 From: andyehrenberg Date: Thu, 1 Dec 2022 10:19:53 -0800 Subject: [PATCH 013/111] remove trailing whitespaces --- .../models/whisper/modeling_flax_whisper.py | 26 +++++++++---------- 1 file changed, 13 insertions(+), 13 deletions(-) diff --git a/src/transformers/models/whisper/modeling_flax_whisper.py b/src/transformers/models/whisper/modeling_flax_whisper.py index 70451d6942dbe9..06895a8acac124 100644 --- a/src/transformers/models/whisper/modeling_flax_whisper.py +++ b/src/transformers/models/whisper/modeling_flax_whisper.py @@ -43,7 +43,7 @@ WHISPER_START_DOCSTRING = r""" This model inherits from [`FlaxPreTrainedModel`]. Check the superclass documentation for the generic methods - the library implements for all its models (such as downloading or saving, resizing the input embeddings, + the library implements for all its models (such as downloading or saving, resizing the input embeddings, pruning heads etc.) This model is also a Flax Linen [flax.nn.Module](https://flax.readthedocs.io/en/latest/_autosummary/flax.nn.module.html) subclass. Use it as a @@ -74,11 +74,11 @@ input_features (`numpy.ndarray` of shape `(batch_size, feature_size, sequence_length)`): Float values mel features extracted from the raw speech waveform. Raw speech waveform can be obtained by loading a `.flac` or `.wav` audio file into an array of type `List[float]` or a `numpy.ndarray`, - *e.g.* via the soundfile library (`pip install soundfile`). To prepare the array into + *e.g.* via the soundfile library (`pip install soundfile`). To prepare the array into `input_features`, the [`WhisperFeatureExtractor`] should be used for extracting the features, padding and conversion into a tensor of type `numpy.ndarray`. See [`~WhisperFeatureExtractor.__call__`] attention_mask (`numpy.ndarray` of shape `(batch_size, sequence_length)`, *optional*): - Whisper does not support masking of the `input_features`, this argument is preserved for + Whisper does not support masking of the `input_features`, this argument is preserved for compatibility, but is not used. By default the silence in the input log mel spectrogram are ignored. decoder_input_ids (`numpy.ndarray` of shape `(batch_size, target_sequence_length)`, *optional*): Indices of decoder input sequence tokens in the vocabulary. @@ -87,7 +87,7 @@ [What are decoder input IDs?](../glossary#decoder-input-ids) Whisper uses the `decoder_start_token_id` as the starting token for `decoder_input_ids` generation. decoder_attention_mask (`numpy.ndarray` of shape `(batch_size, target_sequence_length)`, *optional*): - Default behavior: generate a tensor that ignores pad tokens in `decoder_input_ids`. Causal mask will + Default behavior: generate a tensor that ignores pad tokens in `decoder_input_ids`. Causal mask will also be used by default. If you want to change padding behavior, you should modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more information on the default strategy. @@ -99,7 +99,7 @@ Indices of positions of each decoder input sequence tokens in the position embeddings. Selected in the range `[0, config.max_position_embeddings - 1]`. output_attentions (`bool`, *optional*): - Whether or not to return the attentions tensors of all attention layers. See `attentions` under + Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned tensors for more detail. output_hidden_states (`bool`, *optional*): Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors @@ -113,15 +113,15 @@ input_features (`numpy.ndarray` of shape `(batch_size, feature_size, sequence_length)`): Float values mel features extracted from the raw speech waveform. Raw speech waveform can be obtained by loading a `.flac` or `.wav` audio file into an array of type `List[float]` or a `numpy.ndarray`, - *e.g.* via the soundfile library (`pip install soundfile`). To prepare the array into - `input_features`, the [`WhisperFeatureExtractor`] should be used for extracting the mel features, + *e.g.* via the soundfile library (`pip install soundfile`). To prepare the array into + `input_features`, the [`WhisperFeatureExtractor`] should be used for extracting the mel features, padding and conversion into a tensor of type `numpy.ndarray`. See [`~WhisperFeatureExtractor.__call__`]. attention_mask (`numpy.ndarray` of shape `(batch_size, sequence_length)`, *optional*): - Whisper does not support masking of the `input_features`, this argument is preserved for + Whisper does not support masking of the `input_features`, this argument is preserved for compatibility, but is not used. By default the silence in the input log mel spectrogram are ignored. output_attentions (`bool`, *optional*): - Whether or not to return the attentions tensors of all attention layers. See `attentions` under + Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned tensors for more detail. output_hidden_states (`bool`, *optional*): Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors @@ -138,9 +138,9 @@ [`PreTrainedTokenizer.__call__`] for details. [What are decoder input IDs?](../glossary#decoder-input-ids) encoder_outputs (`tuple(tuple(numpy.ndarray)`): - Tuple consists of (`last_hidden_state`, *optional*: `hidden_states`, *optional*: `attentions`) + Tuple consists of (`last_hidden_state`, *optional*: `hidden_states`, *optional*: `attentions`) `last_hidden_state` of shape `(batch_size, sequence_length, hidden_size)`, *optional*) is a sequence - of hidden-states at the output of the last layer of the encoder. Used in the cross-attention of the + of hidden-states at the output of the last layer of the encoder. Used in the cross-attention of the decoder. encoder_attention_mask (`numpy.ndarray` of shape `(batch_size, sequence_length)`, *optional*): Whisper does not support masking of the `input_features`, this argument is preserved for compatibility, @@ -153,10 +153,10 @@ decoder_position_ids (`numpy.ndarray` of shape `(batch_size, sequence_length)`, *optional*): Indices of positions of each decoder input sequence tokens in the position embeddings. Selected in the range `[0, config.max_position_embeddings - 1]`. - past_key_values (`Dict[str, numpy.ndarray]`, *optional*, returned by `init_cache` or when passing previous + past_key_values (`Dict[str, numpy.ndarray]`, *optional*, returned by `init_cache` or when passing previous `past_key_values`): Dictionary of pre-computed hidden-states (key and values in the attention blocks) that can be used for - fast auto-regressive decoding. Pre-computed key and value hidden-states are of shape + fast auto-regressive decoding. Pre-computed key and value hidden-states are of shape *[batch_size, max_length]*. output_attentions (`bool`, *optional*): Whether or not to return the attentions tensors of all attention layers. See `attentions` under From 7551181b5fb8272d9029c76f675c6f9390b3f59a Mon Sep 17 00:00:00 2001 From: andyehrenberg Date: Thu, 1 Dec 2022 10:31:07 -0800 Subject: [PATCH 014/111] fix flax whisper docs --- .../models/whisper/modeling_flax_whisper.py | 146 ++++++++---------- 1 file changed, 67 insertions(+), 79 deletions(-) diff --git a/src/transformers/models/whisper/modeling_flax_whisper.py b/src/transformers/models/whisper/modeling_flax_whisper.py index 06895a8acac124..7719d54d68051f 100644 --- a/src/transformers/models/whisper/modeling_flax_whisper.py +++ b/src/transformers/models/whisper/modeling_flax_whisper.py @@ -42,10 +42,9 @@ WHISPER_START_DOCSTRING = r""" - This model inherits from [`FlaxPreTrainedModel`]. Check the superclass documentation for the generic methods - the library implements for all its models (such as downloading or saving, resizing the input embeddings, - pruning heads etc.) - This model is also a Flax Linen + This model inherits from [`FlaxPreTrainedModel`]. Check the superclass documentation for the generic methods the + library implements for all its models (such as downloading or saving, resizing the input embeddings, pruning heads + etc.) This model is also a Flax Linen [flax.nn.Module](https://flax.readthedocs.io/en/latest/_autosummary/flax.nn.module.html) subclass. Use it as a regular Flax Module and refer to the Flax documentation for all matter related to general usage and behavior. Finally, this model supports inherent JAX features such as: @@ -56,54 +55,49 @@ Parameters: config ([`WhisperConfig`]): Model configuration class with all the parameters of the model. Initializing with a config file does not load the weights associated with the model, only the - configuration. Check out the [`~FlaxPreTrainedModel.from_pretrained`] method to load the model - weights. + configuration. Check out the [`~FlaxPreTrainedModel.from_pretrained`] method to load the model weights. dtype (`jax.numpy.dtype`, *optional*, defaults to `jax.numpy.float32`): The data type of the computation. Can be one of `jax.numpy.float32`, `jax.numpy.float16` (on GPUs) and - `jax.numpy.bfloat16` (on TPUs). - This can be used to enable mixed-precision training or half-precision inference on GPUs or TPUs. If - specified all the computation will be performed with the given `dtype`. + `jax.numpy.bfloat16` (on TPUs). This can be used to enable mixed-precision training or half-precision + inference on GPUs or TPUs. If specified all the computation will be performed with the given `dtype`. **Note that this only specifies the dtype of the computation and does not influence the dtype of model - parameters.** - If you wish to change the dtype of the model parameters, see [`~FlaxPreTrainedModel.to_fp16`] and - [`~FlaxPreTrainedModel.to_bf16`]. + parameters.** If you wish to change the dtype of the model parameters, see [`~FlaxPreTrainedModel.to_fp16`] + and [`~FlaxPreTrainedModel.to_bf16`]. """ WHISPER_INPUTS_DOCSTRING = r""" Args: input_features (`numpy.ndarray` of shape `(batch_size, feature_size, sequence_length)`): - Float values mel features extracted from the raw speech waveform. Raw speech waveform can be obtained - by loading a `.flac` or `.wav` audio file into an array of type `List[float]` or a `numpy.ndarray`, - *e.g.* via the soundfile library (`pip install soundfile`). To prepare the array into - `input_features`, the [`WhisperFeatureExtractor`] should be used for extracting the features, padding - and conversion into a tensor of type `numpy.ndarray`. See [`~WhisperFeatureExtractor.__call__`] + Float values mel features extracted from the raw speech waveform. Raw speech waveform can be obtained by + loading a `.flac` or `.wav` audio file into an array of type `List[float]` or a `numpy.ndarray`, *e.g.* via + the soundfile library (`pip install soundfile`). To prepare the array into `input_features`, the + [`WhisperFeatureExtractor`] should be used for extracting the features, padding and conversion into a + tensor of type `numpy.ndarray`. See [`~WhisperFeatureExtractor.__call__`] attention_mask (`numpy.ndarray` of shape `(batch_size, sequence_length)`, *optional*): - Whisper does not support masking of the `input_features`, this argument is preserved for - compatibility, but is not used. By default the silence in the input log mel spectrogram are ignored. + Whisper does not support masking of the `input_features`, this argument is preserved for compatibility, but + is not used. By default the silence in the input log mel spectrogram are ignored. decoder_input_ids (`numpy.ndarray` of shape `(batch_size, target_sequence_length)`, *optional*): - Indices of decoder input sequence tokens in the vocabulary. - Indices can be obtained using [`WhisperTokenizer`]. See [`PreTrainedTokenizer.encode`] and - [`PreTrainedTokenizer.__call__`] for details. - [What are decoder input IDs?](../glossary#decoder-input-ids) - Whisper uses the `decoder_start_token_id` as the starting token for `decoder_input_ids` generation. + Indices of decoder input sequence tokens in the vocabulary. Indices can be obtained using + [`WhisperTokenizer`]. See [`PreTrainedTokenizer.encode`] and [`PreTrainedTokenizer.__call__`] for details. + [What are decoder input IDs?](../glossary#decoder-input-ids) Whisper uses the `decoder_start_token_id` as + the starting token for `decoder_input_ids` generation. decoder_attention_mask (`numpy.ndarray` of shape `(batch_size, target_sequence_length)`, *optional*): - Default behavior: generate a tensor that ignores pad tokens in `decoder_input_ids`. Causal mask will - also be used by default. - If you want to change padding behavior, you should modify to your needs. See diagram 1 in [the - paper](https://arxiv.org/abs/1910.13461) for more information on the default strategy. + Default behavior: generate a tensor that ignores pad tokens in `decoder_input_ids`. Causal mask will also + be used by default. If you want to change padding behavior, you should modify to your needs. See diagram 1 + in [the paper](https://arxiv.org/abs/1910.13461) for more information on the default strategy. position_ids (`numpy.ndarray` of shape `(batch_size, sequence_length)`, *optional*): - Whisper does not use position_ids in the encoder as input_features is always the same size and doesn't - use masking, but this argument is preserved for compatibility. By default the silence in the input log - mel spectrogram are ignored. + Whisper does not use position_ids in the encoder as input_features is always the same size and doesn't use + masking, but this argument is preserved for compatibility. By default the silence in the input log mel + spectrogram are ignored. decoder_position_ids (`numpy.ndarray` of shape `(batch_size, sequence_length)`, *optional*): Indices of positions of each decoder input sequence tokens in the position embeddings. Selected in the range `[0, config.max_position_embeddings - 1]`. output_attentions (`bool`, *optional*): - Whether or not to return the attentions tensors of all attention layers. See `attentions` under - returned tensors for more detail. + Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned + tensors for more detail. output_hidden_states (`bool`, *optional*): - Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors - for more detail. + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for + more detail. return_dict (`bool`, *optional*): Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. """ @@ -111,21 +105,20 @@ WHISPER_ENCODE_INPUTS_DOCSTRING = r""" Args: input_features (`numpy.ndarray` of shape `(batch_size, feature_size, sequence_length)`): - Float values mel features extracted from the raw speech waveform. Raw speech waveform can be obtained - by loading a `.flac` or `.wav` audio file into an array of type `List[float]` or a `numpy.ndarray`, - *e.g.* via the soundfile library (`pip install soundfile`). To prepare the array into - `input_features`, the [`WhisperFeatureExtractor`] should be used for extracting the mel features, - padding and conversion into a tensor of type `numpy.ndarray`. See - [`~WhisperFeatureExtractor.__call__`]. + Float values mel features extracted from the raw speech waveform. Raw speech waveform can be obtained by + loading a `.flac` or `.wav` audio file into an array of type `List[float]` or a `numpy.ndarray`, *e.g.* via + the soundfile library (`pip install soundfile`). To prepare the array into `input_features`, the + [`WhisperFeatureExtractor`] should be used for extracting the mel features, padding and conversion into a + tensor of type `numpy.ndarray`. See [`~WhisperFeatureExtractor.__call__`]. attention_mask (`numpy.ndarray` of shape `(batch_size, sequence_length)`, *optional*): - Whisper does not support masking of the `input_features`, this argument is preserved for - compatibility, but is not used. By default the silence in the input log mel spectrogram are ignored. + Whisper does not support masking of the `input_features`, this argument is preserved for compatibility, but + is not used. By default the silence in the input log mel spectrogram are ignored. output_attentions (`bool`, *optional*): - Whether or not to return the attentions tensors of all attention layers. See `attentions` under - returned tensors for more detail. + Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned + tensors for more detail. output_hidden_states (`bool`, *optional*): - Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors - for more detail. + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for + more detail. return_dict (`bool`, *optional*): Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. """ @@ -133,37 +126,33 @@ WHISPER_DECODE_INPUTS_DOCSTRING = r""" Args: decoder_input_ids (`numpy.ndarray` of shape `(batch_size, target_sequence_length)`): - Indices of decoder input sequence tokens in the vocabulary. - Indices can be obtained using [`WhisperTokenizer`]. See [`PreTrainedTokenizer.encode`] and - [`PreTrainedTokenizer.__call__`] for details. + Indices of decoder input sequence tokens in the vocabulary. Indices can be obtained using + [`WhisperTokenizer`]. See [`PreTrainedTokenizer.encode`] and [`PreTrainedTokenizer.__call__`] for details. [What are decoder input IDs?](../glossary#decoder-input-ids) encoder_outputs (`tuple(tuple(numpy.ndarray)`): Tuple consists of (`last_hidden_state`, *optional*: `hidden_states`, *optional*: `attentions`) - `last_hidden_state` of shape `(batch_size, sequence_length, hidden_size)`, *optional*) is a sequence - of hidden-states at the output of the last layer of the encoder. Used in the cross-attention of the - decoder. + `last_hidden_state` of shape `(batch_size, sequence_length, hidden_size)`, *optional*) is a sequence of + hidden-states at the output of the last layer of the encoder. Used in the cross-attention of the decoder. encoder_attention_mask (`numpy.ndarray` of shape `(batch_size, sequence_length)`, *optional*): Whisper does not support masking of the `input_features`, this argument is preserved for compatibility, but it is not used. By default the silence in the input log mel spectrogram are ignored. decoder_attention_mask (`numpy.ndarray` of shape `(batch_size, target_sequence_length)`, *optional*): - Default behavior: generate a tensor that ignores pad tokens in `decoder_input_ids`. Causal mask will - also be used by default. If you want to change padding behavior, you should modify to your needs. See - diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more information on the default - strategy. + Default behavior: generate a tensor that ignores pad tokens in `decoder_input_ids`. Causal mask will also + be used by default. If you want to change padding behavior, you should modify to your needs. See diagram 1 + in [the paper](https://arxiv.org/abs/1910.13461) for more information on the default strategy. decoder_position_ids (`numpy.ndarray` of shape `(batch_size, sequence_length)`, *optional*): Indices of positions of each decoder input sequence tokens in the position embeddings. Selected in the range `[0, config.max_position_embeddings - 1]`. past_key_values (`Dict[str, numpy.ndarray]`, *optional*, returned by `init_cache` or when passing previous `past_key_values`): - Dictionary of pre-computed hidden-states (key and values in the attention blocks) that can be used for - fast auto-regressive decoding. Pre-computed key and value hidden-states are of shape - *[batch_size, max_length]*. + Dictionary of pre-computed hidden-states (key and values in the attention blocks) that can be used for fast + auto-regressive decoding. Pre-computed key and value hidden-states are of shape *[batch_size, max_length]*. output_attentions (`bool`, *optional*): - Whether or not to return the attentions tensors of all attention layers. See `attentions` under - returned tensors for more detail. + Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned + tensors for more detail. output_hidden_states (`bool`, *optional*): - Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors - for more detail. + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for + more detail. return_dict (`bool`, *optional*): Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. """ @@ -863,16 +852,15 @@ def init_cache(self, batch_size, max_length, encoder_outputs): r""" Args: batch_size (`int`): - batch_size used for fast auto-regressive decoding. Defines the batch size of the initialized - cache. + batch_size used for fast auto-regressive decoding. Defines the batch size of the initialized cache. max_length (`int`): - maximum possible length for auto-regressive decoding. Defines the sequence length of the - initialized cache. + maximum possible length for auto-regressive decoding. Defines the sequence length of the initialized + cache. encoder_outputs (`Union[FlaxBaseModelOutput, tuple(tuple(jnp.ndarray)]`): `encoder_outputs` consists of (`last_hidden_state`, *optional*: `hidden_states`, *optional*: - `attentions`). `last_hidden_state` of shape `(batch_size, sequence_length, hidden_size)`, - *optional*) is a sequence of hidden-states at the output of the last layer of the encoder. - Used in the cross-attention of the decoder. + `attentions`). `last_hidden_state` of shape `(batch_size, sequence_length, hidden_size)`, *optional*) + is a sequence of hidden-states at the output of the last layer of the encoder. Used in the + cross-attention of the decoder. """ # init input variables to retrieve cache decoder_input_ids = jnp.ones((batch_size, max_length), dtype="i4") @@ -990,8 +978,8 @@ def decode( >>> inputs = processor(ds[0]["audio"]["array"], return_tensors="np") >>> input_features = inputs.input_features >>> encoder_outputs = model.encode(input_features=input_features) + >>> decoder_start_token_id = model.config.decoder_start_token_id - >>> decoder_start_token_id = model.config.decoder_start_token_id >>> decoder_input_ids = jnp.ones((inputs.input_ids.shape[0], 1), dtype="i4") * decoder_start_token_id >>> outputs = model.decode(decoder_input_ids, encoder_outputs) @@ -1277,8 +1265,8 @@ def decode( >>> inputs = processor(ds[0]["audio"]["array"], return_tensors="np") >>> input_features = inputs.input_features >>> encoder_outputs = model.encode(input_features=input_features) + >>> decoder_start_token_id = model.config.decoder_start_token_id - >>> decoder_start_token_id = model.config.decoder_start_token_id >>> decoder_input_ids = jnp.ones((inputs.input_ids.shape[0], 1), dtype="i4") * decoder_start_token_id >>> outputs = model.decode(decoder_input_ids, encoder_outputs) @@ -1445,9 +1433,9 @@ def update_inputs_for_generation(self, model_outputs, model_kwargs): def convert_unroll_to_scan(model, params: Union[Dict, FrozenDict]): r""" - Convert a `PyTree` of unrolled model parameters to a scanned block of model parameters. This method can be - used to explicitly convert the model parameters to scanned format. This returns a new `params` tree and does - not convert the `params` in place. + Convert a `PyTree` of unrolled model parameters to a scanned block of model parameters. This method can be used to + explicitly convert the model parameters to scanned format. This returns a new `params` tree and does not convert + the `params` in place. ```""" if isinstance(params, FrozenDict): params = unfreeze(params) @@ -1481,9 +1469,9 @@ def convert_unroll_to_scan(model, params: Union[Dict, FrozenDict]): def convert_scan_to_unroll(model, params: Union[Dict, FrozenDict]): r""" - Convert a `PyTree` of scanned model parameters to an unrolled stack of model parameters. This method can be - used to explicitly convert the model parameters to unrolled format. This returns a new `params` tree and does - not convert the `params` in place. + Convert a `PyTree` of scanned model parameters to an unrolled stack of model parameters. This method can be used to + explicitly convert the model parameters to unrolled format. This returns a new `params` tree and does not convert + the `params` in place. ```""" if isinstance(params, FrozenDict): From e255a97e9660173ee0804481efe76164a6189de7 Mon Sep 17 00:00:00 2001 From: andyehrenberg Date: Fri, 2 Dec 2022 09:04:07 -0800 Subject: [PATCH 015/111] add generation logic to unlock flax whisper --- .../generation/flax_logits_process.py | 81 +++++++++++ src/transformers/generation/flax_utils.py | 135 +++++++++++++++--- 2 files changed, 193 insertions(+), 23 deletions(-) diff --git a/src/transformers/generation/flax_logits_process.py b/src/transformers/generation/flax_logits_process.py index c8abeea0982374..8ae0b44729d35a 100644 --- a/src/transformers/generation/flax_logits_process.py +++ b/src/transformers/generation/flax_logits_process.py @@ -265,3 +265,84 @@ def __call__(self, input_ids: jnp.ndarray, scores: jnp.ndarray, cur_len: int) -> scores = jnp.where(apply_penalty, scores.at[:, self.eos_token_id].set(-float("inf")), scores) return scores + + +class FlaxSuppressTokensAtBeginLogitsProcessor: + r""" + Args: + [`FlaxLogitsProcessor`] supressing a list of tokens at an index. + begin_suppress_tokens (`list`): + Tokens to not sample. + begin_index (`int`): + Index where the tokens are suppressed. + """ + + def __init__(self, begin_suppress_tokens, begin_index): + self.begin_suppress_tokens = list(begin_suppress_tokens) + self.begin_index = begin_index + + def __call__(self, input_ids, scores, cur_len: int): + apply_penalty = 1 - jnp.bool_(cur_len - self.begin_index) + + scores = jnp.where(apply_penalty, scores.at[:, self.begin_suppress_tokens].set(-float("inf")), scores) + + return scores + + +class FlaxSuppressTokensLogitsProcessor(FlaxLogitsProcessor): + r""" + Args: + [`FlaxLogitsProcessor`] supressing a list of tokens. + suppress_tokens (`list`): + Tokens to not sample. + """ + + def __init__(self, suppress_tokens: list): + self.suppress_tokens = list(suppress_tokens) + + def __call__(self, input_ids: jnp.ndarray, scores: jnp.ndarray, cur_len: int) -> jnp.ndarray: + scores = scores.at[..., self.suppress_tokens].set(-float("inf")) + + return scores + + +class FlaxForceTokenAtIdxLogitsProcessor(FlaxLogitsProcessor): + r""" + Args: + [`FlaxLogitsProcessor`] that forces a token to be sampled at an index. + apply_idx (`int`): + Index where sampling is forced. + token_id (`int`): + Token that is forced to be sampled. + """ + + def __init__(self, apply_idx: int, token_id: int): + self.apply_idx = apply_idx + self.token_id = token_id + + def __call__(self, input_ids: jnp.ndarray, scores: jnp.ndarray, cur_len: int) -> jnp.ndarray: + new_scores = jnp.full(scores.shape, -float("inf")) + + apply_penalty = 1 - jnp.bool_(cur_len - self.apply_idx) + + scores = jnp.where(apply_penalty, new_scores.at[:, self.token_id].set(0), scores) + + return scores + + +class FlaxForceTokensLogitsProcessor(FlaxLogitsProcessor): + r""" + Args: + [`FlaxLogitsProcessor`] that forces tokens to be sampled at given indices. + force_token_map (`list`): + Map giving token ids and indices where they will be forced to be sampled. + """ + + def __init__(self, force_token_map): + self.processors = [FlaxForceTokenAtIdxLogitsProcessor(i[0], i[1]) for i in force_token_map] + + def __call__(self, input_ids: jnp.ndarray, scores: jnp.ndarray, cur_len: int) -> jnp.ndarray: + for processor in self.processors: + scores = processor(input_ids, scores, cur_len) + + return scores diff --git a/src/transformers/generation/flax_utils.py b/src/transformers/generation/flax_utils.py index 5d936ce5b1dccd..21686e456e3f77 100644 --- a/src/transformers/generation/flax_utils.py +++ b/src/transformers/generation/flax_utils.py @@ -18,7 +18,7 @@ import inspect import warnings from functools import partial -from typing import Any, Dict, Optional +from typing import Any, Dict, List, Optional import numpy as np @@ -36,8 +36,11 @@ from .flax_logits_process import ( FlaxForcedBOSTokenLogitsProcessor, FlaxForcedEOSTokenLogitsProcessor, + FlaxForceTokensLogitsProcessor, FlaxLogitsProcessorList, FlaxMinLengthLogitsProcessor, + FlaxSuppressTokensAtBeginLogitsProcessor, + FlaxSuppressTokensLogitsProcessor, FlaxTemperatureLogitsWarper, FlaxTopKLogitsWarper, FlaxTopPLogitsWarper, @@ -50,10 +53,8 @@ @flax.struct.dataclass class FlaxGreedySearchOutput(ModelOutput): """ - Flax Base class for outputs of decoder-only generation models using greedy search. - - Args: + Flax Base class for outputs of decoder-only generation models using greedy search. sequences (`jnp.ndarray` of shape `(batch_size, max_length)`): The generated sequences. """ @@ -64,10 +65,8 @@ class FlaxGreedySearchOutput(ModelOutput): @flax.struct.dataclass class FlaxSampleOutput(ModelOutput): """ - Flax Base class for outputs of decoder-only generation models using sampling. - - Args: + Flax Base class for outputs of decoder-only generation models using sampling. sequences (`jnp.ndarray` of shape `(batch_size, max_length)`): The generated sequences. """ @@ -78,10 +77,8 @@ class FlaxSampleOutput(ModelOutput): @flax.struct.dataclass class FlaxBeamSearchOutput(ModelOutput): """ - Flax Base class for outputs of decoder-only generation models using greedy search. - - Args: + Flax Base class for outputs of decoder-only generation models using greedy search. sequences (`jnp.ndarray` of shape `(batch_size, max_length)`): The generated sequences. scores (`jnp.ndarray` of shape `(batch_size,)`): @@ -125,9 +122,7 @@ class BeamSearchState: class FlaxGenerationMixin: """ A class containing all functions for auto-regressive text generation, to be used as a mixin in - [`FlaxPreTrainedModel`]. - - The class exposes [`~generation.FlaxGenerationMixin.generate`], which can be used for: + [`FlaxPreTrainedModel`]. The class exposes [`~generation.FlaxGenerationMixin.generate`], which can be used for: - *greedy decoding* by calling [`~generation.FlaxGenerationMixin._greedy_search`] if `num_beams=1` and `do_sample=False`. - *multinomial sampling* by calling [`~generation.FlaxGenerationMixin._sample`] if `num_beams=1` and @@ -155,6 +150,42 @@ def _prepare_encoder_decoder_kwargs_for_generation(self, input_ids, params, mode model_kwargs["encoder_outputs"] = self.encode(input_ids, params=params, return_dict=True, **encoder_kwargs) return model_kwargs + def _prepare_decoder_input_ids_for_generation( + self, + batch_size: int, + decoder_start_token_id: int = None, + bos_token_id: int = None, + model_kwargs: Optional[Dict[str, jnp.ndarray]] = None, + ) -> jnp.ndarray: + if model_kwargs is not None and "decoder_input_ids" in model_kwargs: + # Only use this arg if not None, otherwise just remove from model_kwargs + decoder_input_ids = model_kwargs.pop("decoder_input_ids") + if decoder_input_ids is not None: + return decoder_input_ids + decoder_start_token_id = self._get_decoder_start_token_id(decoder_start_token_id, bos_token_id) + return jnp.array(decoder_start_token_id).reshape(1, -1).repeat(batch_size, axis=0) + + def _get_decoder_start_token_id( + self, + decoder_start_token_id: int = None, + bos_token_id: int = None, + ) -> int: + decoder_start_token_id = ( + decoder_start_token_id if decoder_start_token_id is not None else self.config.decoder_start_token_id + ) + bos_token_id = bos_token_id if bos_token_id is not None else self.config.bos_token_id + + if decoder_start_token_id is not None: + start_token = decoder_start_token_id + elif bos_token_id is not None: + start_token = bos_token_id + else: + raise ValueError( + "`decoder_start_token_id` or `bos_token_id` has to be defined for encoder-decoder generation." + ) + + return start_token + @staticmethod def _expand_to_num_beams(tensor, num_beams): return jnp.broadcast_to(tensor[:, None], (tensor.shape[0], num_beams) + tensor.shape[1:]) @@ -227,6 +258,9 @@ def generate( min_length: Optional[int] = None, forced_bos_token_id: Optional[int] = None, forced_eos_token_id: Optional[int] = None, + suppress_tokens: Optional[List[int]] = None, + begin_suppress_tokens: Optional[List[int]] = None, + forced_decoder_ids: Optional[List[List[int]]] = None, length_penalty: Optional[float] = None, early_stopping: Optional[bool] = None, trace: bool = True, @@ -236,7 +270,6 @@ def generate( r""" Generates sequences of token ids for models with a language modeling head. The method supports the following generation methods for text-decoder, text-to-text, speech-to-text, and vision-to-text models: - - *greedy decoding* by calling [`~generation.FlaxGenerationMixin._greedy_search`] if `num_beams=1` and `do_sample=False`. - *multinomial sampling* by calling [`~generation.FlaxGenerationMixin._sample`] if `num_beams=1` and @@ -252,10 +285,9 @@ def generate( + Parameters: Most of these parameters are explained in more detail in [this blog post](https://huggingface.co/blog/how-to-generate). - - Parameters: input_ids (`jnp.ndarray` of shape `(batch_size, sequence_length)`): The sequence used as a prompt for the generation. max_length (`int`, *optional*, defaults to `model.config.max_length`): @@ -283,6 +315,12 @@ def generate( Number of beams for beam search. 1 means no beam search. decoder_start_token_id (`int`, *optional*): If an encoder-decoder model starts decoding with a different token than *bos*, the id of that token. + suppress_tokens (`list`, *optional*): + Token ids to not sample + begin_suppress_tokens (`list`, *optional*): + Token ids to not sample during the first sample step + forced_decoder_ids (`list`, *optional*): + Sample indices and token ids to force sampling at beginning of generation trace (`bool`, *optional*, defaults to `True`): Whether to trace generation. Setting `trace=False` should only be used for debugging and will lead to a considerably slower runtime. @@ -292,12 +330,9 @@ def generate( Additional model specific kwargs will be forwarded to the `forward` function of the model. If the model is an encoder-decoder model, encoder specific kwargs should not be prefixed and decoder specific kwargs should be prefixed with *decoder_*. Also accepts `encoder_outputs` to skip encoder part. - Return: [`~utils.ModelOutput`]. - Examples: - ```python >>> from transformers import AutoTokenizer, FlaxAutoModelForCausalLM @@ -334,12 +369,21 @@ def generate( "generation results, please set `padding_side='left'` when initializing the tokenizer." ) + batch_size = input_ids.shape[0] + if self.config.is_encoder_decoder: # add encoder_outputs to model_kwargs if model_kwargs.get("encoder_outputs") is None: model_kwargs = self._prepare_encoder_decoder_kwargs_for_generation(input_ids, params, model_kwargs) + if forced_decoder_ids is None and hasattr(self.config, "forced_decoder_ids"): + forced_decoder_ids = self.config.forced_decoder_ids # prepare decoder_input_ids for generation - input_ids = jnp.ones((input_ids.shape[0], 1), dtype="i4") * decoder_start_token_id + input_ids = self._prepare_decoder_input_ids_for_generation( + batch_size, + decoder_start_token_id=decoder_start_token_id, + bos_token_id=bos_token_id, + model_kwargs=model_kwargs, + ) # Prepare `max_length` depending on other stopping criteria. input_ids_seq_length = input_ids.shape[-1] @@ -382,7 +426,15 @@ def generate( if not do_sample and num_beams == 1: logits_processor = self._get_logits_processor( - no_repeat_ngram_size, min_length, max_length, eos_token_id, forced_bos_token_id, forced_eos_token_id + no_repeat_ngram_size, + min_length, + max_length, + eos_token_id, + forced_bos_token_id, + forced_eos_token_id, + input_ids_seq_length, + suppress_tokens=suppress_tokens, + begin_suppress_tokens=begin_suppress_tokens, ) return self._greedy_search( input_ids, @@ -397,7 +449,15 @@ def generate( elif do_sample and num_beams == 1: logits_warper = self._get_logits_warper(top_k=top_k, top_p=top_p, temperature=temperature) logits_processor = self._get_logits_processor( - no_repeat_ngram_size, min_length, max_length, eos_token_id, forced_bos_token_id, forced_eos_token_id + no_repeat_ngram_size, + min_length, + max_length, + eos_token_id, + forced_bos_token_id, + forced_eos_token_id, + input_ids_seq_length, + suppress_tokens=suppress_tokens, + begin_suppress_tokens=begin_suppress_tokens, ) return self._sample( input_ids, @@ -426,7 +486,15 @@ def generate( ) logits_processor = self._get_logits_processor( - no_repeat_ngram_size, min_length, max_length, eos_token_id, forced_bos_token_id, forced_eos_token_id + no_repeat_ngram_size, + min_length, + max_length, + eos_token_id, + forced_bos_token_id, + forced_eos_token_id, + input_ids_seq_length, + suppress_tokens=suppress_tokens, + begin_suppress_tokens=begin_suppress_tokens, ) return self._beam_search( @@ -478,6 +546,10 @@ def _get_logits_processor( eos_token_id: int, forced_bos_token_id: int, forced_eos_token_id: int, + input_ids_seq_length: int, + suppress_tokens: Optional[List[int]] = None, + begin_suppress_tokens: Optional[List[int]] = None, + forced_decoder_ids: Optional[List[List[int]]] = None, ) -> FlaxLogitsProcessorList: """ This class returns a [`FlaxLogitsProcessorList`] list object that contains all relevant [`FlaxLogitsProcessor`] @@ -496,6 +568,12 @@ def _get_logits_processor( forced_eos_token_id = ( forced_eos_token_id if forced_eos_token_id is not None else self.config.forced_eos_token_id ) + suppress_tokens = suppress_tokens if suppress_tokens is not None else self.config.suppress_tokens + begin_suppress_tokens = ( + begin_suppress_tokens if begin_suppress_tokens is not None else self.config.begin_suppress_tokens + ) + if forced_decoder_ids is None and hasattr(self.config, "forced_decoder_ids"): + forced_decoder_ids = self.config.forced_decoder_ids # the following idea is largely copied from this PR: https://github.com/huggingface/transformers/pull/5420/files # all samplers can be found in `generation_utils_samplers.py` @@ -505,6 +583,17 @@ def _get_logits_processor( processors.append(FlaxForcedBOSTokenLogitsProcessor(forced_bos_token_id)) if forced_eos_token_id is not None: processors.append(FlaxForcedEOSTokenLogitsProcessor(max_length, forced_eos_token_id)) + if suppress_tokens is not None: + processors.append(FlaxSuppressTokensLogitsProcessor(suppress_tokens)) + if begin_suppress_tokens is not None: + begin_index = input_ids_seq_length + begin_index = begin_index if (input_ids_seq_length > 1 or forced_bos_token_id is None) else begin_index + 1 + if forced_decoder_ids is not None: + begin_index += forced_decoder_ids[-1][0] # generation starts after the last token that is forced + processors.append(FlaxSuppressTokensAtBeginLogitsProcessor(begin_suppress_tokens, begin_index)) + if forced_decoder_ids is not None: + forced_decoder_ids = [[input_ids_seq_length + i[0] - 1, i[1]] for i in forced_decoder_ids] + processors.append(FlaxForceTokensLogitsProcessor(forced_decoder_ids)) return processors def _greedy_search( From d003074b5cb7211949e99eff27a0a3f5e425320a Mon Sep 17 00:00:00 2001 From: andyehrenberg Date: Fri, 2 Dec 2022 09:20:39 -0800 Subject: [PATCH 016/111] remove scans --- .../models/whisper/modeling_flax_whisper.py | 252 +++--------------- 1 file changed, 44 insertions(+), 208 deletions(-) diff --git a/src/transformers/models/whisper/modeling_flax_whisper.py b/src/transformers/models/whisper/modeling_flax_whisper.py index 7719d54d68051f..88317ce381bcf2 100644 --- a/src/transformers/models/whisper/modeling_flax_whisper.py +++ b/src/transformers/models/whisper/modeling_flax_whisper.py @@ -6,7 +6,6 @@ import jax.numpy as jnp from flax.core.frozen_dict import FrozenDict, freeze, unfreeze from flax.linen import combine_masks, make_causal_mask -from flax.linen import partitioning as nn_partitioning from flax.linen.attention import dot_product_attention_weights from flax.traverse_util import flatten_dict, unflatten_dict from jax import lax @@ -30,9 +29,6 @@ from .configuration_whisper import WhisperConfig -scan_with_axes = nn_partitioning.scan_with_axes - - logger = logging.get_logger(__name__) @@ -200,7 +196,7 @@ def __call__( attention_mask: Optional[jnp.ndarray] = None, init_cache: bool = False, deterministic: bool = True, - ) -> jnp.ndarray: + ) -> Tuple[jnp.ndarray]: is_cross_attention = key_value_states is not None batch_size = hidden_states.shape[0] @@ -317,7 +313,6 @@ def _concatenate_to_cache(self, key, value, query, attention_mask) -> Tuple[jnp. class FlaxWhisperEncoderLayer(nn.Module): config: WhisperConfig dtype: jnp.dtype = jnp.float32 - use_scan: bool = False def setup(self) -> None: self.embed_dim = self.config.d_model @@ -349,10 +344,7 @@ def __call__( hidden_states: jnp.ndarray, output_attentions: bool = False, deterministic: bool = True, - ) -> jnp.ndarray: - if self.use_scan: - hidden_states = hidden_states[0] - + ) -> Tuple[jnp.ndarray]: residual = hidden_states hidden_states = self.self_attn_layer_norm(hidden_states) hidden_states, attn_weights = self.self_attn(hidden_states) @@ -371,18 +363,19 @@ def __call__( if output_attentions: outputs += (attn_weights,) - if self.use_scan: - outputs = (outputs, None) - return outputs class EncoderLayerCollection(nn.Module): config: WhisperConfig dtype: jnp.dtype = jnp.float32 - use_scan: bool = False - @nn.compact + def setup(self) -> None: + self.layers = [ + FlaxWhisperEncoderLayer(self.config, name=str(i), dtype=self.dtype) + for i in range(self.config.encoder_layers) + ] + def __call__( self, hidden_states: jnp.ndarray, @@ -390,41 +383,21 @@ def __call__( output_attentions: bool = False, output_hidden_states: bool = False, return_dict: bool = True, - ) -> jnp.ndarray: + ) -> Tuple[jnp.ndarray]: all_attentions = () if output_attentions else None all_hidden_states = () if output_hidden_states else None - if self.use_scan: - assert not output_attentions, "cannot use `scan` with `output_attentions` set to `True`" - assert not output_hidden_states, "cannot use `scan` with `output_hidden_states` set to `True`" - hidden_states = (hidden_states,) - - hidden_states, _ = scan_with_axes( - FlaxWhisperEncoderLayer, - variable_axes={"params": 0, "cache": 0}, - split_rngs={"params": True, "dropout": True}, - in_axes=(nn.broadcast, nn.broadcast), - length=self.config.encoder_layers, - )(self.config, dtype=self.dtype, use_scan=self.use_scan, name="Layers")( + for encoder_layer in self.layers: + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + layer_outputs = encoder_layer( hidden_states, - output_attentions, - deterministic, + output_attentions=output_attentions, + deterministic=deterministic, ) - hidden_states = hidden_states[0] - else: - for layer in range(self.config.encoder_layers): - if output_hidden_states: - all_hidden_states = all_hidden_states + (hidden_states,) - layer_outputs = FlaxWhisperEncoderLayer( - self.config, dtype=self.dtype, use_scan=self.use_scan, name=str(layer) - )( - hidden_states, - output_attentions=output_attentions, - deterministic=deterministic, - ) - hidden_states = layer_outputs[0] - if output_attentions: - all_attentions = all_attentions + (layer_outputs[1],) + hidden_states = layer_outputs[0] + if output_attentions: + all_attentions = all_attentions + (layer_outputs[1],) if output_hidden_states: all_hidden_states += (hidden_states,) @@ -442,7 +415,6 @@ def __call__( class FlaxWhisperDecoderLayer(nn.Module): config: WhisperConfig dtype: jnp.dtype = jnp.float32 - use_scan: bool = False def setup(self) -> None: self.embed_dim = self.config.d_model @@ -486,8 +458,6 @@ def __call__( output_attentions: bool = True, deterministic: bool = True, ) -> Tuple[jnp.ndarray]: - if self.use_scan: - hidden_states = hidden_states[0] residual = hidden_states hidden_states = self.self_attn_layer_norm(hidden_states) @@ -521,18 +491,19 @@ def __call__( if output_attentions: outputs += (self_attn_weights, cross_attn_weights) - if self.use_scan: - outputs = (outputs, None) - return outputs class FlaxWhisperDecoderLayerCollection(nn.Module): config: WhisperConfig dtype: jnp.dtype = jnp.float32 - use_scan: bool = False - @nn.compact + def setup(self) -> None: + self.layers = [ + FlaxWhisperDecoderLayer(self.config, name=str(i), dtype=self.dtype) + for i in range(self.config.decoder_layers) + ] + def __call__( self, hidden_states: jnp.ndarray, @@ -543,49 +514,29 @@ def __call__( output_attentions: bool = False, deterministic: bool = True, return_dict: bool = True, - ) -> jnp.ndarray: + ) -> Tuple[jnp.ndarray]: # decoder layers all_hidden_states = () if output_hidden_states else None all_self_attns = () if output_attentions else None all_cross_attentions = () if (output_attentions and encoder_hidden_states is not None) else None - if self.use_scan: - hidden_states = (hidden_states,) - hidden_states, _ = scan_with_axes( - FlaxWhisperDecoderLayer, - variable_axes={"params": 0, "cache": 0}, - split_rngs={"params": True, "dropout": True}, - in_axes=(nn.broadcast, nn.broadcast, nn.broadcast, nn.broadcast, nn.broadcast), - length=self.config.decoder_layers, - )(self.config, dtype=self.dtype, use_scan=self.use_scan, name="Layers",)( + for decoder_layer in self.layers: + if output_hidden_states: + all_hidden_states += (hidden_states,) + layer_outputs = decoder_layer( hidden_states, attention_mask, encoder_hidden_states, - init_cache, - output_attentions, - deterministic, + init_cache=init_cache, + output_attentions=output_attentions, + deterministic=deterministic, ) - hidden_states = hidden_states[0] - else: - for layer in range(self.config.decoder_layers): - if output_hidden_states: - all_hidden_states += (hidden_states,) - layer_outputs = FlaxWhisperDecoderLayer( - self.config, dtype=self.dtype, use_scan=self.use_scan, name=str(layer) - )( - hidden_states, - attention_mask, - encoder_hidden_states, - init_cache=init_cache, - output_attentions=output_attentions, - deterministic=deterministic, - ) - hidden_states = layer_outputs[0] - if output_attentions: - all_self_attns += (layer_outputs[1],) + hidden_states = layer_outputs[0] + if output_attentions: + all_self_attns += (layer_outputs[1],) - if encoder_hidden_states is not None: - all_cross_attentions += (layer_outputs[2],) + if encoder_hidden_states is not None: + all_cross_attentions += (layer_outputs[2],) # add hidden states from the last decoder layer if output_hidden_states: @@ -607,7 +558,6 @@ def __call__( class FlaxWhisperEncoder(nn.Module): config: WhisperConfig dtype: jnp.dtype = jnp.float32 - use_scan: bool = False def setup(self) -> None: self.conv1 = nn.Conv( @@ -631,7 +581,6 @@ def setup(self) -> None: self.layers = EncoderLayerCollection( self.config, dtype=self.dtype, - use_scan=self.use_scan, ) self.embed_positions = nn.Embed(self.config.max_source_positions, self.config.d_model, dtype=self.dtype) @@ -644,7 +593,7 @@ def __call__( output_hidden_states: bool = False, return_dict: bool = True, deterministic: bool = True, - ) -> jnp.ndarray: + ) -> Tuple[jnp.ndarray]: input_features = input_features.transpose(0, 2, 1) hidden_states = jax.nn.gelu(self.conv1(input_features), approximate=False) hidden_states = jax.nn.gelu(self.conv2(hidden_states), approximate=False) @@ -680,13 +629,12 @@ def __call__( class FlaxWhisperDecoder(nn.Module): config: WhisperConfig dtype: jnp.dtype = jnp.float32 - use_scan: bool = False def setup(self) -> None: self.embed_tokens = nn.Embed(self.config.vocab_size, self.config.d_model, dtype=self.dtype) self.embed_positions = nn.Embed(self.config.max_target_positions, self.config.d_model, dtype=self.dtype) - self.layers = FlaxWhisperDecoderLayerCollection(self.config, dtype=self.dtype, use_scan=self.use_scan) + self.layers = FlaxWhisperDecoderLayerCollection(self.config, dtype=self.dtype) self.dropout_layer = nn.Dropout(rate=self.config.dropout) @@ -703,7 +651,7 @@ def __call__( output_hidden_states: bool = False, return_dict: bool = True, deterministic: bool = True, - ) -> jnp.ndarray: + ) -> Tuple[jnp.ndarray]: if attention_mask is not None: if position_ids is None: position_ids = attention_mask.cumsum(-1) - 1 @@ -743,11 +691,10 @@ def __call__( class FlaxWhisperModule(nn.Module): config: WhisperConfig dtype: jnp.dtype = jnp.float32 - use_scan: bool = False def setup(self) -> None: - self.encoder = FlaxWhisperEncoder(self.config, dtype=self.dtype, use_scan=self.use_scan) - self.decoder = FlaxWhisperDecoder(self.config, dtype=self.dtype, use_scan=self.use_scan) + self.encoder = FlaxWhisperEncoder(self.config, dtype=self.dtype) + self.decoder = FlaxWhisperDecoder(self.config, dtype=self.dtype) def __call__( self, @@ -1104,47 +1051,6 @@ def __call__( rngs=rngs, ) - def scan(self, params: Optional[FrozenDict] = None): - self._module = self.module_class(config=self.config, dtype=self.dtype, use_scan=True) - init_fn = partial(self.init_weights, input_shape=self.input_shape) - params_shape_tree = jax.eval_shape(init_fn, self.key) - - # get the shape of the parameters - self._params_shape_tree = params_shape_tree - - # save required_params as set - self._required_params = set(flatten_dict(unfreeze(params_shape_tree)).keys()) - - # initialize the parameters - - if params is not None: - params = convert_unroll_to_scan(self, params) - return params - elif self._is_initialized: - self.params = convert_unroll_to_scan(self, self.params) - - def scan_disable(self, params: Optional[FrozenDict] = None): - self._module = self.module_class( - config=self.config, - dtype=self.dtype, - use_scan=False, - ) - init_fn = partial(self.init_weights, input_shape=self.input_shape) - params_shape_tree = jax.eval_shape(init_fn, self.key) - - # get the shape of the parameters - self._params_shape_tree = params_shape_tree - - # save required_params as set - self._required_params = set(flatten_dict(unfreeze(params_shape_tree)).keys()) - - # initialize the parameters - if params is not None: - params = convert_scan_to_unroll(self, params) - return params - elif self._is_initialized: - self.params = convert_scan_to_unroll(self, self.params) - @add_start_docstrings( "The bare Whisper Model transformer outputting raw hidden-states without any specific head on top.", @@ -1164,10 +1070,9 @@ class FlaxWhisperModel(FlaxWhisperPreTrainedModel): class FlaxWhisperForConditionalGenerationModule(nn.Module): config: WhisperConfig dtype: jnp.dtype = jnp.float32 - use_scan: bool = False - def setup(self): - self.model = FlaxWhisperModule(config=self.config, dtype=self.dtype, use_scan=self.use_scan) + def setup(self) -> None: + self.model = FlaxWhisperModule(config=self.config, dtype=self.dtype) self.lm_head = nn.Dense( self.config.vocab_size, use_bias=False, @@ -1429,72 +1334,3 @@ def update_inputs_for_generation(self, model_outputs, model_kwargs): append_replace_return_docstrings( FlaxWhisperForConditionalGeneration, output_type=FlaxSeq2SeqLMOutput, config_class=_CONFIG_FOR_DOC ) - - -def convert_unroll_to_scan(model, params: Union[Dict, FrozenDict]): - r""" - Convert a `PyTree` of unrolled model parameters to a scanned block of model parameters. This method can be used to - explicitly convert the model parameters to scanned format. This returns a new `params` tree and does not convert - the `params` in place. - ```""" - if isinstance(params, FrozenDict): - params = unfreeze(params) - - params = flatten_dict(params, sep="/") - keys = list(params.keys()) - - for k in keys: - # Identify all "unrolled" layers formed in LayerCollections - # These params contain the identifier `layers` in their key - if "layers/0" in k: - # Squash the keys for the N unrolled layers into one single key: - # (layers/0, ..., layers/N) -> layers/Layers - scan_key = k.replace("0", "Layers", 1) - stacked_params = [] - - # Iterate over the unrolled layers (1,...,N) - for i in range(model.config.encoder_layers): - # Stack the params for the N layers into one super block - # and remove the unrolled layer params on the fly - # -> no memory overhead for conversion! - unrolled_layer = params.pop(k.replace("0", str(i), 1)) - stacked_params.append(unrolled_layer) - - params[scan_key] = jnp.stack(stacked_params) - - # Finally, unflatten the dict to restore the nested pytree structure - params = unflatten_dict(params, sep="/") - return params - - -def convert_scan_to_unroll(model, params: Union[Dict, FrozenDict]): - r""" - Convert a `PyTree` of scanned model parameters to an unrolled stack of model parameters. This method can be used to - explicitly convert the model parameters to unrolled format. This returns a new `params` tree and does not convert - the `params` in place. - ```""" - - if isinstance(params, FrozenDict): - params = unfreeze(params) - - params = flatten_dict(params, sep="/") - keys = list(params.keys()) - - for k in keys: - # Identify all "scan" layers in LayerCollections - # These params contain the identifier `EncoderLayers` or `DecoderLayers` in their key - if "Layers" in k: - # Remove the scan layer from the PyTree of params - scan_layer = params.pop(k) - - # Unroll the key for the stacked scan matrix into N separate keys, indexed by layer number - # layer/Layers -> (layer/0, ..., layer/N) - for i in range(model.dims.n_audio_layer): - # Unstack the params for the i-th scan layer to unrolled - # and remove corresponding scan params on the fly - # -> no memory overhead for conversion! - unrolled_key = k.replace("Layers", str(i)) - params[unrolled_key], scan_layer = scan_layer[0], scan_layer[1:] - - params = unflatten_dict(params, sep="/") - return params From ba8a35836494263fa45576cf609ed712212cf9c0 Mon Sep 17 00:00:00 2001 From: andyehrenberg Date: Fri, 2 Dec 2022 09:28:37 -0800 Subject: [PATCH 017/111] give credits to Flax Bart implementation --- .../models/whisper/modeling_flax_whisper.py | 11 ++++++++--- 1 file changed, 8 insertions(+), 3 deletions(-) diff --git a/src/transformers/models/whisper/modeling_flax_whisper.py b/src/transformers/models/whisper/modeling_flax_whisper.py index 88317ce381bcf2..cae34ea76beba2 100644 --- a/src/transformers/models/whisper/modeling_flax_whisper.py +++ b/src/transformers/models/whisper/modeling_flax_whisper.py @@ -153,7 +153,7 @@ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. """ - +# Copied with a few changes from transformers.models.bart.modeling_flax_bart.FlaxBartAttention class FlaxWhisperAttention(nn.Module): config: WhisperConfig embed_dim: int @@ -310,6 +310,7 @@ def _concatenate_to_cache(self, key, value, query, attention_mask) -> Tuple[jnp. return key, value, attention_mask +# Copied with a few changes from transformers.models.bart.modeling_flax_bart.FlaxBartEncoderLayer class FlaxWhisperEncoderLayer(nn.Module): config: WhisperConfig dtype: jnp.dtype = jnp.float32 @@ -366,7 +367,8 @@ def __call__( return outputs -class EncoderLayerCollection(nn.Module): +# Copied with a few changes from transformers.models.bart.modeling_flax_bart.FlaxBartEncoderLayerCollection +class FlaxWhisperEncoderLayerCollection(nn.Module): config: WhisperConfig dtype: jnp.dtype = jnp.float32 @@ -412,6 +414,7 @@ def __call__( ) +# Copied with a few changes from transformers.models.bart.modeling_flax_bart.FlaxBartDecoderLayer class FlaxWhisperDecoderLayer(nn.Module): config: WhisperConfig dtype: jnp.dtype = jnp.float32 @@ -494,6 +497,7 @@ def __call__( return outputs +# Copied with a few changes from transformers.models.bart.modeling_flax_bart.FlaxBartDecoderLayerCollection class FlaxWhisperDecoderLayerCollection(nn.Module): config: WhisperConfig dtype: jnp.dtype = jnp.float32 @@ -578,7 +582,7 @@ def setup(self) -> None: self.dropout_layer = nn.Dropout(rate=self.config.dropout) - self.layers = EncoderLayerCollection( + self.layers = FlaxWhisperEncoderLayerCollection( self.config, dtype=self.dtype, ) @@ -626,6 +630,7 @@ def __call__( ) +# Copied with a few changes from transformers.models.bart.modeling_flax_bart.FlaxBartDecoder class FlaxWhisperDecoder(nn.Module): config: WhisperConfig dtype: jnp.dtype = jnp.float32 From 9f4578dd67b70ef882eaf129fd0f54d61fe8952f Mon Sep 17 00:00:00 2001 From: andyehrenberg Date: Fri, 2 Dec 2022 09:30:29 -0800 Subject: [PATCH 018/111] remove unused imports --- src/transformers/models/whisper/modeling_flax_whisper.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/transformers/models/whisper/modeling_flax_whisper.py b/src/transformers/models/whisper/modeling_flax_whisper.py index cae34ea76beba2..0d0319c62d0cbb 100644 --- a/src/transformers/models/whisper/modeling_flax_whisper.py +++ b/src/transformers/models/whisper/modeling_flax_whisper.py @@ -1,5 +1,5 @@ from functools import partial -from typing import Dict, Optional, Tuple, Union +from typing import Optional, Tuple import flax.linen as nn import jax From be33fbd1f55477322ae1f2f3c57897ca553ecb57 Mon Sep 17 00:00:00 2001 From: andyehrenberg Date: Fri, 2 Dec 2022 09:32:36 -0800 Subject: [PATCH 019/111] add license --- .../models/whisper/modeling_flax_whisper.py | 17 +++++++++++++++++ 1 file changed, 17 insertions(+) diff --git a/src/transformers/models/whisper/modeling_flax_whisper.py b/src/transformers/models/whisper/modeling_flax_whisper.py index 0d0319c62d0cbb..642a9104914c72 100644 --- a/src/transformers/models/whisper/modeling_flax_whisper.py +++ b/src/transformers/models/whisper/modeling_flax_whisper.py @@ -1,3 +1,20 @@ +# coding=utf-8 +# Copyright 2022 The OpenAI Authors and The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" Flax whisper model.""" + + from functools import partial from typing import Optional, Tuple From 8b1338b07c8f215e7d1f71a8fcf24c375ac8b2cf Mon Sep 17 00:00:00 2001 From: andyehrenberg Date: Fri, 2 Dec 2022 09:44:58 -0800 Subject: [PATCH 020/111] remove assert --- .../models/whisper/modeling_flax_whisper.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/src/transformers/models/whisper/modeling_flax_whisper.py b/src/transformers/models/whisper/modeling_flax_whisper.py index 642a9104914c72..a1819cb05a06a4 100644 --- a/src/transformers/models/whisper/modeling_flax_whisper.py +++ b/src/transformers/models/whisper/modeling_flax_whisper.py @@ -619,10 +619,12 @@ def __call__( hidden_states = jax.nn.gelu(self.conv1(input_features), approximate=False) hidden_states = jax.nn.gelu(self.conv2(hidden_states), approximate=False) - assert hidden_states.shape[1:] == ( - self.config.max_source_positions, - self.config.d_model, - ), "incorrect audio shape" + if hidden_states.shape[1:] != (self.config.max_source_positions, self.config.d_model): + raise ValueError( + f"hidden_states.shape[1:] must be equal to (self.config.max_source_positions, self.config.d_model)" + f"(got {hidden_states.shape[1:]}, but should be ({self.config.max_source_positions}, {self.config.d_model}))" + ) + hidden_states = hidden_states + self.embed_positions(jnp.arange(self.config.max_source_positions)) hidden_states = self.dropout_layer(hidden_states, deterministic=deterministic) From c567f79e3bf1d0cd69aa60f15ff99fe8f8ffb500 Mon Sep 17 00:00:00 2001 From: andyehrenberg Date: Fri, 2 Dec 2022 09:48:19 -0800 Subject: [PATCH 021/111] more credits to Bart --- src/transformers/models/whisper/modeling_flax_whisper.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/transformers/models/whisper/modeling_flax_whisper.py b/src/transformers/models/whisper/modeling_flax_whisper.py index a1819cb05a06a4..1629bbab4123de 100644 --- a/src/transformers/models/whisper/modeling_flax_whisper.py +++ b/src/transformers/models/whisper/modeling_flax_whisper.py @@ -770,6 +770,7 @@ def _get_decoder_module(self): return self.decoder +# Copied with a few changes from transformers.models.bart.modeling_flax_bart.FlaxBartPreTrainedModel class FlaxWhisperPreTrainedModel(FlaxPreTrainedModel): config_class = WhisperConfig base_model_prefix: str = "model" @@ -1090,7 +1091,7 @@ class FlaxWhisperModel(FlaxWhisperPreTrainedModel): FlaxWhisperModel, _TOKENIZER_FOR_DOC, _CHECKPOINT_FOR_DOC, FlaxSeq2SeqModelOutput, _CONFIG_FOR_DOC ) - +# Copied with a few changes from transformers.models.bart.modeling_flax_bart.FlaxBartForConditionalGenerationModule class FlaxWhisperForConditionalGenerationModule(nn.Module): config: WhisperConfig dtype: jnp.dtype = jnp.float32 @@ -1157,6 +1158,7 @@ def __call__( ) +# Copied with a few changes from transformers.models.bart.modeling_flax_bart.FlaxBartForConditionalGeneration @add_start_docstrings("The Whisper Model with a language modeling head.", WHISPER_START_DOCSTRING) class FlaxWhisperForConditionalGeneration(FlaxWhisperPreTrainedModel): module_class = FlaxWhisperForConditionalGenerationModule From fbe4e2591dc0e9bf935ed7d8ab94b43fef69d175 Mon Sep 17 00:00:00 2001 From: andyehrenberg Date: Fri, 2 Dec 2022 10:01:58 -0800 Subject: [PATCH 022/111] fix style --- src/transformers/models/whisper/modeling_flax_whisper.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/transformers/models/whisper/modeling_flax_whisper.py b/src/transformers/models/whisper/modeling_flax_whisper.py index 1629bbab4123de..3180e51d21cdda 100644 --- a/src/transformers/models/whisper/modeling_flax_whisper.py +++ b/src/transformers/models/whisper/modeling_flax_whisper.py @@ -170,6 +170,7 @@ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. """ + # Copied with a few changes from transformers.models.bart.modeling_flax_bart.FlaxBartAttention class FlaxWhisperAttention(nn.Module): config: WhisperConfig @@ -1091,6 +1092,7 @@ class FlaxWhisperModel(FlaxWhisperPreTrainedModel): FlaxWhisperModel, _TOKENIZER_FOR_DOC, _CHECKPOINT_FOR_DOC, FlaxSeq2SeqModelOutput, _CONFIG_FOR_DOC ) + # Copied with a few changes from transformers.models.bart.modeling_flax_bart.FlaxBartForConditionalGenerationModule class FlaxWhisperForConditionalGenerationModule(nn.Module): config: WhisperConfig From cde5afd912797aed4bf98364ded5f1c10c23d54a Mon Sep 17 00:00:00 2001 From: andyehrenberg Date: Fri, 2 Dec 2022 10:15:07 -0800 Subject: [PATCH 023/111] formatting --- src/transformers/models/whisper/modeling_flax_whisper.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/src/transformers/models/whisper/modeling_flax_whisper.py b/src/transformers/models/whisper/modeling_flax_whisper.py index 3180e51d21cdda..3cdc1b27e48f4b 100644 --- a/src/transformers/models/whisper/modeling_flax_whisper.py +++ b/src/transformers/models/whisper/modeling_flax_whisper.py @@ -622,8 +622,9 @@ def __call__( if hidden_states.shape[1:] != (self.config.max_source_positions, self.config.d_model): raise ValueError( - f"hidden_states.shape[1:] must be equal to (self.config.max_source_positions, self.config.d_model)" - f"(got {hidden_states.shape[1:]}, but should be ({self.config.max_source_positions}, {self.config.d_model}))" + "hidden_states.shape[1:] must be equal to (self.config.max_source_positions, self.config.d_model)(got" + f" {hidden_states.shape[1:]}, but should be ({self.config.max_source_positions}," + f" {self.config.d_model}))" ) hidden_states = hidden_states + self.embed_positions(jnp.arange(self.config.max_source_positions)) From 6aeb8c89c9f59135fe4c3f74bcdba3a10298db66 Mon Sep 17 00:00:00 2001 From: andyehrenberg Date: Fri, 2 Dec 2022 12:40:36 -0800 Subject: [PATCH 024/111] support left padding --- src/transformers/models/whisper/modeling_flax_whisper.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/transformers/models/whisper/modeling_flax_whisper.py b/src/transformers/models/whisper/modeling_flax_whisper.py index 3cdc1b27e48f4b..ff3284b8148323 100644 --- a/src/transformers/models/whisper/modeling_flax_whisper.py +++ b/src/transformers/models/whisper/modeling_flax_whisper.py @@ -264,7 +264,7 @@ def __call__( attention_bias = lax.select( attention_mask > 0, jnp.full(attention_mask.shape, 0.0), - jnp.full(attention_mask.shape, float("-inf")), + jnp.full(attention_mask.shape, -1e4), ) else: attention_bias = None From ec9ca19fdb323dfc2c526543767e0de47189a781 Mon Sep 17 00:00:00 2001 From: andyehrenberg Date: Mon, 5 Dec 2022 09:16:11 -0800 Subject: [PATCH 025/111] add flax whisper generation test --- .../whisper/test_modeling_flax_whisper.py | 20 +++++++++++++++++++ 1 file changed, 20 insertions(+) diff --git a/tests/models/whisper/test_modeling_flax_whisper.py b/tests/models/whisper/test_modeling_flax_whisper.py index 220826cb84d809..6a00be3dc3006f 100644 --- a/tests/models/whisper/test_modeling_flax_whisper.py +++ b/tests/models/whisper/test_modeling_flax_whisper.py @@ -163,6 +163,8 @@ def prepare_whisper_inputs_dict( return { "input_features": input_ids, "decoder_input_ids": decoder_input_ids, + "attention_mask": attention_mask, + "decoder_attention_mask": decoder_attention_mask, } @@ -287,3 +289,21 @@ def test_tiny_logits_librispeech(self): ) # fmt: on self.assertTrue(np.allclose(logits[0][0, 0, :30], EXPECTED_LOGITS, atol=1e-4)) + + def test_tiny_generation(self): + processor = WhisperProcessor.from_pretrained("openai/whisper-tiny") + model = FlaxWhisperForConditionalGeneration.from_pretrained("openai/whisper-tiny", from_pt=True) + + input_speech = self._load_datasamples(1) + input_features = processor.feature_extractor( + raw_speech=input_speech, sampling_rate=processor.feature_extractor.sampling_rate, return_tensors="jax" + ).input_features + + generated_ids = model.generate(input_features, num_beams=5, max_length=20) + transcript = processor.tokenizer.decode(generated_ids[0]) + + EXPECTED_TRANSCRIPT = ( + "<|startoftranscript|><|en|><|transcribe|><|notimestamps|> Mr. Quilter is the apostle of the middle" + " classes and we are glad" + ) + self.assertEqual(transcript, EXPECTED_TRANSCRIPT) From 3f902f6b40669effd933a57e48cb1b24ff2a5500 Mon Sep 17 00:00:00 2001 From: andyehrenberg Date: Wed, 7 Dec 2022 09:32:57 -0800 Subject: [PATCH 026/111] remove copied from comments whenever not a full copy --- .../models/whisper/modeling_flax_whisper.py | 11 +---------- 1 file changed, 1 insertion(+), 10 deletions(-) diff --git a/src/transformers/models/whisper/modeling_flax_whisper.py b/src/transformers/models/whisper/modeling_flax_whisper.py index ff3284b8148323..e984153316fce9 100644 --- a/src/transformers/models/whisper/modeling_flax_whisper.py +++ b/src/transformers/models/whisper/modeling_flax_whisper.py @@ -171,7 +171,6 @@ """ -# Copied with a few changes from transformers.models.bart.modeling_flax_bart.FlaxBartAttention class FlaxWhisperAttention(nn.Module): config: WhisperConfig embed_dim: int @@ -328,7 +327,6 @@ def _concatenate_to_cache(self, key, value, query, attention_mask) -> Tuple[jnp. return key, value, attention_mask -# Copied with a few changes from transformers.models.bart.modeling_flax_bart.FlaxBartEncoderLayer class FlaxWhisperEncoderLayer(nn.Module): config: WhisperConfig dtype: jnp.dtype = jnp.float32 @@ -353,8 +351,8 @@ def setup(self) -> None: ) self.fc2 = nn.Dense( self.embed_dim, - kernel_init=jax.nn.initializers.normal(self.config.init_std), dtype=self.dtype, + kernel_init=jax.nn.initializers.normal(self.config.init_std), ) self.final_layer_norm = nn.LayerNorm(dtype=self.dtype, epsilon=1e-05) @@ -385,7 +383,6 @@ def __call__( return outputs -# Copied with a few changes from transformers.models.bart.modeling_flax_bart.FlaxBartEncoderLayerCollection class FlaxWhisperEncoderLayerCollection(nn.Module): config: WhisperConfig dtype: jnp.dtype = jnp.float32 @@ -432,7 +429,6 @@ def __call__( ) -# Copied with a few changes from transformers.models.bart.modeling_flax_bart.FlaxBartDecoderLayer class FlaxWhisperDecoderLayer(nn.Module): config: WhisperConfig dtype: jnp.dtype = jnp.float32 @@ -515,7 +511,6 @@ def __call__( return outputs -# Copied with a few changes from transformers.models.bart.modeling_flax_bart.FlaxBartDecoderLayerCollection class FlaxWhisperDecoderLayerCollection(nn.Module): config: WhisperConfig dtype: jnp.dtype = jnp.float32 @@ -651,7 +646,6 @@ def __call__( ) -# Copied with a few changes from transformers.models.bart.modeling_flax_bart.FlaxBartDecoder class FlaxWhisperDecoder(nn.Module): config: WhisperConfig dtype: jnp.dtype = jnp.float32 @@ -772,7 +766,6 @@ def _get_decoder_module(self): return self.decoder -# Copied with a few changes from transformers.models.bart.modeling_flax_bart.FlaxBartPreTrainedModel class FlaxWhisperPreTrainedModel(FlaxPreTrainedModel): config_class = WhisperConfig base_model_prefix: str = "model" @@ -1094,7 +1087,6 @@ class FlaxWhisperModel(FlaxWhisperPreTrainedModel): ) -# Copied with a few changes from transformers.models.bart.modeling_flax_bart.FlaxBartForConditionalGenerationModule class FlaxWhisperForConditionalGenerationModule(nn.Module): config: WhisperConfig dtype: jnp.dtype = jnp.float32 @@ -1161,7 +1153,6 @@ def __call__( ) -# Copied with a few changes from transformers.models.bart.modeling_flax_bart.FlaxBartForConditionalGeneration @add_start_docstrings("The Whisper Model with a language modeling head.", WHISPER_START_DOCSTRING) class FlaxWhisperForConditionalGeneration(FlaxWhisperPreTrainedModel): module_class = FlaxWhisperForConditionalGenerationModule From 3fd0a7c230c8cd749d23b6ad437628217dd336e6 Mon Sep 17 00:00:00 2001 From: andyehrenberg Date: Wed, 7 Dec 2022 09:48:07 -0800 Subject: [PATCH 027/111] fix docstrings for logits processors --- .../generation/flax_logits_process.py | 21 +++++++++++++------ 1 file changed, 15 insertions(+), 6 deletions(-) diff --git a/src/transformers/generation/flax_logits_process.py b/src/transformers/generation/flax_logits_process.py index 8ae0b44729d35a..be08692aa8202b 100644 --- a/src/transformers/generation/flax_logits_process.py +++ b/src/transformers/generation/flax_logits_process.py @@ -269,9 +269,12 @@ def __call__(self, input_ids: jnp.ndarray, scores: jnp.ndarray, cur_len: int) -> class FlaxSuppressTokensAtBeginLogitsProcessor: r""" + [`FlaxSuppressTokensAtBeginLogitsProcessor`] supresses a list of tokens as soon as the `generate` function starts + generating using `begin_index` tokens. This should ensure that the tokens defined by `begin_suppress_tokens` are + not sampled at the begining of the generation. + Args: - [`FlaxLogitsProcessor`] supressing a list of tokens at an index. - begin_suppress_tokens (`list`): + begin_suppress_tokens (`List[int]`): Tokens to not sample. begin_index (`int`): Index where the tokens are suppressed. @@ -291,8 +294,10 @@ def __call__(self, input_ids, scores, cur_len: int): class FlaxSuppressTokensLogitsProcessor(FlaxLogitsProcessor): r""" + [`FlaxSuppressTokensLogitsProcessor`] suppresses a list of tokens at each decoding step. The processor will + set their log probs to be `-inf` so they are not sampled. + Args: - [`FlaxLogitsProcessor`] supressing a list of tokens. suppress_tokens (`list`): Tokens to not sample. """ @@ -308,8 +313,9 @@ def __call__(self, input_ids: jnp.ndarray, scores: jnp.ndarray, cur_len: int) -> class FlaxForceTokenAtIdxLogitsProcessor(FlaxLogitsProcessor): r""" + [`FlaxForceTokenAtIdxLogitsProcessor`] forces a token to be sampled at an index. + Args: - [`FlaxLogitsProcessor`] that forces a token to be sampled at an index. apply_idx (`int`): Index where sampling is forced. token_id (`int`): @@ -332,14 +338,17 @@ def __call__(self, input_ids: jnp.ndarray, scores: jnp.ndarray, cur_len: int) -> class FlaxForceTokensLogitsProcessor(FlaxLogitsProcessor): r""" + [`FlaxForceTokensLogitsProcessor`] takes a list of pairs of integers which indicates a mapping from generation + indices to token indices that will be forced before sampling. The processor will set their log probs to `inf` so + that they are sampled at their corresponding index. + Args: - [`FlaxLogitsProcessor`] that forces tokens to be sampled at given indices. force_token_map (`list`): Map giving token ids and indices where they will be forced to be sampled. """ def __init__(self, force_token_map): - self.processors = [FlaxForceTokenAtIdxLogitsProcessor(i[0], i[1]) for i in force_token_map] + self.processors = dict(force_token_map) def __call__(self, input_ids: jnp.ndarray, scores: jnp.ndarray, cur_len: int) -> jnp.ndarray: for processor in self.processors: From abc14a14da5a406ba84a78c03043b22ac176ca06 Mon Sep 17 00:00:00 2001 From: andyehrenberg Date: Wed, 7 Dec 2022 09:53:15 -0800 Subject: [PATCH 028/111] revert change to FlaxForceTokensLogitsProcessor --- src/transformers/generation/flax_logits_process.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/transformers/generation/flax_logits_process.py b/src/transformers/generation/flax_logits_process.py index be08692aa8202b..59cfc79f73af10 100644 --- a/src/transformers/generation/flax_logits_process.py +++ b/src/transformers/generation/flax_logits_process.py @@ -348,7 +348,7 @@ class FlaxForceTokensLogitsProcessor(FlaxLogitsProcessor): """ def __init__(self, force_token_map): - self.processors = dict(force_token_map) + self.processors = [FlaxForceTokenAtIdxLogitsProcessor(i[0], i[1]) for i in force_token_map] def __call__(self, input_ids: jnp.ndarray, scores: jnp.ndarray, cur_len: int) -> jnp.ndarray: for processor in self.processors: From d784a2385adfcfa27d3bd64d36f0e8bc967e3315 Mon Sep 17 00:00:00 2001 From: andyehrenberg Date: Wed, 7 Dec 2022 10:02:17 -0800 Subject: [PATCH 029/111] revert doc changes --- src/transformers/generation/flax_utils.py | 19 ++++++++++++++----- 1 file changed, 14 insertions(+), 5 deletions(-) diff --git a/src/transformers/generation/flax_utils.py b/src/transformers/generation/flax_utils.py index 21686e456e3f77..d5346f1546a464 100644 --- a/src/transformers/generation/flax_utils.py +++ b/src/transformers/generation/flax_utils.py @@ -53,8 +53,10 @@ @flax.struct.dataclass class FlaxGreedySearchOutput(ModelOutput): """ - Args: Flax Base class for outputs of decoder-only generation models using greedy search. + + + Args: sequences (`jnp.ndarray` of shape `(batch_size, max_length)`): The generated sequences. """ @@ -65,8 +67,10 @@ class FlaxGreedySearchOutput(ModelOutput): @flax.struct.dataclass class FlaxSampleOutput(ModelOutput): """ - Args: Flax Base class for outputs of decoder-only generation models using sampling. + + + Args: sequences (`jnp.ndarray` of shape `(batch_size, max_length)`): The generated sequences. """ @@ -77,8 +81,10 @@ class FlaxSampleOutput(ModelOutput): @flax.struct.dataclass class FlaxBeamSearchOutput(ModelOutput): """ - Args: Flax Base class for outputs of decoder-only generation models using greedy search. + + + Args: sequences (`jnp.ndarray` of shape `(batch_size, max_length)`): The generated sequences. scores (`jnp.ndarray` of shape `(batch_size,)`): @@ -122,7 +128,8 @@ class BeamSearchState: class FlaxGenerationMixin: """ A class containing all functions for auto-regressive text generation, to be used as a mixin in - [`FlaxPreTrainedModel`]. The class exposes [`~generation.FlaxGenerationMixin.generate`], which can be used for: + [`FlaxPreTrainedModel`]. + The class exposes [`~generation.FlaxGenerationMixin.generate`], which can be used for: - *greedy decoding* by calling [`~generation.FlaxGenerationMixin._greedy_search`] if `num_beams=1` and `do_sample=False`. - *multinomial sampling* by calling [`~generation.FlaxGenerationMixin._sample`] if `num_beams=1` and @@ -270,6 +277,7 @@ def generate( r""" Generates sequences of token ids for models with a language modeling head. The method supports the following generation methods for text-decoder, text-to-text, speech-to-text, and vision-to-text models: + - *greedy decoding* by calling [`~generation.FlaxGenerationMixin._greedy_search`] if `num_beams=1` and `do_sample=False`. - *multinomial sampling* by calling [`~generation.FlaxGenerationMixin._sample`] if `num_beams=1` and @@ -285,9 +293,10 @@ def generate( - Parameters: Most of these parameters are explained in more detail in [this blog post](https://huggingface.co/blog/how-to-generate). + + Parameters: input_ids (`jnp.ndarray` of shape `(batch_size, sequence_length)`): The sequence used as a prompt for the generation. max_length (`int`, *optional*, defaults to `model.config.max_length`): From 3dd8282a202765cdfa53089e05bb4b1b6bb2f143 Mon Sep 17 00:00:00 2001 From: andyehrenberg Date: Wed, 7 Dec 2022 10:23:17 -0800 Subject: [PATCH 030/111] improve generation docs --- .../generation/flax_logits_process.py | 2 +- src/transformers/generation/flax_utils.py | 56 ++++++++++++------- 2 files changed, 36 insertions(+), 22 deletions(-) diff --git a/src/transformers/generation/flax_logits_process.py b/src/transformers/generation/flax_logits_process.py index 59cfc79f73af10..d3cd6f801fbfcc 100644 --- a/src/transformers/generation/flax_logits_process.py +++ b/src/transformers/generation/flax_logits_process.py @@ -267,7 +267,7 @@ def __call__(self, input_ids: jnp.ndarray, scores: jnp.ndarray, cur_len: int) -> return scores -class FlaxSuppressTokensAtBeginLogitsProcessor: +class FlaxSuppressTokensAtBeginLogitsProcessor(FlaxLogitsProcessor): r""" [`FlaxSuppressTokensAtBeginLogitsProcessor`] supresses a list of tokens as soon as the `generate` function starts generating using `begin_index` tokens. This should ensure that the tokens defined by `begin_suppress_tokens` are diff --git a/src/transformers/generation/flax_utils.py b/src/transformers/generation/flax_utils.py index d5346f1546a464..4e75cd936466e5 100644 --- a/src/transformers/generation/flax_utils.py +++ b/src/transformers/generation/flax_utils.py @@ -172,26 +172,32 @@ def _prepare_decoder_input_ids_for_generation( decoder_start_token_id = self._get_decoder_start_token_id(decoder_start_token_id, bos_token_id) return jnp.array(decoder_start_token_id).reshape(1, -1).repeat(batch_size, axis=0) - def _get_decoder_start_token_id( - self, - decoder_start_token_id: int = None, - bos_token_id: int = None, - ) -> int: + def _get_decoder_start_token_id(self, decoder_start_token_id: int = None, bos_token_id: int = None) -> int: + # retrieve decoder_start_token_id for encoder-decoder models + # fall back to bos_token_id if necessary decoder_start_token_id = ( decoder_start_token_id if decoder_start_token_id is not None else self.config.decoder_start_token_id ) bos_token_id = bos_token_id if bos_token_id is not None else self.config.bos_token_id - if decoder_start_token_id is not None: - start_token = decoder_start_token_id + return decoder_start_token_id + elif ( + hasattr(self.config, "decoder") + and hasattr(self.config.decoder, "decoder_start_token_id") + and self.config.decoder.decoder_start_token_id is not None + ): + return self.config.decoder.decoder_start_token_id elif bos_token_id is not None: - start_token = bos_token_id - else: - raise ValueError( - "`decoder_start_token_id` or `bos_token_id` has to be defined for encoder-decoder generation." - ) - - return start_token + return bos_token_id + elif ( + hasattr(self.config, "decoder") + and hasattr(self.config.decoder, "bos_token_id") + and self.config.decoder.bos_token_id is not None + ): + return self.config.decoder.bos_token_id + raise ValueError( + "`decoder_start_token_id` or `bos_token_id` has to be defined for encoder-decoder generation." + ) @staticmethod def _expand_to_num_beams(tensor, num_beams): @@ -295,7 +301,7 @@ def generate( Most of these parameters are explained in more detail in [this blog post](https://huggingface.co/blog/how-to-generate). - + Parameters: input_ids (`jnp.ndarray` of shape `(batch_size, sequence_length)`): The sequence used as a prompt for the generation. @@ -324,12 +330,17 @@ def generate( Number of beams for beam search. 1 means no beam search. decoder_start_token_id (`int`, *optional*): If an encoder-decoder model starts decoding with a different token than *bos*, the id of that token. - suppress_tokens (`list`, *optional*): - Token ids to not sample - begin_suppress_tokens (`list`, *optional*): - Token ids to not sample during the first sample step - forced_decoder_ids (`list`, *optional*): - Sample indices and token ids to force sampling at beginning of generation + suppress_tokens (`List[int]`, *optional*, defaults to model.config.suppress_tokens): + A list of tokens that will be supressed at generation. The `FlaxSupressTokensLogitsProcessor` will set + their log probs to `-inf` so that they are not sampled. + begin_suppress_tokens (`List[int]`, *optional*, defaults to model.config.begin_supress_tokens): + A list of tokens that will be supressed at the begining of the generation. The + `FlaxSuppressTokensAtBeginLogitsProcessor` will set their log probs to `-inf` so that they are + not sampled. + forced_decoder_ids (`List[List[int]]`, *optional*, defaults to model.config.forced_decoder_ids): + A list of pairs of integers which indicates a mapping from generation indices to token indices that + will be forced before sampling. For example, `[[1, 123]]` means the second generated token will always + be a token of index 123. trace (`bool`, *optional*, defaults to `True`): Whether to trace generation. Setting `trace=False` should only be used for debugging and will lead to a considerably slower runtime. @@ -339,9 +350,12 @@ def generate( Additional model specific kwargs will be forwarded to the `forward` function of the model. If the model is an encoder-decoder model, encoder specific kwargs should not be prefixed and decoder specific kwargs should be prefixed with *decoder_*. Also accepts `encoder_outputs` to skip encoder part. + Return: [`~utils.ModelOutput`]. + Examples: + ```python >>> from transformers import AutoTokenizer, FlaxAutoModelForCausalLM From 77fce32a433239f3a37bf8108fc522fc598b9e6f Mon Sep 17 00:00:00 2001 From: andyehrenberg Date: Wed, 7 Dec 2022 10:28:09 -0800 Subject: [PATCH 031/111] reorganize --- src/transformers/generation/flax_utils.py | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/src/transformers/generation/flax_utils.py b/src/transformers/generation/flax_utils.py index 4e75cd936466e5..e4cdeec9f5d98e 100644 --- a/src/transformers/generation/flax_utils.py +++ b/src/transformers/generation/flax_utils.py @@ -355,7 +355,7 @@ def generate( [`~utils.ModelOutput`]. Examples: - + ```python >>> from transformers import AutoTokenizer, FlaxAutoModelForCausalLM @@ -379,6 +379,8 @@ def generate( decoder_start_token_id = ( decoder_start_token_id if decoder_start_token_id is not None else self.config.decoder_start_token_id ) + if forced_decoder_ids is None and hasattr(self.config, "forced_decoder_ids"): + forced_decoder_ids = self.config.forced_decoder_ids prng_key = prng_key if prng_key is not None else jax.random.PRNGKey(0) if decoder_start_token_id is None and self.config.is_encoder_decoder: @@ -398,8 +400,6 @@ def generate( # add encoder_outputs to model_kwargs if model_kwargs.get("encoder_outputs") is None: model_kwargs = self._prepare_encoder_decoder_kwargs_for_generation(input_ids, params, model_kwargs) - if forced_decoder_ids is None and hasattr(self.config, "forced_decoder_ids"): - forced_decoder_ids = self.config.forced_decoder_ids # prepare decoder_input_ids for generation input_ids = self._prepare_decoder_input_ids_for_generation( batch_size, @@ -458,6 +458,7 @@ def generate( input_ids_seq_length, suppress_tokens=suppress_tokens, begin_suppress_tokens=begin_suppress_tokens, + forced_decoder_ids=forced_decoder_ids, ) return self._greedy_search( input_ids, @@ -481,6 +482,7 @@ def generate( input_ids_seq_length, suppress_tokens=suppress_tokens, begin_suppress_tokens=begin_suppress_tokens, + forced_decoder_ids=forced_decoder_ids, ) return self._sample( input_ids, @@ -518,6 +520,7 @@ def generate( input_ids_seq_length, suppress_tokens=suppress_tokens, begin_suppress_tokens=begin_suppress_tokens, + forced_decoder_ids=forced_decoder_ids, ) return self._beam_search( @@ -595,8 +598,6 @@ def _get_logits_processor( begin_suppress_tokens = ( begin_suppress_tokens if begin_suppress_tokens is not None else self.config.begin_suppress_tokens ) - if forced_decoder_ids is None and hasattr(self.config, "forced_decoder_ids"): - forced_decoder_ids = self.config.forced_decoder_ids # the following idea is largely copied from this PR: https://github.com/huggingface/transformers/pull/5420/files # all samplers can be found in `generation_utils_samplers.py` From fefefdedd86aa18b2865dc4b0ce03acba3929be1 Mon Sep 17 00:00:00 2001 From: andyehrenberg Date: Wed, 7 Dec 2022 10:35:06 -0800 Subject: [PATCH 032/111] formatting --- src/transformers/generation/flax_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/transformers/generation/flax_utils.py b/src/transformers/generation/flax_utils.py index e4cdeec9f5d98e..1cd849151e02c3 100644 --- a/src/transformers/generation/flax_utils.py +++ b/src/transformers/generation/flax_utils.py @@ -334,7 +334,7 @@ def generate( A list of tokens that will be supressed at generation. The `FlaxSupressTokensLogitsProcessor` will set their log probs to `-inf` so that they are not sampled. begin_suppress_tokens (`List[int]`, *optional*, defaults to model.config.begin_supress_tokens): - A list of tokens that will be supressed at the begining of the generation. The + A list of tokens that will be supressed at the begining of the generation. The `FlaxSuppressTokensAtBeginLogitsProcessor` will set their log probs to `-inf` so that they are not sampled. forced_decoder_ids (`List[List[int]]`, *optional*, defaults to model.config.forced_decoder_ids): From 04ad6510544471f687a6a1096b033368cc0024c8 Mon Sep 17 00:00:00 2001 From: andyehrenberg Date: Wed, 7 Dec 2022 11:03:32 -0800 Subject: [PATCH 033/111] cleanup docs --- .../generation/flax_logits_process.py | 18 +++++++++--------- src/transformers/generation/flax_utils.py | 7 +++---- 2 files changed, 12 insertions(+), 13 deletions(-) diff --git a/src/transformers/generation/flax_logits_process.py b/src/transformers/generation/flax_logits_process.py index d3cd6f801fbfcc..b63f5aac37e220 100644 --- a/src/transformers/generation/flax_logits_process.py +++ b/src/transformers/generation/flax_logits_process.py @@ -269,9 +269,9 @@ def __call__(self, input_ids: jnp.ndarray, scores: jnp.ndarray, cur_len: int) -> class FlaxSuppressTokensAtBeginLogitsProcessor(FlaxLogitsProcessor): r""" - [`FlaxSuppressTokensAtBeginLogitsProcessor`] supresses a list of tokens as soon as the `generate` function starts - generating using `begin_index` tokens. This should ensure that the tokens defined by `begin_suppress_tokens` are - not sampled at the begining of the generation. + [`FlaxLogitsProcessor`] supressing a list of tokens as soon as the `generate` function starts generating using + `begin_index` tokens. This should ensure that the tokens defined by `begin_suppress_tokens` are not sampled at the + begining of the generation. Args: begin_suppress_tokens (`List[int]`): @@ -294,8 +294,8 @@ def __call__(self, input_ids, scores, cur_len: int): class FlaxSuppressTokensLogitsProcessor(FlaxLogitsProcessor): r""" - [`FlaxSuppressTokensLogitsProcessor`] suppresses a list of tokens at each decoding step. The processor will - set their log probs to be `-inf` so they are not sampled. + [`FlaxLogitsProcessor`] suppressing a list of tokens at each decoding step. The processor will set their log probs + to be `-inf` so they are not sampled. Args: suppress_tokens (`list`): @@ -313,7 +313,7 @@ def __call__(self, input_ids: jnp.ndarray, scores: jnp.ndarray, cur_len: int) -> class FlaxForceTokenAtIdxLogitsProcessor(FlaxLogitsProcessor): r""" - [`FlaxForceTokenAtIdxLogitsProcessor`] forces a token to be sampled at an index. + [`FlaxLogitsProcessor`] forcing a token to be sampled at an index. Args: apply_idx (`int`): @@ -338,9 +338,9 @@ def __call__(self, input_ids: jnp.ndarray, scores: jnp.ndarray, cur_len: int) -> class FlaxForceTokensLogitsProcessor(FlaxLogitsProcessor): r""" - [`FlaxForceTokensLogitsProcessor`] takes a list of pairs of integers which indicates a mapping from generation - indices to token indices that will be forced before sampling. The processor will set their log probs to `inf` so - that they are sampled at their corresponding index. + [`FlaxLogitsProcessor`] that takes a list of pairs of integers which indicates a mapping from generation indices to + token indices that will be forced before sampling. The processor will set their log probs to `inf` so that they are + sampled at their corresponding index. Args: force_token_map (`list`): diff --git a/src/transformers/generation/flax_utils.py b/src/transformers/generation/flax_utils.py index 1cd849151e02c3..c4d6b7c8a78bc6 100644 --- a/src/transformers/generation/flax_utils.py +++ b/src/transformers/generation/flax_utils.py @@ -128,8 +128,7 @@ class BeamSearchState: class FlaxGenerationMixin: """ A class containing all functions for auto-regressive text generation, to be used as a mixin in - [`FlaxPreTrainedModel`]. - The class exposes [`~generation.FlaxGenerationMixin.generate`], which can be used for: + [`FlaxPreTrainedModel`]. The class exposes [`~generation.FlaxGenerationMixin.generate`], which can be used for: - *greedy decoding* by calling [`~generation.FlaxGenerationMixin._greedy_search`] if `num_beams=1` and `do_sample=False`. - *multinomial sampling* by calling [`~generation.FlaxGenerationMixin._sample`] if `num_beams=1` and @@ -335,8 +334,8 @@ def generate( their log probs to `-inf` so that they are not sampled. begin_suppress_tokens (`List[int]`, *optional*, defaults to model.config.begin_supress_tokens): A list of tokens that will be supressed at the begining of the generation. The - `FlaxSuppressTokensAtBeginLogitsProcessor` will set their log probs to `-inf` so that they are - not sampled. + `FlaxSuppressTokensAtBeginLogitsProcessor` will set their log probs to `-inf` so that they are not + sampled. forced_decoder_ids (`List[List[int]]`, *optional*, defaults to model.config.forced_decoder_ids): A list of pairs of integers which indicates a mapping from generation indices to token indices that will be forced before sampling. For example, `[[1, 123]]` means the second generated token will always From 14e19c087b0ea5db04e954fd207ff4c423b80e45 Mon Sep 17 00:00:00 2001 From: andyehrenberg Date: Wed, 7 Dec 2022 12:09:39 -0800 Subject: [PATCH 034/111] add tests --- .../whisper/test_modeling_flax_whisper.py | 207 +++++++++++++++++- 1 file changed, 204 insertions(+), 3 deletions(-) diff --git a/tests/models/whisper/test_modeling_flax_whisper.py b/tests/models/whisper/test_modeling_flax_whisper.py index 6a00be3dc3006f..e8aec29d82ed4d 100644 --- a/tests/models/whisper/test_modeling_flax_whisper.py +++ b/tests/models/whisper/test_modeling_flax_whisper.py @@ -17,16 +17,19 @@ import inspect import unittest -from datasets import load_dataset - from transformers import WhisperConfig, is_flax_available from transformers.testing_utils import require_flax, slow from transformers.utils import cached_property +from transformers.utils.import_utils import is_datasets_available from ...test_configuration_common import ConfigTester from ...test_modeling_flax_common import FlaxModelTesterMixin, floats_tensor +if is_datasets_available(): + import datasets + from datasets import load_dataset + if is_flax_available(): import numpy as np @@ -290,6 +293,86 @@ def test_tiny_logits_librispeech(self): # fmt: on self.assertTrue(np.allclose(logits[0][0, 0, :30], EXPECTED_LOGITS, atol=1e-4)) + def test_small_en_logits_librispeech(self): + model = FlaxWhisperModel.from_pretrained("openai/whisper-small.en", from_pt=True) + input_speech = self._load_datasamples(1) + feature_extractor = WhisperFeatureExtractor() + input_features = feature_extractor(input_speech, return_tensors="np").input_features + + logits = model( + input_features, + decoder_input_ids=np.array([model.config.decoder_start_token_id]), + output_hidden_states=False, + output_attentions=False, + return_dict=False, + ) + + logits = logits[0] @ model.params["model"]["decoder"]["embed_tokens"]["embedding"].T + + # fmt: off + EXPECTED_LOGITS = np.array( + [ + -3.6784, -7.7211, -9.5070, -11.9286, -7.6489, -9.7026, -5.6188, + -8.0104, -4.6238, -5.1833, -9.0485, -3.4079, -5.4874, -2.6935, + -6.3479, -7.3398, -6.9558, -7.6867, -7.4748, -8.3463, -9.9781, + -10.8389, -10.3105, -11.7201, -9.7261, -7.1590, -5.9272, -12.4509, + -11.1146, -8.1918 + ] + ) + # fmt: on + self.assertTrue(np.allclose(logits[0, 0, :30], EXPECTED_LOGITS, atol=1e-4)) + + def test_large_logits_librispeech(self): + model = FlaxWhisperModel.from_pretrained("openai/whisper-large", from_pt=True) + input_speech = self._load_datasamples(1) + processor = WhisperProcessor.from_pretrained("openai/whisper-large") + processed_inputs = processor( + audio=input_speech, text="This part of the speech", add_special_tokens=False, return_tensors="np" + ) + input_features = processed_inputs.input_features + decoder_input_ids = processed_inputs.labels + + logits = model( + input_features, + decoder_input_ids=decoder_input_ids, + output_hidden_states=False, + output_attentions=False, + return_dict=False, + ) + + logits = logits[0] @ model.params["model"]["decoder"]["embed_tokens"]["embedding"].T + + # fmt: off + EXPECTED_LOGITS = np.array( + [ + 2.1382, 0.9381, 4.4671, 3.5589, 2.4022, 3.8576, -0.6521, 2.5472, + 1.8301, 1.9957, 2.3432, 1.4678, 0.5459, 2.2597, 1.5179, 2.5357, + 1.1624, 0.6194, 1.0757, 1.8259, 2.4076, 1.6601, 2.3503, 1.3376, + 1.9891, 1.8635, 3.8931, 5.3699, 4.4772, 3.9184 + ] + ) + # fmt: on + self.assertTrue(np.allclose(logits[0, 0, :30], EXPECTED_LOGITS, atol=1e-4)) + + def test_tiny_en_generation(self): + processor = WhisperProcessor.from_pretrained("openai/whisper-tiny.en") + model = FlaxWhisperForConditionalGeneration.from_pretrained("openai/whisper-tiny.en", from_pt=True) + model.config.decoder_start_token_id = 50257 + + input_speech = self._load_datasamples(1) + input_features = processor.feature_extractor( + raw_speech=input_speech, sampling_rate=processor.feature_extractor.sampling_rate, return_tensors="jax" + ).input_features + + generated_ids = model.generate(input_features, num_beams=5, max_length=20).sequences + transcript = processor.tokenizer.decode(generated_ids[0]) + + EXPECTED_TRANSCRIPT = ( + "<|startoftranscript|><|en|><|transcribe|><|notimestamps|> Mr. Quilter is the apostle of the middle" + " classes and we are glad" + ) + self.assertEqual(transcript, EXPECTED_TRANSCRIPT) + def test_tiny_generation(self): processor = WhisperProcessor.from_pretrained("openai/whisper-tiny") model = FlaxWhisperForConditionalGeneration.from_pretrained("openai/whisper-tiny", from_pt=True) @@ -299,7 +382,7 @@ def test_tiny_generation(self): raw_speech=input_speech, sampling_rate=processor.feature_extractor.sampling_rate, return_tensors="jax" ).input_features - generated_ids = model.generate(input_features, num_beams=5, max_length=20) + generated_ids = model.generate(input_features, num_beams=5, max_length=20).sequences transcript = processor.tokenizer.decode(generated_ids[0]) EXPECTED_TRANSCRIPT = ( @@ -307,3 +390,121 @@ def test_tiny_generation(self): " classes and we are glad" ) self.assertEqual(transcript, EXPECTED_TRANSCRIPT) + + def test_large_generation(self): + processor = WhisperProcessor.from_pretrained("openai/whisper-large") + model = FlaxWhisperForConditionalGeneration.from_pretrained("openai/whisper-large", from_pt=True) + + input_speech = self._load_datasamples(1) + input_features = processor.feature_extractor( + raw_speech=input_speech, sampling_rate=processor.feature_extractor.sampling_rate, return_tensors="jax" + ).input_features + + model.config.forced_decoder_ids = processor.get_decoder_prompt_ids(language="en", task="transcribe") + + generated_ids = model.generate(input_features, num_beams=5, max_length=20).sequences + transcript = processor.tokenizer.decode(generated_ids[0], skip_special_tokens=True) + + EXPECTED_TRANSCRIPT = " Mr. Quilter is the apostle of the middle classes and we are glad" + self.assertEqual(transcript, EXPECTED_TRANSCRIPT) + + def test_large_generation_multilingual(self): + processor = WhisperProcessor.from_pretrained("openai/whisper-large") + model = FlaxWhisperForConditionalGeneration.from_pretrained("openai/whisper-large", from_pt=True) + + ds = load_dataset("common_voice", "ja", split="test", streaming=True) + ds = ds.cast_column("audio", datasets.Audio(sampling_rate=16_000)) + input_speech = next(iter(ds))["audio"]["array"] + input_features = processor.feature_extractor(raw_speech=input_speech, return_tensors="np") + + model.config.forced_decoder_ids = processor.get_decoder_prompt_ids(language="ja", task="transcribe") + generated_ids = model.generate(input_features, do_sample=False, max_length=20).sequences + transcript = processor.batch_decode(generated_ids, skip_special_tokens=True)[0] + + EXPECTED_TRANSCRIPT = "木村さんに電話を貸してもらいました" + self.assertEqual(transcript, EXPECTED_TRANSCRIPT) + + model.config.forced_decoder_ids = processor.get_decoder_prompt_ids(language="en", task="transcribe") + generated_ids = model.generate( + input_features, + do_sample=False, + max_length=20, + ).sequences + transcript = processor.batch_decode(generated_ids, skip_special_tokens=True)[0] + + EXPECTED_TRANSCRIPT = " Kimura-san called me." + self.assertEqual(transcript, EXPECTED_TRANSCRIPT) + + model.config.forced_decoder_ids = processor.get_decoder_prompt_ids(language="ja", task="translate") + generated_ids = model.generate(input_features, do_sample=False, max_length=20).sequences + transcript = processor.batch_decode(generated_ids, skip_special_tokens=True)[0] + + EXPECTED_TRANSCRIPT = " I borrowed a phone from Kimura san" + self.assertEqual(transcript, EXPECTED_TRANSCRIPT) + + def test_large_batched_generation(self): + processor = WhisperProcessor.from_pretrained("openai/whisper-large") + model = FlaxWhisperForConditionalGeneration.from_pretrained("openai/whisper-large", from_pt=True) + + input_speech = self._load_datasamples(4) + input_features = processor.feature_extractor(raw_speech=input_speech, return_tensors="np").input_features + generated_ids = model.generate(input_features, max_length=20).sequences + + # fmt: off + EXPECTED_LOGITS = np.array( + [ + [50258, 50358, 50363, 2221, 13, 2326, 388, 391, 307, 264, 50244, 295, 264, 2808, 5359, 293, 321, 366, 5404, 281], + [50258, 50358, 50363, 6966, 307, 2221, 13, 2326, 388, 391, 311, 9060, 1570, 1880, 813, 702, 1871, 13, 50257, 50257], + [50258, 50358, 50363, 634, 5112, 505, 300, 412, 341, 42729, 3196, 295, 264, 1064, 11, 365, 5272, 293, 12904, 9256], + [50258, 50358, 50363, 634, 575, 12525, 22618, 1968, 6144, 35617, 20084, 1756, 311, 589, 307, 534, 10281, 934, 439, 11] + ] + ) + # fmt: on + + self.assertTrue(np.allclose(generated_ids, EXPECTED_LOGITS)) + + # fmt: off + EXPECTED_TRANSCRIPT = [ + " Mr. Quilter is the apostle of the middle classes and we are glad to", + " Nor is Mr. Quilter's manner less interesting than his matter.", + " He tells us that at this festive season of the year, with Christmas and roast beef", + " He has grave doubts whether Sir Frederick Layton's work is really Greek after all,", + ] + # fmt: on + + transcript = processor.batch_decode(generated_ids, skip_special_tokens=True) + self.assertListEqual(transcript, EXPECTED_TRANSCRIPT) + + def test_tiny_en_batched_generation(self): + processor = WhisperProcessor.from_pretrained("openai/whisper-tiny.en") + model = FlaxWhisperForConditionalGeneration.from_pretrained("openai/whisper-tiny.en", from_pt=True) + + input_speech = self._load_datasamples(4) + input_features = processor.feature_extractor(raw_speech=input_speech, return_tensors="np").input_features + generated_ids = model.generate(input_features, max_length=20).sequences + + # fmt: off + EXPECTED_LOGITS = np.array( + [ + [50257, 50362, 1770, 13, 2264, 346, 353, 318, 262, 46329, 286, 262, 3504, 6097, 11, 290, 356, 389, 9675, 284], + [50257, 50362, 5414, 318, 1770, 13, 2264, 346, 353, 338, 5642, 1342, 3499, 621, 465, 2300, 13, 50256, 50256, 50256], + [50257, 50362, 679, 4952, 514, 326, 379, 428, 43856, 1622, 286, 262, 614, 11, 351, 6786, 290, 32595, 12023, 28236], + [50257, 50362, 679, 468, 12296, 17188, 1771, 7361, 26113, 18881, 1122, 338, 670, 318, 1107, 8312, 706, 477, 290, 460] + ] + + ) + # fmt: on + + self.assertTrue(np.allclose(generated_ids, EXPECTED_LOGITS)) + + # fmt: off + EXPECTED_TRANSCRIPT = [ + " Mr. Quilter is the apostle of the middle classes, and we are glad to", + " Nor is Mr. Quilter's manner less interesting than his matter.", + " He tells us that at this festive season of the year, with Christmas and roast beef looming", + " He has grave doubts whether Sir Frederick Layton's work is really Greek after all and can", + ] + # fmt: on + + transcript = processor.batch_decode(generated_ids, skip_special_tokens=True) + self.assertListEqual(transcript, EXPECTED_TRANSCRIPT) From cf67b3816d270739cbd0e23736b3af775dfddd20 Mon Sep 17 00:00:00 2001 From: andyehrenberg Date: Wed, 7 Dec 2022 15:54:35 -0800 Subject: [PATCH 035/111] handle empty list case --- src/transformers/generation/flax_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/transformers/generation/flax_utils.py b/src/transformers/generation/flax_utils.py index c4d6b7c8a78bc6..2165e884e1004d 100644 --- a/src/transformers/generation/flax_utils.py +++ b/src/transformers/generation/flax_utils.py @@ -611,7 +611,7 @@ def _get_logits_processor( if begin_suppress_tokens is not None: begin_index = input_ids_seq_length begin_index = begin_index if (input_ids_seq_length > 1 or forced_bos_token_id is None) else begin_index + 1 - if forced_decoder_ids is not None: + if forced_decoder_ids is not None and len(forced_decoder_ids) > 0: begin_index += forced_decoder_ids[-1][0] # generation starts after the last token that is forced processors.append(FlaxSuppressTokensAtBeginLogitsProcessor(begin_suppress_tokens, begin_index)) if forced_decoder_ids is not None: From 3de7509e861d8c58accf3764a86cc958e8a50097 Mon Sep 17 00:00:00 2001 From: andyehrenberg Date: Wed, 7 Dec 2022 16:10:05 -0800 Subject: [PATCH 036/111] fix forced decoder ids in flax tests --- tests/models/whisper/test_modeling_flax_whisper.py | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/tests/models/whisper/test_modeling_flax_whisper.py b/tests/models/whisper/test_modeling_flax_whisper.py index e8aec29d82ed4d..00b5419397f12e 100644 --- a/tests/models/whisper/test_modeling_flax_whisper.py +++ b/tests/models/whisper/test_modeling_flax_whisper.py @@ -400,7 +400,8 @@ def test_large_generation(self): raw_speech=input_speech, sampling_rate=processor.feature_extractor.sampling_rate, return_tensors="jax" ).input_features - model.config.forced_decoder_ids = processor.get_decoder_prompt_ids(language="en", task="transcribe") + prompt_ids = processor.get_decoder_prompt_ids(language="en", task="transcribe") + model.config.forced_decoder_ids = [[i[0] - 1, i[1]] for i in prompt_ids[1:]] generated_ids = model.generate(input_features, num_beams=5, max_length=20).sequences transcript = processor.tokenizer.decode(generated_ids[0], skip_special_tokens=True) @@ -417,14 +418,16 @@ def test_large_generation_multilingual(self): input_speech = next(iter(ds))["audio"]["array"] input_features = processor.feature_extractor(raw_speech=input_speech, return_tensors="np") - model.config.forced_decoder_ids = processor.get_decoder_prompt_ids(language="ja", task="transcribe") + prompt_ids = processor.get_decoder_prompt_ids(language="ja", task="transcribe") + model.config.forced_decoder_ids = [[i[0] - 1, i[1]] for i in prompt_ids[1:]] generated_ids = model.generate(input_features, do_sample=False, max_length=20).sequences transcript = processor.batch_decode(generated_ids, skip_special_tokens=True)[0] EXPECTED_TRANSCRIPT = "木村さんに電話を貸してもらいました" self.assertEqual(transcript, EXPECTED_TRANSCRIPT) - model.config.forced_decoder_ids = processor.get_decoder_prompt_ids(language="en", task="transcribe") + prompt_ids = processor.get_decoder_prompt_ids(language="en", task="transcribe") + model.config.forced_decoder_ids = [[i[0] - 1, i[1]] for i in prompt_ids[1:]] generated_ids = model.generate( input_features, do_sample=False, @@ -435,7 +438,8 @@ def test_large_generation_multilingual(self): EXPECTED_TRANSCRIPT = " Kimura-san called me." self.assertEqual(transcript, EXPECTED_TRANSCRIPT) - model.config.forced_decoder_ids = processor.get_decoder_prompt_ids(language="ja", task="translate") + prompt_ids = processor.get_decoder_prompt_ids(language="ja", task="translate") + model.config.forced_decoder_ids = [[i[0] - 1, i[1]] for i in prompt_ids[1:]] generated_ids = model.generate(input_features, do_sample=False, max_length=20).sequences transcript = processor.batch_decode(generated_ids, skip_special_tokens=True)[0] From 5e2256a324ac0081af0fd636582365d0a8ea2170 Mon Sep 17 00:00:00 2001 From: andyehrenberg Date: Mon, 12 Dec 2022 13:59:53 -0800 Subject: [PATCH 037/111] add flax whisper to inits --- src/transformers/__init__.py | 4 ++++ src/transformers/models/auto/__init__.py | 4 ++++ 2 files changed, 8 insertions(+) diff --git a/src/transformers/__init__.py b/src/transformers/__init__.py index 8955e7b922ad39..f8b92f0f1999bc 100644 --- a/src/transformers/__init__.py +++ b/src/transformers/__init__.py @@ -3092,6 +3092,7 @@ "FLAX_MODEL_FOR_QUESTION_ANSWERING_MAPPING", "FLAX_MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING", "FLAX_MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING", + "FLAX_MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING", "FLAX_MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING", "FLAX_MODEL_FOR_VISION_2_SEQ_MAPPING", "FLAX_MODEL_MAPPING", @@ -3105,6 +3106,7 @@ "FlaxAutoModelForQuestionAnswering", "FlaxAutoModelForSeq2SeqLM", "FlaxAutoModelForSequenceClassification", + "FlaxAutoModelForSpeechSeq2Seq", "FlaxAutoModelForTokenClassification", "FlaxAutoModelForVision2Seq", ] @@ -5832,6 +5834,7 @@ FLAX_MODEL_FOR_QUESTION_ANSWERING_MAPPING, FLAX_MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING, FLAX_MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING, + FLAX_MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING, FLAX_MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING, FLAX_MODEL_FOR_VISION_2_SEQ_MAPPING, FLAX_MODEL_MAPPING, @@ -5845,6 +5848,7 @@ FlaxAutoModelForQuestionAnswering, FlaxAutoModelForSeq2SeqLM, FlaxAutoModelForSequenceClassification, + FlaxAutoModelForSpeechSeq2Seq, FlaxAutoModelForTokenClassification, FlaxAutoModelForVision2Seq, ) diff --git a/src/transformers/models/auto/__init__.py b/src/transformers/models/auto/__init__.py index a6ee30366b3915..74639d73f8f534 100644 --- a/src/transformers/models/auto/__init__.py +++ b/src/transformers/models/auto/__init__.py @@ -165,6 +165,7 @@ "FLAX_MODEL_FOR_QUESTION_ANSWERING_MAPPING", "FLAX_MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING", "FLAX_MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING", + "FLAX_MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING", "FLAX_MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING", "FLAX_MODEL_FOR_VISION_2_SEQ_MAPPING", "FLAX_MODEL_MAPPING", @@ -178,6 +179,7 @@ "FlaxAutoModelForQuestionAnswering", "FlaxAutoModelForSeq2SeqLM", "FlaxAutoModelForSequenceClassification", + "FlaxAutoModelForSpeechSeq2Seq", "FlaxAutoModelForTokenClassification", "FlaxAutoModelForVision2Seq", ] @@ -320,6 +322,7 @@ FLAX_MODEL_FOR_QUESTION_ANSWERING_MAPPING, FLAX_MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING, FLAX_MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING, + FLAX_MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING, FLAX_MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING, FLAX_MODEL_FOR_VISION_2_SEQ_MAPPING, FLAX_MODEL_MAPPING, @@ -333,6 +336,7 @@ FlaxAutoModelForQuestionAnswering, FlaxAutoModelForSeq2SeqLM, FlaxAutoModelForSequenceClassification, + FlaxAutoModelForSpeechSeq2Seq, FlaxAutoModelForTokenClassification, FlaxAutoModelForVision2Seq, ) From 669db4ebc5c39fc36e988e274dd4043f4078d929 Mon Sep 17 00:00:00 2001 From: andyehrenberg Date: Mon, 12 Dec 2022 14:44:53 -0800 Subject: [PATCH 038/111] upate dummy objects --- src/transformers/utils/dummy_flax_objects.py | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/src/transformers/utils/dummy_flax_objects.py b/src/transformers/utils/dummy_flax_objects.py index 36f1932ba6b75e..95a81b9e869940 100644 --- a/src/transformers/utils/dummy_flax_objects.py +++ b/src/transformers/utils/dummy_flax_objects.py @@ -163,6 +163,9 @@ def __init__(self, *args, **kwargs): FLAX_MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING = None +FLAX_MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING = None + + FLAX_MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING = None @@ -242,6 +245,13 @@ def __init__(self, *args, **kwargs): requires_backends(self, ["flax"]) +class FlaxAutoModelForSpeechSeq2Seq(metaclass=DummyObject): + _backends = ["flax"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["flax"]) + + class FlaxAutoModelForTokenClassification(metaclass=DummyObject): _backends = ["flax"] From bea6cf0e572cdae694efc3fad98132adc09d78cd Mon Sep 17 00:00:00 2001 From: andyehrenberg Date: Mon, 12 Dec 2022 14:52:56 -0800 Subject: [PATCH 039/111] docs for FlaxAutoModelForSpeechSeq2Seq --- docs/source/en/model_doc/auto.mdx | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/docs/source/en/model_doc/auto.mdx b/docs/source/en/model_doc/auto.mdx index 7957f453a2fb86..01551f4f2f45bf 100644 --- a/docs/source/en/model_doc/auto.mdx +++ b/docs/source/en/model_doc/auto.mdx @@ -282,6 +282,10 @@ The following auto classes are available for the following audio tasks. [[autodoc]] TFAutoModelForSpeechSeq2Seq +### FlaxAutoModelForSpeechSeq2Seq + +[[autodoc]] FlaxAutoModelForSpeechSeq2Seq + ### AutoModelForAudioXVector [[autodoc]] AutoModelForAudioXVector From e4270b41b25f7b1a5c6f8afeb61e2341199060a0 Mon Sep 17 00:00:00 2001 From: andyehrenberg Date: Wed, 14 Dec 2022 13:39:37 -0800 Subject: [PATCH 040/111] fix decoder_position_ids computation in pretrained model decode/__call__ fns --- .../models/whisper/modeling_flax_whisper.py | 43 +++++++++++-------- 1 file changed, 26 insertions(+), 17 deletions(-) diff --git a/src/transformers/models/whisper/modeling_flax_whisper.py b/src/transformers/models/whisper/modeling_flax_whisper.py index e984153316fce9..be821d598f9590 100644 --- a/src/transformers/models/whisper/modeling_flax_whisper.py +++ b/src/transformers/models/whisper/modeling_flax_whisper.py @@ -962,16 +962,19 @@ def decode( encoder_hidden_states = encoder_outputs[0] batch_size, sequence_length = decoder_input_ids.shape - if decoder_attention_mask is None: - decoder_attention_mask = jnp.ones((batch_size, sequence_length)) - if decoder_position_ids is None: if past_key_values is not None: raise ValueError("Make sure to provide `decoder_position_ids` when passing `past_key_values`.") - decoder_position_ids = jnp.broadcast_to( - jnp.arange(sequence_length)[None, :], (batch_size, sequence_length) - ) + if decoder_attention_mask is not None: + decoder_position_ids = decoder_attention_mask.cumsum(-1) - 1 + else: + decoder_position_ids = jnp.broadcast_to( + jnp.arange(sequence_length)[None, :], (batch_size, sequence_length) + ) + + if decoder_attention_mask is None: + decoder_attention_mask = jnp.ones((batch_size, sequence_length)) # Handle any PRNG if needed rngs = {} @@ -1047,13 +1050,17 @@ def __call__( return_dict = return_dict if return_dict is not None else self.config.return_dict # prepare decoder inputs + if decoder_position_ids is None: + if decoder_attention_mask is not None: + decoder_position_ids = decoder_attention_mask.cumsum(-1) - 1 + else: + batch_size, sequence_length = decoder_input_ids.shape + + decoder_position_ids = jnp.broadcast_to( + jnp.arange(sequence_length)[None, :], (batch_size, sequence_length) + ) if decoder_attention_mask is None: decoder_attention_mask = jnp.ones_like(decoder_input_ids) - if decoder_position_ids is None: - batch_size, sequence_length = decoder_input_ids.shape - decoder_position_ids = jnp.broadcast_to( - jnp.arange(sequence_length)[None, :], (batch_size, sequence_length) - ) # Handle any PRNG if needed rngs = {"dropout": dropout_rng} if dropout_rng is not None else {} @@ -1206,16 +1213,18 @@ def decode( encoder_hidden_states = encoder_outputs[0] batch_size, sequence_length = decoder_input_ids.shape - if decoder_attention_mask is None: - decoder_attention_mask = jnp.ones((batch_size, sequence_length)) - if decoder_position_ids is None: if past_key_values is not None: raise ValueError("Make sure to provide `decoder_position_ids` when passing `past_key_values`.") - decoder_position_ids = jnp.broadcast_to( - jnp.arange(sequence_length)[None, :], (batch_size, sequence_length) - ) + if decoder_attention_mask is not None: + decoder_position_ids = decoder_attention_mask.cumsum(-1) - 1 + else: + decoder_position_ids = jnp.broadcast_to( + jnp.arange(sequence_length)[None, :], (batch_size, sequence_length) + ) + if decoder_attention_mask is None: + decoder_attention_mask = jnp.ones((batch_size, sequence_length), dtype="i4") # Handle any PRNG if needed rngs = {} From 135b6347a54860f217e76dab5e55e5abb4734502 Mon Sep 17 00:00:00 2001 From: andyehrenberg Date: Thu, 15 Dec 2022 15:50:25 -0800 Subject: [PATCH 041/111] add Copied from statements as necessary --- .../models/whisper/modeling_flax_whisper.py | 131 +++++++++++------- 1 file changed, 79 insertions(+), 52 deletions(-) diff --git a/src/transformers/models/whisper/modeling_flax_whisper.py b/src/transformers/models/whisper/modeling_flax_whisper.py index be821d598f9590..ed831579e52f35 100644 --- a/src/transformers/models/whisper/modeling_flax_whisper.py +++ b/src/transformers/models/whisper/modeling_flax_whisper.py @@ -14,7 +14,7 @@ # limitations under the License. """ Flax whisper model.""" - +import random from functools import partial from typing import Optional, Tuple @@ -327,6 +327,7 @@ def _concatenate_to_cache(self, key, value, query, attention_mask) -> Tuple[jnp. return key, value, attention_mask +# Copied from transformers.models.mbart.modeling_flax_mbart.FlaxMBartEncoderLayer with MBart->Whisper class FlaxWhisperEncoderLayer(nn.Module): config: WhisperConfig dtype: jnp.dtype = jnp.float32 @@ -335,36 +336,35 @@ def setup(self) -> None: self.embed_dim = self.config.d_model self.self_attn = FlaxWhisperAttention( config=self.config, - num_heads=self.config.encoder_attention_heads, embed_dim=self.embed_dim, - causal=False, + num_heads=self.config.encoder_attention_heads, + dropout=self.config.attention_dropout, dtype=self.dtype, ) + self.self_attn_layer_norm = nn.LayerNorm(dtype=self.dtype, epsilon=1e-05) self.dropout_layer = nn.Dropout(rate=self.config.dropout) - self.activation_dropout_layer = nn.Dropout(rate=self.config.activation_dropout) self.activation_fn = ACT2FN[self.config.activation_function] - self.self_attn_layer_norm = nn.LayerNorm(epsilon=1e-05) + self.activation_dropout_layer = nn.Dropout(rate=self.config.activation_dropout) self.fc1 = nn.Dense( self.config.encoder_ffn_dim, dtype=self.dtype, kernel_init=jax.nn.initializers.normal(self.config.init_std), ) self.fc2 = nn.Dense( - self.embed_dim, - dtype=self.dtype, - kernel_init=jax.nn.initializers.normal(self.config.init_std), + self.embed_dim, dtype=self.dtype, kernel_init=jax.nn.initializers.normal(self.config.init_std) ) self.final_layer_norm = nn.LayerNorm(dtype=self.dtype, epsilon=1e-05) def __call__( self, hidden_states: jnp.ndarray, - output_attentions: bool = False, + attention_mask: jnp.ndarray, + output_attentions: bool = True, deterministic: bool = True, ) -> Tuple[jnp.ndarray]: residual = hidden_states hidden_states = self.self_attn_layer_norm(hidden_states) - hidden_states, attn_weights = self.self_attn(hidden_states) + hidden_states, attn_weights = self.self_attn(hidden_states=hidden_states, attention_mask=attention_mask) hidden_states = self.dropout_layer(hidden_states, deterministic=deterministic) hidden_states = residual + hidden_states @@ -373,6 +373,7 @@ def __call__( hidden_states = self.activation_fn(self.fc1(hidden_states)) hidden_states = self.activation_dropout_layer(hidden_states, deterministic=deterministic) hidden_states = self.fc2(hidden_states) + hidden_states = self.dropout_layer(hidden_states, deterministic=deterministic) hidden_states = residual + hidden_states outputs = (hidden_states,) @@ -383,35 +384,44 @@ def __call__( return outputs +# Copied from transformers.models.mbart.modeling_flax_mbart.FlaxMBartEncoderLayerCollection with MBart->Whisper class FlaxWhisperEncoderLayerCollection(nn.Module): config: WhisperConfig - dtype: jnp.dtype = jnp.float32 + dtype: jnp.dtype = jnp.float32 # the dtype of the computation - def setup(self) -> None: + def setup(self): self.layers = [ FlaxWhisperEncoderLayer(self.config, name=str(i), dtype=self.dtype) for i in range(self.config.encoder_layers) ] + self.layerdrop = self.config.encoder_layerdrop def __call__( self, - hidden_states: jnp.ndarray, + hidden_states, + attention_mask, deterministic: bool = True, output_attentions: bool = False, output_hidden_states: bool = False, return_dict: bool = True, - ) -> Tuple[jnp.ndarray]: + ): all_attentions = () if output_attentions else None all_hidden_states = () if output_hidden_states else None for encoder_layer in self.layers: if output_hidden_states: all_hidden_states = all_hidden_states + (hidden_states,) - layer_outputs = encoder_layer( - hidden_states, - output_attentions=output_attentions, - deterministic=deterministic, - ) + # add LayerDrop (see https://arxiv.org/abs/1909.11556 for description) + dropout_probability = random.uniform(0, 1) + if not deterministic and (dropout_probability < self.layerdrop): # skip the layer + layer_outputs = (None, None) + else: + layer_outputs = encoder_layer( + hidden_states, + attention_mask, + output_attentions, + deterministic, + ) hidden_states = layer_outputs[0] if output_attentions: all_attentions = all_attentions + (layer_outputs[1],) @@ -429,6 +439,7 @@ def __call__( ) +# Copied from transformers.models.mbart.modeling_flax_mbart.FlaxMBartDecoderLayer with MBart->Whisper class FlaxWhisperDecoderLayer(nn.Module): config: WhisperConfig dtype: jnp.dtype = jnp.float32 @@ -437,101 +448,110 @@ def setup(self) -> None: self.embed_dim = self.config.d_model self.self_attn = FlaxWhisperAttention( config=self.config, - num_heads=self.config.decoder_attention_heads, embed_dim=self.embed_dim, + num_heads=self.config.decoder_attention_heads, + dropout=self.config.attention_dropout, causal=True, + dtype=self.dtype, ) - self.self_attn_layer_norm = nn.LayerNorm(dtype=self.dtype, epsilon=1e-05) + self.dropout_layer = nn.Dropout(rate=self.config.dropout) + self.activation_fn = ACT2FN[self.config.activation_function] + self.activation_dropout_layer = nn.Dropout(rate=self.config.activation_dropout) + self.self_attn_layer_norm = nn.LayerNorm(dtype=self.dtype, epsilon=1e-05) self.encoder_attn = FlaxWhisperAttention( config=self.config, - num_heads=self.config.decoder_attention_heads, embed_dim=self.embed_dim, - causal=False, + num_heads=self.config.decoder_attention_heads, + dropout=self.config.attention_dropout, dtype=self.dtype, ) self.encoder_attn_layer_norm = nn.LayerNorm(dtype=self.dtype, epsilon=1e-05) - self.dropout_layer = nn.Dropout(rate=self.config.dropout) - self.activation_dropout_layer = nn.Dropout(rate=self.config.activation_dropout) - self.activation_fn = ACT2FN[self.config.activation_function] self.fc1 = nn.Dense( self.config.decoder_ffn_dim, - kernel_init=jax.nn.initializers.normal(self.config.init_std), dtype=self.dtype, + kernel_init=jax.nn.initializers.normal(self.config.init_std), ) self.fc2 = nn.Dense( - self.embed_dim, - kernel_init=jax.nn.initializers.normal(self.config.init_std), - dtype=self.dtype, + self.embed_dim, dtype=self.dtype, kernel_init=jax.nn.initializers.normal(self.config.init_std) ) self.final_layer_norm = nn.LayerNorm(dtype=self.dtype, epsilon=1e-05) def __call__( self, hidden_states: jnp.ndarray, - attention_mask: Optional[jnp.ndarray] = None, + attention_mask: jnp.ndarray, encoder_hidden_states: Optional[jnp.ndarray] = None, + encoder_attention_mask: Optional[jnp.ndarray] = None, init_cache: bool = False, output_attentions: bool = True, deterministic: bool = True, ) -> Tuple[jnp.ndarray]: residual = hidden_states - hidden_states = self.self_attn_layer_norm(hidden_states) + + # Self Attention hidden_states, self_attn_weights = self.self_attn( - hidden_states, - attention_mask=attention_mask, - init_cache=init_cache, + hidden_states=hidden_states, attention_mask=attention_mask, init_cache=init_cache ) hidden_states = self.dropout_layer(hidden_states, deterministic=deterministic) hidden_states = residual + hidden_states + # Cross-Attention Block cross_attn_weights = None if encoder_hidden_states is not None: residual = hidden_states + hidden_states = self.encoder_attn_layer_norm(hidden_states) hidden_states, cross_attn_weights = self.encoder_attn( - hidden_states, - encoder_hidden_states, + hidden_states=hidden_states, + key_value_states=encoder_hidden_states, + attention_mask=encoder_attention_mask, ) hidden_states = self.dropout_layer(hidden_states, deterministic=deterministic) hidden_states = residual + hidden_states + # Fully Connected residual = hidden_states hidden_states = self.final_layer_norm(hidden_states) hidden_states = self.activation_fn(self.fc1(hidden_states)) hidden_states = self.activation_dropout_layer(hidden_states, deterministic=deterministic) hidden_states = self.fc2(hidden_states) + hidden_states = self.dropout_layer(hidden_states, deterministic=deterministic) hidden_states = residual + hidden_states outputs = (hidden_states,) + if output_attentions: outputs += (self_attn_weights, cross_attn_weights) return outputs +# Copied from transformers.models.mbart.modeling_flax_mbart.FlaxMBartDecoderLayerCollection with MBart->Whisper class FlaxWhisperDecoderLayerCollection(nn.Module): config: WhisperConfig - dtype: jnp.dtype = jnp.float32 + dtype: jnp.dtype = jnp.float32 # the dtype of the computation - def setup(self) -> None: + def setup(self): self.layers = [ FlaxWhisperDecoderLayer(self.config, name=str(i), dtype=self.dtype) for i in range(self.config.decoder_layers) ] + self.layerdrop = self.config.decoder_layerdrop def __call__( self, - hidden_states: jnp.ndarray, - attention_mask: Optional[jnp.ndarray] = None, + hidden_states, + attention_mask, encoder_hidden_states: Optional[jnp.ndarray] = None, + encoder_attention_mask: Optional[jnp.ndarray] = None, + deterministic: bool = True, init_cache: bool = False, - output_hidden_states: bool = False, output_attentions: bool = False, - deterministic: bool = True, + output_hidden_states: bool = False, return_dict: bool = True, - ) -> Tuple[jnp.ndarray]: + ): # decoder layers all_hidden_states = () if output_hidden_states else None all_self_attns = () if output_attentions else None @@ -540,14 +560,21 @@ def __call__( for decoder_layer in self.layers: if output_hidden_states: all_hidden_states += (hidden_states,) - layer_outputs = decoder_layer( - hidden_states, - attention_mask, - encoder_hidden_states, - init_cache=init_cache, - output_attentions=output_attentions, - deterministic=deterministic, - ) + # add LayerDrop (see https://arxiv.org/abs/1909.11556 for description) + dropout_probability = random.uniform(0, 1) + if not deterministic and (dropout_probability < self.layerdrop): + layer_outputs = (None, None, None) + else: + layer_outputs = decoder_layer( + hidden_states, + attention_mask=attention_mask, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + init_cache=init_cache, + output_attentions=output_attentions, + deterministic=deterministic, + ) + hidden_states = layer_outputs[0] if output_attentions: all_self_attns += (layer_outputs[1],) From 21fe767cd83ba5b594da76eaecfd5f1e9266693d Mon Sep 17 00:00:00 2001 From: andyehrenberg Date: Fri, 16 Dec 2022 15:30:22 -0700 Subject: [PATCH 042/111] compute position_ids only in __call__ and decode methods of pretrained model subclasses --- .../models/whisper/modeling_flax_whisper.py | 43 ++++++++++--------- .../whisper/test_modeling_flax_whisper.py | 3 -- 2 files changed, 22 insertions(+), 24 deletions(-) diff --git a/src/transformers/models/whisper/modeling_flax_whisper.py b/src/transformers/models/whisper/modeling_flax_whisper.py index ed831579e52f35..4972ff805f76d5 100644 --- a/src/transformers/models/whisper/modeling_flax_whisper.py +++ b/src/transformers/models/whisper/modeling_flax_whisper.py @@ -690,25 +690,19 @@ def setup(self) -> None: def __call__( self, input_ids: jnp.ndarray, - encoder_hidden_states: jnp.ndarray, - attention_mask: Optional[jnp.ndarray] = None, - position_ids: Optional[jnp.ndarray] = None, + attention_mask: jnp.ndarray, + position_ids: jnp.ndarray, + encoder_hidden_states: Optional[jnp.ndarray] = None, init_cache: bool = False, output_attentions: bool = False, output_hidden_states: bool = False, return_dict: bool = True, deterministic: bool = True, ) -> Tuple[jnp.ndarray]: - if attention_mask is not None: - if position_ids is None: - position_ids = attention_mask.cumsum(-1) - 1 - if position_ids is None: - batch_size, sequence_length = input_ids.shape - position_ids = jnp.broadcast_to(jnp.arange(sequence_length)[None, :], (batch_size, sequence_length)) + input_embeds = self.embed_tokens(input_ids) + position_embeds = self.embed_positions(position_ids) - pos_emb = self.embed_positions(position_ids) - - hidden_states = self.embed_tokens(input_ids) + pos_emb + hidden_states = input_embeds + position_embeds hidden_states = self.dropout_layer(hidden_states, deterministic=deterministic) outputs = self.layers( @@ -722,10 +716,18 @@ def __call__( return_dict=return_dict, ) - hidden_states = self.layer_norm(outputs[0]) + last_hidden_states = outputs[0] + last_hidden_states = self.layer_norm(last_hidden_states) + + # update the last element in `hidden_states` after applying `layernorm` above + hidden_states = None + if output_hidden_states: + hidden_states = outputs[1] + hidden_states = hidden_states[:-1] + (last_hidden_states,) if not return_dict: - return (hidden_states,) + outputs[1:] + outputs = (last_hidden_states, hidden_states) + (outputs[2:] if output_hidden_states else outputs[1:]) + return tuple(v for v in outputs if v is not None) return FlaxBaseModelOutputWithPastAndCrossAttentions( last_hidden_state=hidden_states, @@ -745,10 +747,10 @@ def setup(self) -> None: def __call__( self, - input_features, - decoder_input_ids, - decoder_attention_mask: Optional[jnp.ndarray] = None, - decoder_position_ids: Optional[jnp.ndarray] = None, + input_features: jnp.ndarray, + decoder_input_ids: jnp.ndarray, + decoder_attention_mask: jnp.ndarray, + decoder_position_ids: jnp.ndarray, output_attentions: bool = False, output_hidden_states: bool = False, return_dict: bool = True, @@ -931,7 +933,7 @@ def _encoder_forward(module, input_features, **kwargs): return self.module.apply( {"params": params or self.params}, - input_features=input_features, + input_features=jnp.array(input_features, dtype="f4"), output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, @@ -1082,7 +1084,6 @@ def __call__( decoder_position_ids = decoder_attention_mask.cumsum(-1) - 1 else: batch_size, sequence_length = decoder_input_ids.shape - decoder_position_ids = jnp.broadcast_to( jnp.arange(sequence_length)[None, :], (batch_size, sequence_length) ) @@ -1094,7 +1095,7 @@ def __call__( return self.module.apply( {"params": params or self.params}, - input_features=input_features, + input_features=jnp.array(input_features, dtype="f4"), decoder_input_ids=jnp.array(decoder_input_ids, dtype="i4"), decoder_attention_mask=jnp.array(decoder_attention_mask, dtype="i4"), decoder_position_ids=jnp.array(decoder_position_ids, dtype="i4"), diff --git a/tests/models/whisper/test_modeling_flax_whisper.py b/tests/models/whisper/test_modeling_flax_whisper.py index 00b5419397f12e..813ed5e18cf113 100644 --- a/tests/models/whisper/test_modeling_flax_whisper.py +++ b/tests/models/whisper/test_modeling_flax_whisper.py @@ -250,9 +250,6 @@ def _long_tensor(tok_lst): return np.array(tok_lst, dtype=np.int32) -TOLERANCE = 1e-4 - - @slow @require_flax class FlaxWhisperModelIntegrationTest(unittest.TestCase): From a9016748940419157c05b35e173745b417595c3e Mon Sep 17 00:00:00 2001 From: andyehrenberg Date: Fri, 16 Dec 2022 15:46:55 -0700 Subject: [PATCH 043/111] improve readabilityof compute positional embeddings --- src/transformers/models/whisper/modeling_flax_whisper.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/transformers/models/whisper/modeling_flax_whisper.py b/src/transformers/models/whisper/modeling_flax_whisper.py index 4972ff805f76d5..6bd772b0156bf6 100644 --- a/src/transformers/models/whisper/modeling_flax_whisper.py +++ b/src/transformers/models/whisper/modeling_flax_whisper.py @@ -649,7 +649,8 @@ def __call__( f" {self.config.d_model}))" ) - hidden_states = hidden_states + self.embed_positions(jnp.arange(self.config.max_source_positions)) + embed_positions = self.embed_positions(jnp.arange(self.config.max_source_positions)) + hidden_states = hidden_states + embed_positions hidden_states = self.dropout_layer(hidden_states, deterministic=deterministic) From f8d4686a71c6bb032a52fe54ada53be68faefe50 Mon Sep 17 00:00:00 2001 From: andyehrenberg Date: Fri, 16 Dec 2022 15:53:18 -0700 Subject: [PATCH 044/111] check dimensionality of input_features instead of hidden_states --- .../models/whisper/modeling_flax_whisper.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/src/transformers/models/whisper/modeling_flax_whisper.py b/src/transformers/models/whisper/modeling_flax_whisper.py index 6bd772b0156bf6..892ea777d9a9d1 100644 --- a/src/transformers/models/whisper/modeling_flax_whisper.py +++ b/src/transformers/models/whisper/modeling_flax_whisper.py @@ -638,17 +638,17 @@ def __call__( return_dict: bool = True, deterministic: bool = True, ) -> Tuple[jnp.ndarray]: + if input_features.shape[1:] != (self.config.num_mel_bins, self.config.max_source_positions * 2): + raise ValueError( + "input_features.shape[1:], must be equal to (self.config.num_mel_bins, self.config.max_source_positions * 2) (got" + f" {input_features.shape[1:]}, but should be ({self.config.num_mel_bins}," + f" {self.config.max_source_positions * 2}))" + ) + input_features = input_features.transpose(0, 2, 1) hidden_states = jax.nn.gelu(self.conv1(input_features), approximate=False) hidden_states = jax.nn.gelu(self.conv2(hidden_states), approximate=False) - if hidden_states.shape[1:] != (self.config.max_source_positions, self.config.d_model): - raise ValueError( - "hidden_states.shape[1:] must be equal to (self.config.max_source_positions, self.config.d_model)(got" - f" {hidden_states.shape[1:]}, but should be ({self.config.max_source_positions}," - f" {self.config.d_model}))" - ) - embed_positions = self.embed_positions(jnp.arange(self.config.max_source_positions)) hidden_states = hidden_states + embed_positions From b40761119ae71c001af16b4c04fe441ef617a144 Mon Sep 17 00:00:00 2001 From: andyehrenberg Date: Fri, 16 Dec 2022 15:55:06 -0700 Subject: [PATCH 045/111] copied from statement for init_cache --- src/transformers/models/whisper/modeling_flax_whisper.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/transformers/models/whisper/modeling_flax_whisper.py b/src/transformers/models/whisper/modeling_flax_whisper.py index 892ea777d9a9d1..544f613a773a66 100644 --- a/src/transformers/models/whisper/modeling_flax_whisper.py +++ b/src/transformers/models/whisper/modeling_flax_whisper.py @@ -845,6 +845,7 @@ def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple, params: Froz else: return random_params + # Copied from transformers.models.bart.modeling_flax_bart.FlaxBartPreTrainedModel.init_cache with Bart->Whisper def init_cache(self, batch_size, max_length, encoder_outputs): r""" Args: From 8e78c86ae66d569e0b49afe0b12ab92641634650 Mon Sep 17 00:00:00 2001 From: andyehrenberg Date: Fri, 16 Dec 2022 16:01:57 -0700 Subject: [PATCH 046/111] formatting --- src/transformers/models/whisper/modeling_flax_whisper.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/transformers/models/whisper/modeling_flax_whisper.py b/src/transformers/models/whisper/modeling_flax_whisper.py index 544f613a773a66..e2795f8311816c 100644 --- a/src/transformers/models/whisper/modeling_flax_whisper.py +++ b/src/transformers/models/whisper/modeling_flax_whisper.py @@ -640,9 +640,9 @@ def __call__( ) -> Tuple[jnp.ndarray]: if input_features.shape[1:] != (self.config.num_mel_bins, self.config.max_source_positions * 2): raise ValueError( - "input_features.shape[1:], must be equal to (self.config.num_mel_bins, self.config.max_source_positions * 2) (got" - f" {input_features.shape[1:]}, but should be ({self.config.num_mel_bins}," - f" {self.config.max_source_positions * 2}))" + "input_features.shape[1:], must be equal to (self.config.num_mel_bins," + f" self.config.max_source_positions * 2) (got {input_features.shape[1:]}, but should be" + f" ({self.config.num_mel_bins}, {self.config.max_source_positions * 2}))" ) input_features = input_features.transpose(0, 2, 1) From 810358cbf099a2b45736117843d993f5f24735bb Mon Sep 17 00:00:00 2001 From: andyehrenberg Date: Fri, 16 Dec 2022 16:24:42 -0700 Subject: [PATCH 047/111] fix copies --- src/transformers/models/whisper/modeling_flax_whisper.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/transformers/models/whisper/modeling_flax_whisper.py b/src/transformers/models/whisper/modeling_flax_whisper.py index e2795f8311816c..e58e6255c88991 100644 --- a/src/transformers/models/whisper/modeling_flax_whisper.py +++ b/src/transformers/models/whisper/modeling_flax_whisper.py @@ -870,9 +870,9 @@ def init_cache(self, batch_size, max_length, encoder_outputs): def _decoder_forward(module, decoder_input_ids, decoder_attention_mask, decoder_position_ids, **kwargs): decoder_module = module._get_decoder_module() return decoder_module( - input_ids=decoder_input_ids, - attention_mask=decoder_attention_mask, - position_ids=decoder_position_ids, + decoder_input_ids, + decoder_attention_mask, + decoder_position_ids, **kwargs, ) From b06a6baf9129adafd9b18c7cc012014e073da488 Mon Sep 17 00:00:00 2001 From: andyehrenberg Date: Fri, 16 Dec 2022 16:30:50 -0700 Subject: [PATCH 048/111] fix copies --- src/transformers/models/whisper/modeling_flax_whisper.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/transformers/models/whisper/modeling_flax_whisper.py b/src/transformers/models/whisper/modeling_flax_whisper.py index e58e6255c88991..5b0ed4d5b34527 100644 --- a/src/transformers/models/whisper/modeling_flax_whisper.py +++ b/src/transformers/models/whisper/modeling_flax_whisper.py @@ -885,7 +885,6 @@ def _decoder_forward(module, decoder_input_ids, decoder_attention_mask, decoder_ init_cache=True, method=_decoder_forward, # we only need to call the decoder to init the cache ) - return unfreeze(init_variables["cache"]) @add_start_docstrings(WHISPER_ENCODE_INPUTS_DOCSTRING) From 45efd60da3258585d990c6b3e75b7514cd34f8c6 Mon Sep 17 00:00:00 2001 From: andyehrenberg Date: Wed, 21 Dec 2022 11:42:14 -0800 Subject: [PATCH 049/111] pass attention mask to encoder layers --- src/transformers/models/whisper/modeling_flax_whisper.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/transformers/models/whisper/modeling_flax_whisper.py b/src/transformers/models/whisper/modeling_flax_whisper.py index 5b0ed4d5b34527..0612c83cf87d19 100644 --- a/src/transformers/models/whisper/modeling_flax_whisper.py +++ b/src/transformers/models/whisper/modeling_flax_whisper.py @@ -656,6 +656,7 @@ def __call__( outputs = self.layers( hidden_states, + attention_mask=None, deterministic=deterministic, output_attentions=output_attentions, output_hidden_states=output_hidden_states, From 718f53bc839b0efe1741c2ce056945a931643f23 Mon Sep 17 00:00:00 2001 From: andyehrenberg Date: Wed, 21 Dec 2022 12:30:37 -0800 Subject: [PATCH 050/111] fix decoder module outputs --- src/transformers/models/whisper/modeling_flax_whisper.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/transformers/models/whisper/modeling_flax_whisper.py b/src/transformers/models/whisper/modeling_flax_whisper.py index 0612c83cf87d19..3bbb42f3ddb373 100644 --- a/src/transformers/models/whisper/modeling_flax_whisper.py +++ b/src/transformers/models/whisper/modeling_flax_whisper.py @@ -732,8 +732,8 @@ def __call__( return tuple(v for v in outputs if v is not None) return FlaxBaseModelOutputWithPastAndCrossAttentions( - last_hidden_state=hidden_states, - hidden_states=outputs.hidden_states, + last_hidden_state=last_hidden_states, + hidden_states=hidden_states, attentions=outputs.attentions, cross_attentions=outputs.cross_attentions, ) From 07a24a8a8baf5bd40b08543770abc660848f095c Mon Sep 17 00:00:00 2001 From: Andy Ehrenberg <32784181+andyehrenberg@users.noreply.github.com> Date: Thu, 22 Dec 2022 11:06:20 -0800 Subject: [PATCH 051/111] set dtype Co-authored-by: Sanchit Gandhi <93869735+sanchit-gandhi@users.noreply.github.com> --- src/transformers/generation/flax_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/transformers/generation/flax_utils.py b/src/transformers/generation/flax_utils.py index 2165e884e1004d..38fb8c4215a7fb 100644 --- a/src/transformers/generation/flax_utils.py +++ b/src/transformers/generation/flax_utils.py @@ -169,7 +169,7 @@ def _prepare_decoder_input_ids_for_generation( if decoder_input_ids is not None: return decoder_input_ids decoder_start_token_id = self._get_decoder_start_token_id(decoder_start_token_id, bos_token_id) - return jnp.array(decoder_start_token_id).reshape(1, -1).repeat(batch_size, axis=0) + return jnp.array(decoder_start_token_id, dtype="i4").reshape(1, -1).repeat(batch_size, axis=0) def _get_decoder_start_token_id(self, decoder_start_token_id: int = None, bos_token_id: int = None) -> int: # retrieve decoder_start_token_id for encoder-decoder models From 43c4ed83654135f369ad24416b0bfea5074e6f0e Mon Sep 17 00:00:00 2001 From: andyehrenberg Date: Thu, 22 Dec 2022 12:51:21 -0800 Subject: [PATCH 052/111] smaller flax model for whisper test --- .../whisper/test_modeling_flax_whisper.py | 24 +++++++++---------- 1 file changed, 12 insertions(+), 12 deletions(-) diff --git a/tests/models/whisper/test_modeling_flax_whisper.py b/tests/models/whisper/test_modeling_flax_whisper.py index 813ed5e18cf113..cdfcc72a13cf95 100644 --- a/tests/models/whisper/test_modeling_flax_whisper.py +++ b/tests/models/whisper/test_modeling_flax_whisper.py @@ -51,25 +51,25 @@ class FlaxWhisperModelTester: def __init__( self, parent, - batch_size=1, - seq_length=3000, + batch_size=13, + seq_length=60, is_training=True, use_labels=False, vocab_size=99, - d_model=384, - decoder_attention_heads=6, - decoder_ffn_dim=1536, - decoder_layers=4, - encoder_attention_heads=6, - encoder_ffn_dim=1536, - encoder_layers=4, + d_model=16, + decoder_attention_heads=4, + decoder_ffn_dim=16, + decoder_layers=2, + encoder_attention_heads=4, + encoder_ffn_dim=16, + encoder_layers=2, input_channels=1, hidden_act="gelu", hidden_dropout_prob=0.1, attention_probs_dropout_prob=0.1, - max_position_embeddings=20, - max_source_positions=1500, - max_target_positions=448, + max_position_embeddings=70, + max_source_positions=30, + max_target_positions=40, bos_token_id=98, eos_token_id=98, pad_token_id=0, From 7b359078725c541d5768bf56cce229489df64ae9 Mon Sep 17 00:00:00 2001 From: Andy Ehrenberg <32784181+andyehrenberg@users.noreply.github.com> Date: Sat, 31 Dec 2022 11:10:09 -0500 Subject: [PATCH 053/111] Update src/transformers/generation/flax_utils.py Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> --- src/transformers/generation/flax_utils.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/transformers/generation/flax_utils.py b/src/transformers/generation/flax_utils.py index 38fb8c4215a7fb..d78babf74f0682 100644 --- a/src/transformers/generation/flax_utils.py +++ b/src/transformers/generation/flax_utils.py @@ -329,14 +329,14 @@ def generate( Number of beams for beam search. 1 means no beam search. decoder_start_token_id (`int`, *optional*): If an encoder-decoder model starts decoding with a different token than *bos*, the id of that token. - suppress_tokens (`List[int]`, *optional*, defaults to model.config.suppress_tokens): + suppress_tokens (`List[int]`, *optional*, defaults to `model.config.suppress_tokens`): A list of tokens that will be supressed at generation. The `FlaxSupressTokensLogitsProcessor` will set their log probs to `-inf` so that they are not sampled. - begin_suppress_tokens (`List[int]`, *optional*, defaults to model.config.begin_supress_tokens): + begin_suppress_tokens (`List[int]`, *optional*, defaults to `model.config.begin_supress_tokens`): A list of tokens that will be supressed at the begining of the generation. The `FlaxSuppressTokensAtBeginLogitsProcessor` will set their log probs to `-inf` so that they are not sampled. - forced_decoder_ids (`List[List[int]]`, *optional*, defaults to model.config.forced_decoder_ids): + forced_decoder_ids (`List[List[int]]`, *optional*, defaults to `model.config.forced_decoder_ids`): A list of pairs of integers which indicates a mapping from generation indices to token indices that will be forced before sampling. For example, `[[1, 123]]` means the second generated token will always be a token of index 123. From 8a4d990f2891bdb04a9a07489fe902929044fdb1 Mon Sep 17 00:00:00 2001 From: Andy Ehrenberg <32784181+andyehrenberg@users.noreply.github.com> Date: Sat, 31 Dec 2022 11:10:42 -0500 Subject: [PATCH 054/111] Update src/transformers/models/whisper/modeling_flax_whisper.py Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> --- src/transformers/models/whisper/modeling_flax_whisper.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/src/transformers/models/whisper/modeling_flax_whisper.py b/src/transformers/models/whisper/modeling_flax_whisper.py index 3bbb42f3ddb373..aa4453f2bf9fe3 100644 --- a/src/transformers/models/whisper/modeling_flax_whisper.py +++ b/src/transformers/models/whisper/modeling_flax_whisper.py @@ -156,8 +156,7 @@ decoder_position_ids (`numpy.ndarray` of shape `(batch_size, sequence_length)`, *optional*): Indices of positions of each decoder input sequence tokens in the position embeddings. Selected in the range `[0, config.max_position_embeddings - 1]`. - past_key_values (`Dict[str, numpy.ndarray]`, *optional*, returned by `init_cache` or when passing previous - `past_key_values`): + past_key_values (`Dict[str, numpy.ndarray]`, *optional*, returned by `init_cache` or when passing previous `past_key_values`): Dictionary of pre-computed hidden-states (key and values in the attention blocks) that can be used for fast auto-regressive decoding. Pre-computed key and value hidden-states are of shape *[batch_size, max_length]*. output_attentions (`bool`, *optional*): From 17c22fec5aaaa430de0817599745d9842b4cad15 Mon Sep 17 00:00:00 2001 From: Andy Ehrenberg <32784181+andyehrenberg@users.noreply.github.com> Date: Sat, 31 Dec 2022 11:11:16 -0500 Subject: [PATCH 055/111] Update tests/models/whisper/test_modeling_flax_whisper.py Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> --- tests/models/whisper/test_modeling_flax_whisper.py | 9 +-------- 1 file changed, 1 insertion(+), 8 deletions(-) diff --git a/tests/models/whisper/test_modeling_flax_whisper.py b/tests/models/whisper/test_modeling_flax_whisper.py index cdfcc72a13cf95..51099ece2536e5 100644 --- a/tests/models/whisper/test_modeling_flax_whisper.py +++ b/tests/models/whisper/test_modeling_flax_whisper.py @@ -173,14 +173,7 @@ def prepare_whisper_inputs_dict( @require_flax class FlaxWhisperModelTest(FlaxModelTesterMixin, unittest.TestCase): - all_model_classes = ( - ( - FlaxWhisperForConditionalGeneration, - FlaxWhisperModel, - ) - if is_flax_available() - else () - ) + all_model_classes = (FlaxWhisperForConditionalGeneration, FlaxWhisperModel) if is_flax_available() else () all_generative_model_classes = (FlaxWhisperForConditionalGeneration,) if is_flax_available() else () is_encoder_decoder = True test_pruning = False From 8c021ae7da01b609a324679cba7932feb841cf71 Mon Sep 17 00:00:00 2001 From: Andy Ehrenberg <32784181+andyehrenberg@users.noreply.github.com> Date: Sat, 31 Dec 2022 11:12:48 -0500 Subject: [PATCH 056/111] cleanup Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> --- src/transformers/models/whisper/modeling_flax_whisper.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/transformers/models/whisper/modeling_flax_whisper.py b/src/transformers/models/whisper/modeling_flax_whisper.py index aa4453f2bf9fe3..02a9bc58328a37 100644 --- a/src/transformers/models/whisper/modeling_flax_whisper.py +++ b/src/transformers/models/whisper/modeling_flax_whisper.py @@ -157,7 +157,6 @@ Indices of positions of each decoder input sequence tokens in the position embeddings. Selected in the range `[0, config.max_position_embeddings - 1]`. past_key_values (`Dict[str, numpy.ndarray]`, *optional*, returned by `init_cache` or when passing previous `past_key_values`): - Dictionary of pre-computed hidden-states (key and values in the attention blocks) that can be used for fast auto-regressive decoding. Pre-computed key and value hidden-states are of shape *[batch_size, max_length]*. output_attentions (`bool`, *optional*): Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned From 2aed9afc0dbea69fff7d2fed56a4e1b1d6ff3876 Mon Sep 17 00:00:00 2001 From: Andy Ehrenberg <32784181+andyehrenberg@users.noreply.github.com> Date: Sat, 31 Dec 2022 11:14:24 -0500 Subject: [PATCH 057/111] Update src/transformers/models/whisper/modeling_flax_whisper.py Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> --- src/transformers/models/whisper/modeling_flax_whisper.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/transformers/models/whisper/modeling_flax_whisper.py b/src/transformers/models/whisper/modeling_flax_whisper.py index 02a9bc58328a37..ae835f0a0670b2 100644 --- a/src/transformers/models/whisper/modeling_flax_whisper.py +++ b/src/transformers/models/whisper/modeling_flax_whisper.py @@ -157,7 +157,6 @@ Indices of positions of each decoder input sequence tokens in the position embeddings. Selected in the range `[0, config.max_position_embeddings - 1]`. past_key_values (`Dict[str, numpy.ndarray]`, *optional*, returned by `init_cache` or when passing previous `past_key_values`): - auto-regressive decoding. Pre-computed key and value hidden-states are of shape *[batch_size, max_length]*. output_attentions (`bool`, *optional*): Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned tensors for more detail. From 64da8fae66a0d36eb8b21bfe4c89831a40f41dc9 Mon Sep 17 00:00:00 2001 From: andyehrenberg Date: Sat, 31 Dec 2022 11:15:54 -0500 Subject: [PATCH 058/111] bias cleanup --- src/transformers/models/whisper/modeling_flax_whisper.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/transformers/models/whisper/modeling_flax_whisper.py b/src/transformers/models/whisper/modeling_flax_whisper.py index 3bbb42f3ddb373..bfedc590df6ed0 100644 --- a/src/transformers/models/whisper/modeling_flax_whisper.py +++ b/src/transformers/models/whisper/modeling_flax_whisper.py @@ -196,10 +196,10 @@ def setup(self) -> None: kernel_init=jax.nn.initializers.normal(self.config.init_std), ) - self.q_proj = dense(use_bias=self.bias) + self.q_proj = dense() self.k_proj = dense(use_bias=False) - self.v_proj = dense(use_bias=self.bias) - self.out_proj = dense(use_bias=self.bias) + self.v_proj = dense() + self.out_proj = dense() if self.causal: self.causal_mask = make_causal_mask( From 618f85b16456598e005744f1af14065852104556 Mon Sep 17 00:00:00 2001 From: andyehrenberg Date: Sat, 31 Dec 2022 11:20:40 -0500 Subject: [PATCH 059/111] doc fix --- src/transformers/models/whisper/modeling_flax_whisper.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/transformers/models/whisper/modeling_flax_whisper.py b/src/transformers/models/whisper/modeling_flax_whisper.py index ba5b1dc8e6a923..5088803aace38a 100644 --- a/src/transformers/models/whisper/modeling_flax_whisper.py +++ b/src/transformers/models/whisper/modeling_flax_whisper.py @@ -157,6 +157,8 @@ Indices of positions of each decoder input sequence tokens in the position embeddings. Selected in the range `[0, config.max_position_embeddings - 1]`. past_key_values (`Dict[str, numpy.ndarray]`, *optional*, returned by `init_cache` or when passing previous `past_key_values`): + Dictionary of pre-computed hidden-states (key and values in the attention blocks) that can be used for fast + auto-regressive decoding. Pre-computed key and value hidden-states are of shape *[batch_size, max_length]*. output_attentions (`bool`, *optional*): Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned tensors for more detail. From 8b56bf45a3bec31745ba765f717fa81e8928a045 Mon Sep 17 00:00:00 2001 From: andyehrenberg Date: Mon, 2 Jan 2023 10:42:31 -0500 Subject: [PATCH 060/111] align style for force tokens processor --- .../generation/flax_logits_process.py | 65 ++++++++++--------- 1 file changed, 34 insertions(+), 31 deletions(-) diff --git a/src/transformers/generation/flax_logits_process.py b/src/transformers/generation/flax_logits_process.py index b63f5aac37e220..38d70eb9ae7269 100644 --- a/src/transformers/generation/flax_logits_process.py +++ b/src/transformers/generation/flax_logits_process.py @@ -15,6 +15,8 @@ import inspect +import numpy as np + import jax import jax.lax as lax import jax.numpy as jnp @@ -311,36 +313,11 @@ def __call__(self, input_ids: jnp.ndarray, scores: jnp.ndarray, cur_len: int) -> return scores -class FlaxForceTokenAtIdxLogitsProcessor(FlaxLogitsProcessor): - r""" - [`FlaxLogitsProcessor`] forcing a token to be sampled at an index. - - Args: - apply_idx (`int`): - Index where sampling is forced. - token_id (`int`): - Token that is forced to be sampled. - """ - - def __init__(self, apply_idx: int, token_id: int): - self.apply_idx = apply_idx - self.token_id = token_id - - def __call__(self, input_ids: jnp.ndarray, scores: jnp.ndarray, cur_len: int) -> jnp.ndarray: - new_scores = jnp.full(scores.shape, -float("inf")) - - apply_penalty = 1 - jnp.bool_(cur_len - self.apply_idx) - - scores = jnp.where(apply_penalty, new_scores.at[:, self.token_id].set(0), scores) - - return scores - - class FlaxForceTokensLogitsProcessor(FlaxLogitsProcessor): r""" [`FlaxLogitsProcessor`] that takes a list of pairs of integers which indicates a mapping from generation indices to - token indices that will be forced before sampling. The processor will set their log probs to `inf` so that they are - sampled at their corresponding index. + token indices that will be forced before sampling. The processor will set their log probs to 0 and all other tokens + to `-inf` so that they are sampled at their corresponding index. Args: force_token_map (`list`): @@ -348,10 +325,36 @@ class FlaxForceTokensLogitsProcessor(FlaxLogitsProcessor): """ def __init__(self, force_token_map): - self.processors = [FlaxForceTokenAtIdxLogitsProcessor(i[0], i[1]) for i in force_token_map] + force_token_map = dict(force_token_map) + # Converts the dictionary of format {index: token} containing the tokens to be forced to an array, where the + # index of the array corresponds to the index of the token to be forced, for XLA compatibility. + # Indexes without forced tokens will have a negative value. + force_token_array = np.ones((max(force_token_map.keys()) + 1), dtype=np.int32) * -1 + for index, token in force_token_map.items(): + force_token_array[index] = token + self.force_token_array = jnp.array(force_token_array) def __call__(self, input_ids: jnp.ndarray, scores: jnp.ndarray, cur_len: int) -> jnp.ndarray: - for processor in self.processors: - scores = processor(input_ids, scores, cur_len) - + def _force_token(generation_idx): + batch_size = scores.shape[0] + current_token = self.force_token_array[generation_idx] + + new_scores = jnp.ones_like(scores, dtype=scores.dtype) * -float("inf") + updates = jnp.zeros((batch_size, 1), dtype=scores.dtype) + new_scores = lax.dynamic_update_slice(new_scores, updates, (0, current_token)) + return new_scores + + scores = lax.cond( + cur_len >= self.force_token_array.shape[0], + # If the current length is geq than the length of force_token_array, the processor does nothing. + lambda: scores, + # Otherwise, it may force a certain token. + lambda: lax.cond( + self.force_token_array[cur_len] >= 0, + # Only valid (positive) tokens are forced + lambda: _force_token(cur_len), + # Otherwise, the processor does nothing. + lambda: scores, + ), + ) return scores From 209834df5109d6c81233b9ff5c208eb5b28679f1 Mon Sep 17 00:00:00 2001 From: andyehrenberg Date: Tue, 3 Jan 2023 10:03:26 -0500 Subject: [PATCH 061/111] readability --- src/transformers/models/whisper/modeling_flax_whisper.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/transformers/models/whisper/modeling_flax_whisper.py b/src/transformers/models/whisper/modeling_flax_whisper.py index 5088803aace38a..2824996d5f9810 100644 --- a/src/transformers/models/whisper/modeling_flax_whisper.py +++ b/src/transformers/models/whisper/modeling_flax_whisper.py @@ -225,7 +225,9 @@ def __call__( k = self.k_proj(hidden_states) v = self.v_proj(hidden_states) - q, k, v = jax.tree_util.tree_map(self._split_heads, (q, k, v)) + q = self._split_heads(q) + k = self._split_heads(k) + v = self._split_heads(v) if self.causal: query_length, key_length = q.shape[1], k.shape[1] From fac30a0b90414a3eefbf900a7f1b6ee0124e6e4c Mon Sep 17 00:00:00 2001 From: andyehrenberg Date: Tue, 3 Jan 2023 11:09:55 -0500 Subject: [PATCH 062/111] fix input shape in tests --- .../whisper/test_modeling_flax_whisper.py | 233 +++++++++++++++++- 1 file changed, 226 insertions(+), 7 deletions(-) diff --git a/tests/models/whisper/test_modeling_flax_whisper.py b/tests/models/whisper/test_modeling_flax_whisper.py index 51099ece2536e5..de927c2f204398 100644 --- a/tests/models/whisper/test_modeling_flax_whisper.py +++ b/tests/models/whisper/test_modeling_flax_whisper.py @@ -15,10 +15,12 @@ import inspect +import tempfile import unittest -from transformers import WhisperConfig, is_flax_available -from transformers.testing_utils import require_flax, slow +import transformers +from transformers import WhisperConfig, is_flax_available, is_torch_available +from transformers.testing_utils import is_pt_flax_cross_test, require_flax, slow, torch_device from transformers.utils import cached_property from transformers.utils.import_utils import is_datasets_available @@ -34,12 +36,23 @@ import numpy as np import jax + import jax.numpy as jnp + from flax.core.frozen_dict import unfreeze + from flax.traverse_util import flatten_dict from transformers import ( + FLAX_MODEL_MAPPING, FlaxWhisperForConditionalGeneration, FlaxWhisperModel, WhisperFeatureExtractor, WhisperProcessor, ) + from transformers.modeling_flax_pytorch_utils import ( + convert_pytorch_state_dict_to_flax, + load_flax_weights_in_pytorch_model, + ) + +if is_torch_available(): + import torch @require_flax @@ -203,11 +216,12 @@ def test_forward_signature(self): # overwrite because of `input_features` def test_jit_compilation(self): config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + init_shape = (1,) + inputs_dict["input_features"].shape[1:] for model_class in self.all_model_classes: with self.subTest(model_class.__name__): prepared_inputs_dict = self._prepare_for_class(inputs_dict, model_class) - model = model_class(config) + model = model_class(config, input_shape=init_shape) @jax.jit def model_jitted(input_features, decoder_input_ids, **kwargs): @@ -224,6 +238,215 @@ def model_jitted(input_features, decoder_input_ids, **kwargs): for jitted_output, output in zip(jitted_outputs, outputs): self.assertEqual(jitted_output.shape, output.shape) + # overwrite because of `input_features` + @is_pt_flax_cross_test + def test_equivalence_flax_to_pt(self): + config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + init_shape = (1,) + inputs_dict["input_features"].shape[1:] + + for model_class in self.all_model_classes: + with self.subTest(model_class.__name__): + # Output all for aggressive testing + config.output_hidden_states = True + config.output_attentions = self.has_attentions + + # prepare inputs + prepared_inputs_dict = self._prepare_for_class(inputs_dict, model_class) + pt_inputs = {k: torch.tensor(v.tolist(), device=torch_device) for k, v in prepared_inputs_dict.items()} + + # load corresponding PyTorch class + pt_model_class_name = model_class.__name__[4:] # Skip the "Flax" at the beginning + pt_model_class = getattr(transformers, pt_model_class_name) + + pt_model = pt_model_class(config).eval() + # Flax models don't use the `use_cache` option and cache is not returned as a default. + # So we disable `use_cache` here for PyTorch model. + pt_model.config.use_cache = False + fx_model = model_class(config, input_shape=init_shape, dtype=jnp.float32) + + pt_model = load_flax_weights_in_pytorch_model(pt_model, fx_model.params) + + # make sure weights are tied in PyTorch + pt_model.tie_weights() + + # send pytorch model to the correct device + pt_model.to(torch_device) + + with torch.no_grad(): + pt_outputs = pt_model(**pt_inputs) + fx_outputs = fx_model(**prepared_inputs_dict) + + fx_keys = tuple([k for k, v in fx_outputs.items() if v is not None]) + pt_keys = tuple([k for k, v in pt_outputs.items() if v is not None]) + + self.assertEqual(fx_keys, pt_keys) + self.check_pt_flax_outputs(fx_outputs, pt_outputs, model_class) + + with tempfile.TemporaryDirectory() as tmpdirname: + fx_model.save_pretrained(tmpdirname) + pt_model_loaded = pt_model_class.from_pretrained(tmpdirname, from_flax=True) + + # send pytorch model to the correct device + pt_model_loaded.to(torch_device) + pt_model_loaded.eval() + + with torch.no_grad(): + pt_outputs_loaded = pt_model_loaded(**pt_inputs) + + fx_keys = tuple([k for k, v in fx_outputs.items() if v is not None]) + pt_keys = tuple([k for k, v in pt_outputs_loaded.items() if v is not None]) + + self.assertEqual(fx_keys, pt_keys) + self.check_pt_flax_outputs(fx_outputs, pt_outputs_loaded, model_class) + + # overwrite because of `input_features` + @is_pt_flax_cross_test + def test_equivalence_pt_to_flax(self): + # It might be better to put this inside the for loop below (because we modify the config there). + # But logically, it is fine. + config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + init_shape = (1,) + inputs_dict["input_features"].shape[1:] + + for model_class in self.all_model_classes: + with self.subTest(model_class.__name__): + # Output all for aggressive testing + config.output_hidden_states = True + config.output_attentions = self.has_attentions + + # prepare inputs + prepared_inputs_dict = self._prepare_for_class(inputs_dict, model_class) + pt_inputs = {k: torch.tensor(v.tolist(), device=torch_device) for k, v in prepared_inputs_dict.items()} + + # load corresponding PyTorch class + pt_model_class_name = model_class.__name__[4:] # Skip the "Flax" at the beginning + pt_model_class = getattr(transformers, pt_model_class_name) + + pt_model = pt_model_class(config).eval() + # Flax models don't use the `use_cache` option and cache is not returned as a default. + # So we disable `use_cache` here for PyTorch model. + pt_model.config.use_cache = False + fx_model = model_class(config, input_shape=init_shape, dtype=jnp.float32) + + fx_state = convert_pytorch_state_dict_to_flax(pt_model.state_dict(), fx_model) + fx_model.params = fx_state + + # send pytorch model to the correct device + pt_model.to(torch_device) + + with torch.no_grad(): + pt_outputs = pt_model(**pt_inputs) + fx_outputs = fx_model(**prepared_inputs_dict) + + fx_keys = tuple([k for k, v in fx_outputs.items() if v is not None]) + pt_keys = tuple([k for k, v in pt_outputs.items() if v is not None]) + + self.assertEqual(fx_keys, pt_keys) + self.check_pt_flax_outputs(fx_outputs, pt_outputs, model_class) + + with tempfile.TemporaryDirectory() as tmpdirname: + pt_model.save_pretrained(tmpdirname) + fx_model_loaded = model_class.from_pretrained(tmpdirname, input_shape=init_shape, from_pt=True) + + fx_outputs_loaded = fx_model_loaded(**prepared_inputs_dict) + + fx_keys = tuple([k for k, v in fx_outputs_loaded.items() if v is not None]) + pt_keys = tuple([k for k, v in pt_outputs.items() if v is not None]) + + self.assertEqual(fx_keys, pt_keys) + self.check_pt_flax_outputs(fx_outputs_loaded, pt_outputs, model_class) + + # overwrite because of `input_features` + @is_pt_flax_cross_test + def test_save_load_bf16_to_base_pt(self): + config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + init_shape = (1,) + inputs_dict["input_features"].shape[1:] + base_class = FLAX_MODEL_MAPPING[config.__class__] + + for model_class in self.all_model_classes: + if model_class == base_class: + continue + + model = model_class(config, input_shape=init_shape) + model.params = model.to_bf16(model.params) + base_params_from_head = flatten_dict(unfreeze(model.params[model.base_model_prefix])) + + # convert Flax model to PyTorch model + pt_model_class = getattr(transformers, model_class.__name__[4:]) # Skip the "Flax" at the beginning + pt_model = pt_model_class(config).eval() + pt_model = load_flax_weights_in_pytorch_model(pt_model, model.params) + + # check that all base model weights are loaded correctly + with tempfile.TemporaryDirectory() as tmpdirname: + pt_model.save_pretrained(tmpdirname) + base_model = base_class.from_pretrained(tmpdirname, input_shape=init_shape, from_pt=True) + + base_params = flatten_dict(unfreeze(base_model.params)) + + for key in base_params_from_head.keys(): + max_diff = (base_params[key] - base_params_from_head[key]).sum().item() + self.assertLessEqual(max_diff, 1e-3, msg=f"{key} not identical") + + # overwrite because of `input_features` + @is_pt_flax_cross_test + def test_save_load_from_base_pt(self): + config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + init_shape = (1,) + inputs_dict["input_features"].shape[1:] + base_class = FLAX_MODEL_MAPPING[config.__class__] + + for model_class in self.all_model_classes: + if model_class == base_class: + continue + + model = base_class(config, input_shape=init_shape) + base_params = flatten_dict(unfreeze(model.params)) + + # convert Flax model to PyTorch model + pt_model_class = getattr(transformers, base_class.__name__[4:]) # Skip the "Flax" at the beginning + pt_model = pt_model_class(config).eval() + pt_model = load_flax_weights_in_pytorch_model(pt_model, model.params) + + # check that all base model weights are loaded correctly + with tempfile.TemporaryDirectory() as tmpdirname: + # save pt model + pt_model.save_pretrained(tmpdirname) + head_model = model_class.from_pretrained(tmpdirname, input_shape=init_shape, from_pt=True) + + base_param_from_head = flatten_dict(unfreeze(head_model.params[head_model.base_model_prefix])) + + for key in base_param_from_head.keys(): + max_diff = (base_params[key] - base_param_from_head[key]).sum().item() + self.assertLessEqual(max_diff, 1e-3, msg=f"{key} not identical") + + # overwrite because of `input_features` + @is_pt_flax_cross_test + def test_save_load_to_base_pt(self): + config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + init_shape = (1,) + inputs_dict["input_features"].shape[1:] + base_class = FLAX_MODEL_MAPPING[config.__class__] + + for model_class in self.all_model_classes: + if model_class == base_class: + continue + + model = model_class(config, input_shape=init_shape) + base_params_from_head = flatten_dict(unfreeze(model.params[model.base_model_prefix])) + + # convert Flax model to PyTorch model + pt_model_class = getattr(transformers, model_class.__name__[4:]) # Skip the "Flax" at the beginning + pt_model = pt_model_class(config).eval() + pt_model = load_flax_weights_in_pytorch_model(pt_model, model.params) + + # check that all base model weights are loaded correctly + with tempfile.TemporaryDirectory() as tmpdirname: + pt_model.save_pretrained(tmpdirname) + base_model = base_class.from_pretrained(tmpdirname, input_shape=init_shape, from_pt=True) + + base_params = flatten_dict(unfreeze(base_model.params)) + + for key in base_params_from_head.keys(): + max_diff = (base_params[key] - base_params_from_head[key]).sum().item() + self.assertLessEqual(max_diff, 1e-3, msg=f"{key} not identical") + def _assert_tensors_equal(a, b, atol=1e-12, prefix=""): """If tensors not close, or a and b arent both tensors, raise a nice Assertion error.""" @@ -239,10 +462,6 @@ def _assert_tensors_equal(a, b, atol=1e-12, prefix=""): raise AssertionError(f"{prefix}{a} != {b}") -def _long_tensor(tok_lst): - return np.array(tok_lst, dtype=np.int32) - - @slow @require_flax class FlaxWhisperModelIntegrationTest(unittest.TestCase): From aa87c9886d520196358a868943e96151b8083e13 Mon Sep 17 00:00:00 2001 From: andyehrenberg Date: Tue, 3 Jan 2023 11:14:56 -0500 Subject: [PATCH 063/111] revert FlaxGenerationMixin docstring --- src/transformers/generation/flax_utils.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/transformers/generation/flax_utils.py b/src/transformers/generation/flax_utils.py index d78babf74f0682..6915b7f2915252 100644 --- a/src/transformers/generation/flax_utils.py +++ b/src/transformers/generation/flax_utils.py @@ -128,7 +128,9 @@ class BeamSearchState: class FlaxGenerationMixin: """ A class containing all functions for auto-regressive text generation, to be used as a mixin in - [`FlaxPreTrainedModel`]. The class exposes [`~generation.FlaxGenerationMixin.generate`], which can be used for: + [`FlaxPreTrainedModel`]. + + The class exposes [`~generation.FlaxGenerationMixin.generate`], which can be used for: - *greedy decoding* by calling [`~generation.FlaxGenerationMixin._greedy_search`] if `num_beams=1` and `do_sample=False`. - *multinomial sampling* by calling [`~generation.FlaxGenerationMixin._sample`] if `num_beams=1` and From 23af05ba68fb329951d8cdbacd17ce412473d821 Mon Sep 17 00:00:00 2001 From: andyehrenberg Date: Tue, 3 Jan 2023 11:22:04 -0500 Subject: [PATCH 064/111] formatting --- src/transformers/generation/flax_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/transformers/generation/flax_utils.py b/src/transformers/generation/flax_utils.py index 6915b7f2915252..67a0bdce2e1d7a 100644 --- a/src/transformers/generation/flax_utils.py +++ b/src/transformers/generation/flax_utils.py @@ -129,7 +129,7 @@ class FlaxGenerationMixin: """ A class containing all functions for auto-regressive text generation, to be used as a mixin in [`FlaxPreTrainedModel`]. - + The class exposes [`~generation.FlaxGenerationMixin.generate`], which can be used for: - *greedy decoding* by calling [`~generation.FlaxGenerationMixin._greedy_search`] if `num_beams=1` and `do_sample=False`. From b8086b62c23b1c97f296008bfbf428082b4baeae Mon Sep 17 00:00:00 2001 From: andyehrenberg Date: Tue, 3 Jan 2023 12:12:36 -0500 Subject: [PATCH 065/111] fix tests --- .../whisper/test_modeling_flax_whisper.py | 3 - tests/models/whisper/test_modeling_whisper.py | 176 +++++++++++++++++- 2 files changed, 171 insertions(+), 8 deletions(-) diff --git a/tests/models/whisper/test_modeling_flax_whisper.py b/tests/models/whisper/test_modeling_flax_whisper.py index de927c2f204398..4f13c91241ee0d 100644 --- a/tests/models/whisper/test_modeling_flax_whisper.py +++ b/tests/models/whisper/test_modeling_flax_whisper.py @@ -166,8 +166,6 @@ def prepare_whisper_inputs_dict( attention_mask=None, decoder_attention_mask=None, ): - if attention_mask is None: - attention_mask = np.not_equal(input_ids, config.pad_token_id).astype(np.int8) if decoder_attention_mask is None: decoder_attention_mask = np.concatenate( [ @@ -179,7 +177,6 @@ def prepare_whisper_inputs_dict( return { "input_features": input_ids, "decoder_input_ids": decoder_input_ids, - "attention_mask": attention_mask, "decoder_attention_mask": decoder_attention_mask, } diff --git a/tests/models/whisper/test_modeling_whisper.py b/tests/models/whisper/test_modeling_whisper.py index 8b854a2b20e038..544e9cb5dcd22d 100644 --- a/tests/models/whisper/test_modeling_whisper.py +++ b/tests/models/whisper/test_modeling_whisper.py @@ -20,9 +20,18 @@ import tempfile import unittest +import numpy as np + +import transformers from transformers import WhisperConfig -from transformers.testing_utils import is_torch_available, require_torch, require_torchaudio, slow, torch_device -from transformers.utils import cached_property +from transformers.testing_utils import ( + is_pt_flax_cross_test, + require_torch, + require_torchaudio, + slow, + torch_device, +) +from transformers.utils import cached_property, is_flax_available, is_torch_available from transformers.utils.import_utils import is_datasets_available from ...generation.test_utils import GenerationTesterMixin @@ -46,6 +55,13 @@ ) from transformers.models.whisper.modeling_whisper import WhisperDecoder, WhisperEncoder +if is_flax_available(): + import jax.numpy as jnp + from transformers.modeling_flax_pytorch_utils import ( + convert_pytorch_state_dict_to_flax, + load_flax_weights_in_pytorch_model, + ) + def prepare_whisper_inputs_dict( config, @@ -745,6 +761,159 @@ def _create_and_check_torchscript(self, config, inputs_dict): self.assertTrue(models_equal) + @is_pt_flax_cross_test + def test_equivalence_pt_to_flax(self): + config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + init_shape = (1,) + inputs_dict["input_features"].shape[1:] + + for model_class in self.all_model_classes: + with self.subTest(model_class.__name__): + fx_model_class_name = "Flax" + model_class.__name__ + + if not hasattr(transformers, fx_model_class_name): + # no flax model exists for this class + return + + # Output all for aggressive testing + config.output_hidden_states = True + config.output_attentions = self.has_attentions + + fx_model_class = getattr(transformers, fx_model_class_name) + + # load PyTorch class + pt_model = model_class(config).eval() + # Flax models don't use the `use_cache` option and cache is not returned as a default. + # So we disable `use_cache` here for PyTorch model. + pt_model.config.use_cache = False + + # load Flax class + fx_model = fx_model_class(config, input_shape=init_shape, dtype=jnp.float32) + + # make sure only flax inputs are forward that actually exist in function args + fx_input_keys = inspect.signature(fx_model.__call__).parameters.keys() + + # prepare inputs + pt_inputs = self._prepare_for_class(inputs_dict, model_class) + + # remove function args that don't exist in Flax + pt_inputs = {k: v for k, v in pt_inputs.items() if k in fx_input_keys} + + # send pytorch inputs to the correct device + pt_inputs = { + k: v.to(device=torch_device) if isinstance(v, torch.Tensor) else v for k, v in pt_inputs.items() + } + + # convert inputs to Flax + fx_inputs = {k: np.array(v) for k, v in pt_inputs.items() if torch.is_tensor(v)} + + fx_state = convert_pytorch_state_dict_to_flax(pt_model.state_dict(), fx_model) + fx_model.params = fx_state + + # send pytorch model to the correct device + pt_model.to(torch_device) + + with torch.no_grad(): + pt_outputs = pt_model(**pt_inputs) + fx_outputs = fx_model(**fx_inputs) + + fx_keys = tuple([k for k, v in fx_outputs.items() if v is not None]) + pt_keys = tuple([k for k, v in pt_outputs.items() if v is not None]) + + self.assertEqual(fx_keys, pt_keys) + self.check_pt_flax_outputs(fx_outputs, pt_outputs, model_class) + + with tempfile.TemporaryDirectory() as tmpdirname: + pt_model.save_pretrained(tmpdirname) + fx_model_loaded = fx_model_class.from_pretrained(tmpdirname, input_shape=init_shape, from_pt=True) + + fx_outputs_loaded = fx_model_loaded(**fx_inputs) + + fx_keys = tuple([k for k, v in fx_outputs_loaded.items() if v is not None]) + pt_keys = tuple([k for k, v in pt_outputs.items() if v is not None]) + + self.assertEqual(fx_keys, pt_keys) + self.check_pt_flax_outputs(fx_outputs_loaded, pt_outputs, model_class) + + @is_pt_flax_cross_test + def test_equivalence_flax_to_pt(self): + config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + init_shape = (1,) + inputs_dict["input_features"].shape[1:] + + for model_class in self.all_model_classes: + with self.subTest(model_class.__name__): + fx_model_class_name = "Flax" + model_class.__name__ + + if not hasattr(transformers, fx_model_class_name): + # no flax model exists for this class + return + + # Output all for aggressive testing + config.output_hidden_states = True + config.output_attentions = self.has_attentions + + fx_model_class = getattr(transformers, fx_model_class_name) + + # load PyTorch class + pt_model = model_class(config).eval() + # Flax models don't use the `use_cache` option and cache is not returned as a default. + # So we disable `use_cache` here for PyTorch model. + pt_model.config.use_cache = False + + # load Flax class + fx_model = fx_model_class(config, input_shape=init_shape, dtype=jnp.float32) + + # make sure only flax inputs are forward that actually exist in function args + fx_input_keys = inspect.signature(fx_model.__call__).parameters.keys() + + # prepare inputs + pt_inputs = self._prepare_for_class(inputs_dict, model_class) + + # remove function args that don't exist in Flax + pt_inputs = {k: v for k, v in pt_inputs.items() if k in fx_input_keys} + + # send pytorch inputs to the correct device + pt_inputs = { + k: v.to(device=torch_device) if isinstance(v, torch.Tensor) else v for k, v in pt_inputs.items() + } + + # convert inputs to Flax + fx_inputs = {k: np.array(v) for k, v in pt_inputs.items() if torch.is_tensor(v)} + + pt_model = load_flax_weights_in_pytorch_model(pt_model, fx_model.params) + + # make sure weights are tied in PyTorch + pt_model.tie_weights() + + # send pytorch model to the correct device + pt_model.to(torch_device) + + with torch.no_grad(): + pt_outputs = pt_model(**pt_inputs) + fx_outputs = fx_model(**fx_inputs) + + fx_keys = tuple([k for k, v in fx_outputs.items() if v is not None]) + pt_keys = tuple([k for k, v in pt_outputs.items() if v is not None]) + + self.assertEqual(fx_keys, pt_keys) + self.check_pt_flax_outputs(fx_outputs, pt_outputs, model_class) + + with tempfile.TemporaryDirectory() as tmpdirname: + fx_model.save_pretrained(tmpdirname) + pt_model_loaded = model_class.from_pretrained(tmpdirname, from_flax=True) + + # send pytorch model to the correct device + pt_model_loaded.to(torch_device) + pt_model_loaded.eval() + + with torch.no_grad(): + pt_outputs_loaded = pt_model_loaded(**pt_inputs) + + fx_keys = tuple([k for k, v in fx_outputs.items() if v is not None]) + pt_keys = tuple([k for k, v in pt_outputs_loaded.items() if v is not None]) + + self.assertEqual(fx_keys, pt_keys) + self.check_pt_flax_outputs(fx_outputs, pt_outputs_loaded, model_class) + @require_torch @require_torchaudio @@ -754,7 +923,6 @@ def default_processor(self): return WhisperProcessor.from_pretrained("openai/whisper-base") def _load_datasamples(self, num_samples): - ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation") # automatic decoding with librispeech speech_samples = ds.sort("id").select(range(num_samples))[:num_samples]["audio"] @@ -884,7 +1052,6 @@ def test_large_logits_librispeech(self): @slow def test_tiny_en_generation(self): - torch_device = "cpu" set_seed(0) processor = WhisperProcessor.from_pretrained("openai/whisper-tiny.en") @@ -908,7 +1075,6 @@ def test_tiny_en_generation(self): @slow def test_tiny_generation(self): - torch_device = "cpu" set_seed(0) processor = WhisperProcessor.from_pretrained("openai/whisper-tiny") From acef3e02c71bf0997c45fc4efbeddf2e47787c3a Mon Sep 17 00:00:00 2001 From: andyehrenberg Date: Tue, 3 Jan 2023 12:18:29 -0500 Subject: [PATCH 066/111] fix imports --- tests/models/whisper/test_modeling_whisper.py | 8 +------- 1 file changed, 1 insertion(+), 7 deletions(-) diff --git a/tests/models/whisper/test_modeling_whisper.py b/tests/models/whisper/test_modeling_whisper.py index 544e9cb5dcd22d..5dabe0378e0a12 100644 --- a/tests/models/whisper/test_modeling_whisper.py +++ b/tests/models/whisper/test_modeling_whisper.py @@ -24,13 +24,7 @@ import transformers from transformers import WhisperConfig -from transformers.testing_utils import ( - is_pt_flax_cross_test, - require_torch, - require_torchaudio, - slow, - torch_device, -) +from transformers.testing_utils import is_pt_flax_cross_test, require_torch, require_torchaudio, slow, torch_device from transformers.utils import cached_property, is_flax_available, is_torch_available from transformers.utils.import_utils import is_datasets_available From da1df33ebf37640ff8f78b69a5323160eeda7916 Mon Sep 17 00:00:00 2001 From: andyehrenberg Date: Tue, 3 Jan 2023 13:06:48 -0500 Subject: [PATCH 067/111] consistent encoder hidden states --- src/transformers/models/whisper/modeling_whisper.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/transformers/models/whisper/modeling_whisper.py b/src/transformers/models/whisper/modeling_whisper.py index 300c6e62a15bbd..1c29db492628d7 100644 --- a/src/transformers/models/whisper/modeling_whisper.py +++ b/src/transformers/models/whisper/modeling_whisper.py @@ -704,10 +704,11 @@ def custom_forward(*inputs): if output_attentions: all_attentions = all_attentions + (layer_outputs[1],) - hidden_states = self.layer_norm(hidden_states) if output_hidden_states: encoder_states = encoder_states + (hidden_states,) + hidden_states = self.layer_norm(hidden_states) + if not return_dict: return tuple(v for v in [hidden_states, encoder_states, all_attentions] if v is not None) return BaseModelOutput( From 4cdba95a4ae22307960792b78497a47287063fe4 Mon Sep 17 00:00:00 2001 From: andyehrenberg Date: Tue, 3 Jan 2023 13:18:33 -0500 Subject: [PATCH 068/111] consistent hidden states --- src/transformers/models/whisper/modeling_tf_whisper.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/transformers/models/whisper/modeling_tf_whisper.py b/src/transformers/models/whisper/modeling_tf_whisper.py index 40cfb3839327af..a525db40aa7c1b 100644 --- a/src/transformers/models/whisper/modeling_tf_whisper.py +++ b/src/transformers/models/whisper/modeling_tf_whisper.py @@ -706,9 +706,9 @@ def call( if output_attentions: all_attentions += (attn,) - hidden_states = self.layer_norm(hidden_states) if output_hidden_states: encoder_states = encoder_states + (hidden_states,) + hidden_states = self.layer_norm(hidden_states) if not return_dict: return tuple(v for v in [hidden_states, encoder_states, all_attentions] if v is not None) From dd7473babfefcb3076d61c15c2f58d6514666212 Mon Sep 17 00:00:00 2001 From: andyehrenberg Date: Tue, 3 Jan 2023 14:05:15 -0500 Subject: [PATCH 069/111] input shapes --- tests/models/whisper/test_modeling_flax_whisper.py | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/tests/models/whisper/test_modeling_flax_whisper.py b/tests/models/whisper/test_modeling_flax_whisper.py index 4f13c91241ee0d..72aa94ccc73fdc 100644 --- a/tests/models/whisper/test_modeling_flax_whisper.py +++ b/tests/models/whisper/test_modeling_flax_whisper.py @@ -14,6 +14,7 @@ # limitations under the License. +import functools import inspect import tempfile import unittest @@ -181,6 +182,10 @@ def prepare_whisper_inputs_dict( } +def adjust_input_shape(cls, input_shape): + return functools.partial(cls, input_shape=input_shape) + + @require_flax class FlaxWhisperModelTest(FlaxModelTesterMixin, unittest.TestCase): all_model_classes = (FlaxWhisperForConditionalGeneration, FlaxWhisperModel) if is_flax_available() else () @@ -192,6 +197,11 @@ class FlaxWhisperModelTest(FlaxModelTesterMixin, unittest.TestCase): def setUp(self): self.model_tester = FlaxWhisperModelTester(self) + _, inputs_dict = self.model_test.prepare_config_and_inputs_for_common() + init_shape = (1,) + inputs_dict["input_features"].shape[1:] + self.all_model_classes = ( + adjust_input_shape(model_class, init_shape) for model_class in self.all_model_classes + ) self.config_tester = ConfigTester(self, config_class=WhisperConfig) def test_config(self): From c5621f73e36f26d96aefc3d5ea8c257bb310e5e4 Mon Sep 17 00:00:00 2001 From: andyehrenberg Date: Tue, 3 Jan 2023 14:15:30 -0500 Subject: [PATCH 070/111] typo --- tests/models/whisper/test_modeling_flax_whisper.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/models/whisper/test_modeling_flax_whisper.py b/tests/models/whisper/test_modeling_flax_whisper.py index 72aa94ccc73fdc..2e2ee18145a85f 100644 --- a/tests/models/whisper/test_modeling_flax_whisper.py +++ b/tests/models/whisper/test_modeling_flax_whisper.py @@ -197,7 +197,7 @@ class FlaxWhisperModelTest(FlaxModelTesterMixin, unittest.TestCase): def setUp(self): self.model_tester = FlaxWhisperModelTester(self) - _, inputs_dict = self.model_test.prepare_config_and_inputs_for_common() + _, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() init_shape = (1,) + inputs_dict["input_features"].shape[1:] self.all_model_classes = ( adjust_input_shape(model_class, init_shape) for model_class in self.all_model_classes From 46aec123c0f711b6674254c38b08a30b3a276f5e Mon Sep 17 00:00:00 2001 From: andyehrenberg Date: Tue, 3 Jan 2023 14:32:01 -0500 Subject: [PATCH 071/111] partial class trick --- tests/models/whisper/test_modeling_flax_whisper.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/tests/models/whisper/test_modeling_flax_whisper.py b/tests/models/whisper/test_modeling_flax_whisper.py index 2e2ee18145a85f..ec0b5441a82c99 100644 --- a/tests/models/whisper/test_modeling_flax_whisper.py +++ b/tests/models/whisper/test_modeling_flax_whisper.py @@ -182,8 +182,11 @@ def prepare_whisper_inputs_dict( } -def adjust_input_shape(cls, input_shape): - return functools.partial(cls, input_shape=input_shape) +def partialclass(cls, *args, **kwargs): + class NewCls(cls): + __init__ = functools.partialmethod(cls.__init__, *args, **kwargs) + + return NewCls @require_flax @@ -200,7 +203,7 @@ def setUp(self): _, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() init_shape = (1,) + inputs_dict["input_features"].shape[1:] self.all_model_classes = ( - adjust_input_shape(model_class, init_shape) for model_class in self.all_model_classes + partialclass(model_class, init_shape) for model_class in self.all_model_classes ) self.config_tester = ConfigTester(self, config_class=WhisperConfig) From a0036169c88bba7f46d5c64a8245e355b0f9f3f0 Mon Sep 17 00:00:00 2001 From: andyehrenberg Date: Tue, 3 Jan 2023 14:53:01 -0500 Subject: [PATCH 072/111] partial class for input shape --- .../whisper/test_modeling_flax_whisper.py | 42 ++++++++++--------- 1 file changed, 22 insertions(+), 20 deletions(-) diff --git a/tests/models/whisper/test_modeling_flax_whisper.py b/tests/models/whisper/test_modeling_flax_whisper.py index ec0b5441a82c99..133bdddf99b6f2 100644 --- a/tests/models/whisper/test_modeling_flax_whisper.py +++ b/tests/models/whisper/test_modeling_flax_whisper.py @@ -189,6 +189,13 @@ class NewCls(cls): return NewCls +def make_partial_class(full_class, *args, **kwargs): + partial_class = partialclass(full_class, *args, **kwargs) + partial_class.__name__ = full_class.__name__ + + return partial_class + + @require_flax class FlaxWhisperModelTest(FlaxModelTesterMixin, unittest.TestCase): all_model_classes = (FlaxWhisperForConditionalGeneration, FlaxWhisperModel) if is_flax_available() else () @@ -202,8 +209,9 @@ def setUp(self): self.model_tester = FlaxWhisperModelTester(self) _, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() init_shape = (1,) + inputs_dict["input_features"].shape[1:] + self.all_model_classes = ( - partialclass(model_class, init_shape) for model_class in self.all_model_classes + make_partial_class(model_class, input_shape=init_shape) for model_class in self.all_model_classes ) self.config_tester = ConfigTester(self, config_class=WhisperConfig) @@ -226,12 +234,11 @@ def test_forward_signature(self): # overwrite because of `input_features` def test_jit_compilation(self): config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() - init_shape = (1,) + inputs_dict["input_features"].shape[1:] for model_class in self.all_model_classes: with self.subTest(model_class.__name__): prepared_inputs_dict = self._prepare_for_class(inputs_dict, model_class) - model = model_class(config, input_shape=init_shape) + model = model_class(config) @jax.jit def model_jitted(input_features, decoder_input_ids, **kwargs): @@ -252,7 +259,6 @@ def model_jitted(input_features, decoder_input_ids, **kwargs): @is_pt_flax_cross_test def test_equivalence_flax_to_pt(self): config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() - init_shape = (1,) + inputs_dict["input_features"].shape[1:] for model_class in self.all_model_classes: with self.subTest(model_class.__name__): @@ -272,7 +278,7 @@ def test_equivalence_flax_to_pt(self): # Flax models don't use the `use_cache` option and cache is not returned as a default. # So we disable `use_cache` here for PyTorch model. pt_model.config.use_cache = False - fx_model = model_class(config, input_shape=init_shape, dtype=jnp.float32) + fx_model = model_class(config, dtype=jnp.float32) pt_model = load_flax_weights_in_pytorch_model(pt_model, fx_model.params) @@ -315,7 +321,6 @@ def test_equivalence_pt_to_flax(self): # It might be better to put this inside the for loop below (because we modify the config there). # But logically, it is fine. config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() - init_shape = (1,) + inputs_dict["input_features"].shape[1:] for model_class in self.all_model_classes: with self.subTest(model_class.__name__): @@ -335,7 +340,7 @@ def test_equivalence_pt_to_flax(self): # Flax models don't use the `use_cache` option and cache is not returned as a default. # So we disable `use_cache` here for PyTorch model. pt_model.config.use_cache = False - fx_model = model_class(config, input_shape=init_shape, dtype=jnp.float32) + fx_model = model_class(config, dtype=jnp.float32) fx_state = convert_pytorch_state_dict_to_flax(pt_model.state_dict(), fx_model) fx_model.params = fx_state @@ -355,7 +360,7 @@ def test_equivalence_pt_to_flax(self): with tempfile.TemporaryDirectory() as tmpdirname: pt_model.save_pretrained(tmpdirname) - fx_model_loaded = model_class.from_pretrained(tmpdirname, input_shape=init_shape, from_pt=True) + fx_model_loaded = model_class.from_pretrained(tmpdirname, from_pt=True) fx_outputs_loaded = fx_model_loaded(**prepared_inputs_dict) @@ -368,15 +373,14 @@ def test_equivalence_pt_to_flax(self): # overwrite because of `input_features` @is_pt_flax_cross_test def test_save_load_bf16_to_base_pt(self): - config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() - init_shape = (1,) + inputs_dict["input_features"].shape[1:] + config, _ = self.model_tester.prepare_config_and_inputs_for_common() base_class = FLAX_MODEL_MAPPING[config.__class__] for model_class in self.all_model_classes: if model_class == base_class: continue - model = model_class(config, input_shape=init_shape) + model = model_class(config) model.params = model.to_bf16(model.params) base_params_from_head = flatten_dict(unfreeze(model.params[model.base_model_prefix])) @@ -388,7 +392,7 @@ def test_save_load_bf16_to_base_pt(self): # check that all base model weights are loaded correctly with tempfile.TemporaryDirectory() as tmpdirname: pt_model.save_pretrained(tmpdirname) - base_model = base_class.from_pretrained(tmpdirname, input_shape=init_shape, from_pt=True) + base_model = base_class.from_pretrained(tmpdirname, from_pt=True) base_params = flatten_dict(unfreeze(base_model.params)) @@ -399,15 +403,14 @@ def test_save_load_bf16_to_base_pt(self): # overwrite because of `input_features` @is_pt_flax_cross_test def test_save_load_from_base_pt(self): - config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() - init_shape = (1,) + inputs_dict["input_features"].shape[1:] + config, _ = self.model_tester.prepare_config_and_inputs_for_common() base_class = FLAX_MODEL_MAPPING[config.__class__] for model_class in self.all_model_classes: if model_class == base_class: continue - model = base_class(config, input_shape=init_shape) + model = base_class(config) base_params = flatten_dict(unfreeze(model.params)) # convert Flax model to PyTorch model @@ -419,7 +422,7 @@ def test_save_load_from_base_pt(self): with tempfile.TemporaryDirectory() as tmpdirname: # save pt model pt_model.save_pretrained(tmpdirname) - head_model = model_class.from_pretrained(tmpdirname, input_shape=init_shape, from_pt=True) + head_model = model_class.from_pretrained(tmpdirname, from_pt=True) base_param_from_head = flatten_dict(unfreeze(head_model.params[head_model.base_model_prefix])) @@ -430,15 +433,14 @@ def test_save_load_from_base_pt(self): # overwrite because of `input_features` @is_pt_flax_cross_test def test_save_load_to_base_pt(self): - config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() - init_shape = (1,) + inputs_dict["input_features"].shape[1:] + config, _ = self.model_tester.prepare_config_and_inputs_for_common() base_class = FLAX_MODEL_MAPPING[config.__class__] for model_class in self.all_model_classes: if model_class == base_class: continue - model = model_class(config, input_shape=init_shape) + model = model_class(config) base_params_from_head = flatten_dict(unfreeze(model.params[model.base_model_prefix])) # convert Flax model to PyTorch model @@ -449,7 +451,7 @@ def test_save_load_to_base_pt(self): # check that all base model weights are loaded correctly with tempfile.TemporaryDirectory() as tmpdirname: pt_model.save_pretrained(tmpdirname) - base_model = base_class.from_pretrained(tmpdirname, input_shape=init_shape, from_pt=True) + base_model = base_class.from_pretrained(tmpdirname, from_pt=True) base_params = flatten_dict(unfreeze(base_model.params)) From a9604a5688e84a2f2b8f792095c73fbc098a17ad Mon Sep 17 00:00:00 2001 From: andyehrenberg Date: Tue, 3 Jan 2023 15:10:46 -0500 Subject: [PATCH 073/111] base_class with correct input shape --- tests/models/whisper/test_modeling_flax_whisper.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/tests/models/whisper/test_modeling_flax_whisper.py b/tests/models/whisper/test_modeling_flax_whisper.py index 133bdddf99b6f2..b86e7d20cf354a 100644 --- a/tests/models/whisper/test_modeling_flax_whisper.py +++ b/tests/models/whisper/test_modeling_flax_whisper.py @@ -208,10 +208,10 @@ class FlaxWhisperModelTest(FlaxModelTesterMixin, unittest.TestCase): def setUp(self): self.model_tester = FlaxWhisperModelTester(self) _, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() - init_shape = (1,) + inputs_dict["input_features"].shape[1:] + self.init_shape = (1,) + inputs_dict["input_features"].shape[1:] self.all_model_classes = ( - make_partial_class(model_class, input_shape=init_shape) for model_class in self.all_model_classes + make_partial_class(model_class, input_shape=self.init_shape) for model_class in self.all_model_classes ) self.config_tester = ConfigTester(self, config_class=WhisperConfig) @@ -392,7 +392,7 @@ def test_save_load_bf16_to_base_pt(self): # check that all base model weights are loaded correctly with tempfile.TemporaryDirectory() as tmpdirname: pt_model.save_pretrained(tmpdirname) - base_model = base_class.from_pretrained(tmpdirname, from_pt=True) + base_model = base_class.from_pretrained(tmpdirname, input_shape=self.init_shape, from_pt=True) base_params = flatten_dict(unfreeze(base_model.params)) @@ -410,7 +410,7 @@ def test_save_load_from_base_pt(self): if model_class == base_class: continue - model = base_class(config) + model = base_class(config, input_shape=self.init_shape) base_params = flatten_dict(unfreeze(model.params)) # convert Flax model to PyTorch model @@ -451,7 +451,7 @@ def test_save_load_to_base_pt(self): # check that all base model weights are loaded correctly with tempfile.TemporaryDirectory() as tmpdirname: pt_model.save_pretrained(tmpdirname) - base_model = base_class.from_pretrained(tmpdirname, from_pt=True) + base_model = base_class.from_pretrained(tmpdirname, input_shape=self.init_shape, from_pt=True) base_params = flatten_dict(unfreeze(base_model.params)) From 5120afe53cc4ae4d12a4f8a1568ced8d280cac1f Mon Sep 17 00:00:00 2001 From: andyehrenberg Date: Tue, 3 Jan 2023 15:46:07 -0500 Subject: [PATCH 074/111] partial base classes --- tests/models/whisper/test_modeling_flax_whisper.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/tests/models/whisper/test_modeling_flax_whisper.py b/tests/models/whisper/test_modeling_flax_whisper.py index b86e7d20cf354a..94a5a61d22c650 100644 --- a/tests/models/whisper/test_modeling_flax_whisper.py +++ b/tests/models/whisper/test_modeling_flax_whisper.py @@ -374,7 +374,7 @@ def test_equivalence_pt_to_flax(self): @is_pt_flax_cross_test def test_save_load_bf16_to_base_pt(self): config, _ = self.model_tester.prepare_config_and_inputs_for_common() - base_class = FLAX_MODEL_MAPPING[config.__class__] + base_class = make_partial_class(FLAX_MODEL_MAPPING[config.__class__], input_shape=self.init_shape) for model_class in self.all_model_classes: if model_class == base_class: @@ -392,7 +392,7 @@ def test_save_load_bf16_to_base_pt(self): # check that all base model weights are loaded correctly with tempfile.TemporaryDirectory() as tmpdirname: pt_model.save_pretrained(tmpdirname) - base_model = base_class.from_pretrained(tmpdirname, input_shape=self.init_shape, from_pt=True) + base_model = base_class.from_pretrained(tmpdirname, from_pt=True) base_params = flatten_dict(unfreeze(base_model.params)) @@ -404,13 +404,13 @@ def test_save_load_bf16_to_base_pt(self): @is_pt_flax_cross_test def test_save_load_from_base_pt(self): config, _ = self.model_tester.prepare_config_and_inputs_for_common() - base_class = FLAX_MODEL_MAPPING[config.__class__] + base_class = make_partial_class(FLAX_MODEL_MAPPING[config.__class__], input_shape=self.init_shape) for model_class in self.all_model_classes: if model_class == base_class: continue - model = base_class(config, input_shape=self.init_shape) + model = base_class(config) base_params = flatten_dict(unfreeze(model.params)) # convert Flax model to PyTorch model @@ -434,7 +434,7 @@ def test_save_load_from_base_pt(self): @is_pt_flax_cross_test def test_save_load_to_base_pt(self): config, _ = self.model_tester.prepare_config_and_inputs_for_common() - base_class = FLAX_MODEL_MAPPING[config.__class__] + base_class = make_partial_class(FLAX_MODEL_MAPPING[config.__class__], input_shape=self.init_shape) for model_class in self.all_model_classes: if model_class == base_class: @@ -451,7 +451,7 @@ def test_save_load_to_base_pt(self): # check that all base model weights are loaded correctly with tempfile.TemporaryDirectory() as tmpdirname: pt_model.save_pretrained(tmpdirname) - base_model = base_class.from_pretrained(tmpdirname, input_shape=self.init_shape, from_pt=True) + base_model = base_class.from_pretrained(tmpdirname, from_pt=True) base_params = flatten_dict(unfreeze(base_model.params)) From c6b1ae4f2b1a1ba295fb70266bf5855cf4862073 Mon Sep 17 00:00:00 2001 From: andyehrenberg Date: Tue, 3 Jan 2023 16:10:53 -0500 Subject: [PATCH 075/111] match by name --- tests/models/whisper/test_modeling_flax_whisper.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/tests/models/whisper/test_modeling_flax_whisper.py b/tests/models/whisper/test_modeling_flax_whisper.py index 94a5a61d22c650..fbdb62a3e44fcd 100644 --- a/tests/models/whisper/test_modeling_flax_whisper.py +++ b/tests/models/whisper/test_modeling_flax_whisper.py @@ -377,7 +377,7 @@ def test_save_load_bf16_to_base_pt(self): base_class = make_partial_class(FLAX_MODEL_MAPPING[config.__class__], input_shape=self.init_shape) for model_class in self.all_model_classes: - if model_class == base_class: + if model_class.__name__ == base_class.__name__: continue model = model_class(config) @@ -407,7 +407,7 @@ def test_save_load_from_base_pt(self): base_class = make_partial_class(FLAX_MODEL_MAPPING[config.__class__], input_shape=self.init_shape) for model_class in self.all_model_classes: - if model_class == base_class: + if model_class.__name__ == base_class.__name__: continue model = base_class(config) @@ -437,7 +437,7 @@ def test_save_load_to_base_pt(self): base_class = make_partial_class(FLAX_MODEL_MAPPING[config.__class__], input_shape=self.init_shape) for model_class in self.all_model_classes: - if model_class == base_class: + if model_class.__name__ == base_class.__name__: continue model = model_class(config) From 4c239fcd9189a022cb74598acd384a8298f26e93 Mon Sep 17 00:00:00 2001 From: andyehrenberg Date: Wed, 4 Jan 2023 09:36:42 -0500 Subject: [PATCH 076/111] set main_input_name --- .../models/whisper/modeling_flax_whisper.py | 1 + .../whisper/test_modeling_flax_whisper.py | 47 +++++++++++++++++++ 2 files changed, 48 insertions(+) diff --git a/src/transformers/models/whisper/modeling_flax_whisper.py b/src/transformers/models/whisper/modeling_flax_whisper.py index 2824996d5f9810..b8da13de00e256 100644 --- a/src/transformers/models/whisper/modeling_flax_whisper.py +++ b/src/transformers/models/whisper/modeling_flax_whisper.py @@ -801,6 +801,7 @@ def _get_decoder_module(self): class FlaxWhisperPreTrainedModel(FlaxPreTrainedModel): config_class = WhisperConfig base_model_prefix: str = "model" + main_input_name = "input_features" module_class: nn.Module = None def __init__( diff --git a/tests/models/whisper/test_modeling_flax_whisper.py b/tests/models/whisper/test_modeling_flax_whisper.py index fbdb62a3e44fcd..dbd0b4f7f6bc0d 100644 --- a/tests/models/whisper/test_modeling_flax_whisper.py +++ b/tests/models/whisper/test_modeling_flax_whisper.py @@ -192,6 +192,7 @@ class NewCls(cls): def make_partial_class(full_class, *args, **kwargs): partial_class = partialclass(full_class, *args, **kwargs) partial_class.__name__ = full_class.__name__ + partial_class.__module__ = full_class.__module__ return partial_class @@ -459,6 +460,52 @@ def test_save_load_to_base_pt(self): max_diff = (base_params[key] - base_params_from_head[key]).sum().item() self.assertLessEqual(max_diff, 1e-3, msg=f"{key} not identical") + # overwrite because of `input_features` + def test_save_load_from_base(self): + config, _ = self.model_tester.prepare_config_and_inputs_for_common() + base_class = make_partial_class(FLAX_MODEL_MAPPING[config.__class__], input_shape=self.init_shape) + + for model_class in self.all_model_classes: + if model_class == base_class: + continue + + model = base_class(config) + base_params = flatten_dict(unfreeze(model.params)) + + # check that all base model weights are loaded correctly + with tempfile.TemporaryDirectory() as tmpdirname: + model.save_pretrained(tmpdirname) + head_model = model_class.from_pretrained(tmpdirname) + + base_param_from_head = flatten_dict(unfreeze(head_model.params[head_model.base_model_prefix])) + + for key in base_param_from_head.keys(): + max_diff = (base_params[key] - base_param_from_head[key]).sum().item() + self.assertLessEqual(max_diff, 1e-3, msg=f"{key} not identical") + + # overwrite because of `input_features` + def test_save_load_to_base(self): + config, _ = self.model_tester.prepare_config_and_inputs_for_common() + base_class = make_partial_class(FLAX_MODEL_MAPPING[config.__class__], input_shape=self.init_shape) + + for model_class in self.all_model_classes: + if model_class == base_class: + continue + + model = model_class(config) + base_params_from_head = flatten_dict(unfreeze(model.params[model.base_model_prefix])) + + # check that all base model weights are loaded correctly + with tempfile.TemporaryDirectory() as tmpdirname: + model.save_pretrained(tmpdirname) + base_model = base_class.from_pretrained(tmpdirname) + + base_params = flatten_dict(unfreeze(base_model.params)) + + for key in base_params_from_head.keys(): + max_diff = (base_params[key] - base_params_from_head[key]).sum().item() + self.assertLessEqual(max_diff, 1e-3, msg=f"{key} not identical") + def _assert_tensors_equal(a, b, atol=1e-12, prefix=""): """If tensors not close, or a and b arent both tensors, raise a nice Assertion error.""" From 279ceb6f562fca434bf0ccf45b0c771dbc0ebdad Mon Sep 17 00:00:00 2001 From: andyehrenberg Date: Wed, 4 Jan 2023 10:12:42 -0500 Subject: [PATCH 077/111] compare on names --- tests/models/whisper/test_modeling_flax_whisper.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/models/whisper/test_modeling_flax_whisper.py b/tests/models/whisper/test_modeling_flax_whisper.py index dbd0b4f7f6bc0d..3fea507f27e955 100644 --- a/tests/models/whisper/test_modeling_flax_whisper.py +++ b/tests/models/whisper/test_modeling_flax_whisper.py @@ -466,7 +466,7 @@ def test_save_load_from_base(self): base_class = make_partial_class(FLAX_MODEL_MAPPING[config.__class__], input_shape=self.init_shape) for model_class in self.all_model_classes: - if model_class == base_class: + if model_class.__name__ == base_class.__name__: continue model = base_class(config) @@ -489,7 +489,7 @@ def test_save_load_to_base(self): base_class = make_partial_class(FLAX_MODEL_MAPPING[config.__class__], input_shape=self.init_shape) for model_class in self.all_model_classes: - if model_class == base_class: + if model_class.__name__ == base_class.__name__: continue model = model_class(config) From 797fab1684f72c9978b597a0e3072afd5b506f43 Mon Sep 17 00:00:00 2001 From: andyehrenberg Date: Mon, 9 Jan 2023 15:44:37 -0500 Subject: [PATCH 078/111] formatting --- src/transformers/generation/flax_utils.py | 18 ++++++++++++++---- 1 file changed, 14 insertions(+), 4 deletions(-) diff --git a/src/transformers/generation/flax_utils.py b/src/transformers/generation/flax_utils.py index 19ad241d359737..5350c4280e8251 100644 --- a/src/transformers/generation/flax_utils.py +++ b/src/transformers/generation/flax_utils.py @@ -583,12 +583,22 @@ def _get_logits_processor( processors.append(FlaxSuppressTokensLogitsProcessor(generation_config.suppress_tokens)) if generation_config.begin_suppress_tokens is not None: begin_index = input_ids_seq_length - begin_index = begin_index if (input_ids_seq_length > 1 or generation_config.forced_bos_token_id is None) else begin_index + 1 + begin_index = ( + begin_index + if (input_ids_seq_length > 1 or generation_config.forced_bos_token_id is None) + else begin_index + 1 + ) if generation_config.forced_decoder_ids is not None and len(generation_config.forced_decoder_ids) > 0: - begin_index += generation_config.forced_decoder_ids[-1][0] # generation starts after the last token that is forced - processors.append(FlaxSuppressTokensAtBeginLogitsProcessor(generation_config.begin_suppress_tokens, begin_index)) + begin_index += generation_config.forced_decoder_ids[-1][ + 0 + ] # generation starts after the last token that is forced + processors.append( + FlaxSuppressTokensAtBeginLogitsProcessor(generation_config.begin_suppress_tokens, begin_index) + ) if generation_config.forced_decoder_ids is not None: - forced_decoder_ids = [[input_ids_seq_length + i[0] - 1, i[1]] for i in generation_config.forced_decoder_ids] + forced_decoder_ids = [ + [input_ids_seq_length + i[0] - 1, i[1]] for i in generation_config.forced_decoder_ids + ] processors.append(FlaxForceTokensLogitsProcessor(forced_decoder_ids)) return processors From f3173d8e8f3a678c55c7acae52322ff3f8b649f5 Mon Sep 17 00:00:00 2001 From: andyehrenberg Date: Mon, 9 Jan 2023 15:52:50 -0500 Subject: [PATCH 079/111] remove unused import --- src/transformers/generation/flax_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/transformers/generation/flax_utils.py b/src/transformers/generation/flax_utils.py index 5350c4280e8251..c4687377dbe307 100644 --- a/src/transformers/generation/flax_utils.py +++ b/src/transformers/generation/flax_utils.py @@ -19,7 +19,7 @@ import inspect import warnings from functools import partial -from typing import Any, Dict, List, Optional +from typing import Any, Dict, Optional import numpy as np From b4696ca7cb88778795162c380eed27d591491484 Mon Sep 17 00:00:00 2001 From: andyehrenberg Date: Tue, 10 Jan 2023 09:42:02 -0500 Subject: [PATCH 080/111] safer position ids computation --- src/transformers/models/whisper/modeling_flax_whisper.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/transformers/models/whisper/modeling_flax_whisper.py b/src/transformers/models/whisper/modeling_flax_whisper.py index b8da13de00e256..c7593a4bf813ee 100644 --- a/src/transformers/models/whisper/modeling_flax_whisper.py +++ b/src/transformers/models/whisper/modeling_flax_whisper.py @@ -1000,7 +1000,7 @@ def decode( raise ValueError("Make sure to provide `decoder_position_ids` when passing `past_key_values`.") if decoder_attention_mask is not None: - decoder_position_ids = decoder_attention_mask.cumsum(-1) - 1 + decoder_position_ids = (decoder_attention_mask.cumsum(-1) * decoder_attention_mask) - 1 else: decoder_position_ids = jnp.broadcast_to( jnp.arange(sequence_length)[None, :], (batch_size, sequence_length) @@ -1085,7 +1085,7 @@ def __call__( # prepare decoder inputs if decoder_position_ids is None: if decoder_attention_mask is not None: - decoder_position_ids = decoder_attention_mask.cumsum(-1) - 1 + decoder_position_ids = (decoder_attention_mask.cumsum(-1) * decoder_attention_mask) - 1 else: batch_size, sequence_length = decoder_input_ids.shape decoder_position_ids = jnp.broadcast_to( @@ -1250,7 +1250,7 @@ def decode( raise ValueError("Make sure to provide `decoder_position_ids` when passing `past_key_values`.") if decoder_attention_mask is not None: - decoder_position_ids = decoder_attention_mask.cumsum(-1) - 1 + decoder_position_ids = (decoder_attention_mask.cumsum(-1) * decoder_attention_mask) - 1 else: decoder_position_ids = jnp.broadcast_to( jnp.arange(sequence_length)[None, :], (batch_size, sequence_length) @@ -1349,7 +1349,7 @@ def prepare_inputs_for_generation( # Thus we can create a single static attention_mask here, which is more efficient for compilation extended_attention_mask = jnp.ones((batch_size, max_length), dtype="i4") if decoder_attention_mask is not None: - position_ids = decoder_attention_mask.cumsum(axis=-1) - 1 + position_ids = jnp.maximum(0, decoder_attention_mask.cumsum(-1) - 1) extended_attention_mask = lax.dynamic_update_slice(extended_attention_mask, decoder_attention_mask, (0, 0)) else: position_ids = jnp.broadcast_to(jnp.arange(seq_length, dtype="i4")[None, :], (batch_size, seq_length)) From 1c11ca6e8a045ec4e2af6b255b4cfd1f44a874da Mon Sep 17 00:00:00 2001 From: andyehrenberg Date: Tue, 10 Jan 2023 09:48:17 -0500 Subject: [PATCH 081/111] safer position id computation --- src/transformers/models/whisper/modeling_flax_whisper.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/transformers/models/whisper/modeling_flax_whisper.py b/src/transformers/models/whisper/modeling_flax_whisper.py index c7593a4bf813ee..e2d7bf9910e8d5 100644 --- a/src/transformers/models/whisper/modeling_flax_whisper.py +++ b/src/transformers/models/whisper/modeling_flax_whisper.py @@ -1349,7 +1349,7 @@ def prepare_inputs_for_generation( # Thus we can create a single static attention_mask here, which is more efficient for compilation extended_attention_mask = jnp.ones((batch_size, max_length), dtype="i4") if decoder_attention_mask is not None: - position_ids = jnp.maximum(0, decoder_attention_mask.cumsum(-1) - 1) + position_ids = decoder_attention_mask.cumsum(-1) - 1 extended_attention_mask = lax.dynamic_update_slice(extended_attention_mask, decoder_attention_mask, (0, 0)) else: position_ids = jnp.broadcast_to(jnp.arange(seq_length, dtype="i4")[None, :], (batch_size, seq_length)) From c128fd857be67bf5a2f3e397671af90339d1c232 Mon Sep 17 00:00:00 2001 From: Andy Ehrenberg <32784181+andyehrenberg@users.noreply.github.com> Date: Wed, 18 Jan 2023 08:00:42 -0500 Subject: [PATCH 082/111] Update src/transformers/models/whisper/modeling_flax_whisper.py Co-authored-by: Sanchit Gandhi <93869735+sanchit-gandhi@users.noreply.github.com> --- src/transformers/models/whisper/modeling_flax_whisper.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/transformers/models/whisper/modeling_flax_whisper.py b/src/transformers/models/whisper/modeling_flax_whisper.py index e2d7bf9910e8d5..ea5358096d147a 100644 --- a/src/transformers/models/whisper/modeling_flax_whisper.py +++ b/src/transformers/models/whisper/modeling_flax_whisper.py @@ -99,7 +99,7 @@ be used by default. If you want to change padding behavior, you should modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more information on the default strategy. position_ids (`numpy.ndarray` of shape `(batch_size, sequence_length)`, *optional*): - Whisper does not use position_ids in the encoder as input_features is always the same size and doesn't use + Whisper does not use `position_ids` in the encoder as `input_features` is always the same size and doesn't use masking, but this argument is preserved for compatibility. By default the silence in the input log mel spectrogram are ignored. decoder_position_ids (`numpy.ndarray` of shape `(batch_size, sequence_length)`, *optional*): From 2ae5b08e50576295e76e8477a8d2c5574e1b2b12 Mon Sep 17 00:00:00 2001 From: Andy Ehrenberg <32784181+andyehrenberg@users.noreply.github.com> Date: Wed, 18 Jan 2023 08:04:31 -0500 Subject: [PATCH 083/111] Update src/transformers/models/whisper/modeling_flax_whisper.py Co-authored-by: Sanchit Gandhi <93869735+sanchit-gandhi@users.noreply.github.com> --- src/transformers/models/whisper/modeling_flax_whisper.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/transformers/models/whisper/modeling_flax_whisper.py b/src/transformers/models/whisper/modeling_flax_whisper.py index ea5358096d147a..be6e7d69bbc3c4 100644 --- a/src/transformers/models/whisper/modeling_flax_whisper.py +++ b/src/transformers/models/whisper/modeling_flax_whisper.py @@ -818,7 +818,7 @@ def __init__( def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple, params: FrozenDict = None) -> FrozenDict: # init input tensors - input_features = jnp.zeros(input_shape) + input_features = jnp.zeros(input_shape, dtype="f4") input_features = input_features.at[(..., -1)].set(self.config.eos_token_id) decoder_input_ids = jnp.zeros((input_shape[0], 1), dtype="i4") From 48583bd8b1b9e0b3da1f4092e9b3ef0ca79a4e4b Mon Sep 17 00:00:00 2001 From: andyehrenberg Date: Wed, 18 Jan 2023 08:14:31 -0500 Subject: [PATCH 084/111] remove identical inherited tests --- .../whisper/test_modeling_flax_whisper.py | 117 +----------------- 1 file changed, 1 insertion(+), 116 deletions(-) diff --git a/tests/models/whisper/test_modeling_flax_whisper.py b/tests/models/whisper/test_modeling_flax_whisper.py index 3fea507f27e955..aeae14cb824e4e 100644 --- a/tests/models/whisper/test_modeling_flax_whisper.py +++ b/tests/models/whisper/test_modeling_flax_whisper.py @@ -255,122 +255,7 @@ def model_jitted(input_features, decoder_input_ids, **kwargs): self.assertEqual(len(outputs), len(jitted_outputs)) for jitted_output, output in zip(jitted_outputs, outputs): self.assertEqual(jitted_output.shape, output.shape) - - # overwrite because of `input_features` - @is_pt_flax_cross_test - def test_equivalence_flax_to_pt(self): - config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() - - for model_class in self.all_model_classes: - with self.subTest(model_class.__name__): - # Output all for aggressive testing - config.output_hidden_states = True - config.output_attentions = self.has_attentions - - # prepare inputs - prepared_inputs_dict = self._prepare_for_class(inputs_dict, model_class) - pt_inputs = {k: torch.tensor(v.tolist(), device=torch_device) for k, v in prepared_inputs_dict.items()} - - # load corresponding PyTorch class - pt_model_class_name = model_class.__name__[4:] # Skip the "Flax" at the beginning - pt_model_class = getattr(transformers, pt_model_class_name) - - pt_model = pt_model_class(config).eval() - # Flax models don't use the `use_cache` option and cache is not returned as a default. - # So we disable `use_cache` here for PyTorch model. - pt_model.config.use_cache = False - fx_model = model_class(config, dtype=jnp.float32) - - pt_model = load_flax_weights_in_pytorch_model(pt_model, fx_model.params) - - # make sure weights are tied in PyTorch - pt_model.tie_weights() - - # send pytorch model to the correct device - pt_model.to(torch_device) - - with torch.no_grad(): - pt_outputs = pt_model(**pt_inputs) - fx_outputs = fx_model(**prepared_inputs_dict) - - fx_keys = tuple([k for k, v in fx_outputs.items() if v is not None]) - pt_keys = tuple([k for k, v in pt_outputs.items() if v is not None]) - - self.assertEqual(fx_keys, pt_keys) - self.check_pt_flax_outputs(fx_outputs, pt_outputs, model_class) - - with tempfile.TemporaryDirectory() as tmpdirname: - fx_model.save_pretrained(tmpdirname) - pt_model_loaded = pt_model_class.from_pretrained(tmpdirname, from_flax=True) - - # send pytorch model to the correct device - pt_model_loaded.to(torch_device) - pt_model_loaded.eval() - - with torch.no_grad(): - pt_outputs_loaded = pt_model_loaded(**pt_inputs) - - fx_keys = tuple([k for k, v in fx_outputs.items() if v is not None]) - pt_keys = tuple([k for k, v in pt_outputs_loaded.items() if v is not None]) - - self.assertEqual(fx_keys, pt_keys) - self.check_pt_flax_outputs(fx_outputs, pt_outputs_loaded, model_class) - - # overwrite because of `input_features` - @is_pt_flax_cross_test - def test_equivalence_pt_to_flax(self): - # It might be better to put this inside the for loop below (because we modify the config there). - # But logically, it is fine. - config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() - - for model_class in self.all_model_classes: - with self.subTest(model_class.__name__): - # Output all for aggressive testing - config.output_hidden_states = True - config.output_attentions = self.has_attentions - - # prepare inputs - prepared_inputs_dict = self._prepare_for_class(inputs_dict, model_class) - pt_inputs = {k: torch.tensor(v.tolist(), device=torch_device) for k, v in prepared_inputs_dict.items()} - - # load corresponding PyTorch class - pt_model_class_name = model_class.__name__[4:] # Skip the "Flax" at the beginning - pt_model_class = getattr(transformers, pt_model_class_name) - - pt_model = pt_model_class(config).eval() - # Flax models don't use the `use_cache` option and cache is not returned as a default. - # So we disable `use_cache` here for PyTorch model. - pt_model.config.use_cache = False - fx_model = model_class(config, dtype=jnp.float32) - - fx_state = convert_pytorch_state_dict_to_flax(pt_model.state_dict(), fx_model) - fx_model.params = fx_state - - # send pytorch model to the correct device - pt_model.to(torch_device) - - with torch.no_grad(): - pt_outputs = pt_model(**pt_inputs) - fx_outputs = fx_model(**prepared_inputs_dict) - - fx_keys = tuple([k for k, v in fx_outputs.items() if v is not None]) - pt_keys = tuple([k for k, v in pt_outputs.items() if v is not None]) - - self.assertEqual(fx_keys, pt_keys) - self.check_pt_flax_outputs(fx_outputs, pt_outputs, model_class) - - with tempfile.TemporaryDirectory() as tmpdirname: - pt_model.save_pretrained(tmpdirname) - fx_model_loaded = model_class.from_pretrained(tmpdirname, from_pt=True) - - fx_outputs_loaded = fx_model_loaded(**prepared_inputs_dict) - - fx_keys = tuple([k for k, v in fx_outputs_loaded.items() if v is not None]) - pt_keys = tuple([k for k, v in pt_outputs.items() if v is not None]) - - self.assertEqual(fx_keys, pt_keys) - self.check_pt_flax_outputs(fx_outputs_loaded, pt_outputs, model_class) - + # overwrite because of `input_features` @is_pt_flax_cross_test def test_save_load_bf16_to_base_pt(self): From 1c18f6113e196af3b9696d3f17d909a14ef9e05c Mon Sep 17 00:00:00 2001 From: andyehrenberg Date: Wed, 18 Jan 2023 08:36:10 -0500 Subject: [PATCH 085/111] fix prompt ids in tests --- .../whisper/test_modeling_flax_whisper.py | 28 ++++--------------- 1 file changed, 5 insertions(+), 23 deletions(-) diff --git a/tests/models/whisper/test_modeling_flax_whisper.py b/tests/models/whisper/test_modeling_flax_whisper.py index aeae14cb824e4e..dd62ca4223662a 100644 --- a/tests/models/whisper/test_modeling_flax_whisper.py +++ b/tests/models/whisper/test_modeling_flax_whisper.py @@ -392,20 +392,6 @@ def test_save_load_to_base(self): self.assertLessEqual(max_diff, 1e-3, msg=f"{key} not identical") -def _assert_tensors_equal(a, b, atol=1e-12, prefix=""): - """If tensors not close, or a and b arent both tensors, raise a nice Assertion error.""" - if a is None and b is None: - return True - try: - if _assert_tensors_equal(a, b, atol=atol): - return True - raise - except Exception: - if len(prefix) > 0: - prefix = f"{prefix}: " - raise AssertionError(f"{prefix}{a} != {b}") - - @slow @require_flax class FlaxWhisperModelIntegrationTest(unittest.TestCase): @@ -522,7 +508,7 @@ def test_tiny_en_generation(self): EXPECTED_TRANSCRIPT = ( "<|startoftranscript|><|en|><|transcribe|><|notimestamps|> Mr. Quilter is the apostle of the middle" - " classes and we are glad" + " classes and we are glad to" ) self.assertEqual(transcript, EXPECTED_TRANSCRIPT) @@ -553,8 +539,7 @@ def test_large_generation(self): raw_speech=input_speech, sampling_rate=processor.feature_extractor.sampling_rate, return_tensors="jax" ).input_features - prompt_ids = processor.get_decoder_prompt_ids(language="en", task="transcribe") - model.config.forced_decoder_ids = [[i[0] - 1, i[1]] for i in prompt_ids[1:]] + model.config.forced_decoder_ids = processor.get_decoder_prompt_ids(language="en", task="transcribe") generated_ids = model.generate(input_features, num_beams=5, max_length=20).sequences transcript = processor.tokenizer.decode(generated_ids[0], skip_special_tokens=True) @@ -571,16 +556,14 @@ def test_large_generation_multilingual(self): input_speech = next(iter(ds))["audio"]["array"] input_features = processor.feature_extractor(raw_speech=input_speech, return_tensors="np") - prompt_ids = processor.get_decoder_prompt_ids(language="ja", task="transcribe") - model.config.forced_decoder_ids = [[i[0] - 1, i[1]] for i in prompt_ids[1:]] + model.config.forced_decoder_ids = processor.get_decoder_prompt_ids(language="ja", task="transcribe") generated_ids = model.generate(input_features, do_sample=False, max_length=20).sequences transcript = processor.batch_decode(generated_ids, skip_special_tokens=True)[0] EXPECTED_TRANSCRIPT = "木村さんに電話を貸してもらいました" self.assertEqual(transcript, EXPECTED_TRANSCRIPT) - prompt_ids = processor.get_decoder_prompt_ids(language="en", task="transcribe") - model.config.forced_decoder_ids = [[i[0] - 1, i[1]] for i in prompt_ids[1:]] + model.config.forced_decoder_ids = processor.get_decoder_prompt_ids(language="en", task="transcribe") generated_ids = model.generate( input_features, do_sample=False, @@ -591,8 +574,7 @@ def test_large_generation_multilingual(self): EXPECTED_TRANSCRIPT = " Kimura-san called me." self.assertEqual(transcript, EXPECTED_TRANSCRIPT) - prompt_ids = processor.get_decoder_prompt_ids(language="ja", task="translate") - model.config.forced_decoder_ids = [[i[0] - 1, i[1]] for i in prompt_ids[1:]] + model.config.forced_decoder_ids = processor.get_decoder_prompt_ids(language="ja", task="translate") generated_ids = model.generate(input_features, do_sample=False, max_length=20).sequences transcript = processor.batch_decode(generated_ids, skip_special_tokens=True)[0] From c3b1d34a4c218799afc3530a869564a9ce986949 Mon Sep 17 00:00:00 2001 From: andyehrenberg Date: Wed, 18 Jan 2023 08:39:53 -0500 Subject: [PATCH 086/111] use generation config --- src/transformers/generation/flax_utils.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/transformers/generation/flax_utils.py b/src/transformers/generation/flax_utils.py index c4687377dbe307..357f32cf5c1a32 100644 --- a/src/transformers/generation/flax_utils.py +++ b/src/transformers/generation/flax_utils.py @@ -184,9 +184,9 @@ def _get_decoder_start_token_id(self, decoder_start_token_id: int = None, bos_to # retrieve decoder_start_token_id for encoder-decoder models # fall back to bos_token_id if necessary decoder_start_token_id = ( - decoder_start_token_id if decoder_start_token_id is not None else self.config.decoder_start_token_id + decoder_start_token_id if decoder_start_token_id is not None else self.generation_config.decoder_start_token_id ) - bos_token_id = bos_token_id if bos_token_id is not None else self.config.bos_token_id + bos_token_id = bos_token_id if bos_token_id is not None else self.generation_config.bos_token_id if decoder_start_token_id is not None: return decoder_start_token_id elif ( From bf15d5fc11830439dfe0c910ba2d34c828d0329d Mon Sep 17 00:00:00 2001 From: andyehrenberg Date: Wed, 18 Jan 2023 08:41:06 -0500 Subject: [PATCH 087/111] use jnp array --- src/transformers/generation/flax_logits_process.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/src/transformers/generation/flax_logits_process.py b/src/transformers/generation/flax_logits_process.py index 38d70eb9ae7269..3fae8f1a3530b8 100644 --- a/src/transformers/generation/flax_logits_process.py +++ b/src/transformers/generation/flax_logits_process.py @@ -15,8 +15,6 @@ import inspect -import numpy as np - import jax import jax.lax as lax import jax.numpy as jnp @@ -329,7 +327,7 @@ def __init__(self, force_token_map): # Converts the dictionary of format {index: token} containing the tokens to be forced to an array, where the # index of the array corresponds to the index of the token to be forced, for XLA compatibility. # Indexes without forced tokens will have a negative value. - force_token_array = np.ones((max(force_token_map.keys()) + 1), dtype=np.int32) * -1 + force_token_array = jnp.ones((max(force_token_map.keys()) + 1), dtype=jnp.int32) * -1 for index, token in force_token_map.items(): force_token_array[index] = token self.force_token_array = jnp.array(force_token_array) From c5fc14b5b46543748ae6f979b37c7aa5aff3ae94 Mon Sep 17 00:00:00 2001 From: andyehrenberg Date: Wed, 18 Jan 2023 08:44:02 -0500 Subject: [PATCH 088/111] better var names --- .../models/whisper/modeling_flax_whisper.py | 26 +++++++++---------- 1 file changed, 13 insertions(+), 13 deletions(-) diff --git a/src/transformers/models/whisper/modeling_flax_whisper.py b/src/transformers/models/whisper/modeling_flax_whisper.py index be6e7d69bbc3c4..64a7aa93ec0a14 100644 --- a/src/transformers/models/whisper/modeling_flax_whisper.py +++ b/src/transformers/models/whisper/modeling_flax_whisper.py @@ -216,21 +216,21 @@ def __call__( is_cross_attention = key_value_states is not None batch_size = hidden_states.shape[0] - q = self.q_proj(hidden_states) + query_states = self.q_proj(hidden_states) if is_cross_attention: - k = self.k_proj(key_value_states) - v = self.v_proj(key_value_states) + key_states = self.k_proj(key_value_states) + value_states = self.v_proj(key_value_states) else: - k = self.k_proj(hidden_states) - v = self.v_proj(hidden_states) + key_states = self.k_proj(hidden_states) + value_states = self.v_proj(hidden_states) - q = self._split_heads(q) - k = self._split_heads(k) - v = self._split_heads(v) + query_states = self._split_heads(query_states) + key_states = self._split_heads(key_states) + value_states = self._split_heads(value_states) if self.causal: - query_length, key_length = q.shape[1], k.shape[1] + query_length, key_length = query_states.shape[1], key_states.shape[1] if self.has_variable("cache", "cached_key"): mask_shift = self.variables["cache"]["cache_index"] max_decoder_length = self.variables["cache"]["cached_key"].shape[1] @@ -256,7 +256,7 @@ def __call__( # and cache the keys and values step by step. if self.causal and (self.has_variable("cache", "cached_key") or init_cache): - k, v, attention_mask = self._concatenate_to_cache(k, v, q, attention_mask) + key_states, value_states, attention_mask = self._concatenate_to_cache(key_states, value_states, query_states, attention_mask) # Convert the boolean attention mask to an attention bias. if attention_mask is not None: @@ -274,8 +274,8 @@ def __call__( dropout_rng = self.make_rng("dropout") attn_weights = dot_product_attention_weights( - q, - k, + query_states, + key_states, bias=attention_bias, dropout_rng=dropout_rng, dropout_rate=self.dropout, @@ -285,7 +285,7 @@ def __call__( precision=None, ) - attn_output = jnp.einsum("...hqk,...khd->...qhd", attn_weights, v) + attn_output = jnp.einsum("...hqk,...khd->...qhd", attn_weights, value_states) attn_output = self._merge_heads(attn_output) attn_output = self.out_proj(attn_output) From 161cb8afcc5cb46f3948c390984f8da44e57b63b Mon Sep 17 00:00:00 2001 From: andyehrenberg Date: Wed, 18 Jan 2023 08:46:31 -0500 Subject: [PATCH 089/111] more explicit bias use --- src/transformers/models/whisper/modeling_flax_whisper.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/src/transformers/models/whisper/modeling_flax_whisper.py b/src/transformers/models/whisper/modeling_flax_whisper.py index 64a7aa93ec0a14..426fff5e3f6bf3 100644 --- a/src/transformers/models/whisper/modeling_flax_whisper.py +++ b/src/transformers/models/whisper/modeling_flax_whisper.py @@ -190,15 +190,14 @@ def setup(self) -> None: dense = partial( nn.Dense, self.embed_dim, - use_bias=self.bias, dtype=self.dtype, kernel_init=jax.nn.initializers.normal(self.config.init_std), ) - self.q_proj = dense() + self.q_proj = dense(use_bias=self.bias) self.k_proj = dense(use_bias=False) - self.v_proj = dense() - self.out_proj = dense() + self.v_proj = dense(use_bias=self.bias) + self.out_proj = dense(use_bias=self.bias) if self.causal: self.causal_mask = make_causal_mask( From bb9d0af98438737261129109a5ee6a419b359498 Mon Sep 17 00:00:00 2001 From: andyehrenberg Date: Wed, 18 Jan 2023 09:21:14 -0500 Subject: [PATCH 090/111] import transformers --- tests/models/whisper/test_modeling_whisper.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/models/whisper/test_modeling_whisper.py b/tests/models/whisper/test_modeling_whisper.py index de012f872c5227..a023a833c809fc 100644 --- a/tests/models/whisper/test_modeling_whisper.py +++ b/tests/models/whisper/test_modeling_whisper.py @@ -22,6 +22,7 @@ import numpy as np +import transformers from transformers import WhisperConfig from transformers.testing_utils import is_pt_flax_cross_test, require_torch, require_torchaudio, slow, torch_device from transformers.utils import cached_property, is_flax_available, is_torch_available From f1d90d2379ed446941c27d28e33516d8289de6e8 Mon Sep 17 00:00:00 2001 From: andyehrenberg Date: Wed, 18 Jan 2023 09:37:58 -0500 Subject: [PATCH 091/111] formatting --- src/transformers/generation/flax_utils.py | 14 +++++++++----- .../models/whisper/modeling_flax_whisper.py | 4 +++- .../models/whisper/modeling_whisper.py | 3 --- 3 files changed, 12 insertions(+), 9 deletions(-) diff --git a/src/transformers/generation/flax_utils.py b/src/transformers/generation/flax_utils.py index e4b09c6d84599d..3cd4cada619f2d 100644 --- a/src/transformers/generation/flax_utils.py +++ b/src/transformers/generation/flax_utils.py @@ -187,7 +187,9 @@ def _get_decoder_start_token_id(self, decoder_start_token_id: int = None, bos_to # retrieve decoder_start_token_id for encoder-decoder models # fall back to bos_token_id if necessary decoder_start_token_id = ( - decoder_start_token_id if decoder_start_token_id is not None else self.generation_config.decoder_start_token_id + decoder_start_token_id + if decoder_start_token_id is not None + else self.generation_config.decoder_start_token_id ) bos_token_id = bos_token_id if bos_token_id is not None else self.generation_config.bos_token_id if decoder_start_token_id is not None: @@ -370,10 +372,12 @@ def generate( has_default_max_length = kwargs.get("max_length") is None and generation_config.max_length is not None if has_default_max_length and generation_config.max_new_tokens is None: warnings.warn( - "Neither `max_length` nor `max_new_tokens` have been set, `max_length` will default to" - f" {generation_config.max_length} (`generation_config.max_length`). Controlling `max_length` via the" - " config is deprecated and `max_length` will be removed from the config in v5 of Transformers -- we" - " recommend using `max_new_tokens` to control the maximum length of the generation.", + ( + "Neither `max_length` nor `max_new_tokens` have been set, `max_length` will default to" + f" {generation_config.max_length} (`generation_config.max_length`). Controlling `max_length` via" + " the config is deprecated and `max_length` will be removed from the config in v5 of Transformers" + " -- we recommend using `max_new_tokens` to control the maximum length of the generation." + ), UserWarning, ) elif has_default_max_length and generation_config.max_new_tokens is not None: diff --git a/src/transformers/models/whisper/modeling_flax_whisper.py b/src/transformers/models/whisper/modeling_flax_whisper.py index 426fff5e3f6bf3..1b123a5e1edfb2 100644 --- a/src/transformers/models/whisper/modeling_flax_whisper.py +++ b/src/transformers/models/whisper/modeling_flax_whisper.py @@ -255,7 +255,9 @@ def __call__( # and cache the keys and values step by step. if self.causal and (self.has_variable("cache", "cached_key") or init_cache): - key_states, value_states, attention_mask = self._concatenate_to_cache(key_states, value_states, query_states, attention_mask) + key_states, value_states, attention_mask = self._concatenate_to_cache( + key_states, value_states, query_states, attention_mask + ) # Convert the boolean attention mask to an attention bias. if attention_mask is not None: diff --git a/src/transformers/models/whisper/modeling_whisper.py b/src/transformers/models/whisper/modeling_whisper.py index 122235d9b2d2ed..61234efd58f646 100644 --- a/src/transformers/models/whisper/modeling_whisper.py +++ b/src/transformers/models/whisper/modeling_whisper.py @@ -101,7 +101,6 @@ def __init__(self, num_positions: int, embedding_dim: int, padding_idx: Optional super().__init__(num_positions, embedding_dim) def forward(self, input_ids, past_key_values_length=0): - return self.weight[past_key_values_length : past_key_values_length + input_ids.shape[-1]] @@ -898,7 +897,6 @@ def forward( past_key_value = past_key_values[idx] if past_key_values is not None else None if self.gradient_checkpointing and self.training: - if use_cache: logger.warning( "`use_cache = True` is incompatible with gradient checkpointing. Setting `use_cache =" @@ -924,7 +922,6 @@ def custom_forward(*inputs): None, # past_key_value ) else: - layer_outputs = decoder_layer( hidden_states, attention_mask=attention_mask, From 733ae2bab3c63058a34c94e4e40e04354f505590 Mon Sep 17 00:00:00 2001 From: andyehrenberg Date: Wed, 18 Jan 2023 10:28:43 -0500 Subject: [PATCH 092/111] test formatting --- tests/models/whisper/test_modeling_flax_whisper.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/models/whisper/test_modeling_flax_whisper.py b/tests/models/whisper/test_modeling_flax_whisper.py index dd62ca4223662a..b69c83661844e6 100644 --- a/tests/models/whisper/test_modeling_flax_whisper.py +++ b/tests/models/whisper/test_modeling_flax_whisper.py @@ -255,7 +255,7 @@ def model_jitted(input_features, decoder_input_ids, **kwargs): self.assertEqual(len(outputs), len(jitted_outputs)) for jitted_output, output in zip(jitted_outputs, outputs): self.assertEqual(jitted_output.shape, output.shape) - + # overwrite because of `input_features` @is_pt_flax_cross_test def test_save_load_bf16_to_base_pt(self): From 6295691a34ab19a131edc422fec97bb3bbadbb27 Mon Sep 17 00:00:00 2001 From: andyehrenberg Date: Wed, 18 Jan 2023 10:32:50 -0500 Subject: [PATCH 093/111] remove unused imports --- tests/models/whisper/test_modeling_flax_whisper.py | 7 +------ 1 file changed, 1 insertion(+), 6 deletions(-) diff --git a/tests/models/whisper/test_modeling_flax_whisper.py b/tests/models/whisper/test_modeling_flax_whisper.py index b69c83661844e6..4b685b26c26f0c 100644 --- a/tests/models/whisper/test_modeling_flax_whisper.py +++ b/tests/models/whisper/test_modeling_flax_whisper.py @@ -21,7 +21,7 @@ import transformers from transformers import WhisperConfig, is_flax_available, is_torch_available -from transformers.testing_utils import is_pt_flax_cross_test, require_flax, slow, torch_device +from transformers.testing_utils import is_pt_flax_cross_test, require_flax, slow from transformers.utils import cached_property from transformers.utils.import_utils import is_datasets_available @@ -37,7 +37,6 @@ import numpy as np import jax - import jax.numpy as jnp from flax.core.frozen_dict import unfreeze from flax.traverse_util import flatten_dict from transformers import ( @@ -48,13 +47,9 @@ WhisperProcessor, ) from transformers.modeling_flax_pytorch_utils import ( - convert_pytorch_state_dict_to_flax, load_flax_weights_in_pytorch_model, ) -if is_torch_available(): - import torch - @require_flax class FlaxWhisperModelTester: From 902555e24431dbc4cdb853f56f6f66555722a22f Mon Sep 17 00:00:00 2001 From: andyehrenberg Date: Wed, 18 Jan 2023 11:11:25 -0500 Subject: [PATCH 094/111] remove unused imports --- tests/models/whisper/test_modeling_flax_whisper.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/models/whisper/test_modeling_flax_whisper.py b/tests/models/whisper/test_modeling_flax_whisper.py index 4b685b26c26f0c..a24c3993a598df 100644 --- a/tests/models/whisper/test_modeling_flax_whisper.py +++ b/tests/models/whisper/test_modeling_flax_whisper.py @@ -20,7 +20,7 @@ import unittest import transformers -from transformers import WhisperConfig, is_flax_available, is_torch_available +from transformers import WhisperConfig, is_flax_available from transformers.testing_utils import is_pt_flax_cross_test, require_flax, slow from transformers.utils import cached_property from transformers.utils.import_utils import is_datasets_available From cba4942162ecbf24f25d5242e160b968e64eb2d2 Mon Sep 17 00:00:00 2001 From: andyehrenberg Date: Wed, 18 Jan 2023 15:59:56 -0500 Subject: [PATCH 095/111] formatting --- src/transformers/generation/flax_utils.py | 10 ++++------ 1 file changed, 4 insertions(+), 6 deletions(-) diff --git a/src/transformers/generation/flax_utils.py b/src/transformers/generation/flax_utils.py index 3cd4cada619f2d..a26ff7885e7d76 100644 --- a/src/transformers/generation/flax_utils.py +++ b/src/transformers/generation/flax_utils.py @@ -372,12 +372,10 @@ def generate( has_default_max_length = kwargs.get("max_length") is None and generation_config.max_length is not None if has_default_max_length and generation_config.max_new_tokens is None: warnings.warn( - ( - "Neither `max_length` nor `max_new_tokens` have been set, `max_length` will default to" - f" {generation_config.max_length} (`generation_config.max_length`). Controlling `max_length` via" - " the config is deprecated and `max_length` will be removed from the config in v5 of Transformers" - " -- we recommend using `max_new_tokens` to control the maximum length of the generation." - ), + "Neither `max_length` nor `max_new_tokens` have been set, `max_length` will default to" + f" {generation_config.max_length} (`generation_config.max_length`). Controlling `max_length` via" + " the config is deprecated and `max_length` will be removed from the config in v5 of Transformers" + " -- we recommend using `max_new_tokens` to control the maximum length of the generation.", UserWarning, ) elif has_default_max_length and generation_config.max_new_tokens is not None: From 0173945485c520b6ebdce85d1dedacb7e0d5f199 Mon Sep 17 00:00:00 2001 From: andyehrenberg Date: Wed, 18 Jan 2023 16:08:55 -0500 Subject: [PATCH 096/111] isort --- tests/models/whisper/test_modeling_flax_whisper.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/tests/models/whisper/test_modeling_flax_whisper.py b/tests/models/whisper/test_modeling_flax_whisper.py index a24c3993a598df..aab9e3c15fc291 100644 --- a/tests/models/whisper/test_modeling_flax_whisper.py +++ b/tests/models/whisper/test_modeling_flax_whisper.py @@ -46,9 +46,7 @@ WhisperFeatureExtractor, WhisperProcessor, ) - from transformers.modeling_flax_pytorch_utils import ( - load_flax_weights_in_pytorch_model, - ) + from transformers.modeling_flax_pytorch_utils import load_flax_weights_in_pytorch_model @require_flax From 48640e5aac13c779c40713e68206422d0664ab27 Mon Sep 17 00:00:00 2001 From: andyehrenberg Date: Wed, 18 Jan 2023 16:26:14 -0500 Subject: [PATCH 097/111] docs --- src/transformers/models/whisper/modeling_flax_whisper.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/src/transformers/models/whisper/modeling_flax_whisper.py b/src/transformers/models/whisper/modeling_flax_whisper.py index 1b123a5e1edfb2..c34673f85a0368 100644 --- a/src/transformers/models/whisper/modeling_flax_whisper.py +++ b/src/transformers/models/whisper/modeling_flax_whisper.py @@ -65,6 +65,7 @@ - [Automatic Differentiation](https://jax.readthedocs.io/en/latest/jax.html#automatic-differentiation) - [Vectorization](https://jax.readthedocs.io/en/latest/jax.html#vectorization-vmap) - [Parallelization](https://jax.readthedocs.io/en/latest/jax.html#parallelization-pmap) + Parameters: config ([`WhisperConfig`]): Model configuration class with all the parameters of the model. Initializing with a config file does not load the weights associated with the model, only the @@ -99,8 +100,8 @@ be used by default. If you want to change padding behavior, you should modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more information on the default strategy. position_ids (`numpy.ndarray` of shape `(batch_size, sequence_length)`, *optional*): - Whisper does not use `position_ids` in the encoder as `input_features` is always the same size and doesn't use - masking, but this argument is preserved for compatibility. By default the silence in the input log mel + Whisper does not use `position_ids` in the encoder as `input_features` is always the same size and doesn't + use masking, but this argument is preserved for compatibility. By default the silence in the input log mel spectrogram are ignored. decoder_position_ids (`numpy.ndarray` of shape `(batch_size, sequence_length)`, *optional*): Indices of positions of each decoder input sequence tokens in the position embeddings. Selected in the From 1daee2bac02497a86e607677791dc62ca8c33b3c Mon Sep 17 00:00:00 2001 From: andyehrenberg Date: Thu, 26 Jan 2023 08:51:33 -0500 Subject: [PATCH 098/111] fix ln orders for encoder hidden states --- .../models/whisper/modeling_flax_whisper.py | 16 ++++++++++++---- .../models/whisper/modeling_tf_whisper.py | 2 +- .../models/whisper/modeling_whisper.py | 3 +-- 3 files changed, 14 insertions(+), 7 deletions(-) diff --git a/src/transformers/models/whisper/modeling_flax_whisper.py b/src/transformers/models/whisper/modeling_flax_whisper.py index c34673f85a0368..a99a49d2f9f649 100644 --- a/src/transformers/models/whisper/modeling_flax_whisper.py +++ b/src/transformers/models/whisper/modeling_flax_whisper.py @@ -666,14 +666,22 @@ def __call__( return_dict=return_dict, ) - hidden_states = self.layer_norm(outputs[0]) + last_hidden_states = outputs[0] + last_hidden_states = self.layer_norm(last_hidden_states) + + # update the last element in `hidden_states` after applying `layernorm` above + hidden_states = None + if output_hidden_states: + hidden_states = outputs[1] + hidden_states = hidden_states[:-1] + (last_hidden_states,) if not return_dict: - return (hidden_states,) + outputs[1:] + outputs = (last_hidden_states, hidden_states) + (outputs[2:] if output_hidden_states else outputs[1:]) + return tuple(v for v in outputs if v is not None) return FlaxBaseModelOutput( - last_hidden_state=hidden_states, - hidden_states=outputs.hidden_states, + last_hidden_state=last_hidden_states, + hidden_states=hidden_states, attentions=outputs.attentions, ) diff --git a/src/transformers/models/whisper/modeling_tf_whisper.py b/src/transformers/models/whisper/modeling_tf_whisper.py index e83cf43846f025..7a76d42fd526b6 100644 --- a/src/transformers/models/whisper/modeling_tf_whisper.py +++ b/src/transformers/models/whisper/modeling_tf_whisper.py @@ -713,9 +713,9 @@ def call( if output_attentions: all_attentions += (attn,) + hidden_states = self.layer_norm(hidden_states) if output_hidden_states: encoder_states = encoder_states + (hidden_states,) - hidden_states = self.layer_norm(hidden_states) if not return_dict: return tuple(v for v in [hidden_states, encoder_states, all_attentions] if v is not None) diff --git a/src/transformers/models/whisper/modeling_whisper.py b/src/transformers/models/whisper/modeling_whisper.py index 61234efd58f646..495a9555e3db4e 100644 --- a/src/transformers/models/whisper/modeling_whisper.py +++ b/src/transformers/models/whisper/modeling_whisper.py @@ -703,11 +703,10 @@ def custom_forward(*inputs): if output_attentions: all_attentions = all_attentions + (layer_outputs[1],) + hidden_states = self.layer_norm(hidden_states) if output_hidden_states: encoder_states = encoder_states + (hidden_states,) - hidden_states = self.layer_norm(hidden_states) - if not return_dict: return tuple(v for v in [hidden_states, encoder_states, all_attentions] if v is not None) return BaseModelOutput( From 632c4be4149641b9afdee88763909706a4929336 Mon Sep 17 00:00:00 2001 From: andyehrenberg Date: Fri, 3 Feb 2023 09:20:59 -0500 Subject: [PATCH 099/111] whisper unique generation stuff --- .../generation/flax_logits_process.py | 118 +++++++++++++++++- src/transformers/generation/flax_utils.py | 30 ++++- .../models/whisper/modeling_flax_whisper.py | 66 ++++++++++ 3 files changed, 211 insertions(+), 3 deletions(-) diff --git a/src/transformers/generation/flax_logits_process.py b/src/transformers/generation/flax_logits_process.py index 626f90f24de701..1af117aed9daa9 100644 --- a/src/transformers/generation/flax_logits_process.py +++ b/src/transformers/generation/flax_logits_process.py @@ -328,8 +328,8 @@ def __init__(self, force_token_map): # Indexes without forced tokens will have a negative value. force_token_array = jnp.ones((max(force_token_map.keys()) + 1), dtype=jnp.int32) * -1 for index, token in force_token_map.items(): - force_token_array[index] = token - self.force_token_array = jnp.array(force_token_array) + force_token_array = force_token_array.at[index].set(token) + self.force_token_array = jnp.int32(force_token_array) def __call__(self, input_ids: jnp.ndarray, scores: jnp.ndarray, cur_len: int) -> jnp.ndarray: def _force_token(generation_idx): @@ -355,3 +355,117 @@ def _force_token(generation_idx): ), ) return scores + + +class FlaxWhisperTimeStampLogitsProcessor(FlaxLogitsProcessor): + r""" + Whisper specific Processor. This processor can be used to force a list of tokens. The processor will set their log + probs to `inf` so that they are sampled at their corresponding index. + Args: + generate_config (`GenerateConfig`): + The generate config used to generate the output. The following parameters are required: + eos_token_id (`int`, *optional*, defaults to 50257): + The id of the *end-of-sequence* token. + no_timestamps_token_id (`int`, *optional*, defaults to 50363): + The id of the `"<|notimestamps|>"` token. + max_initial_timestamp_index (`int`, *optional*, defaults to 1): + Used to set the maximum value of the initial timestamp. This is used to prevent the model from + predicting timestamps that are too far in the future. + """ + + def __init__(self, generate_config, model_config, decoder_input_length): + self.eos_token_id = generate_config.eos_token_id + self.no_timestamps_token_id = generate_config.no_timestamps_token_id + self.timestamp_begin = generate_config.no_timestamps_token_id + 1 + + self.begin_index = decoder_input_length + 1 # len(generate_config.forced_decoder_ids) + 1 + # if generate_config.forced_decoder_ids[-1][1] == self.no_timestamps_token_id: + # self.begin_index -= 1 + if generate_config.is_multilingual: + # room for language token and task token + self.begin_index += 2 + if hasattr(generate_config, "max_initial_timestamp_index"): + self.max_initial_timestamp_index = generate_config.max_initial_timestamp_index + else: + self.max_initial_timestamp_index = model_config.vocab_size + if self.max_initial_timestamp_index is None: + self.max_initial_timestamp_index = model_config.vocab_size + + def __call__(self, input_ids, scores, cur_len): + # suppress <|notimestamps|> which is handled by without_timestamps + scores = scores.at[:, self.no_timestamps_token_id].set(-float("inf")) + # if input_ids.shape[1] == self.begin_index: + # scores[:, self.timestamp_begin] = 0 + + def handle_pairs(input_ids_k, scores_k): + last_was_timestamp_1 = jax.lax.cond( + (cur_len - self.begin_index) >= 1, + lambda: True, + lambda: False, + ) + last_was_timestamp_2 = jax.lax.cond( + input_ids_k[cur_len - 1] >= self.timestamp_begin, + lambda: True, + lambda: False, + ) + last_was_timestamp = last_was_timestamp_1 * last_was_timestamp_2 + + penultimate_was_timestamp_1 = jax.lax.cond( + (cur_len - self.begin_index) < 2, + lambda: True, + lambda: False, + ) + penultimate_was_timestamp_2 = jax.lax.cond( + input_ids_k[cur_len - 2] >= self.timestamp_begin, + lambda: True, + lambda: False, + ) + penultimate_was_timestamp = penultimate_was_timestamp_1 + penultimate_was_timestamp_2 + + def if_true(): + return jax.lax.cond( + penultimate_was_timestamp > 0, + lambda: scores_k.at[self.timestamp_begin :].set(-float("inf")), + lambda: scores_k.at[: self.eos_token_id].set(-float("inf")), + ) + + return jax.lax.cond(last_was_timestamp, if_true, lambda: scores_k) + + scores = jax.vmap(handle_pairs)(input_ids, scores) + + apply_max_initial_timestamp = jax.lax.cond( + cur_len == self.begin_index, + lambda: True, + lambda: False, + ) + apply_max_initial_timestamp = jax.lax.cond( + self.max_initial_timestamp_index is not None, + lambda: True and apply_max_initial_timestamp, + lambda: False, + ) + + def if_true(): + last_allowed = self.timestamp_begin + self.max_initial_timestamp_index + return scores.at[:, last_allowed + 1 :].set(-float("inf")) + + scores = jnp.where( + apply_max_initial_timestamp == True, + if_true(), + scores, + ) + + # if sum of probability over timestamps is above any other token, sample timestamp + logprobs = jax.nn.log_softmax(scores, axis=-1) + + def handle_cumulative_probs(logprobs_k, scores_k): + timestamp_logprob = jax.nn.logsumexp(logprobs_k[self.timestamp_begin :], axis=-1) + max_text_token_logprob = jnp.max(logprobs_k[: self.timestamp_begin]) + return jnp.where( + timestamp_logprob > max_text_token_logprob, + scores_k.at[: self.timestamp_begin].set(-float("inf")), + scores_k, + ) + + scores = jax.vmap(handle_cumulative_probs)(logprobs, scores) + + return scores diff --git a/src/transformers/generation/flax_utils.py b/src/transformers/generation/flax_utils.py index a26ff7885e7d76..e07932f5db1ac6 100644 --- a/src/transformers/generation/flax_utils.py +++ b/src/transformers/generation/flax_utils.py @@ -272,6 +272,7 @@ def generate( prng_key: Optional[jnp.ndarray] = None, trace: bool = True, params: Optional[Dict[str, jnp.ndarray]] = None, + logits_processor: Optional[FlaxLogitsProcessorList] = None, **kwargs, ): r""" @@ -324,6 +325,8 @@ def generate( model_kwargs = generation_config.update(**kwargs) # All unused kwargs must be model kwargs self._validate_model_kwargs(model_kwargs.copy()) + logits_processor = logits_processor if logits_processor is not None else FlaxLogitsProcessorList() + # set init values prng_key = prng_key if prng_key is not None else jax.random.PRNGKey(0) @@ -402,7 +405,9 @@ def generate( ) logits_processor = self._get_logits_processor( - generation_config=generation_config, input_ids_seq_length=input_ids_seq_length + generation_config=generation_config, + input_ids_seq_length=input_ids_seq_length, + logits_processor=logits_processor, ) if not generation_config.do_sample and generation_config.num_beams == 1: @@ -479,6 +484,7 @@ def _get_logits_processor( self, generation_config: GenerationConfig, input_ids_seq_length: int, + logits_processor: Optional[FlaxLogitsProcessorList], ) -> FlaxLogitsProcessorList: """ This class returns a [`FlaxLogitsProcessorList`] list object that contains all relevant [`FlaxLogitsProcessor`] @@ -521,9 +527,31 @@ def _get_logits_processor( [input_ids_seq_length + i[0] - 1, i[1]] for i in generation_config.forced_decoder_ids ] processors.append(FlaxForceTokensLogitsProcessor(forced_decoder_ids)) + processors = self._merge_criteria_processor_list(processors, logits_processor) return processors + def _merge_criteria_processor_list( + self, + default_list: FlaxLogitsProcessorList, + custom_list: FlaxLogitsProcessorList, + ) -> FlaxLogitsProcessorList: + if len(custom_list) == 0: + return default_list + for default in default_list: + for custom in custom_list: + if type(custom) is type(default): + object_type = "logits processor" + raise ValueError( + f"A custom {object_type} of type {type(custom)} with values {custom} has been passed to" + f" `generate`, but it has already been created with the values {default}. {default} has been" + " created by passing the corresponding arguments to generate or by the model's config default" + f" values. If you just want to change the default values of {object_type} consider passing" + f" them as arguments to `generate` instead of using a custom {object_type}." + ) + default_list.extend(custom_list) + return default_list + def _greedy_search( self, input_ids: None, diff --git a/src/transformers/models/whisper/modeling_flax_whisper.py b/src/transformers/models/whisper/modeling_flax_whisper.py index a99a49d2f9f649..ba71e2f5b8b3aa 100644 --- a/src/transformers/models/whisper/modeling_flax_whisper.py +++ b/src/transformers/models/whisper/modeling_flax_whisper.py @@ -28,6 +28,7 @@ from jax import lax from jax.random import PRNGKey +from ...generation.flax_logits_process import FlaxWhisperTimeStampLogitsProcessor from ...modeling_flax_outputs import ( FlaxBaseModelOutput, FlaxBaseModelOutputWithPastAndCrossAttentions, @@ -1341,6 +1342,71 @@ def _decoder_forward(module, decoder_input_ids, decoder_attention_mask, decoder_ return outputs + def generate( + self, + input_features, + generation_config=None, + logits_processor=None, + return_timestamps=None, + task=None, + language=None, + is_multilingual=None, + **kwargs + ): + if generation_config is None: + generation_config = self.generation_config + + if return_timestamps is not None: + generation_config.return_timestamps = return_timestamps + + if task is not None: + generation_config.task = task + + if is_multilingual is not None: + generation_config.is_multilingual = is_multilingual + + if language is not None: + generation_config.language = language + + if kwargs is not None and "decoder_input_ids" in kwargs: + decoder_input_length = len(kwargs["decoder_input_ids"]) + else: + decoder_input_length = 1 + + forced_decoder_ids = [] + + if hasattr(generation_config, "is_multilingual") and generation_config.is_multilingual: + if hasattr(generation_config, "language"): + forced_decoder_ids.append((1, generation_config.lang_to_id[generation_config.language])) + else: + forced_decoder_ids.append((1, None)) + + if hasattr(generation_config, "task"): + forced_decoder_ids.append((2, generation_config.task_to_id[generation_config.task])) + else: + forced_decoder_ids.append((2, generation_config.task_to_id["transcribe"])) + + if ( + hasattr(generation_config, "return_timestamps") and generation_config.return_timestamps + ) or return_timestamps: + logits_processor = [ + FlaxWhisperTimeStampLogitsProcessor(generation_config, self.config, decoder_input_length) + ] + else: + if forced_decoder_ids and forced_decoder_ids[-1][0] != generation_config.no_timestamps_token_id: + idx = forced_decoder_ids[-1][0] + 1 if forced_decoder_ids else 1 + forced_decoder_ids.append((idx, generation_config.no_timestamps_token_id)) + + if len(forced_decoder_ids) > 0: + generation_config.forced_decoder_ids = forced_decoder_ids + + return super().generate( + input_features, + generation_config, + logits_processor=logits_processor, + **kwargs, + ) + def prepare_inputs_for_generation( self, decoder_input_ids, From c5c3ac1e3ec9613d6870291331c8e867120dc372 Mon Sep 17 00:00:00 2001 From: andyehrenberg Date: Fri, 3 Feb 2023 14:11:28 -0500 Subject: [PATCH 100/111] flake --- src/transformers/generation/flax_logits_process.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/transformers/generation/flax_logits_process.py b/src/transformers/generation/flax_logits_process.py index 1af117aed9daa9..a18bbcd5f1e3d0 100644 --- a/src/transformers/generation/flax_logits_process.py +++ b/src/transformers/generation/flax_logits_process.py @@ -449,7 +449,7 @@ def if_true(): return scores.at[:, last_allowed + 1 :].set(-float("inf")) scores = jnp.where( - apply_max_initial_timestamp == True, + apply_max_initial_timestamp, if_true(), scores, ) From 907905f8a8753b3e21b94c6249ea423e8645a3c7 Mon Sep 17 00:00:00 2001 From: andyehrenberg Date: Fri, 3 Feb 2023 14:18:10 -0500 Subject: [PATCH 101/111] use finfo for attention bias --- src/transformers/models/whisper/modeling_flax_whisper.py | 9 +++------ 1 file changed, 3 insertions(+), 6 deletions(-) diff --git a/src/transformers/models/whisper/modeling_flax_whisper.py b/src/transformers/models/whisper/modeling_flax_whisper.py index ba71e2f5b8b3aa..f66a02453d7936 100644 --- a/src/transformers/models/whisper/modeling_flax_whisper.py +++ b/src/transformers/models/whisper/modeling_flax_whisper.py @@ -52,7 +52,6 @@ _CHECKPOINT_FOR_DOC = "openai/whisper-tiny" _CONFIG_FOR_DOC = "WhisperConfig" -_TOKENIZER_FOR_DOC = "WhisperTokenizer" WHISPER_START_DOCSTRING = r""" @@ -266,8 +265,8 @@ def __call__( # attention mask in the form of attention bias attention_bias = lax.select( attention_mask > 0, - jnp.full(attention_mask.shape, 0.0), - jnp.full(attention_mask.shape, -1e4), + jnp.full(attention_mask.shape, 0.0).astype(self.dtype), + jnp.full(attention_mask.shape, jnp.finfo(self.dtype).min).astype(self.dtype), ) else: attention_bias = None @@ -1132,9 +1131,7 @@ class FlaxWhisperModel(FlaxWhisperPreTrainedModel): module_class = FlaxWhisperModule -append_call_sample_docstring( - FlaxWhisperModel, _TOKENIZER_FOR_DOC, _CHECKPOINT_FOR_DOC, FlaxSeq2SeqModelOutput, _CONFIG_FOR_DOC -) +append_call_sample_docstring(FlaxWhisperModel, _CHECKPOINT_FOR_DOC, FlaxSeq2SeqModelOutput, _CONFIG_FOR_DOC) class FlaxWhisperForConditionalGenerationModule(nn.Module): From 9dbcda869eb1e35983846cbbbc8f05cae4e44674 Mon Sep 17 00:00:00 2001 From: andyehrenberg Date: Fri, 3 Feb 2023 17:01:21 -0500 Subject: [PATCH 102/111] docs --- src/transformers/generation/flax_logits_process.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/transformers/generation/flax_logits_process.py b/src/transformers/generation/flax_logits_process.py index a18bbcd5f1e3d0..64567d7afde99a 100644 --- a/src/transformers/generation/flax_logits_process.py +++ b/src/transformers/generation/flax_logits_process.py @@ -361,6 +361,7 @@ class FlaxWhisperTimeStampLogitsProcessor(FlaxLogitsProcessor): r""" Whisper specific Processor. This processor can be used to force a list of tokens. The processor will set their log probs to `inf` so that they are sampled at their corresponding index. + Args: generate_config (`GenerateConfig`): The generate config used to generate the output. The following parameters are required: From d36cd2c5200df67e40cbd89584e99751942c8e49 Mon Sep 17 00:00:00 2001 From: Andy Ehrenberg <32784181+andyehrenberg@users.noreply.github.com> Date: Tue, 14 Feb 2023 14:12:02 -0800 Subject: [PATCH 103/111] Update src/transformers/generation/flax_utils.py Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com> --- src/transformers/generation/flax_utils.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/src/transformers/generation/flax_utils.py b/src/transformers/generation/flax_utils.py index 058720dfccdc09..5147b5482ee10e 100644 --- a/src/transformers/generation/flax_utils.py +++ b/src/transformers/generation/flax_utils.py @@ -517,9 +517,8 @@ def _get_logits_processor( else begin_index + 1 ) if generation_config.forced_decoder_ids is not None and len(generation_config.forced_decoder_ids) > 0: - begin_index += generation_config.forced_decoder_ids[-1][ - 0 - ] # generation starts after the last token that is forced + # generation starts after the last token that is forced + begin_index += generation_config.forced_decoder_ids[-1][0] processors.append( FlaxSuppressTokensAtBeginLogitsProcessor(generation_config.begin_suppress_tokens, begin_index) ) From ab01cfcc0318d552d3a49bed8d2911a99e05d2fb Mon Sep 17 00:00:00 2001 From: andyehrenberg Date: Tue, 14 Feb 2023 17:14:28 -0500 Subject: [PATCH 104/111] docs --- src/transformers/generation/flax_utils.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/src/transformers/generation/flax_utils.py b/src/transformers/generation/flax_utils.py index 5147b5482ee10e..5ced61434d30e0 100644 --- a/src/transformers/generation/flax_utils.py +++ b/src/transformers/generation/flax_utils.py @@ -293,6 +293,10 @@ def generate( considerably slower runtime. params (`Dict[str, jnp.ndarray]`, *optional*): Optionally the model parameters can be passed. Can be useful for parallelized generation. + logits_processor (`FlaxLogitsProcessorList `, *optional*): + Custom logits processors that complement the default logits processors built from arguments and + generation config. If a logit processor is passed that is already created with the arguments or a + generation config an error is thrown. This feature is intended for advanced users. kwargs: Ad hoc parametrization of `generate_config` and/or additional model-specific kwargs that will be forwarded to the `forward` function of the model. If the model is an encoder-decoder model, encoder From 62d172a4760d868d5149659bcea825dfbcce1a28 Mon Sep 17 00:00:00 2001 From: andyehrenberg Date: Tue, 14 Feb 2023 17:19:24 -0500 Subject: [PATCH 105/111] add timestamp flax test --- .../whisper/test_modeling_flax_whisper.py | 68 +++++++++++++++++++ 1 file changed, 68 insertions(+) diff --git a/tests/models/whisper/test_modeling_flax_whisper.py b/tests/models/whisper/test_modeling_flax_whisper.py index aab9e3c15fc291..ca652c6b56fda2 100644 --- a/tests/models/whisper/test_modeling_flax_whisper.py +++ b/tests/models/whisper/test_modeling_flax_whisper.py @@ -640,3 +640,71 @@ def test_tiny_en_batched_generation(self): transcript = processor.batch_decode(generated_ids, skip_special_tokens=True) self.assertListEqual(transcript, EXPECTED_TRANSCRIPT) + + @slow + def test_tiny_timestamp_generation(self): + processor = WhisperProcessor.from_pretrained("openai/whisper-tiny") + model = FlaxWhisperForConditionalGeneration.from_pretrained("openai/whisper-tiny") + + input_speech = np.concatenate(self._load_datasamples(4)) + input_features = processor.feature_extractor( + raw_speech=input_speech, return_tensors="jax" + ).input_features + + generated_ids = model.generate( + input_features, max_length=448, return_timestamps=True + ) + + # fmt: off + EXPECTED_OUTPUT = np.array([50258, 50259, 50359, 50364, 2221, 13, 2326, 388, 391, 307, 264, 50244, 295, 264, 2808, 5359, 11, 293, 321, 366, 5404, 281, 2928, 702, 14943, 13, 50692, 50692, 6966, 307, 2221, 13, 2326, 388, 391, 311, 9060, 1570, 1880, 813, 702, 1871, 13, 50926, 50926, 634, 5112, 505, 300, 412, 341, 42729, 3196, 295, 264, 1064, 11, 365, 5272, 293, 12904, 9256, 450, 10539, 51208, 51208, 949, 505, 11, 14138, 10117, 490, 3936, 293, 1080, 3542, 5160, 881, 26336, 281, 264, 1575, 13, 51552, 51552, 634, 575, 12525, 22618, 1968, 6144, 35617, 7354, 1292, 6, 589, 307, 534, 10281, 934, 439, 11, 293, 51836, 51836, 50257]) + # fmt: on + + self.assertTrue(np.allclose(generated_ids, EXPECTED_OUTPUT)) + + EXPECTED_TRANSCRIPT = [ + { + "text": ( + " Mr. Quilter is the apostle of the middle classes, and we are glad to welcome his gospel. Nor is" + " Mr. Quilter's manner less interesting than his matter. He tells us that at this festive season" + " of the year, with Christmas and roast beef looming before us, similarly drawn from eating and" + " its results occur most readily to the mind. He has grave doubts whether Sir Frederick Latins'" + " work is really Greek after all, and" + ), + "offsets": [ + { + "text": ( + " Mr. Quilter is the apostle of the middle classes, and we are glad to welcome his gospel." + ), + "timestamp": (0.0, 6.5600000000000005), + }, + { + "text": " Nor is Mr. Quilter's manner less interesting than his matter.", + "timestamp": (6.5600000000000005, 11.24), + }, + { + "text": ( + " He tells us that at this festive season of the year, with Christmas and roast beef" + " looming" + ), + "timestamp": (11.24, 16.88), + }, + { + "text": ( + " before us, similarly drawn from eating and its results occur most readily to the mind." + ), + "timestamp": (16.88, 23.76), + }, + { + "text": ( + " He has grave doubts whether Sir Frederick Latins' work is really Greek after all, and" + ), + "timestamp": (23.76, 29.44), + }, + ], + } + ] + + transcript = processor.batch_decode( + generated_ids, skip_special_tokens=True, output_offsets=True + ) + self.assertEqual(transcript, EXPECTED_TRANSCRIPT) From 455b8bfd061d9e94076f880bfe3875c1b2cf1bd9 Mon Sep 17 00:00:00 2001 From: andyehrenberg Date: Tue, 14 Feb 2023 17:22:53 -0500 Subject: [PATCH 106/111] jit for timestamps --- tests/models/whisper/test_modeling_flax_whisper.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/tests/models/whisper/test_modeling_flax_whisper.py b/tests/models/whisper/test_modeling_flax_whisper.py index ca652c6b56fda2..b11320f9749385 100644 --- a/tests/models/whisper/test_modeling_flax_whisper.py +++ b/tests/models/whisper/test_modeling_flax_whisper.py @@ -651,10 +651,12 @@ def test_tiny_timestamp_generation(self): raw_speech=input_speech, return_tensors="jax" ).input_features - generated_ids = model.generate( - input_features, max_length=448, return_timestamps=True + generate_fn = jax.jit( + functools.partial(model.generate, max_length=448, return_timestamps=True) ) + generated_ids = generate_fn(input_features) + # fmt: off EXPECTED_OUTPUT = np.array([50258, 50259, 50359, 50364, 2221, 13, 2326, 388, 391, 307, 264, 50244, 295, 264, 2808, 5359, 11, 293, 321, 366, 5404, 281, 2928, 702, 14943, 13, 50692, 50692, 6966, 307, 2221, 13, 2326, 388, 391, 311, 9060, 1570, 1880, 813, 702, 1871, 13, 50926, 50926, 634, 5112, 505, 300, 412, 341, 42729, 3196, 295, 264, 1064, 11, 365, 5272, 293, 12904, 9256, 450, 10539, 51208, 51208, 949, 505, 11, 14138, 10117, 490, 3936, 293, 1080, 3542, 5160, 881, 26336, 281, 264, 1575, 13, 51552, 51552, 634, 575, 12525, 22618, 1968, 6144, 35617, 7354, 1292, 6, 589, 307, 534, 10281, 934, 439, 11, 293, 51836, 51836, 50257]) # fmt: on From 89658d00ce342c50c1550221bdaefe7ee1636cd6 Mon Sep 17 00:00:00 2001 From: andyehrenberg Date: Tue, 14 Feb 2023 17:29:13 -0500 Subject: [PATCH 107/111] formatting --- tests/models/whisper/test_modeling_flax_whisper.py | 12 +++--------- 1 file changed, 3 insertions(+), 9 deletions(-) diff --git a/tests/models/whisper/test_modeling_flax_whisper.py b/tests/models/whisper/test_modeling_flax_whisper.py index b11320f9749385..a102f5d48df0e5 100644 --- a/tests/models/whisper/test_modeling_flax_whisper.py +++ b/tests/models/whisper/test_modeling_flax_whisper.py @@ -647,13 +647,9 @@ def test_tiny_timestamp_generation(self): model = FlaxWhisperForConditionalGeneration.from_pretrained("openai/whisper-tiny") input_speech = np.concatenate(self._load_datasamples(4)) - input_features = processor.feature_extractor( - raw_speech=input_speech, return_tensors="jax" - ).input_features + input_features = processor.feature_extractor(raw_speech=input_speech, return_tensors="jax").input_features - generate_fn = jax.jit( - functools.partial(model.generate, max_length=448, return_timestamps=True) - ) + generate_fn = jax.jit(functools.partial(model.generate, max_length=448, return_timestamps=True)) generated_ids = generate_fn(input_features) @@ -706,7 +702,5 @@ def test_tiny_timestamp_generation(self): } ] - transcript = processor.batch_decode( - generated_ids, skip_special_tokens=True, output_offsets=True - ) + transcript = processor.batch_decode(generated_ids, skip_special_tokens=True, output_offsets=True) self.assertEqual(transcript, EXPECTED_TRANSCRIPT) From a75fd038dfd8a937269f77ade0bb0d185d663c0a Mon Sep 17 00:00:00 2001 From: andyehrenberg Date: Wed, 15 Feb 2023 08:19:33 -0500 Subject: [PATCH 108/111] clean up timestamps processor --- .../generation/flax_logits_process.py | 48 ++++++++----------- 1 file changed, 19 insertions(+), 29 deletions(-) diff --git a/src/transformers/generation/flax_logits_process.py b/src/transformers/generation/flax_logits_process.py index 64567d7afde99a..836d15daf2d1ef 100644 --- a/src/transformers/generation/flax_logits_process.py +++ b/src/transformers/generation/flax_logits_process.py @@ -395,54 +395,44 @@ def __init__(self, generate_config, model_config, decoder_input_length): def __call__(self, input_ids, scores, cur_len): # suppress <|notimestamps|> which is handled by without_timestamps scores = scores.at[:, self.no_timestamps_token_id].set(-float("inf")) - # if input_ids.shape[1] == self.begin_index: - # scores[:, self.timestamp_begin] = 0 def handle_pairs(input_ids_k, scores_k): - last_was_timestamp_1 = jax.lax.cond( - (cur_len - self.begin_index) >= 1, - lambda: True, - lambda: False, + last_was_timestamp = jnp.where( + (cur_len - self.begin_index) >= 1, True, False ) - last_was_timestamp_2 = jax.lax.cond( + last_was_timestamp = jnp.where( input_ids_k[cur_len - 1] >= self.timestamp_begin, - lambda: True, - lambda: False, + True and last_was_timestamp, + False, ) - last_was_timestamp = last_was_timestamp_1 * last_was_timestamp_2 - penultimate_was_timestamp_1 = jax.lax.cond( - (cur_len - self.begin_index) < 2, - lambda: True, - lambda: False, + penultimate_was_timestamp = jnp.where( + (cur_len - self.begin_index) < 2, True, False ) - penultimate_was_timestamp_2 = jax.lax.cond( + penultimate_was_timestamp = jnp.where( input_ids_k[cur_len - 2] >= self.timestamp_begin, - lambda: True, - lambda: False, + True, + penultimate_was_timestamp, ) - penultimate_was_timestamp = penultimate_was_timestamp_1 + penultimate_was_timestamp_2 def if_true(): - return jax.lax.cond( + return jnp.where( penultimate_was_timestamp > 0, - lambda: scores_k.at[self.timestamp_begin :].set(-float("inf")), - lambda: scores_k.at[: self.eos_token_id].set(-float("inf")), + scores_k.at[self.timestamp_begin :].set(-float("inf")), + scores_k.at[: self.eos_token_id].set(-float("inf")), ) - return jax.lax.cond(last_was_timestamp, if_true, lambda: scores_k) + return jnp.where(last_was_timestamp, if_true(), scores_k) scores = jax.vmap(handle_pairs)(input_ids, scores) - apply_max_initial_timestamp = jax.lax.cond( - cur_len == self.begin_index, - lambda: True, - lambda: False, + apply_max_initial_timestamp = jnp.where( + cur_len == self.begin_index, True, False ) - apply_max_initial_timestamp = jax.lax.cond( + apply_max_initial_timestamp = jnp.where( self.max_initial_timestamp_index is not None, - lambda: True and apply_max_initial_timestamp, - lambda: False, + True and apply_max_initial_timestamp, + False, ) def if_true(): From 758d56c89242f1af2d3f85b430e5fc9975cca8cb Mon Sep 17 00:00:00 2001 From: andyehrenberg Date: Wed, 15 Feb 2023 08:26:36 -0500 Subject: [PATCH 109/111] formatting --- src/transformers/generation/flax_logits_process.py | 12 +++--------- 1 file changed, 3 insertions(+), 9 deletions(-) diff --git a/src/transformers/generation/flax_logits_process.py b/src/transformers/generation/flax_logits_process.py index 836d15daf2d1ef..ef245975984878 100644 --- a/src/transformers/generation/flax_logits_process.py +++ b/src/transformers/generation/flax_logits_process.py @@ -397,18 +397,14 @@ def __call__(self, input_ids, scores, cur_len): scores = scores.at[:, self.no_timestamps_token_id].set(-float("inf")) def handle_pairs(input_ids_k, scores_k): - last_was_timestamp = jnp.where( - (cur_len - self.begin_index) >= 1, True, False - ) + last_was_timestamp = jnp.where((cur_len - self.begin_index) >= 1, True, False) last_was_timestamp = jnp.where( input_ids_k[cur_len - 1] >= self.timestamp_begin, True and last_was_timestamp, False, ) - penultimate_was_timestamp = jnp.where( - (cur_len - self.begin_index) < 2, True, False - ) + penultimate_was_timestamp = jnp.where((cur_len - self.begin_index) < 2, True, False) penultimate_was_timestamp = jnp.where( input_ids_k[cur_len - 2] >= self.timestamp_begin, True, @@ -426,9 +422,7 @@ def if_true(): scores = jax.vmap(handle_pairs)(input_ids, scores) - apply_max_initial_timestamp = jnp.where( - cur_len == self.begin_index, True, False - ) + apply_max_initial_timestamp = jnp.where(cur_len == self.begin_index, True, False) apply_max_initial_timestamp = jnp.where( self.max_initial_timestamp_index is not None, True and apply_max_initial_timestamp, From f9ac6525b9d2c5f49f3bc2341679c6c4a6539279 Mon Sep 17 00:00:00 2001 From: andyehrenberg Date: Fri, 17 Feb 2023 13:02:11 -0500 Subject: [PATCH 110/111] remove if_true --- .../generation/flax_logits_process.py | 17 ++++++++--------- 1 file changed, 8 insertions(+), 9 deletions(-) diff --git a/src/transformers/generation/flax_logits_process.py b/src/transformers/generation/flax_logits_process.py index ef245975984878..4860ae15aafff7 100644 --- a/src/transformers/generation/flax_logits_process.py +++ b/src/transformers/generation/flax_logits_process.py @@ -411,14 +411,15 @@ def handle_pairs(input_ids_k, scores_k): penultimate_was_timestamp, ) - def if_true(): - return jnp.where( + return jnp.where( + last_was_timestamp, + jnp.where( penultimate_was_timestamp > 0, scores_k.at[self.timestamp_begin :].set(-float("inf")), scores_k.at[: self.eos_token_id].set(-float("inf")), - ) - - return jnp.where(last_was_timestamp, if_true(), scores_k) + ), + scores_k, + ) scores = jax.vmap(handle_pairs)(input_ids, scores) @@ -429,13 +430,11 @@ def if_true(): False, ) - def if_true(): - last_allowed = self.timestamp_begin + self.max_initial_timestamp_index - return scores.at[:, last_allowed + 1 :].set(-float("inf")) + last_allowed = self.timestamp_begin + self.max_initial_timestamp_index scores = jnp.where( apply_max_initial_timestamp, - if_true(), + scores.at[:, last_allowed + 1 :].set(-float("inf")), scores, ) From 94a526e52cbf78e6fbacd54f219cc0937ef0cb77 Mon Sep 17 00:00:00 2001 From: andyehrenberg Date: Fri, 17 Feb 2023 13:06:40 -0500 Subject: [PATCH 111/111] cleanup --- src/transformers/generation/flax_logits_process.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/src/transformers/generation/flax_logits_process.py b/src/transformers/generation/flax_logits_process.py index 4860ae15aafff7..c5be0cfca937ce 100644 --- a/src/transformers/generation/flax_logits_process.py +++ b/src/transformers/generation/flax_logits_process.py @@ -379,9 +379,8 @@ def __init__(self, generate_config, model_config, decoder_input_length): self.no_timestamps_token_id = generate_config.no_timestamps_token_id self.timestamp_begin = generate_config.no_timestamps_token_id + 1 - self.begin_index = decoder_input_length + 1 # len(generate_config.forced_decoder_ids) + 1 - # if generate_config.forced_decoder_ids[-1][1] == self.no_timestamps_token_id: - # self.begin_index -= 1 + self.begin_index = decoder_input_length + 1 + if generate_config.is_multilingual: # room for language token and task token self.begin_index += 2