diff --git a/cmake/onnxruntime_mlas.cmake b/cmake/onnxruntime_mlas.cmake
index 10c307b3b911c..5124262ec0004 100644
--- a/cmake/onnxruntime_mlas.cmake
+++ b/cmake/onnxruntime_mlas.cmake
@@ -41,6 +41,8 @@ onnxruntime_add_static_library(onnxruntime_mlas
${MLAS_SRC_DIR}/sqnbitgemm_q8_block.h
${MLAS_SRC_DIR}/flashattn.cpp
${MLAS_SRC_DIR}/cast.cpp
+ ${MLAS_SRC_DIR}/rotary_embedding.h
+ ${MLAS_SRC_DIR}/rotary_embedding.cpp
)
target_sources(onnxruntime_mlas PRIVATE
@@ -88,8 +90,11 @@ function(setup_mlas_source_for_windows)
${MLAS_SRC_DIR}/qnbitgemm_kernel_neon.cpp
${MLAS_SRC_DIR}/sqnbitgemm_kernel_neon_fp32.cpp
${MLAS_SRC_DIR}/sqnbitgemm_kernel_neon_int8.cpp
- ${MLAS_SRC_DIR}/fp16_neon_common.cpp
+ ${MLAS_SRC_DIR}/cast_kernel_neon.cpp
${MLAS_SRC_DIR}/hqnbitgemm_kernel_neon_fp16.cpp
+ ${MLAS_SRC_DIR}/rotary_embedding_kernel_neon.h
+ ${MLAS_SRC_DIR}/rotary_embedding_kernel_neon.cpp
+ ${MLAS_SRC_DIR}/rotary_embedding_kernel_neon_fp16.cpp
)
set(mlas_platform_preprocess_srcs
@@ -367,6 +372,8 @@ else()
${MLAS_SRC_DIR}/qnbitgemm_kernel_neon.cpp
${MLAS_SRC_DIR}/sqnbitgemm_kernel_neon_fp32.cpp
${MLAS_SRC_DIR}/sqnbitgemm_kernel_neon_int8.cpp
+ ${MLAS_SRC_DIR}/rotary_embedding_kernel_neon.h
+ ${MLAS_SRC_DIR}/rotary_embedding_kernel_neon.cpp
)
set_source_files_properties(${MLAS_SRC_DIR}/sqnbitgemm_kernel_neon_int8.cpp
PROPERTIES COMPILE_FLAGS " -march=armv8.2-a+dotprod")
@@ -384,8 +391,9 @@ else()
${MLAS_SRC_DIR}/qgemm_kernel_smmla.cpp
${MLAS_SRC_DIR}/qgemm_kernel_ummla.cpp
${MLAS_SRC_DIR}/sbgemm_kernel_neon.cpp
- ${MLAS_SRC_DIR}/fp16_neon_common.cpp
+ ${MLAS_SRC_DIR}/cast_kernel_neon.cpp
${MLAS_SRC_DIR}/hqnbitgemm_kernel_neon_fp16.cpp
+ ${MLAS_SRC_DIR}/rotary_embedding_kernel_neon_fp16.cpp
)
set_source_files_properties(${MLAS_SRC_DIR}/aarch64/HalfGemmKernelNeon.S PROPERTIES COMPILE_FLAGS " -march=armv8.2-a+fp16 ")
set_source_files_properties(${MLAS_SRC_DIR}/aarch64/QgemmS8S8KernelSmmla.S PROPERTIES COMPILE_FLAGS " -march=armv8.2-a+i8mm ")
@@ -395,8 +403,9 @@ else()
set_source_files_properties(${MLAS_SRC_DIR}/dwconv.cpp PROPERTIES COMPILE_FLAGS " -march=armv8.2-a+fp16 ")
set_source_files_properties(${MLAS_SRC_DIR}/pooling_fp16.cpp PROPERTIES COMPILE_FLAGS " -march=armv8.2-a+fp16 ")
set_source_files_properties(${MLAS_SRC_DIR}/sbgemm_kernel_neon.cpp PROPERTIES COMPILE_FLAGS " -march=armv8.2-a+bf16 ")
- set_source_files_properties(${MLAS_SRC_DIR}/fp16_neon_common.cpp PROPERTIES COMPILE_FLAGS " -march=armv8.2-a+fp16 ")
+ set_source_files_properties(${MLAS_SRC_DIR}/cast_kernel_neon.cpp PROPERTIES COMPILE_FLAGS " -march=armv8.2-a+fp16 ")
set_source_files_properties(${MLAS_SRC_DIR}/hqnbitgemm_kernel_neon_fp16.cpp PROPERTIES COMPILE_FLAGS " -march=armv8.2-a+fp16 ")
+ set_source_files_properties(${MLAS_SRC_DIR}/rotary_embedding_kernel_neon_fp16.cpp PROPERTIES COMPILE_FLAGS " -march=armv8.2-a+fp16 ")
endif()
if(ONNXRUNTIME_MLAS_MULTI_ARCH)
diff --git a/docs/ContribOperators.md b/docs/ContribOperators.md
index b87532debe4bc..6ea3f93cdea12 100644
--- a/docs/ContribOperators.md
+++ b/docs/ContribOperators.md
@@ -1596,6 +1596,8 @@ This version of the operator has been available since version 1 of the 'com.micr
(Optional) Hardware architecture.
main_context : int
Usually each single EPContext associate with a graph partition.But for some case like QNN, it has single EPContext contains all partitions.In that case, the node with ep_cache_context should set main_context=1. Other nodes set main_context=0 and skip ep_cache_context.The path is relative to this Onnx file. Default is 1.
+max_size : int
+max size in the context. Usage depend on the EP.
notes : string
(Optional) Some notes for the model
onnx_model_filename : string
diff --git a/include/onnxruntime/core/framework/kernel_registry.h b/include/onnxruntime/core/framework/kernel_registry.h
index 7b3d04ee66d9e..aaf533135429c 100644
--- a/include/onnxruntime/core/framework/kernel_registry.h
+++ b/include/onnxruntime/core/framework/kernel_registry.h
@@ -8,6 +8,9 @@
#include "core/framework/op_kernel.h"
namespace onnxruntime {
+namespace logging {
+class Logger;
+}
using KernelCreateMap = std::multimap;
using KernelDefHashes = std::vector>;
@@ -33,6 +36,7 @@ class KernelRegistry {
// Kernel matching uses the types from the node and the kernel_type_str_resolver.
Status TryFindKernel(const Node& node, ProviderType exec_provider,
const IKernelTypeStrResolver& kernel_type_str_resolver,
+ const logging::Logger& logger,
const KernelCreateInfo** out) const;
// map of type constraint name to required type
@@ -42,6 +46,7 @@ class KernelRegistry {
// Kernel matching uses the explicit type constraint name to required type map in type_constraints.
Status TryFindKernel(const Node& node, ProviderType exec_provider,
const TypeConstraintMap& type_constraints,
+ const logging::Logger& logger,
const KernelCreateInfo** out) const;
/**
@@ -61,13 +66,15 @@ class KernelRegistry {
std::string_view domain,
int version,
const KernelRegistry::TypeConstraintMap& type_constraints,
+ const logging::Logger& logger,
const KernelCreateInfo** out) const;
static bool HasImplementationOf(const KernelRegistry& r, const Node& node,
ProviderType exec_provider,
- const IKernelTypeStrResolver& kernel_type_str_resolver) {
+ const IKernelTypeStrResolver& kernel_type_str_resolver,
+ const logging::Logger& logger) {
const KernelCreateInfo* info;
- Status st = r.TryFindKernel(node, exec_provider, kernel_type_str_resolver, &info);
+ Status st = r.TryFindKernel(node, exec_provider, kernel_type_str_resolver, logger, &info);
return st.IsOK();
}
@@ -83,6 +90,7 @@ class KernelRegistry {
Status TryFindKernelImpl(const Node& node, ProviderType exec_provider,
const IKernelTypeStrResolver* kernel_type_str_resolver,
const TypeConstraintMap* type_constraints,
+ const logging::Logger& logger,
const KernelCreateInfo** out) const;
// Check whether the types of inputs/outputs of the given node match the extra
diff --git a/include/onnxruntime/core/optimizer/graph_transformer_utils.h b/include/onnxruntime/core/optimizer/graph_transformer_utils.h
index 6cff153c336f0..31b0f22340510 100644
--- a/include/onnxruntime/core/optimizer/graph_transformer_utils.h
+++ b/include/onnxruntime/core/optimizer/graph_transformer_utils.h
@@ -53,6 +53,7 @@ InlinedVector> GenerateTransformers(
TransformerLevel level,
const SessionOptions& session_options,
const IExecutionProvider& execution_provider /*required by constant folding*/,
+ const logging::Logger& logger,
const InlinedHashSet& rules_and_transformers_to_disable = {},
concurrency::ThreadPool* intra_op_thread_pool = nullptr,
std::unordered_map>* p_buffered_tensors = nullptr);
@@ -84,6 +85,7 @@ InlinedVector> GenerateTransformersForMinimalB
const SessionOptions& session_options,
const SatApplyContextVariant& apply_context,
const IExecutionProvider& cpu_execution_provider,
+ const logging::Logger& logger,
const InlinedHashSet& rules_and_transformers_to_disable = {},
concurrency::ThreadPool* intra_op_thread_pool = nullptr,
std::unordered_map>* p_buffered_tensors = nullptr);
diff --git a/include/onnxruntime/core/providers/coreml/coreml_provider_factory.h b/include/onnxruntime/core/providers/coreml/coreml_provider_factory.h
index 3963b80de58a4..d035fd34bd072 100644
--- a/include/onnxruntime/core/providers/coreml/coreml_provider_factory.h
+++ b/include/onnxruntime/core/providers/coreml/coreml_provider_factory.h
@@ -47,8 +47,20 @@ enum COREMLFlags {
// and SessionOptionsAppendExecutionProvider (C API). For the old API, use COREMLFlags instead.
static const char* const kCoremlProviderOption_MLComputeUnits = "MLComputeUnits";
static const char* const kCoremlProviderOption_ModelFormat = "ModelFormat";
+// same as COREML_FLAG_ONLY_ALLOW_STATIC_INPUT_SHAPES
static const char* const kCoremlProviderOption_RequireStaticInputShapes = "RequireStaticInputShapes";
static const char* const kCoremlProviderOption_EnableOnSubgraphs = "EnableOnSubgraphs";
+// provided by https://developer.apple.com/documentation/coreml/mloptimizationhints-swift.struct/specializationstrategy-swift.property
+// Core ML segments the model’s compute graph and specializes each segment for the target compute device.
+// This process can affect the model loading time and the prediction latency.
+// Use this option to tailor the specialization strategy for your model.
+static const char* const kCoremlProviderOption_SpecializationStrategy = "SpecializationStrategy";
+// Profile the Core ML MLComputePlan.
+// This logs the hardware each operator is dispatched to and the estimated execution time.
+// Intended for developer usage but provide useful diagnostic information if performance is not as expected.
+static const char* const kCoremlProviderOption_ProfileComputePlan = "ProfileComputePlan";
+// please refer to https://developer.apple.com/documentation/coreml/mlmodelconfiguration/allowlowprecisionaccumulationongpu
+static const char* const kCoremlProviderOption_AllowLowPrecisionAccumulationOnGPU = "AllowLowPrecisionAccumulationOnGPU";
#ifdef __cplusplus
extern "C" {
diff --git a/include/onnxruntime/core/session/onnxruntime_c_api.h b/include/onnxruntime/core/session/onnxruntime_c_api.h
index b1a79f5921328..a35d975ac8f1b 100644
--- a/include/onnxruntime/core/session/onnxruntime_c_api.h
+++ b/include/onnxruntime/core/session/onnxruntime_c_api.h
@@ -3667,6 +3667,9 @@ struct OrtApi {
* execution provider (typically CPU EP).
* - "0": Default. Disabled. QNN EP will handle quantization and dequantization of graph I/O.
* - "1": Enabled.
+ * "enable_htp_spill_fill_buffer": Enable HTP spill fill buffer setting. The flag is used while generating context binary.
+ * - "0": Default. Disabled.
+ * - "1": Enabled.
*
* SNPE supported keys:
* "runtime": SNPE runtime engine, options: "CPU", "CPU_FLOAT32", "GPU", "GPU_FLOAT32_16_HYBRID", "GPU_FLOAT16",
@@ -4612,6 +4615,8 @@ struct OrtApi {
* \param[in] num_keys
*
* \snippet{doc} snippets.dox OrtStatus Return Value
+ *
+ * \since Version 1.17.
*/
ORT_API2_STATUS(SessionOptionsAppendExecutionProvider_OpenVINO_V2,
_In_ OrtSessionOptions* options,
@@ -4629,6 +4634,8 @@ struct OrtApi {
* \param[in] num_keys
*
* \snippet{doc} snippets.dox OrtStatus Return Value
+ *
+ * \since Version 1.18.
*/
ORT_API2_STATUS(SessionOptionsAppendExecutionProvider_VitisAI,
_In_ OrtSessionOptions* options,
@@ -4642,7 +4649,10 @@ struct OrtApi {
* \param[in] mem_info OrtMemoryInfo instance
* \param[in] count_or_bytes How many bytes is this scratch buffer
* \param[out] out A pointer to the scrach buffer
+ *
* \snippet{doc} snippets.dox OrtStatus Return Value
+ *
+ * \since Version 1.18.
*/
ORT_API2_STATUS(KernelContext_GetScratchBuffer, _In_ const OrtKernelContext* context, _In_ const OrtMemoryInfo* mem_info, _In_ size_t count_or_bytes, _Outptr_ void** out);
@@ -4653,6 +4663,8 @@ struct OrtApi {
* \param[out] out A pointer to OrtAllocator
*
* \snippet{doc} snippets.dox OrtStatus Return Value
+ *
+ * \since Version 1.18.
*/
ORT_API2_STATUS(KernelInfoGetAllocator, _In_ const OrtKernelInfo* info, _In_ OrtMemType mem_type, _Outptr_ OrtAllocator** out);
@@ -4674,6 +4686,8 @@ struct OrtApi {
* \param[in] num_external_initializer_files Number of external files
*
* \snippet{doc} snippets.dox OrtStatus Return Value
+ *
+ * \since Version 1.18.
*/
ORT_API2_STATUS(AddExternalInitializersFromFilesInMemory, _In_ OrtSessionOptions* options,
_In_reads_(num_external_initializer_files) const ORTCHAR_T* const* external_initializer_file_names,
@@ -4696,6 +4710,8 @@ struct OrtApi {
* OrtApi::ReleaseLoraAdapter.
*
* \snippet{doc} snippets.dox OrtStatus Return Value
+ *
+ * \since Version 1.20.
*/
ORT_API2_STATUS(CreateLoraAdapter, const ORTCHAR_T* adapter_file_path, _In_ OrtAllocator* allocator,
_Outptr_ OrtLoraAdapter** out);
@@ -4714,6 +4730,8 @@ struct OrtApi {
* OrtApi::ReleaseLoraAdapter.
*
* \snippet{doc} snippets.dox OrtStatus Return Value
+ *
+ * \since Version 1.20.
*/
ORT_API2_STATUS(CreateLoraAdapterFromArray, _In_ const void* bytes, size_t num_bytes, _In_ OrtAllocator* allocator,
_Outptr_ OrtLoraAdapter** out);
@@ -4735,6 +4753,8 @@ struct OrtApi {
* \param[in] adapter OrtLoraAdapter instance
*
* \snippet{doc} snippets.dox OrtStatus Return Value
+ *
+ * \since Version 1.20.
*/
ORT_API2_STATUS(RunOptionsAddActiveLoraAdapter, _Inout_ OrtRunOptions* options, _In_ const OrtLoraAdapter* adapter);
@@ -4753,6 +4773,8 @@ struct OrtApi {
* \param[in] kv_len Number of elements in the keys and values arrays
*
* \snippet{doc} snippets.dox OrtStatus Return Value
+ *
+ * \since Version 1.20.
*/
ORT_API2_STATUS(SetEpDynamicOptions, _Inout_ OrtSession* sess, _In_reads_(kv_len) const char* const* keys,
_In_reads_(kv_len) const char* const* values, _In_ size_t kv_len);
diff --git a/js/.eslintrc.js b/js/.eslintrc.js
index bd1e9061355f5..462e417df1d66 100644
--- a/js/.eslintrc.js
+++ b/js/.eslintrc.js
@@ -198,19 +198,6 @@ module.exports = {
'_OrtReleaseTensor',
'_OrtRun',
'_OrtRunWithBinding',
- '_OrtTrainingCopyParametersFromBuffer',
- '_OrtTrainingCopyParametersToBuffer',
- '_OrtTrainingCreateSession',
- '_OrtTrainingEvalStep',
- '_OrtTrainingGetModelInputOutputCount',
- '_OrtTrainingGetModelInputOutputName',
- '_OrtTrainingGetParametersSize',
- '_OrtTrainingLazyResetGrad',
- '_OrtTrainingLoadCheckpoint',
- '_OrtTrainingOptimizerStep',
- '_OrtTrainingReleaseCheckpoint',
- '_OrtTrainingReleaseSession',
- '_OrtTrainingRunTrainStep',
],
},
],
diff --git a/js/common/lib/backend.ts b/js/common/lib/backend.ts
index e27e67622aa82..e63f9c6c9147f 100644
--- a/js/common/lib/backend.ts
+++ b/js/common/lib/backend.ts
@@ -3,7 +3,6 @@
import { InferenceSession } from './inference-session.js';
import { OnnxValue } from './onnx-value.js';
-import { TrainingSession } from './training-session.js';
/**
* @ignore
@@ -42,33 +41,6 @@ export interface InferenceSessionHandler extends SessionHandler {
): Promise;
}
-/**
- * Represent a handler instance of a training inference session.
- *
- * @ignore
- */
-export interface TrainingSessionHandler extends SessionHandler {
- readonly evalInputNames: readonly string[];
- readonly evalOutputNames: readonly string[];
-
- lazyResetGrad(): Promise;
- runTrainStep(
- feeds: SessionHandler.FeedsType,
- fetches: SessionHandler.FetchesType,
- options: InferenceSession.RunOptions,
- ): Promise;
- runOptimizerStep(options: InferenceSession.RunOptions): Promise;
- runEvalStep(
- feeds: SessionHandler.FeedsType,
- fetches: SessionHandler.FetchesType,
- options: InferenceSession.RunOptions,
- ): Promise;
-
- getParametersSize(trainableOnly: boolean): Promise;
- loadParametersBuffer(buffer: Uint8Array, trainableOnly: boolean): Promise;
- getContiguousParameters(trainableOnly: boolean): Promise;
-}
-
/**
* Represent a backend that provides implementation of model inferencing.
*
@@ -84,14 +56,6 @@ export interface Backend {
uriOrBuffer: string | Uint8Array,
options?: InferenceSession.SessionOptions,
): Promise;
-
- createTrainingSessionHandler?(
- checkpointStateUriOrBuffer: TrainingSession.UriOrBuffer,
- trainModelUriOrBuffer: TrainingSession.UriOrBuffer,
- evalModelUriOrBuffer: TrainingSession.UriOrBuffer,
- optimizerModelUriOrBuffer: TrainingSession.UriOrBuffer,
- options: InferenceSession.SessionOptions,
- ): Promise;
}
export { registerBackend } from './backend-impl.js';
diff --git a/js/common/lib/env.ts b/js/common/lib/env.ts
index 642a897a90d26..e70f608ad7030 100644
--- a/js/common/lib/env.ts
+++ b/js/common/lib/env.ts
@@ -2,6 +2,7 @@
// Licensed under the MIT License.
import { env as envImpl } from './env-impl.js';
+import { TryGetGlobalType } from './type-helper.js';
export declare namespace Env {
export type WasmPathPrefix = string;
@@ -14,7 +15,6 @@ export declare namespace Env {
* If not modified, the filename of the .wasm file is:
* - `ort-wasm-simd-threaded.wasm` for default build
* - `ort-wasm-simd-threaded.jsep.wasm` for JSEP build (with WebGPU and WebNN)
- * - `ort-training-wasm-simd-threaded.wasm` for training build
*/
wasm?: URL | string;
/**
@@ -25,7 +25,6 @@ export declare namespace Env {
* If not modified, the filename of the .mjs file is:
* - `ort-wasm-simd-threaded.mjs` for default build
* - `ort-wasm-simd-threaded.jsep.mjs` for JSEP build (with WebGPU and WebNN)
- * - `ort-training-wasm-simd-threaded.mjs` for training build
*/
mjs?: URL | string;
}
@@ -200,22 +199,16 @@ export declare namespace Env {
* value will be the GPU adapter that created by the underlying WebGPU backend.
*
* When use with TypeScript, the type of this property is `GPUAdapter` defined in "@webgpu/types".
- * Use `const adapter = env.webgpu.adapter as GPUAdapter;` in TypeScript to access this property with correct type.
- *
- * see comments on {@link Tensor.GpuBufferType}
*/
- adapter: unknown;
+ adapter: TryGetGlobalType<'GPUAdapter'>;
/**
* Get the device for WebGPU.
*
* This property is only available after the first WebGPU inference session is created.
*
* When use with TypeScript, the type of this property is `GPUDevice` defined in "@webgpu/types".
- * Use `const device = env.webgpu.device as GPUDevice;` in TypeScript to access this property with correct type.
- *
- * see comments on {@link Tensor.GpuBufferType} for more details about why not use types defined in "@webgpu/types".
*/
- readonly device: unknown;
+ readonly device: TryGetGlobalType<'GPUDevice'>;
/**
* Set or get whether validate input content.
*
diff --git a/js/common/lib/index.ts b/js/common/lib/index.ts
index 3ed56b3c2e812..d75e6a477258d 100644
--- a/js/common/lib/index.ts
+++ b/js/common/lib/index.ts
@@ -26,4 +26,3 @@ export * from './tensor-factory.js';
export * from './trace.js';
export * from './onnx-model.js';
export * from './onnx-value.js';
-export * from './training-session.js';
diff --git a/js/common/lib/inference-session.ts b/js/common/lib/inference-session.ts
index 547db029471a2..e62c6579e8333 100644
--- a/js/common/lib/inference-session.ts
+++ b/js/common/lib/inference-session.ts
@@ -4,6 +4,7 @@
import { InferenceSession as InferenceSessionImpl } from './inference-session-impl.js';
import { OnnxModelOptions } from './onnx-model.js';
import { OnnxValue, OnnxValueDataLocation } from './onnx-value.js';
+import { TryGetGlobalType } from './type-helper.js';
/* eslint-disable @typescript-eslint/no-redeclare */
@@ -282,7 +283,7 @@ export declare namespace InferenceSession {
extends WebNNExecutionProviderName,
Omit,
Required> {
- context: unknown /* MLContext */;
+ context: TryGetGlobalType<'MLContext'>;
}
/**
@@ -291,8 +292,8 @@ export declare namespace InferenceSession {
* @see https://www.w3.org/TR/webnn/#dom-ml-createcontext-gpudevice
*/
export interface WebNNOptionsWebGpu extends WebNNExecutionProviderName {
- context: unknown /* MLContext */;
- gpuDevice: unknown /* GPUDevice */;
+ context: TryGetGlobalType<'MLContext'>;
+ gpuDevice: TryGetGlobalType<'GPUDevice'>;
}
/**
diff --git a/js/common/lib/tensor.ts b/js/common/lib/tensor.ts
index af918705b97e3..05553bd96662b 100644
--- a/js/common/lib/tensor.ts
+++ b/js/common/lib/tensor.ts
@@ -4,6 +4,7 @@
import { TensorFactory } from './tensor-factory.js';
import { Tensor as TensorImpl } from './tensor-impl.js';
import { TypedTensorUtils } from './tensor-utils.js';
+import { TryGetGlobalType } from './type-helper.js';
/* eslint-disable @typescript-eslint/no-redeclare */
@@ -131,24 +132,19 @@ export declare namespace Tensor {
*/
export type TextureDataTypes = 'float32';
+ type GpuBufferTypeFallback = { size: number; mapState: 'unmapped' | 'pending' | 'mapped' };
/**
* type alias for WebGPU buffer
- *
- * The reason why we don't use type "GPUBuffer" defined in webgpu.d.ts from @webgpu/types is because "@webgpu/types"
- * requires "@types/dom-webcodecs" as peer dependency when using TypeScript < v5.1 and its version need to be chosen
- * carefully according to the TypeScript version being used. This means so far there is not a way to keep every
- * TypeScript version happy. It turns out that we will easily broke users on some TypeScript version.
- *
- * for more info see https://github.com/gpuweb/types/issues/127
*/
- export type GpuBufferType = { size: number; mapState: 'unmapped' | 'pending' | 'mapped' };
+ export type GpuBufferType = TryGetGlobalType<'GPUBuffer', GpuBufferTypeFallback>;
+ type MLTensorTypeFallback = { destroy(): void };
/**
* type alias for WebNN MLTensor
*
* The specification for WebNN's MLTensor is currently in flux.
*/
- export type MLTensorType = unknown;
+ export type MLTensorType = TryGetGlobalType<'MLTensor', MLTensorTypeFallback>;
/**
* supported data types for constructing a tensor from a WebGPU buffer
diff --git a/js/common/lib/training-session-impl.ts b/js/common/lib/training-session-impl.ts
deleted file mode 100644
index 21dbe5fe51bb9..0000000000000
--- a/js/common/lib/training-session-impl.ts
+++ /dev/null
@@ -1,273 +0,0 @@
-// Copyright (c) Microsoft Corporation. All rights reserved.
-// Licensed under the MIT License.
-
-import { resolveBackendAndExecutionProviders } from './backend-impl.js';
-import { SessionHandler, TrainingSessionHandler } from './backend.js';
-import { InferenceSession as InferenceSession } from './inference-session.js';
-import { OnnxValue } from './onnx-value.js';
-import { Tensor } from './tensor.js';
-import { TrainingSession as TrainingSessionInterface, TrainingSessionCreateOptions } from './training-session.js';
-
-type SessionOptions = InferenceSession.SessionOptions;
-type FeedsType = InferenceSession.FeedsType;
-type FetchesType = InferenceSession.FetchesType;
-type ReturnType = InferenceSession.ReturnType;
-type RunOptions = InferenceSession.RunOptions;
-
-const noBackendErrMsg: string =
- 'Training backend could not be resolved. ' + "Make sure you're using the correct configuration & WebAssembly files.";
-
-export class TrainingSession implements TrainingSessionInterface {
- private constructor(handler: TrainingSessionHandler, hasOptimizerModel: boolean, hasEvalModel: boolean) {
- this.handler = handler;
- this.hasOptimizerModel = hasOptimizerModel;
- this.hasEvalModel = hasEvalModel;
- }
- private handler: TrainingSessionHandler;
- private hasOptimizerModel: boolean;
- private hasEvalModel: boolean;
-
- get trainingInputNames(): readonly string[] {
- return this.handler.inputNames;
- }
- get trainingOutputNames(): readonly string[] {
- return this.handler.outputNames;
- }
-
- get evalInputNames(): readonly string[] {
- if (this.hasEvalModel) {
- return this.handler.evalInputNames;
- } else {
- throw new Error('This training session has no evalModel loaded.');
- }
- }
- get evalOutputNames(): readonly string[] {
- if (this.hasEvalModel) {
- return this.handler.evalOutputNames;
- } else {
- throw new Error('This training session has no evalModel loaded.');
- }
- }
-
- static async create(
- trainingOptions: TrainingSessionCreateOptions,
- sessionOptions?: SessionOptions,
- ): Promise {
- const evalModel: string | Uint8Array = trainingOptions.evalModel || '';
- const optimizerModel: string | Uint8Array = trainingOptions.optimizerModel || '';
- const options: SessionOptions = sessionOptions || {};
-
- // resolve backend, update session options with validated EPs, and create session handler
- const [backend, optionsWithValidatedEPs] = await resolveBackendAndExecutionProviders(options);
- if (backend.createTrainingSessionHandler) {
- const handler = await backend.createTrainingSessionHandler(
- trainingOptions.checkpointState,
- trainingOptions.trainModel,
- evalModel,
- optimizerModel,
- optionsWithValidatedEPs,
- );
- return new TrainingSession(handler, !!trainingOptions.optimizerModel, !!trainingOptions.evalModel);
- } else {
- throw new Error(noBackendErrMsg);
- }
- }
-
- /**
- * Helper function for runTrainStep and future runStep methods that handles the type-narrowing conversion from
- * the given parameters to SessionHandler.FetchesType and RunOptions.
- *
- * @param inputNames the feeds object is checked that they contain all input names in the provided list of input
- * names.
- * @param outputNames the fetches object is checked that their keys match up with valid names in the list of output
- * names.
- * @param feeds the required input
- * @param arg1 narrowed & converted into the SessionHandler.FetchesType or RunOptions object
- * @param arg2 optional RunOptions object.
- * @returns
- */
- typeNarrowingForRunStep(
- inputNames: readonly string[],
- outputNames: readonly string[],
- feeds: FeedsType,
- arg1?: FetchesType | RunOptions,
- arg2?: RunOptions,
- ): [SessionHandler.FetchesType, RunOptions] {
- const fetches: { [name: string]: OnnxValue | null } = {};
- let options: RunOptions = {};
- // check inputs
- if (typeof feeds !== 'object' || feeds === null || feeds instanceof Tensor || Array.isArray(feeds)) {
- throw new TypeError(
- "'feeds' must be an object that use input names as keys and OnnxValue as corresponding values.",
- );
- }
-
- let isFetchesEmpty = true;
- // determine which override is being used
- if (typeof arg1 === 'object') {
- if (arg1 === null) {
- throw new TypeError('Unexpected argument[1]: cannot be null.');
- }
- if (arg1 instanceof Tensor) {
- throw new TypeError("'fetches' cannot be a Tensor");
- }
-
- if (Array.isArray(arg1)) {
- if (arg1.length === 0) {
- throw new TypeError("'fetches' cannot be an empty array.");
- }
- isFetchesEmpty = false;
- // output names
- for (const name of arg1) {
- if (typeof name !== 'string') {
- throw new TypeError("'fetches' must be a string array or an object.");
- }
- if (outputNames.indexOf(name) === -1) {
- throw new RangeError(`'fetches' contains invalid output name: ${name}.`);
- }
- fetches[name] = null;
- }
-
- if (typeof arg2 === 'object' && arg2 !== null) {
- options = arg2;
- } else if (typeof arg2 !== 'undefined') {
- throw new TypeError("'options' must be an object.");
- }
- } else {
- // decide whether arg1 is fetches or options
- // if any output name is present and its value is valid OnnxValue, we consider it fetches
- let isFetches = false;
- const arg1Keys = Object.getOwnPropertyNames(arg1);
- for (const name of outputNames) {
- if (arg1Keys.indexOf(name) !== -1) {
- const v = (arg1 as InferenceSession.NullableOnnxValueMapType)[name];
- if (v === null || v instanceof Tensor) {
- isFetches = true;
- isFetchesEmpty = false;
- fetches[name] = v;
- }
- }
- }
-
- if (isFetches) {
- if (typeof arg2 === 'object' && arg2 !== null) {
- options = arg2;
- } else if (typeof arg2 !== 'undefined') {
- throw new TypeError("'options' must be an object.");
- }
- } else {
- options = arg1 as RunOptions;
- }
- }
- } else if (typeof arg1 !== 'undefined') {
- throw new TypeError("Unexpected argument[1]: must be 'fetches' or 'options'.");
- }
-
- // check if all inputs are in feed
- for (const name of inputNames) {
- if (typeof feeds[name] === 'undefined') {
- throw new Error(`input '${name}' is missing in 'feeds'.`);
- }
- }
-
- // if no fetches is specified, we use the full output names list
- if (isFetchesEmpty) {
- for (const name of outputNames) {
- fetches[name] = null;
- }
- }
-
- return [fetches, options];
- }
-
- /**
- * Helper method for runTrainStep and any other runStep methods. Takes the ReturnType result from the SessionHandler
- * and changes it into a map of Tensors.
- *
- * @param results
- * @returns
- */
- convertHandlerReturnTypeToMapOfTensors(results: SessionHandler.ReturnType): ReturnType {
- const returnValue: { [name: string]: OnnxValue } = {};
- for (const key in results) {
- if (Object.hasOwnProperty.call(results, key)) {
- const result = results[key];
- if (result instanceof Tensor) {
- returnValue[key] = result;
- } else {
- returnValue[key] = new Tensor(result.type, result.data, result.dims);
- }
- }
- }
- return returnValue;
- }
-
- async lazyResetGrad(): Promise {
- await this.handler.lazyResetGrad();
- }
-
- runTrainStep(feeds: FeedsType, options?: RunOptions): Promise;
- runTrainStep(feeds: FeedsType, fetches: FetchesType, options?: RunOptions): Promise;
- async runTrainStep(feeds: FeedsType, arg1?: FetchesType | RunOptions, arg2?: RunOptions): Promise {
- const [fetches, options] = this.typeNarrowingForRunStep(
- this.trainingInputNames,
- this.trainingOutputNames,
- feeds,
- arg1,
- arg2,
- );
- const results = await this.handler.runTrainStep(feeds, fetches, options);
- return this.convertHandlerReturnTypeToMapOfTensors(results);
- }
-
- async runOptimizerStep(options?: InferenceSession.RunOptions | undefined): Promise {
- if (this.hasOptimizerModel) {
- await this.handler.runOptimizerStep(options || {});
- } else {
- throw new Error('This TrainingSession has no OptimizerModel loaded.');
- }
- }
-
- runEvalStep(feeds: FeedsType, options?: RunOptions | undefined): Promise;
- runEvalStep(feeds: FeedsType, fetches: FetchesType, options?: RunOptions | undefined): Promise;
- async runEvalStep(feeds: FeedsType, arg1?: FetchesType | RunOptions, arg2?: RunOptions): Promise {
- if (this.hasEvalModel) {
- const [fetches, options] = this.typeNarrowingForRunStep(
- this.evalInputNames,
- this.evalOutputNames,
- feeds,
- arg1,
- arg2,
- );
- const results = await this.handler.runEvalStep(feeds, fetches, options);
- return this.convertHandlerReturnTypeToMapOfTensors(results);
- } else {
- throw new Error('This TrainingSession has no EvalModel loaded.');
- }
- }
-
- async getParametersSize(trainableOnly = true): Promise {
- return this.handler.getParametersSize(trainableOnly);
- }
-
- async loadParametersBuffer(array: Uint8Array, trainableOnly = true): Promise {
- const paramsSize = await this.getParametersSize(trainableOnly);
- // checking that the size of the Uint8Array is equivalent to the byte length of a Float32Array of the number
- // of parameters
- if (array.length !== 4 * paramsSize) {
- throw new Error(
- 'Size of the buffer passed into loadParametersBuffer must match the number of parameters in ' +
- 'the model. Please use getParametersSize method to check.',
- );
- }
- return this.handler.loadParametersBuffer(array, trainableOnly);
- }
-
- async getContiguousParameters(trainableOnly = true): Promise {
- return this.handler.getContiguousParameters(trainableOnly);
- }
-
- async release(): Promise {
- return this.handler.dispose();
- }
-}
diff --git a/js/common/lib/training-session.ts b/js/common/lib/training-session.ts
deleted file mode 100644
index 45dcafc46deb5..0000000000000
--- a/js/common/lib/training-session.ts
+++ /dev/null
@@ -1,206 +0,0 @@
-// Copyright (c) Microsoft Corporation. All rights reserved.
-// Licensed under the MIT License.
-
-import { InferenceSession } from './inference-session.js';
-import { OnnxValue } from './onnx-value.js';
-import { TrainingSession as TrainingSessionImpl } from './training-session-impl.js';
-
-/* eslint-disable @typescript-eslint/no-redeclare */
-
-export declare namespace TrainingSession {
- /**
- * Either URI file path (string) or Uint8Array containing model or checkpoint information.
- */
- type UriOrBuffer = string | Uint8Array;
-}
-
-/**
- * Represent a runtime instance of an ONNX training session,
- * which contains a model that can be trained, and, optionally,
- * an eval and optimizer model.
- */
-export interface TrainingSession {
- // #region run()
-
- /**
- * Lazily resets the gradients of all trainable parameters to zero. Should happen after the invocation of
- * runOptimizerStep.
- */
- lazyResetGrad(): Promise;
-
- /**
- * Run TrainStep asynchronously with the given feeds and options.
- *
- * @param feeds - Representation of the model input. See type description of `InferenceSession.InputType` for
- detail.
- * @param options - Optional. A set of options that controls the behavior of model training.
- * @returns A promise that resolves to a map, which uses output names as keys and OnnxValue as corresponding values.
- */
- runTrainStep(
- feeds: InferenceSession.FeedsType,
- options?: InferenceSession.RunOptions,
- ): Promise;
-
- /**
- * Run a single train step with the given inputs and options.
- *
- * @param feeds - Representation of the model input.
- * @param fetches - Representation of the model output.
- * detail.
- * @param options - Optional. A set of options that controls the behavior of model training.
- * @returns A promise that resolves to a map, which uses output names as keys and OnnxValue as corresponding
- values.
- */
- runTrainStep(
- feeds: InferenceSession.FeedsType,
- fetches: InferenceSession.FetchesType,
- options?: InferenceSession.RunOptions,
- ): Promise;
-
- /**
- * Runs a single optimizer step, which performs weight updates for the trainable parameters using the optimizer model.
- *
- * @param options - Optional. A set of options that controls the behavior of model optimizing.
- */
- runOptimizerStep(options?: InferenceSession.RunOptions): Promise;
-
- /**
- * Run a single eval step with the given inputs and options using the eval model.
- *
- * @param feeds - Representation of the model input.
- * @param options - Optional. A set of options that controls the behavior of model eval step.
- * @returns A promise that resolves to a map, which uses output names as keys and OnnxValue as corresponding
- values.
- */
- runEvalStep(
- feeds: InferenceSession.FeedsType,
- options?: InferenceSession.RunOptions,
- ): Promise;
-
- /**
- * Run a single eval step with the given inputs and options using the eval model.
- *
- * @param feeds - Representation of the model input.
- * @param fetches - Representation of the model output.
- * detail.
- * @param options - Optional. A set of options that controls the behavior of model eval step.
- * @returns A promise that resolves to a map, which uses output names as keys and OnnxValue as corresponding
- values.
- */
- runEvalStep(
- feeds: InferenceSession.FeedsType,
- fetches: InferenceSession.FetchesType,
- options?: InferenceSession.RunOptions,
- ): Promise;
-
- // #endregion
-
- // #region copy parameters
-
- /**
- * Retrieves the size of all parameters for the training state. Calculates the total number of primitive (datatype of
- * the parameters) elements of all the parameters in the training state.
- *
- * @param trainableOnly - When set to true, the size is calculated for trainable params only. Default value is true.
- */
- getParametersSize(trainableOnly: boolean): Promise;
-
- /**
- * Copies parameter values from the given buffer to the training state. Currently, only supporting models with
- * parameters of type Float32.
- *
- * @param buffer - A Uint8Array representation of Float32 parameters.
- * @param trainableOnly - True if trainable parameters only to be modified, false otherwise. Default value is true.
- */
- loadParametersBuffer(buffer: Uint8Array, trainableOnly: boolean): Promise;
-
- /**
- * Copies the model parameters to a contiguous buffer. Usually used in the context of Federated Learning.
- * Currently, only supporting models with parameters of type Float32.
- *
- * @param trainableOnly - When set to true, only trainable parameters are copied. Trainable parameters are parameters
- * for which requires_grad is set to true. Default value is true.
- * @returns A promise that resolves to a Float32 OnnxValue of the requested parameters.
- */
- getContiguousParameters(trainableOnly: boolean): Promise;
- // #endregion
-
- // #region release()
-
- /**
- * Release the inference session and the underlying resources.
- */
- release(): Promise;
- // #endregion
-
- // #region metadata
-
- /**
- * Get input names of the loaded training model.
- */
- readonly trainingInputNames: readonly string[];
-
- /**
- * Get output names of the loaded training model.
- */
- readonly trainingOutputNames: readonly string[];
-
- /**
- * Get input names of the loaded eval model. Is an empty array if no eval model is loaded.
- */
- readonly evalInputNames: readonly string[];
-
- /**
- * Get output names of the loaded eval model. Is an empty array if no eval model is loaded.
- */
- readonly evalOutputNames: readonly string[];
-
- // #endregion
-}
-
-/**
- * Represents the optional parameters that can be passed into the TrainingSessionFactory.
- */
-export interface TrainingSessionCreateOptions {
- /**
- * URI or buffer for a .ckpt file that contains the checkpoint for the training model.
- */
- checkpointState: TrainingSession.UriOrBuffer;
- /**
- * URI or buffer for the .onnx training file.
- */
- trainModel: TrainingSession.UriOrBuffer;
- /**
- * Optional. URI or buffer for the .onnx optimizer model file.
- */
- optimizerModel?: TrainingSession.UriOrBuffer;
- /**
- * Optional. URI or buffer for the .onnx eval model file.
- */
- evalModel?: TrainingSession.UriOrBuffer;
-}
-
-/**
- * Defines method overload possibilities for creating a TrainingSession.
- */
-export interface TrainingSessionFactory {
- // #region create()
-
- /**
- * Creates a new TrainingSession and asynchronously loads any models passed in through trainingOptions
- *
- * @param trainingOptions specify models and checkpoints to load into the Training Session
- * @param sessionOptions specify configuration for training session behavior
- *
- * @returns Promise that resolves to a TrainingSession object
- */
- create(
- trainingOptions: TrainingSessionCreateOptions,
- sessionOptions?: InferenceSession.SessionOptions,
- ): Promise;
-
- // #endregion
-}
-
-// eslint-disable-next-line @typescript-eslint/naming-convention
-export const TrainingSession: TrainingSessionFactory = TrainingSessionImpl;
diff --git a/js/common/lib/type-helper.ts b/js/common/lib/type-helper.ts
new file mode 100644
index 0000000000000..845ba3018d443
--- /dev/null
+++ b/js/common/lib/type-helper.ts
@@ -0,0 +1,31 @@
+// Copyright (c) Microsoft Corporation. All rights reserved.
+// Licensed under the MIT License.
+
+/**
+ * A helper type to get certain types if they are declared in global scope.
+ *
+ * For example, if you installed "@webgpu/types" as a dev dependency, then `TryGetTypeIfDeclared<'GPUDevice'>` will
+ * be type `GPUDevice`, otherwise it will be type `unknown`.
+ *
+ *
+ * We don't want to introduce "@webgpu/types" as a dependency of this package because:
+ *
+ * (1) For JavaScript users, it's not needed. For TypeScript users, they can install it as dev dependency themselves.
+ *
+ * (2) because "@webgpu/types" requires "@types/dom-webcodecs" as peer dependency when using TypeScript < v5.1 and its
+ * version need to be chosen carefully according to the TypeScript version being used. This means so far there is not a
+ * way to keep every TypeScript version happy. It turns out that we will easily broke users on some TypeScript version.
+ *
+ * for more info see https://github.com/gpuweb/types/issues/127
+ *
+ * Update (2024-08-07): The reason (2) may be no longer valid. Most people should be using TypeScript >= 5.1 by now.
+ * However, we are still not sure whether introducing "@webgpu/types" as direct dependency is a good idea. We find this
+ * type helper is useful for TypeScript users.
+ *
+ * @ignore
+ */
+export type TryGetGlobalType = typeof globalThis extends {
+ [k in Name]: { prototype: infer T };
+}
+ ? T
+ : Fallback;
diff --git a/js/common/typedoc.json b/js/common/typedoc.json
index 088c7ba4053e6..f9c7e7b19db41 100644
--- a/js/common/typedoc.json
+++ b/js/common/typedoc.json
@@ -1,6 +1,7 @@
{
"entryPoints": ["lib/index.ts"],
"excludeInternal": true,
+ "intentionallyNotExported": ["TryGetGlobalType"],
"name": "ONNX Runtime JavaScript API",
"readme": "none",
"cleanOutputDir": true
diff --git a/js/react_native/android/gradle/wrapper/gradle-wrapper.jar b/js/react_native/android/gradle/wrapper/gradle-wrapper.jar
index 249e5832f090a..e6441136f3d4b 100644
Binary files a/js/react_native/android/gradle/wrapper/gradle-wrapper.jar and b/js/react_native/android/gradle/wrapper/gradle-wrapper.jar differ
diff --git a/js/react_native/android/gradle/wrapper/gradle-wrapper.properties b/js/react_native/android/gradle/wrapper/gradle-wrapper.properties
index a4cb2cd861394..381baa9cef1ec 100644
--- a/js/react_native/android/gradle/wrapper/gradle-wrapper.properties
+++ b/js/react_native/android/gradle/wrapper/gradle-wrapper.properties
@@ -1,6 +1,8 @@
distributionBase=GRADLE_USER_HOME
distributionPath=wrapper/dists
distributionSha256Sum=544c35d6bd849ae8a5ed0bcea39ba677dc40f49df7d1835561582da2009b961d
-distributionUrl=https\://services.gradle.org/distributions/gradle-7.5.1-all.zip
+distributionUrl=https\://services.gradle.org/distributions/gradle-8.7-bin.zip
+networkTimeout=10000
+validateDistributionUrl=true
zipStoreBase=GRADLE_USER_HOME
zipStorePath=wrapper/dists
diff --git a/js/react_native/android/gradlew b/js/react_native/android/gradlew
index a69d9cb6c2065..1aa94a4269074 100755
--- a/js/react_native/android/gradlew
+++ b/js/react_native/android/gradlew
@@ -55,7 +55,7 @@
# Darwin, MinGW, and NonStop.
#
# (3) This script is generated from the Groovy template
-# https://github.com/gradle/gradle/blob/master/subprojects/plugins/src/main/resources/org/gradle/api/internal/plugins/unixStartScript.txt
+# https://github.com/gradle/gradle/blob/HEAD/subprojects/plugins/src/main/resources/org/gradle/api/internal/plugins/unixStartScript.txt
# within the Gradle project.
#
# You can find Gradle at https://github.com/gradle/gradle/.
@@ -80,13 +80,11 @@ do
esac
done
-APP_HOME=$( cd "${APP_HOME:-./}" && pwd -P ) || exit
-
-APP_NAME="Gradle"
+# This is normally unused
+# shellcheck disable=SC2034
APP_BASE_NAME=${0##*/}
-
-# Add default JVM options here. You can also use JAVA_OPTS and GRADLE_OPTS to pass JVM options to this script.
-DEFAULT_JVM_OPTS='"-Xmx64m" "-Xms64m"'
+# Discard cd standard output in case $CDPATH is set (https://github.com/gradle/gradle/issues/25036)
+APP_HOME=$( cd "${APP_HOME:-./}" > /dev/null && pwd -P ) || exit
# Use the maximum available, or set MAX_FD != -1 to use that value.
MAX_FD=maximum
@@ -133,22 +131,29 @@ location of your Java installation."
fi
else
JAVACMD=java
- which java >/dev/null 2>&1 || die "ERROR: JAVA_HOME is not set and no 'java' command could be found in your PATH.
+ if ! command -v java >/dev/null 2>&1
+ then
+ die "ERROR: JAVA_HOME is not set and no 'java' command could be found in your PATH.
Please set the JAVA_HOME variable in your environment to match the
location of your Java installation."
+ fi
fi
# Increase the maximum file descriptors if we can.
if ! "$cygwin" && ! "$darwin" && ! "$nonstop" ; then
case $MAX_FD in #(
max*)
+ # In POSIX sh, ulimit -H is undefined. That's why the result is checked to see if it worked.
+ # shellcheck disable=SC2039,SC3045
MAX_FD=$( ulimit -H -n ) ||
warn "Could not query maximum file descriptor limit"
esac
case $MAX_FD in #(
'' | soft) :;; #(
*)
+ # In POSIX sh, ulimit -n is undefined. That's why the result is checked to see if it worked.
+ # shellcheck disable=SC2039,SC3045
ulimit -n "$MAX_FD" ||
warn "Could not set maximum file descriptor limit to $MAX_FD"
esac
@@ -193,11 +198,15 @@ if "$cygwin" || "$msys" ; then
done
fi
-# Collect all arguments for the java command;
-# * $DEFAULT_JVM_OPTS, $JAVA_OPTS, and $GRADLE_OPTS can contain fragments of
-# shell script including quotes and variable substitutions, so put them in
-# double quotes to make sure that they get re-expanded; and
-# * put everything else in single quotes, so that it's not re-expanded.
+
+# Add default JVM options here. You can also use JAVA_OPTS and GRADLE_OPTS to pass JVM options to this script.
+DEFAULT_JVM_OPTS='"-Xmx64m" "-Xms64m"'
+
+# Collect all arguments for the java command:
+# * DEFAULT_JVM_OPTS, JAVA_OPTS, JAVA_OPTS, and optsEnvironmentVar are not allowed to contain shell fragments,
+# and any embedded shellness will be escaped.
+# * For example: A user cannot expect ${Hostname} to be expanded, as it is an environment variable and will be
+# treated as '${Hostname}' itself on the command line.
set -- \
"-Dorg.gradle.appname=$APP_BASE_NAME" \
diff --git a/js/react_native/android/gradlew.bat b/js/react_native/android/gradlew.bat
index f127cfd49d402..25da30dbdeee9 100644
--- a/js/react_native/android/gradlew.bat
+++ b/js/react_native/android/gradlew.bat
@@ -26,6 +26,7 @@ if "%OS%"=="Windows_NT" setlocal
set DIRNAME=%~dp0
if "%DIRNAME%"=="" set DIRNAME=.
+@rem This is normally unused
set APP_BASE_NAME=%~n0
set APP_HOME=%DIRNAME%
@@ -42,11 +43,11 @@ set JAVA_EXE=java.exe
%JAVA_EXE% -version >NUL 2>&1
if %ERRORLEVEL% equ 0 goto execute
-echo.
-echo ERROR: JAVA_HOME is not set and no 'java' command could be found in your PATH.
-echo.
-echo Please set the JAVA_HOME variable in your environment to match the
-echo location of your Java installation.
+echo. 1>&2
+echo ERROR: JAVA_HOME is not set and no 'java' command could be found in your PATH. 1>&2
+echo. 1>&2
+echo Please set the JAVA_HOME variable in your environment to match the 1>&2
+echo location of your Java installation. 1>&2
goto fail
@@ -56,11 +57,11 @@ set JAVA_EXE=%JAVA_HOME%/bin/java.exe
if exist "%JAVA_EXE%" goto execute
-echo.
-echo ERROR: JAVA_HOME is set to an invalid directory: %JAVA_HOME%
-echo.
-echo Please set the JAVA_HOME variable in your environment to match the
-echo location of your Java installation.
+echo. 1>&2
+echo ERROR: JAVA_HOME is set to an invalid directory: %JAVA_HOME% 1>&2
+echo. 1>&2
+echo Please set the JAVA_HOME variable in your environment to match the 1>&2
+echo location of your Java installation. 1>&2
goto fail
diff --git a/js/web/lib/wasm/wasm-core-impl.ts b/js/web/lib/wasm/wasm-core-impl.ts
index 81d1b73efc9d4..da8939cd0263a 100644
--- a/js/web/lib/wasm/wasm-core-impl.ts
+++ b/js/web/lib/wasm/wasm-core-impl.ts
@@ -487,7 +487,7 @@ export const prepareInputOutputTensor = (
}
if (location === 'gpu-buffer') {
- const gpuBuffer = tensor[2].gpuBuffer as GPUBuffer;
+ const gpuBuffer = tensor[2].gpuBuffer;
dataByteLength = calculateTensorSizeInBytes(tensorDataTypeStringToEnum(dataType), dims)!;
const registerBuffer = wasm.jsepRegisterBuffer;
diff --git a/js/web/package-lock.json b/js/web/package-lock.json
index 894667ad58933..07c8f0bf3b940 100644
--- a/js/web/package-lock.json
+++ b/js/web/package-lock.json
@@ -861,9 +861,9 @@
}
},
"node_modules/cross-spawn": {
- "version": "6.0.5",
- "resolved": "https://registry.npmjs.org/cross-spawn/-/cross-spawn-6.0.5.tgz",
- "integrity": "sha512-eTVLrBSt7fjbDygz805pMnstIs2VTBNkRm0qxZd+M7A5XDdxVRWO5MxGBXZhjY4cqLYLdtrGqRf8mBPmzwSpWQ==",
+ "version": "6.0.6",
+ "resolved": "https://registry.npmjs.org/cross-spawn/-/cross-spawn-6.0.6.tgz",
+ "integrity": "sha512-VqCUuhcd1iB+dsv8gxPttb5iZh/D0iubSP21g36KXdEuf6I5JiioesUVjpCdHV9MZRUfVFlvwtIUyPfxo5trtw==",
"dev": true,
"dependencies": {
"nice-try": "^1.0.4",
@@ -4312,9 +4312,9 @@
}
},
"cross-spawn": {
- "version": "6.0.5",
- "resolved": "https://registry.npmjs.org/cross-spawn/-/cross-spawn-6.0.5.tgz",
- "integrity": "sha512-eTVLrBSt7fjbDygz805pMnstIs2VTBNkRm0qxZd+M7A5XDdxVRWO5MxGBXZhjY4cqLYLdtrGqRf8mBPmzwSpWQ==",
+ "version": "6.0.6",
+ "resolved": "https://registry.npmjs.org/cross-spawn/-/cross-spawn-6.0.6.tgz",
+ "integrity": "sha512-VqCUuhcd1iB+dsv8gxPttb5iZh/D0iubSP21g36KXdEuf6I5JiioesUVjpCdHV9MZRUfVFlvwtIUyPfxo5trtw==",
"dev": true,
"requires": {
"nice-try": "^1.0.4",
diff --git a/onnxruntime/contrib_ops/cpu/bert/rotary_embedding.cc b/onnxruntime/contrib_ops/cpu/bert/rotary_embedding.cc
index cbfd2f0949363..9a6c2af022c91 100644
--- a/onnxruntime/contrib_ops/cpu/bert/rotary_embedding.cc
+++ b/onnxruntime/contrib_ops/cpu/bert/rotary_embedding.cc
@@ -4,6 +4,7 @@
#include "contrib_ops/cpu/bert/rotary_embedding.h"
#include "contrib_ops/cpu/bert/rotary_embedding_helper.h"
+#include "core/mlas/inc/mlas.h"
#include "core/platform/threadpool.h"
using onnxruntime::concurrency::ThreadPool;
@@ -78,31 +79,12 @@ Status RunRotaryEmbedding(concurrency::ThreadPool* tp, RotaryParameters paramete
const T* cos_data = cos_cache + cache_offset;
const T* sin_data = sin_cache + cache_offset;
- int cache_idx = 0;
- bool sign = false;
- int j = 0;
- for (int i = 0; i < rotary_emb_dim; i++) {
- if (interleaved) {
- cache_idx = (i / 2) % half_rotary_emb_dim;
- sign = i & 1;
- j = sign ? i - 1 : i + 1; // i - sign
- } else {
- cache_idx = i % half_rotary_emb_dim;
- sign = (i >= half_rotary_emb_dim);
- j = (i + half_rotary_emb_dim) % rotary_emb_dim;
- }
- float output_data_i = static_cast(input_data[i]) * static_cast(cos_data[cache_idx]);
- float input_data_j = static_cast(input_data[j]);
- float sin_data_cache_idx = static_cast(sin_data[cache_idx]);
- if (sign) {
- output_data_i += input_data_j * sin_data_cache_idx;
- } else {
- output_data_i -= input_data_j * sin_data_cache_idx;
- }
- output_data[i] = static_cast(output_data_i);
- }
- for (int i = rotary_emb_dim; i < head_size; i++) {
- output_data[i] = input_data[i];
+ MlasRotaryEmbedOneRow(input_data, sin_data, cos_data, rotary_emb_dim, interleaved, output_data);
+
+ if (rotary_emb_dim < head_size) {
+ std::memcpy(output_data + rotary_emb_dim,
+ input_data + rotary_emb_dim,
+ (head_size - rotary_emb_dim) * sizeof(T));
}
}
});
diff --git a/onnxruntime/contrib_ops/cpu/transformers/subgraph_base.cc b/onnxruntime/contrib_ops/cpu/transformers/subgraph_base.cc
index d675ba742e03b..7757435990a65 100644
--- a/onnxruntime/contrib_ops/cpu/transformers/subgraph_base.cc
+++ b/onnxruntime/contrib_ops/cpu/transformers/subgraph_base.cc
@@ -31,6 +31,7 @@ Subgraph::Subgraph(
allocator_(nullptr),
is_output_float16_(false) {
num_implicit_inputs = static_cast(node.ImplicitInputDefs().size());
+ used_implicit_inputs = std::vector(num_implicit_inputs, true);
auto& subgraph_inputs = subgraph.GetInputs();
auto& subgraph_outputs = subgraph.GetOutputs();
@@ -73,8 +74,18 @@ Status Subgraph::Setup(const SessionState& session_state,
// The position_ids, attention_mask, past_0, ... are created by this operator so the name doesn't matter.
feed_names.insert(feed_names.end(), subgraph_input_names.begin(), subgraph_input_names.end());
- for (auto& entry : node.ImplicitInputDefs()) {
- feed_names.push_back(entry->Name());
+ const auto& subgraph_map = subgraph_session_state.GetOrtValueNameIdxMap();
+
+ const auto& implicit_input_defs = node.ImplicitInputDefs();
+ for (size_t i = 0, end = num_implicit_inputs; i < end; ++i) {
+ const auto* entry = implicit_input_defs[i];
+ int idx;
+ if (subgraph_map.GetIdx(entry->Name(), idx).IsOK()) {
+ feed_names.push_back(entry->Name());
+ } else {
+ --num_implicit_inputs;
+ used_implicit_inputs[i] = false;
+ }
}
InlinedVector feed_locations;
diff --git a/onnxruntime/contrib_ops/cpu/transformers/subgraph_base.h b/onnxruntime/contrib_ops/cpu/transformers/subgraph_base.h
index bde591626bb83..8ec9c9cbdc20f 100644
--- a/onnxruntime/contrib_ops/cpu/transformers/subgraph_base.h
+++ b/onnxruntime/contrib_ops/cpu/transformers/subgraph_base.h
@@ -31,6 +31,7 @@ class Subgraph {
const GraphViewer& subgraph; // The subgraph
int num_implicit_inputs;
+ std::vector used_implicit_inputs;
int num_subgraph_inputs; // Same as subgraph_input_names.size(), keep it for convenience.
int num_subgraph_outputs; // Same as subgraph_output_names.size()
diff --git a/onnxruntime/contrib_ops/cpu/transformers/subgraph_t5_decoder.cc b/onnxruntime/contrib_ops/cpu/transformers/subgraph_t5_decoder.cc
index 9037e58aaf31f..6c66bfc2816e4 100644
--- a/onnxruntime/contrib_ops/cpu/transformers/subgraph_t5_decoder.cc
+++ b/onnxruntime/contrib_ops/cpu/transformers/subgraph_t5_decoder.cc
@@ -281,8 +281,11 @@ Status T5DecoderSubgraph::CreateInitialFeeds(
}
// Pass through implicit inputs.
- for (const auto* entry : implicit_inputs) {
- decoder_feeds.push_back(*entry);
+ for (size_t i = 0; i < implicit_inputs.size(); ++i) {
+ const auto* entry = implicit_inputs[i];
+ if (used_implicit_inputs[i]) {
+ decoder_feeds.push_back(*entry);
+ }
}
return Status::OK();
diff --git a/onnxruntime/contrib_ops/cpu/transformers/subgraph_t5_encoder.cc b/onnxruntime/contrib_ops/cpu/transformers/subgraph_t5_encoder.cc
index 51473c0c931b9..d59db4afac2c2 100644
--- a/onnxruntime/contrib_ops/cpu/transformers/subgraph_t5_encoder.cc
+++ b/onnxruntime/contrib_ops/cpu/transformers/subgraph_t5_encoder.cc
@@ -145,8 +145,11 @@ Status T5EncoderSubgraph::CreateInitialFeeds(
pinned_allocator,
location));
- for (const auto* entry : implicit_inputs) {
- feeds.push_back(*entry);
+ for (size_t i = 0; i < implicit_inputs.size(); ++i) {
+ const auto* entry = implicit_inputs[i];
+ if (used_implicit_inputs[i]) {
+ feeds.push_back(*entry);
+ }
}
return Status::OK();
diff --git a/onnxruntime/core/framework/allocation_planner.cc b/onnxruntime/core/framework/allocation_planner.cc
index 5dca4cf6c165b..ecd3960107926 100644
--- a/onnxruntime/core/framework/allocation_planner.cc
+++ b/onnxruntime/core/framework/allocation_planner.cc
@@ -138,7 +138,8 @@ class PlannerImpl {
const SubgraphsKernelCreateInfoMaps& subgraphs_kernel_create_info_maps,
const InlinedHashMap& outer_scope_node_arg_to_location_map,
const OrtValueNameIdxMap& ort_value_name_idx_map,
- const ISequentialPlannerContext& context, SequentialExecutionPlan& plan)
+ const ISequentialPlannerContext& context, SequentialExecutionPlan& plan,
+ const logging::Logger& logger)
: context_(&context),
plan_(plan),
parent_node_(parent_node),
@@ -148,14 +149,15 @@ class PlannerImpl {
kernel_create_info_map_(kernel_create_info_map),
subgraphs_kernel_create_info_maps_(subgraphs_kernel_create_info_maps),
outer_scope_node_arg_to_location_map_(outer_scope_node_arg_to_location_map),
- ort_value_name_idx_map_(ort_value_name_idx_map) {}
+ ort_value_name_idx_map_(ort_value_name_idx_map),
+ logger_(logger) {
+ }
Status CreatePlan(
#ifdef ORT_ENABLE_STREAM
const IStreamCommandHandleRegistry& stream_handle_registry,
#endif
- const PathString& partition_config_file,
- const logging::Logger& logger);
+ const PathString& partition_config_file);
private:
gsl::not_null context_;
@@ -183,6 +185,12 @@ class PlannerImpl {
InlinedHashMap> dependence_graph_;
InlinedHashMap value_node_map_;
+ // logger_ is not currently used in a minimal build
+#if defined(ORT_MINIMAL_BUILD) && !defined(ORT_EXTENDED_MINIMAL_BUILD)
+ [[maybe_unused]]
+#endif
+ const logging::Logger& logger_;
+
// OrtValueInfo: Auxiliary information about an OrtValue used only during plan-generation:
struct OrtValueInfo {
const onnxruntime::NodeArg* p_def_site; // the (unique) NodeArg corresponding to the MLValue
@@ -213,6 +221,7 @@ class PlannerImpl {
FreeBufferInfo(OrtValueIndex ort_value, size_t dealloc_point)
: ml_value(ort_value), deallocate_point(dealloc_point) {}
};
+
// freelist_ : a list of ml-values whose buffers are free to be reused, sorted by when
// they became free (more recently freed earlier in the list).
std::list freelist_;
@@ -225,7 +234,8 @@ class PlannerImpl {
}
int& UseCount(OrtValueIndex n) {
- ORT_ENFORCE(n >= 0 && static_cast(n) < ort_value_info_.size(), "invalid value index: ", n, " against size ", ort_value_info_.size());
+ ORT_ENFORCE(n >= 0 && static_cast(n) < ort_value_info_.size(),
+ "invalid value index: ", n, " against size ", ort_value_info_.size());
return ort_value_info_[n].usecount;
}
int& UseCount(const OrtValueName& name) { return UseCount(Index(name)); }
@@ -335,9 +345,9 @@ class PlannerImpl {
// we cannot.
const Node* producer_node = graph.GetProducerNode(p_input_arg->Name());
if (producer_node && HasExternalOutputs(*producer_node)) {
- LOGS_DEFAULT(VERBOSE) << "Be noted Node " << node.Name() << " is reusing input buffer of node "
- << producer_node->Name() << " which has external outputs. "
- << "Be cautious the reuse MUST be a read-only usage.";
+ LOGS(logger_, VERBOSE) << "Be noted Node " << node.Name() << " is reusing input buffer of node "
+ << producer_node->Name() << " which has external outputs. "
+ << "Be cautious the reuse MUST be a read-only usage.";
}
#endif
*reusable_input = Index(p_input_arg->Name());
@@ -361,9 +371,9 @@ class PlannerImpl {
// we cannot.
const Node* producer_node = graph.GetProducerNode(p_input_arg->Name());
if (producer_node && HasExternalOutputs(*producer_node)) {
- LOGS_DEFAULT(VERBOSE) << "Be noted Node " << node.Name() << " is reusing input buffer of node "
- << producer_node->Name() << " which has external outputs. "
- << "Be cautious the reuse MUST be a read-only usage.";
+ LOGS(logger_, VERBOSE) << "Be noted Node " << node.Name() << " is reusing input buffer of node "
+ << producer_node->Name() << " which has external outputs. "
+ << "Be cautious the reuse MUST be a read-only usage.";
}
#endif
*reusable_input = Index(p_input_arg->Name());
@@ -397,8 +407,8 @@ class PlannerImpl {
}
} else {
#if !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD)
- LOGS_DEFAULT(VERBOSE) << "Node " << node.Name() << " cannot reuse input buffer for node "
- << producer_node->Name() << " as it has external outputs";
+ LOGS(logger_, VERBOSE) << "Node " << node.Name() << " cannot reuse input buffer for node "
+ << producer_node->Name() << " as it has external outputs";
#endif
}
}
@@ -448,8 +458,8 @@ class PlannerImpl {
return true;
} else {
#if !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD)
- LOGS_DEFAULT(VERBOSE) << "Node " << node.Name() << " cannot reuse strided output buffer for node "
- << producer_node->Name() << " as it has external outputs.";
+ LOGS(logger_, VERBOSE) << "Node " << node.Name() << " cannot reuse strided output buffer for node "
+ << producer_node->Name() << " as it has external outputs.";
#endif
}
}
@@ -1198,9 +1208,9 @@ class PlannerImpl {
// Otherwise, we cannot reuse the buffer.
const Node* producer_node = graph_viewer.GetProducerNode(p_input_arg->Name());
if (producer_node && HasExternalOutputs(*producer_node)) {
- LOGS_DEFAULT(VERBOSE) << "Be noted input buffer " << p_output_arg->Name() << " of node "
- << producer_node->Name() << " which has external outputs is reused. "
- << "Be cautious the reuse MUST be a read-only usage.";
+ LOGS(logger_, VERBOSE) << "Be noted input buffer " << p_output_arg->Name() << " of node "
+ << producer_node->Name() << " which has external outputs is reused. "
+ << "Be cautious the reuse MUST be a read-only usage.";
}
#endif
@@ -1241,9 +1251,9 @@ class PlannerImpl {
// Otherwise, we cannot reuse the buffer.
const Node* producer_node = graph_viewer.GetProducerNode(p_input_arg->Name());
if (producer_node && HasExternalOutputs(*producer_node)) {
- LOGS_DEFAULT(VERBOSE) << "Be noted input buffer " << p_output_arg->Name() << " of node "
- << producer_node->Name() << " which has external outputs is reused. "
- << "Be cautious the reuse MUST be a read-only usage.";
+ LOGS(logger_, VERBOSE) << "Be noted input buffer " << p_output_arg->Name() << " of node "
+ << producer_node->Name() << " which has external outputs is reused. "
+ << "Be cautious the reuse MUST be a read-only usage.";
}
#endif
@@ -1290,8 +1300,8 @@ class PlannerImpl {
}
} else {
#if !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD)
- LOGS_DEFAULT(VERBOSE) << "Node " << node->Name() << " cannot reuse input buffer for node "
- << producer_node->Name() << " as it has external outputs";
+ LOGS(logger_, VERBOSE) << "Node " << node->Name() << " cannot reuse input buffer for node "
+ << producer_node->Name() << " as it has external outputs";
#endif
}
}
@@ -1869,8 +1879,7 @@ class PlannerImpl {
}
#ifndef ORT_ENABLE_STREAM
- void PartitionIntoStreams(const logging::Logger& /*logger*/,
- const ExecutionProviders& /*execution_providers*/,
+ void PartitionIntoStreams(const ExecutionProviders& /*execution_providers*/,
const PathString& /*partition_config_file*/) {
if (graph_viewer_.NumberOfNodes() > 0) {
stream_nodes_.push_back({});
@@ -1915,11 +1924,11 @@ class PlannerImpl {
#else
- void
- PartitionIntoStreams(const logging::Logger& logger, const ExecutionProviders& execution_providers,
- const PathString& partition_config_file) {
- auto partitioner = IGraphPartitioner::CreateGraphPartitioner(logger, partition_config_file);
- auto status = partitioner->PartitionGraph(graph_viewer_, execution_providers, stream_nodes_, context_->GetExecutionOrder());
+ void PartitionIntoStreams(const ExecutionProviders& execution_providers,
+ const PathString& partition_config_file) {
+ auto partitioner = IGraphPartitioner::CreateGraphPartitioner(logger_, partition_config_file);
+ auto status = partitioner->PartitionGraph(graph_viewer_, execution_providers, stream_nodes_,
+ context_->GetExecutionOrder());
ORT_ENFORCE(status.IsOK(), status.ErrorMessage());
plan_.node_stream_map_.resize(SafeInt(graph_viewer_.MaxNodeIndex()) + 1);
for (size_t i = 0; i < stream_nodes_.size(); ++i) {
@@ -2282,10 +2291,9 @@ Status PlannerImpl::CreatePlan(
#ifdef ORT_ENABLE_STREAM
const IStreamCommandHandleRegistry& stream_handle_registry,
#endif
- const PathString& partition_config_file,
- const logging::Logger& logger) {
+ const PathString& partition_config_file) {
// 1. partition graph into streams
- PartitionIntoStreams(logger, execution_providers_, this->parent_node_ ? PathString{} : partition_config_file);
+ PartitionIntoStreams(execution_providers_, parent_node_ ? PathString{} : partition_config_file);
// 2. initialize the plan based on stream partition result
int num_ml_values = ort_value_name_idx_map_.MaxIdx() + 1;
@@ -2354,14 +2362,13 @@ Status SequentialPlanner::CreatePlan(
PlannerImpl planner(parent_node, graph_viewer, outer_scope_node_args, providers,
kernel_create_info_map, subgraphs_kernel_create_info_maps,
outer_scope_node_arg_to_location_map,
- ort_value_name_idx_map, context, *plan);
+ ort_value_name_idx_map, context, *plan, logger);
return planner.CreatePlan(
#ifdef ORT_ENABLE_STREAM
stream_handle_registry,
#endif
- partition_config_file,
- logger);
+ partition_config_file);
}
#ifdef ORT_ENABLE_STREAM
diff --git a/onnxruntime/core/framework/fallback_cpu_capability.cc b/onnxruntime/core/framework/fallback_cpu_capability.cc
index ef68b88187e08..1eb7420b44d2c 100644
--- a/onnxruntime/core/framework/fallback_cpu_capability.cc
+++ b/onnxruntime/core/framework/fallback_cpu_capability.cc
@@ -41,7 +41,8 @@ static bool IsSmallInitializer(const onnxruntime::GraphViewer& graph, const Node
std::unordered_set GetCpuPreferredNodes(const onnxruntime::GraphViewer& graph,
const IExecutionProvider::IKernelLookup& kernel_lookup,
- gsl::span tentative_nodes) {
+ gsl::span tentative_nodes,
+ const logging::Logger& logger) {
// automatic conversion from const std::vector&
const auto& ordered_nodes = graph.GetNodesInTopologicalOrder();
InlinedVector node_id_to_order_map(graph.MaxNodeIndex());
@@ -83,7 +84,7 @@ std::unordered_set GetCpuPreferredNodes(const onnxruntime::GraphViewe
auto consumer_nodes = graph.GetConsumerNodes(node_arg.Name());
for (auto& consumer_node : consumer_nodes) {
candidates.push(consumer_node->Index());
- LOGS_DEFAULT(INFO) << "Candidate for fallback CPU execution: " << consumer_node->Name();
+ LOGS(logger, INFO) << "Candidate for fallback CPU execution: " << consumer_node->Name();
}
}
return Status::OK();
@@ -159,9 +160,9 @@ std::unordered_set GetCpuPreferredNodes(const onnxruntime::GraphViewe
if (place_in_cpu) {
cpu_nodes.insert(cur);
- LOGS_DEFAULT(INFO) << "ORT optimization- Force fallback to CPU execution for node: " << node->Name()
- << " because the CPU execution path is deemed faster than overhead involved with execution on other EPs "
- << " capable of executing this node";
+ LOGS(logger, INFO) << "ORT optimization- Force fallback to CPU execution for node: " << node->Name()
+ << " because the CPU execution path is deemed faster than overhead involved with execution "
+ "on other EPs capable of executing this node";
for (auto* output : node->OutputDefs()) {
cpu_output_args.insert(output);
}
diff --git a/onnxruntime/core/framework/fallback_cpu_capability.h b/onnxruntime/core/framework/fallback_cpu_capability.h
index c5bcd22888b7c..bca75adbfd5a7 100644
--- a/onnxruntime/core/framework/fallback_cpu_capability.h
+++ b/onnxruntime/core/framework/fallback_cpu_capability.h
@@ -9,6 +9,9 @@
#include "core/graph/graph_viewer.h"
namespace onnxruntime {
+namespace logging {
+class Logger;
+}
/**
Returns a list of nodes that are preferred on CPU.
@@ -19,6 +22,7 @@ namespace onnxruntime {
*/
std::unordered_set GetCpuPreferredNodes(const GraphViewer& graph,
const IExecutionProvider::IKernelLookup& kernel_lookup,
- gsl::span tentative_nodes);
+ gsl::span tentative_nodes,
+ const logging::Logger& logger);
} // namespace onnxruntime
diff --git a/onnxruntime/core/framework/graph_partitioner.cc b/onnxruntime/core/framework/graph_partitioner.cc
index 6174122cf3cb4..406fc1b15effc 100644
--- a/onnxruntime/core/framework/graph_partitioner.cc
+++ b/onnxruntime/core/framework/graph_partitioner.cc
@@ -149,13 +149,13 @@ auto get_capabilities = [](const IExecutionProvider& ep,
};
} // namespace
-static Status GetCapabilityForEP(const GetCapabilityForEPParams& params) {
+static Status GetCapabilityForEP(const GetCapabilityForEPParams& params, const logging::Logger& logger) {
auto& current_ep = params.current_ep.get();
const auto& ep_type = current_ep.Type();
#if !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD)
if (current_ep.GetPreferredLayout() == DataLayout::NHWC && !params.transform_layout.get()) {
- LOGS_DEFAULT(WARNING) << ep_type << " cannot be used with this model due to its ONNX opset not being supported by "
+ LOGS(logger, WARNING) << ep_type << " cannot be used with this model due to its ONNX opset not being supported by "
"the layout transformer.";
return Status::OK();
}
@@ -165,7 +165,8 @@ static Status GetCapabilityForEP(const GetCapabilityForEPParams& params) {
const auto kernel_registries_for_ep = kernel_registry_mgr.GetKernelRegistriesByProviderType(ep_type);
const KernelLookup kernel_lookup{ep_type,
kernel_registries_for_ep,
- kernel_registry_mgr.GetKernelTypeStrResolver()};
+ kernel_registry_mgr.GetKernelTypeStrResolver(),
+ logger};
auto& graph = params.graph.get();
auto& capabilities = params.capabilities.get();
@@ -248,13 +249,15 @@ static Status GetCapabilityForEP(const GetCapabilityForEPParams& params) {
static Status GetCapabilityForEPForAotInlining(const GraphViewer& graph_viewer,
const KernelRegistryManager& kernel_registry_mgr,
const IExecutionProvider& current_ep,
+ const logging::Logger& logger,
std::vector>& capabilities) {
const auto& ep_type = current_ep.Type();
const auto kernel_registries_for_ep = kernel_registry_mgr.GetKernelRegistriesByProviderType(ep_type);
const KernelLookup kernel_lookup{ep_type,
kernel_registries_for_ep,
- kernel_registry_mgr.GetKernelTypeStrResolver()};
+ kernel_registry_mgr.GetKernelTypeStrResolver(),
+ logger};
// TODO: Provide EP with a capability to look inside the functions.
capabilities = get_capabilities(current_ep, graph_viewer, kernel_lookup);
@@ -359,7 +362,8 @@ static Status PartitionOnnxFormatModelImpl(Graph& graph, FuncManager& func_mgr,
GraphPartitioner::Mode mode,
int& fused_node_unique_id,
const layout_transformation::TransformLayoutFunction& transform_layout_fn,
- const layout_transformation::DebugGraphFn& debug_graph_fn) {
+ const layout_transformation::DebugGraphFn& debug_graph_fn,
+ const logging::Logger& logger) {
// handle testing edge case where optimizers or constant lifting results in graph with no nodes.
// doing it here saves all providers checking for this in GetCapability
if (graph.NumberOfNodes() == 0) {
@@ -373,7 +377,7 @@ static Status PartitionOnnxFormatModelImpl(Graph& graph, FuncManager& func_mgr,
// we pass through the FuncManager from the top level graph
ORT_RETURN_IF_ERROR(PartitionOnnxFormatModelImpl(*subgraph, func_mgr, kernel_registry_mgr,
fused_kernel_registry, current_ep, mode, fused_node_unique_id,
- transform_layout_fn, debug_graph_fn));
+ transform_layout_fn, debug_graph_fn, logger));
}
}
@@ -398,7 +402,7 @@ static Status PartitionOnnxFormatModelImpl(Graph& graph, FuncManager& func_mgr,
std::cref(transform_layout_fn),
std::cref(debug_graph_fn)};
- ORT_RETURN_IF_ERROR(GetCapabilityForEP(get_capability_params));
+ ORT_RETURN_IF_ERROR(GetCapabilityForEP(get_capability_params, logger));
if (capabilities.empty()) {
return Status::OK();
}
@@ -425,7 +429,7 @@ static Status PartitionOnnxFormatModelImpl(Graph& graph, FuncManager& func_mgr,
Node* n = PlaceNode(graph, *capability->sub_graph, fusion_style, type, mode, fused_node_unique_id);
if (n != nullptr) {
// searching in kernel registries, if no kernel registered for the fused_node, use compile approach
- if (!KernelRegistryManager::HasImplementationOf(kernel_registry_mgr, *n, type)) {
+ if (!KernelRegistryManager::HasImplementationOf(kernel_registry_mgr, *n, type, logger)) {
nodes_to_compile.push_back(n);
capabilities_to_compile.push_back(std::move(capability));
} else {
@@ -559,6 +563,7 @@ static Status InlineNodes(Graph& graph, bool& modified_graph) {
static Status InlineFunctionsAOTImpl(const ExecutionProviders& execution_providers,
const KernelRegistryManager& kernel_registry_mgr,
Graph& graph,
+ const logging::Logger& logger,
InlinedHashSet& not_inlined,
size_t& inlined_count) {
// handle testing edge case where optimizers or constant lifting results in graph with no nodes.
@@ -574,6 +579,7 @@ static Status InlineFunctionsAOTImpl(const ExecutionProviders& execution_provide
ORT_RETURN_IF_ERROR(InlineFunctionsAOTImpl(execution_providers,
kernel_registry_mgr,
*subgraph,
+ logger,
not_inlined,
inlined_count));
}
@@ -597,7 +603,8 @@ static Status InlineFunctionsAOTImpl(const ExecutionProviders& execution_provide
InlinedHashSet claimed_by_ep;
for (const auto& ep : execution_providers) {
std::vector> capabilities;
- ORT_RETURN_IF_ERROR(GetCapabilityForEPForAotInlining(graph_viewer, kernel_registry_mgr, *ep, capabilities));
+ ORT_RETURN_IF_ERROR(GetCapabilityForEPForAotInlining(graph_viewer, kernel_registry_mgr, *ep, logger,
+ capabilities));
for (auto& capability : capabilities) {
const auto& nodes = capability->sub_graph->nodes;
if (nodes.size() == 1) {
@@ -727,7 +734,8 @@ static Status CreateEpContextModel(const ExecutionProviders& execution_providers
static Status PartitionOnnxFormatModel(const PartitionParams& partition_params, GraphPartitioner::Mode mode,
const ExecutionProviders& execution_providers,
- KernelRegistryManager& kernel_registry_manager) {
+ KernelRegistryManager& kernel_registry_manager,
+ const logging::Logger& logger) {
bool modified_graph = false;
auto& graph = partition_params.graph.get();
@@ -742,7 +750,8 @@ static Status PartitionOnnxFormatModel(const PartitionParams& partition_params,
ORT_RETURN_IF_ERROR(PartitionOnnxFormatModelImpl(graph, func_mgr, kernel_registry_manager,
fused_kernel_registry, *ep, mode, fused_node_unique_id,
transform_layout_function,
- partition_params.debug_graph_fn));
+ partition_params.debug_graph_fn,
+ logger));
}
// expand any nodes that have an ONNX function definition but no matching ORT kernel.
@@ -762,7 +771,8 @@ static Status PartitionOnnxFormatModel(const PartitionParams& partition_params,
static Status PartitionOrtFormatModelImpl(const PartitionParams& partition_params,
KernelRegistryManager& kernel_registry_mgr,
- IExecutionProvider& current_ep) {
+ IExecutionProvider& current_ep,
+ const logging::Logger& logger) {
// handle testing edge case where optimizers or constant lifting results in graph with no nodes.
// doing it here saves all providers checking for this in GetCapability
auto& graph = partition_params.graph.get();
@@ -776,7 +786,8 @@ static Status PartitionOrtFormatModelImpl(const PartitionParams& partition_param
auto& subgraph = *entry.second;
PartitionParams subgraph_partition_params = partition_params;
subgraph_partition_params.graph = std::ref(subgraph);
- ORT_RETURN_IF_ERROR(PartitionOrtFormatModelImpl(subgraph_partition_params, kernel_registry_mgr, current_ep));
+ ORT_RETURN_IF_ERROR(PartitionOrtFormatModelImpl(subgraph_partition_params, kernel_registry_mgr, current_ep,
+ logger));
}
}
@@ -795,7 +806,7 @@ static Status PartitionOrtFormatModelImpl(const PartitionParams& partition_param
};
// clang-format on
- ORT_RETURN_IF_ERROR(GetCapabilityForEP(get_capability_params));
+ ORT_RETURN_IF_ERROR(GetCapabilityForEP(get_capability_params, logger));
if (capabilities.empty()) {
return Status::OK();
}
@@ -876,10 +887,11 @@ static Status PartitionOrtFormatModelImpl(const PartitionParams& partition_param
// Simplified partitioning where custom EPs may produce compiled nodes.
static Status PartitionOrtFormatModel(const PartitionParams& partition_params,
const ExecutionProviders& execution_providers,
- KernelRegistryManager& kernel_registry_manager) {
+ KernelRegistryManager& kernel_registry_manager,
+ const logging::Logger& logger) {
// process full graph with each EP
for (const auto& ep : execution_providers) {
- ORT_RETURN_IF_ERROR(PartitionOrtFormatModelImpl(partition_params, kernel_registry_manager, *ep));
+ ORT_RETURN_IF_ERROR(PartitionOrtFormatModelImpl(partition_params, kernel_registry_manager, *ep, logger));
}
return Status::OK();
@@ -906,6 +918,7 @@ Status GraphPartitioner::InlineFunctionsAOT(Model& model,
ORT_RETURN_IF_ERROR(InlineFunctionsAOTImpl(execution_providers,
kernel_registry_manager,
graph,
+ logger,
not_inlined,
inlined_count));
@@ -977,8 +990,7 @@ Status GraphPartitioner::Partition(Graph& graph, FuncManager& func_mgr,
if (mode == Mode::kNormal || mode == Mode::kAssignOnly) {
#if !defined(ORT_MINIMAL_BUILD)
- ORT_RETURN_IF_ERROR(PartitionOnnxFormatModel(partition_params, mode,
- providers_, kernel_registry_mgr_));
+ ORT_RETURN_IF_ERROR(PartitionOnnxFormatModel(partition_params, mode, providers_, kernel_registry_mgr_, logger));
bool ep_context_enabled = config_options.GetConfigOrDefault(kOrtSessionOptionEpContextEnable, "0") == "1";
std::string ep_context_path = config_options.GetConfigOrDefault(kOrtSessionOptionEpContextFilePath, "");
@@ -991,8 +1003,7 @@ Status GraphPartitioner::Partition(Graph& graph, FuncManager& func_mgr,
return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "ONNX models are not supported in this build.");
#endif //! defined(ORT_MINIMAL_BUILD)
} else {
- ORT_RETURN_IF_ERROR(PartitionOrtFormatModel(partition_params,
- providers_, kernel_registry_mgr_));
+ ORT_RETURN_IF_ERROR(PartitionOrtFormatModel(partition_params, providers_, kernel_registry_mgr_, logger));
}
#if !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD)
diff --git a/onnxruntime/core/framework/kernel_lookup.h b/onnxruntime/core/framework/kernel_lookup.h
index 0dd17d2f4a624..fac43bad0fefb 100644
--- a/onnxruntime/core/framework/kernel_lookup.h
+++ b/onnxruntime/core/framework/kernel_lookup.h
@@ -21,17 +21,19 @@ class KernelLookup final : public IExecutionProvider::IKernelLookup {
public:
KernelLookup(ProviderType provider_type,
gsl::span> kernel_registries,
- const IKernelTypeStrResolver& kernel_type_str_resolver)
+ const IKernelTypeStrResolver& kernel_type_str_resolver,
+ const logging::Logger& logger)
: provider_type_{provider_type},
kernel_registries_{kernel_registries},
- kernel_type_str_resolver_{kernel_type_str_resolver} {
+ kernel_type_str_resolver_{kernel_type_str_resolver},
+ logger_{logger} {
ORT_ENFORCE(!provider_type_.empty(), "provider_type must be specified.");
}
const KernelCreateInfo* LookUpKernel(const Node& node) const override {
const KernelCreateInfo* kernel_create_info{};
for (const auto& registry : kernel_registries_) {
- const auto lookup_status = registry->TryFindKernel(node, provider_type_, kernel_type_str_resolver_,
+ const auto lookup_status = registry->TryFindKernel(node, provider_type_, kernel_type_str_resolver_, logger_,
&kernel_create_info);
if (lookup_status.IsOK() && kernel_create_info != nullptr) {
return kernel_create_info;
@@ -45,6 +47,7 @@ class KernelLookup final : public IExecutionProvider::IKernelLookup {
ProviderType provider_type_;
const gsl::span> kernel_registries_;
const IKernelTypeStrResolver& kernel_type_str_resolver_;
+ const logging::Logger& logger_;
};
} // namespace onnxruntime
diff --git a/onnxruntime/core/framework/kernel_registry.cc b/onnxruntime/core/framework/kernel_registry.cc
index cbbe0f86b8b7e..8602a3b4004ff 100644
--- a/onnxruntime/core/framework/kernel_registry.cc
+++ b/onnxruntime/core/framework/kernel_registry.cc
@@ -183,6 +183,7 @@ Status KernelRegistry::TryFindKernelImpl(const Node& node,
ProviderType exec_provider,
const IKernelTypeStrResolver* kernel_type_str_resolver,
const TypeConstraintMap* type_constraints,
+ const logging::Logger& logger,
const KernelCreateInfo** out) const {
const auto& node_provider = node.GetExecutionProviderType();
const auto& expected_provider = (node_provider.empty() ? exec_provider : node_provider);
@@ -215,7 +216,7 @@ Status KernelRegistry::TryFindKernelImpl(const Node& node,
std::ostream_iterator(oss, "\n"));
oss << ")";
- VLOGS_DEFAULT(2) << "TryFindKernel failed, Reason: " << oss.str();
+ VLOGS(logger, 2) << "TryFindKernel failed, Reason: " << oss.str();
return Status(common::ONNXRUNTIME, common::FAIL, oss.str());
}
@@ -224,14 +225,16 @@ Status KernelRegistry::TryFindKernelImpl(const Node& node,
Status KernelRegistry::TryFindKernel(const Node& node, ProviderType exec_provider,
const IKernelTypeStrResolver& kernel_type_str_resolver,
+ const logging::Logger& logger,
const KernelCreateInfo** out) const {
- return TryFindKernelImpl(node, exec_provider, &kernel_type_str_resolver, nullptr, out);
+ return TryFindKernelImpl(node, exec_provider, &kernel_type_str_resolver, nullptr, logger, out);
}
Status KernelRegistry::TryFindKernel(const Node& node, ProviderType exec_provider,
const TypeConstraintMap& type_constraints,
+ const logging::Logger& logger,
const KernelCreateInfo** out) const {
- return TryFindKernelImpl(node, exec_provider, nullptr, &type_constraints, out);
+ return TryFindKernelImpl(node, exec_provider, nullptr, &type_constraints, logger, out);
}
static bool KernelDefCompatible(int version, const KernelDef& kernel_def,
@@ -261,6 +264,7 @@ Status KernelRegistry::TryFindKernel(ProviderType exec_provider,
std::string_view domain,
int version,
const KernelRegistry::TypeConstraintMap& type_constraints,
+ const logging::Logger& logger,
const KernelCreateInfo** out) const {
auto range = kernel_creator_fn_map_.equal_range(GetMapKey(op_type, domain, exec_provider));
if (out) *out = nullptr;
@@ -289,7 +293,7 @@ Status KernelRegistry::TryFindKernel(ProviderType exec_provider,
std::ostream_iterator(oss, "\n"));
oss << ")";
- VLOGS_DEFAULT(2) << "TryFindKernel failed, Reason: " << oss.str();
+ VLOGS(logger, 2) << "TryFindKernel failed, Reason: " << oss.str();
return Status(common::ONNXRUNTIME, common::FAIL, oss.str());
}
diff --git a/onnxruntime/core/framework/kernel_registry_manager.cc b/onnxruntime/core/framework/kernel_registry_manager.cc
index f8ccdb8fb0238..721353854a474 100644
--- a/onnxruntime/core/framework/kernel_registry_manager.cc
+++ b/onnxruntime/core/framework/kernel_registry_manager.cc
@@ -57,7 +57,7 @@ void KernelRegistryManager::RegisterKernelRegistry(std::shared_ptrTryFindKernel(node, std::string(), GetKernelTypeStrResolver(), kernel_create_info);
+ status = registry->TryFindKernel(node, std::string(), GetKernelTypeStrResolver(), logger, kernel_create_info);
if (status.IsOK()) {
return status;
}
@@ -95,7 +95,7 @@ Status KernelRegistryManager::SearchKernelRegistry(const Node& node,
}
if (p != nullptr) {
- status = p->TryFindKernel(node, std::string(), GetKernelTypeStrResolver(), kernel_create_info);
+ status = p->TryFindKernel(node, std::string(), GetKernelTypeStrResolver(), logger, kernel_create_info);
if (status.IsOK()) {
return status;
}
@@ -104,10 +104,14 @@ Status KernelRegistryManager::SearchKernelRegistry(const Node& node,
return Status(ONNXRUNTIME, NOT_IMPLEMENTED, create_error_message("Failed to find kernel for "));
}
-bool KernelRegistryManager::HasImplementationOf(const KernelRegistryManager& r, const Node& node, const std::string& provider_type) {
+bool KernelRegistryManager::HasImplementationOf(const KernelRegistryManager& r,
+ const Node& node,
+ const std::string& provider_type,
+ const logging::Logger& logger) {
const auto kernel_registries = r.GetKernelRegistriesByProviderType(provider_type);
return std::any_of(kernel_registries.begin(), kernel_registries.end(), [&](const KernelRegistry* kernel_registry) {
- return KernelRegistry::HasImplementationOf(*kernel_registry, node, provider_type, r.GetKernelTypeStrResolver());
+ return KernelRegistry::HasImplementationOf(*kernel_registry, node, provider_type, r.GetKernelTypeStrResolver(),
+ logger);
});
}
diff --git a/onnxruntime/core/framework/kernel_registry_manager.h b/onnxruntime/core/framework/kernel_registry_manager.h
index 1da73208cb536..72f0ed3c6268a 100644
--- a/onnxruntime/core/framework/kernel_registry_manager.h
+++ b/onnxruntime/core/framework/kernel_registry_manager.h
@@ -67,13 +67,14 @@ class KernelRegistryManager {
// This function assumes the node is already assigned to an execution provider
// Don't call this function before graph partition is done
- Status SearchKernelRegistry(const Node& node,
+ Status SearchKernelRegistry(const Node& node, const logging::Logger& logger,
/*out*/ const KernelCreateInfo** kernel_create_info) const;
/**
* Whether this node can be run on this provider
*/
- static bool HasImplementationOf(const KernelRegistryManager& r, const Node& node, const std::string& provider_type);
+ static bool HasImplementationOf(const KernelRegistryManager& r, const Node& node, const std::string& provider_type,
+ const logging::Logger& logger);
Status CreateKernel(const Node& node,
const IExecutionProvider& execution_provider,
diff --git a/onnxruntime/core/framework/session_state.cc b/onnxruntime/core/framework/session_state.cc
index 0d0b22ff61e01..0ac2271ba09f1 100644
--- a/onnxruntime/core/framework/session_state.cc
+++ b/onnxruntime/core/framework/session_state.cc
@@ -178,7 +178,7 @@ Status SessionState::PopulateKernelCreateInfo(const KernelRegistryManager& kerne
bool saving_ort_format) {
for (auto& node : graph_.Nodes()) {
const KernelCreateInfo* kci = nullptr;
- auto status = kernel_registry_manager.SearchKernelRegistry(node, &kci);
+ auto status = kernel_registry_manager.SearchKernelRegistry(node, logger_, &kci);
if (!status.IsOK() && saving_ort_format) {
// if we didn't find the kernel and are saving to ORT format an EP that compiles nodes is enabled.
// in that case we assigned the node to that EP but do not compile it into a fused node.
@@ -187,7 +187,7 @@ Status SessionState::PopulateKernelCreateInfo(const KernelRegistryManager& kerne
// at runtime when the model is loaded in a minimal build, the compiling EP will replace this node if possible.
// if that's not possible for some reason we can fallback to the CPU EP implementation.
node.SetExecutionProviderType(kCpuExecutionProvider);
- status = kernel_registry_manager.SearchKernelRegistry(node, &kci);
+ status = kernel_registry_manager.SearchKernelRegistry(node, logger_, &kci);
}
ORT_RETURN_IF_ERROR(status);
diff --git a/onnxruntime/core/graph/contrib_ops/contrib_defs.cc b/onnxruntime/core/graph/contrib_ops/contrib_defs.cc
index 09a4a77780916..c7a0793c4748f 100644
--- a/onnxruntime/core/graph/contrib_ops/contrib_defs.cc
+++ b/onnxruntime/core/graph/contrib_ops/contrib_defs.cc
@@ -3335,6 +3335,11 @@ void RegisterContribSchemas() {
AttributeProto::STRING,
OPTIONAL_VALUE)
.Attr("notes", "(Optional) Some notes for the model", AttributeProto::STRING, OPTIONAL_VALUE)
+ .Attr(
+ "max_size",
+ "max size in the context. Usage depend on the EP.",
+ AttributeProto::INT,
+ static_cast(0))
.AllowUncheckedAttributes()
.Input(
0,
diff --git a/onnxruntime/core/graph/contrib_ops/quantization_defs.cc b/onnxruntime/core/graph/contrib_ops/quantization_defs.cc
index 6f1f1c831d191..5a3cd86b04492 100644
--- a/onnxruntime/core/graph/contrib_ops/quantization_defs.cc
+++ b/onnxruntime/core/graph/contrib_ops/quantization_defs.cc
@@ -9,7 +9,7 @@
#include "core/graph/constants.h"
#include "core/graph/contrib_ops/contrib_defs.h"
#include "core/graph/contrib_ops/shape_inference_functions.h"
-#include "onnx/onnx-ml.pb.h" // ?
+#include "core/graph/onnx_protobuf.h"
// Suppress a warning: global initializer calls a non-constexpr function 'symbol' which is from
// ONNX_OPERATOR_SET_SCHEMA_EX macro and only happens in debug build
@@ -23,7 +23,7 @@ void convTransposeShapeInference(InferenceContext& ctx);
void convPoolShapeInference(ONNX_NAMESPACE::InferenceContext& ctx, bool use_dilation, bool require_kernel_shape,
int input1Idx, int input2Idx);
namespace defs::math::utils {
- void MatMulShapeInference(ONNX_NAMESPACE::InferenceContext& ctx, int input1Idx, int input2Idx);
+void MatMulShapeInference(ONNX_NAMESPACE::InferenceContext& ctx, int input1Idx, int input2Idx);
}
} // namespace ONNX_NAMESPACE
@@ -822,10 +822,10 @@ ONNX_MS_OPERATOR_SET_SCHEMA(
}
}
- if (all_lengths_known) {
- output_shape->mutable_dim(axis)->set_dim_value(total_length);
- }
- }));
+ if (all_lengths_known) {
+ output_shape->mutable_dim(axis)->set_dim_value(total_length);
+ }
+ }));
ONNX_MS_OPERATOR_SET_SCHEMA(QLinearWhere, 1, OpSchema()
.SetDoc("Return elements, either from X or Y, depending on condition.")
@@ -955,7 +955,8 @@ ONNX_MS_OPERATOR_SET_SCHEMA(
AttributeProto::INT, static_cast(0))
.Attr("do_rotary", "Whether to use rotary position embedding. Default value is 0.",
AttributeProto::INT, OPTIONAL_VALUE)
- .Attr("past_present_share_buffer", "Corresponding past and present are same tensor, its shape is "
+ .Attr("past_present_share_buffer",
+ "Corresponding past and present are same tensor, its shape is "
"(2, batch_size, num_heads, max_sequence_length, head_size)",
AttributeProto::INT, OPTIONAL_VALUE)
.Attr("mask_filter_value",
diff --git a/onnxruntime/core/mlas/inc/mlas.h b/onnxruntime/core/mlas/inc/mlas.h
index 28ae64c4d5b3e..207c058d899b4 100644
--- a/onnxruntime/core/mlas/inc/mlas.h
+++ b/onnxruntime/core/mlas/inc/mlas.h
@@ -1435,6 +1435,29 @@ MLAS_FP16* Destination,
size_t Count
);
+/**
+ * @brief rotary embedding for one hidden state vector
+ *
+ * @tparam T: data type of input, sin, cos and output. Currently only float32/16 are supported.
+ * @param input: input tensor, of shape [dim]
+ * @param sin: sin tensor, of shape [dim/2]
+ * @param cos: cos tensor, of shape [dim/2]
+ * @param dim: dimension of rotary embedding
+ * @param interleaved: whether the real part and imaginary parts are interleaved
+ * @param output: output tensor, of shape [dim]
+ */
+template
+void
+MLASCALL
+MlasRotaryEmbedOneRow(
+ const T* input,
+ const T* sin,
+ const T* cos,
+ size_t dim,
+ bool interleaved,
+ T* output
+);
+
/**
* @brief Whether current CPU supports FP16 acceleration.
*/
diff --git a/onnxruntime/core/mlas/lib/fp16_neon_common.cpp b/onnxruntime/core/mlas/lib/cast_kernel_neon.cpp
similarity index 99%
rename from onnxruntime/core/mlas/lib/fp16_neon_common.cpp
rename to onnxruntime/core/mlas/lib/cast_kernel_neon.cpp
index 29734c2277667..8a385c9c61751 100644
--- a/onnxruntime/core/mlas/lib/fp16_neon_common.cpp
+++ b/onnxruntime/core/mlas/lib/cast_kernel_neon.cpp
@@ -6,7 +6,7 @@ Licensed under the MIT License.
Module Name:
- fp16_neon_common.cpp
+ cast_kernel_neon.cpp
Abstract:
diff --git a/onnxruntime/core/mlas/lib/mlasi.h b/onnxruntime/core/mlas/lib/mlasi.h
index 0533a5e49b0bb..100d7d47751aa 100644
--- a/onnxruntime/core/mlas/lib/mlasi.h
+++ b/onnxruntime/core/mlas/lib/mlasi.h
@@ -1049,6 +1049,13 @@ extern const MLAS_QNBIT_GEMM_DISPATCH MlasSQNBitGemmDispatchAvx512;
extern const MLAS_QNBIT_GEMM_DISPATCH MlasSQNBitGemmDispatchAvx512vnni;
+//
+// Rotary embedding dispatch structure.
+//
+struct MLAS_ROPE_DISPATCH;
+extern const MLAS_ROPE_DISPATCH MlasRopeDispatchNeon;
+
+
//
// Quantized depthwise convolution kernels.
//
@@ -1208,6 +1215,8 @@ struct MLAS_PLATFORM {
MLAS_CAST_F16_TO_F32_KERNEL* CastF16ToF32Kernel;
MLAS_CAST_F32_TO_F16_KERNEL* CastF32ToF16Kernel;
+
+ const MLAS_ROPE_DISPATCH* RopeDispatch{nullptr};
};
inline
diff --git a/onnxruntime/core/mlas/lib/platform.cpp b/onnxruntime/core/mlas/lib/platform.cpp
index b3c9461293fce..ec572a4150292 100644
--- a/onnxruntime/core/mlas/lib/platform.cpp
+++ b/onnxruntime/core/mlas/lib/platform.cpp
@@ -543,6 +543,7 @@ Return Value:
this->ConvSymU8S8Dispatch = &MlasConvSymU8DispatchNeon;
this->ConvSymS8S8Dispatch = &MlasConvSymS8DispatchNeon;
this->QNBitGemmDispatch = &MlasSQNBitGemmDispatchNeon;
+ this->RopeDispatch = &MlasRopeDispatchNeon;
//
// Check if the processor supports ASIMD dot product instructions.
diff --git a/onnxruntime/core/mlas/lib/rotary_embedding.cpp b/onnxruntime/core/mlas/lib/rotary_embedding.cpp
new file mode 100644
index 0000000000000..1f8f7b240694c
--- /dev/null
+++ b/onnxruntime/core/mlas/lib/rotary_embedding.cpp
@@ -0,0 +1,101 @@
+/*++
+
+Copyright (c) Intel Corporation. All rights reserved.
+
+Licensed under the MIT License.
+
+Module Name:
+
+ rotary_embedding.cpp
+
+Abstract:
+
+ This module implements rotary embedding kernels for fp32/16.
+
+--*/
+
+#include "rotary_embedding.h"
+
+namespace {
+
+template
+void
+MLASCALL
+MlasRotaryEmbedOneRow_FallBack(
+ const T* input_data,
+ const T* sin_data,
+ const T* cos_data,
+ size_t rotary_emb_dim,
+ bool interleaved,
+ T* output_data
+) {
+ const size_t half_rotary_emb_dim = rotary_emb_dim / 2;
+ size_t cache_idx = 0;
+ bool sign = false;
+ size_t j = 0;
+ for (size_t i = 0; i < rotary_emb_dim; i++) {
+ if (interleaved) {
+ cache_idx = (i / 2) % half_rotary_emb_dim;
+ sign = i & 1;
+ j = sign ? i - 1 : i + 1; // i - sign
+ } else {
+ cache_idx = i % half_rotary_emb_dim;
+ sign = (i >= half_rotary_emb_dim);
+ j = (i + half_rotary_emb_dim) % rotary_emb_dim;
+ }
+ float output_data_i = static_cast(input_data[i]) * static_cast(cos_data[cache_idx]);
+ float input_data_j = static_cast(input_data[j]);
+ float sin_data_cache_idx = static_cast(sin_data[cache_idx]);
+ if (sign) {
+ output_data_i += input_data_j * sin_data_cache_idx;
+ } else {
+ output_data_i -= input_data_j * sin_data_cache_idx;
+ }
+ output_data[i] = static_cast(output_data_i);
+ }
+}
+
+} // namespace
+
+
+template <>
+void
+MLASCALL
+MlasRotaryEmbedOneRow(
+ const float* input,
+ const float* sin,
+ const float* cos,
+ size_t dim,
+ bool interleaved,
+ float* output
+) {
+ const auto* dispatch = GetMlasPlatform().RopeDispatch;
+
+ if (dispatch == nullptr || dispatch->SRope == nullptr) {
+ MlasRotaryEmbedOneRow_FallBack(input, sin, cos, dim, interleaved, output);
+ return;
+ }
+
+ dispatch->SRope(input, sin, cos, dim, interleaved, output);
+}
+
+template <>
+void
+MLASCALL
+MlasRotaryEmbedOneRow(
+ const MLAS_FP16* input,
+ const MLAS_FP16* sin,
+ const MLAS_FP16* cos,
+ size_t dim,
+ bool interleaved,
+ MLAS_FP16* output
+) {
+ const auto* dispatch = GetMlasPlatform().RopeDispatch;
+
+ if (dispatch == nullptr || dispatch->HRope == nullptr) {
+ MlasRotaryEmbedOneRow_FallBack(input, sin, cos, dim, interleaved, output);
+ return;
+ }
+
+ dispatch->HRope(input, sin, cos, dim, interleaved, output);
+}
diff --git a/onnxruntime/core/mlas/lib/rotary_embedding.h b/onnxruntime/core/mlas/lib/rotary_embedding.h
new file mode 100644
index 0000000000000..352dddccf1025
--- /dev/null
+++ b/onnxruntime/core/mlas/lib/rotary_embedding.h
@@ -0,0 +1,46 @@
+/*++
+
+Copyright (c) Microsoft Corporation. All rights reserved.
+
+Licensed under the MIT License.
+
+Module Name:
+
+ rotary_embedding.h
+
+Abstract:
+
+ This module includes kernel function prototypes and helper functions for
+ implementing rotary embedding.
+
+--*/
+
+#pragma once
+
+#include "mlasi.h"
+
+struct MLAS_ROPE_DISPATCH {
+ // rotary embedding kernel for fp32
+ typedef void(SRope_Fn)(
+ const float* input,
+ const float* sin,
+ const float* cos,
+ size_t dim,
+ bool interleaved,
+ float* output
+ );
+
+ SRope_Fn* SRope = nullptr;
+
+ // rotary embedding kernel for fp16
+ typedef void(HRope_Fn)(
+ const MLAS_FP16* input,
+ const MLAS_FP16* sin,
+ const MLAS_FP16* cos,
+ size_t dim,
+ bool interleaved,
+ MLAS_FP16* output
+ );
+
+ HRope_Fn* HRope = nullptr;
+};
diff --git a/onnxruntime/core/mlas/lib/rotary_embedding_kernel_neon.cpp b/onnxruntime/core/mlas/lib/rotary_embedding_kernel_neon.cpp
new file mode 100644
index 0000000000000..e59a95cd9ee4e
--- /dev/null
+++ b/onnxruntime/core/mlas/lib/rotary_embedding_kernel_neon.cpp
@@ -0,0 +1,32 @@
+/*++
+
+Copyright (c) Microsoft Corporation. All rights reserved.
+
+Licensed under the MIT License.
+
+Module Name:
+
+ rotary_embedding_kernel_neon.cpp
+
+Abstract:
+
+ This module implements the rotary embedding kernels for ARM NEON.
+
+--*/
+
+#include "rotary_embedding.h"
+#include "rotary_embedding_kernel_neon.h"
+
+//
+// Kernel dispatch structure definition.
+//
+const MLAS_ROPE_DISPATCH MlasRopeDispatchNeon = []() {
+ MLAS_ROPE_DISPATCH d;
+
+#if defined(MLAS_F16VEC_INTRINSICS_SUPPORTED) && defined(MLAS_TARGET_ARM64)
+ if (MlasFp16AccelerationSupported()) {
+ d.HRope = rope_neon::RopeKernel_Fp16;
+ }
+#endif
+ return d;
+}();
diff --git a/onnxruntime/core/mlas/lib/rotary_embedding_kernel_neon.h b/onnxruntime/core/mlas/lib/rotary_embedding_kernel_neon.h
new file mode 100644
index 0000000000000..8153f65650f7d
--- /dev/null
+++ b/onnxruntime/core/mlas/lib/rotary_embedding_kernel_neon.h
@@ -0,0 +1,37 @@
+/*++
+
+Copyright (c) Microsoft Corporation. All rights reserved.
+
+Licensed under the MIT License.
+
+Module Name:
+
+ rotary_embedding_kernel_neon.h
+
+Abstract:
+
+ This module includes function declarations and common helper functions for
+ rotary embedding on ARM cpu.
+
+--*/
+
+#pragma once
+
+#include
+
+#include "mlasi.h"
+
+namespace rope_neon {
+
+// Rotary embedding kernel for fp16. Embed one hidden state vector.
+void
+RopeKernel_Fp16(
+ const MLAS_FP16* input,
+ const MLAS_FP16* sin,
+ const MLAS_FP16* cos,
+ size_t dim,
+ bool interleaved,
+ MLAS_FP16* output
+);
+
+} // namespace rope_neon
diff --git a/onnxruntime/core/mlas/lib/rotary_embedding_kernel_neon_fp16.cpp b/onnxruntime/core/mlas/lib/rotary_embedding_kernel_neon_fp16.cpp
new file mode 100644
index 0000000000000..3e2eb8fee0e6e
--- /dev/null
+++ b/onnxruntime/core/mlas/lib/rotary_embedding_kernel_neon_fp16.cpp
@@ -0,0 +1,253 @@
+/*++
+
+Copyright (c) Microsoft Corporation. All rights reserved.
+
+Licensed under the MIT License.
+
+Module Name:
+
+ rotary_embedding_kernel_neon_fp16.cpp
+
+Abstract:
+
+ This module implements the fp16 rotary embedding kernels for ARM NEON.
+
+--*/
+
+#include
+#include
+
+#include "fp16_common.h"
+#include "rotary_embedding.h"
+#include "rotary_embedding_kernel_neon.h"
+
+namespace rope_neon {
+
+namespace {
+
+template
+void
+RopeKernel_Fp16_Impl(
+ const _mlas_fp16_* input,
+ const _mlas_fp16_* sin,
+ const _mlas_fp16_* cos,
+ size_t dim,
+ _mlas_fp16_* output
+);
+
+template <>
+void
+RopeKernel_Fp16_Impl(
+ const _mlas_fp16_* input,
+ const _mlas_fp16_* sin,
+ const _mlas_fp16_* cos,
+ size_t dim,
+ _mlas_fp16_* output
+) {
+ const size_t half_dim = dim >> 1;
+ size_t i = 0, j = half_dim;
+ for (; i + 7 < half_dim; i += 8, j += 8) {
+ float16x8_t real = MlasLoadFloat16x8(input + i);
+ float16x8_t imag = MlasLoadFloat16x8(input + j);
+ float16x8_t sin_val = MlasLoadFloat16x8(sin + i);
+ float16x8_t cos_val = MlasLoadFloat16x8(cos + i);
+ float16x8_t real_out = vfmsq_f16(vmulq_f16(real, cos_val), imag, sin_val);
+ float16x8_t imag_out = vfmaq_f16(vmulq_f16(real, sin_val), imag, cos_val);
+ MlasStoreFloat16x8(output + i, real_out);
+ MlasStoreFloat16x8(output + j, imag_out);
+ }
+ for (; i + 3 < half_dim; i += 4, j += 4) {
+ float16x4_t real = MlasLoadFloat16x4(input + i);
+ float16x4_t imag = MlasLoadFloat16x4(input + j);
+ float16x4_t sin_val = MlasLoadFloat16x4(sin + i);
+ float16x4_t cos_val = MlasLoadFloat16x4(cos + i);
+ float16x4_t real_out = vfms_f16(vmul_f16(real, cos_val), imag, sin_val);
+ float16x4_t imag_out = vfma_f16(vmul_f16(real, sin_val), imag, cos_val);
+ MlasStoreFloat16x4(output + i, real_out);
+ MlasStoreFloat16x4(output + j, imag_out);
+ }
+ if (half_dim - i == 3) {
+ float16x4_t real = MlasZeroFloat16x4();
+ float16x4_t imag = MlasZeroFloat16x4();
+ float16x4_t sin_val = MlasZeroFloat16x4();
+ float16x4_t cos_val = MlasZeroFloat16x4();
+ real = MlasLoadLaneFloat16x4<0>(input + i, real);
+ real = MlasLoadLaneFloat16x4<1>(input + i + 1, real);
+ real = MlasLoadLaneFloat16x4<2>(input + i + 2, real);
+ imag = MlasLoadLaneFloat16x4<0>(input + j, imag);
+ imag = MlasLoadLaneFloat16x4<1>(input + j + 1, imag);
+ imag = MlasLoadLaneFloat16x4<2>(input + j + 2, imag);
+ sin_val = MlasLoadLaneFloat16x4<0>(sin + i, sin_val);
+ sin_val = MlasLoadLaneFloat16x4<1>(sin + i + 1, sin_val);
+ sin_val = MlasLoadLaneFloat16x4<2>(sin + i + 2, sin_val);
+ cos_val = MlasLoadLaneFloat16x4<0>(cos + i, cos_val);
+ cos_val = MlasLoadLaneFloat16x4<1>(cos + i + 1, cos_val);
+ cos_val = MlasLoadLaneFloat16x4<2>(cos + i + 2, cos_val);
+ float16x4_t real_out = vfms_f16(vmul_f16(real, cos_val), imag, sin_val);
+ float16x4_t imag_out = vfma_f16(vmul_f16(real, sin_val), imag, cos_val);
+ MlasStoreLaneFloat16x4<0>(output + i, real_out);
+ MlasStoreLaneFloat16x4<1>(output + i + 1, real_out);
+ MlasStoreLaneFloat16x4<2>(output + i + 2, real_out);
+ MlasStoreLaneFloat16x4<0>(output + j, imag_out);
+ MlasStoreLaneFloat16x4<1>(output + j + 1, imag_out);
+ MlasStoreLaneFloat16x4<2>(output + j + 2, imag_out);
+ } else if (half_dim - i == 2) {
+ float16x4_t real = MlasZeroFloat16x4();
+ float16x4_t imag = MlasZeroFloat16x4();
+ float16x4_t sin_val = MlasZeroFloat16x4();
+ float16x4_t cos_val = MlasZeroFloat16x4();
+ real = MlasLoadLaneFloat16x4<0>(input + i, real);
+ real = MlasLoadLaneFloat16x4<1>(input + i + 1, real);
+ imag = MlasLoadLaneFloat16x4<0>(input + j, imag);
+ imag = MlasLoadLaneFloat16x4<1>(input + j + 1, imag);
+ sin_val = MlasLoadLaneFloat16x4<0>(sin + i, sin_val);
+ sin_val = MlasLoadLaneFloat16x4<1>(sin + i + 1, sin_val);
+ cos_val = MlasLoadLaneFloat16x4<0>(cos + i, cos_val);
+ cos_val = MlasLoadLaneFloat16x4<1>(cos + i + 1, cos_val);
+ float16x4_t real_out = vfms_f16(vmul_f16(real, cos_val), imag, sin_val);
+ float16x4_t imag_out = vfma_f16(vmul_f16(real, sin_val), imag, cos_val);
+ MlasStoreLaneFloat16x4<0>(output + i, real_out);
+ MlasStoreLaneFloat16x4<1>(output + i + 1, real_out);
+ MlasStoreLaneFloat16x4<0>(output + j, imag_out);
+ MlasStoreLaneFloat16x4<1>(output + j + 1, imag_out);
+ } else if (half_dim - i == 1) {
+ float16x4_t real = MlasZeroFloat16x4();
+ float16x4_t imag = MlasZeroFloat16x4();
+ float16x4_t sin_val = MlasZeroFloat16x4();
+ float16x4_t cos_val = MlasZeroFloat16x4();
+ real = MlasLoadLaneFloat16x4<0>(input + i, real);
+ imag = MlasLoadLaneFloat16x4<0>(input + j, imag);
+ sin_val = MlasLoadLaneFloat16x4<0>(sin + i, sin_val);
+ cos_val = MlasLoadLaneFloat16x4<0>(cos + i, cos_val);
+ float16x4_t real_out = vfms_f16(vmul_f16(real, cos_val), imag, sin_val);
+ float16x4_t imag_out = vfma_f16(vmul_f16(real, sin_val), imag, cos_val);
+ MlasStoreLaneFloat16x4<0>(output + i, real_out);
+ MlasStoreLaneFloat16x4<0>(output + j, imag_out);
+ }
+}
+
+template <>
+void
+RopeKernel_Fp16_Impl(
+ const _mlas_fp16_* input,
+ const _mlas_fp16_* sin,
+ const _mlas_fp16_* cos,
+ size_t dim,
+ _mlas_fp16_* output
+) {
+ size_t i = 0;
+ for (; i + 15 < dim; i += 16) {
+ float16x8_t x0 = MlasLoadFloat16x8(input + i);
+ float16x8_t x1 = MlasLoadFloat16x8(input + i + 8);
+ float16x8_t real = vuzp1q_f16(x0, x1);
+ float16x8_t imag = vuzp2q_f16(x0, x1);
+ float16x8_t sin_val = MlasLoadFloat16x8(sin + i);
+ float16x8_t cos_val = MlasLoadFloat16x8(cos + i);
+ float16x8_t real_out = vfmsq_f16(vmulq_f16(real, cos_val), imag, sin_val);
+ float16x8_t imag_out = vfmaq_f16(vmulq_f16(real, sin_val), imag, cos_val);
+ float16x8_t y0 = vzip1q_f16(real_out, imag_out);
+ float16x8_t y1 = vzip2q_f16(real_out, imag_out);
+ MlasStoreFloat16x8(output + i, y0);
+ MlasStoreFloat16x8(output + i + 8, y1);
+ }
+ for (; i + 7 < dim; i += 8) {
+ float16x4_t x0 = MlasLoadFloat16x4(input + i);
+ float16x4_t x1 = MlasLoadFloat16x4(input + i + 4);
+ float16x4_t real = vuzp1_f16(x0, x1);
+ float16x4_t imag = vuzp2_f16(x0, x1);
+ float16x4_t sin_val = MlasLoadFloat16x4(sin + i);
+ float16x4_t cos_val = MlasLoadFloat16x4(cos + i);
+ float16x4_t real_out = vfms_f16(vmul_f16(real, cos_val), imag, sin_val);
+ float16x4_t imag_out = vfma_f16(vmul_f16(real, sin_val), imag, cos_val);
+ float16x4_t y0 = vzip1_f16(real_out, imag_out);
+ float16x4_t y1 = vzip2_f16(real_out, imag_out);
+ MlasStoreFloat16x4(output + i, y0);
+ MlasStoreFloat16x4(output + i + 4, y1);
+ }
+ if (dim - i == 6) {
+ float16x4_t real = MlasZeroFloat16x4();
+ float16x4_t imag = MlasZeroFloat16x4();
+ float16x4_t sin_val = MlasZeroFloat16x4();
+ float16x4_t cos_val = MlasZeroFloat16x4();
+ real = MlasLoadLaneFloat16x4<0>(input + i, real);
+ imag = MlasLoadLaneFloat16x4<0>(input + i + 1, imag);
+ real = MlasLoadLaneFloat16x4<1>(input + i + 2, real);
+ imag = MlasLoadLaneFloat16x4<1>(input + i + 3, imag);
+ real = MlasLoadLaneFloat16x4<2>(input + i + 4, real);
+ imag = MlasLoadLaneFloat16x4<2>(input + i + 5, imag);
+ sin_val = MlasLoadLaneFloat16x4<0>(sin + i, sin_val);
+ sin_val = MlasLoadLaneFloat16x4<1>(sin + i + 1, sin_val);
+ sin_val = MlasLoadLaneFloat16x4<2>(sin + i + 2, sin_val);
+ cos_val = MlasLoadLaneFloat16x4<0>(cos + i, cos_val);
+ cos_val = MlasLoadLaneFloat16x4<1>(cos + i + 1, cos_val);
+ cos_val = MlasLoadLaneFloat16x4<2>(cos + i + 2, cos_val);
+ float16x4_t real_out = vfms_f16(vmul_f16(real, cos_val), imag, sin_val);
+ float16x4_t imag_out = vfma_f16(vmul_f16(real, sin_val), imag, cos_val);
+ MlasStoreLaneFloat16x4<0>(output + i, real_out);
+ MlasStoreLaneFloat16x4<0>(output + i + 1, imag_out);
+ MlasStoreLaneFloat16x4<1>(output + i + 2, real_out);
+ MlasStoreLaneFloat16x4<1>(output + i + 3, imag_out);
+ MlasStoreLaneFloat16x4<2>(output + i + 4, real_out);
+ MlasStoreLaneFloat16x4<2>(output + i + 5, imag_out);
+ } else if (dim - i == 4) {
+ float16x4_t real = MlasZeroFloat16x4();
+ float16x4_t imag = MlasZeroFloat16x4();
+ float16x4_t sin_val = MlasZeroFloat16x4();
+ float16x4_t cos_val = MlasZeroFloat16x4();
+ real = MlasLoadLaneFloat16x4<0>(input + i, real);
+ imag = MlasLoadLaneFloat16x4<0>(input + i + 1, imag);
+ real = MlasLoadLaneFloat16x4<1>(input + i + 2, real);
+ imag = MlasLoadLaneFloat16x4<1>(input + i + 3, imag);
+ sin_val = MlasLoadLaneFloat16x4<0>(sin + i, sin_val);
+ sin_val = MlasLoadLaneFloat16x4<1>(sin + i + 1, sin_val);
+ cos_val = MlasLoadLaneFloat16x4<0>(cos + i, cos_val);
+ cos_val = MlasLoadLaneFloat16x4<1>(cos + i + 1, cos_val);
+ float16x4_t real_out = vfms_f16(vmul_f16(real, cos_val), imag, sin_val);
+ float16x4_t imag_out = vfma_f16(vmul_f16(real, sin_val), imag, cos_val);
+ MlasStoreLaneFloat16x4<0>(output + i, real_out);
+ MlasStoreLaneFloat16x4<0>(output + i + 1, imag_out);
+ MlasStoreLaneFloat16x4<1>(output + i + 2, real_out);
+ MlasStoreLaneFloat16x4<1>(output + i + 3, imag_out);
+ } else if (dim - i == 2) {
+ float16x4_t real = MlasZeroFloat16x4();
+ float16x4_t imag = MlasZeroFloat16x4();
+ float16x4_t sin_val = MlasZeroFloat16x4();
+ float16x4_t cos_val = MlasZeroFloat16x4();
+ real = MlasLoadLaneFloat16x4<0>(input + i, real);
+ imag = MlasLoadLaneFloat16x4<0>(input + i + 1, imag);
+ sin_val = MlasLoadLaneFloat16x4<0>(sin + i, sin_val);
+ cos_val = MlasLoadLaneFloat16x4<0>(cos + i, cos_val);
+ float16x4_t real_out = vfms_f16(vmul_f16(real, cos_val), imag, sin_val);
+ float16x4_t imag_out = vfma_f16(vmul_f16(real, sin_val), imag, cos_val);
+ MlasStoreLaneFloat16x4<0>(output + i, real_out);
+ MlasStoreLaneFloat16x4<0>(output + i + 1, imag_out);
+ }
+}
+
+} // namespace
+
+void
+RopeKernel_Fp16(
+ const MLAS_FP16* input,
+ const MLAS_FP16* sin,
+ const MLAS_FP16* cos,
+ size_t dim,
+ bool interleaved,
+ MLAS_FP16* output
+) {
+ // real part and imaginary part must be paired
+ assert(dim % 2 == 0);
+
+ const auto* input_impl = reinterpret_cast(input);
+ const auto* sin_impl = reinterpret_cast(sin);
+ const auto* cos_impl = reinterpret_cast(cos);
+ auto* output_impl = reinterpret_cast<_mlas_fp16_*>(output);
+
+ if (interleaved) {
+ RopeKernel_Fp16_Impl(input_impl, sin_impl, cos_impl, dim, output_impl);
+ } else {
+ RopeKernel_Fp16_Impl(input_impl, sin_impl, cos_impl, dim, output_impl);
+ }
+}
+
+} // namespace rope_neon
diff --git a/onnxruntime/core/optimizer/constant_folding.cc b/onnxruntime/core/optimizer/constant_folding.cc
index 1466de51d0b99..e755b4bfa6364 100644
--- a/onnxruntime/core/optimizer/constant_folding.cc
+++ b/onnxruntime/core/optimizer/constant_folding.cc
@@ -227,11 +227,12 @@ Status ConstantFolding::ApplyImpl(Graph& graph, bool& modified, int graph_level,
#if !defined(DISABLE_SPARSE_TENSORS)
// Create execution frame for executing constant nodes.
OptimizerExecutionFrame::Info info({node}, constant_inputs, graph.ModelPath(), execution_provider_,
- is_sparse_initializer_check);
+ is_sparse_initializer_check, logger);
#else
// Create execution frame for executing constant nodes.
- OptimizerExecutionFrame::Info info({node}, constant_inputs, graph.ModelPath(), execution_provider_,
- [](std::string const&) { return false; });
+ OptimizerExecutionFrame::Info info(
+ {node}, constant_inputs, graph.ModelPath(), execution_provider_, [](const std::string&) { return false; },
+ logger);
#endif
std::vector fetch_mlvalue_idxs;
diff --git a/onnxruntime/core/optimizer/graph_transformer_utils.cc b/onnxruntime/core/optimizer/graph_transformer_utils.cc
index 2f2524420dc44..ba2b87b5aa0ca 100644
--- a/onnxruntime/core/optimizer/graph_transformer_utils.cc
+++ b/onnxruntime/core/optimizer/graph_transformer_utils.cc
@@ -190,6 +190,7 @@ InlinedVector> GenerateTransformers(
TransformerLevel level,
const SessionOptions& session_options,
const IExecutionProvider& cpu_execution_provider, /*required by constant folding*/
+ const logging::Logger& logger,
const InlinedHashSet& rules_and_transformers_to_disable,
[[maybe_unused]] concurrency::ThreadPool* intra_op_thread_pool,
std::unordered_map>* p_buffered_tensors) {
@@ -404,7 +405,8 @@ InlinedVector> GenerateTransformers(
}
auto cpu_registry = cpu_execution_provider.GetKernelRegistry();
- auto nhwc_transformer = std::make_unique(std::move(cpu_allocator), std::move(cpu_registry));
+ auto nhwc_transformer = std::make_unique(std::move(cpu_allocator), std::move(cpu_registry),
+ logger);
if (nhwc_transformer->IsActive()) {
transformers.emplace_back(std::move(nhwc_transformer));
}
@@ -437,6 +439,7 @@ InlinedVector> GenerateTransformersForMinimalB
const SessionOptions& session_options,
const SatApplyContextVariant& apply_context,
const IExecutionProvider& cpu_execution_provider,
+ const logging::Logger& logger,
const InlinedHashSet& rules_and_transformers_to_disable,
[[maybe_unused]] concurrency::ThreadPool* intra_op_thread_pool,
std::unordered_map>* p_buffered_tensors) {
@@ -490,7 +493,8 @@ InlinedVector> GenerateTransformersForMinimalB
#ifndef DISABLE_CONTRIB_OPS
AllocatorPtr cpu_allocator = std::make_shared();
auto cpu_registry = cpu_execution_provider.GetKernelRegistry();
- auto nhwc_transformer = std::make_unique(std::move(cpu_allocator), std::move(cpu_registry));
+ auto nhwc_transformer = std::make_unique(std::move(cpu_allocator), std::move(cpu_registry),
+ logger);
if (nhwc_transformer->IsActive()) {
transformers.emplace_back(std::move(nhwc_transformer));
}
diff --git a/onnxruntime/core/optimizer/insert_cast_transformer.cc b/onnxruntime/core/optimizer/insert_cast_transformer.cc
index 67ebc22dab41d..b1665c7172549 100644
--- a/onnxruntime/core/optimizer/insert_cast_transformer.cc
+++ b/onnxruntime/core/optimizer/insert_cast_transformer.cc
@@ -84,7 +84,9 @@ static bool NodeNeedsInputCastToFp32(const onnxruntime::Node& node) {
// going to a node that will need a Cast.
//
// Return true if all the fp16 inputs and outputs are connected to nodes that will be cast to fp32.
-static bool IsIsolatedFp16NodeOnCpu(const onnxruntime::Node& node, onnxruntime::Graph& graph, const KernelRegistry& cpu_kernel_registry) {
+static bool IsIsolatedFp16NodeOnCpu(const onnxruntime::Node& node, onnxruntime::Graph& graph,
+ const KernelRegistry& cpu_kernel_registry,
+ const logging::Logger& logger) {
// we can check if it's an isolated fp16 node
// if node has input coming from other nodes (only consuming graph inputs or initializers if it doesn't),
// does not have a subgraph (would have to alter subgraph inputs if we cast the input to this node),
@@ -211,7 +213,7 @@ static bool IsIsolatedFp16NodeOnCpu(const onnxruntime::Node& node, onnxruntime::
const KernelCreateInfo* kernel_create_info{};
const auto lookup_status = cpu_kernel_registry.TryFindKernel(
kCpuExecutionProvider, node.OpType(), node.Domain(),
- node.SinceVersion(), type_constraint_map, &kernel_create_info);
+ node.SinceVersion(), type_constraint_map, logger, &kernel_create_info);
if (lookup_status.IsOK() && kernel_create_info != nullptr) {
return true;
}
@@ -220,9 +222,10 @@ static bool IsIsolatedFp16NodeOnCpu(const onnxruntime::Node& node, onnxruntime::
return false;
}
-static Status ForceSingleNodeCPUFloat16ToFloat32(onnxruntime::Graph& graph, const KernelRegistry& cpu_kernel_registry) {
+static Status ForceSingleNodeCPUFloat16ToFloat32(onnxruntime::Graph& graph, const KernelRegistry& cpu_kernel_registry,
+ const logging::Logger& logger) {
for (auto& node : graph.Nodes()) {
- if (IsIsolatedFp16NodeOnCpu(node, graph, cpu_kernel_registry)) {
+ if (IsIsolatedFp16NodeOnCpu(node, graph, cpu_kernel_registry, logger)) {
// unassign the node so that NeedInsertCast will return true for it, forcing it to fp32
node.SetExecutionProviderType("");
}
@@ -319,7 +322,8 @@ class RemoveDuplicateCastTransformer : public GraphTransformer {
return dst_bit_length <= src_bit_length;
}
- if ((*src_type == "tensor(float16)" && *dst_type == "tensor(bfloat16)") || (*src_type == "tensor(bfloat16)" && *dst_type == "tensor(float16)")) {
+ if ((*src_type == "tensor(float16)" && *dst_type == "tensor(bfloat16)") ||
+ (*src_type == "tensor(bfloat16)" && *dst_type == "tensor(float16)")) {
return true;
}
@@ -453,7 +457,7 @@ class RemoveDuplicateCastTransformer : public GraphTransformer {
Status InsertCastTransformer::ApplyImpl(onnxruntime::Graph& graph, bool& modified, int graph_level,
const logging::Logger& logger) const {
if (force_cpu_fp32_)
- ORT_RETURN_IF_ERROR(ForceSingleNodeCPUFloat16ToFloat32(graph, *cpu_kernel_registries_));
+ ORT_RETURN_IF_ERROR(ForceSingleNodeCPUFloat16ToFloat32(graph, *cpu_kernel_registries_, logger));
GraphViewer graph_viewer(graph);
auto& order = graph_viewer.GetNodesInTopologicalOrder();
diff --git a/onnxruntime/core/optimizer/nhwc_transformer.cc b/onnxruntime/core/optimizer/nhwc_transformer.cc
index ee79fa620374e..cd654991c92d5 100644
--- a/onnxruntime/core/optimizer/nhwc_transformer.cc
+++ b/onnxruntime/core/optimizer/nhwc_transformer.cc
@@ -44,7 +44,9 @@ NhwcConvLookup(
return &(iter->second);
}
-NhwcTransformer::NhwcTransformer(AllocatorPtr cpu_allocator, std::shared_ptr cpu_kernel_registry) noexcept
+NhwcTransformer::NhwcTransformer(AllocatorPtr cpu_allocator,
+ std::shared_ptr cpu_kernel_registry,
+ const logging::Logger& logger) noexcept
: GraphTransformer("NhwcTransformer"), cpu_allocator_(std::move(cpu_allocator)) {
if (!cpu_kernel_registry) {
// This is a CPU op nodes optimizer, not useful if cpu EP is not available.
@@ -64,7 +66,7 @@ NhwcTransformer::NhwcTransformer(AllocatorPtr cpu_allocator, std::shared_ptrTryFindKernel(
kCpuExecutionProvider, qconv_int8.op_type_, qconv_int8.domain_,
- qconv_int8.version_, qconv_int8.type_constraints_, &kernel_create_info);
+ qconv_int8.version_, qconv_int8.type_constraints_, logger, &kernel_create_info);
if (status.IsOK() && kernel_create_info != nullptr) {
kernel_create_info = nullptr;
conv_table_.emplace(
@@ -83,7 +85,7 @@ NhwcTransformer::NhwcTransformer(AllocatorPtr cpu_allocator, std::shared_ptrTryFindKernel(
kCpuExecutionProvider, qconv_uint8.op_type_, qconv_uint8.domain_,
- qconv_uint8.version_, qconv_uint8.type_constraints_, &kernel_create_info);
+ qconv_uint8.version_, qconv_uint8.type_constraints_, logger, &kernel_create_info);
if (status.IsOK() && kernel_create_info != nullptr) {
kernel_create_info = nullptr;
conv_table_.emplace(
@@ -103,7 +105,7 @@ NhwcTransformer::NhwcTransformer(AllocatorPtr cpu_allocator, std::shared_ptrTryFindKernel(
kCpuExecutionProvider, nhwc_conv_fp16.op_type_, nhwc_conv_fp16.domain_,
- nhwc_conv_fp16.version_, nhwc_conv_fp16.type_constraints_, &kernel_create_info);
+ nhwc_conv_fp16.version_, nhwc_conv_fp16.type_constraints_, logger, &kernel_create_info);
if (status.IsOK() && kernel_create_info != nullptr) {
kernel_create_info = nullptr;
conv_table_.emplace(
@@ -123,7 +125,7 @@ NhwcTransformer::NhwcTransformer(AllocatorPtr cpu_allocator, std::shared_ptrTryFindKernel(
kCpuExecutionProvider, nhwc_maxpool_fp16.op_type_, nhwc_maxpool_fp16.domain_,
- nhwc_maxpool_fp16.version_, nhwc_maxpool_fp16.type_constraints_, &kernel_create_info);
+ nhwc_maxpool_fp16.version_, nhwc_maxpool_fp16.type_constraints_, logger, &kernel_create_info);
if (status.IsOK() && kernel_create_info != nullptr) {
kernel_create_info = nullptr;
conv_table_.emplace(
@@ -140,7 +142,7 @@ NhwcTransformer::NhwcTransformer(AllocatorPtr cpu_allocator, std::shared_ptrTryFindKernel(
kCpuExecutionProvider, nhwc_avgpool_fp16.op_type_, nhwc_avgpool_fp16.domain_,
- nhwc_avgpool_fp16.version_, nhwc_avgpool_fp16.type_constraints_, &kernel_create_info);
+ nhwc_avgpool_fp16.version_, nhwc_avgpool_fp16.type_constraints_, logger, &kernel_create_info);
if (status.IsOK() && kernel_create_info != nullptr) {
kernel_create_info = nullptr;
conv_table_.emplace(
@@ -157,7 +159,7 @@ NhwcTransformer::NhwcTransformer(AllocatorPtr cpu_allocator, std::shared_ptrTryFindKernel(
kCpuExecutionProvider, nhwc_gavgpool_fp16.op_type_, nhwc_gavgpool_fp16.domain_,
- nhwc_gavgpool_fp16.version_, nhwc_gavgpool_fp16.type_constraints_, &kernel_create_info);
+ nhwc_gavgpool_fp16.version_, nhwc_gavgpool_fp16.type_constraints_, logger, &kernel_create_info);
if (status.IsOK() && kernel_create_info != nullptr) {
kernel_create_info = nullptr;
conv_table_.emplace(
diff --git a/onnxruntime/core/optimizer/nhwc_transformer.h b/onnxruntime/core/optimizer/nhwc_transformer.h
index 000732060b889..c65f851fdab9d 100644
--- a/onnxruntime/core/optimizer/nhwc_transformer.h
+++ b/onnxruntime/core/optimizer/nhwc_transformer.h
@@ -75,7 +75,8 @@ and inserts nodes to transpose tensors as needed.
class NhwcTransformer : public GraphTransformer {
private:
public:
- explicit NhwcTransformer(AllocatorPtr cpu_allocator, std::shared_ptr cpu_kernel_registry) noexcept;
+ explicit NhwcTransformer(AllocatorPtr cpu_allocator, std::shared_ptr cpu_kernel_registry,
+ const logging::Logger& logger) noexcept;
/**
* @brief Usually called right after constructor, it shows whether
diff --git a/onnxruntime/core/optimizer/optimizer_execution_frame.cc b/onnxruntime/core/optimizer/optimizer_execution_frame.cc
index ed7d5feb2beb3..b2e8e491c361c 100644
--- a/onnxruntime/core/optimizer/optimizer_execution_frame.cc
+++ b/onnxruntime/core/optimizer/optimizer_execution_frame.cc
@@ -32,9 +32,11 @@ OptimizerExecutionFrame::Info::Info(const std::vector& nodes,
const InitializedTensorSet& initialized_tensor_set,
const std::filesystem::path& model_path,
const IExecutionProvider& execution_provider,
- const std::function& is_sparse_initializer_func)
+ const std::function& is_sparse_initializer_func,
+ const logging::Logger& logger)
: execution_provider_(execution_provider),
- is_sparse_initializer_func_(is_sparse_initializer_func) {
+ is_sparse_initializer_func_(is_sparse_initializer_func),
+ logger_(logger) {
allocator_ptr_ = std::make_shared();
ORT_ENFORCE(allocator_ptr_, "Failed to get allocator for optimizer");
@@ -79,9 +81,11 @@ OptimizerExecutionFrame::Info::Info(const std::vector& nodes,
const std::unordered_map& initialized_tensor_set,
const std::filesystem::path& /* model_path */,
const IExecutionProvider& execution_provider,
- const std::function& is_sparse_initializer_func)
+ const std::function& is_sparse_initializer_func,
+ const logging::Logger& logger)
: execution_provider_(execution_provider),
- is_sparse_initializer_func_(is_sparse_initializer_func) {
+ is_sparse_initializer_func_(is_sparse_initializer_func),
+ logger_(logger) {
allocator_ptr_ = std::make_shared();
ORT_ENFORCE(allocator_ptr_, "Failed to get allocator for optimizer");
@@ -117,7 +121,7 @@ OptimizerExecutionFrame::Info::Info(const std::vector& nodes,
Status OptimizerExecutionFrame::Info::TryFindKernel(const Node* node, const KernelCreateInfo** out) const {
std::shared_ptr kernel_registry = execution_provider_.GetKernelRegistry();
const OpSchemaKernelTypeStrResolver kernel_type_str_resolver{};
- return kernel_registry->TryFindKernel(*node, execution_provider_.Type(), kernel_type_str_resolver, out);
+ return kernel_registry->TryFindKernel(*node, execution_provider_.Type(), kernel_type_str_resolver, logger_, out);
}
static Status TryCreateKernel(const Node& node,
@@ -128,10 +132,11 @@ static Status TryCreateKernel(const Node& node,
FuncManager& funcs_mgr,
const DataTransferManager& data_transfer_mgr,
const ConfigOptions& config_options,
+ const logging::Logger& logger,
/*out*/ std::unique_ptr& op_kernel) {
const OpSchemaKernelTypeStrResolver kernel_type_str_resolver{};
const KernelCreateInfo* kernel_create_info = nullptr;
- ORT_RETURN_IF_ERROR(kernel_registry.TryFindKernel(node, execution_provider.Type(), kernel_type_str_resolver,
+ ORT_RETURN_IF_ERROR(kernel_registry.TryFindKernel(node, execution_provider.Type(), kernel_type_str_resolver, logger,
&kernel_create_info));
static const AllocatorMap dummy_allocators;
@@ -154,7 +159,7 @@ OptimizerExecutionFrame::Info::CreateKernel(const Node* node, const ConfigOption
std::shared_ptr kernel_registry = execution_provider_.GetKernelRegistry();
FuncManager func;
auto status = TryCreateKernel(*node, *kernel_registry, execution_provider_, initializers_,
- ort_value_name_idx_map_, func, data_transfer_mgr_, config_options,
+ ort_value_name_idx_map_, func, data_transfer_mgr_, config_options, logger_,
op_kernel);
// Kernel found in the CPU kernel registry
diff --git a/onnxruntime/core/optimizer/optimizer_execution_frame.h b/onnxruntime/core/optimizer/optimizer_execution_frame.h
index b0f7f461661b5..24a23312feba9 100644
--- a/onnxruntime/core/optimizer/optimizer_execution_frame.h
+++ b/onnxruntime/core/optimizer/optimizer_execution_frame.h
@@ -27,13 +27,15 @@ class OptimizerExecutionFrame final : public IExecutionFrame {
const InitializedTensorSet& initialized_tensor_set,
const std::filesystem::path& model_path,
const IExecutionProvider& execution_provider,
- const std::function& is_sparse_initializer_func);
+ const std::function& is_sparse_initializer_func,
+ const logging::Logger& logger);
Info(const std::vector& nodes,
const std::unordered_map& initialized_tensor_set,
const std::filesystem::path& model_path,
const IExecutionProvider& execution_provider,
- const std::function& is_sparse_initializer_func);
+ const std::function& is_sparse_initializer_func,
+ const logging::Logger& logger);
~Info() = default;
@@ -76,6 +78,7 @@ class OptimizerExecutionFrame final : public IExecutionFrame {
std::unique_ptr node_index_info_;
const IExecutionProvider& execution_provider_;
const std::function& is_sparse_initializer_func_;
+ const logging::Logger& logger_;
ORT_DISALLOW_COPY_ASSIGNMENT_AND_MOVE(Info);
};
diff --git a/onnxruntime/core/optimizer/qdq_transformer/avx2_weight_s8_to_u8.cc b/onnxruntime/core/optimizer/qdq_transformer/avx2_weight_s8_to_u8.cc
index 18e462c04dff3..5538aa54801cc 100644
--- a/onnxruntime/core/optimizer/qdq_transformer/avx2_weight_s8_to_u8.cc
+++ b/onnxruntime/core/optimizer/qdq_transformer/avx2_weight_s8_to_u8.cc
@@ -36,7 +36,7 @@ static inline bool MatchesOpSinceVersion(
return std::find(versions.begin(), versions.end(), node.SinceVersion()) != versions.end();
}
-static bool TryConvertDynamicQuantizeLSTM(Node& op_node, Graph& graph) {
+static bool TryConvertDynamicQuantizeLSTM(Node& op_node, Graph& graph, const logging::Logger& logger) {
constexpr size_t w_idx = 1;
constexpr size_t w_zp_idx = 9;
constexpr size_t r_idx = 2;
@@ -60,7 +60,7 @@ static bool TryConvertDynamicQuantizeLSTM(Node& op_node, Graph& graph) {
if (!graph_utils::NodeArgIsConstant(graph, *input_defs[r_idx]) ||
!graph.GetInitializedTensor(input_defs[r_idx]->Name(), r_tensor_proto) ||
r_tensor_proto->data_type() != ONNX_NAMESPACE::TensorProto_DataType_INT8) {
- LOGS_DEFAULT(WARNING) << "Unable transforming DynamicQuantizeLSTM operator,"
+ LOGS(logger, WARNING) << "Unable transforming DynamicQuantizeLSTM operator,"
<< " cannot locate recurrence tensor of const int8 type,"
<< " int8 overflow might impact precision !";
return false;
@@ -86,7 +86,7 @@ static bool TryConvertDynamicQuantizeLSTM(Node& op_node, Graph& graph) {
if (!graph_utils::NodeArgIsConstant(graph, *input_defs[r_zp_idx]) ||
!graph.GetInitializedTensor(input_defs[r_zp_idx]->Name(), r_zp_tensor_proto) ||
r_zp_tensor_proto->data_type() != ONNX_NAMESPACE::TensorProto_DataType_INT8) {
- LOGS_DEFAULT(WARNING) << "Unable transforming DynamicQuantizeLSTM operator,"
+ LOGS(logger, WARNING) << "Unable transforming DynamicQuantizeLSTM operator,"
<< " unable to locate recurrence tensor or its zero point value,"
<< " int8 overflow might impact precision !";
return false;
@@ -171,7 +171,7 @@ Status Avx2WeightS8ToU8Transformer::ApplyImpl(Graph& graph, bool& modified, int
if (graph_utils::IsSupportedOptypeVersionAndDomain(
op_node, "DynamicQuantizeLSTM", {1}, kMSDomain)) {
// This one has two set of quantized arguments
- modified |= TryConvertDynamicQuantizeLSTM(op_node, graph);
+ modified |= TryConvertDynamicQuantizeLSTM(op_node, graph, logger);
continue; // go on to next operator node
}
diff --git a/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/shared/utils.cc b/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/shared/utils.cc
index d2240b5d50194..81305f7effa16 100644
--- a/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/shared/utils.cc
+++ b/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/shared/utils.cc
@@ -291,7 +291,8 @@ SelectorManager::SelectorManager() {
InitializeSelectorsMap();
}
-std::vector SelectorManager::GetQDQSelections(const GraphViewer& graph_viewer) const {
+std::vector SelectorManager::GetQDQSelections(const GraphViewer& graph_viewer,
+ const logging::Logger& logger) const {
std::vector qdq_selections;
for (auto index : graph_viewer.GetNodesInTopologicalOrder()) {
const auto* node = graph_viewer.GetNode(index);
@@ -313,7 +314,7 @@ std::vector SelectorManager::GetQDQSelections(const GraphViewer& grap
const auto& versions = op_versions_and_selector.op_versions_map.find(node->OpType())->second;
if (!versions.empty()) {
if (std::find(versions.cbegin(), versions.cend(), node->SinceVersion()) == versions.cend()) {
- LOGS_DEFAULT(VERBOSE) << "Op version is not supported for" << node->OpType();
+ LOGS(logger, VERBOSE) << "Op version is not supported for" << node->OpType();
continue;
}
}
@@ -329,7 +330,7 @@ std::vector SelectorManager::GetQDQSelections(const GraphViewer& grap
}
std::pair>, std::unordered_map>
-GetAllNodeUnits(const GraphViewer& graph_viewer) {
+GetAllNodeUnits(const GraphViewer& graph_viewer, const logging::Logger& logger) {
std::vector> node_unit_holder;
std::unordered_map node_unit_map;
@@ -342,7 +343,7 @@ GetAllNodeUnits(const GraphViewer& graph_viewer) {
// Get QDQ NodeUnits first
QDQ::SelectorManager selector_mgr;
- const auto qdq_selections = selector_mgr.GetQDQSelections(graph_viewer);
+ const auto qdq_selections = selector_mgr.GetQDQSelections(graph_viewer, logger);
for (const auto& qdq_selection : qdq_selections) {
auto qdq_unit = std::make_unique(graph_viewer, qdq_selection);
diff --git a/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/shared/utils.h b/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/shared/utils.h
index f388206551172..ccc1844e3e985 100644
--- a/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/shared/utils.h
+++ b/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/shared/utils.h
@@ -15,7 +15,9 @@
#endif
namespace onnxruntime {
-
+namespace logging {
+class Logger;
+}
class GraphViewer;
class Node;
@@ -65,7 +67,7 @@ class SelectorManager {
// Methods that finds and returns a vector of QDQ::NodeGroup in a given graph
// Can be used in QDQ support in different EPs
- std::vector GetQDQSelections(const GraphViewer& graph_viewer) const;
+ std::vector GetQDQSelections(const GraphViewer& graph_viewer, const logging::Logger& logger) const;
private:
Selectors qdq_selectors_;
@@ -88,7 +90,7 @@ class SelectorManager {
// We currently have a bit of a mess with generic things like this to get all the node units being in the optimizer
// library whereas it should be able to be used by an EP with no dependency on optimizers.
std::pair>, std::unordered_map>
-GetAllNodeUnits(const GraphViewer& graph_viewer);
+GetAllNodeUnits(const GraphViewer& graph_viewer, const logging::Logger& logger);
} // namespace QDQ
} // namespace onnxruntime
diff --git a/onnxruntime/core/optimizer/transformer_memcpy.cc b/onnxruntime/core/optimizer/transformer_memcpy.cc
index 6a5a85ce0ff31..8c0136c495403 100644
--- a/onnxruntime/core/optimizer/transformer_memcpy.cc
+++ b/onnxruntime/core/optimizer/transformer_memcpy.cc
@@ -17,13 +17,22 @@ class TransformerMemcpyImpl {
TransformerMemcpyImpl(onnxruntime::Graph& graph, const std::string& provider)
: graph_(graph), provider_(provider) {}
- bool ModifyGraph(const KernelRegistryManager& schema_registries, const logging::Logger& logger, int& copy_node_counter);
+ bool ModifyGraph(const KernelRegistryManager& schema_registries,
+ const logging::Logger& logger,
+ int& copy_node_counter);
private:
- void ProcessDefs(onnxruntime::Node& node, const KernelRegistryManager& kernel_registries, InitializedTensorSet& initializers_consumed);
- void BuildDefsMapping(const onnxruntime::NodeArg* arg, const KernelRegistryManager& kernel_registries);
+ void ProcessDefs(onnxruntime::Node& node,
+ const KernelRegistryManager& kernel_registries,
+ InitializedTensorSet& initializers_consumed,
+ const logging::Logger& logger);
+ void BuildDefsMapping(const onnxruntime::NodeArg* arg,
+ const KernelRegistryManager& kernel_registries,
+ const logging::Logger& logger);
void AddCopyNode(onnxruntime::NodeArg* arg, bool is_input, const logging::Logger& logger);
- bool ProcessInitializers(const KernelRegistryManager& kernel_registries, const InitializedTensorSet& initializers_consumed);
+ bool ProcessInitializers(const KernelRegistryManager& kernel_registries,
+ const InitializedTensorSet& initializers_consumed,
+ const logging::Logger& logger);
private:
ORT_DISALLOW_COPY_ASSIGNMENT_AND_MOVE(TransformerMemcpyImpl);
@@ -130,21 +139,21 @@ bool TransformerMemcpyImpl::ModifyGraph(const KernelRegistryManager& kernel_regi
// find defs that require copy
for (auto& node : graph_.Nodes()) {
// as we process the defs, collect all the initializers consumed at the current graph level
- ProcessDefs(node, kernel_registries, initializers_consumed);
+ ProcessDefs(node, kernel_registries, initializers_consumed, logger);
}
// for initializers shared by different providers, create dups
- if (ProcessInitializers(kernel_registries, initializers_consumed))
+ if (ProcessInitializers(kernel_registries, initializers_consumed, logger))
modified = true;
for (auto arg : graph_.GetInputs())
- BuildDefsMapping(arg, kernel_registries);
+ BuildDefsMapping(arg, kernel_registries, logger);
for (auto arg : non_provider_input_defs_)
- BuildDefsMapping(arg, kernel_registries);
+ BuildDefsMapping(arg, kernel_registries, logger);
for (auto arg : non_provider_output_defs_)
- BuildDefsMapping(arg, kernel_registries);
+ BuildDefsMapping(arg, kernel_registries, logger);
for (auto arg : graph_.GetInputs())
// For inputs we need to create a copy node only when the input is connected to both provider
@@ -202,8 +211,10 @@ bool TransformerMemcpyImpl::ModifyGraph(const KernelRegistryManager& kernel_regi
return modified;
}
-void TransformerMemcpyImpl::ProcessDefs(onnxruntime::Node& node, const KernelRegistryManager& kernel_registries,
- InitializedTensorSet& initializers_consumed) {
+void TransformerMemcpyImpl::ProcessDefs(onnxruntime::Node& node,
+ const KernelRegistryManager& kernel_registries,
+ InitializedTensorSet& initializers_consumed,
+ const logging::Logger& logger) {
auto node_provider_type = node.GetExecutionProviderType();
if ((node_provider_type == provider_) ||
(node_provider_type == kCudaExecutionProvider && kTensorrtExecutionProvider == provider_) ||
@@ -211,7 +222,7 @@ void TransformerMemcpyImpl::ProcessDefs(onnxruntime::Node& node, const KernelReg
provider_nodes_.insert(&node);
// note KernelCreateInfo might be nullptr for custom kernel
const KernelCreateInfo* kci = nullptr;
- ORT_IGNORE_RETURN_VALUE(kernel_registries.SearchKernelRegistry(node, &kci));
+ ORT_IGNORE_RETURN_VALUE(kernel_registries.SearchKernelRegistry(node, logger, &kci));
bool is_implicit_input = false;
auto process_inputs =
@@ -278,7 +289,9 @@ void TransformerMemcpyImpl::ProcessDefs(onnxruntime::Node& node, const KernelReg
}
// for non_provider defs, collect the nodes that expect it is provider tensor as input/output.
-void TransformerMemcpyImpl::BuildDefsMapping(const onnxruntime::NodeArg* arg, const KernelRegistryManager& kernel_registries) {
+void TransformerMemcpyImpl::BuildDefsMapping(const onnxruntime::NodeArg* arg,
+ const KernelRegistryManager& kernel_registries,
+ const logging::Logger& logger) {
for (auto& it : graph_.Nodes()) {
if (it.OpType() == "MemcpyFromHost" || it.OpType() == "MemcpyToHost") continue;
auto input_it =
@@ -296,7 +309,7 @@ void TransformerMemcpyImpl::BuildDefsMapping(const onnxruntime::NodeArg* arg, co
(node_provider_type == kCudaExecutionProvider && kTensorrtExecutionProvider == provider_) ||
(node_provider_type == kRocmExecutionProvider && kMIGraphXExecutionProvider == provider_)) {
const KernelCreateInfo* kci = nullptr;
- ORT_IGNORE_RETURN_VALUE(kernel_registries.SearchKernelRegistry(it, &kci));
+ ORT_IGNORE_RETURN_VALUE(kernel_registries.SearchKernelRegistry(it, logger, &kci));
if (arg_input_index != -1) {
if (!kci || !utils::IsInputOnCpu(it, kci, arg_input_index)) provider_input_nodes_[arg].insert(&it);
}
@@ -351,7 +364,9 @@ static const onnxruntime::NodeArg* FindNodeArg(const NodeArgSetType& def_set, co
// We duplicate any initializer that is used by both provider nodes and non-provider nodes
// to ensure that provider nodes and non-provider nodes don't share initializers, as they
// need to stay in different memory locations.
-bool TransformerMemcpyImpl::ProcessInitializers(const KernelRegistryManager& kernel_registries, const InitializedTensorSet& initializers_consumed) {
+bool TransformerMemcpyImpl::ProcessInitializers(const KernelRegistryManager& kernel_registries,
+ const InitializedTensorSet& initializers_consumed,
+ const logging::Logger& logger) {
std::map replacements;
for (const auto& pair : initializers_consumed) {
const auto& name = pair.first;
@@ -383,7 +398,7 @@ bool TransformerMemcpyImpl::ProcessInitializers(const KernelRegistryManager& ker
auto dup_replacements = replacements;
const KernelCreateInfo* kci = nullptr;
- auto status = kernel_registries.SearchKernelRegistry(*p_node, &kci);
+ auto status = kernel_registries.SearchKernelRegistry(*p_node, logger, &kci);
ORT_ENFORCE(status.IsOK(), status.ErrorMessage());
if (kci == nullptr) continue;
if (kci->kernel_def == nullptr) continue;
diff --git a/onnxruntime/core/platform/windows/hardware_core_enumerator.cc b/onnxruntime/core/platform/windows/hardware_core_enumerator.cc
index bf3b53afbd7d3..7464ab4c57d01 100644
--- a/onnxruntime/core/platform/windows/hardware_core_enumerator.cc
+++ b/onnxruntime/core/platform/windows/hardware_core_enumerator.cc
@@ -1,7 +1,8 @@
-// Copyright (c) Microsoft Corporation. All rights reserved.
+// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#include "hardware_core_enumerator.h"
+#include "core/platform/windows/env.h"
#include
#include
#include
@@ -83,6 +84,38 @@ uint32_t HardwareCoreEnumerator::DefaultIntraOpNumThreads() {
// # of physical cores = # of P cores + # of E Cores + # of Soc Cores.
// # of logical cores = # of P cores x 2 (if hyper threading is enabled) + # of E cores + # of Soc Cores.
auto cores = GetCoreInfo();
+#if !defined(_M_ARM64EC) && !defined(_M_ARM64) && !defined(__aarch64__)
+ const int kVendorID_Intel[3] = {0x756e6547, 0x6c65746e, 0x49656e69}; // "GenuntelineI"
+ bool isIntelSpecifiedPlatform = false;
+ const int kVendorID_IntelSpecifiedPlatformIDs[3] = {
+ // ExtendedModel, ExtendedFamily, Family Code, and Model Number
+ 0xa06a, // MTL
+ 0xc065, // ARL-H
+ 0xb065 // ARL-U
+ };
+
+ int regs_leaf0[4];
+ int regs_leaf1[4];
+ __cpuid(regs_leaf0, 0);
+ __cpuid(regs_leaf1, 0x1);
+
+ auto isIntel = (kVendorID_Intel[0] == regs_leaf0[1]) && (kVendorID_Intel[1] == regs_leaf0[2]) && (kVendorID_Intel[2] == regs_leaf0[3]);
+
+ for (int intelSpecifiedPlatform : kVendorID_IntelSpecifiedPlatformIDs) {
+ if ((regs_leaf1[0] >> 4) == intelSpecifiedPlatform) {
+ isIntelSpecifiedPlatform = true;
+ }
+ }
+
+ if (isIntel) {
+ if (isIntelSpecifiedPlatform) {
+ // We want to exclude cores without an LLC
+ return cores.LLCCores;
+ } else {
+ return cores.PhysicalCores;
+ }
+ }
+#endif
return cores.LLCCores;
}
diff --git a/onnxruntime/core/providers/cann/cann_execution_provider.cc b/onnxruntime/core/providers/cann/cann_execution_provider.cc
index a799ed743ef52..f954baf3eabae 100644
--- a/onnxruntime/core/providers/cann/cann_execution_provider.cc
+++ b/onnxruntime/core/providers/cann/cann_execution_provider.cc
@@ -1288,15 +1288,15 @@ CANNExecutionProvider::GetCapability(const onnxruntime::GraphViewer& graph_viewe
const KernelCreateInfo* cann_kernel_def = kernel_lookup.LookUpKernel(node);
if (cann_kernel_def == nullptr) {
- LOGS_DEFAULT(INFO) << "CANN kernel not found in registries for Op type: " << node.OpType()
- << " node name: " << node.Name();
+ LOGS(*GetLogger(), INFO) << "CANN kernel not found in registries for Op type: " << node.OpType()
+ << " node name: " << node.Name();
continue;
}
candidates.push_back(node.Index());
}
- auto cpu_nodes = GetCpuPreferredNodes(graph_viewer, kernel_lookup, candidates);
+ auto cpu_nodes = GetCpuPreferredNodes(graph_viewer, kernel_lookup, candidates, *GetLogger());
for (auto& node_index : candidates) {
if (cpu_nodes.count(node_index) > 0)
continue;
diff --git a/onnxruntime/core/providers/coreml/builders/impl/batch_norm_op_builder.cc b/onnxruntime/core/providers/coreml/builders/impl/batch_norm_op_builder.cc
index cc68fa6ec399a..442194cb31cbc 100644
--- a/onnxruntime/core/providers/coreml/builders/impl/batch_norm_op_builder.cc
+++ b/onnxruntime/core/providers/coreml/builders/impl/batch_norm_op_builder.cc
@@ -151,7 +151,7 @@ bool BatchNormalizationOpBuilder::IsOpSupportedImpl(const Node& node, const OpBu
return false;
}
-#if defined(TARGET_OS_IOS) && defined(TARGET_CPU_X86_64)
+#if defined(TARGET_OS_IOS) && defined(TARGET_CPU_X86_64) && TARGET_OS_IOS && TARGET_CPU_X86_64
// To Pass IOS pipeline https://dev.azure.com/onnxruntime/onnxruntime/_build?definitionId=134&_a=summary
auto input_dtype = input_defs[0]->TypeAsProto()->tensor_type().elem_type();
if (input_dtype == ONNX_NAMESPACE::TensorProto_DataType_FLOAT16 && input_params.coreml_version < 7) {
diff --git a/onnxruntime/core/providers/coreml/builders/impl/reduction_op_builder.cc b/onnxruntime/core/providers/coreml/builders/impl/reduction_op_builder.cc
index f161b309a2425..d533b867bd454 100644
--- a/onnxruntime/core/providers/coreml/builders/impl/reduction_op_builder.cc
+++ b/onnxruntime/core/providers/coreml/builders/impl/reduction_op_builder.cc
@@ -133,9 +133,8 @@ bool ReductionOpBuilder::IsOpSupportedImpl(const Node& node, const OpBuilderInpu
return false;
}
-#if defined(TARGET_OS_IOS) && defined(TARGET_CPU_X86_64)
- // to pass https://dev.azure.com/onnxruntime/onnxruntime/_build/results?buildId=1563483&view=logs&j=f7cc61a9-cc70-56e7-b06c-4668ca17e426
- // ReductionOpTest.ReduceSum_half_bert
+#if defined(TARGET_OS_IOS) && defined(TARGET_CPU_X86_64) && TARGET_OS_IOS && TARGET_CPU_X86_64
+ // skip ReductionOpTest.ReduceSum_half_bert because reduce_sum will output all zeros
int32_t input_type;
GetType(*input_defs[0], input_type, logger);
if (node.OpType() == "ReduceSum" && input_type == ONNX_NAMESPACE::TensorProto_DataType_FLOAT16) {
diff --git a/onnxruntime/core/providers/coreml/builders/impl/squeeze_op_builder.cc b/onnxruntime/core/providers/coreml/builders/impl/squeeze_op_builder.cc
index c8df7c1a43f65..a1b3a18265c70 100644
--- a/onnxruntime/core/providers/coreml/builders/impl/squeeze_op_builder.cc
+++ b/onnxruntime/core/providers/coreml/builders/impl/squeeze_op_builder.cc
@@ -13,6 +13,10 @@
#include "core/optimizer/initializer.h"
#include "core/providers/cpu/tensor/unsqueeze.h"
+#ifdef __APPLE__
+#include
+#endif
+
namespace onnxruntime {
namespace coreml {
@@ -54,32 +58,50 @@ void SqueezeOpBuilder::AddInitializersToSkip(ModelBuilder& model_builder, const
}
}
+#if defined(COREML_ENABLE_MLPROGRAM)
+void HandleX86ArchUnsqueezeScalarInput(ModelBuilder& model_builder,
+ const Node& node, const logging::Logger& logger) {
+ const auto& input_defs(node.InputDefs());
+ TensorShapeVector axes;
+ GetAxes(model_builder, node, axes);
+
+ std::vector input_shape;
+ GetShape(*input_defs[0], input_shape, logger);
+ auto op = model_builder.CreateOperation(node, "reshape");
+ AddOperationInput(*op, "x", input_defs[0]->Name());
+ TensorShapeVector output_shape = UnsqueezeBase::ComputeOutputShape(TensorShape(input_shape), axes);
+ AddOperationInput(*op, "shape", model_builder.AddConstant(op->type(), "shape", AsSpan(output_shape)));
+ AddOperationOutput(*op, *node.OutputDefs()[0]);
+ model_builder.AddOperation(std::move(op));
+}
+#endif
+
Status SqueezeOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder,
const Node& node,
[[maybe_unused]] const logging::Logger& logger) const {
std::unique_ptr layer = model_builder.CreateNNLayer(node);
- const auto& input_defs(node.InputDefs());
auto* coreml_squeeze = layer->mutable_squeeze();
TensorShapeVector axes;
GetAxes(model_builder, node, axes);
- std::vector input_shape;
- GetShape(*input_defs[0], input_shape, logger);
#if defined(COREML_ENABLE_MLPROGRAM)
+ const auto& input_defs(node.InputDefs());
if (model_builder.CreateMLProgram()) {
using namespace CoreML::Specification::MILSpec;
- std::string_view coreml_op_type = node.OpType() == "Squeeze" ? "squeeze" : "reshape";
+#if defined(TARGET_CPU_X86_64) && TARGET_CPU_X86_64
+ // expand_dims has limited requirements for static shape, however, X86_64 has a bug that it can't handle scalar input
+ if (node.OpType() == "Unsqueeze" && input_defs[0]->Shape()->dim_size() < 2) {
+ HandleX86ArchUnsqueezeScalarInput(model_builder, node, logger);
+ return Status::OK();
+ }
+#endif
+ std::string_view coreml_op_type = node.OpType() == "Squeeze" ? "squeeze" : "expand_dims";
std::unique_ptr op = model_builder.CreateOperation(node, coreml_op_type);
AddOperationInput(*op, "x", input_defs[0]->Name());
- if (coreml_op_type == "squeeze") {
- if (!axes.empty()) {
- // coreml squeeze op does support negative axes
- AddOperationInput(*op, "axes", model_builder.AddConstant(op->type(), "axes", AsSpan(axes)));
- }
- } else {
- TensorShapeVector output_shape = UnsqueezeBase::ComputeOutputShape(TensorShape(input_shape), axes);
- AddOperationInput(*op, "shape", model_builder.AddConstant(op->type(), "shape", AsSpan(output_shape)));
+ if (!axes.empty()) {
+ // coreml supports negative axes
+ AddOperationInput(*op, "axes", model_builder.AddConstant(op->type(), "axes", AsSpan(axes)));
}
AddOperationOutput(*op, *node.OutputDefs()[0]);
model_builder.AddOperation(std::move(op));
diff --git a/onnxruntime/core/providers/coreml/builders/model_builder.cc b/onnxruntime/core/providers/coreml/builders/model_builder.cc
index 2a02c1f4124f6..6486942199df7 100644
--- a/onnxruntime/core/providers/coreml/builders/model_builder.cc
+++ b/onnxruntime/core/providers/coreml/builders/model_builder.cc
@@ -408,7 +408,7 @@ ModelBuilder::ModelBuilder(const GraphViewer& graph_viewer, const logging::Logge
: graph_viewer_(graph_viewer),
logger_(logger),
coreml_version_(coreml_version),
- coreml_compute_unit_(coreml_options.ComputeUnits()),
+ coreml_options_(coreml_options),
create_ml_program_(coreml_options.CreateMLProgram()),
model_output_path_(GetModelOutputPath(create_ml_program_)),
onnx_input_names_(std::move(onnx_input_names)),
@@ -989,7 +989,7 @@ Status ModelBuilder::LoadModel(std::unique_ptr& model) {
get_sanitized_io_info(std::move(input_output_info_)),
std::move(scalar_outputs_),
std::move(int64_outputs_),
- logger_, coreml_compute_unit_);
+ logger_, coreml_options_);
} else
#endif
{
@@ -999,7 +999,7 @@ Status ModelBuilder::LoadModel(std::unique_ptr& model) {
std::move(input_output_info_),
std::move(scalar_outputs_),
std::move(int64_outputs_),
- logger_, coreml_compute_unit_);
+ logger_, coreml_options_);
}
return model->LoadModel(); // load using CoreML API, including compilation
diff --git a/onnxruntime/core/providers/coreml/builders/model_builder.h b/onnxruntime/core/providers/coreml/builders/model_builder.h
index af47869f7e1c3..e19597cf0dc2e 100644
--- a/onnxruntime/core/providers/coreml/builders/model_builder.h
+++ b/onnxruntime/core/providers/coreml/builders/model_builder.h
@@ -7,6 +7,7 @@
#include "core/graph/graph_viewer.h"
#include "core/providers/coreml/builders/coreml_spec.h"
#include "core/providers/coreml/model/model.h"
+#include "core/providers/coreml/coreml_options.h"
#if defined(COREML_ENABLE_MLPROGRAM)
// coremltools classes
@@ -22,8 +23,6 @@ class StorageWriter;
#endif
namespace onnxruntime {
-class CoreMLOptions;
-
namespace coreml {
class IOpBuilder;
@@ -218,7 +217,7 @@ class ModelBuilder {
const GraphViewer& graph_viewer_;
const logging::Logger& logger_;
const int32_t coreml_version_;
- const uint32_t coreml_compute_unit_;
+ CoreMLOptions coreml_options_;
const bool create_ml_program_; // ML Program (CoreML5, iOS 15+, macOS 12+) or NeuralNetwork (old)
const std::string model_output_path_; // create_ml_program_ ? dir for mlpackage : filename for mlmodel
diff --git a/onnxruntime/core/providers/coreml/coreml_options.cc b/onnxruntime/core/providers/coreml/coreml_options.cc
index df78f74383871..4ec780208e528 100644
--- a/onnxruntime/core/providers/coreml/coreml_options.cc
+++ b/onnxruntime/core/providers/coreml/coreml_options.cc
@@ -63,11 +63,14 @@ void CoreMLOptions::ValidateAndParseProviderOption(const ProviderOptions& option
{"MLProgram", COREML_FLAG_CREATE_MLPROGRAM},
{"NeuralNetwork", COREML_FLAG_USE_NONE},
};
- std::unordered_set valid_options = {
+ const std::unordered_set valid_options = {
kCoremlProviderOption_MLComputeUnits,
kCoremlProviderOption_ModelFormat,
kCoremlProviderOption_RequireStaticInputShapes,
kCoremlProviderOption_EnableOnSubgraphs,
+ kCoremlProviderOption_SpecializationStrategy,
+ kCoremlProviderOption_ProfileComputePlan,
+ kCoremlProviderOption_AllowLowPrecisionAccumulationOnGPU,
};
// Validate the options
for (const auto& option : options) {
@@ -90,6 +93,16 @@ void CoreMLOptions::ValidateAndParseProviderOption(const ProviderOptions& option
require_static_shape_ = option.second == "1";
} else if (kCoremlProviderOption_EnableOnSubgraphs == option.first) {
enable_on_subgraph_ = option.second == "1";
+ } else if (kCoremlProviderOption_SpecializationStrategy == option.first) {
+ if (option.second != "Default" && option.second != "FastPrediction") {
+ ORT_THROW("Invalid value for option ", option.first, ": ", option.second,
+ ". Valid values are Default and FastPrediction.");
+ }
+ strategy_ = option.second;
+ } else if (kCoremlProviderOption_ProfileComputePlan == option.first) {
+ profile_compute_plan_ = option.second == "1";
+ } else if (kCoremlProviderOption_AllowLowPrecisionAccumulationOnGPU == option.first) {
+ allow_low_precision_accumulation_on_gpu_ = option.second == "1";
}
}
}
diff --git a/onnxruntime/core/providers/coreml/coreml_options.h b/onnxruntime/core/providers/coreml/coreml_options.h
index 8bb748fcd69c9..fd05c96927bd1 100644
--- a/onnxruntime/core/providers/coreml/coreml_options.h
+++ b/onnxruntime/core/providers/coreml/coreml_options.h
@@ -14,6 +14,9 @@ class CoreMLOptions {
bool create_mlprogram_{false};
bool enable_on_subgraph_{false};
uint32_t compute_units_{0};
+ std::string strategy_;
+ bool profile_compute_plan_{false};
+ bool allow_low_precision_accumulation_on_gpu_{false};
public:
explicit CoreMLOptions(uint32_t coreml_flags);
@@ -25,6 +28,9 @@ class CoreMLOptions {
bool CreateMLProgram() const { return create_mlprogram_; }
bool EnableOnSubgraph() const { return enable_on_subgraph_; }
uint32_t ComputeUnits(uint32_t specific_flag = 0xffffffff) const { return compute_units_ & specific_flag; }
+ bool AllowLowPrecisionAccumulationOnGPU() const { return allow_low_precision_accumulation_on_gpu_; }
+ bool UseStrategy(std::string_view strategy) const { return strategy_ == strategy; }
+ bool ProfileComputePlan() const { return profile_compute_plan_ && create_mlprogram_; }
private:
void ValidateAndParseProviderOption(const ProviderOptions& options);
diff --git a/onnxruntime/core/providers/coreml/model/model.h b/onnxruntime/core/providers/coreml/model/model.h
index 68ecbe5fb80c4..84b7d741b4714 100644
--- a/onnxruntime/core/providers/coreml/model/model.h
+++ b/onnxruntime/core/providers/coreml/model/model.h
@@ -18,6 +18,7 @@
#endif
namespace onnxruntime {
+class CoreMLOptions;
namespace coreml {
class Execution;
@@ -53,7 +54,7 @@ class Model {
std::unordered_map&& input_output_info,
std::unordered_set&& scalar_outputs,
std::unordered_set&& int64_outputs,
- const logging::Logger& logger, uint32_t coreml_compute_unit);
+ const logging::Logger& logger, const CoreMLOptions& coreml_options);
~Model();
ORT_DISALLOW_COPY_ASSIGNMENT_AND_MOVE(Model);
diff --git a/onnxruntime/core/providers/coreml/model/model.mm b/onnxruntime/core/providers/coreml/model/model.mm
index c8edb64ff55d7..755dbfbd6e68c 100644
--- a/onnxruntime/core/providers/coreml/model/model.mm
+++ b/onnxruntime/core/providers/coreml/model/model.mm
@@ -25,6 +25,7 @@
#include "core/providers/coreml/model/host_utils.h"
#include "core/providers/coreml/model/objc_str_utils.h"
#include "core/providers/coreml/shape_utils.h"
+#include "core/providers/coreml/coreml_options.h"
// force the linker to create a dependency on the CoreML framework so that in MAUI usage we don't need
// to manually do this
@@ -300,6 +301,53 @@ Status GetMLMultiArrayCopyInfo(const MLMultiArray* _Nonnull array,
return Status::OK();
}
+// since __clang_major__ >= 15, MLComputePlan is introduced in
+// We are actually ensure the MacOS/IOS version and Xcode version is greater than `macOS 14.4, iOS 17.4`.
+// The macro API_AVAILABLE should also be fine.
+// Otherwise, the compiler will complain `MLComputePlan` is not defined.
+// we define __clang_analyzer__ here is for bypass static analysis
+void ProfileComputePlan(NSURL* compileUrl, MLModelConfiguration* config) {
+#if defined(__APPLE__) && defined(__clang__) && __clang_major__ >= 15 && !defined(__clang_analyzer__)
+ if (@available(macOS 14.4, iOS 17.4, *)) {
+ [MLComputePlan loadContentsOfURL:compileUrl
+ configuration:config
+ completionHandler:^(MLComputePlan* _Nullable computePlan, NSError* _Nullable error) {
+ if (!computePlan) {
+ NSLog(@"Error loading compute plan: %@", error);
+ // Handle error.
+ return;
+ }
+ MLModelStructureProgram* program = computePlan.modelStructure.program;
+ if (!program) {
+ NSLog(@"Error loading program from compute plan., this is not a mlprogram model");
+ return;
+ }
+
+ MLModelStructureProgramFunction* mainFunction = program.functions[@"main"];
+ if (!mainFunction) {
+ NSLog(@"Error loading main function from program");
+ return;
+ }
+
+ NSArray* operations = mainFunction.block.operations;
+ NSLog(@"Number of operations, 'const' node is included. : %lu", operations.count);
+ for (MLModelStructureProgramOperation* operation in operations) {
+ // Get the compute device usage for the operation.
+ MLComputePlanDeviceUsage* computeDeviceUsage = [computePlan computeDeviceUsageForMLProgramOperation:operation];
+ id preferredDevice = computeDeviceUsage.preferredComputeDevice;
+ // Get the estimated cost of executing the operation.
+ MLComputePlanCost* estimatedCost = [computePlan estimatedCostOfMLProgramOperation:operation];
+ if (![operation.operatorName isEqualToString:@"const"]) {
+ NSLog(@"Operation: %@, Device Usage: %@, Estimated Cost: %f", operation.operatorName, preferredDevice, estimatedCost.weight);
+ }
+ }
+ }];
+ } else {
+ NSLog(@"iOS 17.4+/macOS 14.4+ or later is required to use the compute plan API");
+ }
+#endif
+}
+
// Internal Execution class
// This class is part of the model class and handles the calls into CoreML. Specifically, it performs
// 1. Compile the model by given path for execution
@@ -307,7 +355,7 @@ Status GetMLMultiArrayCopyInfo(const MLMultiArray* _Nonnull array,
// 3. The compiled model will be removed in dealloc or removed using cleanup function
class Execution {
public:
- Execution(const std::string& path, const logging::Logger& logger, uint32_t coreml_flags);
+ Execution(const std::string& path, const logging::Logger& logger, const CoreMLOptions& coreml_options);
~Execution();
Status LoadModel();
@@ -320,13 +368,13 @@ Status Predict(const std::unordered_map& inputs,
NSString* coreml_model_path_{nil};
NSString* compiled_model_path_{nil};
const logging::Logger& logger_;
- uint32_t coreml_compute_unit_{0};
+ CoreMLOptions coreml_options_;
MLModel* model_{nil};
};
-Execution::Execution(const std::string& path, const logging::Logger& logger, uint32_t coreml_compute_unit)
+Execution::Execution(const std::string& path, const logging::Logger& logger, const CoreMLOptions& coreml_options)
: logger_(logger),
- coreml_compute_unit_(coreml_compute_unit) {
+ coreml_options_(coreml_options) {
@autoreleasepool {
coreml_model_path_ = util::Utf8StringToNSString(path.c_str());
}
@@ -395,17 +443,41 @@ Status Predict(const std::unordered_map& inputs,
compiled_model_path_ = [compileUrl path];
MLModelConfiguration* config = [[MLModelConfiguration alloc] init];
-
- if (coreml_compute_unit_ & COREML_FLAG_USE_CPU_ONLY) {
+ uint32_t coreml_compute_unit = coreml_options_.ComputeUnits();
+ if (coreml_compute_unit & COREML_FLAG_USE_CPU_ONLY) {
config.computeUnits = MLComputeUnitsCPUOnly;
- } else if (coreml_compute_unit_ & COREML_FLAG_USE_CPU_AND_GPU) {
+ } else if (coreml_compute_unit & COREML_FLAG_USE_CPU_AND_GPU) {
config.computeUnits = MLComputeUnitsCPUAndGPU;
- } else if (coreml_compute_unit_ & COREML_FLAG_ONLY_ENABLE_DEVICE_WITH_ANE) {
+ } else if (coreml_compute_unit & COREML_FLAG_ONLY_ENABLE_DEVICE_WITH_ANE) {
config.computeUnits = MLComputeUnitsCPUAndNeuralEngine; // Apple Neural Engine
} else {
config.computeUnits = MLComputeUnitsAll;
}
+ if (coreml_options_.AllowLowPrecisionAccumulationOnGPU()) {
+ config.allowLowPrecisionAccumulationOnGPU = YES;
+ }
+
+// Set the specialization strategy to FastPrediction for macOS 10.15+
+// since __clang_major__ >= 15, optimizationHints is introduced in
+// Same as above comments for why we are checking __clang_major__.
+// we define __clang_analyzer__ here is for bypass static analysis
+#if defined(__APPLE__) && defined(__clang__) && __clang_major__ >= 15 && !defined(__clang_analyzer__)
+ if (HAS_COREML8_OR_LATER) {
+ MLOptimizationHints* optimizationHints = [[MLOptimizationHints alloc] init];
+ if (coreml_options_.UseStrategy("FastPrediction")) {
+ optimizationHints.specializationStrategy = MLSpecializationStrategyFastPrediction;
+ config.optimizationHints = optimizationHints;
+ } else if (coreml_options_.UseStrategy("Default")) {
+ optimizationHints.specializationStrategy = MLSpecializationStrategyDefault;
+ config.optimizationHints = optimizationHints;
+ }
+ }
+#endif
+ if (coreml_options_.ProfileComputePlan()) {
+ ProfileComputePlan(compileUrl, config);
+ }
+
model_ = [MLModel modelWithContentsOfURL:compileUrl configuration:config error:&error];
if (error != nil || model_ == nil) {
@@ -524,8 +596,8 @@ Status Predict(const std::unordered_map& inputs,
std::unordered_set&& scalar_outputs,
std::unordered_set&& int64_outputs,
const logging::Logger& logger,
- uint32_t coreml_flags)
- : execution_(std::make_unique(path, logger, coreml_flags)),
+ const CoreMLOptions& coreml_options)
+ : execution_(std::make_unique(path, logger, coreml_options)),
model_input_names_(std::move(model_input_names)),
model_output_names_(std::move(model_output_names)),
input_output_info_(std::move(input_output_info)),
diff --git a/onnxruntime/core/providers/coreml/model/model_stub.cc b/onnxruntime/core/providers/coreml/model/model_stub.cc
index c6f2e7401ea1e..e9036e2fc7e1a 100644
--- a/onnxruntime/core/providers/coreml/model/model_stub.cc
+++ b/onnxruntime/core/providers/coreml/model/model_stub.cc
@@ -4,6 +4,7 @@
#include "core/providers/coreml/model/model.h"
namespace onnxruntime {
+class CoreMLOptions;
namespace coreml {
class Execution {};
@@ -15,7 +16,7 @@ Model::Model(const std::string& /*path*/,
std::unordered_set&& scalar_outputs,
std::unordered_set&& int64_outputs,
const logging::Logger& /*logger*/,
- uint32_t /*coreml_flags*/)
+ const CoreMLOptions& /*coreml_flags*/)
: execution_(std::make_unique()),
model_input_names_(std::move(model_input_names)),
model_output_names_(std::move(model_output_names)),
diff --git a/onnxruntime/core/providers/cuda/cuda_execution_provider.cc b/onnxruntime/core/providers/cuda/cuda_execution_provider.cc
index 8396e2629d2bf..d4013a7dc3d57 100644
--- a/onnxruntime/core/providers/cuda/cuda_execution_provider.cc
+++ b/onnxruntime/core/providers/cuda/cuda_execution_provider.cc
@@ -2693,7 +2693,7 @@ CUDAExecutionProvider::GetCapability(const onnxruntime::GraphViewer& graph,
// For CUDA EP, exclude the subgraph that is preferred to be placed in CPU
// These are usually shape related computation subgraphs
// Following logic can be extended for other EPs
- auto cpu_nodes = GetCpuPreferredNodes(graph, kernel_lookup, tentative_nodes);
+ auto cpu_nodes = GetCpuPreferredNodes(graph, kernel_lookup, tentative_nodes, logger);
std::vector> result;
for (auto& node_index : candidates) {
if (cpu_nodes.count(node_index) > 0)
diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlGraphFusionTransformer.cpp b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlGraphFusionTransformer.cpp
index 35a2c451a49a5..9f95818501dac 100644
--- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlGraphFusionTransformer.cpp
+++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlGraphFusionTransformer.cpp
@@ -62,7 +62,8 @@ namespace Dml
const auto kernel_type_str_resolver = onnxruntime::OpSchemaKernelTypeStrResolver{};
const auto kernel_lookup = onnxruntime::KernelLookup{provider_type,
gsl::make_span(®istry, 1),
- kernel_type_str_resolver};
+ kernel_type_str_resolver,
+ logger};
std::vector> compiledPartitionInfos;
std::vector additionalSplittingNodes;
diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlRuntimeGraphFusionTransformer.cpp b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlRuntimeGraphFusionTransformer.cpp
index 6318b0d5e2865..b9b90d6bc17bd 100644
--- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlRuntimeGraphFusionTransformer.cpp
+++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlRuntimeGraphFusionTransformer.cpp
@@ -54,7 +54,8 @@ namespace Dml
const auto kernelLookup = onnxruntime::KernelLookup(
providerType,
gsl::make_span(®istry, 1),
- kernelTypeStrResolver);
+ kernelTypeStrResolver,
+ logger);
onnxruntime::GraphViewer graphViewer(graph);
const auto& nodeTopologyList = graphViewer.GetNodesInTopologicalOrder();
diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/ExecutionProvider.cpp b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/ExecutionProvider.cpp
index 228dfeb123175..826f48b5f7a68 100644
--- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/ExecutionProvider.cpp
+++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/ExecutionProvider.cpp
@@ -95,7 +95,7 @@ namespace Dml
const onnxruntime::IExecutionProvider::IKernelLookup& kernel_lookup) const
{
#ifdef ENABLE_GRAPH_COMPILATION
- return m_impl->GetCapability(graph, kernel_lookup);
+ return m_impl->GetCapability(graph, kernel_lookup, *GetLogger());
#else
return onnxruntime::IExecutionProvider::GetCapability(graph, kernel_lookup);
#endif
@@ -876,7 +876,8 @@ namespace Dml
std::vector>
ExecutionProviderImpl::GetCapability(
const onnxruntime::GraphViewer& graph,
- const onnxruntime::IExecutionProvider::IKernelLookup& kernel_lookup) const
+ const onnxruntime::IExecutionProvider::IKernelLookup& kernel_lookup,
+ const onnxruntime::logging::Logger& logger) const
{
uint32_t deviceDataTypeMask = GetSupportedDeviceDataTypeMask(); // Each bit corresponds to each DML_TENSOR_DATA_TYPE.
@@ -900,7 +901,7 @@ namespace Dml
}
// Get the list of nodes that should stay on the CPU
- auto cpuPreferredNodes = GetCpuPreferredNodes(graph, kernel_lookup, tentativeNodes);
+ auto cpuPreferredNodes = GetCpuPreferredNodes(graph, kernel_lookup, tentativeNodes, logger);
for (size_t nodeIndex : toplogicalOrder)
{
diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/ExecutionProvider.h b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/ExecutionProvider.h
index 32a5b9add35a0..e7d859c5764de 100644
--- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/ExecutionProvider.h
+++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/ExecutionProvider.h
@@ -88,7 +88,8 @@ namespace Dml
std::vector>
GetCapability(
const onnxruntime::GraphViewer& graph,
- const onnxruntime::IExecutionProvider::IKernelLookup& kernel_lookup
+ const onnxruntime::IExecutionProvider::IKernelLookup& kernel_lookup,
+ const onnxruntime::logging::Logger& logger
) const;
uint32_t GetSupportedDeviceDataTypeMask() const;
diff --git a/onnxruntime/core/providers/js/js_execution_provider.cc b/onnxruntime/core/providers/js/js_execution_provider.cc
index 4cb40ec8bf5fd..c1a8b373bed84 100644
--- a/onnxruntime/core/providers/js/js_execution_provider.cc
+++ b/onnxruntime/core/providers/js/js_execution_provider.cc
@@ -818,7 +818,7 @@ std::vector> JsExecutionProvider::GetCapabili
candidates.push_back(node.Index());
tenative_candidates.push_back(node.Index());
}
- auto cpu_nodes = GetCpuPreferredNodes(graph, kernel_lookup, tenative_candidates);
+ auto cpu_nodes = GetCpuPreferredNodes(graph, kernel_lookup, tenative_candidates, *GetLogger());
std::vector> result;
for (auto& node_index : candidates) {
if (cpu_nodes.count(node_index) > 0) {
diff --git a/onnxruntime/core/providers/nnapi/nnapi_builtin/builders/model_builder.cc b/onnxruntime/core/providers/nnapi/nnapi_builtin/builders/model_builder.cc
index 12416ea0c121b..e4bee6f959a01 100644
--- a/onnxruntime/core/providers/nnapi/nnapi_builtin/builders/model_builder.cc
+++ b/onnxruntime/core/providers/nnapi/nnapi_builtin/builders/model_builder.cc
@@ -32,8 +32,16 @@ namespace nnapi {
ModelBuilder::ModelBuilder(const GraphViewer& graph_viewer, const NnApi& nnapi_handle,
gsl::span nnapi_target_devices,
- TargetDeviceOption target_device_option)
- : nnapi_(nnapi_handle), graph_viewer_(graph_viewer), nnapi_model_{std::make_unique(nnapi_handle)}, shaper_{graph_viewer}, nnapi_target_devices_(nnapi_target_devices), target_device_option_(target_device_option), nnapi_effective_feature_level_(GetNNAPIEffectiveFeatureLevel(nnapi_handle, nnapi_target_devices_)) {
+ TargetDeviceOption target_device_option,
+ const logging::Logger& logger)
+ : nnapi_(nnapi_handle),
+ graph_viewer_(graph_viewer),
+ nnapi_model_{std::make_unique(nnapi_handle)},
+ shaper_{graph_viewer},
+ nnapi_target_devices_(nnapi_target_devices),
+ target_device_option_(target_device_option),
+ nnapi_effective_feature_level_(GetNNAPIEffectiveFeatureLevel(nnapi_handle, nnapi_target_devices_)),
+ logger_(logger) {
nnapi_model_->nnapi_effective_feature_level_ = nnapi_effective_feature_level_;
}
@@ -136,7 +144,7 @@ const NodeUnit& ModelBuilder::GetNodeUnit(const Node* node) const {
}
void ModelBuilder::PreprocessNodeUnits() {
- std::tie(node_unit_holder_, node_unit_map_) = QDQ::GetAllNodeUnits(graph_viewer_);
+ std::tie(node_unit_holder_, node_unit_map_) = QDQ::GetAllNodeUnits(graph_viewer_, logger_);
}
// Help to get all quantized operators' input and the NodeUnit(s) using the input
diff --git a/onnxruntime/core/providers/nnapi/nnapi_builtin/builders/model_builder.h b/onnxruntime/core/providers/nnapi/nnapi_builtin/builders/model_builder.h
index b2118150dd304..4db335afa98b0 100644
--- a/onnxruntime/core/providers/nnapi/nnapi_builtin/builders/model_builder.h
+++ b/onnxruntime/core/providers/nnapi/nnapi_builtin/builders/model_builder.h
@@ -14,7 +14,9 @@
struct NnApi;
namespace onnxruntime {
-
+namespace logging {
+class Logger;
+}
class GraphViewer;
enum class DataLayout;
class NodeUnit;
@@ -31,7 +33,8 @@ class ModelBuilder {
using Shape = Shaper::Shape;
ModelBuilder(const GraphViewer& graph_viewer, const NnApi& nnapi_handle,
- gsl::span nnapi_target_devices, TargetDeviceOption target_device_option);
+ gsl::span nnapi_target_devices, TargetDeviceOption target_device_option,
+ const logging::Logger& logger);
common::Status Compile(std::unique_ptr& model);
@@ -173,6 +176,9 @@ class ModelBuilder {
// <1,1> <1,2> <1,3>
InlinedVector> operations_recorder_;
#endif
+
+ const logging::Logger& logger_;
+
// Convert the ONNX model to ANeuralNetworksModel
common::Status Prepare();
diff --git a/onnxruntime/core/providers/nnapi/nnapi_builtin/nnapi_execution_provider.cc b/onnxruntime/core/providers/nnapi/nnapi_builtin/nnapi_execution_provider.cc
index fca52396a190c..f92c9592742d5 100644
--- a/onnxruntime/core/providers/nnapi/nnapi_builtin/nnapi_execution_provider.cc
+++ b/onnxruntime/core/providers/nnapi/nnapi_builtin/nnapi_execution_provider.cc
@@ -81,6 +81,7 @@ NnapiExecutionProvider::~NnapiExecutionProvider() {}
std::vector>
NnapiExecutionProvider::GetCapability(const onnxruntime::GraphViewer& graph_viewer,
const IKernelLookup& /*kernel_lookup*/) const {
+ const auto& logger = *GetLogger();
std::vector> result;
// TODO: Task 812756: NNAPI EP, add support for subgraph (If and Loop operators)
@@ -101,7 +102,7 @@ NnapiExecutionProvider::GetCapability(const onnxruntime::GraphViewer& graph_view
return ORT_NNAPI_MAX_SUPPORTED_API_LEVEL;
#endif
}();
- LOGS_DEFAULT(VERBOSE) << "Effective NNAPI feature level: " << android_feature_level;
+ LOGS(logger, VERBOSE) << "Effective NNAPI feature level: " << android_feature_level;
const nnapi::OpSupportCheckParams params{
android_feature_level,
@@ -109,7 +110,7 @@ NnapiExecutionProvider::GetCapability(const onnxruntime::GraphViewer& graph_view
};
if (params.android_feature_level < ORT_NNAPI_MIN_API_LEVEL) {
- LOGS_DEFAULT(WARNING) << "All ops will fallback to CPU EP, because system NNAPI feature level ["
+ LOGS(logger, WARNING) << "All ops will fallback to CPU EP, because system NNAPI feature level ["
<< params.android_feature_level
<< "] is lower than minimal supported NNAPI API feature level ["
<< ORT_NNAPI_MIN_API_LEVEL
@@ -121,7 +122,7 @@ NnapiExecutionProvider::GetCapability(const onnxruntime::GraphViewer& graph_view
std::vector> node_unit_holder;
std::unordered_map node_unit_map;
- std::tie(node_unit_holder, node_unit_map) = QDQ::GetAllNodeUnits(graph_viewer);
+ std::tie(node_unit_holder, node_unit_map) = QDQ::GetAllNodeUnits(graph_viewer, logger);
// This holds the result of whether a NodeUnit is supported or not,
// to prevent nodes in a NodeUnit to be checked for multiple times
@@ -150,7 +151,7 @@ NnapiExecutionProvider::GetCapability(const onnxruntime::GraphViewer& graph_view
node_unit_supported_result[node_unit] = supported;
}
- LOGS_DEFAULT(VERBOSE) << "Node supported: [" << supported
+ LOGS(logger, VERBOSE) << "Node supported: [" << supported
<< "] Operator type: [" << node.OpType()
<< "] index: [" << node.Index()
<< "] name: [" << node.Name()
@@ -224,9 +225,9 @@ NnapiExecutionProvider::GetCapability(const onnxruntime::GraphViewer& graph_view
// If the graph is partitioned in multiple subgraphs, and this may impact performance,
// we want to give users a summary message at warning level.
if (num_of_partitions > 1) {
- LOGS_DEFAULT(WARNING) << summary_msg;
+ LOGS(logger, WARNING) << summary_msg;
} else {
- LOGS_DEFAULT(INFO) << summary_msg;
+ LOGS(logger, INFO) << summary_msg;
}
return result;
@@ -273,11 +274,13 @@ static Status GetOutputBuffer(Ort::KernelContext& context,
common::Status NnapiExecutionProvider::Compile(const std::vector& fused_nodes_and_graphs,
std::vector& node_compute_funcs) {
using namespace android::nn::wrapper;
+ const auto& logger = *GetLogger();
+
for (const auto& fused_node_and_graph : fused_nodes_and_graphs) {
Node& fused_node = fused_node_and_graph.fused_node;
const onnxruntime::GraphViewer& graph_viewer(fused_node_and_graph.filtered_graph);
- nnapi::ModelBuilder builder(graph_viewer, *nnapi_handle_, nnapi_target_devices_, target_device_option_);
+ nnapi::ModelBuilder builder(graph_viewer, *nnapi_handle_, nnapi_target_devices_, target_device_option_, logger);
builder.SetUseNCHW(nnapi_flags_ & NNAPI_FLAG_USE_NCHW);
builder.SetUseFp16(nnapi_flags_ & NNAPI_FLAG_USE_FP16);
diff --git a/onnxruntime/core/providers/openvino/qdq_transformations/qdq_stripping.cc b/onnxruntime/core/providers/openvino/qdq_transformations/qdq_stripping.cc
index f1df1abf4c49a..decfe91c598be 100644
--- a/onnxruntime/core/providers/openvino/qdq_transformations/qdq_stripping.cc
+++ b/onnxruntime/core/providers/openvino/qdq_transformations/qdq_stripping.cc
@@ -687,7 +687,7 @@ Status CreateModelWithStrippedQDQNodes(const GraphViewer& src_graph,
// Get all the NodeUnits in the graph_viewer
std::vector> node_unit_holder;
std::unordered_map node_unit_map;
- std::tie(node_unit_holder, node_unit_map) = QDQ::GetAllNodeUnits(&src_graph);
+ std::tie(node_unit_holder, node_unit_map) = QDQ::GetAllNodeUnits(&src_graph, logger);
std::unordered_set seen_node_units;
const auto& node_indices = src_graph.GetNodesInTopologicalOrder();
diff --git a/onnxruntime/core/providers/qnn/builder/onnx_ctx_model_helper.cc b/onnxruntime/core/providers/qnn/builder/onnx_ctx_model_helper.cc
index 57ae8c354abb7..79674fd706151 100644
--- a/onnxruntime/core/providers/qnn/builder/onnx_ctx_model_helper.cc
+++ b/onnxruntime/core/providers/qnn/builder/onnx_ctx_model_helper.cc
@@ -87,7 +87,8 @@ Status CreateNodeArgs(const std::vector& names,
Status GetEpContextFromMainNode(const onnxruntime::Node& main_context_node,
const onnxruntime::PathString& ctx_onnx_model_path,
QnnBackendManager* qnn_backend_manager,
- QnnModelLookupTable& qnn_models) {
+ QnnModelLookupTable& qnn_models,
+ int64_t max_spill_fill_size) {
ORT_RETURN_IF_NOT(EPCONTEXT_OP == main_context_node.OpType(), "Should only filter in the EPContext node.");
NodeAttrHelper node_helper(main_context_node);
bool is_embed_mode = node_helper.Get(EMBED_MODE, true);
@@ -96,7 +97,8 @@ Status GetEpContextFromMainNode(const onnxruntime::Node& main_context_node,
return qnn_backend_manager->LoadCachedQnnContextFromBuffer(const_cast(context_binary.c_str()),
static_cast(context_binary.length()),
main_context_node.Name(),
- qnn_models);
+ qnn_models,
+ max_spill_fill_size);
}
std::filesystem::path folder_path = std::filesystem::path(ctx_onnx_model_path).parent_path();
@@ -145,17 +147,46 @@ Status GetEpContextFromMainNode(const onnxruntime::Node& main_context_node,
return qnn_backend_manager->LoadCachedQnnContextFromBuffer(buffer.get(),
static_cast(buffer_size),
main_context_node.Name(),
- qnn_models);
+ qnn_models,
+ max_spill_fill_size);
+}
+
+Status TryGetMaxSpillFillSize(const std::vector& fused_nodes_and_graphs,
+ uint32_t total_context_size,
+ int64_t& max_spill_fill_size,
+ std::vector& main_context_pos_list) {
+ max_spill_fill_size = 0;
+ int max_size_index = 0;
+ for (uint32_t i = 0; i < total_context_size; ++i) {
+ auto index = main_context_pos_list[i];
+ const onnxruntime::GraphViewer& main_ctx_graph_viewer(fused_nodes_and_graphs[index].filtered_graph);
+ ORT_RETURN_IF(main_ctx_graph_viewer.NumberOfNodes() != 1, "One filtered graph should has only one EPContext node!");
+ const auto& ep_context_node = main_ctx_graph_viewer.Nodes().begin();
+ NodeAttrHelper node_helper(*ep_context_node);
+ int64_t max_size = node_helper.Get(MAX_SIZE, static_cast(0));
+ if (max_size > max_spill_fill_size) {
+ max_spill_fill_size = max_size;
+ max_size_index = i;
+ }
+ }
+ if (0 != max_size_index) {
+ int tmp_index = main_context_pos_list[0];
+ main_context_pos_list[0] = main_context_pos_list[max_size_index];
+ main_context_pos_list[max_size_index] = tmp_index;
+ }
+
+ return Status::OK();
}
Status LoadQnnCtxFromOnnxGraph(const onnxruntime::GraphViewer& graph_viewer,
const onnxruntime::PathString& ctx_onnx_model_path,
QnnBackendManager* qnn_backend_manager,
QnnModelLookupTable& qnn_models,
- const logging::Logger& logger) {
+ const logging::Logger& logger,
+ int64_t max_spill_fill_size) {
ORT_RETURN_IF(graph_viewer.NumberOfNodes() != 1, "One filtered graph should has only one EPContext node!");
Status status = GetEpContextFromMainNode(*graph_viewer.Nodes().begin(), ctx_onnx_model_path, qnn_backend_manager,
- qnn_models);
+ qnn_models, max_spill_fill_size);
// This is the protocol with customer that status with INVALID_GRAPH will be generated if failed to load context model
if (!status.IsOK()) {
@@ -196,6 +227,7 @@ Status CreateEPContextNodes(Model* model,
const QnnModelLookupTable& qnn_models,
const onnxruntime::PathString& context_cache_path,
bool qnn_context_embed_mode,
+ uint64_t max_spill_fill_buffer_size,
const logging::Logger& logger) {
auto& graph = model->MainGraph();
@@ -238,6 +270,7 @@ Status CreateEPContextNodes(Model* model,
}
of_stream.write(reinterpret_cast(buffer), buffer_size);
ep_node.AddAttribute(EP_CACHE_CONTEXT, context_cache_name);
+ ep_node.AddAttribute(MAX_SIZE, static_cast(max_spill_fill_buffer_size));
}
} else {
ep_node.AddAttribute(MAIN_CONTEXT, static_cast(0));
diff --git a/onnxruntime/core/providers/qnn/builder/onnx_ctx_model_helper.h b/onnxruntime/core/providers/qnn/builder/onnx_ctx_model_helper.h
index f308a7456d46c..92c5391b40f09 100644
--- a/onnxruntime/core/providers/qnn/builder/onnx_ctx_model_helper.h
+++ b/onnxruntime/core/providers/qnn/builder/onnx_ctx_model_helper.h
@@ -28,6 +28,7 @@ static const std::string EP_CACHE_CONTEXT = "ep_cache_context";
static const std::string EP_SDK_VER = "ep_sdk_version";
static const std::string PARTITION_NAME = "partition_name";
static const std::string SOURCE = "source";
+static const std::string MAX_SIZE = "max_size";
bool GraphHasEpContextNode(const onnxruntime::GraphViewer& graph_viewer);
@@ -49,13 +50,20 @@ bool ValidateContextCacheFilePath(bool is_qnn_ctx_model,
Status GetEpContextFromMainNode(const onnxruntime::Node& main_context_node,
const onnxruntime::PathString& ctx_onnx_model_path,
QnnBackendManager* qnn_backend_manager,
- QnnModelLookupTable& qnn_models);
+ QnnModelLookupTable& qnn_models,
+ int64_t max_spill_fill_size);
+
+Status TryGetMaxSpillFillSize(const std::vector& fused_nodes_and_graphs,
+ uint32_t total_context_size,
+ int64_t& max_spill_fill_size,
+ std::vector& main_context_pos_list);
Status LoadQnnCtxFromOnnxGraph(const onnxruntime::GraphViewer& graph_viewer,
const onnxruntime::PathString& ctx_onnx_model_path,
QnnBackendManager* qnn_backend_manager,
QnnModelLookupTable& qnn_models,
- const logging::Logger& logger);
+ const logging::Logger& logger,
+ int64_t max_spill_fill_size);
Status CreateEPContextNodes(Model* model,
unsigned char* buffer,
@@ -65,6 +73,7 @@ Status CreateEPContextNodes(Model* model,
const std::unordered_map>& qnn_models,
const onnxruntime::PathString& context_cache_path,
bool qnn_context_embed_mode,
+ uint64_t max_spill_fill_buffer_size,
const logging::Logger& logger);
} // namespace qnn
} // namespace onnxruntime
diff --git a/onnxruntime/core/providers/qnn/builder/qnn_backend_manager.cc b/onnxruntime/core/providers/qnn/builder/qnn_backend_manager.cc
index f37c91aa0413b..8a717c3f29ff9 100644
--- a/onnxruntime/core/providers/qnn/builder/qnn_backend_manager.cc
+++ b/onnxruntime/core/providers/qnn/builder/qnn_backend_manager.cc
@@ -8,6 +8,7 @@
#include
#include "QnnOpDef.h"
#include "HTP/QnnHtpPerfInfrastructure.h"
+#include "HTP/QnnHtpSystemContext.h"
#include "CPU/QnnCpuCommon.h"
// TODO: not exist for Windows yet
// #include "GPU/QnnGpuCommon.h"
@@ -532,11 +533,11 @@ Status QnnBackendManager::CreateContext() {
}
QnnContext_Config_t context_config_weight_sharing = QNN_CONTEXT_CONFIG_INIT;
- QnnHtpContext_CustomConfig_t customConfig;
- customConfig.option = QNN_HTP_CONTEXT_CONFIG_OPTION_WEIGHT_SHARING_ENABLED;
- customConfig.weightSharingEnabled = enable_htp_weight_sharing_;
+ QnnHtpContext_CustomConfig_t custom_config;
+ custom_config.option = QNN_HTP_CONTEXT_CONFIG_OPTION_WEIGHT_SHARING_ENABLED;
+ custom_config.weightSharingEnabled = enable_htp_weight_sharing_;
context_config_weight_sharing.option = QNN_CONTEXT_CONFIG_OPTION_CUSTOM;
- context_config_weight_sharing.customConfig = &customConfig;
+ context_config_weight_sharing.customConfig = &custom_config;
QnnContext_Config_t context_priority_config = QNN_CONTEXT_CONFIG_INIT;
ORT_RETURN_IF_ERROR(SetQnnContextConfig(context_priority_, context_priority_config));
@@ -615,9 +616,71 @@ std::unique_ptr QnnBackendManager::GetContextBinaryBuffer(uint6
return context_buffer;
}
+Status QnnBackendManager::GetMaxSpillFillBufferSize(unsigned char* buffer,
+ uint64_t buffer_length,
+ uint64_t& max_spill_fill_buffer_size) {
+ bool result = nullptr == qnn_sys_interface_.systemContextCreate ||
+ nullptr == qnn_sys_interface_.systemContextGetBinaryInfo ||
+ nullptr == qnn_sys_interface_.systemContextFree;
+ ORT_RETURN_IF(result, "Failed to get valid function pointer.");
+
+ QnnSystemContext_Handle_t sys_ctx_handle = nullptr;
+ auto rt = qnn_sys_interface_.systemContextCreate(&sys_ctx_handle);
+ ORT_RETURN_IF(QNN_SUCCESS != rt, "Failed to create system handle.");
+
+ const QnnSystemContext_BinaryInfo_t* binary_info = nullptr;
+ Qnn_ContextBinarySize_t binary_info_size{0};
+ rt = qnn_sys_interface_.systemContextGetBinaryInfo(sys_ctx_handle,
+ static_cast(buffer),
+ buffer_length,
+ &binary_info,
+ &binary_info_size);
+ ORT_RETURN_IF(QNN_SUCCESS != rt, "Failed to get context binary info.");
+
+ // binary_info life cycle is here
+ // Binary info to graph info
+ // retrieve Qnn graph info from binary info
+ ORT_RETURN_IF(nullptr == binary_info, "Qnn cached binary info is nullptr.");
+ uint32_t graph_count = 0;
+ QnnSystemContext_GraphInfo_t* graphs_info = nullptr;
+ if (binary_info->version == QNN_SYSTEM_CONTEXT_BINARY_INFO_VERSION_3) {
+ graph_count = binary_info->contextBinaryInfoV3.numGraphs;
+ graphs_info = binary_info->contextBinaryInfoV3.graphs;
+ } else if (binary_info->version == QNN_SYSTEM_CONTEXT_BINARY_INFO_VERSION_2) {
+ graph_count = binary_info->contextBinaryInfoV2.numGraphs;
+ graphs_info = binary_info->contextBinaryInfoV2.graphs;
+ } else if (binary_info->version == QNN_SYSTEM_CONTEXT_BINARY_INFO_VERSION_1) {
+ graph_count = binary_info->contextBinaryInfoV1.numGraphs;
+ graphs_info = binary_info->contextBinaryInfoV1.graphs;
+ } else {
+ return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Unsupported context binary info version.");
+ }
+
+ for (uint32_t i = 0; i < graph_count; ++i) {
+ if (graphs_info[i].version == QNN_SYSTEM_CONTEXT_GRAPH_INFO_VERSION_3) {
+ auto htp_graph_info = reinterpret_cast(graphs_info[i].graphInfoV3.graphBlobInfo);
+ if (htp_graph_info->version == QNN_SYSTEM_CONTEXT_HTP_GRAPH_INFO_BLOB_VERSION_V1) {
+ auto spill_fill_buffer_size = htp_graph_info->contextBinaryGraphBlobInfoV1.spillFillBufferSize;
+ max_spill_fill_buffer_size = spill_fill_buffer_size > max_spill_fill_buffer_size ? spill_fill_buffer_size : max_spill_fill_buffer_size;
+ } else {
+ LOGS(*logger_, VERBOSE) << "Unknown context binary graph info blob version.";
+ }
+ } else if (graphs_info[i].version == QNN_SYSTEM_CONTEXT_GRAPH_INFO_VERSION_2 ||
+ graphs_info[i].version == QNN_SYSTEM_CONTEXT_GRAPH_INFO_VERSION_1) {
+ LOGS(*logger_, VERBOSE) << "Skip retrieve spill file buffer size, it is not supported with graph info v1 & v2.";
+ } else {
+ LOGS(*logger_, VERBOSE) << "Unknown context binary graph info version.";
+ }
+ }
+
+ LOGS(*logger_, VERBOSE) << "Get max spill fill buffer size completed.";
+ return Status::OK();
+}
+
Status QnnBackendManager::LoadCachedQnnContextFromBuffer(char* buffer, uint64_t buffer_length,
std::string node_name,
- QnnModelLookupTable& qnn_models) {
+ QnnModelLookupTable& qnn_models,
+ int64_t max_spill_fill_size) {
bool result = nullptr == qnn_sys_interface_.systemContextCreate ||
nullptr == qnn_sys_interface_.systemContextGetBinaryInfo ||
nullptr == qnn_sys_interface_.systemContextFree;
@@ -638,7 +701,7 @@ Status QnnBackendManager::LoadCachedQnnContextFromBuffer(char* buffer, uint64_t
// binary_info life cycle is here
// Binary info to graph info
- // retrieve Qnn graph infor from binary info
+ // retrieve Qnn graph info from binary info
ORT_RETURN_IF(nullptr == binary_info, "Qnn cached binary info is nullptr.");
uint32_t graph_count = 0;
QnnSystemContext_GraphInfo_t* graphs_info = nullptr;
@@ -658,13 +721,33 @@ Status QnnBackendManager::LoadCachedQnnContextFromBuffer(char* buffer, uint64_t
ORT_RETURN_IF(graph_count < 1 || graphs_info == nullptr, "Failed to get graph info from Qnn cached context.");
LOGS(*logger_, VERBOSE) << "Graph count from QNN context: " << graph_count;
- ORT_RETURN_IF(nullptr == qnn_interface_.contextCreateFromBinary,
- "Invalid function pointer for contextCreateFromBinary.");
-
QnnContext_Config_t qnn_context_config = QNN_CONTEXT_CONFIG_INIT;
ORT_RETURN_IF_ERROR(SetQnnContextConfig(context_priority_, qnn_context_config));
- const QnnContext_Config_t* context_configs[] = {&qnn_context_config, nullptr};
+ // Register spill fill buffer for multi context
+ QnnContext_Config_t spill_fill_config = QNN_CONTEXT_CONFIG_INIT;
+
+ // The spill fill buffer is available since 2.28, API version starts from 2.21
+#if QNN_API_VERSION_MAJOR == 2 && (QNN_API_VERSION_MINOR >= 21)
+ QnnHtpContext_CustomConfig_t custom_config;
+ custom_config.option = QNN_HTP_CONTEXT_CONFIG_OPTION_REGISTER_MULTI_CONTEXTS;
+ QnnHtpContext_GroupRegistration_t group_info;
+ size_t current_contexts_size = GetQnnContextSize();
+ // set to 0x0 (new group) if this is the first context, otherwise point to the first context handle
+ // note that we already move the context with max spill fill size to the beginning of the list
+ group_info.firstGroupHandle = (max_spill_fill_size > 0 && current_contexts_size > 0) ? GetQnnContext(0) : 0x0;
+ group_info.maxSpillFillBuffer = max_spill_fill_size; // Max spill-fill buffer across contexts. Must be >0
+ custom_config.groupRegistration = group_info;
+ spill_fill_config.option = QNN_CONTEXT_CONFIG_OPTION_CUSTOM;
+ spill_fill_config.customConfig = &custom_config;
+#endif
+ QnnContext_Config_t* spill_fill_config_pointer = max_spill_fill_size > 0 ? &spill_fill_config : nullptr;
+ LOGS(*logger_, VERBOSE) << "Max spill fill buffer size:" << max_spill_fill_size;
+
+ const QnnContext_Config_t* context_configs[] = {&qnn_context_config, spill_fill_config_pointer, nullptr};
+
+ ORT_RETURN_IF(nullptr == qnn_interface_.contextCreateFromBinary,
+ "Invalid function pointer for contextCreateFromBinary.");
Qnn_ContextHandle_t context = nullptr;
rt = qnn_interface_.contextCreateFromBinary(backend_handle_,
device_handle_,
@@ -673,7 +756,7 @@ Status QnnBackendManager::LoadCachedQnnContextFromBuffer(char* buffer, uint64_t
buffer_length,
&context,
profile_backend_handle_);
- ORT_RETURN_IF(QNN_SUCCESS != rt, "Failed to create context from binary.");
+ ORT_RETURN_IF(QNN_SUCCESS != rt, "Failed to create context from binary. Error code: ", rt);
contexts_.push_back(context);
if (1 == graph_count) {
// in case the EPContext node is generated from script
@@ -699,7 +782,11 @@ Status QnnBackendManager::LoadCachedQnnContextFromBuffer(char* buffer, uint64_t
return Status::OK();
}
-Status QnnBackendManager::SetupBackend(const logging::Logger& logger, bool load_from_cached_context) {
+// need to load system lib if load from Qnn context binary
+// or generate Qnn context binary is enabled -- to get the max spill fill buffer size
+Status QnnBackendManager::SetupBackend(const logging::Logger& logger,
+ bool load_from_cached_context,
+ bool need_load_system_lib) {
std::lock_guard lock(logger_mutex_);
if (backend_setup_completed_) {
LOGS(logger, VERBOSE) << "Backend setup already!";
@@ -714,7 +801,7 @@ Status QnnBackendManager::SetupBackend(const logging::Logger& logger, bool load_
LOGS(logger, VERBOSE) << "LoadBackend succeed.";
- if (load_from_cached_context) {
+ if (load_from_cached_context || need_load_system_lib) {
ORT_RETURN_IF_ERROR(LoadQnnSystemLib());
}
@@ -933,20 +1020,6 @@ Status QnnBackendManager::SetRpcControlLatency(uint32_t htp_power_config_client_
return Status::OK();
}
-void QnnBackendManager::Split(std::vector& split_string,
- const std::string& tokenized_string,
- const char separator) {
- split_string.clear();
- std::istringstream tokenized_string_stream(tokenized_string);
- while (!tokenized_string_stream.eof()) {
- std::string value;
- getline(tokenized_string_stream, value, separator);
- if (!value.empty()) {
- split_string.push_back(value);
- }
- }
-}
-
Status QnnBackendManager::DestroyHTPPowerConfigID(uint32_t htp_power_config_id) {
QnnDevice_Infrastructure_t qnn_device_infra = nullptr;
auto status = qnn_interface_.deviceGetInfrastructure(&qnn_device_infra);
diff --git a/onnxruntime/core/providers/qnn/builder/qnn_backend_manager.h b/onnxruntime/core/providers/qnn/builder/qnn_backend_manager.h
index 43007d4a5c244..b145f2a2cd724 100644
--- a/onnxruntime/core/providers/qnn/builder/qnn_backend_manager.h
+++ b/onnxruntime/core/providers/qnn/builder/qnn_backend_manager.h
@@ -93,9 +93,10 @@ class QnnBackendManager {
Status LoadCachedQnnContextFromBuffer(char* buffer, uint64_t buffer_length,
std::string node_name,
- std::unordered_map>& qnn_models);
+ std::unordered_map>& qnn_models,
+ int64_t max_spill_fill_size);
- Status SetupBackend(const logging::Logger& logger, bool load_from_cached_context);
+ Status SetupBackend(const logging::Logger& logger, bool load_from_cached_context, bool need_load_system_lib);
Status CreateHtpPowerCfgId(uint32_t deviceId, uint32_t coreId, uint32_t& htp_power_config_id);
@@ -112,6 +113,10 @@ class QnnBackendManager {
return contexts_[index];
}
+ size_t GetQnnContextSize() {
+ return contexts_.size();
+ }
+
const Qnn_BackendHandle_t& GetQnnBackendHandle() { return backend_handle_; }
const Qnn_ProfileHandle_t& GetQnnProfileHandle() { return profile_backend_handle_; }
@@ -145,8 +150,6 @@ class QnnBackendManager {
void ReleaseResources();
- void Split(std::vector& split_string, const std::string& tokenized_string, const char separator);
-
Status ExtractBackendProfilingInfo();
Status ExtractProfilingSubEvents(QnnProfile_EventId_t profile_event_id, std::ofstream& outfile,
bool backendSupportsExtendedEventData, bool tracelogging_provider_ep_enabled);
@@ -163,6 +166,10 @@ class QnnBackendManager {
Status DestroyHTPPowerConfigID(uint32_t htp_power_config_id);
+ Status GetMaxSpillFillBufferSize(unsigned char* buffer,
+ uint64_t buffer_length,
+ uint64_t& max_spill_fill_buffer_size);
+
private:
void* LoadLib(const char* file_name, int flags, std::string& error_msg);
diff --git a/onnxruntime/core/providers/qnn/builder/qnn_model.cc b/onnxruntime/core/providers/qnn/builder/qnn_model.cc
index 88fa6429fc01e..75973c7031d62 100644
--- a/onnxruntime/core/providers/qnn/builder/qnn_model.cc
+++ b/onnxruntime/core/providers/qnn/builder/qnn_model.cc
@@ -104,7 +104,7 @@ Status QnnModel::ComposeGraph(const GraphViewer& graph_viewer,
// valid throughout the lifetime of the ModelBuilder
std::vector> node_unit_holder;
std::unordered_map node_unit_map;
- std::tie(node_unit_holder, node_unit_map) = QDQ::GetAllNodeUnits(graph_viewer);
+ std::tie(node_unit_holder, node_unit_map) = QDQ::GetAllNodeUnits(graph_viewer, logger);
// This name must be same with the EPContext node name
const auto& graph_name = fused_node.Name();
diff --git a/onnxruntime/core/providers/qnn/qnn_execution_provider.cc b/onnxruntime/core/providers/qnn/qnn_execution_provider.cc
index 6735528bebbf9..060bbd4f79bf2 100644
--- a/onnxruntime/core/providers/qnn/qnn_execution_provider.cc
+++ b/onnxruntime/core/providers/qnn/qnn_execution_provider.cc
@@ -363,20 +363,24 @@ QNNExecutionProvider::QNNExecutionProvider(const ProviderOptions& provider_optio
LOGS_DEFAULT(VERBOSE) << "User specified enable_htp_fp16_precision: " << enable_HTP_FP16_precision_;
}
+ bool enable_htp_weight_sharing = false;
static const std::string QNN_HTP_WEIGHT_SHARING_ENABLED = "enable_htp_weight_sharing";
auto htp_weight_sharing_enabled_pos = provider_options_map.find(QNN_HTP_WEIGHT_SHARING_ENABLED);
if (htp_weight_sharing_enabled_pos != provider_options_map.end()) {
if ("1" == htp_weight_sharing_enabled_pos->second) {
- enable_htp_weight_sharing_ = true;
+ enable_htp_weight_sharing = true;
} else if ("0" == htp_weight_sharing_enabled_pos->second) {
- enable_htp_weight_sharing_ = false;
+ enable_htp_weight_sharing = false;
} else {
- LOGS_DEFAULT(VERBOSE) << "Invalid enable_htp_weight_sharing: " << enable_htp_weight_sharing_
+ LOGS_DEFAULT(VERBOSE) << "Invalid enable_htp_weight_sharing: " << enable_htp_weight_sharing
<< " only 0 or 1 allowed. Set to 0.";
}
- LOGS_DEFAULT(VERBOSE) << "User specified enable_htp_weight_sharing: " << enable_htp_weight_sharing_;
+ LOGS_DEFAULT(VERBOSE) << "User specified enable_htp_weight_sharing: " << enable_htp_weight_sharing;
}
+ // Add this option because this feature requires QnnSystem lib and it's no supported for Windows x86_64 platform
+ enable_spill_fill_buffer_ = ParseBoolOption("enable_htp_spill_fill_buffer", false, provider_options_map);
+
model_settings_.offload_graph_io_quantization = ParseBoolOption("offload_graph_io_quantization", false,
provider_options_map);
@@ -396,7 +400,7 @@ QNNExecutionProvider::QNNExecutionProvider(const ProviderOptions& provider_optio
device_id_,
htp_arch,
soc_model,
- enable_htp_weight_sharing_);
+ enable_htp_weight_sharing);
#ifdef _WIN32
auto& etwRegistrationManager = logging::EtwRegistrationManager::Instance();
@@ -686,7 +690,8 @@ QNNExecutionProvider::GetCapability(const onnxruntime::GraphViewer& graph_viewer
// It will load the QnnSystem lib if is_qnn_ctx_model=true, and
// delay the Qnn context creation to Compile() using the cached context binary
- auto rt = qnn_backend_manager_->SetupBackend(logger, is_qnn_ctx_model);
+ // or generate context cache enable, need to use use QnnSystem lib to parse the binary to get the max spill fill buffer size
+ auto rt = qnn_backend_manager_->SetupBackend(logger, is_qnn_ctx_model, context_cache_enabled_ && enable_spill_fill_buffer_);
if (Status::OK() != rt) {
LOGS(logger, ERROR) << "QNN SetupBackend failed " << rt.ErrorMessage();
return result;
@@ -713,7 +718,7 @@ QNNExecutionProvider::GetCapability(const onnxruntime::GraphViewer& graph_viewer
std::vector> node_unit_holder;
std::unordered_map node_unit_map;
- std::tie(node_unit_holder, node_unit_map) = QDQ::GetAllNodeUnits(graph_viewer);
+ std::tie(node_unit_holder, node_unit_map) = QDQ::GetAllNodeUnits(graph_viewer, logger);
// remove is_qnn_ctx_model related code
const auto supported_nodes = GetSupportedNodes(graph_viewer, node_unit_map,
@@ -934,6 +939,16 @@ Status QNNExecutionProvider::Compile(const std::vector& fused
std::vector main_context_pos_list;
ORT_RETURN_IF_ERROR(qnn::GetMainContextNode(fused_nodes_and_graphs, main_context_pos_list));
+ uint32_t total_context_size = SafeInt(main_context_pos_list.size());
+
+ int64_t max_spill_fill_size = 0;
+
+ // Adjust the main_context_pos_list, move the one with max spill fill buffer to the beginning
+ // HTP spill fill buffer only works for multiple QNN contexts generated after QNN v2.28
+ if (total_context_size > 1) {
+ ORT_RETURN_IF_ERROR(qnn::TryGetMaxSpillFillSize(fused_nodes_and_graphs, total_context_size,
+ max_spill_fill_size, main_context_pos_list));
+ }
for (auto main_context_pos : main_context_pos_list) {
const onnxruntime::GraphViewer& main_ctx_graph_viewer(fused_nodes_and_graphs[main_context_pos].filtered_graph);
@@ -942,7 +957,8 @@ Status QNNExecutionProvider::Compile(const std::vector& fused
context_cache_path,
qnn_backend_manager_.get(),
qnn_models,
- logger));
+ logger,
+ max_spill_fill_size));
}
for (auto fused_node_and_graph : fused_nodes_and_graphs) {
@@ -984,6 +1000,13 @@ Status QNNExecutionProvider::Compile(const std::vector& fused
// All partitioned graph share single QNN context, included in the same context binary
uint64_t buffer_size(0);
auto context_buffer = qnn_backend_manager_->GetContextBinaryBuffer(buffer_size);
+ // Get max spill fill buffer size
+ uint64_t max_spill_fill_buffer_size = 0;
+ if (enable_spill_fill_buffer_) {
+ ORT_RETURN_IF_ERROR(qnn_backend_manager_->GetMaxSpillFillBufferSize(context_buffer.get(),
+ buffer_size,
+ max_spill_fill_buffer_size));
+ }
qnn_ep_context_model_ = std::make_unique("qnn_ep_context_model", false, logger);
ORT_RETURN_IF_ERROR(qnn::CreateEPContextNodes(qnn_ep_context_model_.get(),
context_buffer.get(),
@@ -993,6 +1016,7 @@ Status QNNExecutionProvider::Compile(const std::vector& fused
qnn_models_,
context_cache_path,
qnn_context_embed_mode_,
+ max_spill_fill_buffer_size,
logger));
}
return Status::OK();
diff --git a/onnxruntime/core/providers/qnn/qnn_execution_provider.h b/onnxruntime/core/providers/qnn/qnn_execution_provider.h
index 35c061de6132c..a0577e8fd87f2 100644
--- a/onnxruntime/core/providers/qnn/qnn_execution_provider.h
+++ b/onnxruntime/core/providers/qnn/qnn_execution_provider.h
@@ -141,7 +141,6 @@ class QNNExecutionProvider : public IExecutionProvider {
std::string context_node_name_prefix_ = "";
bool disable_cpu_ep_fallback_ = false; // True if CPU EP fallback has been disabled for this session.
bool qnn_context_embed_mode_ = true;
- bool enable_htp_weight_sharing_ = false;
int32_t vtcm_size_in_mb_ = 0;
std::unique_ptr qnn_ep_context_model_;
ModelMetadefIdGenerator metadef_id_generator_;
@@ -150,6 +149,7 @@ class QNNExecutionProvider : public IExecutionProvider {
uint32_t default_rpc_control_latency_ = 0;
bool enable_HTP_FP16_precision_ = true;
bool share_ep_contexts_ = false;
+ bool enable_spill_fill_buffer_ = false;
#ifdef _WIN32
onnxruntime::logging::EtwRegistrationManager::EtwInternalCallback callback_ETWSink_provider_ = nullptr;
#endif
diff --git a/onnxruntime/core/providers/rocm/rocm_execution_provider.cc b/onnxruntime/core/providers/rocm/rocm_execution_provider.cc
index 75b8ac7e134f3..0a427b146dcaa 100644
--- a/onnxruntime/core/providers/rocm/rocm_execution_provider.cc
+++ b/onnxruntime/core/providers/rocm/rocm_execution_provider.cc
@@ -2493,7 +2493,7 @@ ROCMExecutionProvider::GetCapability(const onnxruntime::GraphViewer& graph,
// For ROCM EP, exclude the subgraph that is preferred to be placed in CPU
// These are usually shape related computation subgraphs
// Following logic can be extended for other EPs
- auto cpu_nodes = GetCpuPreferredNodes(graph, kernel_lookup, tentative_nodes);
+ auto cpu_nodes = GetCpuPreferredNodes(graph, kernel_lookup, tentative_nodes, logger);
std::vector> result;
for (auto& node_index : candidates) {
if (cpu_nodes.count(node_index) > 0)
diff --git a/onnxruntime/core/providers/shared_library/provider_api.h b/onnxruntime/core/providers/shared_library/provider_api.h
index b84825236a453..45f81ed22b7f7 100644
--- a/onnxruntime/core/providers/shared_library/provider_api.h
+++ b/onnxruntime/core/providers/shared_library/provider_api.h
@@ -294,7 +294,8 @@ std::unique_ptr CreateGPUDataTransfer();
std::unordered_set