Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Cherry pick 1.17.3 - Round 2 #20178

Merged
merged 4 commits into from
Apr 3, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 45 additions & 1 deletion onnxruntime/core/platform/windows/env.cc
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,9 @@
#include "core/common/span_utils.h"
#include "core/platform/env.h"
#include "core/platform/scoped_resource.h"
#if defined(_M_X64) && !defined(_M_ARM64EC)
#include "core/platform/windows/hardware_core_enumerator.h"
#endif
#include <unsupported/Eigen/CXX11/ThreadPool>
#include <wil/Resource.h>

Expand Down Expand Up @@ -248,12 +251,53 @@
Sleep(static_cast<DWORD>(micros) / 1000);
}

// EIGEN_NO_CPUID is not defined in any C/C++ source code. It is a compile option.
#if defined(_M_X64) && !defined(_M_ARM64EC) && !defined(EIGEN_NO_CPUID)
static constexpr std::array<int, 3> kVendorID_Intel = {0x756e6547, 0x6c65746e, 0x49656e69}; // "GenuntelineI"
#endif
int WindowsEnv::DefaultNumCores() {
return std::max(1, static_cast<int>(std::thread::hardware_concurrency() / 2));
}

int WindowsEnv::GetNumPhysicalCpuCores() const {
return cores_.empty() ? DefaultNumCores() : static_cast<int>(cores_.size());
// EIGEN_NO_CPUID is not defined in any C/C++ source code. It is a compile option.
#if defined(_M_X64) && !defined(_M_ARM64EC) && !defined(EIGEN_NO_CPUID)
// The following code is a temporary fix for a perf problem on Intel's Meteor Lake CPUs. The Intel compute platform has

Check warning on line 265 in onnxruntime/core/platform/windows/env.cc

View workflow job for this annotation

GitHub Actions / Lint C++

[cpplint] reported by reviewdog 🐶 Lines should be <= 120 characters long [whitespace/line_length] [2] Raw Output: onnxruntime/core/platform/windows/env.cc:265: Lines should be <= 120 characters long [whitespace/line_length] [2]
// a hybrid architecture that some CPU cores runs significant slower than the others. If we distribute our compute work

Check warning on line 266 in onnxruntime/core/platform/windows/env.cc

View workflow job for this annotation

GitHub Actions / Lint C++

[cpplint] reported by reviewdog 🐶 Lines should be <= 120 characters long [whitespace/line_length] [2] Raw Output: onnxruntime/core/platform/windows/env.cc:266: Lines should be <= 120 characters long [whitespace/line_length] [2]
// evenly to all CPU cores, the slowest CPU core will drag the performance down. So, instead, we reduce the total number

Check warning on line 267 in onnxruntime/core/platform/windows/env.cc

View workflow job for this annotation

GitHub Actions / Lint C++

[cpplint] reported by reviewdog 🐶 Lines should be <= 120 characters long [whitespace/line_length] [2] Raw Output: onnxruntime/core/platform/windows/env.cc:267: Lines should be <= 120 characters long [whitespace/line_length] [2]
// of threads to exclude the slowest cores out.
// The following code is based on assumptions that:
// 1. All Intel hybrid CPUs should have 3 levels of cache.
// 2. If a CPU core is only associated with two levels of cache, it should be a low performance CPU core and should
// not be used.
// Since we don't know what the next Intel hybrid CPU would be like, later on we may need to rework the following code.

Check warning on line 273 in onnxruntime/core/platform/windows/env.cc

View workflow job for this annotation

GitHub Actions / Lint C++

[cpplint] reported by reviewdog 🐶 Lines should be <= 120 characters long [whitespace/line_length] [2] Raw Output: onnxruntime/core/platform/windows/env.cc:273: Lines should be <= 120 characters long [whitespace/line_length] [2]
// However, no matter what the code should not cause any crash. The worst is it might return 1 that
// thread pools will not be created, which is just a perf issue and does not impact usability.
// TODO: detect if CPUID instruction is available per instructions at https://wiki.osdev.org/CPUID#Checking_CPUID_availability

Check warning on line 276 in onnxruntime/core/platform/windows/env.cc

View workflow job for this annotation

GitHub Actions / Lint C++

[cpplint] reported by reviewdog 🐶 Missing username in TODO; it should look like "// TODO(my_username): Stuff." [readability/todo] [2] Raw Output: onnxruntime/core/platform/windows/env.cc:276: Missing username in TODO; it should look like "// TODO(my_username): Stuff." [readability/todo] [2]
int regs[4];
__cpuid(regs, 0);
bool bIsIntel =
(kVendorID_Intel[0] == regs[1]) &&
(kVendorID_Intel[1] == regs[2]) &&
(kVendorID_Intel[2] == regs[3]);
if (bIsIntel && regs[0] >= 7) {
// Query Structured Extended Feature Flags Enumeration Leaf
__cpuid(regs, 0x7);
// The bit 15 of EDX indicates if the processor is identified as a hybrid part.
bool ishybrid = regs[3] & (1 << 15);
if (ishybrid) {
// NOTE: even if ishybrid is true, it doesn't mean the processor must have P-cores and E-cores.
// On Intel CPUs we assume the HardwareCoreEnumerator::DefaultIntraOpNumThreads function would never fail.
// NOTE: due to resource restrictions, we cannot test this branch in our CI build pipelines.
return std::max(static_cast<uint32_t>(1), HardwareCoreEnumerator::DefaultIntraOpNumThreads());
} else {
return cores_.empty() ? DefaultNumCores() : static_cast<int>(cores_.size());
}
} else

Check warning on line 296 in onnxruntime/core/platform/windows/env.cc

View workflow job for this annotation

GitHub Actions / Lint C++

[cpplint] reported by reviewdog 🐶 If an else has a brace on one side, it should have it on both [readability/braces] [5] Raw Output: onnxruntime/core/platform/windows/env.cc:296: If an else has a brace on one side, it should have it on both [readability/braces] [5]
#endif
{
return cores_.empty() ? DefaultNumCores() : static_cast<int>(cores_.size());
}
}

