Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support token suppression, forced tokens (besides eos and bos), and decoder prompting for flax generation #20539

Closed
andyehrenberg opened this issue Dec 1, 2022 · 4 comments

Comments

@andyehrenberg
Copy link
Contributor

Feature request

Add logits processors for token suppression and forced tokens at specific indices.
Enable prompting the decoder of encoder-decoder models with decoder_input_ids.

Motivation

Currently, the flax generation utilities do not support token suppression, forcing specific tokens to be decoded at specific response indices, nor prompting the decoder (helpful for models like Whisper that support decoder prompts - Flax Whisper is implemented in #20479). Adding these would move the flax utilities closer to feature parity with the pytorch generation utilities. Adding these features would fully unlock a flax implementation of Whisper inference.

Your contribution

I already have these features implemented in a branch of my fork - happy to open a PR!

@andyehrenberg
Copy link
Contributor Author

@sanchit-gandhi
Copy link
Contributor

Did we decide to implement these features in the Flax Whisper PR in the end? cc @ArthurZucker

@andyehrenberg
Copy link
Contributor Author

@sanchit-gandhi @ArthurZucker I just added these back into the Flax Whisper PR

@sanchit-gandhi
Copy link
Contributor

Cool! Closing this issue in favour of the PR #20479

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants