Skip to content

Commit

Permalink
Merge remote-tracking branch 'origin/main' into Cjian/rn-0.69.1
Browse files Browse the repository at this point in the history
# Conflicts:
#	js/react_native/android/gradle/wrapper/gradle-wrapper.properties
  • Loading branch information
jchen351 committed Dec 10, 2024
2 parents e5eadf5 + 5f7b9d0 commit 192e2d5
Show file tree
Hide file tree
Showing 137 changed files with 1,626 additions and 1,044 deletions.
15 changes: 12 additions & 3 deletions cmake/onnxruntime_mlas.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,8 @@ onnxruntime_add_static_library(onnxruntime_mlas
${MLAS_SRC_DIR}/sqnbitgemm_q8_block.h
${MLAS_SRC_DIR}/flashattn.cpp
${MLAS_SRC_DIR}/cast.cpp
${MLAS_SRC_DIR}/rotary_embedding.h
${MLAS_SRC_DIR}/rotary_embedding.cpp
)

target_sources(onnxruntime_mlas PRIVATE
Expand Down Expand Up @@ -88,8 +90,11 @@ function(setup_mlas_source_for_windows)
${MLAS_SRC_DIR}/qnbitgemm_kernel_neon.cpp
${MLAS_SRC_DIR}/sqnbitgemm_kernel_neon_fp32.cpp
${MLAS_SRC_DIR}/sqnbitgemm_kernel_neon_int8.cpp
${MLAS_SRC_DIR}/fp16_neon_common.cpp
${MLAS_SRC_DIR}/cast_kernel_neon.cpp
${MLAS_SRC_DIR}/hqnbitgemm_kernel_neon_fp16.cpp
${MLAS_SRC_DIR}/rotary_embedding_kernel_neon.h
${MLAS_SRC_DIR}/rotary_embedding_kernel_neon.cpp
${MLAS_SRC_DIR}/rotary_embedding_kernel_neon_fp16.cpp
)

set(mlas_platform_preprocess_srcs
Expand Down Expand Up @@ -367,6 +372,8 @@ else()
${MLAS_SRC_DIR}/qnbitgemm_kernel_neon.cpp
${MLAS_SRC_DIR}/sqnbitgemm_kernel_neon_fp32.cpp
${MLAS_SRC_DIR}/sqnbitgemm_kernel_neon_int8.cpp
${MLAS_SRC_DIR}/rotary_embedding_kernel_neon.h
${MLAS_SRC_DIR}/rotary_embedding_kernel_neon.cpp
)
set_source_files_properties(${MLAS_SRC_DIR}/sqnbitgemm_kernel_neon_int8.cpp
PROPERTIES COMPILE_FLAGS " -march=armv8.2-a+dotprod")
Expand All @@ -384,8 +391,9 @@ else()
${MLAS_SRC_DIR}/qgemm_kernel_smmla.cpp
${MLAS_SRC_DIR}/qgemm_kernel_ummla.cpp
${MLAS_SRC_DIR}/sbgemm_kernel_neon.cpp
${MLAS_SRC_DIR}/fp16_neon_common.cpp
${MLAS_SRC_DIR}/cast_kernel_neon.cpp
${MLAS_SRC_DIR}/hqnbitgemm_kernel_neon_fp16.cpp
${MLAS_SRC_DIR}/rotary_embedding_kernel_neon_fp16.cpp
)
set_source_files_properties(${MLAS_SRC_DIR}/aarch64/HalfGemmKernelNeon.S PROPERTIES COMPILE_FLAGS " -march=armv8.2-a+fp16 ")
set_source_files_properties(${MLAS_SRC_DIR}/aarch64/QgemmS8S8KernelSmmla.S PROPERTIES COMPILE_FLAGS " -march=armv8.2-a+i8mm ")
Expand All @@ -395,8 +403,9 @@ else()
set_source_files_properties(${MLAS_SRC_DIR}/dwconv.cpp PROPERTIES COMPILE_FLAGS " -march=armv8.2-a+fp16 ")
set_source_files_properties(${MLAS_SRC_DIR}/pooling_fp16.cpp PROPERTIES COMPILE_FLAGS " -march=armv8.2-a+fp16 ")
set_source_files_properties(${MLAS_SRC_DIR}/sbgemm_kernel_neon.cpp PROPERTIES COMPILE_FLAGS " -march=armv8.2-a+bf16 ")
set_source_files_properties(${MLAS_SRC_DIR}/fp16_neon_common.cpp PROPERTIES COMPILE_FLAGS " -march=armv8.2-a+fp16 ")
set_source_files_properties(${MLAS_SRC_DIR}/cast_kernel_neon.cpp PROPERTIES COMPILE_FLAGS " -march=armv8.2-a+fp16 ")
set_source_files_properties(${MLAS_SRC_DIR}/hqnbitgemm_kernel_neon_fp16.cpp PROPERTIES COMPILE_FLAGS " -march=armv8.2-a+fp16 ")
set_source_files_properties(${MLAS_SRC_DIR}/rotary_embedding_kernel_neon_fp16.cpp PROPERTIES COMPILE_FLAGS " -march=armv8.2-a+fp16 ")
endif()