std::vector<LogicalProcessors> WindowsEnv::GetDefaultThreadAffinities() const {
Expand Down
90 changes: 90 additions & 0 deletions onnxruntime/core/platform/windows/hardware_core_enumerator.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
// Copyright (c) Microsoft Corporation. All rights reserved.

Check warning on line 1 in onnxruntime/core/platform/windows/hardware_core_enumerator.cc

View workflow job for this annotation

GitHub Actions / Lint C++

[cpplint] reported by reviewdog 🐶 At least two spaces is best between code and comments [whitespace/comments] [2] Raw Output: onnxruntime/core/platform/windows/hardware_core_enumerator.cc:1: At least two spaces is best between code and comments [whitespace/comments] [2]
// Licensed under the MIT License.

#include "hardware_core_enumerator.h"

Check warning on line 4 in onnxruntime/core/platform/windows/hardware_core_enumerator.cc

View workflow job for this annotation

GitHub Actions / Lint C++

[cpplint] reported by reviewdog 🐶 Include the directory when naming header files [build/include_subdir] [4] Raw Output: onnxruntime/core/platform/windows/hardware_core_enumerator.cc:4: Include the directory when naming header files [build/include_subdir] [4]
#include <memory>
#include <Windows.h>
#include <assert.h>

Check warning on line 7 in onnxruntime/core/platform/windows/hardware_core_enumerator.cc

View workflow job for this annotation

GitHub Actions / Lint C++

[cpplint] reported by reviewdog 🐶 Found C system header after C++ system header. Should be: hardware_core_enumerator.h, c system, c++ system, other. [build/include_order] [4] Raw Output: onnxruntime/core/platform/windows/hardware_core_enumerator.cc:7: Found C system header after C++ system header. Should be: hardware_core_enumerator.h, c system, c++ system, other. [build/include_order] [4]

namespace onnxruntime {

struct LogicalProcessorInformation {
std::unique_ptr<char[]> Buffer;
size_t Length;
};

struct CoreCounter {
uint32_t PhysicalCores = 0;
uint32_t LLCCores = 0;
};

static LogicalProcessorInformation GetLogicalProcessorInfos(LOGICAL_PROCESSOR_RELATIONSHIP relationship) {
DWORD length = 0;
DWORD rc = GetLogicalProcessorInformationEx(relationship, nullptr, &length);

assert(rc == FALSE);

auto processorInformationBytes = std::make_unique<char[]>(length);

rc = GetLogicalProcessorInformationEx(
relationship, reinterpret_cast<PSYSTEM_LOGICAL_PROCESSOR_INFORMATION_EX>(processorInformationBytes.get()), &length);

Check warning on line 30 in onnxruntime/core/platform/windows/hardware_core_enumerator.cc

View workflow job for this annotation

GitHub Actions / Lint C++

[cpplint] reported by reviewdog 🐶 Lines should be <= 120 characters long [whitespace/line_length] [2] Raw Output: onnxruntime/core/platform/windows/hardware_core_enumerator.cc:30: Lines should be <= 120 characters long [whitespace/line_length] [2]

assert(rc == TRUE);

return {std::move(processorInformationBytes), length};
}

uint32_t CountSetBits(DWORD input) {
uint32_t c;
for (c = 0; input; c++) {
input &= input - 1;
}
return c;
}

static CoreCounter GetCoreInfo() {
auto logicalProcessorInformation = GetLogicalProcessorInfos(RelationAll);

CoreCounter cores;
DWORD dwLevel2GroupMask = 0;
DWORD dwLevel3GroupMask = 0;
size_t read = 0;
PSYSTEM_LOGICAL_PROCESSOR_INFORMATION_EX currentProcessorInfo = NULL;

while ((read + FIELD_OFFSET(SYSTEM_LOGICAL_PROCESSOR_INFORMATION_EX, Processor)) < logicalProcessorInformation.Length) {
currentProcessorInfo =
reinterpret_cast<PSYSTEM_LOGICAL_PROCESSOR_INFORMATION_EX>(logicalProcessorInformation.Buffer.get() + read);
if ((read + currentProcessorInfo->Size) > logicalProcessorInformation.Length) {
break;
}

switch (currentProcessorInfo->Relationship) {
case RelationProcessorCore:
cores.PhysicalCores++;
break;
case RelationCache:
if (currentProcessorInfo->Cache.Level == 2) {
dwLevel2GroupMask |= currentProcessorInfo->Cache.GroupMask.Mask;
} else if (currentProcessorInfo->Cache.Level == 3) {
dwLevel3GroupMask |= currentProcessorInfo->Cache.GroupMask.Mask;
}
break;
}

read += currentProcessorInfo->Size;
}
// Cores with L2 and LLC cache levels = # Physical Cores - # logical cores without LLC
cores.LLCCores = cores.PhysicalCores - CountSetBits(dwLevel2GroupMask & ~dwLevel3GroupMask);

return cores;
}

uint32_t HardwareCoreEnumerator::DefaultIntraOpNumThreads() {
// # of physical cores = # of P cores + # of E Cores + # of Soc Cores.
// # of logical cores = # of P cores x 2 (if hyper threading is enabled) + # of E cores + # of Soc Cores.
auto cores = GetCoreInfo();

return cores.LLCCores;
}

} // namespace onnxruntime
12 changes: 12 additions & 0 deletions onnxruntime/core/platform/windows/hardware_core_enumerator.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

#pragma once
#include <stdint.h>

namespace onnxruntime {
struct HardwareCoreEnumerator {
HardwareCoreEnumerator() = delete;
static uint32_t DefaultIntraOpNumThreads();
};
} // namespace onnxruntime
19 changes: 14 additions & 5 deletions onnxruntime/core/util/thread_utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -93,22 +93,31 @@ static std::unique_ptr<ThreadPool>
CreateThreadPoolHelper(Env* env, OrtThreadPoolParams options) {
ThreadOptions to;
if (options.thread_pool_size <= 0) { // default
auto default_affinities = Env::Default().GetDefaultThreadAffinities();
if (default_affinities.size() <= 1) {
return nullptr;
}
options.thread_pool_size = static_cast<int>(default_affinities.size());
if (options.auto_set_affinity) {
#ifdef _WIN32
// Only set thread affinity on Server with auto affinity.
// On client best to let OS scheduler handle.
// On big (P-Core) / little (E-Core) CPU designs affinity overrides QoS and has high power usage
if (IsWindowsServer()) {
auto default_affinities = Env::Default().GetDefaultThreadAffinities();
if (default_affinities.size() <= 1) {
return nullptr;
}
options.thread_pool_size = static_cast<int>(default_affinities.size());
to.affinities = std::move(default_affinities);
} else {
options.thread_pool_size = Env::Default().GetNumPhysicalCpuCores();
}
#else
auto default_affinities = Env::Default().GetDefaultThreadAffinities();
if (default_affinities.size() <= 1) {
return nullptr;
}
options.thread_pool_size = static_cast<int>(default_affinities.size());
to.affinities = std::move(default_affinities);
#endif
} else {
options.thread_pool_size = Env::Default().GetNumPhysicalCpuCores();
}
}
if (options.thread_pool_size <= 1) {
Expand Down
51 changes: 41 additions & 10 deletions onnxruntime/python/tools/transformers/models/llama/benchmark_e2e.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,14 @@
# 4) Install the latest ONNX Runtime version
#
# $ pip install onnxruntime-gpu
#
# 5) Install flash attention v2
#
# $ pip install flash-attn --no-build-isolation
#
# 6) Install bitsandbytes
#
# $ pip install bitsandbytes

from __future__ import annotations

Expand All @@ -38,22 +46,44 @@
import torch
from benchmark_helper import setup_logger
from llama_inputs import add_io_bindings_as_tensors, get_initial_inputs_and_outputs
from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer
from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig

import onnxruntime as ort

logger = logging.getLogger(__name__)


def get_model(args):
def get_model(args: argparse.Namespace):
if args.benchmark_type in {"pt-eager", "pt-compile"}:
model = AutoModelForCausalLM.from_pretrained(
args.hf_dir_path if args.hf_dir_path != "" else args.model_name,
cache_dir=args.cache_dir,
torch_dtype=args.torch_dtype,
use_auth_token=args.auth,
use_cache=True,
).to(args.target_device)
model = None
if args.onnx_precision == "int4" and args.device == "cuda":
bnb_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_use_double_quant=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_compute_dtype=torch.float16,
)

model = AutoModelForCausalLM.from_pretrained(
args.hf_dir_path if args.hf_dir_path != "" else args.model_name,
cache_dir=args.cache_dir,
torch_dtype=args.torch_dtype,
use_auth_token=args.auth,
use_cache=True,
attn_implementation="flash_attention_2",
quantization_config=bnb_config,
max_memory={args.device_id: "80GB"},
)
else:
model = AutoModelForCausalLM.from_pretrained(
args.hf_dir_path if args.hf_dir_path != "" else args.model_name,
cache_dir=args.cache_dir,
torch_dtype=args.torch_dtype,
use_auth_token=args.auth,
use_cache=True,
attn_implementation=("flash_attention_2" if args.device == "cuda" else "sdpa"),
).to(args.target_device)

model.eval()

if args.benchmark_type == "pt-compile":
Expand Down Expand Up @@ -223,7 +253,7 @@ def get_args():
parser.add_argument(
"-s",
"--prompt-lengths",
default="32 64 128 256 512",
default="16 64 256 1024",
)

parser.add_argument(
Expand Down Expand Up @@ -277,6 +307,7 @@ def get_args():
args.prompt_lengths = args.prompt_lengths.split(" ")

# Use FP32 precision for FP32, INT8, INT4 CPU models, use FP16 precision for FP16 and INT4 GPU models
setattr(args, "onnx_precision", args.precision) # noqa: B010
args.precision = (
"fp32" if args.precision in {"int8", "fp32"} or (args.precision == "int4" and args.device == "cpu") else "fp16"
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -151,7 +151,7 @@ def get_args(argv: list[str]):
parser.add_argument(
"-m",
"--model_name",
required=True,
required=False,
help="Model name in Hugging Face",
)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -506,7 +506,13 @@ def main(argv=None):
# Wrap parity check in try-except to allow export to continue in case this produces an error
try:
with torch.no_grad():
max_diff = WhisperHelper.verify_onnx(args.model_name_or_path, cache_dir, ort_session, device)
# Verify batched decoding with prompts for whisper openai implementation
if args.model_impl == "openai" and args.use_forced_decoder_ids:
max_diff = WhisperHelper.verify_onnx(
args.model_name_or_path, cache_dir, ort_session, device, batch_size=2, prompt_mode=True
)
else:
max_diff = WhisperHelper.verify_onnx(args.model_name_or_path, cache_dir, ort_session, device)
if max_diff > 1e-4:
logger.warning("PyTorch and ONNX Runtime results are NOT close")
else:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -51,8 +51,8 @@ def chain_model(args):
decoder_model = onnx.load_model(args.decoder_path, load_external_data=True)
decoder_model.graph.name = "decoder subgraph"

config = WhisperConfig.from_pretrained(args.model_name_or_path)
tokenizer = WhisperTokenizer.from_pretrained(args.model_name_or_path)
config = WhisperConfig.from_pretrained(args.model_name_or_path, cache_dir=args.cache_dir)
tokenizer = WhisperTokenizer.from_pretrained(args.model_name_or_path, cache_dir=args.cache_dir)

# Create inputs/outputs for WhisperBeamSearch op
temperature_name = "temperature_fp16" if args.precision == Precision.FLOAT16 else "temperature"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -126,6 +126,7 @@ def create_dummy(
device: torch.device,
float16: bool = False,
use_int32_inputs: bool = False,
model_impl: str = "hf",
): # -> WhisperDecoderInputs:
"""Create dummy inputs for WhisperDecoder.

Expand Down Expand Up @@ -170,7 +171,7 @@ def create_dummy(
cross_attention_past_shape = [
batch_size,
num_attention_heads,
encode_sequence_length,
encode_sequence_length if model_impl == "hf" else past_decode_sequence_length,
head_size,
]

Expand Down Expand Up @@ -228,6 +229,7 @@ def export_onnx(
past_decode_sequence_length=6 if isinstance(decoder, WhisperDecoder) else 0,
device=device,
use_int32_inputs=use_int32_inputs,
model_impl=decoder.model_impl,
)
input_list = inputs.to_list()

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
# license information.
# --------------------------------------------------------------------------

import copy
import logging
import os
import tempfile
Expand Down Expand Up @@ -51,12 +50,15 @@ def forward(
self,
encoder_input_ids: torch.Tensor,
decoder_input_ids: torch.Tensor = None,
remove_hooks: bool = False,
):
encoder_hidden_states: torch.FloatTensor = self.whisper_encoder(encoder_input_ids)
# Decoder out: (logits, past_key_values, encoder_hidden_state)
if self.model_impl == "openai":
encoder_hidden_states.unsqueeze(0)
decinit_out, present = self.whisper_decoder_openai_init(decoder_input_ids, encoder_hidden_states)
decinit_out, present = self.whisper_decoder_openai_init(
decoder_input_ids, encoder_hidden_states, remove_hooks=remove_hooks
)
return decinit_out, encoder_hidden_states, present
else:
decinit_out = self.whisper_decoder_init(decoder_input_ids, encoder_hidden_states)
Expand Down Expand Up @@ -131,9 +133,7 @@ def export_onnx(
)
input_list = inputs.to_list()

# TODO : Investigate whether copy of model if needed
cloned_model = copy.deepcopy(model).to(device)
out = cloned_model(inputs.encoder_input_ids, inputs.decoder_input_ids)
out = model(inputs.encoder_input_ids, inputs.decoder_input_ids, remove_hooks=True)
present = out[2]
present_names = PastKeyValuesHelper.get_input_names(present, encoder=True)

Expand Down
Loading
Loading