diff --git a/onnxruntime/core/platform/windows/env.cc b/onnxruntime/core/platform/windows/env.cc index 5f099073792c1..0b2db3f33f4b9 100644 --- a/onnxruntime/core/platform/windows/env.cc +++ b/onnxruntime/core/platform/windows/env.cc @@ -32,6 +32,9 @@ limitations under the License. #include "core/common/span_utils.h" #include "core/platform/env.h" #include "core/platform/scoped_resource.h" +#if defined(_M_X64) && !defined(_M_ARM64EC) +#include "core/platform/windows/hardware_core_enumerator.h" +#endif #include #include @@ -248,12 +251,53 @@ void WindowsEnv::SleepForMicroseconds(int64_t micros) const { Sleep(static_cast(micros) / 1000); } +// EIGEN_NO_CPUID is not defined in any C/C++ source code. It is a compile option. +#if defined(_M_X64) && !defined(_M_ARM64EC) && !defined(EIGEN_NO_CPUID) +static constexpr std::array kVendorID_Intel = {0x756e6547, 0x6c65746e, 0x49656e69}; // "GenuntelineI" +#endif int WindowsEnv::DefaultNumCores() { return std::max(1, static_cast(std::thread::hardware_concurrency() / 2)); } int WindowsEnv::GetNumPhysicalCpuCores() const { - return cores_.empty() ? DefaultNumCores() : static_cast(cores_.size()); +// EIGEN_NO_CPUID is not defined in any C/C++ source code. It is a compile option. +#if defined(_M_X64) && !defined(_M_ARM64EC) && !defined(EIGEN_NO_CPUID) + // The following code is a temporary fix for a perf problem on Intel's Meteor Lake CPUs. The Intel compute platform has + // a hybrid architecture that some CPU cores runs significant slower than the others. If we distribute our compute work + // evenly to all CPU cores, the slowest CPU core will drag the performance down. So, instead, we reduce the total number + // of threads to exclude the slowest cores out. + // The following code is based on assumptions that: + // 1. All Intel hybrid CPUs should have 3 levels of cache. + // 2. If a CPU core is only associated with two levels of cache, it should be a low performance CPU core and should + // not be used. + // Since we don't know what the next Intel hybrid CPU would be like, later on we may need to rework the following code. + // However, no matter what the code should not cause any crash. The worst is it might return 1 that + // thread pools will not be created, which is just a perf issue and does not impact usability. + // TODO: detect if CPUID instruction is available per instructions at https://wiki.osdev.org/CPUID#Checking_CPUID_availability + int regs[4]; + __cpuid(regs, 0); + bool bIsIntel = + (kVendorID_Intel[0] == regs[1]) && + (kVendorID_Intel[1] == regs[2]) && + (kVendorID_Intel[2] == regs[3]); + if (bIsIntel && regs[0] >= 7) { + // Query Structured Extended Feature Flags Enumeration Leaf + __cpuid(regs, 0x7); + // The bit 15 of EDX indicates if the processor is identified as a hybrid part. + bool ishybrid = regs[3] & (1 << 15); + if (ishybrid) { + // NOTE: even if ishybrid is true, it doesn't mean the processor must have P-cores and E-cores. + // On Intel CPUs we assume the HardwareCoreEnumerator::DefaultIntraOpNumThreads function would never fail. + // NOTE: due to resource restrictions, we cannot test this branch in our CI build pipelines. + return std::max(static_cast(1), HardwareCoreEnumerator::DefaultIntraOpNumThreads()); + } else { + return cores_.empty() ? DefaultNumCores() : static_cast(cores_.size()); + } + } else +#endif + { + return cores_.empty() ? DefaultNumCores() : static_cast(cores_.size()); + } } std::vector WindowsEnv::GetDefaultThreadAffinities() const { diff --git a/onnxruntime/core/platform/windows/hardware_core_enumerator.cc b/onnxruntime/core/platform/windows/hardware_core_enumerator.cc new file mode 100644 index 0000000000000..bf3b53afbd7d3 --- /dev/null +++ b/onnxruntime/core/platform/windows/hardware_core_enumerator.cc @@ -0,0 +1,90 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "hardware_core_enumerator.h" +#include +#include +#include + +namespace onnxruntime { + +struct LogicalProcessorInformation { + std::unique_ptr Buffer; + size_t Length; +}; + +struct CoreCounter { + uint32_t PhysicalCores = 0; + uint32_t LLCCores = 0; +}; + +static LogicalProcessorInformation GetLogicalProcessorInfos(LOGICAL_PROCESSOR_RELATIONSHIP relationship) { + DWORD length = 0; + DWORD rc = GetLogicalProcessorInformationEx(relationship, nullptr, &length); + + assert(rc == FALSE); + + auto processorInformationBytes = std::make_unique(length); + + rc = GetLogicalProcessorInformationEx( + relationship, reinterpret_cast(processorInformationBytes.get()), &length); + + assert(rc == TRUE); + + return {std::move(processorInformationBytes), length}; +} + +uint32_t CountSetBits(DWORD input) { + uint32_t c; + for (c = 0; input; c++) { + input &= input - 1; + } + return c; +} + +static CoreCounter GetCoreInfo() { + auto logicalProcessorInformation = GetLogicalProcessorInfos(RelationAll); + + CoreCounter cores; + DWORD dwLevel2GroupMask = 0; + DWORD dwLevel3GroupMask = 0; + size_t read = 0; + PSYSTEM_LOGICAL_PROCESSOR_INFORMATION_EX currentProcessorInfo = NULL; + + while ((read + FIELD_OFFSET(SYSTEM_LOGICAL_PROCESSOR_INFORMATION_EX, Processor)) < logicalProcessorInformation.Length) { + currentProcessorInfo = + reinterpret_cast(logicalProcessorInformation.Buffer.get() + read); + if ((read + currentProcessorInfo->Size) > logicalProcessorInformation.Length) { + break; + } + + switch (currentProcessorInfo->Relationship) { + case RelationProcessorCore: + cores.PhysicalCores++; + break; + case RelationCache: + if (currentProcessorInfo->Cache.Level == 2) { + dwLevel2GroupMask |= currentProcessorInfo->Cache.GroupMask.Mask; + } else if (currentProcessorInfo->Cache.Level == 3) { + dwLevel3GroupMask |= currentProcessorInfo->Cache.GroupMask.Mask; + } + break; + } + + read += currentProcessorInfo->Size; + } + // Cores with L2 and LLC cache levels = # Physical Cores - # logical cores without LLC + cores.LLCCores = cores.PhysicalCores - CountSetBits(dwLevel2GroupMask & ~dwLevel3GroupMask); + + return cores; +} + +uint32_t HardwareCoreEnumerator::DefaultIntraOpNumThreads() { + // # of physical cores = # of P cores + # of E Cores + # of Soc Cores. + // # of logical cores = # of P cores x 2 (if hyper threading is enabled) + # of E cores + # of Soc Cores. + auto cores = GetCoreInfo(); + + return cores.LLCCores; +} + +} // namespace onnxruntime diff --git a/onnxruntime/core/platform/windows/hardware_core_enumerator.h b/onnxruntime/core/platform/windows/hardware_core_enumerator.h new file mode 100644 index 0000000000000..93b50f452afcd --- /dev/null +++ b/onnxruntime/core/platform/windows/hardware_core_enumerator.h @@ -0,0 +1,12 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once +#include + +namespace onnxruntime { +struct HardwareCoreEnumerator { + HardwareCoreEnumerator() = delete; + static uint32_t DefaultIntraOpNumThreads(); +}; +} // namespace onnxruntime \ No newline at end of file diff --git a/onnxruntime/core/util/thread_utils.cc b/onnxruntime/core/util/thread_utils.cc index a5a165e150cf1..2a6c14ff1b058 100644 --- a/onnxruntime/core/util/thread_utils.cc +++ b/onnxruntime/core/util/thread_utils.cc @@ -93,22 +93,31 @@ static std::unique_ptr CreateThreadPoolHelper(Env* env, OrtThreadPoolParams options) { ThreadOptions to; if (options.thread_pool_size <= 0) { // default - auto default_affinities = Env::Default().GetDefaultThreadAffinities(); - if (default_affinities.size() <= 1) { - return nullptr; - } - options.thread_pool_size = static_cast(default_affinities.size()); if (options.auto_set_affinity) { #ifdef _WIN32 // Only set thread affinity on Server with auto affinity. // On client best to let OS scheduler handle. // On big (P-Core) / little (E-Core) CPU designs affinity overrides QoS and has high power usage if (IsWindowsServer()) { + auto default_affinities = Env::Default().GetDefaultThreadAffinities(); + if (default_affinities.size() <= 1) { + return nullptr; + } + options.thread_pool_size = static_cast(default_affinities.size()); to.affinities = std::move(default_affinities); + } else { + options.thread_pool_size = Env::Default().GetNumPhysicalCpuCores(); } #else + auto default_affinities = Env::Default().GetDefaultThreadAffinities(); + if (default_affinities.size() <= 1) { + return nullptr; + } + options.thread_pool_size = static_cast(default_affinities.size()); to.affinities = std::move(default_affinities); #endif + } else { + options.thread_pool_size = Env::Default().GetNumPhysicalCpuCores(); } } if (options.thread_pool_size <= 1) { diff --git a/onnxruntime/python/tools/transformers/models/llama/benchmark_e2e.py b/onnxruntime/python/tools/transformers/models/llama/benchmark_e2e.py index 47b7f35cbdd7c..b69bd229745c6 100644 --- a/onnxruntime/python/tools/transformers/models/llama/benchmark_e2e.py +++ b/onnxruntime/python/tools/transformers/models/llama/benchmark_e2e.py @@ -20,6 +20,14 @@ # 4) Install the latest ONNX Runtime version # # $ pip install onnxruntime-gpu +# +# 5) Install flash attention v2 +# +# $ pip install flash-attn --no-build-isolation +# +# 6) Install bitsandbytes +# +# $ pip install bitsandbytes from __future__ import annotations @@ -38,22 +46,44 @@ import torch from benchmark_helper import setup_logger from llama_inputs import add_io_bindings_as_tensors, get_initial_inputs_and_outputs -from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer +from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig import onnxruntime as ort logger = logging.getLogger(__name__) -def get_model(args): +def get_model(args: argparse.Namespace): if args.benchmark_type in {"pt-eager", "pt-compile"}: - model = AutoModelForCausalLM.from_pretrained( - args.hf_dir_path if args.hf_dir_path != "" else args.model_name, - cache_dir=args.cache_dir, - torch_dtype=args.torch_dtype, - use_auth_token=args.auth, - use_cache=True, - ).to(args.target_device) + model = None + if args.onnx_precision == "int4" and args.device == "cuda": + bnb_config = BitsAndBytesConfig( + load_in_4bit=True, + bnb_4bit_use_double_quant=True, + bnb_4bit_quant_type="nf4", + bnb_4bit_compute_dtype=torch.float16, + ) + + model = AutoModelForCausalLM.from_pretrained( + args.hf_dir_path if args.hf_dir_path != "" else args.model_name, + cache_dir=args.cache_dir, + torch_dtype=args.torch_dtype, + use_auth_token=args.auth, + use_cache=True, + attn_implementation="flash_attention_2", + quantization_config=bnb_config, + max_memory={args.device_id: "80GB"}, + ) + else: + model = AutoModelForCausalLM.from_pretrained( + args.hf_dir_path if args.hf_dir_path != "" else args.model_name, + cache_dir=args.cache_dir, + torch_dtype=args.torch_dtype, + use_auth_token=args.auth, + use_cache=True, + attn_implementation=("flash_attention_2" if args.device == "cuda" else "sdpa"), + ).to(args.target_device) + model.eval() if args.benchmark_type == "pt-compile": @@ -223,7 +253,7 @@ def get_args(): parser.add_argument( "-s", "--prompt-lengths", - default="32 64 128 256 512", + default="16 64 256 1024", ) parser.add_argument( @@ -277,6 +307,7 @@ def get_args(): args.prompt_lengths = args.prompt_lengths.split(" ") # Use FP32 precision for FP32, INT8, INT4 CPU models, use FP16 precision for FP16 and INT4 GPU models + setattr(args, "onnx_precision", args.precision) # noqa: B010 args.precision = ( "fp32" if args.precision in {"int8", "fp32"} or (args.precision == "int4" and args.device == "cpu") else "fp16" ) diff --git a/onnxruntime/python/tools/transformers/models/llama/llama_parity.py b/onnxruntime/python/tools/transformers/models/llama/llama_parity.py index 0a034aaf7a2fc..386abd5fa6d59 100644 --- a/onnxruntime/python/tools/transformers/models/llama/llama_parity.py +++ b/onnxruntime/python/tools/transformers/models/llama/llama_parity.py @@ -151,7 +151,7 @@ def get_args(argv: list[str]): parser.add_argument( "-m", "--model_name", - required=True, + required=False, help="Model name in Hugging Face", ) diff --git a/onnxruntime/python/tools/transformers/models/whisper/convert_to_onnx.py b/onnxruntime/python/tools/transformers/models/whisper/convert_to_onnx.py index 35211aab272e4..112e2fbb1bfdf 100644 --- a/onnxruntime/python/tools/transformers/models/whisper/convert_to_onnx.py +++ b/onnxruntime/python/tools/transformers/models/whisper/convert_to_onnx.py @@ -506,7 +506,13 @@ def main(argv=None): # Wrap parity check in try-except to allow export to continue in case this produces an error try: with torch.no_grad(): - max_diff = WhisperHelper.verify_onnx(args.model_name_or_path, cache_dir, ort_session, device) + # Verify batched decoding with prompts for whisper openai implementation + if args.model_impl == "openai" and args.use_forced_decoder_ids: + max_diff = WhisperHelper.verify_onnx( + args.model_name_or_path, cache_dir, ort_session, device, batch_size=2, prompt_mode=True + ) + else: + max_diff = WhisperHelper.verify_onnx(args.model_name_or_path, cache_dir, ort_session, device) if max_diff > 1e-4: logger.warning("PyTorch and ONNX Runtime results are NOT close") else: diff --git a/onnxruntime/python/tools/transformers/models/whisper/whisper_chain.py b/onnxruntime/python/tools/transformers/models/whisper/whisper_chain.py index 0b128f122e0f4..be05ebc9d5dac 100644 --- a/onnxruntime/python/tools/transformers/models/whisper/whisper_chain.py +++ b/onnxruntime/python/tools/transformers/models/whisper/whisper_chain.py @@ -51,8 +51,8 @@ def chain_model(args): decoder_model = onnx.load_model(args.decoder_path, load_external_data=True) decoder_model.graph.name = "decoder subgraph" - config = WhisperConfig.from_pretrained(args.model_name_or_path) - tokenizer = WhisperTokenizer.from_pretrained(args.model_name_or_path) + config = WhisperConfig.from_pretrained(args.model_name_or_path, cache_dir=args.cache_dir) + tokenizer = WhisperTokenizer.from_pretrained(args.model_name_or_path, cache_dir=args.cache_dir) # Create inputs/outputs for WhisperBeamSearch op temperature_name = "temperature_fp16" if args.precision == Precision.FLOAT16 else "temperature" diff --git a/onnxruntime/python/tools/transformers/models/whisper/whisper_decoder.py b/onnxruntime/python/tools/transformers/models/whisper/whisper_decoder.py index 93fd64c9eb7d3..5da235d72ca0b 100644 --- a/onnxruntime/python/tools/transformers/models/whisper/whisper_decoder.py +++ b/onnxruntime/python/tools/transformers/models/whisper/whisper_decoder.py @@ -126,6 +126,7 @@ def create_dummy( device: torch.device, float16: bool = False, use_int32_inputs: bool = False, + model_impl: str = "hf", ): # -> WhisperDecoderInputs: """Create dummy inputs for WhisperDecoder. @@ -170,7 +171,7 @@ def create_dummy( cross_attention_past_shape = [ batch_size, num_attention_heads, - encode_sequence_length, + encode_sequence_length if model_impl == "hf" else past_decode_sequence_length, head_size, ] @@ -228,6 +229,7 @@ def export_onnx( past_decode_sequence_length=6 if isinstance(decoder, WhisperDecoder) else 0, device=device, use_int32_inputs=use_int32_inputs, + model_impl=decoder.model_impl, ) input_list = inputs.to_list() diff --git a/onnxruntime/python/tools/transformers/models/whisper/whisper_encoder_decoder_init.py b/onnxruntime/python/tools/transformers/models/whisper/whisper_encoder_decoder_init.py index 832f692e9980d..fab2a2aa4c8a8 100644 --- a/onnxruntime/python/tools/transformers/models/whisper/whisper_encoder_decoder_init.py +++ b/onnxruntime/python/tools/transformers/models/whisper/whisper_encoder_decoder_init.py @@ -4,7 +4,6 @@ # license information. # -------------------------------------------------------------------------- -import copy import logging import os import tempfile @@ -51,12 +50,15 @@ def forward( self, encoder_input_ids: torch.Tensor, decoder_input_ids: torch.Tensor = None, + remove_hooks: bool = False, ): encoder_hidden_states: torch.FloatTensor = self.whisper_encoder(encoder_input_ids) # Decoder out: (logits, past_key_values, encoder_hidden_state) if self.model_impl == "openai": encoder_hidden_states.unsqueeze(0) - decinit_out, present = self.whisper_decoder_openai_init(decoder_input_ids, encoder_hidden_states) + decinit_out, present = self.whisper_decoder_openai_init( + decoder_input_ids, encoder_hidden_states, remove_hooks=remove_hooks + ) return decinit_out, encoder_hidden_states, present else: decinit_out = self.whisper_decoder_init(decoder_input_ids, encoder_hidden_states) @@ -131,9 +133,7 @@ def export_onnx( ) input_list = inputs.to_list() - # TODO : Investigate whether copy of model if needed - cloned_model = copy.deepcopy(model).to(device) - out = cloned_model(inputs.encoder_input_ids, inputs.decoder_input_ids) + out = model(inputs.encoder_input_ids, inputs.decoder_input_ids, remove_hooks=True) present = out[2] present_names = PastKeyValuesHelper.get_input_names(present, encoder=True) diff --git a/onnxruntime/python/tools/transformers/models/whisper/whisper_helper.py b/onnxruntime/python/tools/transformers/models/whisper/whisper_helper.py index a1d0d7fb3deeb..1a066f5ad4ac7 100644 --- a/onnxruntime/python/tools/transformers/models/whisper/whisper_helper.py +++ b/onnxruntime/python/tools/transformers/models/whisper/whisper_helper.py @@ -314,22 +314,13 @@ def optimize_onnx( m.save_model_to_file(optimized_model_path, use_external_data_format, all_tensors_to_one_file=True) @staticmethod - def verify_onnx( - model_name_or_path: str, - cache_dir: str, - ort_session: InferenceSession, + def pt_transcription_for_verify_onnx( + processor: WhisperProcessor, + pt_model: torch.nn.Module, device: torch.device, + batch_size: int = 1, + prompt_mode: bool = False, ): - """Compare the result from PyTorch and ONNX Runtime to verify the ONNX model is good.""" - extra_kwargs = {} - if version.parse(transformers_version) >= version.parse("4.36.0"): - extra_kwargs["attn_implementation"] = "eager" - pt_model = WhisperForConditionalGeneration.from_pretrained( - model_name_or_path, cache_dir=cache_dir, **extra_kwargs - ).to(device) - processor = WhisperProcessor.from_pretrained(model_name_or_path) - config = WhisperConfig.from_pretrained(model_name_or_path) - # Try to import `datasets` pip package try: from datasets import load_dataset @@ -342,14 +333,18 @@ def verify_onnx( from datasets import load_dataset # noqa: F811 ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation") - input_features = processor([ds[0]["audio"]["array"]], return_tensors="pt").input_features - - start_id = [config.decoder_start_token_id] # ex: [50258] - prompt_ids = processor.get_decoder_prompt_ids(language="english", task="transcribe") - prompt_ids = list(map(lambda token: token[1], prompt_ids)) # ex: [50259, 50358, 50363] - forced_decoder_ids = start_id + prompt_ids # ex: [50258, 50259, 50358, 50363] - - batch_size, max_length, min_length, num_beams, num_return_sequences = 1, 30, 0, 1, 1 + input_features_ = [] + if batch_size == 1: + input_features = processor([ds[0]["audio"]["array"]], return_tensors="pt").input_features + else: + input_features_ = [ + processor([ds[3]["audio"]["array"]], return_tensors="pt").input_features, + processor([ds[3]["audio"]["array"]], return_tensors="pt").input_features, + ] + assert len(input_features_) == batch_size + input_features = torch.cat((input_features_[0], input_features_[1])) + + max_length, min_length, num_beams, num_return_sequences = 30, 0, 1, 1 length_penalty, repetition_penalty = 1.0, 1.0 inputs = { "input_features": input_features.to(device), @@ -362,10 +357,97 @@ def verify_onnx( "early_stopping": True, "use_cache": True, } - pt_outputs = pt_model.generate(**inputs).detach().cpu().numpy() + if prompt_mode: + prompts = ["John has doubts", "Maria has grave doubts"] + prompt_ids = [processor.get_prompt_ids(p) for p in prompts] + pt_transcription = [] + pt_outputs = [] + # The looping for model.generate is necessary here due to the limitation as per + # https://huggingface.co/docs/transformers/model_doc/whisper#transformers.WhisperForConditionalGeneration.generate.prompt_ids + # prompt_ids input requires a tensor of rank 1 + for i in range(batch_size): + inputs["prompt_ids"] = torch.from_numpy(prompt_ids[i]) + inputs["input_features"] = input_features_[i].to(device) + pt_output = pt_model.generate(**inputs).detach().cpu().numpy() + pt_outputs.append(pt_output) + pt_transcription.append(processor.batch_decode(pt_output, skip_special_tokens=True)[0]) + inputs["input_features"] = input_features + del inputs["prompt_ids"] + else: + prompt_ids = [] + pt_outputs = pt_model.generate(**inputs).detach().cpu().numpy() + pt_transcription = [processor.batch_decode(pt_outputs, skip_special_tokens=True)[0]] + pt_outputs = list(pt_outputs) del inputs["early_stopping"] del inputs["use_cache"] + return inputs, pt_transcription, pt_outputs, prompt_ids + + @staticmethod + def select_transcription_options( + batch_size: int, + prompt_mode: bool, + ): + if batch_size > 1 and prompt_mode: + expected_transcription_no_comma_prompt1 = " John has doubts whether Sir Frederick Layton's work is really Greek after all and can discover in it but little of Rocky I" + expected_transcription_misspelled_prompt1 = " John has doubts whether Sir Frederick Latins work is really Greek after all and can discover in it but little of Rocky I" + expected_transcription_no_comma_prompt2 = " Maria has grave doubts whether Sir Frederick Layton's work is really Greek after all and can discover in it but little of Rocky" + expected_transcription_misspelled_prompt2 = " Maria has grave doubts whether Sir Frederick Latins work is really Greek after all and can discover in it but little of Rocky I" + expected_transcription_options = { + expected_transcription_no_comma_prompt1, + expected_transcription_no_comma_prompt2, + expected_transcription_misspelled_prompt1, + expected_transcription_misspelled_prompt2, + } + else: + expected_transcription_no_comma = ( + " Mr. Quilter is the apostle of the middle classes and we are glad to welcome his gospel." + ) + expected_transcription_with_comma = ( + " Mr. Quilter is the apostle of the middle classes, and we are glad to welcome his gospel." + ) + expected_transcription_with_quote_and_comma = ( + ' "Mr. Quilter is the apostle of the middle classes, and we are glad to welcome his gospel.' + ) + expected_transcription_options = { + expected_transcription_no_comma, + expected_transcription_with_comma, + expected_transcription_with_quote_and_comma, + } + return expected_transcription_options + + @staticmethod + def verify_onnx( + model_name_or_path: str, + cache_dir: str, + ort_session: InferenceSession, + device: torch.device, + batch_size: int = 1, + prompt_mode: bool = False, + ): + """Compare the result from PyTorch and ONNX Runtime to verify the ONNX model is good.""" + extra_kwargs = {} + if version.parse(transformers_version) >= version.parse("4.36.0"): + extra_kwargs["attn_implementation"] = "eager" + pt_model = WhisperForConditionalGeneration.from_pretrained( + model_name_or_path, cache_dir=cache_dir, **extra_kwargs + ).to(device) + processor = WhisperProcessor.from_pretrained(model_name_or_path, cache_dir=cache_dir) + config = WhisperConfig.from_pretrained(model_name_or_path, cache_dir=cache_dir) + + inputs, pt_transcription, pt_outputs, decoder_prompt_ids = WhisperHelper.pt_transcription_for_verify_onnx( + processor, + pt_model, + device, + batch_size=batch_size, + prompt_mode=prompt_mode, + ) + + start_id = [config.decoder_start_token_id] # ex: [50258] + prompt_ids = processor.get_decoder_prompt_ids(language="english", task="transcribe") + prompt_ids = list(map(lambda token: token[1], prompt_ids)) # ex: [50259, 50358, 50363] + forced_decoder_ids = start_id + prompt_ids # ex: [50258, 50259, 50358, 50363] + ort_names = list(map(lambda entry: entry.name, ort_session.get_inputs())) ort_dtypes = list(map(lambda entry: entry.type, ort_session.get_inputs())) ort_to_np = { @@ -386,8 +468,24 @@ def verify_onnx( elif name == "prefix_vocab_mask": inputs[name] = np.ones((batch_size, config.vocab_size), dtype=ort_to_np[dtype]) elif name == "decoder_input_ids": - raw_input_ids = [start_id] if use_extra_decoding_ids else [forced_decoder_ids] - inputs[name] = np.array(raw_input_ids, dtype=ort_to_np[dtype]) + if not prompt_mode: + raw_input_ids = [start_id] if use_extra_decoding_ids else [forced_decoder_ids] + inputs[name] = np.array(raw_input_ids, dtype=ort_to_np[dtype]) + else: + # This logic handles the scenario for when prompts are not of the same size + # For example if our prompt ids are [p1_id_1, p1_id_2] and [p2_id_1] + # The final decoder_input_ids will look as such after padding + # [prev_token, p1_id_1, p1_id_2, start_token, lang_token, transcribe_token] + # [prev_token, p2_id_1, PAD_TOKEN, start_token, lang_token, transcribe_token] + ort_prompts = [] + for i in range(batch_size): + ort_prompts.append(decoder_prompt_ids[i].tolist()) + max_len = max(len(p) for p in ort_prompts) + padded_prompts = [] + for p in ort_prompts: + padded_prompt = [*p, *([config.pad_token_id] * (max_len - len(p)))] + padded_prompts.append(padded_prompt + forced_decoder_ids) + inputs[name] = np.array(padded_prompts, dtype=ort_to_np[dtype]) elif name == "logits_processor": inputs[name] = np.array([1], dtype=ort_to_np[dtype]) elif name == "cross_qk_layer_head": @@ -398,36 +496,26 @@ def verify_onnx( inputs[name] = np.array([1.0], dtype=ort_to_np[dtype]) else: inputs[name] = np.array([inputs[name]], dtype=ort_to_np[dtype]) - ort_outputs = ort_session.run(None, inputs)[0][0] - - expected_transcription_no_comma = ( - " Mr. Quilter is the apostle of the middle classes and we are glad to welcome his gospel." - ) - expected_transcription_with_comma = ( - " Mr. Quilter is the apostle of the middle classes, and we are glad to welcome his gospel." - ) - expected_transcription_with_quote_and_comma = ( - ' "Mr. Quilter is the apostle of the middle classes, and we are glad to welcome his gospel.' - ) - expected_transcription_options = { - expected_transcription_no_comma, - expected_transcription_with_comma, - expected_transcription_with_quote_and_comma, - } - pt_transcription = processor.batch_decode(pt_outputs, skip_special_tokens=True)[0] - ort_transcription = processor.batch_decode(ort_outputs, skip_special_tokens=True)[0] - - parity = ( - pt_transcription in expected_transcription_options and ort_transcription in expected_transcription_options - ) + ort_outputs = ort_session.run(None, inputs)[0][:, 0, :] + ort_transcription = processor.batch_decode(ort_outputs, skip_special_tokens=True) + expected_transcription_options = WhisperHelper.select_transcription_options(batch_size, prompt_mode) + + parity = 1 + for i in range(batch_size): + parity *= ( + pt_transcription[i] in expected_transcription_options + and ort_transcription[i] in expected_transcription_options + ) max_diff = 0 if not parity: - if pt_outputs.shape != ort_outputs.shape: - diff = pt_outputs - ort_outputs[:, : len(pt_outputs[0])] - else: - diff = pt_outputs - ort_outputs - max_diff = max(diff.min(), diff.max(), key=abs) + for i in range(batch_size): + if pt_outputs[i].shape != ort_outputs[i].shape: + diff = pt_outputs[i] - ort_outputs[i][:, : len(pt_outputs[i])] + else: + diff = pt_outputs[i] - ort_outputs[i] + max_diff_i = max(diff.min(), diff.max(), key=abs) + max_diff = max(max_diff, max_diff_i) if max_diff != 0: logger.warning(f"PyTorch outputs: {pt_transcription}") diff --git a/onnxruntime/python/tools/transformers/models/whisper/whisper_openai_helper.py b/onnxruntime/python/tools/transformers/models/whisper/whisper_openai_helper.py index 941f61cf7cc29..849c3059f21f7 100644 --- a/onnxruntime/python/tools/transformers/models/whisper/whisper_openai_helper.py +++ b/onnxruntime/python/tools/transformers/models/whisper/whisper_openai_helper.py @@ -30,6 +30,7 @@ def forward( tokens, audio_features, past=None, + remove_hooks=False, ): # Create a kv_cache for past_values past_kv_cache = dict() @@ -44,8 +45,9 @@ def forward( past_kv_cache[block.cross_attn.key] = past[2 * idx + half_idx] past_kv_cache[block.cross_attn.value] = past[2 * idx + half_idx + 1] + hooks = None if not self.kv_cache: - self.kv_cache, _ = self.whisper_model.install_kv_cache_hooks() + self.kv_cache, hooks = self.whisper_model.install_kv_cache_hooks() logits = self.whisper_decoder(tokens, audio_features, kv_cache=past_kv_cache) @@ -73,4 +75,10 @@ def forward( present_self = [ present_val.reshape(present_val.shape[:2] + (-1, 64)).transpose(1, 2) for present_val in present_self ] + + # Remove forward hooks to avoid model cloning step + if hooks is not None and remove_hooks: + self.kv_cache = {} + for hook in hooks: + hook.remove() return logits, present_self diff --git a/tools/ci_build/build.py b/tools/ci_build/build.py index 1484186b8a331..85583e11f5930 100644 --- a/tools/ci_build/build.py +++ b/tools/ci_build/build.py @@ -1503,8 +1503,10 @@ def generate_build_tree( # The "/profile" flag implies "/DEBUG:FULL /DEBUGTYPE:cv,fixup /OPT:REF /OPT:NOICF /INCREMENTAL:NO /FIXED:NO". We set it for satisfying a Microsoft internal compliance requirement. External users # do not need to have it. ldflags = ["/profile", "/DYNAMICBASE"] - if args.enable_qspectre: - cflags += ["/Qspectre"] + # Address Sanitizer libs do not have a Qspectre version. So they two cannot be both enabled. + if not args.enable_address_sanitizer: + # Also enable a special perf patch that was made for Intel Meteor Lake mobile CPUs + cflags += ["/Qspectre", "/DONNXRUNTIME_ENABLE_INTEL_METEOR_LAKE_MOBILE_PLATFORM_PERF_PATCH"] if config == "Release": cflags += ["/O2", "/Ob2", "/DNDEBUG"] elif config == "RelWithDebInfo": @@ -1606,9 +1608,11 @@ def generate_build_tree( [ *temp_cmake_args, f"-DCMAKE_BUILD_TYPE={config}", - f"-DCMAKE_PREFIX_PATH={build_dir}/{config}/installed" - if preinstalled_dir.exists() and not (args.arm64 or args.arm64ec or args.arm) - else "", + ( + f"-DCMAKE_PREFIX_PATH={build_dir}/{config}/installed" + if preinstalled_dir.exists() and not (args.arm64 or args.arm64ec or args.arm) + else "" + ), ], cwd=config_build_dir, cuda_home=cuda_home, diff --git a/winml/lib/Api/HardwareCoreEnumerator.cpp b/winml/lib/Api/HardwareCoreEnumerator.cpp index b6b44690f4f6c..edec83b5a10cb 100644 --- a/winml/lib/Api/HardwareCoreEnumerator.cpp +++ b/winml/lib/Api/HardwareCoreEnumerator.cpp @@ -14,7 +14,7 @@ struct LogicalProcessorInformation { struct CoreCounter { uint32_t PhysicalCores = 0; - uint32_t Num2CacheCores = 0; + uint32_t LLCCores = 0; }; static LogicalProcessorInformation GetLogicalProcessorInfos(LOGICAL_PROCESSOR_RELATIONSHIP relationship) { @@ -42,7 +42,7 @@ uint32_t CountSetBits(DWORD input) { return c; } -static CoreCounter GetNumberOPhysicalAndEngineeringCores() { +static CoreCounter GetCoreInfo() { auto logicalProcessorInformation = GetLogicalProcessorInfos(RelationAll); CoreCounter cores; @@ -64,6 +64,7 @@ static CoreCounter GetNumberOPhysicalAndEngineeringCores() { cores.PhysicalCores++; break; case RelationCache: + //Cache level masks count Logicial processors if (currentProcessorInfo->Cache.Level == 2) { dwLevel2GroupMask |= currentProcessorInfo->Cache.GroupMask.Mask; } else if (currentProcessorInfo->Cache.Level == 3) { @@ -75,14 +76,15 @@ static CoreCounter GetNumberOPhysicalAndEngineeringCores() { read += currentProcessorInfo->Size; } - cores.Num2CacheCores = CountSetBits(dwLevel2GroupMask & ~dwLevel3GroupMask); + cores.LLCCores = cores.PhysicalCores - CountSetBits(dwLevel2GroupMask & ~dwLevel3GroupMask); + return cores; } uint32_t HardwareCoreEnumerator::DefaultIntraOpNumThreads() { // # of physical cores = # of P cores + # of E Cores + # of Soc Cores. // # of logical cores = # of P cores x 2 (if hyper threading is enabled) + # of E cores + # of Soc Cores. - auto cores = GetNumberOPhysicalAndEngineeringCores(); + auto cores = GetCoreInfo(); #if !defined(_M_ARM64) && !defined(__aarch64__) const int kVendorID_Intel[3] = {0x756e6547, 0x6c65746e, 0x49656e69}; // "GenuntelineI" @@ -97,9 +99,8 @@ uint32_t HardwareCoreEnumerator::DefaultIntraOpNumThreads() { auto isHybrid = (regs_leaf7[3] & (1 << 15)); if (isIntel && isHybrid) { - // We want to use the number of physical cores, but exclude soc cores - // On Intel Hybrid processors, numSocCores == cores.Num2CacheCores - return cores.PhysicalCores - cores.Num2CacheCores; + // We want to use the number of physical cores, but exclude cores without an LLC + return cores.LLCCores; } #endif