Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Multi head attention #22143

Closed
wants to merge 83 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
83 commits
Select commit Hold shift + click to select a range
4037bd4
[WIP] WebGPU EP initial commit
fs-eire Aug 28, 2024
9c36250
update C-API
fs-eire Aug 28, 2024
3a0756d
fix build break
fs-eire Aug 28, 2024
5199e98
add an empty symbols.txt file
fs-eire Aug 28, 2024
1c68dbd
fix an error in doc
fs-eire Aug 29, 2024
7db03de
remove string_join.h in favor of absl::StrJoin
fs-eire Aug 29, 2024
6a373c2
fix DLL copy
fs-eire Aug 29, 2024
ee42bba
update doc: require --skip_tests
fs-eire Aug 29, 2024
5fac202
Merge remote-tracking branch 'origin/main' into fs-eire/webgpu-ep
fs-eire Aug 29, 2024
3f46e5c
update dawn version
fs-eire Aug 29, 2024
9f61279
disable Tint tests
fs-eire Aug 29, 2024
6bb6335
fix one build break in Linux
fs-eire Aug 29, 2024
d839dbc
remove unused variables
fs-eire Aug 30, 2024
b70943d
make webgpu build on linux and known to most tools (#21937)
guschmue Aug 30, 2024
c33ac2e
Merge remote-tracking branch 'origin/main' into fs-eire/webgpu-ep
fs-eire Aug 30, 2024
8437267
revert type of ShaderVariable::rank_ to int
fs-eire Aug 30, 2024
3caf032
output Impl() for variables
fs-eire Aug 30, 2024
84494c4
code formatting
fs-eire Aug 30, 2024
aa70163
better format of Uniform
fs-eire Aug 30, 2024
d772db7
revise document
fs-eire Aug 30, 2024
6ef3dad
more build fix for linux
fs-eire Aug 31, 2024
a56f6c3
apply formatter
fs-eire Aug 31, 2024
12cd79d
simple test runner
fs-eire Aug 31, 2024
14c8966
Program macros update - allow extend
fs-eire Aug 31, 2024
4fff35f
fix BucketCacheManager
fs-eire Sep 1, 2024
4fd8ad1
add a method to get logger from ComputeContext
fs-eire Sep 1, 2024
3bd92ad
add verbose log for cache key
fs-eire Sep 1, 2024
6a1bbfe
revise suite test
fs-eire Sep 1, 2024
947aee1
device lost handler
fs-eire Sep 1, 2024
99b2578
add '-a' and '-t' to test runner
fs-eire Sep 1, 2024
aa7b3f5
atol/rtol 0.0001 -> 0.001
fs-eire Sep 1, 2024
e659acd
Fix uniform
fs-eire Sep 2, 2024
6ad89c5
add some unary ops
fs-eire Sep 2, 2024
8361fc3
various of fixes
fs-eire Sep 2, 2024
c89159d
fix workgroup_size, cache key stringnify and indices type
fs-eire Sep 3, 2024
5ea5936
shape_uniforms preparation
fs-eire Sep 3, 2024
7d83054
allow uniforms of input/output shape/stride being added automatically
fs-eire Sep 3, 2024
7a64cc7
Merge remote-tracking branch 'origin/main' into fs-eire/webgpu-ep
fs-eire Sep 3, 2024
1d53ac8
fix build (linux)
fs-eire Sep 3, 2024
4d52602
fix stride
fs-eire Sep 3, 2024
3761aad
fix "{res_name}_bi2o_{name}"
fs-eire Sep 3, 2024
351da84
Add Expand operator (#21933)
qjia7 Sep 3, 2024
0b7ce77
support onnxruntime_test_all
fs-eire Sep 3, 2024
33726b1
reflect change in WebGpuProviderFactoryCreator::Create signature (#21…
guschmue Sep 3, 2024
50ea9eb
compare the content of WEBGPU_BUFFER, not the address (#21967)
guschmue Sep 3, 2024
d6f6148
fix tanh
fs-eire Sep 3, 2024
626edaf
support size==0 for element wise operators
fs-eire Sep 4, 2024
8913da1
Merge remote-tracking branch 'origin/main' into fs-eire/webgpu-ep
fs-eire Sep 4, 2024
bacc54c
use shared ComputeBroadcastOutputShape()
fs-eire Sep 4, 2024
7ecc5bb
add workgroup_idx
fs-eire Sep 4, 2024
ae836b1
expose name for shader variable
fs-eire Sep 4, 2024
243078b
add uniform for 1D variable
fs-eire Sep 5, 2024
4d48d28
fix GetElementAt with uniform
fs-eire Sep 5, 2024
dbe673b
document update folder
fs-eire Sep 5, 2024
38f182e
fix adapter/device creating: add toggles
fs-eire Sep 5, 2024
eb80f7c
more strict shape&stride usage check
fs-eire Sep 6, 2024
39d5509
fix vector realloc
fs-eire Sep 6, 2024
cd961c3
simplify cache hint interface.
fs-eire Sep 6, 2024
ddc2fbb
revise expand
fs-eire Sep 6, 2024
e8be835
revise unary
fs-eire Sep 6, 2024
bd7d592
Elu/Relu/LeakyRelu/ThresholdedRelu/Gelu
fs-eire Sep 6, 2024
eecac18
Merge remote-tracking branch 'origin/main' into fs-eire/webgpu-ep
fs-eire Sep 6, 2024
601e50f
remove unused field in class Gelu
fs-eire Sep 6, 2024
8f36da2
remove out-of-dated comments
fs-eire Sep 6, 2024
72ebd85
Clip
fs-eire Sep 7, 2024
a3244ae
fix rank in shader helper
fs-eire Sep 7, 2024
5a2ae8c
fix shader variable
fs-eire Sep 9, 2024
aa54ff8
move components number from variable to program
fs-eire Sep 9, 2024
969384d
mark components in cache key
fs-eire Sep 9, 2024
6b82486
Add FastGelu op (#21991)
qjia7 Sep 10, 2024
2b3e7c2
use 'set/add' as prefix for some functions
fs-eire Sep 10, 2024
ef0d53b
remove unnecessary cache hint for FastGelu
fs-eire Sep 10, 2024
c4ca47f
revise unary - expose consts in header
fs-eire Sep 10, 2024
8806d57
use path for header file
fs-eire Sep 10, 2024
0568e2b
a few revises to the code (#22047)
fs-eire Sep 10, 2024
b7a9c0e
use OrtMutex
fs-eire Sep 11, 2024
f65ade9
Merge remote-tracking branch 'origin/main' into fs-eire/webgpu-ep
fs-eire Sep 11, 2024
d4a963d
[webgpu-native] Add transpose op (#21986)
axinging Sep 11, 2024
8b61532
PushErrorScope and PopErrorScope
fs-eire Sep 11, 2024
dce0f18
placeholder for setting proc table
fs-eire Sep 12, 2024
8978d89
Revert "placeholder for setting proc table"
fs-eire Sep 12, 2024
43ccaf4
allow setting "ValidationMode"
fs-eire Sep 12, 2024
409ac5c
webgpu: support MultiHeadAttention operator
xhcao Sep 19, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions cmake/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -151,6 +151,7 @@ option(onnxruntime_TVM_USE_LLVM "Build TVM with LLVM. Set customized path to llv
option(onnxruntime_TVM_USE_HASH "Build ipp-crypto library for support hash algorithm. It is defined for TVM only")
option(onnxruntime_USE_XNNPACK "Build with XNNPACK support. Provides an alternative math library on ARM, WebAssembly and x86." OFF)
option(onnxruntime_USE_WEBNN "Build with WebNN support. Enable hardware acceleration in web browsers." OFF)
option(onnxruntime_USE_WEBGPU "Build with WebGPU support. Enable WebGPU via C/C++ interface." OFF)

# Options related to reducing the binary size produced by the build
# XNNPACK EP requires the internal NHWC contrib ops to be available, so this option must be OFF when onnxruntime_USE_XNNPACK is ON
Expand Down Expand Up @@ -910,6 +911,11 @@ if (onnxruntime_USE_WEBNN)
list(APPEND ORT_PROVIDER_CMAKE_FLAGS -Donnxruntime_USE_WEBNN=1)
list(APPEND ONNXRUNTIME_PROVIDER_NAMES webnn)
endif()
if (onnxruntime_USE_WEBGPU)
list(APPEND ORT_PROVIDER_FLAGS -DUSE_WEBGPU=1)
list(APPEND ORT_PROVIDER_CMAKE_FLAGS -Donnxruntime_USE_WEBGPU=1)
list(APPEND ONNXRUNTIME_PROVIDER_NAMES webgpu)
endif()
if (onnxruntime_USE_CANN)
list(APPEND ORT_PROVIDER_FLAGS -DUSE_CANN=1)
list(APPEND ORT_PROVIDER_CMAKE_FLAGS -Donnxruntime_USE_CANN=1)
Expand Down
14 changes: 14 additions & 0 deletions cmake/external/onnxruntime_external_deps.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -633,6 +633,20 @@ if (onnxruntime_USE_COREML)
FetchContent_Populate(coremltools)
endif()

if (onnxruntime_USE_WEBGPU)
FetchContent_Declare(
dawn
URL ${DEP_URL_dawn}
URL_HASH SHA1=${DEP_SHA1_dawn}
)
set(DAWN_FETCH_DEPENDENCIES ON)
set(DAWN_ENABLE_INSTALL ON)
set(TINT_BUILD_TESTS OFF)
set(DAWN_USE_BUILT_DXC ON)
set(DAWN_DXC_ENABLE_ASSERTS_IN_NDEBUG OFF)
onnxruntime_fetchcontent_makeavailable(dawn)
endif()

message(STATUS "Finished fetching external dependencies")

set(onnxruntime_LINK_DIRS )
Expand Down
3 changes: 2 additions & 1 deletion cmake/onnxruntime.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ function(get_c_cxx_api_headers HEADERS_VAR)

# need to add header files for enabled EPs
foreach(f ${ONNXRUNTIME_PROVIDER_NAMES})
# The header files in include/onnxruntime/core/providers/cuda directory cannot be flattened to the same directory
# The header files in include/onnxruntime/core/providers/cuda directory cannot be flattened to the same directory
# with onnxruntime_c_api.h . Most other EPs probably also do not work in this way.
if((NOT f STREQUAL cuda) AND (NOT f STREQUAL rocm))
file(GLOB _provider_headers CONFIGURE_DEPENDS
Expand Down Expand Up @@ -200,6 +200,7 @@ set(onnxruntime_INTERNAL_LIBRARIES
${PROVIDERS_RKNPU}
${PROVIDERS_VSINPU}
${PROVIDERS_XNNPACK}
${PROVIDERS_WEBGPU}
${PROVIDERS_WEBNN}
${PROVIDERS_AZURE}
${PROVIDERS_INTERNAL_TESTING}
Expand Down
7 changes: 7 additions & 0 deletions cmake/onnxruntime_providers.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,9 @@ endif()
if(onnxruntime_USE_WEBNN)
set(PROVIDERS_WEBNN onnxruntime_providers_webnn)
endif()
if(onnxruntime_USE_WEBGPU)
set(PROVIDERS_WEBGPU onnxruntime_providers_webgpu)
endif()
if (onnxruntime_USE_CANN)
set(PROVIDERS_CANN onnxruntime_providers_cann)
endif()
Expand Down Expand Up @@ -151,6 +154,10 @@ if (onnxruntime_USE_WEBNN)
include(onnxruntime_providers_webnn.cmake)
endif()

if (onnxruntime_USE_WEBGPU)
include(onnxruntime_providers_webgpu.cmake)
endif()

if (onnxruntime_USE_NNAPI_BUILTIN)
include(onnxruntime_providers_nnapi.cmake)
endif()
Expand Down
7 changes: 6 additions & 1 deletion cmake/onnxruntime_providers_cpu.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,11 @@ file(GLOB_RECURSE onnxruntime_js_contrib_ops_cc_srcs CONFIGURE_DEPENDS
"${ONNXRUNTIME_ROOT}/contrib_ops/js/*.cc"
)

file(GLOB_RECURSE onnxruntime_webgpu_contrib_ops_cc_srcs CONFIGURE_DEPENDS
"${ONNXRUNTIME_ROOT}/contrib_ops/webgpu/*.h"
"${ONNXRUNTIME_ROOT}/contrib_ops/webgpu/*.cc"
)

file(GLOB onnxruntime_providers_common_srcs CONFIGURE_DEPENDS
"${ONNXRUNTIME_ROOT}/core/providers/*.h"
"${ONNXRUNTIME_ROOT}/core/providers/*.cc"
Expand All @@ -60,7 +65,7 @@ if(NOT onnxruntime_DISABLE_CONTRIB_OPS)
"${ONNXRUNTIME_ROOT}/contrib_ops/cpu/aten_ops/aten_op_executor.cc"
)
endif()
set(onnxruntime_cpu_neural_speed_srcs
set(onnxruntime_cpu_neural_speed_srcs
"${ONNXRUNTIME_ROOT}/contrib_ops/cpu/quantization/neural_speed_wrapper.h"
"${ONNXRUNTIME_ROOT}/contrib_ops/cpu/quantization/neural_speed_defs.h"
"${ONNXRUNTIME_ROOT}/contrib_ops/cpu/quantization/neural_speed_gemm.cc"
Expand Down
37 changes: 37 additions & 0 deletions cmake/onnxruntime_providers_webgpu.cmake
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.

if (onnxruntime_MINIMAL_BUILD AND NOT onnxruntime_EXTENDED_MINIMAL_BUILD)
message(FATAL_ERROR "WebGPU EP can not be used in a basic minimal build. Please build with '--minimal_build extended'")
endif()

# find_package(Dawn REQUIRED)

add_compile_definitions(USE_WEBGPU=1)
if (onnxruntime_ENABLE_WEBASSEMBLY_THREADS)
add_definitions(-DENABLE_WEBASSEMBLY_THREADS=1)
endif()
file(GLOB_RECURSE onnxruntime_providers_webgpu_cc_srcs CONFIGURE_DEPENDS
"${ONNXRUNTIME_ROOT}/core/providers/webgpu/*.h"
"${ONNXRUNTIME_ROOT}/core/providers/webgpu/*.cc"
# "${ONNXRUNTIME_ROOT}/core/providers/shared/utils/utils.h"
# "${ONNXRUNTIME_ROOT}/core/providers/shared/utils/utils.cc"
)
if(NOT onnxruntime_DISABLE_CONTRIB_OPS)
source_group(TREE ${ONNXRUNTIME_ROOT} FILES ${onnxruntime_webgpu_contrib_ops_cc_srcs})
list(APPEND onnxruntime_providers_webgpu_cc_srcs ${onnxruntime_webgpu_contrib_ops_cc_srcs})
endif()

source_group(TREE ${REPO_ROOT} FILES ${onnxruntime_providers_webgpu_cc_srcs})
onnxruntime_add_static_library(onnxruntime_providers_webgpu ${onnxruntime_providers_webgpu_cc_srcs})
onnxruntime_add_include_to_target(onnxruntime_providers_webgpu onnxruntime_common onnx onnx_proto flatbuffers::flatbuffers Boost::mp11 safeint_interface)
target_link_libraries(onnxruntime_providers_webgpu dawn::webgpu_dawn)

# Copy webgpu_dawn.dll to the output directory
add_custom_command(
TARGET onnxruntime_providers_webgpu
POST_BUILD
COMMAND ${CMAKE_COMMAND} -E copy_if_different "$<TARGET_FILE:dawn::webgpu_dawn>" "$<TARGET_FILE_DIR:onnxruntime_providers_webgpu>"
VERBATIM )

set_target_properties(onnxruntime_providers_webgpu PROPERTIES FOLDER "ONNXRuntime")
2 changes: 1 addition & 1 deletion cmake/onnxruntime_providers_webnn.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -22,4 +22,4 @@

add_dependencies(onnxruntime_providers_webnn onnx ${onnxruntime_EXTERNAL_DEPENDENCIES})
set_target_properties(onnxruntime_providers_webnn PROPERTIES FOLDER "ONNXRuntime")
set_target_properties(onnxruntime_providers_webnn PROPERTIES LINKER_LANGUAGE CXX)
set_target_properties(onnxruntime_providers_webnn PROPERTIES LINKER_LANGUAGE CXX)
1 change: 1 addition & 0 deletions cmake/onnxruntime_python.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -180,6 +180,7 @@ target_link_libraries(onnxruntime_pybind11_state PRIVATE
${PROVIDERS_ACL}
${PROVIDERS_ARMNN}
${PROVIDERS_XNNPACK}
${PROVIDERS_WEBGPU}
${PROVIDERS_AZURE}
${PROVIDERS_QNN}
onnxruntime_optimizer
Expand Down
28 changes: 28 additions & 0 deletions cmake/onnxruntime_unittests.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -557,6 +557,10 @@ if(onnxruntime_USE_JSEP)
list(APPEND onnxruntime_test_providers_dependencies onnxruntime_providers_js)
endif()

if(onnxruntime_USE_WEBGPU)
list(APPEND onnxruntime_test_providers_dependencies onnxruntime_providers_webgpu)
endif()

if(onnxruntime_USE_RKNPU)
list(APPEND onnxruntime_test_providers_dependencies onnxruntime_providers_rknpu)
endif()
Expand Down Expand Up @@ -598,6 +602,7 @@ set(ONNXRUNTIME_TEST_LIBS
${PROVIDERS_NNAPI}
${PROVIDERS_VSINPU}
${PROVIDERS_JS}
${PROVIDERS_WEBGPU}
${PROVIDERS_QNN}
${PROVIDERS_SNPE}
${PROVIDERS_RKNPU}
Expand Down Expand Up @@ -658,6 +663,13 @@ if(onnxruntime_USE_JSEP)
list(APPEND onnxruntime_test_providers_libs onnxruntime_providers_js)
endif()

if(onnxruntime_USE_WEBGPU)
list(APPEND onnxruntime_test_framework_src_patterns ${TEST_SRC_DIR}/providers/webgpu/*)
list(APPEND onnxruntime_test_framework_libs onnxruntime_providers_webgpu)
list(APPEND onnxruntime_test_providers_dependencies onnxruntime_providers_webgpu)
list(APPEND onnxruntime_test_providers_libs onnxruntime_providers_webgpu)
endif()

# QNN EP tests require CPU EP op implementations for accuracy evaluation, so disable on minimal
# or reduced op builds.
if(onnxruntime_USE_QNN AND NOT onnxruntime_MINIMAL_BUILD AND NOT onnxruntime_REDUCED_OPS_BUILD)
Expand Down Expand Up @@ -1112,6 +1124,22 @@ if (NOT IOS)
LIBRARY DESTINATION ${CMAKE_INSTALL_LIBDIR}
BUNDLE DESTINATION ${CMAKE_INSTALL_LIBDIR}
RUNTIME DESTINATION ${CMAKE_INSTALL_BINDIR})

## TODO: remove this when merging to main branch
#
# should support better test runner
#
if (onnxruntime_USE_WEBGPU)
add_custom_command(
TARGET onnx_test_runner
POST_BUILD
COMMAND ${CMAKE_COMMAND} -E copy_if_different
"${ONNXRUNTIME_ROOT}/test/providers/webgpu/test_webgpu.js"
"${ONNXRUNTIME_ROOT}/test/providers/webgpu/test_webgpu.bat"
"$<TARGET_FILE_DIR:onnx_test_runner>"
VERBATIM )
endif()

endif()

if (NOT onnxruntime_ENABLE_TRAINING_TORCH_INTEROP)
Expand Down
1 change: 1 addition & 0 deletions include/onnxruntime/core/graph/constants.h
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ constexpr const char* kSnpeExecutionProvider = "SNPEExecutionProvider";
constexpr const char* kTvmExecutionProvider = "TvmExecutionProvider";
constexpr const char* kXnnpackExecutionProvider = "XnnpackExecutionProvider";
constexpr const char* kWebNNExecutionProvider = "WebNNExecutionProvider";
constexpr const char* kWebGpuExecutionProvider = "WebGpuExecutionProvider";
constexpr const char* kCannExecutionProvider = "CANNExecutionProvider";
constexpr const char* kAzureExecutionProvider = "AzureExecutionProvider";
constexpr const char* kVSINPUExecutionProvider = "VSINPUExecutionProvider";
Expand Down
63 changes: 63 additions & 0 deletions include/onnxruntime/core/session/onnxruntime_c_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -624,6 +624,38 @@ typedef struct OrtMIGraphXProviderOptions {
bool migraphx_exhaustive_tune; // migraphx tuned compile Default = false
} OrtMIGraphXProviderOptions;

/** \brief WebGPU Execution Provider Options
*
* When a user wants to use WebGPU as the execution provider, there are 2 ways to specify the WebGPU device:
*
* 1. Use the default WebGPU device. The default WebGPU device is managed by WebGPU EP internally. The user doesn't
* need to provide any device information in this case. All the fields should be set to nullptr or 0.
*
* 2. Use a custom WebGPU device. The user should create their own handles of `WGPUInstance`, `WGPUAdapter`, and
* `WGPUDevice` and use arbitrary number in [1..65536) as the device id. The user should provide the handles
* and the device id in the options.
*
* When specifying an existing Device ID, the user should provide the handles of `WGPUInstance`, `WGPUAdapter`, and
* `WGPUDevice` in the options. The device id should be the same as the one used previously.
*
* It's user's responsibility to manage the lifecycle of the handles and ensure the handles are valid during the
* lifetime of the inference session.
*
* About DawnProcTable:
*
* When using an ONNX Runtime build that is not directly linked dawn during the build, a pointer to the runtime memory
* address of the DawnProcTable should be provided. Otherwise, keep it as nullptr.
*
* \see OrtApi::SessionOptionsAppendExecutionProvider_WGPU
*/
typedef struct OrtWGPUProviderOptions {
int device_id; // WebGPU device id.
void* instance_handle; // WebGPU instance handle.
void* adapter_handle; // WebGPU adapter handle.
void* device_handle; // WebGPU device handle.
void* dawn_proc_table; // DawnProcTable pointer.
} OrtWGPUProviderOptions;

/** \brief OpenVINO Provider Options
*
* \see OrtApi::SessionOptionsAppendExecutionProvider_OpenVINO
Expand Down Expand Up @@ -4670,6 +4702,37 @@ struct OrtApi {
_In_reads_(num_external_initializer_files) char* const* external_initializer_file_buffer_array,
_In_reads_(num_external_initializer_files) const size_t* external_initializer_file_lengths,
size_t num_external_initializer_files);

/** \brief Append WebGPU execution provider to session options
*
* If WebGPU is not available, this function will return failure.
*
* \param[in] options
* \param[in] wgpu_options - specify the WebGPU provider options.
* \param[in] string_options_keys - keys to configure the string options
* \param[in] string_options_values - values to configure the string options
* \param[in] num_keys - number of keys passed in
*
* Supported keys are listed as below. All entries are optional.
*
* | Key | Possible Values | Default Value |
* | ------------------------------ | ---------------------------------------------- | -------------- |
* | "preferredLayout" | "NHWC" or "NCHW" | "NHWC" |
* | "enableGraphCapture" | "1" or "0" | "0" |
* | "storageBufferCacheMode" | "disabled", "lazyRelease", "simple", "bucket" | "bucket" |
* | "uniformBufferCacheMode" | "disabled", "lazyRelease", "simple", "bucket" | "lazyRelease" |
* | "queryResolveBufferCacheMode" | "disabled", "lazyRelease", "simple", "bucket" | "disabled" |
* | "defaultBufferCacheMode" | "disabled", "lazyRelease", "simple", "bucket" | "disabled" |
*
* \snippet{doc} snippets.dox OrtStatus Return Value
*
* \since Version 1.20.
*/
ORT_API2_STATUS(SessionOptionsAppendExecutionProvider_WGPU,
_In_ OrtSessionOptions* options, _In_ const OrtWGPUProviderOptions* wgpu_options,
_In_reads_(num_keys) const char* const* string_options_keys,
_In_reads_(num_keys) const char* const* string_options_values,
_In_ size_t num_keys);
};

/*
Expand Down
3 changes: 3 additions & 0 deletions include/onnxruntime/core/session/onnxruntime_cxx_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -890,6 +890,9 @@ struct SessionOptionsImpl : ConstSessionOptionsImpl<T> {
SessionOptionsImpl& AppendExecutionProvider_CANN(const OrtCANNProviderOptions& provider_options);
///< Wraps OrtApi::SessionOptionsAppendExecutionProvider_Dnnl
SessionOptionsImpl& AppendExecutionProvider_Dnnl(const OrtDnnlProviderOptions& provider_options);
///< Wraps OrtApi::SessionOptionsAppendExecutionProvider_WGPU
SessionOptionsImpl& AppendExecutionProvider_WGPU(const OrtWGPUProviderOptions& wgpu_options,
const std::unordered_map<std::string, std::string>& string_options = {});
/// Wraps OrtApi::SessionOptionsAppendExecutionProvider. Currently supports QNN, SNPE and XNNPACK.
SessionOptionsImpl& AppendExecutionProvider(const std::string& provider_name,
const std::unordered_map<std::string, std::string>& provider_options = {});
Expand Down
19 changes: 19 additions & 0 deletions include/onnxruntime/core/session/onnxruntime_cxx_inline.h
Original file line number Diff line number Diff line change
Expand Up @@ -838,6 +838,25 @@ inline SessionOptionsImpl<T>& SessionOptionsImpl<T>::AppendExecutionProvider_MIG
return *this;
}

template <typename T>
inline SessionOptionsImpl<T>& SessionOptionsImpl<T>::AppendExecutionProvider_WGPU(const OrtWGPUProviderOptions& wgpu_options,
const std::unordered_map<std::string, std::string>& string_options) {
auto num_entries = string_options.size();
std::vector<const char*> keys, values;
if (num_entries > 0) {
keys.reserve(num_entries);
values.reserve(num_entries);

for (const auto& entry : string_options) {
keys.push_back(entry.first.c_str());
values.push_back(entry.second.c_str());
}
}

ThrowOnError(GetApi().SessionOptionsAppendExecutionProvider_WGPU(this->p_, &wgpu_options, keys.data(), values.data(), num_entries));
return *this;
}

template <typename T>
inline SessionOptionsImpl<T>& SessionOptionsImpl<T>::AppendExecutionProvider_CANN(const OrtCANNProviderOptions& provider_options) {
ThrowOnError(GetApi().SessionOptionsAppendExecutionProvider_CANN(this->p_, &provider_options));
Expand Down
Loading