if(ONNXRUNTIME_MLAS_MULTI_ARCH)
Expand Down
2 changes: 2 additions & 0 deletions docs/ContribOperators.md
Original file line number Diff line number Diff line change
Expand Up @@ -1596,6 +1596,8 @@ This version of the operator has been available since version 1 of the 'com.micr
<dd>(Optional) Hardware architecture.</dd>
<dt><tt>main_context</tt> : int</dt>
<dd>Usually each single EPContext associate with a graph partition.But for some case like QNN, it has single EPContext contains all partitions.In that case, the node with ep_cache_context should set main_context=1. Other nodes set main_context=0 and skip ep_cache_context.The path is relative to this Onnx file. Default is 1.</dd>
<dt><tt>max_size</tt> : int</dt>
<dd>max size in the context. Usage depend on the EP.</dd>
<dt><tt>notes</tt> : string</dt>
<dd>(Optional) Some notes for the model</dd>
<dt><tt>onnx_model_filename</tt> : string</dt>
Expand Down
12 changes: 10 additions & 2 deletions include/onnxruntime/core/framework/kernel_registry.h
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,9 @@
#include "core/framework/op_kernel.h"

namespace onnxruntime {
namespace logging {
class Logger;
}

using KernelCreateMap = std::multimap<std::string, KernelCreateInfo>;
using KernelDefHashes = std::vector<std::pair<std::string, HashValue>>;
Expand All @@ -33,6 +36,7 @@ class KernelRegistry {
// Kernel matching uses the types from the node and the kernel_type_str_resolver.
Status TryFindKernel(const Node& node, ProviderType exec_provider,
const IKernelTypeStrResolver& kernel_type_str_resolver,
const logging::Logger& logger,
const KernelCreateInfo** out) const;

// map of type constraint name to required type
Expand All @@ -42,6 +46,7 @@ class KernelRegistry {
// Kernel matching uses the explicit type constraint name to required type map in type_constraints.
Status TryFindKernel(const Node& node, ProviderType exec_provider,
const TypeConstraintMap& type_constraints,
const logging::Logger& logger,
const KernelCreateInfo** out) const;

/**
Expand All @@ -61,13 +66,15 @@ class KernelRegistry {
std::string_view domain,
int version,
const KernelRegistry::TypeConstraintMap& type_constraints,
const logging::Logger& logger,
const KernelCreateInfo** out) const;

static bool HasImplementationOf(const KernelRegistry& r, const Node& node,
ProviderType exec_provider,
const IKernelTypeStrResolver& kernel_type_str_resolver) {
const IKernelTypeStrResolver& kernel_type_str_resolver,
const logging::Logger& logger) {
const KernelCreateInfo* info;
Status st = r.TryFindKernel(node, exec_provider, kernel_type_str_resolver, &info);
Status st = r.TryFindKernel(node, exec_provider, kernel_type_str_resolver, logger, &info);
return st.IsOK();
}

Expand All @@ -83,6 +90,7 @@ class KernelRegistry {
Status TryFindKernelImpl(const Node& node, ProviderType exec_provider,
const IKernelTypeStrResolver* kernel_type_str_resolver,
const TypeConstraintMap* type_constraints,
const logging::Logger& logger,
const KernelCreateInfo** out) const;

// Check whether the types of inputs/outputs of the given node match the extra
Expand Down
2 changes: 2 additions & 0 deletions include/onnxruntime/core/optimizer/graph_transformer_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ InlinedVector<std::unique_ptr<GraphTransformer>> GenerateTransformers(
TransformerLevel level,
const SessionOptions& session_options,
const IExecutionProvider& execution_provider /*required by constant folding*/,
const logging::Logger& logger,
const InlinedHashSet<std::string>& rules_and_transformers_to_disable = {},
concurrency::ThreadPool* intra_op_thread_pool = nullptr,
std::unordered_map<std::string, std::unique_ptr<Tensor>>* p_buffered_tensors = nullptr);
Expand Down Expand Up @@ -84,6 +85,7 @@ InlinedVector<std::unique_ptr<GraphTransformer>> GenerateTransformersForMinimalB
const SessionOptions& session_options,
const SatApplyContextVariant& apply_context,
const IExecutionProvider& cpu_execution_provider,
const logging::Logger& logger,
const InlinedHashSet<std::string>& rules_and_transformers_to_disable = {},
concurrency::ThreadPool* intra_op_thread_pool = nullptr,
std::unordered_map<std::string, std::unique_ptr<Tensor>>* p_buffered_tensors = nullptr);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -47,8 +47,20 @@ enum COREMLFlags {
// and SessionOptionsAppendExecutionProvider (C API). For the old API, use COREMLFlags instead.
static const char* const kCoremlProviderOption_MLComputeUnits = "MLComputeUnits";
static const char* const kCoremlProviderOption_ModelFormat = "ModelFormat";
// same as COREML_FLAG_ONLY_ALLOW_STATIC_INPUT_SHAPES
static const char* const kCoremlProviderOption_RequireStaticInputShapes = "RequireStaticInputShapes";
static const char* const kCoremlProviderOption_EnableOnSubgraphs = "EnableOnSubgraphs";
// provided by https://developer.apple.com/documentation/coreml/mloptimizationhints-swift.struct/specializationstrategy-swift.property
// Core ML segments the model’s compute graph and specializes each segment for the target compute device.
// This process can affect the model loading time and the prediction latency.
// Use this option to tailor the specialization strategy for your model.
static const char* const kCoremlProviderOption_SpecializationStrategy = "SpecializationStrategy";
// Profile the Core ML MLComputePlan.
// This logs the hardware each operator is dispatched to and the estimated execution time.
// Intended for developer usage but provide useful diagnostic information if performance is not as expected.
static const char* const kCoremlProviderOption_ProfileComputePlan = "ProfileComputePlan";
// please refer to https://developer.apple.com/documentation/coreml/mlmodelconfiguration/allowlowprecisionaccumulationongpu
static const char* const kCoremlProviderOption_AllowLowPrecisionAccumulationOnGPU = "AllowLowPrecisionAccumulationOnGPU";

#ifdef __cplusplus
extern "C" {
Expand Down
22 changes: 22 additions & 0 deletions include/onnxruntime/core/session/onnxruntime_c_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -3667,6 +3667,9 @@ struct OrtApi {
* execution provider (typically CPU EP).
* - "0": Default. Disabled. QNN EP will handle quantization and dequantization of graph I/O.
* - "1": Enabled.
* "enable_htp_spill_fill_buffer": Enable HTP spill fill buffer setting. The flag is used while generating context binary.
* - "0": Default. Disabled.
* - "1": Enabled.
*
* SNPE supported keys:
* "runtime": SNPE runtime engine, options: "CPU", "CPU_FLOAT32", "GPU", "GPU_FLOAT32_16_HYBRID", "GPU_FLOAT16",
Expand Down Expand Up @@ -4612,6 +4615,8 @@ struct OrtApi {
* \param[in] num_keys
*
* \snippet{doc} snippets.dox OrtStatus Return Value
*
* \since Version 1.17.
*/
ORT_API2_STATUS(SessionOptionsAppendExecutionProvider_OpenVINO_V2,
_In_ OrtSessionOptions* options,
Expand All @@ -4629,6 +4634,8 @@ struct OrtApi {
* \param[in] num_keys
*
* \snippet{doc} snippets.dox OrtStatus Return Value
*
* \since Version 1.18.
*/
ORT_API2_STATUS(SessionOptionsAppendExecutionProvider_VitisAI,
_In_ OrtSessionOptions* options,
Expand All @@ -4642,7 +4649,10 @@ struct OrtApi {
* \param[in] mem_info OrtMemoryInfo instance
* \param[in] count_or_bytes How many bytes is this scratch buffer
* \param[out] out A pointer to the scrach buffer
*
* \snippet{doc} snippets.dox OrtStatus Return Value
*
* \since Version 1.18.
*/
ORT_API2_STATUS(KernelContext_GetScratchBuffer, _In_ const OrtKernelContext* context, _In_ const OrtMemoryInfo* mem_info, _In_ size_t count_or_bytes, _Outptr_ void** out);

Expand All @@ -4653,6 +4663,8 @@ struct OrtApi {
* \param[out] out A pointer to OrtAllocator
*
* \snippet{doc} snippets.dox OrtStatus Return Value
*
* \since Version 1.18.
*/
ORT_API2_STATUS(KernelInfoGetAllocator, _In_ const OrtKernelInfo* info, _In_ OrtMemType mem_type, _Outptr_ OrtAllocator** out);

Expand All @@ -4674,6 +4686,8 @@ struct OrtApi {
* \param[in] num_external_initializer_files Number of external files
*
* \snippet{doc} snippets.dox OrtStatus Return Value
*
* \since Version 1.18.
*/
ORT_API2_STATUS(AddExternalInitializersFromFilesInMemory, _In_ OrtSessionOptions* options,
_In_reads_(num_external_initializer_files) const ORTCHAR_T* const* external_initializer_file_names,
Expand All @@ -4696,6 +4710,8 @@ struct OrtApi {
* OrtApi::ReleaseLoraAdapter.
*
* \snippet{doc} snippets.dox OrtStatus Return Value
*
* \since Version 1.20.
*/
ORT_API2_STATUS(CreateLoraAdapter, const ORTCHAR_T* adapter_file_path, _In_ OrtAllocator* allocator,
_Outptr_ OrtLoraAdapter** out);
Expand All @@ -4714,6 +4730,8 @@ struct OrtApi {
* OrtApi::ReleaseLoraAdapter.
*
* \snippet{doc} snippets.dox OrtStatus Return Value
*
* \since Version 1.20.
*/
ORT_API2_STATUS(CreateLoraAdapterFromArray, _In_ const void* bytes, size_t num_bytes, _In_ OrtAllocator* allocator,
_Outptr_ OrtLoraAdapter** out);
Expand All @@ -4735,6 +4753,8 @@ struct OrtApi {
* \param[in] adapter OrtLoraAdapter instance
*
* \snippet{doc} snippets.dox OrtStatus Return Value
*
* \since Version 1.20.
*/
ORT_API2_STATUS(RunOptionsAddActiveLoraAdapter, _Inout_ OrtRunOptions* options, _In_ const OrtLoraAdapter* adapter);

Expand All @@ -4753,6 +4773,8 @@ struct OrtApi {
* \param[in] kv_len Number of elements in the keys and values arrays
*
* \snippet{doc} snippets.dox OrtStatus Return Value
*
* \since Version 1.20.
*/
ORT_API2_STATUS(SetEpDynamicOptions, _Inout_ OrtSession* sess, _In_reads_(kv_len) const char* const* keys,
_In_reads_(kv_len) const char* const* values, _In_ size_t kv_len);
Expand Down
13 changes: 0 additions & 13 deletions js/.eslintrc.js
Original file line number Diff line number Diff line change
Expand Up @@ -198,19 +198,6 @@ module.exports = {
'_OrtReleaseTensor',
'_OrtRun',
'_OrtRunWithBinding',
'_OrtTrainingCopyParametersFromBuffer',
'_OrtTrainingCopyParametersToBuffer',
'_OrtTrainingCreateSession',
'_OrtTrainingEvalStep',
'_OrtTrainingGetModelInputOutputCount',
'_OrtTrainingGetModelInputOutputName',
'_OrtTrainingGetParametersSize',
'_OrtTrainingLazyResetGrad',
'_OrtTrainingLoadCheckpoint',
'_OrtTrainingOptimizerStep',
'_OrtTrainingReleaseCheckpoint',
'_OrtTrainingReleaseSession',
'_OrtTrainingRunTrainStep',
],
},
],
Expand Down
36 changes: 0 additions & 36 deletions js/common/lib/backend.ts
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@

import { InferenceSession } from './inference-session.js';
import { OnnxValue } from './onnx-value.js';
import { TrainingSession } from './training-session.js';

/**
* @ignore
Expand Down Expand Up @@ -42,33 +41,6 @@ export interface InferenceSessionHandler extends SessionHandler {
): Promise<SessionHandler.ReturnType>;
}

/**
* Represent a handler instance of a training inference session.
*
* @ignore
*/
export interface TrainingSessionHandler extends SessionHandler {
readonly evalInputNames: readonly string[];
readonly evalOutputNames: readonly string[];

lazyResetGrad(): Promise<void>;
runTrainStep(
feeds: SessionHandler.FeedsType,
fetches: SessionHandler.FetchesType,
options: InferenceSession.RunOptions,
): Promise<SessionHandler.ReturnType>;
runOptimizerStep(options: InferenceSession.RunOptions): Promise<void>;
runEvalStep(
feeds: SessionHandler.FeedsType,
fetches: SessionHandler.FetchesType,
options: InferenceSession.RunOptions,
): Promise<SessionHandler.ReturnType>;

getParametersSize(trainableOnly: boolean): Promise<number>;
loadParametersBuffer(buffer: Uint8Array, trainableOnly: boolean): Promise<void>;
getContiguousParameters(trainableOnly: boolean): Promise<OnnxValue>;
}

/**
* Represent a backend that provides implementation of model inferencing.
*
Expand All @@ -84,14 +56,6 @@ export interface Backend {
uriOrBuffer: string | Uint8Array,
options?: InferenceSession.SessionOptions,
): Promise<InferenceSessionHandler>;

createTrainingSessionHandler?(
checkpointStateUriOrBuffer: TrainingSession.UriOrBuffer,
trainModelUriOrBuffer: TrainingSession.UriOrBuffer,
evalModelUriOrBuffer: TrainingSession.UriOrBuffer,
optimizerModelUriOrBuffer: TrainingSession.UriOrBuffer,
options: InferenceSession.SessionOptions,
): Promise<TrainingSessionHandler>;
}

export { registerBackend } from './backend-impl.js';
13 changes: 3 additions & 10 deletions js/common/lib/env.ts
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
// Licensed under the MIT License.

import { env as envImpl } from './env-impl.js';
import { TryGetGlobalType } from './type-helper.js';

export declare namespace Env {
export type WasmPathPrefix = string;
Expand All @@ -14,7 +15,6 @@ export declare namespace Env {
* If not modified, the filename of the .wasm file is:
* - `ort-wasm-simd-threaded.wasm` for default build
* - `ort-wasm-simd-threaded.jsep.wasm` for JSEP build (with WebGPU and WebNN)
* - `ort-training-wasm-simd-threaded.wasm` for training build
*/
wasm?: URL | string;
/**
Expand All @@ -25,7 +25,6 @@ export declare namespace Env {
* If not modified, the filename of the .mjs file is:
* - `ort-wasm-simd-threaded.mjs` for default build
* - `ort-wasm-simd-threaded.jsep.mjs` for JSEP build (with WebGPU and WebNN)
* - `ort-training-wasm-simd-threaded.mjs` for training build
*/
mjs?: URL | string;
}
Expand Down Expand Up @@ -200,22 +199,16 @@ export declare namespace Env {
* value will be the GPU adapter that created by the underlying WebGPU backend.
*
* When use with TypeScript, the type of this property is `GPUAdapter` defined in "@webgpu/types".
* Use `const adapter = env.webgpu.adapter as GPUAdapter;` in TypeScript to access this property with correct type.
*
* see comments on {@link Tensor.GpuBufferType}
*/
adapter: unknown;
adapter: TryGetGlobalType<'GPUAdapter'>;
/**
* Get the device for WebGPU.
*
* This property is only available after the first WebGPU inference session is created.
*
* When use with TypeScript, the type of this property is `GPUDevice` defined in "@webgpu/types".
* Use `const device = env.webgpu.device as GPUDevice;` in TypeScript to access this property with correct type.
*
* see comments on {@link Tensor.GpuBufferType} for more details about why not use types defined in "@webgpu/types".
*/
readonly device: unknown;
readonly device: TryGetGlobalType<'GPUDevice'>;
/**
* Set or get whether validate input content.
*
Expand Down
1 change: 0 additions & 1 deletion js/common/lib/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -26,4 +26,3 @@ export * from './tensor-factory.js';
export * from './trace.js';
export * from './onnx-model.js';
export * from './onnx-value.js';
export * from './training-session.js';
Loading

0 comments on commit 192e2d5

Please sign in to comment.