From 7400dcf6813ff99b4b97d74a512b2c4012e5c5be Mon Sep 17 00:00:00 2001 From: Kunal Vaishnavi Date: Sat, 20 Jan 2024 01:32:51 +0000 Subject: [PATCH 01/19] Add temperature to model and clean up scripts --- docs/ContribOperators.md | 4 +- docs/OperatorKernels.md | 4 +- .../transformers/beam_search_parameters.cc | 13 + .../cuda/transformers/beam_search.cc | 1 + .../core/graph/contrib_ops/contrib_defs.cc | 1 + .../transformers/models/whisper/README.md | 9 + .../transformers/models/whisper/benchmark.py | 16 +- .../models/whisper/convert_to_onnx.py | 278 ++++++++++-------- .../models/whisper/requirements.txt | 16 + .../models/whisper/whisper_chain.py | 228 +++++++------- .../models/whisper/whisper_helper.py | 7 +- .../transformers/torch_onnx_export_helper.py | 3 +- .../python/transformers/test_generation.py | 17 +- 13 files changed, 340 insertions(+), 257 deletions(-) create mode 100644 onnxruntime/python/tools/transformers/models/whisper/requirements.txt diff --git a/docs/ContribOperators.md b/docs/ContribOperators.md index 45c0e6f822ce9..eaa2160c83735 100644 --- a/docs/ContribOperators.md +++ b/docs/ContribOperators.md @@ -5755,7 +5755,7 @@ This version of the operator has been available since version 1 of the 'com.micr
Size of the vocabulary. If not provided, it will be inferred from the decoder subgraph's output shape
-#### Inputs (5 - 14) +#### Inputs (5 - 15)
input_ids : F
@@ -5786,6 +5786,8 @@ This version of the operator has been available since version 1 of the 'com.micr
Only keep this list of (layer, head) of QK in the final cross_qk output when use_cross_qk is set. Default collect allits shape is (number of (layer, head) to keep, 2), i.e., [[layer_id1, head_id1], [layer_id2, head_id2]......]
extra_decoding_ids (optional) : I
Part of the decoder_input_ids that we need cross qk for it. it is of shape (batch_size, extra_decoding_ids_len).In such case, we should remove this from the tail of the decoder_input_ids, and put it here. ids < 0 in it (for multiple batch) are treated as stop of the extra_decoding_ids for corresponding batch.
+
temperature (optional) : T
+
Temperature value to apply to logits processing during this execution's decoding. Shape is (1)
#### Outputs (1 - 5) diff --git a/docs/OperatorKernels.md b/docs/OperatorKernels.md index 394bd7ad2abae..d6322de18e7ef 100644 --- a/docs/OperatorKernels.md +++ b/docs/OperatorKernels.md @@ -499,7 +499,7 @@ Do not modify directly.* |TransposeMatMul|*in* A:**T**
*in* B:**T**
*out* Y:**T**|1+|**T** = tensor(float)| |Trilu|*in* X:**T**
*in* k:**tensor(int64)**
*out* Y:**T**|1+|**T** = tensor(double), tensor(float), tensor(int64)| |Unique|*in* x:**T**
*out* y:**T**
*out* idx:**tensor(int64)**
*out* counts:**tensor(int64)**|1+|**T** = tensor(float)| -|WhisperBeamSearch|*in* input_ids:**F**
*in* max_length:**I**
*in* min_length:**I**
*in* num_beams:**I**
*in* num_return_sequences:**I**
*in* length_penalty:**T**
*in* repetition_penalty:**T**
*in* vocab_mask:**M**
*in* prefix_vocab_mask:**M**
*in* attention_mask:**I**
*in* decoder_input_ids:**I**
*in* logits_processor:**I**
*in* cross_qk_layer_head:**I**
*in* extra_decoding_ids:**I**
*out* sequences:**I**
*out* sequences_scores:**T**
*out* scores:**T**
*out* cross_qk:**V**
*out* non_speech_probs:**T**|1+|**T** = tensor(float)| +|WhisperBeamSearch|*in* input_ids:**F**
*in* max_length:**I**
*in* min_length:**I**
*in* num_beams:**I**
*in* num_return_sequences:**I**
*in* length_penalty:**T**
*in* repetition_penalty:**T**
*in* vocab_mask:**M**
*in* prefix_vocab_mask:**M**
*in* attention_mask:**I**
*in* decoder_input_ids:**I**
*in* logits_processor:**I**
*in* cross_qk_layer_head:**I**
*in* extra_decoding_ids:**I**
*in* temperature:**T**
*out* sequences:**I**
*out* sequences_scores:**T**
*out* scores:**T**
*out* cross_qk:**V**
*out* non_speech_probs:**T**|1+|**T** = tensor(float)| |WordConvEmbedding|*in* Sequence:**T**
*in* W:**T1**
*in* B:**T1**
*in* C:**T1**
*out* Y:**T1**|1+|**T** = tensor(int32)
**T1** = tensor(float)| | | | | @@ -876,7 +876,7 @@ Do not modify directly.* |TransposeMatMul|*in* A:**T**
*in* B:**T**
*out* Y:**T**|1+|**T** = tensor(bfloat16), tensor(double), tensor(float), tensor(float16)| |Trilu|*in* X:**T**
*in* k:**tensor(int64)**
*out* Y:**T**|1+|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| |UnfoldTensor|*in* input:**T**
*out* output:**T**|1+|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(string), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| -|WhisperBeamSearch|*in* input_ids:**F**
*in* max_length:**I**
*in* min_length:**I**
*in* num_beams:**I**
*in* num_return_sequences:**I**
*in* length_penalty:**T**
*in* repetition_penalty:**T**
*in* vocab_mask:**M**
*in* prefix_vocab_mask:**M**
*in* attention_mask:**I**
*in* decoder_input_ids:**I**
*in* logits_processor:**I**
*in* cross_qk_layer_head:**I**
*in* extra_decoding_ids:**I**
*out* sequences:**I**
*out* sequences_scores:**T**
*out* scores:**T**
*out* cross_qk:**V**
*out* non_speech_probs:**T**|1+|**T** = tensor(float), tensor(float16)| +|WhisperBeamSearch|*in* input_ids:**F**
*in* max_length:**I**
*in* min_length:**I**
*in* num_beams:**I**
*in* num_return_sequences:**I**
*in* length_penalty:**T**
*in* repetition_penalty:**T**
*in* vocab_mask:**M**
*in* prefix_vocab_mask:**M**
*in* attention_mask:**I**
*in* decoder_input_ids:**I**
*in* logits_processor:**I**
*in* cross_qk_layer_head:**I**
*in* extra_decoding_ids:**I**
*in* temperature:**T**
*out* sequences:**I**
*out* sequences_scores:**T**
*out* scores:**T**
*out* cross_qk:**V**
*out* non_speech_probs:**T**|1+|**T** = tensor(float), tensor(float16)| | | | | diff --git a/onnxruntime/contrib_ops/cpu/transformers/beam_search_parameters.cc b/onnxruntime/contrib_ops/cpu/transformers/beam_search_parameters.cc index 3962486d5b5eb..d6c4084d4d28b 100644 --- a/onnxruntime/contrib_ops/cpu/transformers/beam_search_parameters.cc +++ b/onnxruntime/contrib_ops/cpu/transformers/beam_search_parameters.cc @@ -123,6 +123,19 @@ void BeamSearchParameters::ParseFromInputs(OpKernelContext* context) { logits_processor = logits_processor_tensor ? static_cast(*logits_processor_tensor->Data()) : 0; ORT_ENFORCE(logits_processor >= 0, "logits_processor shall be a non-negative integer, got ", logits_processor); + + if (this->model_type == IGenerationParameters::kModelTypeWhisper) { + auto* temperature_tensor = context->Input(14); + if (temperature_tensor) { + if (temperature_tensor->IsDataType()) { + temperature = *temperature_tensor->Data(); + } else { + temperature = static_cast(*temperature_tensor->Data()); + } + } else { + temperature = 1.0f; + } + } } void BeamSearchParameters::SetSubgraphParameters(int vocabulary_size, int heads, int hidden_size_per_head, int layers) { diff --git a/onnxruntime/contrib_ops/cuda/transformers/beam_search.cc b/onnxruntime/contrib_ops/cuda/transformers/beam_search.cc index 2a90e4911f286..08cbb145a6f65 100644 --- a/onnxruntime/contrib_ops/cuda/transformers/beam_search.cc +++ b/onnxruntime/contrib_ops/cuda/transformers/beam_search.cc @@ -49,6 +49,7 @@ ONNX_OPERATOR_KERNEL_EX( .InputMemoryType(OrtMemTypeCPUInput, 9) // 'attention_mask' needs to be on CPU .InputMemoryType(OrtMemTypeCPUInput, 10) // 'decoder_input_ids' needs to be on CPU .InputMemoryType(OrtMemTypeCPUInput, 11) // 'logits_processor' needs to be on CPU + .InputMemoryType(OrtMemTypeCPUInput, 14) // 'temperature' needs to be on CPU .OutputMemoryType(OrtMemTypeCPUOutput, 0) // 'sequences' output on CPU .OutputMemoryType(OrtMemTypeCPUOutput, 1) // 'sequences_scores' output on CPU .TypeConstraint("T", {DataTypeImpl::GetTensorType(), diff --git a/onnxruntime/core/graph/contrib_ops/contrib_defs.cc b/onnxruntime/core/graph/contrib_ops/contrib_defs.cc index 982e8fd834b76..27c968a59eb91 100644 --- a/onnxruntime/core/graph/contrib_ops/contrib_defs.cc +++ b/onnxruntime/core/graph/contrib_ops/contrib_defs.cc @@ -1231,6 +1231,7 @@ ONNX_MS_OPERATOR_SET_SCHEMA(WhisperBeamSearch, 1, "In such case, we should remove this from the tail of the decoder_input_ids, and put it here. ids < 0 in it (for multiple batch) " "are treated as stop of the extra_decoding_ids for corresponding batch.", "I", OpSchema::Optional) + .Input(14, "temperature", "Temperature value to apply to logits processing during this execution's decoding. Shape is (1)", "T", OpSchema::Optional) .Output(0, "sequences", "Word IDs of generated sequences. Shape is (batch_size, num_return_sequences, max_sequence_length)", "I") .Output(1, "sequences_scores", "Final beam score of the generated sequences. Shape is (batch_size, num_return_sequences)", "T", OpSchema::Optional) .Output(2, "scores", diff --git a/onnxruntime/python/tools/transformers/models/whisper/README.md b/onnxruntime/python/tools/transformers/models/whisper/README.md index 8ff5c8a6e1de0..f2b870fd5c2ad 100644 --- a/onnxruntime/python/tools/transformers/models/whisper/README.md +++ b/onnxruntime/python/tools/transformers/models/whisper/README.md @@ -75,6 +75,15 @@ $ python3 -m models.whisper.convert_to_onnx -m openai/whisper-tiny --output whis $ python3 -m onnxruntime.transformers.models.whisper.convert_to_onnx -m openai/whisper-tiny --output whispertiny --use_external_data_format --precision int8 --quantize_embedding_layer ``` +To see all available options +``` +# From source: +$ python3 -m models.whisper.convert_to_onnx --help + +# From wheel: +$ python3 -m onnxruntime.transformers.models.whisper.convert_to_onnx --help +``` + ## Benchmark Whisper Here are some examples of how you can benchmark Whisper across various end-to-end (E2E) implementations. diff --git a/onnxruntime/python/tools/transformers/models/whisper/benchmark.py b/onnxruntime/python/tools/transformers/models/whisper/benchmark.py index 759ae6d14f184..df1a6d073fce2 100644 --- a/onnxruntime/python/tools/transformers/models/whisper/benchmark.py +++ b/onnxruntime/python/tools/transformers/models/whisper/benchmark.py @@ -54,6 +54,8 @@ def load_via_numpy(): inputs["decoder_input_ids"] = np.array([args.decoder_input_ids], dtype=np.int32) if args.has_logits_processor: inputs["logits_processor"] = np.array([args.logits_processor], dtype=np.int32) + if args.has_temperature: + inputs["temperature"] = np.array([args.temperature], dtype=np.float32) # Measure time taken to load audio file logger.info(f"Load audio: {args.audio_path}") @@ -163,6 +165,7 @@ def get_model(args: argparse.Namespace): def time_fn(args, fn, inputs): warmup_inputs = inputs[0] if type(inputs) is tuple else inputs benchmark_inputs = inputs[1] if type(inputs) is tuple else inputs + torch_device = torch.device(args.target_device) # Warm up warmup_range = ( @@ -180,7 +183,7 @@ def time_fn(args, fn, inputs): # Benchmark if args.device != "cpu": - torch.cuda.synchronize() + torch.cuda.synchronize(torch_device) start_time = time.time() bench_range = ( @@ -192,7 +195,7 @@ def time_fn(args, fn, inputs): fn(benchmark_inputs) if args.device != "cpu": - torch.cuda.synchronize() + torch.cuda.synchronize(torch_device) end_time = time.time() # Newline print after trange in order to print metrics on new lines without progress bar on same line @@ -500,7 +503,13 @@ def parse_args(): "--logits-processor", type=int, default=1, - help="Type of logits processor to use. See `BeamSearch` in https://github.com/microsoft/onnxruntime/blob/main/onnxruntime/core/graph/contrib_ops/contrib_defs.cc for details.", + help="Type of logits processor to use. See `WhisperBeamSearch` in https://github.com/microsoft/onnxruntime/blob/main/onnxruntime/core/graph/contrib_ops/contrib_defs.cc for details.", + ) + parser.add_argument( + "--temperature", + type=float, + default=1.0, + help="Temperature value for generation.", ) # Args for accessing detailed info @@ -581,6 +590,7 @@ def main(): args.has_audio_stream = "audio_stream" in ort_model_inputs setattr(args, "has_decoder_input_ids", "decoder_input_ids" in ort_model_inputs) # noqa: B010 setattr(args, "has_logits_processor", "logits_processor" in ort_model_inputs) # noqa: B010 + setattr(args, "has_temperature", "temperature" in ort_model_inputs) # noqa: B010 if args.decoder_input_ids == []: args.decoder_input_ids = [config.decoder_start_token_id] diff --git a/onnxruntime/python/tools/transformers/models/whisper/convert_to_onnx.py b/onnxruntime/python/tools/transformers/models/whisper/convert_to_onnx.py index 50637b772c233..975290277a7b1 100644 --- a/onnxruntime/python/tools/transformers/models/whisper/convert_to_onnx.py +++ b/onnxruntime/python/tools/transformers/models/whisper/convert_to_onnx.py @@ -28,17 +28,25 @@ def parse_arguments(argv=None): parser = argparse.ArgumentParser() - pretrained_models = PRETRAINED_WHISPER_MODELS - parser.add_argument( + conversion_args = parser.add_argument_group("Conversion Process Args") + optional_inputs = parser.add_argument_group("Optional Inputs (for WhisperBeamSearch op)") + optional_outputs = parser.add_argument_group("Optional Outputs (for WhisperBeamSearch op)") + quant_args = parser.add_argument_group("INT8 Quantization Args") + + ################################# + # Conversion options for Whisper + ################################# + + conversion_args.add_argument( "-m", "--model_name_or_path", required=False, default=PRETRAINED_WHISPER_MODELS[0], type=str, - help="Model path, or pretrained model name in the list: " + ", ".join(pretrained_models), + help="Model path, or pretrained model name in the list: " + ", ".join(PRETRAINED_WHISPER_MODELS), ) - parser.add_argument( + conversion_args.add_argument( "--cache_dir", required=False, type=str, @@ -46,7 +54,7 @@ def parse_arguments(argv=None): help="Directory to cache pre-trained models", ) - parser.add_argument( + conversion_args.add_argument( "--output", required=False, type=str, @@ -54,19 +62,24 @@ def parse_arguments(argv=None): help="Output directory", ) - parser.add_argument( + conversion_args.add_argument( "-o", "--optimize_onnx", required=False, action="store_true", help="Use optimizer.py to optimize onnx model", ) - parser.set_defaults(optimize_onnx=False) + conversion_args.set_defaults(optimize_onnx=False) - parser.add_argument("--use_gpu", required=False, action="store_true", help="use GPU for inference") - parser.set_defaults(use_gpu=False) + conversion_args.add_argument( + "--use_gpu", + required=False, + action="store_true", + help="Use GPU for model inference", + ) + conversion_args.set_defaults(use_gpu=False) - parser.add_argument( + conversion_args.add_argument( "-p", "--precision", required=False, @@ -76,221 +89,226 @@ def parse_arguments(argv=None): help="Precision of model to run. fp32 for full precision, fp16 for half precision, int8 for quantization", ) - parser.add_argument("--verbose", required=False, action="store_true") - parser.set_defaults(verbose=False) - - parser.add_argument("-e", "--use_external_data_format", required=False, action="store_true") - parser.set_defaults(use_external_data_format=False) - - parser.add_argument( - "-s", - "--use_decoder_start_token", + conversion_args.add_argument( + "--use_int64_inputs", required=False, action="store_true", - help="Use config.decoder_start_token_id. Otherwise, add an extra graph input to \ - the encoder-decoder-init subgraph for decoder_input_ids.", + help="Use int64 instead of int32 for input_ids and attention_mask.", ) - parser.set_defaults(use_decoder_start_token=False) + conversion_args.set_defaults(use_int64_inputs=False) - parser.add_argument( - "-f", - "--use_forced_decoder_ids", + conversion_args.add_argument( + "--disable_auto_mixed_precision", required=False, action="store_true", - help="Use decoder_input_ids as an extra graph input to the beam search op", + help="Use pure fp16 instead of mixed precision", ) - parser.set_defaults(use_forced_decoder_ids=False) + conversion_args.set_defaults(disable_auto_mixed_precision=False) - parser.add_argument( - "-l", - "--use_logits_processor", + conversion_args.add_argument( + "-r", + "--provider", required=False, - action="store_true", - help="Use logits_processor as an extra graph input to enable specific logits processing", + type=str, + default="cpu", + choices=list(PROVIDERS.keys()), + help="Provider to benchmark. Default is CPUExecutionProvider.", ) - parser.set_defaults(use_specific_logits_processor=False) - parser.add_argument( - "-v", - "--use_vocab_mask", + conversion_args.add_argument( + "--verbose", required=False, action="store_true", - help="Use vocab_mask as an extra graph input to enable specific logits processing", + help="Enable verbose logging", ) - parser.set_defaults(use_vocab_mask=False) + conversion_args.set_defaults(verbose=False) - parser.add_argument( - "-u", - "--use_prefix_vocab_mask", + conversion_args.add_argument( + "-e", + "--use_external_data_format", required=False, action="store_true", - help="Use prefix_vocab_mask as an extra graph input to enable specific logits processing", + help="Save weights in external file. Necessary for 'small', 'medium', and 'large' models. Optional for 'tiny' and 'base' models.", ) - parser.set_defaults(use_prefix_vocab_mask=False) + conversion_args.set_defaults(use_external_data_format=False) - parser.add_argument( + conversion_args.add_argument( "-w", "--overwrite", required=False, action="store_true", - help="overwrite existing ONNX model", + help="Overwrite existing ONNX model", ) - parser.set_defaults(overwrite=False) + conversion_args.set_defaults(overwrite=False) - parser.add_argument( - "--disable_auto_mixed_precision", + conversion_args.add_argument( + "--separate_encoder_and_decoder_init", required=False, action="store_true", - help="use pure fp16 instead of mixed precision", + help="Do not merge encoder and decoder init to initialize past KV caches. Output 3 instead of 2 ONNX models.", ) - parser.set_defaults(disable_auto_mixed_precision=False) + conversion_args.set_defaults(separate_encoder_and_decoder_init=False) - parser.add_argument( - "--separate_encoder_and_decoder_init", + conversion_args.add_argument( + "--no_beam_search_op", required=False, action="store_true", - help="Do not merge encode and decoder init. Output 3 instead of 2 onnx models.", + help="Do not produce model with WhisperBeamSearch op, which chains encdecinit and decoder models into one op.", ) - parser.set_defaults(separate_encoder_and_decoder_init=False) + conversion_args.set_defaults(no_beam_search_op=False) - parser.add_argument( - "--use_int64_inputs", + conversion_args.add_argument( + "--state_dict_path", + type=str, + default="", + help="Filepath to load pre-trained model with custom state dictionary (e.g. pytorch_model.bin)", + ) + + ############################################################# + # Optional inputs for Whisper + # (listed below in the order that WhisperBeamSearch expects) + ############################################################# + + optional_inputs.add_argument( + "-v", + "--use_vocab_mask", required=False, action="store_true", - help="Use int64 instead of int32 for input_ids, position_ids and attention_mask.", + help="Use vocab_mask as an extra graph input to enable specific logits processing", ) - parser.set_defaults(use_int64_inputs=False) + optional_inputs.set_defaults(use_vocab_mask=False) - parser.add_argument( - "--chain_model", + optional_inputs.add_argument( + "-u", + "--use_prefix_vocab_mask", required=False, action="store_true", - help="Produce beam search model with chained encdecinit and decoder.", + help="Use prefix_vocab_mask as an extra graph input to enable specific logits processing", ) - parser.set_defaults(chain_model=True) - - parser.add_argument( - "--use_whisper_beamsearch", + optional_inputs.set_defaults(use_prefix_vocab_mask=False) + + optional_inputs.add_argument( + "-f", + "--use_forced_decoder_ids", required=False, action="store_true", - help="When chain_model, using WhisperBeamSearch operator rather than BeamSearch operator. \ - It will be set to true when collect_cross_qk, extra_decoding_ids or output_no_speech_probs is set.", + help="Use decoder_input_ids as an extra graph input to the beam search op", ) - parser.set_defaults(use_whisper_beamsearch=False) + optional_inputs.set_defaults(use_forced_decoder_ids=False) - parser.add_argument( - "--extra_decoding_ids", + optional_inputs.add_argument( + "-l", + "--use_logits_processor", required=False, action="store_true", - help="Need extra starting decoding ids for some feature like cross qk. Default if false.", + help="Use logits_processor as an extra graph input to enable specific logits processing", ) - parser.set_defaults(extra_decoding_ids=False) + optional_inputs.set_defaults(use_specific_logits_processor=False) - parser.add_argument( + optional_inputs.add_argument( "--collect_cross_qk", required=False, action="store_true", help="Beam search model collect stacked cross QK.", ) - parser.set_defaults(collect_cross_qk=False) + optional_inputs.set_defaults(collect_cross_qk=False) - parser.add_argument( - "--output_cross_qk", + optional_inputs.add_argument( + "--extra_decoding_ids", required=False, action="store_true", - help="Beam search model output collected qk as output. Also hint collect_cross_qk", + help="Need extra starting decoding ids for some feature like cross qk. Default if false.", ) - parser.set_defaults(output_cross_qk=False) + optional_inputs.set_defaults(extra_decoding_ids=False) - parser.add_argument( - "--no_speech_token_id", - default=50362, + optional_inputs.add_argument( + "-t", + "--use_temperature", + required=False, + action="store_true", + help="Use temperature as an extra graph input for the WhisperBeamSearch op", + ) + optional_inputs.set_defaults(use_temperature=False) + + optional_inputs.add_argument( + "--no_repeat_ngram_size", type=int, - help="specify no_speech_token_id. Default is 50362. if >= 0, will be add into beam search attr. \ - Note that default value maybe different between the multilingual and English-only models.", + default=0, + help="default to 0", ) - parser.add_argument( - "--output_no_speech_probs", + ############################################################# + # Optional outputs for Whisper + # (listed below in the order that WhisperBeamSearch expects) + ############################################################# + + optional_outputs.add_argument( + "--output_sequence_scores", required=False, action="store_true", - help="Beam search model output no speech probs which is computed from the encoder/context-decoder graph.", + help="Beam search model output scores for each generated sequence.", ) - parser.set_defaults(output_no_speech_probs=False) + optional_outputs.set_defaults(output_sequence_scores=False) - parser.add_argument( + optional_outputs.add_argument( "--output_scores", required=False, action="store_true", help="Beam search model output scores over vocab per generated token.", ) - parser.set_defaults(output_scores=False) + optional_outputs.set_defaults(output_scores=False) - parser.add_argument( - "--output_sequence_scores", + optional_outputs.add_argument( + "--output_cross_qk", required=False, action="store_true", - help="Beam search model output scores for each generated sequence.", + help="Beam search model output collected qk as output. Also hint collect_cross_qk", ) - parser.set_defaults(output_sequence_scores=False) + optional_outputs.set_defaults(output_cross_qk=False) - parser.add_argument( + optional_outputs.add_argument( "--cross_qk_onnx_model", required=False, type=str, default=None, - help="the model which consume cross_qk.", + help="The model which consumes cross_qk outputs.", ) - parser.add_argument( - "--beam_output_model", - type=str, - default="whisper_beamsearch.onnx", - help="default name is whisper_beamsearch.onnx.", + optional_outputs.add_argument( + "--output_no_speech_probs", + required=False, + action="store_true", + help="Beam search model output no speech probs which is computed from the encoder/context-decoder graph.", ) + optional_outputs.set_defaults(output_no_speech_probs=False) + + ################################### + # Quantization options for Whisper + ################################### - parser.add_argument( + quant_args.add_argument( "--quantize_embedding_layer", required=False, action="store_true", help="Quantize MatMul, GEMM, and Gather.", ) - parser.set_defaults(quantize_embedding_layer=False) + quant_args.set_defaults(quantize_embedding_layer=False) - parser.add_argument( + quant_args.add_argument( "--quantize_per_channel", required=False, action="store_true", help="Quantize weights per each channel.", ) - parser.set_defaults(quantize_per_channel=False) + quant_args.set_defaults(quantize_per_channel=False) - parser.add_argument( + quant_args.add_argument( "--quantize_reduce_range", required=False, action="store_true", help="Quantize weights with 7 bits.", ) - parser.set_defaults(quantize_reduce_range=False) - - parser.add_argument("--no_repeat_ngram_size", type=int, default=0, help="default to 0") - - parser.add_argument( - "--state_dict_path", - type=str, - default="", - help="filepath to load pre-trained model with custom state dictionary (e.g. pytorch_model.bin)", - ) - - parser.add_argument( - "-r", - "--provider", - required=False, - type=str, - default="cpu", - choices=list(PROVIDERS.keys()), - help="Provider to benchmark. Default is CPUExecutionProvider.", - ) + quant_args.set_defaults(quantize_reduce_range=False) args = parser.parse_args(argv) args.collect_cross_qk = args.collect_cross_qk or args.output_cross_qk @@ -307,7 +325,7 @@ def export_onnx_models( optimize_onnx, precision, verbose, - use_decoder_start_token: bool = False, + use_forced_decoder_ids: bool = False, merge_encoder_and_decoder_init: bool = True, overwrite: bool = False, disable_auto_mixed_precision: bool = False, @@ -352,7 +370,7 @@ def export_onnx_models( onnx_path, verbose, use_external_data_format, - use_decoder_input_ids=not use_decoder_start_token, + use_decoder_input_ids=use_forced_decoder_ids, use_int32_inputs=use_int32_inputs, ) else: @@ -396,7 +414,7 @@ def export_onnx_models( extra_options={"MatMulConstBOnly": True}, ) else: - logger.info(f"Skip optimizing: existed ONNX model {onnx_path}") + logger.info(f"Skip optimizing: existing ONNX model {onnx_path}") else: output_path = onnx_path @@ -438,7 +456,7 @@ def main(argv=None): args.optimize_onnx, args.precision, args.verbose, - args.use_decoder_start_token, + args.use_forced_decoder_ids, not args.separate_encoder_and_decoder_init, args.overwrite, args.disable_auto_mixed_precision, @@ -451,7 +469,7 @@ def main(argv=None): ) max_diff = 0 - if args.chain_model: + if not args.no_beam_search_op: logger.info("Chaining model ... :") args.beam_model_output_dir = WhisperHelper.get_onnx_path( output_dir, diff --git a/onnxruntime/python/tools/transformers/models/whisper/requirements.txt b/onnxruntime/python/tools/transformers/models/whisper/requirements.txt new file mode 100644 index 0000000000000..720d1f8f21b03 --- /dev/null +++ b/onnxruntime/python/tools/transformers/models/whisper/requirements.txt @@ -0,0 +1,16 @@ +# NOTE: In addition to the below packages, you will need to install `ffmpeg` on your machine. +# Visit https://ffmpeg.org/ for details. You can also install it natively using package managers. +# Ex: +# Linux: sudo apt-get install ffmpeg +# MacOS: sudo brew install ffmpeg +# Windows: Download from website + +torch>=1.13.0 +transformers>=4.24.0 +openai-whisper +ffmpeg-python +datasets +onnxruntime>=1.15.1 +onnxruntime-extensions>=0.9.0 +protobuf==3.20.2 +numpy==1.23.3 \ No newline at end of file diff --git a/onnxruntime/python/tools/transformers/models/whisper/whisper_chain.py b/onnxruntime/python/tools/transformers/models/whisper/whisper_chain.py index 33958e55f8c38..fde8a362ed6d2 100644 --- a/onnxruntime/python/tools/transformers/models/whisper/whisper_chain.py +++ b/onnxruntime/python/tools/transformers/models/whisper/whisper_chain.py @@ -9,7 +9,7 @@ update_decoder_subgraph_share_buffer_and_use_decoder_masked_mha, ) from onnx import TensorProto, helper -from transformers import WhisperConfig +from transformers import WhisperConfig, WhisperTokenizer logger = logging.getLogger(__name__) @@ -23,11 +23,22 @@ def verify_inputs(beam_inputs, graph_inputs): assert graph_input.name in beam_input +def clean_list(arr, remove_all_strings=True): + if remove_all_strings: + # Remove all empty strings in list + return list(filter(lambda elm: elm != "", arr)) + + # Remove empty strings at end of list + while len(arr) > 0: + if arr[-1] == "": + arr.pop() + else: + break + return arr + + def chain_model(args): - # Load encoder/decoder and insert necessary (but unused) graph inputs expected by BeamSearch op or WhisperBeamSearch op - args.use_whisper_beamsearch = ( - args.use_whisper_beamsearch or args.collect_cross_qk or args.output_no_speech_probs or args.extra_decoding_ids - ) + # Load encoder/decoder and insert necessary (but unused) graph inputs expected by WhisperBeamSearch op encoder_model = onnx.load_model(args.encoder_path, load_external_data=True) encoder_model.graph.name = "encoderdecoderinit subgraph" @@ -35,7 +46,10 @@ def chain_model(args): decoder_model.graph.name = "decoder subgraph" config = WhisperConfig.from_pretrained(args.model_name_or_path) + tokenizer = WhisperTokenizer.from_pretrained(args.model_name_or_path) + # Create inputs/outputs for WhisperBeamSearch op + temperature_name = "temperature_fp16" if args.precision == Precision.FLOAT16 else "temperature" beam_inputs = [ "input_features_fp16" if args.precision == Precision.FLOAT16 else "input_features", "max_length", @@ -44,37 +58,26 @@ def chain_model(args): "num_return_sequences", "length_penalty_fp16" if args.precision == Precision.FLOAT16 else "length_penalty", "repetition_penalty_fp16" if args.precision == Precision.FLOAT16 else "repetition_penalty", - "vocab_mask" if args.use_prefix_vocab_mask else "", + "vocab_mask" if args.use_vocab_mask else "", "prefix_vocab_mask" if args.use_prefix_vocab_mask else "", "", # attention mask "decoder_input_ids" if args.use_forced_decoder_ids else "", "logits_processor" if args.use_logits_processor else "", + "cross_qk_layer_head" if args.collect_cross_qk else "", + "extra_decoding_ids" if args.extra_decoding_ids else "", + temperature_name if args.use_temperature else "", ] - beam_outputs = ["sequences"] - if args.output_sequence_scores: - beam_outputs.append("sequence_scores") - if args.output_scores: - beam_outputs.append("scores") - - if args.use_whisper_beamsearch: - assert len(beam_inputs) == 12 - beam_inputs.extend( - [ - "cross_qk_layer_head" if args.collect_cross_qk else "", - "extra_decoding_ids" if args.extra_decoding_ids else "", - ] - ) - if args.collect_cross_qk: - while len(beam_outputs) < 3: - beam_outputs.extend([""]) - beam_outputs.extend(["cross_qk"]) - if args.output_no_speech_probs: - while len(beam_outputs) < 4: - beam_outputs.extend([""]) - beam_outputs.extend(["no_speech_probs_beam"]) + beam_outputs = [ + "sequences", + "sequence_scores" if args.output_sequence_scores else "", + "scores" if args.output_scores else "", + "cross_qk" if args.collect_cross_qk else "", + "no_speech_probs_beam" if args.output_no_speech_probs else "", + ] - input_features_cast_node, len_pen_cast_node, rep_pen_cast_node = None, None, None + input_features_cast_node, len_pen_cast_node, rep_pen_cast_node, temp_cast_node = None, None, None, None + graph_nodes = [] if args.precision == Precision.FLOAT16: input_features_cast_node = helper.make_node( "Cast", @@ -97,26 +100,38 @@ def chain_model(args): name="CastRepetitionPenaltyToFp16", to=TensorProto.FLOAT16, ) - - operator_type = "WhisperBeamSearch" if args.use_whisper_beamsearch else "BeamSearch" - node = helper.make_node(operator_type, inputs=beam_inputs, outputs=beam_outputs, name="BeamSearch_zcode") - node.domain = "com.microsoft" - node.attribute.extend( - [ - helper.make_attribute("eos_token_id", config.eos_token_id), - helper.make_attribute("pad_token_id", config.pad_token_id), - helper.make_attribute("decoder_start_token_id", config.decoder_start_token_id), - helper.make_attribute("no_repeat_ngram_size", args.no_repeat_ngram_size), - helper.make_attribute("early_stopping", True), - helper.make_attribute("model_type", 2), - ] + graph_nodes.extend([input_features_cast_node, len_pen_cast_node, rep_pen_cast_node]) + if args.use_temperature: + temp_cast_node = helper.make_node( + "Cast", + inputs=["temperature"], + outputs=["temperature_fp16"], + name="temperature_to_fp16", + to=TensorProto.FLOAT16, + ) + graph_nodes.append(temp_cast_node) + + # Create WhisperBeamSearch op + beam_search_attrs = [ + helper.make_attribute("eos_token_id", config.eos_token_id), + helper.make_attribute("pad_token_id", config.pad_token_id), + helper.make_attribute("decoder_start_token_id", config.decoder_start_token_id), + helper.make_attribute("no_repeat_ngram_size", args.no_repeat_ngram_size), + helper.make_attribute("early_stopping", True), + helper.make_attribute("model_type", 2), + helper.make_attribute("decoder_output_cross_qk", 1) if args.collect_cross_qk else "", + helper.make_attribute("no_speech_token", tokenizer.convert_tokens_to_ids(['<|nospeech|>'])[0]) if args.output_no_speech_probs else "", + ] + node = helper.make_node( + "WhisperBeamSearch", + inputs=clean_list(beam_inputs, remove_all_strings=False), + outputs=clean_list(beam_outputs, remove_all_strings=False), + name="BeamSearch", + domain="com.microsoft", ) - if args.use_whisper_beamsearch: - if args.collect_cross_qk: - node.attribute.extend([helper.make_attribute("decoder_output_cross_qk", 1)]) - if args.no_speech_token_id >= 0: - node.attribute.extend([helper.make_attribute("no_speech_token", args.no_speech_token_id)]) + node.attribute.extend(clean_list(beam_search_attrs, remove_all_strings=True)) + # Graph inputs input_features = helper.make_tensor_value_info( "input_features", TensorProto.FLOAT, ["batch_size", "feature_size", "sequence_length"] ) @@ -126,8 +141,23 @@ def chain_model(args): num_return_sequences = helper.make_tensor_value_info("num_return_sequences", TensorProto.INT32, [1]) length_penalty = helper.make_tensor_value_info("length_penalty", TensorProto.FLOAT, [1]) repetition_penalty = helper.make_tensor_value_info("repetition_penalty", TensorProto.FLOAT, [1]) - - graph_inputs = [ + vocab_mask = helper.make_tensor_value_info("vocab_mask", TensorProto.INT32, [config.vocab_size]) + prefix_vocab_mask = helper.make_tensor_value_info( + "prefix_vocab_mask", TensorProto.INT32, ["batch_size", config.vocab_size] + ) + decoder_input_ids = helper.make_tensor_value_info( + "decoder_input_ids", TensorProto.INT32, ["batch_size", "initial_sequence_length"] + ) + logits_processor = helper.make_tensor_value_info("logits_processor", TensorProto.INT32, [1]) + cross_qk_layer_head = helper.make_tensor_value_info( + "cross_qk_layer_head", TensorProto.INT32, ["num_layer_head", 2] + ) + extra_decoding_ids = helper.make_tensor_value_info( + "extra_decoding_ids", TensorProto.INT32, ["batch_size", "extra_decoding_ids_len"] + ) + temperature = helper.make_tensor_value_info("temperature", TensorProto.FLOAT, [1]) + + graph_inputs = clean_list([ input_features, max_length, min_length, @@ -135,64 +165,37 @@ def chain_model(args): num_return_sequences, length_penalty, repetition_penalty, - ] - if args.use_vocab_mask: - vocab_mask = helper.make_tensor_value_info("vocab_mask", TensorProto.INT32, [config.vocab_size]) - graph_inputs.append(vocab_mask) - - if args.use_prefix_vocab_mask: - prefix_vocab_mask = helper.make_tensor_value_info( - "prefix_vocab_mask", TensorProto.INT32, ["batch_size", config.vocab_size] - ) - graph_inputs.append(prefix_vocab_mask) - - if args.use_forced_decoder_ids: - decoder_input_ids = helper.make_tensor_value_info( - "decoder_input_ids", TensorProto.INT32, ["batch_size", "initial_sequence_length"] - ) - graph_inputs.append(decoder_input_ids) - - if args.use_logits_processor: - logits_processor = helper.make_tensor_value_info("logits_processor", TensorProto.INT32, [1]) - graph_inputs.append(logits_processor) - - if args.collect_cross_qk: - cross_qk_layer_head = helper.make_tensor_value_info( - "cross_qk_layer_head", TensorProto.INT32, ["num_layer_head", 2] - ) - graph_inputs.append(cross_qk_layer_head) - - if args.extra_decoding_ids: - extra_decoding_ids = helper.make_tensor_value_info( - "extra_decoding_ids", TensorProto.INT32, ["batch_size", "extra_decoding_ids_len"] - ) - graph_inputs.append(extra_decoding_ids) - - # graph outputs + vocab_mask if args.use_vocab_mask else "", + prefix_vocab_mask if args.use_prefix_vocab_mask else "", + decoder_input_ids if args.use_forced_decoder_ids else "", + logits_processor if args.use_logits_processor else "", + cross_qk_layer_head if args.collect_cross_qk else "", + extra_decoding_ids if args.extra_decoding_ids else "", + temperature if args.use_temperature else "", + ]) + + # Graph outputs sequences = helper.make_tensor_value_info( "sequences", TensorProto.INT32, ["batch_size", "num_return_sequences", "max_length"] ) - graph_outputs = [sequences] - if args.output_cross_qk or (not args.cross_qk_onnx_model and args.collect_cross_qk): - cross_qk = helper.make_tensor_value_info( - "cross_qk", - TensorProto.FLOAT, - ["batch_size", "num_return_sequences", "num_layer_head_cross_qk", "max_length", "frames"], - ) - graph_outputs.extend([cross_qk]) - - if args.output_no_speech_probs: - no_speech_probs = helper.make_tensor_value_info("no_speech_probs", TensorProto.FLOAT, ["batch_size"]) - graph_outputs.extend([no_speech_probs]) - - if args.output_sequence_scores: - sequence_scores = helper.make_tensor_value_info("sequence_scores", TensorProto.FLOAT, ["batch_size"]) - graph_outputs.extend([sequence_scores]) + sequence_scores = helper.make_tensor_value_info("sequence_scores", TensorProto.FLOAT, ["batch_size"]) + scores = helper.make_tensor_value_info("scores", TensorProto.FLOAT, ["batch_size"]) + cross_qk = helper.make_tensor_value_info( + "cross_qk", + TensorProto.FLOAT, + ["batch_size", "num_return_sequences", "num_layer_head_cross_qk", "max_length", "frames"], + ) + no_speech_probs = helper.make_tensor_value_info("no_speech_probs", TensorProto.FLOAT, ["batch_size"]) - if args.output_scores: - scores = helper.make_tensor_value_info("scores", TensorProto.FLOAT, ["batch_size"]) - graph_outputs.extend([scores]) + graph_outputs = clean_list([ + sequences, + sequence_scores if args.output_sequence_scores else "", + scores if args.output_scores else "", + cross_qk if args.output_cross_qk or (not args.cross_qk_onnx_model and args.collect_cross_qk) else "", + no_speech_probs if args.output_no_speech_probs else "", + ]) + # Replace MultiHeadAttention with DecoderMaskedMultiHeadAttention for CUDA EP inference if hasattr(args, "use_gpu") and args.use_gpu: if update_decoder_subgraph_share_buffer_and_use_decoder_masked_mha(decoder_model.graph): logger.info("Updated whisper decoder subgraph to use DecoderMaskedMultiHeadAttention successfully!") @@ -213,11 +216,7 @@ def chain_model(args): opset_import = [helper.make_opsetid(domain="com.microsoft", version=1), helper.make_opsetid(domain="", version=17)] - graph_nodes = ( - [input_features_cast_node, len_pen_cast_node, rep_pen_cast_node, node] - if args.precision == Precision.FLOAT16 - else [node] - ) + graph_nodes.append(node) if args.output_no_speech_probs: prob_cast_node = helper.make_node( "Cast", @@ -226,9 +225,16 @@ def chain_model(args): name="no_speech_probs_cast_to_fp32", to=TensorProto.FLOAT, ) - graph_nodes.extend([prob_cast_node]) - - beam_graph = helper.make_graph(graph_nodes, "beam-search-test", graph_inputs, graph_outputs, initializers) + graph_nodes.append(prob_cast_node) + + # Make graph with WhisperBeamSearch op + beam_graph = helper.make_graph( + graph_nodes, + name="WhisperBeamSearch Graph", + inputs=graph_inputs, + outputs=graph_outputs, + initializer=initializers, + ) beam_graph_input_names = [gi.name for gi in graph_inputs] beam_graph_output_names = [go.name for go in graph_outputs] @@ -262,10 +268,12 @@ def chain_model(args): ir_version=decoder_model.ir_version, ) + # Save WhisperBeamSearch graph and external data if os.path.isfile(args.beam_model_output_dir): logger.info(f"Overwriting {args.beam_model_output_dir} and {args.beam_model_output_dir + '.data'}") os.remove(args.beam_model_output_dir) os.remove(args.beam_model_output_dir + ".data") + onnx.save( beam_model, args.beam_model_output_dir, diff --git a/onnxruntime/python/tools/transformers/models/whisper/whisper_helper.py b/onnxruntime/python/tools/transformers/models/whisper/whisper_helper.py index 8c22cd5e745b3..5295d4310dec2 100644 --- a/onnxruntime/python/tools/transformers/models/whisper/whisper_helper.py +++ b/onnxruntime/python/tools/transformers/models/whisper/whisper_helper.py @@ -29,14 +29,15 @@ PRETRAINED_WHISPER_MODELS = [ "whisper-tiny", "whisper-tiny.en", + "whisper-base", + "whisper-base.en", "whisper-small", "whisper-small.en", "whisper-medium", "whisper-medium.en", - "whisper-base", - "whisper-base.en", "whisper-large", "whisper-large-v2", + "whisper-large-v3", ] @@ -333,6 +334,8 @@ def verify_onnx( inputs[name] = np.array([[0, 0]], dtype=ort_to_np[dtype]) elif name == "extra_decoding_ids": inputs[name] = np.repeat(np.array([[50259, 50359, 50363]], dtype=ort_to_np[dtype]), batch_size, 0) + elif name == "temperature": + inputs[name] = np.array([1.0], dtype=ort_to_np[dtype]) else: inputs[name] = np.array([inputs[name]], dtype=ort_to_np[dtype]) ort_outputs = ort_session.run(None, inputs)[0][0] diff --git a/onnxruntime/python/tools/transformers/torch_onnx_export_helper.py b/onnxruntime/python/tools/transformers/torch_onnx_export_helper.py index f3e67930adbff..66f24c47f6cdb 100644 --- a/onnxruntime/python/tools/transformers/torch_onnx_export_helper.py +++ b/onnxruntime/python/tools/transformers/torch_onnx_export_helper.py @@ -4,6 +4,7 @@ # -------------------------------------------------------------------------- import torch +from torch._C._onnx import OperatorExportTypes TrainingMode = torch.onnx.TrainingMode from packaging.version import Version # noqa: E402 @@ -18,7 +19,7 @@ def torch_onnx_export( training=TrainingMode.EVAL, input_names=None, output_names=None, - operator_export_type=None, + operator_export_type=OperatorExportTypes.ONNX, opset_version=None, _retain_param_name=None, do_constant_folding=True, diff --git a/onnxruntime/test/python/transformers/test_generation.py b/onnxruntime/test/python/transformers/test_generation.py index c9db1fbc02931..8372e29236068 100644 --- a/onnxruntime/test/python/transformers/test_generation.py +++ b/onnxruntime/test/python/transformers/test_generation.py @@ -380,18 +380,19 @@ def test_logits_processor(self): @pytest.mark.slow def test_cross_qk_overall(self): - decoder_input_ids = [ - "--chain_model", - "--collect_cross_qk", - "--output_cross_qk", - "--use_forced_decoder_ids", - "--extra_decoding_ids", - "--output_no_speech_probs", + cross_qk_input_args = [ "--use_vocab_mask", "--use_prefix_vocab_mask", + "--use_forced_decoder_ids", "--use_logits_processor", + "--collect_cross_qk", + "--extra_decoding_ids", ] - self.run_configs(decoder_input_ids) + cross_qk_output_args = [ + "--output_cross_qk", + "--output_no_speech_probs", + ] + self.run_configs(cross_qk_input_args + cross_qk_output_args) if __name__ == "__main__": From 10cf102062bc4580b80ad4256da470c3f9fb3925 Mon Sep 17 00:00:00 2001 From: Kunal Vaishnavi Date: Fri, 26 Jan 2024 06:24:04 +0000 Subject: [PATCH 02/19] Add task token ids to WhisperBeamSearch --- .../transformers/beam_search_parameters.cc | 2 + .../cpu/transformers/generation_shared.h | 2 + .../cpu/transformers/logits_processor.cc | 7 --- .../cpu/transformers/logits_processor.h | 31 ++++++++-- .../transformers/generation_device_helper.cc | 5 +- .../core/graph/contrib_ops/contrib_defs.cc | 2 + .../models/whisper/convert_to_onnx.py | 1 - .../models/whisper/whisper_chain.py | 2 + .../whisper/whisper_encoder_decoder_init.py | 6 +- .../models/whisper/whisper_helper.py | 58 ++++++++++--------- 10 files changed, 72 insertions(+), 44 deletions(-) diff --git a/onnxruntime/contrib_ops/cpu/transformers/beam_search_parameters.cc b/onnxruntime/contrib_ops/cpu/transformers/beam_search_parameters.cc index d6c4084d4d28b..9c28fb6eb3663 100644 --- a/onnxruntime/contrib_ops/cpu/transformers/beam_search_parameters.cc +++ b/onnxruntime/contrib_ops/cpu/transformers/beam_search_parameters.cc @@ -154,6 +154,8 @@ void WhisperBeamSearchParameters::ParseFromAttributes(const OpKernelInfo& info) model_type = static_cast(info.GetAttrOrDefault("model_type", IGenerationParameters::kModelTypeWhisper)); ORT_ENFORCE(model_type == IGenerationParameters::kModelTypeWhisper); + transcribe_token_id = static_cast(info.GetAttrOrDefault("transcribe_token_id", -1LL)); + translate_token_id = static_cast(info.GetAttrOrDefault("translate_token_id", -1LL)); no_speech_token = static_cast(info.GetAttrOrDefault("no_speech_token", -1LL)); cross_qk_layer_head_input_id = 12; extra_decoding_ids_input_id = 13; diff --git a/onnxruntime/contrib_ops/cpu/transformers/generation_shared.h b/onnxruntime/contrib_ops/cpu/transformers/generation_shared.h index f6faf2e325f8f..286f71f38dd77 100644 --- a/onnxruntime/contrib_ops/cpu/transformers/generation_shared.h +++ b/onnxruntime/contrib_ops/cpu/transformers/generation_shared.h @@ -180,6 +180,8 @@ struct IGenerationParameters { // Parameters for whisper model bool decoder_output_cross_qk = false; gsl::span extra_decoding_ids; + int32_t transcribe_token_id = -1; + int32_t translate_token_id = -1; int32_t no_speech_token = -1; void* no_speech_probs = nullptr; diff --git a/onnxruntime/contrib_ops/cpu/transformers/logits_processor.cc b/onnxruntime/contrib_ops/cpu/transformers/logits_processor.cc index f39f090c78b0c..5dbdd0a0c5420 100644 --- a/onnxruntime/contrib_ops/cpu/transformers/logits_processor.cc +++ b/onnxruntime/contrib_ops/cpu/transformers/logits_processor.cc @@ -17,13 +17,6 @@ namespace onnxruntime { namespace contrib { namespace transformers { -#ifdef DEBUG_GENERATION -template -void DumpScores(const char* name, const NextTokenScores& next_token_scores) { - std::cout << name << std::endl; - ORT_UNUSED_PARAMETER(next_token_scores); -} -#endif // Interface for all scorers for beam search or beam sample. template diff --git a/onnxruntime/contrib_ops/cpu/transformers/logits_processor.h b/onnxruntime/contrib_ops/cpu/transformers/logits_processor.h index 4688ff272cee9..c6646bb746219 100644 --- a/onnxruntime/contrib_ops/cpu/transformers/logits_processor.h +++ b/onnxruntime/contrib_ops/cpu/transformers/logits_processor.h @@ -10,6 +10,7 @@ #include "contrib_ops/cpu/transformers/greedy_search_parameters.h" #include "contrib_ops/cpu/transformers/sampling_parameters.h" #include "contrib_ops/cpu/transformers/generation_shared.h" +#include namespace onnxruntime { namespace contrib { @@ -34,6 +35,14 @@ struct NextTokenScores { } }; +#ifdef DEBUG_GENERATION +template +void DumpScores(const char* name, const NextTokenScores& next_token_scores) { + std::cout << name << std::endl; + ORT_UNUSED_PARAMETER(next_token_scores); +} +#endif + // Interface for all scorers for beam search or beam sample. template class ILogitsProcessor { @@ -150,18 +159,23 @@ class PresencePenaltyLogitsProcessor : public ILogitsProcessor { template class TimestampLogitsProcessor : public ILogitsProcessor { public: - TimestampLogitsProcessor(int eos_token_id, int max_initial_timestamp_index) - : eos_token_id_(eos_token_id), max_initial_timestamp_index_(max_initial_timestamp_index) {} + TimestampLogitsProcessor(int eos_token_id, + int transcribe_token_id, + int translate_token_id, + int max_initial_timestamp_index) + : eos_token_id_(eos_token_id), + transcribe_token_id_(transcribe_token_id), + translate_token_id_(translate_token_id), + max_initial_timestamp_index_(max_initial_timestamp_index) {} void Process(const ISequences* sequences, NextTokenScores& next_token_scores) override { - // TODO: translate_token_id_ and transcribe_token_id_ need to support both multilingual and English-only models. const int beg_token_id_ = eos_token_id_ + 107; const int not_token_id_ = eos_token_id_ + 106; const int solm_token_id_ = eos_token_id_ + 105; const int sot_token_id_ = eos_token_id_ + 1; - constexpr int translate_token_id_ = 50358; - constexpr int transcribe_token_id_ = 50359; + std::cout << "Transcribe token id: " << transcribe_token_id_ << std::endl; + std::cout << "Translate token id: " << translate_token_id_ << std::endl; const int batch_beam_size = next_token_scores.batch_beam_size; const int vocab_size = next_token_scores.vocab_size; @@ -273,6 +287,8 @@ class TimestampLogitsProcessor : public ILogitsProcessor { private: int eos_token_id_; + int transcribe_token_id_; + int translate_token_id_; int max_initial_timestamp_index_; }; @@ -334,7 +350,10 @@ class LogitsProcessorList : public ILogitsProcessorList { // Add timestamp processor for whisper model if (parameters.model_type == IGenerationParameters::kModelTypeWhisper && parameters.logits_processor == IGenerationParameters::kLogitsProcessorTypeWhisper) { constexpr int max_initial_timestamp_index = 50; - timestamp_processor_ = std::make_unique>(parameters.eos_token_id, max_initial_timestamp_index); + timestamp_processor_ = std::make_unique>(parameters.eos_token_id, + parameters.transcribe_token_id, + parameters.translate_token_id, + max_initial_timestamp_index); processor_list_.push_back(timestamp_processor_.get()); } diff --git a/onnxruntime/contrib_ops/cuda/transformers/generation_device_helper.cc b/onnxruntime/contrib_ops/cuda/transformers/generation_device_helper.cc index 380d561bbb23c..777b6682e9d94 100644 --- a/onnxruntime/contrib_ops/cuda/transformers/generation_device_helper.cc +++ b/onnxruntime/contrib_ops/cuda/transformers/generation_device_helper.cc @@ -469,7 +469,10 @@ Status ProcessLogits(const OrtValue& logits, // cudaMemcpyDeviceToHost, cuda_stream)); constexpr int max_initial_timestamp_index = 50; - onnxruntime::contrib::transformers::TimestampLogitsProcessor time_logit_processor(parameters->eos_token_id, max_initial_timestamp_index); + onnxruntime::contrib::transformers::TimestampLogitsProcessor time_logit_processor(parameters->eos_token_id, + parameters->transcribe_token_id, + parameters->translate_token_id, + max_initial_timestamp_index); onnxruntime::contrib::transformers::NextTokenScores next_token_scores_timestamp({cpu_next_token_scores_span, batch_beam_size, vocab_size}); CUDA_RETURN_IF_ERROR(cudaStreamSynchronize(cuda_stream)); diff --git a/onnxruntime/core/graph/contrib_ops/contrib_defs.cc b/onnxruntime/core/graph/contrib_ops/contrib_defs.cc index 27c968a59eb91..1c4684c893355 100644 --- a/onnxruntime/core/graph/contrib_ops/contrib_defs.cc +++ b/onnxruntime/core/graph/contrib_ops/contrib_defs.cc @@ -1189,6 +1189,8 @@ ONNX_MS_OPERATOR_SET_SCHEMA(WhisperBeamSearch, 1, .Attr("eos_token_id", "The id of the end-of-sequence token", AttributeProto::INT) .Attr("pad_token_id", "The id of the padding token", AttributeProto::INT) .Attr("decoder_start_token_id", "The id of the token that indicates decoding starts.", AttributeProto::INT, static_cast(-1)) + .Attr("transcribe_token_id", "The id of the transcribe task", AttributeProto::INT) + .Attr("translate_token_id", "The id of the translate task", AttributeProto::INT) .Attr("no_repeat_ngram_size", "no repeat ngrams size", AttributeProto::INT, static_cast(0)) .Attr("early_stopping", "early stop or not", AttributeProto::INT, static_cast(0)) .Attr("model_type", "Must be 2 for whisper", AttributeProto::INT, static_cast(2)) diff --git a/onnxruntime/python/tools/transformers/models/whisper/convert_to_onnx.py b/onnxruntime/python/tools/transformers/models/whisper/convert_to_onnx.py index 975290277a7b1..aec908f36df53 100644 --- a/onnxruntime/python/tools/transformers/models/whisper/convert_to_onnx.py +++ b/onnxruntime/python/tools/transformers/models/whisper/convert_to_onnx.py @@ -370,7 +370,6 @@ def export_onnx_models( onnx_path, verbose, use_external_data_format, - use_decoder_input_ids=use_forced_decoder_ids, use_int32_inputs=use_int32_inputs, ) else: diff --git a/onnxruntime/python/tools/transformers/models/whisper/whisper_chain.py b/onnxruntime/python/tools/transformers/models/whisper/whisper_chain.py index fde8a362ed6d2..71ecf4283dbdc 100644 --- a/onnxruntime/python/tools/transformers/models/whisper/whisper_chain.py +++ b/onnxruntime/python/tools/transformers/models/whisper/whisper_chain.py @@ -116,6 +116,8 @@ def chain_model(args): helper.make_attribute("eos_token_id", config.eos_token_id), helper.make_attribute("pad_token_id", config.pad_token_id), helper.make_attribute("decoder_start_token_id", config.decoder_start_token_id), + helper.make_attribute("transcribe_token_id", tokenizer.convert_tokens_to_ids(['<|transcribe|>'])[0]), + helper.make_attribute("translate_token_id", tokenizer.convert_tokens_to_ids(['<|translate|>'])[0]), helper.make_attribute("no_repeat_ngram_size", args.no_repeat_ngram_size), helper.make_attribute("early_stopping", True), helper.make_attribute("model_type", 2), diff --git a/onnxruntime/python/tools/transformers/models/whisper/whisper_encoder_decoder_init.py b/onnxruntime/python/tools/transformers/models/whisper/whisper_encoder_decoder_init.py index a145178dbf37e..1c0c0d7be7088 100644 --- a/onnxruntime/python/tools/transformers/models/whisper/whisper_encoder_decoder_init.py +++ b/onnxruntime/python/tools/transformers/models/whisper/whisper_encoder_decoder_init.py @@ -63,7 +63,7 @@ def create_dummy( config: WhisperConfig, batch_size: int, encode_sequence_length: int, - use_decoder_input_ids: int, + use_decoder_input_ids: bool, device: torch.device, use_int32_inputs: bool = False, ): # -> WhisperEncoderDecoderInitInputs: @@ -114,7 +114,7 @@ def export_onnx( model.config, batch_size=2, encode_sequence_length=3000, - use_decoder_input_ids=use_decoder_input_ids, + use_decoder_input_ids=True, device=device, use_int32_inputs=use_int32_inputs, ) @@ -146,7 +146,7 @@ def export_onnx( hidden_size = str(model.config.d_model) head_size = str(model.config.d_model // model.config.encoder_attention_heads) dynamic_axes = { - "encoder_input_ids": {0: "batch_size", 1: "encode_sequence_length"}, + "encoder_input_ids": {0: "batch_size", 1: "feature_size"}, "encoder_hidden_states": { 0: "batch_size", 1: "encode_sequence_length", diff --git a/onnxruntime/python/tools/transformers/models/whisper/whisper_helper.py b/onnxruntime/python/tools/transformers/models/whisper/whisper_helper.py index 5295d4310dec2..e049d90bf4858 100644 --- a/onnxruntime/python/tools/transformers/models/whisper/whisper_helper.py +++ b/onnxruntime/python/tools/transformers/models/whisper/whisper_helper.py @@ -285,7 +285,12 @@ def verify_onnx( ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation") input_features = processor([ds[0]["audio"]["array"]], return_tensors="pt").input_features - batch_size, max_length, min_length, num_beams, num_return_sequences = 1, 26, 0, 5, 1 + start_id = [config.decoder_start_token_id] # ex: [50258] + prompt_ids = processor.get_decoder_prompt_ids(language="english", task="transcribe") + prompt_ids = list(map(lambda token: token[1], prompt_ids)) # ex: [50259, 50358, 50363] + forced_decoder_ids = start_id + prompt_ids # ex: [50258, 50259, 50358, 50363] + + batch_size, max_length, min_length, num_beams, num_return_sequences = 1, 30, 0, 1, 1 length_penalty, repetition_penalty = 1.0, 1.0 inputs = { "input_features": input_features.to(device), @@ -322,45 +327,46 @@ def verify_onnx( elif name == "prefix_vocab_mask": inputs[name] = np.ones((batch_size, config.vocab_size), dtype=ort_to_np[dtype]) elif name == "decoder_input_ids": - raw_input_ids = ( - [[config.decoder_start_token_id]] - if use_extra_decoding_ids - else [[config.decoder_start_token_id, 50259, 50359, 50363]] - ) + raw_input_ids = [start_id] if use_extra_decoding_ids else [forced_decoder_ids] inputs[name] = np.array(raw_input_ids, dtype=ort_to_np[dtype]) elif name == "logits_processor": inputs[name] = np.array([1], dtype=ort_to_np[dtype]) elif name == "cross_qk_layer_head": inputs[name] = np.array([[0, 0]], dtype=ort_to_np[dtype]) elif name == "extra_decoding_ids": - inputs[name] = np.repeat(np.array([[50259, 50359, 50363]], dtype=ort_to_np[dtype]), batch_size, 0) + inputs[name] = np.repeat(np.array([prompt_ids], dtype=ort_to_np[dtype]), batch_size, 0) elif name == "temperature": inputs[name] = np.array([1.0], dtype=ort_to_np[dtype]) else: inputs[name] = np.array([inputs[name]], dtype=ort_to_np[dtype]) ort_outputs = ort_session.run(None, inputs)[0][0] + logger.warning(ort_outputs) - if pt_outputs.shape != ort_outputs.shape: - logger.warning("PyTorch and ONNX Runtime outputs do not have the same shape") + expected_transcription_no_comma = ( + " Mr. Quilter is the apostle of the middle classes and we are glad to welcome his gospel." + ) + expected_transcription_with_comma = ( + " Mr. Quilter is the apostle of the middle classes, and we are glad to welcome his gospel." + ) + expected_transcription_options = {expected_transcription_no_comma, expected_transcription_with_comma} + pt_transcription = processor.batch_decode(pt_outputs, skip_special_tokens=True)[0] + ort_transcription = processor.batch_decode(ort_outputs, skip_special_tokens=True)[0] - diff = pt_outputs - ort_outputs - max_diff = max(diff.min(), diff.max(), key=abs) + max_diff = 0 + parity = ( + pt_transcription in expected_transcription_no_comma and ort_transcription in expected_transcription_with_comma + ) - if max_diff > 0: - # For ONNX Runtime INT8 model - pt_expected_transcription = ( - " Mr. Quilter is the apostle of the middle classes and we are glad to welcome his gospel." - ) - pt_transcription = processor.batch_decode(pt_outputs, skip_special_tokens=True) - ort_expected_transcription = ( - " Mr. Quilter is the apostle of the middle classes, and we are glad to welcome his gospel." - ) - ort_transcription = processor.batch_decode(ort_outputs, skip_special_tokens=True) + if pt_outputs.shape != ort_outputs.shape: + logger.warning("PyTorch and ONNX Runtime outputs do not have the same shape") + else: + diff = pt_outputs - ort_outputs[] + max_diff = max(diff.min(), diff.max(), key=abs) - parity = ( - pt_expected_transcription == pt_transcription[0] and ort_expected_transcription == ort_transcription[0] - ) - if parity: - max_diff = 0 + if parity: + max_diff = 0 + else: + logger.warning(f"PyTorch outputs: {pt_transcription}") + logger.warning(f"ONNX Runtime outputs: {ort_transcription}") return max_diff From 9fbca34fcdbf5e9c2e911fbd89bcea98eb993b0c Mon Sep 17 00:00:00 2001 From: Kunal Vaishnavi Date: Fri, 26 Jan 2024 18:33:55 +0000 Subject: [PATCH 03/19] Fix parity check --- .../models/whisper/whisper_helper.py | 16 +++++++--------- 1 file changed, 7 insertions(+), 9 deletions(-) diff --git a/onnxruntime/python/tools/transformers/models/whisper/whisper_helper.py b/onnxruntime/python/tools/transformers/models/whisper/whisper_helper.py index e049d90bf4858..07d83c5f269bd 100644 --- a/onnxruntime/python/tools/transformers/models/whisper/whisper_helper.py +++ b/onnxruntime/python/tools/transformers/models/whisper/whisper_helper.py @@ -340,7 +340,6 @@ def verify_onnx( else: inputs[name] = np.array([inputs[name]], dtype=ort_to_np[dtype]) ort_outputs = ort_session.run(None, inputs)[0][0] - logger.warning(ort_outputs) expected_transcription_no_comma = ( " Mr. Quilter is the apostle of the middle classes and we are glad to welcome his gospel." @@ -352,20 +351,19 @@ def verify_onnx( pt_transcription = processor.batch_decode(pt_outputs, skip_special_tokens=True)[0] ort_transcription = processor.batch_decode(ort_outputs, skip_special_tokens=True)[0] - max_diff = 0 parity = ( pt_transcription in expected_transcription_no_comma and ort_transcription in expected_transcription_with_comma ) + max_diff = 0 - if pt_outputs.shape != ort_outputs.shape: - logger.warning("PyTorch and ONNX Runtime outputs do not have the same shape") - else: - diff = pt_outputs - ort_outputs[] + if not parity: + if pt_outputs.shape != ort_outputs.shape: + diff = pt_outputs - ort_outputs[:, : len(pt_outputs[0])] + else: + diff = pt_outputs - ort_outputs max_diff = max(diff.min(), diff.max(), key=abs) - if parity: - max_diff = 0 - else: + if max_diff != 0: logger.warning(f"PyTorch outputs: {pt_transcription}") logger.warning(f"ONNX Runtime outputs: {ort_transcription}") From 272258250bb4090c34ad41ce08e88b1b557fc2b9 Mon Sep 17 00:00:00 2001 From: Kunal Vaishnavi Date: Thu, 1 Feb 2024 19:37:13 +0000 Subject: [PATCH 04/19] Add packages to requirements file --- .../python/tools/transformers/models/whisper/requirements.txt | 3 +++ 1 file changed, 3 insertions(+) diff --git a/onnxruntime/python/tools/transformers/models/whisper/requirements.txt b/onnxruntime/python/tools/transformers/models/whisper/requirements.txt index 720d1f8f21b03..3fcdbdfd56ea8 100644 --- a/onnxruntime/python/tools/transformers/models/whisper/requirements.txt +++ b/onnxruntime/python/tools/transformers/models/whisper/requirements.txt @@ -10,6 +10,9 @@ transformers>=4.24.0 openai-whisper ffmpeg-python datasets +soundfile +librosa +optimum onnxruntime>=1.15.1 onnxruntime-extensions>=0.9.0 protobuf==3.20.2 From f44e4277d1430266a243878e64ec5b5be7142018 Mon Sep 17 00:00:00 2001 From: Kunal Vaishnavi Date: Mon, 12 Feb 2024 06:29:46 +0000 Subject: [PATCH 05/19] Fix token ids in timestamps processor --- .../contrib_ops/cpu/transformers/logits_processor.h | 10 ++++------ 1 file changed, 4 insertions(+), 6 deletions(-) diff --git a/onnxruntime/contrib_ops/cpu/transformers/logits_processor.h b/onnxruntime/contrib_ops/cpu/transformers/logits_processor.h index c6646bb746219..7a6b8544bae2d 100644 --- a/onnxruntime/contrib_ops/cpu/transformers/logits_processor.h +++ b/onnxruntime/contrib_ops/cpu/transformers/logits_processor.h @@ -170,12 +170,10 @@ class TimestampLogitsProcessor : public ILogitsProcessor { void Process(const ISequences* sequences, NextTokenScores& next_token_scores) override { - const int beg_token_id_ = eos_token_id_ + 107; - const int not_token_id_ = eos_token_id_ + 106; - const int solm_token_id_ = eos_token_id_ + 105; - const int sot_token_id_ = eos_token_id_ + 1; - std::cout << "Transcribe token id: " << transcribe_token_id_ << std::endl; - std::cout << "Translate token id: " << translate_token_id_ << std::endl; + const int beg_token_id_ = eos_token_id_ + 107; // <|0.00|> + const int not_token_id_ = eos_token_id_ + 106; // <|notimestamps|> + const int solm_token_id_ = eos_token_id_ + 103; // <|startoflm|> + const int sot_token_id_ = eos_token_id_ + 1; // <|startoftranscript|> const int batch_beam_size = next_token_scores.batch_beam_size; const int vocab_size = next_token_scores.vocab_size; From 6c991731990ab30df8013f9bb0d93a5d01abab80 Mon Sep 17 00:00:00 2001 From: Kunal Vaishnavi Date: Mon, 12 Feb 2024 19:36:21 +0000 Subject: [PATCH 06/19] Convert other token ids to attrs --- .../transformers/beam_search_impl_whisper.h | 4 +- .../transformers/beam_search_parameters.cc | 8 +++- .../cpu/transformers/generation_shared.h | 9 ++++- .../cpu/transformers/logits_processor.h | 38 ++++++++++++++----- .../transformers/generation_device_helper.cc | 9 ++++- .../core/graph/contrib_ops/contrib_defs.cc | 15 +++++--- .../models/whisper/whisper_chain.py | 9 +++-- 7 files changed, 65 insertions(+), 27 deletions(-) diff --git a/onnxruntime/contrib_ops/cpu/transformers/beam_search_impl_whisper.h b/onnxruntime/contrib_ops/cpu/transformers/beam_search_impl_whisper.h index 72e6d3930a548..af0904b7d6e4b 100644 --- a/onnxruntime/contrib_ops/cpu/transformers/beam_search_impl_whisper.h +++ b/onnxruntime/contrib_ops/cpu/transformers/beam_search_impl_whisper.h @@ -134,8 +134,8 @@ Status BeamSearchWhisper::Execute(const FeedsFetchesManager& encoder_feeds_fe TensorShape no_speech_probs_shape{parameters->batch_size}; Tensor* no_speech_probs = this->context_.Output(parameters->no_speech_probs_output_id, no_speech_probs_shape); if (no_speech_probs && no_speech_probs->MutableData()) { - ORT_ENFORCE(parameters->no_speech_token >= 0 && parameters->no_speech_token < parameters->vocab_size, - "no_speech_token id out of range, it is ", parameters->no_speech_token, + ORT_ENFORCE(parameters->no_speech_token_id >= 0 && parameters->no_speech_token_id < parameters->vocab_size, + "no_speech_token_id is out of range, it is ", parameters->no_speech_token_id, ", vocab_size is ", parameters->vocab_size); this->parameters_->no_speech_probs = (void*)no_speech_probs->MutableData(); } diff --git a/onnxruntime/contrib_ops/cpu/transformers/beam_search_parameters.cc b/onnxruntime/contrib_ops/cpu/transformers/beam_search_parameters.cc index b8beb98e3d1dd..93837e785b4a4 100644 --- a/onnxruntime/contrib_ops/cpu/transformers/beam_search_parameters.cc +++ b/onnxruntime/contrib_ops/cpu/transformers/beam_search_parameters.cc @@ -153,9 +153,13 @@ void WhisperBeamSearchParameters::ParseFromAttributes(const OpKernelInfo& info) model_type = static_cast(info.GetAttrOrDefault("model_type", IGenerationParameters::kModelTypeWhisper)); ORT_ENFORCE(model_type == IGenerationParameters::kModelTypeWhisper); - transcribe_token_id = static_cast(info.GetAttrOrDefault("transcribe_token_id", -1LL)); + // Token ids are defined below in the order that they appear in the tokenizer translate_token_id = static_cast(info.GetAttrOrDefault("translate_token_id", -1LL)); - no_speech_token = static_cast(info.GetAttrOrDefault("no_speech_token", -1LL)); + transcribe_token_id = static_cast(info.GetAttrOrDefault("transcribe_token_id", -1LL)); + start_of_lm_token_id = static_cast(info.GetAttrOrDefault("start_of_lm_token_id", -1LL)); + no_speech_token_id = static_cast(info.GetAttrOrDefault("no_speech_token_id", -1LL)); + no_timestamps_token_id = static_cast(info.GetAttrOrDefault("no_timestamps_token_id", -1LL)); + beginning_timestamp_token_id = static_cast(info.GetAttrOrDefault("beginning_timestamp_token_id", -1LL)); cross_qk_layer_head_input_id = 12; extra_decoding_ids_input_id = 13; cross_qk_output_id = 3; diff --git a/onnxruntime/contrib_ops/cpu/transformers/generation_shared.h b/onnxruntime/contrib_ops/cpu/transformers/generation_shared.h index 4c6346afa7118..b1dd55eb20f34 100644 --- a/onnxruntime/contrib_ops/cpu/transformers/generation_shared.h +++ b/onnxruntime/contrib_ops/cpu/transformers/generation_shared.h @@ -183,9 +183,14 @@ struct IGenerationParameters { // Parameters for whisper model bool decoder_output_cross_qk = false; gsl::span extra_decoding_ids; - int32_t transcribe_token_id = -1; + + // Token ids are defined below in the order that they appear in the tokenizer int32_t translate_token_id = -1; - int32_t no_speech_token = -1; + int32_t transcribe_token_id = -1; + int32_t start_of_lm_token_id = -1; + int32_t no_speech_token_id = -1; + int32_t no_timestamps_token_id = -1; + int32_t beginning_timestamp_token_id = -1; void* no_speech_probs = nullptr; int cross_qk_layer_head_input_id = -1; diff --git a/onnxruntime/contrib_ops/cpu/transformers/logits_processor.h b/onnxruntime/contrib_ops/cpu/transformers/logits_processor.h index 7a6b8544bae2d..2f11e6c5dd4d8 100644 --- a/onnxruntime/contrib_ops/cpu/transformers/logits_processor.h +++ b/onnxruntime/contrib_ops/cpu/transformers/logits_processor.h @@ -159,21 +159,30 @@ class PresencePenaltyLogitsProcessor : public ILogitsProcessor { template class TimestampLogitsProcessor : public ILogitsProcessor { public: - TimestampLogitsProcessor(int eos_token_id, - int transcribe_token_id, - int translate_token_id, + TimestampLogitsProcessor(int eos_token_id, // <|endoftext|> + int sot_token_id, // <|startoftranscript|> + int translate_token_id, // <|translate|> + int transcribe_token_id, // <|transcribe|> + int solm_token_id, // <|startoflm|> + int not_token_id, // <|notimestamps|> + int beg_token_id, // <|0.00|> int max_initial_timestamp_index) : eos_token_id_(eos_token_id), - transcribe_token_id_(transcribe_token_id), + sot_token_id_(sot_token_id), translate_token_id_(translate_token_id), + transcribe_token_id_(transcribe_token_id), + solm_token_id_(solm_token_id), + not_token_id_(not_token_id), + beg_token_id_(beg_token_id), max_initial_timestamp_index_(max_initial_timestamp_index) {} void Process(const ISequences* sequences, NextTokenScores& next_token_scores) override { - const int beg_token_id_ = eos_token_id_ + 107; // <|0.00|> - const int not_token_id_ = eos_token_id_ + 106; // <|notimestamps|> - const int solm_token_id_ = eos_token_id_ + 103; // <|startoflm|> - const int sot_token_id_ = eos_token_id_ + 1; // <|startoftranscript|> + // TODO: convert below token ids to attrs in WhisperBeamSearch because they differ in whisper-tiny vs whisper-large-v3 + // const int beg_token_id_ = eos_token_id_ + 107; // <|0.00|> + // const int not_token_id_ = eos_token_id_ + 106; // <|notimestamps|> + // const int solm_token_id_ = eos_token_id_ + 103; // <|startoflm|> + // const int sot_token_id_ = eos_token_id_ + 1; // <|startoftranscript|> const int batch_beam_size = next_token_scores.batch_beam_size; const int vocab_size = next_token_scores.vocab_size; @@ -285,8 +294,12 @@ class TimestampLogitsProcessor : public ILogitsProcessor { private: int eos_token_id_; - int transcribe_token_id_; + int sot_token_id_; int translate_token_id_; + int transcribe_token_id_; + int solm_token_id_; + int not_token_id_; + int beg_token_id_; int max_initial_timestamp_index_; }; @@ -348,9 +361,14 @@ class LogitsProcessorList : public ILogitsProcessorList { // Add timestamp processor for whisper model if (parameters.model_type == IGenerationParameters::kModelTypeWhisper && parameters.logits_processor == IGenerationParameters::kLogitsProcessorTypeWhisper) { constexpr int max_initial_timestamp_index = 50; + // Token ids are passed below in the order that they appear in the tokenizer timestamp_processor_ = std::make_unique>(parameters.eos_token_id, - parameters.transcribe_token_id, + parameters.decoder_start_token_id, parameters.translate_token_id, + parameters.transcribe_token_id, + parameters.start_of_lm_token_id, + parameters.no_timestamps_token_id, + parameters.beginning_timestamp_token_id, max_initial_timestamp_index); processor_list_.push_back(timestamp_processor_.get()); } diff --git a/onnxruntime/contrib_ops/cuda/transformers/generation_device_helper.cc b/onnxruntime/contrib_ops/cuda/transformers/generation_device_helper.cc index 2d4d7afd47d69..7adc2fe0a67ea 100644 --- a/onnxruntime/contrib_ops/cuda/transformers/generation_device_helper.cc +++ b/onnxruntime/contrib_ops/cuda/transformers/generation_device_helper.cc @@ -424,7 +424,7 @@ Status ProcessLogits(const OrtValue& logits, // const bool is_whisper_model = (parameters->model_type == onnxruntime::contrib::transformers::IGenerationParameters::kModelTypeWhisper); if (step == 1 && is_whisper_model && parameters->no_speech_probs) { cuda::LaunchSaveNoSpeechProbs( - (T*)parameters->no_speech_probs, Y_data, batch_size, num_beams, vocab_size, parameters->no_speech_token, cuda_stream); + (T*)parameters->no_speech_probs, Y_data, batch_size, num_beams, vocab_size, parameters->no_speech_token_id, cuda_stream); } // NOTE: currently we treat extra decoding ids are same @@ -469,9 +469,14 @@ Status ProcessLogits(const OrtValue& logits, // cudaMemcpyDeviceToHost, cuda_stream)); constexpr int max_initial_timestamp_index = 50; + // Token ids are passed below in the order that they appear in the tokenizer onnxruntime::contrib::transformers::TimestampLogitsProcessor time_logit_processor(parameters->eos_token_id, - parameters->transcribe_token_id, + parameters->decoder_start_token_id, parameters->translate_token_id, + parameters->transcribe_token_id, + parameters->start_of_lm_token_id, + parameters->no_timestamps_token_id, + parameters->beginning_timestamp_token_id, max_initial_timestamp_index); onnxruntime::contrib::transformers::NextTokenScores next_token_scores_timestamp({cpu_next_token_scores_span, batch_beam_size, vocab_size}); diff --git a/onnxruntime/core/graph/contrib_ops/contrib_defs.cc b/onnxruntime/core/graph/contrib_ops/contrib_defs.cc index 1c4684c893355..5b429c22442ff 100644 --- a/onnxruntime/core/graph/contrib_ops/contrib_defs.cc +++ b/onnxruntime/core/graph/contrib_ops/contrib_defs.cc @@ -1188,9 +1188,15 @@ ONNX_MS_OPERATOR_SET_SCHEMA(WhisperBeamSearch, 1, .SetDoc("Beam Search for whisper model, especiall with cross_qk features etc.") .Attr("eos_token_id", "The id of the end-of-sequence token", AttributeProto::INT) .Attr("pad_token_id", "The id of the padding token", AttributeProto::INT) - .Attr("decoder_start_token_id", "The id of the token that indicates decoding starts.", AttributeProto::INT, static_cast(-1)) - .Attr("transcribe_token_id", "The id of the transcribe task", AttributeProto::INT) + .Attr("decoder_start_token_id", "The id of the token that indicates decoding starts (i.e. the start of transcription token id)", AttributeProto::INT, static_cast(-1)) .Attr("translate_token_id", "The id of the translate task", AttributeProto::INT) + .Attr("transcribe_token_id", "The id of the transcribe task", AttributeProto::INT) + .Attr("start_of_lm_token_id", "The id of the token that indicates LM starts", AttributeProto::INT) + .Attr("no_speech_token_id", + "The token in whisper model that marks all sequence empty. With this model, whisper could output no_speech_prob after. Default -1.", + AttributeProto::INT, OPTIONAL_VALUE) + .Attr("no_timestamps_token_id", "The id of the token that indicates no timestamps", AttributeProto::INT) + .Attr("beginning_timestamp_token_id", "The id of the first timestamp", AttributeProto::INT) .Attr("no_repeat_ngram_size", "no repeat ngrams size", AttributeProto::INT, static_cast(0)) .Attr("early_stopping", "early stop or not", AttributeProto::INT, static_cast(0)) .Attr("model_type", "Must be 2 for whisper", AttributeProto::INT, static_cast(2)) @@ -1205,9 +1211,6 @@ ONNX_MS_OPERATOR_SET_SCHEMA(WhisperBeamSearch, 1, "If not provided, it will be inferred from the decoder subgraph's output shape", AttributeProto::INT, static_cast(-1)) .Attr("decoder_output_cross_qk", "If nozero, decoder subgraph contains output Q*K from cross attentions. Default 0.", AttributeProto::INT, OPTIONAL_VALUE) - .Attr("no_speech_token", - "The token in whisper model that marks all sequence empty. With this model, whisper could output no_speech_prob after. Default -1.", - AttributeProto::INT, OPTIONAL_VALUE) .Input(0, "input_ids", "The sequence used as a prompt for the generation in the encoder subgraph. Shape is (batch_size, sequence_length)", "F") .Input(1, "max_length", "The maximum length of the sequence to be generated. Shape is (1)", "I") .Input(2, "min_length", "The minimum length below which the score of eos_token_id is set to -Inf. Shape is (1)", "I", OpSchema::Optional) @@ -1248,7 +1251,7 @@ ONNX_MS_OPERATOR_SET_SCHEMA(WhisperBeamSearch, 1, "If cross_qk_layer_head is given, shape is [B, R, cross_qk_layer_head.shape[0], T, F]", "V", OpSchema::Optional) .Output(4, "non_speech_probs", - "For whisper model, output the probabilities from logits after encoder and context decoding for the no_speech_token." + "For whisper model, output the probabilities from logits after encoder and context decoding for the no_speech_token_id." "Currently we treat the last token's logits is what we need, in future extra graph logic may be add to the encoder/context-decoder subgraph." "The prob is save before logits may be updated by extra-decoding-ids. The shape of non_speech_probs is [B]", "T", OpSchema::Optional) diff --git a/onnxruntime/python/tools/transformers/models/whisper/whisper_chain.py b/onnxruntime/python/tools/transformers/models/whisper/whisper_chain.py index acf0c8f75e2f0..d3552d3e07f15 100644 --- a/onnxruntime/python/tools/transformers/models/whisper/whisper_chain.py +++ b/onnxruntime/python/tools/transformers/models/whisper/whisper_chain.py @@ -139,14 +139,17 @@ def chain_model(args): beam_search_attrs = [ helper.make_attribute("eos_token_id", config.eos_token_id), helper.make_attribute("pad_token_id", config.pad_token_id), - helper.make_attribute("decoder_start_token_id", config.decoder_start_token_id), - helper.make_attribute("transcribe_token_id", tokenizer.convert_tokens_to_ids(['<|transcribe|>'])[0]), + helper.make_attribute("decoder_start_token_id", config.decoder_start_token_id), # same as tokenizer.convert_tokens_to_ids(['<|startoftranscript|>'])[0] helper.make_attribute("translate_token_id", tokenizer.convert_tokens_to_ids(['<|translate|>'])[0]), + helper.make_attribute("transcribe_token_id", tokenizer.convert_tokens_to_ids(['<|transcribe|>'])[0]), + helper.make_attribute("start_of_lm_token_id", tokenizer.convert_tokens_to_ids(['<|startoflm|>'])[0]), + helper.make_attribute("no_speech_token_id", tokenizer.convert_tokens_to_ids(['<|nospeech|>'])[0]) if args.output_no_speech_probs else "", + helper.make_attribute("no_timestamps_token_id", tokenizer.convert_tokens_to_ids(['<|notimestamps|>'])[0]), + helper.make_attribute("beginning_timestamp_token_id", tokenizer.convert_tokens_to_ids(['<|0.00|>'])[0]), helper.make_attribute("no_repeat_ngram_size", args.no_repeat_ngram_size), helper.make_attribute("early_stopping", True), helper.make_attribute("model_type", 2), helper.make_attribute("decoder_output_cross_qk", 1) if args.collect_cross_qk else "", - helper.make_attribute("no_speech_token", tokenizer.convert_tokens_to_ids(['<|nospeech|>'])[0]) if args.output_no_speech_probs else "", ] node = helper.make_node( "WhisperBeamSearch", From 3df0429607cb538899a7f379bbfe5bd817313e11 Mon Sep 17 00:00:00 2001 From: Kunal Vaishnavi Date: Tue, 13 Feb 2024 19:06:16 +0000 Subject: [PATCH 07/19] Fix timestamps test case bugs --- .../contrib_ops/cpu/transformers/logits_processor.h | 8 ++++++++ .../transformers/test_whisper_timestamp_processor.py | 4 ++-- 2 files changed, 10 insertions(+), 2 deletions(-) diff --git a/onnxruntime/contrib_ops/cpu/transformers/logits_processor.h b/onnxruntime/contrib_ops/cpu/transformers/logits_processor.h index 2f11e6c5dd4d8..f5f4536cb6241 100644 --- a/onnxruntime/contrib_ops/cpu/transformers/logits_processor.h +++ b/onnxruntime/contrib_ops/cpu/transformers/logits_processor.h @@ -184,6 +184,14 @@ class TimestampLogitsProcessor : public ILogitsProcessor { // const int solm_token_id_ = eos_token_id_ + 103; // <|startoflm|> // const int sot_token_id_ = eos_token_id_ + 1; // <|startoftranscript|> + std::cout << "<|endoftext|> - " << eos_token_id_ << std::endl; + std::cout << "<|startoftranscript|> - " << sot_token_id_ << std::endl; + std::cout << "<|translate|> - " << translate_token_id_ << std::endl; + std::cout << "<|transcribe|> - " << transcribe_token_id_ << std::endl; + std::cout << "<|startoflm|> - " << solm_token_id_ << std::endl; + std::cout << "<|notimestamps|> - " << not_token_id_ << std::endl; + std::cout << "<|0.00|> - " << beg_token_id_ << std::endl; + const int batch_beam_size = next_token_scores.batch_beam_size; const int vocab_size = next_token_scores.vocab_size; for (int i = 0; i < batch_beam_size; i++) { diff --git a/onnxruntime/test/python/transformers/test_whisper_timestamp_processor.py b/onnxruntime/test/python/transformers/test_whisper_timestamp_processor.py index 77ce09d7e793b..7892000ae45a0 100644 --- a/onnxruntime/test/python/transformers/test_whisper_timestamp_processor.py +++ b/onnxruntime/test/python/transformers/test_whisper_timestamp_processor.py @@ -50,7 +50,7 @@ def run_timestamp(self, provider: str): ort_out = sess.run(None, ort_inputs) ort_out_tensor = torch.from_numpy(ort_out[0]) ort_transcription = processor.batch_decode( - ort_out_tensor[0][0].view(1, -1), skip_special_tokens=True, output_offsets=True + ort_out_tensor[0][0].view(1, -1), skip_special_tokens=True, output_offsets=True, decode_with_timestamps=True ) print(ort_transcription) expected_transcription = [ @@ -58,7 +58,7 @@ def run_timestamp(self, provider: str): "text": "<|0.00|> Mr. Quilter is the apostle of the middle classes and we are glad to welcome his gospel.<|5.44|>", "offsets": [ { - "text": "<|0.00|> Mr. Quilter is the apostle of the middle classes and we are glad to welcome his gospel.<|5.44|>", + "text": " Mr. Quilter is the apostle of the middle classes and we are glad to welcome his gospel.", "timestamp": (0.0, 5.44), } ], From 29ced94589870e399a35f8b075399646b972946c Mon Sep 17 00:00:00 2001 From: Kunal Vaishnavi Date: Tue, 13 Feb 2024 19:18:59 +0000 Subject: [PATCH 08/19] Cleaning up comments --- .../contrib_ops/cpu/transformers/logits_processor.h | 13 ------------- 1 file changed, 13 deletions(-) diff --git a/onnxruntime/contrib_ops/cpu/transformers/logits_processor.h b/onnxruntime/contrib_ops/cpu/transformers/logits_processor.h index f5f4536cb6241..b04de851c2c93 100644 --- a/onnxruntime/contrib_ops/cpu/transformers/logits_processor.h +++ b/onnxruntime/contrib_ops/cpu/transformers/logits_processor.h @@ -178,19 +178,6 @@ class TimestampLogitsProcessor : public ILogitsProcessor { void Process(const ISequences* sequences, NextTokenScores& next_token_scores) override { - // TODO: convert below token ids to attrs in WhisperBeamSearch because they differ in whisper-tiny vs whisper-large-v3 - // const int beg_token_id_ = eos_token_id_ + 107; // <|0.00|> - // const int not_token_id_ = eos_token_id_ + 106; // <|notimestamps|> - // const int solm_token_id_ = eos_token_id_ + 103; // <|startoflm|> - // const int sot_token_id_ = eos_token_id_ + 1; // <|startoftranscript|> - - std::cout << "<|endoftext|> - " << eos_token_id_ << std::endl; - std::cout << "<|startoftranscript|> - " << sot_token_id_ << std::endl; - std::cout << "<|translate|> - " << translate_token_id_ << std::endl; - std::cout << "<|transcribe|> - " << transcribe_token_id_ << std::endl; - std::cout << "<|startoflm|> - " << solm_token_id_ << std::endl; - std::cout << "<|notimestamps|> - " << not_token_id_ << std::endl; - std::cout << "<|0.00|> - " << beg_token_id_ << std::endl; const int batch_beam_size = next_token_scores.batch_beam_size; const int vocab_size = next_token_scores.vocab_size; From e723bb0c0953ad314905cfcb4844a328d97b1cd3 Mon Sep 17 00:00:00 2001 From: Kunal Vaishnavi Date: Tue, 13 Feb 2024 19:24:50 +0000 Subject: [PATCH 09/19] Add changes suggested by linter --- .../cpu/transformers/logits_processor.cc | 1 - .../cpu/transformers/logits_processor.h | 15 ++-- .../models/whisper/convert_to_onnx.py | 6 +- .../models/whisper/whisper_chain.py | 76 ++++++++++--------- .../models/whisper/whisper_helper.py | 2 +- 5 files changed, 52 insertions(+), 48 deletions(-) diff --git a/onnxruntime/contrib_ops/cpu/transformers/logits_processor.cc b/onnxruntime/contrib_ops/cpu/transformers/logits_processor.cc index 5dbdd0a0c5420..9391b39dd0146 100644 --- a/onnxruntime/contrib_ops/cpu/transformers/logits_processor.cc +++ b/onnxruntime/contrib_ops/cpu/transformers/logits_processor.cc @@ -17,7 +17,6 @@ namespace onnxruntime { namespace contrib { namespace transformers { - // Interface for all scorers for beam search or beam sample. template MinLengthLogitsProcessor::MinLengthLogitsProcessor(int min_length, int eos_token_id) diff --git a/onnxruntime/contrib_ops/cpu/transformers/logits_processor.h b/onnxruntime/contrib_ops/cpu/transformers/logits_processor.h index b04de851c2c93..4501cec515db3 100644 --- a/onnxruntime/contrib_ops/cpu/transformers/logits_processor.h +++ b/onnxruntime/contrib_ops/cpu/transformers/logits_processor.h @@ -159,13 +159,13 @@ class PresencePenaltyLogitsProcessor : public ILogitsProcessor { template class TimestampLogitsProcessor : public ILogitsProcessor { public: - TimestampLogitsProcessor(int eos_token_id, // <|endoftext|> - int sot_token_id, // <|startoftranscript|> - int translate_token_id, // <|translate|> - int transcribe_token_id, // <|transcribe|> - int solm_token_id, // <|startoflm|> - int not_token_id, // <|notimestamps|> - int beg_token_id, // <|0.00|> + TimestampLogitsProcessor(int eos_token_id, // <|endoftext|> + int sot_token_id, // <|startoftranscript|> + int translate_token_id, // <|translate|> + int transcribe_token_id, // <|transcribe|> + int solm_token_id, // <|startoflm|> + int not_token_id, // <|notimestamps|> + int beg_token_id, // <|0.00|> int max_initial_timestamp_index) : eos_token_id_(eos_token_id), sot_token_id_(sot_token_id), @@ -178,7 +178,6 @@ class TimestampLogitsProcessor : public ILogitsProcessor { void Process(const ISequences* sequences, NextTokenScores& next_token_scores) override { - const int batch_beam_size = next_token_scores.batch_beam_size; const int vocab_size = next_token_scores.vocab_size; for (int i = 0; i < batch_beam_size; i++) { diff --git a/onnxruntime/python/tools/transformers/models/whisper/convert_to_onnx.py b/onnxruntime/python/tools/transformers/models/whisper/convert_to_onnx.py index c1e0032db7932..121a01210d7d7 100644 --- a/onnxruntime/python/tools/transformers/models/whisper/convert_to_onnx.py +++ b/onnxruntime/python/tools/transformers/models/whisper/convert_to_onnx.py @@ -186,7 +186,7 @@ def parse_arguments(argv=None): help="Use prefix_vocab_mask as an extra graph input to enable specific logits processing", ) optional_inputs.set_defaults(use_prefix_vocab_mask=False) - + optional_inputs.add_argument( "-f", "--use_forced_decoder_ids", @@ -238,7 +238,7 @@ def parse_arguments(argv=None): ) ############################################################# - # Optional outputs for Whisper + # Optional outputs for Whisper # (listed below in the order that WhisperBeamSearch expects) ############################################################# @@ -281,7 +281,7 @@ def parse_arguments(argv=None): help="Beam search model output no speech probs which is computed from the encoder/context-decoder graph.", ) optional_outputs.set_defaults(output_no_speech_probs=False) - + ################################### # Quantization options for Whisper ################################### diff --git a/onnxruntime/python/tools/transformers/models/whisper/whisper_chain.py b/onnxruntime/python/tools/transformers/models/whisper/whisper_chain.py index d3552d3e07f15..54293a0178676 100644 --- a/onnxruntime/python/tools/transformers/models/whisper/whisper_chain.py +++ b/onnxruntime/python/tools/transformers/models/whisper/whisper_chain.py @@ -104,7 +104,7 @@ def chain_model(args): to=TensorProto.FLOAT16, ) graph_nodes.extend([input_features_cast_node, len_pen_cast_node, rep_pen_cast_node]) - + if args.use_temperature: temp_cast_node = helper.make_node( "Cast", @@ -139,13 +139,17 @@ def chain_model(args): beam_search_attrs = [ helper.make_attribute("eos_token_id", config.eos_token_id), helper.make_attribute("pad_token_id", config.pad_token_id), - helper.make_attribute("decoder_start_token_id", config.decoder_start_token_id), # same as tokenizer.convert_tokens_to_ids(['<|startoftranscript|>'])[0] - helper.make_attribute("translate_token_id", tokenizer.convert_tokens_to_ids(['<|translate|>'])[0]), - helper.make_attribute("transcribe_token_id", tokenizer.convert_tokens_to_ids(['<|transcribe|>'])[0]), - helper.make_attribute("start_of_lm_token_id", tokenizer.convert_tokens_to_ids(['<|startoflm|>'])[0]), - helper.make_attribute("no_speech_token_id", tokenizer.convert_tokens_to_ids(['<|nospeech|>'])[0]) if args.output_no_speech_probs else "", - helper.make_attribute("no_timestamps_token_id", tokenizer.convert_tokens_to_ids(['<|notimestamps|>'])[0]), - helper.make_attribute("beginning_timestamp_token_id", tokenizer.convert_tokens_to_ids(['<|0.00|>'])[0]), + helper.make_attribute( + "decoder_start_token_id", config.decoder_start_token_id + ), # same as tokenizer.convert_tokens_to_ids(['<|startoftranscript|>'])[0] + helper.make_attribute("translate_token_id", tokenizer.convert_tokens_to_ids(["<|translate|>"])[0]), + helper.make_attribute("transcribe_token_id", tokenizer.convert_tokens_to_ids(["<|transcribe|>"])[0]), + helper.make_attribute("start_of_lm_token_id", tokenizer.convert_tokens_to_ids(["<|startoflm|>"])[0]), + helper.make_attribute("no_speech_token_id", tokenizer.convert_tokens_to_ids(["<|nospeech|>"])[0]) + if args.output_no_speech_probs + else "", + helper.make_attribute("no_timestamps_token_id", tokenizer.convert_tokens_to_ids(["<|notimestamps|>"])[0]), + helper.make_attribute("beginning_timestamp_token_id", tokenizer.convert_tokens_to_ids(["<|0.00|>"])[0]), helper.make_attribute("no_repeat_ngram_size", args.no_repeat_ngram_size), helper.make_attribute("early_stopping", True), helper.make_attribute("model_type", 2), @@ -178,30 +182,30 @@ def chain_model(args): "decoder_input_ids", TensorProto.INT32, ["batch_size", "initial_sequence_length"] ) logits_processor = helper.make_tensor_value_info("logits_processor", TensorProto.INT32, [1]) - cross_qk_layer_head = helper.make_tensor_value_info( - "cross_qk_layer_head", TensorProto.INT32, ["num_layer_head", 2] - ) + cross_qk_layer_head = helper.make_tensor_value_info("cross_qk_layer_head", TensorProto.INT32, ["num_layer_head", 2]) extra_decoding_ids = helper.make_tensor_value_info( "extra_decoding_ids", TensorProto.INT32, ["batch_size", "extra_decoding_ids_len"] ) temperature = helper.make_tensor_value_info("temperature", TensorProto.FLOAT, [1]) - - graph_inputs = clean_list([ - input_features, - max_length, - min_length, - num_beams, - num_return_sequences, - length_penalty, - repetition_penalty, - vocab_mask if args.use_vocab_mask else "", - prefix_vocab_mask if args.use_prefix_vocab_mask else "", - decoder_input_ids if args.use_forced_decoder_ids else "", - logits_processor if args.use_logits_processor else "", - cross_qk_layer_head if args.collect_cross_qk else "", - extra_decoding_ids if args.extra_decoding_ids else "", - temperature if args.use_temperature else "", - ]) + + graph_inputs = clean_list( + [ + input_features, + max_length, + min_length, + num_beams, + num_return_sequences, + length_penalty, + repetition_penalty, + vocab_mask if args.use_vocab_mask else "", + prefix_vocab_mask if args.use_prefix_vocab_mask else "", + decoder_input_ids if args.use_forced_decoder_ids else "", + logits_processor if args.use_logits_processor else "", + cross_qk_layer_head if args.collect_cross_qk else "", + extra_decoding_ids if args.extra_decoding_ids else "", + temperature if args.use_temperature else "", + ] + ) # Graph outputs sequences = helper.make_tensor_value_info( @@ -216,13 +220,15 @@ def chain_model(args): ) no_speech_probs = helper.make_tensor_value_info("no_speech_probs", TensorProto.FLOAT, ["batch_size"]) - graph_outputs = clean_list([ - sequences, - sequence_scores if args.output_sequence_scores else "", - scores if args.output_scores else "", - cross_qk if args.output_cross_qk or (not args.cross_qk_onnx_model and args.collect_cross_qk) else "", - no_speech_probs if args.output_no_speech_probs else "", - ]) + graph_outputs = clean_list( + [ + sequences, + sequence_scores if args.output_sequence_scores else "", + scores if args.output_scores else "", + cross_qk if args.output_cross_qk or (not args.cross_qk_onnx_model and args.collect_cross_qk) else "", + no_speech_probs if args.output_no_speech_probs else "", + ] + ) # Replace MultiHeadAttention with DecoderMaskedMultiHeadAttention for CUDA EP inference if hasattr(args, "use_gpu") and args.use_gpu: diff --git a/onnxruntime/python/tools/transformers/models/whisper/whisper_helper.py b/onnxruntime/python/tools/transformers/models/whisper/whisper_helper.py index e127dc4a85f72..69800d818155f 100644 --- a/onnxruntime/python/tools/transformers/models/whisper/whisper_helper.py +++ b/onnxruntime/python/tools/transformers/models/whisper/whisper_helper.py @@ -363,7 +363,7 @@ def verify_onnx( ort_transcription = processor.batch_decode(ort_outputs, skip_special_tokens=True)[0] parity = ( - pt_transcription in expected_transcription_no_comma and ort_transcription in expected_transcription_with_comma + pt_transcription in expected_transcription_options and ort_transcription in expected_transcription_options ) max_diff = 0 From 781cab75a477faf70ae5939510336fa93b17631b Mon Sep 17 00:00:00 2001 From: Kunal Vaishnavi Date: Tue, 13 Feb 2024 22:32:02 +0000 Subject: [PATCH 10/19] Fix CodeQL warnings --- .../python/tools/transformers/models/whisper/whisper_chain.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/onnxruntime/python/tools/transformers/models/whisper/whisper_chain.py b/onnxruntime/python/tools/transformers/models/whisper/whisper_chain.py index 54293a0178676..6d25fab577143 100644 --- a/onnxruntime/python/tools/transformers/models/whisper/whisper_chain.py +++ b/onnxruntime/python/tools/transformers/models/whisper/whisper_chain.py @@ -78,8 +78,6 @@ def chain_model(args): "no_speech_probs_beam" if args.output_no_speech_probs else "", ] - input_features_cast_node, len_pen_cast_node, rep_pen_cast_node, temp_cast_node = None, None, None, None - output_sequence_scores_cast_node, output_scores_cast_node = None, None graph_nodes = [] if args.precision == Precision.FLOAT16: input_features_cast_node = helper.make_node( From 6df843003885280874eb43bc807ec0c84692ea17 Mon Sep 17 00:00:00 2001 From: Kunal Vaishnavi Date: Tue, 13 Feb 2024 22:39:06 +0000 Subject: [PATCH 11/19] Add copyright to Whisper scripts --- .../python/tools/transformers/models/whisper/benchmark.py | 6 ++++++ .../tools/transformers/models/whisper/benchmark_all.py | 6 ++++++ .../tools/transformers/models/whisper/whisper_chain.py | 6 ++++++ 3 files changed, 18 insertions(+) diff --git a/onnxruntime/python/tools/transformers/models/whisper/benchmark.py b/onnxruntime/python/tools/transformers/models/whisper/benchmark.py index df1a6d073fce2..2b7de46c0aed9 100644 --- a/onnxruntime/python/tools/transformers/models/whisper/benchmark.py +++ b/onnxruntime/python/tools/transformers/models/whisper/benchmark.py @@ -1,3 +1,9 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for +# license information. +# -------------------------------------------------------------------------- + import argparse import ast import datetime diff --git a/onnxruntime/python/tools/transformers/models/whisper/benchmark_all.py b/onnxruntime/python/tools/transformers/models/whisper/benchmark_all.py index d205a2d340721..814b0dd1ef6ac 100644 --- a/onnxruntime/python/tools/transformers/models/whisper/benchmark_all.py +++ b/onnxruntime/python/tools/transformers/models/whisper/benchmark_all.py @@ -1,3 +1,9 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for +# license information. +# -------------------------------------------------------------------------- + import argparse import datetime import json diff --git a/onnxruntime/python/tools/transformers/models/whisper/whisper_chain.py b/onnxruntime/python/tools/transformers/models/whisper/whisper_chain.py index 6d25fab577143..14691da4ad643 100644 --- a/onnxruntime/python/tools/transformers/models/whisper/whisper_chain.py +++ b/onnxruntime/python/tools/transformers/models/whisper/whisper_chain.py @@ -1,3 +1,9 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for +# license information. +# -------------------------------------------------------------------------- + import logging import os From 75a575bc763c502e1d7c5aad83721db31eccd6f3 Mon Sep 17 00:00:00 2001 From: Kunal Vaishnavi Date: Tue, 13 Feb 2024 23:47:14 +0000 Subject: [PATCH 12/19] Add updated contrib ops doc --- docs/ContribOperators.md | 16 +++++++++++++--- 1 file changed, 13 insertions(+), 3 deletions(-) diff --git a/docs/ContribOperators.md b/docs/ContribOperators.md index e7b537d6894c8..49f91267dada7 100644 --- a/docs/ContribOperators.md +++ b/docs/ContribOperators.md @@ -5743,12 +5743,14 @@ This version of the operator has been available since version 1 of the 'com.micr #### Attributes
+
beginning_timestamp_token_id : int (required)
+
The id of the first timestamp
decoder : graph (required)
Decoder subgraph to execute in a loop.
decoder_output_cross_qk : int
If nozero, decoder subgraph contains output Q*K from cross attentions. Default 0.
decoder_start_token_id : int
-
The id of the token that indicates decoding starts.
+
The id of the token that indicates decoding starts (i.e. the start of transcription token id)
early_stopping : int
early stop or not
encoder : graph
@@ -5761,10 +5763,18 @@ This version of the operator has been available since version 1 of the 'com.micr
Must be 2 for whisper
no_repeat_ngram_size : int
no repeat ngrams size
-
no_speech_token : int
+
no_speech_token_id : int
The token in whisper model that marks all sequence empty. With this model, whisper could output no_speech_prob after. Default -1.
+
no_timestamps_token_id : int (required)
+
The id of the token that indicates no timestamps
pad_token_id : int (required)
The id of the padding token
+
start_of_lm_token_id : int (required)
+
The id of the token that indicates LM starts
+
transcribe_token_id : int (required)
+
The id of the transcribe task
+
translate_token_id : int (required)
+
The id of the translate task
vocab_size : int
Size of the vocabulary. If not provided, it will be inferred from the decoder subgraph's output shape
@@ -5816,7 +5826,7 @@ This version of the operator has been available since version 1 of the 'com.micr
cross_qk (optional) : V
Output the accumulated stacked Q*K in cross attentions. Let H = number of Head of cross attention, F = the frames or kv-seq-len of the cross attention input, T = real decoded token length, L = number of layers,B = batch size, R = num_return_sequences. It then should return tensor of shape [B, R, L*H, T, F].If cross_qk_layer_head is given, shape is [B, R, cross_qk_layer_head.shape[0], T, F]
non_speech_probs (optional) : T
-
For whisper model, output the probabilities from logits after encoder and context decoding for the no_speech_token.Currently we treat the last token's logits is what we need, in future extra graph logic may be add to the encoder/context-decoder subgraph.The prob is save before logits may be updated by extra-decoding-ids. The shape of non_speech_probs is [B]
+
For whisper model, output the probabilities from logits after encoder and context decoding for the no_speech_token_id.Currently we treat the last token's logits is what we need, in future extra graph logic may be add to the encoder/context-decoder subgraph.The prob is save before logits may be updated by extra-decoding-ids. The shape of non_speech_probs is [B]
#### Type Constraints From c009bbcff4a95a5cd3d9063bdccfcc2193b0f686 Mon Sep 17 00:00:00 2001 From: Kunal Vaishnavi Date: Wed, 14 Feb 2024 05:47:06 +0000 Subject: [PATCH 13/19] Make token id attributes optional --- onnxruntime/core/graph/contrib_ops/contrib_defs.cc | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/onnxruntime/core/graph/contrib_ops/contrib_defs.cc b/onnxruntime/core/graph/contrib_ops/contrib_defs.cc index 5b429c22442ff..11cd27e0c8b36 100644 --- a/onnxruntime/core/graph/contrib_ops/contrib_defs.cc +++ b/onnxruntime/core/graph/contrib_ops/contrib_defs.cc @@ -1189,14 +1189,14 @@ ONNX_MS_OPERATOR_SET_SCHEMA(WhisperBeamSearch, 1, .Attr("eos_token_id", "The id of the end-of-sequence token", AttributeProto::INT) .Attr("pad_token_id", "The id of the padding token", AttributeProto::INT) .Attr("decoder_start_token_id", "The id of the token that indicates decoding starts (i.e. the start of transcription token id)", AttributeProto::INT, static_cast(-1)) - .Attr("translate_token_id", "The id of the translate task", AttributeProto::INT) - .Attr("transcribe_token_id", "The id of the transcribe task", AttributeProto::INT) - .Attr("start_of_lm_token_id", "The id of the token that indicates LM starts", AttributeProto::INT) + .Attr("translate_token_id", "The id of the translate task", AttributeProto::INT, OPTIONAL_VALUE) + .Attr("transcribe_token_id", "The id of the transcribe task", AttributeProto::INT, OPTIONAL_VALUE) + .Attr("start_of_lm_token_id", "The id of the token that indicates LM starts", AttributeProto::INT, OPTIONAL_VALUE) .Attr("no_speech_token_id", "The token in whisper model that marks all sequence empty. With this model, whisper could output no_speech_prob after. Default -1.", AttributeProto::INT, OPTIONAL_VALUE) - .Attr("no_timestamps_token_id", "The id of the token that indicates no timestamps", AttributeProto::INT) - .Attr("beginning_timestamp_token_id", "The id of the first timestamp", AttributeProto::INT) + .Attr("no_timestamps_token_id", "The id of the token that indicates no timestamps", AttributeProto::INT, OPTIONAL_VALUE) + .Attr("beginning_timestamp_token_id", "The id of the first timestamp", AttributeProto::INT, OPTIONAL_VALUE) .Attr("no_repeat_ngram_size", "no repeat ngrams size", AttributeProto::INT, static_cast(0)) .Attr("early_stopping", "early stop or not", AttributeProto::INT, static_cast(0)) .Attr("model_type", "Must be 2 for whisper", AttributeProto::INT, static_cast(2)) From 3c63a51d78d254cca5a9e006bc27b7be72f4db2e Mon Sep 17 00:00:00 2001 From: Kunal Vaishnavi Date: Wed, 14 Feb 2024 06:29:20 +0000 Subject: [PATCH 14/19] Add updated contrib ops doc --- docs/ContribOperators.md | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/docs/ContribOperators.md b/docs/ContribOperators.md index 49f91267dada7..98f1cf4220e65 100644 --- a/docs/ContribOperators.md +++ b/docs/ContribOperators.md @@ -5743,7 +5743,7 @@ This version of the operator has been available since version 1 of the 'com.micr #### Attributes
-
beginning_timestamp_token_id : int (required)
+
beginning_timestamp_token_id : int
The id of the first timestamp
decoder : graph (required)
Decoder subgraph to execute in a loop.
@@ -5765,15 +5765,15 @@ This version of the operator has been available since version 1 of the 'com.micr
no repeat ngrams size
no_speech_token_id : int
The token in whisper model that marks all sequence empty. With this model, whisper could output no_speech_prob after. Default -1.
-
no_timestamps_token_id : int (required)
+
no_timestamps_token_id : int
The id of the token that indicates no timestamps
pad_token_id : int (required)
The id of the padding token
-
start_of_lm_token_id : int (required)
+
start_of_lm_token_id : int
The id of the token that indicates LM starts
-
transcribe_token_id : int (required)
+
transcribe_token_id : int
The id of the transcribe task
-
translate_token_id : int (required)
+
translate_token_id : int
The id of the translate task
vocab_size : int
Size of the vocabulary. If not provided, it will be inferred from the decoder subgraph's output shape
From b89c13c4f21b447a7e6f663e59fd75744ab0f690 Mon Sep 17 00:00:00 2001 From: Kunal Vaishnavi Date: Thu, 15 Feb 2024 02:14:54 +0000 Subject: [PATCH 15/19] Address PR feedback --- .../cpu/transformers/logits_processor.h | 62 +++++++++---------- .../core/graph/contrib_ops/contrib_defs.cc | 27 ++++---- .../transformers/models/whisper/README.md | 34 +++++----- .../transformers/models/whisper/benchmark.py | 2 +- .../models/whisper/requirements.txt | 2 +- 5 files changed, 63 insertions(+), 64 deletions(-) diff --git a/onnxruntime/contrib_ops/cpu/transformers/logits_processor.h b/onnxruntime/contrib_ops/cpu/transformers/logits_processor.h index 1c656bf61b440..231eb17d1a947 100644 --- a/onnxruntime/contrib_ops/cpu/transformers/logits_processor.h +++ b/onnxruntime/contrib_ops/cpu/transformers/logits_processor.h @@ -159,21 +159,21 @@ class PresencePenaltyLogitsProcessor : public ILogitsProcessor { template class TimestampLogitsProcessor : public ILogitsProcessor { public: - TimestampLogitsProcessor(int eos_token_id, // <|endoftext|> - int sot_token_id, // <|startoftranscript|> - int translate_token_id, // <|translate|> - int transcribe_token_id, // <|transcribe|> - int solm_token_id, // <|startoflm|> - int not_token_id, // <|notimestamps|> - int beg_token_id, // <|0.00|> + TimestampLogitsProcessor(int end_of_text_token_id, // <|endoftext|> + int start_of_transcript_token_id, // <|startoftranscript|> + int translate_token_id, // <|translate|> + int transcribe_token_id, // <|transcribe|> + int start_of_lm_token_id, // <|startoflm|> + int no_timestamps_token_id, // <|notimestamps|> + int beginning_timestamp_token_id, // <|0.00|> int max_initial_timestamp_index) - : eos_token_id_(eos_token_id), - sot_token_id_(sot_token_id), + : end_of_text_token_id_(end_of_text_token_id), + start_of_transcript_token_id_(start_of_transcript_token_id), translate_token_id_(translate_token_id), transcribe_token_id_(transcribe_token_id), - solm_token_id_(solm_token_id), - not_token_id_(not_token_id), - beg_token_id_(beg_token_id), + start_of_lm_token_id_(start_of_lm_token_id), + no_timestamps_token_id_(no_timestamps_token_id), + beginning_timestamp_token_id_(beginning_timestamp_token_id), max_initial_timestamp_index_(max_initial_timestamp_index) {} void Process(const ISequences* sequences, @@ -189,7 +189,7 @@ class TimestampLogitsProcessor : public ILogitsProcessor { size_t sample_begin = 0; for (size_t j = 0; j < seq_length; j++) { sample_begin++; - if (sequence[j] >= beg_token_id_) { + if (sequence[j] >= beginning_timestamp_token_id_) { break; } } @@ -197,30 +197,30 @@ class TimestampLogitsProcessor : public ILogitsProcessor { // Suppress tokens for (int j = 0; j < vocab_size; j++) { // Suppress notimestamps and solm tokens - if (j == not_token_id_ || j == solm_token_id_) { + if (j == no_timestamps_token_id_ || j == start_of_lm_token_id_) { beam_token_scores[j] = std::numeric_limits::lowest(); } // Suppress sot, translate and transcribe tokens if (seq_length > sample_begin) { - if (j == sot_token_id_ || j == translate_token_id_ || j == transcribe_token_id_) { + if (j == start_of_transcript_token_id_ || j == translate_token_id_ || j == transcribe_token_id_) { beam_token_scores[j] = std::numeric_limits::lowest(); } } } // Timestamps should be in pair except the first one - const bool last_was_timestamp = seq_length > 0 && sequence.back() >= beg_token_id_; - const bool penultimate_was_timestamp = seq_length <= sample_begin || sequence[seq_length - 2] >= beg_token_id_; + const bool last_was_timestamp = seq_length > 0 && sequence.back() >= beginning_timestamp_token_id_; + const bool penultimate_was_timestamp = seq_length <= sample_begin || sequence[seq_length - 2] >= beginning_timestamp_token_id_; if (last_was_timestamp) { if (penultimate_was_timestamp) { // If timestamps show up in pair, or it's the first timestamp, no more timestamp is generated - for (int j = beg_token_id_; j < vocab_size; j++) { + for (int j = beginning_timestamp_token_id_; j < vocab_size; j++) { beam_token_scores[j] = std::numeric_limits::lowest(); } } else { // If timestamp doesn't show up in pair, generate timestamp - for (int j = 0; j < eos_token_id_; j++) { + for (int j = 0; j < end_of_text_token_id_; j++) { beam_token_scores[j] = std::numeric_limits::lowest(); } } @@ -229,7 +229,7 @@ class TimestampLogitsProcessor : public ILogitsProcessor { // Find timestamp tokens std::vector timestamps; for (const auto& word_id : sequence) { - if (word_id >= beg_token_id_) { + if (word_id >= beginning_timestamp_token_id_) { timestamps.push_back(word_id); } } @@ -246,13 +246,13 @@ class TimestampLogitsProcessor : public ILogitsProcessor { timestamp_last = timestamps.back() + 1; } - for (int j = beg_token_id_; j < timestamp_last; j++) { + for (int j = beginning_timestamp_token_id_; j < timestamp_last; j++) { beam_token_scores[j] = std::numeric_limits::lowest(); } } if (seq_length == sample_begin) { - const int last_allowed = beg_token_id_ + max_initial_timestamp_index_; + const int last_allowed = beginning_timestamp_token_id_ + max_initial_timestamp_index_; for (int j = last_allowed + 1; j < vocab_size; j++) { beam_token_scores[j] = std::numeric_limits::lowest(); } @@ -262,8 +262,8 @@ class TimestampLogitsProcessor : public ILogitsProcessor { float timestamp_logprob = std::numeric_limits::lowest(); { float logsumexp = 0.0f; - const float logprob_max = *std::max_element(beam_token_scores.begin() + beg_token_id_, beam_token_scores.end()); - for (int j = beg_token_id_; j < vocab_size; ++j) { + const float logprob_max = *std::max_element(beam_token_scores.begin() + beginning_timestamp_token_id_, beam_token_scores.end()); + for (int j = beginning_timestamp_token_id_; j < vocab_size; ++j) { if (beam_token_scores[j] > std::numeric_limits::lowest()) { logsumexp += expf(beam_token_scores[j] - logprob_max); } @@ -273,9 +273,9 @@ class TimestampLogitsProcessor : public ILogitsProcessor { } } - const float max_text_token_logprob = *std::max_element(beam_token_scores.begin(), beam_token_scores.begin() + beg_token_id_); + const float max_text_token_logprob = *std::max_element(beam_token_scores.begin(), beam_token_scores.begin() + beginning_timestamp_token_id_); if (timestamp_logprob > max_text_token_logprob) { - for (int j = 0; j < beg_token_id_; ++j) { + for (int j = 0; j < beginning_timestamp_token_id_; ++j) { beam_token_scores[j] = std::numeric_limits::lowest(); } } @@ -283,13 +283,13 @@ class TimestampLogitsProcessor : public ILogitsProcessor { } private: - int eos_token_id_; - int sot_token_id_; + int end_of_text_token_id_; + int start_of_transcript_token_id_; int translate_token_id_; int transcribe_token_id_; - int solm_token_id_; - int not_token_id_; - int beg_token_id_; + int start_of_lm_token_id_; + int no_timestamps_token_id_; + int beginning_timestamp_token_id_; int max_initial_timestamp_index_; }; diff --git a/onnxruntime/core/graph/contrib_ops/contrib_defs.cc b/onnxruntime/core/graph/contrib_ops/contrib_defs.cc index 11cd27e0c8b36..e33ce20737f80 100644 --- a/onnxruntime/core/graph/contrib_ops/contrib_defs.cc +++ b/onnxruntime/core/graph/contrib_ops/contrib_defs.cc @@ -1163,7 +1163,7 @@ ONNX_MS_OPERATOR_SET_SCHEMA(BeamSearch, 1, "Shape is (1,)", "T", OpSchema::Optional) .Input(6, "repetition_penalty", "The parameter for repetition penalty. Default value 1.0 means no penalty. Accepts value > 0.0. Shape is (1)", "T", OpSchema::Optional) - .Input(7, "vocab_mask", "Mask of vocabulary. Words that masked with 0 are not allowed to be generated, and 1 is allowed. Shape is (vacab_size)", "M", OpSchema::Optional) + .Input(7, "vocab_mask", "Mask of vocabulary. Words that masked with 0 are not allowed to be generated, and 1 is allowed. Shape is (vocab_size)", "M", OpSchema::Optional) .Input(8, "prefix_vocab_mask", "Mask of vocabulary for first step. Words that masked with 0 are not allowed to be generated, and 1 is allowed. Shape is (batch_size, vocab_size)", "M", OpSchema::Optional) .Input(9, "attention_mask", "Custom attention mask. Shape is (batch_size, sequence_length)", "I", OpSchema::Optional) .Input(10, "decoder_input_ids", "The forced input id sequence for the decoder subgraph. Shape is (batch_size, initial_sequence_length)", "I", OpSchema::Optional) @@ -1217,18 +1217,18 @@ ONNX_MS_OPERATOR_SET_SCHEMA(WhisperBeamSearch, 1, .Input(3, "num_beams", "Number of beams for beam search. 1 means no beam search. Shape is (1)", "I") .Input(4, "num_return_sequences", "The number of returned sequences in the batch. Shape is (1)", "I") .Input(5, "length_penalty", - "Exponential penalty to the length. Default value 1.0 means no penalty." - "Value > 1.0 encourages longer sequences, while values < 1.0 produces shorter sequences." + "Exponential penalty to the length. Default value 1.0 means no penalty. " + "Value > 1.0 encourages longer sequences, while values < 1.0 produces shorter sequences. " "Shape is (1,)", "T", OpSchema::Optional) .Input(6, "repetition_penalty", "The parameter for repetition penalty. Default value 1.0 means no penalty. Accepts value > 0.0. Shape is (1)", "T", OpSchema::Optional) - .Input(7, "vocab_mask", "Mask of vocabulary. Words that masked with 0 are not allowed to be generated, and 1 is allowed. Shape is (vacab_size)", "M", OpSchema::Optional) + .Input(7, "vocab_mask", "Mask of vocabulary. Words that masked with 0 are not allowed to be generated, and 1 is allowed. Shape is (vocab_size)", "M", OpSchema::Optional) .Input(8, "prefix_vocab_mask", "Mask of vocabulary for first step. Words that masked with 0 are not allowed to be generated, and 1 is allowed. Shape is (batch_size, vocab_size)", "M", OpSchema::Optional) .Input(9, "attention_mask", "Custom attention mask. Shape is (batch_size, sequence_length)", "I", OpSchema::Optional) .Input(10, "decoder_input_ids", "The forced input id sequence for the decoder subgraph. Shape is (batch_size, initial_sequence_length)", "I", OpSchema::Optional) .Input(11, "logits_processor", "Specific logits processor for different types of beamsearch models. Default value 0 means no specific logit processor. Accepts value >= 0. Shape is (1)", "I", OpSchema::Optional) .Input(12, "cross_qk_layer_head", - "Only keep this list of (layer, head) of QK in the final cross_qk output when use_cross_qk is set. Default collect all" + "Only keep this list of (layer, head) of QK in the final cross_qk output when use_cross_qk is set. Default collect all " "its shape is (number of (layer, head) to keep, 2), i.e., [[layer_id1, head_id1], [layer_id2, head_id2]......]", "I", OpSchema::Optional) .Input(13, "extra_decoding_ids", @@ -1240,20 +1240,19 @@ ONNX_MS_OPERATOR_SET_SCHEMA(WhisperBeamSearch, 1, .Output(0, "sequences", "Word IDs of generated sequences. Shape is (batch_size, num_return_sequences, max_sequence_length)", "I") .Output(1, "sequences_scores", "Final beam score of the generated sequences. Shape is (batch_size, num_return_sequences)", "T", OpSchema::Optional) .Output(2, "scores", - "Processed beam scores for each vocabulary token at each generation step." - "Beam scores consisting of log softmax scores for each vocabulary token and sum of log softmax of previously generated tokens in this beam." + "Processed beam scores for each vocabulary token at each generation step. " + "Beam scores consisting of log softmax scores for each vocabulary token and sum of log softmax of previously generated tokens in this beam. " "Shape is (max_length - sequence_length, batch_size, num_beams, vocab_size)", "T", OpSchema::Optional) .Output(3, "cross_qk", "Output the accumulated stacked Q*K in cross attentions. Let H = number of Head of cross attention, " - "F = the frames or kv-seq-len of the cross attention input, T = real decoded token length, L = number of layers," - "B = batch size, R = num_return_sequences. It then should return tensor of shape [B, R, L*H, T, F]." + "F = the frames or kv-seq-len of the cross attention input, T = real decoded token length, L = number of layers, " + "B = batch size, R = num_return_sequences. It then should return tensor of shape [B, R, L*H, T, F]. " "If cross_qk_layer_head is given, shape is [B, R, cross_qk_layer_head.shape[0], T, F]", "V", OpSchema::Optional) .Output(4, "non_speech_probs", - "For whisper model, output the probabilities from logits after encoder and context decoding for the no_speech_token_id." - "Currently we treat the last token's logits is what we need, in future extra graph logic may be add to the encoder/context-decoder subgraph." - "The prob is save before logits may be updated by extra-decoding-ids. The shape of non_speech_probs is [B]", + "For whisper model, output the probabilities from logits after encoder and context decoding for the no_speech_token_id. " + "The shape of non_speech_probs is [B]", "T", OpSchema::Optional) .TypeConstraint("T", {"tensor(float)", "tensor(float16)"}, "Constrain to float tensors.") .TypeConstraint("F", {"tensor(float)", "tensor(int32)", "tensor(float16)"}, "Constrain input type to float or int tensors.") @@ -1327,7 +1326,7 @@ ONNX_MS_OPERATOR_SET_SCHEMA(GreedySearch, 1, .Input(1, "max_length", "The maximum length of the sequence to be generated. Shape is (1)", "I") .Input(2, "min_length", "The minimum length below which the score of eos_token_id is set to -Inf. Shape is (1)", "I", OpSchema::Optional) .Input(3, "repetition_penalty", "The parameter for repetition penalty. Default value 1.0 means no penalty. Accepts value > 0.0. Shape is (1)", "T", OpSchema::Optional) - .Input(4, "vocab_mask", "Mask of vocabulary. Words that masked with 0 are not allowed to be generated, and 1 is allowed. Shape is (vacab_size)", "I", OpSchema::Optional) + .Input(4, "vocab_mask", "Mask of vocabulary. Words that masked with 0 are not allowed to be generated, and 1 is allowed. Shape is (vocab_size)", "I", OpSchema::Optional) .Input(5, "prefix_vocab_mask", "Mask of vocabulary for first step. Words that masked with 0 are not allowed to be generated, and 1 is allowed. Shape is (batch_size, vocab_size)", "I", OpSchema::Optional) .Input(6, "attention_mask", "Custom attention mask. Shape is (batch_size, sequence_length)", "I", OpSchema::Optional) .Output(0, "sequences", "Word IDs of generated sequences. Shape is (batch_size, max_sequence_length)", "I") @@ -1368,7 +1367,7 @@ ONNX_MS_OPERATOR_SET_SCHEMA(Sampling, 1, .Input(1, "max_length", "The maximum length of the sequence to be generated. Shape is (1)", "I") .Input(2, "min_length", "The minimum length below which the score of eos_token_id is set to -Inf. Shape is (1)", "I", OpSchema::Optional) .Input(3, "repetition_penalty", "The parameter for repetition penalty. Default value 1.0 means no penalty. Accepts value > 0.0. Shape is (1)", "T", OpSchema::Optional) - .Input(4, "vocab_mask", "Mask of vocabulary. Words that masked with 0 are not allowed to be generated, and 1 is allowed. Shape is (vacab_size)", "I", OpSchema::Optional) + .Input(4, "vocab_mask", "Mask of vocabulary. Words that masked with 0 are not allowed to be generated, and 1 is allowed. Shape is (vocab_size)", "I", OpSchema::Optional) .Input(5, "prefix_vocab_mask", "Mask of vocabulary for first step. Words that masked with 0 are not allowed to be generated, and 1 is allowed. Shape is (batch_size, vocab_size)", "I", OpSchema::Optional) .Input(6, "attention_mask", "Custom attention mask. Shape is (batch_size, sequence_length)", "I", OpSchema::Optional) .Input(7, "presence_mask", "Presence penalty mask. Shape is (batch_size, vocab_size)", "I", OpSchema::Optional) diff --git a/onnxruntime/python/tools/transformers/models/whisper/README.md b/onnxruntime/python/tools/transformers/models/whisper/README.md index 3219697afdf41..cbd1bc8d88c19 100644 --- a/onnxruntime/python/tools/transformers/models/whisper/README.md +++ b/onnxruntime/python/tools/transformers/models/whisper/README.md @@ -10,10 +10,10 @@ There are several ways to export Whisper with beam search (using Whisper tiny as # From source $ git clone https://github.com/microsoft/onnxruntime $ cd onnxruntime/onnxruntime/python/tools/transformers/ -$ python3 -m models.whisper.convert_to_onnx -m openai/whisper-tiny --output whispertiny --use_external_data_format +$ python3 -m models.whisper.convert_to_onnx -m openai/whisper-large-v3 --output whisperlargev3 --use_external_data_format # From wheel -$ python3 -m onnxruntime.transformers.models.whisper.convert_to_onnx -m openai/whisper-tiny --output whispertiny --use_external_data_format +$ python3 -m onnxruntime.transformers.models.whisper.convert_to_onnx -m openai/whisper-large-v3 --output whisperlargev3 --use_external_data_format ``` ### Option 2: end-to-end model from [Olive](https://github.com/microsoft/Olive/tree/main/examples/whisper) @@ -39,49 +39,49 @@ model.save_pretrained(model_name.split("/")[-1] + "-onnx") Here are some additional examples for exporting Whisper with beam search. -Export with Forced Decoder Input Ids +To see all available options ``` # From source: -$ python3 -m models.whisper.convert_to_onnx -m openai/whisper-tiny --output whispertiny --use_external_data_format --use_forced_decoder_ids +$ python3 -m models.whisper.convert_to_onnx --help # From wheel: -$ python3 -m onnxruntime.transformers.models.whisper.convert_to_onnx -m openai/whisper-tiny --output whispertiny --use_external_data_format --use_forced_decoder_ids +$ python3 -m onnxruntime.transformers.models.whisper.convert_to_onnx --help ``` -Export + Optimize for FP32 +Export with Forced Decoder Input Ids ``` # From source: -$ python3 -m models.whisper.convert_to_onnx -m openai/whisper-tiny --output whispertiny --use_external_data_format --optimize_onnx --precision fp32 +$ python3 -m models.whisper.convert_to_onnx -m openai/whisper-large-v3 --output whisperlargev3 --use_external_data_format --use_forced_decoder_ids # From wheel: -$ python3 -m onnxruntime.transformers.models.whisper.convert_to_onnx -m openai/whisper-tiny --output whispertiny --use_external_data_format --optimize_onnx --precision fp32 +$ python3 -m onnxruntime.transformers.models.whisper.convert_to_onnx -m openai/whisper-large-v3 --output whisperlargev3 --use_external_data_format --use_forced_decoder_ids ``` -Export + Optimize for FP16 and GPU +Export + Optimize for FP32 ``` # From source: -$ python3 -m models.whisper.convert_to_onnx -m openai/whisper-tiny --output whispertiny --use_external_data_format --optimize_onnx --precision fp16 --use_gpu --provider cuda --disable_auto_mixed_precision +$ python3 -m models.whisper.convert_to_onnx -m openai/whisper-large-v3 --output whisperlargev3 --use_external_data_format --optimize_onnx --precision fp32 # From wheel: -$ python3 -m onnxruntime.transformers.models.whisper.convert_to_onnx -m openai/whisper-tiny --output whispertiny --use_external_data_format --optimize_onnx --precision fp16 --use_gpu --provider cuda --disable_auto_mixed_precision +$ python3 -m onnxruntime.transformers.models.whisper.convert_to_onnx -m openai/whisper-large-v3 --output whisperlargev3 --use_external_data_format --optimize_onnx --precision fp32 ``` -Export + Quantize for INT8 +Export + Optimize for FP16 and GPU ``` # From source: -$ python3 -m models.whisper.convert_to_onnx -m openai/whisper-tiny --output whispertiny --use_external_data_format --precision int8 --quantize_embedding_layer +$ python3 -m models.whisper.convert_to_onnx -m openai/whisper-large-v3 --output whisperlargev3 --use_external_data_format --optimize_onnx --precision fp16 --use_gpu --provider cuda --disable_auto_mixed_precision # From wheel: -$ python3 -m onnxruntime.transformers.models.whisper.convert_to_onnx -m openai/whisper-tiny --output whispertiny --use_external_data_format --precision int8 --quantize_embedding_layer +$ python3 -m onnxruntime.transformers.models.whisper.convert_to_onnx -m openai/whisper-large-v3 --output whisperlargev3 --use_external_data_format --optimize_onnx --precision fp16 --use_gpu --provider cuda --disable_auto_mixed_precision ``` -To see all available options +Export + Quantize for INT8 ``` # From source: -$ python3 -m models.whisper.convert_to_onnx --help +$ python3 -m models.whisper.convert_to_onnx -m openai/whisper-large-v3 --output whisperlargev3 --use_external_data_format --precision int8 --quantize_embedding_layer # From wheel: -$ python3 -m onnxruntime.transformers.models.whisper.convert_to_onnx --help +$ python3 -m onnxruntime.transformers.models.whisper.convert_to_onnx -m openai/whisper-large-v3 --output whisperlargev3 --use_external_data_format --precision int8 --quantize_embedding_layer ``` ## Benchmark Whisper diff --git a/onnxruntime/python/tools/transformers/models/whisper/benchmark.py b/onnxruntime/python/tools/transformers/models/whisper/benchmark.py index 2b7de46c0aed9..e57385aa6db8f 100644 --- a/onnxruntime/python/tools/transformers/models/whisper/benchmark.py +++ b/onnxruntime/python/tools/transformers/models/whisper/benchmark.py @@ -509,7 +509,7 @@ def parse_args(): "--logits-processor", type=int, default=1, - help="Type of logits processor to use. See `WhisperBeamSearch` in https://github.com/microsoft/onnxruntime/blob/main/onnxruntime/core/graph/contrib_ops/contrib_defs.cc for details.", + help="Whether to use timestamps logits processor or not (0 for false, 1 for true).", ) parser.add_argument( "--temperature", diff --git a/onnxruntime/python/tools/transformers/models/whisper/requirements.txt b/onnxruntime/python/tools/transformers/models/whisper/requirements.txt index 3fcdbdfd56ea8..c81e1f0a7c374 100644 --- a/onnxruntime/python/tools/transformers/models/whisper/requirements.txt +++ b/onnxruntime/python/tools/transformers/models/whisper/requirements.txt @@ -13,7 +13,7 @@ datasets soundfile librosa optimum -onnxruntime>=1.15.1 +onnxruntime-gpu>=1.17.0 onnxruntime-extensions>=0.9.0 protobuf==3.20.2 numpy==1.23.3 \ No newline at end of file From 55787452a24678ae0d3b0e455d09386ece74b613 Mon Sep 17 00:00:00 2001 From: Kunal Vaishnavi Date: Thu, 15 Feb 2024 03:00:06 +0000 Subject: [PATCH 16/19] Add updated docs and fix linter warnings --- docs/ContribOperators.md | 18 +++++++++--------- docs/OperatorKernels.md | 4 ++-- .../models/whisper/whisper_helper.py | 11 ++++------- 3 files changed, 15 insertions(+), 18 deletions(-) diff --git a/docs/ContribOperators.md b/docs/ContribOperators.md index 98f1cf4220e65..f523e97293427 100644 --- a/docs/ContribOperators.md +++ b/docs/ContribOperators.md @@ -461,7 +461,7 @@ This version of the operator has been available since version 1 of the 'com.micr
repetition_penalty (optional) : T
The parameter for repetition penalty. Default value 1.0 means no penalty. Accepts value > 0.0. Shape is (1)
vocab_mask (optional) : M
-
Mask of vocabulary. Words that masked with 0 are not allowed to be generated, and 1 is allowed. Shape is (vacab_size)
+
Mask of vocabulary. Words that masked with 0 are not allowed to be generated, and 1 is allowed. Shape is (vocab_size)
prefix_vocab_mask (optional) : M
Mask of vocabulary for first step. Words that masked with 0 are not allowed to be generated, and 1 is allowed. Shape is (batch_size, vocab_size)
attention_mask (optional) : I
@@ -2252,7 +2252,7 @@ This version of the operator has been available since version 1 of the 'com.micr
repetition_penalty (optional) : T
The parameter for repetition penalty. Default value 1.0 means no penalty. Accepts value > 0.0. Shape is (1)
vocab_mask (optional) : I
-
Mask of vocabulary. Words that masked with 0 are not allowed to be generated, and 1 is allowed. Shape is (vacab_size)
+
Mask of vocabulary. Words that masked with 0 are not allowed to be generated, and 1 is allowed. Shape is (vocab_size)
prefix_vocab_mask (optional) : I
Mask of vocabulary for first step. Words that masked with 0 are not allowed to be generated, and 1 is allowed. Shape is (batch_size, vocab_size)
attention_mask (optional) : I
@@ -5154,7 +5154,7 @@ This version of the operator has been available since version 1 of the 'com.micr
repetition_penalty (optional) : T
The parameter for repetition penalty. Default value 1.0 means no penalty. Accepts value > 0.0. Shape is (1)
vocab_mask (optional) : I
-
Mask of vocabulary. Words that masked with 0 are not allowed to be generated, and 1 is allowed. Shape is (vacab_size)
+
Mask of vocabulary. Words that masked with 0 are not allowed to be generated, and 1 is allowed. Shape is (vocab_size)
prefix_vocab_mask (optional) : I
Mask of vocabulary for first step. Words that masked with 0 are not allowed to be generated, and 1 is allowed. Shape is (batch_size, vocab_size)
attention_mask (optional) : I
@@ -5793,11 +5793,11 @@ This version of the operator has been available since version 1 of the 'com.micr
num_return_sequences : I
The number of returned sequences in the batch. Shape is (1)
length_penalty (optional) : T
-
Exponential penalty to the length. Default value 1.0 means no penalty.Value > 1.0 encourages longer sequences, while values < 1.0 produces shorter sequences.Shape is (1,)
+
Exponential penalty to the length. Default value 1.0 means no penalty. Value > 1.0 encourages longer sequences, while values < 1.0 produces shorter sequences. Shape is (1,)
repetition_penalty (optional) : T
The parameter for repetition penalty. Default value 1.0 means no penalty. Accepts value > 0.0. Shape is (1)
vocab_mask (optional) : M
-
Mask of vocabulary. Words that masked with 0 are not allowed to be generated, and 1 is allowed. Shape is (vacab_size)
+
Mask of vocabulary. Words that masked with 0 are not allowed to be generated, and 1 is allowed. Shape is (vocab_size)
prefix_vocab_mask (optional) : M
Mask of vocabulary for first step. Words that masked with 0 are not allowed to be generated, and 1 is allowed. Shape is (batch_size, vocab_size)
attention_mask (optional) : I
@@ -5807,7 +5807,7 @@ This version of the operator has been available since version 1 of the 'com.micr
logits_processor (optional) : I
Specific logits processor for different types of beamsearch models. Default value 0 means no specific logit processor. Accepts value >= 0. Shape is (1)
cross_qk_layer_head (optional) : I
-
Only keep this list of (layer, head) of QK in the final cross_qk output when use_cross_qk is set. Default collect allits shape is (number of (layer, head) to keep, 2), i.e., [[layer_id1, head_id1], [layer_id2, head_id2]......]
+
Only keep this list of (layer, head) of QK in the final cross_qk output when use_cross_qk is set. Default collect all its shape is (number of (layer, head) to keep, 2), i.e., [[layer_id1, head_id1], [layer_id2, head_id2]......]
extra_decoding_ids (optional) : I
Part of the decoder_input_ids that we need cross qk for it. it is of shape (batch_size, extra_decoding_ids_len).In such case, we should remove this from the tail of the decoder_input_ids, and put it here. ids < 0 in it (for multiple batch) are treated as stop of the extra_decoding_ids for corresponding batch.
temperature (optional) : T
@@ -5822,11 +5822,11 @@ This version of the operator has been available since version 1 of the 'com.micr
sequences_scores (optional) : T
Final beam score of the generated sequences. Shape is (batch_size, num_return_sequences)
scores (optional) : T
-
Processed beam scores for each vocabulary token at each generation step.Beam scores consisting of log softmax scores for each vocabulary token and sum of log softmax of previously generated tokens in this beam.Shape is (max_length - sequence_length, batch_size, num_beams, vocab_size)
+
Processed beam scores for each vocabulary token at each generation step. Beam scores consisting of log softmax scores for each vocabulary token and sum of log softmax of previously generated tokens in this beam. Shape is (max_length - sequence_length, batch_size, num_beams, vocab_size)
cross_qk (optional) : V
-
Output the accumulated stacked Q*K in cross attentions. Let H = number of Head of cross attention, F = the frames or kv-seq-len of the cross attention input, T = real decoded token length, L = number of layers,B = batch size, R = num_return_sequences. It then should return tensor of shape [B, R, L*H, T, F].If cross_qk_layer_head is given, shape is [B, R, cross_qk_layer_head.shape[0], T, F]
+
Output the accumulated stacked Q*K in cross attentions. Let H = number of Head of cross attention, F = the frames or kv-seq-len of the cross attention input, T = real decoded token length, L = number of layers, B = batch size, R = num_return_sequences. It then should return tensor of shape [B, R, L*H, T, F]. If cross_qk_layer_head is given, shape is [B, R, cross_qk_layer_head.shape[0], T, F]
non_speech_probs (optional) : T
-
For whisper model, output the probabilities from logits after encoder and context decoding for the no_speech_token_id.Currently we treat the last token's logits is what we need, in future extra graph logic may be add to the encoder/context-decoder subgraph.The prob is save before logits may be updated by extra-decoding-ids. The shape of non_speech_probs is [B]
+
For whisper model, output the probabilities from logits after encoder and context decoding for the no_speech_token_id. The shape of non_speech_probs is [B]
#### Type Constraints diff --git a/docs/OperatorKernels.md b/docs/OperatorKernels.md index 2ea557b7d61fe..8ff2135c6b1f6 100644 --- a/docs/OperatorKernels.md +++ b/docs/OperatorKernels.md @@ -765,7 +765,7 @@ Do not modify directly.* |Sigmoid|*in* X:**T**
*out* Y:**T**|13+|**T** = tensor(bfloat16), tensor(double), tensor(float), tensor(float16)| |||[6, 12]|**T** = tensor(double), tensor(float), tensor(float16)| |Sign|*in* input:**T**
*out* output:**T**|13+|**T** = tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| -|SimplifiedLayerNormalization|*in* X:**T**
*in* scale:**V**
*out* Y:**V**
*out* inv_std_var:**U**|1+|**T** = tensor(double), tensor(float), tensor(float16)
**U** = tensor(double), tensor(float)
**V** = tensor(double), tensor(float), tensor(float16)| +|SimplifiedLayerNormalization|*in* X:**T**
*in* scale:**V**
*out* Y:**V**
*out* inv_std_var:**U**|1+|**T** = tensor(bfloat16), tensor(double), tensor(float), tensor(float16)
**U** = tensor(double), tensor(float)
**V** = tensor(bfloat16), tensor(double), tensor(float), tensor(float16)| |Sin|*in* input:**T**
*out* output:**T**|7+|**T** = tensor(double), tensor(float), tensor(float16)| |Size|*in* data:**T**
*out* size:**T1**|13+|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(string), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)
**T1** = tensor(int64)| |||[1, 12]|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(string), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)
**T1** = tensor(int64)| @@ -784,7 +784,7 @@ Do not modify directly.* |||[13, 17]|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| |||[11, 12]|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| |||[2, 10]|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| -|Sqrt|*in* X:**T**
*out* Y:**T**|13+|**T** = tensor(double), tensor(float), tensor(float16)| +|Sqrt|*in* X:**T**
*out* Y:**T**|13+|**T** = tensor(bfloat16), tensor(double), tensor(float), tensor(float16)| |||[6, 12]|**T** = tensor(double), tensor(float), tensor(float16)| |Squeeze|*in* data:**T**
*in* axes:**tensor(int64)**
*out* squeezed:**T**

or

*in* data:**T**
*out* squeezed:**T**|13+|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| |||[11, 12]|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| diff --git a/onnxruntime/python/tools/transformers/models/whisper/whisper_helper.py b/onnxruntime/python/tools/transformers/models/whisper/whisper_helper.py index 48f907acd9008..0c647d836a20f 100644 --- a/onnxruntime/python/tools/transformers/models/whisper/whisper_helper.py +++ b/onnxruntime/python/tools/transformers/models/whisper/whisper_helper.py @@ -6,12 +6,14 @@ import logging import os -import sys from pathlib import Path from typing import Dict, Tuple, Union import numpy as np import torch +from float16 import float_to_float16_max_diff +from onnx_model import OnnxModel +from optimizer import optimize_model from packaging import version from transformers import WhisperConfig, WhisperForConditionalGeneration, WhisperProcessor from transformers import __version__ as transformers_version @@ -21,11 +23,6 @@ from onnxruntime import InferenceSession -sys.path.append(os.path.join(os.path.dirname(__file__), "..", "..")) -from float16 import float_to_float16_max_diff -from onnx_model import OnnxModel -from optimizer import optimize_model - logger = logging.getLogger(__name__) PRETRAINED_WHISPER_MODELS = [ @@ -342,7 +339,7 @@ def verify_onnx( logger.warning(f"Could not import `datasets`. Attempting to install `datasets` via `{install_cmd}`.") os.system(install_cmd) - from datasets import load_dataset + from datasets import load_dataset # noqa: F811 ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation") input_features = processor([ds[0]["audio"]["array"]], return_tensors="pt").input_features From 9f9d8839de2a2a32b746e2279d9881670dbc2f34 Mon Sep 17 00:00:00 2001 From: Kunal Vaishnavi Date: Thu, 15 Feb 2024 03:18:20 +0000 Subject: [PATCH 17/19] Remove noqa to pass CI lintrunner and ignore local lintrunner failure --- .../python/tools/transformers/models/whisper/whisper_helper.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/onnxruntime/python/tools/transformers/models/whisper/whisper_helper.py b/onnxruntime/python/tools/transformers/models/whisper/whisper_helper.py index 0c647d836a20f..5565cf4a65250 100644 --- a/onnxruntime/python/tools/transformers/models/whisper/whisper_helper.py +++ b/onnxruntime/python/tools/transformers/models/whisper/whisper_helper.py @@ -339,7 +339,7 @@ def verify_onnx( logger.warning(f"Could not import `datasets`. Attempting to install `datasets` via `{install_cmd}`.") os.system(install_cmd) - from datasets import load_dataset # noqa: F811 + from datasets import load_dataset ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation") input_features = processor([ds[0]["audio"]["array"]], return_tensors="pt").input_features From 8bcb5591fa3f0244294510666add18561636f765 Mon Sep 17 00:00:00 2001 From: Kunal Vaishnavi Date: Thu, 15 Feb 2024 23:17:52 +0000 Subject: [PATCH 18/19] Update README and fix export --- .../tools/transformers/models/whisper/README.md | 17 +++++++++++++++++ .../models/whisper/requirements-cpu.txt | 2 ++ .../models/whisper/requirements-cuda.txt | 4 ++++ .../models/whisper/requirements.txt | 8 -------- .../models/whisper/whisper_decoder.py | 2 +- .../test/python/transformers/test_generation.py | 2 +- 6 files changed, 25 insertions(+), 10 deletions(-) create mode 100644 onnxruntime/python/tools/transformers/models/whisper/requirements-cpu.txt create mode 100644 onnxruntime/python/tools/transformers/models/whisper/requirements-cuda.txt diff --git a/onnxruntime/python/tools/transformers/models/whisper/README.md b/onnxruntime/python/tools/transformers/models/whisper/README.md index cbd1bc8d88c19..7a678f2734ade 100644 --- a/onnxruntime/python/tools/transformers/models/whisper/README.md +++ b/onnxruntime/python/tools/transformers/models/whisper/README.md @@ -1,5 +1,22 @@ # Whisper +## Prerequisites + +Please note the package versions needed for using Whisper in the `requirements.txt` file that fits your scenario. +- `requirements-cpu.txt` + - For running Whisper on CPU +- `requirements-cuda.txt` + - For running Whisper on CUDA + - Note that `torch` with CUDA enabled is not installed automatically. This is because `torch` should be installed with the CUDA version used on your machine. Please visit [the PyTorch website](https://pytorch.org/get-started/locally/) to download the `torch` version that is used with the CUDA version installed on your machine and satisfies the requirement listed in the file. +- `requirements.txt` + - Package versions needed in each of the above files + +In addition to the above packages, you will need to install `ffmpeg` on your machine. Visit the [FFmpeg website](https://ffmpeg.org/) for details. You can also install it natively using package managers. + +- Linux: `sudo apt-get install ffmpeg` +- MacOS: `sudo brew install ffmpeg` +- Windows: Download from website + ## Exporting Whisper with Beam Search There are several ways to export Whisper with beam search (using Whisper tiny as an example). diff --git a/onnxruntime/python/tools/transformers/models/whisper/requirements-cpu.txt b/onnxruntime/python/tools/transformers/models/whisper/requirements-cpu.txt new file mode 100644 index 0000000000000..db2cd95324328 --- /dev/null +++ b/onnxruntime/python/tools/transformers/models/whisper/requirements-cpu.txt @@ -0,0 +1,2 @@ +-r requirements.txt +onnxruntime>=1.17.1 \ No newline at end of file diff --git a/onnxruntime/python/tools/transformers/models/whisper/requirements-cuda.txt b/onnxruntime/python/tools/transformers/models/whisper/requirements-cuda.txt new file mode 100644 index 0000000000000..9bd215de9bc09 --- /dev/null +++ b/onnxruntime/python/tools/transformers/models/whisper/requirements-cuda.txt @@ -0,0 +1,4 @@ +-r requirements.txt +# Please manually install torch>=1.13.0 with CUDA enabled for the CUDA version installed in your system. +# Instructions can be found here: https://pytorch.org/get-started/locally/ +onnxruntime-gpu>=1.17.1 diff --git a/onnxruntime/python/tools/transformers/models/whisper/requirements.txt b/onnxruntime/python/tools/transformers/models/whisper/requirements.txt index c81e1f0a7c374..c307a3665f8a0 100644 --- a/onnxruntime/python/tools/transformers/models/whisper/requirements.txt +++ b/onnxruntime/python/tools/transformers/models/whisper/requirements.txt @@ -1,10 +1,3 @@ -# NOTE: In addition to the below packages, you will need to install `ffmpeg` on your machine. -# Visit https://ffmpeg.org/ for details. You can also install it natively using package managers. -# Ex: -# Linux: sudo apt-get install ffmpeg -# MacOS: sudo brew install ffmpeg -# Windows: Download from website - torch>=1.13.0 transformers>=4.24.0 openai-whisper @@ -13,7 +6,6 @@ datasets soundfile librosa optimum -onnxruntime-gpu>=1.17.0 onnxruntime-extensions>=0.9.0 protobuf==3.20.2 numpy==1.23.3 \ No newline at end of file diff --git a/onnxruntime/python/tools/transformers/models/whisper/whisper_decoder.py b/onnxruntime/python/tools/transformers/models/whisper/whisper_decoder.py index 0d69960a095ac..93fd64c9eb7d3 100644 --- a/onnxruntime/python/tools/transformers/models/whisper/whisper_decoder.py +++ b/onnxruntime/python/tools/transformers/models/whisper/whisper_decoder.py @@ -170,7 +170,7 @@ def create_dummy( cross_attention_past_shape = [ batch_size, num_attention_heads, - past_decode_sequence_length, + encode_sequence_length, head_size, ] diff --git a/onnxruntime/test/python/transformers/test_generation.py b/onnxruntime/test/python/transformers/test_generation.py index 3ac0aa50e3ad8..33ec1bd7728fe 100644 --- a/onnxruntime/test/python/transformers/test_generation.py +++ b/onnxruntime/test/python/transformers/test_generation.py @@ -397,7 +397,7 @@ def test_cross_qk_overall(self): @pytest.mark.slow def test_openai_impl_whisper(self): - optional_args = ["--model_impl", "openai", "--chain_model", "--use_whisper_beamsearch"] + optional_args = ["--model_impl", "openai"] self.run_configs(optional_args) From af38716dd4038736eda3457d6742a708fd3c9722 Mon Sep 17 00:00:00 2001 From: Kunal Vaishnavi Date: Fri, 16 Feb 2024 02:42:20 +0000 Subject: [PATCH 19/19] Add another expected transcription --- .../tools/transformers/models/whisper/whisper_helper.py | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/onnxruntime/python/tools/transformers/models/whisper/whisper_helper.py b/onnxruntime/python/tools/transformers/models/whisper/whisper_helper.py index 5565cf4a65250..1b47b9426d983 100644 --- a/onnxruntime/python/tools/transformers/models/whisper/whisper_helper.py +++ b/onnxruntime/python/tools/transformers/models/whisper/whisper_helper.py @@ -406,7 +406,14 @@ def verify_onnx( expected_transcription_with_comma = ( " Mr. Quilter is the apostle of the middle classes, and we are glad to welcome his gospel." ) - expected_transcription_options = {expected_transcription_no_comma, expected_transcription_with_comma} + expected_transcription_with_quote_and_comma = ( + ' "Mr. Quilter is the apostle of the middle classes, and we are glad to welcome his gospel.' + ) + expected_transcription_options = { + expected_transcription_no_comma, + expected_transcription_with_comma, + expected_transcription_with_quote_and_comma, + } pt_transcription = processor.batch_decode(pt_outputs, skip_special_tokens=True)[0] ort_transcription = processor.batch_decode(ort_outputs, skip_special_tokens=True)[0]