Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Whisper Prompting #22395

Closed
sanchit-gandhi opened this issue Mar 27, 2023 · 11 comments · Fixed by #22496
Closed

Whisper Prompting #22395

sanchit-gandhi opened this issue Mar 27, 2023 · 11 comments · Fixed by #22496

Comments

@sanchit-gandhi
Copy link
Contributor

sanchit-gandhi commented Mar 27, 2023

Feature request

Add prompting for the Whisper model to control the style/formatting of the generated text.

Motivation

During training, Whisper can be fed a "previous context window" to condition on longer passages of text.

The original OpenAI Whisper implementation provides the user with the option of passing an initial_prompt to the model. This prompt is replaces the "previous context window" during inference.

By passing the prompt as the "previous context window", the Whisper model conditions its generation on whatever text is passed as the prompt. This allows the user to control aspects of the generation, such as spellings of named entities and punctuation formatting (see openai/whisper#963 (comment)).

This is possibly a cheaper way of adapting the Whisper model to specific decoding constraints than fine-tuning.

This notebook demonstrates prompting with the initial codebase, and explains how this can be achieved for HF's Whisper: https://colab.research.google.com/drive/14FSeaoRvgs5arOTfiMQBnQ5NaLyma7Tq?usp=sharing

The proposed API for prompting would look something as follows:

  1. Encode prompt text to prompt token ids (processor.get_prompt_ids) - this method is a wrapper around processor.tokenizer.__call__ that doesn't add the special token ids:
prompt = "IR, Newswire"
prompt_ids = processor.get_prompt_ids(prompt)
  1. Pass the input audio and prompt token ids to the .generate method to get the predicted ids:
pred_ids = model.generate(input_features, prompt_ids=prompt_ids)
  1. Decode the predicted ids and 'slice' off the prompt (we can do this by passing the prompt_ids):
pred_str = processor.batch_decode(pred_ids, prompt_ids=prompt_ids)

=> We would need to wrap all of this forced_decoder_ids logic into the generate method and update the processor/tokenizer accordingly.

Your contribution

Happy to guide the integration and review any PRs!

@sanchit-gandhi
Copy link
Contributor Author

cc @hollance

@pmollerus23
Copy link
Contributor

Hello, I'd like to pick up this issue!

@sanchit-gandhi
Copy link
Contributor Author

Hey @mollerup23! Super cool! We would first need to update the generate modelling code to slide the forced decoder ids as explained in the notebook:

And then add a new method in the tokenizer to ignore the prompt ids. Does this sound good to you?

@connor-henderson
Copy link
Contributor

Hey @mollerup23 @sanchit-gandhi. Apologies, I'm not sure how picking these up works, I started working on it cause I saw there was no assignee and now have something I think is ready for review. Should I just keep it locally or push it up?

Totally fine with whatever, @mollerup23 commented first.

@pmollerus23
Copy link
Contributor

@connor-henderson @sanchit-gandhi I have not yet started on this issue, feel free to push your commits and pick it up!

@pmollerus23
Copy link
Contributor

I will continue to look into what @sanchit-gandhi mentioned in the meantime.

@connor-henderson
Copy link
Contributor

Sounds good, thanks

@sanchit-gandhi
Copy link
Contributor Author

Closed via #22496

@romitjain
Copy link

Hi @sanchit-gandhi and @connor-henderson
I saw the PR, but I was wondering if we also integrated always_use_initial_prompt and condition_on_previous_text to the API? If no, is there any active work going towards it?
Thanks

@sanchit-gandhi
Copy link
Contributor Author

Hey @romitjain - we're working on integrating the OpenAI Whisper algorithm into Transformers, which will provide more support for these fine-grained decoding parameters! c.f. #27492

@M-Ali-ML
Copy link

Hey @romitjain - we're working on integrating the OpenAI Whisper algorithm into Transformers, which will provide more support for these fine-grained decoding parameters! c.f. #27492

are contribution allowed here? I'd like to help on that.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging a pull request may close this issue.

5 participants