From 01768370d89a7da8e89d7d48b249976c46c22198 Mon Sep 17 00:00:00 2001 From: mercury233 Date: Tue, 7 Mar 2023 11:22:03 +0800 Subject: [PATCH 1/2] add always_use_initial_prompt --- whisper/transcribe.py | 13 ++++++++++++- 1 file changed, 12 insertions(+), 1 deletion(-) diff --git a/whisper/transcribe.py b/whisper/transcribe.py index 8e1240bd6..cd1c9727f 100644 --- a/whisper/transcribe.py +++ b/whisper/transcribe.py @@ -46,6 +46,7 @@ def transcribe( no_speech_threshold: Optional[float] = 0.6, condition_on_previous_text: bool = True, initial_prompt: Optional[str] = None, + always_use_initial_prompt: bool = False, word_timestamps: bool = False, prepend_punctuations: str = "\"'“¿([{-", append_punctuations: str = "\"'.。,,!!??::”)]}、", @@ -102,6 +103,11 @@ def transcribe( "prompt-engineer" a context for transcription, e.g. custom vocabularies or proper nouns to make it more likely to predict those word correctly. + always_use_initial_prompt: bool + if True, the initial_prompt will be used to all windows, and condition_on_previous_text + will be ignored. Enabling this may make the text more consistent if the audio is long + and you set the initial_prompt properly. + decode_options: dict Keyword arguments to construct `DecodingOptions` instances @@ -275,7 +281,11 @@ def new_segment( segment_duration = segment_size * HOP_LENGTH / SAMPLE_RATE mel_segment = pad_or_trim(mel_segment, N_FRAMES).to(model.device).to(dtype) - decode_options["prompt"] = all_tokens[prompt_reset_since:] + if always_use_initial_prompt: + decode_options["prompt"] = initial_prompt_tokens + else: + decode_options["prompt"] = all_tokens[prompt_reset_since:] + result: DecodingResult = decode_with_fallback(mel_segment) tokens = torch.tensor(result.tokens) @@ -530,6 +540,7 @@ def valid_model_name(name): parser.add_argument("--suppress_tokens", type=str, default="-1", help="comma-separated list of token ids to suppress during sampling; '-1' will suppress most special characters except common punctuations") parser.add_argument("--initial_prompt", type=str, default=None, help="optional text to provide as a prompt for the first window.") parser.add_argument("--condition_on_previous_text", type=str2bool, default=True, help="if True, provide the previous output of the model as a prompt for the next window; disabling may make the text inconsistent across windows, but the model becomes less prone to getting stuck in a failure loop") + parser.add_argument("--always_use_initial_prompt", type=str2bool, default=False, help="if True, the initial_prompt will be used to all windows, and condition_on_previous_text will be ignored. Enabling this may make the text more consistent if if the audio is long and you set the initial_prompt properly.") parser.add_argument("--fp16", type=str2bool, default=True, help="whether to perform inference in fp16; True by default") parser.add_argument("--temperature_increment_on_fallback", type=optional_float, default=0.2, help="temperature to increase when falling back when the decoding fails to meet either of the thresholds below") From daf8ba8fe1aa2b7020d3630a5afaaa49e565d1bb Mon Sep 17 00:00:00 2001 From: mercury233 Date: Tue, 7 Mar 2023 11:24:38 +0800 Subject: [PATCH 2/2] typo --- whisper/transcribe.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/whisper/transcribe.py b/whisper/transcribe.py index cd1c9727f..674e450de 100644 --- a/whisper/transcribe.py +++ b/whisper/transcribe.py @@ -540,7 +540,7 @@ def valid_model_name(name): parser.add_argument("--suppress_tokens", type=str, default="-1", help="comma-separated list of token ids to suppress during sampling; '-1' will suppress most special characters except common punctuations") parser.add_argument("--initial_prompt", type=str, default=None, help="optional text to provide as a prompt for the first window.") parser.add_argument("--condition_on_previous_text", type=str2bool, default=True, help="if True, provide the previous output of the model as a prompt for the next window; disabling may make the text inconsistent across windows, but the model becomes less prone to getting stuck in a failure loop") - parser.add_argument("--always_use_initial_prompt", type=str2bool, default=False, help="if True, the initial_prompt will be used to all windows, and condition_on_previous_text will be ignored. Enabling this may make the text more consistent if if the audio is long and you set the initial_prompt properly.") + parser.add_argument("--always_use_initial_prompt", type=str2bool, default=False, help="if True, the initial_prompt will be used to all windows, and condition_on_previous_text will be ignored. Enabling this may make the text more consistent if the audio is long and you set the initial_prompt properly.") parser.add_argument("--fp16", type=str2bool, default=True, help="whether to perform inference in fp16; True by default") parser.add_argument("--temperature_increment_on_fallback", type=optional_float, default=0.2, help="temperature to increase when falling back when the decoding fails to meet either of the thresholds below")