-
Notifications
You must be signed in to change notification settings - Fork 27.3k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
add flax whisper implementation #20479
Conversation
Thank you for the PR. However, a pull request should focus on a single objective/goal, rather than changing multiple things at the same time which are not absolutely coupled. Please
The goal of this PR is to add Flax implementation of Whisper. For other changes, it's better to open issue tickets, and if we all agree with the proposals, a PR could proceed :-) Thank you! |
I see a few other instances in this repo where the pytorch implementation computes Happy to remove the changes to the generation stuff and open a separate PR for that - will definitely do this to make flax Whisper generation work! |
I wasn't aware of that inconsistency, thank you for pointing out. This is a good question! But I don't think that's a very serious problem so far - the most important thing is the different frameworks produce the same outputs when feeding the same (supported) inputs + the API on the top model levels being consistent. (The internal computation could be somehow different - if there is good reason) In any case, this could be discussed in an issue and we can proceed with a PR once decided :-) |
BTW, there is some issue for triggering CircleCI. The message is Could not find a usable config.yml, you may have revoked the CircleCI OAuth app.
Please sign out of CircleCI and log back in with your VCS before triggering a new pipeline. Do you use some IDE to push the commits? Could you try to push the commit with a commandline tool or some git GUI tools instead? |
The documentation is not available anymore as the PR was closed or merged. |
Also cc @sanchit-gandhi |
Hey! Thanks for opening the follow PR 🤗 I don't think I agree with @ydshieh here, adding the Will have a look at the PR 😉 |
You are right! I am not aware of those generation features are introduced when you added Whisper @ArthurZucker . Sorry about that, @andyehrenberg ! |
Super excited by this PR! 🚀 Feel free to tag me with questions / review requests as well @andyehrenberg 🤗 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nice work there!
Not really think we are gonna push for the scan
methods, but it is debatable. @sgugger correct me if I am wrong
if attention_mask is not None: | ||
if position_ids is None: | ||
position_ids = attention_mask.cumsum(-1) - 1 | ||
if position_ids is None: | ||
batch_size, sequence_length = input_ids.shape | ||
position_ids = jnp.broadcast_to(jnp.arange(sequence_length)[None, :], (batch_size, sequence_length)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Would be great if we could follow the simple logic that we have in the pytorch version where we use the input_ids
with self.embed_positions(input_ids, past_key_values_length=past_key_values_length
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think we should stick with computing position_ids
to keep a similar api to the other flax models, and because this better handles the scenario where we have a batch to run generation for with different decoder prompt lengths. The pytorch version ends up just using past_key_values_length
to compute something akin to position_ids
, but we can just use the attention_mask
to figure them out. I'd actually argue we should change the pytorch whisper implementation to use position_ids
, because as it currently stands it'll fail to decode batches of varying decoder prompt lengths - it should take more inspiration from the decoder-only models that compute position_ids
as opposed to the encoder-decoder models that don't assume decoder prefixes.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I agree with @andyehrenberg that we should use the Flax implementation here. However, it would be better still in terms of Flax compatibility if this logic went under the decode
's __call__
method, rather than under FlaxWhisperDecoder (as we do in Flax MBart for example)
Also sorry! We just modified Whisper quit a bit 😅 |
@ArthurZucker - Doesn't actually look too bad to catch up with those changes! Can do that soon-ish. I already have a jax timestamp processor that's compilable. |
Oh no - sorry you have to iterate again here @andyehrenberg! Feel free to ping me with any questions / discussions - more than happy to help with the final sprint of the integration! Otherwise super excited to review a final time before merge! 🚀 |
@sanchit-gandhi - I think this is ready for another look - the recent commits (I think) get us to feature parity with the torch version. |
@sanchit-gandhi Bump |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Wow! Very clean, thanks a lot for the long work! I just left 1 comment on testing the timestamp
generation but should be good to merge otherwise! cc @sanchit-gandhi
# fmt: on | ||
|
||
transcript = processor.batch_decode(generated_ids, skip_special_tokens=True) | ||
self.assertListEqual(transcript, EXPECTED_TRANSCRIPT) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you add the test_tiny_timestamp_generation
where you can test if jit compile produces the correct timestamps?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is just to make sure that the logit processor correctly predicts them. I speak from TF experience, my code worked but when compiling it started failing 😓
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just added - some local sanity checks were working for me under jit compilation at least!
@sanchit-gandhi @ArthurZucker - Addressed Arthur's comments and cleaned up the timestamp logits processor a bit. Hopefully we're close to getting this merged! |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Very nice @andyehrenberg! Thanks for iterating here - reviewed the new changes and the PR is looking super clean. Last request from me is if we can avoid defining the if_true()
functions if possible and just add the code explicitly! Good for merge otherwise :)
For sure, made those changes :) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks again for your contribution!
* add flax whisper implementation * rever change to setup * remove unused imports * revert generation changes * flax whisper docs * docs * import order * import sorting * isort * add dummy objects * doc formatting * formatting * remove trailing whitespaces * fix flax whisper docs * add generation logic to unlock flax whisper * remove scans * give credits to Flax Bart implementation * remove unused imports * add license * remove assert * more credits to Bart * fix style * formatting * support left padding * add flax whisper generation test * remove copied from comments whenever not a full copy * fix docstrings for logits processors * revert change to FlaxForceTokensLogitsProcessor * revert doc changes * improve generation docs * reorganize * formatting * cleanup docs * add tests * handle empty list case * fix forced decoder ids in flax tests * add flax whisper to inits * upate dummy objects * docs for FlaxAutoModelForSpeechSeq2Seq * fix decoder_position_ids computation in pretrained model decode/__call__ fns * add Copied from statements as necessary * compute position_ids only in __call__ and decode methods of pretrained model subclasses * improve readabilityof compute positional embeddings * check dimensionality of input_features instead of hidden_states * copied from statement for init_cache * formatting * fix copies * fix copies * pass attention mask to encoder layers * fix decoder module outputs * set dtype Co-authored-by: Sanchit Gandhi <[email protected]> * smaller flax model for whisper test * Update src/transformers/generation/flax_utils.py Co-authored-by: Sylvain Gugger <[email protected]> * Update src/transformers/models/whisper/modeling_flax_whisper.py Co-authored-by: Sylvain Gugger <[email protected]> * Update tests/models/whisper/test_modeling_flax_whisper.py Co-authored-by: Sylvain Gugger <[email protected]> * cleanup Co-authored-by: Sylvain Gugger <[email protected]> * Update src/transformers/models/whisper/modeling_flax_whisper.py Co-authored-by: Sylvain Gugger <[email protected]> * bias cleanup * doc fix * align style for force tokens processor * readability * fix input shape in tests * revert FlaxGenerationMixin docstring * formatting * fix tests * fix imports * consistent encoder hidden states * consistent hidden states * input shapes * typo * partial class trick * partial class for input shape * base_class with correct input shape * partial base classes * match by name * set main_input_name * compare on names * formatting * remove unused import * safer position ids computation * safer position id computation * Update src/transformers/models/whisper/modeling_flax_whisper.py Co-authored-by: Sanchit Gandhi <[email protected]> * Update src/transformers/models/whisper/modeling_flax_whisper.py Co-authored-by: Sanchit Gandhi <[email protected]> * remove identical inherited tests * fix prompt ids in tests * use generation config * use jnp array * better var names * more explicit bias use * import transformers * formatting * test formatting * remove unused imports * remove unused imports * formatting * isort * docs * fix ln orders for encoder hidden states * whisper unique generation stuff * flake * use finfo for attention bias * docs * Update src/transformers/generation/flax_utils.py Co-authored-by: Arthur <[email protected]> * docs * add timestamp flax test * jit for timestamps * formatting * clean up timestamps processor * formatting * remove if_true * cleanup --------- Co-authored-by: Sanchit Gandhi <[email protected]> Co-authored-by: Sylvain Gugger <[email protected]> Co-authored-by: Arthur <[email protected]>
Is there any instructions to open the google cloud TPU port, admin? |
Adds Flax whisper implementations, and adjusts flax generation utils to support it.
@ydshieh @ArthurZucker
See discussion in #19512