From 4037bd4c8ee908bcb5f896744963869ef1b6b2ad Mon Sep 17 00:00:00 2001 From: Yulong Wang <7679871+fs-eire@users.noreply.github.com> Date: Wed, 28 Aug 2024 14:53:36 -0700 Subject: [PATCH 01/77] [WIP] WebGPU EP initial commit --- cmake/CMakeLists.txt | 6 + cmake/deps.txt | 1 + .../external/onnxruntime_external_deps.cmake | 11 + cmake/onnxruntime_providers.cmake | 7 + cmake/onnxruntime_providers_cpu.cmake | 7 +- cmake/onnxruntime_providers_webgpu.cmake | 30 + cmake/onnxruntime_providers_webnn.cmake | 2 +- cmake/onnxruntime_unittests.cmake | 22 + include/onnxruntime/core/common/string_join.h | 61 ++ include/onnxruntime/core/graph/constants.h | 1 + .../core/session/onnxruntime_c_api.h | 57 ++ .../core/session/onnxruntime_cxx_api.h | 3 + .../core/session/onnxruntime_cxx_inline.h | 19 + .../webgpu/webgpu_contrib_kernels.cc | 70 ++ .../webgpu/webgpu_contrib_kernels.h | 17 + .../core/providers/get_execution_providers.cc | 8 + .../providers/provider_factory_creators.h | 4 + .../webgpu/How_to_Write_WebGPU_EP_Kernel.md | 156 ++++ onnxruntime/core/providers/webgpu/README.md | 104 +++ .../core/providers/webgpu/allocator.cc | 38 + onnxruntime/core/providers/webgpu/allocator.h | 34 + .../core/providers/webgpu/buffer_manager.cc | 362 ++++++++ .../core/providers/webgpu/buffer_manager.h | 96 ++ .../core/providers/webgpu/compute_context.cc | 37 + .../core/providers/webgpu/compute_context.h | 97 ++ .../core/providers/webgpu/data_transfer.cc | 48 + .../core/providers/webgpu/data_transfer.h | 28 + .../webgpu/math/unary_elementwise_ops.cc | 68 ++ .../webgpu/math/unary_elementwise_ops.h | 28 + onnxruntime/core/providers/webgpu/program.cc | 196 ++++ onnxruntime/core/providers/webgpu/program.h | 491 ++++++++++ .../providers/webgpu/program_cache_key.cc | 90 ++ .../core/providers/webgpu/program_cache_key.h | 16 + .../core/providers/webgpu/program_manager.cc | 188 ++++ .../core/providers/webgpu/program_manager.h | 71 ++ .../core/providers/webgpu/shader_helper.cc | 204 +++++ .../core/providers/webgpu/shader_helper.h | 161 ++++ .../core/providers/webgpu/shader_macros.h | 66 ++ .../core/providers/webgpu/shader_variable.cc | 277 ++++++ .../core/providers/webgpu/shader_variable.h | 263 ++++++ .../core/providers/webgpu/webgpu_context.cc | 349 ++++++++ .../core/providers/webgpu/webgpu_context.h | 124 +++ .../webgpu/webgpu_execution_provider.cc | 837 ++++++++++++++++++ .../webgpu/webgpu_execution_provider.h | 77 ++ .../core/providers/webgpu/webgpu_kernel.h | 42 + .../webgpu/webgpu_provider_factory.cc | 144 +++ .../webgpu/webgpu_provider_factory_creator.h | 18 + .../webgpu/webgpu_provider_options.h | 40 + .../providers/webgpu/webgpu_supported_types.h | 34 + onnxruntime/core/session/inference_session.cc | 8 +- onnxruntime/core/session/onnxruntime_c_api.cc | 2 + onnxruntime/core/session/ort_apis.h | 7 + .../core/session/provider_registration.cc | 60 ++ onnxruntime/test/onnx/main.cc | 62 +- tools/ci_build/build.py | 5 + 55 files changed, 5246 insertions(+), 8 deletions(-) create mode 100644 cmake/onnxruntime_providers_webgpu.cmake create mode 100644 include/onnxruntime/core/common/string_join.h create mode 100644 onnxruntime/contrib_ops/webgpu/webgpu_contrib_kernels.cc create mode 100644 onnxruntime/contrib_ops/webgpu/webgpu_contrib_kernels.h create mode 100644 onnxruntime/core/providers/webgpu/How_to_Write_WebGPU_EP_Kernel.md create mode 100644 onnxruntime/core/providers/webgpu/README.md create mode 100644 onnxruntime/core/providers/webgpu/allocator.cc create mode 100644 onnxruntime/core/providers/webgpu/allocator.h create mode 100644 onnxruntime/core/providers/webgpu/buffer_manager.cc create mode 100644 onnxruntime/core/providers/webgpu/buffer_manager.h create mode 100644 onnxruntime/core/providers/webgpu/compute_context.cc create mode 100644 onnxruntime/core/providers/webgpu/compute_context.h create mode 100644 onnxruntime/core/providers/webgpu/data_transfer.cc create mode 100644 onnxruntime/core/providers/webgpu/data_transfer.h create mode 100644 onnxruntime/core/providers/webgpu/math/unary_elementwise_ops.cc create mode 100644 onnxruntime/core/providers/webgpu/math/unary_elementwise_ops.h create mode 100644 onnxruntime/core/providers/webgpu/program.cc create mode 100644 onnxruntime/core/providers/webgpu/program.h create mode 100644 onnxruntime/core/providers/webgpu/program_cache_key.cc create mode 100644 onnxruntime/core/providers/webgpu/program_cache_key.h create mode 100644 onnxruntime/core/providers/webgpu/program_manager.cc create mode 100644 onnxruntime/core/providers/webgpu/program_manager.h create mode 100644 onnxruntime/core/providers/webgpu/shader_helper.cc create mode 100644 onnxruntime/core/providers/webgpu/shader_helper.h create mode 100644 onnxruntime/core/providers/webgpu/shader_macros.h create mode 100644 onnxruntime/core/providers/webgpu/shader_variable.cc create mode 100644 onnxruntime/core/providers/webgpu/shader_variable.h create mode 100644 onnxruntime/core/providers/webgpu/webgpu_context.cc create mode 100644 onnxruntime/core/providers/webgpu/webgpu_context.h create mode 100644 onnxruntime/core/providers/webgpu/webgpu_execution_provider.cc create mode 100644 onnxruntime/core/providers/webgpu/webgpu_execution_provider.h create mode 100644 onnxruntime/core/providers/webgpu/webgpu_kernel.h create mode 100644 onnxruntime/core/providers/webgpu/webgpu_provider_factory.cc create mode 100644 onnxruntime/core/providers/webgpu/webgpu_provider_factory_creator.h create mode 100644 onnxruntime/core/providers/webgpu/webgpu_provider_options.h create mode 100644 onnxruntime/core/providers/webgpu/webgpu_supported_types.h diff --git a/cmake/CMakeLists.txt b/cmake/CMakeLists.txt index 2e9a50e522171..db0f1ac6ba080 100644 --- a/cmake/CMakeLists.txt +++ b/cmake/CMakeLists.txt @@ -149,6 +149,7 @@ option(onnxruntime_TVM_USE_LLVM "Build TVM with LLVM. Set customized path to llv option(onnxruntime_TVM_USE_HASH "Build ipp-crypto library for support hash algorithm. It is defined for TVM only") option(onnxruntime_USE_XNNPACK "Build with XNNPACK support. Provides an alternative math library on ARM, WebAssembly and x86." OFF) option(onnxruntime_USE_WEBNN "Build with WebNN support. Enable hardware acceleration in web browsers." OFF) +option(onnxruntime_USE_WEBGPU "Build with WebGPU support. Enable WebGPU via C/C++ interface." OFF) # Options related to reducing the binary size produced by the build # XNNPACK EP requires the internal NHWC contrib ops to be available, so this option must be OFF when onnxruntime_USE_XNNPACK is ON @@ -907,6 +908,11 @@ if (onnxruntime_USE_WEBNN) list(APPEND ORT_PROVIDER_CMAKE_FLAGS -Donnxruntime_USE_WEBNN=1) list(APPEND ONNXRUNTIME_PROVIDER_NAMES webnn) endif() +if (onnxruntime_USE_WEBGPU) + list(APPEND ORT_PROVIDER_FLAGS -DUSE_WEBGPU=1) + list(APPEND ORT_PROVIDER_CMAKE_FLAGS -Donnxruntime_USE_WEBGPU=1) + list(APPEND ONNXRUNTIME_PROVIDER_NAMES webgpu) +endif() if (onnxruntime_USE_CANN) list(APPEND ORT_PROVIDER_FLAGS -DUSE_CANN=1) list(APPEND ORT_PROVIDER_CMAKE_FLAGS -Donnxruntime_USE_CANN=1) diff --git a/cmake/deps.txt b/cmake/deps.txt index 2487ea144227d..2ab00cdbeb30c 100644 --- a/cmake/deps.txt +++ b/cmake/deps.txt @@ -59,3 +59,4 @@ extensions;https://github.com/microsoft/onnxruntime-extensions/archive/94142d839 composable_kernel;https://github.com/ROCmSoftwarePlatform/composable_kernel/archive/204da9c522cebec5220bba52cd3542ebcaf99e7a.zip;1827348efd47831c13074245274d41b7cae8a557 directx_headers;https://github.com/microsoft/DirectX-Headers/archive/refs/tags/v1.613.1.zip;47653509a3371eabb156360f42faf582f314bf2e cudnn_frontend;https://github.com/NVIDIA/cudnn-frontend/archive/refs/tags/v1.5.2.zip;11071a47594b20f00af09aad83e0d5203ccf6029 +dawn;https://github.com/google/dawn/archive/9a912d8162d5a837950de14f8849230212e3f51c.zip;7f2cad3db905e2d846d8f2422623850a4463915f diff --git a/cmake/external/onnxruntime_external_deps.cmake b/cmake/external/onnxruntime_external_deps.cmake index 4e52707474052..2dad3479c3c0f 100644 --- a/cmake/external/onnxruntime_external_deps.cmake +++ b/cmake/external/onnxruntime_external_deps.cmake @@ -585,6 +585,17 @@ if (onnxruntime_USE_COREML) FetchContent_Populate(coremltools) endif() +if (onnxruntime_USE_WEBGPU) + FetchContent_Declare( + dawn + URL ${DEP_URL_dawn} + URL_HASH SHA1=${DEP_SHA1_dawn} + ) + set(DAWN_FETCH_DEPENDENCIES ON) + set(DAWN_ENABLE_INSTALL ON) + onnxruntime_fetchcontent_makeavailable(dawn) +endif() + message("Finished fetching external dependencies") set(onnxruntime_LINK_DIRS ) diff --git a/cmake/onnxruntime_providers.cmake b/cmake/onnxruntime_providers.cmake index 05a50a55db409..9666877cdc206 100644 --- a/cmake/onnxruntime_providers.cmake +++ b/cmake/onnxruntime_providers.cmake @@ -110,6 +110,9 @@ endif() if(onnxruntime_USE_WEBNN) set(PROVIDERS_WEBNN onnxruntime_providers_webnn) endif() +if(onnxruntime_USE_WEBGPU) + set(PROVIDERS_WEBGPU onnxruntime_providers_webgpu) +endif() if (onnxruntime_USE_CANN) set(PROVIDERS_CANN onnxruntime_providers_cann) endif() @@ -151,6 +154,10 @@ if (onnxruntime_USE_WEBNN) include(onnxruntime_providers_webnn.cmake) endif() +if (onnxruntime_USE_WEBGPU) + include(onnxruntime_providers_webgpu.cmake) +endif() + if (onnxruntime_USE_NNAPI_BUILTIN) include(onnxruntime_providers_nnapi.cmake) endif() diff --git a/cmake/onnxruntime_providers_cpu.cmake b/cmake/onnxruntime_providers_cpu.cmake index bbcc709b144a0..219fb97536351 100644 --- a/cmake/onnxruntime_providers_cpu.cmake +++ b/cmake/onnxruntime_providers_cpu.cmake @@ -40,6 +40,11 @@ file(GLOB_RECURSE onnxruntime_js_contrib_ops_cc_srcs CONFIGURE_DEPENDS "${ONNXRUNTIME_ROOT}/contrib_ops/js/*.cc" ) +file(GLOB_RECURSE onnxruntime_webgpu_contrib_ops_cc_srcs CONFIGURE_DEPENDS + "${ONNXRUNTIME_ROOT}/contrib_ops/webgpu/*.h" + "${ONNXRUNTIME_ROOT}/contrib_ops/webgpu/*.cc" +) + file(GLOB onnxruntime_providers_common_srcs CONFIGURE_DEPENDS "${ONNXRUNTIME_ROOT}/core/providers/*.h" "${ONNXRUNTIME_ROOT}/core/providers/*.cc" @@ -60,7 +65,7 @@ if(NOT onnxruntime_DISABLE_CONTRIB_OPS) "${ONNXRUNTIME_ROOT}/contrib_ops/cpu/aten_ops/aten_op_executor.cc" ) endif() - set(onnxruntime_cpu_neural_speed_srcs + set(onnxruntime_cpu_neural_speed_srcs "${ONNXRUNTIME_ROOT}/contrib_ops/cpu/quantization/neural_speed_wrapper.h" "${ONNXRUNTIME_ROOT}/contrib_ops/cpu/quantization/neural_speed_defs.h" "${ONNXRUNTIME_ROOT}/contrib_ops/cpu/quantization/neural_speed_gemm.cc" diff --git a/cmake/onnxruntime_providers_webgpu.cmake b/cmake/onnxruntime_providers_webgpu.cmake new file mode 100644 index 0000000000000..303ab9483c38a --- /dev/null +++ b/cmake/onnxruntime_providers_webgpu.cmake @@ -0,0 +1,30 @@ +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. + + if (onnxruntime_MINIMAL_BUILD AND NOT onnxruntime_EXTENDED_MINIMAL_BUILD) + message(FATAL_ERROR "WebGPU EP can not be used in a basic minimal build. Please build with '--minimal_build extended'") + endif() + + # find_package(Dawn REQUIRED) + + add_compile_definitions(USE_WEBGPU=1) + if (onnxruntime_ENABLE_WEBASSEMBLY_THREADS) + add_definitions(-DENABLE_WEBASSEMBLY_THREADS=1) + endif() + file(GLOB_RECURSE onnxruntime_providers_webgpu_cc_srcs CONFIGURE_DEPENDS + "${ONNXRUNTIME_ROOT}/core/providers/webgpu/*.h" + "${ONNXRUNTIME_ROOT}/core/providers/webgpu/*.cc" + # "${ONNXRUNTIME_ROOT}/core/providers/shared/utils/utils.h" + # "${ONNXRUNTIME_ROOT}/core/providers/shared/utils/utils.cc" + ) + if(NOT onnxruntime_DISABLE_CONTRIB_OPS) + source_group(TREE ${ONNXRUNTIME_ROOT} FILES ${onnxruntime_webgpu_contrib_ops_cc_srcs}) + list(APPEND onnxruntime_providers_webgpu_cc_srcs ${onnxruntime_webgpu_contrib_ops_cc_srcs}) + endif() + + source_group(TREE ${REPO_ROOT} FILES ${onnxruntime_providers_webgpu_cc_srcs}) + onnxruntime_add_static_library(onnxruntime_providers_webgpu ${onnxruntime_providers_webgpu_cc_srcs}) + onnxruntime_add_include_to_target(onnxruntime_providers_webgpu onnxruntime_common onnx onnx_proto flatbuffers::flatbuffers Boost::mp11 safeint_interface) + target_link_libraries(onnxruntime_providers_webgpu dawn::webgpu_dawn) + + set_target_properties(onnxruntime_providers_webgpu PROPERTIES FOLDER "ONNXRuntime") diff --git a/cmake/onnxruntime_providers_webnn.cmake b/cmake/onnxruntime_providers_webnn.cmake index 05c63c22244db..39ca476810f41 100644 --- a/cmake/onnxruntime_providers_webnn.cmake +++ b/cmake/onnxruntime_providers_webnn.cmake @@ -22,4 +22,4 @@ add_dependencies(onnxruntime_providers_webnn onnx ${onnxruntime_EXTERNAL_DEPENDENCIES}) set_target_properties(onnxruntime_providers_webnn PROPERTIES FOLDER "ONNXRuntime") - set_target_properties(onnxruntime_providers_webnn PROPERTIES LINKER_LANGUAGE CXX) \ No newline at end of file + set_target_properties(onnxruntime_providers_webnn PROPERTIES LINKER_LANGUAGE CXX) diff --git a/cmake/onnxruntime_unittests.cmake b/cmake/onnxruntime_unittests.cmake index d7f4a0675e118..5434ead12f65d 100644 --- a/cmake/onnxruntime_unittests.cmake +++ b/cmake/onnxruntime_unittests.cmake @@ -557,6 +557,10 @@ if(onnxruntime_USE_JSEP) list(APPEND onnxruntime_test_providers_dependencies onnxruntime_providers_js) endif() +if(onnxruntime_USE_WEBGPU) + list(APPEND onnxruntime_test_providers_dependencies onnxruntime_providers_webgpu) +endif() + if(onnxruntime_USE_RKNPU) list(APPEND onnxruntime_test_providers_dependencies onnxruntime_providers_rknpu) endif() @@ -598,6 +602,7 @@ set(ONNXRUNTIME_TEST_LIBS ${PROVIDERS_NNAPI} ${PROVIDERS_VSINPU} ${PROVIDERS_JS} + ${PROVIDERS_WEBGPU} ${PROVIDERS_QNN} ${PROVIDERS_SNPE} ${PROVIDERS_RKNPU} @@ -658,6 +663,13 @@ if(onnxruntime_USE_JSEP) list(APPEND onnxruntime_test_providers_libs onnxruntime_providers_js) endif() +if(onnxruntime_USE_WEBGPU) + list(APPEND onnxruntime_test_framework_src_patterns ${TEST_SRC_DIR}/providers/webgpu/*) + list(APPEND onnxruntime_test_framework_libs onnxruntime_providers_webgpu) + list(APPEND onnxruntime_test_providers_dependencies onnxruntime_providers_webgpu) + list(APPEND onnxruntime_test_providers_libs onnxruntime_providers_webgpu) +endif() + # QNN EP tests require CPU EP op implementations for accuracy evaluation, so disable on minimal # or reduced op builds. if(onnxruntime_USE_QNN AND NOT onnxruntime_MINIMAL_BUILD AND NOT onnxruntime_REDUCED_OPS_BUILD) @@ -1088,6 +1100,11 @@ if (NOT IOS) endif() set_target_properties(onnx_test_runner PROPERTIES FOLDER "ONNXRuntimeTest") + add_custom_command(TARGET onnx_test_runner POST_BUILD + COMMAND ${CMAKE_COMMAND} -E copy $ $ + COMMAND_EXPAND_LISTS + ) + if (onnxruntime_USE_TVM) if (WIN32) target_link_options(onnx_test_runner PRIVATE "/STACK:4000000") @@ -1218,6 +1235,11 @@ if (NOT onnxruntime_ENABLE_TRAINING_TORCH_INTEROP) endif() endif() + add_custom_command(TARGET onnxruntime_perf_test POST_BUILD + COMMAND ${CMAKE_COMMAND} -E copy $ $ + COMMAND_EXPAND_LISTS + ) + if (onnxruntime_BUILD_SHARED_LIB) #It will dynamically link to onnxruntime. So please don't add onxruntime_graph/onxruntime_framework/... here. #onnxruntime_common is kind of ok because it is thin, tiny and totally stateless. diff --git a/include/onnxruntime/core/common/string_join.h b/include/onnxruntime/core/common/string_join.h new file mode 100644 index 0000000000000..2c2181d4ad048 --- /dev/null +++ b/include/onnxruntime/core/common/string_join.h @@ -0,0 +1,61 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include +#include +#include + +#include "core/common/make_string.h" + +namespace onnxruntime { + +namespace detail { + +template +inline void StringJoinImpl(const Separator& separator, std::ostringstream& ss) noexcept { +} + +template +inline void StringJoinImpl(const Separator& separator, std::ostringstream& ss, const T& t) noexcept { + ss << separator << t; +} + +template +inline void StringJoinImpl(const Separator& separator, std::ostringstream& ss, const T& t, const Args&... args) noexcept { + StringJoinImpl(separator, ss, t); + StringJoinImpl(separator, ss, args...); +} + +template +inline std::string StringJoinImpl(const Separator& separator, const Args&... args) noexcept { + std::ostringstream ss; + ss.imbue(std::locale::classic()); + StringJoinImpl(separator, ss, args...); + return ss.str(); +} +} // namespace detail + +/** + * Makes a string by concatenating string representations of the arguments using the specified separator. + * Uses std::locale::classic() + */ +template +std::string StringJoin(const Separator& separator, const Args&... args) { + return detail::StringJoinImpl(separator, detail::if_char_array_make_ptr_t(args)...); +} + +// StringJoin versions for already-a-string types. + +template +inline std::string StringJoin(const Separator& /* separator */, const std::string& str) { + return str; +} + +template +inline std::string StringJoin(const Separator& /* separator */, const char* cstr) { + return cstr; +} + +} // namespace onnxruntime diff --git a/include/onnxruntime/core/graph/constants.h b/include/onnxruntime/core/graph/constants.h index 39acb6b4f2aa4..f072badd199ba 100644 --- a/include/onnxruntime/core/graph/constants.h +++ b/include/onnxruntime/core/graph/constants.h @@ -50,6 +50,7 @@ constexpr const char* kSnpeExecutionProvider = "SNPEExecutionProvider"; constexpr const char* kTvmExecutionProvider = "TvmExecutionProvider"; constexpr const char* kXnnpackExecutionProvider = "XnnpackExecutionProvider"; constexpr const char* kWebNNExecutionProvider = "WebNNExecutionProvider"; +constexpr const char* kWebGpuExecutionProvider = "WebGpuExecutionProvider"; constexpr const char* kCannExecutionProvider = "CANNExecutionProvider"; constexpr const char* kAzureExecutionProvider = "AzureExecutionProvider"; constexpr const char* kVSINPUExecutionProvider = "VSINPUExecutionProvider"; diff --git a/include/onnxruntime/core/session/onnxruntime_c_api.h b/include/onnxruntime/core/session/onnxruntime_c_api.h index 4674db42fb1c9..9e5d9339bffe9 100644 --- a/include/onnxruntime/core/session/onnxruntime_c_api.h +++ b/include/onnxruntime/core/session/onnxruntime_c_api.h @@ -624,6 +624,32 @@ typedef struct OrtMIGraphXProviderOptions { bool migraphx_exhaustive_tune; // migraphx tuned compile Default = false } OrtMIGraphXProviderOptions; +/** \brief WebGPU Execution Provider Options + * + * When a user wants to use WebGPU as the execution provider, there are 2 ways to specify the WebGPU device: + * + * 1. Use the default WebGPU device. The default WebGPU device is managed by WebGPU EP internally. The user doesn't + * need to provide any device information in this case. All the fields should be set to nullptr or 0. + * + * 2. Use a custom WebGPU device. The user should create their own handles of `WGPUInstance`, `WGPUAdapter`, and + * `WGPUDevice` and use arbitrary number in [1..65536) as the device id. The user should provide the handles + * and the device id in the options. + * + * When specifying an existing Device ID, the user should provide the handles of `WGPUInstance`, `WGPUAdapter`, and + * `WGPUDevice` in the options. The device id should be the same as the one used previously. + * + * It's user's responsibility to manage the lifecycle of the handles and ensure the handles are valid during the + * lifetime of the inference session. + * + * \see OrtApi::SessionOptionsAppendExecutionProvider_WebGPU + */ +typedef struct OrtWebGPUProviderOptions { + int device_id; // WebGPU device id. + void* instance_handle; // WebGPU instance handle. + void* adapter_handle; // WebGPU adapter handle. + void* device_handle; // WebGPU device handle. +} OrtWebGPUProviderOptions; + /** \brief OpenVINO Provider Options * * \see OrtApi::SessionOptionsAppendExecutionProvider_OpenVINO @@ -4667,6 +4693,37 @@ struct OrtApi { _In_reads_(num_external_initializer_files) char* const* external_initializer_file_buffer_array, _In_reads_(num_external_initializer_files) const size_t* external_initializer_file_lengths, size_t num_external_initializer_files); + + /** \brief Append WebGPU execution provider to session options + * + * If WebGPU is not available, this function will return failure. + * + * \param[in] options + * \param[in] webgpu_options - specify the WebGPU provider options. + * \param[in] string_options_keys - keys to configure the string options + * \param[in] string_options_values - values to configure the string options + * \param[in] num_keys - number of keys passed in + * + * Supported keys are listed as below. All entries are optional. + * + * | Key | Possible Values | Default Value | + * | ------------------------------ | ---------------------------------------------- | -------------- | + * | "preferredLayout" | "NHWC" or "NCHW" | "NHWC" | + * | "enableGraphCapture" | "1" or "0" | "0" | + * | "storageBufferCacheMode" | "disabled", "lazyRelease", "simple", "bucket" | "bucket" | + * | "uniformBufferCacheMode" | "disabled", "lazyRelease", "simple", "bucket" | "lazyRelease" | + * | "queryResolveBufferCacheMode" | "disabled", "lazyRelease", "simple", "bucket" | "disabled" | + * | "defaultBufferCacheMode" | "disabled", "lazyRelease", "simple", "bucket" | "disabled" | + * + * \snippet{doc} snippets.dox OrtStatus Return Value + * + * \since Version 1.20. + */ + ORT_API2_STATUS(SessionOptionsAppendExecutionProvider_WebGPU, + _In_ OrtSessionOptions* options, _In_ const OrtWebGPUProviderOptions* webgpu_options, + _In_reads_(num_keys) const char* const* string_options_keys, + _In_reads_(num_keys) const char* const* string_options_values, + _In_ size_t num_keys); }; /* diff --git a/include/onnxruntime/core/session/onnxruntime_cxx_api.h b/include/onnxruntime/core/session/onnxruntime_cxx_api.h index 29a229f427163..cf30584e18a4a 100644 --- a/include/onnxruntime/core/session/onnxruntime_cxx_api.h +++ b/include/onnxruntime/core/session/onnxruntime_cxx_api.h @@ -890,6 +890,9 @@ struct SessionOptionsImpl : ConstSessionOptionsImpl { SessionOptionsImpl& AppendExecutionProvider_CANN(const OrtCANNProviderOptions& provider_options); ///< Wraps OrtApi::SessionOptionsAppendExecutionProvider_Dnnl SessionOptionsImpl& AppendExecutionProvider_Dnnl(const OrtDnnlProviderOptions& provider_options); + ///< Wraps OrtApi::SessionOptionsAppendExecutionProvider_WebGPU + SessionOptionsImpl& AppendExecutionProvider_WebGPU(const OrtWebGPUProviderOptions& webgpu_options, + const std::unordered_map& string_options = {}); /// Wraps OrtApi::SessionOptionsAppendExecutionProvider. Currently supports QNN, SNPE and XNNPACK. SessionOptionsImpl& AppendExecutionProvider(const std::string& provider_name, const std::unordered_map& provider_options = {}); diff --git a/include/onnxruntime/core/session/onnxruntime_cxx_inline.h b/include/onnxruntime/core/session/onnxruntime_cxx_inline.h index d3a8cade4d28f..e5c84395ad95b 100644 --- a/include/onnxruntime/core/session/onnxruntime_cxx_inline.h +++ b/include/onnxruntime/core/session/onnxruntime_cxx_inline.h @@ -838,6 +838,25 @@ inline SessionOptionsImpl& SessionOptionsImpl::AppendExecutionProvider_MIG return *this; } +template +inline SessionOptionsImpl& SessionOptionsImpl::AppendExecutionProvider_WebGPU(const OrtWebGPUProviderOptions& webgpu_options, + const std::unordered_map& string_options) { + auto num_entries = provider_options.size(); + std::vector keys, values; + if (num_entries > 0) { + keys.reserve(num_entries); + values.reserve(num_entries); + + for (const auto& entry : provider_options) { + keys.push_back(entry.first.c_str()); + values.push_back(entry.second.c_str()); + } + } + + ThrowOnError(GetApi().SessionOptionsAppendExecutionProvider_WebGPU(this->p_, &provider_options, keys.data(), values.data(), num_entries)); + return *this; +} + template inline SessionOptionsImpl& SessionOptionsImpl::AppendExecutionProvider_CANN(const OrtCANNProviderOptions& provider_options) { ThrowOnError(GetApi().SessionOptionsAppendExecutionProvider_CANN(this->p_, &provider_options)); diff --git a/onnxruntime/contrib_ops/webgpu/webgpu_contrib_kernels.cc b/onnxruntime/contrib_ops/webgpu/webgpu_contrib_kernels.cc new file mode 100644 index 0000000000000..91f51df588fca --- /dev/null +++ b/onnxruntime/contrib_ops/webgpu/webgpu_contrib_kernels.cc @@ -0,0 +1,70 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "contrib_ops/webgpu/webgpu_contrib_kernels.h" + +#include "core/framework/op_kernel.h" + +namespace onnxruntime { +namespace contrib { +namespace webgpu { + +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kMSDomain, 1, Attention); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kMSDomain, 1, BiasAdd); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kMSDomain, 1, BiasSplitGelu); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kMSDomain, 1, FastGelu); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kMSDomain, 1, FusedConv); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kMSDomain, 1, Gelu); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kMSDomain, 1, GroupQueryAttention); +// LayerNormalization used to be a contrib op that (incorrectly) used kOnnxDomain so we need to version it +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 1, 16, LayerNormalization); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kMSDomain, 1, MatMulNBits); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kMSDomain, 1, MultiHeadAttention); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kMSDomain, 1, QuickGelu); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kMSDomain, 1, RotaryEmbedding); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kMSDomain, 1, SkipLayerNormalization); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 1, SimplifiedLayerNormalization); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kMSDomain, 1, SkipSimplifiedLayerNormalization); + +// template <> +// KernelCreateInfo BuildKernelCreateInfo() { +// KernelCreateInfo info; +// return info; +// } + +Status RegisterWebGpuContribKernels(KernelRegistry& kernel_registry) { + static const BuildKernelCreateInfoFn function_table[] = { + BuildKernelCreateInfo, // default entry to avoid the list become empty after ops-reducing + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + // // LayerNormalization used to be a contrib op that (incorrectly) used kOnnxDomain so we need to version it + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + // BuildKernelCreateInfo + }; + + for (auto& function_table_entry : function_table) { + KernelCreateInfo info = function_table_entry(); + if (info.kernel_def != nullptr) { // filter disabled entries where type is void + ORT_RETURN_IF_ERROR(kernel_registry.Register(std::move(info))); + } + } + return Status::OK(); +} + +} // namespace webgpu +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/webgpu/webgpu_contrib_kernels.h b/onnxruntime/contrib_ops/webgpu/webgpu_contrib_kernels.h new file mode 100644 index 0000000000000..6cdf7382804f9 --- /dev/null +++ b/onnxruntime/contrib_ops/webgpu/webgpu_contrib_kernels.h @@ -0,0 +1,17 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include "core/framework/op_kernel.h" +#include "core/framework/kernel_registry.h" + +namespace onnxruntime { +namespace contrib { +namespace webgpu { + +Status RegisterWebGpuContribKernels(KernelRegistry& kernel_registry); + +} // namespace webgpu +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/get_execution_providers.cc b/onnxruntime/core/providers/get_execution_providers.cc index 61c035bc29ed5..d2a72c3a38b03 100644 --- a/onnxruntime/core/providers/get_execution_providers.cc +++ b/onnxruntime/core/providers/get_execution_providers.cc @@ -162,6 +162,14 @@ constexpr ProviderInfo kProvidersInPriorityOrder[] = true, #else false, +#endif + }, + { + kWebGpuExecutionProvider, +#ifdef USE_WEBGPU + true, +#else + false, #endif }, { diff --git a/onnxruntime/core/providers/provider_factory_creators.h b/onnxruntime/core/providers/provider_factory_creators.h index 47d3f2f793d7c..41e418d9eb97f 100644 --- a/onnxruntime/core/providers/provider_factory_creators.h +++ b/onnxruntime/core/providers/provider_factory_creators.h @@ -94,6 +94,10 @@ #include "core/providers/webnn/webnn_provider_factory_creator.h" #endif +#if defined(USE_WEBGPU) +#include "core/providers/webgpu/webgpu_provider_factory_creator.h" +#endif + #if defined(USE_CANN) #include "core/providers/cann/cann_provider_factory_creator.h" #endif diff --git a/onnxruntime/core/providers/webgpu/How_to_Write_WebGPU_EP_Kernel.md b/onnxruntime/core/providers/webgpu/How_to_Write_WebGPU_EP_Kernel.md new file mode 100644 index 0000000000000..a5a71fd94bf47 --- /dev/null +++ b/onnxruntime/core/providers/webgpu/How_to_Write_WebGPU_EP_Kernel.md @@ -0,0 +1,156 @@ +# How to Write WebGPU EP Kernel + +This document describes how to write a WebGPU EP kernel for ONNX Runtime. + +The following document will assume the operator name is `Example`, and you will see class `ExampleProgram` and `ExampleOpKernel` in the examples. Replace `Example` with the actual operator name you are implementing. + +Follow the following steps to create a WebGPU kernel: + +## 1. Decide _filename_ and _cateogory_, and create a new file at: + +`onnxruntime/core/providers/webgpu/{category}/{filename}.cc` + +- filename is usually a snake_case_name of the operator name, or a descriptive name if it includes multiple operators (eg. binary_elementwise_ops.cc) +- category is the subfolder representing the operator category (eg. math/nn/controlflow) + + see folder structure under onnxruntime/core/providers/cpu/ or onnxruntime/core/providers/cuda/ for examples + +## 2. Declare a new Program class + +### 2.1. The Program class should inherit from Program: + +```c++ +class ExampleProgram : public Program { +// ... +} +``` + +### 2.2. The Program class can define the following information: + +There are 3 types of definitions described as below. All of them are optional. If not specified, it is treated as empty. Those definitions are defined as static const members to ensure they don't depend on any runtime information. + +#### **constants** + +constants are declaration of values that are never changes in the shader code. They are inserted into the WGSL source code like this: + +```wgsl +const A : u32 = 64; +``` + +Use macro `WEBGPU_PROGRAM_DEFINE_CONSTANTS` to define constants in your Program class. + +#### **overridable constants** + +overridable constants are similar to constants, but they can be overridden before the compute pipeline is created. Overridable constants may or may not have a default value. They are inserted into the WGSL source code like this: + +```wgsl +override B : u32 = 64; +override C : f32; +``` + +Use macro `WEBGPU_PROGRAM_DEFINE_OVERRIDABLE_CONSTANTS` to define overridable constants in your Program class. + +#### **uniform definitions** + +uniform definitions are declaration of uniform varables. Their names and type must be defined and cannot be changed. Their values(including length) can be set at runtime. + +Use macro `WEBGPU_PROGRAM_DEFINE_UNIFORMS` to define uniform definitions in your Program class. + +### 2.3. The Program class should override the `GenerateShaderCode` method: + +```c++ +Status GenerateShaderCode(ShaderHelper& sh) const override; +``` + +In the function implementation, `sh` is an instance of `ShaderHelper` which provides a set of helper functions to generate shader code. + +Example: + +```c++ +Status UnaryElementwiseProgram::GenerateShaderCode(ShaderHelper& shader) const { + const auto& input = shader.AddVariable(ProgramVariableScope::Input, + "x", + ToProgramVariableDataType(Inputs()[0].tensor->GetElementType(), 4), + 1); + const auto& output = shader.AddVariable(ProgramVariableScope::Output, + "y", + ToProgramVariableDataType(Outputs()[0]->GetElementType(), 4), + 1); + shader.AppendImplementation(additional_impl_); + shader.MainFunctionBody(shader.GuardAgainstOutOfBoundsWorkgroupSizes("uniforms.vec_size"), + "let a = ", input.GetByOffset("global_idx"), ";\n", + output.SetByOffset("global_idx", expression_)); + + return Status::OK(); +} +``` + +`ShaderHelper::AddVariable` creates an instace of `ShaderVariable`. The class `ShaderVariable` is similar to `IndicesHelper` in onnxruntime-web. It provides a set of helper functions as value/indices/offset getter/setter. + +`ShaderHelper::AppendImplementation` inserts additional implementation code into the shader code. It will be put before the main function. + +`ShaderHelper::MainFunctionBody` generates the main function body. It accepts arbitrary number of arguments and concatenates them into the main function body. + +### 2.3. Lifecycle of the Program class + +For each calls into the `ExampleOpKernel::ComputeInternal()` method, a new instance of the `ExampleProgram` class should be created as local variable (The detail will be explained in `ExampleOpKernel` as below). The Program instance is destroyed when reaching the end of scope. + +A few functions can be called on the Program instance: + +- call `ProgramBase::Inputs` and `ProgramBase::Outputs` to set input/output tensor info. +- call `ProgramBase::CacheHint` to set the cache hint. +- call `ProgramBase::UniformsVariables`(optional) and `ProgramBase::OverridableConstants`(optional) to set runtime info of uniforms and overridable constants. They need to match the corresponding definitions described above. +- call `ProgramBase::DispatchGroupSize` and `ProgramBase::WorkgroupSize`(optional) to set the dispatch group size and workgroup size. + +## 3. Declare a new OpKernel class + +### 3.1. The OpKernel class should inherit from WebGpuKernel: + +```c++ +class ExampleOpKernel : public WebGpuKernel { +// ... +} +``` + +### 3.2. The OpKernel class should override the `ComputeInternal` method: + +```c++ +Status ComputeInternal(ComputeContext& context) const override; +``` + +Usually, in the implementation, we do 3 things: +- Create a local variable of the Program class. +- Set a few runtime info of the Program instance. +- Call `context.RunProgram(program)` to run the program and return the status. + +Complicated operators may do more things. Check header files and existing implementations for more details. + +## 4. Register the operator + +Register the operator just like any EP does. Check existing implementations for more details. + +Please note that registration is composed of 2 parts: +- Use macros like `ONNX_OPERATOR_KERNEL_EX` or `ONNX_OPERATOR_VERSIONED_KERNEL_EX` (or wrap a new macro as what we usually do) to register the operator in kernel source code file. +- Add the operator to onnxruntime/core/providers/webgpu/webgpu_execution_provider.cc + +## 5. Write tests + +This section is WIP. + +## 6. Build and test + +use `build.bat --use_webgpu` to build the WebGPU EP. For Release build, append `--config Release` or `--config RelWithDebInfo` to the command line. + +to test, find the "onnx_test_runner.exe" in your build folder. run it like: +``` +onnx_test_runner.exe -v -e webgpu -C "session.disable_cpu_ep_fallback|1" --model_path=C:\code\onnxruntime\js\test\data\node\opset17\test_abs +``` + +> Assume C:\code\onnxruntime is the root of your onnxruntime repo +> +> if it does not exist, run the following in your onnxruntime repo root: +> ``` +> cd js +> npm ci +> npm run prepare-node-tests +> ``` diff --git a/onnxruntime/core/providers/webgpu/README.md b/onnxruntime/core/providers/webgpu/README.md new file mode 100644 index 0000000000000..d9c4313c8bf3f --- /dev/null +++ b/onnxruntime/core/providers/webgpu/README.md @@ -0,0 +1,104 @@ +# WebGPU Execution Provider + +This folder is for the WebGPU execution provider(WebGPU EP). Currently, WebGPU EP is working in progress. + +## Build WebGPU EP + +Just append `--use_webgpu` to the `build.bat` command line. + +Currently only works on Windows. + +## Troubleshooting + +TODO: add solutions to common problems. + +## Development Guide + +See [How to write WebGPU EP kernel](./How_to_Write_WebGPU_EP_Kernel.md) for more information. + +## Convention + +### Use "webgpu" other than "wgpu" in this folder + +This is referring to the naming convention of variables, classes and namespace. + +ORT C API is using "wgpu". + +Let's keep it "webgpu" for this folder for now. I have a very good reason to do so: + +- search for "webgpu" in the code base shows the WebGPU EP related code and search for "wgpu" shows the WebGPU API related code. This helps me easier to find the code I want to look at. + +And anyway, it's not hard to change it back to "wgpu" if we want to. (but it's harder to change it from "wgpu" to "webgpu") + +### Use macros defined in shader_macros.h + +Take `SS` as example. It's a macro defined in `shader_macros.h` and it's used to concatenate strings. It's just make the `std::ostream::operator<<` to be used in a function call style. + +I prefer to use the macro because I feel like it's easier to read. Check the following code: + +```cpp +ss << "vec4(" << type << ">(" << value1 << ", " << value2 << ", " << value3 << ", " << value4 << ")"; +``` + +vs. + +```cpp +SS("vec4<", type, ">(", value1, ", ", value2, ", ", value3, ", ", value4, ")"); +``` + +### Use the subfolder for kernel implementation + +Operator implementation source code need to be put under a subfolder like "math"/"nn"/"tensor". + +See folder structure under onnxruntime/core/providers/cpu/ or onnxruntime/core/providers/cuda/ for examples. + +## Best Practices + +### Always use std::ostringstream to generate shader code if possible + +This helps to the performance of code generation. + +For example: + +```cpp +ss << "var " << name << " = " << value << ";\n"; +``` + +is better than + +```cpp +ss << ("var " + name + " = " + value + ";\n"); +``` + +### Avoid creating template class for kernel using data type as template parameter. + +This basically means that we should define class like this: + +```cpp +class Abs : public WebGpuKernel { + ... +}; +``` + +instead of + +```cpp + +template // T is tensor element type +class Abs : public WebGpuKernel { + ... +}; +``` + +This is because we don't really read and use `Tensor::Data()`. Tensor stores a handle to a WebGPU buffer but not a pointer to the data. Using template for data type only increases the binary size with no real benefit. + +## TODO items + +The following items are not yet implemented: + +- [ ] Validation Switch (allows to change the behavior of whether perform specific validation checks) +- [ ] pushErrorScope/popErrorScope +- [ ] Graph Capture +- [ ] Profiling supported by WebGPU Query Buffer +- [ ] WebGPU resources tracking (mainly for buffers) +- [ ] Global hanlders( unhandled exceptions and device lost ) diff --git a/onnxruntime/core/providers/webgpu/allocator.cc b/onnxruntime/core/providers/webgpu/allocator.cc new file mode 100644 index 0000000000000..8e27acdc285d4 --- /dev/null +++ b/onnxruntime/core/providers/webgpu/allocator.cc @@ -0,0 +1,38 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#ifdef __EMSCRIPTEN__ +#include +#endif + +#include "core/framework/session_state.h" +#include "core/providers/webgpu/allocator.h" +#include "core/providers/webgpu/webgpu_context.h" + +namespace onnxruntime { +namespace webgpu { + +void* GpuBufferAllocator::Alloc(size_t size) { + if (size == 0) { + return nullptr; + } + + auto buffer = context_.BufferManager().Create(size); + + stats_.num_allocs++; + return buffer; +} + +void GpuBufferAllocator::Free(void* p) { + if (p != nullptr) { + context_.BufferManager().Release(static_cast(p)); + stats_.num_allocs--; + } +} + +void GpuBufferAllocator::GetStats(AllocatorStats* stats) { + *stats = stats_; +} + +} // namespace webgpu +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/webgpu/allocator.h b/onnxruntime/core/providers/webgpu/allocator.h new file mode 100644 index 0000000000000..51ca65a8b4822 --- /dev/null +++ b/onnxruntime/core/providers/webgpu/allocator.h @@ -0,0 +1,34 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include "core/framework/allocator.h" +#include "core/framework/ortdevice.h" + +namespace onnxruntime { +namespace webgpu { + +class WebGpuContext; + +class GpuBufferAllocator : public IAllocator { + public: + GpuBufferAllocator(const WebGpuContext& context) + : IAllocator( + OrtMemoryInfo(WEBGPU_BUFFER, OrtAllocatorType::OrtDeviceAllocator, + OrtDevice(OrtDevice::GPU, OrtDevice::MemType::DEFAULT, 0), + 0, OrtMemTypeDefault)), + context_{context} { + } + + virtual void* Alloc(size_t size) override; + virtual void Free(void* p) override; + void GetStats(AllocatorStats* stats) override; + + private: + AllocatorStats stats_; + const WebGpuContext& context_; +}; + +} // namespace webgpu +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/webgpu/buffer_manager.cc b/onnxruntime/core/providers/webgpu/buffer_manager.cc new file mode 100644 index 0000000000000..d69b1210ade4b --- /dev/null +++ b/onnxruntime/core/providers/webgpu/buffer_manager.cc @@ -0,0 +1,362 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "core/providers/webgpu/buffer_manager.h" +#include "core/providers/webgpu/webgpu_context.h" + +static int xx = 1; + +namespace onnxruntime { +namespace webgpu { + +size_t NormalizeBufferSize(size_t size) { + return (size + 15) / 16 * 16; +} + +class DisabledCacheManager : public IBufferCacheManager { + size_t CalculateBufferSize(size_t request_size) override { + return NormalizeBufferSize(request_size); + } + + WGPUBuffer TryAcquireCachedBuffer(size_t /*buffer_size*/) override { + // always return empty buffer + return nullptr; + } + void RegisterBuffer(WGPUBuffer /*buffer*/, size_t /*request_size*/) override { + // no-op + } + void ReleaseBuffer(WGPUBuffer buffer) override { + wgpuBufferRelease(buffer); + } + + void OnRefresh() override { + // no-op + } +}; + +class LazyReleaseCacheManager : public IBufferCacheManager { + size_t CalculateBufferSize(size_t request_size) override { + return NormalizeBufferSize(request_size); + } + + WGPUBuffer TryAcquireCachedBuffer(size_t /*buffer_size*/) override { + return nullptr; + } + + void RegisterBuffer(WGPUBuffer /*buffer*/, size_t /*request_size*/) override { + // no-op + } + + void ReleaseBuffer(WGPUBuffer buffer) override { + pending_buffers_.emplace_back(buffer); + } + + void OnRefresh() override { + for (auto& buffer : pending_buffers_) { + wgpuBufferRelease(buffer); + } + pending_buffers_.clear(); + } + + std::vector pending_buffers_; +}; + +class SimpleCacheManager : public IBufferCacheManager { + size_t CalculateBufferSize(size_t request_size) override { + return NormalizeBufferSize(request_size); + } + + WGPUBuffer TryAcquireCachedBuffer(size_t buffer_size) override { + auto it = buffers_.find(buffer_size); + if (it != buffers_.end() && !it->second.empty()) { + auto buffer = it->second.back(); + it->second.pop_back(); + return buffer; + } + + return nullptr; + } + + void RegisterBuffer(WGPUBuffer /*buffer*/, size_t /*request_size*/) override { + // no-op + } + + void ReleaseBuffer(WGPUBuffer buffer) override { + pending_buffers_.emplace_back(buffer); + } + + void OnRefresh() override { + for (auto& buffer : pending_buffers_) { + buffers_[wgpuBufferGetSize(buffer)].push_back(buffer); + } + pending_buffers_.clear(); + } + + std::map> buffers_; + std::vector pending_buffers_; +}; + +// TODO: maybe use different bucket size for storage and uniform buffers? +constexpr std::initializer_list> BUCKET_DEFAULT_LIMIT_TABLE = { + {64, 250}, + {128, 200}, + {256, 200}, + {512, 200}, + {2048, 230}, + {4096, 200}, + {8192, 50}, + {16384, 50}, + {32768, 50}, + {65536, 50}, + {131072, 50}, + {262144, 50}, + {524288, 50}, + {1048576, 50}, + {2097152, 30}, + {4194304, 20}, + {8388608, 10}, + {12582912, 10}, + {16777216, 10}, + {26214400, 15}, + {33554432, 22}, + {44236800, 2}, + {58982400, 6}, + // we don't want to cache the bucket sizes below but not caching them + // results in some major performance hits for models like sd-turbo. + {67108864, 6}, + {134217728, 6}, + {167772160, 6}, +}; + +class BucketCacheManager : public IBufferCacheManager { + public: + BucketCacheManager() : buckets_limit_{BUCKET_DEFAULT_LIMIT_TABLE} { + Initialize(); + } + BucketCacheManager(std::unordered_map&& buckets_limit) : buckets_limit_{buckets_limit} { + Initialize(); + } + + size_t CalculateBufferSize(size_t request_size) override { + // binary serch size + auto it = std::lower_bound(buckets_keys_.begin(), buckets_keys_.end(), request_size); + if (it == buckets_keys_.end()) { + return NormalizeBufferSize(request_size); + } else { + return *it; + } + } + + WGPUBuffer TryAcquireCachedBuffer(size_t buffer_size) override { + auto it = buckets_.find(buffer_size); + if (it != buckets_.end() && !it->second.empty()) { + auto buffer = it->second.back(); + it->second.pop_back(); + return buffer; + } + return nullptr; + } + + void RegisterBuffer(WGPUBuffer /*buffer*/, size_t /*request_size*/) override { + // no-op + } + + void ReleaseBuffer(WGPUBuffer buffer) override { + pending_buffers_.emplace_back(buffer); + } + + void OnRefresh() override { + // TODO: consider graph capture. currently not supported + + for (auto& buffer : pending_buffers_) { + auto buffer_size = wgpuBufferGetSize(buffer); + + auto it = buckets_.find(buffer_size); + if (it != buckets_.end() && it->second.size() < buckets_limit_[buffer_size]) { + it->second.push_back(buffer); + } else { + wgpuBufferRelease(buffer); + } + } + } + + protected: + void Initialize() { + buckets_keys_.reserve(buckets_limit_.size()); + buckets_.reserve(buckets_limit_.size()); + for (const auto& pair : buckets_limit_) { + buckets_keys_.push_back(pair.first); + buckets_.emplace(pair.first, std::vector()); + } + std::sort(buckets_keys_.begin(), buckets_keys_.end()); + +#ifndef NDEBUG // if debug build + for (size_t i = 0; i < buckets_keys_.size(); ++i) { + ORT_ENFORCE(buckets_keys_[i] % 16 == 0, "Bucket sizes must be multiples of 16."); + } + + for (size_t i = 1; i < buckets_keys_.size(); ++i) { + ORT_ENFORCE(buckets_keys_[i] > buckets_keys_[i - 1], "Bucket sizes must be in increasing order."); + } +#endif + } + std::unordered_map buckets_limit_; + std::unordered_map> buckets_; + std::vector pending_buffers_; + std::vector buckets_keys_; +}; + +std::unique_ptr CreateBufferCacheManager(BufferCacheMode cache_mode) { + switch (cache_mode) { + case BufferCacheMode::Disabled: + return std::make_unique(); + case BufferCacheMode::LazyRelease: + return std::make_unique(); + case BufferCacheMode::Simple: + return std::make_unique(); + case BufferCacheMode::Bucket: + return std::make_unique(); + default: + ORT_NOT_IMPLEMENTED("Unsupported buffer cache mode"); + } +} + +std::ostream& operator<<(std::ostream& os, BufferCacheMode mode) { + switch (mode) { + case BufferCacheMode::Disabled: + os << "Disabled"; + break; + case BufferCacheMode::LazyRelease: + os << "LazyRelease"; + break; + case BufferCacheMode::Simple: + os << "Simple"; + break; + case BufferCacheMode::Bucket: + os << "Bucket"; + break; + default: + os << "Unknown(" << static_cast(mode) << ")"; + } + return os; +} + +BufferManager::BufferManager(WebGpuContext& context, BufferCacheMode storage_buffer_cache_mode, BufferCacheMode uniform_buffer_cache_mode, BufferCacheMode query_resolve_buffer_cache_mode) + : context_{context}, + storage_cache_{std::move(CreateBufferCacheManager(storage_buffer_cache_mode))}, + uniform_cache_{std::move(CreateBufferCacheManager(uniform_buffer_cache_mode))}, + query_resolve_cache_{std::move(CreateBufferCacheManager(query_resolve_buffer_cache_mode))}, + default_cache_{std::move(CreateBufferCacheManager(BufferCacheMode::Disabled))} { +} + +void BufferManager::Upload(void* src, WGPUBuffer dst, size_t size) { + auto buffer_size = NormalizeBufferSize(size); + + wgpu::BufferDescriptor desc{}; + desc.size = buffer_size; + desc.usage = wgpu::BufferUsage::CopySrc | wgpu::BufferUsage::MapWrite; + desc.mappedAtCreation = true; + + auto staging_buffer = context_.Device().CreateBuffer(&desc); + auto mapped_data = staging_buffer.GetMappedRange(); + memcpy(mapped_data, src, size); + staging_buffer.Unmap(); + + auto& command_encoder = context_.GetCommandEncoder(); + context_.EndComputePass(); + command_encoder.CopyBufferToBuffer(staging_buffer, 0, dst, 0, buffer_size); + pending_staging_buffers_.push_back(staging_buffer); +} + +void BufferManager::MemCpy(WGPUBuffer src, WGPUBuffer dst, size_t size) { + ORT_ENFORCE(src != dst, "Source and destination buffers must be different."); + + auto buffer_size = NormalizeBufferSize(size); + ORT_ENFORCE(buffer_size <= wgpuBufferGetSize(src) && buffer_size <= wgpuBufferGetSize(dst), + "Source and destination buffers must have enough space for the copy operation. src_size=", + wgpuBufferGetSize(src), ", dst_size=", wgpuBufferGetSize(dst), ", copy_size=", buffer_size, "."); + + auto& command_encoder = context_.GetCommandEncoder(); + context_.EndComputePass(); + command_encoder.CopyBufferToBuffer(src, 0, dst, 0, buffer_size); +} + +WGPUBuffer BufferManager::Create(size_t size, wgpu::BufferUsage usage) { + auto& cache = GetCacheManager(static_cast(usage)); + auto buffer_size = cache.CalculateBufferSize(size); + + auto buffer = cache.TryAcquireCachedBuffer(buffer_size); + if (buffer) { + return buffer; + } + + // cache miss, create a new buffer + wgpu::BufferDescriptor desc{}; + desc.size = buffer_size; + desc.usage = usage; + // desc.label = std::to_string(xx++).c_str(); + buffer = context_.Device().CreateBuffer(&desc).MoveToCHandle(); + + ORT_ENFORCE(buffer, "Failed to create GPU buffer: size=", buffer_size, ", usage=", uint64_t(usage), "."); + + cache.RegisterBuffer(buffer, size); + return buffer; +} + +void BufferManager::Release(WGPUBuffer buffer) { + GetCacheManager(buffer).ReleaseBuffer(buffer); +} + +void BufferManager::Download(WGPUBuffer src, void* dst, size_t size) { + auto buffer_size = NormalizeBufferSize(size); + + wgpu::BufferDescriptor desc{}; + desc.size = buffer_size; + desc.usage = wgpu::BufferUsage::CopyDst | wgpu::BufferUsage::MapRead; + + auto staging_buffer = context_.Device().CreateBuffer(&desc); + auto& command_encoder = context_.GetCommandEncoder(); + context_.EndComputePass(); + command_encoder.CopyBufferToBuffer(src, 0, staging_buffer, 0, buffer_size); + context_.Flush(); + + // TODO: revise wait in whole project + + ORT_ENFORCE(context_.Wait(staging_buffer.MapAsync(wgpu::MapMode::Read, 0, buffer_size, wgpu::CallbackMode::WaitAnyOnly, [](wgpu::MapAsyncStatus status, const char* message) { + ORT_ENFORCE(status == wgpu::MapAsyncStatus::Success, "Failed to download data from buffer: ", message); + })) == Status::OK()); + + auto mapped_data = staging_buffer.GetConstMappedRange(); + memcpy(dst, mapped_data, size); +} + +void BufferManager::RefreshPendingBuffers() { + pending_staging_buffers_.clear(); + storage_cache_->OnRefresh(); + uniform_cache_->OnRefresh(); + query_resolve_cache_->OnRefresh(); + default_cache_->OnRefresh(); +} + +IBufferCacheManager& BufferManager::GetCacheManager(WGPUBufferUsage usage) const { + if (usage & WGPUBufferUsage_Storage) { + return *storage_cache_; + } else if (usage & WGPUBufferUsage_Uniform) { + return *uniform_cache_; + } else if (usage & WGPUBufferUsage_QueryResolve) { + return *query_resolve_cache_; + } else { + return *default_cache_; + } +} + +IBufferCacheManager& BufferManager::GetCacheManager(WGPUBuffer buffer) const { + return GetCacheManager(wgpuBufferGetUsage(buffer)); +} + +std::unique_ptr BufferManagerFactory::Create(WebGpuContext& context, BufferCacheMode storage_buffer_cache_mode, BufferCacheMode uniform_buffer_cache_mode, BufferCacheMode query_resolve_buffer_cache_mode) { + return std::make_unique(context, storage_buffer_cache_mode, uniform_buffer_cache_mode, query_resolve_buffer_cache_mode); +} + +} // namespace webgpu +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/webgpu/buffer_manager.h b/onnxruntime/core/providers/webgpu/buffer_manager.h new file mode 100644 index 0000000000000..c94f77b6b5fa0 --- /dev/null +++ b/onnxruntime/core/providers/webgpu/buffer_manager.h @@ -0,0 +1,96 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include + +#ifdef __EMSCRIPTEN__ +#include +#endif + +#include + +#include "core/framework/data_transfer.h" +#include "core/framework/execution_provider.h" + +namespace onnxruntime { +namespace webgpu { + +class WebGpuContext; + +enum class BufferCacheMode { + Disabled, + LazyRelease, + Simple, + Bucket +}; +std::ostream& operator<<(std::ostream& os, BufferCacheMode mode); + +// +// IBufferCacheManager is an interface for buffer cache management. +// +// By implementing this interface, we can have different buffer cache management strategies. +// Currently, we have 3 strategies: +// - Disabled: no cache. always allocate a new buffer and release it immediately after use. +// - LazyRelease: no cache. the difference from Disabled is that it delays the release of buffers until the next refresh. +// - Simple: a simple cache that always keeps buffers. when a buffer is requested, it tries to find a buffer in the cache. +// - Bucket: a cache that keeps buffers in different buckets based on the buffer size, with a maximum number of buffers in each bucket. +// +class IBufferCacheManager { + public: + virtual ~IBufferCacheManager() = default; + + // calculate actual buffer size to allocate based on the requested size. + virtual size_t CalculateBufferSize(size_t request_size) = 0; + + // return a buffer if available in cache. otherwise empty. + virtual WGPUBuffer TryAcquireCachedBuffer(size_t buffer_size) = 0; + + // register a newly created buffer + virtual void RegisterBuffer(WGPUBuffer buffer, size_t request_size) = 0; + + // release a buffer + virtual void ReleaseBuffer(WGPUBuffer buffer) = 0; + + // when a stream refresh is requested + virtual void OnRefresh() = 0; +}; + +// +// BufferManager manages operations on buffers. +// +class BufferManager { + public: + BufferManager(WebGpuContext& context, BufferCacheMode storage_buffer_cache_mode, BufferCacheMode uniform_buffer_cache_mode, BufferCacheMode query_resolve_buffer_cache_mode); + + void Upload(void* src, WGPUBuffer dst, size_t size); + void MemCpy(WGPUBuffer src, WGPUBuffer dst, size_t size); + WGPUBuffer Create(size_t size, wgpu::BufferUsage usage = wgpu::BufferUsage::Storage | wgpu::BufferUsage::CopySrc | wgpu::BufferUsage::CopyDst); + void Release(WGPUBuffer buffer); + void Download(WGPUBuffer src, void* dst, size_t size); + void RefreshPendingBuffers(); + + private: + IBufferCacheManager& GetCacheManager(WGPUBufferUsage usage) const; + IBufferCacheManager& GetCacheManager(WGPUBuffer buffer) const; + + WebGpuContext& context_; + std::unique_ptr storage_cache_; + std::unique_ptr uniform_cache_; + std::unique_ptr query_resolve_cache_; + std::unique_ptr default_cache_; + + std::vector pending_staging_buffers_; +}; + +class BufferManagerFactory { + public: + static std::unique_ptr Create(WebGpuContext& context, BufferCacheMode storage_buffer_cache_mode, BufferCacheMode uniform_buffer_cache_mode, BufferCacheMode query_resolve_buffer_cache_mode); + + private: + BufferManagerFactory() {} +}; + +} // namespace webgpu +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/webgpu/compute_context.cc b/onnxruntime/core/providers/webgpu/compute_context.cc new file mode 100644 index 0000000000000..67c55f823d78a --- /dev/null +++ b/onnxruntime/core/providers/webgpu/compute_context.cc @@ -0,0 +1,37 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "core/framework/op_kernel.h" + +#include "core/providers/webgpu/compute_context.h" +#include "core/providers/webgpu/webgpu_context.h" + +namespace onnxruntime { +namespace webgpu { +ComputeContext::ComputeContext(OpKernelContext& kernel_context) + : webgpu_context_{WebGpuContextFactory::GetContext(kernel_context.GetDeviceId())}, + kernel_context_{kernel_context} { +} + +const wgpu::AdapterInfo& ComputeContext::AdapterInfo() const { + return webgpu_context_.AdapterInfo(); +} + +const wgpu::Limits& ComputeContext::DeviceLimits() const { + return webgpu_context_.DeviceLimits(); +} + +int ComputeContext::InputCount() const { + return kernel_context_.InputCount(); +} + +int ComputeContext::OutputCount() const { + return kernel_context_.OutputCount(); +} + +Status ComputeContext::RunProgram(const ProgramBase& program) { + return webgpu_context_.Run(*this, program); +} + +} // namespace webgpu +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/webgpu/compute_context.h b/onnxruntime/core/providers/webgpu/compute_context.h new file mode 100644 index 0000000000000..d7aeae240101a --- /dev/null +++ b/onnxruntime/core/providers/webgpu/compute_context.h @@ -0,0 +1,97 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#ifdef __EMSCRIPTEN__ +#include +#endif + +#include + +#include + +#include "core/framework/execution_provider.h" + +#include "core/providers/webgpu/program.h" + +namespace onnxruntime { + +class Tensor; +class OpKernelContext; + +namespace webgpu { + +class WebGpuContext; + +class ComputeContext { + public: + ComputeContext(OpKernelContext& kernel_context); + + virtual ~ComputeContext() = default; + + // + // Get various information from the context. + // + + const wgpu::AdapterInfo& AdapterInfo() const; + const wgpu::Limits& DeviceLimits() const; + + // + // Get input tensor. + // + template + const T* Input(int index) const { + return kernel_context_.Input(index); + } + + // + // Get input count. + // + int InputCount() const; + + // + // Set output tensor. + // + template + Tensor* Output(int index, TensorShapeType&& shape) { + return kernel_context_.Output(index, std::forward(shape)); + } + + // + // Get output count. + // + int OutputCount() const; + + // + // Create CPU tensor. + // + template + Tensor CreateCPUTensor(MLDataType data_type, TensorShapeType&& shape) { + AllocatorPtr allocator; + ORT_THROW_IF_ERROR(kernel_context_.GetTempSpaceCPUAllocator(&allocator)); + return {data_type, std::forward(shape)..., allocator}; + } + + // + // Create GPU tensor. + // + template + Tensor CreateGPUTensor(MLDataType data_type, TensorShapeType&& shape) { + AllocatorPtr allocator; + ORT_THROW_IF_ERROR(kernel_context_.GetTempSpaceAllocator(&allocator)); + return {data_type, std::forward(shape)..., allocator}; + } + + // + // Run a compute shader program. + // + Status RunProgram(const ProgramBase& program); + + protected: + WebGpuContext& webgpu_context_; + OpKernelContext& kernel_context_; +}; + +} // namespace webgpu +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/webgpu/data_transfer.cc b/onnxruntime/core/providers/webgpu/data_transfer.cc new file mode 100644 index 0000000000000..615ae11175782 --- /dev/null +++ b/onnxruntime/core/providers/webgpu/data_transfer.cc @@ -0,0 +1,48 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#ifdef __EMSCRIPTEN__ +#include +#endif + +#include "core/providers/webgpu/data_transfer.h" +#include "core/providers/webgpu/webgpu_context.h" + +namespace onnxruntime { +namespace webgpu { + +bool DataTransfer::CanCopy(const OrtDevice& src_device, const OrtDevice& dst_device) const { + return (dst_device.Type() == OrtDevice::GPU && src_device.Type() == OrtDevice::CPU) || + (dst_device.Type() == OrtDevice::GPU && src_device.Type() == OrtDevice::GPU) || + (dst_device.Type() == OrtDevice::CPU && src_device.Type() == OrtDevice::GPU); +} + +common::Status DataTransfer::CopyTensor(const Tensor& src, Tensor& dst) const { + size_t bytes = src.SizeInBytes(); + if (bytes > 0) { + void const* src_data = src.DataRaw(); + void* dst_data = dst.MutableDataRaw(); + + auto& src_device = src.Location().device; + auto& dst_device = dst.Location().device; + + if (dst_device.Type() == OrtDevice::GPU) { + if (src_device.Type() == OrtDevice::GPU) { + // copy from GPU to GPU + context_.BufferManager().MemCpy(static_cast(const_cast(src_data)), + static_cast(dst_data), bytes); + } else { + // copy from CPU to GPU + context_.BufferManager().Upload(const_cast(src_data), static_cast(dst_data), bytes); + } + } else /* if (src_device.Type() == OrtDevice::GPU) */ { + // copy from GPU to CPU + context_.BufferManager().Download(static_cast(const_cast(src_data)), dst_data, bytes); + } + } + + return Status::OK(); +} + +} // namespace webgpu +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/webgpu/data_transfer.h b/onnxruntime/core/providers/webgpu/data_transfer.h new file mode 100644 index 0000000000000..f9949576aa60b --- /dev/null +++ b/onnxruntime/core/providers/webgpu/data_transfer.h @@ -0,0 +1,28 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include "core/framework/data_transfer.h" +#include "core/framework/execution_provider.h" + +namespace onnxruntime { +namespace webgpu { + +class WebGpuContext; + +class DataTransfer : public IDataTransfer { + public: + DataTransfer(const WebGpuContext& context) : context_{context} {}; + ~DataTransfer() {}; + + bool CanCopy(const OrtDevice& src_device, const OrtDevice& dst_device) const override; + + common::Status CopyTensor(const Tensor& src, Tensor& dst) const override; + + private: + const WebGpuContext& context_; +}; + +} // namespace webgpu +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/webgpu/math/unary_elementwise_ops.cc b/onnxruntime/core/providers/webgpu/math/unary_elementwise_ops.cc new file mode 100644 index 0000000000000..5c774df84638e --- /dev/null +++ b/onnxruntime/core/providers/webgpu/math/unary_elementwise_ops.cc @@ -0,0 +1,68 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "core/providers/webgpu/math/unary_elementwise_ops.h" +#include "core/providers/webgpu/shader_helper.h" +#include "core/providers/webgpu/webgpu_supported_types.h" + +namespace onnxruntime { +namespace webgpu { +Status UnaryElementwiseProgram::GenerateShaderCode(ShaderHelper& shader) const { + const auto& input = shader.AddVariable(ProgramVariableScope::Input, + "x", + ToProgramVariableDataType(Inputs()[0].tensor->GetElementType(), 4), + 1); + const auto& output = shader.AddVariable(ProgramVariableScope::Output, + "y", + ToProgramVariableDataType(Outputs()[0]->GetElementType(), 4), + 1); + shader.AppendImplementation(additional_impl_); + shader.MainFunctionBody(shader.GuardAgainstOutOfBoundsWorkgroupSizes("uniforms.vec_size"), + "let a = ", input.GetByOffset("global_idx"), ";\n", + output.SetByOffset("global_idx", expression_)); + + return Status::OK(); +} + +#define WEBGPU_ELEMENTWISE_IMPL(OP_TYPE, ...) \ + class OP_TYPE final : public WebGpuKernel { \ + public: \ + OP_TYPE(const OpKernelInfo& info) : WebGpuKernel{info} {} \ + \ + protected: \ + Status ComputeInternal(ComputeContext& context) const override { \ + const auto* input_tensor = context.Input(0); \ + auto* output_tensor = context.Output(0, input_tensor->Shape()); \ + SafeInt vec_size = (input_tensor->Shape().Size() + 3) / 4; \ + UnaryElementwiseProgram program{#OP_TYPE, __VA_ARGS__}; \ + program \ + .Inputs({{input_tensor, ProgramInputTensorDependency::Type}}) \ + .Outputs({output_tensor}) \ + .DispatchGroupSize((vec_size + WORKGROUP_SIZE - 1) / WORKGROUP_SIZE) \ + .UniformVariables({ \ + {static_cast(vec_size)}, \ + }); \ + return context.RunProgram(program); \ + } \ + }; + +#define WEBGPU_ELEMENTWISE_KERNEL(OP_TYPE, VERSION, KERNEL_CLASS, TYPE) \ + ONNX_OPERATOR_KERNEL_EX( \ + OP_TYPE, kOnnxDomain, VERSION, kWebGpuExecutionProvider, \ + KernelDefBuilder().TypeConstraint("T", TYPE), \ + KERNEL_CLASS); + +#define WEBGPU_ELEMENTWISE_VERSIONED_KERNEL(OP_TYPE, VERSION_FROM, VERSION_TO, KERNEL_CLASS, TYPE) \ + ONNX_OPERATOR_VERSIONED_KERNEL_EX( \ + OP_TYPE, kOnnxDomain, VERSION_FROM, VERSION_TO, kWebGpuExecutionProvider, \ + KernelDefBuilder().TypeConstraint("T", TYPE), \ + KERNEL_CLASS); + +WEBGPU_ELEMENTWISE_IMPL(Abs, "abs(a)") +WEBGPU_ELEMENTWISE_VERSIONED_KERNEL(Abs, 6, 12, Abs, WebGpuSupportedFloatTypes()) +WEBGPU_ELEMENTWISE_KERNEL(Abs, 13, Abs, WebGpuSupportedFloatTypes()) + +// TODO: add other unary elementwise ops + +} // namespace webgpu +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/webgpu/math/unary_elementwise_ops.h b/onnxruntime/core/providers/webgpu/math/unary_elementwise_ops.h new file mode 100644 index 0000000000000..837f66af30dde --- /dev/null +++ b/onnxruntime/core/providers/webgpu/math/unary_elementwise_ops.h @@ -0,0 +1,28 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include "core/providers/webgpu/webgpu_kernel.h" +#include "core/providers/webgpu/program.h" + +namespace onnxruntime { +namespace webgpu { + +class UnaryElementwiseProgram final : public Program { + public: + UnaryElementwiseProgram(const std::string& kernel_name, const std::string& expression, const std::string& additional_impl = "") + : Program{kernel_name}, expression_{expression}, additional_impl_{additional_impl} { + } + + Status GenerateShaderCode(ShaderHelper& sh) const override; + + WEBGPU_PROGRAM_DEFINE_UNIFORM_VARIABLES({"vec_size", ProgramUniformVariableDataType::Uint32}); + + private: + std::string expression_; + std::string additional_impl_; +}; + +} // namespace webgpu +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/webgpu/program.cc b/onnxruntime/core/providers/webgpu/program.cc new file mode 100644 index 0000000000000..8ba33bcafb316 --- /dev/null +++ b/onnxruntime/core/providers/webgpu/program.cc @@ -0,0 +1,196 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include + +#include "core/session/onnxruntime_c_api.h" + +#include "core/providers/webgpu/program.h" + +namespace onnxruntime { +namespace webgpu { + +ProgramUniformVariableValue::ProgramUniformVariableValue() + : length{0}, data_type{} {} // representing an empty uniform variable + +ProgramUniformVariableValue::ProgramUniformVariableValue(float value) + : ProgramUniformVariableValue(ProgramUniformVariableDataType::Float32, &value, sizeof(float)) {} + +ProgramUniformVariableValue::ProgramUniformVariableValue(uint32_t value) + : ProgramUniformVariableValue(ProgramUniformVariableDataType::Uint32, &value, sizeof(uint32_t)) {} + +ProgramUniformVariableValue::ProgramUniformVariableValue(int32_t value) + : ProgramUniformVariableValue(ProgramUniformVariableDataType::Int32, &value, sizeof(int32_t)) {} + +ProgramUniformVariableValue::ProgramUniformVariableValue(MLFloat16 value) + : ProgramUniformVariableValue(ProgramUniformVariableDataType::Float16, &value, sizeof(MLFloat16)) {} + +ProgramUniformVariableValue::ProgramUniformVariableValue(gsl::span values) + : ProgramUniformVariableValue(ProgramUniformVariableDataType::Float32, values.data(), sizeof(float), values.size()) {} + +ProgramUniformVariableValue::ProgramUniformVariableValue(gsl::span values) + : ProgramUniformVariableValue(ProgramUniformVariableDataType::Uint32, values.data(), sizeof(uint32_t), values.size()) {} + +ProgramUniformVariableValue::ProgramUniformVariableValue(gsl::span values) + : ProgramUniformVariableValue(ProgramUniformVariableDataType::Int32, values.data(), sizeof(int32_t), values.size()) {} + +ProgramUniformVariableValue::ProgramUniformVariableValue(gsl::span values) + : ProgramUniformVariableValue(ProgramUniformVariableDataType::Float16, values.data(), sizeof(MLFloat16), values.size()) {} + +ProgramUniformVariableValue::ProgramUniformVariableValue(ProgramUniformVariableDataType data_type, + const void* ptr, + size_t element_byte_size, + size_t length /* = 1 */) + : length{length}, data_type{data_type} { + ORT_ENFORCE(length > 0, "number of element of uniform variable must be greater than 0"); + + data.resize(length * element_byte_size); + memcpy(data.data(), ptr, length * element_byte_size); +} + +std::ostream& operator<<(std::ostream& os, ProgramUniformVariableDataType type) { + os << ProgramUniformVariableDataTypeName[static_cast(type)]; + return os; +} + +std::ostream& operator<<(std::ostream& os, ProgramConstantDataType type) { + os << ProgramConstantDataTypeName[static_cast(type)]; + return os; +} + +std::ostream& operator<<(std::ostream& os, ProgramInputTensorDependency dep) { + bool first = true; + if ((dep & ProgramInputTensorDependency::Type) == ProgramInputTensorDependency::Type) { + os << "Type"; + first = false; + } + if ((dep & ProgramInputTensorDependency::Rank) == ProgramInputTensorDependency::Rank) { + if (!first) os << "|"; + os << "Rank"; + first = false; + } + if ((dep & ProgramInputTensorDependency::Shape) == ProgramInputTensorDependency::Shape) { + if (!first) os << "|"; + os << "Shape"; + first = false; + } + if (first) { + os << "None"; + } + + return os; +} + +ProgramVariableDataType ToProgramVariableDataType(int32_t element_type, int component /* = 1 */) { + if (component == 1) { + switch (element_type) { + case ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT: + return ProgramVariableDataType::Float32; + case ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT16: + return ProgramVariableDataType::Float16; + case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32: + return ProgramVariableDataType::Int32; + case ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT32: + return ProgramVariableDataType::Uint32; + case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64: + return ProgramVariableDataType::Int64; + case ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT64: + return ProgramVariableDataType::Uint64; + default: + return ProgramVariableDataType::InvalidType; + } + } else if (component == 2) { + switch (element_type) { + case ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT: + return ProgramVariableDataType::Vec2Float32; + case ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT16: + return ProgramVariableDataType::Vec2Float16; + case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32: + return ProgramVariableDataType::Vec2Int32; + case ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT32: + return ProgramVariableDataType::Vec2Uint32; + default: + return ProgramVariableDataType::InvalidType; + } + } else if (component == 4) { + switch (element_type) { + case ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT: + return ProgramVariableDataType::Vec4Float32; + case ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT16: + return ProgramVariableDataType::Vec4Float16; + case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32: + return ProgramVariableDataType::Vec4Int32; + case ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT32: + return ProgramVariableDataType::Vec4Uint32; + case ONNX_TENSOR_ELEMENT_DATA_TYPE_BOOL: + return ProgramVariableDataType::Vec4Bool; + default: + return ProgramVariableDataType::InvalidType; + } + } else { + return ProgramVariableDataType::InvalidType; + } +} + +ProgramBase::ProgramBase(const std::string& name) + : name_{name}, + dispatch_group_size_x_{0}, + dispatch_group_size_y_{0}, + dispatch_group_size_z_{0}, + workgroup_size_x_{WORKGROUP_SIZE}, + workgroup_size_y_{1}, + workgroup_size_z_{1} { +} + +ProgramBase& ProgramBase::Inputs(std::initializer_list inputs) { + inputs_.assign(inputs.begin(), inputs.end()); + return *this; +} + +ProgramBase& ProgramBase::Outputs(std::initializer_list outputs) { + outputs_.assign(outputs.begin(), outputs.end()); + return *this; +} + +ProgramBase& ProgramBase::DispatchGroupSize(uint32_t x) { + return DispatchGroupSize(x, 1, 1); +} + +ProgramBase& ProgramBase::DispatchGroupSize(uint32_t x, uint32_t y) { + return DispatchGroupSize(x, y, 1); +} + +ProgramBase& ProgramBase::DispatchGroupSize(uint32_t x, uint32_t y, uint32_t z) { + dispatch_group_size_x_ = x; + dispatch_group_size_y_ = y; + dispatch_group_size_z_ = z; + return *this; +} + +ProgramBase& ProgramBase::WorkgroupSize(uint32_t x) { + return WorkgroupSize(x, 1, 1); +} + +ProgramBase& ProgramBase::WorkgroupSize(uint32_t x, uint32_t y) { + return WorkgroupSize(x, y, 1); +} + +ProgramBase& ProgramBase::WorkgroupSize(uint32_t x, uint32_t y, uint32_t z) { + workgroup_size_x_ = x; + workgroup_size_y_ = y; + workgroup_size_z_ = z; + return *this; +} + +ProgramBase& ProgramBase::UniformVariables(std::initializer_list variables) { + variables_.insert(variables_.end(), variables.begin(), variables.end()); + return *this; +} + +ProgramBase& ProgramBase::OverridableConstants(std::initializer_list overridable_constants) { + overridable_constants_.insert(overridable_constants_.end(), overridable_constants.begin(), overridable_constants.end()); + return *this; +} + +} // namespace webgpu +} // namespace onnxruntime \ No newline at end of file diff --git a/onnxruntime/core/providers/webgpu/program.h b/onnxruntime/core/providers/webgpu/program.h new file mode 100644 index 0000000000000..6df918e2f7f71 --- /dev/null +++ b/onnxruntime/core/providers/webgpu/program.h @@ -0,0 +1,491 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include +#include +#include + +#include "core/common/common.h" +#include "core/common/string_join.h" +#include "core/common/safeint.h" +#include "core/framework/tensor.h" + +namespace onnxruntime { +namespace webgpu { +class ShaderHelper; +class ComputeContext; +class WebGpuContext; + +// data type of uniform variable +enum class ProgramUniformVariableDataType { + Float32, + Float16, + Uint32, + Int32, +}; +std::ostream& operator<<(std::ostream& os, ProgramUniformVariableDataType); + +constexpr size_t ProgramUniformVariableDataTypeSize[] = {sizeof(float), sizeof(uint16_t), sizeof(uint32_t), sizeof(int32_t)}; + +constexpr std::string_view ProgramUniformVariableDataTypeName[] = {"f32", "f16", "u32", "i32"}; + +// represents a runtime value of a uniform variable +struct ProgramUniformVariableValue { + ProgramUniformVariableValue(); // representing an empty uniform variable + ProgramUniformVariableValue(float value); + ProgramUniformVariableValue(uint32_t value); + ProgramUniformVariableValue(int32_t value); + ProgramUniformVariableValue(MLFloat16 value); + ProgramUniformVariableValue(gsl::span values); + ProgramUniformVariableValue(gsl::span values); + ProgramUniformVariableValue(gsl::span values); + ProgramUniformVariableValue(gsl::span values); + + size_t length; + ProgramUniformVariableDataType data_type; + std::vector data; + + private: + ProgramUniformVariableValue(ProgramUniformVariableDataType data_type, const void* ptr, size_t element_byte_size, size_t length = 1); +}; + +// represents a uniform variable definition +struct ProgramUniformVariableDefinition { + std::string_view name; + ProgramUniformVariableDataType data_type; +}; + +// data type of constant +enum class ProgramConstantDataType { + Float32, + Float16, + Uint32, + Int32, + Bool +}; +std::ostream& operator<<(std::ostream& os, ProgramConstantDataType); + +constexpr std::string_view ProgramConstantDataTypeName[] = {"f32", "f16", "u32", "i32", "bool"}; + +// represents a constant in a program +struct ProgramConstant { + constexpr ProgramConstant(std::string_view name, float value) : name{name}, type{ProgramConstantDataType::Float32}, f32{value} {} + constexpr ProgramConstant(std::string_view name, uint32_t value) : name{name}, type{ProgramConstantDataType::Uint32}, u32{value} {} + constexpr ProgramConstant(std::string_view name, int32_t value) : name{name}, type{ProgramConstantDataType::Int32}, i32{value} {} + constexpr ProgramConstant(std::string_view name, MLFloat16 value) : name{name}, type{ProgramConstantDataType::Float16}, f16{value} {} + constexpr ProgramConstant(std::string_view name, bool value) : name{name}, type{ProgramConstantDataType::Bool}, boolean{value} {} + + std::string_view name; + ProgramConstantDataType type; + union { + float f32; + uint32_t u32; + int32_t i32; + MLFloat16 f16; + bool boolean; + }; +}; + +// represents a runtime value of an overridable constant +struct ProgramOverridableConstantValue { + constexpr ProgramOverridableConstantValue() : type{}, u32{}, has_value{false} {} // representing not overriding + constexpr ProgramOverridableConstantValue(float value) : type{ProgramConstantDataType::Float32}, f32{value}, has_value{true} {} + constexpr ProgramOverridableConstantValue(uint32_t value) : type{ProgramConstantDataType::Uint32}, u32{value}, has_value{true} {} + constexpr ProgramOverridableConstantValue(int32_t value) : type{ProgramConstantDataType::Int32}, i32{value}, has_value{true} {} + constexpr ProgramOverridableConstantValue(MLFloat16 value) : type{ProgramConstantDataType::Float16}, f16{value}, has_value{true} {} + constexpr ProgramOverridableConstantValue(bool value) : type{ProgramConstantDataType::Bool}, boolean{value}, has_value{true} {} + + ProgramConstantDataType type; + union { + float f32; + uint32_t u32; + int32_t i32; + MLFloat16 f16; + bool boolean; + }; + bool has_value; +}; + +// represents an overridable constant definition. may or may not have a default value. +struct ProgramOverridableConstantDefinition { + constexpr ProgramOverridableConstantDefinition(std::string_view name, ProgramConstantDataType type) + : name{name}, type{type}, u32{}, has_default_value{false} {} + constexpr ProgramOverridableConstantDefinition(std::string_view name, float value) + : name{name}, type{ProgramConstantDataType::Float32}, f32{value}, has_default_value{true} {} + constexpr ProgramOverridableConstantDefinition(std::string_view name, uint32_t value) + : name{name}, type{ProgramConstantDataType::Uint32}, u32{value}, has_default_value{true} {} + constexpr ProgramOverridableConstantDefinition(std::string_view name, int32_t value) + : name{name}, type{ProgramConstantDataType::Int32}, i32{value}, has_default_value{true} {} + constexpr ProgramOverridableConstantDefinition(std::string_view name, MLFloat16 value) + : name{name}, type{ProgramConstantDataType::Float16}, f16{value}, has_default_value{true} {} + constexpr ProgramOverridableConstantDefinition(std::string_view name, bool value) + : name{name}, type{ProgramConstantDataType::Bool}, boolean{value}, has_default_value{true} {} + + std::string_view name; + ProgramConstantDataType type; + union { + float f32; + uint32_t u32; + int32_t i32; + MLFloat16 f16; + bool boolean; + }; + bool has_default_value; +}; + +// represents whether the program shader depends on the type, rank, or shape of an input/output tensor +enum class ProgramInputTensorDependency : int { + None = 0, + Type = 1, + Rank = 2, + Shape = 4, + TypeAndRank = Type | Rank, + TypeAndShape = Type | Shape, +}; +std::ostream& operator<<(std::ostream& os, ProgramInputTensorDependency); + +inline ProgramInputTensorDependency operator|(ProgramInputTensorDependency a, ProgramInputTensorDependency b) { + return (ProgramInputTensorDependency)((int&)a | (int&)b); +} +inline ProgramInputTensorDependency operator&(ProgramInputTensorDependency a, ProgramInputTensorDependency b) { + return (ProgramInputTensorDependency)((int&)a & (int&)b); +} +inline ProgramInputTensorDependency& operator|=(ProgramInputTensorDependency& a, ProgramInputTensorDependency b) { + return (ProgramInputTensorDependency&)((int&)a |= (int&)b); +} +inline ProgramInputTensorDependency& operator&=(ProgramInputTensorDependency& a, ProgramInputTensorDependency b) { + return (ProgramInputTensorDependency&)((int&)a &= (int&)b); +} + +struct ProgramInput { + const Tensor* tensor; + ProgramInputTensorDependency dependency; +}; + +constexpr SafeInt WORKGROUP_SIZE = 64; + +// represents the scope of a variable in a shader program. +// +// this is not a full list of all possible variable scopes in shader programs. +// it only includes what are used in WebGPU EP. +enum class ProgramVariableScope { + Input = 0, // storage buffer variable with access mode "read" + Output = 1, // storage buffer variable with access mode "read_write" + Local = 2, // local variable + + Count // should always be the last element +}; + +// data type of variable +// +// this is not a full list of all possible data types in shader programs. +// it only includes what are used in WebGPU EP. +enum class ProgramVariableDataType { + InvalidType = -1, + Float32, + Vec2Float32, + Vec4Float32, + Float16, + Vec2Float16, + Vec4Float16, + Int32, + Vec2Int32, + Vec4Int32, + Uint32, + Vec2Uint32, + Vec4Uint32, + Int64, + Uint64, + Vec4Bool, +}; + +ProgramVariableDataType ToProgramVariableDataType(int32_t element_type, int component = 1); + +namespace detail { +class ProgramWrapper; +} + +struct ProgramMetadata; + +class ProgramBase { + public: + // + // chain-style methods for setting properties + // + + // set the cache hint for the program + template + ProgramBase& CacheHint(CacheHintArgs&&... args) { + cache_hint_ = StringJoin("|", std::forward(args)...); + } + + // set one or more program inputs + ProgramBase& Inputs(std::initializer_list inputs); + // set one or more program outputs + ProgramBase& Outputs(std::initializer_list outputs); + + // set the size of dispatch groups. Y and Z are 1 if not specified. + ProgramBase& DispatchGroupSize(uint32_t x); + // set the size of dispatch groups. Z is 1 if not specified. + ProgramBase& DispatchGroupSize(uint32_t x, uint32_t y); + // set the size of dispatch groups. + ProgramBase& DispatchGroupSize(uint32_t x, uint32_t y, uint32_t z); + + // set the size of a workgroup grid. Y and Z are 1 if not specified. + ProgramBase& WorkgroupSize(uint32_t x); + // set the size of a workgroup grid. Z is 1 if not specified. + ProgramBase& WorkgroupSize(uint32_t x, uint32_t y); + // set the size of a workgroup grid. + ProgramBase& WorkgroupSize(uint32_t x, uint32_t y, uint32_t z); + + // set the uniform variables. + // + // the specified uniform variables should match the uniform definition in the class, + // specified by macro WEBGPU_PROGRAM_DEFINE_UNIFORM_VARIABLES. + ProgramBase& UniformVariables(std::initializer_list variables); + + // set the overridable constants + // + // the specified overridable constants should match the overridable constant definition in the class, + // specified by macro WEBGPU_PROGRAM_DEFINE_OVERRIDABLE_CONSTANTS. + ProgramBase& OverridableConstants(std::initializer_list overridable_constants); + + // + // shader code generation + // + + virtual Status GenerateShaderCode(ShaderHelper& shader) const = 0; + + // + // abstract methods for getting metadata + // + // A derived class may contain any of the following static members: + // + // \code{.cpp} + // // define a list of constant that used in the shader program + // static constexpr const ProgramConstant constants[] = { ... }; + // + // // define a list of overridable constant that used in the shader program + // static constexpr const ProgramOverridableConstantDefinition overridable_constants[] = { ... }; + // + // // define a list of uniform variables that used in the shader program + // static constexpr const ProgramUniformVariableDefinition uniform_variables[] = { ... }; + // \endcode + // + // If those static members exist, the value of them will be used to generate the metadata. + virtual ProgramMetadata GetMetadata() const = 0; + + // + // Properties Getters + // + + inline const std::string& Name() const { return name_; } + inline const std::string& CacheHint() const { return cache_hint_; } + inline const std::vector& Inputs() const { return inputs_; } + inline const std::vector& Outputs() const { return outputs_; } + inline uint32_t DispatchGroupSizeX() const { return dispatch_group_size_x_; } + inline uint32_t DispatchGroupSizeY() const { return dispatch_group_size_y_; } + inline uint32_t DispatchGroupSizeZ() const { return dispatch_group_size_z_; } + inline uint32_t WorkgroupSizeX() const { return workgroup_size_x_; } + inline uint32_t WorkgroupSizeY() const { return workgroup_size_y_; } + inline uint32_t WorkgroupSizeZ() const { return workgroup_size_z_; } + inline const std::vector& UniformVariables() const { return variables_; } + inline const std::vector& OverridableConstants() const { return overridable_constants_; } + + protected: + virtual ~ProgramBase() = default; + + private: + // Make the constructor private to prevent direct instantiation or inheritance from this class + // Use the Program template class as base class to create a new program class + explicit ProgramBase(const std::string& name); + + std::string name_; + std::string cache_hint_; + std::vector inputs_; + std::vector outputs_; + + uint32_t dispatch_group_size_x_; + uint32_t dispatch_group_size_y_; + uint32_t dispatch_group_size_z_; + + uint32_t workgroup_size_x_; + uint32_t workgroup_size_y_; + uint32_t workgroup_size_z_; + + std::vector variables_; + std::vector overridable_constants_; + + friend class detail::ProgramWrapper; +}; + +namespace detail { +// class ProgramWrapper is for accessing private constructor of ProgramBase. +// only ProgramWrapper can access the constructor of ProgramBase because ProgramWrapper is the only friend class of +// ProgramBase. This design is used to prevent direct instantiation or inheritance from ProgramBase. +class ProgramWrapper : public ProgramBase { + protected: + template + ProgramWrapper(Args&&... args) : ProgramBase{std::forward(args)...} {} +}; + +#if defined(ORT_WEBGPU_REGISTER_DERIVED_PROGRAM_CLASS_TYPE_CHECK) +#error "macro ORT_WEBGPU_REGISTER_DERIVED_PROGRAM_CLASS_TYPE_CHECK is already defined" +#endif + +#define ORT_WEBGPU_REGISTER_DERIVED_PROGRAM_CLASS_TYPE_CHECK(identifier, element_type) \ + private: \ + template \ + static auto test_has_##identifier(int)->decltype(U::identifier, std::true_type{}); /* checks if member exists */ \ + template \ + static auto test_has_##identifier(...)->std::false_type; \ + \ + template && /* - is array */ \ + std::is_const_v && /* - has "const" modifier */ \ + std::is_convertible_v && /* - can convert to a const pointer */ \ + !std::is_member_pointer_v>> /* - is static */ \ + static auto test_has_##identifier##_with_correct_type(int)->std::true_type; \ + template \ + static auto test_has_##identifier##_with_correct_type(...)->std::false_type; \ + \ + public: \ + static constexpr bool has_##identifier = decltype(test_has_##identifier(0))::value; \ + static constexpr bool has_##identifier##_with_correct_type = decltype(test_has_##identifier##_with_correct_type(0))::value + +// the following template class checks whether certain static members exist in the derived class (SFINAE) +template +class DerivedProgramClassTypeCheck { + ORT_WEBGPU_REGISTER_DERIVED_PROGRAM_CLASS_TYPE_CHECK(constants, ProgramConstant); + ORT_WEBGPU_REGISTER_DERIVED_PROGRAM_CLASS_TYPE_CHECK(overridable_constants, ProgramOverridableConstantDefinition); + ORT_WEBGPU_REGISTER_DERIVED_PROGRAM_CLASS_TYPE_CHECK(uniform_variables, ProgramUniformVariableDefinition); +}; + +// compile-time tests for the type check +namespace test { + +struct TestClass_Empty {}; +struct TestClass_0 { + int b; +}; +struct TestClass_1 { + int a; +}; +struct TestClass_2 { + const int a; +}; +struct TestClass_3 { + const int a[2]; +}; +struct TestClass_4 { + static constexpr int a[] = {0}; +}; +struct TestClass_5 { + static int a[]; +}; +struct TestClass_6 { + static const int a[]; +}; + +template +class TestTypeCheck { + ORT_WEBGPU_REGISTER_DERIVED_PROGRAM_CLASS_TYPE_CHECK(a, int); +}; + +static_assert(!TestTypeCheck::has_a); +static_assert(!TestTypeCheck::has_a_with_correct_type); +static_assert(!TestTypeCheck::has_a); +static_assert(!TestTypeCheck::has_a_with_correct_type); +static_assert(TestTypeCheck::has_a); +static_assert(!TestTypeCheck::has_a_with_correct_type); +static_assert(TestTypeCheck::has_a); +static_assert(!TestTypeCheck::has_a_with_correct_type); +static_assert(TestTypeCheck::has_a); +static_assert(!TestTypeCheck::has_a_with_correct_type); +static_assert(TestTypeCheck::has_a); +static_assert(TestTypeCheck::has_a_with_correct_type); +static_assert(TestTypeCheck::has_a); +static_assert(!TestTypeCheck::has_a_with_correct_type); +static_assert(TestTypeCheck::has_a); +static_assert(TestTypeCheck::has_a_with_correct_type); + +} // namespace test + +#undef ORT_WEBGPU_REGISTER_DERIVED_PROGRAM_CLASS_TYPE_CHECK + +} // namespace detail + +struct ProgramMetadata { + gsl::span constants; + gsl::span overridable_constants; + gsl::span uniform_variables; +}; + +template +class Program : public detail::ProgramWrapper { + public: + template + Program(Args&&... args) : detail::ProgramWrapper{std::forward(args)...} {} + + virtual ProgramMetadata GetMetadata() const final { + ProgramMetadata metadata; + if constexpr (detail::DerivedProgramClassTypeCheck::has_constants) { + constexpr const ProgramConstant* ptr = T::constants; + constexpr size_t len = sizeof(T::constants) / sizeof(ProgramConstant); + + static_assert(detail::DerivedProgramClassTypeCheck::has_constants_with_correct_type && + sizeof(T::constants) % sizeof(ProgramConstant) == 0, + "Derived class of \"Program\" has member \"constants\" but its type is incorrect. " + "Please use macro WEBGPU_PROGRAM_DEFINE_CONSTANTS() to declare constants."); + + metadata.constants = {ptr, len}; + } else { + metadata.constants = {}; + } + + if constexpr (detail::DerivedProgramClassTypeCheck::has_overridable_constants) { + constexpr const ProgramOverridableConstantDefinition* ptr = T::overridable_constants; + constexpr size_t len = sizeof(T::overridable_constants) / sizeof(ProgramOverridableConstantDefinition); + + static_assert(detail::DerivedProgramClassTypeCheck::has_overridable_constants_with_correct_type && + sizeof(T::overridable_constants) % sizeof(ProgramOverridableConstantDefinition) == 0, + "Derived class of \"Program\" has member \"overridable_constants\" but its type is incorrect. " + "Please use macro WEBGPU_PROGRAM_DEFINE_OVERRIDABLE_CONSTANTS() to declare overridable constants."); + + metadata.overridable_constants = {ptr, len}; + } else { + metadata.overridable_constants = {}; + } + + if constexpr (detail::DerivedProgramClassTypeCheck::has_uniform_variables) { + constexpr const ProgramUniformVariableDefinition* ptr = T::uniform_variables; + constexpr size_t len = sizeof(T::uniform_variables) / sizeof(ProgramUniformVariableDefinition); + + static_assert(detail::DerivedProgramClassTypeCheck::has_uniform_variables_with_correct_type && + sizeof(T::uniform_variables) % sizeof(ProgramUniformVariableDefinition) == 0, + "Derived class of \"Program\" has member \"uniform_variables\" but its type is incorrect. " + "Please use macro WEBGPU_PROGRAM_DEFINE_UNIFORM_VARIABLES() to declare uniform variables."); + + metadata.uniform_variables = {ptr, len}; + } else { + metadata.uniform_variables = {}; + } + + return metadata; + } +}; + +#define WEBGPU_PROGRAM_DEFINE_CONSTANTS(...) \ + static constexpr const onnxruntime::webgpu::ProgramConstant constants[] = {__VA_ARGS__} + +#define WEBGPU_PROGRAM_DEFINE_OVERRIDABLE_CONSTANTS(...) \ + static constexpr const onnxruntime::webgpu::ProgramOverridableConstantDefinition overridable_constants[] = {__VA_ARGS__} + +#define WEBGPU_PROGRAM_DEFINE_UNIFORM_VARIABLES(...) \ + static constexpr const onnxruntime::webgpu::ProgramUniformVariableDefinition uniform_variables[] = {__VA_ARGS__} + +} // namespace webgpu +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/webgpu/program_cache_key.cc b/onnxruntime/core/providers/webgpu/program_cache_key.cc new file mode 100644 index 0000000000000..d720c55fb5427 --- /dev/null +++ b/onnxruntime/core/providers/webgpu/program_cache_key.cc @@ -0,0 +1,90 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "core/providers/webgpu/program_cache_key.h" + +#include "core/providers/webgpu/shader_macros.h" + +namespace onnxruntime { +namespace webgpu { + +std::string CalculateProgramCacheKey(const ProgramBase& program, bool is_1d_dispatch) { + std::ostringstream ss; + ss.imbue(std::locale::classic()); + + // final key format: + // =[]:::: + // + // = ||... + // = ,, + // = + // = ||... + // = + // = ||... + // = ; + ss << program.Name(); + + // append custom cache hint if any + if (auto& hint = program.CacheHint(); !hint.empty()) { + ss << "[" D("CacheHint=") << hint << "]"; + } + + // append workgroup size if overridden + if (auto x = program.WorkgroupSizeX(), y = program.WorkgroupSizeY(), z = program.WorkgroupSizeZ(); + x != 0 || y != 0 || z != 0) { + ss << ":" D("WorkgroupSize="); + // only append non-zero values. zero values are considered as use default + if (x > 0) { + ss << x; + } + ss << ","; + if (y > 0) { + ss << y; + } + ss << ","; + if (z > 0) { + ss << z; + } + } + + ss << ":" D("DispatchDim=") << is_1d_dispatch ? "1" : "3"; + ss << ":" D("UniformSizes="); + bool first = true; + for (const auto& uniform : program.UniformVariables()) { + if (first) { + first = false; + } else { + ss << "|"; + } + if (uniform.length > 0) { + ss << uniform.length; + } + } + ss << ":" D("Inputs="); + first = true; + for (const auto& input : program.Inputs()) { + if (first) { + first = false; + } else { + ss << "|"; + } + if ((input.dependency & ProgramInputTensorDependency::Type) == ProgramInputTensorDependency::Type) { +#ifndef NDEBUG // if debug build + ss << DataTypeImpl::ToString(input.tensor->DataType()); +#else + ss << input.tensor->GetElementType(); +#endif + } + ss << ";"; + if ((input.dependency & ProgramInputTensorDependency::Rank) == ProgramInputTensorDependency::Rank) { + ss D("Rank=") << input.tensor->Shape().NumDimensions(); + } else if ((input.dependency & ProgramInputTensorDependency::Shape) == ProgramInputTensorDependency::Shape) { + ss D("Dims=") << input.tensor->Shape().ToString(); + } + } + + return ss.str(); +} + +} // namespace webgpu +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/webgpu/program_cache_key.h b/onnxruntime/core/providers/webgpu/program_cache_key.h new file mode 100644 index 0000000000000..22ba19ebd0f25 --- /dev/null +++ b/onnxruntime/core/providers/webgpu/program_cache_key.h @@ -0,0 +1,16 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include + +#include "core/providers/webgpu/program.h" + +namespace onnxruntime { +namespace webgpu { + +std::string CalculateProgramCacheKey(const ProgramBase& program, bool is_1d_dispatch); + +} // namespace webgpu +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/webgpu/program_manager.cc b/onnxruntime/core/providers/webgpu/program_manager.cc new file mode 100644 index 0000000000000..de228a038b7db --- /dev/null +++ b/onnxruntime/core/providers/webgpu/program_manager.cc @@ -0,0 +1,188 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "core/common/common.h" +#include "core/common/safeint.h" + +#include "core/common/common.h" +#include "core/common/logging/logging.h" + +#include "core/providers/webgpu/program_manager.h" +#include "core/providers/webgpu/shader_helper.h" + +namespace onnxruntime { +namespace webgpu { + +ProgramArtifact::ProgramArtifact(const ProgramBase& program, wgpu::ComputePipeline&& compute_pipeline) + : name{program.Name()}, compute_pipeline{compute_pipeline} { + // prepare uniform info + size_t current_offset = 0; + for (const auto& uniform : program.UniformVariables()) { + bool is_f16 = uniform.data_type == ProgramUniformVariableDataType::Float16; + size_t length = uniform.length; + size_t element_size = ProgramUniformVariableDataTypeSize[static_cast(uniform.data_type)]; + // https://www.w3.org/TR/WGSL/#alignof + size_t base_alignment = is_f16 + ? (length > 4 ? 16 : length > 2 ? 8 + : length * element_size) + : (length > 2 ? 16 : length * element_size); + size_t struct_size = is_f16 && length <= 4 ? length * element_size : 16; + + current_offset = (current_offset + base_alignment - 1) / base_alignment * base_alignment; + uniforms.push_back({uniform.data_type, current_offset, length}); + + // For non-float16 type, when length > 4, the uniform variable is of type array,N>, where + // N = ceil(data.length / 4) and SizeOf(vec4) = 16. The total byte length is N * SizeOf(vec4). + // For float16 type, when length > 4, the uniform variable is of type array,N>, where + // N = ceil(data.length / 8) and SizeOf(mat2x4) = 16. The total byte length is N * SizeOf(mat2x4). + size_t element_per_struct = is_f16 ? 8 : 4; + current_offset += + length > 4 ? (length + element_per_struct - 1) / element_per_struct * struct_size : length * element_size; + } + + // Meet alignment of struct here: https://www.w3.org/TR/WGSL/#alignment-and-size. For simplicity, set + // max_alignment_of_field to 16 since the underlying buffer has been rounded up to 16. + const int max_alignment_of_field = 16; + uniform_total_size = (current_offset + max_alignment_of_field - 1) / max_alignment_of_field * max_alignment_of_field; +} + +Status ProgramManager::NormalizeDispatchGroupSize(uint32_t& x, uint32_t& y, uint32_t& z) const { + ORT_RETURN_IF(x == 0 || y == 0 || z == 0, "Invalid dispatch group size (", x, ", ", y, ", ", z, ")"); + + auto limit_per_dimension = limits_.maxComputeWorkgroupsPerDimension; + if (x > limit_per_dimension || y > limit_per_dimension || z > limit_per_dimension) { + auto size = static_cast(x) * static_cast(y) * static_cast(z); + SafeInt dispatch_avg = std::ceil(std::sqrt(size)); + if (dispatch_avg > limit_per_dimension) { + dispatch_avg = std::ceil(std::cbrt(size)); + ORT_RETURN_IF(dispatch_avg > limit_per_dimension, "The dispatch group size exceeds WebGPU maximum."); + x = y = z = dispatch_avg; + } else { + x = y = dispatch_avg; + z = 1; + } + } + return Status::OK(); +} + +Status ProgramManager::Build(const ProgramBase& program, + const ProgramMetadata& program_metadata, +#ifndef NDEBUG // if debug build + const std::string& program_key, +#endif + uint32_t normalized_dispatch_x, + uint32_t normalized_dispatch_y, + uint32_t normalized_dispatch_z, + wgpu::ComputePipeline& compute_pipeline) const { + ShaderHelper shader_helper{program, + program_metadata, + device_, + limits_, + normalized_dispatch_x, + normalized_dispatch_y, + normalized_dispatch_z}; + ORT_RETURN_IF_ERROR(shader_helper.Init()); + + ORT_RETURN_IF_ERROR(program.GenerateShaderCode(shader_helper)); + + // code is a large std::string that contains the final shader code + auto code = shader_helper.GetFinalSourceCode(); + + LOGS_DEFAULT(VERBOSE) << "\n=== WebGPU Shader code [" << program.Name() +#ifndef NDEBUG // if debug build + << ", Key=\"" << program_key << "\"" +#endif + << "] Start ===\n\n" + << code + << "\n=== WebGPU Shader code [" << program.Name() +#ifndef NDEBUG // if debug build + << ", Key=\"" << program_key << "\"" +#endif + << "] End ===\n"; + + wgpu::ShaderModuleWGSLDescriptor wgsl_descriptor{}; + wgsl_descriptor.code = code.c_str(); + + wgpu::ShaderModuleDescriptor descriptor{}; + descriptor.nextInChain = &wgsl_descriptor; + + auto shader_module = device_.CreateShaderModule(&descriptor); + + // process overridable constants if available + size_t constant_count = program.OverridableConstants().size(); + + // making a copy of the constant names is required because they are stored as std::string_view in the program + // metadata. A value of std::string_view is not guaranteed to be a C-stlye string (null-terminated) and hence + // cannot be used directly in the WebGPU API (which expects a const char*). + std::vector constant_names; + constant_names.reserve(constant_count); + std::vector constant_entries; + constant_entries.reserve(constant_count); + for (size_t i = 0; i < constant_count; ++i) { + const auto& constant_override = program.OverridableConstants()[i]; + const auto& constant_def = program_metadata.overridable_constants[i]; + + if (constant_override.has_value) { + double value = 0; + switch (constant_override.type) { + case ProgramConstantDataType::Bool: + value = constant_override.boolean ? 1 : 0; + break; + case ProgramConstantDataType::Float16: + // convert f16(MLFloat16) -> f32(float) -> f64(double) + // because the value of a constant must be a double in WebGPU API, it is expensive to use f16 overridable constants. + value = constant_override.f16.ToFloat(); + break; + case ProgramConstantDataType::Float32: + value = constant_override.f32; + break; + case ProgramConstantDataType::Int32: + value = constant_override.i32; + break; + case ProgramConstantDataType::Uint32: + value = constant_override.u32; + break; + } + + const auto& name_string = constant_names.emplace_back(constant_def.name); + wgpu::ConstantEntry entry{}; + entry.key = name_string.c_str(); + entry.value = value; + constant_entries.push_back(std::move(entry)); + } + } + + wgpu::ProgrammableStageDescriptor compute_stage{}; + compute_stage.module = shader_module; + compute_stage.entryPoint = "main"; + if (!constant_entries.empty()) { + compute_stage.constants = constant_entries.data(); + compute_stage.constantCount = constant_entries.size(); + } + + wgpu::ComputePipelineDescriptor pipeline_descriptor{}; + pipeline_descriptor.compute = compute_stage; +#ifndef NDEBUG // if debug build + pipeline_descriptor.label = program.Name().c_str(); +#endif + + compute_pipeline = device_.CreateComputePipeline(&pipeline_descriptor); + + return Status(); +} + +const ProgramArtifact* ProgramManager::Get(const std::string& key) const { + auto result = programs_.find(key); + if (result != programs_.end()) { + return &result->second; + } + + return nullptr; +} + +const ProgramArtifact* ProgramManager::Set(const std::string& key, ProgramArtifact&& program) { + return &(programs_.emplace(key, std::move(program)).first->second); +} + +} // namespace webgpu +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/webgpu/program_manager.h b/onnxruntime/core/providers/webgpu/program_manager.h new file mode 100644 index 0000000000000..9d1b7655c8640 --- /dev/null +++ b/onnxruntime/core/providers/webgpu/program_manager.h @@ -0,0 +1,71 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#ifdef __EMSCRIPTEN__ +#include +#endif + +#include +#include + +#include + +#include "core/common/common.h" + +#include "core/providers/webgpu/program.h" + +namespace onnxruntime { +class Tensor; + +namespace webgpu { + +struct ProgramUniformInfo { + ProgramUniformVariableDataType data_type; + size_t offset; + size_t length; +}; + +class ProgramArtifact { + public: + ProgramArtifact(const ProgramBase& program, wgpu::ComputePipeline&& compute_pipeline); + + std::string name; + wgpu::ComputePipeline compute_pipeline; + std::vector uniforms; + size_t uniform_total_size; + + ProgramArtifact(ProgramArtifact&&) = default; + ProgramArtifact& operator=(ProgramArtifact&&) = default; + + private: + ORT_DISALLOW_COPY_AND_ASSIGNMENT(ProgramArtifact); +}; + +class ProgramManager { + public: + ProgramManager(const wgpu::Device& device, const wgpu::Limits& limits) : device_(device), limits_(limits) {} + + Status NormalizeDispatchGroupSize(uint32_t& x, uint32_t& y, uint32_t& z) const; + + Status Build(const ProgramBase& program, + const ProgramMetadata& metadata, +#ifndef NDEBUG // if debug build + const std::string& program_key, +#endif + uint32_t normalized_dispatch_x, + uint32_t normalized_dispatch_y, + uint32_t normalized_dispatch_z, + wgpu::ComputePipeline& compute_pipeline) const; + const ProgramArtifact* Get(const std::string& key) const; + const ProgramArtifact* Set(const std::string& key, ProgramArtifact&& program); + + private: + std::unordered_map programs_; + const wgpu::Device& device_; + const wgpu::Limits& limits_; +}; + +} // namespace webgpu +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/webgpu/shader_helper.cc b/onnxruntime/core/providers/webgpu/shader_helper.cc new file mode 100644 index 0000000000000..203f11ff90000 --- /dev/null +++ b/onnxruntime/core/providers/webgpu/shader_helper.cc @@ -0,0 +1,204 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include +#include +#include + +#include "core/session/onnxruntime_c_api.h" + +#include "core/providers/webgpu/shader_helper.h" + +namespace onnxruntime { +namespace webgpu { + +ShaderHelper::ShaderHelper(const ProgramBase& program, + const ProgramMetadata& program_metadata, + const wgpu::Device& device, + const wgpu::Limits& limits, + uint32_t dispatch_group_size_x, + uint32_t dispatch_group_size_y, + uint32_t dispatch_group_size_z) + : device_{device}, + limits_{limits}, + dispatch_group_size_x_{dispatch_group_size_x}, + dispatch_group_size_y_{dispatch_group_size_y}, + dispatch_group_size_z_{dispatch_group_size_z}, + program_{program}, + program_metadata_{program_metadata}, + use_f16_{false} { +} + +Status ShaderHelper::Init() { + // dispatch group size is normalized so no need to validate it here + + // validate workgroup size + auto workgroup_size_x = program_.WorkgroupSizeX(); + auto workgroup_size_y = program_.WorkgroupSizeY(); + auto workgroup_size_z = program_.WorkgroupSizeZ(); + + ORT_RETURN_IF_NOT(workgroup_size_x > 0 && workgroup_size_y > 0 && workgroup_size_z > 0, + "Workgroup size must be greater than 0"); + ORT_RETURN_IF_NOT(workgroup_size_x <= limits_.maxComputeWorkgroupSizeX && + workgroup_size_y <= limits_.maxComputeWorkgroupSizeY && + workgroup_size_z <= limits_.maxComputeWorkgroupSizeZ, + "Workgroup size exceeds the maximum allowed size [", + limits_.maxComputeWorkgroupSizeX, ", ", + limits_.maxComputeWorkgroupSizeY, ", ", + limits_.maxComputeWorkgroupSizeZ, "]"); + + ORT_RETURN_IF_NOT(workgroup_size_x * workgroup_size_y * workgroup_size_z <= limits_.maxComputeInvocationsPerWorkgroup, + "Workgroup size exceeds the maximum allowed invocations ", limits_.maxComputeInvocationsPerWorkgroup); + + // init body string stream + bool is_1d_dispatch = dispatch_group_size_y_ == 1 && dispatch_group_size_z_ == 1; + body_.imbue(std::locale::classic()); + + // append header for main function so it is ready for user to append main function body + body_ << "@compute @workgroup_size(workgroup_size_x, workgroup_size_y, workgroup_size_z)\n" + "fn main(@builtin(global_invocation_id) global_id : vec3,\n" + " @builtin(workgroup_id) workgroup_id : vec3,\n" + " @builtin(local_invocation_id) local_id : vec3"; + if (!is_1d_dispatch) { + body_ << ",\n" + " @builtin(local_invocation_index) local_idx : u32,\n" + " @builtin(num_workgroups) num_workgroups : vec3"; + } + body_ << ") {\n"; + if (is_1d_dispatch) { + body_ << " let global_idx = global_id.x;\n" + " let local_idx = local_id.x;\n"; + } else { + body_ << " let global_idx = (workgroup_id.z * num_workgroups[0] * num_workgroups[1] + workgroup_id.y * num_workgroups[0] + workgroup_id.x)\n" + " * (workgroup_size_x * workgroup_size_y * workgroup_size_z) + local_idx;\n"; + } + + // init additional implementation string stream + additional_implementation_.imbue(std::locale::classic()); + + return Status::OK(); +} + +std::string ShaderHelper::GetFinalSourceCode() { + std::ostringstream ss; + ss.imbue(std::locale::classic()); + + // + // Section feature enabling + // + if (use_f16_) { + ORT_ENFORCE(device_.HasFeature(wgpu::FeatureName::ShaderF16), "Program ", program_.Name(), " requires f16 but the device does not support it."); + ss << "enable f16;\n"; + } + + // + // Section constants + // + ss << "\nconst workgroup_size_x: u32 = " << program_.WorkgroupSizeX() + << ";\nconst workgroup_size_y: u32 = " << program_.WorkgroupSizeY() + << ";\nconst workgroup_size_z: u32 = " << program_.WorkgroupSizeZ() << ";\n"; + + for (const auto& constant : program_metadata_.constants) { + ss << "const " << constant.name << ": " << constant.type << " = "; + WriteConstantValue(ss, constant); + ss << ";\n"; + } + + size_t override_constant_count = program_metadata_.overridable_constants.size(); + for (size_t i = 0; i < override_constant_count; ++i) { + // size and type are previously checked to match + const auto& constant_def = program_metadata_.overridable_constants[i]; + const auto& constant_override = program_.OverridableConstants()[i]; + + ss << "override " << constant_def.name << ": " << constant_def.type << " = "; + if (constant_override.has_value) { + WriteConstantValue(ss, constant_override); + } else { + WriteConstantValue(ss, constant_def); + } + ss << ";\n"; + } + + // + // Input/output variables + // + int variable_count = 0; + for (const auto& input : vars_[static_cast(ProgramVariableScope::Input)]) { + ss << "@group(0) @binding(" << variable_count++ << ") var " << input.name_ << ": array<" << input.StorageType() << ">;\n"; + } + for (const auto& output : vars_[static_cast(ProgramVariableScope::Output)]) { + ss << "@group(0) @binding(" << variable_count++ << ") var " << output.name_ << ": array<" << output.StorageType() << ">;\n"; + } + + // + // uniform variables + // + if (std::any_of(program_.UniformVariables().cbegin(), + program_.UniformVariables().cend(), + [](const ProgramUniformVariableValue& x) { return x.length > 0; })) { + bool first = true; + ss << "struct Uniforms {\n"; + + size_t uniform_count = program_.UniformVariables().size(); + for (size_t i = 0; i < uniform_count; i++) { + const auto& uniform_def = program_metadata_.uniform_variables[i]; + const auto& uniform_value = program_.UniformVariables()[i]; + + const auto& name = uniform_def.name; + const auto& data_type = uniform_def.data_type; + const auto length = uniform_value.length; + + if (first) { + first = false; + } else { + ss << ",\n"; + } + + auto alignment = (data_type == ProgramUniformVariableDataType::Float16 && length > 4) ? "@align(16) " : ""; + ss << " " << alignment << name << ": "; + if (length > 4) { + if (data_type == ProgramUniformVariableDataType::Float16) { + size_t array_size = (length + 7) / 8; + ss << "array, " << array_size << ">"; + } else { + size_t array_size = (length + 3) / 4; + ss << "array, " << array_size << ">"; + } + } else if (length > 1) { + ss << "vec" << length << "<" << data_type << ">"; + } else { + ss << data_type; + } + } + + ss << "};\n" + "@group(0) @binding(" + << variable_count << ") var uniforms: Uniforms;\n"; + } + + // + // Indices helper + // + ss << "\n"; + // for (const auto& group : vars_) { + // } + + // + // Additional Implementation + // + ss << additional_implementation_.str(); + additional_implementation_.str(""); + + // + // Main Function Body + // + ss << body_.str(); + body_.str(""); + ss << "\n" + "}\n"; + + return ss.str(); +} + +} // namespace webgpu +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/webgpu/shader_helper.h b/onnxruntime/core/providers/webgpu/shader_helper.h new file mode 100644 index 0000000000000..ac6dfebfef816 --- /dev/null +++ b/onnxruntime/core/providers/webgpu/shader_helper.h @@ -0,0 +1,161 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include + +#ifdef __EMSCRIPTEN__ +#include +#endif + +#include + +#include "core/common/safeint.h" +#include "core/framework/tensor_shape.h" + +#include "core/providers/webgpu/program.h" +#include "core/providers/webgpu/shader_variable.h" + +namespace onnxruntime { +namespace webgpu { + +class ShaderHelper final { + // The content of a shader code is composed of the following parts: + // + // ** + // ** section: feature sets definition + // ** + // // this sections enable features like "enable f16;". need to be defined at the beginning of the shader. + // + // ** + // ** section: constants and overridable constants + // ** + // // this section defines constants and overridable constants. + // - constants are defined as "const a:f32 = 1.0;". It's hard coded in the shader. + // - overridable constants are defined as "override a:f32 = 1.0;" (may override or not) + // or "override b:u32;" (must override) + // the value can be overriden by pipeline creation config. + // + // ** + // ** section: inputs and outputs + // ** + // // this section defines input and output variables. + // user can call shader_helper.AddVariable() to add input and output variables. + // + // ** + // ** section: uniforms + // ** + // // this section defines uniform type and variables. + // + // ** + // ** section: indices helper generated utility functions + // ** + // // this section defines utility functions to calculate indices. + // + // ** + // ** section: additional implementation + // ** + // // this section contains additional implementation provided by the user. + // user can call shader_helper.AppendImplementation() to append additional implementation. + // + // ** + // ** section: main function + // ** + // // this section contains the main function of the shader. + // user can call shader_helper.MainFunctionBody() to set the main function body. + // + + public: + ShaderHelper(const ProgramBase& program, + const ProgramMetadata& program_metadata, + const wgpu::Device& device, + const wgpu::Limits& limits, + uint32_t dispatch_group_size_x, + uint32_t dispatch_group_size_y, + uint32_t dispatch_group_size_z); + + Status Init(); + + const ShaderVariable& AddVariable(ProgramVariableScope scope, const std::string& name, ProgramVariableDataType type, int rank = 1) { + return AddVariableImpl(scope, name, type, rank); + } + const ShaderVariable& AddVariable(ProgramVariableScope scope, const std::string& name, ProgramVariableDataType type, const TensorShape& dims) { + return AddVariableImpl(scope, name, type, dims); + } + + template + inline std::ostringstream& AppendImplementation(Strs&&... impl) { + onnxruntime::detail::MakeStringImpl(additional_implementation_, std::forward(impl)...); + return additional_implementation_; + } + + template + inline std::ostringstream& MainFunctionBody(Strs&&... body) { + onnxruntime::detail::MakeStringImpl(body_, std::forward(body)...); + return body_; + } + + std::string GuardAgainstOutOfBoundsWorkgroupSizes(const std::string& size) const { + return " if (global_idx >= " + size + ") { return; }\n"; + } + + private: + template // T is one of {int, const TensorShape&} + const ShaderVariable& AddVariableImpl(ProgramVariableScope scope, const std::string& name, ProgramVariableDataType type, T&& arg) { + ORT_ENFORCE((scope == ProgramVariableScope::Input || scope == ProgramVariableScope::Output) && + vars_[static_cast(ProgramVariableScope::Input)].size() + vars_[static_cast(ProgramVariableScope::Output)].size() < limits_.maxStorageBuffersPerShaderStage, + "Too many storage buffers in shader. Max is ", limits_.maxStorageBuffersPerShaderStage); + + if (type == ProgramVariableDataType::Float16 || type == ProgramVariableDataType::Vec2Float16 || type == ProgramVariableDataType::Vec4Float16) { + use_f16_ = true; + } + + return vars_[static_cast(scope)].emplace_back(name, type, std::forward(arg)); + } + + template // ConstantType is one of {ProgramConstant, ProgramOverridableConstantValue, ProgramOverridableConstantDefinition} + void WriteConstantValue(std::ostringstream& ss, const ConstantType& constant) const { + switch (constant.type) { + case ProgramConstantDataType::Float16: + ss << constant.f16.ToFloat(); + break; + case ProgramConstantDataType::Float32: + ss << constant.f32; + break; + case ProgramConstantDataType::Int32: + ss << constant.i32; + break; + case ProgramConstantDataType::Uint32: + ss << constant.u32; + break; + case ProgramConstantDataType::Bool: + ss << (constant.boolean ? "true" : "false"); + break; + default: + ORT_THROW("Invalid constant type", constant.type); + } + } + + std::string GetFinalSourceCode(); + friend class ProgramManager; + + const wgpu::Device& device_; + const wgpu::Limits& limits_; + uint32_t dispatch_group_size_x_; + uint32_t dispatch_group_size_y_; + uint32_t dispatch_group_size_z_; + + const ProgramBase& program_; + const ProgramMetadata& program_metadata_; + + std::array, static_cast(ProgramVariableScope::Count)> vars_; + std::ostringstream ss2; + std::ostringstream additional_implementation_; + std::ostringstream body_; + + bool use_f16_ = false; +}; + +} // namespace webgpu +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/webgpu/shader_macros.h b/onnxruntime/core/providers/webgpu/shader_macros.h new file mode 100644 index 0000000000000..a1c61950e6a10 --- /dev/null +++ b/onnxruntime/core/providers/webgpu/shader_macros.h @@ -0,0 +1,66 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +// macro "D": append to the ostream only in debug build +// +// Usage example: +// +// ss << "error code: " << err_code D(" (") << D(err_msg) D(")"); +// +// This resolves to: (debug build) +// ss << "error code: " << err_code << " (" << err_msg << ")"; +// +// This resolves to: (release build) +// ss << "error code: " << err_code; + +#ifdef D +#undef D +#endif + +#ifndef NDEBUG // if debug build +#define D(str) << str +#else +#define D(str) +#endif + +// macro "DSS" append to the ostream only in debug build +// (assume variable "ss" is in scope) +// +// Usage example: +// +// DSS << "detail error message: " << err_msg; +// +// This resolves to: (debug build) +// ss << "detail error message: " << err_msg; +// +// This resolves to: (release build) +// if constexpr (false) ss << "detail error message: " << err_msg; // no-op + +#ifdef DSS +#undef DSS +#endif + +#ifndef NDEBUG // if debug build +#define DSS ss +#else +#define DSS \ + if constexpr (false) ss +#endif + +// macro "SS" - use function call style to append to the ostream +// (assume variable "ss" is in scope) +// +// Usage example: +// +// SS("error code: ", err_code, " (", err_msg, ")"); +// +// This resolves to: +// ss << "error code: " << err_code << " (" << err_msg << ")"; + +#ifdef SS +#undef SS +#endif + +#define SS(...) ::onnxruntime::detail::MakeStringImpl(ss, __VA_ARGS__) diff --git a/onnxruntime/core/providers/webgpu/shader_variable.cc b/onnxruntime/core/providers/webgpu/shader_variable.cc new file mode 100644 index 0000000000000..d49d76c1ee858 --- /dev/null +++ b/onnxruntime/core/providers/webgpu/shader_variable.cc @@ -0,0 +1,277 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include +#include +#include + +#include "core/providers/webgpu/shader_variable.h" + +#include "core/providers/webgpu/shader_macros.h" + +namespace onnxruntime { +namespace webgpu { + +ShaderVariable::ShaderVariable(const std::string& name, ProgramVariableDataType type, int rank) + : name_(name), type_(type), rank_(rank), usage_(UseUniform) { + Init(); +} + +ShaderVariable::ShaderVariable(const std::string& name, ProgramVariableDataType type, const TensorShape& dims) + : name_(name), type_(type), rank_(static_cast(dims.NumDimensions())), dims_(dims), usage_(None) { + Init(); +} + +void ShaderVariable::Init() { + ORT_ENFORCE(type_ != ProgramVariableDataType::InvalidType, "Invalid type for variable ", name_); +} + +void ShaderVariable::Impl(std::ostringstream& ss) { + // Start generating code + + const std::string value_t = name_ + "_value_t"; + const std::string indices_t = name_ + "_indices_t"; + + const std::string shape = (usage_ & UseUniform) ? "uniforms." + name_ + "_shape" : name_ + "_shape"; + const std::string stride = (usage_ & UseUniform) ? "uniforms." + name_ + "_stride" : name_ + "_stride"; + + // Types + SS("alias ", value_t, " = ", ValueType(), ";\n"); + SS("alias ", indices_t, " = ", IndicesType(), ";\n"); + + // Need shape and strides when (not use uniform) and (any other usage is enabled) + if (!(usage_ & UseUniform) && (usage_ & ~UseUniform)) { + SS("const ", shape, " = ", indices_t, "("); + + bool first = true; + for (auto dim : dims_.GetDims()) { + if (!first) { + ss << ","; + } + + ss << dim; + first = false; + } + ss << ");\n"; + + SS("const ", stride, " = ", indices_t, "("); + first = true; + for (int i = rank_ - 1; i >= 0; i--) { + if (!first) { + ss << ","; + } + ss << dims_.SizeToDimension(i); + first = false; + } + ss << ");\n"; + } + + // Implementation of "fn o2i_{name}" + if (usage_ & UseOffsetToIndices) { + if (rank_ >= 2) { + SS("fn o2i_", name_, "(offset : u32)->", indices_t, " {\n"); + SS(" var indices: ", indices_t, ";\n"); + SS(" var current = offset;\n"); + for (size_t i = 0; i < rank_ - 1; i++) { + auto current_stride = GetElementAt(stride, i, rank_); + SS(" let dim", i, " = current / ", current_stride, ";\n"); + SS(" let rest", i, " = current % ", current_stride, ";\n"); + SS(" indices[", i, "] = dim", i, ";\n"); + SS(" current = rest", i, ";\n"); + } + SS(" indices[", rank_ - 1, "] = current;\n"); + SS(" return indices;\n"); + SS("}\n"); + } + } + + // Implementation of "fn i2o_{name}" + if (usage_ & UseIndicesToOffset) { + if (rank_ >= 2) { + SS("fn i2o_", name_, "(indices : ", indices_t, ")->u32 {\n"); + SS(" return "); + for (size_t i = 0; i < rank_ - 1; i++) { + SS("indices[", i, "] * ", GetElementAt(stride, i, rank_), " + "); + } + SS("indices[", rank_ - 1, "];\n"); + SS("}\n"); + } + } + + // Implementation of "fn {res_name}_bi2o_{name}" + if (usage_ & UseBroadcastedIndicesToOffset) { + // TODO: do we need this if rank < 2? + for (const auto& iter : broadcasted_to_) { + const auto& broadcasted_result = iter.get(); + SS("fn ", broadcasted_result.name_, "_bi2o_", name_, "(indices : ", broadcasted_result.IndicesType(), ")->u32 {\n"); + if (rank_ == 0) { + SS(" return 0;\n"); + } else { + SS(" return "); + for (size_t i = 0; i < rank_ - 1; i++) { + auto idx = broadcasted_result.IndicesGet("indices", i + broadcasted_result.rank_ - rank_); + SS(IndicesGet(stride, i), " * (", idx, " % ", IndicesGet(shape, i), ") + "); + } + SS(broadcasted_result.IndicesGet("indices", broadcasted_result.rank_ - 1), " % ", IndicesGet(shape, rank_ - 1), ";\n"); + } + SS("}\n"); + } + } + + // Implementation of "fn set_{name}" + if (usage_ & UseSet) { + if (rank_ >= 2) { + SS("fn set_", name_, "(d0: u32"); + for (size_t i = 1; i < rank_; i++) { + SS(", d", i, ": u32"); + } + SS(", value: ", value_t, ") {\n"); + SS(" set_", name_, "_by_indices(d0"); + for (size_t i = 1; i < rank_; i++) { + SS(", d", i); + } + SS(", value);\n"); + SS("}\n"); + } + } + + // Implementation of "fn set_{name}_by_indices" + if (usage_ & UseSetByIndices) { + if (rank_ >= 2) { + SS("fn set_", name_, "_by_indices(indices: ", indices_t, ", value: ", value_t, ") {\n"); + SS(" ", SetByOffset("i2o_" + name_ + "(indices)", "value"), "\n"); + SS("}\n"); + } + } + + // Implementation of "fn get_{name}" + if (usage_ & UseGet) { + if (rank_ >= 2) { + SS("fn get_", name_, "(d0: u32"); + for (size_t i = 1; i < rank_; i++) { + SS(", d", i, ": u32"); + } + SS(")->", value_t, " {\n"); + SS(" return get_", name_, "_by_indices(d0"); + for (size_t i = 1; i < rank_; i++) { + SS(", d", i); + } + SS(");\n"); + SS("}\n"); + } + } + + // Implementation of "fn get_{name}_by_indices" + if (usage_ & UseGetByIndices) { + if (rank_ >= 2) { + SS("fn get_", name_, "_by_indices(indices: ", indices_t, ")->", value_t, " {\n"); + SS(" return ", GetByOffset("i2o_" + name_ + "(indices)"), ";\n"); + SS("}\n"); + } + } +} + +std::string ShaderVariable::GetByOffsetImpl(const std::string& offset) const { + std::ostringstream ss; + ss.imbue(std::locale::classic()); + + switch (type_) { + case onnxruntime::webgpu::ProgramVariableDataType::InvalidType: + ORT_THROW("Invalid type"); + break; + case onnxruntime::webgpu::ProgramVariableDataType::Int64: + ss << "i32(" << name_ << "[" << offset << "].x)"; + break; + case onnxruntime::webgpu::ProgramVariableDataType::Uint64: + ss << "u32(" << name_ << "[" << offset << "].x)"; + break; + case onnxruntime::webgpu::ProgramVariableDataType::Vec4Bool: + ss << "vec4(bool(" + << name_ << "[" << offset << "] & 0xFFu), bool(" + << name_ << "[" << offset << "] & 0xFF00u), bool(" + << name_ << "[" << offset << "] & 0xFF0000u), bool(" + << name_ << "[" << offset << "] & 0xFF000000u))"; + break; + default: + ss << name_ << "[" << offset << "]"; + } + + return ss.str(); +} + +std::string ShaderVariable::SetByOffsetImpl(const std::string& offset, const std::string& value) const { + std::ostringstream ss; + ss.imbue(std::locale::classic()); + + switch (type_) { + case onnxruntime::webgpu::ProgramVariableDataType::InvalidType: + ORT_THROW("Invalid type"); + break; + case onnxruntime::webgpu::ProgramVariableDataType::Int64: + ss << name_ << "[" << offset << "]=vec2(u32(" << value << "), select(0u, 0xFFFFFFFFu, " << value << " < 0));"; + break; + case onnxruntime::webgpu::ProgramVariableDataType::Uint64: + ss << name_ << "[" << offset << "]=vec2(u32(" << value << "), 0u);"; + break; + case onnxruntime::webgpu::ProgramVariableDataType::Vec4Bool: + ss << name_ << "[" << offset << "]=dot(vec4(0x1, 0x100, 0x10000, 0x1000000), vec4(" << value << "));"; + break; + default: + ss << name_ << "[" << offset << "]=" << value << ";"; + } + + return ss.str(); +} + +std::string_view ShaderVariable::StorageType() const { + constexpr static const std::string_view STORAGE_TYPE[] = { + "f32", // f32 + "vec2", // vec2f32 + "vec4", // vec4f32 + "f16", // f16 + "vec2", // vec2f16 + "vec4", // vec4f16 + "i32", // i32 + "vec2", // vec2i32 + "vec4", // vec4i32 + "u32", // u32 + "vec2", // vec2u32 + "vec4", // vec4u32 + "vec2", // int64 + "vec2", // uint64 + "u32", // vec4bool + }; + + return STORAGE_TYPE[static_cast(type_)]; +} + +std::string_view ShaderVariable::ValueType() const { + constexpr static const std::string_view VALUE_TYPE[] = { + "f32", // f32 + "f32", // vec2f32 + "f32", // vec4f32 + "f16", // f16 + "f16", // vec2f16 + "f16", // vec4f16 + "i32", // i32 + "i32", // vec2i32 + "i32", // vec4i32 + "u32", // u32 + "u32", // vec2u32 + "u32", // vec4u32 + "i32", // int64 (trancated to i32) + "u32", // uint64 (trancated to u32) + "vec4", // vec4bool + }; + + return VALUE_TYPE[static_cast(type_)]; +} + +std::string ShaderVariable::IndicesType() const { + return rank_ < 2 ? "u32" + : (rank_ < 4 ? MakeStringWithClassicLocale("vec", rank_, "") + : MakeStringWithClassicLocale("array")); +} + +} // namespace webgpu +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/webgpu/shader_variable.h b/onnxruntime/core/providers/webgpu/shader_variable.h new file mode 100644 index 0000000000000..0a5cad8237871 --- /dev/null +++ b/onnxruntime/core/providers/webgpu/shader_variable.h @@ -0,0 +1,263 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include + +#include "core/common/safeint.h" +#include "core/framework/tensor_shape.h" + +#include "core/providers/webgpu/program.h" + +namespace onnxruntime { +namespace webgpu { + +template +std::string GetElementAt(const std::string& var, const TIdx& idx, int rank, bool is_f16 = false) { + // "std::string::rfind(str, 0) == 0" is equivalent to "std::string::starts_with(str)" before C++20. + if (var.rfind("uniform.", 0) == 0) { + if (rank > 4) { + if constexpr (std::is_integral_v) { + if (is_f16) { + return MakeStringWithClassicLocale(var, "[", idx / 8, "][", (idx % 8) / 4, "][", (idx % 8) % 4, "]"); + } else { + return MakeStringWithClassicLocale(var, "[", idx / 4, "][", idx % 4, "]"); + } + } else { + if (is_f16) { + return MakeStringWithClassicLocale(var, "[(", idx, ") / 8][(", idx, ") % 8 / 4][(", idx, ") % 8 % 4]"); + } else { + return MakeStringWithClassicLocale(var, "[(", idx, ") / 4][(", idx, ") % 4]"); + } + } + } else { + return rank > 1 ? MakeStringWithClassicLocale(var, "[", idx, "]") : var; + } + } else { + return rank > 1 ? MakeStringWithClassicLocale(var, "[", idx, "]") : var; + } +} + +class ShaderVariable { + public: + ShaderVariable(const std::string& name, ProgramVariableDataType type, int rank); + ShaderVariable(const std::string& name, ProgramVariableDataType type, const TensorShape& dims); + + ShaderVariable(ShaderVariable&&) = default; + ShaderVariable& operator=(ShaderVariable&&) = default; + + // create a WGSL expression ({varname}_indices_t) for getting indices from offset. + // \param offset: a WGSL expression (u32) representing the offset. + inline std::string OffsetToIndices(const std::string& offset_expr) const; + + // create a WGSL expression (u32) for getting offset from indices. + // \param indices: a WGSL expression ({varname}_indices_t) representing the indices. + inline std::string IndicesToOffset(const std::string& indices_expr) const; + + // create a WGSL expression (u32) for getting original offset from broadcasted indices. + // \param indices: a WGSL expression ({broadcasted_result_varname}_indices_t) representing the broadcasted indices. + // \param broadcasted_result: the broadcasted result variable. + inline std::string BroadcastedIndicesToOffset(const std::string& indices_expr, const ShaderVariable& broadcasted_result) const; + + // create a WGSL expression ({varname}_indices_t) as an indices literal + // \param init: a list of indices values. + template + inline std::string Indices(TIndices&&... indices_args) const; + + // create a WGSL statement for setting value of the specified dimension of the indices. + // \param indices_var: name of the indices variable ({varname}_indices_t). + // \param idx: the index (i32|u32) of the dimension to set. + // \param value: the value (u32) to set. + template + inline std::string IndicesSet(const std::string& indices_var, const TIdx& idx_expr, const TVal& value) const; + + // create a WGSL expression (u32) for getting value of the specified dimension of the indices. + // \param indices_var: name of the indices variable ({varname}_indices_t). + // \param idx: the index (i32|u32) of the dimension to get. + template + inline std::string IndicesGet(const std::string& indices_var, const TIdx& idx_expr) const; + + // create a WGSL statement for setting data at the given indices. + // \param args: a list of indices values (u32) followed by a value ({varname}_value_t). + template + inline std::string Set(TIndicesAndValue&&... args) const; + + // create a WGSL statement for setting data at the given indices. + // \param indices_var: name of the indices variable ({varname}_indices_t). + // \param value: the value ({varname}_value_t) to set. + inline std::string SetByIndices(const std::string& indices_var, const std::string& value) const; + + // create a WGSL statement for setting data at the given offset. + // \param offset: a WGSL expression (u32) representing the offset. + // \param value: the value ({varname}_value_t) to set. + template + inline std::string SetByOffset(TOffset&& offset, TValue&& value) const; + + // create a WGSL expression ({varname}_value_t) for getting data at the given indices. + // \param indices: a list of indices values (u32). + template + inline std::string Get(TIndices&&... indices) const; + + // create a WGSL expression ({varname}_value_t) for getting data at the given indices. + // \param indices_var: name of the indices variable ({varname}_indices_t). + inline std::string GetByIndices(const std::string& indices_var) const; + + // create a WGSL expression ({varname}_value_t) for getting data at the given offset. + // \param offset: a WGSL expression (u32) representing the offset. + template + inline std::string GetByOffset(TOffset&& offset) const; + + private: + enum Usage : uint32_t { + None = 0, + UseOffsetToIndices = 1, + UseIndicesToOffset = 2, + UseBroadcastedIndicesToOffset = 4, + UseSet = 8, + UseSetByIndices = 16, + UseGet = 32, + UseGetByIndices = 64, + UseUniform = 128, + }; + + friend ShaderVariable::Usage operator|(ShaderVariable::Usage a, ShaderVariable::Usage b); + friend ShaderVariable::Usage operator&(ShaderVariable::Usage a, ShaderVariable::Usage b); + friend ShaderVariable::Usage& operator|=(ShaderVariable::Usage& a, ShaderVariable::Usage b); + friend ShaderVariable::Usage& operator&=(ShaderVariable::Usage& a, ShaderVariable::Usage b); + + ORT_DISALLOW_COPY_AND_ASSIGNMENT(ShaderVariable); + + void Init(); + void Impl(std::ostringstream& ss); + + std::string ShaderVariable::GetByOffsetImpl(const std::string& offset) const; + std::string SetByOffsetImpl(const std::string& offset, const std::string& value) const; + + std::string_view StorageType() const; + std::string_view ValueType() const; + std::string IndicesType() const; + + std::string name_; + ProgramVariableDataType type_; + int rank_; + TensorShape dims_; + + mutable Usage usage_; + mutable std::vector> broadcasted_to_; + + friend class ShaderHelper; +}; + +inline ShaderVariable::Usage operator|(ShaderVariable::Usage a, ShaderVariable::Usage b) { + return (ShaderVariable::Usage)((uint32_t&)a | (uint32_t&)b); +} +inline ShaderVariable::Usage operator&(ShaderVariable::Usage a, ShaderVariable::Usage b) { + return (ShaderVariable::Usage)((uint32_t&)a & (uint32_t&)b); +} +inline ShaderVariable::Usage& operator|=(ShaderVariable::Usage& a, ShaderVariable::Usage b) { + return (ShaderVariable::Usage&)((uint32_t&)a |= (uint32_t&)b); +} +inline ShaderVariable::Usage& operator&=(ShaderVariable::Usage& a, ShaderVariable::Usage b) { + return (ShaderVariable::Usage&)((uint32_t&)a &= (uint32_t&)b); +} + +namespace detail { +template >> +std::string pass_as_string(T&& v) { + return std::to_string(std::forward(v)); +} +template +std::string pass_as_string(const T& v) { + return v; +} +} // namespace detail + +inline std::string ShaderVariable::OffsetToIndices(const std::string& offset_expr) const { + usage_ |= UseOffsetToIndices; + return rank_ < 2 ? offset_expr : MakeStringWithClassicLocale("o2i_", name_, '(', offset_expr, ')'); +} + +inline std::string ShaderVariable::IndicesToOffset(const std::string& indices_expr) const { + usage_ |= UseIndicesToOffset; + return rank_ < 2 ? indices_expr : MakeStringWithClassicLocale("i2o_", name_, '(', indices_expr, ')'); +} + +inline std::string ShaderVariable::BroadcastedIndicesToOffset(const std::string& indices_expr, const ShaderVariable& broadcasted_result) const { + usage_ |= UseBroadcastedIndicesToOffset; + broadcasted_to_.push_back(broadcasted_result); + return MakeStringWithClassicLocale(broadcasted_result.name_, "_bi2o_", name_, '(', indices_expr, ')'); +} + +template +inline std::string ShaderVariable::Indices(TIndices&&... indices_args) const { + return rank_ == 0 ? "" : MakeStringWithClassicLocale(name_, "_indices_t(", onnxruntime::detail::StringJoinImpl(", ", std::forward(indices_args)...), ')'); +} + +template +inline std::string ShaderVariable::IndicesSet(const std::string& indices_var, const TIdx& idx_expr, const TVal& value) const { + return rank_ < 2 ? MakeStringWithClassicLocale(indices_var, '=', value, ';') + : MakeStringWithClassicLocale(GetElementAt(indices_var, idx_expr, rank_), '=', value, ';'); +} + +template +inline std::string ShaderVariable::IndicesGet(const std::string& indices_var, const TIdx& idx_expr) const { + return rank_ < 2 ? indices_var : GetElementAt(indices_var, idx_expr, rank_); +} + +template +inline std::string ShaderVariable::SetByOffset(TOffset&& offset, TValue&& value) const { + return SetByOffsetImpl(detail::pass_as_string(offset), detail::pass_as_string(value)); +} + +template +inline std::string ShaderVariable::Set(TIndicesAndValue&&... args) const { + ORT_ENFORCE(sizeof...(TIndicesAndValue) == rank_ + 1, "Number of arguments should be ", rank_ + 1, "(rank + 1)"); + if constexpr (sizeof...(TIndicesAndValue) == 1) { + return SetByOffset("0", std::forward(args)...); + } else if constexpr (sizeof...(TIndicesAndValue) == 2) { + return SetByOffset(std::forward(args)...); + } else { + usage_ |= UseSet | UseSetByIndices | UseIndicesToOffset; + return MakeStringWithClassicLocale("set_", name_, '(', onnxruntime::detail::StringJoinImpl(", ", std::forward(args)...), ");"); + } +} + +inline std::string ShaderVariable::SetByIndices(const std::string& indices_var, const std::string& value) const { + if (rank_ < 2) { + return SetByOffset(indices_var, value); + } else { + usage_ |= UseSetByIndices | UseIndicesToOffset; + return MakeStringWithClassicLocale("set_", name_, "_by_indices(", indices_var, ", ", value, ");"); + } +} + +template +inline std::string ShaderVariable::GetByOffset(TOffset&& offset) const { + return GetByOffsetImpl(detail::pass_as_string(offset)); +} + +template +inline std::string ShaderVariable::Get(TIndices&&... indices) const { + ORT_ENFORCE(sizeof...(TIndices) == rank_, "Number of arguments should be ", rank_, "(rank)"); + if constexpr (sizeof...(TIndices) == 0) { + return GetByOffset("0"); + } else if constexpr (sizeof...(TIndices) == 1) { + return GetByOffset(std::forward(indices)...); + } else { + usage_ |= UseGet | UseGetByIndices | UseIndicesToOffset; + return MakeStringWithClassicLocale("get_", name_, '(', onnxruntime::detail::StringJoinImpl(", ", std::forward(indices)...), ')'); + } +} + +inline std::string ShaderVariable::GetByIndices(const std::string& indices_var) const { + if (rank_ < 2) { + return GetByOffset(indices_var); + } else { + usage_ |= UseGetByIndices | UseIndicesToOffset; + return MakeStringWithClassicLocale("get_", name_, "_by_indices(", indices_var, ")"); + } +} + +} // namespace webgpu +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/webgpu/webgpu_context.cc b/onnxruntime/core/providers/webgpu/webgpu_context.cc new file mode 100644 index 0000000000000..a891f5a8a5516 --- /dev/null +++ b/onnxruntime/core/providers/webgpu/webgpu_context.cc @@ -0,0 +1,349 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include +#include + +#include "core/common/common.h" + +#include "core/providers/webgpu/webgpu_context.h" +#include "core/providers/webgpu/buffer_manager.h" +#include "core/providers/webgpu/webgpu_execution_provider.h" +#include "core/providers/webgpu/program.h" +#include "core/providers/webgpu/program_cache_key.h" +#include "core/providers/webgpu/program_manager.h" + +namespace onnxruntime { +namespace webgpu { + +std::vector GetAvailableRequiredFeatures(const wgpu::Adapter& adapter) { + std::vector required_features; + constexpr wgpu::FeatureName features[]{ + wgpu::FeatureName::ChromiumExperimentalTimestampQueryInsidePasses, + wgpu::FeatureName::TimestampQuery, + wgpu::FeatureName::ShaderF16}; + for (auto feature : features) { + if (adapter.HasFeature(feature)) { + required_features.push_back(feature); + } + } + return required_features; +} + +wgpu::RequiredLimits GetAvailableRequiredLimits(const wgpu::Adapter& adapter) { + wgpu::RequiredLimits required_limits{}; + wgpu::SupportedLimits adapter_limits; + ORT_ENFORCE(adapter.GetLimits(&adapter_limits)); + + required_limits.limits.maxBindGroups = adapter_limits.limits.maxBindGroups; + required_limits.limits.maxComputeWorkgroupStorageSize = adapter_limits.limits.maxComputeWorkgroupStorageSize; + required_limits.limits.maxComputeWorkgroupsPerDimension = adapter_limits.limits.maxComputeWorkgroupsPerDimension; + required_limits.limits.maxStorageBufferBindingSize = adapter_limits.limits.maxStorageBufferBindingSize; + required_limits.limits.maxBufferSize = adapter_limits.limits.maxBufferSize; + required_limits.limits.maxComputeInvocationsPerWorkgroup = adapter_limits.limits.maxComputeInvocationsPerWorkgroup; + required_limits.limits.maxComputeWorkgroupSizeX = adapter_limits.limits.maxComputeWorkgroupSizeX; + required_limits.limits.maxComputeWorkgroupSizeY = adapter_limits.limits.maxComputeWorkgroupSizeY; + required_limits.limits.maxComputeWorkgroupSizeZ = adapter_limits.limits.maxComputeWorkgroupSizeZ; + + return required_limits; +} + +void WebGpuContext::Initialize(const WebGpuExecutionProviderInfo& webgpu_ep_info) { + std::call_once(init_flag_, [this, &webgpu_ep_info]() { + // Initialization.Step.1 - Create wgpu::Instance + if (instance_ == nullptr) { + wgpu::InstanceDescriptor instance_desc{}; + instance_desc.features.timedWaitAnyEnable = true; + instance_ = wgpu::CreateInstance(&instance_desc); + + ORT_ENFORCE(instance_ != nullptr, "Failed to create wgpu::Instance."); + } + + // Initialization.Step.2 - Create wgpu::Adapter + if (adapter_ == nullptr) { + wgpu::RequestAdapterOptions req_adapter_options = {}; + wgpu::RequestAdapterCallbackInfo req_adapter_callback_info = {}; + req_adapter_callback_info.mode = wgpu::CallbackMode::WaitAnyOnly; + req_adapter_callback_info.callback = [](WGPURequestAdapterStatus status, + WGPUAdapter adapter, const char* message, + void* userdata) { + ORT_ENFORCE(status == WGPURequestAdapterStatus_Success, "Failed to get a WebGPU adapter: ", message); + *static_cast(userdata) = wgpu::Adapter::Acquire(adapter); + }; + req_adapter_callback_info.userdata = &adapter_; + ORT_ENFORCE(wgpu::WaitStatus::Success == instance_.WaitAny(instance_.RequestAdapter(&req_adapter_options, req_adapter_callback_info), UINT64_MAX)); + ORT_ENFORCE(adapter_ != nullptr, "Failed to get a WebGPU adapter."); + } + + // Initialization.Step.3 - Create wgpu::Device + if (device_ == nullptr) { + wgpu::DeviceDescriptor device_desc = {}; + std::vector required_features = GetAvailableRequiredFeatures(adapter_); + if (required_features.size() > 0) { + device_desc.requiredFeatures = required_features.data(); + } + wgpu::RequiredLimits required_limits = GetAvailableRequiredLimits(adapter_); + device_desc.requiredLimits = &required_limits; + + // TODO: temporary error handling + device_desc.SetUncapturedErrorCallback([](const wgpu::Device& /*device*/, wgpu::ErrorType type, const char* message) { + LOGS_DEFAULT(ERROR) << "WebGPU device error(" << int(type) << "): " << message; + }); + + wgpu::RequestDeviceCallbackInfo req_device_callback_info = {}; + req_device_callback_info.mode = wgpu::CallbackMode::WaitAnyOnly; + req_device_callback_info.callback = [](WGPURequestDeviceStatus status, WGPUDevice device, char const* message, void* userdata) { + ORT_ENFORCE(status == WGPURequestAdapterStatus_Success, "Failed to get a WebGPU device: ", message); + *static_cast(userdata) = wgpu::Device::Acquire(device); + }; + req_device_callback_info.userdata = &device_; + ORT_ENFORCE(wgpu::WaitStatus::Success == instance_.WaitAny(adapter_.RequestDevice(&device_desc, req_device_callback_info), UINT64_MAX)); + ORT_ENFORCE(device_ != nullptr, "Failed to get a WebGPU device."); + } + + // cache adapter info + ORT_ENFORCE(Adapter().GetInfo(&adapter_info_)); + // cache device limits + wgpu::SupportedLimits device_supported_limits; + ORT_ENFORCE(Device().GetLimits(&device_supported_limits)); + device_limits_ = device_supported_limits.limits; + + // create buffer manager + buffer_mgr_ = BufferManagerFactory::Create(*this, webgpu_ep_info.storage_buffer_cache_mode, webgpu_ep_info.uniform_buffer_cache_mode, webgpu_ep_info.query_resolve_buffer_cache_mode); + + // create program manager + program_mgr_ = std::make_unique(Device(), DeviceLimits()); + }); +} + +Status WebGpuContext::Wait(wgpu::Future f) { + auto status = instance_.WaitAny(f, UINT64_MAX); + if (status == wgpu::WaitStatus::Success) { + return Status::OK(); + } + return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Failed to wait for the operation:", uint32_t(status)); +} + +Status WebGpuContext::Run(const ComputeContext& /*context*/, const ProgramBase& program) { + const auto& inputs = program.Inputs(); + const auto& outputs = program.Outputs(); + +#ifndef NDEBUG // if debug build + ORT_ENFORCE(std::all_of(inputs.begin(), inputs.end(), [](const ProgramInput& input) { + const auto* tensor = input.tensor; + return tensor != nullptr && + tensor->Location().mem_type == OrtMemType::OrtMemTypeDefault && + tensor->Location().device.Type() == OrtDevice::GPU && + tensor->Location().name == WEBGPU_BUFFER; + }), + "All inputs must be tensors on WebGPU buffers."); + + ORT_ENFORCE(std::all_of(outputs.begin(), outputs.end(), [](Tensor* tensor) { + return tensor != nullptr && + tensor->Location().mem_type == OrtMemType::OrtMemTypeDefault && + tensor->Location().device.Type() == OrtDevice::GPU && + tensor->Location().name == WEBGPU_BUFFER; + }), + "All outputs must be tensors on WebGPU buffers."); +#endif + + if (outputs.size() == 0) { + return Status::OK(); + } + + ProgramMetadata metadata = program.GetMetadata(); + + // validate program metadata + { + const auto& [constants, overridable_constants, uniform_variables] = metadata; + + // check overridable constants + ORT_RETURN_IF(program.OverridableConstants().size() != overridable_constants.size(), + "Size of overridable constants mismatch in program \"", program.Name(), + "\", Expected: ", overridable_constants.size(), + ", Actual: ", program.OverridableConstants().size()); + size_t num_overridable_constants = program.OverridableConstants().size(); + for (size_t i = 0; i < num_overridable_constants; ++i) { + const auto& override_value = program.OverridableConstants()[i]; + const auto& definition = overridable_constants[i]; + ORT_RETURN_IF(override_value.has_value && override_value.type != definition.type, + "Overridable override_value[", i, "] (", definition.name, ") data type mismatch in program \"", program.Name(), + "\", Expected: ", definition.type, + ", Actual: ", override_value.type); + ORT_RETURN_IF(!override_value.has_value && !definition.has_default_value, + "Overridable override_value[", i, "] (", definition.name, ") no override_value specified in program \"", program.Name(), + "\""); + } + + // check uniform variables + ORT_RETURN_IF(program.UniformVariables().size() != uniform_variables.size(), + "Size of uniform_value variables mismatch in program \"", program.Name(), + "\", Expected: ", uniform_variables.size(), + ", Actual: ", program.UniformVariables().size()); + size_t num_uniform_variables = program.UniformVariables().size(); + for (size_t i = 0; i < num_uniform_variables; ++i) { + const auto& uniform_value = program.UniformVariables()[i]; + const auto& definition = uniform_variables[i]; + ORT_RETURN_IF(uniform_value.length > 0 && uniform_value.data_type != definition.data_type, + "Uniform variable[", i, "] (", definition.name, ") data type mismatch in program \"", program.Name(), + "\", Expected: ", definition.data_type, + ", Actual: ", uniform_value.data_type); + } + } + + uint32_t x = program.DispatchGroupSizeX(); + uint32_t y = program.DispatchGroupSizeY(); + uint32_t z = program.DispatchGroupSizeZ(); + ORT_RETURN_IF_ERROR(program_mgr_->NormalizeDispatchGroupSize(x, y, z)); + + bool is_1d_dispatch = (y == 1 && z == 1); + + auto key = CalculateProgramCacheKey(program, is_1d_dispatch); + + const auto* program_artifact = program_mgr_->Get(key); + if (program_artifact == nullptr) { + wgpu::ComputePipeline compute_pipeline; + auto status = program_mgr_->Build(program, + metadata, +#ifndef NDEBUG // if debug build + key, +#endif + x, + y, + z, + compute_pipeline); + ORT_RETURN_IF_ERROR(status); + program_artifact = program_mgr_->Set(key, ProgramArtifact{program, std::move(compute_pipeline)}); +#ifndef NDEBUG // if debug build + ORT_ENFORCE(program_artifact != nullptr, "Program artifact should not be nullptr."); +#endif + } + + std::vector input_buffers; + input_buffers.reserve(inputs.size()); + for (const auto& input : inputs) { + input_buffers.push_back(reinterpret_cast(const_cast(input.tensor->DataRaw()))); + } + + std::vector output_buffers; + output_buffers.reserve(outputs.size()); + for (const auto& output : outputs) { + output_buffers.push_back(reinterpret_cast(output->MutableDataRaw())); + } + + WGPUBuffer uniform_buffer = nullptr; + auto uniform_buffer_size = program_artifact->uniform_total_size; + if (uniform_buffer_size > 0) { + auto num_uniforms = program.UniformVariables().size(); + ORT_ENFORCE(program_artifact->uniforms.size() == num_uniforms, + "Uniforms size mismatch. Artifact: ", program_artifact->uniforms.size(), ", Current: ", num_uniforms); + + std::vector uniform_data(uniform_buffer_size); + + for (size_t i = 0; i < num_uniforms; ++i) { + const auto& uniform = program.UniformVariables()[i]; + const auto& artifact_uniform = program_artifact->uniforms[i]; + + ORT_ENFORCE(uniform.data_type == artifact_uniform.data_type, + "Uniform[", i, "] data type mismatch. Artifact: ", artifact_uniform.data_type, + ", Current: ", uniform.data_type); + ORT_ENFORCE(uniform.length == artifact_uniform.length, + "Uniform[", i, "] elements number mismatch. Artifact: ", artifact_uniform.length, ", Current: ", uniform.length); + ORT_ENFORCE(uniform.data.size() == artifact_uniform.length * ProgramUniformVariableDataTypeSize[static_cast(uniform.data_type)], + "Uniform[", i, "] data size mismatch. Artifact: ", artifact_uniform.length * ProgramUniformVariableDataTypeSize[static_cast(uniform.data_type)], + ", Current: ", uniform.data.size()); + + auto offset = artifact_uniform.offset; + auto size = uniform.data.size(); + memcpy(uniform_data.data() + offset, uniform.data.data(), size); + } + + uniform_buffer = buffer_mgr_->Create(uniform_buffer_size, wgpu::BufferUsage::CopyDst | wgpu::BufferUsage::Uniform); + device_.GetQueue().WriteBuffer(uniform_buffer, 0, uniform_data.data(), uniform_buffer_size); + } + + const auto& compute_pass_encoder = GetComputePassEncoder(); + + // TODO: write timestamp query + + uint32_t entry_index = 0; + std::vector bind_group_entries; + for (const auto& input : inputs) { + bind_group_entries.push_back({nullptr, entry_index++, reinterpret_cast(const_cast(input.tensor->DataRaw()))}); + } + for (const auto& output : outputs) { + bind_group_entries.push_back({nullptr, entry_index++, reinterpret_cast(output->MutableDataRaw())}); + } + if (uniform_buffer) { + bind_group_entries.push_back({nullptr, entry_index++, uniform_buffer}); + } + + wgpu::BindGroupDescriptor bind_group_desc{}; + bind_group_desc.layout = program_artifact->compute_pipeline.GetBindGroupLayout(0); + bind_group_desc.entryCount = bind_group_entries.size(); + bind_group_desc.entries = bind_group_entries.data(); + bind_group_desc.label = program_artifact->name.c_str(); + + auto bind_group = Device().CreateBindGroup(&bind_group_desc); + + // TODO support graph capture + + compute_pass_encoder.SetPipeline(program_artifact->compute_pipeline); + compute_pass_encoder.SetBindGroup(0, bind_group); + compute_pass_encoder.DispatchWorkgroups(x, y, z); + + if (uniform_buffer) { + buffer_mgr_->Release(uniform_buffer); + } + + // TODO: write timestamp query + + ++num_pending_dispatches_; + + // if (querytype == at-passes) { EndComputePass(); } + + if (num_pending_dispatches_ >= max_num_pending_dispatches_) { + Flush(); + } + + return Status::OK(); +} + +std::unordered_map> WebGpuContextFactory::contexts_; +std::mutex WebGpuContextFactory::mutex_; + +WebGpuContext& WebGpuContextFactory::CreateContext(int context_id, WGPUInstance instance, WGPUAdapter adapter, WGPUDevice device) { + if (context_id == 0) { + // context ID is preserved for the default context. User cannot use context ID 0 as a custom context. + ORT_ENFORCE(instance == nullptr && adapter == nullptr && device == nullptr, + "WebGPU EP default context (contextId=0) must not have custom WebGPU instance, adapter or device."); + } else { + // for context ID > 0, user must provide custom WebGPU instance, adapter and device. + ORT_ENFORCE(instance != nullptr && adapter != nullptr && device != nullptr, + "WebGPU EP custom context (contextId>0) must have custom WebGPU instance, adapter and device."); + } + + std::lock_guard lock(mutex_); + + auto it = contexts_.find(context_id); + if (it == contexts_.end()) { + auto context = std::unique_ptr(new WebGpuContext(instance, adapter, device)); + it = contexts_.emplace(context_id, std::move(context)).first; + } else if (context_id != 0) { + ORT_ENFORCE(it->second->instance_.Get() == instance && it->second->adapter_.Get() == adapter && it->second->device_.Get() == device, + "WebGPU EP context ID ", context_id, " is already created with different WebGPU instance, adapter or device."); + } + return *it->second; +} + +WebGpuContext& WebGpuContextFactory::GetContext(int context_id) { + std::lock_guard lock(mutex_); + + auto it = contexts_.find(context_id); + ORT_ENFORCE(it != contexts_.end(), "WebGPU EP context ID ", context_id, " is not found."); + + return *it->second; +} + +} // namespace webgpu +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/webgpu/webgpu_context.h b/onnxruntime/core/providers/webgpu/webgpu_context.h new file mode 100644 index 0000000000000..d8b0c2b48b067 --- /dev/null +++ b/onnxruntime/core/providers/webgpu/webgpu_context.h @@ -0,0 +1,124 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#ifdef __EMSCRIPTEN__ +#include +#endif + +#include +#include + +#include + +#include "core/common/common.h" +#include "core/providers/webgpu/webgpu_execution_provider.h" +#include "core/providers/webgpu/buffer_manager.h" +#include "core/providers/webgpu/program_manager.h" + +namespace onnxruntime { +class Tensor; + +namespace webgpu { +class WebGpuContext; +class ComputeContext; +class ProgramBase; + +class WebGpuContextFactory { + public: + static WebGpuContext& CreateContext(int context_id, WGPUInstance instance, WGPUAdapter adapter, WGPUDevice device); + static WebGpuContext& GetContext(int context_id); + + private: + WebGpuContextFactory() {} + + static std::unordered_map> contexts_; + static std::mutex mutex_; +}; + +// Class WebGpuContext includes all necessary resources for the context. +class WebGpuContext final { + public: + void Initialize(const WebGpuExecutionProviderInfo& webgpu_ep_info); + + Status Wait(wgpu::Future f); + + const wgpu::Adapter& Adapter() const { return adapter_; } + const wgpu::Device& Device() const { return device_; } + + const wgpu::AdapterInfo& AdapterInfo() const { return adapter_info_; } + const wgpu::Limits& DeviceLimits() const { return device_limits_; } + + const wgpu::CommandEncoder& GetCommandEncoder() { + if (!current_command_encoder_) { + current_command_encoder_ = device_.CreateCommandEncoder(); + } + return current_command_encoder_; + } + + const wgpu::ComputePassEncoder& GetComputePassEncoder() { + if (!current_compute_pass_encoder_) { + auto& command_encoder = GetCommandEncoder(); + + wgpu::ComputePassDescriptor compute_pass_desc{}; + + // TODO: add support for GPU Query + + current_compute_pass_encoder_ = command_encoder.BeginComputePass(&compute_pass_desc); + } + return current_compute_pass_encoder_; + } + + void EndComputePass() { + if (current_compute_pass_encoder_) { + current_compute_pass_encoder_.End(); + current_compute_pass_encoder_ = nullptr; + } + } + + void Flush() { + if (!current_command_encoder_) { + return; + } + + EndComputePass(); + + // TODO: add support for GPU Query + + auto command_buffer = current_command_encoder_.Finish(); + Device().GetQueue().Submit(1, &command_buffer); + BufferManager().RefreshPendingBuffers(); + current_command_encoder_ = nullptr; + } + + webgpu::BufferManager& BufferManager() const { return *buffer_mgr_; } + + Status Run(const ComputeContext& context, const ProgramBase& program); + + private: + WebGpuContext(WGPUInstance instance, WGPUAdapter adapter, WGPUDevice device) : instance_{instance}, adapter_{adapter}, device_{device} {} + ORT_DISALLOW_COPY_ASSIGNMENT_AND_MOVE(WebGpuContext); + + std::once_flag init_flag_; + + wgpu::Instance instance_; + wgpu::Adapter adapter_; + wgpu::Device device_; + + wgpu::AdapterInfo adapter_info_; + wgpu::Limits device_limits_; + + wgpu::CommandEncoder current_command_encoder_; + wgpu::ComputePassEncoder current_compute_pass_encoder_; + + std::unique_ptr buffer_mgr_; + std::unique_ptr program_mgr_; + friend class WebGpuContextFactory; + + int num_pending_dispatches_ = 0; + const int max_num_pending_dispatches_ = 16; +}; + +} // namespace webgpu +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/webgpu/webgpu_execution_provider.cc b/onnxruntime/core/providers/webgpu/webgpu_execution_provider.cc new file mode 100644 index 0000000000000..e7688d1fafb94 --- /dev/null +++ b/onnxruntime/core/providers/webgpu/webgpu_execution_provider.cc @@ -0,0 +1,837 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "core/providers/webgpu/webgpu_execution_provider.h" + +#ifdef __EMSCRIPTEN__ +#include +#endif +#include +#include +#include +#include +#include + +#ifndef DISABLE_CONTRIB_OPS +#include "contrib_ops/webgpu/webgpu_contrib_kernels.h" +#endif + +#include "allocator.h" +#include "core/framework/compute_capability.h" +#include "core/framework/data_transfer_manager.h" +#include "core/framework/fallback_cpu_capability.h" +#include "core/framework/kernel_registry.h" +#include "core/graph/function_utils.h" +#include "core/graph/indexed_sub_graph.h" +#include "data_transfer.h" + +namespace onnxruntime { + +namespace webgpu { +template <> +KernelCreateInfo BuildKernelCreateInfo() { + KernelCreateInfo info; + return info; +} + +class Memcpy final : public OpKernel { + public: + Memcpy(const OpKernelInfo& info) : OpKernel(info) {} + + Status Compute(OpKernelContext* ctx) const override { + const auto* X = ctx->Input(0); + Tensor* Y = ctx->Output(0, X->Shape()); + return Info().GetDataTransferManager().CopyTensor(*X, *Y); + } +}; + +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 1, MemcpyFromHost); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 1, MemcpyToHost); + +ONNX_OPERATOR_KERNEL_EX( + MemcpyFromHost, + kOnnxDomain, + 1, + kWebGpuExecutionProvider, + (*KernelDefBuilder::Create()) + .InputMemoryType(OrtMemTypeCPU, 0) + .ExecQueueId(0) + .TypeConstraint("T", DataTypeImpl::AllFixedSizeTensorTypes()), + Memcpy); + +ONNX_OPERATOR_KERNEL_EX( + MemcpyToHost, + kOnnxDomain, + 1, + kWebGpuExecutionProvider, + (*KernelDefBuilder::Create()) + .OutputMemoryType(OrtMemTypeCPU, 0) + .ExecQueueId(1) + .TypeConstraint("T", DataTypeImpl::AllFixedSizeTensorTypes()), + Memcpy); + +#define KERNEL_CREATE_INFO_VERSIONED(Start, End, Op) \ + BuildKernelCreateInfo< \ + ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, Start, End, Op)> + +#define KERNEL_CREATE_INFO(Start, Op) \ + BuildKernelCreateInfo< \ + ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, Start, Op)> + +#define KERNEL_CREATE_INFO_TYPED(Start, type, Op) \ + BuildKernelCreateInfo< \ + ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, Start, type, Op)> + +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 6, 12, Abs); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 13, Abs); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 6, 12, Neg); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 13, Neg); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 6, 12, Floor); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 13, Floor); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 6, 12, Ceil); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 13, Ceil); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 6, 12, Reciprocal); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 13, Reciprocal); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 6, 12, Sqrt); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 13, Sqrt); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 6, 12, Exp); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 13, Exp); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 9, 12, Erf); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 13, Erf); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 6, 12, Sigmoid); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 13, Sigmoid); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 6, HardSigmoid); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 6, 12, Log); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 13, Log); + +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 7, Sin); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 7, Cos); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 7, Tan); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 7, Asin); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 7, Acos); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 7, Atan); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 9, Sinh); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 9, Cosh); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 9, Asinh); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 9, Acosh); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 9, Atanh); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 6, 12, Tanh); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 13, Tanh); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 1, Not); + +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 6, 8, Cast); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 9, 12, Cast); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 13, 18, Cast); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 19, Cast); + +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 6, 10, Clip); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 11, 11, Clip); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 12, 12, Clip); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 13, Clip); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 6, Elu); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 6, 12, Relu); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 13, 13, Relu); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 14, Relu); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 6, 15, LeakyRelu); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 16, LeakyRelu); + +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 1, 10, ReduceMax); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 11, 11, ReduceMax); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 12, 12, ReduceMax); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 13, 17, ReduceMax); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 18, ReduceMax); + +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 1, 10, ReduceMean); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 11, 12, ReduceMean); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 13, 17, ReduceMean); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 18, ReduceMean); + +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 1, 10, ReduceMin); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 11, 11, ReduceMin); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 12, 12, ReduceMin); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 13, 17, ReduceMin); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 18, ReduceMin); + +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 1, 10, ReduceProd); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 11, 12, ReduceProd); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 13, 17, ReduceProd); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 18, ReduceProd); + +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 1, 10, ReduceSum); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 11, 12, ReduceSum); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 13, ReduceSum); + +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 1, 10, ReduceL1); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 11, 12, ReduceL1); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 13, 17, ReduceL1); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 18, ReduceL1); + +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 1, 10, ReduceL2); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 11, 12, ReduceL2); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 13, 17, ReduceL2); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 18, ReduceL2); + +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 1, 10, ReduceLogSum); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 11, 12, ReduceLogSum); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 13, 17, ReduceLogSum); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 18, ReduceLogSum); + +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 1, 10, ReduceSumSquare); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 11, 12, ReduceSumSquare); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 13, 17, ReduceSumSquare); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 18, ReduceSumSquare); + +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 1, 10, ReduceLogSumExp); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 11, 12, ReduceLogSumExp); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 13, 17, ReduceLogSumExp); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 18, ReduceLogSumExp); + +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 10, ThresholdedRelu); + +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 7, 12, Add); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 13, 13, Add); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 14, Add); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 7, 12, Sub); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 13, 13, Sub); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 14, Sub); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 7, 12, Mul); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 13, 13, Mul); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 14, Mul); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 7, 12, Div); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 13, 13, Div); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 14, Div); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 7, 11, Pow); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 12, 12, Pow); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 13, 14, Pow); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 15, Pow); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 7, 10, Equal); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 11, 12, Equal); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 13, 18, Equal); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 19, Equal); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 7, 8, Greater); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 9, 12, Greater); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 13, Greater); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 12, 15, GreaterOrEqual); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 16, GreaterOrEqual); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 7, 8, Less); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 9, 12, Less); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 13, Less); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 12, 15, LessOrEqual); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 16, LessOrEqual); + +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 1, 12, Shape); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 13, 14, Shape); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 15, Shape); + +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 5, 12, Reshape); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 13, 13, Reshape); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 14, Reshape); + +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 1, 10, Squeeze); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 11, 12, Squeeze); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 13, Squeeze); + +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 1, 10, Unsqueeze); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 11, 12, Unsqueeze); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 13, Unsqueeze); + +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 9, 15, Where); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 16, Where); + +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 1, 12, Transpose); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 13, Transpose); + +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 11, 12, DepthToSpace); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 13, DepthToSpace); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kMSInternalNHWCDomain, 11, 12, DepthToSpace); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kMSInternalNHWCDomain, 13, DepthToSpace); + +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 1, 10, Conv); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 11, Conv); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kMSInternalNHWCDomain, 1, 10, Conv); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kMSInternalNHWCDomain, 11, Conv); + +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 1, 10, ConvTranspose); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 11, ConvTranspose); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kMSInternalNHWCDomain, 1, 10, ConvTranspose); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kMSInternalNHWCDomain, 11, ConvTranspose); + +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 1, 7, MaxPool); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 8, 9, MaxPool); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 10, 10, MaxPool); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 11, 11, MaxPool); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 12, MaxPool); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kMSInternalNHWCDomain, 1, 7, MaxPool); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kMSInternalNHWCDomain, 8, 9, MaxPool); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kMSInternalNHWCDomain, 10, 10, MaxPool); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kMSInternalNHWCDomain, 11, 11, MaxPool); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kMSInternalNHWCDomain, 12, MaxPool); + +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 7, 9, AveragePool); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 10, 10, AveragePool); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 11, AveragePool); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kMSInternalNHWCDomain, 7, 9, AveragePool); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kMSInternalNHWCDomain, 10, 10, AveragePool); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kMSInternalNHWCDomain, 11, AveragePool); + +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 1, GlobalAveragePool); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kMSInternalNHWCDomain, 1, GlobalAveragePool); + +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 1, GlobalMaxPool); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kMSInternalNHWCDomain, 1, GlobalMaxPool); + +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 7, 8, Gemm); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 9, 10, Gemm); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 11, 12, Gemm); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 13, Gemm); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 1, 12, MatMul); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 13, MatMul); + +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 1, 10, float, ArgMax); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 11, 12, float, ArgMax); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 13, float, ArgMax); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 1, 10, float, ArgMin); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 11, 12, float, ArgMin); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 13, float, ArgMin); + +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 1, 10, Softmax); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 11, 12, Softmax); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 13, Softmax); + +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 1, 3, Concat); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 4, 10, Concat); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 11, 12, Concat); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 13, Concat); + +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 1, 1, Split); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 2, 10, Split); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 11, 12, Split); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 13, 17, Split); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 18, Split); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 8, 12, Expand); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 13, Expand); + +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 10, 10, Resize); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 11, 12, Resize); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 13, 17, Resize); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 18, 18, Resize); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 19, Resize); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kMSInternalNHWCDomain, 10, 10, Resize); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kMSInternalNHWCDomain, 11, 12, Resize); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kMSInternalNHWCDomain, 13, 17, Resize); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kMSInternalNHWCDomain, 18, 18, Resize); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kMSInternalNHWCDomain, 19, Resize); + +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 1, 10, Gather); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 11, 12, Gather); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 13, Gather); + +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 11, 12, GatherElements); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 13, GatherElements); + +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 1, 9, Slice); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 10, 10, Slice); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 11, 12, Slice); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 13, Slice); + +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 1, 8, Flatten); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 9, 10, Flatten); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 11, 12, Flatten); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 13, Flatten); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 6, 12, Tile); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 13, Tile); + +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 17, LayerNormalization); + +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 6, InstanceNormalization); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kMSInternalNHWCDomain, 6, InstanceNormalization); + +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 11, Range); + +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 12, float, Einsum); + +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 2, 10, Pad); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 11, 12, Pad); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 13, 17, Pad); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 18, 18, Pad); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 19, Pad); + +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 1, 10, If); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 11, 12, If); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 13, 18, If); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 19, If); + +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 7, 8, BatchNormalization); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 9, 13, BatchNormalization); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 14, 14, BatchNormalization); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 15, BatchNormalization); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kMSInternalNHWCDomain, 7, 8, BatchNormalization); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kMSInternalNHWCDomain, 9, 13, BatchNormalization); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kMSInternalNHWCDomain, 14, 14, BatchNormalization); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kMSInternalNHWCDomain, 15, BatchNormalization); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 11, 13, CumSum); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 14, CumSum); + +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 10, 12, uint8_t, DequantizeLinear); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 10, 12, int8_t, DequantizeLinear); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 10, 12, int32_t, DequantizeLinear); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 13, 18, uint8_t, DequantizeLinear); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 13, 18, int8_t, DequantizeLinear); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 13, 18, int32_t, DequantizeLinear); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 19, 20, uint8_t, DequantizeLinear); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 19, 20, int8_t, DequantizeLinear); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 19, 20, int32_t, DequantizeLinear); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 21, uint8_t, DequantizeLinear); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 21, int8_t, DequantizeLinear); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 21, int32_t, DequantizeLinear); + +std::unique_ptr RegisterKernels() { + auto kernel_registry = std::make_unique(); + + static const BuildKernelCreateInfoFn function_table[] = { + BuildKernelCreateInfo, // default entry to avoid the list becoming empty after ops-reducing + BuildKernelCreateInfo, + BuildKernelCreateInfo, + + // element-wise operators + // unary - math + KERNEL_CREATE_INFO_VERSIONED(6, 12, Abs), + KERNEL_CREATE_INFO(13, Abs), + // KERNEL_CREATE_INFO_VERSIONED(6, 12, Neg), + // KERNEL_CREATE_INFO(13, Neg), + // KERNEL_CREATE_INFO_VERSIONED(6, 12, Floor), + // KERNEL_CREATE_INFO(13, Floor), + // KERNEL_CREATE_INFO_VERSIONED(6, 12, Ceil), + // KERNEL_CREATE_INFO(13, Ceil), + // KERNEL_CREATE_INFO_VERSIONED(6, 12, Reciprocal), + // KERNEL_CREATE_INFO(13, Reciprocal), + // KERNEL_CREATE_INFO_VERSIONED(6, 12, Sqrt), + // KERNEL_CREATE_INFO(13, Sqrt), + // KERNEL_CREATE_INFO_VERSIONED(6, 12, Exp), + // KERNEL_CREATE_INFO(13, Exp), + // KERNEL_CREATE_INFO_VERSIONED(9, 12, Erf), + // KERNEL_CREATE_INFO(13, Erf), + // KERNEL_CREATE_INFO_VERSIONED(6, 12, Sigmoid), + // KERNEL_CREATE_INFO(13, Sigmoid), + // KERNEL_CREATE_INFO(6, HardSigmoid), + // KERNEL_CREATE_INFO_VERSIONED(6, 12, Log), + // KERNEL_CREATE_INFO(13, Log), + + // KERNEL_CREATE_INFO(7, Sin), + // KERNEL_CREATE_INFO(7, Cos), + // KERNEL_CREATE_INFO(7, Tan), + // KERNEL_CREATE_INFO(7, Asin), + // KERNEL_CREATE_INFO(7, Acos), + // KERNEL_CREATE_INFO(7, Atan), + // KERNEL_CREATE_INFO(9, Sinh), + // KERNEL_CREATE_INFO(9, Cosh), + // KERNEL_CREATE_INFO(9, Asinh), + // KERNEL_CREATE_INFO(9, Acosh), + // KERNEL_CREATE_INFO(9, Atanh), + // KERNEL_CREATE_INFO_VERSIONED(6, 12, Tanh), + // KERNEL_CREATE_INFO(13, Tanh), + // KERNEL_CREATE_INFO(1, Not), + + // KERNEL_CREATE_INFO_VERSIONED(6, 8, Cast), + // KERNEL_CREATE_INFO_VERSIONED(9, 12, Cast), + // KERNEL_CREATE_INFO_VERSIONED(13, 18, Cast), + // KERNEL_CREATE_INFO(19, Cast), + + // // activations + // KERNEL_CREATE_INFO_VERSIONED(6, 10, Clip), + // KERNEL_CREATE_INFO_VERSIONED(11, 11, Clip), + // KERNEL_CREATE_INFO_VERSIONED(12, 12, Clip), + // KERNEL_CREATE_INFO(13, Clip), + // KERNEL_CREATE_INFO(6, Elu), + // KERNEL_CREATE_INFO_VERSIONED(6, 12, Relu), + // KERNEL_CREATE_INFO_VERSIONED(13, 13, Relu), + // KERNEL_CREATE_INFO(14, Relu), + // KERNEL_CREATE_INFO_VERSIONED(6, 15, LeakyRelu), + // KERNEL_CREATE_INFO(16, LeakyRelu), + // KERNEL_CREATE_INFO(10, ThresholdedRelu), + + // // binary - math + // KERNEL_CREATE_INFO_VERSIONED(7, 12, Add), + // KERNEL_CREATE_INFO_VERSIONED(13, 13, Add), + // KERNEL_CREATE_INFO(14, Add), + // KERNEL_CREATE_INFO_VERSIONED(7, 12, Sub), + // KERNEL_CREATE_INFO_VERSIONED(13, 13, Sub), + // KERNEL_CREATE_INFO(14, Sub), + // KERNEL_CREATE_INFO_VERSIONED(7, 12, Mul), + // KERNEL_CREATE_INFO_VERSIONED(13, 13, Mul), + // KERNEL_CREATE_INFO(14, Mul), + // KERNEL_CREATE_INFO_VERSIONED(7, 12, Div), + // KERNEL_CREATE_INFO_VERSIONED(13, 13, Div), + // KERNEL_CREATE_INFO(14, Div), + // KERNEL_CREATE_INFO_VERSIONED(7, 11, Pow), + // KERNEL_CREATE_INFO_VERSIONED(12, 12, Pow), + // KERNEL_CREATE_INFO_VERSIONED(13, 14, Pow), + // KERNEL_CREATE_INFO(15, Pow), + // KERNEL_CREATE_INFO_VERSIONED(7, 10, Equal), + // KERNEL_CREATE_INFO_VERSIONED(11, 12, Equal), + // KERNEL_CREATE_INFO_VERSIONED(13, 18, Equal), + // KERNEL_CREATE_INFO(19, Equal), + // KERNEL_CREATE_INFO_VERSIONED(7, 8, Greater), + // KERNEL_CREATE_INFO_VERSIONED(9, 12, Greater), + // KERNEL_CREATE_INFO(13, Greater), + // KERNEL_CREATE_INFO_VERSIONED(12, 15, GreaterOrEqual), + // KERNEL_CREATE_INFO(16, GreaterOrEqual), + // KERNEL_CREATE_INFO_VERSIONED(7, 8, Less), + // KERNEL_CREATE_INFO_VERSIONED(9, 12, Less), + // KERNEL_CREATE_INFO(13, Less), + // KERNEL_CREATE_INFO_VERSIONED(12, 15, LessOrEqual), + // KERNEL_CREATE_INFO(16, LessOrEqual), + + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + + // KERNEL_CREATE_INFO_VERSIONED(9, 15, Where), + // KERNEL_CREATE_INFO(16, Where), + + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + + // BuildKernelCreateInfo, + + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + + // BuildKernelCreateInfo, + + // BuildKernelCreateInfo, + + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + }; + + for (auto& function_table_entry : function_table) { + KernelCreateInfo info = function_table_entry(); + if (info.kernel_def != nullptr) { // filter disabled entries where type is void + ORT_THROW_IF_ERROR(kernel_registry->Register(std::move(info))); + } + } + +#ifndef DISABLE_CONTRIB_OPS + Status status = ::onnxruntime::contrib::webgpu::RegisterWebGpuContribKernels(*kernel_registry); + ORT_ENFORCE(status.IsOK(), "Failed to register WebGPU contrib kernels: " + status.ErrorMessage()); +#endif + + return kernel_registry; +} + +} // namespace webgpu + +using namespace webgpu; + +WebGpuExecutionProvider::WebGpuExecutionProvider(int context_id, + const WebGpuContext& context, + const WebGpuExecutionProviderInfo& info) + : IExecutionProvider{kWebGpuExecutionProvider, OrtDevice(OrtDevice::GPU, OrtDevice::MemType::DEFAULT, 0)}, + context_id_{context_id}, + context_{context}, + preferred_data_layout_{info.data_layout}, + enable_graph_capture_{info.enable_graph_capture} { +} + +std::vector WebGpuExecutionProvider::CreatePreferredAllocators() { + AllocatorCreationInfo gpuBufferAllocatorCreationInfo([&](int) { + return std::make_unique(context_); + }, + 0, false); + return std::vector{CreateAllocator(gpuBufferAllocatorCreationInfo)}; +} + +std::vector> WebGpuExecutionProvider::GetCapability( + const onnxruntime::GraphViewer& graph, + const IKernelLookup& kernel_lookup) const { + InlinedVector candidates; + // `tenative_candidates` is a subset of `candidates`. + InlinedVector tenative_candidates; + for (auto& node_index : graph.GetNodesInTopologicalOrder()) { + const auto* p_node = graph.GetNode(node_index); + if (p_node == nullptr) + continue; + + const auto& node = *p_node; + if (!node.GetExecutionProviderType().empty()) { + // If the node was added by layout transformer, do not move it to CPU + if (node.GetExecutionProviderType() == kWebGpuExecutionProvider) { + candidates.push_back(node.Index()); + } + continue; + } + + const KernelCreateInfo* webgpu_kernel_def = kernel_lookup.LookUpKernel(node); + // none of the provided registries has a webgpu kernel for this node + if (webgpu_kernel_def == nullptr) { + LOGS(*GetLogger(), INFO) << "webgpu kernel not found in registries for Op type: " + << node.OpType() << " node name: " << node.Name(); + continue; + } + candidates.push_back(node.Index()); + tenative_candidates.push_back(node.Index()); + } + auto cpu_nodes = GetCpuPreferredNodes(graph, kernel_lookup, tenative_candidates); + std::vector> result; + for (auto& node_index : candidates) { + if (cpu_nodes.count(node_index) > 0) { + continue; + } + + auto sub_graph = std::make_unique(); + sub_graph->nodes.push_back(node_index); + result.emplace_back(std::make_unique(std::move(sub_graph))); + } + return result; +} + +std::shared_ptr WebGpuExecutionProvider::GetKernelRegistry() const { + static std::shared_ptr registry = webgpu::RegisterKernels(); + + return registry; +} + +std::unique_ptr WebGpuExecutionProvider::GetDataTransfer() const { + return std::make_unique(context_); +} + +WebGpuExecutionProvider::~WebGpuExecutionProvider() { +} + +Status WebGpuExecutionProvider::OnRunStart(const onnxruntime::RunOptions& /*run_options*/) { + if (IsGraphCaptureEnabled() && IsGraphCaptureAllowed() && !IsGraphCaptured(0)) { + ORT_NOT_IMPLEMENTED("graph capture not implemented"); + } + return Status::OK(); +} + +Status WebGpuExecutionProvider::OnRunEnd(bool /* sync_stream */, const onnxruntime::RunOptions& /*run_options*/) { + if (IsGraphCaptureEnabled() && !IsGraphCaptured(0)) { + if (IsGraphCaptureAllowed()) { + ORT_NOT_IMPLEMENTED("graph capture not implemented"); + // is_graph_captured_ = true; + } else { + IncrementRegularRunCountBeforeGraphCapture(); + } + } + + return Status::OK(); +} + +bool WebGpuExecutionProvider::IsGraphCaptureEnabled() const { + return enable_graph_capture_; +} + +bool WebGpuExecutionProvider::IsGraphCaptured(int) const { + return is_graph_captured_; +} + +Status WebGpuExecutionProvider::ReplayGraph(int) { + ORT_ENFORCE(IsGraphCaptured(0)); + ORT_ENFORCE(false); + return Status::OK(); +} + +bool WebGpuExecutionProvider::IsGraphCaptureAllowed() const { + return regular_run_count_before_graph_capture_ >= min_num_runs_before_cuda_graph_capture_; +} + +void WebGpuExecutionProvider::IncrementRegularRunCountBeforeGraphCapture() { + ++regular_run_count_before_graph_capture_; +} +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/webgpu/webgpu_execution_provider.h b/onnxruntime/core/providers/webgpu/webgpu_execution_provider.h new file mode 100644 index 0000000000000..6fb2381637a67 --- /dev/null +++ b/onnxruntime/core/providers/webgpu/webgpu_execution_provider.h @@ -0,0 +1,77 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Copyright (c) 2019, NXP Semiconductor, Inc. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include "core/framework/execution_provider.h" +#include "core/framework/session_options.h" +#include "core/graph/constants.h" +#include "core/providers/providers.h" + +struct pthreadpool; +namespace onnxruntime { +namespace webgpu { + +// forward declaration for this EP's namespace. +template +KernelCreateInfo BuildKernelCreateInfo(); + +class WebGpuContext; +enum class BufferCacheMode; +} // namespace webgpu + +struct WebGpuExecutionProviderInfo { + DataLayout data_layout; + bool enable_graph_capture; + webgpu::BufferCacheMode storage_buffer_cache_mode; + webgpu::BufferCacheMode uniform_buffer_cache_mode; + webgpu::BufferCacheMode query_resolve_buffer_cache_mode; + webgpu::BufferCacheMode default_buffer_cache_mode; +}; + +class WebGpuExecutionProvider : public IExecutionProvider { + public: + WebGpuExecutionProvider(int context_id, const webgpu::WebGpuContext& context, const WebGpuExecutionProviderInfo& info); + ~WebGpuExecutionProvider() override; + + std::vector> GetCapability( + const onnxruntime::GraphViewer& graph_viewer, + const IKernelLookup& /*kernel_lookup*/) const override; + + std::shared_ptr GetKernelRegistry() const override; + std::unique_ptr GetDataTransfer() const override; + + DataLayout GetPreferredLayout() const override { return preferred_data_layout_; } + + FusionStyle GetFusionStyle() const override { return FusionStyle::FilteredGraphViewer; } + + // WebGPU EP disallow concurrent run because actual implementation (eg. WebGPU backend) relies on global states to + // work, and concurrent run with async function may mess up the states and cause undefined behavior. + bool ConcurrentRunSupported() const override { return false; } + + std::vector CreatePreferredAllocators() override; + + Status OnRunStart(const onnxruntime::RunOptions& run_options) override; + Status OnRunEnd(bool sync_stream, const onnxruntime::RunOptions& run_options) override; + + // WebGPU EP reuses the Device ID as the key to get the WebGpuContext instance. + int GetDeviceId() const override { return context_id_; } + + bool IsGraphCaptureEnabled() const override; + bool IsGraphCaptured(int graph_annotation_id) const override; + Status ReplayGraph(int graph_annotation_id) override; + + private: + bool IsGraphCaptureAllowed() const; + void IncrementRegularRunCountBeforeGraphCapture(); + int context_id_; + const webgpu::WebGpuContext& context_; + DataLayout preferred_data_layout_; + bool enable_graph_capture_ = false; + bool is_graph_captured_ = false; + int regular_run_count_before_graph_capture_ = 0; + const int min_num_runs_before_cuda_graph_capture_ = 1; // required min regular runs before graph capture for the necessary memory allocations. +}; + +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/webgpu/webgpu_kernel.h b/onnxruntime/core/providers/webgpu/webgpu_kernel.h new file mode 100644 index 0000000000000..6486987501d14 --- /dev/null +++ b/onnxruntime/core/providers/webgpu/webgpu_kernel.h @@ -0,0 +1,42 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include "core/providers/webgpu/webgpu_execution_provider.h" +#include "core/providers/webgpu/compute_context.h" + +#include "core/framework/op_kernel.h" + +namespace onnxruntime { +namespace webgpu { + +// ----------------------------------------------------------------------- +// Base class for WebGPU kernels +// ----------------------------------------------------------------------- +class WebGpuKernel : public OpKernel { + public: + explicit WebGpuKernel(const OpKernelInfo& info) + : OpKernel(info) { + } + + Status Compute(OpKernelContext* p_op_kernel_context) const override { + ComputeContext context{*p_op_kernel_context}; + auto s = ComputeInternal(context); + // use this to precisely locate the node where CUDA failure comes from + // if (cudaSuccess != cudaDeviceSynchronize()) + // __debugbreak(); + // if (s.IsOK()) { + // auto err = cudaGetLastError(); + // if (err != cudaSuccess) { + // return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "CUDA error ", cudaGetErrorName(err), ":", cudaGetErrorString(err)); + // } + // } + return s; + } + + virtual Status ComputeInternal(ComputeContext& context) const = 0; +}; + +} // namespace webgpu +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/webgpu/webgpu_provider_factory.cc b/onnxruntime/core/providers/webgpu/webgpu_provider_factory.cc new file mode 100644 index 0000000000000..93258b84c5112 --- /dev/null +++ b/onnxruntime/core/providers/webgpu/webgpu_provider_factory.cc @@ -0,0 +1,144 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include + +#include "core/framework/error_code_helper.h" +#include "core/providers/webgpu/buffer_manager.h" +#include "core/providers/webgpu/webgpu_execution_provider.h" +#include "core/providers/webgpu/webgpu_provider_factory_creator.h" +#include "core/providers/webgpu/webgpu_context.h" +#include "core/session/abi_session_options_impl.h" +#include "core/session/ort_apis.h" + +#include "core/providers/webgpu/webgpu_provider_options.h" +using namespace onnxruntime::webgpu::options; + +namespace onnxruntime { + +struct WebGpuProviderFactory : IExecutionProviderFactory { + WebGpuProviderFactory(int context_id, const webgpu::WebGpuContext& context, const WebGpuExecutionProviderInfo& webgpu_ep_info) + : context_id_{context_id}, context_{context}, info_{webgpu_ep_info} { + } + + std::unique_ptr CreateProvider() override { + return std::make_unique(context_id_, context_, info_); + } + + private: + int context_id_; + const webgpu::WebGpuContext& context_; + WebGpuExecutionProviderInfo info_; +}; + +std::shared_ptr WebGpuProviderFactoryCreator::Create(const SessionOptions* session_options) { + // + // STEP.1 - prepare WebGpuExecutionProviderInfo + // + WebGpuExecutionProviderInfo webgpu_ep_info{ + // preferred layout is NHWC by default + DataLayout::NHWC, + // graph capture feature is disabled by default + false, + }; + + std::string preferred_layout_str; + if (session_options->config_options.TryGetConfigEntry(kPreferredLayout, preferred_layout_str)) { + if (preferred_layout_str == kPreferredLayout_NHWC) { + webgpu_ep_info.data_layout = DataLayout::NHWC; + } else if (preferred_layout_str == kPreferredLayout_NCHW) { + webgpu_ep_info.data_layout = DataLayout::NCHW; + } else { + ORT_THROW("Invalid preferred layout: ", preferred_layout_str); + } + } + LOGS_DEFAULT(VERBOSE) << "WebGPU EP preferred layout: " << int(webgpu_ep_info.data_layout) << " (parsed from \"" + << preferred_layout_str << "\")"; + + std::string enable_graph_capture_str; + if (session_options->config_options.TryGetConfigEntry(kEnableGraphCapture, enable_graph_capture_str)) { + if (enable_graph_capture_str == kkEnableGraphCapture_ON) { + webgpu_ep_info.enable_graph_capture = true; + } else if (enable_graph_capture_str == kkEnableGraphCapture_OFF) { + webgpu_ep_info.enable_graph_capture = false; + } else { + ORT_THROW("Invalid enable graph capture: ", enable_graph_capture_str); + } + } + LOGS_DEFAULT(VERBOSE) << "WebGPU EP graph capture enable: " << webgpu_ep_info.enable_graph_capture; + + auto parse_buffer_cache_mode = [session_options](const std::string& config_entry_str, webgpu::BufferCacheMode default) -> webgpu::BufferCacheMode { + std::string buffer_cache_mode_str; + if (session_options->config_options.TryGetConfigEntry(config_entry_str, buffer_cache_mode_str)) { + if (buffer_cache_mode_str == kBufferCacheMode_Disabled) { + return webgpu::BufferCacheMode::Disabled; + } else if (buffer_cache_mode_str == kBufferCacheMode_LazyRelease) { + return webgpu::BufferCacheMode::LazyRelease; + } else if (buffer_cache_mode_str == kBufferCacheMode_Simple) { + return webgpu::BufferCacheMode::Simple; + } else if (buffer_cache_mode_str == kBufferCacheMode_Bucket) { + return webgpu::BufferCacheMode::Bucket; + } else { + ORT_THROW("Invalid buffer cache mode: ", config_entry_str); + } + } else { + return default; + } + }; + + webgpu_ep_info.storage_buffer_cache_mode = parse_buffer_cache_mode(kStorageBufferCacheMode, webgpu::BufferCacheMode::Bucket); + LOGS_DEFAULT(VERBOSE) << "WebGPU EP storage buffer cache mode: " << webgpu_ep_info.storage_buffer_cache_mode; + + webgpu_ep_info.uniform_buffer_cache_mode = parse_buffer_cache_mode(kUniformBufferCacheMode, webgpu::BufferCacheMode::LazyRelease); + LOGS_DEFAULT(VERBOSE) << "WebGPU EP uniform buffer cache mode: " << webgpu_ep_info.uniform_buffer_cache_mode; + + webgpu_ep_info.query_resolve_buffer_cache_mode = parse_buffer_cache_mode(kQueryResolveBufferCacheMode, webgpu::BufferCacheMode::Disabled); + LOGS_DEFAULT(VERBOSE) << "WebGPU EP query resolve buffer cache mode: " << webgpu_ep_info.query_resolve_buffer_cache_mode; + + webgpu_ep_info.default_buffer_cache_mode = parse_buffer_cache_mode(kDefaultBufferCacheMode, webgpu::BufferCacheMode::Disabled); + LOGS_DEFAULT(VERBOSE) << "WebGPU EP default buffer cache mode: " << webgpu_ep_info.default_buffer_cache_mode; + + // + // STEP.2 - prepare WebGpuContext + // + int context_id = 0; + std::string context_id_str; + if (session_options->config_options.TryGetConfigEntry(kDeviceId, context_id_str)) { + ORT_ENFORCE(std::errc{} == + std::from_chars(context_id_str.data(), context_id_str.data() + context_id_str.size(), context_id).ec); + } + + size_t webgpu_instance = 0; + std::string webgpu_instance_str; + if (session_options->config_options.TryGetConfigEntry(kWebGpuInstance, webgpu_instance_str)) { + static_assert(sizeof(WGPUInstance) == sizeof(size_t), "WGPUInstance size mismatch"); + ORT_ENFORCE(std::errc{} == + std::from_chars(webgpu_instance_str.data(), webgpu_instance_str.data() + webgpu_instance_str.size(), webgpu_instance).ec); + } + + size_t webgpu_adapter = 0; + std::string webgpu_adapter_str; + if (session_options->config_options.TryGetConfigEntry(kWebGpuAdapter, webgpu_adapter_str)) { + static_assert(sizeof(WGPUAdapter) == sizeof(size_t), "WGPUAdapter size mismatch"); + ORT_ENFORCE(std::errc{} == + std::from_chars(webgpu_adapter_str.data(), webgpu_adapter_str.data() + webgpu_adapter_str.size(), webgpu_adapter).ec); + } + + size_t webgpu_device = 0; + std::string webgpu_device_str; + if (session_options->config_options.TryGetConfigEntry(kWebGpuDevice, webgpu_device_str)) { + static_assert(sizeof(WGPUDevice) == sizeof(size_t), "WGPUDevice size mismatch"); + ORT_ENFORCE(std::errc{} == + std::from_chars(webgpu_device_str.data(), webgpu_device_str.data() + webgpu_device_str.size(), webgpu_device).ec); + } + + auto& context = webgpu::WebGpuContextFactory::CreateContext(context_id, + reinterpret_cast(webgpu_instance), + reinterpret_cast(webgpu_adapter), + reinterpret_cast(webgpu_device)); + context.Initialize(webgpu_ep_info); + + return std::make_shared(context_id, context, webgpu_ep_info); +} + +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/webgpu/webgpu_provider_factory_creator.h b/onnxruntime/core/providers/webgpu/webgpu_provider_factory_creator.h new file mode 100644 index 0000000000000..7fac9234b949b --- /dev/null +++ b/onnxruntime/core/providers/webgpu/webgpu_provider_factory_creator.h @@ -0,0 +1,18 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include + +#include "core/framework/provider_options.h" +#include "core/providers/providers.h" + +namespace onnxruntime { +struct SessionOptions; + +struct WebGpuProviderFactoryCreator { + static std::shared_ptr Create(const SessionOptions* session_options); +}; + +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/webgpu/webgpu_provider_options.h b/onnxruntime/core/providers/webgpu/webgpu_provider_options.h new file mode 100644 index 0000000000000..65ccbd800b122 --- /dev/null +++ b/onnxruntime/core/providers/webgpu/webgpu_provider_options.h @@ -0,0 +1,40 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +namespace onnxruntime { +namespace webgpu { +namespace options { + +// The following are the options that can be set in the WebGPU provider options. + +constexpr const char* kPreferredLayout = "preferredLayout"; +constexpr const char* kEnableGraphCapture = "enableGraphCapture"; + +constexpr const char* kDeviceId = "deviceId"; +constexpr const char* kWebGpuInstance = "webgpuInstance"; +constexpr const char* kWebGpuAdapter = "webgpuAdapter"; +constexpr const char* kWebGpuDevice = "webgpuDevice"; + +constexpr const char* kStorageBufferCacheMode = "storageBufferCacheMode"; +constexpr const char* kUniformBufferCacheMode = "uniformBufferCacheMode"; +constexpr const char* kQueryResolveBufferCacheMode = "queryResolveBufferCacheMode"; +constexpr const char* kDefaultBufferCacheMode = "defaultBufferCacheMode"; + +// The following are the possible values for the provider options. + +constexpr const char* kPreferredLayout_NCHW = "NCHW"; +constexpr const char* kPreferredLayout_NHWC = "NHWC"; + +constexpr const char* kkEnableGraphCapture_ON = "1"; +constexpr const char* kkEnableGraphCapture_OFF = "0"; + +constexpr const char* kBufferCacheMode_Disabled = "disabled"; +constexpr const char* kBufferCacheMode_LazyRelease = "lazyRelease"; +constexpr const char* kBufferCacheMode_Simple = "simple"; +constexpr const char* kBufferCacheMode_Bucket = "bucket"; + +} // namespace options +} // namespace webgpu +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/webgpu/webgpu_supported_types.h b/onnxruntime/core/providers/webgpu/webgpu_supported_types.h new file mode 100644 index 0000000000000..fccaef2c53575 --- /dev/null +++ b/onnxruntime/core/providers/webgpu/webgpu_supported_types.h @@ -0,0 +1,34 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include "core/providers/cpu/tensor/shape_op.h" + +namespace onnxruntime { +namespace webgpu { + +using SupportedTypes = + TypeList< + float, + MLFloat16, + int32_t, + uint32_t>; + +using SupportedFloats = + TypeList< + float, + MLFloat16>; + +inline const std::vector& WebGpuSupportedDataTypes() { + static const std::vector supportedDataTypes = BuildKernelDefConstraintsFromTypeList(); + return supportedDataTypes; +} + +inline const std::vector& WebGpuSupportedFloatTypes() { + static const std::vector supportedDataTypes = BuildKernelDefConstraintsFromTypeList(); + return supportedDataTypes; +} + +} // namespace webgpu +} // namespace onnxruntime diff --git a/onnxruntime/core/session/inference_session.cc b/onnxruntime/core/session/inference_session.cc index b9e017df5baa3..dced1cf0e1464 100644 --- a/onnxruntime/core/session/inference_session.cc +++ b/onnxruntime/core/session/inference_session.cc @@ -759,12 +759,12 @@ common::Status InferenceSession::RegisterExecutionProvider(const std::shared_ptr // Some session option values (default or user provided) may not work with some EPs. // Rather than put the onus on the user to know these, make the appropriate change while logging the change. - if (provider_type == onnxruntime::kDmlExecutionProvider) { - // DML's memory is not byte addressable and hence mem pattern doesn't work. + if (provider_type == onnxruntime::kDmlExecutionProvider || provider_type == onnxruntime::kWebGpuExecutionProvider) { + // DML and WebGPU memory is not byte addressable and hence mem pattern doesn't work. if (session_options_.enable_mem_pattern) { LOGS(*session_logger_, INFO) - << "Having memory pattern enabled is not supported while using the DML Execution Provider. " - << "So disabling it for this session since it uses the DML Execution Provider."; + << "Having memory pattern enabled is not supported while using " << provider_type << ". " + << "So disabling it for this session since it uses " << provider_type << "."; session_options_.enable_mem_pattern = false; } diff --git a/onnxruntime/core/session/onnxruntime_c_api.cc b/onnxruntime/core/session/onnxruntime_c_api.cc index 1a5484ddc0055..f231b0148b37e 100644 --- a/onnxruntime/core/session/onnxruntime_c_api.cc +++ b/onnxruntime/core/session/onnxruntime_c_api.cc @@ -2730,6 +2730,8 @@ static constexpr OrtApi ort_api_1_to_20 = { &OrtApis::KernelInfoGetAllocator, &OrtApis::AddExternalInitializersFromFilesInMemory, // End of Version 18 - DO NOT MODIFY ABOVE (see above text for more information) + + &OrtApis::SessionOptionsAppendExecutionProvider_WebGPU, }; // OrtApiBase can never change as there is no way to know what version of OrtApiBase is returned by OrtGetApiBase. diff --git a/onnxruntime/core/session/ort_apis.h b/onnxruntime/core/session/ort_apis.h index fcae173e6c162..fd765feae6ad0 100644 --- a/onnxruntime/core/session/ort_apis.h +++ b/onnxruntime/core/session/ort_apis.h @@ -384,6 +384,13 @@ ORT_API_STATUS_IMPL(InvokeOp, ORT_API(void, ReleaseOp, _Frees_ptr_opt_ OrtOp* op); +ORT_API_STATUS_IMPL(SessionOptionsAppendExecutionProvider_WebGPU, + _In_ OrtSessionOptions* options, + _In_ const OrtWebGPUProviderOptions* webgpu_options, + _In_reads_(num_keys) const char* const* string_options_keys, + _In_reads_(num_keys) const char* const* string_options_values, + _In_ size_t num_keys); + ORT_API_STATUS_IMPL(SessionOptionsAppendExecutionProvider, _In_ OrtSessionOptions* options, _In_ const char* provider_name, diff --git a/onnxruntime/core/session/provider_registration.cc b/onnxruntime/core/session/provider_registration.cc index db8b97f6d2c13..d2f8579fef7ec 100644 --- a/onnxruntime/core/session/provider_registration.cc +++ b/onnxruntime/core/session/provider_registration.cc @@ -2,6 +2,7 @@ // Licensed under the MIT License. #include +#include #include "core/common/common.h" #include "core/common/logging/logging.h" @@ -131,6 +132,12 @@ ORT_API_STATUS_IMPL(OrtApis::SessionOptionsAppendExecutionProvider, options->provider_factories.push_back(WebNNProviderFactoryCreator::Create(provider_options)); #else status = create_not_supported_status(); +#endif + } else if (strcmp(provider_name, "WebGPU") == 0) { +#if defined(USE_WEBGPU) + options->provider_factories.push_back(WebGpuProviderFactoryCreator::Create(&(options->value))); +#else + status = create_not_supported_status(); #endif } else if (strcmp(provider_name, "AZURE") == 0) { #if defined(USE_AZURE) @@ -158,6 +165,59 @@ ORT_API_STATUS_IMPL(OrtApis::SessionOptionsAppendExecutionProvider, API_IMPL_END } +ORT_API_STATUS_IMPL(OrtApis::SessionOptionsAppendExecutionProvider_WebGPU, + _In_ OrtSessionOptions* options, + _In_ const OrtWebGPUProviderOptions* webgpu_options, + _In_reads_(num_keys) const char* const* string_options_keys, + _In_reads_(num_keys) const char* const* string_options_values, + _In_ size_t num_keys) { + API_IMPL_BEGIN + std::vector options_keys; + options_keys.reserve(num_keys + 4); + std::vector options_values; + options_values.reserve(num_keys + 4); + + // the following code uses std::to_chars() to convert int/size_t to string. + // unlike std::to_string(), std::to_chars() is guaranteed locale-independent. + // + // uint64_t to string is no more than 20 characters, and + // int32_t to string is no more than 11 characters. + static_assert(sizeof(size_t) == 4 || sizeof(size_t) == 8); + char buffer[sizeof(size_t) == 4 ? 11 : 20]; + + auto res = std::to_chars(buffer, buffer + sizeof(buffer), webgpu_options->device_id); + ORT_ENFORCE(res.ec == std::errc(), "Failed to convert device_id to string"); + std::string device_id(buffer, res.ptr - buffer); + options_keys.push_back("deviceId"); + options_values.push_back(device_id.c_str()); + + res = std::to_chars(buffer, buffer + sizeof(buffer), reinterpret_cast(webgpu_options->instance_handle)); + ORT_ENFORCE(res.ec == std::errc(), "Failed to convert instance_handle to string"); + std::string instance_handle(buffer, res.ptr - buffer); + options_keys.push_back("webgpuInstance"); + options_values.push_back(instance_handle.c_str()); + + res = std::to_chars(buffer, buffer + sizeof(buffer), reinterpret_cast(webgpu_options->adapter_handle)); + ORT_ENFORCE(res.ec == std::errc(), "Failed to convert adapter_handle to string"); + std::string adapter_handle(buffer, res.ptr - buffer); + options_keys.push_back("webgpuAdapter"); + options_values.push_back(adapter_handle.c_str()); + + res = std::to_chars(buffer, buffer + sizeof(buffer), reinterpret_cast(webgpu_options->device_handle)); + ORT_ENFORCE(res.ec == std::errc(), "Failed to convert device_handle to string"); + std::string device_handle(buffer, res.ptr - buffer); + options_keys.push_back("webgpuDevice"); + options_values.push_back(device_handle.c_str()); + + for (size_t i = 0; i != num_keys; ++i) { + options_keys.push_back(string_options_keys[i]); + options_values.push_back(string_options_values[i]); + } + + return OrtApis::SessionOptionsAppendExecutionProvider(options, "WebGPU", options_keys.data(), options_values.data(), options_keys.size()); + API_IMPL_END +} + #if defined(__APPLE__) || defined(ORT_MINIMAL_BUILD) static OrtStatus* CreateNotEnabledStatus(const std::string& ep) { return OrtApis::CreateStatus(ORT_FAIL, (ep + " execution provider is not enabled in this build. ").c_str()); diff --git a/onnxruntime/test/onnx/main.cc b/onnxruntime/test/onnx/main.cc index 0397bba90438b..17b5cce6a4d6e 100644 --- a/onnxruntime/test/onnx/main.cc +++ b/onnxruntime/test/onnx/main.cc @@ -47,13 +47,16 @@ void usage() { "\t-v: verbose\n" "\t-n [test_case_name]: Specifies a single test case to run.\n" "\t-e [EXECUTION_PROVIDER]: EXECUTION_PROVIDER could be 'cpu', 'cuda', 'dnnl', 'tensorrt', 'vsinpu'" - "'openvino', 'rocm', 'migraphx', 'acl', 'armnn', 'xnnpack', 'nnapi', 'qnn', 'snpe' or 'coreml'. " + "'openvino', 'rocm', 'migraphx', 'acl', 'armnn', 'xnnpack', 'webgpu', 'nnapi', 'qnn', 'snpe' or 'coreml'. " "Default: 'cpu'.\n" "\t-p: Pause after launch, can attach debugger and continue\n" "\t-x: Use parallel executor, default (without -x): sequential executor.\n" "\t-d [device_id]: Specifies the device id for multi-device (e.g. GPU). The value should > 0\n" "\t-t: Specify custom relative tolerance values for output value comparison. default: 1e-5\n" "\t-a: Specify custom absolute tolerance values for output value comparison. default: 1e-5\n" + "\t-C: Specify session configuration entries as key-value pairs: -C \"| |\" \n" + "\t Refer to onnxruntime_session_options_config_keys.h for valid keys and values. \n" + "\t [Example] -C \"session.disable_cpu_ep_fallback|1 ep.context_enable|1\" \n" "\t-i: Specify EP specific runtime options as key value pairs. Different runtime options available are: \n" "\t [QNN only] [backend_path]: QNN backend path. e.g '/folderpath/libQnnHtp.so', '/folderpath/libQnnCpu.so'.\n" "\t [QNN only] [profiling_level]: QNN profiling level, options: 'basic', 'detailed', default 'off'.\n" @@ -123,6 +126,39 @@ static TestTolerances LoadTestTolerances(bool enable_cuda, bool enable_openvino, overrides_json["atol_default"], overrides_json["rtol_default"], absolute_overrides, relative_overrides); } +static bool ParseSessionConfigs(const std::string& configs_string, + std::unordered_map& session_configs) { + std::istringstream ss(configs_string); + std::string token; + + while (ss >> token) { + if (token == "") { + continue; + } + + std::string_view token_sv(token); + + auto pos = token_sv.find("|"); + if (pos == std::string_view::npos || pos == 0 || pos == token_sv.length()) { + // Error: must use a '|' to separate the key and value for session configuration entries. + return false; + } + + std::string key(token_sv.substr(0, pos)); + std::string value(token_sv.substr(pos + 1)); + + auto it = session_configs.find(key); + if (it != session_configs.end()) { + // Error: specified duplicate session configuration entry: {key} + return false; + } + + session_configs.insert(std::make_pair(std::move(key), std::move(value))); + } + + return true; +} + #ifdef _WIN32 int GetNumCpuCores() { SYSTEM_LOGICAL_PROCESSOR_INFORMATION buffer[256]; @@ -179,6 +215,7 @@ int real_main(int argc, char* argv[], Ort::Env& env) { bool enable_armnn = false; bool enable_rocm = false; bool enable_migraphx = false; + bool enable_webgpu = false; bool enable_xnnpack = false; bool override_tolerance = false; double atol = 1e-5; @@ -188,6 +225,7 @@ int real_main(int argc, char* argv[], Ort::Env& env) { bool user_graph_optimization_level_set = false; bool set_denormal_as_zero = false; std::basic_string ep_runtime_config_string; + std::unordered_map session_config_entries; std::string provider_name = "cpu"; OrtLoggingLevel logging_level = ORT_LOGGING_LEVEL_ERROR; @@ -198,7 +236,7 @@ int real_main(int argc, char* argv[], Ort::Env& env) { bool pause = false; { int ch; - while ((ch = getopt(argc, argv, ORT_TSTR("Ac:hj:Mn:r:e:t:a:xvo:d:i:pzfb"))) != -1) { + while ((ch = getopt(argc, argv, ORT_TSTR("Ac:hj:Mn:r:e:t:a:xvo:d:C:i:pzfb"))) != -1) { switch (ch) { case 'A': enable_cpu_mem_arena = false; @@ -267,6 +305,8 @@ int real_main(int argc, char* argv[], Ort::Env& env) { enable_rocm = true; } else if (!CompareCString(optarg, ORT_TSTR("migraphx"))) { enable_migraphx = true; + } else if (!CompareCString(optarg, ORT_TSTR("webgpu"))) { + enable_webgpu = true; } else if (!CompareCString(optarg, ORT_TSTR("xnnpack"))) { enable_xnnpack = true; } else { @@ -323,6 +363,11 @@ int real_main(int argc, char* argv[], Ort::Env& env) { return -1; } break; + case 'C': + if (!ParseSessionConfigs(ToUTF8String(optarg), session_config_entries)) { + return -1; + } + break; case 'i': ep_runtime_config_string = optarg; break; @@ -409,6 +454,10 @@ int real_main(int argc, char* argv[], Ort::Env& env) { if (disable_ep_context_embed_mode) sf.AddConfigEntry(kOrtSessionOptionEpContextEmbedMode, "0"); + for (auto& it : session_config_entries) { + sf.AddConfigEntry(it.first.c_str(), it.second.c_str()); + } + if (enable_tensorrt) { #ifdef USE_TENSORRT Ort::ThrowOnError(OrtSessionOptionsAppendExecutionProvider_Tensorrt(sf, device_id)); @@ -698,6 +747,15 @@ select from 'TF8', 'TF16', 'UINT8', 'FLOAT', 'ITENSOR'. \n)"); #endif } + if (enable_webgpu) { +#ifdef USE_WEBGPU + sf.AppendExecutionProvider("WebGPU", {}); +#else + fprintf(stderr, "WebGPU is not supported in this build"); + return -1; +#endif + } + if (user_graph_optimization_level_set) { sf.SetGraphOptimizationLevel(graph_optimization_level); } diff --git a/tools/ci_build/build.py b/tools/ci_build/build.py index 587d035541c45..9e71e35a92909 100644 --- a/tools/ci_build/build.py +++ b/tools/ci_build/build.py @@ -563,6 +563,7 @@ def convert_arg_line_to_args(self, arg_line): "--nnapi_min_api", type=int, help="Minimum Android API level to enable NNAPI, should be no less than 27" ) parser.add_argument("--use_jsep", action="store_true", help="Build with JavaScript kernels.") + parser.add_argument("--use_webgpu", action="store_true", help="Build with WebGPU support.") parser.add_argument("--use_qnn", action="store_true", help="Build with QNN support.") parser.add_argument("--qnn_home", help="Path to QNN SDK dir.") parser.add_argument("--use_rknpu", action="store_true", help="Build with RKNPU.") @@ -1054,6 +1055,7 @@ def generate_build_tree( "-Donnxruntime_ARMNN_RELU_USE_CPU=" + ("OFF" if args.armnn_relu else "ON"), "-Donnxruntime_ARMNN_BN_USE_CPU=" + ("OFF" if args.armnn_bn else "ON"), "-Donnxruntime_USE_JSEP=" + ("ON" if args.use_jsep else "OFF"), + "-Donnxruntime_USE_WEBGPU=" + ("ON" if args.use_webgpu else "OFF"), # Training related flags "-Donnxruntime_ENABLE_NVTX_PROFILE=" + ("ON" if args.enable_nvtx_profile else "OFF"), "-Donnxruntime_ENABLE_TRAINING=" + ("ON" if args.enable_training else "OFF"), @@ -1310,6 +1312,9 @@ def generate_build_tree( raise BuildError("WebNN is only available for WebAssembly build.") cmake_args += ["-Donnxruntime_USE_WEBNN=ON"] + if args.use_jsep and args.use_webgpu: + raise BuildError("JSEP (--use_jsep) and WebGPU (--use_webgpu) cannot be enabled at the same time.") + if args.use_snpe: cmake_args += ["-Donnxruntime_USE_SNPE=ON"] From 9c362501db6c8cf645c22cc5b76b14993462ab02 Mon Sep 17 00:00:00 2001 From: Yulong Wang <7679871+fs-eire@users.noreply.github.com> Date: Wed, 28 Aug 2024 16:07:13 -0700 Subject: [PATCH 02/77] update C-API --- .../core/session/onnxruntime_c_api.h | 18 ++++++++++++------ .../core/session/onnxruntime_cxx_api.h | 6 +++--- .../core/session/onnxruntime_cxx_inline.h | 10 +++++----- onnxruntime/core/session/onnxruntime_c_api.cc | 2 +- onnxruntime/core/session/ort_apis.h | 4 ++-- .../core/session/provider_registration.cc | 4 +++- 6 files changed, 26 insertions(+), 18 deletions(-) diff --git a/include/onnxruntime/core/session/onnxruntime_c_api.h b/include/onnxruntime/core/session/onnxruntime_c_api.h index 9e5d9339bffe9..e6049b45c8eca 100644 --- a/include/onnxruntime/core/session/onnxruntime_c_api.h +++ b/include/onnxruntime/core/session/onnxruntime_c_api.h @@ -641,14 +641,20 @@ typedef struct OrtMIGraphXProviderOptions { * It's user's responsibility to manage the lifecycle of the handles and ensure the handles are valid during the * lifetime of the inference session. * - * \see OrtApi::SessionOptionsAppendExecutionProvider_WebGPU + * About DawnProcTable: + * + * When using an ONNX Runtime build that is not directly linked dawn during the build, a pointer to the runtime memory + * address of the DawnProcTable should be provided. Otherwise, keep it as nullptr. + * + * \see OrtApi::SessionOptionsAppendExecutionProvider_WGPU */ -typedef struct OrtWebGPUProviderOptions { +typedef struct OrtWGPUProviderOptions { int device_id; // WebGPU device id. void* instance_handle; // WebGPU instance handle. void* adapter_handle; // WebGPU adapter handle. void* device_handle; // WebGPU device handle. -} OrtWebGPUProviderOptions; + void* dawn_proc_table; // DawnProcTable pointer. +} OrtWGPUProviderOptions; /** \brief OpenVINO Provider Options * @@ -4699,7 +4705,7 @@ struct OrtApi { * If WebGPU is not available, this function will return failure. * * \param[in] options - * \param[in] webgpu_options - specify the WebGPU provider options. + * \param[in] wgpu_options - specify the WebGPU provider options. * \param[in] string_options_keys - keys to configure the string options * \param[in] string_options_values - values to configure the string options * \param[in] num_keys - number of keys passed in @@ -4719,8 +4725,8 @@ struct OrtApi { * * \since Version 1.20. */ - ORT_API2_STATUS(SessionOptionsAppendExecutionProvider_WebGPU, - _In_ OrtSessionOptions* options, _In_ const OrtWebGPUProviderOptions* webgpu_options, + ORT_API2_STATUS(SessionOptionsAppendExecutionProvider_WGPU, + _In_ OrtSessionOptions* options, _In_ const OrtWGPUProviderOptions* wgpu_options, _In_reads_(num_keys) const char* const* string_options_keys, _In_reads_(num_keys) const char* const* string_options_values, _In_ size_t num_keys); diff --git a/include/onnxruntime/core/session/onnxruntime_cxx_api.h b/include/onnxruntime/core/session/onnxruntime_cxx_api.h index cf30584e18a4a..f85dad0a41ea7 100644 --- a/include/onnxruntime/core/session/onnxruntime_cxx_api.h +++ b/include/onnxruntime/core/session/onnxruntime_cxx_api.h @@ -890,9 +890,9 @@ struct SessionOptionsImpl : ConstSessionOptionsImpl { SessionOptionsImpl& AppendExecutionProvider_CANN(const OrtCANNProviderOptions& provider_options); ///< Wraps OrtApi::SessionOptionsAppendExecutionProvider_Dnnl SessionOptionsImpl& AppendExecutionProvider_Dnnl(const OrtDnnlProviderOptions& provider_options); - ///< Wraps OrtApi::SessionOptionsAppendExecutionProvider_WebGPU - SessionOptionsImpl& AppendExecutionProvider_WebGPU(const OrtWebGPUProviderOptions& webgpu_options, - const std::unordered_map& string_options = {}); + ///< Wraps OrtApi::SessionOptionsAppendExecutionProvider_WGPU + SessionOptionsImpl& AppendExecutionProvider_WGPU(const OrtWGPUProviderOptions& wgpu_options, + const std::unordered_map& string_options = {}); /// Wraps OrtApi::SessionOptionsAppendExecutionProvider. Currently supports QNN, SNPE and XNNPACK. SessionOptionsImpl& AppendExecutionProvider(const std::string& provider_name, const std::unordered_map& provider_options = {}); diff --git a/include/onnxruntime/core/session/onnxruntime_cxx_inline.h b/include/onnxruntime/core/session/onnxruntime_cxx_inline.h index e5c84395ad95b..b675ff04268fc 100644 --- a/include/onnxruntime/core/session/onnxruntime_cxx_inline.h +++ b/include/onnxruntime/core/session/onnxruntime_cxx_inline.h @@ -839,21 +839,21 @@ inline SessionOptionsImpl& SessionOptionsImpl::AppendExecutionProvider_MIG } template -inline SessionOptionsImpl& SessionOptionsImpl::AppendExecutionProvider_WebGPU(const OrtWebGPUProviderOptions& webgpu_options, - const std::unordered_map& string_options) { - auto num_entries = provider_options.size(); +inline SessionOptionsImpl& SessionOptionsImpl::AppendExecutionProvider_WGPU(const OrtWGPUProviderOptions& wgpu_options, + const std::unordered_map& string_options) { + auto num_entries = string_options.size(); std::vector keys, values; if (num_entries > 0) { keys.reserve(num_entries); values.reserve(num_entries); - for (const auto& entry : provider_options) { + for (const auto& entry : string_options) { keys.push_back(entry.first.c_str()); values.push_back(entry.second.c_str()); } } - ThrowOnError(GetApi().SessionOptionsAppendExecutionProvider_WebGPU(this->p_, &provider_options, keys.data(), values.data(), num_entries)); + ThrowOnError(GetApi().SessionOptionsAppendExecutionProvider_WGPU(this->p_, &wgpu_options, keys.data(), values.data(), num_entries)); return *this; } diff --git a/onnxruntime/core/session/onnxruntime_c_api.cc b/onnxruntime/core/session/onnxruntime_c_api.cc index f231b0148b37e..3e787bb17aee5 100644 --- a/onnxruntime/core/session/onnxruntime_c_api.cc +++ b/onnxruntime/core/session/onnxruntime_c_api.cc @@ -2731,7 +2731,7 @@ static constexpr OrtApi ort_api_1_to_20 = { &OrtApis::AddExternalInitializersFromFilesInMemory, // End of Version 18 - DO NOT MODIFY ABOVE (see above text for more information) - &OrtApis::SessionOptionsAppendExecutionProvider_WebGPU, + &OrtApis::SessionOptionsAppendExecutionProvider_WGPU, }; // OrtApiBase can never change as there is no way to know what version of OrtApiBase is returned by OrtGetApiBase. diff --git a/onnxruntime/core/session/ort_apis.h b/onnxruntime/core/session/ort_apis.h index fd765feae6ad0..86cb3f3122d66 100644 --- a/onnxruntime/core/session/ort_apis.h +++ b/onnxruntime/core/session/ort_apis.h @@ -384,9 +384,9 @@ ORT_API_STATUS_IMPL(InvokeOp, ORT_API(void, ReleaseOp, _Frees_ptr_opt_ OrtOp* op); -ORT_API_STATUS_IMPL(SessionOptionsAppendExecutionProvider_WebGPU, +ORT_API_STATUS_IMPL(SessionOptionsAppendExecutionProvider_WGPU, _In_ OrtSessionOptions* options, - _In_ const OrtWebGPUProviderOptions* webgpu_options, + _In_ const OrtWGPUProviderOptions* wgpu_options, _In_reads_(num_keys) const char* const* string_options_keys, _In_reads_(num_keys) const char* const* string_options_values, _In_ size_t num_keys); diff --git a/onnxruntime/core/session/provider_registration.cc b/onnxruntime/core/session/provider_registration.cc index d2f8579fef7ec..1938ea3fd2c10 100644 --- a/onnxruntime/core/session/provider_registration.cc +++ b/onnxruntime/core/session/provider_registration.cc @@ -165,7 +165,7 @@ ORT_API_STATUS_IMPL(OrtApis::SessionOptionsAppendExecutionProvider, API_IMPL_END } -ORT_API_STATUS_IMPL(OrtApis::SessionOptionsAppendExecutionProvider_WebGPU, +ORT_API_STATUS_IMPL(OrtApis::SessionOptionsAppendExecutionProvider_WGPU, _In_ OrtSessionOptions* options, _In_ const OrtWebGPUProviderOptions* webgpu_options, _In_reads_(num_keys) const char* const* string_options_keys, @@ -209,6 +209,8 @@ ORT_API_STATUS_IMPL(OrtApis::SessionOptionsAppendExecutionProvider_WebGPU, options_keys.push_back("webgpuDevice"); options_values.push_back(device_handle.c_str()); + // TODO: dawn proc table + for (size_t i = 0; i != num_keys; ++i) { options_keys.push_back(string_options_keys[i]); options_values.push_back(string_options_values[i]); From 3a0756d4f2cc091dea9c4e04d9053f6931b5e0ac Mon Sep 17 00:00:00 2001 From: Yulong Wang <7679871+fs-eire@users.noreply.github.com> Date: Wed, 28 Aug 2024 16:54:41 -0700 Subject: [PATCH 03/77] fix build break --- onnxruntime/core/session/provider_registration.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/onnxruntime/core/session/provider_registration.cc b/onnxruntime/core/session/provider_registration.cc index 1938ea3fd2c10..da97cdc25ab12 100644 --- a/onnxruntime/core/session/provider_registration.cc +++ b/onnxruntime/core/session/provider_registration.cc @@ -167,7 +167,7 @@ ORT_API_STATUS_IMPL(OrtApis::SessionOptionsAppendExecutionProvider, ORT_API_STATUS_IMPL(OrtApis::SessionOptionsAppendExecutionProvider_WGPU, _In_ OrtSessionOptions* options, - _In_ const OrtWebGPUProviderOptions* webgpu_options, + _In_ const OrtWGPUProviderOptions* webgpu_options, _In_reads_(num_keys) const char* const* string_options_keys, _In_reads_(num_keys) const char* const* string_options_values, _In_ size_t num_keys) { From 5199e9858993b85b9f7809c94054e65a2c4ed5e3 Mon Sep 17 00:00:00 2001 From: Yulong Wang <7679871+fs-eire@users.noreply.github.com> Date: Wed, 28 Aug 2024 16:59:44 -0700 Subject: [PATCH 04/77] add an empty symbols.txt file --- onnxruntime/core/providers/webgpu/symbols.txt | 0 1 file changed, 0 insertions(+), 0 deletions(-) create mode 100644 onnxruntime/core/providers/webgpu/symbols.txt diff --git a/onnxruntime/core/providers/webgpu/symbols.txt b/onnxruntime/core/providers/webgpu/symbols.txt new file mode 100644 index 0000000000000..e69de29bb2d1d From 1c68dbd361157e331923668e0c5c04b1d0d17864 Mon Sep 17 00:00:00 2001 From: Yulong Wang <7679871+fs-eire@users.noreply.github.com> Date: Wed, 28 Aug 2024 17:18:40 -0700 Subject: [PATCH 05/77] fix an error in doc --- .../core/providers/webgpu/How_to_Write_WebGPU_EP_Kernel.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/onnxruntime/core/providers/webgpu/How_to_Write_WebGPU_EP_Kernel.md b/onnxruntime/core/providers/webgpu/How_to_Write_WebGPU_EP_Kernel.md index a5a71fd94bf47..87309f6673bbc 100644 --- a/onnxruntime/core/providers/webgpu/How_to_Write_WebGPU_EP_Kernel.md +++ b/onnxruntime/core/providers/webgpu/How_to_Write_WebGPU_EP_Kernel.md @@ -143,7 +143,7 @@ use `build.bat --use_webgpu` to build the WebGPU EP. For Release build, append ` to test, find the "onnx_test_runner.exe" in your build folder. run it like: ``` -onnx_test_runner.exe -v -e webgpu -C "session.disable_cpu_ep_fallback|1" --model_path=C:\code\onnxruntime\js\test\data\node\opset17\test_abs +onnx_test_runner.exe -v -e webgpu -C "session.disable_cpu_ep_fallback|1" C:\code\onnxruntime\js\test\data\node\opset17\test_abs ``` > Assume C:\code\onnxruntime is the root of your onnxruntime repo From 7db03de2ccdd60b246405d88c34a02934e70f0f2 Mon Sep 17 00:00:00 2001 From: Yulong Wang <7679871+fs-eire@users.noreply.github.com> Date: Wed, 28 Aug 2024 17:31:49 -0700 Subject: [PATCH 06/77] remove string_join.h in favor of absl::StrJoin --- include/onnxruntime/core/common/string_join.h | 61 ------------------- onnxruntime/core/providers/webgpu/program.h | 5 +- .../core/providers/webgpu/shader_variable.h | 14 ++++- 3 files changed, 14 insertions(+), 66 deletions(-) delete mode 100644 include/onnxruntime/core/common/string_join.h diff --git a/include/onnxruntime/core/common/string_join.h b/include/onnxruntime/core/common/string_join.h deleted file mode 100644 index 2c2181d4ad048..0000000000000 --- a/include/onnxruntime/core/common/string_join.h +++ /dev/null @@ -1,61 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#pragma once - -#include -#include -#include - -#include "core/common/make_string.h" - -namespace onnxruntime { - -namespace detail { - -template -inline void StringJoinImpl(const Separator& separator, std::ostringstream& ss) noexcept { -} - -template -inline void StringJoinImpl(const Separator& separator, std::ostringstream& ss, const T& t) noexcept { - ss << separator << t; -} - -template -inline void StringJoinImpl(const Separator& separator, std::ostringstream& ss, const T& t, const Args&... args) noexcept { - StringJoinImpl(separator, ss, t); - StringJoinImpl(separator, ss, args...); -} - -template -inline std::string StringJoinImpl(const Separator& separator, const Args&... args) noexcept { - std::ostringstream ss; - ss.imbue(std::locale::classic()); - StringJoinImpl(separator, ss, args...); - return ss.str(); -} -} // namespace detail - -/** - * Makes a string by concatenating string representations of the arguments using the specified separator. - * Uses std::locale::classic() - */ -template -std::string StringJoin(const Separator& separator, const Args&... args) { - return detail::StringJoinImpl(separator, detail::if_char_array_make_ptr_t(args)...); -} - -// StringJoin versions for already-a-string types. - -template -inline std::string StringJoin(const Separator& /* separator */, const std::string& str) { - return str; -} - -template -inline std::string StringJoin(const Separator& /* separator */, const char* cstr) { - return cstr; -} - -} // namespace onnxruntime diff --git a/onnxruntime/core/providers/webgpu/program.h b/onnxruntime/core/providers/webgpu/program.h index 6df918e2f7f71..277c00e089017 100644 --- a/onnxruntime/core/providers/webgpu/program.h +++ b/onnxruntime/core/providers/webgpu/program.h @@ -7,8 +7,9 @@ #include #include +#include + #include "core/common/common.h" -#include "core/common/string_join.h" #include "core/common/safeint.h" #include "core/framework/tensor.h" @@ -218,7 +219,7 @@ class ProgramBase { // set the cache hint for the program template ProgramBase& CacheHint(CacheHintArgs&&... args) { - cache_hint_ = StringJoin("|", std::forward(args)...); + cache_hint_ = absl::StrJoin(std::forward_as_tuple(std::forward(args)...), "|"); } // set one or more program inputs diff --git a/onnxruntime/core/providers/webgpu/shader_variable.h b/onnxruntime/core/providers/webgpu/shader_variable.h index 0a5cad8237871..65a015c8e7ba2 100644 --- a/onnxruntime/core/providers/webgpu/shader_variable.h +++ b/onnxruntime/core/providers/webgpu/shader_variable.h @@ -191,7 +191,11 @@ inline std::string ShaderVariable::BroadcastedIndicesToOffset(const std::string& template inline std::string ShaderVariable::Indices(TIndices&&... indices_args) const { - return rank_ == 0 ? "" : MakeStringWithClassicLocale(name_, "_indices_t(", onnxruntime::detail::StringJoinImpl(", ", std::forward(indices_args)...), ')'); + return rank_ == 0 + ? "" + : MakeStringWithClassicLocale(name_, "_indices_t(", + absl::StrJoin(std::forward_as_tuple(std::forward(indices_args)...), ", "), + ')'); } template @@ -219,7 +223,9 @@ inline std::string ShaderVariable::Set(TIndicesAndValue&&... args) const { return SetByOffset(std::forward(args)...); } else { usage_ |= UseSet | UseSetByIndices | UseIndicesToOffset; - return MakeStringWithClassicLocale("set_", name_, '(', onnxruntime::detail::StringJoinImpl(", ", std::forward(args)...), ");"); + return MakeStringWithClassicLocale("set_", name_, '(', + absl::StrJoin(std::forward_as_tuple(std::forward(args)...), ", "), + ");"); } } @@ -246,7 +252,9 @@ inline std::string ShaderVariable::Get(TIndices&&... indices) const { return GetByOffset(std::forward(indices)...); } else { usage_ |= UseGet | UseGetByIndices | UseIndicesToOffset; - return MakeStringWithClassicLocale("get_", name_, '(', onnxruntime::detail::StringJoinImpl(", ", std::forward(indices)...), ')'); + return MakeStringWithClassicLocale("get_", name_, '(', + absl::StrJoin(std::forward_as_tuple(std::forward(indices)...), ", "), + ')'); } } From 6a373c231ea048799e5359a929170a3fffda0a6c Mon Sep 17 00:00:00 2001 From: Yulong Wang <7679871+fs-eire@users.noreply.github.com> Date: Wed, 28 Aug 2024 21:26:46 -0700 Subject: [PATCH 07/77] fix DLL copy --- cmake/onnxruntime_providers_webgpu.cmake | 7 +++++++ cmake/onnxruntime_unittests.cmake | 10 ---------- 2 files changed, 7 insertions(+), 10 deletions(-) diff --git a/cmake/onnxruntime_providers_webgpu.cmake b/cmake/onnxruntime_providers_webgpu.cmake index 303ab9483c38a..587c4b2c1ff2c 100644 --- a/cmake/onnxruntime_providers_webgpu.cmake +++ b/cmake/onnxruntime_providers_webgpu.cmake @@ -27,4 +27,11 @@ onnxruntime_add_include_to_target(onnxruntime_providers_webgpu onnxruntime_common onnx onnx_proto flatbuffers::flatbuffers Boost::mp11 safeint_interface) target_link_libraries(onnxruntime_providers_webgpu dawn::webgpu_dawn) + # Copy webgpu_dawn.dll to the output directory + add_custom_command( + TARGET onnxruntime_providers_webgpu + POST_BUILD + COMMAND ${CMAKE_COMMAND} -E copy_if_different "$" "$" + VERBATIM ) + set_target_properties(onnxruntime_providers_webgpu PROPERTIES FOLDER "ONNXRuntime") diff --git a/cmake/onnxruntime_unittests.cmake b/cmake/onnxruntime_unittests.cmake index 5434ead12f65d..511c25dd6d15e 100644 --- a/cmake/onnxruntime_unittests.cmake +++ b/cmake/onnxruntime_unittests.cmake @@ -1100,11 +1100,6 @@ if (NOT IOS) endif() set_target_properties(onnx_test_runner PROPERTIES FOLDER "ONNXRuntimeTest") - add_custom_command(TARGET onnx_test_runner POST_BUILD - COMMAND ${CMAKE_COMMAND} -E copy $ $ - COMMAND_EXPAND_LISTS - ) - if (onnxruntime_USE_TVM) if (WIN32) target_link_options(onnx_test_runner PRIVATE "/STACK:4000000") @@ -1235,11 +1230,6 @@ if (NOT onnxruntime_ENABLE_TRAINING_TORCH_INTEROP) endif() endif() - add_custom_command(TARGET onnxruntime_perf_test POST_BUILD - COMMAND ${CMAKE_COMMAND} -E copy $ $ - COMMAND_EXPAND_LISTS - ) - if (onnxruntime_BUILD_SHARED_LIB) #It will dynamically link to onnxruntime. So please don't add onxruntime_graph/onxruntime_framework/... here. #onnxruntime_common is kind of ok because it is thin, tiny and totally stateless. From ee42bba8a2e19030c8d8cea3e0d0d08c279082d6 Mon Sep 17 00:00:00 2001 From: Yulong Wang <7679871+fs-eire@users.noreply.github.com> Date: Thu, 29 Aug 2024 01:11:14 -0700 Subject: [PATCH 08/77] update doc: require --skip_tests --- .../core/providers/webgpu/How_to_Write_WebGPU_EP_Kernel.md | 2 +- onnxruntime/core/providers/webgpu/README.md | 4 +++- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/onnxruntime/core/providers/webgpu/How_to_Write_WebGPU_EP_Kernel.md b/onnxruntime/core/providers/webgpu/How_to_Write_WebGPU_EP_Kernel.md index 87309f6673bbc..3c20130ae2cef 100644 --- a/onnxruntime/core/providers/webgpu/How_to_Write_WebGPU_EP_Kernel.md +++ b/onnxruntime/core/providers/webgpu/How_to_Write_WebGPU_EP_Kernel.md @@ -139,7 +139,7 @@ This section is WIP. ## 6. Build and test -use `build.bat --use_webgpu` to build the WebGPU EP. For Release build, append `--config Release` or `--config RelWithDebInfo` to the command line. +use `build.bat --use_webgpu --skip_tests` to build the WebGPU EP. For Release build, append `--config Release` or `--config RelWithDebInfo` to the command line. to test, find the "onnx_test_runner.exe" in your build folder. run it like: ``` diff --git a/onnxruntime/core/providers/webgpu/README.md b/onnxruntime/core/providers/webgpu/README.md index d9c4313c8bf3f..20864d3609145 100644 --- a/onnxruntime/core/providers/webgpu/README.md +++ b/onnxruntime/core/providers/webgpu/README.md @@ -4,7 +4,9 @@ This folder is for the WebGPU execution provider(WebGPU EP). Currently, WebGPU E ## Build WebGPU EP -Just append `--use_webgpu` to the `build.bat` command line. +Just append `--use_webgpu --skip_tests` to the `build.bat` command line. + +NOTE: `--skip_tests` is required for now. All existing tests are for CPU EP anyway so no need to run them. Currently only works on Windows. From 3f46e5c6e6aa8311540f3dbfb01a45dce2e35bab Mon Sep 17 00:00:00 2001 From: Yulong Wang <7679871+fs-eire@users.noreply.github.com> Date: Thu, 29 Aug 2024 02:10:44 -0700 Subject: [PATCH 09/77] update dawn version --- cmake/deps.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/cmake/deps.txt b/cmake/deps.txt index 2ab00cdbeb30c..597c051b5f477 100644 --- a/cmake/deps.txt +++ b/cmake/deps.txt @@ -59,4 +59,4 @@ extensions;https://github.com/microsoft/onnxruntime-extensions/archive/94142d839 composable_kernel;https://github.com/ROCmSoftwarePlatform/composable_kernel/archive/204da9c522cebec5220bba52cd3542ebcaf99e7a.zip;1827348efd47831c13074245274d41b7cae8a557 directx_headers;https://github.com/microsoft/DirectX-Headers/archive/refs/tags/v1.613.1.zip;47653509a3371eabb156360f42faf582f314bf2e cudnn_frontend;https://github.com/NVIDIA/cudnn-frontend/archive/refs/tags/v1.5.2.zip;11071a47594b20f00af09aad83e0d5203ccf6029 -dawn;https://github.com/google/dawn/archive/9a912d8162d5a837950de14f8849230212e3f51c.zip;7f2cad3db905e2d846d8f2422623850a4463915f +dawn;https://github.com/google/dawn/archive/511eb80847afe6bded34ec491a38d5d78ba2d604.zip;c493f5aca5586f6634e25d0121c85df71189fb99 From 9f61279361e33e3cec0891a8cd95869d841bc17a Mon Sep 17 00:00:00 2001 From: Yulong Wang <7679871+fs-eire@users.noreply.github.com> Date: Thu, 29 Aug 2024 13:11:38 -0700 Subject: [PATCH 10/77] disable Tint tests --- cmake/external/onnxruntime_external_deps.cmake | 1 + 1 file changed, 1 insertion(+) diff --git a/cmake/external/onnxruntime_external_deps.cmake b/cmake/external/onnxruntime_external_deps.cmake index 2dad3479c3c0f..6640609aa71dd 100644 --- a/cmake/external/onnxruntime_external_deps.cmake +++ b/cmake/external/onnxruntime_external_deps.cmake @@ -593,6 +593,7 @@ if (onnxruntime_USE_WEBGPU) ) set(DAWN_FETCH_DEPENDENCIES ON) set(DAWN_ENABLE_INSTALL ON) + set(TINT_BUILD_TESTS OFF) onnxruntime_fetchcontent_makeavailable(dawn) endif() From 6bb6335a71bdfbeb45f0f2ed9bd20ebb004967ab Mon Sep 17 00:00:00 2001 From: Yulong Wang <7679871+fs-eire@users.noreply.github.com> Date: Thu, 29 Aug 2024 16:02:16 -0700 Subject: [PATCH 11/77] fix one build break in Linux --- onnxruntime/core/providers/webgpu/compute_context.h | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/onnxruntime/core/providers/webgpu/compute_context.h b/onnxruntime/core/providers/webgpu/compute_context.h index d7aeae240101a..4d567b088fc1a 100644 --- a/onnxruntime/core/providers/webgpu/compute_context.h +++ b/onnxruntime/core/providers/webgpu/compute_context.h @@ -70,7 +70,7 @@ class ComputeContext { Tensor CreateCPUTensor(MLDataType data_type, TensorShapeType&& shape) { AllocatorPtr allocator; ORT_THROW_IF_ERROR(kernel_context_.GetTempSpaceCPUAllocator(&allocator)); - return {data_type, std::forward(shape)..., allocator}; + return {data_type, std::forward(shape), allocator}; } // @@ -80,7 +80,7 @@ class ComputeContext { Tensor CreateGPUTensor(MLDataType data_type, TensorShapeType&& shape) { AllocatorPtr allocator; ORT_THROW_IF_ERROR(kernel_context_.GetTempSpaceAllocator(&allocator)); - return {data_type, std::forward(shape)..., allocator}; + return {data_type, std::forward(shape), allocator}; } // From d839dbc213e8402b75b595173fa6592f8e3cc021 Mon Sep 17 00:00:00 2001 From: Yulong Wang <7679871+fs-eire@users.noreply.github.com> Date: Thu, 29 Aug 2024 19:52:47 -0700 Subject: [PATCH 12/77] remove unused variables --- onnxruntime/core/providers/webgpu/webgpu_context.cc | 12 ------------ 1 file changed, 12 deletions(-) diff --git a/onnxruntime/core/providers/webgpu/webgpu_context.cc b/onnxruntime/core/providers/webgpu/webgpu_context.cc index a891f5a8a5516..5f09223b2271b 100644 --- a/onnxruntime/core/providers/webgpu/webgpu_context.cc +++ b/onnxruntime/core/providers/webgpu/webgpu_context.cc @@ -219,18 +219,6 @@ Status WebGpuContext::Run(const ComputeContext& /*context*/, const ProgramBase& #endif } - std::vector input_buffers; - input_buffers.reserve(inputs.size()); - for (const auto& input : inputs) { - input_buffers.push_back(reinterpret_cast(const_cast(input.tensor->DataRaw()))); - } - - std::vector output_buffers; - output_buffers.reserve(outputs.size()); - for (const auto& output : outputs) { - output_buffers.push_back(reinterpret_cast(output->MutableDataRaw())); - } - WGPUBuffer uniform_buffer = nullptr; auto uniform_buffer_size = program_artifact->uniform_total_size; if (uniform_buffer_size > 0) { From b70943d92b0c225eae7b9576a22d6633342aff6e Mon Sep 17 00:00:00 2001 From: Guenther Schmuelling Date: Fri, 30 Aug 2024 14:02:58 -0700 Subject: [PATCH 13/77] make webgpu build on linux and known to most tools (#21937) Co-authored-by: Yulong Wang <7679871+fs-eire@users.noreply.github.com> --- cmake/onnxruntime.cmake | 3 ++- cmake/onnxruntime_python.cmake | 1 + onnxruntime/core/providers/webgpu/compute_context.h | 2 +- onnxruntime/core/providers/webgpu/shader_variable.h | 4 ++-- .../core/providers/webgpu/webgpu_provider_factory.cc | 5 +++-- onnxruntime/python/onnxruntime_pybind_state.cc | 4 ++++ onnxruntime/test/perftest/command_args_parser.cc | 6 ++++-- onnxruntime/test/perftest/ort_test_session.cc | 7 +++++++ onnxruntime/test/util/default_providers.cc | 8 ++++++++ onnxruntime/test/util/include/default_providers.h | 1 + tools/ci_build/gen_def.py | 1 + 11 files changed, 34 insertions(+), 8 deletions(-) diff --git a/cmake/onnxruntime.cmake b/cmake/onnxruntime.cmake index 927b4ac84b037..52b6bda346862 100644 --- a/cmake/onnxruntime.cmake +++ b/cmake/onnxruntime.cmake @@ -38,7 +38,7 @@ function(get_c_cxx_api_headers HEADERS_VAR) # need to add header files for enabled EPs foreach(f ${ONNXRUNTIME_PROVIDER_NAMES}) - # The header files in include/onnxruntime/core/providers/cuda directory cannot be flattened to the same directory + # The header files in include/onnxruntime/core/providers/cuda directory cannot be flattened to the same directory # with onnxruntime_c_api.h . Most other EPs probably also do not work in this way. if((NOT f STREQUAL cuda) AND (NOT f STREQUAL rocm)) file(GLOB _provider_headers CONFIGURE_DEPENDS @@ -200,6 +200,7 @@ set(onnxruntime_INTERNAL_LIBRARIES ${PROVIDERS_RKNPU} ${PROVIDERS_VSINPU} ${PROVIDERS_XNNPACK} + ${PROVIDERS_WEBGPU} ${PROVIDERS_WEBNN} ${PROVIDERS_AZURE} ${PROVIDERS_INTERNAL_TESTING} diff --git a/cmake/onnxruntime_python.cmake b/cmake/onnxruntime_python.cmake index b2dbe4b3da5e8..c5ba544217233 100644 --- a/cmake/onnxruntime_python.cmake +++ b/cmake/onnxruntime_python.cmake @@ -178,6 +178,7 @@ target_link_libraries(onnxruntime_pybind11_state PRIVATE ${PROVIDERS_ACL} ${PROVIDERS_ARMNN} ${PROVIDERS_XNNPACK} + ${PROVIDERS_WEBGPU} ${PROVIDERS_AZURE} ${PROVIDERS_QNN} onnxruntime_optimizer diff --git a/onnxruntime/core/providers/webgpu/compute_context.h b/onnxruntime/core/providers/webgpu/compute_context.h index 4d567b088fc1a..9c352d3d76dd9 100644 --- a/onnxruntime/core/providers/webgpu/compute_context.h +++ b/onnxruntime/core/providers/webgpu/compute_context.h @@ -14,11 +14,11 @@ #include "core/framework/execution_provider.h" #include "core/providers/webgpu/program.h" +#include "core/framework/op_kernel.h" namespace onnxruntime { class Tensor; -class OpKernelContext; namespace webgpu { diff --git a/onnxruntime/core/providers/webgpu/shader_variable.h b/onnxruntime/core/providers/webgpu/shader_variable.h index 65a015c8e7ba2..ef95e26e6df74 100644 --- a/onnxruntime/core/providers/webgpu/shader_variable.h +++ b/onnxruntime/core/providers/webgpu/shader_variable.h @@ -131,7 +131,7 @@ class ShaderVariable { void Init(); void Impl(std::ostringstream& ss); - std::string ShaderVariable::GetByOffsetImpl(const std::string& offset) const; + std::string GetByOffsetImpl(const std::string& offset) const; std::string SetByOffsetImpl(const std::string& offset, const std::string& value) const; std::string_view StorageType() const; @@ -140,7 +140,7 @@ class ShaderVariable { std::string name_; ProgramVariableDataType type_; - int rank_; + size_t rank_; TensorShape dims_; mutable Usage usage_; diff --git a/onnxruntime/core/providers/webgpu/webgpu_provider_factory.cc b/onnxruntime/core/providers/webgpu/webgpu_provider_factory.cc index 93258b84c5112..e871b66f1dc92 100644 --- a/onnxruntime/core/providers/webgpu/webgpu_provider_factory.cc +++ b/onnxruntime/core/providers/webgpu/webgpu_provider_factory.cc @@ -67,7 +67,8 @@ std::shared_ptr WebGpuProviderFactoryCreator::Create( } LOGS_DEFAULT(VERBOSE) << "WebGPU EP graph capture enable: " << webgpu_ep_info.enable_graph_capture; - auto parse_buffer_cache_mode = [session_options](const std::string& config_entry_str, webgpu::BufferCacheMode default) -> webgpu::BufferCacheMode { + auto parse_buffer_cache_mode = [session_options](const std::string& config_entry_str, + webgpu::BufferCacheMode default_value) -> webgpu::BufferCacheMode { std::string buffer_cache_mode_str; if (session_options->config_options.TryGetConfigEntry(config_entry_str, buffer_cache_mode_str)) { if (buffer_cache_mode_str == kBufferCacheMode_Disabled) { @@ -82,7 +83,7 @@ std::shared_ptr WebGpuProviderFactoryCreator::Create( ORT_THROW("Invalid buffer cache mode: ", config_entry_str); } } else { - return default; + return default_value; } }; diff --git a/onnxruntime/python/onnxruntime_pybind_state.cc b/onnxruntime/python/onnxruntime_pybind_state.cc index 47b8d75f22aea..036585586d9ac 100644 --- a/onnxruntime/python/onnxruntime_pybind_state.cc +++ b/onnxruntime/python/onnxruntime_pybind_state.cc @@ -1207,6 +1207,10 @@ std::unique_ptr CreateExecutionProviderInstance( return onnxruntime::XnnpackProviderFactoryCreator::Create( cit == provider_options_map.end() ? ProviderOptions{} : cit->second, &session_options) ->CreateProvider(); +#endif + } else if (type == kWebGpuExecutionProvider) { +#if defined(USE_WEBGPU) + return onnxruntime::WebGpuProviderFactoryCreator::Create(&session_options)->CreateProvider(); #endif } else if (type == kCannExecutionProvider) { #ifdef USE_CANN diff --git a/onnxruntime/test/perftest/command_args_parser.cc b/onnxruntime/test/perftest/command_args_parser.cc index 84c3bc16346f3..0b8e291ec7fbc 100644 --- a/onnxruntime/test/perftest/command_args_parser.cc +++ b/onnxruntime/test/perftest/command_args_parser.cc @@ -37,8 +37,8 @@ namespace perftest { "\t-A: Disable memory arena\n" "\t-I: Generate tensor input binding (Free dimensions are treated as 1.)\n" "\t-c [parallel runs]: Specifies the (max) number of runs to invoke simultaneously. Default:1.\n" - "\t-e [cpu|cuda|dnnl|tensorrt|openvino|dml|acl|nnapi|coreml|qnn|snpe|rocm|migraphx|xnnpack|vitisai]: Specifies the provider 'cpu','cuda','dnnl','tensorrt', " - "'openvino', 'dml', 'acl', 'nnapi', 'coreml', 'qnn', 'snpe', 'rocm', 'migraphx', 'xnnpack' or 'vitisai'. " + "\t-e [cpu|cuda|dnnl|tensorrt|openvino|dml|acl|nnapi|coreml|qnn|snpe|rocm|migraphx|xnnpack|vitisai:webgpu]: Specifies the provider 'cpu','cuda','dnnl','tensorrt', " + "'openvino', 'dml', 'acl', 'nnapi', 'coreml', 'qnn', 'snpe', 'rocm', 'migraphx', 'xnnpack', 'vitisai' or 'webgpu'. " "Default:'cpu'.\n" "\t-b [tf|ort]: backend to use. Default:ort\n" "\t-r [repeated_times]: Specifies the repeated times if running in 'times' test mode.Default:1000.\n" @@ -279,6 +279,8 @@ static bool ParseSessionConfigs(const std::string& configs_string, test_config.machine_config.provider_type_name = onnxruntime::kXnnpackExecutionProvider; } else if (!CompareCString(optarg, ORT_TSTR("vitisai"))) { test_config.machine_config.provider_type_name = onnxruntime::kVitisAIExecutionProvider; + } else if (!CompareCString(optarg, ORT_TSTR("webgpu"))) { + test_config.machine_config.provider_type_name = onnxruntime::kWebGpuExecutionProvider; } else { return false; } diff --git a/onnxruntime/test/perftest/ort_test_session.cc b/onnxruntime/test/perftest/ort_test_session.cc index fc1bdb10d7453..57a20e2d03ee9 100644 --- a/onnxruntime/test/perftest/ort_test_session.cc +++ b/onnxruntime/test/perftest/ort_test_session.cc @@ -551,6 +551,13 @@ select from 'TF8', 'TF16', 'UINT8', 'FLOAT', 'ITENSOR'. \n)"); "XNNPACK", {{"intra_op_num_threads", std::to_string(performance_test_config.run_config.intra_op_num_threads)}}); #else ORT_THROW("Xnnpack is not supported in this build\n"); +#endif + } else if (provider_name_ == onnxruntime::kWebGpuExecutionProvider) { +#ifdef USE_WEBGPU + session_options.AppendExecutionProvider( + "WebGPU", {{"intra_op_num_threads", std::to_string(performance_test_config.run_config.intra_op_num_threads)}}); +#else + ORT_THROW("WebGpu is not supported in this build\n"); #endif } else if (provider_name_ == onnxruntime::kVitisAIExecutionProvider) { #ifdef USE_VITISAI diff --git a/onnxruntime/test/util/default_providers.cc b/onnxruntime/test/util/default_providers.cc index 1feba20e32bbb..871285269daf4 100644 --- a/onnxruntime/test/util/default_providers.cc +++ b/onnxruntime/test/util/default_providers.cc @@ -301,6 +301,14 @@ std::unique_ptr DefaultXnnpackExecutionProvider() { #endif } +std::unique_ptr DefaultWebGpuExecutionProvider() { +#ifdef USE_WEBGPU + return WebGpuProviderFactoryCreator::Create(nullptr)->CreateProvider(); +#else + return nullptr; +#endif +} + std::unique_ptr DefaultCannExecutionProvider() { #ifdef USE_CANN OrtCANNProviderOptions provider_options{}; diff --git a/onnxruntime/test/util/include/default_providers.h b/onnxruntime/test/util/include/default_providers.h index 606dfc068d399..610b5b4ced68d 100644 --- a/onnxruntime/test/util/include/default_providers.h +++ b/onnxruntime/test/util/include/default_providers.h @@ -62,6 +62,7 @@ std::unique_ptr DefaultQnnExecutionProvider(); std::unique_ptr QnnExecutionProviderWithOptions(const ProviderOptions& options, const SessionOptions* session_options = nullptr); std::unique_ptr DefaultXnnpackExecutionProvider(); +std::unique_ptr DefaultWebGpuExecutionProvider(); std::unique_ptr DefaultCannExecutionProvider(); std::unique_ptr DefaultDmlExecutionProvider(); diff --git a/tools/ci_build/gen_def.py b/tools/ci_build/gen_def.py index c4add6f0e8910..765e9d135b7f0 100755 --- a/tools/ci_build/gen_def.py +++ b/tools/ci_build/gen_def.py @@ -80,6 +80,7 @@ def parse_arguments(): "dnnl", "tensorrt", "azure", + "webgpu" ): file.write(f"#include \n") file.write("void* GetFunctionEntryByName(const char* name){\n") From 843726753f8bb35bb2a900add041a53f4e4245e8 Mon Sep 17 00:00:00 2001 From: Yulong Wang <7679871+fs-eire@users.noreply.github.com> Date: Fri, 30 Aug 2024 14:45:28 -0700 Subject: [PATCH 14/77] revert type of ShaderVariable::rank_ to int --- onnxruntime/core/providers/webgpu/shader_variable.h | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/onnxruntime/core/providers/webgpu/shader_variable.h b/onnxruntime/core/providers/webgpu/shader_variable.h index ef95e26e6df74..15d2259c34a93 100644 --- a/onnxruntime/core/providers/webgpu/shader_variable.h +++ b/onnxruntime/core/providers/webgpu/shader_variable.h @@ -140,7 +140,7 @@ class ShaderVariable { std::string name_; ProgramVariableDataType type_; - size_t rank_; + int rank_; TensorShape dims_; mutable Usage usage_; From 3caf032a9848d922428e4e2a5f798be1649b3c72 Mon Sep 17 00:00:00 2001 From: Yulong Wang <7679871+fs-eire@users.noreply.github.com> Date: Fri, 30 Aug 2024 15:49:26 -0700 Subject: [PATCH 15/77] output Impl() for variables --- onnxruntime/core/providers/webgpu/shader_helper.cc | 8 ++++++-- onnxruntime/core/providers/webgpu/shader_variable.cc | 2 +- onnxruntime/core/providers/webgpu/shader_variable.h | 2 +- 3 files changed, 8 insertions(+), 4 deletions(-) diff --git a/onnxruntime/core/providers/webgpu/shader_helper.cc b/onnxruntime/core/providers/webgpu/shader_helper.cc index 203f11ff90000..d3466b6d611ac 100644 --- a/onnxruntime/core/providers/webgpu/shader_helper.cc +++ b/onnxruntime/core/providers/webgpu/shader_helper.cc @@ -180,8 +180,12 @@ std::string ShaderHelper::GetFinalSourceCode() { // Indices helper // ss << "\n"; - // for (const auto& group : vars_) { - // } + for (const auto& var_group : vars_) { + for (const auto& var : var_group) { + var.Impl(ss); + } + ss << "\n"; + } // // Additional Implementation diff --git a/onnxruntime/core/providers/webgpu/shader_variable.cc b/onnxruntime/core/providers/webgpu/shader_variable.cc index d49d76c1ee858..4bff31e9dd30b 100644 --- a/onnxruntime/core/providers/webgpu/shader_variable.cc +++ b/onnxruntime/core/providers/webgpu/shader_variable.cc @@ -26,7 +26,7 @@ void ShaderVariable::Init() { ORT_ENFORCE(type_ != ProgramVariableDataType::InvalidType, "Invalid type for variable ", name_); } -void ShaderVariable::Impl(std::ostringstream& ss) { +void ShaderVariable::Impl(std::ostringstream& ss) const { // Start generating code const std::string value_t = name_ + "_value_t"; diff --git a/onnxruntime/core/providers/webgpu/shader_variable.h b/onnxruntime/core/providers/webgpu/shader_variable.h index 15d2259c34a93..fbdb6590a7359 100644 --- a/onnxruntime/core/providers/webgpu/shader_variable.h +++ b/onnxruntime/core/providers/webgpu/shader_variable.h @@ -129,7 +129,7 @@ class ShaderVariable { ORT_DISALLOW_COPY_AND_ASSIGNMENT(ShaderVariable); void Init(); - void Impl(std::ostringstream& ss); + void Impl(std::ostringstream& ss) const; std::string GetByOffsetImpl(const std::string& offset) const; std::string SetByOffsetImpl(const std::string& offset, const std::string& value) const; From 84494c4344027b15951f47246be941e7c72a3604 Mon Sep 17 00:00:00 2001 From: Yulong Wang <7679871+fs-eire@users.noreply.github.com> Date: Fri, 30 Aug 2024 16:03:16 -0700 Subject: [PATCH 16/77] code formatting --- tools/ci_build/gen_def.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tools/ci_build/gen_def.py b/tools/ci_build/gen_def.py index 765e9d135b7f0..2b7790ec4e683 100755 --- a/tools/ci_build/gen_def.py +++ b/tools/ci_build/gen_def.py @@ -80,7 +80,7 @@ def parse_arguments(): "dnnl", "tensorrt", "azure", - "webgpu" + "webgpu", ): file.write(f"#include \n") file.write("void* GetFunctionEntryByName(const char* name){\n") From aa70163a7a0431402b46408570fb97b41558bb27 Mon Sep 17 00:00:00 2001 From: Yulong Wang <7679871+fs-eire@users.noreply.github.com> Date: Fri, 30 Aug 2024 16:04:07 -0700 Subject: [PATCH 17/77] better format of Uniform --- onnxruntime/core/providers/webgpu/shader_helper.cc | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/onnxruntime/core/providers/webgpu/shader_helper.cc b/onnxruntime/core/providers/webgpu/shader_helper.cc index d3466b6d611ac..3986b13e0a7d7 100644 --- a/onnxruntime/core/providers/webgpu/shader_helper.cc +++ b/onnxruntime/core/providers/webgpu/shader_helper.cc @@ -137,7 +137,7 @@ std::string ShaderHelper::GetFinalSourceCode() { program_.UniformVariables().cend(), [](const ProgramUniformVariableValue& x) { return x.length > 0; })) { bool first = true; - ss << "struct Uniforms {\n"; + ss << "struct Uniforms {"; size_t uniform_count = program_.UniformVariables().size(); for (size_t i = 0; i < uniform_count; i++) { @@ -151,11 +151,11 @@ std::string ShaderHelper::GetFinalSourceCode() { if (first) { first = false; } else { - ss << ",\n"; + ss << ","; } auto alignment = (data_type == ProgramUniformVariableDataType::Float16 && length > 4) ? "@align(16) " : ""; - ss << " " << alignment << name << ": "; + ss << "\n " << alignment << name << ": "; if (length > 4) { if (data_type == ProgramUniformVariableDataType::Float16) { size_t array_size = (length + 7) / 8; @@ -171,7 +171,7 @@ std::string ShaderHelper::GetFinalSourceCode() { } } - ss << "};\n" + ss << "\n};\n" "@group(0) @binding(" << variable_count << ") var uniforms: Uniforms;\n"; } From d772db7ae7a2710a0f6ca6f9338186029dcb1e3c Mon Sep 17 00:00:00 2001 From: Yulong Wang <7679871+fs-eire@users.noreply.github.com> Date: Fri, 30 Aug 2024 16:06:02 -0700 Subject: [PATCH 18/77] revise document --- .../core/providers/webgpu/How_to_Write_WebGPU_EP_Kernel.md | 4 ++-- onnxruntime/core/providers/webgpu/README.md | 7 +++++-- 2 files changed, 7 insertions(+), 4 deletions(-) diff --git a/onnxruntime/core/providers/webgpu/How_to_Write_WebGPU_EP_Kernel.md b/onnxruntime/core/providers/webgpu/How_to_Write_WebGPU_EP_Kernel.md index 3c20130ae2cef..a7123ac4a580d 100644 --- a/onnxruntime/core/providers/webgpu/How_to_Write_WebGPU_EP_Kernel.md +++ b/onnxruntime/core/providers/webgpu/How_to_Write_WebGPU_EP_Kernel.md @@ -146,9 +146,9 @@ to test, find the "onnx_test_runner.exe" in your build folder. run it like: onnx_test_runner.exe -v -e webgpu -C "session.disable_cpu_ep_fallback|1" C:\code\onnxruntime\js\test\data\node\opset17\test_abs ``` -> Assume C:\code\onnxruntime is the root of your onnxruntime repo +> Assume `C:\code\onnxruntime` is the root of your onnxruntime repo > -> if it does not exist, run the following in your onnxruntime repo root: +> if folder `C:\code\onnxruntime\js\test\data` does not exist, run the following in your onnxruntime repo root: > ``` > cd js > npm ci diff --git a/onnxruntime/core/providers/webgpu/README.md b/onnxruntime/core/providers/webgpu/README.md index 20864d3609145..999f1fecbda76 100644 --- a/onnxruntime/core/providers/webgpu/README.md +++ b/onnxruntime/core/providers/webgpu/README.md @@ -4,11 +4,14 @@ This folder is for the WebGPU execution provider(WebGPU EP). Currently, WebGPU E ## Build WebGPU EP -Just append `--use_webgpu --skip_tests` to the `build.bat` command line. +Just append `--use_webgpu --skip_tests` to the `build.bat`/`build.sh` command line. NOTE: `--skip_tests` is required for now. All existing tests are for CPU EP anyway so no need to run them. -Currently only works on Windows. +For linux, a few dependencies need to be installed: +```sh +apt-get install libx11-dev libxrandr-dev libxinerama-dev libxcursor-dev libxi-dev libx11-dev libx11-xcb-dev +``` ## Troubleshooting From 6ef3dadfa5c4290922f9d3874858c071ed07f36c Mon Sep 17 00:00:00 2001 From: Yulong Wang <7679871+fs-eire@users.noreply.github.com> Date: Fri, 30 Aug 2024 21:20:29 -0700 Subject: [PATCH 19/77] more build fix for linux --- .../core/providers/webgpu/buffer_manager.cc | 2 -- onnxruntime/core/providers/webgpu/program.h | 1 + .../core/providers/webgpu/program_cache_key.cc | 2 +- .../core/providers/webgpu/shader_variable.cc | 14 +++++++------- .../core/providers/webgpu/webgpu_context.cc | 2 +- .../providers/webgpu/webgpu_execution_provider.h | 7 +++++++ 6 files changed, 17 insertions(+), 11 deletions(-) diff --git a/onnxruntime/core/providers/webgpu/buffer_manager.cc b/onnxruntime/core/providers/webgpu/buffer_manager.cc index d69b1210ade4b..e1f065b65f13e 100644 --- a/onnxruntime/core/providers/webgpu/buffer_manager.cc +++ b/onnxruntime/core/providers/webgpu/buffer_manager.cc @@ -4,8 +4,6 @@ #include "core/providers/webgpu/buffer_manager.h" #include "core/providers/webgpu/webgpu_context.h" -static int xx = 1; - namespace onnxruntime { namespace webgpu { diff --git a/onnxruntime/core/providers/webgpu/program.h b/onnxruntime/core/providers/webgpu/program.h index 277c00e089017..812e44e014ee6 100644 --- a/onnxruntime/core/providers/webgpu/program.h +++ b/onnxruntime/core/providers/webgpu/program.h @@ -220,6 +220,7 @@ class ProgramBase { template ProgramBase& CacheHint(CacheHintArgs&&... args) { cache_hint_ = absl::StrJoin(std::forward_as_tuple(std::forward(args)...), "|"); + return *this; } // set one or more program inputs diff --git a/onnxruntime/core/providers/webgpu/program_cache_key.cc b/onnxruntime/core/providers/webgpu/program_cache_key.cc index d720c55fb5427..a4530910944d4 100644 --- a/onnxruntime/core/providers/webgpu/program_cache_key.cc +++ b/onnxruntime/core/providers/webgpu/program_cache_key.cc @@ -47,7 +47,7 @@ std::string CalculateProgramCacheKey(const ProgramBase& program, bool is_1d_disp } } - ss << ":" D("DispatchDim=") << is_1d_dispatch ? "1" : "3"; + ss << ":" D("DispatchDim=") << (is_1d_dispatch ? "1" : "3"); ss << ":" D("UniformSizes="); bool first = true; for (const auto& uniform : program.UniformVariables()) { diff --git a/onnxruntime/core/providers/webgpu/shader_variable.cc b/onnxruntime/core/providers/webgpu/shader_variable.cc index 4bff31e9dd30b..9483ab19036c4 100644 --- a/onnxruntime/core/providers/webgpu/shader_variable.cc +++ b/onnxruntime/core/providers/webgpu/shader_variable.cc @@ -72,7 +72,7 @@ void ShaderVariable::Impl(std::ostringstream& ss) const { SS("fn o2i_", name_, "(offset : u32)->", indices_t, " {\n"); SS(" var indices: ", indices_t, ";\n"); SS(" var current = offset;\n"); - for (size_t i = 0; i < rank_ - 1; i++) { + for (int i = 0; i < rank_ - 1; i++) { auto current_stride = GetElementAt(stride, i, rank_); SS(" let dim", i, " = current / ", current_stride, ";\n"); SS(" let rest", i, " = current % ", current_stride, ";\n"); @@ -90,7 +90,7 @@ void ShaderVariable::Impl(std::ostringstream& ss) const { if (rank_ >= 2) { SS("fn i2o_", name_, "(indices : ", indices_t, ")->u32 {\n"); SS(" return "); - for (size_t i = 0; i < rank_ - 1; i++) { + for (int i = 0; i < rank_ - 1; i++) { SS("indices[", i, "] * ", GetElementAt(stride, i, rank_), " + "); } SS("indices[", rank_ - 1, "];\n"); @@ -108,7 +108,7 @@ void ShaderVariable::Impl(std::ostringstream& ss) const { SS(" return 0;\n"); } else { SS(" return "); - for (size_t i = 0; i < rank_ - 1; i++) { + for (int i = 0; i < rank_ - 1; i++) { auto idx = broadcasted_result.IndicesGet("indices", i + broadcasted_result.rank_ - rank_); SS(IndicesGet(stride, i), " * (", idx, " % ", IndicesGet(shape, i), ") + "); } @@ -122,12 +122,12 @@ void ShaderVariable::Impl(std::ostringstream& ss) const { if (usage_ & UseSet) { if (rank_ >= 2) { SS("fn set_", name_, "(d0: u32"); - for (size_t i = 1; i < rank_; i++) { + for (int i = 1; i < rank_; i++) { SS(", d", i, ": u32"); } SS(", value: ", value_t, ") {\n"); SS(" set_", name_, "_by_indices(d0"); - for (size_t i = 1; i < rank_; i++) { + for (int i = 1; i < rank_; i++) { SS(", d", i); } SS(", value);\n"); @@ -148,12 +148,12 @@ void ShaderVariable::Impl(std::ostringstream& ss) const { if (usage_ & UseGet) { if (rank_ >= 2) { SS("fn get_", name_, "(d0: u32"); - for (size_t i = 1; i < rank_; i++) { + for (int i = 1; i < rank_; i++) { SS(", d", i, ": u32"); } SS(")->", value_t, " {\n"); SS(" return get_", name_, "_by_indices(d0"); - for (size_t i = 1; i < rank_; i++) { + for (int i = 1; i < rank_; i++) { SS(", d", i); } SS(");\n"); diff --git a/onnxruntime/core/providers/webgpu/webgpu_context.cc b/onnxruntime/core/providers/webgpu/webgpu_context.cc index 5f09223b2271b..049a729f5c988 100644 --- a/onnxruntime/core/providers/webgpu/webgpu_context.cc +++ b/onnxruntime/core/providers/webgpu/webgpu_context.cc @@ -93,7 +93,7 @@ void WebGpuContext::Initialize(const WebGpuExecutionProviderInfo& webgpu_ep_info wgpu::RequestDeviceCallbackInfo req_device_callback_info = {}; req_device_callback_info.mode = wgpu::CallbackMode::WaitAnyOnly; req_device_callback_info.callback = [](WGPURequestDeviceStatus status, WGPUDevice device, char const* message, void* userdata) { - ORT_ENFORCE(status == WGPURequestAdapterStatus_Success, "Failed to get a WebGPU device: ", message); + ORT_ENFORCE(status == WGPURequestDeviceStatus_Success, "Failed to get a WebGPU device: ", message); *static_cast(userdata) = wgpu::Device::Acquire(device); }; req_device_callback_info.userdata = &device_; diff --git a/onnxruntime/core/providers/webgpu/webgpu_execution_provider.h b/onnxruntime/core/providers/webgpu/webgpu_execution_provider.h index 6fb2381637a67..4b2d2882b6ec2 100644 --- a/onnxruntime/core/providers/webgpu/webgpu_execution_provider.h +++ b/onnxruntime/core/providers/webgpu/webgpu_execution_provider.h @@ -22,6 +22,13 @@ enum class BufferCacheMode; } // namespace webgpu struct WebGpuExecutionProviderInfo { + WebGpuExecutionProviderInfo(DataLayout data_layout1, bool enable_graph_capture1) + : data_layout{data_layout1} + , enable_graph_capture{enable_graph_capture1} + , storage_buffer_cache_mode{} + , uniform_buffer_cache_mode{} + , query_resolve_buffer_cache_mode{} + , default_buffer_cache_mode{} {} DataLayout data_layout; bool enable_graph_capture; webgpu::BufferCacheMode storage_buffer_cache_mode; From a56f6c3edae7ad7c15cffc4032cd13c7e4b2f452 Mon Sep 17 00:00:00 2001 From: Yulong Wang <7679871+fs-eire@users.noreply.github.com> Date: Fri, 30 Aug 2024 21:23:36 -0700 Subject: [PATCH 20/77] apply formatter --- .../providers/webgpu/webgpu_execution_provider.h | 12 ++++++------ test_webgpu.bat | 12 ++++++++++++ test_webgpu_cases.txt | 1 + 3 files changed, 19 insertions(+), 6 deletions(-) create mode 100644 test_webgpu.bat create mode 100644 test_webgpu_cases.txt diff --git a/onnxruntime/core/providers/webgpu/webgpu_execution_provider.h b/onnxruntime/core/providers/webgpu/webgpu_execution_provider.h index 4b2d2882b6ec2..5f27fad14afc6 100644 --- a/onnxruntime/core/providers/webgpu/webgpu_execution_provider.h +++ b/onnxruntime/core/providers/webgpu/webgpu_execution_provider.h @@ -23,12 +23,12 @@ enum class BufferCacheMode; struct WebGpuExecutionProviderInfo { WebGpuExecutionProviderInfo(DataLayout data_layout1, bool enable_graph_capture1) - : data_layout{data_layout1} - , enable_graph_capture{enable_graph_capture1} - , storage_buffer_cache_mode{} - , uniform_buffer_cache_mode{} - , query_resolve_buffer_cache_mode{} - , default_buffer_cache_mode{} {} + : data_layout{data_layout1}, + enable_graph_capture{enable_graph_capture1}, + storage_buffer_cache_mode{}, + uniform_buffer_cache_mode{}, + query_resolve_buffer_cache_mode{}, + default_buffer_cache_mode{} {} DataLayout data_layout; bool enable_graph_capture; webgpu::BufferCacheMode storage_buffer_cache_mode; diff --git a/test_webgpu.bat b/test_webgpu.bat new file mode 100644 index 0000000000000..feec724c1a7d0 --- /dev/null +++ b/test_webgpu.bat @@ -0,0 +1,12 @@ +rem @echo off +:: if file js\test\data\node\__generated_onnx_node_tests not found, generate it +if not exist "%~dp0js\test\data\node\__generated_onnx_node_tests" ( + pushd "%~dp0js" + call npm ci + call npm run prepare-node-tests + popd +) + +for /F "tokens=*" %%A in (%~dp0test_webgpu_cases.txt) do ( + echo %%A +) diff --git a/test_webgpu_cases.txt b/test_webgpu_cases.txt new file mode 100644 index 0000000000000..4cc29f5b13ed8 --- /dev/null +++ b/test_webgpu_cases.txt @@ -0,0 +1 @@ +test_abs From 12cd79d6742e8967e697fddf144fcd55dcf1c5cc Mon Sep 17 00:00:00 2001 From: Yulong Wang <7679871+fs-eire@users.noreply.github.com> Date: Sat, 31 Aug 2024 01:38:11 -0700 Subject: [PATCH 21/77] simple test runner --- cmake/onnxruntime_unittests.cmake | 16 +++ .../webgpu/How_to_Write_WebGPU_EP_Kernel.md | 13 ++- .../test/providers/webgpu/test_webgpu.bat | 3 + .../test/providers/webgpu/test_webgpu.js | 98 +++++++++++++++++++ test_webgpu.bat | 12 --- test_webgpu_cases.txt | 1 - 6 files changed, 129 insertions(+), 14 deletions(-) create mode 100644 onnxruntime/test/providers/webgpu/test_webgpu.bat create mode 100644 onnxruntime/test/providers/webgpu/test_webgpu.js delete mode 100644 test_webgpu.bat delete mode 100644 test_webgpu_cases.txt diff --git a/cmake/onnxruntime_unittests.cmake b/cmake/onnxruntime_unittests.cmake index b050698a5570e..6c43680ecc75b 100644 --- a/cmake/onnxruntime_unittests.cmake +++ b/cmake/onnxruntime_unittests.cmake @@ -1120,6 +1120,22 @@ if (NOT IOS) LIBRARY DESTINATION ${CMAKE_INSTALL_LIBDIR} BUNDLE DESTINATION ${CMAKE_INSTALL_LIBDIR} RUNTIME DESTINATION ${CMAKE_INSTALL_BINDIR}) + + ## TODO: remove this when merging to main branch + # + # should support better test runner + # + if (onnxruntime_USE_WEBGPU) + add_custom_command( + TARGET onnx_test_runner + POST_BUILD + COMMAND ${CMAKE_COMMAND} -E copy_if_different + "${ONNXRUNTIME_ROOT}/test/providers/webgpu/test_webgpu.js" + "${ONNXRUNTIME_ROOT}/test/providers/webgpu/test_webgpu.bat" + "$" + VERBATIM ) + endif() + endif() if (NOT onnxruntime_ENABLE_TRAINING_TORCH_INTEROP) diff --git a/onnxruntime/core/providers/webgpu/How_to_Write_WebGPU_EP_Kernel.md b/onnxruntime/core/providers/webgpu/How_to_Write_WebGPU_EP_Kernel.md index a7123ac4a580d..a27a7b3131bd0 100644 --- a/onnxruntime/core/providers/webgpu/How_to_Write_WebGPU_EP_Kernel.md +++ b/onnxruntime/core/providers/webgpu/How_to_Write_WebGPU_EP_Kernel.md @@ -141,7 +141,18 @@ This section is WIP. use `build.bat --use_webgpu --skip_tests` to build the WebGPU EP. For Release build, append `--config Release` or `--config RelWithDebInfo` to the command line. -to test, find the "onnx_test_runner.exe" in your build folder. run it like: +to test, find the "test_webgpu.bat" in your build folder. run it for tests: +``` +# run all tests +test_webgpu.bat + +# run a specific test +test_webgpu.bat test_abs +``` + + + +to test or debug a single test, find the "onnx_test_runner.exe" in your build folder. run it like: ``` onnx_test_runner.exe -v -e webgpu -C "session.disable_cpu_ep_fallback|1" C:\code\onnxruntime\js\test\data\node\opset17\test_abs ``` diff --git a/onnxruntime/test/providers/webgpu/test_webgpu.bat b/onnxruntime/test/providers/webgpu/test_webgpu.bat new file mode 100644 index 0000000000000..fad6569c24570 --- /dev/null +++ b/onnxruntime/test/providers/webgpu/test_webgpu.bat @@ -0,0 +1,3 @@ +@echo off + +node "%~dp0test_webgpu.js" %* diff --git a/onnxruntime/test/providers/webgpu/test_webgpu.js b/onnxruntime/test/providers/webgpu/test_webgpu.js new file mode 100644 index 0000000000000..111f321ccbbd2 --- /dev/null +++ b/onnxruntime/test/providers/webgpu/test_webgpu.js @@ -0,0 +1,98 @@ +const HELP = ` + Call onnx_test_runner to test WebGPU EP. + + Usage: node test_webgpu.js [options] + + Options: + -h Print this help message. + -t= Path of the test data folder (eg. "../../../js/test/data/node") + -v Verbose mode. + -m= ';' separated list of test names (eg. test_abs) +`; + +const DEFAULT_TESTS = [ + 'test_abs', +]; + +const path = require('path'); +const fs = require('fs'); +const { spawnSync } = require('child_process'); + +const ONNX_TEST_RUNNER_FILENAME = path.join(__dirname, + 'onnx_test_runner' + (process.platform === 'win32' ? '.exe' : '')); + +if (process.argv.includes('-h')) { + console.log(HELP); + process.exit(0); +} + +const VERBOSE = process.argv.includes('-v'); +let test_data_path = process.argv.find(arg => arg.startsWith('-t=')); +if (!test_data_path) { + test_data_path = path.join(__dirname, (process.platform === 'win32' ? '../' : '') + '../../../js/test/data/node'); +} else { + test_data_path = test_data_path.substring(3); +} + +const test_models = []; +const test_model_list = process.argv.find(arg => arg.startsWith('-m=')); +if (test_model_list) { + test_model_list.substring(3).split(';').forEach(test_model => { + test_models.push(test_model); + }); +} +const tests = new Set(test_model_list ? test_models : DEFAULT_TESTS); +const test_cases = []; +fs.readdirSync(test_data_path, { withFileTypes: true }).forEach(dirent => { + if (dirent.isDirectory()) { + const opset = dirent.name; + fs.readdirSync(path.join(test_data_path, opset), { withFileTypes: true }).forEach(dirent => { + if (dirent.isDirectory()) { + const name = dirent.name; + if (tests.has(name)) { + test_cases.push(path.join(test_data_path, opset, name)); + } + } + }); + } +}); + +let passed = []; +let not_implemented = []; +let failed = []; +test_cases.forEach(test_case => { + process.stdout.write(`Running test case: "${test_case}"...`); + const args = [ + '-e', 'webgpu', '-C', '"session.disable_cpu_ep_fallback|1"', test_case, + ]; + if (VERBOSE) { + args.unshift('-v'); + } + const p = spawnSync(ONNX_TEST_RUNNER_FILENAME, args, { shell: true, stdio: ['ignore', 'pipe', 'pipe'] }); + if (p.status !== 0) { + process.stdout.write('Failed\n'); + failed.push(test_case); + } else if (!p.stdout.toString().includes('Not implemented: 0')) { + process.stdout.write('Not Implemented\n'); + not_implemented.push(test_case); + } else { + process.stdout.write('OK\n'); + passed.push(test_case); + } +}); + +console.log(`\n${passed.length} tests passed.`); +console.log(`\n${not_implemented.length} tests not implemented:`); +not_implemented.slice(0, 3).forEach(test_case => { + console.log(` ${test_case}`); +}); +if (not_implemented.length > 3) { + console.log(` ...`); +} +console.log(`\n${failed.length} tests failed:`); +failed.slice(0, 3).forEach(test_case => { + console.log(` ${test_case}`); +}); +if (failed.length > 3) { + console.log(` ...`); +} diff --git a/test_webgpu.bat b/test_webgpu.bat deleted file mode 100644 index feec724c1a7d0..0000000000000 --- a/test_webgpu.bat +++ /dev/null @@ -1,12 +0,0 @@ -rem @echo off -:: if file js\test\data\node\__generated_onnx_node_tests not found, generate it -if not exist "%~dp0js\test\data\node\__generated_onnx_node_tests" ( - pushd "%~dp0js" - call npm ci - call npm run prepare-node-tests - popd -) - -for /F "tokens=*" %%A in (%~dp0test_webgpu_cases.txt) do ( - echo %%A -) diff --git a/test_webgpu_cases.txt b/test_webgpu_cases.txt deleted file mode 100644 index 4cc29f5b13ed8..0000000000000 --- a/test_webgpu_cases.txt +++ /dev/null @@ -1 +0,0 @@ -test_abs From 14c89661ae854f502ee7ef2f85b15915bd123af9 Mon Sep 17 00:00:00 2001 From: Yulong Wang <7679871+fs-eire@users.noreply.github.com> Date: Sat, 31 Aug 2024 16:38:53 -0700 Subject: [PATCH 22/77] Program macros update - allow extend --- .../webgpu/How_to_Write_WebGPU_EP_Kernel.md | 6 +- onnxruntime/core/providers/webgpu/program.h | 219 +++++++++++++----- 2 files changed, 159 insertions(+), 66 deletions(-) diff --git a/onnxruntime/core/providers/webgpu/How_to_Write_WebGPU_EP_Kernel.md b/onnxruntime/core/providers/webgpu/How_to_Write_WebGPU_EP_Kernel.md index a27a7b3131bd0..7ae7e2b37fc29 100644 --- a/onnxruntime/core/providers/webgpu/How_to_Write_WebGPU_EP_Kernel.md +++ b/onnxruntime/core/providers/webgpu/How_to_Write_WebGPU_EP_Kernel.md @@ -37,7 +37,7 @@ constants are declaration of values that are never changes in the shader code. T const A : u32 = 64; ``` -Use macro `WEBGPU_PROGRAM_DEFINE_CONSTANTS` to define constants in your Program class. +Use macro `WEBGPU_PROGRAM_DEFINE_CONSTANTS` to define constants in your Program class, or use `WEBGPU_PROGRAM_EXTEND_CONSTANTS` to extend the constants defined in the base class. #### **overridable constants** @@ -48,13 +48,13 @@ override B : u32 = 64; override C : f32; ``` -Use macro `WEBGPU_PROGRAM_DEFINE_OVERRIDABLE_CONSTANTS` to define overridable constants in your Program class. +Use macro `WEBGPU_PROGRAM_DEFINE_OVERRIDABLE_CONSTANTS` to define overridable constants in your Program class, or use `WEBGPU_PROGRAM_EXTEND_OVERRIDABLE_CONSTANTS` to extend the overridable constants defined in the base class. #### **uniform definitions** uniform definitions are declaration of uniform varables. Their names and type must be defined and cannot be changed. Their values(including length) can be set at runtime. -Use macro `WEBGPU_PROGRAM_DEFINE_UNIFORMS` to define uniform definitions in your Program class. +Use macro `WEBGPU_PROGRAM_DEFINE_UNIFORMS_VARIABLES` to define uniform definitions in your Program class, or use `WEBGPU_PROGRAM_EXTEND_UNIFORMS_VARIABLES` to extend the uniform definitions defined in the base class. ### 2.3. The Program class should override the `GenerateShaderCode` method: diff --git a/onnxruntime/core/providers/webgpu/program.h b/onnxruntime/core/providers/webgpu/program.h index 812e44e014ee6..d056ee8577f11 100644 --- a/onnxruntime/core/providers/webgpu/program.h +++ b/onnxruntime/core/providers/webgpu/program.h @@ -54,6 +54,9 @@ struct ProgramUniformVariableValue { // represents a uniform variable definition struct ProgramUniformVariableDefinition { + constexpr ProgramUniformVariableDefinition(std::string_view name, ProgramUniformVariableDataType data_type) + : name{name}, data_type{data_type} {} + std::string_view name; ProgramUniformVariableDataType data_type; }; @@ -337,27 +340,32 @@ class ProgramWrapper : public ProgramBase { #error "macro ORT_WEBGPU_REGISTER_DERIVED_PROGRAM_CLASS_TYPE_CHECK is already defined" #endif -#define ORT_WEBGPU_REGISTER_DERIVED_PROGRAM_CLASS_TYPE_CHECK(identifier, element_type) \ - private: \ - template \ - static auto test_has_##identifier(int)->decltype(U::identifier, std::true_type{}); /* checks if member exists */ \ - template \ - static auto test_has_##identifier(...)->std::false_type; \ - \ - template && /* - is array */ \ - std::is_const_v && /* - has "const" modifier */ \ - std::is_convertible_v && /* - can convert to a const pointer */ \ - !std::is_member_pointer_v>> /* - is static */ \ - static auto test_has_##identifier##_with_correct_type(int)->std::true_type; \ - template \ - static auto test_has_##identifier##_with_correct_type(...)->std::false_type; \ - \ - public: \ - static constexpr bool has_##identifier = decltype(test_has_##identifier(0))::value; \ +#define ORT_WEBGPU_REGISTER_DERIVED_PROGRAM_CLASS_TYPE_CHECK(identifier, element_type) \ + private: \ + template \ + static auto test_has_##identifier(int)->decltype(U::identifier, std::true_type{}); /* checks if member exists */ \ + template \ + static auto test_has_##identifier(...)->std::false_type; \ + \ + template ::value && /* - is a const std::array */ \ + std::is_const_v && /* - has "const" modifier */ \ + !std::is_member_pointer_v>> /* - is static */ \ + static auto test_has_##identifier##_with_correct_type(int)->std::true_type; \ + template \ + static auto test_has_##identifier##_with_correct_type(...)->std::false_type; \ + \ + public: \ + static constexpr bool has_##identifier = decltype(test_has_##identifier(0))::value; \ static constexpr bool has_##identifier##_with_correct_type = decltype(test_has_##identifier##_with_correct_type(0))::value +// the following template class checks whether the type is a const std::array +template +struct is_const_std_array : std::false_type {}; +template +struct is_const_std_array> : std::true_type {}; + // the following template class checks whether certain static members exist in the derived class (SFINAE) template class DerivedProgramClassTypeCheck { @@ -367,52 +375,90 @@ class DerivedProgramClassTypeCheck { }; // compile-time tests for the type check +// +// TODO: move this to test folder namespace test { +template +class TestTypeCheck { + ORT_WEBGPU_REGISTER_DERIVED_PROGRAM_CLASS_TYPE_CHECK(a, int); +}; + struct TestClass_Empty {}; -struct TestClass_0 { +static_assert(!TestTypeCheck::has_a); +static_assert(!TestTypeCheck::has_a_with_correct_type); + +struct TestClass_NotArray_0 { int b; }; -struct TestClass_1 { +static_assert(!TestTypeCheck::has_a); +static_assert(!TestTypeCheck::has_a_with_correct_type); + +struct TestClass_NotArray_1 { int a; }; -struct TestClass_2 { +static_assert(TestTypeCheck::has_a); +static_assert(!TestTypeCheck::has_a_with_correct_type); + +struct TestClass_NotArray_2 { const int a; }; -struct TestClass_3 { +static_assert(TestTypeCheck::has_a); +static_assert(!TestTypeCheck::has_a_with_correct_type); + +struct TestClass_NotStdArray_0 { const int a[2]; }; -struct TestClass_4 { +static_assert(TestTypeCheck::has_a); +static_assert(!TestTypeCheck::has_a_with_correct_type); + +struct TestClass_NotStdArray_1 { static constexpr int a[] = {0}; }; -struct TestClass_5 { +static_assert(TestTypeCheck::has_a); +static_assert(!TestTypeCheck::has_a_with_correct_type); + +struct TestClass_NotStdArray_2 { static int a[]; }; -struct TestClass_6 { +static_assert(TestTypeCheck::has_a); +static_assert(!TestTypeCheck::has_a_with_correct_type); + +struct TestClass_NotStdArray_3 { static const int a[]; }; +static_assert(TestTypeCheck::has_a); +static_assert(!TestTypeCheck::has_a_with_correct_type); -template -class TestTypeCheck { - ORT_WEBGPU_REGISTER_DERIVED_PROGRAM_CLASS_TYPE_CHECK(a, int); +struct TestClass_StdArray_0 { + std::array a = {1}; }; +static_assert(TestTypeCheck::has_a); +static_assert(!TestTypeCheck::has_a_with_correct_type); -static_assert(!TestTypeCheck::has_a); -static_assert(!TestTypeCheck::has_a_with_correct_type); -static_assert(!TestTypeCheck::has_a); -static_assert(!TestTypeCheck::has_a_with_correct_type); -static_assert(TestTypeCheck::has_a); -static_assert(!TestTypeCheck::has_a_with_correct_type); -static_assert(TestTypeCheck::has_a); -static_assert(!TestTypeCheck::has_a_with_correct_type); -static_assert(TestTypeCheck::has_a); -static_assert(!TestTypeCheck::has_a_with_correct_type); -static_assert(TestTypeCheck::has_a); -static_assert(TestTypeCheck::has_a_with_correct_type); -static_assert(TestTypeCheck::has_a); -static_assert(!TestTypeCheck::has_a_with_correct_type); -static_assert(TestTypeCheck::has_a); -static_assert(TestTypeCheck::has_a_with_correct_type); +struct TestClass_StdArray_1 { + static constexpr std::array a = {1, 2}; +}; +static_assert(TestTypeCheck::has_a); +static_assert(TestTypeCheck::has_a_with_correct_type); + +struct TestClass_StdArray_2 { + static const std::array a; +}; +static_assert(TestTypeCheck::has_a); +static_assert(TestTypeCheck::has_a_with_correct_type); + +struct TestClass_StdArray_3 { + static constexpr const std::array a = {1, 2, 3, 4}; +}; +static_assert(TestTypeCheck::has_a); +static_assert(TestTypeCheck::has_a_with_correct_type); + +struct TestClass_StdArray_4 { + static std::array a; +}; +static_assert(TestTypeCheck::has_a); +static_assert(!TestTypeCheck::has_a_with_correct_type); } // namespace test @@ -435,13 +481,12 @@ class Program : public detail::ProgramWrapper { virtual ProgramMetadata GetMetadata() const final { ProgramMetadata metadata; if constexpr (detail::DerivedProgramClassTypeCheck::has_constants) { - constexpr const ProgramConstant* ptr = T::constants; - constexpr size_t len = sizeof(T::constants) / sizeof(ProgramConstant); + constexpr const ProgramConstant* ptr = T::constants.data(); + constexpr size_t len = T::constants.size(); - static_assert(detail::DerivedProgramClassTypeCheck::has_constants_with_correct_type && - sizeof(T::constants) % sizeof(ProgramConstant) == 0, + static_assert(detail::DerivedProgramClassTypeCheck::has_constants_with_correct_type, "Derived class of \"Program\" has member \"constants\" but its type is incorrect. " - "Please use macro WEBGPU_PROGRAM_DEFINE_CONSTANTS() to declare constants."); + "Please use macro WEBGPU_PROGRAM_DEFINE_CONSTANTS() or WEBGPU_PROGRAM_EXTEND_CONSTANTS() to declare constants."); metadata.constants = {ptr, len}; } else { @@ -449,13 +494,12 @@ class Program : public detail::ProgramWrapper { } if constexpr (detail::DerivedProgramClassTypeCheck::has_overridable_constants) { - constexpr const ProgramOverridableConstantDefinition* ptr = T::overridable_constants; - constexpr size_t len = sizeof(T::overridable_constants) / sizeof(ProgramOverridableConstantDefinition); + constexpr const ProgramOverridableConstantDefinition* ptr = T::overridable_constants.data(); + constexpr size_t len = T::overridable_constants.size(); - static_assert(detail::DerivedProgramClassTypeCheck::has_overridable_constants_with_correct_type && - sizeof(T::overridable_constants) % sizeof(ProgramOverridableConstantDefinition) == 0, + static_assert(detail::DerivedProgramClassTypeCheck::has_overridable_constants_with_correct_type, "Derived class of \"Program\" has member \"overridable_constants\" but its type is incorrect. " - "Please use macro WEBGPU_PROGRAM_DEFINE_OVERRIDABLE_CONSTANTS() to declare overridable constants."); + "Please use macro WEBGPU_PROGRAM_DEFINE_OVERRIDABLE_CONSTANTS() or WEBGPU_PROGRAM_EXTEND_OVERRIDABLE_CONSTANTS() to declare overridable constants."); metadata.overridable_constants = {ptr, len}; } else { @@ -463,13 +507,12 @@ class Program : public detail::ProgramWrapper { } if constexpr (detail::DerivedProgramClassTypeCheck::has_uniform_variables) { - constexpr const ProgramUniformVariableDefinition* ptr = T::uniform_variables; - constexpr size_t len = sizeof(T::uniform_variables) / sizeof(ProgramUniformVariableDefinition); + constexpr const ProgramUniformVariableDefinition* ptr = T::uniform_variables.data(); + constexpr size_t len = T::uniform_variables.size(); - static_assert(detail::DerivedProgramClassTypeCheck::has_uniform_variables_with_correct_type && - sizeof(T::uniform_variables) % sizeof(ProgramUniformVariableDefinition) == 0, + static_assert(detail::DerivedProgramClassTypeCheck::has_uniform_variables_with_correct_type, "Derived class of \"Program\" has member \"uniform_variables\" but its type is incorrect. " - "Please use macro WEBGPU_PROGRAM_DEFINE_UNIFORM_VARIABLES() to declare uniform variables."); + "Please use macro WEBGPU_PROGRAM_DEFINE_UNIFORM_VARIABLES() or WEBGPU_PROGRAM_EXTEND_UNIFORM_VARIABLES() to declare uniform variables."); metadata.uniform_variables = {ptr, len}; } else { @@ -480,14 +523,64 @@ class Program : public detail::ProgramWrapper { } }; +namespace detail { +// helper function to convert a C-style array to std::array +// +// This is basically the same as std::to_array in C++20. +// +template +constexpr auto _to_std_array_impl(T (&arr)[N], std::index_sequence) -> std::array, N> { + return {{arr[Idx]...}}; +} + +template +constexpr auto _to_std_array(T (&arr)[N]) -> std::array, N> { + return _to_std_array_impl(arr, std::make_index_sequence{}); +} + +// helper function to concatenate a std::array and a C-style array to a std::array +// +template +constexpr std::array, L + R> _concat2_impl(const std::array& lhs, + T (&rhs)[R], + std::index_sequence, + std::index_sequence) { + return {{lhs[IdxL]..., rhs[IdxR]...}}; +} + +template +constexpr std::array, L + R> _concat2(const std::array& lhs, T (&rhs)[R]) { + return _concat2_impl(lhs, rhs, std::make_index_sequence{}, std::make_index_sequence{}); +} + +} // namespace detail +#define WEBGPU_PROGRAM_DEFINE_(identifier, T, ...) \ + static constexpr const T identifier##_own[] = {__VA_ARGS__}; \ + static constexpr const auto identifier = \ + onnxruntime::webgpu::detail::_to_std_array(identifier##_own) + +#define WEBGPU_PROGRAM_EXTEND_(identifier, T, BASE, ...) \ + static constexpr const T identifier##_own[] = {__VA_ARGS__}; \ + static constexpr const auto identifier = \ + onnxruntime::webgpu::detail::_concat2(BASE::identifier, identifier##_own) + #define WEBGPU_PROGRAM_DEFINE_CONSTANTS(...) \ - static constexpr const onnxruntime::webgpu::ProgramConstant constants[] = {__VA_ARGS__} + WEBGPU_PROGRAM_DEFINE_(constants, onnxruntime::webgpu::ProgramConstant, __VA_ARGS__) + +#define WEBGPU_PROGRAM_EXTEND_CONSTANTS(BASE, ...) \ + WEBGPU_PROGRAM_EXTEND_(constants, onnxruntime::webgpu::ProgramConstant, BASE, __VA_ARGS__) #define WEBGPU_PROGRAM_DEFINE_OVERRIDABLE_CONSTANTS(...) \ - static constexpr const onnxruntime::webgpu::ProgramOverridableConstantDefinition overridable_constants[] = {__VA_ARGS__} + WEBGPU_PROGRAM_DEFINE_(overridable_constants, onnxruntime::webgpu::ProgramOverridableConstantDefinition, __VA_ARGS__) + +#define WEBGPU_PROGRAM_EXTEND_OVERRIDABLE_CONSTANTS(BASE, ...) \ + WEBGPU_PROGRAM_EXTEND_(overridable_constants, onnxruntime::webgpu::ProgramOverridableConstantDefinition, BASE, __VA_ARGS__) #define WEBGPU_PROGRAM_DEFINE_UNIFORM_VARIABLES(...) \ - static constexpr const onnxruntime::webgpu::ProgramUniformVariableDefinition uniform_variables[] = {__VA_ARGS__} + WEBGPU_PROGRAM_DEFINE_(uniform_variables, onnxruntime::webgpu::ProgramUniformVariableDefinition, __VA_ARGS__) + +#define WEBGPU_PROGRAM_EXTEND_UNIFORM_VARIABLES(BASE, ...) \ + WEBGPU_PROGRAM_EXTEND_(uniform_variables, onnxruntime::webgpu::ProgramUniformVariableDefinition, BASE, __VA_ARGS__) } // namespace webgpu } // namespace onnxruntime From 4fff35f99fe26df47c862882179eafe1695a961d Mon Sep 17 00:00:00 2001 From: Yulong Wang <7679871+fs-eire@users.noreply.github.com> Date: Sat, 31 Aug 2024 21:42:12 -0700 Subject: [PATCH 23/77] fix BucketCacheManager --- onnxruntime/core/providers/webgpu/buffer_manager.cc | 2 ++ 1 file changed, 2 insertions(+) diff --git a/onnxruntime/core/providers/webgpu/buffer_manager.cc b/onnxruntime/core/providers/webgpu/buffer_manager.cc index e1f065b65f13e..da544e1d1ed60 100644 --- a/onnxruntime/core/providers/webgpu/buffer_manager.cc +++ b/onnxruntime/core/providers/webgpu/buffer_manager.cc @@ -176,6 +176,8 @@ class BucketCacheManager : public IBufferCacheManager { wgpuBufferRelease(buffer); } } + + pending_buffers_.clear(); } protected: From 4fd8ad19327db52b9fa147ccb5bc5a0b8978acc1 Mon Sep 17 00:00:00 2001 From: Yulong Wang <7679871+fs-eire@users.noreply.github.com> Date: Sat, 31 Aug 2024 22:09:51 -0700 Subject: [PATCH 24/77] add a method to get logger from ComputeContext --- .../core/providers/webgpu/compute_context.cc | 20 ------------ .../core/providers/webgpu/compute_context.h | 31 ++++++++++++++----- 2 files changed, 24 insertions(+), 27 deletions(-) diff --git a/onnxruntime/core/providers/webgpu/compute_context.cc b/onnxruntime/core/providers/webgpu/compute_context.cc index 67c55f823d78a..b7a1af5b26ef7 100644 --- a/onnxruntime/core/providers/webgpu/compute_context.cc +++ b/onnxruntime/core/providers/webgpu/compute_context.cc @@ -13,25 +13,5 @@ ComputeContext::ComputeContext(OpKernelContext& kernel_context) kernel_context_{kernel_context} { } -const wgpu::AdapterInfo& ComputeContext::AdapterInfo() const { - return webgpu_context_.AdapterInfo(); -} - -const wgpu::Limits& ComputeContext::DeviceLimits() const { - return webgpu_context_.DeviceLimits(); -} - -int ComputeContext::InputCount() const { - return kernel_context_.InputCount(); -} - -int ComputeContext::OutputCount() const { - return kernel_context_.OutputCount(); -} - -Status ComputeContext::RunProgram(const ProgramBase& program) { - return webgpu_context_.Run(*this, program); -} - } // namespace webgpu } // namespace onnxruntime diff --git a/onnxruntime/core/providers/webgpu/compute_context.h b/onnxruntime/core/providers/webgpu/compute_context.h index 9c352d3d76dd9..ab090956b4d4b 100644 --- a/onnxruntime/core/providers/webgpu/compute_context.h +++ b/onnxruntime/core/providers/webgpu/compute_context.h @@ -34,34 +34,49 @@ class ComputeContext { // Get various information from the context. // - const wgpu::AdapterInfo& AdapterInfo() const; - const wgpu::Limits& DeviceLimits() const; + inline const wgpu::AdapterInfo& AdapterInfo() const { + return webgpu_context_.AdapterInfo(); + } + inline const wgpu::Limits& DeviceLimits() const { + return webgpu_context_.DeviceLimits(); + } + + // + // Get the logger + // + inline const logging::Logger& Logger() const { + return kernel_context_.Logger(); + } // // Get input tensor. // template - const T* Input(int index) const { + inline const T* Input(int index) const { return kernel_context_.Input(index); } // // Get input count. // - int InputCount() const; + inline int InputCount() const { + return kernel_context_.InputCount(); + } // // Set output tensor. // template - Tensor* Output(int index, TensorShapeType&& shape) { + inline Tensor* Output(int index, TensorShapeType&& shape) { return kernel_context_.Output(index, std::forward(shape)); } // // Get output count. // - int OutputCount() const; + inline int OutputCount() const { + return kernel_context_.OutputCount(); + } // // Create CPU tensor. @@ -86,7 +101,9 @@ class ComputeContext { // // Run a compute shader program. // - Status RunProgram(const ProgramBase& program); + inline Status RunProgram(const ProgramBase& program) { + return webgpu_context_.Run(*this, program); + } protected: WebGpuContext& webgpu_context_; From 3bd92adcf6ea537477b6b65171a14e2429c1443e Mon Sep 17 00:00:00 2001 From: Yulong Wang <7679871+fs-eire@users.noreply.github.com> Date: Sat, 31 Aug 2024 22:33:25 -0700 Subject: [PATCH 25/77] add verbose log for cache key --- onnxruntime/core/providers/webgpu/compute_context.h | 1 + onnxruntime/core/providers/webgpu/webgpu_context.cc | 5 ++++- 2 files changed, 5 insertions(+), 1 deletion(-) diff --git a/onnxruntime/core/providers/webgpu/compute_context.h b/onnxruntime/core/providers/webgpu/compute_context.h index ab090956b4d4b..132f629ac745e 100644 --- a/onnxruntime/core/providers/webgpu/compute_context.h +++ b/onnxruntime/core/providers/webgpu/compute_context.h @@ -14,6 +14,7 @@ #include "core/framework/execution_provider.h" #include "core/providers/webgpu/program.h" +#include "core/providers/webgpu/webgpu_context.h" #include "core/framework/op_kernel.h" namespace onnxruntime { diff --git a/onnxruntime/core/providers/webgpu/webgpu_context.cc b/onnxruntime/core/providers/webgpu/webgpu_context.cc index 049a729f5c988..9e51cc08eec0f 100644 --- a/onnxruntime/core/providers/webgpu/webgpu_context.cc +++ b/onnxruntime/core/providers/webgpu/webgpu_context.cc @@ -6,6 +6,7 @@ #include "core/common/common.h" +#include "core/providers/webgpu/compute_context.h" #include "core/providers/webgpu/webgpu_context.h" #include "core/providers/webgpu/buffer_manager.h" #include "core/providers/webgpu/webgpu_execution_provider.h" @@ -124,7 +125,7 @@ Status WebGpuContext::Wait(wgpu::Future f) { return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Failed to wait for the operation:", uint32_t(status)); } -Status WebGpuContext::Run(const ComputeContext& /*context*/, const ProgramBase& program) { +Status WebGpuContext::Run(const ComputeContext& context, const ProgramBase& program) { const auto& inputs = program.Inputs(); const auto& outputs = program.Outputs(); @@ -200,6 +201,8 @@ Status WebGpuContext::Run(const ComputeContext& /*context*/, const ProgramBase& auto key = CalculateProgramCacheKey(program, is_1d_dispatch); + LOGS(context.Logger(), INFO) << "Starting program \"" << key << "\" (" << x << ", " << y << ", " << z << ")"; + const auto* program_artifact = program_mgr_->Get(key); if (program_artifact == nullptr) { wgpu::ComputePipeline compute_pipeline; From 6a1bbfe907cf1f5a2b494edde9ebbdcbfe1795d9 Mon Sep 17 00:00:00 2001 From: Yulong Wang <7679871+fs-eire@users.noreply.github.com> Date: Sun, 1 Sep 2024 01:02:31 -0700 Subject: [PATCH 26/77] revise suite test --- .../webgpu/How_to_Write_WebGPU_EP_Kernel.md | 54 +- .../test/providers/webgpu/test_webgpu.js | 1138 ++++++++++++++++- 2 files changed, 1130 insertions(+), 62 deletions(-) diff --git a/onnxruntime/core/providers/webgpu/How_to_Write_WebGPU_EP_Kernel.md b/onnxruntime/core/providers/webgpu/How_to_Write_WebGPU_EP_Kernel.md index 7ae7e2b37fc29..624cfd80dd8f7 100644 --- a/onnxruntime/core/providers/webgpu/How_to_Write_WebGPU_EP_Kernel.md +++ b/onnxruntime/core/providers/webgpu/How_to_Write_WebGPU_EP_Kernel.md @@ -119,6 +119,7 @@ Status ComputeInternal(ComputeContext& context) const override; ``` Usually, in the implementation, we do 3 things: + - Create a local variable of the Program class. - Set a few runtime info of the Program instance. - Call `context.RunProgram(program)` to run the program and return the status. @@ -130,6 +131,7 @@ Complicated operators may do more things. Check header files and existing implem Register the operator just like any EP does. Check existing implementations for more details. Please note that registration is composed of 2 parts: + - Use macros like `ONNX_OPERATOR_KERNEL_EX` or `ONNX_OPERATOR_VERSIONED_KERNEL_EX` (or wrap a new macro as what we usually do) to register the operator in kernel source code file. - Add the operator to onnxruntime/core/providers/webgpu/webgpu_execution_provider.cc @@ -139,29 +141,59 @@ This section is WIP. ## 6. Build and test +### Build + use `build.bat --use_webgpu --skip_tests` to build the WebGPU EP. For Release build, append `--config Release` or `--config RelWithDebInfo` to the command line. -to test, find the "test_webgpu.bat" in your build folder. run it for tests: +### Prepare test data + +Assume `C:\code\onnxruntime` is the root of your onnxruntime repo in all documents below. + +if folder `C:\code\onnxruntime\js\test\data` does not exist, run the following in your onnxruntime repo root: + +``` +cd js +npm ci +npm run prepare-node-tests +``` + +### Run Suite test (temporary: this may change recently) + +to do suite test, find the "test_webgpu.bat" in your build folder (It's usually in `build\Windows\Debug\Debug`). run it for tests: + ``` # run all tests test_webgpu.bat -# run a specific test -test_webgpu.bat test_abs +# run a test list from args +test_webgpu.bat -m=test_abs;test_cos ``` +To add more tests to the suite list, edit the file at `C:\code\onnxruntime\onnxruntime\test\providers\webgpu\test_webgpu.js`. After editing, run build again otherwise this file will not be copied to the build folder. +> How does it work? +> +> The `test_webgpu.bat` calls `test_webgpu.js` with nodejs. +> +> The `test_webgpu.js` use the test list (either the suite list or from cmd args) to prepare a temporary folder and creates symbolic links to the test data folder (under `C:\code\onnxruntime\js\test\data`). Then it runs `onnx_test_runner` on the temporary folder. + +### Run single test / debug to test or debug a single test, find the "onnx_test_runner.exe" in your build folder. run it like: + ``` onnx_test_runner.exe -v -e webgpu -C "session.disable_cpu_ep_fallback|1" C:\code\onnxruntime\js\test\data\node\opset17\test_abs ``` -> Assume `C:\code\onnxruntime` is the root of your onnxruntime repo -> -> if folder `C:\code\onnxruntime\js\test\data` does not exist, run the following in your onnxruntime repo root: -> ``` -> cd js -> npm ci -> npm run prepare-node-tests -> ``` +The `-C` flag is split by space for each key-value pair. Each key-value pair is separated by `|`. The key is the option name and the value is the option value. + +Some features are useful but if you are troubleshooting and want to rule out the cause, you can: + +- set `storageBufferCacheMode` to `disabled` to disable the storage buffer cache. +- set `-M` and `-A` to disable memory pattern and memory arena. +- set `-j 1` to disable parallel execution (if you have multiple models to test). + +Example: +``` +onnx_test_runner.exe -v -A -M -j 1 -e webgpu -C "session.disable_cpu_ep_fallback|1 storageBufferCacheMode|disabled" C:\code\onnxruntime\js\test\data\node\opset17\test_abs +``` diff --git a/onnxruntime/test/providers/webgpu/test_webgpu.js b/onnxruntime/test/providers/webgpu/test_webgpu.js index 111f321ccbbd2..254bded19ae7c 100644 --- a/onnxruntime/test/providers/webgpu/test_webgpu.js +++ b/onnxruntime/test/providers/webgpu/test_webgpu.js @@ -11,7 +11,1047 @@ const HELP = ` `; const DEFAULT_TESTS = [ - 'test_abs', + "test_abs", + "test_acos_example", + "test_acos", + "test_acosh_example", + "test_acosh", + // // "test_adagrad_multiple", + // // "test_adagrad", + // // "test_adam_multiple", + // // "test_adam", + "test_add_bcast", + // "test_add_uint8", + "test_add", + // "test_and_bcast3v1d", + // "test_and_bcast3v2d", + // "test_and_bcast4v2d", + // "test_and_bcast4v3d", + // "test_and_bcast4v4d", + // "test_and2d", + // "test_and3d", + // "test_and4d", + "test_argmax_default_axis_example_select_last_index", + "test_argmax_default_axis_example", + "test_argmax_default_axis_random_select_last_index", + "test_argmax_default_axis_random", + "test_argmax_keepdims_example_select_last_index", + "test_argmax_keepdims_example", + "test_argmax_keepdims_random_select_last_index", + "test_argmax_keepdims_random", + "test_argmax_negative_axis_keepdims_example_select_last_index", + "test_argmax_negative_axis_keepdims_example", + "test_argmax_negative_axis_keepdims_random_select_last_index", + "test_argmax_negative_axis_keepdims_random", + "test_argmax_no_keepdims_example_select_last_index", + "test_argmax_no_keepdims_example", + "test_argmax_no_keepdims_random_select_last_index", + "test_argmax_no_keepdims_random", + "test_argmin_default_axis_example_select_last_index", + "test_argmin_default_axis_example", + "test_argmin_default_axis_random_select_last_index", + "test_argmin_default_axis_random", + "test_argmin_keepdims_example_select_last_index", + "test_argmin_keepdims_example", + "test_argmin_keepdims_random_select_last_index", + "test_argmin_keepdims_random", + "test_argmin_negative_axis_keepdims_example_select_last_index", + "test_argmin_negative_axis_keepdims_example", + "test_argmin_negative_axis_keepdims_random_select_last_index", + "test_argmin_negative_axis_keepdims_random", + "test_argmin_no_keepdims_example_select_last_index", + "test_argmin_no_keepdims_example", + "test_argmin_no_keepdims_random_select_last_index", + "test_argmin_no_keepdims_random", + "test_asin_example", + "test_asin", + "test_asinh_example", + "test_asinh", + "test_atan_example", + "test_atan", + "test_atanh_example", + "test_atanh", + // "test_averagepool_1d_default", + // "test_averagepool_2d_ceil", + "test_averagepool_2d_default", + "test_averagepool_2d_pads_count_include_pad", + "test_averagepool_2d_pads", + "test_averagepool_2d_precomputed_pads_count_include_pad", + "test_averagepool_2d_precomputed_pads", + "test_averagepool_2d_precomputed_same_upper", + "test_averagepool_2d_precomputed_strides", + "test_averagepool_2d_same_lower", + "test_averagepool_2d_same_upper", + "test_averagepool_2d_strides", + // "test_averagepool_3d_default", + "test_basic_conv_with_padding", + "test_basic_conv_without_padding", + // "test_basic_convinteger", + // "test_batchnorm_epsilon_training_mode", + "test_batchnorm_epsilon", + // "test_batchnorm_example_training_mode", + "test_batchnorm_example", + // // "test_bernoulli_double_expanded", + // // "test_bernoulli_double", + // // "test_bernoulli_expanded", + // // "test_bernoulli_seed_expanded", + // // "test_bernoulli_seed", + // // "test_bernoulli", + // // "test_bitshift_left_uint16", + // // "test_bitshift_left_uint32", + // // "test_bitshift_left_uint64", + // // "test_bitshift_left_uint8", + // // "test_bitshift_right_uint16", + // // "test_bitshift_right_uint32", + // // "test_bitshift_right_uint64", + // // "test_bitshift_right_uint8", + // // "test_blackmanwindow_expanded", + // // "test_blackmanwindow_symmetric_expanded", + // // "test_blackmanwindow_symmetric", + // // "test_blackmanwindow", + // // "test_cast_BFLOAT16_to_FLOAT", + // // "test_cast_DOUBLE_to_FLOAT", + // // "test_cast_DOUBLE_to_FLOAT16", + // // "test_cast_FLOAT_to_BFLOAT16", + // // "test_cast_FLOAT_to_DOUBLE", + // // "test_cast_FLOAT_to_FLOAT16", + // // "test_cast_FLOAT_to_STRING", + // // "test_cast_FLOAT16_to_DOUBLE", + // // "test_cast_FLOAT16_to_FLOAT", + // // "test_cast_STRING_to_FLOAT", + // // "test_castlike_BFLOAT16_to_FLOAT_expanded", + // // "test_castlike_BFLOAT16_to_FLOAT", + // // "test_castlike_DOUBLE_to_FLOAT_expanded", + // // "test_castlike_DOUBLE_to_FLOAT", + // // "test_castlike_DOUBLE_to_FLOAT16_expanded", + // // "test_castlike_DOUBLE_to_FLOAT16", + // // "test_castlike_FLOAT_to_BFLOAT16_expanded", + // // "test_castlike_FLOAT_to_BFLOAT16", + // // "test_castlike_FLOAT_to_DOUBLE_expanded", + // // "test_castlike_FLOAT_to_DOUBLE", + // // "test_castlike_FLOAT_to_FLOAT16_expanded", + // // "test_castlike_FLOAT_to_FLOAT16", + // // "test_castlike_FLOAT_to_STRING_expanded", + // // "test_castlike_FLOAT_to_STRING", + // // "test_castlike_FLOAT16_to_DOUBLE_expanded", + // // "test_castlike_FLOAT16_to_DOUBLE", + // // "test_castlike_FLOAT16_to_FLOAT_expanded", + // // "test_castlike_FLOAT16_to_FLOAT", + // // "test_castlike_STRING_to_FLOAT_expanded", + // // "test_castlike_STRING_to_FLOAT", + "test_ceil_example", + "test_ceil", + // "test_celu_expanded", + // "test_celu", + // "test_clip_default_inbounds", + // "test_clip_default_int8_inbounds", + // "test_clip_default_int8_max", + // "test_clip_default_int8_min", + // "test_clip_default_max", + // "test_clip_default_min", + // "test_clip_example", + // "test_clip_inbounds", + // "test_clip_outbounds", + // "test_clip_splitbounds", + // "test_clip", + // // "test_compress_0", + // // "test_compress_1", + // // "test_compress_default_axis", + // // "test_compress_negative_axis", + "test_concat_1d_axis_0", + "test_concat_1d_axis_negative_1", + "test_concat_2d_axis_0", + "test_concat_2d_axis_1", + "test_concat_2d_axis_negative_1", + "test_concat_2d_axis_negative_2", + "test_concat_3d_axis_0", + "test_concat_3d_axis_1", + "test_concat_3d_axis_2", + "test_concat_3d_axis_negative_1", + "test_concat_3d_axis_negative_2", + "test_concat_3d_axis_negative_3", + "test_conv_with_autopad_same", + "test_conv_with_strides_and_asymmetric_padding", + "test_conv_with_strides_no_padding", + "test_conv_with_strides_padding", + // // "test_convinteger_with_padding", + // // "test_convinteger_without_padding", + "test_convtranspose_1d", + // // "test_convtranspose_3d", + "test_convtranspose_autopad_same", + "test_convtranspose_dilations", + "test_convtranspose_kernel_shape", + "opset{9,17}/test_convtranspose_output_shape", + "test_convtranspose_pad", + "test_convtranspose_pads", + "test_convtranspose_with_kernel", + "test_convtranspose", + "test_cos_example", + "test_cos", + "test_cosh_example", + "test_cosh", + // "test_cumsum_1d_exclusive", + // "test_cumsum_1d_reverse_exclusive", + // "test_cumsum_1d_reverse", + // "test_cumsum_1d", + // "test_cumsum_2d_axis_0", + // "test_cumsum_2d_axis_1", + // "test_cumsum_2d_negative_axis", + "test_depthtospace_crd_mode_example", + "test_depthtospace_crd_mode", + "test_depthtospace_dcr_mode", + "test_depthtospace_example", + "test_depthtospace", + "test_dequantizelinear_axis", + "test_dequantizelinear", + // // "test_det_2d", + // // "test_det_nd", + // // "test_dft_axis", + // // "test_dft_inverse", + // // "test_dft", + "test_div_bcast", + "test_div_example", + // "test_div_uint8", + "test_div", + // // "test_dropout_default_mask_ratio", + // // "test_dropout_default_mask", + // // "test_dropout_default_old", + // // "test_dropout_default_ratio", + // // "test_dropout_default", + // // "test_dropout_random_old", + // // "test_dropout_random", + // // "test_dynamic_slice_default_axes", + // // "test_dynamic_slice_end_out_of_bounds", + // // "test_dynamic_slice_neg", + // // "test_dynamic_slice_start_out_of_bounds", + // // "test_dynamic_slice", + // // "test_dynamicquantizelinear_expanded", + // // "test_dynamicquantizelinear_max_adjusted_expanded", + // // "test_dynamicquantizelinear_max_adjusted", + // // "test_dynamicquantizelinear_min_adjusted_expanded", + // // "test_dynamicquantizelinear_min_adjusted", + // // "test_dynamicquantizelinear", + "test_edge_pad", + // "test_einsum_batch_diagonal", + // "test_einsum_batch_matmul", + // "test_einsum_inner_prod", + // "test_einsum_sum", + // "test_einsum_transpose", + "test_elu_default", + "test_elu_example", + "test_elu", + // "test_equal_bcast", + // "test_equal", + "test_erf", + "test_exp_example", + "test_exp", + "test_expand_dim_changed", + "test_expand_dim_unchanged", + // "test_eyelike_populate_off_main_diagonal", + // "test_eyelike_with_dtype", + // "test_eyelike_without_dtype", + "test_flatten_axis0", + "test_flatten_axis1", + "test_flatten_axis2", + "test_flatten_axis3", + "test_flatten_default_axis", + "test_flatten_negative_axis1", + "test_flatten_negative_axis2", + "test_flatten_negative_axis3", + "test_flatten_negative_axis4", + "test_floor_example", + "test_floor", + "test_gather_0", + "test_gather_1", + "test_gather_2d_indices", + "test_gather_negative_indices", + "test_gather_elements_0", + "test_gather_elements_1", + "test_gather_elements_negative_indices", + // "test_gather_negative_indices", + // // "test_gathernd_example_float32", + // // "test_gathernd_example_int32_batch_dim1", + // // "test_gathernd_example_int32", + "test_gemm_all_attributes", + "test_gemm_alpha", + "test_gemm_beta", + "test_gemm_broadcast", + "test_gemm_default_matrix_bias", + "test_gemm_default_no_bias", + // "test_gemm_default_scalar_bias", + "test_gemm_default_single_elem_vector_bias", + "test_gemm_default_vector_bias", + "test_gemm_default_zero_bias", + "test_gemm_nobroadcast", + "test_gemm_transposeA", + "test_gemm_transposeB", + "test_globalaveragepool_precomputed", + "test_globalaveragepool", + "test_globalmaxpool_precomputed", + "test_globalmaxpool", + "test_greater_bcast", + "test_greater_equal_bcast_expanded", + "test_greater_equal_bcast", + "test_greater_equal_expanded", + "test_greater_equal", + "test_greater", + // // "test_gridsample_aligncorners_true", + // // "test_gridsample_bicubic", + // // "test_gridsample_bilinear", + // // "test_gridsample_border_padding", + // // "test_gridsample_nearest", + // // "test_gridsample_reflection_padding", + // // "test_gridsample_zeros_padding", + // // "test_gridsample", + // // "test_gru_batchwise", + // // "test_gru_defaults", + // // "test_gru_seq_length", + // // "test_gru_with_initial_bias", + // // "test_hammingwindow_expanded", + // // "test_hammingwindow_symmetric_expanded", + // // "test_hammingwindow_symmetric", + // // "test_hammingwindow", + // // "test_hannwindow_expanded", + // // "test_hannwindow_symmetric_expanded", + // // "test_hannwindow_symmetric", + // // "test_hannwindow", + // // "test_hardmax_axis_0", + // // "test_hardmax_axis_1", + // // "test_hardmax_axis_2", + // // "test_hardmax_default_axis", + // // "test_hardmax_example", + // // "test_hardmax_negative_axis", + // // "test_hardmax_one_hot", + "test_hardsigmoid_default", + "test_hardsigmoid_example", + "test_hardsigmoid", + // // "test_hardswish_expanded", + // // "test_hardswish", + "test_if", + // TODO: Uncomment 'test_if_seq' and 'test_if_opt' once the test infra + // supports Sequence and Optional types + // "test_if_seq", + // "test_if_opt", + "test_instancenorm_epsilon", + "test_instancenorm_example", + // "test_isinf_negative", + // "test_isinf_positive", + // "test_isinf", + // "test_isnan", + "test_layer_normalization_2d_axis_negative_1_expanded", + "test_layer_normalization_2d_axis_negative_1", + "test_layer_normalization_2d_axis_negative_2_expanded", + "test_layer_normalization_2d_axis_negative_2", + "test_layer_normalization_2d_axis0_expanded", + "test_layer_normalization_2d_axis0", + "test_layer_normalization_2d_axis1_expanded", + "test_layer_normalization_2d_axis1", + // // "test_layer_normalization_3d_axis_negative_1_epsilon_expanded", + "test_layer_normalization_3d_axis_negative_1_epsilon", + // // "test_layer_normalization_3d_axis_negative_2_epsilon_expanded", + "test_layer_normalization_3d_axis_negative_2_epsilon", + // // "test_layer_normalization_3d_axis_negative_3_epsilon_expanded", + "test_layer_normalization_3d_axis_negative_3_epsilon", + // // "test_layer_normalization_3d_axis0_epsilon_expanded", + "test_layer_normalization_3d_axis0_epsilon", + "test_layer_normalization_3d_axis1_epsilon_expanded", + "test_layer_normalization_3d_axis1_epsilon", + // // "test_layer_normalization_3d_axis2_epsilon_expanded", + "test_layer_normalization_3d_axis2_epsilon", + "test_layer_normalization_4d_axis_negative_1_expanded", + "test_layer_normalization_4d_axis_negative_1", + // // "test_layer_normalization_4d_axis_negative_2_expanded", + "test_layer_normalization_4d_axis_negative_2", + // "test_layer_normalization_4d_axis_negative_3_expanded", + "test_layer_normalization_4d_axis_negative_3", + // "test_layer_normalization_4d_axis_negative_4_expanded", + "test_layer_normalization_4d_axis_negative_4", + "test_layer_normalization_4d_axis0_expanded", + "test_layer_normalization_4d_axis0", + "test_layer_normalization_4d_axis1_expanded", + "test_layer_normalization_4d_axis1", + // // "test_layer_normalization_4d_axis2_expanded", + "test_layer_normalization_4d_axis2", + "test_layer_normalization_4d_axis3_expanded", + "test_layer_normalization_4d_axis3", + "test_layer_normalization_default_axis_expanded", + "test_layer_normalization_default_axis", + "test_leakyrelu_default", + "test_leakyrelu_example", + "test_leakyrelu", + "test_less_bcast", + "test_less_equal_bcast_expanded", + "test_less_equal_bcast", + "test_less_equal_expanded", + "test_less_equal", + "test_less", + "test_log_example", + "test_log", + // // "test_logsoftmax_axis_0_expanded", + // // "test_logsoftmax_axis_0", + // // "test_logsoftmax_axis_1_expanded", + // // "test_logsoftmax_axis_1", + // // "test_logsoftmax_axis_2_expanded", + // // "test_logsoftmax_axis_2", + // // "test_logsoftmax_default_axis_expanded", + // // "test_logsoftmax_default_axis", + // // "test_logsoftmax_example_1_expanded", + // // "test_logsoftmax_example_1", + // // "test_logsoftmax_large_number_expanded", + // // "test_logsoftmax_large_number", + // // "test_logsoftmax_negative_axis_expanded", + // // "test_logsoftmax_negative_axis", + // "test_lrn_default", + // "test_lrn", + // // "test_lstm_batchwise", + // // "test_lstm_defaults", + // // "test_lstm_with_initial_bias", + // // "test_lstm_with_peepholes", + "test_matmul_2d", + "test_matmul_3d", + "test_matmul_4d", + // // "test_matmulinteger", + // "test_max_example", + // "test_max_float16", + // "test_max_float32", + // "test_max_float64", + // "test_max_int16", + // "test_max_int32", + // "test_max_int64", + // "test_max_int8", + // "test_max_one_input", + // "test_max_two_inputs", + // "test_max_uint16", + // "test_max_uint32", + // "test_max_uint64", + // "test_max_uint8", + // "test_maxpool_1d_default", + // "test_maxpool_2d_ceil", + "test_maxpool_2d_default", + // "test_maxpool_2d_dilations", + "test_maxpool_2d_pads", + "test_maxpool_2d_precomputed_pads", + "test_maxpool_2d_precomputed_same_upper", + "test_maxpool_2d_precomputed_strides", + "test_maxpool_2d_same_lower", + "test_maxpool_2d_same_upper", + "test_maxpool_2d_strides", + // "test_maxpool_2d_uint8", + // "test_maxpool_3d_default", + // "test_maxpool_with_argmax_2d_precomputed_pads", + // "test_maxpool_with_argmax_2d_precomputed_strides", + // // "test_maxunpool_export_with_output_shape", + // // "test_maxunpool_export_without_output_shape", + // // "test_mean_example", + // // "test_mean_one_input", + // // "test_mean_two_inputs", + // // "test_melweightmatrix", + // "test_min_example", + // "test_min_float16", + // "test_min_float32", + // "test_min_float64", + // "test_min_int16", + // "test_min_int32", + // "test_min_int64", + // "test_min_int8", + // "test_min_one_input", + // "test_min_two_inputs", + // "test_min_uint16", + // "test_min_uint32", + // "test_min_uint64", + // "test_min_uint8", + // "test_mod_bcast", + // "test_mod_broadcast", + // "test_mod_float_mixed_sign_example", + // "test_mod_fmod_mixed_sign_example", + // "test_mod_int64_fmod", + // "test_mod_int64_mixed_sign_example", + // "test_mod_mixed_sign_float16", + // "test_mod_mixed_sign_float32", + // "test_mod_mixed_sign_float64", + // "test_mod_mixed_sign_int16", + // "test_mod_mixed_sign_int32", + // "test_mod_mixed_sign_int64", + // "test_mod_mixed_sign_int8", + // "test_mod_uint16", + // "test_mod_uint32", + // "test_mod_uint64", + // "test_mod_uint8", + // // "test_momentum_multiple", + // // "test_momentum", + "test_mul_bcast", + "test_mul_example", + // "test_mul_uint8", + "test_mul", + // "test_mvn_expanded", + // "test_mvn", + "test_neg_example", + "test_neg", + // // "test_negative_log_likelihood_loss_iinput_shape_is_NCd1_weight_ignore_index_expanded", + // // "test_negative_log_likelihood_loss_iinput_shape_is_NCd1_weight_ignore_index", + // // "test_negative_log_likelihood_loss_input_shape_is_NC_expanded", + // // "test_negative_log_likelihood_loss_input_shape_is_NC", + // // "test_negative_log_likelihood_loss_input_shape_is_NCd1_expanded", + // // "test_negative_log_likelihood_loss_input_shape_is_NCd1_ignore_index_expanded", + // // "test_negative_log_likelihood_loss_input_shape_is_NCd1_ignore_index", + // // "test_negative_log_likelihood_loss_input_shape_is_NCd1_mean_weight_negative_ignore_index_expanded", + // // "test_negative_log_likelihood_loss_input_shape_is_NCd1_mean_weight_negative_ignore_index", + // // "test_negative_log_likelihood_loss_input_shape_is_NCd1_weight_expanded", + // // "test_negative_log_likelihood_loss_input_shape_is_NCd1_weight", + // // "test_negative_log_likelihood_loss_input_shape_is_NCd1", + // // "test_negative_log_likelihood_loss_input_shape_is_NCd1d2_expanded", + // // "test_negative_log_likelihood_loss_input_shape_is_NCd1d2_no_weight_reduction_mean_ignore_index_expanded", + // // "test_negative_log_likelihood_loss_input_shape_is_NCd1d2_no_weight_reduction_mean_ignore_index", + // // "test_negative_log_likelihood_loss_input_shape_is_NCd1d2_reduction_mean_expanded", + // // "test_negative_log_likelihood_loss_input_shape_is_NCd1d2_reduction_mean", + // // "test_negative_log_likelihood_loss_input_shape_is_NCd1d2_reduction_sum_expanded", + // // "test_negative_log_likelihood_loss_input_shape_is_NCd1d2_reduction_sum", + // // "test_negative_log_likelihood_loss_input_shape_is_NCd1d2_with_weight_expanded", + // // "test_negative_log_likelihood_loss_input_shape_is_NCd1d2_with_weight_reduction_mean_expanded", + // // "test_negative_log_likelihood_loss_input_shape_is_NCd1d2_with_weight_reduction_mean", + // // "test_negative_log_likelihood_loss_input_shape_is_NCd1d2_with_weight_reduction_sum_expanded", + // // "test_negative_log_likelihood_loss_input_shape_is_NCd1d2_with_weight_reduction_sum_ignore_index_expanded", + // // "test_negative_log_likelihood_loss_input_shape_is_NCd1d2_with_weight_reduction_sum_ignore_index", + // // "test_negative_log_likelihood_loss_input_shape_is_NCd1d2_with_weight_reduction_sum", + // // "test_negative_log_likelihood_loss_input_shape_is_NCd1d2_with_weight", + // // "test_negative_log_likelihood_loss_input_shape_is_NCd1d2", + // // "test_negative_log_likelihood_loss_input_shape_is_NCd1d2d3_none_no_weight_negative_ignore_index_expanded", + // // "test_negative_log_likelihood_loss_input_shape_is_NCd1d2d3_none_no_weight_negative_ignore_index", + // // "test_negative_log_likelihood_loss_input_shape_is_NCd1d2d3_sum_weight_high_ignore_index_expanded", + // // "test_negative_log_likelihood_loss_input_shape_is_NCd1d2d3_sum_weight_high_ignore_index", + // // "test_negative_log_likelihood_loss_input_shape_is_NCd1d2d3d4d5_mean_weight_expanded", + // // "test_negative_log_likelihood_loss_input_shape_is_NCd1d2d3d4d5_mean_weight", + // // "test_negative_log_likelihood_loss_input_shape_is_NCd1d2d3d4d5_none_no_weight_expanded", + // // "test_negative_log_likelihood_loss_input_shape_is_NCd1d2d3d4d5_none_no_weight", + // // "test_nesterov_momentum", + // // "test_nllloss_NC_expanded", + // // "test_nllloss_NC", + // // "test_nllloss_NCd1_expanded", + // // "test_nllloss_NCd1_ii_expanded", + // // "test_nllloss_NCd1_ii", + // // "test_nllloss_NCd1_mean_weight_negative_ii_expanded", + // // "test_nllloss_NCd1_mean_weight_negative_ii", + // // "test_nllloss_NCd1_weight_expanded", + // // "test_nllloss_NCd1_weight_ii_expanded", + // // "test_nllloss_NCd1_weight_ii", + // // "test_nllloss_NCd1_weight", + // // "test_nllloss_NCd1", + // // "test_nllloss_NCd1d2_expanded", + // // "test_nllloss_NCd1d2_no_weight_reduction_mean_ii_expanded", + // // "test_nllloss_NCd1d2_no_weight_reduction_mean_ii", + // // "test_nllloss_NCd1d2_reduction_mean_expanded", + // // "test_nllloss_NCd1d2_reduction_mean", + // // "test_nllloss_NCd1d2_reduction_sum_expanded", + // // "test_nllloss_NCd1d2_reduction_sum", + // // "test_nllloss_NCd1d2_with_weight_expanded", + // // "test_nllloss_NCd1d2_with_weight_reduction_mean_expanded", + // // "test_nllloss_NCd1d2_with_weight_reduction_mean", + // // "test_nllloss_NCd1d2_with_weight_reduction_sum_expanded", + // // "test_nllloss_NCd1d2_with_weight_reduction_sum_ii_expanded", + // // "test_nllloss_NCd1d2_with_weight_reduction_sum_ii", + // // "test_nllloss_NCd1d2_with_weight_reduction_sum", + // // "test_nllloss_NCd1d2_with_weight", + // // "test_nllloss_NCd1d2", + // // "test_nllloss_NCd1d2d3_none_no_weight_negative_ii_expanded", + // // "test_nllloss_NCd1d2d3_none_no_weight_negative_ii", + // // "test_nllloss_NCd1d2d3_sum_weight_high_ii_expanded", + // // "test_nllloss_NCd1d2d3_sum_weight_high_ii", + // // "test_nllloss_NCd1d2d3d4d5_mean_weight_expanded", + // // "test_nllloss_NCd1d2d3d4d5_mean_weight", + // // "test_nllloss_NCd1d2d3d4d5_none_no_weight_expanded", + // // "test_nllloss_NCd1d2d3d4d5_none_no_weight", + // "test_nonmaxsuppression_center_point_box_format", + // "test_nonmaxsuppression_flipped_coordinates", + // "test_nonmaxsuppression_identical_boxes", + // "test_nonmaxsuppression_limit_output_size", + // "test_nonmaxsuppression_single_box", + // "test_nonmaxsuppression_suppress_by_IOU_and_scores", + // "test_nonmaxsuppression_suppress_by_IOU", + // "test_nonmaxsuppression_two_batches", + // "test_nonmaxsuppression_two_classes", + // "test_nonzero_example", + "test_not_2d", + "test_not_3d", + "test_not_4d", + // // "test_onehot_negative_indices", + // // "test_onehot_with_axis", + // // "test_onehot_with_negative_axis", + // // "test_onehot_without_axis", + // // "test_optional_get_element_sequence", + // // "test_optional_get_element", + // // "test_optional_has_element_empty", + // // "test_optional_has_element", + // "test_or_bcast3v1d", + // "test_or_bcast3v2d", + // "test_or_bcast4v2d", + // "test_or_bcast4v3d", + // "test_or_bcast4v4d", + // "test_or2d", + // "test_or3d", + // "test_or4d", + "test_pow_bcast_array", + "test_pow_bcast_scalar", + "test_pow_example", + // "test_pow_types_float", + // "test_pow_types_float32_int32", + // "test_pow_types_float32_int64", + // "test_pow_types_float32_uint32", + // "test_pow_types_float32_uint64", + // "test_pow_types_int", + // "test_pow_types_int32_float32", + // "test_pow_types_int32_int32", + // "test_pow_types_int64_float32", + // "test_pow_types_int64_int64", + "test_pow", + // "test_prelu_broadcast", + // "test_prelu_example", + // // "test_qlinearconv", + // // "test_qlinearmatmul_2D", + // // "test_qlinearmatmul_3D", + // // "test_quantizelinear_axis", + // // "test_quantizelinear", + "test_range_float_type_positive_delta_expanded", + "test_range_float_type_positive_delta", + "test_range_int32_type_negative_delta_expanded", + "test_range_int32_type_negative_delta", + "test_reciprocal_example", + "test_reciprocal", + "test_reduce_l1_default_axes_keepdims_example", + "test_reduce_l1_default_axes_keepdims_random", + "test_reduce_l1_do_not_keepdims_example", + "test_reduce_l1_do_not_keepdims_random", + "test_reduce_l1_keep_dims_example", + "test_reduce_l1_keep_dims_random", + "test_reduce_l1_negative_axes_keep_dims_example", + "test_reduce_l1_negative_axes_keep_dims_random", + "test_reduce_l2_default_axes_keepdims_example", + "test_reduce_l2_default_axes_keepdims_random", + "test_reduce_l2_do_not_keepdims_example", + "test_reduce_l2_do_not_keepdims_random", + "test_reduce_l2_keep_dims_example", + "test_reduce_l2_keep_dims_random", + "test_reduce_l2_negative_axes_keep_dims_example", + "test_reduce_l2_negative_axes_keep_dims_random", + "test_reduce_log_sum_asc_axes", + "test_reduce_log_sum_default", + "test_reduce_log_sum_desc_axes", + // tests "test_reduce_log_sum_exp_*" on opset17/opset18 are excluded because they use float64. + "opset{7,8,9}/test_reduce_log_sum_exp_default_axes_keepdims_example", + "opset{7,8,9}/test_reduce_log_sum_exp_default_axes_keepdims_random", + "opset{7,8,9}/test_reduce_log_sum_exp_do_not_keepdims_example", + "opset{7,8,9}/test_reduce_log_sum_exp_do_not_keepdims_random", + "opset{7,8,9}/test_reduce_log_sum_exp_keepdims_example", + "opset{7,8,9}/test_reduce_log_sum_exp_keepdims_random", + "opset11/test_reduce_log_sum_exp_negative_axes_keepdims_example", + "opset11/test_reduce_log_sum_exp_negative_axes_keepdims_random", + "test_reduce_log_sum_negative_axes", + "test_reduce_log_sum", + "test_reduce_max_default_axes_keepdim_example", + "test_reduce_max_default_axes_keepdims_random", + "test_reduce_max_do_not_keepdims_example", + "test_reduce_max_do_not_keepdims_random", + "test_reduce_max_keepdims_example", + "test_reduce_max_keepdims_random", + "test_reduce_max_negative_axes_keepdims_example", + "test_reduce_max_negative_axes_keepdims_random", + "test_reduce_mean_default_axes_keepdims_example", + "test_reduce_mean_default_axes_keepdims_random", + "test_reduce_mean_do_not_keepdims_example", + "test_reduce_mean_do_not_keepdims_random", + "test_reduce_mean_keepdims_example", + "test_reduce_mean_keepdims_random", + "test_reduce_mean_negative_axes_keepdims_example", + "test_reduce_mean_negative_axes_keepdims_random", + "test_reduce_min_default_axes_keepdims_example", + "test_reduce_min_default_axes_keepdims_random", + "test_reduce_min_do_not_keepdims_example", + "test_reduce_min_do_not_keepdims_random", + "test_reduce_min_keepdims_example", + "test_reduce_min_keepdims_random", + "test_reduce_min_negative_axes_keepdims_example", + "test_reduce_min_negative_axes_keepdims_random", + "test_reduce_prod_default_axes_keepdims_example", + "test_reduce_prod_default_axes_keepdims_random", + "test_reduce_prod_do_not_keepdims_example", + "test_reduce_prod_do_not_keepdims_random", + "test_reduce_prod_keepdims_example", + "test_reduce_prod_keepdims_random", + "test_reduce_prod_negative_axes_keepdims_example", + "test_reduce_prod_negative_axes_keepdims_random", + "test_reduce_sum_default_axes_keepdims_example", + "test_reduce_sum_default_axes_keepdims_random", + "test_reduce_sum_do_not_keepdims_example", + "test_reduce_sum_do_not_keepdims_random", + "test_reduce_sum_empty_axes_input_noop_example", + "test_reduce_sum_empty_axes_input_noop_random", + "test_reduce_sum_keepdims_example", + "test_reduce_sum_keepdims_random", + "test_reduce_sum_negative_axes_keepdims_example", + "test_reduce_sum_negative_axes_keepdims_random", + "test_reduce_sum_square_default_axes_keepdims_example", + "test_reduce_sum_square_default_axes_keepdims_random", + "test_reduce_sum_square_do_not_keepdims_example", + "test_reduce_sum_square_do_not_keepdims_random", + "test_reduce_sum_square_keepdims_example", + "test_reduce_sum_square_keepdims_random", + "test_reduce_sum_square_negative_axes_keepdims_example", + "test_reduce_sum_square_negative_axes_keepdims_random", + "test_reflect_pad", + "test_relu", + // "test_reshape_allowzero_reordered", + "test_reshape_extended_dims", + "test_reshape_negative_dim", + "test_reshape_negative_extended_dims", + "test_reshape_one_dim", + "test_reshape_reduced_dims", + "test_reshape_reordered_all_dims", + "test_reshape_reordered_dims", + "test_reshape_reordered_last_dims", + "test_reshape_zero_and_negative_dim", + "test_reshape_zero_dim", + "test_resize_downsample_linear", + "test_resize_downsample_nearest", + "test_resize_downsample_scales_cubic_A_n0p5_exclude_outside", + // "test_resize_downsample_scales_cubic_align_corners", + "test_resize_downsample_scales_cubic", + // "test_resize_downsample_scales_linear_align_corners", + "test_resize_downsample_scales_linear", + "test_resize_downsample_scales_nearest", + "test_resize_downsample_sizes_cubic", + "test_resize_downsample_sizes_linear_pytorch_half_pixel", + "test_resize_downsample_sizes_nearest_tf_half_pixel_for_nn", + "test_resize_downsample_sizes_nearest", + "test_resize_nearest", + "test_resize_tf_crop_and_resize", + "test_resize_upsample_linear", + "test_resize_upsample_nearest", + "test_resize_upsample_scales_cubic_A_n0p5_exclude_outside", + "test_resize_upsample_scales_cubic_align_corners", + "test_resize_upsample_scales_cubic_asymmetric", + "test_resize_upsample_scales_cubic", + "test_resize_upsample_scales_linear_align_corners", + "test_resize_upsample_scales_linear", + "test_resize_upsample_scales_nearest", + "test_resize_upsample_sizes_cubic", + "opset{12,13,17,18}/test_resize_upsample_sizes_nearest_ceil_half_pixel", + "opset{12,13,17,18}/test_resize_upsample_sizes_nearest_floor_align_corners", + "opset{12,13,17,18}/test_resize_upsample_sizes_nearest_round_prefer_ceil_asymmetric", + "test_resize_upsample_sizes_nearest", + // // "test_reversesequence_batch", + // // "test_reversesequence_time", + // // "test_rnn_seq_length", + // // "test_roialign_aligned_false", + // // "test_roialign_aligned_true", + // // "test_roialign", + // // "test_round", + // // "test_scan_sum", + // // "test_scan9_sum", + // // "test_scatter_elements_with_axis", + // // "test_scatter_elements_with_duplicate_indices", + // // "test_scatter_elements_with_negative_indices", + // // "test_scatter_elements_without_axis", + // // "test_scatter_with_axis", + // // "test_scatter_without_axis", + // // "test_scatternd_add", + // // "test_scatternd_multiply", + // // "test_scatternd", + // // "test_sce_mean_3d_expanded", + // // "test_sce_mean_3d_log_prob_expanded", + // // "test_sce_mean_3d_log_prob", + // // "test_sce_mean_3d", + // // "test_sce_mean_expanded", + // // "test_sce_mean_log_prob_expanded", + // // "test_sce_mean_log_prob", + // // "test_sce_mean_no_weight_ii_3d_expanded", + // // "test_sce_mean_no_weight_ii_3d_log_prob_expanded", + // // "test_sce_mean_no_weight_ii_3d_log_prob", + // // "test_sce_mean_no_weight_ii_3d", + // // "test_sce_mean_no_weight_ii_4d_expanded", + // // "test_sce_mean_no_weight_ii_4d_log_prob_expanded", + // // "test_sce_mean_no_weight_ii_4d_log_prob", + // // "test_sce_mean_no_weight_ii_4d", + // // "test_sce_mean_no_weight_ii_expanded", + // // "test_sce_mean_no_weight_ii_log_prob_expanded", + // // "test_sce_mean_no_weight_ii_log_prob", + // // "test_sce_mean_no_weight_ii", + // // "test_sce_mean_weight_expanded", + // // "test_sce_mean_weight_ii_3d_expanded", + // // "test_sce_mean_weight_ii_3d_log_prob_expanded", + // // "test_sce_mean_weight_ii_3d_log_prob", + // // "test_sce_mean_weight_ii_3d", + // // "test_sce_mean_weight_ii_4d_expanded", + // // "test_sce_mean_weight_ii_4d_log_prob_expanded", + // // "test_sce_mean_weight_ii_4d_log_prob", + // // "test_sce_mean_weight_ii_4d", + // // "test_sce_mean_weight_ii_expanded", + // // "test_sce_mean_weight_ii_log_prob_expanded", + // // "test_sce_mean_weight_ii_log_prob", + // // "test_sce_mean_weight_ii", + // // "test_sce_mean_weight_log_prob_expanded", + // // "test_sce_mean_weight_log_prob", + // // "test_sce_mean_weight", + // // "test_sce_mean", + // // "test_sce_NCd1_mean_weight_negative_ii_expanded", + // // "test_sce_NCd1_mean_weight_negative_ii_log_prob_expanded", + // // "test_sce_NCd1_mean_weight_negative_ii_log_prob", + // // "test_sce_NCd1_mean_weight_negative_ii", + // // "test_sce_NCd1d2d3_none_no_weight_negative_ii_expanded", + // // "test_sce_NCd1d2d3_none_no_weight_negative_ii_log_prob_expanded", + // // "test_sce_NCd1d2d3_none_no_weight_negative_ii_log_prob", + // // "test_sce_NCd1d2d3_none_no_weight_negative_ii", + // // "test_sce_NCd1d2d3_sum_weight_high_ii_expanded", + // // "test_sce_NCd1d2d3_sum_weight_high_ii_log_prob_expanded", + // // "test_sce_NCd1d2d3_sum_weight_high_ii_log_prob", + // // "test_sce_NCd1d2d3_sum_weight_high_ii", + // // "test_sce_NCd1d2d3d4d5_mean_weight_expanded", + // // "test_sce_NCd1d2d3d4d5_mean_weight_log_prob_expanded", + // // "test_sce_NCd1d2d3d4d5_mean_weight_log_prob", + // // "test_sce_NCd1d2d3d4d5_mean_weight", + // // "test_sce_NCd1d2d3d4d5_none_no_weight_expanded", + // // "test_sce_NCd1d2d3d4d5_none_no_weight_log_prob_expanded", + // // "test_sce_NCd1d2d3d4d5_none_no_weight_log_prob", + // // "test_sce_NCd1d2d3d4d5_none_no_weight", + // // "test_sce_none_expanded", + // // "test_sce_none_log_prob_expanded", + // // "test_sce_none_log_prob", + // // "test_sce_none_weights_expanded", + // // "test_sce_none_weights_log_prob_expanded", + // // "test_sce_none_weights_log_prob", + // // "test_sce_none_weights", + // // "test_sce_none", + // // "test_sce_sum_expanded", + // // "test_sce_sum_log_prob_expanded", + // // "test_sce_sum_log_prob", + // // "test_sce_sum", + // "test_selu_default", + // "test_selu_example", + // "test_selu", + // // "test_sequence_insert_at_back", + // // "test_sequence_insert_at_front", + // // "test_sequence_map_add_1_sequence_1_tensor_expanded", + // // "test_sequence_map_add_1_sequence_1_tensor", + // // "test_sequence_map_add_2_sequences_expanded", + // // "test_sequence_map_add_2_sequences", + // // "test_sequence_map_extract_shapes_expanded", + // // "test_sequence_map_extract_shapes", + // // "test_sequence_map_identity_1_sequence_1_tensor_expanded", + // // "test_sequence_map_identity_1_sequence_1_tensor", + // // "test_sequence_map_identity_1_sequence_expanded", + // // "test_sequence_map_identity_1_sequence", + // // "test_sequence_map_identity_2_sequences_expanded", + // // "test_sequence_map_identity_2_sequences", + // "test_shrink_hard", + // "test_shrink_soft", + "test_sigmoid_example", + "test_sigmoid", + // "test_sign", + // "test_simple_rnn_batchwise", + // "test_simple_rnn_defaults", + // "test_simple_rnn_with_initial_bias", + "test_sin_example", + "test_sin", + "test_sinh_example", + "test_sinh", + // // "test_size_example", + // // "test_size", + "test_slice_default_axes", + "test_slice_default_steps", + // "test_slice_end_out_of_bounds", + "test_slice_neg_steps", + "test_slice_neg", + "test_slice_negative_axes", + // "test_slice_start_out_of_bounds", + "test_slice", + // "test_softmax_axis_0_expanded", + // "test_softmax_axis_0", + // "test_softmax_axis_1_expanded", + // "test_softmax_axis_1", + "test_softmax_axis_2_expanded", + "test_softmax_axis_2", + // "test_softmax_cross_entropy_input_shape_is_NCd1_mean_weight_negative_ignore_index_expanded", + // "test_softmax_cross_entropy_input_shape_is_NCd1_mean_weight_negative_ignore_index_log_prob_expanded", + // "test_softmax_cross_entropy_input_shape_is_NCd1_mean_weight_negative_ignore_index_log_prob", + // "test_softmax_cross_entropy_input_shape_is_NCd1_mean_weight_negative_ignore_index", + // "test_softmax_cross_entropy_input_shape_is_NCd1d2d3_none_no_weight_negative_ignore_index_expanded", + // "test_softmax_cross_entropy_input_shape_is_NCd1d2d3_none_no_weight_negative_ignore_index_log_prob_expanded", + // "test_softmax_cross_entropy_input_shape_is_NCd1d2d3_none_no_weight_negative_ignore_index_log_prob", + // "test_softmax_cross_entropy_input_shape_is_NCd1d2d3_none_no_weight_negative_ignore_index", + // "test_softmax_cross_entropy_input_shape_is_NCd1d2d3_sum_weight_high_ignore_index_expanded", + // "test_softmax_cross_entropy_input_shape_is_NCd1d2d3_sum_weight_high_ignore_index_log_prob_expanded", + // "test_softmax_cross_entropy_input_shape_is_NCd1d2d3_sum_weight_high_ignore_index_log_prob", + // "test_softmax_cross_entropy_input_shape_is_NCd1d2d3_sum_weight_high_ignore_index", + // "test_softmax_cross_entropy_input_shape_is_NCd1d2d3d4d5_mean_weight_expanded", + // "test_softmax_cross_entropy_input_shape_is_NCd1d2d3d4d5_mean_weight_log_prob_expanded", + // "test_softmax_cross_entropy_input_shape_is_NCd1d2d3d4d5_mean_weight_log_prob", + // "test_softmax_cross_entropy_input_shape_is_NCd1d2d3d4d5_mean_weight", + // "test_softmax_cross_entropy_input_shape_is_NCd1d2d3d4d5_none_no_weight_expanded", + // "test_softmax_cross_entropy_input_shape_is_NCd1d2d3d4d5_none_no_weight_log_prob_expanded", + // "test_softmax_cross_entropy_input_shape_is_NCd1d2d3d4d5_none_no_weight_log_prob", + // "test_softmax_cross_entropy_input_shape_is_NCd1d2d3d4d5_none_no_weight", + // "test_softmax_cross_entropy_mean_3d_expanded", + // "test_softmax_cross_entropy_mean_3d_log_prob_expanded", + // "test_softmax_cross_entropy_mean_3d_log_prob", + // "test_softmax_cross_entropy_mean_3d", + // "test_softmax_cross_entropy_mean_expanded", + // "test_softmax_cross_entropy_mean_log_prob_expanded", + // "test_softmax_cross_entropy_mean_log_prob", + // "test_softmax_cross_entropy_mean_no_weight_ignore_index_3d_expanded", + // "test_softmax_cross_entropy_mean_no_weight_ignore_index_3d_log_prob_expanded", + // "test_softmax_cross_entropy_mean_no_weight_ignore_index_3d_log_prob", + // "test_softmax_cross_entropy_mean_no_weight_ignore_index_3d", + // "test_softmax_cross_entropy_mean_no_weight_ignore_index_4d_expanded", + // "test_softmax_cross_entropy_mean_no_weight_ignore_index_4d_log_prob_expanded", + // "test_softmax_cross_entropy_mean_no_weight_ignore_index_4d_log_prob", + // "test_softmax_cross_entropy_mean_no_weight_ignore_index_4d", + // "test_softmax_cross_entropy_mean_no_weight_ignore_index_expanded", + // "test_softmax_cross_entropy_mean_no_weight_ignore_index_log_prob_expanded", + // "test_softmax_cross_entropy_mean_no_weight_ignore_index_log_prob", + // "test_softmax_cross_entropy_mean_no_weight_ignore_index", + // "test_softmax_cross_entropy_mean_weight_expanded", + // "test_softmax_cross_entropy_mean_weight_ignore_index_3d_expanded", + // "test_softmax_cross_entropy_mean_weight_ignore_index_3d_log_prob_expanded", + // "test_softmax_cross_entropy_mean_weight_ignore_index_3d_log_prob", + // "test_softmax_cross_entropy_mean_weight_ignore_index_3d", + // "test_softmax_cross_entropy_mean_weight_ignore_index_4d_expanded", + // "test_softmax_cross_entropy_mean_weight_ignore_index_4d_log_prob_expanded", + // "test_softmax_cross_entropy_mean_weight_ignore_index_4d_log_prob", + // "test_softmax_cross_entropy_mean_weight_ignore_index_4d", + // "test_softmax_cross_entropy_mean_weight_ignore_index_expanded", + // "test_softmax_cross_entropy_mean_weight_ignore_index_log_prob_expanded", + // "test_softmax_cross_entropy_mean_weight_ignore_index_log_prob", + // "test_softmax_cross_entropy_mean_weight_ignore_index", + // "test_softmax_cross_entropy_mean_weight_log_prob_expanded", + // "test_softmax_cross_entropy_mean_weight_log_prob", + // "test_softmax_cross_entropy_mean_weight", + // "test_softmax_cross_entropy_mean", + // "test_softmax_cross_entropy_none_expanded", + // "test_softmax_cross_entropy_none_log_prob_expanded", + // "test_softmax_cross_entropy_none_log_prob", + // "test_softmax_cross_entropy_none_weights_expanded", + // "test_softmax_cross_entropy_none_weights_log_prob_expanded", + // "test_softmax_cross_entropy_none_weights_log_prob", + // "test_softmax_cross_entropy_none_weights", + // "test_softmax_cross_entropy_none", + // "test_softmax_cross_entropy_sum_expanded", + // "test_softmax_cross_entropy_sum_log_prob_expanded", + // "test_softmax_cross_entropy_sum_log_prob", + // "test_softmax_cross_entropy_sum", + "opset13/test_softmax_default_axis_expanded", + "opset13/test_softmax_default_axis", + "test_softmax_example_expanded", + "test_softmax_example", + "test_softmax_large_number_expanded", + "test_softmax_large_number", + "test_softmax_negative_axis_expanded", + "test_softmax_negative_axis", + // // "test_softplus_example", + // // "test_softplus", + // // "test_softsign_example", + // // "test_softsign", + // "test_spacetodepth_example", + // "test_spacetodepth", + "test_split_equal_parts_1d", + "test_split_equal_parts_2d", + "test_split_equal_parts_default_axis", + "test_split_variable_parts_1d", + "test_split_variable_parts_2d", + "test_split_variable_parts_default_axis", + "test_split_zero_size_splits", + "test_sqrt_example", + "test_sqrt", + "test_squeeze_negative_axes", + "test_squeeze", + // // "test_stft_with_window", + // // "test_stft", + // // "test_strnormalizer_export_monday_casesensintive_lower", + // // "test_strnormalizer_export_monday_casesensintive_nochangecase", + // // "test_strnormalizer_export_monday_casesensintive_upper", + // // "test_strnormalizer_export_monday_empty_output", + // // "test_strnormalizer_export_monday_insensintive_upper_twodim", + // // "test_strnormalizer_nostopwords_nochangecase", + "test_sub_bcast", + "test_sub_example", + // "test_sub_uint8", + "test_sub", + // "test_sum_example", + // "test_sum_one_input", + // "test_sum_two_inputs", + "test_tan_example", + "test_tan", + "test_tanh_example", + "test_tanh", + // // "test_tfidfvectorizer_tf_batch_onlybigrams_skip0", + // // "test_tfidfvectorizer_tf_batch_onlybigrams_skip5", + // // "test_tfidfvectorizer_tf_batch_uniandbigrams_skip5", + // // "test_tfidfvectorizer_tf_only_bigrams_skip0", + // // "test_tfidfvectorizer_tf_onlybigrams_levelempty", + // // "test_tfidfvectorizer_tf_onlybigrams_skip5", + // // "test_tfidfvectorizer_tf_uniandbigrams_skip5", + "test_thresholdedrelu_default", + "test_thresholdedrelu_example", + "test_thresholdedrelu", + "test_tile_precomputed", + "test_tile", + // // "test_top_k_negative_axis", + // // "test_top_k_smallest", + // // "test_top_k", + // // "test_training_dropout_default_mask", + // // "test_training_dropout_default", + // // "test_training_dropout_mask", + // // "test_training_dropout_zero_ratio_mask", + // // "test_training_dropout_zero_ratio", + // // "test_training_dropout", + "test_transpose_all_permutations_0", + "test_transpose_all_permutations_1", + "test_transpose_all_permutations_2", + "test_transpose_all_permutations_3", + "test_transpose_all_permutations_4", + "test_transpose_all_permutations_5", + "test_transpose_default", + // "test_tril_neg", + // "test_tril_one_row_neg", + // "test_tril_out_neg", + // "test_tril_out_pos", + // "test_tril_pos", + // "test_tril_square_neg", + // "test_tril_square", + // "test_tril_zero", + // "test_tril", + // "test_triu_neg", + // "test_triu_one_row", + // "test_triu_out_neg_out", + // "test_triu_out_pos", + // "test_triu_pos", + // "test_triu_square_neg", + // "test_triu_square", + // "test_triu_zero", + // "test_triu", + // // "test_unique_not_sorted_without_axis", + // // "test_unique_sorted_with_axis_3d", + // // "test_unique_sorted_with_axis", + // // "test_unique_sorted_with_negative_axis", + // // "test_unique_sorted_without_axis", + "test_unsqueeze_axis_0", + "test_unsqueeze_axis_1", + "test_unsqueeze_axis_2", + "test_unsqueeze_axis_3", + "test_unsqueeze_negative_axes", + "test_unsqueeze_three_axes", + "test_unsqueeze_two_axes", + "test_unsqueeze_unsorted_axes", + "test_unsqueeze", + "test_wrap_pad" + // "test_upsample_nearest", + // "test_where_example", + // "test_where_long_example", + // "test_xor_bcast3v1d", + // "test_xor_bcast3v2d", + // "test_xor_bcast4v2d", + // "test_xor_bcast4v3d", + // "test_xor_bcast4v4d", + // "test_xor2d", + // "test_xor3d", + // "test_xor4d" ]; const path = require('path'); @@ -21,6 +1061,12 @@ const { spawnSync } = require('child_process'); const ONNX_TEST_RUNNER_FILENAME = path.join(__dirname, 'onnx_test_runner' + (process.platform === 'win32' ? '.exe' : '')); +if (!fs.existsSync(ONNX_TEST_RUNNER_FILENAME)) { + console.error('Error: onnx_test_runner not found.'); + console.error('Please perform a build and run this script in the build folder.'); + process.exit(1); +} + if (process.argv.includes('-h')) { console.log(HELP); process.exit(0); @@ -34,65 +1080,55 @@ if (!test_data_path) { test_data_path = test_data_path.substring(3); } -const test_models = []; +let test_models = DEFAULT_TESTS; const test_model_list = process.argv.find(arg => arg.startsWith('-m=')); if (test_model_list) { + test_models = []; test_model_list.substring(3).split(';').forEach(test_model => { test_models.push(test_model); }); } -const tests = new Set(test_model_list ? test_models : DEFAULT_TESTS); -const test_cases = []; -fs.readdirSync(test_data_path, { withFileTypes: true }).forEach(dirent => { - if (dirent.isDirectory()) { - const opset = dirent.name; - fs.readdirSync(path.join(test_data_path, opset), { withFileTypes: true }).forEach(dirent => { - if (dirent.isDirectory()) { - const name = dirent.name; - if (tests.has(name)) { - test_cases.push(path.join(test_data_path, opset, name)); - } - } - }); - } -}); +const tests = new Set(test_models); + +const TEST_ROOT = path.join(__dirname, 'webgpu_test_root'); -let passed = []; -let not_implemented = []; -let failed = []; -test_cases.forEach(test_case => { - process.stdout.write(`Running test case: "${test_case}"...`); - const args = [ - '-e', 'webgpu', '-C', '"session.disable_cpu_ep_fallback|1"', test_case, - ]; - if (VERBOSE) { - args.unshift('-v'); +let test_data_ready = false; +const test_list_json_data = JSON.stringify(test_models, null, 2); +const test_list_json_filepath = path.join(TEST_ROOT, 'test_list.json'); +if (fs.existsSync(TEST_ROOT)) { + if (fs.existsSync(test_list_json_filepath)) { + test_data_ready = fs.readFileSync(test_list_json_filepath).toString() == test_list_json_data; } - const p = spawnSync(ONNX_TEST_RUNNER_FILENAME, args, { shell: true, stdio: ['ignore', 'pipe', 'pipe'] }); - if (p.status !== 0) { - process.stdout.write('Failed\n'); - failed.push(test_case); - } else if (!p.stdout.toString().includes('Not implemented: 0')) { - process.stdout.write('Not Implemented\n'); - not_implemented.push(test_case); - } else { - process.stdout.write('OK\n'); - passed.push(test_case); + if (!test_data_ready) { + fs.rmdirSync(TEST_ROOT, { recursive: true }); } -}); +} +if (!test_data_ready) { + fs.mkdirSync(TEST_ROOT); -console.log(`\n${passed.length} tests passed.`); -console.log(`\n${not_implemented.length} tests not implemented:`); -not_implemented.slice(0, 3).forEach(test_case => { - console.log(` ${test_case}`); -}); -if (not_implemented.length > 3) { - console.log(` ...`); + fs.readdirSync(test_data_path, { withFileTypes: true }).forEach(dirent => { + if (dirent.isDirectory()) { + const opset = dirent.name; + fs.readdirSync(path.join(test_data_path, opset), { withFileTypes: true }).forEach(dirent => { + if (dirent.isDirectory()) { + const name = dirent.name; + if (tests.has(name)) { + fs.symlinkSync(path.join(test_data_path, opset, name), path.join(TEST_ROOT, `${opset}_${name}`), 'junction'); + } + } + }); + } + }); + fs.writeFileSync(test_list_json_filepath, test_list_json_data); } -console.log(`\n${failed.length} tests failed:`); -failed.slice(0, 3).forEach(test_case => { - console.log(` ${test_case}`); -}); -if (failed.length > 3) { - console.log(` ...`); + +const args = ['-A', '-M', '-j', '1', '-e', 'webgpu', '-C', '"session.disable_cpu_ep_fallback|1 storageBufferCacheMode|disabled"', 'webgpu_test_root']; +if (VERBOSE) { + args.unshift('-v'); } +process.exit( + spawnSync( + ONNX_TEST_RUNNER_FILENAME, + args, + { shell: true, cwd: __dirname, stdio: 'inherit' } + ).status); From 947aee18a2b15d1aa2501f546aa6426a20bf6466 Mon Sep 17 00:00:00 2001 From: Yulong Wang <7679871+fs-eire@users.noreply.github.com> Date: Sun, 1 Sep 2024 12:46:01 -0700 Subject: [PATCH 27/77] device lost handler --- onnxruntime/core/providers/webgpu/webgpu_context.cc | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/onnxruntime/core/providers/webgpu/webgpu_context.cc b/onnxruntime/core/providers/webgpu/webgpu_context.cc index 9e51cc08eec0f..776fbb069bb5e 100644 --- a/onnxruntime/core/providers/webgpu/webgpu_context.cc +++ b/onnxruntime/core/providers/webgpu/webgpu_context.cc @@ -86,10 +86,15 @@ void WebGpuContext::Initialize(const WebGpuExecutionProviderInfo& webgpu_ep_info wgpu::RequiredLimits required_limits = GetAvailableRequiredLimits(adapter_); device_desc.requiredLimits = &required_limits; - // TODO: temporary error handling + // TODO: revise temporary error handling device_desc.SetUncapturedErrorCallback([](const wgpu::Device& /*device*/, wgpu::ErrorType type, const char* message) { LOGS_DEFAULT(ERROR) << "WebGPU device error(" << int(type) << "): " << message; }); + // TODO: revise temporary device lost handling + device_desc.SetDeviceLostCallback(wgpu::CallbackMode::AllowSpontaneous, [](const wgpu::Device& /*device*/, wgpu::DeviceLostReason reason, const char* message) { + // cannot use ORT logger because it may be already destroyed + std::cerr << "WebGPU device lost (" << int(reason) << "): " << message; + }); wgpu::RequestDeviceCallbackInfo req_device_callback_info = {}; req_device_callback_info.mode = wgpu::CallbackMode::WaitAnyOnly; From 99b2578a49444684ff820afa56ebb84a911c35de Mon Sep 17 00:00:00 2001 From: Yulong Wang <7679871+fs-eire@users.noreply.github.com> Date: Sun, 1 Sep 2024 14:38:47 -0700 Subject: [PATCH 28/77] add '-a' and '-t' to test runner --- .../webgpu/How_to_Write_WebGPU_EP_Kernel.md | 19 ++++++++++++++++--- .../test/providers/webgpu/test_webgpu.js | 3 ++- 2 files changed, 18 insertions(+), 4 deletions(-) diff --git a/onnxruntime/core/providers/webgpu/How_to_Write_WebGPU_EP_Kernel.md b/onnxruntime/core/providers/webgpu/How_to_Write_WebGPU_EP_Kernel.md index 624cfd80dd8f7..9bc19a2099a4c 100644 --- a/onnxruntime/core/providers/webgpu/How_to_Write_WebGPU_EP_Kernel.md +++ b/onnxruntime/core/providers/webgpu/How_to_Write_WebGPU_EP_Kernel.md @@ -182,10 +182,23 @@ To add more tests to the suite list, edit the file at `C:\code\onnxruntime\onnxr to test or debug a single test, find the "onnx_test_runner.exe" in your build folder. run it like: ``` -onnx_test_runner.exe -v -e webgpu -C "session.disable_cpu_ep_fallback|1" C:\code\onnxruntime\js\test\data\node\opset17\test_abs +onnx_test_runner.exe -v -e webgpu -a 0.0001 -t 0.0001 -C "session.disable_cpu_ep_fallback|1" C:\code\onnxruntime\js\test\data\node\opset17\test_abs ``` -The `-C` flag is split by space for each key-value pair. Each key-value pair is separated by `|`. The key is the option name and the value is the option value. +The `-C` flag is split by space for each key-value pair. Each key-value pair is separated by `|`. The key is the option name and the value is the option value. See `onnxruntime\core\providers\webgpu\webgpu_provider_options.h` for available WebGPU EP options. + +The `-a` and `-t` flags are used to specify the absolute and relative tolerance for the test. +- currently the value is set to `0.0001` for both absolute and relative tolerance for the WebGPU EP. +- `onnx_test_runner` will try to load file `\testdata\onnx_backend_test_series_overrides.jsonc>` if available to set the default tolerance values. It is recommended to set the tolerance values in the command line to ensure consistent behavior. + > This is why the following command may have different results: + > + > ``` + > C:\code\onnxruntime> build\Windows\Debug\Debug\onnx_test_runner.exe -e webgpu C:\code\onnxruntime\js\test\data\node\opset9\test_asin_example + > ``` + > + > ``` + > C:\code\onnxruntime\build\Windows\Debug\Debug> onnx_test_runner.exe -e webgpu C:\code\onnxruntime\js\test\data\node\opset9\test_asin_example + > ``` Some features are useful but if you are troubleshooting and want to rule out the cause, you can: @@ -195,5 +208,5 @@ Some features are useful but if you are troubleshooting and want to rule out the Example: ``` -onnx_test_runner.exe -v -A -M -j 1 -e webgpu -C "session.disable_cpu_ep_fallback|1 storageBufferCacheMode|disabled" C:\code\onnxruntime\js\test\data\node\opset17\test_abs +onnx_test_runner.exe -v -A -M -j 1 -e webgpu -a 0.0001 -t 0.0001 -C "session.disable_cpu_ep_fallback|1 storageBufferCacheMode|disabled" C:\code\onnxruntime\js\test\data\node\opset17\test_abs ``` diff --git a/onnxruntime/test/providers/webgpu/test_webgpu.js b/onnxruntime/test/providers/webgpu/test_webgpu.js index 254bded19ae7c..e6d28c9e5b4da 100644 --- a/onnxruntime/test/providers/webgpu/test_webgpu.js +++ b/onnxruntime/test/providers/webgpu/test_webgpu.js @@ -1122,7 +1122,8 @@ if (!test_data_ready) { fs.writeFileSync(test_list_json_filepath, test_list_json_data); } -const args = ['-A', '-M', '-j', '1', '-e', 'webgpu', '-C', '"session.disable_cpu_ep_fallback|1 storageBufferCacheMode|disabled"', 'webgpu_test_root']; +// const args = ['-A', '-M', '-j', '1', '-t', '0.0001', '-a', '0.0001', '-e', 'webgpu', '-C', '"session.disable_cpu_ep_fallback|1 storageBufferCacheMode|disabled"', 'webgpu_test_root']; +const args = ['-j', '1', '-t', '0.0001', '-a', '0.0001', '-e', 'webgpu', '-C', '"session.disable_cpu_ep_fallback|1"', 'webgpu_test_root']; if (VERBOSE) { args.unshift('-v'); } From aa7b3f52aaef02e6faed4acd4668f597507b6672 Mon Sep 17 00:00:00 2001 From: Yulong Wang <7679871+fs-eire@users.noreply.github.com> Date: Sun, 1 Sep 2024 15:25:51 -0700 Subject: [PATCH 29/77] atol/rtol 0.0001 -> 0.001 --- .../core/providers/webgpu/How_to_Write_WebGPU_EP_Kernel.md | 6 +++--- onnxruntime/test/providers/webgpu/test_webgpu.js | 4 ++-- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/onnxruntime/core/providers/webgpu/How_to_Write_WebGPU_EP_Kernel.md b/onnxruntime/core/providers/webgpu/How_to_Write_WebGPU_EP_Kernel.md index 9bc19a2099a4c..3e501cd957e03 100644 --- a/onnxruntime/core/providers/webgpu/How_to_Write_WebGPU_EP_Kernel.md +++ b/onnxruntime/core/providers/webgpu/How_to_Write_WebGPU_EP_Kernel.md @@ -182,13 +182,13 @@ To add more tests to the suite list, edit the file at `C:\code\onnxruntime\onnxr to test or debug a single test, find the "onnx_test_runner.exe" in your build folder. run it like: ``` -onnx_test_runner.exe -v -e webgpu -a 0.0001 -t 0.0001 -C "session.disable_cpu_ep_fallback|1" C:\code\onnxruntime\js\test\data\node\opset17\test_abs +onnx_test_runner.exe -v -e webgpu -a 0.001 -t 0.001 -C "session.disable_cpu_ep_fallback|1" C:\code\onnxruntime\js\test\data\node\opset17\test_abs ``` The `-C` flag is split by space for each key-value pair. Each key-value pair is separated by `|`. The key is the option name and the value is the option value. See `onnxruntime\core\providers\webgpu\webgpu_provider_options.h` for available WebGPU EP options. The `-a` and `-t` flags are used to specify the absolute and relative tolerance for the test. -- currently the value is set to `0.0001` for both absolute and relative tolerance for the WebGPU EP. +- currently the value is set to `0.001` for both absolute and relative tolerance for the WebGPU EP. - `onnx_test_runner` will try to load file `\testdata\onnx_backend_test_series_overrides.jsonc>` if available to set the default tolerance values. It is recommended to set the tolerance values in the command line to ensure consistent behavior. > This is why the following command may have different results: > @@ -208,5 +208,5 @@ Some features are useful but if you are troubleshooting and want to rule out the Example: ``` -onnx_test_runner.exe -v -A -M -j 1 -e webgpu -a 0.0001 -t 0.0001 -C "session.disable_cpu_ep_fallback|1 storageBufferCacheMode|disabled" C:\code\onnxruntime\js\test\data\node\opset17\test_abs +onnx_test_runner.exe -v -A -M -j 1 -e webgpu -a 0.001 -t 0.001 -C "session.disable_cpu_ep_fallback|1 storageBufferCacheMode|disabled" C:\code\onnxruntime\js\test\data\node\opset17\test_abs ``` diff --git a/onnxruntime/test/providers/webgpu/test_webgpu.js b/onnxruntime/test/providers/webgpu/test_webgpu.js index e6d28c9e5b4da..d6c452e1625c5 100644 --- a/onnxruntime/test/providers/webgpu/test_webgpu.js +++ b/onnxruntime/test/providers/webgpu/test_webgpu.js @@ -1122,8 +1122,8 @@ if (!test_data_ready) { fs.writeFileSync(test_list_json_filepath, test_list_json_data); } -// const args = ['-A', '-M', '-j', '1', '-t', '0.0001', '-a', '0.0001', '-e', 'webgpu', '-C', '"session.disable_cpu_ep_fallback|1 storageBufferCacheMode|disabled"', 'webgpu_test_root']; -const args = ['-j', '1', '-t', '0.0001', '-a', '0.0001', '-e', 'webgpu', '-C', '"session.disable_cpu_ep_fallback|1"', 'webgpu_test_root']; +// const args = ['-A', '-M', '-j', '1', '-t', '0.001', '-a', '0.001', '-e', 'webgpu', '-C', '"session.disable_cpu_ep_fallback|1 storageBufferCacheMode|disabled"', 'webgpu_test_root']; +const args = ['-j', '1', '-t', '0.001', '-a', '0.001', '-e', 'webgpu', '-C', '"session.disable_cpu_ep_fallback|1"', 'webgpu_test_root']; if (VERBOSE) { args.unshift('-v'); } From e659acd0eb3d35603fb77e7ff1e7ac3943c6d1a1 Mon Sep 17 00:00:00 2001 From: Yulong Wang <7679871+fs-eire@users.noreply.github.com> Date: Sun, 1 Sep 2024 23:31:52 -0700 Subject: [PATCH 30/77] Fix uniform --- .../core/providers/webgpu/program_manager.cc | 48 +++++------- .../core/providers/webgpu/program_manager.h | 8 -- .../core/providers/webgpu/shader_helper.cc | 4 + .../core/providers/webgpu/shader_variable.cc | 8 +- .../core/providers/webgpu/shader_variable.h | 61 ++++++++------- .../core/providers/webgpu/webgpu_context.cc | 74 ++++++++++++------- 6 files changed, 105 insertions(+), 98 deletions(-) diff --git a/onnxruntime/core/providers/webgpu/program_manager.cc b/onnxruntime/core/providers/webgpu/program_manager.cc index de228a038b7db..00036a915f695 100644 --- a/onnxruntime/core/providers/webgpu/program_manager.cc +++ b/onnxruntime/core/providers/webgpu/program_manager.cc @@ -14,37 +14,7 @@ namespace onnxruntime { namespace webgpu { ProgramArtifact::ProgramArtifact(const ProgramBase& program, wgpu::ComputePipeline&& compute_pipeline) - : name{program.Name()}, compute_pipeline{compute_pipeline} { - // prepare uniform info - size_t current_offset = 0; - for (const auto& uniform : program.UniformVariables()) { - bool is_f16 = uniform.data_type == ProgramUniformVariableDataType::Float16; - size_t length = uniform.length; - size_t element_size = ProgramUniformVariableDataTypeSize[static_cast(uniform.data_type)]; - // https://www.w3.org/TR/WGSL/#alignof - size_t base_alignment = is_f16 - ? (length > 4 ? 16 : length > 2 ? 8 - : length * element_size) - : (length > 2 ? 16 : length * element_size); - size_t struct_size = is_f16 && length <= 4 ? length * element_size : 16; - - current_offset = (current_offset + base_alignment - 1) / base_alignment * base_alignment; - uniforms.push_back({uniform.data_type, current_offset, length}); - - // For non-float16 type, when length > 4, the uniform variable is of type array,N>, where - // N = ceil(data.length / 4) and SizeOf(vec4) = 16. The total byte length is N * SizeOf(vec4). - // For float16 type, when length > 4, the uniform variable is of type array,N>, where - // N = ceil(data.length / 8) and SizeOf(mat2x4) = 16. The total byte length is N * SizeOf(mat2x4). - size_t element_per_struct = is_f16 ? 8 : 4; - current_offset += - length > 4 ? (length + element_per_struct - 1) / element_per_struct * struct_size : length * element_size; - } - - // Meet alignment of struct here: https://www.w3.org/TR/WGSL/#alignment-and-size. For simplicity, set - // max_alignment_of_field to 16 since the underlying buffer has been rounded up to 16. - const int max_alignment_of_field = 16; - uniform_total_size = (current_offset + max_alignment_of_field - 1) / max_alignment_of_field * max_alignment_of_field; -} + : name{program.Name()}, compute_pipeline{compute_pipeline} {} Status ProgramManager::NormalizeDispatchGroupSize(uint32_t& x, uint32_t& y, uint32_t& z) const { ORT_RETURN_IF(x == 0 || y == 0 || z == 0, "Invalid dispatch group size (", x, ", ", y, ", ", z, ")"); @@ -108,6 +78,22 @@ Status ProgramManager::Build(const ProgramBase& program, auto shader_module = device_.CreateShaderModule(&descriptor); + // TODO: a new cache hierarchy for constants. + // + // Explaination: + // Currently, we use Uniforms for dynamic data. This helps to reduce the number of program artifacts. + // + // "dynamic data" here means the data the determined at runtime, such as the shape of the input tensor. + // + // However, some programs may not necessarily depend on dynamic data. For example, "Clip" may depend on the value of "min" and "max". + // We are using uniforms for the value of "min" and "max" in the current implementation, but usually "min" and "max" are determined + // earlier because they are either from Attributes or from the initializers of the model. + // + // Questions: + // - can we use one instance of ShaderModule to create multiple ComputePipeline? + // - is there any benefit to do so compared to the current implementation? + // + // process overridable constants if available size_t constant_count = program.OverridableConstants().size(); diff --git a/onnxruntime/core/providers/webgpu/program_manager.h b/onnxruntime/core/providers/webgpu/program_manager.h index 9d1b7655c8640..087c75bfee773 100644 --- a/onnxruntime/core/providers/webgpu/program_manager.h +++ b/onnxruntime/core/providers/webgpu/program_manager.h @@ -21,20 +21,12 @@ class Tensor; namespace webgpu { -struct ProgramUniformInfo { - ProgramUniformVariableDataType data_type; - size_t offset; - size_t length; -}; - class ProgramArtifact { public: ProgramArtifact(const ProgramBase& program, wgpu::ComputePipeline&& compute_pipeline); std::string name; wgpu::ComputePipeline compute_pipeline; - std::vector uniforms; - size_t uniform_total_size; ProgramArtifact(ProgramArtifact&&) = default; ProgramArtifact& operator=(ProgramArtifact&&) = default; diff --git a/onnxruntime/core/providers/webgpu/shader_helper.cc b/onnxruntime/core/providers/webgpu/shader_helper.cc index 3986b13e0a7d7..5883696430de6 100644 --- a/onnxruntime/core/providers/webgpu/shader_helper.cc +++ b/onnxruntime/core/providers/webgpu/shader_helper.cc @@ -148,6 +148,10 @@ std::string ShaderHelper::GetFinalSourceCode() { const auto& data_type = uniform_def.data_type; const auto length = uniform_value.length; + if (length == 0) { + continue; + } + if (first) { first = false; } else { diff --git a/onnxruntime/core/providers/webgpu/shader_variable.cc b/onnxruntime/core/providers/webgpu/shader_variable.cc index 9483ab19036c4..fda4ad72deb20 100644 --- a/onnxruntime/core/providers/webgpu/shader_variable.cc +++ b/onnxruntime/core/providers/webgpu/shader_variable.cc @@ -12,12 +12,12 @@ namespace onnxruntime { namespace webgpu { -ShaderVariable::ShaderVariable(const std::string& name, ProgramVariableDataType type, int rank) +ShaderVariable::ShaderVariable(std::string_view name, ProgramVariableDataType type, int rank) : name_(name), type_(type), rank_(rank), usage_(UseUniform) { Init(); } -ShaderVariable::ShaderVariable(const std::string& name, ProgramVariableDataType type, const TensorShape& dims) +ShaderVariable::ShaderVariable(std::string_view name, ProgramVariableDataType type, const TensorShape& dims) : name_(name), type_(type), rank_(static_cast(dims.NumDimensions())), dims_(dims), usage_(None) { Init(); } @@ -171,7 +171,7 @@ void ShaderVariable::Impl(std::ostringstream& ss) const { } } -std::string ShaderVariable::GetByOffsetImpl(const std::string& offset) const { +std::string ShaderVariable::GetByOffsetImpl(std::string_view offset) const { std::ostringstream ss; ss.imbue(std::locale::classic()); @@ -199,7 +199,7 @@ std::string ShaderVariable::GetByOffsetImpl(const std::string& offset) const { return ss.str(); } -std::string ShaderVariable::SetByOffsetImpl(const std::string& offset, const std::string& value) const { +std::string ShaderVariable::SetByOffsetImpl(std::string_view offset, std::string_view value) const { std::ostringstream ss; ss.imbue(std::locale::classic()); diff --git a/onnxruntime/core/providers/webgpu/shader_variable.h b/onnxruntime/core/providers/webgpu/shader_variable.h index fbdb6590a7359..34d7674148412 100644 --- a/onnxruntime/core/providers/webgpu/shader_variable.h +++ b/onnxruntime/core/providers/webgpu/shader_variable.h @@ -14,7 +14,7 @@ namespace onnxruntime { namespace webgpu { template -std::string GetElementAt(const std::string& var, const TIdx& idx, int rank, bool is_f16 = false) { +std::string GetElementAt(std::string_view var, const TIdx& idx, int rank, bool is_f16 = false) { // "std::string::rfind(str, 0) == 0" is equivalent to "std::string::starts_with(str)" before C++20. if (var.rfind("uniform.", 0) == 0) { if (rank > 4) { @@ -31,34 +31,32 @@ std::string GetElementAt(const std::string& var, const TIdx& idx, int rank, bool return MakeStringWithClassicLocale(var, "[(", idx, ") / 4][(", idx, ") % 4]"); } } - } else { - return rank > 1 ? MakeStringWithClassicLocale(var, "[", idx, "]") : var; } - } else { - return rank > 1 ? MakeStringWithClassicLocale(var, "[", idx, "]") : var; } + + return rank > 1 ? MakeStringWithClassicLocale(var, "[", idx, "]") : std::string{var}; } class ShaderVariable { public: - ShaderVariable(const std::string& name, ProgramVariableDataType type, int rank); - ShaderVariable(const std::string& name, ProgramVariableDataType type, const TensorShape& dims); + ShaderVariable(std::string_view name, ProgramVariableDataType type, int rank); + ShaderVariable(std::string_view name, ProgramVariableDataType type, const TensorShape& dims); ShaderVariable(ShaderVariable&&) = default; ShaderVariable& operator=(ShaderVariable&&) = default; // create a WGSL expression ({varname}_indices_t) for getting indices from offset. // \param offset: a WGSL expression (u32) representing the offset. - inline std::string OffsetToIndices(const std::string& offset_expr) const; + inline std::string OffsetToIndices(std::string_view offset_expr) const; // create a WGSL expression (u32) for getting offset from indices. // \param indices: a WGSL expression ({varname}_indices_t) representing the indices. - inline std::string IndicesToOffset(const std::string& indices_expr) const; + inline std::string IndicesToOffset(std::string_view indices_expr) const; // create a WGSL expression (u32) for getting original offset from broadcasted indices. // \param indices: a WGSL expression ({broadcasted_result_varname}_indices_t) representing the broadcasted indices. // \param broadcasted_result: the broadcasted result variable. - inline std::string BroadcastedIndicesToOffset(const std::string& indices_expr, const ShaderVariable& broadcasted_result) const; + inline std::string BroadcastedIndicesToOffset(std::string_view indices_expr, const ShaderVariable& broadcasted_result) const; // create a WGSL expression ({varname}_indices_t) as an indices literal // \param init: a list of indices values. @@ -70,13 +68,13 @@ class ShaderVariable { // \param idx: the index (i32|u32) of the dimension to set. // \param value: the value (u32) to set. template - inline std::string IndicesSet(const std::string& indices_var, const TIdx& idx_expr, const TVal& value) const; + inline std::string IndicesSet(std::string_view indices_var, const TIdx& idx_expr, const TVal& value) const; // create a WGSL expression (u32) for getting value of the specified dimension of the indices. // \param indices_var: name of the indices variable ({varname}_indices_t). // \param idx: the index (i32|u32) of the dimension to get. template - inline std::string IndicesGet(const std::string& indices_var, const TIdx& idx_expr) const; + inline std::string IndicesGet(std::string_view indices_var, const TIdx& idx_expr) const; // create a WGSL statement for setting data at the given indices. // \param args: a list of indices values (u32) followed by a value ({varname}_value_t). @@ -86,7 +84,7 @@ class ShaderVariable { // create a WGSL statement for setting data at the given indices. // \param indices_var: name of the indices variable ({varname}_indices_t). // \param value: the value ({varname}_value_t) to set. - inline std::string SetByIndices(const std::string& indices_var, const std::string& value) const; + inline std::string SetByIndices(std::string_view indices_var, std::string_view value) const; // create a WGSL statement for setting data at the given offset. // \param offset: a WGSL expression (u32) representing the offset. @@ -101,7 +99,7 @@ class ShaderVariable { // create a WGSL expression ({varname}_value_t) for getting data at the given indices. // \param indices_var: name of the indices variable ({varname}_indices_t). - inline std::string GetByIndices(const std::string& indices_var) const; + inline std::string GetByIndices(std::string_view indices_var) const; // create a WGSL expression ({varname}_value_t) for getting data at the given offset. // \param offset: a WGSL expression (u32) representing the offset. @@ -131,8 +129,8 @@ class ShaderVariable { void Init(); void Impl(std::ostringstream& ss) const; - std::string GetByOffsetImpl(const std::string& offset) const; - std::string SetByOffsetImpl(const std::string& offset, const std::string& value) const; + std::string GetByOffsetImpl(std::string_view offset) const; + std::string SetByOffsetImpl(std::string_view offset, std::string_view value) const; std::string_view StorageType() const; std::string_view ValueType() const; @@ -167,23 +165,29 @@ template >> std::string pass_as_string(T&& v) { return std::to_string(std::forward(v)); } +template +std::string_view pass_as_string(std::string_view sv) { + return sv; +} template -std::string pass_as_string(const T& v) { - return v; +std::string pass_as_string(T&& v) { + return std::forward(v); } } // namespace detail -inline std::string ShaderVariable::OffsetToIndices(const std::string& offset_expr) const { +inline std::string ShaderVariable::OffsetToIndices(std::string_view offset_expr) const { usage_ |= UseOffsetToIndices; - return rank_ < 2 ? offset_expr : MakeStringWithClassicLocale("o2i_", name_, '(', offset_expr, ')'); + return rank_ < 2 ? std::string{offset_expr} + : MakeStringWithClassicLocale("o2i_", name_, '(', offset_expr, ')'); } -inline std::string ShaderVariable::IndicesToOffset(const std::string& indices_expr) const { +inline std::string ShaderVariable::IndicesToOffset(std::string_view indices_expr) const { usage_ |= UseIndicesToOffset; - return rank_ < 2 ? indices_expr : MakeStringWithClassicLocale("i2o_", name_, '(', indices_expr, ')'); + return rank_ < 2 ? std::string{indices_expr} + : MakeStringWithClassicLocale("i2o_", name_, '(', indices_expr, ')'); } -inline std::string ShaderVariable::BroadcastedIndicesToOffset(const std::string& indices_expr, const ShaderVariable& broadcasted_result) const { +inline std::string ShaderVariable::BroadcastedIndicesToOffset(std::string_view indices_expr, const ShaderVariable& broadcasted_result) const { usage_ |= UseBroadcastedIndicesToOffset; broadcasted_to_.push_back(broadcasted_result); return MakeStringWithClassicLocale(broadcasted_result.name_, "_bi2o_", name_, '(', indices_expr, ')'); @@ -199,14 +203,15 @@ inline std::string ShaderVariable::Indices(TIndices&&... indices_args) const { } template -inline std::string ShaderVariable::IndicesSet(const std::string& indices_var, const TIdx& idx_expr, const TVal& value) const { +inline std::string ShaderVariable::IndicesSet(std::string_view indices_var, const TIdx& idx_expr, const TVal& value) const { return rank_ < 2 ? MakeStringWithClassicLocale(indices_var, '=', value, ';') : MakeStringWithClassicLocale(GetElementAt(indices_var, idx_expr, rank_), '=', value, ';'); } template -inline std::string ShaderVariable::IndicesGet(const std::string& indices_var, const TIdx& idx_expr) const { - return rank_ < 2 ? indices_var : GetElementAt(indices_var, idx_expr, rank_); +inline std::string ShaderVariable::IndicesGet(std::string_view indices_var, const TIdx& idx_expr) const { + return rank_ < 2 ? std::string{indices_var} + : GetElementAt(indices_var, idx_expr, rank_); } template @@ -229,7 +234,7 @@ inline std::string ShaderVariable::Set(TIndicesAndValue&&... args) const { } } -inline std::string ShaderVariable::SetByIndices(const std::string& indices_var, const std::string& value) const { +inline std::string ShaderVariable::SetByIndices(std::string_view indices_var, std::string_view value) const { if (rank_ < 2) { return SetByOffset(indices_var, value); } else { @@ -258,7 +263,7 @@ inline std::string ShaderVariable::Get(TIndices&&... indices) const { } } -inline std::string ShaderVariable::GetByIndices(const std::string& indices_var) const { +inline std::string ShaderVariable::GetByIndices(std::string_view indices_var) const { if (rank_ < 2) { return GetByOffset(indices_var); } else { diff --git a/onnxruntime/core/providers/webgpu/webgpu_context.cc b/onnxruntime/core/providers/webgpu/webgpu_context.cc index 776fbb069bb5e..d2428d8bb7be8 100644 --- a/onnxruntime/core/providers/webgpu/webgpu_context.cc +++ b/onnxruntime/core/providers/webgpu/webgpu_context.cc @@ -157,7 +157,7 @@ Status WebGpuContext::Run(const ComputeContext& context, const ProgramBase& prog return Status::OK(); } - ProgramMetadata metadata = program.GetMetadata(); + const ProgramMetadata metadata = program.GetMetadata(); // validate program metadata { @@ -227,35 +227,55 @@ Status WebGpuContext::Run(const ComputeContext& context, const ProgramBase& prog #endif } + // prepare uniform info + const auto& uniforms = program.UniformVariables(); + size_t current_offset = 0; + std::vector> uniform_and_offsets; + uniform_and_offsets.reserve(uniforms.size()); + for (const auto& uniform : uniforms) { + bool is_f16 = uniform.data_type == ProgramUniformVariableDataType::Float16; + size_t length = uniform.length; + + // skip zero-length uniform + if (length == 0) { + continue; + } + + size_t element_size = ProgramUniformVariableDataTypeSize[static_cast(uniform.data_type)]; + // https://www.w3.org/TR/WGSL/#alignof + size_t base_alignment = is_f16 + ? (length > 4 ? 16 : length > 2 ? 8 + : length * element_size) + : (length > 2 ? 16 : length * element_size); + size_t struct_size = is_f16 && length <= 4 ? length * element_size : 16; + + current_offset = (current_offset + base_alignment - 1) / base_alignment * base_alignment; + uniform_and_offsets.emplace_back(uniform, current_offset); + + // For non-float16 type, when length > 4, the uniform variable is of type array,N>, where + // N = ceil(data.length / 4) and SizeOf(vec4) = 16. The total byte length is N * SizeOf(vec4). + // For float16 type, when length > 4, the uniform variable is of type array,N>, where + // N = ceil(data.length / 8) and SizeOf(mat2x4) = 16. The total byte length is N * SizeOf(mat2x4). + size_t element_per_struct = is_f16 ? 8 : 4; + current_offset += + length > 4 ? (length + element_per_struct - 1) / element_per_struct * struct_size : length * element_size; + } + + // Meet alignment of struct here: https://www.w3.org/TR/WGSL/#alignment-and-size. For simplicity, set + // max_alignment_of_field to 16 since the underlying buffer has been rounded up to 16. + const size_t max_alignment_of_field = 16; + const size_t uniform_buffer_total_size = (current_offset + max_alignment_of_field - 1) / max_alignment_of_field * max_alignment_of_field; + WGPUBuffer uniform_buffer = nullptr; - auto uniform_buffer_size = program_artifact->uniform_total_size; - if (uniform_buffer_size > 0) { - auto num_uniforms = program.UniformVariables().size(); - ORT_ENFORCE(program_artifact->uniforms.size() == num_uniforms, - "Uniforms size mismatch. Artifact: ", program_artifact->uniforms.size(), ", Current: ", num_uniforms); - - std::vector uniform_data(uniform_buffer_size); - - for (size_t i = 0; i < num_uniforms; ++i) { - const auto& uniform = program.UniformVariables()[i]; - const auto& artifact_uniform = program_artifact->uniforms[i]; - - ORT_ENFORCE(uniform.data_type == artifact_uniform.data_type, - "Uniform[", i, "] data type mismatch. Artifact: ", artifact_uniform.data_type, - ", Current: ", uniform.data_type); - ORT_ENFORCE(uniform.length == artifact_uniform.length, - "Uniform[", i, "] elements number mismatch. Artifact: ", artifact_uniform.length, ", Current: ", uniform.length); - ORT_ENFORCE(uniform.data.size() == artifact_uniform.length * ProgramUniformVariableDataTypeSize[static_cast(uniform.data_type)], - "Uniform[", i, "] data size mismatch. Artifact: ", artifact_uniform.length * ProgramUniformVariableDataTypeSize[static_cast(uniform.data_type)], - ", Current: ", uniform.data.size()); - - auto offset = artifact_uniform.offset; - auto size = uniform.data.size(); - memcpy(uniform_data.data() + offset, uniform.data.data(), size); + if (uniform_buffer_total_size > 0) { + std::vector uniform_data_buffer(uniform_buffer_total_size); + + for (auto const& [uniform, offset] : uniform_and_offsets) { + memcpy(uniform_data_buffer.data() + offset, uniform.data.data(), uniform.data.size()); } - uniform_buffer = buffer_mgr_->Create(uniform_buffer_size, wgpu::BufferUsage::CopyDst | wgpu::BufferUsage::Uniform); - device_.GetQueue().WriteBuffer(uniform_buffer, 0, uniform_data.data(), uniform_buffer_size); + uniform_buffer = buffer_mgr_->Create(uniform_buffer_total_size, wgpu::BufferUsage::CopyDst | wgpu::BufferUsage::Uniform); + device_.GetQueue().WriteBuffer(uniform_buffer, 0, uniform_data_buffer.data(), uniform_buffer_total_size); } const auto& compute_pass_encoder = GetComputePassEncoder(); From 6ad89c56bfe85794c50f1e547446d649083be53a Mon Sep 17 00:00:00 2001 From: Yulong Wang <7679871+fs-eire@users.noreply.github.com> Date: Sun, 1 Sep 2024 23:32:51 -0700 Subject: [PATCH 31/77] add some unary ops --- .../webgpu/math/unary_elementwise_ops.cc | 193 +++++++++++++++--- .../webgpu/math/unary_elementwise_ops.h | 33 ++- .../webgpu/webgpu_execution_provider.cc | 66 +++--- 3 files changed, 225 insertions(+), 67 deletions(-) diff --git a/onnxruntime/core/providers/webgpu/math/unary_elementwise_ops.cc b/onnxruntime/core/providers/webgpu/math/unary_elementwise_ops.cc index 5c774df84638e..0ae48ccbd6341 100644 --- a/onnxruntime/core/providers/webgpu/math/unary_elementwise_ops.cc +++ b/onnxruntime/core/providers/webgpu/math/unary_elementwise_ops.cc @@ -24,43 +24,172 @@ Status UnaryElementwiseProgram::GenerateShaderCode(ShaderHelper& shader) const { return Status::OK(); } -#define WEBGPU_ELEMENTWISE_IMPL(OP_TYPE, ...) \ - class OP_TYPE final : public WebGpuKernel { \ - public: \ - OP_TYPE(const OpKernelInfo& info) : WebGpuKernel{info} {} \ - \ - protected: \ - Status ComputeInternal(ComputeContext& context) const override { \ - const auto* input_tensor = context.Input(0); \ - auto* output_tensor = context.Output(0, input_tensor->Shape()); \ - SafeInt vec_size = (input_tensor->Shape().Size() + 3) / 4; \ - UnaryElementwiseProgram program{#OP_TYPE, __VA_ARGS__}; \ - program \ - .Inputs({{input_tensor, ProgramInputTensorDependency::Type}}) \ - .Outputs({output_tensor}) \ - .DispatchGroupSize((vec_size + WORKGROUP_SIZE - 1) / WORKGROUP_SIZE) \ - .UniformVariables({ \ - {static_cast(vec_size)}, \ - }); \ - return context.RunProgram(program); \ - } \ +Status UnaryElementwise::ComputeInternal(ComputeContext& context) const { + const auto* input_tensor = context.Input(0); + auto* output_tensor = context.Output(0, input_tensor->Shape()); + SafeInt vec_size = (input_tensor->Shape().Size() + 3) / 4; + UnaryElementwiseProgram program{kernel_name_, expression_, additional_impl_}; + program + .Inputs({{input_tensor, ProgramInputTensorDependency::Type}}) + .Outputs({output_tensor}) + .DispatchGroupSize((vec_size + WORKGROUP_SIZE - 1) / WORKGROUP_SIZE) + .UniformVariables({ + {static_cast(vec_size)}, + }); + ORT_RETURN_IF_ERROR(ConfigureProgram(program)); + return context.RunProgram(program); +} + +#define WEBGPU_ELEMENTWISE_IMPL(OP_TYPE, ...) \ + class OP_TYPE final : public UnaryElementwise { \ + public: \ + OP_TYPE(const OpKernelInfo& info) : UnaryElementwise{info, #OP_TYPE, __VA_ARGS__} {} \ }; -#define WEBGPU_ELEMENTWISE_KERNEL(OP_TYPE, VERSION, KERNEL_CLASS, TYPE) \ - ONNX_OPERATOR_KERNEL_EX( \ - OP_TYPE, kOnnxDomain, VERSION, kWebGpuExecutionProvider, \ - KernelDefBuilder().TypeConstraint("T", TYPE), \ - KERNEL_CLASS); +#define WEBGPU_ELEMENTWISE_KERNEL(OP_TYPE_AND_CLASS_NAME, VERSION, TYPE) \ + ONNX_OPERATOR_KERNEL_EX( \ + OP_TYPE_AND_CLASS_NAME, kOnnxDomain, VERSION, kWebGpuExecutionProvider, \ + KernelDefBuilder().TypeConstraint("T", TYPE), \ + OP_TYPE_AND_CLASS_NAME); + +#define WEBGPU_ELEMENTWISE_VERSIONED_KERNEL(OP_TYPE_AND_CLASS_NAME, VERSION_FROM, VERSION_TO, TYPE) \ + ONNX_OPERATOR_VERSIONED_KERNEL_EX( \ + OP_TYPE_AND_CLASS_NAME, kOnnxDomain, VERSION_FROM, VERSION_TO, kWebGpuExecutionProvider, \ + KernelDefBuilder().TypeConstraint("T", TYPE), \ + OP_TYPE_AND_CLASS_NAME); -#define WEBGPU_ELEMENTWISE_VERSIONED_KERNEL(OP_TYPE, VERSION_FROM, VERSION_TO, KERNEL_CLASS, TYPE) \ - ONNX_OPERATOR_VERSIONED_KERNEL_EX( \ - OP_TYPE, kOnnxDomain, VERSION_FROM, VERSION_TO, kWebGpuExecutionProvider, \ - KernelDefBuilder().TypeConstraint("T", TYPE), \ - KERNEL_CLASS); +// +// math +// WEBGPU_ELEMENTWISE_IMPL(Abs, "abs(a)") -WEBGPU_ELEMENTWISE_VERSIONED_KERNEL(Abs, 6, 12, Abs, WebGpuSupportedFloatTypes()) -WEBGPU_ELEMENTWISE_KERNEL(Abs, 13, Abs, WebGpuSupportedFloatTypes()) +WEBGPU_ELEMENTWISE_VERSIONED_KERNEL(Abs, 6, 12, WebGpuSupportedFloatTypes()) +WEBGPU_ELEMENTWISE_KERNEL(Abs, 13, WebGpuSupportedFloatTypes()) + +WEBGPU_ELEMENTWISE_IMPL(Neg, "-a") +WEBGPU_ELEMENTWISE_VERSIONED_KERNEL(Neg, 6, 12, WebGpuSupportedFloatTypes()) +WEBGPU_ELEMENTWISE_KERNEL(Neg, 13, WebGpuSupportedFloatTypes()) + +WEBGPU_ELEMENTWISE_IMPL(Floor, "floor(a)") +WEBGPU_ELEMENTWISE_VERSIONED_KERNEL(Floor, 6, 12, WebGpuSupportedFloatTypes()) +WEBGPU_ELEMENTWISE_KERNEL(Floor, 13, WebGpuSupportedFloatTypes()) + +WEBGPU_ELEMENTWISE_IMPL(Ceil, "ceil(a)") +WEBGPU_ELEMENTWISE_VERSIONED_KERNEL(Ceil, 6, 12, WebGpuSupportedFloatTypes()) +WEBGPU_ELEMENTWISE_KERNEL(Ceil, 13, WebGpuSupportedFloatTypes()) + +WEBGPU_ELEMENTWISE_IMPL(Reciprocal, "1.0/a") +WEBGPU_ELEMENTWISE_VERSIONED_KERNEL(Reciprocal, 6, 12, WebGpuSupportedFloatTypes()) +WEBGPU_ELEMENTWISE_KERNEL(Reciprocal, 13, WebGpuSupportedFloatTypes()) + +WEBGPU_ELEMENTWISE_IMPL(Sqrt, "sqrt(a)") +WEBGPU_ELEMENTWISE_VERSIONED_KERNEL(Sqrt, 6, 12, WebGpuSupportedFloatTypes()) +WEBGPU_ELEMENTWISE_KERNEL(Sqrt, 13, WebGpuSupportedFloatTypes()) + +WEBGPU_ELEMENTWISE_IMPL(Exp, "exp(a)") +WEBGPU_ELEMENTWISE_VERSIONED_KERNEL(Exp, 6, 12, WebGpuSupportedFloatTypes()) +WEBGPU_ELEMENTWISE_KERNEL(Exp, 13, WebGpuSupportedFloatTypes()) + +constexpr char ErfImpl[] = R"( +const r0: x_value_t = 0.3275911; +const r1: x_value_t = 0.254829592; +const r2: x_value_t = -0.284496736; +const r3: x_value_t = 1.421413741; +const r4: x_value_t = -1.453152027; +const r5: x_value_t = 1.061405429; + +fn erf_v(v: vec4) -> vec4 { + let absv = abs(v); + let x = 1.0 / (1.0 + r0 * absv); + return sign(v) * (1.0 - ((((r5 * x + r4) * x + r3) * x + r2) * x + r1) * x * exp(-absv * absv)); +} +)"; + +WEBGPU_ELEMENTWISE_IMPL(Erf, "erf_v(a)", ErfImpl) +WEBGPU_ELEMENTWISE_VERSIONED_KERNEL(Erf, 9, 12, WebGpuSupportedFloatTypes()) +WEBGPU_ELEMENTWISE_KERNEL(Erf, 13, WebGpuSupportedFloatTypes()) + +WEBGPU_ELEMENTWISE_IMPL(Log, "log(a)") +WEBGPU_ELEMENTWISE_VERSIONED_KERNEL(Log, 6, 12, WebGpuSupportedFloatTypes()) +WEBGPU_ELEMENTWISE_KERNEL(Log, 13, WebGpuSupportedFloatTypes()) + +WEBGPU_ELEMENTWISE_IMPL(Sigmoid, "1.0 / (1.0 + exp(-a))") +WEBGPU_ELEMENTWISE_VERSIONED_KERNEL(Sigmoid, 6, 12, WebGpuSupportedFloatTypes()) +WEBGPU_ELEMENTWISE_KERNEL(Sigmoid, 13, WebGpuSupportedFloatTypes()) + +class HardSigmoid final : public UnaryElementwise { + public: + HardSigmoid(const OpKernelInfo& info) + : UnaryElementwise{ + info, + "HardSigmoid", + // alpha = uniforms.f32_attr[0] + // beta = uniforms.f32_attr[1] + "max(vec4(0.0), min(vec4(1.0), x_value_t(uniforms.f32_attr[0]) * a + vec4(uniforms.f32_attr[1])))"} { + info.GetAttrOrDefault("alpha", attr, 0.2f); + info.GetAttrOrDefault("beta", attr + 1, 0.5f); + } + + Status ConfigureProgram(UnaryElementwiseProgram& program) const override { + program.UniformVariables({gsl::make_span(attr, 2), {}}); + return Status::OK(); + } + + protected: + float attr[2]; +}; + +WEBGPU_ELEMENTWISE_KERNEL(HardSigmoid, 6, WebGpuSupportedFloatTypes()) + +WEBGPU_ELEMENTWISE_IMPL(Sin, "sin(a)") +WEBGPU_ELEMENTWISE_KERNEL(Sin, 7, WebGpuSupportedFloatTypes()) + +WEBGPU_ELEMENTWISE_IMPL(Cos, "cos(a)") +WEBGPU_ELEMENTWISE_KERNEL(Cos, 7, WebGpuSupportedFloatTypes()) + +WEBGPU_ELEMENTWISE_IMPL(Tan, "tan(a)") +WEBGPU_ELEMENTWISE_KERNEL(Tan, 7, WebGpuSupportedFloatTypes()) + +WEBGPU_ELEMENTWISE_IMPL(Asin, "asin(a)") +WEBGPU_ELEMENTWISE_KERNEL(Asin, 7, WebGpuSupportedFloatTypes()) + +WEBGPU_ELEMENTWISE_IMPL(Acos, "acos(a)") +WEBGPU_ELEMENTWISE_KERNEL(Acos, 7, WebGpuSupportedFloatTypes()) + +WEBGPU_ELEMENTWISE_IMPL(Atan, "atan(a)") +WEBGPU_ELEMENTWISE_KERNEL(Atan, 7, WebGpuSupportedFloatTypes()) + +WEBGPU_ELEMENTWISE_IMPL(Sinh, "sinh(a)") +WEBGPU_ELEMENTWISE_KERNEL(Sinh, 9, WebGpuSupportedFloatTypes()) + +WEBGPU_ELEMENTWISE_IMPL(Cosh, "cosh(a)") +WEBGPU_ELEMENTWISE_KERNEL(Cosh, 9, WebGpuSupportedFloatTypes()) + +WEBGPU_ELEMENTWISE_IMPL(Tanh, "tanh(a)") +WEBGPU_ELEMENTWISE_VERSIONED_KERNEL(Tanh, 6, 12, WebGpuSupportedFloatTypes()) +WEBGPU_ELEMENTWISE_KERNEL(Tanh, 13, WebGpuSupportedFloatTypes()) + +WEBGPU_ELEMENTWISE_IMPL(Asinh, "asinh(a)") +WEBGPU_ELEMENTWISE_KERNEL(Asinh, 9, WebGpuSupportedFloatTypes()) + +WEBGPU_ELEMENTWISE_IMPL(Acosh, "acosh(a)") +WEBGPU_ELEMENTWISE_KERNEL(Acosh, 9, WebGpuSupportedFloatTypes()) + +WEBGPU_ELEMENTWISE_IMPL(Atanh, "atanh(a)") +WEBGPU_ELEMENTWISE_KERNEL(Atanh, 9, WebGpuSupportedFloatTypes()) + +// todo: logical ops + +// +// activation +// + +// todo: clip + +// constexpr char EluImpl[] = R"( +//)"; +// +// WEBGPU_ELEMENTWISE_IMPL(Elu, "elu_v(a)", ) // TODO: add other unary elementwise ops diff --git a/onnxruntime/core/providers/webgpu/math/unary_elementwise_ops.h b/onnxruntime/core/providers/webgpu/math/unary_elementwise_ops.h index 837f66af30dde..dbf15248b6b13 100644 --- a/onnxruntime/core/providers/webgpu/math/unary_elementwise_ops.h +++ b/onnxruntime/core/providers/webgpu/math/unary_elementwise_ops.h @@ -11,15 +11,44 @@ namespace webgpu { class UnaryElementwiseProgram final : public Program { public: - UnaryElementwiseProgram(const std::string& kernel_name, const std::string& expression, const std::string& additional_impl = "") + UnaryElementwiseProgram(const std::string& kernel_name, std::string_view expression, std::string_view additional_impl) : Program{kernel_name}, expression_{expression}, additional_impl_{additional_impl} { } Status GenerateShaderCode(ShaderHelper& sh) const override; - WEBGPU_PROGRAM_DEFINE_UNIFORM_VARIABLES({"vec_size", ProgramUniformVariableDataType::Uint32}); + WEBGPU_PROGRAM_DEFINE_UNIFORM_VARIABLES( + {"vec_size", ProgramUniformVariableDataType::Uint32}, // output size + {"f32_attr", ProgramUniformVariableDataType::Float32}, // float type attribute(s) + {"i32_attr", ProgramUniformVariableDataType::Int32}); // int type attribute(s) private: + std::string_view expression_; + std::string_view additional_impl_; +}; + +// TODO: after upgrading to C++20, use consteval to make a compile-time constructor so that it will be safe to switch +// the std::string to std::string_view. This will avoid the cost of copying the string. + +class UnaryElementwise : public WebGpuKernel { + public: + UnaryElementwise(const OpKernelInfo& info, + const std::string& kernel_name, + const std::string& expression, + const std::string& additional_impl = "") : WebGpuKernel{info}, + kernel_name_{kernel_name}, + expression_{expression}, + additional_impl_{additional_impl} {} + + protected: + Status ComputeInternal(ComputeContext& context) const final; + virtual Status ConfigureProgram(UnaryElementwiseProgram& program) const { + program.UniformVariables({{}, {}}); // empty for both float and int attribute(s) + return Status::OK(); + } + + private: + std::string kernel_name_; std::string expression_; std::string additional_impl_; }; diff --git a/onnxruntime/core/providers/webgpu/webgpu_execution_provider.cc b/onnxruntime/core/providers/webgpu/webgpu_execution_provider.cc index e7688d1fafb94..202742a1c79bc 100644 --- a/onnxruntime/core/providers/webgpu/webgpu_execution_provider.cc +++ b/onnxruntime/core/providers/webgpu/webgpu_execution_provider.cc @@ -397,39 +397,39 @@ std::unique_ptr RegisterKernels() { // unary - math KERNEL_CREATE_INFO_VERSIONED(6, 12, Abs), KERNEL_CREATE_INFO(13, Abs), - // KERNEL_CREATE_INFO_VERSIONED(6, 12, Neg), - // KERNEL_CREATE_INFO(13, Neg), - // KERNEL_CREATE_INFO_VERSIONED(6, 12, Floor), - // KERNEL_CREATE_INFO(13, Floor), - // KERNEL_CREATE_INFO_VERSIONED(6, 12, Ceil), - // KERNEL_CREATE_INFO(13, Ceil), - // KERNEL_CREATE_INFO_VERSIONED(6, 12, Reciprocal), - // KERNEL_CREATE_INFO(13, Reciprocal), - // KERNEL_CREATE_INFO_VERSIONED(6, 12, Sqrt), - // KERNEL_CREATE_INFO(13, Sqrt), - // KERNEL_CREATE_INFO_VERSIONED(6, 12, Exp), - // KERNEL_CREATE_INFO(13, Exp), - // KERNEL_CREATE_INFO_VERSIONED(9, 12, Erf), - // KERNEL_CREATE_INFO(13, Erf), - // KERNEL_CREATE_INFO_VERSIONED(6, 12, Sigmoid), - // KERNEL_CREATE_INFO(13, Sigmoid), - // KERNEL_CREATE_INFO(6, HardSigmoid), - // KERNEL_CREATE_INFO_VERSIONED(6, 12, Log), - // KERNEL_CREATE_INFO(13, Log), - - // KERNEL_CREATE_INFO(7, Sin), - // KERNEL_CREATE_INFO(7, Cos), - // KERNEL_CREATE_INFO(7, Tan), - // KERNEL_CREATE_INFO(7, Asin), - // KERNEL_CREATE_INFO(7, Acos), - // KERNEL_CREATE_INFO(7, Atan), - // KERNEL_CREATE_INFO(9, Sinh), - // KERNEL_CREATE_INFO(9, Cosh), - // KERNEL_CREATE_INFO(9, Asinh), - // KERNEL_CREATE_INFO(9, Acosh), - // KERNEL_CREATE_INFO(9, Atanh), - // KERNEL_CREATE_INFO_VERSIONED(6, 12, Tanh), - // KERNEL_CREATE_INFO(13, Tanh), + KERNEL_CREATE_INFO_VERSIONED(6, 12, Neg), + KERNEL_CREATE_INFO(13, Neg), + KERNEL_CREATE_INFO_VERSIONED(6, 12, Floor), + KERNEL_CREATE_INFO(13, Floor), + KERNEL_CREATE_INFO_VERSIONED(6, 12, Ceil), + KERNEL_CREATE_INFO(13, Ceil), + KERNEL_CREATE_INFO_VERSIONED(6, 12, Reciprocal), + KERNEL_CREATE_INFO(13, Reciprocal), + KERNEL_CREATE_INFO_VERSIONED(6, 12, Sqrt), + KERNEL_CREATE_INFO(13, Sqrt), + KERNEL_CREATE_INFO_VERSIONED(6, 12, Exp), + KERNEL_CREATE_INFO(13, Exp), + KERNEL_CREATE_INFO_VERSIONED(9, 12, Erf), + KERNEL_CREATE_INFO(13, Erf), + KERNEL_CREATE_INFO_VERSIONED(6, 12, Sigmoid), + KERNEL_CREATE_INFO(13, Sigmoid), + KERNEL_CREATE_INFO(6, HardSigmoid), + KERNEL_CREATE_INFO_VERSIONED(6, 12, Log), + KERNEL_CREATE_INFO(13, Log), + + KERNEL_CREATE_INFO(7, Sin), + KERNEL_CREATE_INFO(7, Cos), + KERNEL_CREATE_INFO(7, Tan), + KERNEL_CREATE_INFO(7, Asin), + KERNEL_CREATE_INFO(7, Acos), + KERNEL_CREATE_INFO(7, Atan), + KERNEL_CREATE_INFO(9, Sinh), + KERNEL_CREATE_INFO(9, Cosh), + KERNEL_CREATE_INFO(9, Asinh), + KERNEL_CREATE_INFO(9, Acosh), + KERNEL_CREATE_INFO(9, Atanh), + KERNEL_CREATE_INFO_VERSIONED(6, 12, Tanh), + KERNEL_CREATE_INFO(13, Tanh), // KERNEL_CREATE_INFO(1, Not), // KERNEL_CREATE_INFO_VERSIONED(6, 8, Cast), From 8361fc3e440b53bb20671831ed7e10631d4fb528 Mon Sep 17 00:00:00 2001 From: Yulong Wang <7679871+fs-eire@users.noreply.github.com> Date: Mon, 2 Sep 2024 16:58:24 -0700 Subject: [PATCH 32/77] various of fixes --- .../webgpu/math/unary_elementwise_ops.cc | 58 +++--- .../webgpu/math/unary_elementwise_ops.h | 17 +- onnxruntime/core/providers/webgpu/program.cc | 40 ++++- onnxruntime/core/providers/webgpu/program.h | 53 ++++-- .../providers/webgpu/program_cache_key.cc | 50 ++++-- .../core/providers/webgpu/program_manager.cc | 3 +- .../core/providers/webgpu/shader_helper.cc | 169 +++++++++++++++++- .../core/providers/webgpu/shader_helper.h | 43 +++-- .../core/providers/webgpu/shader_variable.cc | 90 ++++++---- .../core/providers/webgpu/shader_variable.h | 34 ++-- .../core/providers/webgpu/webgpu_context.cc | 5 +- 11 files changed, 410 insertions(+), 152 deletions(-) diff --git a/onnxruntime/core/providers/webgpu/math/unary_elementwise_ops.cc b/onnxruntime/core/providers/webgpu/math/unary_elementwise_ops.cc index 0ae48ccbd6341..97dd2c5984631 100644 --- a/onnxruntime/core/providers/webgpu/math/unary_elementwise_ops.cc +++ b/onnxruntime/core/providers/webgpu/math/unary_elementwise_ops.cc @@ -2,20 +2,17 @@ // Licensed under the MIT License. #include "core/providers/webgpu/math/unary_elementwise_ops.h" -#include "core/providers/webgpu/shader_helper.h" #include "core/providers/webgpu/webgpu_supported_types.h" namespace onnxruntime { namespace webgpu { Status UnaryElementwiseProgram::GenerateShaderCode(ShaderHelper& shader) const { - const auto& input = shader.AddVariable(ProgramVariableScope::Input, - "x", - ToProgramVariableDataType(Inputs()[0].tensor->GetElementType(), 4), - 1); - const auto& output = shader.AddVariable(ProgramVariableScope::Output, - "y", - ToProgramVariableDataType(Outputs()[0]->GetElementType(), 4), - 1); + const auto& input = shader.AddInput("x", + ToProgramVariableDataType(Inputs()[0].tensor->GetElementType(), 4), + ShaderVariable::UseUniform | additional_usage_); + const auto& output = shader.AddOutput("y", + ToProgramVariableDataType(Outputs()[0].tensor->GetElementType(), 4), + ShaderVariable::UseUniform); shader.AppendImplementation(additional_impl_); shader.MainFunctionBody(shader.GuardAgainstOutOfBoundsWorkgroupSizes("uniforms.vec_size"), "let a = ", input.GetByOffset("global_idx"), ";\n", @@ -27,11 +24,12 @@ Status UnaryElementwiseProgram::GenerateShaderCode(ShaderHelper& shader) const { Status UnaryElementwise::ComputeInternal(ComputeContext& context) const { const auto* input_tensor = context.Input(0); auto* output_tensor = context.Output(0, input_tensor->Shape()); - SafeInt vec_size = (input_tensor->Shape().Size() + 3) / 4; - UnaryElementwiseProgram program{kernel_name_, expression_, additional_impl_}; + int64_t size = input_tensor->Shape().Size(); + SafeInt vec_size = (size + 3) / 4; + UnaryElementwiseProgram program{kernel_name_, expression_, additional_impl_, additional_usage_}; program - .Inputs({{input_tensor, ProgramInputTensorDependency::Type}}) - .Outputs({output_tensor}) + .Inputs({{input_tensor, ProgramTensorMetadataDependency::Type, {vec_size}}}) + .Outputs({{output_tensor, ProgramTensorMetadataDependency::None, {vec_size}}}) .DispatchGroupSize((vec_size + WORKGROUP_SIZE - 1) / WORKGROUP_SIZE) .UniformVariables({ {static_cast(vec_size)}, @@ -91,21 +89,21 @@ WEBGPU_ELEMENTWISE_VERSIONED_KERNEL(Exp, 6, 12, WebGpuSupportedFloatTypes()) WEBGPU_ELEMENTWISE_KERNEL(Exp, 13, WebGpuSupportedFloatTypes()) constexpr char ErfImpl[] = R"( -const r0: x_value_t = 0.3275911; -const r1: x_value_t = 0.254829592; -const r2: x_value_t = -0.284496736; -const r3: x_value_t = 1.421413741; -const r4: x_value_t = -1.453152027; -const r5: x_value_t = 1.061405429; - -fn erf_v(v: vec4) -> vec4 { +const r0 = 0.3275911; +const r1 = 0.254829592; +const r2 = -0.284496736; +const r3 = 1.421413741; +const r4 = -1.453152027; +const r5 = 1.061405429; + +fn erf_v(v: x_value_t) -> x_value_t { let absv = abs(v); let x = 1.0 / (1.0 + r0 * absv); return sign(v) * (1.0 - ((((r5 * x + r4) * x + r3) * x + r2) * x + r1) * x * exp(-absv * absv)); } )"; -WEBGPU_ELEMENTWISE_IMPL(Erf, "erf_v(a)", ErfImpl) +WEBGPU_ELEMENTWISE_IMPL(Erf, "erf_v(a)", ErfImpl, ShaderVariable::UseValueTypeAlias) WEBGPU_ELEMENTWISE_VERSIONED_KERNEL(Erf, 9, 12, WebGpuSupportedFloatTypes()) WEBGPU_ELEMENTWISE_KERNEL(Erf, 13, WebGpuSupportedFloatTypes()) @@ -117,15 +115,19 @@ WEBGPU_ELEMENTWISE_IMPL(Sigmoid, "1.0 / (1.0 + exp(-a))") WEBGPU_ELEMENTWISE_VERSIONED_KERNEL(Sigmoid, 6, 12, WebGpuSupportedFloatTypes()) WEBGPU_ELEMENTWISE_KERNEL(Sigmoid, 13, WebGpuSupportedFloatTypes()) +constexpr char HardSigmoidImpl[] = R"( +fn hard_sigmoid_v(v: x_value_t) -> x_value_t { + let alpha = x_element_t(uniforms.f32_attr[0]); + let beta_v = vec4(uniforms.f32_attr[1]); + return max(vec4(0.0), + min(vec4(1.0), alpha * v + beta_v)); +} +)"; class HardSigmoid final : public UnaryElementwise { public: HardSigmoid(const OpKernelInfo& info) - : UnaryElementwise{ - info, - "HardSigmoid", - // alpha = uniforms.f32_attr[0] - // beta = uniforms.f32_attr[1] - "max(vec4(0.0), min(vec4(1.0), x_value_t(uniforms.f32_attr[0]) * a + vec4(uniforms.f32_attr[1])))"} { + : UnaryElementwise{info, "HardSigmoid", "hard_sigmoid_v(a)", HardSigmoidImpl, ShaderVariable::UseElementTypeAlias | ShaderVariable::UseValueTypeAlias} { + // attr[0] is alpha, attr[1] is beta info.GetAttrOrDefault("alpha", attr, 0.2f); info.GetAttrOrDefault("beta", attr + 1, 0.5f); } diff --git a/onnxruntime/core/providers/webgpu/math/unary_elementwise_ops.h b/onnxruntime/core/providers/webgpu/math/unary_elementwise_ops.h index dbf15248b6b13..2d084bf227f72 100644 --- a/onnxruntime/core/providers/webgpu/math/unary_elementwise_ops.h +++ b/onnxruntime/core/providers/webgpu/math/unary_elementwise_ops.h @@ -4,6 +4,7 @@ #pragma once #include "core/providers/webgpu/webgpu_kernel.h" +#include "core/providers/webgpu/shader_helper.h" #include "core/providers/webgpu/program.h" namespace onnxruntime { @@ -11,8 +12,8 @@ namespace webgpu { class UnaryElementwiseProgram final : public Program { public: - UnaryElementwiseProgram(const std::string& kernel_name, std::string_view expression, std::string_view additional_impl) - : Program{kernel_name}, expression_{expression}, additional_impl_{additional_impl} { + UnaryElementwiseProgram(const std::string& kernel_name, std::string_view expression, std::string_view additional_impl, ShaderVariable::Usage usage) + : Program{kernel_name}, expression_{expression}, additional_impl_{additional_impl}, additional_usage_{usage} { } Status GenerateShaderCode(ShaderHelper& sh) const override; @@ -25,6 +26,7 @@ class UnaryElementwiseProgram final : public Program { private: std::string_view expression_; std::string_view additional_impl_; + ShaderVariable::Usage additional_usage_; }; // TODO: after upgrading to C++20, use consteval to make a compile-time constructor so that it will be safe to switch @@ -35,10 +37,12 @@ class UnaryElementwise : public WebGpuKernel { UnaryElementwise(const OpKernelInfo& info, const std::string& kernel_name, const std::string& expression, - const std::string& additional_impl = "") : WebGpuKernel{info}, - kernel_name_{kernel_name}, - expression_{expression}, - additional_impl_{additional_impl} {} + const std::string& additional_impl = "", + ShaderVariable::Usage usage = ShaderVariable::None) : WebGpuKernel{info}, + kernel_name_{kernel_name}, + expression_{expression}, + additional_impl_{additional_impl}, + additional_usage_{usage} {} protected: Status ComputeInternal(ComputeContext& context) const final; @@ -51,6 +55,7 @@ class UnaryElementwise : public WebGpuKernel { std::string kernel_name_; std::string expression_; std::string additional_impl_; + ShaderVariable::Usage additional_usage_; }; } // namespace webgpu diff --git a/onnxruntime/core/providers/webgpu/program.cc b/onnxruntime/core/providers/webgpu/program.cc index 8ba33bcafb316..91f86d2cf681a 100644 --- a/onnxruntime/core/providers/webgpu/program.cc +++ b/onnxruntime/core/providers/webgpu/program.cc @@ -2,6 +2,7 @@ // Licensed under the MIT License. #include +#include #include "core/session/onnxruntime_c_api.h" @@ -49,27 +50,27 @@ ProgramUniformVariableValue::ProgramUniformVariableValue(ProgramUniformVariableD } std::ostream& operator<<(std::ostream& os, ProgramUniformVariableDataType type) { - os << ProgramUniformVariableDataTypeName[static_cast(type)]; + os << ProgramUniformVariableDataTypeName[std::underlying_type::type(type)]; return os; } std::ostream& operator<<(std::ostream& os, ProgramConstantDataType type) { - os << ProgramConstantDataTypeName[static_cast(type)]; + os << ProgramConstantDataTypeName[std::underlying_type::type(type)]; return os; } -std::ostream& operator<<(std::ostream& os, ProgramInputTensorDependency dep) { +std::ostream& operator<<(std::ostream& os, ProgramTensorMetadataDependency dep) { bool first = true; - if ((dep & ProgramInputTensorDependency::Type) == ProgramInputTensorDependency::Type) { + if ((dep & ProgramTensorMetadataDependency::Type) == ProgramTensorMetadataDependency::Type) { os << "Type"; first = false; } - if ((dep & ProgramInputTensorDependency::Rank) == ProgramInputTensorDependency::Rank) { + if ((dep & ProgramTensorMetadataDependency::Rank) == ProgramTensorMetadataDependency::Rank) { if (!first) os << "|"; os << "Rank"; first = false; } - if ((dep & ProgramInputTensorDependency::Shape) == ProgramInputTensorDependency::Shape) { + if ((dep & ProgramTensorMetadataDependency::Shape) == ProgramTensorMetadataDependency::Shape) { if (!first) os << "|"; os << "Shape"; first = false; @@ -81,6 +82,31 @@ std::ostream& operator<<(std::ostream& os, ProgramInputTensorDependency dep) { return os; } +int NumberOfComponents(ProgramVariableDataType type) { + switch (type) { + case ProgramVariableDataType::Float32: + case ProgramVariableDataType::Int32: + case ProgramVariableDataType::Uint32: + case ProgramVariableDataType::Int64: + case ProgramVariableDataType::Uint64: + case ProgramVariableDataType::Float16: + return 1; + case ProgramVariableDataType::Vec2Float32: + case ProgramVariableDataType::Vec2Int32: + case ProgramVariableDataType::Vec2Uint32: + case ProgramVariableDataType::Vec2Float16: + return 2; + case ProgramVariableDataType::Vec4Float32: + case ProgramVariableDataType::Vec4Int32: + case ProgramVariableDataType::Vec4Uint32: + case ProgramVariableDataType::Vec4Float16: + case ProgramVariableDataType::Vec4Bool: + return 4; + default: + return -1; + } +} + ProgramVariableDataType ToProgramVariableDataType(int32_t element_type, int component /* = 1 */) { if (component == 1) { switch (element_type) { @@ -147,7 +173,7 @@ ProgramBase& ProgramBase::Inputs(std::initializer_list inputs) { return *this; } -ProgramBase& ProgramBase::Outputs(std::initializer_list outputs) { +ProgramBase& ProgramBase::Outputs(std::initializer_list outputs) { outputs_.assign(outputs.begin(), outputs.end()); return *this; } diff --git a/onnxruntime/core/providers/webgpu/program.h b/onnxruntime/core/providers/webgpu/program.h index d056ee8577f11..c48bdb1a4ff12 100644 --- a/onnxruntime/core/providers/webgpu/program.h +++ b/onnxruntime/core/providers/webgpu/program.h @@ -140,7 +140,7 @@ struct ProgramOverridableConstantDefinition { }; // represents whether the program shader depends on the type, rank, or shape of an input/output tensor -enum class ProgramInputTensorDependency : int { +enum class ProgramTensorMetadataDependency : int { None = 0, Type = 1, Rank = 2, @@ -148,24 +148,47 @@ enum class ProgramInputTensorDependency : int { TypeAndRank = Type | Rank, TypeAndShape = Type | Shape, }; -std::ostream& operator<<(std::ostream& os, ProgramInputTensorDependency); +std::ostream& operator<<(std::ostream& os, ProgramTensorMetadataDependency); -inline ProgramInputTensorDependency operator|(ProgramInputTensorDependency a, ProgramInputTensorDependency b) { - return (ProgramInputTensorDependency)((int&)a | (int&)b); +inline ProgramTensorMetadataDependency operator|(ProgramTensorMetadataDependency a, ProgramTensorMetadataDependency b) { + return (ProgramTensorMetadataDependency)((int&)a | (int&)b); } -inline ProgramInputTensorDependency operator&(ProgramInputTensorDependency a, ProgramInputTensorDependency b) { - return (ProgramInputTensorDependency)((int&)a & (int&)b); +inline ProgramTensorMetadataDependency operator&(ProgramTensorMetadataDependency a, ProgramTensorMetadataDependency b) { + return (ProgramTensorMetadataDependency)((int&)a & (int&)b); } -inline ProgramInputTensorDependency& operator|=(ProgramInputTensorDependency& a, ProgramInputTensorDependency b) { - return (ProgramInputTensorDependency&)((int&)a |= (int&)b); +inline ProgramTensorMetadataDependency& operator|=(ProgramTensorMetadataDependency& a, ProgramTensorMetadataDependency b) { + return (ProgramTensorMetadataDependency&)((int&)a |= (int&)b); } -inline ProgramInputTensorDependency& operator&=(ProgramInputTensorDependency& a, ProgramInputTensorDependency b) { - return (ProgramInputTensorDependency&)((int&)a &= (int&)b); +inline ProgramTensorMetadataDependency& operator&=(ProgramTensorMetadataDependency& a, ProgramTensorMetadataDependency b) { + return (ProgramTensorMetadataDependency&)((int&)a &= (int&)b); } struct ProgramInput { + ProgramInput(const Tensor* tensor) + : ProgramInput{tensor, ProgramTensorMetadataDependency::TypeAndRank} {} + ProgramInput(const Tensor* tensor, ProgramTensorMetadataDependency dependency) + : tensor{tensor}, dependency{dependency}, use_override_shape{false}, override_shape{} {} + ProgramInput(const Tensor* tensor, ProgramTensorMetadataDependency dependency, const TensorShape& override_shape) + : tensor{tensor}, dependency{dependency}, use_override_shape{true}, override_shape{override_shape} {} + const Tensor* tensor; - ProgramInputTensorDependency dependency; + ProgramTensorMetadataDependency dependency; + bool use_override_shape; + TensorShape override_shape; +}; + +struct ProgramOutput { + ProgramOutput(Tensor* tensor) + : ProgramOutput{tensor, ProgramTensorMetadataDependency::None} {} + ProgramOutput(Tensor* tensor, ProgramTensorMetadataDependency dependency) + : tensor{tensor}, dependency{dependency}, use_override_shape{false}, override_shape{} {} + ProgramOutput(Tensor* tensor, ProgramTensorMetadataDependency dependency, const TensorShape& override_shape) + : tensor{tensor}, dependency{dependency}, use_override_shape{true}, override_shape{override_shape} {} + + Tensor* tensor; + ProgramTensorMetadataDependency dependency; + bool use_override_shape; + TensorShape override_shape; }; constexpr SafeInt WORKGROUP_SIZE = 64; @@ -205,6 +228,8 @@ enum class ProgramVariableDataType { Vec4Bool, }; +int NumberOfComponents(ProgramVariableDataType type); + ProgramVariableDataType ToProgramVariableDataType(int32_t element_type, int component = 1); namespace detail { @@ -229,7 +254,7 @@ class ProgramBase { // set one or more program inputs ProgramBase& Inputs(std::initializer_list inputs); // set one or more program outputs - ProgramBase& Outputs(std::initializer_list outputs); + ProgramBase& Outputs(std::initializer_list outputs); // set the size of dispatch groups. Y and Z are 1 if not specified. ProgramBase& DispatchGroupSize(uint32_t x); @@ -289,7 +314,7 @@ class ProgramBase { inline const std::string& Name() const { return name_; } inline const std::string& CacheHint() const { return cache_hint_; } inline const std::vector& Inputs() const { return inputs_; } - inline const std::vector& Outputs() const { return outputs_; } + inline const std::vector& Outputs() const { return outputs_; } inline uint32_t DispatchGroupSizeX() const { return dispatch_group_size_x_; } inline uint32_t DispatchGroupSizeY() const { return dispatch_group_size_y_; } inline uint32_t DispatchGroupSizeZ() const { return dispatch_group_size_z_; } @@ -310,7 +335,7 @@ class ProgramBase { std::string name_; std::string cache_hint_; std::vector inputs_; - std::vector outputs_; + std::vector outputs_; uint32_t dispatch_group_size_x_; uint32_t dispatch_group_size_y_; diff --git a/onnxruntime/core/providers/webgpu/program_cache_key.cc b/onnxruntime/core/providers/webgpu/program_cache_key.cc index a4530910944d4..7bea82a1b0c65 100644 --- a/onnxruntime/core/providers/webgpu/program_cache_key.cc +++ b/onnxruntime/core/providers/webgpu/program_cache_key.cc @@ -8,6 +8,29 @@ namespace onnxruntime { namespace webgpu { +namespace { +void AppendTensorInfo(std::ostringstream& ss, const Tensor& tensor, ProgramTensorMetadataDependency dependency, bool& first) { + if (first) { + first = false; + } else { + ss << "|"; + } + if ((dependency & ProgramTensorMetadataDependency::Type) == ProgramTensorMetadataDependency::Type) { +#ifndef NDEBUG // if debug build + ss << DataTypeImpl::ToString(tensor.DataType()); +#else + ss << output.tensor->GetElementType(); +#endif + } + ss << ";"; + if ((dependency & ProgramTensorMetadataDependency::Shape) == ProgramTensorMetadataDependency::Shape) { + ss D("Rank=") << tensor.Shape().NumDimensions(); + } else if ((dependency & ProgramTensorMetadataDependency::Rank) == ProgramTensorMetadataDependency::Rank) { + ss D("Dims=") << tensor.Shape().ToString(); + } +} +} // namespace + std::string CalculateProgramCacheKey(const ProgramBase& program, bool is_1d_dispatch) { std::ostringstream ss; ss.imbue(std::locale::classic()); @@ -34,6 +57,7 @@ std::string CalculateProgramCacheKey(const ProgramBase& program, bool is_1d_disp x != 0 || y != 0 || z != 0) { ss << ":" D("WorkgroupSize="); // only append non-zero values. zero values are considered as use default + // todo: this is actually not working correctly. revisit this logic. currently even if it's default, the value is not zero and will be appended if (x > 0) { ss << x; } @@ -60,27 +84,17 @@ std::string CalculateProgramCacheKey(const ProgramBase& program, bool is_1d_disp ss << uniform.length; } } + ss << ":" D("Inputs="); first = true; for (const auto& input : program.Inputs()) { - if (first) { - first = false; - } else { - ss << "|"; - } - if ((input.dependency & ProgramInputTensorDependency::Type) == ProgramInputTensorDependency::Type) { -#ifndef NDEBUG // if debug build - ss << DataTypeImpl::ToString(input.tensor->DataType()); -#else - ss << input.tensor->GetElementType(); -#endif - } - ss << ";"; - if ((input.dependency & ProgramInputTensorDependency::Rank) == ProgramInputTensorDependency::Rank) { - ss D("Rank=") << input.tensor->Shape().NumDimensions(); - } else if ((input.dependency & ProgramInputTensorDependency::Shape) == ProgramInputTensorDependency::Shape) { - ss D("Dims=") << input.tensor->Shape().ToString(); - } + AppendTensorInfo(ss, *input.tensor, input.dependency, first); + } + + ss << ":" D("Outputs="); + first = true; + for (const auto& output : program.Outputs()) { + AppendTensorInfo(ss, *output.tensor, output.dependency, first); } return ss.str(); diff --git a/onnxruntime/core/providers/webgpu/program_manager.cc b/onnxruntime/core/providers/webgpu/program_manager.cc index 00036a915f695..a10412f21f498 100644 --- a/onnxruntime/core/providers/webgpu/program_manager.cc +++ b/onnxruntime/core/providers/webgpu/program_manager.cc @@ -56,7 +56,8 @@ Status ProgramManager::Build(const ProgramBase& program, ORT_RETURN_IF_ERROR(program.GenerateShaderCode(shader_helper)); // code is a large std::string that contains the final shader code - auto code = shader_helper.GetFinalSourceCode(); + std::string code; + ORT_RETURN_IF_ERROR(shader_helper.GetFinalSourceCode(code)); LOGS_DEFAULT(VERBOSE) << "\n=== WebGPU Shader code [" << program.Name() #ifndef NDEBUG // if debug build diff --git a/onnxruntime/core/providers/webgpu/shader_helper.cc b/onnxruntime/core/providers/webgpu/shader_helper.cc index 5883696430de6..c8c79dd6233d2 100644 --- a/onnxruntime/core/providers/webgpu/shader_helper.cc +++ b/onnxruntime/core/providers/webgpu/shader_helper.cc @@ -4,10 +4,12 @@ #include #include #include +#include #include "core/session/onnxruntime_c_api.h" #include "core/providers/webgpu/shader_helper.h" +#include "core/providers/webgpu/program.h" namespace onnxruntime { namespace webgpu { @@ -79,7 +81,145 @@ Status ShaderHelper::Init() { return Status::OK(); } -std::string ShaderHelper::GetFinalSourceCode() { +const ShaderVariable& ShaderHelper::AddInput(const std::string& name, ProgramVariableDataType type, ShaderVariable::Usage usage) { + const size_t input_index = vars_[std::underlying_type::type(ProgramVariableScope::Input)].size(); + ORT_ENFORCE(input_index < program_.Inputs().size(), + "Too many inputs in the program (", program_.Inputs().size(), ")"); + + const auto& dims = program_.Inputs()[input_index].use_override_shape ? program_.Inputs()[input_index].override_shape + : program_.Inputs()[input_index].tensor->Shape(); + return AddVariableImpl(ProgramVariableScope::Input, name, type, usage, dims); +} + +const ShaderVariable& ShaderHelper::AddOutput(const std::string& name, ProgramVariableDataType type, ShaderVariable::Usage usage) { + const size_t output_index = vars_[std::underlying_type::type(ProgramVariableScope::Output)].size(); + ORT_ENFORCE(output_index < program_.Outputs().size(), + "Too many outputs in the program (", program_.Outputs().size(), ")"); + + const auto& dims = program_.Outputs()[output_index].use_override_shape ? program_.Outputs()[output_index].override_shape + : program_.Outputs()[output_index].tensor->Shape(); + return AddVariableImpl(ProgramVariableScope::Output, name, type, usage, dims); +} + +#ifndef NDEBUG // if debug build +namespace { +Status ValidateVariableDataType(int32_t element_type, ProgramVariableDataType var_type) { + switch (element_type) { + case ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT: + ORT_RETURN_IF_NOT(var_type == ProgramVariableDataType::Float32 || + var_type == ProgramVariableDataType::Vec2Float32 || + var_type == ProgramVariableDataType::Vec4Float32, + "Unexpected program variable type ", int(var_type), " for float32 tensor"); + break; + case ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT16: + ORT_RETURN_IF_NOT(var_type == ProgramVariableDataType::Float16 || + var_type == ProgramVariableDataType::Vec2Float16 || + var_type == ProgramVariableDataType::Vec4Float16, + "Unexpected program variable type ", int(var_type), " for float16 tensor"); + + break; + case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32: + ORT_RETURN_IF_NOT(var_type == ProgramVariableDataType::Int32 || + var_type == ProgramVariableDataType::Vec2Int32 || + var_type == ProgramVariableDataType::Vec4Int32, + "Unexpected program variable type ", int(var_type), " for int32 tensor"); + break; + case ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT32: + ORT_RETURN_IF_NOT(var_type == ProgramVariableDataType::Uint32 || + var_type == ProgramVariableDataType::Vec2Uint32 || + var_type == ProgramVariableDataType::Vec4Uint32, + "Unexpected program variable type ", int(var_type), " for uint32 tensor"); + break; + case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64: + ORT_RETURN_IF_NOT(var_type == ProgramVariableDataType::Int64, + "Unexpected program variable type ", int(var_type), " for int64 tensor"); + break; + case ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT64: + ORT_RETURN_IF_NOT(var_type == ProgramVariableDataType::Uint64, + "Unexpected program variable type ", int(var_type), " for uint64 tensor"); + break; + case ONNX_TENSOR_ELEMENT_DATA_TYPE_BOOL: + ORT_RETURN_IF_NOT(var_type == ProgramVariableDataType::Vec4Bool, + "Unexpected program variable type ", int(var_type), " for bool tensor"); + break; + default: + ORT_RETURN_IF(true, "Unsupported data type: ", element_type); + // todo: add int4/uint4 + } + return Status::OK(); +} + +using RankOrShape = std::variant>; + +Status ValidateVariableShape(const TensorShape& origin_shape, + bool use_override_shape, + const TensorShape& override_shape, + int num_components) { + if (use_override_shape) { + // if override shape specified, assert override_size == ceil( origin_size / 4 ) + ORT_RETURN_IF_NOT((origin_shape.Size() + num_components - 1) / num_components == override_shape.Size(), + "Tensor original shape ", origin_shape, " cannot reshape to ", override_shape, " with component number ", num_components); + } else if (num_components > 1) { + // if shape is not overriden, assert origin_shape[-1] % 4 == 0 + ORT_RETURN_IF_NOT(origin_shape.Size() > 0 && origin_shape[origin_shape.Size() - 1] % num_components == 0, + "Tensor original shape ", origin_shape, " cannot be divided by component number ", num_components); + } + + // if (use_uniform) { + // const auto& rank = std::get(rank_or_shape); + // ORT_RETURN_IF_NOT(rank == SafeInt(override_shape.NumDimensions()), + // "Shader variable rank ", rank, " does not match the tensor shape ", override_shape); + // } else { + // const auto& shape = std::get>(rank_or_shape).get(); + // ORT_RETURN_IF(use_override_shape, "Cannot specify both variable shape and program input/output shape override"); + // ORT_RETURN_IF_NOT(origin_shape.Size() == shape.Size() * num_components, + // "Tensor original shape ", origin_shape, " cannot reshape to ", shape, " with component number ", num_components); + // } + return Status::OK(); +} +} // namespace + +const ShaderVariable& ShaderHelper::AddVariableImpl(ProgramVariableScope scope, + const std::string& name, + ProgramVariableDataType type, + ShaderVariable::Usage usage, + const TensorShape& dims) { + if (scope == ProgramVariableScope::Input || scope == ProgramVariableScope::Output) { + ORT_ENFORCE(vars_[std::underlying_type::type(ProgramVariableScope::Input)].size() + + vars_[std::underlying_type::type(ProgramVariableScope::Output)].size() < + limits_.maxStorageBuffersPerShaderStage, + "Too many storage buffers in shader. Max is ", limits_.maxStorageBuffersPerShaderStage); + } + + if (type == ProgramVariableDataType::Float16 || type == ProgramVariableDataType::Vec2Float16 || type == ProgramVariableDataType::Vec4Float16) { + use_f16_ = true; + } + + if (scope == ProgramVariableScope::Local) { + ORT_NOT_IMPLEMENTED("Local variables are not supported yet."); + } + + return vars_[std::underlying_type::type(scope)].emplace_back(name, type, usage, dims); +} + +Status ShaderHelper::ValidateVariable(const ProgramInput& input, const ShaderVariable& var) const { + ORT_RETURN_IF_ERROR(ValidateVariableDataType(input.tensor->GetElementType(), var.type_)); + ORT_RETURN_IF_ERROR(ValidateVariableShape(input.tensor->Shape(), + input.use_override_shape, + input.use_override_shape ? input.override_shape : input.tensor->Shape(), + var.num_components_)); + + return Status::OK(); +} +Status ShaderHelper::ValidateVariable(const ProgramOutput& output, const ShaderVariable& var) const { + ORT_RETURN_IF_ERROR(ValidateVariableDataType(output.tensor->GetElementType(), var.type_)); + + // todo: add reshaped shape and check + return Status::OK(); +} +#endif + +Status ShaderHelper::GetFinalSourceCode(std::string& code) { std::ostringstream ss; ss.imbue(std::locale::classic()); @@ -87,14 +227,14 @@ std::string ShaderHelper::GetFinalSourceCode() { // Section feature enabling // if (use_f16_) { - ORT_ENFORCE(device_.HasFeature(wgpu::FeatureName::ShaderF16), "Program ", program_.Name(), " requires f16 but the device does not support it."); + ORT_RETURN_IF_NOT(device_.HasFeature(wgpu::FeatureName::ShaderF16), "Program ", program_.Name(), " requires f16 but the device does not support it."); ss << "enable f16;\n"; } // // Section constants // - ss << "\nconst workgroup_size_x: u32 = " << program_.WorkgroupSizeX() + ss << "const workgroup_size_x: u32 = " << program_.WorkgroupSizeX() << ";\nconst workgroup_size_y: u32 = " << program_.WorkgroupSizeY() << ";\nconst workgroup_size_z: u32 = " << program_.WorkgroupSizeZ() << ";\n"; @@ -122,11 +262,23 @@ std::string ShaderHelper::GetFinalSourceCode() { // // Input/output variables // - int variable_count = 0; - for (const auto& input : vars_[static_cast(ProgramVariableScope::Input)]) { + size_t variable_count = 0; + const auto& input_vars = vars_[static_cast(ProgramVariableScope::Input)]; + ORT_RETURN_IF_NOT(input_vars.size() == program_.Inputs().size(), + "Mismatched input variable count. Shader: ", variable_count, ", Program: ", program_.Inputs().size()); + for (const auto& input : input_vars) { +#ifndef NDEBUG // if debug build + ORT_RETURN_IF_ERROR(ValidateVariable(program_.Inputs()[variable_count], input)); +#endif ss << "@group(0) @binding(" << variable_count++ << ") var " << input.name_ << ": array<" << input.StorageType() << ">;\n"; } - for (const auto& output : vars_[static_cast(ProgramVariableScope::Output)]) { + const auto& output_vars = vars_[static_cast(ProgramVariableScope::Output)]; + ORT_RETURN_IF_NOT(output_vars.size() == program_.Outputs().size(), + "Mismatched output variable count. Shader: ", variable_count, ", Program: ", program_.Outputs().size()); + for (const auto& output : output_vars) { +#ifndef NDEBUG // if debug build + ORT_RETURN_IF_ERROR(ValidateVariable(program_.Outputs()[variable_count - input_vars.size()], output)); +#endif ss << "@group(0) @binding(" << variable_count++ << ") var " << output.name_ << ": array<" << output.StorageType() << ">;\n"; } @@ -188,8 +340,8 @@ std::string ShaderHelper::GetFinalSourceCode() { for (const auto& var : var_group) { var.Impl(ss); } - ss << "\n"; } + ss << "\n"; // // Additional Implementation @@ -205,7 +357,8 @@ std::string ShaderHelper::GetFinalSourceCode() { ss << "\n" "}\n"; - return ss.str(); + code = ss.str(); + return Status::OK(); } } // namespace webgpu diff --git a/onnxruntime/core/providers/webgpu/shader_helper.h b/onnxruntime/core/providers/webgpu/shader_helper.h index ac6dfebfef816..e1f008ff6a901 100644 --- a/onnxruntime/core/providers/webgpu/shader_helper.h +++ b/onnxruntime/core/providers/webgpu/shader_helper.h @@ -77,12 +77,13 @@ class ShaderHelper final { Status Init(); - const ShaderVariable& AddVariable(ProgramVariableScope scope, const std::string& name, ProgramVariableDataType type, int rank = 1) { - return AddVariableImpl(scope, name, type, rank); - } - const ShaderVariable& AddVariable(ProgramVariableScope scope, const std::string& name, ProgramVariableDataType type, const TensorShape& dims) { - return AddVariableImpl(scope, name, type, dims); - } + const ShaderVariable& AddInput(const std::string& name, + ProgramVariableDataType type, + ShaderVariable::Usage usage = ShaderVariable::UseIndicesTypeAlias | ShaderVariable::UseValueTypeAlias | ShaderVariable::UseUniform); + + const ShaderVariable& AddOutput(const std::string& name, + ProgramVariableDataType type, + ShaderVariable::Usage usage = ShaderVariable::UseIndicesTypeAlias | ShaderVariable::UseValueTypeAlias | ShaderVariable::UseUniform); template inline std::ostringstream& AppendImplementation(Strs&&... impl) { @@ -91,8 +92,8 @@ class ShaderHelper final { } template - inline std::ostringstream& MainFunctionBody(Strs&&... body) { - onnxruntime::detail::MakeStringImpl(body_, std::forward(body)...); + inline std::ostringstream& MainFunctionBody(const Strs&... body) { + onnxruntime::detail::MakeStringImpl(body_, std::forward>(body)...); return body_; } @@ -101,19 +102,6 @@ class ShaderHelper final { } private: - template // T is one of {int, const TensorShape&} - const ShaderVariable& AddVariableImpl(ProgramVariableScope scope, const std::string& name, ProgramVariableDataType type, T&& arg) { - ORT_ENFORCE((scope == ProgramVariableScope::Input || scope == ProgramVariableScope::Output) && - vars_[static_cast(ProgramVariableScope::Input)].size() + vars_[static_cast(ProgramVariableScope::Output)].size() < limits_.maxStorageBuffersPerShaderStage, - "Too many storage buffers in shader. Max is ", limits_.maxStorageBuffersPerShaderStage); - - if (type == ProgramVariableDataType::Float16 || type == ProgramVariableDataType::Vec2Float16 || type == ProgramVariableDataType::Vec4Float16) { - use_f16_ = true; - } - - return vars_[static_cast(scope)].emplace_back(name, type, std::forward(arg)); - } - template // ConstantType is one of {ProgramConstant, ProgramOverridableConstantValue, ProgramOverridableConstantDefinition} void WriteConstantValue(std::ostringstream& ss, const ConstantType& constant) const { switch (constant.type) { @@ -137,7 +125,18 @@ class ShaderHelper final { } } - std::string GetFinalSourceCode(); + const ShaderVariable& AddVariableImpl(ProgramVariableScope scope, + const std::string& name, + ProgramVariableDataType type, + ShaderVariable::Usage usage, + const TensorShape& dims); + +#ifndef NDEBUG // if debug build + Status ValidateVariable(const ProgramInput& input, const ShaderVariable& var) const; + Status ValidateVariable(const ProgramOutput& output, const ShaderVariable& var) const; +#endif + + Status GetFinalSourceCode(std::string& code); friend class ProgramManager; const wgpu::Device& device_; diff --git a/onnxruntime/core/providers/webgpu/shader_variable.cc b/onnxruntime/core/providers/webgpu/shader_variable.cc index fda4ad72deb20..9a4ebc80bf665 100644 --- a/onnxruntime/core/providers/webgpu/shader_variable.cc +++ b/onnxruntime/core/providers/webgpu/shader_variable.cc @@ -5,6 +5,7 @@ #include #include +#include "core/common/safeint.h" #include "core/providers/webgpu/shader_variable.h" #include "core/providers/webgpu/shader_macros.h" @@ -12,18 +13,15 @@ namespace onnxruntime { namespace webgpu { -ShaderVariable::ShaderVariable(std::string_view name, ProgramVariableDataType type, int rank) - : name_(name), type_(type), rank_(rank), usage_(UseUniform) { - Init(); -} - -ShaderVariable::ShaderVariable(std::string_view name, ProgramVariableDataType type, const TensorShape& dims) - : name_(name), type_(type), rank_(static_cast(dims.NumDimensions())), dims_(dims), usage_(None) { - Init(); -} - -void ShaderVariable::Init() { +ShaderVariable::ShaderVariable(std::string_view name, ProgramVariableDataType type, Usage usage, const TensorShape& dims) + : name_(name), + type_(type), + num_components_{NumberOfComponents(type)}, + rank_{SafeInt(dims.NumDimensions())}, + dims_{dims}, + usage_(usage) { ORT_ENFORCE(type_ != ProgramVariableDataType::InvalidType, "Invalid type for variable ", name_); + ORT_ENFORCE(num_components_ > 0, "Invalid number of components for variable ", name_); } void ShaderVariable::Impl(std::ostringstream& ss) const { @@ -31,17 +29,27 @@ void ShaderVariable::Impl(std::ostringstream& ss) const { const std::string value_t = name_ + "_value_t"; const std::string indices_t = name_ + "_indices_t"; + const std::string element_t = name_ + "_element_t"; const std::string shape = (usage_ & UseUniform) ? "uniforms." + name_ + "_shape" : name_ + "_shape"; const std::string stride = (usage_ & UseUniform) ? "uniforms." + name_ + "_stride" : name_ + "_stride"; // Types - SS("alias ", value_t, " = ", ValueType(), ";\n"); - SS("alias ", indices_t, " = ", IndicesType(), ";\n"); + std::string_view value_type = (usage_ & UseValueTypeAlias) ? value_t : ValueType(); + if (usage_ & UseValueTypeAlias) { + SS("alias ", name_, "_value_t = ", ValueType(), ";\n"); + } + std::string_view indices_type = (usage_ & UseIndicesTypeAlias) ? indices_t : IndicesType(); + if (usage_ & UseIndicesTypeAlias) { + SS("alias ", name_, "_indices_t = ", IndicesType(), ";\n"); + } + if (usage_ & UseElementTypeAlias) { + SS("alias ", name_, "_element_t = ", ElementType(), ";\n"); + } // Need shape and strides when (not use uniform) and (any other usage is enabled) if (!(usage_ & UseUniform) && (usage_ & ~UseUniform)) { - SS("const ", shape, " = ", indices_t, "("); + SS("const ", shape, " = ", indices_type, "("); bool first = true; for (auto dim : dims_.GetDims()) { @@ -54,7 +62,7 @@ void ShaderVariable::Impl(std::ostringstream& ss) const { } ss << ");\n"; - SS("const ", stride, " = ", indices_t, "("); + SS("const ", stride, " = ", indices_type, "("); first = true; for (int i = rank_ - 1; i >= 0; i--) { if (!first) { @@ -69,8 +77,8 @@ void ShaderVariable::Impl(std::ostringstream& ss) const { // Implementation of "fn o2i_{name}" if (usage_ & UseOffsetToIndices) { if (rank_ >= 2) { - SS("fn o2i_", name_, "(offset : u32)->", indices_t, " {\n"); - SS(" var indices: ", indices_t, ";\n"); + SS("fn o2i_", name_, "(offset : u32)->", indices_type, " {\n"); + SS(" var indices: ", indices_type, ";\n"); SS(" var current = offset;\n"); for (int i = 0; i < rank_ - 1; i++) { auto current_stride = GetElementAt(stride, i, rank_); @@ -88,7 +96,7 @@ void ShaderVariable::Impl(std::ostringstream& ss) const { // Implementation of "fn i2o_{name}" if (usage_ & UseIndicesToOffset) { if (rank_ >= 2) { - SS("fn i2o_", name_, "(indices : ", indices_t, ")->u32 {\n"); + SS("fn i2o_", name_, "(indices : ", indices_type, ")->u32 {\n"); SS(" return "); for (int i = 0; i < rank_ - 1; i++) { SS("indices[", i, "] * ", GetElementAt(stride, i, rank_), " + "); @@ -125,7 +133,7 @@ void ShaderVariable::Impl(std::ostringstream& ss) const { for (int i = 1; i < rank_; i++) { SS(", d", i, ": u32"); } - SS(", value: ", value_t, ") {\n"); + SS(", value: ", value_type, ") {\n"); SS(" set_", name_, "_by_indices(d0"); for (int i = 1; i < rank_; i++) { SS(", d", i); @@ -138,7 +146,7 @@ void ShaderVariable::Impl(std::ostringstream& ss) const { // Implementation of "fn set_{name}_by_indices" if (usage_ & UseSetByIndices) { if (rank_ >= 2) { - SS("fn set_", name_, "_by_indices(indices: ", indices_t, ", value: ", value_t, ") {\n"); + SS("fn set_", name_, "_by_indices(indices: ", indices_type, ", value: ", value_type, ") {\n"); SS(" ", SetByOffset("i2o_" + name_ + "(indices)", "value"), "\n"); SS("}\n"); } @@ -151,7 +159,7 @@ void ShaderVariable::Impl(std::ostringstream& ss) const { for (int i = 1; i < rank_; i++) { SS(", d", i, ": u32"); } - SS(")->", value_t, " {\n"); + SS(")->", value_type, " {\n"); SS(" return get_", name_, "_by_indices(d0"); for (int i = 1; i < rank_; i++) { SS(", d", i); @@ -164,7 +172,7 @@ void ShaderVariable::Impl(std::ostringstream& ss) const { // Implementation of "fn get_{name}_by_indices" if (usage_ & UseGetByIndices) { if (rank_ >= 2) { - SS("fn get_", name_, "_by_indices(indices: ", indices_t, ")->", value_t, " {\n"); + SS("fn get_", name_, "_by_indices(indices: ", indices_type, ")->", value_type, " {\n"); SS(" return ", GetByOffset("i2o_" + name_ + "(indices)"), ";\n"); SS("}\n"); } @@ -248,17 +256,17 @@ std::string_view ShaderVariable::StorageType() const { std::string_view ShaderVariable::ValueType() const { constexpr static const std::string_view VALUE_TYPE[] = { "f32", // f32 - "f32", // vec2f32 - "f32", // vec4f32 + "vec2", // vec2f32 + "vec4", // vec4f32 "f16", // f16 - "f16", // vec2f16 - "f16", // vec4f16 + "vec2", // vec2f16 + "vec4", // vec4f16 "i32", // i32 - "i32", // vec2i32 - "i32", // vec4i32 + "vec2", // vec2i32 + "vec4", // vec4i32 "u32", // u32 - "u32", // vec2u32 - "u32", // vec4u32 + "vec2", // vec2u32 + "vec4", // vec4u32 "i32", // int64 (trancated to i32) "u32", // uint64 (trancated to u32) "vec4", // vec4bool @@ -267,6 +275,28 @@ std::string_view ShaderVariable::ValueType() const { return VALUE_TYPE[static_cast(type_)]; } +std::string_view ShaderVariable::ElementType() const { + constexpr static const std::string_view ELEMENT_TYPE[] = { + "f32", // f32 + "f32", // vec2f32 + "f32", // vec4f32 + "f16", // f16 + "f16", // vec2f16 + "f16", // vec4f16 + "i32", // i32 + "i32", // vec2i32 + "i32", // vec4i32 + "u32", // u32 + "u32", // vec2u32 + "u32", // vec4u32 + "i32", // int64 + "u32", // uint64 + "bool", // vec4bool + }; + + return ELEMENT_TYPE[static_cast(type_)]; +} + std::string ShaderVariable::IndicesType() const { return rank_ < 2 ? "u32" : (rank_ < 4 ? MakeStringWithClassicLocale("vec", rank_, "") diff --git a/onnxruntime/core/providers/webgpu/shader_variable.h b/onnxruntime/core/providers/webgpu/shader_variable.h index 34d7674148412..86eaaac5e1591 100644 --- a/onnxruntime/core/providers/webgpu/shader_variable.h +++ b/onnxruntime/core/providers/webgpu/shader_variable.h @@ -5,7 +5,6 @@ #include -#include "core/common/safeint.h" #include "core/framework/tensor_shape.h" #include "core/providers/webgpu/program.h" @@ -39,8 +38,22 @@ std::string GetElementAt(std::string_view var, const TIdx& idx, int rank, bool i class ShaderVariable { public: - ShaderVariable(std::string_view name, ProgramVariableDataType type, int rank); - ShaderVariable(std::string_view name, ProgramVariableDataType type, const TensorShape& dims); + enum Usage : uint32_t { + None = 0, // no usage. this means no additional implementation code will be generated. + UseIndicesTypeAlias = 1, // use type alias "{name}_indices_t" for indices (eg. u32, vec2, vec3, vec4, ...) + UseValueTypeAlias = 2, // use type alias "{name}_value_t" for value (eg. f32, vecT, vec4, ...) + UseElementTypeAlias = 4, // use type alias "{name}_element_t" for element (eg. f32, bool, ...) + UseOffsetToIndices = 8, // use implementation of fn o2i_{name} + UseIndicesToOffset = 16, // use implementation of fn i2o_{name} + UseBroadcastedIndicesToOffset = 32, // use implementation of fn {broadcasted_result_name}_bi2o_{name} + UseSet = 64, // use implementation of fn set_{name} + UseSetByIndices = 128, // use implementation of fn set_{name}_by_indices + UseGet = 256, // use implementation of fn get_{name} + UseGetByIndices = 512, // use implementation of fn get_{name}_by_indices + UseUniform = 1024, // use uniform for shape and stride + }; + + ShaderVariable(std::string_view name, ProgramVariableDataType type, Usage usage, const TensorShape& dims); ShaderVariable(ShaderVariable&&) = default; ShaderVariable& operator=(ShaderVariable&&) = default; @@ -107,18 +120,6 @@ class ShaderVariable { inline std::string GetByOffset(TOffset&& offset) const; private: - enum Usage : uint32_t { - None = 0, - UseOffsetToIndices = 1, - UseIndicesToOffset = 2, - UseBroadcastedIndicesToOffset = 4, - UseSet = 8, - UseSetByIndices = 16, - UseGet = 32, - UseGetByIndices = 64, - UseUniform = 128, - }; - friend ShaderVariable::Usage operator|(ShaderVariable::Usage a, ShaderVariable::Usage b); friend ShaderVariable::Usage operator&(ShaderVariable::Usage a, ShaderVariable::Usage b); friend ShaderVariable::Usage& operator|=(ShaderVariable::Usage& a, ShaderVariable::Usage b); @@ -126,7 +127,6 @@ class ShaderVariable { ORT_DISALLOW_COPY_AND_ASSIGNMENT(ShaderVariable); - void Init(); void Impl(std::ostringstream& ss) const; std::string GetByOffsetImpl(std::string_view offset) const; @@ -134,10 +134,12 @@ class ShaderVariable { std::string_view StorageType() const; std::string_view ValueType() const; + std::string_view ElementType() const; std::string IndicesType() const; std::string name_; ProgramVariableDataType type_; + int num_components_; int rank_; TensorShape dims_; diff --git a/onnxruntime/core/providers/webgpu/webgpu_context.cc b/onnxruntime/core/providers/webgpu/webgpu_context.cc index d2428d8bb7be8..e5852d9a3a6ae 100644 --- a/onnxruntime/core/providers/webgpu/webgpu_context.cc +++ b/onnxruntime/core/providers/webgpu/webgpu_context.cc @@ -144,7 +144,8 @@ Status WebGpuContext::Run(const ComputeContext& context, const ProgramBase& prog }), "All inputs must be tensors on WebGPU buffers."); - ORT_ENFORCE(std::all_of(outputs.begin(), outputs.end(), [](Tensor* tensor) { + ORT_ENFORCE(std::all_of(outputs.begin(), outputs.end(), [](const ProgramOutput& output) { + const auto* tensor = output.tensor; return tensor != nullptr && tensor->Location().mem_type == OrtMemType::OrtMemTypeDefault && tensor->Location().device.Type() == OrtDevice::GPU && @@ -288,7 +289,7 @@ Status WebGpuContext::Run(const ComputeContext& context, const ProgramBase& prog bind_group_entries.push_back({nullptr, entry_index++, reinterpret_cast(const_cast(input.tensor->DataRaw()))}); } for (const auto& output : outputs) { - bind_group_entries.push_back({nullptr, entry_index++, reinterpret_cast(output->MutableDataRaw())}); + bind_group_entries.push_back({nullptr, entry_index++, reinterpret_cast(output.tensor->MutableDataRaw())}); } if (uniform_buffer) { bind_group_entries.push_back({nullptr, entry_index++, uniform_buffer}); From c89159d4407685451a0284a0569cbc69b8be09da Mon Sep 17 00:00:00 2001 From: Yulong Wang <7679871+fs-eire@users.noreply.github.com> Date: Mon, 2 Sep 2024 19:28:30 -0700 Subject: [PATCH 33/77] fix workgroup_size, cache key stringnify and indices type --- onnxruntime/core/providers/webgpu/program.cc | 6 +++--- onnxruntime/core/providers/webgpu/program_cache_key.cc | 10 +++++----- onnxruntime/core/providers/webgpu/shader_helper.cc | 8 +++----- onnxruntime/core/providers/webgpu/shader_variable.cc | 4 ++-- 4 files changed, 13 insertions(+), 15 deletions(-) diff --git a/onnxruntime/core/providers/webgpu/program.cc b/onnxruntime/core/providers/webgpu/program.cc index 91f86d2cf681a..4a5785dc4def1 100644 --- a/onnxruntime/core/providers/webgpu/program.cc +++ b/onnxruntime/core/providers/webgpu/program.cc @@ -163,9 +163,9 @@ ProgramBase::ProgramBase(const std::string& name) dispatch_group_size_x_{0}, dispatch_group_size_y_{0}, dispatch_group_size_z_{0}, - workgroup_size_x_{WORKGROUP_SIZE}, - workgroup_size_y_{1}, - workgroup_size_z_{1} { + workgroup_size_x_{0}, + workgroup_size_y_{0}, + workgroup_size_z_{0} { } ProgramBase& ProgramBase::Inputs(std::initializer_list inputs) { diff --git a/onnxruntime/core/providers/webgpu/program_cache_key.cc b/onnxruntime/core/providers/webgpu/program_cache_key.cc index 7bea82a1b0c65..944fbb0bf8a50 100644 --- a/onnxruntime/core/providers/webgpu/program_cache_key.cc +++ b/onnxruntime/core/providers/webgpu/program_cache_key.cc @@ -13,7 +13,7 @@ void AppendTensorInfo(std::ostringstream& ss, const Tensor& tensor, ProgramTenso if (first) { first = false; } else { - ss << "|"; + ss << '|'; } if ((dependency & ProgramTensorMetadataDependency::Type) == ProgramTensorMetadataDependency::Type) { #ifndef NDEBUG // if debug build @@ -21,12 +21,12 @@ void AppendTensorInfo(std::ostringstream& ss, const Tensor& tensor, ProgramTenso #else ss << output.tensor->GetElementType(); #endif + ss << ';'; } - ss << ";"; if ((dependency & ProgramTensorMetadataDependency::Shape) == ProgramTensorMetadataDependency::Shape) { - ss D("Rank=") << tensor.Shape().NumDimensions(); - } else if ((dependency & ProgramTensorMetadataDependency::Rank) == ProgramTensorMetadataDependency::Rank) { ss D("Dims=") << tensor.Shape().ToString(); + } else if ((dependency & ProgramTensorMetadataDependency::Rank) == ProgramTensorMetadataDependency::Rank) { + ss D("Rank=") << tensor.Shape().NumDimensions(); } } } // namespace @@ -49,7 +49,7 @@ std::string CalculateProgramCacheKey(const ProgramBase& program, bool is_1d_disp // append custom cache hint if any if (auto& hint = program.CacheHint(); !hint.empty()) { - ss << "[" D("CacheHint=") << hint << "]"; + ss << '[' D("CacheHint=") << hint << ']'; } // append workgroup size if overridden diff --git a/onnxruntime/core/providers/webgpu/shader_helper.cc b/onnxruntime/core/providers/webgpu/shader_helper.cc index c8c79dd6233d2..054910a7dd57c 100644 --- a/onnxruntime/core/providers/webgpu/shader_helper.cc +++ b/onnxruntime/core/providers/webgpu/shader_helper.cc @@ -35,12 +35,10 @@ Status ShaderHelper::Init() { // dispatch group size is normalized so no need to validate it here // validate workgroup size - auto workgroup_size_x = program_.WorkgroupSizeX(); - auto workgroup_size_y = program_.WorkgroupSizeY(); - auto workgroup_size_z = program_.WorkgroupSizeZ(); + auto workgroup_size_x = program_.WorkgroupSizeX() == 0 ? WORKGROUP_SIZE : program_.WorkgroupSizeX(); + auto workgroup_size_y = program_.WorkgroupSizeY() == 0 ? 1 : program_.WorkgroupSizeY(); + auto workgroup_size_z = program_.WorkgroupSizeZ() == 0 ? 1 : program_.WorkgroupSizeZ(); - ORT_RETURN_IF_NOT(workgroup_size_x > 0 && workgroup_size_y > 0 && workgroup_size_z > 0, - "Workgroup size must be greater than 0"); ORT_RETURN_IF_NOT(workgroup_size_x <= limits_.maxComputeWorkgroupSizeX && workgroup_size_y <= limits_.maxComputeWorkgroupSizeY && workgroup_size_z <= limits_.maxComputeWorkgroupSizeZ, diff --git a/onnxruntime/core/providers/webgpu/shader_variable.cc b/onnxruntime/core/providers/webgpu/shader_variable.cc index 9a4ebc80bf665..ef80fd3c57f6c 100644 --- a/onnxruntime/core/providers/webgpu/shader_variable.cc +++ b/onnxruntime/core/providers/webgpu/shader_variable.cc @@ -28,7 +28,6 @@ void ShaderVariable::Impl(std::ostringstream& ss) const { // Start generating code const std::string value_t = name_ + "_value_t"; - const std::string indices_t = name_ + "_indices_t"; const std::string element_t = name_ + "_element_t"; const std::string shape = (usage_ & UseUniform) ? "uniforms." + name_ + "_shape" : name_ + "_shape"; @@ -36,10 +35,11 @@ void ShaderVariable::Impl(std::ostringstream& ss) const { // Types std::string_view value_type = (usage_ & UseValueTypeAlias) ? value_t : ValueType(); + const std::string indices_type = (usage_ & UseIndicesTypeAlias) ? name_ + "_indices_t" : IndicesType(); + if (usage_ & UseValueTypeAlias) { SS("alias ", name_, "_value_t = ", ValueType(), ";\n"); } - std::string_view indices_type = (usage_ & UseIndicesTypeAlias) ? indices_t : IndicesType(); if (usage_ & UseIndicesTypeAlias) { SS("alias ", name_, "_indices_t = ", IndicesType(), ";\n"); } From 5ea5936a2e0927bd2f54242de83087e21376f75e Mon Sep 17 00:00:00 2001 From: Yulong Wang <7679871+fs-eire@users.noreply.github.com> Date: Mon, 2 Sep 2024 20:37:33 -0700 Subject: [PATCH 34/77] shape_uniforms preparation --- .../core/providers/webgpu/program_manager.cc | 9 ++++-- .../core/providers/webgpu/program_manager.h | 6 ++-- .../core/providers/webgpu/shader_helper.cc | 23 ++++++++++++++ .../core/providers/webgpu/shader_helper.h | 31 ++++++++++++++----- .../core/providers/webgpu/webgpu_context.cc | 15 +++++++-- 5 files changed, 69 insertions(+), 15 deletions(-) diff --git a/onnxruntime/core/providers/webgpu/program_manager.cc b/onnxruntime/core/providers/webgpu/program_manager.cc index a10412f21f498..ff956b46697c4 100644 --- a/onnxruntime/core/providers/webgpu/program_manager.cc +++ b/onnxruntime/core/providers/webgpu/program_manager.cc @@ -13,8 +13,8 @@ namespace onnxruntime { namespace webgpu { -ProgramArtifact::ProgramArtifact(const ProgramBase& program, wgpu::ComputePipeline&& compute_pipeline) - : name{program.Name()}, compute_pipeline{compute_pipeline} {} +ProgramArtifact::ProgramArtifact(const ProgramBase& program, wgpu::ComputePipeline&& compute_pipeline, std::vector&& shape_uniforms) + : name{program.Name()}, compute_pipeline{compute_pipeline}, shape_uniforms{shape_uniforms} {} Status ProgramManager::NormalizeDispatchGroupSize(uint32_t& x, uint32_t& y, uint32_t& z) const { ORT_RETURN_IF(x == 0 || y == 0 || z == 0, "Invalid dispatch group size (", x, ", ", y, ", ", z, ")"); @@ -43,7 +43,8 @@ Status ProgramManager::Build(const ProgramBase& program, uint32_t normalized_dispatch_x, uint32_t normalized_dispatch_y, uint32_t normalized_dispatch_z, - wgpu::ComputePipeline& compute_pipeline) const { + wgpu::ComputePipeline& compute_pipeline, + std::vector& shape_uniforms) const { ShaderHelper shader_helper{program, program_metadata, device_, @@ -55,6 +56,8 @@ Status ProgramManager::Build(const ProgramBase& program, ORT_RETURN_IF_ERROR(program.GenerateShaderCode(shader_helper)); + ORT_RETURN_IF_ERROR(shader_helper.AppendShapeUniformValues(shape_uniforms)); + // code is a large std::string that contains the final shader code std::string code; ORT_RETURN_IF_ERROR(shader_helper.GetFinalSourceCode(code)); diff --git a/onnxruntime/core/providers/webgpu/program_manager.h b/onnxruntime/core/providers/webgpu/program_manager.h index 087c75bfee773..5f4c28a140a50 100644 --- a/onnxruntime/core/providers/webgpu/program_manager.h +++ b/onnxruntime/core/providers/webgpu/program_manager.h @@ -23,10 +23,11 @@ namespace webgpu { class ProgramArtifact { public: - ProgramArtifact(const ProgramBase& program, wgpu::ComputePipeline&& compute_pipeline); + ProgramArtifact(const ProgramBase& program, wgpu::ComputePipeline&& compute_pipeline, std::vector&& shape_uniforms); std::string name; wgpu::ComputePipeline compute_pipeline; + std::vector shape_uniforms; ProgramArtifact(ProgramArtifact&&) = default; ProgramArtifact& operator=(ProgramArtifact&&) = default; @@ -49,7 +50,8 @@ class ProgramManager { uint32_t normalized_dispatch_x, uint32_t normalized_dispatch_y, uint32_t normalized_dispatch_z, - wgpu::ComputePipeline& compute_pipeline) const; + wgpu::ComputePipeline& compute_pipeline, + std::vector& shape_uniforms) const; const ProgramArtifact* Get(const std::string& key) const; const ProgramArtifact* Set(const std::string& key, ProgramArtifact&& program); diff --git a/onnxruntime/core/providers/webgpu/shader_helper.cc b/onnxruntime/core/providers/webgpu/shader_helper.cc index 054910a7dd57c..cf040ddfa4271 100644 --- a/onnxruntime/core/providers/webgpu/shader_helper.cc +++ b/onnxruntime/core/providers/webgpu/shader_helper.cc @@ -215,6 +215,29 @@ Status ShaderHelper::ValidateVariable(const ProgramOutput& output, const ShaderV // todo: add reshaped shape and check return Status::OK(); } + +Status ShaderHelper::AppendShapeUniformValues(std::vector& /*shape_uniforms*/) const { + // TODO: move input/output check(validation) here + // TODO: also check input dependencies with actual usages. + // [deps] [usages] + // input -> use shape && !use_uniform -> OK + // input -> use shape && use_uniform -> err + // input -> !use shape && !use_uniform -> err: must use shape if not using uniform + // input -> !use shape && use_uniform -> + // use_rank -> OK + // !use_rank -> err: must use rank + // + // output -> do not check + + // TODO: tensor shape and strides adding to uniforms (in front) + // when: use_rank && rank >=2 + // info need for codegen: [rank, variable name] content -> "vecN {name}_shape, vecN {name}_strides" + // // further optimization: strides can be vecN-1 + // minimal info stored in artifact: array<[rank, variable name] | not_use > + + return Status::OK(); +} + #endif Status ShaderHelper::GetFinalSourceCode(std::string& code) { diff --git a/onnxruntime/core/providers/webgpu/shader_helper.h b/onnxruntime/core/providers/webgpu/shader_helper.h index e1f008ff6a901..dc46b62754261 100644 --- a/onnxruntime/core/providers/webgpu/shader_helper.h +++ b/onnxruntime/core/providers/webgpu/shader_helper.h @@ -77,28 +77,41 @@ class ShaderHelper final { Status Init(); + // Add an input variable to the shader. + // + // depending on the usage of the variable, additional code may be generated. const ShaderVariable& AddInput(const std::string& name, ProgramVariableDataType type, ShaderVariable::Usage usage = ShaderVariable::UseIndicesTypeAlias | ShaderVariable::UseValueTypeAlias | ShaderVariable::UseUniform); + // Add an output variable to the shader. + // + // depending on the usage of the variable, additional code may be generated. const ShaderVariable& AddOutput(const std::string& name, ProgramVariableDataType type, ShaderVariable::Usage usage = ShaderVariable::UseIndicesTypeAlias | ShaderVariable::UseValueTypeAlias | ShaderVariable::UseUniform); + // Append additional implementation code to the shader. + // + // can be called multiple times. template - inline std::ostringstream& AppendImplementation(Strs&&... impl) { + inline ShaderHelper& AppendImplementation(Strs&&... impl) { onnxruntime::detail::MakeStringImpl(additional_implementation_, std::forward(impl)...); - return additional_implementation_; + return *this; } + // Set the main function body of the shader. + // + // can be called only once. template - inline std::ostringstream& MainFunctionBody(const Strs&... body) { + inline void MainFunctionBody(const Strs&... body) { + ORT_ENFORCE(!body_set_, "Main function body is already set"); onnxruntime::detail::MakeStringImpl(body_, std::forward>(body)...); - return body_; + body_set_ = true; } - std::string GuardAgainstOutOfBoundsWorkgroupSizes(const std::string& size) const { - return " if (global_idx >= " + size + ") { return; }\n"; + std::string GuardAgainstOutOfBoundsWorkgroupSizes(std::string_view size) const { + return MakeStringWithClassicLocale(" if (global_idx >= ", size, ") { return; }\n"); } private: @@ -135,7 +148,9 @@ class ShaderHelper final { Status ValidateVariable(const ProgramInput& input, const ShaderVariable& var) const; Status ValidateVariable(const ProgramOutput& output, const ShaderVariable& var) const; #endif - + // Append the uniform values of all shape variables. Including shape/strides of input/output variables, + // if UseUniform is set in the usage of the variable. + Status AppendShapeUniformValues(std::vector& shape_uniforms) const; Status GetFinalSourceCode(std::string& code); friend class ProgramManager; @@ -149,11 +164,11 @@ class ShaderHelper final { const ProgramMetadata& program_metadata_; std::array, static_cast(ProgramVariableScope::Count)> vars_; - std::ostringstream ss2; std::ostringstream additional_implementation_; std::ostringstream body_; bool use_f16_ = false; + bool body_set_ = false; }; } // namespace webgpu diff --git a/onnxruntime/core/providers/webgpu/webgpu_context.cc b/onnxruntime/core/providers/webgpu/webgpu_context.cc index e5852d9a3a6ae..638bc7c1f7c05 100644 --- a/onnxruntime/core/providers/webgpu/webgpu_context.cc +++ b/onnxruntime/core/providers/webgpu/webgpu_context.cc @@ -212,6 +212,7 @@ Status WebGpuContext::Run(const ComputeContext& context, const ProgramBase& prog const auto* program_artifact = program_mgr_->Get(key); if (program_artifact == nullptr) { wgpu::ComputePipeline compute_pipeline; + std::vector shape_uniforms; auto status = program_mgr_->Build(program, metadata, #ifndef NDEBUG // if debug build @@ -220,15 +221,25 @@ Status WebGpuContext::Run(const ComputeContext& context, const ProgramBase& prog x, y, z, - compute_pipeline); + compute_pipeline, + shape_uniforms); ORT_RETURN_IF_ERROR(status); - program_artifact = program_mgr_->Set(key, ProgramArtifact{program, std::move(compute_pipeline)}); + program_artifact = program_mgr_->Set(key, ProgramArtifact{program, + std::move(compute_pipeline), + std::move(shape_uniforms)}); #ifndef NDEBUG // if debug build ORT_ENFORCE(program_artifact != nullptr, "Program artifact should not be nullptr."); #endif } // prepare uniform info + + // TODO: also append artifacts uniform info and fill in actual input/output (override) shape value + + // foreach (uniform in artifact) { + // check if match; + // if match, create ProgramUniformVariableValue + // } const auto& uniforms = program.UniformVariables(); size_t current_offset = 0; std::vector> uniform_and_offsets; From 7d8305445379227709457b46389d02bf77dd048f Mon Sep 17 00:00:00 2001 From: Yulong Wang <7679871+fs-eire@users.noreply.github.com> Date: Tue, 3 Sep 2024 01:06:34 -0700 Subject: [PATCH 35/77] allow uniforms of input/output shape/stride being added automatically --- .../providers/webgpu/program_cache_key.cc | 1 - .../core/providers/webgpu/program_manager.cc | 14 +- .../core/providers/webgpu/program_manager.h | 10 +- .../core/providers/webgpu/shader_helper.cc | 174 ++++++++++++------ .../core/providers/webgpu/shader_helper.h | 19 +- .../core/providers/webgpu/webgpu_context.cc | 47 +++-- 6 files changed, 177 insertions(+), 88 deletions(-) diff --git a/onnxruntime/core/providers/webgpu/program_cache_key.cc b/onnxruntime/core/providers/webgpu/program_cache_key.cc index 944fbb0bf8a50..c6ab16a73423d 100644 --- a/onnxruntime/core/providers/webgpu/program_cache_key.cc +++ b/onnxruntime/core/providers/webgpu/program_cache_key.cc @@ -57,7 +57,6 @@ std::string CalculateProgramCacheKey(const ProgramBase& program, bool is_1d_disp x != 0 || y != 0 || z != 0) { ss << ":" D("WorkgroupSize="); // only append non-zero values. zero values are considered as use default - // todo: this is actually not working correctly. revisit this logic. currently even if it's default, the value is not zero and will be appended if (x > 0) { ss << x; } diff --git a/onnxruntime/core/providers/webgpu/program_manager.cc b/onnxruntime/core/providers/webgpu/program_manager.cc index ff956b46697c4..3e4fbd33a6bdf 100644 --- a/onnxruntime/core/providers/webgpu/program_manager.cc +++ b/onnxruntime/core/providers/webgpu/program_manager.cc @@ -1,6 +1,8 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. +#include + #include "core/common/common.h" #include "core/common/safeint.h" @@ -13,8 +15,10 @@ namespace onnxruntime { namespace webgpu { -ProgramArtifact::ProgramArtifact(const ProgramBase& program, wgpu::ComputePipeline&& compute_pipeline, std::vector&& shape_uniforms) - : name{program.Name()}, compute_pipeline{compute_pipeline}, shape_uniforms{shape_uniforms} {} +ProgramArtifact::ProgramArtifact(const ProgramBase& program, wgpu::ComputePipeline&& compute_pipeline, std::vector&& shape_uniform_ranks) + : name{program.Name()}, + compute_pipeline{compute_pipeline}, + shape_uniform_ranks{shape_uniform_ranks} {} Status ProgramManager::NormalizeDispatchGroupSize(uint32_t& x, uint32_t& y, uint32_t& z) const { ORT_RETURN_IF(x == 0 || y == 0 || z == 0, "Invalid dispatch group size (", x, ", ", y, ", ", z, ")"); @@ -44,7 +48,7 @@ Status ProgramManager::Build(const ProgramBase& program, uint32_t normalized_dispatch_y, uint32_t normalized_dispatch_z, wgpu::ComputePipeline& compute_pipeline, - std::vector& shape_uniforms) const { + std::vector& shape_uniform_ranks) const { ShaderHelper shader_helper{program, program_metadata, device_, @@ -56,11 +60,11 @@ Status ProgramManager::Build(const ProgramBase& program, ORT_RETURN_IF_ERROR(program.GenerateShaderCode(shader_helper)); - ORT_RETURN_IF_ERROR(shader_helper.AppendShapeUniformValues(shape_uniforms)); + ORT_RETURN_IF_ERROR(shader_helper.ValidateShapeForInputsAndOutputs()); // code is a large std::string that contains the final shader code std::string code; - ORT_RETURN_IF_ERROR(shader_helper.GetFinalSourceCode(code)); + ORT_RETURN_IF_ERROR(shader_helper.GenerateSourceCode(code, shape_uniform_ranks)); LOGS_DEFAULT(VERBOSE) << "\n=== WebGPU Shader code [" << program.Name() #ifndef NDEBUG // if debug build diff --git a/onnxruntime/core/providers/webgpu/program_manager.h b/onnxruntime/core/providers/webgpu/program_manager.h index 5f4c28a140a50..782788910e3a5 100644 --- a/onnxruntime/core/providers/webgpu/program_manager.h +++ b/onnxruntime/core/providers/webgpu/program_manager.h @@ -23,11 +23,11 @@ namespace webgpu { class ProgramArtifact { public: - ProgramArtifact(const ProgramBase& program, wgpu::ComputePipeline&& compute_pipeline, std::vector&& shape_uniforms); + ProgramArtifact(const ProgramBase& program, wgpu::ComputePipeline&& compute_pipeline, std::vector&& shape_uniform_ranks); - std::string name; - wgpu::ComputePipeline compute_pipeline; - std::vector shape_uniforms; + const std::string name; + const wgpu::ComputePipeline compute_pipeline; + const std::vector shape_uniform_ranks; ProgramArtifact(ProgramArtifact&&) = default; ProgramArtifact& operator=(ProgramArtifact&&) = default; @@ -51,7 +51,7 @@ class ProgramManager { uint32_t normalized_dispatch_y, uint32_t normalized_dispatch_z, wgpu::ComputePipeline& compute_pipeline, - std::vector& shape_uniforms) const; + std::vector& shape_uniform_ranks) const; const ProgramArtifact* Get(const std::string& key) const; const ProgramArtifact* Set(const std::string& key, ProgramArtifact&& program); diff --git a/onnxruntime/core/providers/webgpu/shader_helper.cc b/onnxruntime/core/providers/webgpu/shader_helper.cc index cf040ddfa4271..d06a6573ab2bd 100644 --- a/onnxruntime/core/providers/webgpu/shader_helper.cc +++ b/onnxruntime/core/providers/webgpu/shader_helper.cc @@ -35,9 +35,9 @@ Status ShaderHelper::Init() { // dispatch group size is normalized so no need to validate it here // validate workgroup size - auto workgroup_size_x = program_.WorkgroupSizeX() == 0 ? WORKGROUP_SIZE : program_.WorkgroupSizeX(); - auto workgroup_size_y = program_.WorkgroupSizeY() == 0 ? 1 : program_.WorkgroupSizeY(); - auto workgroup_size_z = program_.WorkgroupSizeZ() == 0 ? 1 : program_.WorkgroupSizeZ(); + auto workgroup_size_x = program_.WorkgroupSizeX(); + auto workgroup_size_y = program_.WorkgroupSizeY(); + auto workgroup_size_z = program_.WorkgroupSizeZ(); ORT_RETURN_IF_NOT(workgroup_size_x <= limits_.maxComputeWorkgroupSizeX && workgroup_size_y <= limits_.maxComputeWorkgroupSizeY && @@ -163,16 +163,6 @@ Status ValidateVariableShape(const TensorShape& origin_shape, "Tensor original shape ", origin_shape, " cannot be divided by component number ", num_components); } - // if (use_uniform) { - // const auto& rank = std::get(rank_or_shape); - // ORT_RETURN_IF_NOT(rank == SafeInt(override_shape.NumDimensions()), - // "Shader variable rank ", rank, " does not match the tensor shape ", override_shape); - // } else { - // const auto& shape = std::get>(rank_or_shape).get(); - // ORT_RETURN_IF(use_override_shape, "Cannot specify both variable shape and program input/output shape override"); - // ORT_RETURN_IF_NOT(origin_shape.Size() == shape.Size() * num_components, - // "Tensor original shape ", origin_shape, " cannot reshape to ", shape, " with component number ", num_components); - // } return Status::OK(); } } // namespace @@ -211,36 +201,75 @@ Status ShaderHelper::ValidateVariable(const ProgramInput& input, const ShaderVar } Status ShaderHelper::ValidateVariable(const ProgramOutput& output, const ShaderVariable& var) const { ORT_RETURN_IF_ERROR(ValidateVariableDataType(output.tensor->GetElementType(), var.type_)); - - // todo: add reshaped shape and check + ORT_RETURN_IF_ERROR(ValidateVariableShape(output.tensor->Shape(), + output.use_override_shape, + output.use_override_shape ? output.override_shape : output.tensor->Shape(), + var.num_components_)); return Status::OK(); } -Status ShaderHelper::AppendShapeUniformValues(std::vector& /*shape_uniforms*/) const { - // TODO: move input/output check(validation) here - // TODO: also check input dependencies with actual usages. - // [deps] [usages] - // input -> use shape && !use_uniform -> OK - // input -> use shape && use_uniform -> err - // input -> !use shape && !use_uniform -> err: must use shape if not using uniform - // input -> !use shape && use_uniform -> - // use_rank -> OK - // !use_rank -> err: must use rank - // - // output -> do not check +Status ShaderHelper::ValidateShapeForInputsAndOutputs() const { + const auto& input_vars = vars_[static_cast(ProgramVariableScope::Input)]; + const auto& output_vars = vars_[static_cast(ProgramVariableScope::Output)]; + + // Validate input/output as dependencies of shape_uniforms + ORT_RETURN_IF_NOT(input_vars.size() == program_.Inputs().size(), + "Mismatched input variable count. Shader: ", input_vars.size(), ", Program: ", program_.Inputs().size()); + ORT_RETURN_IF_NOT(output_vars.size() == program_.Outputs().size(), + "Mismatched output variable count. Shader: ", output_vars.size(), ", Program: ", program_.Outputs().size()); + + for (size_t i = 0; i < input_vars.size(); i++) { +#ifndef NDEBUG // if debug build + // Validate input shape + ORT_RETURN_IF_ERROR(ValidateVariable(program_.Inputs()[i], input_vars[i])); +#endif - // TODO: tensor shape and strides adding to uniforms (in front) - // when: use_rank && rank >=2 - // info need for codegen: [rank, variable name] content -> "vecN {name}_shape, vecN {name}_strides" - // // further optimization: strides can be vecN-1 - // minimal info stored in artifact: array<[rank, variable name] | not_use > + // check input dependencies with actual usages. + auto usage = input_vars[i].usage_; + bool use_uniform = (usage & ShaderVariable::UseUniform) == ShaderVariable::UseUniform; + auto dependency = program_.Inputs()[i].dependency; + bool use_rank = (dependency & ProgramTensorMetadataDependency::Rank) == ProgramTensorMetadataDependency::Rank; + bool use_shape = (dependency & ProgramTensorMetadataDependency::Shape) == ProgramTensorMetadataDependency::Shape; + if (use_uniform) { + ORT_RETURN_IF_NOT((use_rank || input_vars[i].rank_ < 2) && !use_shape, + "When UseUniform is set in variable usage, the corresponding program input should depend on rank but not shape."); + } else { + ORT_RETURN_IF_NOT(use_shape, + "When UseUniform is not set in variable usage, the corresponding program input should depend on shape."); + // If you want neither hard-coded shape nor shape uniform, set UseUniform with a flattened shape (rank=1). + // This will not generate any shape variables in the shader, can you can only use offset to set/get values. + } + } + + for (size_t i = 0; i < output_vars.size(); i++) { +#ifndef NDEBUG // if debug build + // Validate output shape + ORT_RETURN_IF_ERROR(ValidateVariable(program_.Outputs()[i], output_vars[i])); +#endif + + // check output dependencies with actual usages. + auto usage = output_vars[i].usage_; + bool use_uniform = (usage & ShaderVariable::UseUniform) == ShaderVariable::UseUniform; + auto dependency = program_.Outputs()[i].dependency; + bool use_shape = (dependency & ProgramTensorMetadataDependency::Shape) == ProgramTensorMetadataDependency::Shape; + + if (use_uniform) { + // output tensor shape check is looser than input tensor shape check, because output shape is always calculated so it is not + // necessarily a part of the cache key. + ORT_RETURN_IF_NOT(!use_shape, + "When UseUniform is set in variable usage, the corresponding program output should not depend on shape."); + } else { + ORT_RETURN_IF_NOT(use_shape, + "When UseUniform is not set in variable usage, the corresponding program output should depend on shape."); + } + } return Status::OK(); } #endif -Status ShaderHelper::GetFinalSourceCode(std::string& code) { +Status ShaderHelper::GenerateSourceCode(std::string& code, std::vector& shape_uniform_ranks) const { std::ostringstream ss; ss.imbue(std::locale::classic()); @@ -255,9 +284,10 @@ Status ShaderHelper::GetFinalSourceCode(std::string& code) { // // Section constants // - ss << "const workgroup_size_x: u32 = " << program_.WorkgroupSizeX() - << ";\nconst workgroup_size_y: u32 = " << program_.WorkgroupSizeY() - << ";\nconst workgroup_size_z: u32 = " << program_.WorkgroupSizeZ() << ";\n"; + ss << "const workgroup_size_x: u32 = " << (program_.WorkgroupSizeX() == 0 ? uint32_t(WORKGROUP_SIZE) : program_.WorkgroupSizeX()) + << ";\nconst workgroup_size_y: u32 = " << (program_.WorkgroupSizeY() == 0 ? uint32_t(1) : program_.WorkgroupSizeY()) + << ";\nconst workgroup_size_z: u32 = " << (program_.WorkgroupSizeZ() == 0 ? uint32_t(1) : program_.WorkgroupSizeZ()) + << ";\n"; for (const auto& constant : program_metadata_.constants) { ss << "const " << constant.name << ": " << constant.type << " = "; @@ -285,44 +315,44 @@ Status ShaderHelper::GetFinalSourceCode(std::string& code) { // size_t variable_count = 0; const auto& input_vars = vars_[static_cast(ProgramVariableScope::Input)]; - ORT_RETURN_IF_NOT(input_vars.size() == program_.Inputs().size(), - "Mismatched input variable count. Shader: ", variable_count, ", Program: ", program_.Inputs().size()); for (const auto& input : input_vars) { -#ifndef NDEBUG // if debug build - ORT_RETURN_IF_ERROR(ValidateVariable(program_.Inputs()[variable_count], input)); -#endif ss << "@group(0) @binding(" << variable_count++ << ") var " << input.name_ << ": array<" << input.StorageType() << ">;\n"; } const auto& output_vars = vars_[static_cast(ProgramVariableScope::Output)]; - ORT_RETURN_IF_NOT(output_vars.size() == program_.Outputs().size(), - "Mismatched output variable count. Shader: ", variable_count, ", Program: ", program_.Outputs().size()); for (const auto& output : output_vars) { -#ifndef NDEBUG // if debug build - ORT_RETURN_IF_ERROR(ValidateVariable(program_.Outputs()[variable_count - input_vars.size()], output)); -#endif ss << "@group(0) @binding(" << variable_count++ << ") var " << output.name_ << ": array<" << output.StorageType() << ">;\n"; } // // uniform variables // - if (std::any_of(program_.UniformVariables().cbegin(), - program_.UniformVariables().cend(), - [](const ProgramUniformVariableValue& x) { return x.length > 0; })) { - bool first = true; - ss << "struct Uniforms {"; - size_t uniform_count = program_.UniformVariables().size(); - for (size_t i = 0; i < uniform_count; i++) { - const auto& uniform_def = program_metadata_.uniform_variables[i]; - const auto& uniform_value = program_.UniformVariables()[i]; + // store shape uniform ranks in shape_uniform_ranks + bool use_any_shape_uniform = false; + ORT_ENFORCE(shape_uniform_ranks.size() == 0); + shape_uniform_ranks.reserve(input_vars.size() + output_vars.size()); + + for (const auto& input : vars_[static_cast(ProgramVariableScope::Input)]) { + bool use_uniform = (input.usage_ & ShaderVariable::UseUniform) == ShaderVariable::UseUniform && input.rank_ > 1; + use_any_shape_uniform |= use_uniform; + shape_uniform_ranks.push_back(use_uniform ? input.rank_ : 0); + } + for (const auto& output : vars_[static_cast(ProgramVariableScope::Output)]) { + bool use_uniform = (output.usage_ & ShaderVariable::UseUniform) == ShaderVariable::UseUniform && output.rank_ > 1; + use_any_shape_uniform |= use_uniform; + shape_uniform_ranks.push_back(use_uniform ? output.rank_ : 0); + } - const auto& name = uniform_def.name; - const auto& data_type = uniform_def.data_type; - const auto length = uniform_value.length; + if (use_any_shape_uniform || std::any_of(program_.UniformVariables().cbegin(), + program_.UniformVariables().cend(), + [](const ProgramUniformVariableValue& x) { return x.length > 0; })) { + bool first = true; + ss << "struct Uniforms {"; + // lambda append_uniform is used to append one uniform variable to the uniform struct + auto append_uniform = [&ss, &first](std::string_view name, ProgramUniformVariableDataType data_type, size_t length) { if (length == 0) { - continue; + return; } if (first) { @@ -346,6 +376,30 @@ Status ShaderHelper::GetFinalSourceCode(std::string& code) { } else { ss << data_type; } + }; + + for (const auto& input : vars_[static_cast(ProgramVariableScope::Input)]) { + if (input.rank_ > 1) { + std::string shape = input.name_ + "_shape"; + std::string stride = input.name_ + "_stride"; + append_uniform(shape, ProgramUniformVariableDataType::Uint32, input.rank_); + append_uniform(stride, ProgramUniformVariableDataType::Uint32, input.rank_); + } + } + + for (const auto& output : vars_[static_cast(ProgramVariableScope::Output)]) { + if (output.rank_ > 1) { + std::string shape = output.name_ + "_shape"; + std::string stride = output.name_ + "_stride"; + append_uniform(shape, ProgramUniformVariableDataType::Uint32, output.rank_); + append_uniform(stride, ProgramUniformVariableDataType::Uint32, output.rank_); + } + } + + for (size_t i = 0; i < program_.UniformVariables().size(); i++) { + const auto& uniform_def = program_metadata_.uniform_variables[i]; + const auto& uniform_value = program_.UniformVariables()[i]; + append_uniform(uniform_def.name, uniform_def.data_type, uniform_value.length); } ss << "\n};\n" @@ -368,13 +422,11 @@ Status ShaderHelper::GetFinalSourceCode(std::string& code) { // Additional Implementation // ss << additional_implementation_.str(); - additional_implementation_.str(""); // // Main Function Body // ss << body_.str(); - body_.str(""); ss << "\n" "}\n"; diff --git a/onnxruntime/core/providers/webgpu/shader_helper.h b/onnxruntime/core/providers/webgpu/shader_helper.h index dc46b62754261..bb04c4ad628aa 100644 --- a/onnxruntime/core/providers/webgpu/shader_helper.h +++ b/onnxruntime/core/providers/webgpu/shader_helper.h @@ -148,10 +148,21 @@ class ShaderHelper final { Status ValidateVariable(const ProgramInput& input, const ShaderVariable& var) const; Status ValidateVariable(const ProgramOutput& output, const ShaderVariable& var) const; #endif - // Append the uniform values of all shape variables. Including shape/strides of input/output variables, - // if UseUniform is set in the usage of the variable. - Status AppendShapeUniformValues(std::vector& shape_uniforms) const; - Status GetFinalSourceCode(std::string& code); + + Status ShaderHelper::ValidateShapeForInputsAndOutputs() const; + + // Generate source code. + // + // This function: + // - performs validation if neccessary, + // - appends the ranks for variables to the shape_uniform_ranks. + // (The rank value is zero if no uniform is needed for the variable.) + // - generates the final source code. + // + // \param code The generated full WGSL source code. + // \param shape_uniform_ranks The ranks for variables that need a uniform for the shape. + // + Status GenerateSourceCode(std::string& code, std::vector& shape_uniform_ranks) const; friend class ProgramManager; const wgpu::Device& device_; diff --git a/onnxruntime/core/providers/webgpu/webgpu_context.cc b/onnxruntime/core/providers/webgpu/webgpu_context.cc index 638bc7c1f7c05..7c9763d6937f0 100644 --- a/onnxruntime/core/providers/webgpu/webgpu_context.cc +++ b/onnxruntime/core/providers/webgpu/webgpu_context.cc @@ -212,7 +212,7 @@ Status WebGpuContext::Run(const ComputeContext& context, const ProgramBase& prog const auto* program_artifact = program_mgr_->Get(key); if (program_artifact == nullptr) { wgpu::ComputePipeline compute_pipeline; - std::vector shape_uniforms; + std::vector shape_uniform_ranks; auto status = program_mgr_->Build(program, metadata, #ifndef NDEBUG // if debug build @@ -222,29 +222,52 @@ Status WebGpuContext::Run(const ComputeContext& context, const ProgramBase& prog y, z, compute_pipeline, - shape_uniforms); + shape_uniform_ranks); ORT_RETURN_IF_ERROR(status); program_artifact = program_mgr_->Set(key, ProgramArtifact{program, std::move(compute_pipeline), - std::move(shape_uniforms)}); + std::move(shape_uniform_ranks)}); #ifndef NDEBUG // if debug build ORT_ENFORCE(program_artifact != nullptr, "Program artifact should not be nullptr."); #endif } - // prepare uniform info + // prepare shape uniforms for shader variables (if any) and user defined uniforms + std::vector shape_uniforms; + shape_uniforms.reserve(program_artifact->shape_uniform_ranks.size() * 2); + ORT_RETURN_IF_NOT(program_artifact->shape_uniform_ranks.size() == inputs.size() + outputs.size(), + "Invalid program artifact: variable size (", program_artifact->shape_uniform_ranks.size(), + ") does not match current program (input: ", inputs.size(), ", output: ", outputs.size(), ")"); + for (size_t i = 0; i < program_artifact->shape_uniform_ranks.size(); ++i) { + SafeInt expected_rank = program_artifact->shape_uniform_ranks[i]; + if (expected_rank > 0) { + const auto& shape = i < inputs.size() ? (inputs[i].use_override_shape ? inputs[i].override_shape + : inputs[i].tensor->Shape()) + : (outputs[i - inputs.size()].use_override_shape ? outputs[i - inputs.size()].override_shape + : outputs[i - inputs.size()].tensor->Shape()); + ORT_RETURN_IF(expected_rank != shape.NumDimensions(), + "Invalid program artifact: variable[", i, "] rank mismatch. Expected: ", (int)expected_rank, + ", Actual: ", shape.NumDimensions()); + + std::vector dims(shape.NumDimensions()); + std::vector stride(shape.NumDimensions()); + for (size_t j = 0; j < shape.NumDimensions(); ++j) { + dims[j] = SafeInt(shape[j]); + stride[j] = SafeInt(shape.SizeFromDimension(j)); + } - // TODO: also append artifacts uniform info and fill in actual input/output (override) shape value + shape_uniforms.emplace_back(gsl::make_span(dims)); + shape_uniforms.emplace_back(gsl::make_span(stride)); + } + } - // foreach (uniform in artifact) { - // check if match; - // if match, create ProgramUniformVariableValue - // } - const auto& uniforms = program.UniformVariables(); + const size_t uniform_count = shape_uniforms.size() + program.UniformVariables().size(); size_t current_offset = 0; std::vector> uniform_and_offsets; - uniform_and_offsets.reserve(uniforms.size()); - for (const auto& uniform : uniforms) { + uniform_and_offsets.reserve(uniform_count); + for (size_t i = 0; i < uniform_count; i++) { + const auto& uniform = i < shape_uniforms.size() ? shape_uniforms[i] + : program.UniformVariables()[i - shape_uniforms.size()]; bool is_f16 = uniform.data_type == ProgramUniformVariableDataType::Float16; size_t length = uniform.length; From 1d53ac89429586768170aaef23133983b0e78bc3 Mon Sep 17 00:00:00 2001 From: Yulong Wang <7679871+fs-eire@users.noreply.github.com> Date: Tue, 3 Sep 2024 02:18:41 -0700 Subject: [PATCH 36/77] fix build (linux) --- onnxruntime/core/providers/webgpu/shader_helper.h | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/onnxruntime/core/providers/webgpu/shader_helper.h b/onnxruntime/core/providers/webgpu/shader_helper.h index bb04c4ad628aa..ca1bf9ce7ff58 100644 --- a/onnxruntime/core/providers/webgpu/shader_helper.h +++ b/onnxruntime/core/providers/webgpu/shader_helper.h @@ -149,7 +149,7 @@ class ShaderHelper final { Status ValidateVariable(const ProgramOutput& output, const ShaderVariable& var) const; #endif - Status ShaderHelper::ValidateShapeForInputsAndOutputs() const; + Status ValidateShapeForInputsAndOutputs() const; // Generate source code. // From 4d52602a208daae68fab6510eabac5eeb5aac87f Mon Sep 17 00:00:00 2001 From: Yulong Wang <7679871+fs-eire@users.noreply.github.com> Date: Tue, 3 Sep 2024 02:21:21 -0700 Subject: [PATCH 37/77] fix stride --- onnxruntime/core/providers/webgpu/webgpu_context.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/onnxruntime/core/providers/webgpu/webgpu_context.cc b/onnxruntime/core/providers/webgpu/webgpu_context.cc index 7c9763d6937f0..755ebbfd174ca 100644 --- a/onnxruntime/core/providers/webgpu/webgpu_context.cc +++ b/onnxruntime/core/providers/webgpu/webgpu_context.cc @@ -253,7 +253,7 @@ Status WebGpuContext::Run(const ComputeContext& context, const ProgramBase& prog std::vector stride(shape.NumDimensions()); for (size_t j = 0; j < shape.NumDimensions(); ++j) { dims[j] = SafeInt(shape[j]); - stride[j] = SafeInt(shape.SizeFromDimension(j)); + stride[j] = SafeInt(shape.SizeFromDimension(j + 1)); } shape_uniforms.emplace_back(gsl::make_span(dims)); From 3761aad4bd668af38813a330959d70ed30fb2060 Mon Sep 17 00:00:00 2001 From: Yulong Wang <7679871+fs-eire@users.noreply.github.com> Date: Tue, 3 Sep 2024 02:38:24 -0700 Subject: [PATCH 38/77] fix "{res_name}_bi2o_{name}" --- onnxruntime/core/providers/webgpu/shader_variable.cc | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/onnxruntime/core/providers/webgpu/shader_variable.cc b/onnxruntime/core/providers/webgpu/shader_variable.cc index ef80fd3c57f6c..d19116f570992 100644 --- a/onnxruntime/core/providers/webgpu/shader_variable.cc +++ b/onnxruntime/core/providers/webgpu/shader_variable.cc @@ -116,11 +116,11 @@ void ShaderVariable::Impl(std::ostringstream& ss) const { SS(" return 0;\n"); } else { SS(" return "); - for (int i = 0; i < rank_ - 1; i++) { + for (int i = rank_ - 1; i >= 0; i--) { auto idx = broadcasted_result.IndicesGet("indices", i + broadcasted_result.rank_ - rank_); SS(IndicesGet(stride, i), " * (", idx, " % ", IndicesGet(shape, i), ") + "); } - SS(broadcasted_result.IndicesGet("indices", broadcasted_result.rank_ - 1), " % ", IndicesGet(shape, rank_ - 1), ";\n"); + SS("0;\n"); } SS("}\n"); } From 351da844d364d7c762deea845e7e4a81d816c8b2 Mon Sep 17 00:00:00 2001 From: Jiajia Qin Date: Wed, 4 Sep 2024 02:57:38 +0800 Subject: [PATCH 39/77] Add Expand operator (#21933) ### Description ### Motivation and Context --- .../core/providers/webgpu/shader_helper.cc | 4 +- .../core/providers/webgpu/shader_variable.cc | 4 +- .../core/providers/webgpu/tensor/expand.cc | 95 +++++++++++++++++++ .../core/providers/webgpu/tensor/expand.h | 30 ++++++ .../webgpu/webgpu_execution_provider.cc | 4 +- 5 files changed, 131 insertions(+), 6 deletions(-) create mode 100644 onnxruntime/core/providers/webgpu/tensor/expand.cc create mode 100644 onnxruntime/core/providers/webgpu/tensor/expand.h diff --git a/onnxruntime/core/providers/webgpu/shader_helper.cc b/onnxruntime/core/providers/webgpu/shader_helper.cc index d06a6573ab2bd..e6ae5ae0d9403 100644 --- a/onnxruntime/core/providers/webgpu/shader_helper.cc +++ b/onnxruntime/core/providers/webgpu/shader_helper.cc @@ -379,7 +379,7 @@ Status ShaderHelper::GenerateSourceCode(std::string& code, std::vector& sha }; for (const auto& input : vars_[static_cast(ProgramVariableScope::Input)]) { - if (input.rank_ > 1) { + if (input.rank_ > 1 && (input.usage_ & ShaderVariable::Usage::UseUniform)) { std::string shape = input.name_ + "_shape"; std::string stride = input.name_ + "_stride"; append_uniform(shape, ProgramUniformVariableDataType::Uint32, input.rank_); @@ -388,7 +388,7 @@ Status ShaderHelper::GenerateSourceCode(std::string& code, std::vector& sha } for (const auto& output : vars_[static_cast(ProgramVariableScope::Output)]) { - if (output.rank_ > 1) { + if (output.rank_ > 1 && (output.usage_ & ShaderVariable::Usage::UseUniform)) { std::string shape = output.name_ + "_shape"; std::string stride = output.name_ + "_stride"; append_uniform(shape, ProgramUniformVariableDataType::Uint32, output.rank_); diff --git a/onnxruntime/core/providers/webgpu/shader_variable.cc b/onnxruntime/core/providers/webgpu/shader_variable.cc index d19116f570992..a652d720dbf7b 100644 --- a/onnxruntime/core/providers/webgpu/shader_variable.cc +++ b/onnxruntime/core/providers/webgpu/shader_variable.cc @@ -64,11 +64,11 @@ void ShaderVariable::Impl(std::ostringstream& ss) const { SS("const ", stride, " = ", indices_type, "("); first = true; - for (int i = rank_ - 1; i >= 0; i--) { + for (int i = 1; i <= rank_; i++) { if (!first) { ss << ","; } - ss << dims_.SizeToDimension(i); + ss << dims_.SizeFromDimension(i); first = false; } ss << ");\n"; diff --git a/onnxruntime/core/providers/webgpu/tensor/expand.cc b/onnxruntime/core/providers/webgpu/tensor/expand.cc new file mode 100644 index 0000000000000..4d241da544150 --- /dev/null +++ b/onnxruntime/core/providers/webgpu/tensor/expand.cc @@ -0,0 +1,95 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "core/providers/webgpu/tensor/expand.h" +#include "core/providers/webgpu/shader_helper.h" +#include "core/providers/webgpu/webgpu_supported_types.h" + +namespace onnxruntime { +namespace webgpu { + +namespace { +Status ComputeOutputShape(const std::string& node_name, const TensorShape& lhs_shape, const TensorShape& rhs_shape, TensorShape& out_shape) { + size_t lhs_rank = lhs_shape.NumDimensions(); + size_t rhs_rank = rhs_shape.NumDimensions(); + size_t out_rank = std::max(lhs_rank, rhs_rank); + + std::vector output_dims(out_rank, 0); + for (size_t i = 0; i < out_rank; ++i) { + int64_t lhs_dim = 1; + if (i < lhs_rank) + lhs_dim = lhs_shape[lhs_rank - 1 - i]; + int64_t rhs_dim = 1; + if (i < rhs_rank) + rhs_dim = rhs_shape[rhs_rank - 1 - i]; + int64_t max = std::max(lhs_dim, rhs_dim); + int64_t min = std::min(lhs_dim, rhs_dim); + int64_t out_dim = (min == 0 ? min : max); // special case a dim value of 0. + if (lhs_dim != out_dim && lhs_dim != 1) + return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, node_name, ": left operand cannot broadcast on dim ", lhs_rank - 1 - i, + " LeftShape: ", lhs_shape.ToString(), ", RightShape: ", rhs_shape.ToString()); + if (rhs_dim != out_dim && rhs_dim != 1) + return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, node_name, ": right operand cannot broadcast on dim ", rhs_rank - 1 - i, + " LeftShape: ", lhs_shape.ToString(), ", RightShape: ", rhs_shape.ToString()); + output_dims[out_rank - 1 - i] = out_dim; + } + out_shape = TensorShape(output_dims); + return Status::OK(); +} +} // namespace + +Status ExpandProgram::GenerateShaderCode(ShaderHelper& shader) const { + const auto& input = shader.AddInput("input", + ToProgramVariableDataType(Inputs()[0].tensor->GetElementType()), + ShaderVariable::UseUniform); + const auto& output = shader.AddOutput("output", + ToProgramVariableDataType(Outputs()[0].tensor->GetElementType()), + ShaderVariable::UseUniform); + + shader.MainFunctionBody(shader.GuardAgainstOutOfBoundsWorkgroupSizes("uniforms.vec_size"), + "let output_indices = ", output.OffsetToIndices("global_idx"), ";\n", + "let input_offset = ", input.BroadcastedIndicesToOffset("output_indices", output), ";\n", + output.SetByOffset("global_idx", input.GetByOffset("input_offset"))); + + return Status::OK(); +} + +Status Expand::ComputeInternal(ComputeContext& context) const { + const auto* input_tensor = context.Input(0); + const auto* input_shape_tensor = context.Input(1); + + const auto* p_shape = input_shape_tensor->Data(); + TensorShapeVector output_dims{p_shape, p_shape + input_shape_tensor->Shape().Size()}; + TensorShape output_shape(output_dims); + ORT_RETURN_IF_ERROR(ComputeOutputShape(Node().Name(), input_tensor->Shape(), output_dims, output_shape)); + + auto* output_tensor = context.Output(0, output_shape); + SafeInt vec_size = output_shape.Size(); + ExpandProgram program{"Expand"}; + program + .Inputs({{input_tensor, ProgramTensorMetadataDependency::TypeAndRank}}) + .Outputs({{output_tensor, ProgramTensorMetadataDependency::Rank}}) + .DispatchGroupSize((vec_size + WORKGROUP_SIZE - 1) / WORKGROUP_SIZE) + .UniformVariables({ + {static_cast(vec_size)}, + }); + return context.RunProgram(program); +} + +#define WEBGPU_EXPAND_KERNEL(OP_TYPE, VERSION, KERNEL_CLASS, TYPE) \ + ONNX_OPERATOR_KERNEL_EX( \ + OP_TYPE, kOnnxDomain, VERSION, kWebGpuExecutionProvider, \ + KernelDefBuilder().TypeConstraint("T", TYPE).InputMemoryType(OrtMemTypeCPU, 1), \ + KERNEL_CLASS); + +#define WEBGPU_EXPAND_VERSIONED_KERNEL(OP_TYPE, VERSION_FROM, VERSION_TO, KERNEL_CLASS, TYPE) \ + ONNX_OPERATOR_VERSIONED_KERNEL_EX( \ + OP_TYPE, kOnnxDomain, VERSION_FROM, VERSION_TO, kWebGpuExecutionProvider, \ + KernelDefBuilder().TypeConstraint("T", TYPE).InputMemoryType(OrtMemTypeCPU, 1), \ + KERNEL_CLASS); + +WEBGPU_EXPAND_VERSIONED_KERNEL(Expand, 8, 12, Expand, WebGpuSupportedFloatTypes()) +WEBGPU_EXPAND_KERNEL(Expand, 13, Expand, WebGpuSupportedFloatTypes()) + +} // namespace webgpu +}; // namespace onnxruntime diff --git a/onnxruntime/core/providers/webgpu/tensor/expand.h b/onnxruntime/core/providers/webgpu/tensor/expand.h new file mode 100644 index 0000000000000..a5c24f1fa4969 --- /dev/null +++ b/onnxruntime/core/providers/webgpu/tensor/expand.h @@ -0,0 +1,30 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include "core/providers/webgpu/program.h" +#include "core/providers/webgpu/webgpu_kernel.h" + +namespace onnxruntime { +namespace webgpu { + +class ExpandProgram final : public Program { + public: + ExpandProgram(const std::string& kernel_name) : Program{kernel_name} { + } + + Status GenerateShaderCode(ShaderHelper& sh) const override; + + WEBGPU_PROGRAM_DEFINE_UNIFORM_VARIABLES({"vec_size", ProgramUniformVariableDataType::Uint32}); +}; + +class Expand final : public WebGpuKernel { + public: + Expand(const OpKernelInfo& info) : WebGpuKernel(info) {} + + Status ComputeInternal(ComputeContext& context) const override; +}; + +} // namespace webgpu +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/webgpu/webgpu_execution_provider.cc b/onnxruntime/core/providers/webgpu/webgpu_execution_provider.cc index 202742a1c79bc..1ee7a51618f7f 100644 --- a/onnxruntime/core/providers/webgpu/webgpu_execution_provider.cc +++ b/onnxruntime/core/providers/webgpu/webgpu_execution_provider.cc @@ -622,8 +622,8 @@ std::unique_ptr RegisterKernels() { // BuildKernelCreateInfo, // BuildKernelCreateInfo, // BuildKernelCreateInfo, - // BuildKernelCreateInfo, - // BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, // BuildKernelCreateInfo, // BuildKernelCreateInfo, From 0b7ce771a7c8008607a223a2e9ed4fceab333b82 Mon Sep 17 00:00:00 2001 From: Yulong Wang <7679871+fs-eire@users.noreply.github.com> Date: Tue, 3 Sep 2024 13:36:52 -0700 Subject: [PATCH 40/77] support onnxruntime_test_all --- .../webgpu/webgpu_provider_factory.cc | 18 +++++++++--------- .../webgpu/webgpu_provider_factory_creator.h | 6 ++++-- .../core/session/provider_registration.cc | 2 +- onnxruntime/test/providers/base_tester.cc | 3 +++ onnxruntime/test/util/default_providers.cc | 7 ++++++- 5 files changed, 23 insertions(+), 13 deletions(-) diff --git a/onnxruntime/core/providers/webgpu/webgpu_provider_factory.cc b/onnxruntime/core/providers/webgpu/webgpu_provider_factory.cc index e871b66f1dc92..3848ccfc19f51 100644 --- a/onnxruntime/core/providers/webgpu/webgpu_provider_factory.cc +++ b/onnxruntime/core/providers/webgpu/webgpu_provider_factory.cc @@ -31,7 +31,7 @@ struct WebGpuProviderFactory : IExecutionProviderFactory { WebGpuExecutionProviderInfo info_; }; -std::shared_ptr WebGpuProviderFactoryCreator::Create(const SessionOptions* session_options) { +std::shared_ptr WebGpuProviderFactoryCreator::Create(const ConfigOptions& config_options) { // // STEP.1 - prepare WebGpuExecutionProviderInfo // @@ -43,7 +43,7 @@ std::shared_ptr WebGpuProviderFactoryCreator::Create( }; std::string preferred_layout_str; - if (session_options->config_options.TryGetConfigEntry(kPreferredLayout, preferred_layout_str)) { + if (config_options.TryGetConfigEntry(kPreferredLayout, preferred_layout_str)) { if (preferred_layout_str == kPreferredLayout_NHWC) { webgpu_ep_info.data_layout = DataLayout::NHWC; } else if (preferred_layout_str == kPreferredLayout_NCHW) { @@ -56,7 +56,7 @@ std::shared_ptr WebGpuProviderFactoryCreator::Create( << preferred_layout_str << "\")"; std::string enable_graph_capture_str; - if (session_options->config_options.TryGetConfigEntry(kEnableGraphCapture, enable_graph_capture_str)) { + if (config_options.TryGetConfigEntry(kEnableGraphCapture, enable_graph_capture_str)) { if (enable_graph_capture_str == kkEnableGraphCapture_ON) { webgpu_ep_info.enable_graph_capture = true; } else if (enable_graph_capture_str == kkEnableGraphCapture_OFF) { @@ -67,10 +67,10 @@ std::shared_ptr WebGpuProviderFactoryCreator::Create( } LOGS_DEFAULT(VERBOSE) << "WebGPU EP graph capture enable: " << webgpu_ep_info.enable_graph_capture; - auto parse_buffer_cache_mode = [session_options](const std::string& config_entry_str, + auto parse_buffer_cache_mode = [&config_options](const std::string& config_entry_str, webgpu::BufferCacheMode default_value) -> webgpu::BufferCacheMode { std::string buffer_cache_mode_str; - if (session_options->config_options.TryGetConfigEntry(config_entry_str, buffer_cache_mode_str)) { + if (config_options.TryGetConfigEntry(config_entry_str, buffer_cache_mode_str)) { if (buffer_cache_mode_str == kBufferCacheMode_Disabled) { return webgpu::BufferCacheMode::Disabled; } else if (buffer_cache_mode_str == kBufferCacheMode_LazyRelease) { @@ -104,14 +104,14 @@ std::shared_ptr WebGpuProviderFactoryCreator::Create( // int context_id = 0; std::string context_id_str; - if (session_options->config_options.TryGetConfigEntry(kDeviceId, context_id_str)) { + if (config_options.TryGetConfigEntry(kDeviceId, context_id_str)) { ORT_ENFORCE(std::errc{} == std::from_chars(context_id_str.data(), context_id_str.data() + context_id_str.size(), context_id).ec); } size_t webgpu_instance = 0; std::string webgpu_instance_str; - if (session_options->config_options.TryGetConfigEntry(kWebGpuInstance, webgpu_instance_str)) { + if (config_options.TryGetConfigEntry(kWebGpuInstance, webgpu_instance_str)) { static_assert(sizeof(WGPUInstance) == sizeof(size_t), "WGPUInstance size mismatch"); ORT_ENFORCE(std::errc{} == std::from_chars(webgpu_instance_str.data(), webgpu_instance_str.data() + webgpu_instance_str.size(), webgpu_instance).ec); @@ -119,7 +119,7 @@ std::shared_ptr WebGpuProviderFactoryCreator::Create( size_t webgpu_adapter = 0; std::string webgpu_adapter_str; - if (session_options->config_options.TryGetConfigEntry(kWebGpuAdapter, webgpu_adapter_str)) { + if (config_options.TryGetConfigEntry(kWebGpuAdapter, webgpu_adapter_str)) { static_assert(sizeof(WGPUAdapter) == sizeof(size_t), "WGPUAdapter size mismatch"); ORT_ENFORCE(std::errc{} == std::from_chars(webgpu_adapter_str.data(), webgpu_adapter_str.data() + webgpu_adapter_str.size(), webgpu_adapter).ec); @@ -127,7 +127,7 @@ std::shared_ptr WebGpuProviderFactoryCreator::Create( size_t webgpu_device = 0; std::string webgpu_device_str; - if (session_options->config_options.TryGetConfigEntry(kWebGpuDevice, webgpu_device_str)) { + if (config_options.TryGetConfigEntry(kWebGpuDevice, webgpu_device_str)) { static_assert(sizeof(WGPUDevice) == sizeof(size_t), "WGPUDevice size mismatch"); ORT_ENFORCE(std::errc{} == std::from_chars(webgpu_device_str.data(), webgpu_device_str.data() + webgpu_device_str.size(), webgpu_device).ec); diff --git a/onnxruntime/core/providers/webgpu/webgpu_provider_factory_creator.h b/onnxruntime/core/providers/webgpu/webgpu_provider_factory_creator.h index 7fac9234b949b..e0030a3ec2a11 100644 --- a/onnxruntime/core/providers/webgpu/webgpu_provider_factory_creator.h +++ b/onnxruntime/core/providers/webgpu/webgpu_provider_factory_creator.h @@ -8,11 +8,13 @@ #include "core/framework/provider_options.h" #include "core/providers/providers.h" +#include "core/providers/webgpu/webgpu_provider_options.h" + namespace onnxruntime { -struct SessionOptions; +struct ConfigOptions; struct WebGpuProviderFactoryCreator { - static std::shared_ptr Create(const SessionOptions* session_options); + static std::shared_ptr Create(const ConfigOptions& config_options); }; } // namespace onnxruntime diff --git a/onnxruntime/core/session/provider_registration.cc b/onnxruntime/core/session/provider_registration.cc index da97cdc25ab12..156b59a7af10a 100644 --- a/onnxruntime/core/session/provider_registration.cc +++ b/onnxruntime/core/session/provider_registration.cc @@ -135,7 +135,7 @@ ORT_API_STATUS_IMPL(OrtApis::SessionOptionsAppendExecutionProvider, #endif } else if (strcmp(provider_name, "WebGPU") == 0) { #if defined(USE_WEBGPU) - options->provider_factories.push_back(WebGpuProviderFactoryCreator::Create(&(options->value))); + options->provider_factories.push_back(WebGpuProviderFactoryCreator::Create(options->value.config_options)); #else status = create_not_supported_status(); #endif diff --git a/onnxruntime/test/providers/base_tester.cc b/onnxruntime/test/providers/base_tester.cc index 01de15e6f8ec8..dea39bc99d3e9 100644 --- a/onnxruntime/test/providers/base_tester.cc +++ b/onnxruntime/test/providers/base_tester.cc @@ -657,6 +657,7 @@ void BaseTester::RunWithConfig(size_t* number_of_pre_packed_weights_counter, kQnnExecutionProvider, kSnpeExecutionProvider, kXnnpackExecutionProvider, + kWebGpuExecutionProvider, }; // need to special case any synthetic EP names in the exclude list @@ -712,6 +713,8 @@ void BaseTester::RunWithConfig(size_t* number_of_pre_packed_weights_counter, execution_provider = DefaultXnnpackExecutionProvider(); else if (provider_type == onnxruntime::kDmlExecutionProvider) execution_provider = DefaultDmlExecutionProvider(); + else if (provider_type == onnxruntime::kWebGpuExecutionProvider) + execution_provider = DefaultWebGpuExecutionProvider(); // skip if execution provider is disabled if (execution_provider == nullptr) diff --git a/onnxruntime/test/util/default_providers.cc b/onnxruntime/test/util/default_providers.cc index 871285269daf4..c9c64003ddabb 100644 --- a/onnxruntime/test/util/default_providers.cc +++ b/onnxruntime/test/util/default_providers.cc @@ -303,7 +303,12 @@ std::unique_ptr DefaultXnnpackExecutionProvider() { std::unique_ptr DefaultWebGpuExecutionProvider() { #ifdef USE_WEBGPU - return WebGpuProviderFactoryCreator::Create(nullptr)->CreateProvider(); + ConfigOptions config_options{}; + // Disable storage buffer cache + ORT_ENFORCE(config_options.AddConfigEntry(webgpu::options::kStorageBufferCacheMode, + webgpu::options::kBufferCacheMode_Disabled) + .IsOK()); + return WebGpuProviderFactoryCreator::Create(config_options)->CreateProvider(); #else return nullptr; #endif From 33726b1aa5f435a0d6e7701b2de6eff29404f66d Mon Sep 17 00:00:00 2001 From: Guenther Schmuelling Date: Tue, 3 Sep 2024 15:34:00 -0700 Subject: [PATCH 41/77] reflect change in WebGpuProviderFactoryCreator::Create signature (#21971) reflect change in WebGpuProviderFactoryCreator::Create signature --- onnxruntime/python/onnxruntime_pybind_state.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/onnxruntime/python/onnxruntime_pybind_state.cc b/onnxruntime/python/onnxruntime_pybind_state.cc index 036585586d9ac..01889df8fec1d 100644 --- a/onnxruntime/python/onnxruntime_pybind_state.cc +++ b/onnxruntime/python/onnxruntime_pybind_state.cc @@ -1210,7 +1210,7 @@ std::unique_ptr CreateExecutionProviderInstance( #endif } else if (type == kWebGpuExecutionProvider) { #if defined(USE_WEBGPU) - return onnxruntime::WebGpuProviderFactoryCreator::Create(&session_options)->CreateProvider(); + return onnxruntime::WebGpuProviderFactoryCreator::Create(session_options.config_options)->CreateProvider(); #endif } else if (type == kCannExecutionProvider) { #ifdef USE_CANN From 50ea9eb959855544f4f42d62073bb3c0ad1505c4 Mon Sep 17 00:00:00 2001 From: Guenther Schmuelling Date: Tue, 3 Sep 2024 16:41:49 -0700 Subject: [PATCH 42/77] compare the content of WEBGPU_BUFFER, not the address (#21967) On linux (not sure about windows) WEBGPU_BUFFER is defined in multiple object files and comparing the address is not sufficient - use strcmp. onnxruntime uses strcmp for the most but there are some other places that compare against address which might make trouble if passed acrross object file boundary. --- onnxruntime/core/providers/webgpu/webgpu_context.cc | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/onnxruntime/core/providers/webgpu/webgpu_context.cc b/onnxruntime/core/providers/webgpu/webgpu_context.cc index 755ebbfd174ca..343da693c716b 100644 --- a/onnxruntime/core/providers/webgpu/webgpu_context.cc +++ b/onnxruntime/core/providers/webgpu/webgpu_context.cc @@ -140,7 +140,7 @@ Status WebGpuContext::Run(const ComputeContext& context, const ProgramBase& prog return tensor != nullptr && tensor->Location().mem_type == OrtMemType::OrtMemTypeDefault && tensor->Location().device.Type() == OrtDevice::GPU && - tensor->Location().name == WEBGPU_BUFFER; + !strcmp(tensor->Location().name, WEBGPU_BUFFER); }), "All inputs must be tensors on WebGPU buffers."); @@ -149,7 +149,7 @@ Status WebGpuContext::Run(const ComputeContext& context, const ProgramBase& prog return tensor != nullptr && tensor->Location().mem_type == OrtMemType::OrtMemTypeDefault && tensor->Location().device.Type() == OrtDevice::GPU && - tensor->Location().name == WEBGPU_BUFFER; + !strcmp(tensor->Location().name, WEBGPU_BUFFER); }), "All outputs must be tensors on WebGPU buffers."); #endif From d6f6148fd58e33287c49ba5c0db13c2cac4d7d30 Mon Sep 17 00:00:00 2001 From: Yulong Wang <7679871+fs-eire@users.noreply.github.com> Date: Tue, 3 Sep 2024 16:38:32 -0700 Subject: [PATCH 43/77] fix tanh --- .../core/providers/webgpu/math/unary_elementwise_ops.cc | 4 +++- .../test/providers/cpu/math/element_wise_ops_test.cc | 9 ++++++++- 2 files changed, 11 insertions(+), 2 deletions(-) diff --git a/onnxruntime/core/providers/webgpu/math/unary_elementwise_ops.cc b/onnxruntime/core/providers/webgpu/math/unary_elementwise_ops.cc index 97dd2c5984631..9d47cab347296 100644 --- a/onnxruntime/core/providers/webgpu/math/unary_elementwise_ops.cc +++ b/onnxruntime/core/providers/webgpu/math/unary_elementwise_ops.cc @@ -167,7 +167,9 @@ WEBGPU_ELEMENTWISE_KERNEL(Sinh, 9, WebGpuSupportedFloatTypes()) WEBGPU_ELEMENTWISE_IMPL(Cosh, "cosh(a)") WEBGPU_ELEMENTWISE_KERNEL(Cosh, 9, WebGpuSupportedFloatTypes()) -WEBGPU_ELEMENTWISE_IMPL(Tanh, "tanh(a)") +// built-in function tanh() does not work with large input (f32 88.7 or f16 11.09) +// https://github.com/gpuweb/gpuweb/issues/4458 +WEBGPU_ELEMENTWISE_IMPL(Tanh, "sign(a) * (1 - exp(-2 * abs(a))) / (1 + exp(-2 * abs(a)))") WEBGPU_ELEMENTWISE_VERSIONED_KERNEL(Tanh, 6, 12, WebGpuSupportedFloatTypes()) WEBGPU_ELEMENTWISE_KERNEL(Tanh, 13, WebGpuSupportedFloatTypes()) diff --git a/onnxruntime/test/providers/cpu/math/element_wise_ops_test.cc b/onnxruntime/test/providers/cpu/math/element_wise_ops_test.cc index bd3d21d4929f3..4ca915dd394c1 100644 --- a/onnxruntime/test/providers/cpu/math/element_wise_ops_test.cc +++ b/onnxruntime/test/providers/cpu/math/element_wise_ops_test.cc @@ -3016,7 +3016,14 @@ TEST(MathOpTest, Tan) { TEST(MathOpTest, Asin) { OpTester test("Asin"); - float abs_error = DefaultDmlExecutionProvider().get() != nullptr ? 0.0001f : -1.0f; + float abs_error = +#ifdef _WIN32 + // Set abs_error to 0.0001f for built-in function asin() in HLSL based EPs (DML and WebGPU) + DefaultDmlExecutionProvider().get() != nullptr || DefaultWebGpuExecutionProvider().get() != nullptr + ? 0.0001f + : +#endif + -1.0f; TrigFloatTest<::asinf>(test, {-1.0f, -0.5f, 0.0f, 0.5f, 1.0f}, abs_error); } From 626edafbd87e70cdc4040786c92edda0f8845b82 Mon Sep 17 00:00:00 2001 From: Yulong Wang <7679871+fs-eire@users.noreply.github.com> Date: Tue, 3 Sep 2024 17:08:50 -0700 Subject: [PATCH 44/77] support size==0 for element wise operators --- .../core/providers/webgpu/math/unary_elementwise_ops.cc | 3 +++ 1 file changed, 3 insertions(+) diff --git a/onnxruntime/core/providers/webgpu/math/unary_elementwise_ops.cc b/onnxruntime/core/providers/webgpu/math/unary_elementwise_ops.cc index 9d47cab347296..079a192213775 100644 --- a/onnxruntime/core/providers/webgpu/math/unary_elementwise_ops.cc +++ b/onnxruntime/core/providers/webgpu/math/unary_elementwise_ops.cc @@ -25,6 +25,9 @@ Status UnaryElementwise::ComputeInternal(ComputeContext& context) const { const auto* input_tensor = context.Input(0); auto* output_tensor = context.Output(0, input_tensor->Shape()); int64_t size = input_tensor->Shape().Size(); + if (size == 0) { + return Status::OK(); + } SafeInt vec_size = (size + 3) / 4; UnaryElementwiseProgram program{kernel_name_, expression_, additional_impl_, additional_usage_}; program From bacc54cc09e2140a1aa33fb1202dfd2295ecc8c0 Mon Sep 17 00:00:00 2001 From: Yulong Wang <7679871+fs-eire@users.noreply.github.com> Date: Tue, 3 Sep 2024 20:20:58 -0700 Subject: [PATCH 45/77] use shared ComputeBroadcastOutputShape() --- .../core/providers/webgpu/tensor/expand.cc | 34 ++----------------- 1 file changed, 3 insertions(+), 31 deletions(-) diff --git a/onnxruntime/core/providers/webgpu/tensor/expand.cc b/onnxruntime/core/providers/webgpu/tensor/expand.cc index 4d241da544150..53991365d6543 100644 --- a/onnxruntime/core/providers/webgpu/tensor/expand.cc +++ b/onnxruntime/core/providers/webgpu/tensor/expand.cc @@ -1,6 +1,8 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. +#include "core/providers/common.h" + #include "core/providers/webgpu/tensor/expand.h" #include "core/providers/webgpu/shader_helper.h" #include "core/providers/webgpu/webgpu_supported_types.h" @@ -8,36 +10,6 @@ namespace onnxruntime { namespace webgpu { -namespace { -Status ComputeOutputShape(const std::string& node_name, const TensorShape& lhs_shape, const TensorShape& rhs_shape, TensorShape& out_shape) { - size_t lhs_rank = lhs_shape.NumDimensions(); - size_t rhs_rank = rhs_shape.NumDimensions(); - size_t out_rank = std::max(lhs_rank, rhs_rank); - - std::vector output_dims(out_rank, 0); - for (size_t i = 0; i < out_rank; ++i) { - int64_t lhs_dim = 1; - if (i < lhs_rank) - lhs_dim = lhs_shape[lhs_rank - 1 - i]; - int64_t rhs_dim = 1; - if (i < rhs_rank) - rhs_dim = rhs_shape[rhs_rank - 1 - i]; - int64_t max = std::max(lhs_dim, rhs_dim); - int64_t min = std::min(lhs_dim, rhs_dim); - int64_t out_dim = (min == 0 ? min : max); // special case a dim value of 0. - if (lhs_dim != out_dim && lhs_dim != 1) - return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, node_name, ": left operand cannot broadcast on dim ", lhs_rank - 1 - i, - " LeftShape: ", lhs_shape.ToString(), ", RightShape: ", rhs_shape.ToString()); - if (rhs_dim != out_dim && rhs_dim != 1) - return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, node_name, ": right operand cannot broadcast on dim ", rhs_rank - 1 - i, - " LeftShape: ", lhs_shape.ToString(), ", RightShape: ", rhs_shape.ToString()); - output_dims[out_rank - 1 - i] = out_dim; - } - out_shape = TensorShape(output_dims); - return Status::OK(); -} -} // namespace - Status ExpandProgram::GenerateShaderCode(ShaderHelper& shader) const { const auto& input = shader.AddInput("input", ToProgramVariableDataType(Inputs()[0].tensor->GetElementType()), @@ -61,7 +33,7 @@ Status Expand::ComputeInternal(ComputeContext& context) const { const auto* p_shape = input_shape_tensor->Data(); TensorShapeVector output_dims{p_shape, p_shape + input_shape_tensor->Shape().Size()}; TensorShape output_shape(output_dims); - ORT_RETURN_IF_ERROR(ComputeOutputShape(Node().Name(), input_tensor->Shape(), output_dims, output_shape)); + ORT_RETURN_IF_ERROR(ComputeBroadcastOutputShape(Node().Name(), input_tensor->Shape(), output_dims, output_shape)); auto* output_tensor = context.Output(0, output_shape); SafeInt vec_size = output_shape.Size(); From 7ecc5bbaac5e09508dc789add1ce2a8cf96e5338 Mon Sep 17 00:00:00 2001 From: Yulong Wang <7679871+fs-eire@users.noreply.github.com> Date: Wed, 4 Sep 2024 12:11:08 -0700 Subject: [PATCH 46/77] add workgroup_idx --- onnxruntime/core/providers/webgpu/shader_helper.cc | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/onnxruntime/core/providers/webgpu/shader_helper.cc b/onnxruntime/core/providers/webgpu/shader_helper.cc index e6ae5ae0d9403..245de6d7c2ed0 100644 --- a/onnxruntime/core/providers/webgpu/shader_helper.cc +++ b/onnxruntime/core/providers/webgpu/shader_helper.cc @@ -67,10 +67,11 @@ Status ShaderHelper::Init() { body_ << ") {\n"; if (is_1d_dispatch) { body_ << " let global_idx = global_id.x;\n" - " let local_idx = local_id.x;\n"; + " let local_idx = local_id.x;\n" + " let workgroup_idx = workgroup_id.x;\n"; } else { - body_ << " let global_idx = (workgroup_id.z * num_workgroups[0] * num_workgroups[1] + workgroup_id.y * num_workgroups[0] + workgroup_id.x)\n" - " * (workgroup_size_x * workgroup_size_y * workgroup_size_z) + local_idx;\n"; + body_ << " let workgroup_idx = workgroup_id.z * num_workgroups[0] * num_workgroups[1] + workgroup_id.y * num_workgroups[0] + workgroup_id.x;\n" + " let global_idx = workgroup_idx * (workgroup_size_x * workgroup_size_y * workgroup_size_z) + local_idx;\n"; } // init additional implementation string stream From ae836b129c71ee942e5db5a3319eced43dbec279 Mon Sep 17 00:00:00 2001 From: Yulong Wang <7679871+fs-eire@users.noreply.github.com> Date: Wed, 4 Sep 2024 15:05:04 -0700 Subject: [PATCH 47/77] expose name for shader variable --- .../core/providers/webgpu/shader_variable.cc | 162 +++++++++--------- .../core/providers/webgpu/shader_variable.h | 14 +- 2 files changed, 93 insertions(+), 83 deletions(-) diff --git a/onnxruntime/core/providers/webgpu/shader_variable.cc b/onnxruntime/core/providers/webgpu/shader_variable.cc index a652d720dbf7b..0b7a7d390057c 100644 --- a/onnxruntime/core/providers/webgpu/shader_variable.cc +++ b/onnxruntime/core/providers/webgpu/shader_variable.cc @@ -13,13 +13,76 @@ namespace onnxruntime { namespace webgpu { +namespace { +constexpr static const std::string_view STORAGE_TYPE[] = { + "f32", // f32 + "vec2", // vec2f32 + "vec4", // vec4f32 + "f16", // f16 + "vec2", // vec2f16 + "vec4", // vec4f16 + "i32", // i32 + "vec2", // vec2i32 + "vec4", // vec4i32 + "u32", // u32 + "vec2", // vec2u32 + "vec4", // vec4u32 + "vec2", // int64 + "vec2", // uint64 + "u32", // vec4bool +}; + +constexpr static const std::string_view VALUE_TYPE[] = { + "f32", // f32 + "vec2", // vec2f32 + "vec4", // vec4f32 + "f16", // f16 + "vec2", // vec2f16 + "vec4", // vec4f16 + "i32", // i32 + "vec2", // vec2i32 + "vec4", // vec4i32 + "u32", // u32 + "vec2", // vec2u32 + "vec4", // vec4u32 + "i32", // int64 (trancated to i32) + "u32", // uint64 (trancated to u32) + "vec4", // vec4bool +}; + +constexpr static const std::string_view ELEMENT_TYPE[] = { + "f32", // f32 + "f32", // vec2f32 + "f32", // vec4f32 + "f16", // f16 + "f16", // vec2f16 + "f16", // vec4f16 + "i32", // i32 + "i32", // vec2i32 + "i32", // vec4i32 + "u32", // u32 + "u32", // vec2u32 + "u32", // vec4u32 + "i32", // int64 + "u32", // uint64 + "bool", // vec4bool +}; + +} // namespace + ShaderVariable::ShaderVariable(std::string_view name, ProgramVariableDataType type, Usage usage, const TensorShape& dims) : name_(name), type_(type), num_components_{NumberOfComponents(type)}, rank_{SafeInt(dims.NumDimensions())}, dims_{dims}, - usage_(usage) { + usage_(usage), + indices_type_{rank_ < 2 ? "u32" + : (rank_ < 4 ? MakeStringWithClassicLocale("vec", rank_, "") + : MakeStringWithClassicLocale("array"))}, + value_type_alias_{name_ + "_value_t"}, + element_type_alias_{name_ + "_element_t"}, + indices_type_alias_{name_ + "_indices_t"} { ORT_ENFORCE(type_ != ProgramVariableDataType::InvalidType, "Invalid type for variable ", name_); ORT_ENFORCE(num_components_ > 0, "Invalid number of components for variable ", name_); } @@ -27,29 +90,23 @@ ShaderVariable::ShaderVariable(std::string_view name, ProgramVariableDataType ty void ShaderVariable::Impl(std::ostringstream& ss) const { // Start generating code - const std::string value_t = name_ + "_value_t"; - const std::string element_t = name_ + "_element_t"; - const std::string shape = (usage_ & UseUniform) ? "uniforms." + name_ + "_shape" : name_ + "_shape"; const std::string stride = (usage_ & UseUniform) ? "uniforms." + name_ + "_stride" : name_ + "_stride"; // Types - std::string_view value_type = (usage_ & UseValueTypeAlias) ? value_t : ValueType(); - const std::string indices_type = (usage_ & UseIndicesTypeAlias) ? name_ + "_indices_t" : IndicesType(); - if (usage_ & UseValueTypeAlias) { - SS("alias ", name_, "_value_t = ", ValueType(), ";\n"); + SS("alias ", value_type_alias_, " = ", VALUE_TYPE[static_cast(type_)], ";\n"); } if (usage_ & UseIndicesTypeAlias) { - SS("alias ", name_, "_indices_t = ", IndicesType(), ";\n"); + SS("alias ", indices_type_alias_, " = ", indices_type_, ";\n"); } if (usage_ & UseElementTypeAlias) { - SS("alias ", name_, "_element_t = ", ElementType(), ";\n"); + SS("alias ", element_type_alias_, " = ", ELEMENT_TYPE[static_cast(type_)], ";\n"); } // Need shape and strides when (not use uniform) and (any other usage is enabled) if (!(usage_ & UseUniform) && (usage_ & ~UseUniform)) { - SS("const ", shape, " = ", indices_type, "("); + SS("const ", shape, " = ", IndicesType(), "("); bool first = true; for (auto dim : dims_.GetDims()) { @@ -62,7 +119,7 @@ void ShaderVariable::Impl(std::ostringstream& ss) const { } ss << ");\n"; - SS("const ", stride, " = ", indices_type, "("); + SS("const ", stride, " = ", IndicesType(), "("); first = true; for (int i = 1; i <= rank_; i++) { if (!first) { @@ -77,8 +134,8 @@ void ShaderVariable::Impl(std::ostringstream& ss) const { // Implementation of "fn o2i_{name}" if (usage_ & UseOffsetToIndices) { if (rank_ >= 2) { - SS("fn o2i_", name_, "(offset : u32)->", indices_type, " {\n"); - SS(" var indices: ", indices_type, ";\n"); + SS("fn o2i_", name_, "(offset : u32)->", IndicesType(), " {\n"); + SS(" var indices: ", IndicesType(), ";\n"); SS(" var current = offset;\n"); for (int i = 0; i < rank_ - 1; i++) { auto current_stride = GetElementAt(stride, i, rank_); @@ -96,7 +153,7 @@ void ShaderVariable::Impl(std::ostringstream& ss) const { // Implementation of "fn i2o_{name}" if (usage_ & UseIndicesToOffset) { if (rank_ >= 2) { - SS("fn i2o_", name_, "(indices : ", indices_type, ")->u32 {\n"); + SS("fn i2o_", name_, "(indices : ", IndicesType(), ")->u32 {\n"); SS(" return "); for (int i = 0; i < rank_ - 1; i++) { SS("indices[", i, "] * ", GetElementAt(stride, i, rank_), " + "); @@ -111,7 +168,7 @@ void ShaderVariable::Impl(std::ostringstream& ss) const { // TODO: do we need this if rank < 2? for (const auto& iter : broadcasted_to_) { const auto& broadcasted_result = iter.get(); - SS("fn ", broadcasted_result.name_, "_bi2o_", name_, "(indices : ", broadcasted_result.IndicesType(), ")->u32 {\n"); + SS("fn ", broadcasted_result.name_, "_bi2o_", name_, "(indices : ", broadcasted_result.indices_type_, ")->u32 {\n"); if (rank_ == 0) { SS(" return 0;\n"); } else { @@ -133,7 +190,7 @@ void ShaderVariable::Impl(std::ostringstream& ss) const { for (int i = 1; i < rank_; i++) { SS(", d", i, ": u32"); } - SS(", value: ", value_type, ") {\n"); + SS(", value: ", ValueType(), ") {\n"); SS(" set_", name_, "_by_indices(d0"); for (int i = 1; i < rank_; i++) { SS(", d", i); @@ -146,7 +203,7 @@ void ShaderVariable::Impl(std::ostringstream& ss) const { // Implementation of "fn set_{name}_by_indices" if (usage_ & UseSetByIndices) { if (rank_ >= 2) { - SS("fn set_", name_, "_by_indices(indices: ", indices_type, ", value: ", value_type, ") {\n"); + SS("fn set_", name_, "_by_indices(indices: ", IndicesType(), ", value: ", ValueType(), ") {\n"); SS(" ", SetByOffset("i2o_" + name_ + "(indices)", "value"), "\n"); SS("}\n"); } @@ -159,7 +216,7 @@ void ShaderVariable::Impl(std::ostringstream& ss) const { for (int i = 1; i < rank_; i++) { SS(", d", i, ": u32"); } - SS(")->", value_type, " {\n"); + SS(")->", ValueType(), " {\n"); SS(" return get_", name_, "_by_indices(d0"); for (int i = 1; i < rank_; i++) { SS(", d", i); @@ -172,7 +229,7 @@ void ShaderVariable::Impl(std::ostringstream& ss) const { // Implementation of "fn get_{name}_by_indices" if (usage_ & UseGetByIndices) { if (rank_ >= 2) { - SS("fn get_", name_, "_by_indices(indices: ", indices_type, ")->", value_type, " {\n"); + SS("fn get_", name_, "_by_indices(indices: ", IndicesType(), ")->", ValueType(), " {\n"); SS(" return ", GetByOffset("i2o_" + name_ + "(indices)"), ";\n"); SS("}\n"); } @@ -232,76 +289,19 @@ std::string ShaderVariable::SetByOffsetImpl(std::string_view offset, std::string } std::string_view ShaderVariable::StorageType() const { - constexpr static const std::string_view STORAGE_TYPE[] = { - "f32", // f32 - "vec2", // vec2f32 - "vec4", // vec4f32 - "f16", // f16 - "vec2", // vec2f16 - "vec4", // vec4f16 - "i32", // i32 - "vec2", // vec2i32 - "vec4", // vec4i32 - "u32", // u32 - "vec2", // vec2u32 - "vec4", // vec4u32 - "vec2", // int64 - "vec2", // uint64 - "u32", // vec4bool - }; - return STORAGE_TYPE[static_cast(type_)]; } std::string_view ShaderVariable::ValueType() const { - constexpr static const std::string_view VALUE_TYPE[] = { - "f32", // f32 - "vec2", // vec2f32 - "vec4", // vec4f32 - "f16", // f16 - "vec2", // vec2f16 - "vec4", // vec4f16 - "i32", // i32 - "vec2", // vec2i32 - "vec4", // vec4i32 - "u32", // u32 - "vec2", // vec2u32 - "vec4", // vec4u32 - "i32", // int64 (trancated to i32) - "u32", // uint64 (trancated to u32) - "vec4", // vec4bool - }; - - return VALUE_TYPE[static_cast(type_)]; + return (usage_ & UseValueTypeAlias) ? value_type_alias_ : VALUE_TYPE[static_cast(type_)]; } std::string_view ShaderVariable::ElementType() const { - constexpr static const std::string_view ELEMENT_TYPE[] = { - "f32", // f32 - "f32", // vec2f32 - "f32", // vec4f32 - "f16", // f16 - "f16", // vec2f16 - "f16", // vec4f16 - "i32", // i32 - "i32", // vec2i32 - "i32", // vec4i32 - "u32", // u32 - "u32", // vec2u32 - "u32", // vec4u32 - "i32", // int64 - "u32", // uint64 - "bool", // vec4bool - }; - - return ELEMENT_TYPE[static_cast(type_)]; + return (usage_ & UseElementTypeAlias) ? element_type_alias_ : ELEMENT_TYPE[static_cast(type_)]; } -std::string ShaderVariable::IndicesType() const { - return rank_ < 2 ? "u32" - : (rank_ < 4 ? MakeStringWithClassicLocale("vec", rank_, "") - : MakeStringWithClassicLocale("array")); +std::string_view ShaderVariable::IndicesType() const { + return (usage_ & UseIndicesTypeAlias) ? indices_type_alias_ : indices_type_; } - } // namespace webgpu } // namespace onnxruntime diff --git a/onnxruntime/core/providers/webgpu/shader_variable.h b/onnxruntime/core/providers/webgpu/shader_variable.h index 86eaaac5e1591..778017a50dda7 100644 --- a/onnxruntime/core/providers/webgpu/shader_variable.h +++ b/onnxruntime/core/providers/webgpu/shader_variable.h @@ -58,6 +58,9 @@ class ShaderVariable { ShaderVariable(ShaderVariable&&) = default; ShaderVariable& operator=(ShaderVariable&&) = default; + // get the name of the variable. + std::string_view Name() const; + // create a WGSL expression ({varname}_indices_t) for getting indices from offset. // \param offset: a WGSL expression (u32) representing the offset. inline std::string OffsetToIndices(std::string_view offset_expr) const; @@ -131,11 +134,10 @@ class ShaderVariable { std::string GetByOffsetImpl(std::string_view offset) const; std::string SetByOffsetImpl(std::string_view offset, std::string_view value) const; - std::string_view StorageType() const; std::string_view ValueType() const; std::string_view ElementType() const; - std::string IndicesType() const; + std::string_view IndicesType() const; std::string name_; ProgramVariableDataType type_; @@ -146,6 +148,14 @@ class ShaderVariable { mutable Usage usage_; mutable std::vector> broadcasted_to_; + // unlike storage/element/value type, indices type is not a string view to a constant string. so we need to store it. + std::string indices_type_; + + // the alias for the types + std::string value_type_alias_; + std::string element_type_alias_; + std::string indices_type_alias_; + friend class ShaderHelper; }; From 243078b0de15bbfebbf8b639da7d02ea8863ed70 Mon Sep 17 00:00:00 2001 From: Yulong Wang <7679871+fs-eire@users.noreply.github.com> Date: Wed, 4 Sep 2024 17:08:16 -0700 Subject: [PATCH 48/77] add uniform for 1D variable --- .../core/providers/webgpu/shader_helper.cc | 18 ++--- .../core/providers/webgpu/shader_variable.cc | 66 ++++++++++--------- .../core/providers/webgpu/shader_variable.h | 6 +- .../core/providers/webgpu/webgpu_context.cc | 21 +++--- 4 files changed, 62 insertions(+), 49 deletions(-) diff --git a/onnxruntime/core/providers/webgpu/shader_helper.cc b/onnxruntime/core/providers/webgpu/shader_helper.cc index 245de6d7c2ed0..cd3507a6439ab 100644 --- a/onnxruntime/core/providers/webgpu/shader_helper.cc +++ b/onnxruntime/core/providers/webgpu/shader_helper.cc @@ -334,12 +334,12 @@ Status ShaderHelper::GenerateSourceCode(std::string& code, std::vector& sha shape_uniform_ranks.reserve(input_vars.size() + output_vars.size()); for (const auto& input : vars_[static_cast(ProgramVariableScope::Input)]) { - bool use_uniform = (input.usage_ & ShaderVariable::UseUniform) == ShaderVariable::UseUniform && input.rank_ > 1; + bool use_uniform = (input.usage_ & ShaderVariable::UseUniform) && input.rank_ > 0; use_any_shape_uniform |= use_uniform; shape_uniform_ranks.push_back(use_uniform ? input.rank_ : 0); } for (const auto& output : vars_[static_cast(ProgramVariableScope::Output)]) { - bool use_uniform = (output.usage_ & ShaderVariable::UseUniform) == ShaderVariable::UseUniform && output.rank_ > 1; + bool use_uniform = (output.usage_ & ShaderVariable::UseUniform) && output.rank_ > 0; use_any_shape_uniform |= use_uniform; shape_uniform_ranks.push_back(use_uniform ? output.rank_ : 0); } @@ -380,20 +380,22 @@ Status ShaderHelper::GenerateSourceCode(std::string& code, std::vector& sha }; for (const auto& input : vars_[static_cast(ProgramVariableScope::Input)]) { - if (input.rank_ > 1 && (input.usage_ & ShaderVariable::Usage::UseUniform)) { + const size_t rank = input.rank_; + if (rank > 0 && (input.usage_ & ShaderVariable::Usage::UseUniform)) { std::string shape = input.name_ + "_shape"; std::string stride = input.name_ + "_stride"; - append_uniform(shape, ProgramUniformVariableDataType::Uint32, input.rank_); - append_uniform(stride, ProgramUniformVariableDataType::Uint32, input.rank_); + append_uniform(shape, ProgramUniformVariableDataType::Uint32, rank); + append_uniform(stride, ProgramUniformVariableDataType::Uint32, rank - 1); } } for (const auto& output : vars_[static_cast(ProgramVariableScope::Output)]) { - if (output.rank_ > 1 && (output.usage_ & ShaderVariable::Usage::UseUniform)) { + const size_t rank = output.rank_; + if (rank > 0 && (output.usage_ & ShaderVariable::Usage::UseUniform)) { std::string shape = output.name_ + "_shape"; std::string stride = output.name_ + "_stride"; - append_uniform(shape, ProgramUniformVariableDataType::Uint32, output.rank_); - append_uniform(stride, ProgramUniformVariableDataType::Uint32, output.rank_); + append_uniform(shape, ProgramUniformVariableDataType::Uint32, rank); + append_uniform(stride, ProgramUniformVariableDataType::Uint32, rank - 1); } } diff --git a/onnxruntime/core/providers/webgpu/shader_variable.cc b/onnxruntime/core/providers/webgpu/shader_variable.cc index 0b7a7d390057c..98720c7854815 100644 --- a/onnxruntime/core/providers/webgpu/shader_variable.cc +++ b/onnxruntime/core/providers/webgpu/shader_variable.cc @@ -68,6 +68,12 @@ constexpr static const std::string_view ELEMENT_TYPE[] = { "bool", // vec4bool }; +inline std::string GetIndicesType(int rank) { + return rank < 2 ? "u32" + : (rank < 4 ? MakeStringWithClassicLocale("vec", rank, "") + : MakeStringWithClassicLocale("array")); +} + } // namespace ShaderVariable::ShaderVariable(std::string_view name, ProgramVariableDataType type, Usage usage, const TensorShape& dims) @@ -77,9 +83,7 @@ ShaderVariable::ShaderVariable(std::string_view name, ProgramVariableDataType ty rank_{SafeInt(dims.NumDimensions())}, dims_{dims}, usage_(usage), - indices_type_{rank_ < 2 ? "u32" - : (rank_ < 4 ? MakeStringWithClassicLocale("vec", rank_, "") - : MakeStringWithClassicLocale("array"))}, + indices_type_{GetIndicesType(rank_)}, value_type_alias_{name_ + "_value_t"}, element_type_alias_{name_ + "_element_t"}, indices_type_alias_{name_ + "_indices_t"} { @@ -105,7 +109,7 @@ void ShaderVariable::Impl(std::ostringstream& ss) const { } // Need shape and strides when (not use uniform) and (any other usage is enabled) - if (!(usage_ & UseUniform) && (usage_ & ~UseUniform)) { + if (!(usage_ & UseUniform) && (usage_ & ~UseUniform) && rank_ > 0) { SS("const ", shape, " = ", IndicesType(), "("); bool first = true; @@ -119,16 +123,18 @@ void ShaderVariable::Impl(std::ostringstream& ss) const { } ss << ");\n"; - SS("const ", stride, " = ", IndicesType(), "("); - first = true; - for (int i = 1; i <= rank_; i++) { - if (!first) { - ss << ","; + if (rank_ > 1) { + SS("const ", stride, " = ", GetIndicesType(rank_ - 1), "("); + first = true; + for (int i = 1; i < rank_; i++) { + if (!first) { + ss << ","; + } + ss << dims_.SizeFromDimension(i); + first = false; } - ss << dims_.SizeFromDimension(i); - first = false; + ss << ");\n"; } - ss << ");\n"; } // Implementation of "fn o2i_{name}" @@ -138,7 +144,7 @@ void ShaderVariable::Impl(std::ostringstream& ss) const { SS(" var indices: ", IndicesType(), ";\n"); SS(" var current = offset;\n"); for (int i = 0; i < rank_ - 1; i++) { - auto current_stride = GetElementAt(stride, i, rank_); + auto current_stride = GetElementAt(stride, i, rank_ - 1); SS(" let dim", i, " = current / ", current_stride, ";\n"); SS(" let rest", i, " = current % ", current_stride, ";\n"); SS(" indices[", i, "] = dim", i, ";\n"); @@ -156,7 +162,7 @@ void ShaderVariable::Impl(std::ostringstream& ss) const { SS("fn i2o_", name_, "(indices : ", IndicesType(), ")->u32 {\n"); SS(" return "); for (int i = 0; i < rank_ - 1; i++) { - SS("indices[", i, "] * ", GetElementAt(stride, i, rank_), " + "); + SS("indices[", i, "] * ", GetElementAt(stride, i, rank_ - 1), " + "); } SS("indices[", rank_ - 1, "];\n"); SS("}\n"); @@ -165,21 +171,23 @@ void ShaderVariable::Impl(std::ostringstream& ss) const { // Implementation of "fn {res_name}_bi2o_{name}" if (usage_ & UseBroadcastedIndicesToOffset) { - // TODO: do we need this if rank < 2? - for (const auto& iter : broadcasted_to_) { - const auto& broadcasted_result = iter.get(); - SS("fn ", broadcasted_result.name_, "_bi2o_", name_, "(indices : ", broadcasted_result.indices_type_, ")->u32 {\n"); - if (rank_ == 0) { - SS(" return 0;\n"); - } else { - SS(" return "); - for (int i = rank_ - 1; i >= 0; i--) { - auto idx = broadcasted_result.IndicesGet("indices", i + broadcasted_result.rank_ - rank_); - SS(IndicesGet(stride, i), " * (", idx, " % ", IndicesGet(shape, i), ") + "); + if (rank_ > 0) { + for (const auto& iter : broadcasted_to_) { + const auto& broadcasted_result = iter.get(); + SS("fn ", broadcasted_result.name_, "_bi2o_", name_, "(indices : ", broadcasted_result.indices_type_, ")->u32 {\n"); + if (rank_ == 1) { + SS(" return ", broadcasted_result.IndicesGet("indices", broadcasted_result.rank_ - 1), " % ", shape, ";\n"); + } else { + SS(" return "); + for (int i = 0; i < rank_ - 1; i++) { + auto idx = broadcasted_result.IndicesGet("indices", i + broadcasted_result.rank_ - rank_); + std::string current_stride = rank_ == 2 ? stride : GetElementAt(stride, i, rank_ - 1); + SS(current_stride, " * (", idx, " % ", IndicesGet(shape, i), ") + "); + } + SS(broadcasted_result.IndicesGet("indices", broadcasted_result.rank_ - 1), " % ", IndicesGet(shape, rank_ - 1), ";\n"); } - SS("0;\n"); + SS("}\n"); } - SS("}\n"); } } @@ -245,10 +253,8 @@ std::string ShaderVariable::GetByOffsetImpl(std::string_view offset) const { ORT_THROW("Invalid type"); break; case onnxruntime::webgpu::ProgramVariableDataType::Int64: - ss << "i32(" << name_ << "[" << offset << "].x)"; - break; case onnxruntime::webgpu::ProgramVariableDataType::Uint64: - ss << "u32(" << name_ << "[" << offset << "].x)"; + ss << ElementType() << "(" << name_ << "[" << offset << "].x)"; break; case onnxruntime::webgpu::ProgramVariableDataType::Vec4Bool: ss << "vec4(bool(" diff --git a/onnxruntime/core/providers/webgpu/shader_variable.h b/onnxruntime/core/providers/webgpu/shader_variable.h index 778017a50dda7..c6d28975bae3e 100644 --- a/onnxruntime/core/providers/webgpu/shader_variable.h +++ b/onnxruntime/core/providers/webgpu/shader_variable.h @@ -202,13 +202,15 @@ inline std::string ShaderVariable::IndicesToOffset(std::string_view indices_expr inline std::string ShaderVariable::BroadcastedIndicesToOffset(std::string_view indices_expr, const ShaderVariable& broadcasted_result) const { usage_ |= UseBroadcastedIndicesToOffset; broadcasted_to_.push_back(broadcasted_result); - return MakeStringWithClassicLocale(broadcasted_result.name_, "_bi2o_", name_, '(', indices_expr, ')'); + return rank_ == 0 + ? "0" + : MakeStringWithClassicLocale(broadcasted_result.name_, "_bi2o_", name_, '(', indices_expr, ')'); } template inline std::string ShaderVariable::Indices(TIndices&&... indices_args) const { return rank_ == 0 - ? "" + ? "0" : MakeStringWithClassicLocale(name_, "_indices_t(", absl::StrJoin(std::forward_as_tuple(std::forward(indices_args)...), ", "), ')'); diff --git a/onnxruntime/core/providers/webgpu/webgpu_context.cc b/onnxruntime/core/providers/webgpu/webgpu_context.cc index 343da693c716b..599ee9bbb82f6 100644 --- a/onnxruntime/core/providers/webgpu/webgpu_context.cc +++ b/onnxruntime/core/providers/webgpu/webgpu_context.cc @@ -249,15 +249,19 @@ Status WebGpuContext::Run(const ComputeContext& context, const ProgramBase& prog "Invalid program artifact: variable[", i, "] rank mismatch. Expected: ", (int)expected_rank, ", Actual: ", shape.NumDimensions()); - std::vector dims(shape.NumDimensions()); - std::vector stride(shape.NumDimensions()); - for (size_t j = 0; j < shape.NumDimensions(); ++j) { + std::vector dims(expected_rank); + std::vector stride(expected_rank - 1); + for (size_t j = 0; j < expected_rank; ++j) { dims[j] = SafeInt(shape[j]); - stride[j] = SafeInt(shape.SizeFromDimension(j + 1)); + if (j < expected_rank - 1) { + stride[j] = SafeInt(shape.SizeFromDimension(j + 1)); + } } shape_uniforms.emplace_back(gsl::make_span(dims)); - shape_uniforms.emplace_back(gsl::make_span(stride)); + if (expected_rank > 1) { + shape_uniforms.emplace_back(gsl::make_span(stride)); + } } } @@ -268,14 +272,13 @@ Status WebGpuContext::Run(const ComputeContext& context, const ProgramBase& prog for (size_t i = 0; i < uniform_count; i++) { const auto& uniform = i < shape_uniforms.size() ? shape_uniforms[i] : program.UniformVariables()[i - shape_uniforms.size()]; - bool is_f16 = uniform.data_type == ProgramUniformVariableDataType::Float16; size_t length = uniform.length; - - // skip zero-length uniform - if (length == 0) { + if (length == 0) { // skip zero-length uniform continue; } + bool is_f16 = uniform.data_type == ProgramUniformVariableDataType::Float16; + size_t element_size = ProgramUniformVariableDataTypeSize[static_cast(uniform.data_type)]; // https://www.w3.org/TR/WGSL/#alignof size_t base_alignment = is_f16 From 4d48d287feb3ae48ffb01e905fce70411d78416e Mon Sep 17 00:00:00 2001 From: Yulong Wang <7679871+fs-eire@users.noreply.github.com> Date: Wed, 4 Sep 2024 17:50:08 -0700 Subject: [PATCH 49/77] fix GetElementAt with uniform --- onnxruntime/core/providers/webgpu/shader_variable.h | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/onnxruntime/core/providers/webgpu/shader_variable.h b/onnxruntime/core/providers/webgpu/shader_variable.h index c6d28975bae3e..b8b44de92911b 100644 --- a/onnxruntime/core/providers/webgpu/shader_variable.h +++ b/onnxruntime/core/providers/webgpu/shader_variable.h @@ -15,7 +15,7 @@ namespace webgpu { template std::string GetElementAt(std::string_view var, const TIdx& idx, int rank, bool is_f16 = false) { // "std::string::rfind(str, 0) == 0" is equivalent to "std::string::starts_with(str)" before C++20. - if (var.rfind("uniform.", 0) == 0) { + if (var.rfind("uniforms.", 0) == 0) { if (rank > 4) { if constexpr (std::is_integral_v) { if (is_f16) { From dbe673bebc2e3159360fac62a49b486f5058fe9c Mon Sep 17 00:00:00 2001 From: Yulong Wang <7679871+fs-eire@users.noreply.github.com> Date: Thu, 5 Sep 2024 00:37:28 -0700 Subject: [PATCH 50/77] document update folder --- onnxruntime/core/providers/webgpu/README.md | 80 ++----------------- .../providers/webgpu/docs/Best_Practices.md | 37 +++++++++ .../core/providers/webgpu/docs/Conventions.md | 33 ++++++++ .../How_to_Write_WebGPU_EP_Kernel.md | 0 4 files changed, 75 insertions(+), 75 deletions(-) create mode 100644 onnxruntime/core/providers/webgpu/docs/Best_Practices.md create mode 100644 onnxruntime/core/providers/webgpu/docs/Conventions.md rename onnxruntime/core/providers/webgpu/{ => docs}/How_to_Write_WebGPU_EP_Kernel.md (100%) diff --git a/onnxruntime/core/providers/webgpu/README.md b/onnxruntime/core/providers/webgpu/README.md index 999f1fecbda76..fe0d99b1d602b 100644 --- a/onnxruntime/core/providers/webgpu/README.md +++ b/onnxruntime/core/providers/webgpu/README.md @@ -4,9 +4,7 @@ This folder is for the WebGPU execution provider(WebGPU EP). Currently, WebGPU E ## Build WebGPU EP -Just append `--use_webgpu --skip_tests` to the `build.bat`/`build.sh` command line. - -NOTE: `--skip_tests` is required for now. All existing tests are for CPU EP anyway so no need to run them. +Just append `--use_webgpu` to the `build.bat`/`build.sh` command line. For linux, a few dependencies need to be installed: ```sh @@ -19,83 +17,15 @@ TODO: add solutions to common problems. ## Development Guide -See [How to write WebGPU EP kernel](./How_to_Write_WebGPU_EP_Kernel.md) for more information. - -## Convention - -### Use "webgpu" other than "wgpu" in this folder - -This is referring to the naming convention of variables, classes and namespace. - -ORT C API is using "wgpu". - -Let's keep it "webgpu" for this folder for now. I have a very good reason to do so: - -- search for "webgpu" in the code base shows the WebGPU EP related code and search for "wgpu" shows the WebGPU API related code. This helps me easier to find the code I want to look at. - -And anyway, it's not hard to change it back to "wgpu" if we want to. (but it's harder to change it from "wgpu" to "webgpu") - -### Use macros defined in shader_macros.h - -Take `SS` as example. It's a macro defined in `shader_macros.h` and it's used to concatenate strings. It's just make the `std::ostream::operator<<` to be used in a function call style. - -I prefer to use the macro because I feel like it's easier to read. Check the following code: - -```cpp -ss << "vec4(" << type << ">(" << value1 << ", " << value2 << ", " << value3 << ", " << value4 << ")"; -``` - -vs. - -```cpp -SS("vec4<", type, ">(", value1, ", ", value2, ", ", value3, ", ", value4, ")"); -``` - -### Use the subfolder for kernel implementation +See [How to write WebGPU EP kernel](./docs/How_to_Write_WebGPU_EP_Kernel.md) for more information. -Operator implementation source code need to be put under a subfolder like "math"/"nn"/"tensor". +## Conventions -See folder structure under onnxruntime/core/providers/cpu/ or onnxruntime/core/providers/cuda/ for examples. +See [Conventions](./docs/Conventions.md) for more information. ## Best Practices -### Always use std::ostringstream to generate shader code if possible - -This helps to the performance of code generation. - -For example: - -```cpp -ss << "var " << name << " = " << value << ";\n"; -``` - -is better than - -```cpp -ss << ("var " + name + " = " + value + ";\n"); -``` - -### Avoid creating template class for kernel using data type as template parameter. - -This basically means that we should define class like this: - -```cpp -class Abs : public WebGpuKernel { - ... -}; -``` - -instead of - -```cpp - -template // T is tensor element type -class Abs : public WebGpuKernel { - ... -}; -``` - -This is because we don't really read and use `Tensor::Data()`. Tensor stores a handle to a WebGPU buffer but not a pointer to the data. Using template for data type only increases the binary size with no real benefit. +See [Best Practices](./docs/Best_Practices.md) for more information. ## TODO items diff --git a/onnxruntime/core/providers/webgpu/docs/Best_Practices.md b/onnxruntime/core/providers/webgpu/docs/Best_Practices.md new file mode 100644 index 0000000000000..d519292b226d0 --- /dev/null +++ b/onnxruntime/core/providers/webgpu/docs/Best_Practices.md @@ -0,0 +1,37 @@ +### Always use std::ostringstream to generate shader code if possible + +This helps to the performance of code generation. + +For example: + +```cpp +ss << "var " << name << " = " << value << ";\n"; +``` + +is better than + +```cpp +ss << ("var " + name + " = " + value + ";\n"); +``` + +### Avoid creating template class for kernel using data type as template parameter. + +This basically means that we should define class like this: + +```cpp +class Abs : public WebGpuKernel { + ... +}; +``` + +instead of + +```cpp + +template // T is tensor element type +class Abs : public WebGpuKernel { + ... +}; +``` + +This is because we don't really read and use `Tensor::Data()`. Tensor stores a handle to a WebGPU buffer but not a pointer to the data. Using template for data type only increases the binary size with no real benefit. diff --git a/onnxruntime/core/providers/webgpu/docs/Conventions.md b/onnxruntime/core/providers/webgpu/docs/Conventions.md new file mode 100644 index 0000000000000..1a86e508cdda8 --- /dev/null +++ b/onnxruntime/core/providers/webgpu/docs/Conventions.md @@ -0,0 +1,33 @@ +### Use "webgpu" other than "wgpu" in this folder + +This is referring to the naming convention of variables, classes and namespace. + +ORT C API is using "wgpu". + +Let's keep it "webgpu" for this folder for now. I have a very good reason to do so: + +- search for "webgpu" in the code base shows the WebGPU EP related code and search for "wgpu" shows the WebGPU API related code. This helps me easier to find the code I want to look at. + +And anyway, it's not hard to change it back to "wgpu" if we want to. (but it's harder to change it from "wgpu" to "webgpu") + +### Use macros defined in shader_macros.h + +Take `SS` as example. It's a macro defined in `shader_macros.h` and it's used to concatenate strings. It's just make the `std::ostream::operator<<` to be used in a function call style. + +I prefer to use the macro because I feel like it's easier to read. Check the following code: + +```cpp +ss << "vec4(" << type << ">(" << value1 << ", " << value2 << ", " << value3 << ", " << value4 << ")"; +``` + +vs. + +```cpp +SS("vec4<", type, ">(", value1, ", ", value2, ", ", value3, ", ", value4, ")"); +``` + +### Use the subfolder for kernel implementation + +Operator implementation source code need to be put under a subfolder like "math"/"nn"/"tensor". + +See folder structure under onnxruntime/core/providers/cpu/ or onnxruntime/core/providers/cuda/ for examples. diff --git a/onnxruntime/core/providers/webgpu/How_to_Write_WebGPU_EP_Kernel.md b/onnxruntime/core/providers/webgpu/docs/How_to_Write_WebGPU_EP_Kernel.md similarity index 100% rename from onnxruntime/core/providers/webgpu/How_to_Write_WebGPU_EP_Kernel.md rename to onnxruntime/core/providers/webgpu/docs/How_to_Write_WebGPU_EP_Kernel.md From 38f182e65e7a312d60eacd4bc29c8d5de19141c0 Mon Sep 17 00:00:00 2001 From: Yulong Wang <7679871+fs-eire@users.noreply.github.com> Date: Thu, 5 Sep 2024 01:04:13 -0700 Subject: [PATCH 51/77] fix adapter/device creating: add toggles --- .../external/onnxruntime_external_deps.cmake | 2 + .../core/providers/webgpu/webgpu_context.cc | 61 ++++++++++++++++++- 2 files changed, 60 insertions(+), 3 deletions(-) diff --git a/cmake/external/onnxruntime_external_deps.cmake b/cmake/external/onnxruntime_external_deps.cmake index 6640609aa71dd..a8ab4a53b9f3a 100644 --- a/cmake/external/onnxruntime_external_deps.cmake +++ b/cmake/external/onnxruntime_external_deps.cmake @@ -594,6 +594,8 @@ if (onnxruntime_USE_WEBGPU) set(DAWN_FETCH_DEPENDENCIES ON) set(DAWN_ENABLE_INSTALL ON) set(TINT_BUILD_TESTS OFF) + set(DAWN_USE_BUILT_DXC ON) + set(DAWN_DXC_ENABLE_ASSERTS_IN_NDEBUG OFF) onnxruntime_fetchcontent_makeavailable(dawn) endif() diff --git a/onnxruntime/core/providers/webgpu/webgpu_context.cc b/onnxruntime/core/providers/webgpu/webgpu_context.cc index 599ee9bbb82f6..276d74905adb7 100644 --- a/onnxruntime/core/providers/webgpu/webgpu_context.cc +++ b/onnxruntime/core/providers/webgpu/webgpu_context.cc @@ -17,12 +17,46 @@ namespace onnxruntime { namespace webgpu { +namespace { + +std::vector GetEnabledAdapterToggles() { + // See the description of all the toggles in toggles.cpp + // "use_dxc" for Shader Model 6+ features (e.g. float16) + // "allow_unsafe_apis" for chromium experimental features + constexpr const char* toggles[] = { + "use_dxc", + "allow_unsafe_apis", + }; + return std::vector(std::begin(toggles), std::end(toggles)); +} + +std::vector GetEnabledDeviceToggles() { + // Enable / disable other toggles that may affect the performance. + // Other toggles that may be useful: "dump_shaders", "disable_symbol_renaming" + constexpr const char* toggles[] = { + "skip_validation", + "disable_robustness", + "disable_workgroup_init", + "d3d_disable_ieee_strictness", + }; + return std::vector(std::begin(toggles), std::end(toggles)); +} + +std::vector GetDisabledDeviceToggles() { + constexpr const char* toggles[] = { + "lazy_clear_resource_on_first_use", + }; + return std::vector(std::begin(toggles), std::end(toggles)); +} + std::vector GetAvailableRequiredFeatures(const wgpu::Adapter& adapter) { std::vector required_features; constexpr wgpu::FeatureName features[]{ wgpu::FeatureName::ChromiumExperimentalTimestampQueryInsidePasses, wgpu::FeatureName::TimestampQuery, - wgpu::FeatureName::ShaderF16}; + wgpu::FeatureName::ShaderF16, + wgpu::FeatureName::Subgroups, + wgpu::FeatureName::SubgroupsF16}; for (auto feature : features) { if (adapter.HasFeature(feature)) { required_features.push_back(feature); @@ -31,7 +65,7 @@ std::vector GetAvailableRequiredFeatures(const wgpu::Adapter& return required_features; } -wgpu::RequiredLimits GetAvailableRequiredLimits(const wgpu::Adapter& adapter) { +wgpu::RequiredLimits GetRequiredLimits(const wgpu::Adapter& adapter) { wgpu::RequiredLimits required_limits{}; wgpu::SupportedLimits adapter_limits; ORT_ENFORCE(adapter.GetLimits(&adapter_limits)); @@ -49,6 +83,8 @@ wgpu::RequiredLimits GetAvailableRequiredLimits(const wgpu::Adapter& adapter) { return required_limits; } +} // namespace + void WebGpuContext::Initialize(const WebGpuExecutionProviderInfo& webgpu_ep_info) { std::call_once(init_flag_, [this, &webgpu_ep_info]() { // Initialization.Step.1 - Create wgpu::Instance @@ -63,6 +99,13 @@ void WebGpuContext::Initialize(const WebGpuExecutionProviderInfo& webgpu_ep_info // Initialization.Step.2 - Create wgpu::Adapter if (adapter_ == nullptr) { wgpu::RequestAdapterOptions req_adapter_options = {}; + wgpu::DawnTogglesDescriptor adapter_toggles_desc = {}; + req_adapter_options.nextInChain = &adapter_toggles_desc; + + auto enabled_adapter_toggles = GetEnabledAdapterToggles(); + adapter_toggles_desc.enabledToggleCount = enabled_adapter_toggles.size(); + adapter_toggles_desc.enabledToggles = enabled_adapter_toggles.data(); + wgpu::RequestAdapterCallbackInfo req_adapter_callback_info = {}; req_adapter_callback_info.mode = wgpu::CallbackMode::WaitAnyOnly; req_adapter_callback_info.callback = [](WGPURequestAdapterStatus status, @@ -79,11 +122,23 @@ void WebGpuContext::Initialize(const WebGpuExecutionProviderInfo& webgpu_ep_info // Initialization.Step.3 - Create wgpu::Device if (device_ == nullptr) { wgpu::DeviceDescriptor device_desc = {}; + wgpu::DawnTogglesDescriptor device_toggles_desc = {}; + device_desc.nextInChain = &device_toggles_desc; + + auto enabled_device_toggles = GetEnabledDeviceToggles(); + device_toggles_desc.enabledToggleCount = enabled_device_toggles.size(); + device_toggles_desc.enabledToggles = enabled_device_toggles.data(); + + auto disabled_device_toggles = GetDisabledDeviceToggles(); + device_toggles_desc.disabledToggleCount = disabled_device_toggles.size(); + device_toggles_desc.disabledToggles = disabled_device_toggles.data(); + std::vector required_features = GetAvailableRequiredFeatures(adapter_); if (required_features.size() > 0) { device_desc.requiredFeatures = required_features.data(); + device_desc.requiredFeatureCount = required_features.size(); } - wgpu::RequiredLimits required_limits = GetAvailableRequiredLimits(adapter_); + wgpu::RequiredLimits required_limits = GetRequiredLimits(adapter_); device_desc.requiredLimits = &required_limits; // TODO: revise temporary error handling From eb80f7c4e28c08d735e1ad8efdb91335dfb5cd1c Mon Sep 17 00:00:00 2001 From: Yulong Wang <7679871+fs-eire@users.noreply.github.com> Date: Thu, 5 Sep 2024 17:16:51 -0700 Subject: [PATCH 52/77] more strict shape&stride usage check --- .../core/providers/webgpu/shader_helper.cc | 51 ++++++++++++++++--- .../core/providers/webgpu/shader_variable.cc | 4 +- .../core/providers/webgpu/shader_variable.h | 40 +++++++++------ 3 files changed, 71 insertions(+), 24 deletions(-) diff --git a/onnxruntime/core/providers/webgpu/shader_helper.cc b/onnxruntime/core/providers/webgpu/shader_helper.cc index cd3507a6439ab..bf791a36858b3 100644 --- a/onnxruntime/core/providers/webgpu/shader_helper.cc +++ b/onnxruntime/core/providers/webgpu/shader_helper.cc @@ -102,6 +102,7 @@ const ShaderVariable& ShaderHelper::AddOutput(const std::string& name, ProgramVa #ifndef NDEBUG // if debug build namespace { +// Validate if the tensor element type matches the program variable data type Status ValidateVariableDataType(int32_t element_type, ProgramVariableDataType var_type) { switch (element_type) { case ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT: @@ -148,8 +149,7 @@ Status ValidateVariableDataType(int32_t element_type, ProgramVariableDataType va return Status::OK(); } -using RankOrShape = std::variant>; - +// Validate if the number of components and override shape match the original shape Status ValidateVariableShape(const TensorShape& origin_shape, bool use_override_shape, const TensorShape& override_shape, @@ -166,6 +166,36 @@ Status ValidateVariableShape(const TensorShape& origin_shape, return Status::OK(); } + +// Validate if the dependency and variable usage match +Status ValidateVariableDependency(ProgramTensorMetadataDependency dependency, ShaderVariable::Usage usage, bool is_input) { + bool dependency_rank = (dependency & ProgramTensorMetadataDependency::Rank) == ProgramTensorMetadataDependency::Rank; + bool dependency_shape = (dependency & ProgramTensorMetadataDependency::Shape) == ProgramTensorMetadataDependency::Shape; + bool dependency_type = (dependency & ProgramTensorMetadataDependency::Type) == ProgramTensorMetadataDependency::Type; + + // if dependency is already set for shape, it is no need to set for rank. + ORT_RETURN_IF(dependency_rank && dependency_shape, + "Dependency cannot set for both \"Rank\" and \"Shape\"."); + + // if dependency is set for shape, it's already part of the shader cache. no need to use uniform. + ORT_RETURN_IF(dependency_shape && (usage & ShaderVariable::UseUniform) == ShaderVariable::UseUniform, + "Dependency is set for \"Shape\", using uniform for shape is not allowed."); + + // for input variable, check is more strict. + // this is because usually output shape is determined by the existing information, which is already part of the shader cache. + if (is_input) { + // if dependency is not set for type, should not use type alias for element and value. + // storage type is always used. so setting not depending on type is at user's own risk. + ORT_RETURN_IF(!dependency_type && (usage & (ShaderVariable::UseElementTypeAlias | ShaderVariable::UseValueTypeAlias)), + "Input dependency is not set for \"Type\", but type alias for element type or value type is used."); + + // if dependency is not set for rank and shape, the shader should not use shape and stride. + ORT_RETURN_IF(!dependency_rank && !dependency_shape && (usage & ShaderVariable::UseShapeAndStride), + "Input dependency is set for neither \"Rank\" nor \"Shape\", but variable shape and stride is used."); + } + + return Status::OK(); +} } // namespace const ShaderVariable& ShaderHelper::AddVariableImpl(ProgramVariableScope scope, @@ -197,6 +227,7 @@ Status ShaderHelper::ValidateVariable(const ProgramInput& input, const ShaderVar input.use_override_shape, input.use_override_shape ? input.override_shape : input.tensor->Shape(), var.num_components_)); + ORT_RETURN_IF_ERROR(ValidateVariableDependency(input.dependency, var.usage_, true)); return Status::OK(); } @@ -206,6 +237,8 @@ Status ShaderHelper::ValidateVariable(const ProgramOutput& output, const ShaderV output.use_override_shape, output.use_override_shape ? output.override_shape : output.tensor->Shape(), var.num_components_)); + ORT_RETURN_IF_ERROR(ValidateVariableDependency(output.dependency, var.usage_, false)); + return Status::OK(); } @@ -280,6 +313,12 @@ Status ShaderHelper::GenerateSourceCode(std::string& code, std::vector& sha if (use_f16_) { ORT_RETURN_IF_NOT(device_.HasFeature(wgpu::FeatureName::ShaderF16), "Program ", program_.Name(), " requires f16 but the device does not support it."); ss << "enable f16;\n"; + if (device_.HasFeature(wgpu::FeatureName::SubgroupsF16)) { + ss << "enable subgroups_f16;\n"; + } + } + if (device_.HasFeature(wgpu::FeatureName::Subgroups)) { + ss << "enable subgroups;\n"; } // @@ -334,12 +373,12 @@ Status ShaderHelper::GenerateSourceCode(std::string& code, std::vector& sha shape_uniform_ranks.reserve(input_vars.size() + output_vars.size()); for (const auto& input : vars_[static_cast(ProgramVariableScope::Input)]) { - bool use_uniform = (input.usage_ & ShaderVariable::UseUniform) && input.rank_ > 0; + bool use_uniform = (input.usage_ & ShaderVariable::UseUniform) && (input.usage_ & ShaderVariable::UseShapeAndStride) && input.rank_ > 0; use_any_shape_uniform |= use_uniform; shape_uniform_ranks.push_back(use_uniform ? input.rank_ : 0); } for (const auto& output : vars_[static_cast(ProgramVariableScope::Output)]) { - bool use_uniform = (output.usage_ & ShaderVariable::UseUniform) && output.rank_ > 0; + bool use_uniform = (output.usage_ & ShaderVariable::UseUniform) && (output.usage_ & ShaderVariable::UseShapeAndStride) && output.rank_ > 0; use_any_shape_uniform |= use_uniform; shape_uniform_ranks.push_back(use_uniform ? output.rank_ : 0); } @@ -381,7 +420,7 @@ Status ShaderHelper::GenerateSourceCode(std::string& code, std::vector& sha for (const auto& input : vars_[static_cast(ProgramVariableScope::Input)]) { const size_t rank = input.rank_; - if (rank > 0 && (input.usage_ & ShaderVariable::Usage::UseUniform)) { + if (rank > 0 && (input.usage_ & ShaderVariable::Usage::UseUniform) && (input.usage_ & ShaderVariable::Usage::UseShapeAndStride)) { std::string shape = input.name_ + "_shape"; std::string stride = input.name_ + "_stride"; append_uniform(shape, ProgramUniformVariableDataType::Uint32, rank); @@ -391,7 +430,7 @@ Status ShaderHelper::GenerateSourceCode(std::string& code, std::vector& sha for (const auto& output : vars_[static_cast(ProgramVariableScope::Output)]) { const size_t rank = output.rank_; - if (rank > 0 && (output.usage_ & ShaderVariable::Usage::UseUniform)) { + if (rank > 0 && (output.usage_ & ShaderVariable::Usage::UseUniform) && (output.usage_ & ShaderVariable::Usage::UseShapeAndStride)) { std::string shape = output.name_ + "_shape"; std::string stride = output.name_ + "_stride"; append_uniform(shape, ProgramUniformVariableDataType::Uint32, rank); diff --git a/onnxruntime/core/providers/webgpu/shader_variable.cc b/onnxruntime/core/providers/webgpu/shader_variable.cc index 98720c7854815..f5fc236aca71d 100644 --- a/onnxruntime/core/providers/webgpu/shader_variable.cc +++ b/onnxruntime/core/providers/webgpu/shader_variable.cc @@ -108,8 +108,8 @@ void ShaderVariable::Impl(std::ostringstream& ss) const { SS("alias ", element_type_alias_, " = ", ELEMENT_TYPE[static_cast(type_)], ";\n"); } - // Need shape and strides when (not use uniform) and (any other usage is enabled) - if (!(usage_ & UseUniform) && (usage_ & ~UseUniform) && rank_ > 0) { + // Need shape and strides when (not use uniform) and (use shape and stride is enabled) + if (!(usage_ & UseUniform) && (usage_ & UseShapeAndStride) && rank_ > 0) { SS("const ", shape, " = ", IndicesType(), "("); bool first = true; diff --git a/onnxruntime/core/providers/webgpu/shader_variable.h b/onnxruntime/core/providers/webgpu/shader_variable.h index b8b44de92911b..aa186d58740e3 100644 --- a/onnxruntime/core/providers/webgpu/shader_variable.h +++ b/onnxruntime/core/providers/webgpu/shader_variable.h @@ -39,18 +39,19 @@ std::string GetElementAt(std::string_view var, const TIdx& idx, int rank, bool i class ShaderVariable { public: enum Usage : uint32_t { - None = 0, // no usage. this means no additional implementation code will be generated. - UseIndicesTypeAlias = 1, // use type alias "{name}_indices_t" for indices (eg. u32, vec2, vec3, vec4, ...) - UseValueTypeAlias = 2, // use type alias "{name}_value_t" for value (eg. f32, vecT, vec4, ...) - UseElementTypeAlias = 4, // use type alias "{name}_element_t" for element (eg. f32, bool, ...) - UseOffsetToIndices = 8, // use implementation of fn o2i_{name} - UseIndicesToOffset = 16, // use implementation of fn i2o_{name} - UseBroadcastedIndicesToOffset = 32, // use implementation of fn {broadcasted_result_name}_bi2o_{name} - UseSet = 64, // use implementation of fn set_{name} - UseSetByIndices = 128, // use implementation of fn set_{name}_by_indices - UseGet = 256, // use implementation of fn get_{name} - UseGetByIndices = 512, // use implementation of fn get_{name}_by_indices - UseUniform = 1024, // use uniform for shape and stride + None = 0, // no usage. this means no additional implementation code will be generated. + UseIndicesTypeAlias = 1, // use type alias "{name}_indices_t" for indices (eg. u32, vec2, vec3, vec4, ...) + UseValueTypeAlias = 2, // use type alias "{name}_value_t" for value (eg. f32, vecT, vec4, ...) + UseElementTypeAlias = 4, // use type alias "{name}_element_t" for element (eg. f32, bool, ...) + UseShapeAndStride = 16, // use shape and stride for the variable + UseOffsetToIndices = 32, // use implementation of fn o2i_{name} + UseIndicesToOffset = 64, // use implementation of fn i2o_{name} + UseBroadcastedIndicesToOffset = 128, // use implementation of fn {broadcasted_result_name}_bi2o_{name} + UseSet = 256, // use implementation of fn set_{name} + UseSetByIndices = 512, // use implementation of fn set_{name}_by_indices + UseGet = 1024, // use implementation of fn get_{name} + UseGetByIndices = 2048, // use implementation of fn get_{name}_by_indices + UseUniform = 32768, // use uniform for shape and stride }; ShaderVariable(std::string_view name, ProgramVariableDataType type, Usage usage, const TensorShape& dims); @@ -188,19 +189,19 @@ std::string pass_as_string(T&& v) { } // namespace detail inline std::string ShaderVariable::OffsetToIndices(std::string_view offset_expr) const { - usage_ |= UseOffsetToIndices; + usage_ |= UseOffsetToIndices | UseShapeAndStride; return rank_ < 2 ? std::string{offset_expr} : MakeStringWithClassicLocale("o2i_", name_, '(', offset_expr, ')'); } inline std::string ShaderVariable::IndicesToOffset(std::string_view indices_expr) const { - usage_ |= UseIndicesToOffset; + usage_ |= UseIndicesToOffset | UseShapeAndStride; return rank_ < 2 ? std::string{indices_expr} : MakeStringWithClassicLocale("i2o_", name_, '(', indices_expr, ')'); } inline std::string ShaderVariable::BroadcastedIndicesToOffset(std::string_view indices_expr, const ShaderVariable& broadcasted_result) const { - usage_ |= UseBroadcastedIndicesToOffset; + usage_ |= UseBroadcastedIndicesToOffset | UseShapeAndStride; broadcasted_to_.push_back(broadcasted_result); return rank_ == 0 ? "0" @@ -209,21 +210,24 @@ inline std::string ShaderVariable::BroadcastedIndicesToOffset(std::string_view i template inline std::string ShaderVariable::Indices(TIndices&&... indices_args) const { + usage_ |= UseShapeAndStride; return rank_ == 0 ? "0" - : MakeStringWithClassicLocale(name_, "_indices_t(", + : MakeStringWithClassicLocale(IndicesType(), "(", absl::StrJoin(std::forward_as_tuple(std::forward(indices_args)...), ", "), ')'); } template inline std::string ShaderVariable::IndicesSet(std::string_view indices_var, const TIdx& idx_expr, const TVal& value) const { + usage_ |= UseShapeAndStride; return rank_ < 2 ? MakeStringWithClassicLocale(indices_var, '=', value, ';') : MakeStringWithClassicLocale(GetElementAt(indices_var, idx_expr, rank_), '=', value, ';'); } template inline std::string ShaderVariable::IndicesGet(std::string_view indices_var, const TIdx& idx_expr) const { + usage_ |= UseShapeAndStride; return rank_ < 2 ? std::string{indices_var} : GetElementAt(indices_var, idx_expr, rank_); } @@ -235,6 +239,7 @@ inline std::string ShaderVariable::SetByOffset(TOffset&& offset, TValue&& value) template inline std::string ShaderVariable::Set(TIndicesAndValue&&... args) const { + usage_ |= UseShapeAndStride; ORT_ENFORCE(sizeof...(TIndicesAndValue) == rank_ + 1, "Number of arguments should be ", rank_ + 1, "(rank + 1)"); if constexpr (sizeof...(TIndicesAndValue) == 1) { return SetByOffset("0", std::forward(args)...); @@ -249,6 +254,7 @@ inline std::string ShaderVariable::Set(TIndicesAndValue&&... args) const { } inline std::string ShaderVariable::SetByIndices(std::string_view indices_var, std::string_view value) const { + usage_ |= UseShapeAndStride; if (rank_ < 2) { return SetByOffset(indices_var, value); } else { @@ -264,6 +270,7 @@ inline std::string ShaderVariable::GetByOffset(TOffset&& offset) const { template inline std::string ShaderVariable::Get(TIndices&&... indices) const { + usage_ |= UseShapeAndStride; ORT_ENFORCE(sizeof...(TIndices) == rank_, "Number of arguments should be ", rank_, "(rank)"); if constexpr (sizeof...(TIndices) == 0) { return GetByOffset("0"); @@ -278,6 +285,7 @@ inline std::string ShaderVariable::Get(TIndices&&... indices) const { } inline std::string ShaderVariable::GetByIndices(std::string_view indices_var) const { + usage_ |= UseShapeAndStride; if (rank_ < 2) { return GetByOffset(indices_var); } else { From 39d55098f9db0d5c46b3085e9e8fcd5f76abb943 Mon Sep 17 00:00:00 2001 From: Yulong Wang <7679871+fs-eire@users.noreply.github.com> Date: Thu, 5 Sep 2024 23:47:23 -0700 Subject: [PATCH 53/77] fix vector realloc --- .../core/providers/webgpu/shader_helper.cc | 47 ++++++++++--------- .../core/providers/webgpu/shader_helper.h | 2 +- 2 files changed, 27 insertions(+), 22 deletions(-) diff --git a/onnxruntime/core/providers/webgpu/shader_helper.cc b/onnxruntime/core/providers/webgpu/shader_helper.cc index bf791a36858b3..f43806d1406c2 100644 --- a/onnxruntime/core/providers/webgpu/shader_helper.cc +++ b/onnxruntime/core/providers/webgpu/shader_helper.cc @@ -218,7 +218,8 @@ const ShaderVariable& ShaderHelper::AddVariableImpl(ProgramVariableScope scope, ORT_NOT_IMPLEMENTED("Local variables are not supported yet."); } - return vars_[std::underlying_type::type(scope)].emplace_back(name, type, usage, dims); + const auto& var = vars_[std::underlying_type::type(scope)].emplace_back(std::make_unique(name, type, usage, dims)); + return *var; } Status ShaderHelper::ValidateVariable(const ProgramInput& input, const ShaderVariable& var) const { @@ -255,18 +256,18 @@ Status ShaderHelper::ValidateShapeForInputsAndOutputs() const { for (size_t i = 0; i < input_vars.size(); i++) { #ifndef NDEBUG // if debug build // Validate input shape - ORT_RETURN_IF_ERROR(ValidateVariable(program_.Inputs()[i], input_vars[i])); + ORT_RETURN_IF_ERROR(ValidateVariable(program_.Inputs()[i], *input_vars[i])); #endif // check input dependencies with actual usages. - auto usage = input_vars[i].usage_; + auto usage = input_vars[i]->usage_; bool use_uniform = (usage & ShaderVariable::UseUniform) == ShaderVariable::UseUniform; auto dependency = program_.Inputs()[i].dependency; bool use_rank = (dependency & ProgramTensorMetadataDependency::Rank) == ProgramTensorMetadataDependency::Rank; bool use_shape = (dependency & ProgramTensorMetadataDependency::Shape) == ProgramTensorMetadataDependency::Shape; if (use_uniform) { - ORT_RETURN_IF_NOT((use_rank || input_vars[i].rank_ < 2) && !use_shape, + ORT_RETURN_IF_NOT((use_rank || input_vars[i]->rank_ < 2) && !use_shape, "When UseUniform is set in variable usage, the corresponding program input should depend on rank but not shape."); } else { ORT_RETURN_IF_NOT(use_shape, @@ -279,11 +280,11 @@ Status ShaderHelper::ValidateShapeForInputsAndOutputs() const { for (size_t i = 0; i < output_vars.size(); i++) { #ifndef NDEBUG // if debug build // Validate output shape - ORT_RETURN_IF_ERROR(ValidateVariable(program_.Outputs()[i], output_vars[i])); + ORT_RETURN_IF_ERROR(ValidateVariable(program_.Outputs()[i], *output_vars[i])); #endif // check output dependencies with actual usages. - auto usage = output_vars[i].usage_; + auto usage = output_vars[i]->usage_; bool use_uniform = (usage & ShaderVariable::UseUniform) == ShaderVariable::UseUniform; auto dependency = program_.Outputs()[i].dependency; bool use_shape = (dependency & ProgramTensorMetadataDependency::Shape) == ProgramTensorMetadataDependency::Shape; @@ -356,11 +357,11 @@ Status ShaderHelper::GenerateSourceCode(std::string& code, std::vector& sha size_t variable_count = 0; const auto& input_vars = vars_[static_cast(ProgramVariableScope::Input)]; for (const auto& input : input_vars) { - ss << "@group(0) @binding(" << variable_count++ << ") var " << input.name_ << ": array<" << input.StorageType() << ">;\n"; + ss << "@group(0) @binding(" << variable_count++ << ") var " << input->name_ << ": array<" << input->StorageType() << ">;\n"; } const auto& output_vars = vars_[static_cast(ProgramVariableScope::Output)]; for (const auto& output : output_vars) { - ss << "@group(0) @binding(" << variable_count++ << ") var " << output.name_ << ": array<" << output.StorageType() << ">;\n"; + ss << "@group(0) @binding(" << variable_count++ << ") var " << output->name_ << ": array<" << output->StorageType() << ">;\n"; } // @@ -373,14 +374,18 @@ Status ShaderHelper::GenerateSourceCode(std::string& code, std::vector& sha shape_uniform_ranks.reserve(input_vars.size() + output_vars.size()); for (const auto& input : vars_[static_cast(ProgramVariableScope::Input)]) { - bool use_uniform = (input.usage_ & ShaderVariable::UseUniform) && (input.usage_ & ShaderVariable::UseShapeAndStride) && input.rank_ > 0; + bool use_uniform = (input->usage_ & ShaderVariable::UseUniform) && + (input->usage_ & ShaderVariable::UseShapeAndStride) && + input->rank_ > 0; use_any_shape_uniform |= use_uniform; - shape_uniform_ranks.push_back(use_uniform ? input.rank_ : 0); + shape_uniform_ranks.push_back(use_uniform ? input->rank_ : 0); } for (const auto& output : vars_[static_cast(ProgramVariableScope::Output)]) { - bool use_uniform = (output.usage_ & ShaderVariable::UseUniform) && (output.usage_ & ShaderVariable::UseShapeAndStride) && output.rank_ > 0; + bool use_uniform = (output->usage_ & ShaderVariable::UseUniform) && + (output->usage_ & ShaderVariable::UseShapeAndStride) && + output->rank_ > 0; use_any_shape_uniform |= use_uniform; - shape_uniform_ranks.push_back(use_uniform ? output.rank_ : 0); + shape_uniform_ranks.push_back(use_uniform ? output->rank_ : 0); } if (use_any_shape_uniform || std::any_of(program_.UniformVariables().cbegin(), @@ -419,20 +424,20 @@ Status ShaderHelper::GenerateSourceCode(std::string& code, std::vector& sha }; for (const auto& input : vars_[static_cast(ProgramVariableScope::Input)]) { - const size_t rank = input.rank_; - if (rank > 0 && (input.usage_ & ShaderVariable::Usage::UseUniform) && (input.usage_ & ShaderVariable::Usage::UseShapeAndStride)) { - std::string shape = input.name_ + "_shape"; - std::string stride = input.name_ + "_stride"; + const size_t rank = input->rank_; + if (rank > 0 && (input->usage_ & ShaderVariable::Usage::UseUniform) && (input->usage_ & ShaderVariable::Usage::UseShapeAndStride)) { + std::string shape = input->name_ + "_shape"; + std::string stride = input->name_ + "_stride"; append_uniform(shape, ProgramUniformVariableDataType::Uint32, rank); append_uniform(stride, ProgramUniformVariableDataType::Uint32, rank - 1); } } for (const auto& output : vars_[static_cast(ProgramVariableScope::Output)]) { - const size_t rank = output.rank_; - if (rank > 0 && (output.usage_ & ShaderVariable::Usage::UseUniform) && (output.usage_ & ShaderVariable::Usage::UseShapeAndStride)) { - std::string shape = output.name_ + "_shape"; - std::string stride = output.name_ + "_stride"; + const size_t rank = output->rank_; + if (rank > 0 && (output->usage_ & ShaderVariable::Usage::UseUniform) && (output->usage_ & ShaderVariable::Usage::UseShapeAndStride)) { + std::string shape = output->name_ + "_shape"; + std::string stride = output->name_ + "_stride"; append_uniform(shape, ProgramUniformVariableDataType::Uint32, rank); append_uniform(stride, ProgramUniformVariableDataType::Uint32, rank - 1); } @@ -455,7 +460,7 @@ Status ShaderHelper::GenerateSourceCode(std::string& code, std::vector& sha ss << "\n"; for (const auto& var_group : vars_) { for (const auto& var : var_group) { - var.Impl(ss); + var->Impl(ss); } } ss << "\n"; diff --git a/onnxruntime/core/providers/webgpu/shader_helper.h b/onnxruntime/core/providers/webgpu/shader_helper.h index ca1bf9ce7ff58..23c1ff42b0df5 100644 --- a/onnxruntime/core/providers/webgpu/shader_helper.h +++ b/onnxruntime/core/providers/webgpu/shader_helper.h @@ -174,7 +174,7 @@ class ShaderHelper final { const ProgramBase& program_; const ProgramMetadata& program_metadata_; - std::array, static_cast(ProgramVariableScope::Count)> vars_; + std::array>, static_cast(ProgramVariableScope::Count)> vars_; std::ostringstream additional_implementation_; std::ostringstream body_; From cd961c3a75d12014f472a41e6eca81fa0639f63c Mon Sep 17 00:00:00 2001 From: Yulong Wang <7679871+fs-eire@users.noreply.github.com> Date: Thu, 5 Sep 2024 23:59:39 -0700 Subject: [PATCH 54/77] simplify cache hint interface. --- onnxruntime/core/providers/webgpu/program.h | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/onnxruntime/core/providers/webgpu/program.h b/onnxruntime/core/providers/webgpu/program.h index c48bdb1a4ff12..6b339af767f5e 100644 --- a/onnxruntime/core/providers/webgpu/program.h +++ b/onnxruntime/core/providers/webgpu/program.h @@ -245,9 +245,9 @@ class ProgramBase { // // set the cache hint for the program - template - ProgramBase& CacheHint(CacheHintArgs&&... args) { - cache_hint_ = absl::StrJoin(std::forward_as_tuple(std::forward(args)...), "|"); + template + ProgramBase& CacheHint(T&& hint) { + cache_hint_ = std::forward(hint); return *this; } From ddc2fbb7e948e21b566f2826b0e2c96e196db8ab Mon Sep 17 00:00:00 2001 From: Yulong Wang <7679871+fs-eire@users.noreply.github.com> Date: Fri, 6 Sep 2024 00:43:04 -0700 Subject: [PATCH 55/77] revise expand --- .../core/providers/webgpu/tensor/expand.cc | 17 ++++++++--------- .../core/providers/webgpu/tensor/expand.h | 5 ++--- 2 files changed, 10 insertions(+), 12 deletions(-) diff --git a/onnxruntime/core/providers/webgpu/tensor/expand.cc b/onnxruntime/core/providers/webgpu/tensor/expand.cc index 53991365d6543..9052095dec677 100644 --- a/onnxruntime/core/providers/webgpu/tensor/expand.cc +++ b/onnxruntime/core/providers/webgpu/tensor/expand.cc @@ -18,7 +18,7 @@ Status ExpandProgram::GenerateShaderCode(ShaderHelper& shader) const { ToProgramVariableDataType(Outputs()[0].tensor->GetElementType()), ShaderVariable::UseUniform); - shader.MainFunctionBody(shader.GuardAgainstOutOfBoundsWorkgroupSizes("uniforms.vec_size"), + shader.MainFunctionBody(shader.GuardAgainstOutOfBoundsWorkgroupSizes("uniforms.data_size"), "let output_indices = ", output.OffsetToIndices("global_idx"), ";\n", "let input_offset = ", input.BroadcastedIndicesToOffset("output_indices", output), ";\n", output.SetByOffset("global_idx", input.GetByOffset("input_offset"))); @@ -30,20 +30,19 @@ Status Expand::ComputeInternal(ComputeContext& context) const { const auto* input_tensor = context.Input(0); const auto* input_shape_tensor = context.Input(1); - const auto* p_shape = input_shape_tensor->Data(); - TensorShapeVector output_dims{p_shape, p_shape + input_shape_tensor->Shape().Size()}; - TensorShape output_shape(output_dims); + auto output_dims = input_shape_tensor->DataAsSpan(); + TensorShape output_shape{}; ORT_RETURN_IF_ERROR(ComputeBroadcastOutputShape(Node().Name(), input_tensor->Shape(), output_dims, output_shape)); auto* output_tensor = context.Output(0, output_shape); - SafeInt vec_size = output_shape.Size(); - ExpandProgram program{"Expand"}; + uint32_t data_size = SafeInt(output_shape.Size()); + ExpandProgram program{}; program .Inputs({{input_tensor, ProgramTensorMetadataDependency::TypeAndRank}}) .Outputs({{output_tensor, ProgramTensorMetadataDependency::Rank}}) - .DispatchGroupSize((vec_size + WORKGROUP_SIZE - 1) / WORKGROUP_SIZE) + .DispatchGroupSize((data_size + WORKGROUP_SIZE - 1) / WORKGROUP_SIZE) .UniformVariables({ - {static_cast(vec_size)}, + {data_size}, }); return context.RunProgram(program); } @@ -64,4 +63,4 @@ WEBGPU_EXPAND_VERSIONED_KERNEL(Expand, 8, 12, Expand, WebGpuSupportedFloatTypes( WEBGPU_EXPAND_KERNEL(Expand, 13, Expand, WebGpuSupportedFloatTypes()) } // namespace webgpu -}; // namespace onnxruntime +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/webgpu/tensor/expand.h b/onnxruntime/core/providers/webgpu/tensor/expand.h index a5c24f1fa4969..046520b479257 100644 --- a/onnxruntime/core/providers/webgpu/tensor/expand.h +++ b/onnxruntime/core/providers/webgpu/tensor/expand.h @@ -11,12 +11,11 @@ namespace webgpu { class ExpandProgram final : public Program { public: - ExpandProgram(const std::string& kernel_name) : Program{kernel_name} { - } + ExpandProgram() : Program{"Expand"} {} Status GenerateShaderCode(ShaderHelper& sh) const override; - WEBGPU_PROGRAM_DEFINE_UNIFORM_VARIABLES({"vec_size", ProgramUniformVariableDataType::Uint32}); + WEBGPU_PROGRAM_DEFINE_UNIFORM_VARIABLES({"data_size", ProgramUniformVariableDataType::Uint32}); }; class Expand final : public WebGpuKernel { From e8be835cae04ed21efa72a8e0d5ab417cbef7992 Mon Sep 17 00:00:00 2001 From: Yulong Wang <7679871+fs-eire@users.noreply.github.com> Date: Fri, 6 Sep 2024 00:50:54 -0700 Subject: [PATCH 56/77] revise unary --- .../core/providers/webgpu/math/unary_elementwise_ops.cc | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/onnxruntime/core/providers/webgpu/math/unary_elementwise_ops.cc b/onnxruntime/core/providers/webgpu/math/unary_elementwise_ops.cc index 079a192213775..630cfce486ca8 100644 --- a/onnxruntime/core/providers/webgpu/math/unary_elementwise_ops.cc +++ b/onnxruntime/core/providers/webgpu/math/unary_elementwise_ops.cc @@ -15,7 +15,7 @@ Status UnaryElementwiseProgram::GenerateShaderCode(ShaderHelper& shader) const { ShaderVariable::UseUniform); shader.AppendImplementation(additional_impl_); shader.MainFunctionBody(shader.GuardAgainstOutOfBoundsWorkgroupSizes("uniforms.vec_size"), - "let a = ", input.GetByOffset("global_idx"), ";\n", + " let a = ", input.GetByOffset("global_idx"), ";\n ", output.SetByOffset("global_idx", expression_)); return Status::OK(); @@ -119,7 +119,7 @@ WEBGPU_ELEMENTWISE_VERSIONED_KERNEL(Sigmoid, 6, 12, WebGpuSupportedFloatTypes()) WEBGPU_ELEMENTWISE_KERNEL(Sigmoid, 13, WebGpuSupportedFloatTypes()) constexpr char HardSigmoidImpl[] = R"( -fn hard_sigmoid_v(v: x_value_t) -> x_value_t { +fn hard_sigmoid_v(v: vec4) -> vec4 { let alpha = x_element_t(uniforms.f32_attr[0]); let beta_v = vec4(uniforms.f32_attr[1]); return max(vec4(0.0), @@ -129,7 +129,7 @@ fn hard_sigmoid_v(v: x_value_t) -> x_value_t { class HardSigmoid final : public UnaryElementwise { public: HardSigmoid(const OpKernelInfo& info) - : UnaryElementwise{info, "HardSigmoid", "hard_sigmoid_v(a)", HardSigmoidImpl, ShaderVariable::UseElementTypeAlias | ShaderVariable::UseValueTypeAlias} { + : UnaryElementwise{info, "HardSigmoid", "hard_sigmoid_v(a)", HardSigmoidImpl, ShaderVariable::UseElementTypeAlias} { // attr[0] is alpha, attr[1] is beta info.GetAttrOrDefault("alpha", attr, 0.2f); info.GetAttrOrDefault("beta", attr + 1, 0.5f); From bd7d592386932b5dd55793dd4a44328808114269 Mon Sep 17 00:00:00 2001 From: Yulong Wang <7679871+fs-eire@users.noreply.github.com> Date: Fri, 6 Sep 2024 15:28:35 -0700 Subject: [PATCH 57/77] Elu/Relu/LeakyRelu/ThresholdedRelu/Gelu --- .../webgpu/math/unary_elementwise_ops.cc | 87 +++++++++++++++++-- .../webgpu/math/unary_elementwise_ops.h | 2 + .../webgpu/webgpu_execution_provider.cc | 19 ++-- 3 files changed, 94 insertions(+), 14 deletions(-) diff --git a/onnxruntime/core/providers/webgpu/math/unary_elementwise_ops.cc b/onnxruntime/core/providers/webgpu/math/unary_elementwise_ops.cc index 630cfce486ca8..baa92fdc5c3dc 100644 --- a/onnxruntime/core/providers/webgpu/math/unary_elementwise_ops.cc +++ b/onnxruntime/core/providers/webgpu/math/unary_elementwise_ops.cc @@ -37,6 +37,9 @@ Status UnaryElementwise::ComputeInternal(ComputeContext& context) const { .UniformVariables({ {static_cast(vec_size)}, }); + if (!cache_hint.empty()) { + program.CacheHint(cache_hint); + } ORT_RETURN_IF_ERROR(ConfigureProgram(program)); return context.RunProgram(program); } @@ -172,7 +175,13 @@ WEBGPU_ELEMENTWISE_KERNEL(Cosh, 9, WebGpuSupportedFloatTypes()) // built-in function tanh() does not work with large input (f32 88.7 or f16 11.09) // https://github.com/gpuweb/gpuweb/issues/4458 -WEBGPU_ELEMENTWISE_IMPL(Tanh, "sign(a) * (1 - exp(-2 * abs(a))) / (1 + exp(-2 * abs(a)))") +constexpr char TanhImpl[] = R"( +fn tanh_v(a: x_value_t) -> x_value_t { + let expr = exp(-2 * abs(a)); + return sign(a) * (1 - expr) / (1 + expr); +} +)"; +WEBGPU_ELEMENTWISE_IMPL(Tanh, "tanh_v(a)", TanhImpl, ShaderVariable::UseValueTypeAlias) WEBGPU_ELEMENTWISE_VERSIONED_KERNEL(Tanh, 6, 12, WebGpuSupportedFloatTypes()) WEBGPU_ELEMENTWISE_KERNEL(Tanh, 13, WebGpuSupportedFloatTypes()) @@ -193,10 +202,78 @@ WEBGPU_ELEMENTWISE_KERNEL(Atanh, 9, WebGpuSupportedFloatTypes()) // todo: clip -// constexpr char EluImpl[] = R"( -//)"; -// -// WEBGPU_ELEMENTWISE_IMPL(Elu, "elu_v(a)", ) +class LinearUnit : public UnaryElementwise { + public: + LinearUnit(const OpKernelInfo& info, + const std::string& kernel_name, + const std::string& expression, + const std::string& additional_impl, + float default_alpha) + : UnaryElementwise{info, kernel_name, expression, additional_impl, ShaderVariable::UseElementTypeAlias} { + info.GetAttrOrDefault("alpha", &alpha_, default_alpha); + } + + Status ConfigureProgram(UnaryElementwiseProgram& program) const override { + program.UniformVariables({alpha_, {}}); + return Status::OK(); + } + + protected: + float alpha_; +}; + +#define WEBGPU_LU_IMPL(OP_TYPE, ...) \ + class OP_TYPE final : public LinearUnit { \ + public: \ + OP_TYPE(const OpKernelInfo& info) : LinearUnit{info, #OP_TYPE, __VA_ARGS__} {} \ + }; + +constexpr char EluImpl[] = R"( +fn elu(a: x_element_t) -> x_element_t { + let alpha = x_element_t(uniforms.f32_attr); + return select((exp(a) - 1.0) * alpha, a, a >= 0.0); +} + +fn elu_v(v: vec4) -> vec4 { + return vec4(elu(v.x), elu(v.y), elu(v.z), elu(v.w)); +} +)"; + +WEBGPU_LU_IMPL(Elu, "elu_v(a)", EluImpl, 1.0) +WEBGPU_ELEMENTWISE_KERNEL(Elu, 6, WebGpuSupportedFloatTypes()) + +// TODO: support attribute "approximate" +class Gelu : public UnaryElementwise { + public: + Gelu(const OpKernelInfo& info) + : UnaryElementwise{info, + "Gelu", + info.GetAttrOrDefault("approximate", "none") == "tanh" ? TanhBasedImpl : DefaultImpl, + info.GetAttrOrDefault("approximate", "none") == "tanh" ? TanhImpl : ErfImpl, + ShaderVariable::UseValueTypeAlias} { + cache_hint = info.GetAttrOrDefault("approximate", "none"); + } + + constexpr static const char DefaultImpl[] = "0.5 * a * (1.0 + erf_v(a * 0.7071067811865475))"; + constexpr static const char TanhBasedImpl[] = "0.5 * a * (1 + tanh_v(0.7978845608028654 * (a + 0.044715 * a * a * a)))"; + + protected: + float alpha_; +}; + +WEBGPU_ELEMENTWISE_KERNEL(Gelu, 20, WebGpuSupportedFloatTypes()) + +WEBGPU_ELEMENTWISE_IMPL(Relu, "select(x_value_t(0), a, a > x_value_t(0))", "", ShaderVariable::UseValueTypeAlias) +WEBGPU_ELEMENTWISE_VERSIONED_KERNEL(Relu, 6, 12, WebGpuSupportedFloatTypes()) +WEBGPU_ELEMENTWISE_VERSIONED_KERNEL(Relu, 13, 13, WebGpuSupportedFloatTypes()) +WEBGPU_ELEMENTWISE_KERNEL(Relu, 14, WebGpuSupportedFloatTypes()) + +WEBGPU_LU_IMPL(LeakyRelu, "select(x_element_t(uniforms.f32_attr) * a, a, a >= vec4(0))", "", 0.01f) +WEBGPU_ELEMENTWISE_VERSIONED_KERNEL(LeakyRelu, 6, 15, WebGpuSupportedFloatTypes()) +WEBGPU_ELEMENTWISE_KERNEL(LeakyRelu, 16, WebGpuSupportedFloatTypes()) + +WEBGPU_LU_IMPL(ThresholdedRelu, "select(vec4(0), a, a > vec4(uniforms.f32_attr))", "", 1.0f) +WEBGPU_ELEMENTWISE_KERNEL(ThresholdedRelu, 10, WebGpuSupportedFloatTypes()) // TODO: add other unary elementwise ops diff --git a/onnxruntime/core/providers/webgpu/math/unary_elementwise_ops.h b/onnxruntime/core/providers/webgpu/math/unary_elementwise_ops.h index 2d084bf227f72..711b0b0a6044c 100644 --- a/onnxruntime/core/providers/webgpu/math/unary_elementwise_ops.h +++ b/onnxruntime/core/providers/webgpu/math/unary_elementwise_ops.h @@ -45,6 +45,8 @@ class UnaryElementwise : public WebGpuKernel { additional_usage_{usage} {} protected: + std::string cache_hint; + Status ComputeInternal(ComputeContext& context) const final; virtual Status ConfigureProgram(UnaryElementwiseProgram& program) const { program.UniformVariables({{}, {}}); // empty for both float and int attribute(s) diff --git a/onnxruntime/core/providers/webgpu/webgpu_execution_provider.cc b/onnxruntime/core/providers/webgpu/webgpu_execution_provider.cc index 1ee7a51618f7f..decc74b59cae6 100644 --- a/onnxruntime/core/providers/webgpu/webgpu_execution_provider.cc +++ b/onnxruntime/core/providers/webgpu/webgpu_execution_provider.cc @@ -134,6 +134,8 @@ class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxD class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 14, Relu); class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 6, 15, LeakyRelu); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 16, LeakyRelu); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 10, ThresholdedRelu); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 20, Gelu); class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 1, 10, ReduceMax); class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 11, 11, ReduceMax); @@ -186,8 +188,6 @@ class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxD class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 13, 17, ReduceLogSumExp); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 18, ReduceLogSumExp); -class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 10, ThresholdedRelu); - class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 7, 12, Add); class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 13, 13, Add); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 14, Add); @@ -442,13 +442,14 @@ std::unique_ptr RegisterKernels() { // KERNEL_CREATE_INFO_VERSIONED(11, 11, Clip), // KERNEL_CREATE_INFO_VERSIONED(12, 12, Clip), // KERNEL_CREATE_INFO(13, Clip), - // KERNEL_CREATE_INFO(6, Elu), - // KERNEL_CREATE_INFO_VERSIONED(6, 12, Relu), - // KERNEL_CREATE_INFO_VERSIONED(13, 13, Relu), - // KERNEL_CREATE_INFO(14, Relu), - // KERNEL_CREATE_INFO_VERSIONED(6, 15, LeakyRelu), - // KERNEL_CREATE_INFO(16, LeakyRelu), - // KERNEL_CREATE_INFO(10, ThresholdedRelu), + KERNEL_CREATE_INFO(6, Elu), + KERNEL_CREATE_INFO_VERSIONED(6, 12, Relu), + KERNEL_CREATE_INFO_VERSIONED(13, 13, Relu), + KERNEL_CREATE_INFO(14, Relu), + KERNEL_CREATE_INFO_VERSIONED(6, 15, LeakyRelu), + KERNEL_CREATE_INFO(16, LeakyRelu), + KERNEL_CREATE_INFO(10, ThresholdedRelu), + KERNEL_CREATE_INFO(20, Gelu), // // binary - math // KERNEL_CREATE_INFO_VERSIONED(7, 12, Add), From 601e50f142478db99923453112c051918eef2a07 Mon Sep 17 00:00:00 2001 From: Yulong Wang <7679871+fs-eire@users.noreply.github.com> Date: Fri, 6 Sep 2024 15:40:03 -0700 Subject: [PATCH 58/77] remove unused field in class Gelu --- .../core/providers/webgpu/math/unary_elementwise_ops.cc | 3 --- 1 file changed, 3 deletions(-) diff --git a/onnxruntime/core/providers/webgpu/math/unary_elementwise_ops.cc b/onnxruntime/core/providers/webgpu/math/unary_elementwise_ops.cc index baa92fdc5c3dc..2c015524e1ac7 100644 --- a/onnxruntime/core/providers/webgpu/math/unary_elementwise_ops.cc +++ b/onnxruntime/core/providers/webgpu/math/unary_elementwise_ops.cc @@ -256,9 +256,6 @@ class Gelu : public UnaryElementwise { constexpr static const char DefaultImpl[] = "0.5 * a * (1.0 + erf_v(a * 0.7071067811865475))"; constexpr static const char TanhBasedImpl[] = "0.5 * a * (1 + tanh_v(0.7978845608028654 * (a + 0.044715 * a * a * a)))"; - - protected: - float alpha_; }; WEBGPU_ELEMENTWISE_KERNEL(Gelu, 20, WebGpuSupportedFloatTypes()) From 8f36da219dab7073f7a054e4b24ade0ea934d39c Mon Sep 17 00:00:00 2001 From: Yulong Wang <7679871+fs-eire@users.noreply.github.com> Date: Fri, 6 Sep 2024 16:53:12 -0700 Subject: [PATCH 59/77] remove out-of-dated comments --- onnxruntime/core/providers/webgpu/math/unary_elementwise_ops.cc | 1 - 1 file changed, 1 deletion(-) diff --git a/onnxruntime/core/providers/webgpu/math/unary_elementwise_ops.cc b/onnxruntime/core/providers/webgpu/math/unary_elementwise_ops.cc index 2c015524e1ac7..d28769f080719 100644 --- a/onnxruntime/core/providers/webgpu/math/unary_elementwise_ops.cc +++ b/onnxruntime/core/providers/webgpu/math/unary_elementwise_ops.cc @@ -242,7 +242,6 @@ fn elu_v(v: vec4) -> vec4 { WEBGPU_LU_IMPL(Elu, "elu_v(a)", EluImpl, 1.0) WEBGPU_ELEMENTWISE_KERNEL(Elu, 6, WebGpuSupportedFloatTypes()) -// TODO: support attribute "approximate" class Gelu : public UnaryElementwise { public: Gelu(const OpKernelInfo& info) From 72ebd856efc9a1088105d5f4a190ef07ecf3110b Mon Sep 17 00:00:00 2001 From: Yulong Wang <7679871+fs-eire@users.noreply.github.com> Date: Fri, 6 Sep 2024 19:01:33 -0700 Subject: [PATCH 60/77] Clip --- .../webgpu/math/unary_elementwise_ops.cc | 94 ++++++++++++++++--- .../webgpu/math/unary_elementwise_ops.h | 10 +- 2 files changed, 84 insertions(+), 20 deletions(-) diff --git a/onnxruntime/core/providers/webgpu/math/unary_elementwise_ops.cc b/onnxruntime/core/providers/webgpu/math/unary_elementwise_ops.cc index d28769f080719..ceaae426ddde6 100644 --- a/onnxruntime/core/providers/webgpu/math/unary_elementwise_ops.cc +++ b/onnxruntime/core/providers/webgpu/math/unary_elementwise_ops.cc @@ -1,6 +1,8 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. +#include + #include "core/providers/webgpu/math/unary_elementwise_ops.h" #include "core/providers/webgpu/webgpu_supported_types.h" @@ -40,7 +42,7 @@ Status UnaryElementwise::ComputeInternal(ComputeContext& context) const { if (!cache_hint.empty()) { program.CacheHint(cache_hint); } - ORT_RETURN_IF_ERROR(ConfigureProgram(program)); + ORT_RETURN_IF_ERROR(ConfigureProgram(context, program)); return context.RunProgram(program); } @@ -62,6 +64,12 @@ Status UnaryElementwise::ComputeInternal(ComputeContext& context) const { KernelDefBuilder().TypeConstraint("T", TYPE), \ OP_TYPE_AND_CLASS_NAME); +#define WEBGPU_ELEMENTWISE_BOOLEAN_KERNEL(OP_TYPE_AND_CLASS_NAME, VERSION) \ + ONNX_OPERATOR_KERNEL_EX( \ + OP_TYPE_AND_CLASS_NAME, kOnnxDomain, VERSION, kWebGpuExecutionProvider, \ + KernelDefBuilder().TypeConstraint("T", DataTypeImpl::GetTensorType()), \ + OP_TYPE_AND_CLASS_NAME); + // // math // @@ -123,8 +131,8 @@ WEBGPU_ELEMENTWISE_KERNEL(Sigmoid, 13, WebGpuSupportedFloatTypes()) constexpr char HardSigmoidImpl[] = R"( fn hard_sigmoid_v(v: vec4) -> vec4 { - let alpha = x_element_t(uniforms.f32_attr[0]); - let beta_v = vec4(uniforms.f32_attr[1]); + let alpha = x_element_t(uniforms.attr[0]); + let beta_v = vec4(uniforms.attr[1]); return max(vec4(0.0), min(vec4(1.0), alpha * v + beta_v)); } @@ -138,8 +146,8 @@ class HardSigmoid final : public UnaryElementwise { info.GetAttrOrDefault("beta", attr + 1, 0.5f); } - Status ConfigureProgram(UnaryElementwiseProgram& program) const override { - program.UniformVariables({gsl::make_span(attr, 2), {}}); + Status ConfigureProgram(const ComputeContext& /*context*/, UnaryElementwiseProgram& program) const override { + program.UniformVariables({gsl::make_span(attr, 2)}); return Status::OK(); } @@ -194,14 +202,72 @@ WEBGPU_ELEMENTWISE_KERNEL(Acosh, 9, WebGpuSupportedFloatTypes()) WEBGPU_ELEMENTWISE_IMPL(Atanh, "atanh(a)") WEBGPU_ELEMENTWISE_KERNEL(Atanh, 9, WebGpuSupportedFloatTypes()) -// todo: logical ops +WEBGPU_ELEMENTWISE_IMPL(Not, "!a") +WEBGPU_ELEMENTWISE_BOOLEAN_KERNEL(Not, 1) + +// No longer support Clip < opset 11 (where min and max are attributes) +// +// Use template class for "Clip" because the implementation is significantly different between float16 and float32 +template +class Clip final : public UnaryElementwise { + public: + Clip(const OpKernelInfo& info) + : UnaryElementwise{info, + "Clip", + std::is_same_v ? ClipF16Impl : ClipImpl, + "", ShaderVariable::UseElementTypeAlias} {} + + Status ConfigureProgram(const ComputeContext& context, UnaryElementwiseProgram& program) const override { + const auto* clip_min_tensor = context.Input(1); + const auto* clip_max_tensor = context.Input(2); + const T attr[] = {clip_min_tensor->Data()[0], + clip_max_tensor->Data()[0]}; + if constexpr (std::is_same_v) { + // F16: stores span as a single float + float encoded_value = *reinterpret_cast(attr); + program.UniformVariables({encoded_value}); + } else { + static_assert(sizeof(T) == sizeof(float), "T must be f32, i32 or u32"); + // stores span as-is + program.UniformVariables({gsl::make_span(attr, 2)}); + } + return Status::OK(); + } + + // uniforms.attr is a f32 value. It is encoded as a float for 2 f16 values. + // bitcast>(uniforms.attr)[0] is clip_min, bitcast>(uniforms.attr)[1] is clip_max + constexpr static const char ClipF16Impl[] = "clamp(a, vec4(bitcast>(uniforms.attr)[0]), vec4(bitcast>(uniforms.attr)[1]))"; + + // the size of element of uniforms.attr should be the same as x_element_t. use bitcast to convert between them + // uniforms.attr[0] is clip_min, uniforms.attr[1] is clip_max + constexpr static const char ClipImpl[] = "clamp(a, vec4(bitcast(uniforms.attr[0])), vec4(bitcast(uniforms.attr[1])))"; +}; +#define WEBGPU_CLIP_KERNEL(TYPE) \ + ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_EX(Clip, kOnnxDomain, 11, 11, TYPE, kWebGpuExecutionProvider, \ + KernelDefBuilder() \ + .TypeConstraint("T", DataTypeImpl::GetTensorType()) \ + .InputMemoryType(OrtMemTypeCPU, 1) \ + .InputMemoryType(OrtMemTypeCPU, 2), \ + Clip) \ + ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_EX(Clip, kOnnxDomain, 12, 12, TYPE, kWebGpuExecutionProvider, \ + KernelDefBuilder() \ + .TypeConstraint("T", DataTypeImpl::GetTensorType()) \ + .InputMemoryType(OrtMemTypeCPU, 1) \ + .InputMemoryType(OrtMemTypeCPU, 2), \ + Clip) \ + ONNX_OPERATOR_TYPED_KERNEL_EX(Clip, kOnnxDomain, 13, TYPE, kWebGpuExecutionProvider, \ + KernelDefBuilder() \ + .TypeConstraint("T", DataTypeImpl::GetTensorType()) \ + .InputMemoryType(OrtMemTypeCPU, 1) \ + .InputMemoryType(OrtMemTypeCPU, 2), \ + Clip); +WEBGPU_CLIP_KERNEL(float) +WEBGPU_CLIP_KERNEL(MLFloat16) // // activation // -// todo: clip - class LinearUnit : public UnaryElementwise { public: LinearUnit(const OpKernelInfo& info, @@ -213,8 +279,8 @@ class LinearUnit : public UnaryElementwise { info.GetAttrOrDefault("alpha", &alpha_, default_alpha); } - Status ConfigureProgram(UnaryElementwiseProgram& program) const override { - program.UniformVariables({alpha_, {}}); + Status ConfigureProgram(const ComputeContext& /*context*/, UnaryElementwiseProgram& program) const override { + program.UniformVariables({alpha_}); return Status::OK(); } @@ -230,7 +296,7 @@ class LinearUnit : public UnaryElementwise { constexpr char EluImpl[] = R"( fn elu(a: x_element_t) -> x_element_t { - let alpha = x_element_t(uniforms.f32_attr); + let alpha = x_element_t(uniforms.attr); return select((exp(a) - 1.0) * alpha, a, a >= 0.0); } @@ -264,14 +330,12 @@ WEBGPU_ELEMENTWISE_VERSIONED_KERNEL(Relu, 6, 12, WebGpuSupportedFloatTypes()) WEBGPU_ELEMENTWISE_VERSIONED_KERNEL(Relu, 13, 13, WebGpuSupportedFloatTypes()) WEBGPU_ELEMENTWISE_KERNEL(Relu, 14, WebGpuSupportedFloatTypes()) -WEBGPU_LU_IMPL(LeakyRelu, "select(x_element_t(uniforms.f32_attr) * a, a, a >= vec4(0))", "", 0.01f) +WEBGPU_LU_IMPL(LeakyRelu, "select(x_element_t(uniforms.attr) * a, a, a >= vec4(0))", "", 0.01f) WEBGPU_ELEMENTWISE_VERSIONED_KERNEL(LeakyRelu, 6, 15, WebGpuSupportedFloatTypes()) WEBGPU_ELEMENTWISE_KERNEL(LeakyRelu, 16, WebGpuSupportedFloatTypes()) -WEBGPU_LU_IMPL(ThresholdedRelu, "select(vec4(0), a, a > vec4(uniforms.f32_attr))", "", 1.0f) +WEBGPU_LU_IMPL(ThresholdedRelu, "select(vec4(0), a, a > vec4(uniforms.attr))", "", 1.0f) WEBGPU_ELEMENTWISE_KERNEL(ThresholdedRelu, 10, WebGpuSupportedFloatTypes()) -// TODO: add other unary elementwise ops - } // namespace webgpu } // namespace onnxruntime diff --git a/onnxruntime/core/providers/webgpu/math/unary_elementwise_ops.h b/onnxruntime/core/providers/webgpu/math/unary_elementwise_ops.h index 711b0b0a6044c..d870278f4c090 100644 --- a/onnxruntime/core/providers/webgpu/math/unary_elementwise_ops.h +++ b/onnxruntime/core/providers/webgpu/math/unary_elementwise_ops.h @@ -19,9 +19,9 @@ class UnaryElementwiseProgram final : public Program { Status GenerateShaderCode(ShaderHelper& sh) const override; WEBGPU_PROGRAM_DEFINE_UNIFORM_VARIABLES( - {"vec_size", ProgramUniformVariableDataType::Uint32}, // output size - {"f32_attr", ProgramUniformVariableDataType::Float32}, // float type attribute(s) - {"i32_attr", ProgramUniformVariableDataType::Int32}); // int type attribute(s) + {"vec_size", ProgramUniformVariableDataType::Uint32}, // output size + {"attr", ProgramUniformVariableDataType::Float32}); // float type attribute(s) + // TODO: add u32/i32 attribute(s) if needed private: std::string_view expression_; @@ -48,8 +48,8 @@ class UnaryElementwise : public WebGpuKernel { std::string cache_hint; Status ComputeInternal(ComputeContext& context) const final; - virtual Status ConfigureProgram(UnaryElementwiseProgram& program) const { - program.UniformVariables({{}, {}}); // empty for both float and int attribute(s) + virtual Status ConfigureProgram(const ComputeContext& /*context*/, UnaryElementwiseProgram& program) const { + program.UniformVariables({{}}); // empty for attribute(s) return Status::OK(); } From a3244aeb685c1d048b684672429ae5b17af343de Mon Sep 17 00:00:00 2001 From: Yulong Wang <7679871+fs-eire@users.noreply.github.com> Date: Fri, 6 Sep 2024 23:24:11 -0700 Subject: [PATCH 61/77] fix rank in shader helper --- onnxruntime/core/providers/webgpu/shader_helper.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/onnxruntime/core/providers/webgpu/shader_helper.cc b/onnxruntime/core/providers/webgpu/shader_helper.cc index f43806d1406c2..7e6130dd4e917 100644 --- a/onnxruntime/core/providers/webgpu/shader_helper.cc +++ b/onnxruntime/core/providers/webgpu/shader_helper.cc @@ -160,7 +160,7 @@ Status ValidateVariableShape(const TensorShape& origin_shape, "Tensor original shape ", origin_shape, " cannot reshape to ", override_shape, " with component number ", num_components); } else if (num_components > 1) { // if shape is not overriden, assert origin_shape[-1] % 4 == 0 - ORT_RETURN_IF_NOT(origin_shape.Size() > 0 && origin_shape[origin_shape.Size() - 1] % num_components == 0, + ORT_RETURN_IF_NOT(origin_shape.Size() > 0 && origin_shape[origin_shape.NumDimensions() - 1] % num_components == 0, "Tensor original shape ", origin_shape, " cannot be divided by component number ", num_components); } From 5a2ae8c54347c0a51bad33d1ef55b9c1077a098c Mon Sep 17 00:00:00 2001 From: Yulong Wang <7679871+fs-eire@users.noreply.github.com> Date: Mon, 9 Sep 2024 01:56:28 -0700 Subject: [PATCH 62/77] fix shader variable --- onnxruntime/core/providers/webgpu/shader_variable.cc | 4 ++-- onnxruntime/core/providers/webgpu/shader_variable.h | 7 ++++--- 2 files changed, 6 insertions(+), 5 deletions(-) diff --git a/onnxruntime/core/providers/webgpu/shader_variable.cc b/onnxruntime/core/providers/webgpu/shader_variable.cc index f5fc236aca71d..07c5915be466b 100644 --- a/onnxruntime/core/providers/webgpu/shader_variable.cc +++ b/onnxruntime/core/providers/webgpu/shader_variable.cc @@ -172,8 +172,8 @@ void ShaderVariable::Impl(std::ostringstream& ss) const { // Implementation of "fn {res_name}_bi2o_{name}" if (usage_ & UseBroadcastedIndicesToOffset) { if (rank_ > 0) { - for (const auto& iter : broadcasted_to_) { - const auto& broadcasted_result = iter.get(); + for (const auto& broadcasted_result_ptr : broadcasted_to_) { + const auto& broadcasted_result = *broadcasted_result_ptr; SS("fn ", broadcasted_result.name_, "_bi2o_", name_, "(indices : ", broadcasted_result.indices_type_, ")->u32 {\n"); if (rank_ == 1) { SS(" return ", broadcasted_result.IndicesGet("indices", broadcasted_result.rank_ - 1), " % ", shape, ";\n"); diff --git a/onnxruntime/core/providers/webgpu/shader_variable.h b/onnxruntime/core/providers/webgpu/shader_variable.h index aa186d58740e3..d4281dd31d65c 100644 --- a/onnxruntime/core/providers/webgpu/shader_variable.h +++ b/onnxruntime/core/providers/webgpu/shader_variable.h @@ -4,6 +4,7 @@ #pragma once #include +#include #include "core/framework/tensor_shape.h" @@ -60,7 +61,7 @@ class ShaderVariable { ShaderVariable& operator=(ShaderVariable&&) = default; // get the name of the variable. - std::string_view Name() const; + inline std::string_view Name() const { return name_; } // create a WGSL expression ({varname}_indices_t) for getting indices from offset. // \param offset: a WGSL expression (u32) representing the offset. @@ -147,7 +148,7 @@ class ShaderVariable { TensorShape dims_; mutable Usage usage_; - mutable std::vector> broadcasted_to_; + mutable std::set broadcasted_to_; // unlike storage/element/value type, indices type is not a string view to a constant string. so we need to store it. std::string indices_type_; @@ -202,7 +203,7 @@ inline std::string ShaderVariable::IndicesToOffset(std::string_view indices_expr inline std::string ShaderVariable::BroadcastedIndicesToOffset(std::string_view indices_expr, const ShaderVariable& broadcasted_result) const { usage_ |= UseBroadcastedIndicesToOffset | UseShapeAndStride; - broadcasted_to_.push_back(broadcasted_result); + broadcasted_to_.insert(&broadcasted_result); return rank_ == 0 ? "0" : MakeStringWithClassicLocale(broadcasted_result.name_, "_bi2o_", name_, '(', indices_expr, ')'); From aa54ff8012d45ad6dc7ade798957cefa971397f8 Mon Sep 17 00:00:00 2001 From: Yulong Wang <7679871+fs-eire@users.noreply.github.com> Date: Mon, 9 Sep 2024 03:47:09 -0700 Subject: [PATCH 63/77] move components number from variable to program --- .../webgpu/math/unary_elementwise_ops.cc | 12 +-- onnxruntime/core/providers/webgpu/program.h | 74 ++++++++++++------- .../core/providers/webgpu/shader_helper.cc | 39 ++++++---- .../core/providers/webgpu/shader_helper.h | 4 - .../core/providers/webgpu/tensor/expand.cc | 8 +- 5 files changed, 77 insertions(+), 60 deletions(-) diff --git a/onnxruntime/core/providers/webgpu/math/unary_elementwise_ops.cc b/onnxruntime/core/providers/webgpu/math/unary_elementwise_ops.cc index ceaae426ddde6..8d8f855ec20ae 100644 --- a/onnxruntime/core/providers/webgpu/math/unary_elementwise_ops.cc +++ b/onnxruntime/core/providers/webgpu/math/unary_elementwise_ops.cc @@ -9,12 +9,8 @@ namespace onnxruntime { namespace webgpu { Status UnaryElementwiseProgram::GenerateShaderCode(ShaderHelper& shader) const { - const auto& input = shader.AddInput("x", - ToProgramVariableDataType(Inputs()[0].tensor->GetElementType(), 4), - ShaderVariable::UseUniform | additional_usage_); - const auto& output = shader.AddOutput("y", - ToProgramVariableDataType(Outputs()[0].tensor->GetElementType(), 4), - ShaderVariable::UseUniform); + const auto& input = shader.AddInput("x", ShaderVariable::UseUniform | additional_usage_); + const auto& output = shader.AddOutput("y", ShaderVariable::UseUniform); shader.AppendImplementation(additional_impl_); shader.MainFunctionBody(shader.GuardAgainstOutOfBoundsWorkgroupSizes("uniforms.vec_size"), " let a = ", input.GetByOffset("global_idx"), ";\n ", @@ -33,8 +29,8 @@ Status UnaryElementwise::ComputeInternal(ComputeContext& context) const { SafeInt vec_size = (size + 3) / 4; UnaryElementwiseProgram program{kernel_name_, expression_, additional_impl_, additional_usage_}; program - .Inputs({{input_tensor, ProgramTensorMetadataDependency::Type, {vec_size}}}) - .Outputs({{output_tensor, ProgramTensorMetadataDependency::None, {vec_size}}}) + .Inputs({{input_tensor, ProgramTensorMetadataDependency::Type, {vec_size}, 4}}) + .Outputs({{output_tensor, ProgramTensorMetadataDependency::None, {vec_size}, 4}}) .DispatchGroupSize((vec_size + WORKGROUP_SIZE - 1) / WORKGROUP_SIZE) .UniformVariables({ {static_cast(vec_size)}, diff --git a/onnxruntime/core/providers/webgpu/program.h b/onnxruntime/core/providers/webgpu/program.h index 6b339af767f5e..38e7a842aa32e 100644 --- a/onnxruntime/core/providers/webgpu/program.h +++ b/onnxruntime/core/providers/webgpu/program.h @@ -163,34 +163,6 @@ inline ProgramTensorMetadataDependency& operator&=(ProgramTensorMetadataDependen return (ProgramTensorMetadataDependency&)((int&)a &= (int&)b); } -struct ProgramInput { - ProgramInput(const Tensor* tensor) - : ProgramInput{tensor, ProgramTensorMetadataDependency::TypeAndRank} {} - ProgramInput(const Tensor* tensor, ProgramTensorMetadataDependency dependency) - : tensor{tensor}, dependency{dependency}, use_override_shape{false}, override_shape{} {} - ProgramInput(const Tensor* tensor, ProgramTensorMetadataDependency dependency, const TensorShape& override_shape) - : tensor{tensor}, dependency{dependency}, use_override_shape{true}, override_shape{override_shape} {} - - const Tensor* tensor; - ProgramTensorMetadataDependency dependency; - bool use_override_shape; - TensorShape override_shape; -}; - -struct ProgramOutput { - ProgramOutput(Tensor* tensor) - : ProgramOutput{tensor, ProgramTensorMetadataDependency::None} {} - ProgramOutput(Tensor* tensor, ProgramTensorMetadataDependency dependency) - : tensor{tensor}, dependency{dependency}, use_override_shape{false}, override_shape{} {} - ProgramOutput(Tensor* tensor, ProgramTensorMetadataDependency dependency, const TensorShape& override_shape) - : tensor{tensor}, dependency{dependency}, use_override_shape{true}, override_shape{override_shape} {} - - Tensor* tensor; - ProgramTensorMetadataDependency dependency; - bool use_override_shape; - TensorShape override_shape; -}; - constexpr SafeInt WORKGROUP_SIZE = 64; // represents the scope of a variable in a shader program. @@ -232,6 +204,52 @@ int NumberOfComponents(ProgramVariableDataType type); ProgramVariableDataType ToProgramVariableDataType(int32_t element_type, int component = 1); +struct ProgramInput { + ProgramInput(const Tensor* tensor) + : ProgramInput{tensor, ProgramTensorMetadataDependency::TypeAndRank} {} + ProgramInput(const Tensor* tensor, ProgramTensorMetadataDependency dependency, int component = 1) + : tensor{tensor}, + dependency{dependency}, + var_type{ToProgramVariableDataType(tensor->GetElementType(), component)}, + use_override_shape{false}, + override_shape{} {} + ProgramInput(const Tensor* tensor, ProgramTensorMetadataDependency dependency, const TensorShape& override_shape, int component) + : tensor{tensor}, + dependency{dependency}, + var_type{ToProgramVariableDataType(tensor->GetElementType(), component)}, + use_override_shape{true}, + override_shape{override_shape} {} + + const Tensor* tensor; + ProgramTensorMetadataDependency dependency; + ProgramVariableDataType var_type; + bool use_override_shape; + TensorShape override_shape; +}; + +struct ProgramOutput { + ProgramOutput(Tensor* tensor) + : ProgramOutput{tensor, ProgramTensorMetadataDependency::None} {} + ProgramOutput(Tensor* tensor, ProgramTensorMetadataDependency dependency, int component = 1) + : tensor{tensor}, + dependency{dependency}, + var_type{ToProgramVariableDataType(tensor->GetElementType(), component)}, + use_override_shape{false}, + override_shape{} {} + ProgramOutput(Tensor* tensor, ProgramTensorMetadataDependency dependency, const TensorShape& override_shape, int component) + : tensor{tensor}, + dependency{dependency}, + var_type{ToProgramVariableDataType(tensor->GetElementType(), component)}, + use_override_shape{true}, + override_shape{override_shape} {} + + Tensor* tensor; + ProgramTensorMetadataDependency dependency; + ProgramVariableDataType var_type; + bool use_override_shape; + TensorShape override_shape; +}; + namespace detail { class ProgramWrapper; } diff --git a/onnxruntime/core/providers/webgpu/shader_helper.cc b/onnxruntime/core/providers/webgpu/shader_helper.cc index 7e6130dd4e917..cd21f4752f300 100644 --- a/onnxruntime/core/providers/webgpu/shader_helper.cc +++ b/onnxruntime/core/providers/webgpu/shader_helper.cc @@ -27,9 +27,7 @@ ShaderHelper::ShaderHelper(const ProgramBase& program, dispatch_group_size_y_{dispatch_group_size_y}, dispatch_group_size_z_{dispatch_group_size_z}, program_{program}, - program_metadata_{program_metadata}, - use_f16_{false} { -} + program_metadata_{program_metadata} {} Status ShaderHelper::Init() { // dispatch group size is normalized so no need to validate it here @@ -80,24 +78,24 @@ Status ShaderHelper::Init() { return Status::OK(); } -const ShaderVariable& ShaderHelper::AddInput(const std::string& name, ProgramVariableDataType type, ShaderVariable::Usage usage) { +const ShaderVariable& ShaderHelper::AddInput(const std::string& name, ShaderVariable::Usage usage) { const size_t input_index = vars_[std::underlying_type::type(ProgramVariableScope::Input)].size(); ORT_ENFORCE(input_index < program_.Inputs().size(), "Too many inputs in the program (", program_.Inputs().size(), ")"); const auto& dims = program_.Inputs()[input_index].use_override_shape ? program_.Inputs()[input_index].override_shape : program_.Inputs()[input_index].tensor->Shape(); - return AddVariableImpl(ProgramVariableScope::Input, name, type, usage, dims); + return AddVariableImpl(ProgramVariableScope::Input, name, usage, dims); } -const ShaderVariable& ShaderHelper::AddOutput(const std::string& name, ProgramVariableDataType type, ShaderVariable::Usage usage) { +const ShaderVariable& ShaderHelper::AddOutput(const std::string& name, ShaderVariable::Usage usage) { const size_t output_index = vars_[std::underlying_type::type(ProgramVariableScope::Output)].size(); ORT_ENFORCE(output_index < program_.Outputs().size(), "Too many outputs in the program (", program_.Outputs().size(), ")"); const auto& dims = program_.Outputs()[output_index].use_override_shape ? program_.Outputs()[output_index].override_shape : program_.Outputs()[output_index].tensor->Shape(); - return AddVariableImpl(ProgramVariableScope::Output, name, type, usage, dims); + return AddVariableImpl(ProgramVariableScope::Output, name, usage, dims); } #ifndef NDEBUG // if debug build @@ -200,7 +198,6 @@ Status ValidateVariableDependency(ProgramTensorMetadataDependency dependency, Sh const ShaderVariable& ShaderHelper::AddVariableImpl(ProgramVariableScope scope, const std::string& name, - ProgramVariableDataType type, ShaderVariable::Usage usage, const TensorShape& dims) { if (scope == ProgramVariableScope::Input || scope == ProgramVariableScope::Output) { @@ -210,15 +207,20 @@ const ShaderVariable& ShaderHelper::AddVariableImpl(ProgramVariableScope scope, "Too many storage buffers in shader. Max is ", limits_.maxStorageBuffersPerShaderStage); } - if (type == ProgramVariableDataType::Float16 || type == ProgramVariableDataType::Vec2Float16 || type == ProgramVariableDataType::Vec4Float16) { - use_f16_ = true; - } + auto& vars = vars_[std::underlying_type::type(scope)]; + ProgramVariableDataType type = ProgramVariableDataType::InvalidType; - if (scope == ProgramVariableScope::Local) { + if (scope == ProgramVariableScope::Input) { + const auto& input = program_.Inputs()[vars.size()]; + type = input.var_type; + } else if (scope == ProgramVariableScope::Output) { + const auto& output = program_.Outputs()[vars.size()]; + type = output.var_type; + } else { ORT_NOT_IMPLEMENTED("Local variables are not supported yet."); } - const auto& var = vars_[std::underlying_type::type(scope)].emplace_back(std::make_unique(name, type, usage, dims)); + const auto& var = vars.emplace_back(std::make_unique(name, type, usage, dims)); return *var; } @@ -311,7 +313,16 @@ Status ShaderHelper::GenerateSourceCode(std::string& code, std::vector& sha // // Section feature enabling // - if (use_f16_) { + if (std::any_of(program_.Inputs().begin(), + program_.Inputs().end(), + [](const ProgramInput& input) { + return input.tensor->GetElementType() == ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT16; + }) || + std::any_of(program_.Outputs().begin(), + program_.Outputs().end(), + [](const ProgramOutput& output) { + return output.tensor->GetElementType() == ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT16; + })) { ORT_RETURN_IF_NOT(device_.HasFeature(wgpu::FeatureName::ShaderF16), "Program ", program_.Name(), " requires f16 but the device does not support it."); ss << "enable f16;\n"; if (device_.HasFeature(wgpu::FeatureName::SubgroupsF16)) { diff --git a/onnxruntime/core/providers/webgpu/shader_helper.h b/onnxruntime/core/providers/webgpu/shader_helper.h index 23c1ff42b0df5..08ff111f8a690 100644 --- a/onnxruntime/core/providers/webgpu/shader_helper.h +++ b/onnxruntime/core/providers/webgpu/shader_helper.h @@ -81,14 +81,12 @@ class ShaderHelper final { // // depending on the usage of the variable, additional code may be generated. const ShaderVariable& AddInput(const std::string& name, - ProgramVariableDataType type, ShaderVariable::Usage usage = ShaderVariable::UseIndicesTypeAlias | ShaderVariable::UseValueTypeAlias | ShaderVariable::UseUniform); // Add an output variable to the shader. // // depending on the usage of the variable, additional code may be generated. const ShaderVariable& AddOutput(const std::string& name, - ProgramVariableDataType type, ShaderVariable::Usage usage = ShaderVariable::UseIndicesTypeAlias | ShaderVariable::UseValueTypeAlias | ShaderVariable::UseUniform); // Append additional implementation code to the shader. @@ -140,7 +138,6 @@ class ShaderHelper final { const ShaderVariable& AddVariableImpl(ProgramVariableScope scope, const std::string& name, - ProgramVariableDataType type, ShaderVariable::Usage usage, const TensorShape& dims); @@ -178,7 +175,6 @@ class ShaderHelper final { std::ostringstream additional_implementation_; std::ostringstream body_; - bool use_f16_ = false; bool body_set_ = false; }; diff --git a/onnxruntime/core/providers/webgpu/tensor/expand.cc b/onnxruntime/core/providers/webgpu/tensor/expand.cc index 9052095dec677..82451c9398243 100644 --- a/onnxruntime/core/providers/webgpu/tensor/expand.cc +++ b/onnxruntime/core/providers/webgpu/tensor/expand.cc @@ -11,12 +11,8 @@ namespace onnxruntime { namespace webgpu { Status ExpandProgram::GenerateShaderCode(ShaderHelper& shader) const { - const auto& input = shader.AddInput("input", - ToProgramVariableDataType(Inputs()[0].tensor->GetElementType()), - ShaderVariable::UseUniform); - const auto& output = shader.AddOutput("output", - ToProgramVariableDataType(Outputs()[0].tensor->GetElementType()), - ShaderVariable::UseUniform); + const auto& input = shader.AddInput("input", ShaderVariable::UseUniform); + const auto& output = shader.AddOutput("output", ShaderVariable::UseUniform); shader.MainFunctionBody(shader.GuardAgainstOutOfBoundsWorkgroupSizes("uniforms.data_size"), "let output_indices = ", output.OffsetToIndices("global_idx"), ";\n", From 969384d23c2ea6043d9450b8007dc59455674f4e Mon Sep 17 00:00:00 2001 From: Yulong Wang <7679871+fs-eire@users.noreply.github.com> Date: Mon, 9 Sep 2024 14:04:19 -0700 Subject: [PATCH 64/77] mark components in cache key --- onnxruntime/core/providers/webgpu/program.cc | 24 +++++++++++++++++++ onnxruntime/core/providers/webgpu/program.h | 3 +++ .../providers/webgpu/program_cache_key.cc | 11 +++++---- 3 files changed, 33 insertions(+), 5 deletions(-) diff --git a/onnxruntime/core/providers/webgpu/program.cc b/onnxruntime/core/providers/webgpu/program.cc index 4a5785dc4def1..d4a2b24172d07 100644 --- a/onnxruntime/core/providers/webgpu/program.cc +++ b/onnxruntime/core/providers/webgpu/program.cc @@ -82,6 +82,30 @@ std::ostream& operator<<(std::ostream& os, ProgramTensorMetadataDependency dep) return os; } +#ifndef NDEBUG +constexpr std::string_view ProgramVariableDataTypeName[] = { + "f32", // f32 + "f32x2", // vec2f32 + "f32x4", // vec4f32 + "f16", // f16 + "f16x2", // vec2f16 + "f16x4", // vec4f16 + "i32", // i32 + "i32x2", // vec2i32 + "i32x4", // vec4i32 + "u32", // u32 + "u32x2", // vec2u32 + "u32x4", // vec4u32 + "i64", // int64 + "u64", // uint64 + "boolx4", // vec4bool +}; +std::ostream& operator<<(std::ostream& os, ProgramVariableDataType type) { + os << ProgramVariableDataTypeName[std::underlying_type::type(type)]; + return os; +} +#endif + int NumberOfComponents(ProgramVariableDataType type) { switch (type) { case ProgramVariableDataType::Float32: diff --git a/onnxruntime/core/providers/webgpu/program.h b/onnxruntime/core/providers/webgpu/program.h index 38e7a842aa32e..e162cddbb6408 100644 --- a/onnxruntime/core/providers/webgpu/program.h +++ b/onnxruntime/core/providers/webgpu/program.h @@ -199,6 +199,9 @@ enum class ProgramVariableDataType { Uint64, Vec4Bool, }; +#ifndef NDEBUG +std::ostream& operator<<(std::ostream& os, ProgramVariableDataType); +#endif int NumberOfComponents(ProgramVariableDataType type); diff --git a/onnxruntime/core/providers/webgpu/program_cache_key.cc b/onnxruntime/core/providers/webgpu/program_cache_key.cc index c6ab16a73423d..09a536f7916b2 100644 --- a/onnxruntime/core/providers/webgpu/program_cache_key.cc +++ b/onnxruntime/core/providers/webgpu/program_cache_key.cc @@ -9,7 +9,8 @@ namespace onnxruntime { namespace webgpu { namespace { -void AppendTensorInfo(std::ostringstream& ss, const Tensor& tensor, ProgramTensorMetadataDependency dependency, bool& first) { +// append the info of an input or output to the cachekey +void AppendTensorInfo(std::ostringstream& ss, const Tensor& tensor, ProgramVariableDataType var_type, ProgramTensorMetadataDependency dependency, bool& first) { if (first) { first = false; } else { @@ -17,9 +18,9 @@ void AppendTensorInfo(std::ostringstream& ss, const Tensor& tensor, ProgramTenso } if ((dependency & ProgramTensorMetadataDependency::Type) == ProgramTensorMetadataDependency::Type) { #ifndef NDEBUG // if debug build - ss << DataTypeImpl::ToString(tensor.DataType()); + ss << var_type; #else - ss << output.tensor->GetElementType(); + ss << static_cast(var_type); #endif ss << ';'; } @@ -87,13 +88,13 @@ std::string CalculateProgramCacheKey(const ProgramBase& program, bool is_1d_disp ss << ":" D("Inputs="); first = true; for (const auto& input : program.Inputs()) { - AppendTensorInfo(ss, *input.tensor, input.dependency, first); + AppendTensorInfo(ss, *input.tensor, input.var_type, input.dependency, first); } ss << ":" D("Outputs="); first = true; for (const auto& output : program.Outputs()) { - AppendTensorInfo(ss, *output.tensor, output.dependency, first); + AppendTensorInfo(ss, *output.tensor, output.var_type, output.dependency, first); } return ss.str(); From 6b824861ad565b5b549cd80aad186cd71ec8835c Mon Sep 17 00:00:00 2001 From: Jiajia Qin Date: Tue, 10 Sep 2024 08:30:54 +0800 Subject: [PATCH 65/77] Add FastGelu op (#21991) ### Description ### Motivation and Context --------- Co-authored-by: Yulong Wang <7679871+fs-eire@users.noreply.github.com> --- .../contrib_ops/webgpu/bert/fast_gelu.cc | 84 +++++++++++++++++++ .../contrib_ops/webgpu/bert/fast_gelu.h | 38 +++++++++ .../webgpu/webgpu_contrib_kernels.cc | 42 +++++----- .../webgpu/webgpu_contrib_kernels.h | 5 +- onnxruntime/core/providers/webgpu/program.cc | 19 ++++- onnxruntime/core/providers/webgpu/program.h | 13 ++- .../test/contrib_ops/fastgelu_op_test.cc | 13 ++- 7 files changed, 184 insertions(+), 30 deletions(-) create mode 100644 onnxruntime/contrib_ops/webgpu/bert/fast_gelu.cc create mode 100644 onnxruntime/contrib_ops/webgpu/bert/fast_gelu.h diff --git a/onnxruntime/contrib_ops/webgpu/bert/fast_gelu.cc b/onnxruntime/contrib_ops/webgpu/bert/fast_gelu.cc new file mode 100644 index 0000000000000..42f056206f3f5 --- /dev/null +++ b/onnxruntime/contrib_ops/webgpu/bert/fast_gelu.cc @@ -0,0 +1,84 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "fast_gelu.h" +#include "core/providers/webgpu/shader_helper.h" +#include "core/providers/webgpu/webgpu_supported_types.h" +#include "contrib_ops/webgpu/webgpu_contrib_kernels.h" + +namespace onnxruntime { +namespace contrib { +namespace webgpu { + +ONNX_OPERATOR_KERNEL_EX( + FastGelu, + kMSDomain, + 1, + kWebGpuExecutionProvider, + (*KernelDefBuilder::Create()) + .TypeConstraint("T", WebGpuSupportedFloatTypes()), + FastGelu); + +Status FastGeluProgram::GenerateShaderCode(ShaderHelper& shader) const { + const auto& input = shader.AddInput("input", ShaderVariable::UseUniform | ShaderVariable::UseValueTypeAlias); + const auto& output = shader.AddOutput("output", ShaderVariable::UseUniform); + + std::string add_bias = ""; + if (Inputs().size() > 1) { + const auto& bias = shader.AddInput("bias", ShaderVariable::UseUniform | ShaderVariable::UseShapeAndStride); + add_bias = bias_components_ == 1 ? " let bias_offset = global_idx * 4;\n" + " x += input_value_t(" + + bias.GetByOffset("bias_offset % uniforms.bias_shape") + ", " + + bias.GetByOffset("(bias_offset + 1) % uniforms.bias_shape") + ", " + + bias.GetByOffset("(bias_offset + 2) % uniforms.bias_shape") + ", " + + bias.GetByOffset("(bias_offset + 3) % uniforms.bias_shape") + ");\n" + : " x += " + bias.GetByOffset("global_idx % uniforms.bias_shape") + ";\n"; + } + + shader.MainFunctionBody(shader.GuardAgainstOutOfBoundsWorkgroupSizes("uniforms.vec_size"), + " var x = ", input.GetByOffset("global_idx"), ";\n", + add_bias, + " let y = x * (0.5 + 0.5 * tanh(x * (0.035677408136300125 * x * x + 0.7978845608028654)));\n ", + output.SetByOffset("global_idx", "y")); + + return Status::OK(); +} + +Status FastGelu::ComputeInternal(onnxruntime::webgpu::ComputeContext& context) const { + const auto* input = context.Input(0); + const auto* bias = context.Input(1); + auto* output = context.Output(0, input->Shape()); + + uint32_t data_size = SafeInt(output->Shape().Size()); + if (data_size == 0) { + return Status::OK(); + } + + const auto vec_size = (data_size + 3) / 4; + uint32_t bias_size = 0; + int bias_components = 1; + + if (bias != nullptr) { + bias_size = SafeInt(bias->Shape().Size()); + if (bias_size % 4 == 0) { + bias_components = 4; + bias_size = bias_size / 4; + } + } + + FastGeluProgram program{bias_components}; + program.Input({input, ProgramTensorMetadataDependency::Type, {vec_size}, 4}) + .Output({output, ProgramTensorMetadataDependency::None, {vec_size}, 4}) + .DispatchGroupSize((vec_size + WORKGROUP_SIZE - 1) / WORKGROUP_SIZE) + .UniformVariable({vec_size}); + + if (bias != nullptr) { + program.Input({bias, ProgramTensorMetadataDependency::TypeAndRank, {bias_size}, bias_components}) + .CacheHint(std::to_string(bias_components)); + } + return context.RunProgram(program); +} + +} // namespace webgpu +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/webgpu/bert/fast_gelu.h b/onnxruntime/contrib_ops/webgpu/bert/fast_gelu.h new file mode 100644 index 0000000000000..fa40d52bf301f --- /dev/null +++ b/onnxruntime/contrib_ops/webgpu/bert/fast_gelu.h @@ -0,0 +1,38 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include "core/providers/webgpu/program.h" +#include "core/providers/webgpu/webgpu_kernel.h" + +namespace onnxruntime { +namespace contrib { +namespace webgpu { + +using namespace onnxruntime::webgpu; +using onnxruntime::webgpu::ComputeContext; + +class FastGeluProgram final : public Program { + public: + FastGeluProgram(int bias_components) : Program{"FastGelu"}, bias_components_{bias_components} { + } + + Status GenerateShaderCode(ShaderHelper& sh) const override; + + WEBGPU_PROGRAM_DEFINE_UNIFORM_VARIABLES({"vec_size", ProgramUniformVariableDataType::Uint32}); + + private: + int bias_components_; +}; + +class FastGelu final : public WebGpuKernel { + public: + FastGelu(const OpKernelInfo& info) : WebGpuKernel(info) {} + + Status ComputeInternal(ComputeContext& context) const override; +}; + +} // namespace webgpu +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/webgpu/webgpu_contrib_kernels.cc b/onnxruntime/contrib_ops/webgpu/webgpu_contrib_kernels.cc index 91f51df588fca..def104b6cb108 100644 --- a/onnxruntime/contrib_ops/webgpu/webgpu_contrib_kernels.cc +++ b/onnxruntime/contrib_ops/webgpu/webgpu_contrib_kernels.cc @@ -26,11 +26,11 @@ class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kMSDomain, 1, Sk class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 1, SimplifiedLayerNormalization); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kMSDomain, 1, SkipSimplifiedLayerNormalization); -// template <> -// KernelCreateInfo BuildKernelCreateInfo() { -// KernelCreateInfo info; -// return info; -// } +template <> +KernelCreateInfo BuildKernelCreateInfo() { + KernelCreateInfo info; + return info; +} Status RegisterWebGpuContribKernels(KernelRegistry& kernel_registry) { static const BuildKernelCreateInfoFn function_table[] = { @@ -38,22 +38,22 @@ Status RegisterWebGpuContribKernels(KernelRegistry& kernel_registry) { // BuildKernelCreateInfo, // BuildKernelCreateInfo, // BuildKernelCreateInfo, - // BuildKernelCreateInfo, - // BuildKernelCreateInfo, - // BuildKernelCreateInfo, - // BuildKernelCreateInfo, - // // LayerNormalization used to be a contrib op that (incorrectly) used kOnnxDomain so we need to version it - // BuildKernelCreateInfo, - // BuildKernelCreateInfo, - // BuildKernelCreateInfo, - // BuildKernelCreateInfo, - // BuildKernelCreateInfo, - // BuildKernelCreateInfo, - // BuildKernelCreateInfo, - // BuildKernelCreateInfo + BuildKernelCreateInfo, + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + // // LayerNormalization used to be a contrib op that (incorrectly) used kOnnxDomain so we need to version it + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + // BuildKernelCreateInfo }; for (auto& function_table_entry : function_table) { diff --git a/onnxruntime/contrib_ops/webgpu/webgpu_contrib_kernels.h b/onnxruntime/contrib_ops/webgpu/webgpu_contrib_kernels.h index 6cdf7382804f9..d73859de78239 100644 --- a/onnxruntime/contrib_ops/webgpu/webgpu_contrib_kernels.h +++ b/onnxruntime/contrib_ops/webgpu/webgpu_contrib_kernels.h @@ -3,13 +3,16 @@ #pragma once -#include "core/framework/op_kernel.h" #include "core/framework/kernel_registry.h" namespace onnxruntime { namespace contrib { namespace webgpu { +// forward declaration for this EP's namespace. +template +KernelCreateInfo BuildKernelCreateInfo(); + Status RegisterWebGpuContribKernels(KernelRegistry& kernel_registry); } // namespace webgpu diff --git a/onnxruntime/core/providers/webgpu/program.cc b/onnxruntime/core/providers/webgpu/program.cc index d4a2b24172d07..b05b576b4bc32 100644 --- a/onnxruntime/core/providers/webgpu/program.cc +++ b/onnxruntime/core/providers/webgpu/program.cc @@ -192,13 +192,23 @@ ProgramBase::ProgramBase(const std::string& name) workgroup_size_z_{0} { } +ProgramBase& ProgramBase::Input(ProgramInput&& input) { + inputs_.emplace_back(input); + return *this; +} + ProgramBase& ProgramBase::Inputs(std::initializer_list inputs) { - inputs_.assign(inputs.begin(), inputs.end()); + inputs_.insert(inputs_.end(), inputs.begin(), inputs.end()); + return *this; +} + +ProgramBase& ProgramBase::Output(ProgramOutput&& output) { + outputs_.emplace_back(output); return *this; } ProgramBase& ProgramBase::Outputs(std::initializer_list outputs) { - outputs_.assign(outputs.begin(), outputs.end()); + outputs_.insert(outputs_.end(), outputs.begin(), outputs.end()); return *this; } @@ -232,6 +242,11 @@ ProgramBase& ProgramBase::WorkgroupSize(uint32_t x, uint32_t y, uint32_t z) { return *this; } +ProgramBase& ProgramBase::UniformVariable(ProgramUniformVariableValue&& variable) { + variables_.emplace_back(variable); + return *this; +} + ProgramBase& ProgramBase::UniformVariables(std::initializer_list variables) { variables_.insert(variables_.end(), variables.begin(), variables.end()); return *this; diff --git a/onnxruntime/core/providers/webgpu/program.h b/onnxruntime/core/providers/webgpu/program.h index e162cddbb6408..f5f75747dbe5a 100644 --- a/onnxruntime/core/providers/webgpu/program.h +++ b/onnxruntime/core/providers/webgpu/program.h @@ -272,8 +272,12 @@ class ProgramBase { return *this; } - // set one or more program inputs + // append a program input + ProgramBase& Input(ProgramInput&& input); + // append multiple program inputs ProgramBase& Inputs(std::initializer_list inputs); + // append a program output + ProgramBase& Output(ProgramOutput&& output); // set one or more program outputs ProgramBase& Outputs(std::initializer_list outputs); @@ -291,7 +295,12 @@ class ProgramBase { // set the size of a workgroup grid. ProgramBase& WorkgroupSize(uint32_t x, uint32_t y, uint32_t z); - // set the uniform variables. + // append a uniform variable. + // + // the specified uniform variable should match the uniform definition in the class, + // specified by macro WEBGPU_PROGRAM_DEFINE_UNIFORM_VARIABLES. + ProgramBase& UniformVariable(ProgramUniformVariableValue&& variable); + // append multiple uniform variables. // // the specified uniform variables should match the uniform definition in the class, // specified by macro WEBGPU_PROGRAM_DEFINE_UNIFORM_VARIABLES. diff --git a/onnxruntime/test/contrib_ops/fastgelu_op_test.cc b/onnxruntime/test/contrib_ops/fastgelu_op_test.cc index 5cf749dc4c97c..a7d751f4472fc 100644 --- a/onnxruntime/test/contrib_ops/fastgelu_op_test.cc +++ b/onnxruntime/test/contrib_ops/fastgelu_op_test.cc @@ -41,7 +41,7 @@ const std::vector GetExpectedResult(const std::vector& input_data, return ComputeGelu(add_bias_data); } -#if defined(USE_CUDA) || defined(USE_ROCM) +#if defined(USE_CUDA) || defined(USE_ROCM) || defined(USE_WEBGPU) static void RunFastGeluGpuTest(const std::vector& input_data, const std::vector& bias_data, const std::vector& output_data, const std::vector& input_dims, const std::vector& bias_dims, const std::vector& output_dims, @@ -75,6 +75,8 @@ static void RunFastGeluGpuTest(const std::vector& input_data, const std:: execution_providers.push_back(DefaultCudaExecutionProvider()); #elif USE_ROCM execution_providers.push_back(DefaultRocmExecutionProvider()); +#elif USE_WEBGPU + execution_providers.push_back(DefaultWebGpuExecutionProvider()); #endif tester.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers); } @@ -142,7 +144,7 @@ static void RunFastGeluTest( std::vector input_dims = {batch_size, sequence_length, hidden_size}; std::vector bias_dims = {hidden_size}; std::vector output_dims = input_dims; -#if defined(USE_CUDA) || defined(USE_ROCM) +#if defined(USE_CUDA) || defined(USE_ROCM) || defined(USE_WEBGPU) RunFastGeluGpuTest(input_data, bias_data, output_data, input_dims, bias_dims, output_dims, has_bias); #endif RunFastGeluCpuTest(input_data, bias_data, output_data, input_dims, bias_dims, output_dims, has_bias); @@ -245,8 +247,8 @@ TEST(FastGeluTest, FastGeluWithoutBiasFloat32) { RunFastGeluTest(input_data, bias_data, batch_size, sequence_length, hidden_size); } -// CUDA and ROCm only for Float16 and BFloat16 type. -#if defined(USE_CUDA) || defined(USE_ROCM) +// CUDA, ROCm and WebGPU only for Float16 type. +#if defined(USE_CUDA) || defined(USE_ROCM) || defined(USE_WEBGPU) TEST(FastGeluTest, FastGeluWithBiasFloat16_2) { int batch_size = 1; int sequence_length = 2; @@ -381,7 +383,10 @@ TEST(FastGeluTest, FastGeluWithoutBiasFloat16_8) { RunFastGeluGpuTest(input_data, bias_data, output_data, input_dims, bias_dims, output_dims, false, true); } +#endif +// CUDA and ROCm only for BFloat16 type. +#if defined(USE_CUDA) || defined(USE_ROCM) TEST(FastGeluTest, FastGeluWithBias_BFloat16) { #ifdef USE_CUDA int min_cuda_architecture = 530; From 2b3e7c2d81395c4aca7ad0a4f2ffb583e8622051 Mon Sep 17 00:00:00 2001 From: Yulong Wang <7679871+fs-eire@users.noreply.github.com> Date: Mon, 9 Sep 2024 19:04:39 -0700 Subject: [PATCH 66/77] use 'set/add' as prefix for some functions --- .../contrib_ops/webgpu/bert/fast_gelu.cc | 20 +++++----- .../webgpu/math/unary_elementwise_ops.cc | 22 +++++------ .../webgpu/math/unary_elementwise_ops.h | 2 +- onnxruntime/core/providers/webgpu/program.cc | 34 ++++++++--------- onnxruntime/core/providers/webgpu/program.h | 38 +++++++++---------- .../core/providers/webgpu/shader_helper.h | 2 +- .../core/providers/webgpu/tensor/expand.cc | 16 ++++---- 7 files changed, 67 insertions(+), 67 deletions(-) diff --git a/onnxruntime/contrib_ops/webgpu/bert/fast_gelu.cc b/onnxruntime/contrib_ops/webgpu/bert/fast_gelu.cc index 42f056206f3f5..40c083c76d33c 100644 --- a/onnxruntime/contrib_ops/webgpu/bert/fast_gelu.cc +++ b/onnxruntime/contrib_ops/webgpu/bert/fast_gelu.cc @@ -35,11 +35,11 @@ Status FastGeluProgram::GenerateShaderCode(ShaderHelper& shader) const { : " x += " + bias.GetByOffset("global_idx % uniforms.bias_shape") + ";\n"; } - shader.MainFunctionBody(shader.GuardAgainstOutOfBoundsWorkgroupSizes("uniforms.vec_size"), - " var x = ", input.GetByOffset("global_idx"), ";\n", - add_bias, - " let y = x * (0.5 + 0.5 * tanh(x * (0.035677408136300125 * x * x + 0.7978845608028654)));\n ", - output.SetByOffset("global_idx", "y")); + shader.SetMainFunctionBody(shader.GuardAgainstOutOfBoundsWorkgroupSizes("uniforms.vec_size"), + " var x = ", input.GetByOffset("global_idx"), ";\n", + add_bias, + " let y = x * (0.5 + 0.5 * tanh(x * (0.035677408136300125 * x * x + 0.7978845608028654)));\n ", + output.SetByOffset("global_idx", "y")); return Status::OK(); } @@ -67,13 +67,13 @@ Status FastGelu::ComputeInternal(onnxruntime::webgpu::ComputeContext& context) c } FastGeluProgram program{bias_components}; - program.Input({input, ProgramTensorMetadataDependency::Type, {vec_size}, 4}) - .Output({output, ProgramTensorMetadataDependency::None, {vec_size}, 4}) - .DispatchGroupSize((vec_size + WORKGROUP_SIZE - 1) / WORKGROUP_SIZE) - .UniformVariable({vec_size}); + program.AddInput({input, ProgramTensorMetadataDependency::Type, {vec_size}, 4}) + .AddOutput({output, ProgramTensorMetadataDependency::None, {vec_size}, 4}) + .SetDispatchGroupSize((vec_size + WORKGROUP_SIZE - 1) / WORKGROUP_SIZE) + .AddUniformVariable({vec_size}); if (bias != nullptr) { - program.Input({bias, ProgramTensorMetadataDependency::TypeAndRank, {bias_size}, bias_components}) + program.AddInput({bias, ProgramTensorMetadataDependency::TypeAndRank, {bias_size}, bias_components}) .CacheHint(std::to_string(bias_components)); } return context.RunProgram(program); diff --git a/onnxruntime/core/providers/webgpu/math/unary_elementwise_ops.cc b/onnxruntime/core/providers/webgpu/math/unary_elementwise_ops.cc index 8d8f855ec20ae..272ff43a68dfd 100644 --- a/onnxruntime/core/providers/webgpu/math/unary_elementwise_ops.cc +++ b/onnxruntime/core/providers/webgpu/math/unary_elementwise_ops.cc @@ -12,9 +12,9 @@ Status UnaryElementwiseProgram::GenerateShaderCode(ShaderHelper& shader) const { const auto& input = shader.AddInput("x", ShaderVariable::UseUniform | additional_usage_); const auto& output = shader.AddOutput("y", ShaderVariable::UseUniform); shader.AppendImplementation(additional_impl_); - shader.MainFunctionBody(shader.GuardAgainstOutOfBoundsWorkgroupSizes("uniforms.vec_size"), - " let a = ", input.GetByOffset("global_idx"), ";\n ", - output.SetByOffset("global_idx", expression_)); + shader.SetMainFunctionBody(shader.GuardAgainstOutOfBoundsWorkgroupSizes("uniforms.vec_size"), + " let a = ", input.GetByOffset("global_idx"), ";\n ", + output.SetByOffset("global_idx", expression_)); return Status::OK(); } @@ -29,10 +29,10 @@ Status UnaryElementwise::ComputeInternal(ComputeContext& context) const { SafeInt vec_size = (size + 3) / 4; UnaryElementwiseProgram program{kernel_name_, expression_, additional_impl_, additional_usage_}; program - .Inputs({{input_tensor, ProgramTensorMetadataDependency::Type, {vec_size}, 4}}) - .Outputs({{output_tensor, ProgramTensorMetadataDependency::None, {vec_size}, 4}}) - .DispatchGroupSize((vec_size + WORKGROUP_SIZE - 1) / WORKGROUP_SIZE) - .UniformVariables({ + .AddInputs({{input_tensor, ProgramTensorMetadataDependency::Type, {vec_size}, 4}}) + .AddOutputs({{output_tensor, ProgramTensorMetadataDependency::None, {vec_size}, 4}}) + .SetDispatchGroupSize((vec_size + WORKGROUP_SIZE - 1) / WORKGROUP_SIZE) + .AddUniformVariables({ {static_cast(vec_size)}, }); if (!cache_hint.empty()) { @@ -143,7 +143,7 @@ class HardSigmoid final : public UnaryElementwise { } Status ConfigureProgram(const ComputeContext& /*context*/, UnaryElementwiseProgram& program) const override { - program.UniformVariables({gsl::make_span(attr, 2)}); + program.AddUniformVariables({gsl::make_span(attr, 2)}); return Status::OK(); } @@ -221,11 +221,11 @@ class Clip final : public UnaryElementwise { if constexpr (std::is_same_v) { // F16: stores span as a single float float encoded_value = *reinterpret_cast(attr); - program.UniformVariables({encoded_value}); + program.AddUniformVariable({encoded_value}); } else { static_assert(sizeof(T) == sizeof(float), "T must be f32, i32 or u32"); // stores span as-is - program.UniformVariables({gsl::make_span(attr, 2)}); + program.AddUniformVariable({gsl::make_span(attr, 2)}); } return Status::OK(); } @@ -276,7 +276,7 @@ class LinearUnit : public UnaryElementwise { } Status ConfigureProgram(const ComputeContext& /*context*/, UnaryElementwiseProgram& program) const override { - program.UniformVariables({alpha_}); + program.AddUniformVariables({alpha_}); return Status::OK(); } diff --git a/onnxruntime/core/providers/webgpu/math/unary_elementwise_ops.h b/onnxruntime/core/providers/webgpu/math/unary_elementwise_ops.h index d870278f4c090..2691d67e1f9f6 100644 --- a/onnxruntime/core/providers/webgpu/math/unary_elementwise_ops.h +++ b/onnxruntime/core/providers/webgpu/math/unary_elementwise_ops.h @@ -49,7 +49,7 @@ class UnaryElementwise : public WebGpuKernel { Status ComputeInternal(ComputeContext& context) const final; virtual Status ConfigureProgram(const ComputeContext& /*context*/, UnaryElementwiseProgram& program) const { - program.UniformVariables({{}}); // empty for attribute(s) + program.AddUniformVariables({{}}); // empty for attribute(s) return Status::OK(); } diff --git a/onnxruntime/core/providers/webgpu/program.cc b/onnxruntime/core/providers/webgpu/program.cc index b05b576b4bc32..023fa78a4196b 100644 --- a/onnxruntime/core/providers/webgpu/program.cc +++ b/onnxruntime/core/providers/webgpu/program.cc @@ -192,67 +192,67 @@ ProgramBase::ProgramBase(const std::string& name) workgroup_size_z_{0} { } -ProgramBase& ProgramBase::Input(ProgramInput&& input) { +ProgramBase& ProgramBase::AddInput(ProgramInput&& input) { inputs_.emplace_back(input); return *this; } -ProgramBase& ProgramBase::Inputs(std::initializer_list inputs) { +ProgramBase& ProgramBase::AddInputs(std::initializer_list inputs) { inputs_.insert(inputs_.end(), inputs.begin(), inputs.end()); return *this; } -ProgramBase& ProgramBase::Output(ProgramOutput&& output) { +ProgramBase& ProgramBase::AddOutput(ProgramOutput&& output) { outputs_.emplace_back(output); return *this; } -ProgramBase& ProgramBase::Outputs(std::initializer_list outputs) { +ProgramBase& ProgramBase::AddOutputs(std::initializer_list outputs) { outputs_.insert(outputs_.end(), outputs.begin(), outputs.end()); return *this; } -ProgramBase& ProgramBase::DispatchGroupSize(uint32_t x) { - return DispatchGroupSize(x, 1, 1); +ProgramBase& ProgramBase::SetDispatchGroupSize(uint32_t x) { + return SetDispatchGroupSize(x, 1, 1); } -ProgramBase& ProgramBase::DispatchGroupSize(uint32_t x, uint32_t y) { - return DispatchGroupSize(x, y, 1); +ProgramBase& ProgramBase::SetDispatchGroupSize(uint32_t x, uint32_t y) { + return SetDispatchGroupSize(x, y, 1); } -ProgramBase& ProgramBase::DispatchGroupSize(uint32_t x, uint32_t y, uint32_t z) { +ProgramBase& ProgramBase::SetDispatchGroupSize(uint32_t x, uint32_t y, uint32_t z) { dispatch_group_size_x_ = x; dispatch_group_size_y_ = y; dispatch_group_size_z_ = z; return *this; } -ProgramBase& ProgramBase::WorkgroupSize(uint32_t x) { - return WorkgroupSize(x, 1, 1); +ProgramBase& ProgramBase::SetWorkgroupSize(uint32_t x) { + return SetWorkgroupSize(x, 1, 1); } -ProgramBase& ProgramBase::WorkgroupSize(uint32_t x, uint32_t y) { - return WorkgroupSize(x, y, 1); +ProgramBase& ProgramBase::SetWorkgroupSize(uint32_t x, uint32_t y) { + return SetWorkgroupSize(x, y, 1); } -ProgramBase& ProgramBase::WorkgroupSize(uint32_t x, uint32_t y, uint32_t z) { +ProgramBase& ProgramBase::SetWorkgroupSize(uint32_t x, uint32_t y, uint32_t z) { workgroup_size_x_ = x; workgroup_size_y_ = y; workgroup_size_z_ = z; return *this; } -ProgramBase& ProgramBase::UniformVariable(ProgramUniformVariableValue&& variable) { +ProgramBase& ProgramBase::AddUniformVariable(ProgramUniformVariableValue&& variable) { variables_.emplace_back(variable); return *this; } -ProgramBase& ProgramBase::UniformVariables(std::initializer_list variables) { +ProgramBase& ProgramBase::AddUniformVariables(std::initializer_list variables) { variables_.insert(variables_.end(), variables.begin(), variables.end()); return *this; } -ProgramBase& ProgramBase::OverridableConstants(std::initializer_list overridable_constants) { +ProgramBase& ProgramBase::SetOverridableConstants(std::initializer_list overridable_constants) { overridable_constants_.insert(overridable_constants_.end(), overridable_constants.begin(), overridable_constants.end()); return *this; } diff --git a/onnxruntime/core/providers/webgpu/program.h b/onnxruntime/core/providers/webgpu/program.h index f5f75747dbe5a..ae3d82a6371d0 100644 --- a/onnxruntime/core/providers/webgpu/program.h +++ b/onnxruntime/core/providers/webgpu/program.h @@ -272,45 +272,45 @@ class ProgramBase { return *this; } - // append a program input - ProgramBase& Input(ProgramInput&& input); - // append multiple program inputs - ProgramBase& Inputs(std::initializer_list inputs); - // append a program output - ProgramBase& Output(ProgramOutput&& output); - // set one or more program outputs - ProgramBase& Outputs(std::initializer_list outputs); + // add a program input + ProgramBase& AddInput(ProgramInput&& input); + // add multiple program inputs + ProgramBase& AddInputs(std::initializer_list inputs); + // add a program output + ProgramBase& AddOutput(ProgramOutput&& output); + // add multiple program outputs + ProgramBase& AddOutputs(std::initializer_list outputs); // set the size of dispatch groups. Y and Z are 1 if not specified. - ProgramBase& DispatchGroupSize(uint32_t x); + ProgramBase& SetDispatchGroupSize(uint32_t x); // set the size of dispatch groups. Z is 1 if not specified. - ProgramBase& DispatchGroupSize(uint32_t x, uint32_t y); + ProgramBase& SetDispatchGroupSize(uint32_t x, uint32_t y); // set the size of dispatch groups. - ProgramBase& DispatchGroupSize(uint32_t x, uint32_t y, uint32_t z); + ProgramBase& SetDispatchGroupSize(uint32_t x, uint32_t y, uint32_t z); // set the size of a workgroup grid. Y and Z are 1 if not specified. - ProgramBase& WorkgroupSize(uint32_t x); + ProgramBase& SetWorkgroupSize(uint32_t x); // set the size of a workgroup grid. Z is 1 if not specified. - ProgramBase& WorkgroupSize(uint32_t x, uint32_t y); + ProgramBase& SetWorkgroupSize(uint32_t x, uint32_t y); // set the size of a workgroup grid. - ProgramBase& WorkgroupSize(uint32_t x, uint32_t y, uint32_t z); + ProgramBase& SetWorkgroupSize(uint32_t x, uint32_t y, uint32_t z); - // append a uniform variable. + // add a uniform variable. // // the specified uniform variable should match the uniform definition in the class, // specified by macro WEBGPU_PROGRAM_DEFINE_UNIFORM_VARIABLES. - ProgramBase& UniformVariable(ProgramUniformVariableValue&& variable); - // append multiple uniform variables. + ProgramBase& AddUniformVariable(ProgramUniformVariableValue&& variable); + // add multiple uniform variables. // // the specified uniform variables should match the uniform definition in the class, // specified by macro WEBGPU_PROGRAM_DEFINE_UNIFORM_VARIABLES. - ProgramBase& UniformVariables(std::initializer_list variables); + ProgramBase& AddUniformVariables(std::initializer_list variables); // set the overridable constants // // the specified overridable constants should match the overridable constant definition in the class, // specified by macro WEBGPU_PROGRAM_DEFINE_OVERRIDABLE_CONSTANTS. - ProgramBase& OverridableConstants(std::initializer_list overridable_constants); + ProgramBase& SetOverridableConstants(std::initializer_list overridable_constants); // // shader code generation diff --git a/onnxruntime/core/providers/webgpu/shader_helper.h b/onnxruntime/core/providers/webgpu/shader_helper.h index 08ff111f8a690..811ae3cfa15cc 100644 --- a/onnxruntime/core/providers/webgpu/shader_helper.h +++ b/onnxruntime/core/providers/webgpu/shader_helper.h @@ -102,7 +102,7 @@ class ShaderHelper final { // // can be called only once. template - inline void MainFunctionBody(const Strs&... body) { + inline void SetMainFunctionBody(const Strs&... body) { ORT_ENFORCE(!body_set_, "Main function body is already set"); onnxruntime::detail::MakeStringImpl(body_, std::forward>(body)...); body_set_ = true; diff --git a/onnxruntime/core/providers/webgpu/tensor/expand.cc b/onnxruntime/core/providers/webgpu/tensor/expand.cc index 82451c9398243..45084472d3537 100644 --- a/onnxruntime/core/providers/webgpu/tensor/expand.cc +++ b/onnxruntime/core/providers/webgpu/tensor/expand.cc @@ -14,10 +14,10 @@ Status ExpandProgram::GenerateShaderCode(ShaderHelper& shader) const { const auto& input = shader.AddInput("input", ShaderVariable::UseUniform); const auto& output = shader.AddOutput("output", ShaderVariable::UseUniform); - shader.MainFunctionBody(shader.GuardAgainstOutOfBoundsWorkgroupSizes("uniforms.data_size"), - "let output_indices = ", output.OffsetToIndices("global_idx"), ";\n", - "let input_offset = ", input.BroadcastedIndicesToOffset("output_indices", output), ";\n", - output.SetByOffset("global_idx", input.GetByOffset("input_offset"))); + shader.SetMainFunctionBody(shader.GuardAgainstOutOfBoundsWorkgroupSizes("uniforms.data_size"), + " let output_indices = ", output.OffsetToIndices("global_idx"), ";\n", + " let input_offset = ", input.BroadcastedIndicesToOffset("output_indices", output), ";\n ", + output.SetByOffset("global_idx", input.GetByOffset("input_offset"))); return Status::OK(); } @@ -34,10 +34,10 @@ Status Expand::ComputeInternal(ComputeContext& context) const { uint32_t data_size = SafeInt(output_shape.Size()); ExpandProgram program{}; program - .Inputs({{input_tensor, ProgramTensorMetadataDependency::TypeAndRank}}) - .Outputs({{output_tensor, ProgramTensorMetadataDependency::Rank}}) - .DispatchGroupSize((data_size + WORKGROUP_SIZE - 1) / WORKGROUP_SIZE) - .UniformVariables({ + .AddInputs({{input_tensor, ProgramTensorMetadataDependency::TypeAndRank}}) + .AddOutputs({{output_tensor, ProgramTensorMetadataDependency::Rank}}) + .SetDispatchGroupSize((data_size + WORKGROUP_SIZE - 1) / WORKGROUP_SIZE) + .AddUniformVariables({ {data_size}, }); return context.RunProgram(program); From ef0d53b78c518e56cbb5d69173bc9f4aa8ace387 Mon Sep 17 00:00:00 2001 From: Yulong Wang <7679871+fs-eire@users.noreply.github.com> Date: Mon, 9 Sep 2024 19:06:32 -0700 Subject: [PATCH 67/77] remove unnecessary cache hint for FastGelu --- onnxruntime/contrib_ops/webgpu/bert/fast_gelu.cc | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/onnxruntime/contrib_ops/webgpu/bert/fast_gelu.cc b/onnxruntime/contrib_ops/webgpu/bert/fast_gelu.cc index 40c083c76d33c..f6631025f0b34 100644 --- a/onnxruntime/contrib_ops/webgpu/bert/fast_gelu.cc +++ b/onnxruntime/contrib_ops/webgpu/bert/fast_gelu.cc @@ -73,8 +73,7 @@ Status FastGelu::ComputeInternal(onnxruntime::webgpu::ComputeContext& context) c .AddUniformVariable({vec_size}); if (bias != nullptr) { - program.AddInput({bias, ProgramTensorMetadataDependency::TypeAndRank, {bias_size}, bias_components}) - .CacheHint(std::to_string(bias_components)); + program.AddInput({bias, ProgramTensorMetadataDependency::TypeAndRank, {bias_size}, bias_components}); } return context.RunProgram(program); } From c4ca47f763f6e7be8b021576195c44ea8992dc54 Mon Sep 17 00:00:00 2001 From: Yulong Wang <7679871+fs-eire@users.noreply.github.com> Date: Mon, 9 Sep 2024 21:17:38 -0700 Subject: [PATCH 68/77] revise unary - expose consts in header --- .../webgpu/math/unary_elementwise_ops.cc | 47 +---------------- .../webgpu/math/unary_elementwise_ops.h | 50 +++++++++++++++++++ 2 files changed, 51 insertions(+), 46 deletions(-) diff --git a/onnxruntime/core/providers/webgpu/math/unary_elementwise_ops.cc b/onnxruntime/core/providers/webgpu/math/unary_elementwise_ops.cc index 272ff43a68dfd..b4b397b2c4b5f 100644 --- a/onnxruntime/core/providers/webgpu/math/unary_elementwise_ops.cc +++ b/onnxruntime/core/providers/webgpu/math/unary_elementwise_ops.cc @@ -98,21 +98,6 @@ WEBGPU_ELEMENTWISE_IMPL(Exp, "exp(a)") WEBGPU_ELEMENTWISE_VERSIONED_KERNEL(Exp, 6, 12, WebGpuSupportedFloatTypes()) WEBGPU_ELEMENTWISE_KERNEL(Exp, 13, WebGpuSupportedFloatTypes()) -constexpr char ErfImpl[] = R"( -const r0 = 0.3275911; -const r1 = 0.254829592; -const r2 = -0.284496736; -const r3 = 1.421413741; -const r4 = -1.453152027; -const r5 = 1.061405429; - -fn erf_v(v: x_value_t) -> x_value_t { - let absv = abs(v); - let x = 1.0 / (1.0 + r0 * absv); - return sign(v) * (1.0 - ((((r5 * x + r4) * x + r3) * x + r2) * x + r1) * x * exp(-absv * absv)); -} -)"; - WEBGPU_ELEMENTWISE_IMPL(Erf, "erf_v(a)", ErfImpl, ShaderVariable::UseValueTypeAlias) WEBGPU_ELEMENTWISE_VERSIONED_KERNEL(Erf, 9, 12, WebGpuSupportedFloatTypes()) WEBGPU_ELEMENTWISE_KERNEL(Erf, 13, WebGpuSupportedFloatTypes()) @@ -125,14 +110,6 @@ WEBGPU_ELEMENTWISE_IMPL(Sigmoid, "1.0 / (1.0 + exp(-a))") WEBGPU_ELEMENTWISE_VERSIONED_KERNEL(Sigmoid, 6, 12, WebGpuSupportedFloatTypes()) WEBGPU_ELEMENTWISE_KERNEL(Sigmoid, 13, WebGpuSupportedFloatTypes()) -constexpr char HardSigmoidImpl[] = R"( -fn hard_sigmoid_v(v: vec4) -> vec4 { - let alpha = x_element_t(uniforms.attr[0]); - let beta_v = vec4(uniforms.attr[1]); - return max(vec4(0.0), - min(vec4(1.0), alpha * v + beta_v)); -} -)"; class HardSigmoid final : public UnaryElementwise { public: HardSigmoid(const OpKernelInfo& info) @@ -177,14 +154,6 @@ WEBGPU_ELEMENTWISE_KERNEL(Sinh, 9, WebGpuSupportedFloatTypes()) WEBGPU_ELEMENTWISE_IMPL(Cosh, "cosh(a)") WEBGPU_ELEMENTWISE_KERNEL(Cosh, 9, WebGpuSupportedFloatTypes()) -// built-in function tanh() does not work with large input (f32 88.7 or f16 11.09) -// https://github.com/gpuweb/gpuweb/issues/4458 -constexpr char TanhImpl[] = R"( -fn tanh_v(a: x_value_t) -> x_value_t { - let expr = exp(-2 * abs(a)); - return sign(a) * (1 - expr) / (1 + expr); -} -)"; WEBGPU_ELEMENTWISE_IMPL(Tanh, "tanh_v(a)", TanhImpl, ShaderVariable::UseValueTypeAlias) WEBGPU_ELEMENTWISE_VERSIONED_KERNEL(Tanh, 6, 12, WebGpuSupportedFloatTypes()) WEBGPU_ELEMENTWISE_KERNEL(Tanh, 13, WebGpuSupportedFloatTypes()) @@ -290,17 +259,6 @@ class LinearUnit : public UnaryElementwise { OP_TYPE(const OpKernelInfo& info) : LinearUnit{info, #OP_TYPE, __VA_ARGS__} {} \ }; -constexpr char EluImpl[] = R"( -fn elu(a: x_element_t) -> x_element_t { - let alpha = x_element_t(uniforms.attr); - return select((exp(a) - 1.0) * alpha, a, a >= 0.0); -} - -fn elu_v(v: vec4) -> vec4 { - return vec4(elu(v.x), elu(v.y), elu(v.z), elu(v.w)); -} -)"; - WEBGPU_LU_IMPL(Elu, "elu_v(a)", EluImpl, 1.0) WEBGPU_ELEMENTWISE_KERNEL(Elu, 6, WebGpuSupportedFloatTypes()) @@ -309,14 +267,11 @@ class Gelu : public UnaryElementwise { Gelu(const OpKernelInfo& info) : UnaryElementwise{info, "Gelu", - info.GetAttrOrDefault("approximate", "none") == "tanh" ? TanhBasedImpl : DefaultImpl, + info.GetAttrOrDefault("approximate", "none") == "tanh" ? FastGeluExpr : GeluExpr, info.GetAttrOrDefault("approximate", "none") == "tanh" ? TanhImpl : ErfImpl, ShaderVariable::UseValueTypeAlias} { cache_hint = info.GetAttrOrDefault("approximate", "none"); } - - constexpr static const char DefaultImpl[] = "0.5 * a * (1.0 + erf_v(a * 0.7071067811865475))"; - constexpr static const char TanhBasedImpl[] = "0.5 * a * (1 + tanh_v(0.7978845608028654 * (a + 0.044715 * a * a * a)))"; }; WEBGPU_ELEMENTWISE_KERNEL(Gelu, 20, WebGpuSupportedFloatTypes()) diff --git a/onnxruntime/core/providers/webgpu/math/unary_elementwise_ops.h b/onnxruntime/core/providers/webgpu/math/unary_elementwise_ops.h index 2691d67e1f9f6..de85c18da117a 100644 --- a/onnxruntime/core/providers/webgpu/math/unary_elementwise_ops.h +++ b/onnxruntime/core/providers/webgpu/math/unary_elementwise_ops.h @@ -60,5 +60,55 @@ class UnaryElementwise : public WebGpuKernel { ShaderVariable::Usage additional_usage_; }; +constexpr const char ErfImpl[] = R"( +const r0 = 0.3275911; +const r1 = 0.254829592; +const r2 = -0.284496736; +const r3 = 1.421413741; +const r4 = -1.453152027; +const r5 = 1.061405429; + +fn erf_v(v: x_value_t) -> x_value_t { + let absv = abs(v); + let x = 1.0 / (1.0 + r0 * absv); + return sign(v) * (1.0 - ((((r5 * x + r4) * x + r3) * x + r2) * x + r1) * x * exp(-absv * absv)); +} +)"; + +constexpr const char HardSigmoidImpl[] = R"( +fn hard_sigmoid_v(v: vec4) -> vec4 { + let alpha = x_element_t(uniforms.attr[0]); + let beta_v = vec4(uniforms.attr[1]); + return max(vec4(0.0), + min(vec4(1.0), alpha * v + beta_v)); +} +)"; + +// built-in function tanh() does not work with large input (f32 88.7 or f16 11.09) +// https://github.com/gpuweb/gpuweb/issues/4458 +constexpr const char TanhImpl[] = R"( +fn tanh_v(a: x_value_t) -> x_value_t { + let expr = exp(-2 * abs(a)); + return sign(a) * (1 - expr) / (1 + expr); +} +)"; + +constexpr const char EluImpl[] = R"( +fn elu(a: x_element_t) -> x_element_t { + let alpha = x_element_t(uniforms.attr); + return select((exp(a) - 1.0) * alpha, a, a >= 0.0); +} + +fn elu_v(v: vec4) -> vec4 { + return vec4(elu(v.x), elu(v.y), elu(v.z), elu(v.w)); +} +)"; + +// default GELU expression, depending on ErfImpl +constexpr const char GeluExpr[] = "0.5 * a * (1.0 + erf_v(a * 0.7071067811865475))"; + +// fast GELU expression, depending on TanhImpl +constexpr const char FastGeluExpr[] = "a * (0.5 + 0.5 * tanh_v(a * (0.035677408136300125 * a * a + 0.7978845608028654)))"; + } // namespace webgpu } // namespace onnxruntime From 8806d57727be2723ff4d84fd33fb9504503fb7b5 Mon Sep 17 00:00:00 2001 From: Yulong Wang <7679871+fs-eire@users.noreply.github.com> Date: Mon, 9 Sep 2024 21:53:24 -0700 Subject: [PATCH 69/77] use path for header file --- onnxruntime/contrib_ops/webgpu/bert/fast_gelu.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/onnxruntime/contrib_ops/webgpu/bert/fast_gelu.cc b/onnxruntime/contrib_ops/webgpu/bert/fast_gelu.cc index f6631025f0b34..7d8bef1e66f42 100644 --- a/onnxruntime/contrib_ops/webgpu/bert/fast_gelu.cc +++ b/onnxruntime/contrib_ops/webgpu/bert/fast_gelu.cc @@ -1,9 +1,9 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. -#include "fast_gelu.h" #include "core/providers/webgpu/shader_helper.h" #include "core/providers/webgpu/webgpu_supported_types.h" +#include "contrib_ops/webgpu/bert/fast_gelu.h" #include "contrib_ops/webgpu/webgpu_contrib_kernels.h" namespace onnxruntime { From 0568e2b6e59e68f5c47265deb3c2a2739804025c Mon Sep 17 00:00:00 2001 From: Yulong Wang <7679871+fs-eire@users.noreply.github.com> Date: Tue, 10 Sep 2024 16:21:46 -0700 Subject: [PATCH 70/77] a few revises to the code (#22047) --- .../core/providers/webgpu/buffer_manager.cc | 8 ++-- .../providers/webgpu/program_cache_key.cc | 5 +- .../core/providers/webgpu/program_manager.h | 2 +- .../core/providers/webgpu/shader_helper.cc | 46 +++++++++---------- 4 files changed, 32 insertions(+), 29 deletions(-) diff --git a/onnxruntime/core/providers/webgpu/buffer_manager.cc b/onnxruntime/core/providers/webgpu/buffer_manager.cc index da544e1d1ed60..8751338d24178 100644 --- a/onnxruntime/core/providers/webgpu/buffer_manager.cc +++ b/onnxruntime/core/providers/webgpu/buffer_manager.cc @@ -243,10 +243,10 @@ std::ostream& operator<<(std::ostream& os, BufferCacheMode mode) { BufferManager::BufferManager(WebGpuContext& context, BufferCacheMode storage_buffer_cache_mode, BufferCacheMode uniform_buffer_cache_mode, BufferCacheMode query_resolve_buffer_cache_mode) : context_{context}, - storage_cache_{std::move(CreateBufferCacheManager(storage_buffer_cache_mode))}, - uniform_cache_{std::move(CreateBufferCacheManager(uniform_buffer_cache_mode))}, - query_resolve_cache_{std::move(CreateBufferCacheManager(query_resolve_buffer_cache_mode))}, - default_cache_{std::move(CreateBufferCacheManager(BufferCacheMode::Disabled))} { + storage_cache_{CreateBufferCacheManager(storage_buffer_cache_mode)}, + uniform_cache_{CreateBufferCacheManager(uniform_buffer_cache_mode)}, + query_resolve_cache_{CreateBufferCacheManager(query_resolve_buffer_cache_mode)}, + default_cache_{CreateBufferCacheManager(BufferCacheMode::Disabled)} { } void BufferManager::Upload(void* src, WGPUBuffer dst, size_t size) { diff --git a/onnxruntime/core/providers/webgpu/program_cache_key.cc b/onnxruntime/core/providers/webgpu/program_cache_key.cc index 09a536f7916b2..6c7ef2bc89c6b 100644 --- a/onnxruntime/core/providers/webgpu/program_cache_key.cc +++ b/onnxruntime/core/providers/webgpu/program_cache_key.cc @@ -10,12 +10,14 @@ namespace webgpu { namespace { // append the info of an input or output to the cachekey -void AppendTensorInfo(std::ostringstream& ss, const Tensor& tensor, ProgramVariableDataType var_type, ProgramTensorMetadataDependency dependency, bool& first) { +void AppendTensorInfo(std::ostringstream& ss, const Tensor& tensor, ProgramVariableDataType var_type, ProgramTensorMetadataDependency dependency, + bool& first) { if (first) { first = false; } else { ss << '|'; } + if ((dependency & ProgramTensorMetadataDependency::Type) == ProgramTensorMetadataDependency::Type) { #ifndef NDEBUG // if debug build ss << var_type; @@ -24,6 +26,7 @@ void AppendTensorInfo(std::ostringstream& ss, const Tensor& tensor, ProgramVaria #endif ss << ';'; } + if ((dependency & ProgramTensorMetadataDependency::Shape) == ProgramTensorMetadataDependency::Shape) { ss D("Dims=") << tensor.Shape().ToString(); } else if ((dependency & ProgramTensorMetadataDependency::Rank) == ProgramTensorMetadataDependency::Rank) { diff --git a/onnxruntime/core/providers/webgpu/program_manager.h b/onnxruntime/core/providers/webgpu/program_manager.h index 782788910e3a5..eded1cfa17970 100644 --- a/onnxruntime/core/providers/webgpu/program_manager.h +++ b/onnxruntime/core/providers/webgpu/program_manager.h @@ -30,7 +30,7 @@ class ProgramArtifact { const std::vector shape_uniform_ranks; ProgramArtifact(ProgramArtifact&&) = default; - ProgramArtifact& operator=(ProgramArtifact&&) = default; + ProgramArtifact& operator=(ProgramArtifact&&) = delete; private: ORT_DISALLOW_COPY_AND_ASSIGNMENT(ProgramArtifact); diff --git a/onnxruntime/core/providers/webgpu/shader_helper.cc b/onnxruntime/core/providers/webgpu/shader_helper.cc index cd21f4752f300..be89efae5fc97 100644 --- a/onnxruntime/core/providers/webgpu/shader_helper.cc +++ b/onnxruntime/core/providers/webgpu/shader_helper.cc @@ -196,6 +196,29 @@ Status ValidateVariableDependency(ProgramTensorMetadataDependency dependency, Sh } } // namespace +Status ShaderHelper::ValidateVariable(const ProgramInput& input, const ShaderVariable& var) const { + ORT_RETURN_IF_ERROR(ValidateVariableDataType(input.tensor->GetElementType(), var.type_)); + ORT_RETURN_IF_ERROR(ValidateVariableShape(input.tensor->Shape(), + input.use_override_shape, + input.use_override_shape ? input.override_shape : input.tensor->Shape(), + var.num_components_)); + ORT_RETURN_IF_ERROR(ValidateVariableDependency(input.dependency, var.usage_, true)); + + return Status::OK(); +} +Status ShaderHelper::ValidateVariable(const ProgramOutput& output, const ShaderVariable& var) const { + ORT_RETURN_IF_ERROR(ValidateVariableDataType(output.tensor->GetElementType(), var.type_)); + ORT_RETURN_IF_ERROR(ValidateVariableShape(output.tensor->Shape(), + output.use_override_shape, + output.use_override_shape ? output.override_shape : output.tensor->Shape(), + var.num_components_)); + ORT_RETURN_IF_ERROR(ValidateVariableDependency(output.dependency, var.usage_, false)); + + return Status::OK(); +} + +#endif // NDEBUG + const ShaderVariable& ShaderHelper::AddVariableImpl(ProgramVariableScope scope, const std::string& name, ShaderVariable::Usage usage, @@ -224,27 +247,6 @@ const ShaderVariable& ShaderHelper::AddVariableImpl(ProgramVariableScope scope, return *var; } -Status ShaderHelper::ValidateVariable(const ProgramInput& input, const ShaderVariable& var) const { - ORT_RETURN_IF_ERROR(ValidateVariableDataType(input.tensor->GetElementType(), var.type_)); - ORT_RETURN_IF_ERROR(ValidateVariableShape(input.tensor->Shape(), - input.use_override_shape, - input.use_override_shape ? input.override_shape : input.tensor->Shape(), - var.num_components_)); - ORT_RETURN_IF_ERROR(ValidateVariableDependency(input.dependency, var.usage_, true)); - - return Status::OK(); -} -Status ShaderHelper::ValidateVariable(const ProgramOutput& output, const ShaderVariable& var) const { - ORT_RETURN_IF_ERROR(ValidateVariableDataType(output.tensor->GetElementType(), var.type_)); - ORT_RETURN_IF_ERROR(ValidateVariableShape(output.tensor->Shape(), - output.use_override_shape, - output.use_override_shape ? output.override_shape : output.tensor->Shape(), - var.num_components_)); - ORT_RETURN_IF_ERROR(ValidateVariableDependency(output.dependency, var.usage_, false)); - - return Status::OK(); -} - Status ShaderHelper::ValidateShapeForInputsAndOutputs() const { const auto& input_vars = vars_[static_cast(ProgramVariableScope::Input)]; const auto& output_vars = vars_[static_cast(ProgramVariableScope::Output)]; @@ -304,8 +306,6 @@ Status ShaderHelper::ValidateShapeForInputsAndOutputs() const { return Status::OK(); } -#endif - Status ShaderHelper::GenerateSourceCode(std::string& code, std::vector& shape_uniform_ranks) const { std::ostringstream ss; ss.imbue(std::locale::classic()); From b7a9c0e90a164ce2f97d39ac51a5d6b3e6646a72 Mon Sep 17 00:00:00 2001 From: Yulong Wang <7679871+fs-eire@users.noreply.github.com> Date: Tue, 10 Sep 2024 17:42:39 -0700 Subject: [PATCH 71/77] use OrtMutex --- onnxruntime/core/providers/webgpu/webgpu_context.cc | 6 +++--- onnxruntime/core/providers/webgpu/webgpu_context.h | 3 ++- 2 files changed, 5 insertions(+), 4 deletions(-) diff --git a/onnxruntime/core/providers/webgpu/webgpu_context.cc b/onnxruntime/core/providers/webgpu/webgpu_context.cc index 276d74905adb7..01d7704d2be22 100644 --- a/onnxruntime/core/providers/webgpu/webgpu_context.cc +++ b/onnxruntime/core/providers/webgpu/webgpu_context.cc @@ -419,7 +419,7 @@ Status WebGpuContext::Run(const ComputeContext& context, const ProgramBase& prog } std::unordered_map> WebGpuContextFactory::contexts_; -std::mutex WebGpuContextFactory::mutex_; +OrtMutex WebGpuContextFactory::mutex_; WebGpuContext& WebGpuContextFactory::CreateContext(int context_id, WGPUInstance instance, WGPUAdapter adapter, WGPUDevice device) { if (context_id == 0) { @@ -432,7 +432,7 @@ WebGpuContext& WebGpuContextFactory::CreateContext(int context_id, WGPUInstance "WebGPU EP custom context (contextId>0) must have custom WebGPU instance, adapter and device."); } - std::lock_guard lock(mutex_); + std::lock_guard lock(mutex_); auto it = contexts_.find(context_id); if (it == contexts_.end()) { @@ -446,7 +446,7 @@ WebGpuContext& WebGpuContextFactory::CreateContext(int context_id, WGPUInstance } WebGpuContext& WebGpuContextFactory::GetContext(int context_id) { - std::lock_guard lock(mutex_); + std::lock_guard lock(mutex_); auto it = contexts_.find(context_id); ORT_ENFORCE(it != contexts_.end(), "WebGPU EP context ID ", context_id, " is not found."); diff --git a/onnxruntime/core/providers/webgpu/webgpu_context.h b/onnxruntime/core/providers/webgpu/webgpu_context.h index d8b0c2b48b067..2086213e248f3 100644 --- a/onnxruntime/core/providers/webgpu/webgpu_context.h +++ b/onnxruntime/core/providers/webgpu/webgpu_context.h @@ -13,6 +13,7 @@ #include #include "core/common/common.h" +#include "core/platform/ort_mutex.h" #include "core/providers/webgpu/webgpu_execution_provider.h" #include "core/providers/webgpu/buffer_manager.h" #include "core/providers/webgpu/program_manager.h" @@ -34,7 +35,7 @@ class WebGpuContextFactory { WebGpuContextFactory() {} static std::unordered_map> contexts_; - static std::mutex mutex_; + static OrtMutex mutex_; }; // Class WebGpuContext includes all necessary resources for the context. From d4a963d7bf7e9be9b09e41056eda4d6e9a9fe550 Mon Sep 17 00:00:00 2001 From: Xu Xing Date: Wed, 11 Sep 2024 15:31:11 +0800 Subject: [PATCH 72/77] [webgpu-native] Add transpose op (#21986) Co-authored-by: Yulong Wang <7679871+fs-eire@users.noreply.github.com> --- .../core/providers/webgpu/tensor/transpose.cc | 103 ++++++++++++++++++ .../core/providers/webgpu/tensor/transpose.h | 37 +++++++ .../webgpu/webgpu_execution_provider.cc | 6 +- .../providers/webgpu/webgpu_supported_types.h | 6 +- 4 files changed, 146 insertions(+), 6 deletions(-) create mode 100644 onnxruntime/core/providers/webgpu/tensor/transpose.cc create mode 100644 onnxruntime/core/providers/webgpu/tensor/transpose.h diff --git a/onnxruntime/core/providers/webgpu/tensor/transpose.cc b/onnxruntime/core/providers/webgpu/tensor/transpose.cc new file mode 100644 index 0000000000000..68af858d515c2 --- /dev/null +++ b/onnxruntime/core/providers/webgpu/tensor/transpose.cc @@ -0,0 +1,103 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "core/common/inlined_containers.h" +#include "core/providers/webgpu/tensor/transpose.h" +#include "core/providers/cpu/tensor/utils.h" +#include "core/providers/webgpu/shader_variable.h" +#include "core/providers/webgpu/shader_helper.h" +#include "core/providers/webgpu/webgpu_supported_types.h" + +namespace onnxruntime { +namespace webgpu { + +ONNX_OPERATOR_VERSIONED_KERNEL_EX( + Transpose, + kOnnxDomain, + 1, 12, + kWebGpuExecutionProvider, + (*KernelDefBuilder::Create()) + .TypeConstraint("T", WebGpuSupportedNumberTypes()), + Transpose); + +ONNX_OPERATOR_VERSIONED_KERNEL_EX( + Transpose, + kOnnxDomain, + 13, 20, + kWebGpuExecutionProvider, + (*KernelDefBuilder::Create()) + .TypeConstraint("T", WebGpuSupportedNumberTypes()), + Transpose); + +ONNX_OPERATOR_VERSIONED_KERNEL_EX( + Transpose, + kOnnxDomain, + 21, 22, + kWebGpuExecutionProvider, + (*KernelDefBuilder::Create()) + .TypeConstraint("T", WebGpuSupportedNumberTypes()), + Transpose); + +ONNX_OPERATOR_KERNEL_EX( + Transpose, + kOnnxDomain, + 23, + kWebGpuExecutionProvider, + (*KernelDefBuilder::Create()) + .TypeConstraint("T", WebGpuSupportedNumberTypes()), + Transpose); + +const std::string AppendPermFunction(gsl::span perm) { + std::ostringstream ss; + ss.imbue(std::locale::classic()); + ss << "fn perm(i: y_indices_t)->x_indices_t {\n" + " var a: x_indices_t;\n"; + for (auto i = 0; i < perm.size(); ++i) { + ss << " a[" << perm[i] << "] = i[" << i << "];\n"; + } + ss << " return a;\n" + "}\n"; + return ss.str(); +} + +Status TransposeProgram::GenerateShaderCode(ShaderHelper& shader) const { + const auto& input = shader.AddInput("x", ShaderVariable::UseUniform | ShaderVariable::UseIndicesTypeAlias); + const auto& output = shader.AddOutput("y", ShaderVariable::UseUniform | ShaderVariable::UseIndicesTypeAlias); + shader.AppendImplementation(AppendPermFunction(this->perm_)); + shader.SetMainFunctionBody(shader.GuardAgainstOutOfBoundsWorkgroupSizes("uniforms.output_size"), + " let indices = ", output.OffsetToIndices("global_idx"), + ";\n" + " let x_indices = perm(indices); \n" + " ", + output.SetByOffset("global_idx", input.GetByIndices("x_indices"))); + return Status::OK(); +} + +Status Transpose::ComputeInternal(ComputeContext& context) const { + // TODO: there is an optimized version of transpose to port. + const auto* input_tensor = context.Input(0); + const TensorShape& input_shape = input_tensor->Shape(); + int32_t rank = gsl::narrow_cast(input_shape.NumDimensions()); + + TensorShapeVector output_dims(rank); + InlinedVector default_perm(rank); + const InlinedVector* p_perm = nullptr; + ORT_RETURN_IF_ERROR(ComputeOutputShape(*input_tensor, output_dims, default_perm, p_perm)); + TensorShape output_shape(output_dims); + auto* output_tensor = context.Output(0, output_shape); + + uint32_t output_size = gsl::narrow_cast(input_tensor->Shape().Size()); + TransposeProgram program{*p_perm}; + program + .CacheHint(absl::StrJoin(*p_perm, "-")) + .AddInputs({{input_tensor, ProgramTensorMetadataDependency::TypeAndRank}}) + .AddOutputs({output_tensor}) + .SetDispatchGroupSize((output_size + WORKGROUP_SIZE - 1) / WORKGROUP_SIZE) + .AddUniformVariables({ + {static_cast(output_size)}, + }); + return context.RunProgram(program); +} + +} // namespace webgpu +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/webgpu/tensor/transpose.h b/onnxruntime/core/providers/webgpu/tensor/transpose.h new file mode 100644 index 0000000000000..3ca5674d5dfab --- /dev/null +++ b/onnxruntime/core/providers/webgpu/tensor/transpose.h @@ -0,0 +1,37 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include "core/providers/webgpu/webgpu_supported_types.h" +#include "core/providers/cpu/tensor/transpose.h" +#include "core/providers/webgpu/webgpu_kernel.h" +#include "core/providers/webgpu/program.h" + +namespace onnxruntime { +namespace webgpu { + +class TransposeProgram final : public Program { + public: + TransposeProgram(const gsl::span& permutations) + : Program{"Transpose"}, perm_(permutations.begin(), permutations.end()) { + } + + Status GenerateShaderCode(ShaderHelper& sh) const override; + + WEBGPU_PROGRAM_DEFINE_UNIFORM_VARIABLES({"output_size", ProgramUniformVariableDataType::Uint32}); + + private: + InlinedVector perm_; +}; + +class Transpose final : public WebGpuKernel, public TransposeBase { + public: + Transpose(const OpKernelInfo& info) : WebGpuKernel{info}, TransposeBase{info} { + } + + Status ComputeInternal(ComputeContext& context) const override; +}; + +} // namespace webgpu +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/webgpu/webgpu_execution_provider.cc b/onnxruntime/core/providers/webgpu/webgpu_execution_provider.cc index decc74b59cae6..ae5b429fb2301 100644 --- a/onnxruntime/core/providers/webgpu/webgpu_execution_provider.cc +++ b/onnxruntime/core/providers/webgpu/webgpu_execution_provider.cc @@ -239,7 +239,7 @@ class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxD class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 16, Where); class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 1, 12, Transpose); -class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 13, Transpose); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 13, 20, Transpose); class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 11, 12, DepthToSpace); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 13, DepthToSpace); @@ -552,8 +552,8 @@ std::unique_ptr RegisterKernels() { // KERNEL_CREATE_INFO_VERSIONED(9, 15, Where), // KERNEL_CREATE_INFO(16, Where), - // BuildKernelCreateInfo, - // BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, // BuildKernelCreateInfo, // BuildKernelCreateInfo, diff --git a/onnxruntime/core/providers/webgpu/webgpu_supported_types.h b/onnxruntime/core/providers/webgpu/webgpu_supported_types.h index fccaef2c53575..ff66cd535399e 100644 --- a/onnxruntime/core/providers/webgpu/webgpu_supported_types.h +++ b/onnxruntime/core/providers/webgpu/webgpu_supported_types.h @@ -8,7 +8,7 @@ namespace onnxruntime { namespace webgpu { -using SupportedTypes = +using SupportedNumberTypes = TypeList< float, MLFloat16, @@ -20,8 +20,8 @@ using SupportedFloats = float, MLFloat16>; -inline const std::vector& WebGpuSupportedDataTypes() { - static const std::vector supportedDataTypes = BuildKernelDefConstraintsFromTypeList(); +inline const std::vector& WebGpuSupportedNumberTypes() { + static const std::vector supportedDataTypes = BuildKernelDefConstraintsFromTypeList(); return supportedDataTypes; } From 8b61532e73f7d04b65c00ebfc719d03e085a9a06 Mon Sep 17 00:00:00 2001 From: Yulong Wang <7679871+fs-eire@users.noreply.github.com> Date: Wed, 11 Sep 2024 16:34:09 -0700 Subject: [PATCH 73/77] PushErrorScope and PopErrorScope --- .../core/providers/webgpu/compute_context.cc | 20 +++++++++++++++++++ .../core/providers/webgpu/compute_context.h | 14 +++++++++++++ .../core/providers/webgpu/webgpu_context.cc | 5 +++++ .../core/providers/webgpu/webgpu_kernel.h | 18 ++++++++--------- 4 files changed, 47 insertions(+), 10 deletions(-) diff --git a/onnxruntime/core/providers/webgpu/compute_context.cc b/onnxruntime/core/providers/webgpu/compute_context.cc index b7a1af5b26ef7..62289b7cd6aa6 100644 --- a/onnxruntime/core/providers/webgpu/compute_context.cc +++ b/onnxruntime/core/providers/webgpu/compute_context.cc @@ -13,5 +13,25 @@ ComputeContext::ComputeContext(OpKernelContext& kernel_context) kernel_context_{kernel_context} { } +void ComputeContext::PushErrorScope() { + webgpu_context_.Device().PushErrorScope(wgpu::ErrorFilter::Validation); +} + +Status ComputeContext::PopErrorScope() { + Status status{}; + + ORT_RETURN_IF_ERROR(webgpu_context_.Wait( + webgpu_context_.Device().PopErrorScope( + wgpu::CallbackMode::WaitAnyOnly, [](wgpu::PopErrorScopeStatus pop_status, wgpu::ErrorType error_type, char const* message, Status* status) { + ORT_ENFORCE(pop_status == wgpu::PopErrorScopeStatus::Success, "Instance dropped."); + if (error_type == wgpu::ErrorType::NoError) { + return; + } + *status = ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "WebGPU validation failed. ", message); + }, + &status))); + return status; +} + } // namespace webgpu } // namespace onnxruntime diff --git a/onnxruntime/core/providers/webgpu/compute_context.h b/onnxruntime/core/providers/webgpu/compute_context.h index 132f629ac745e..c98480523ae64 100644 --- a/onnxruntime/core/providers/webgpu/compute_context.h +++ b/onnxruntime/core/providers/webgpu/compute_context.h @@ -106,6 +106,20 @@ class ComputeContext { return webgpu_context_.Run(*this, program); } + // + // Push error scope. + // + // This is useful only when "skip_validation" is not set. + // + void PushErrorScope(); + + // + // Pop error scope. + // + // This is useful only when "skip_validation" is not set. + // + Status PopErrorScope(); + protected: WebGpuContext& webgpu_context_; OpKernelContext& kernel_context_; diff --git a/onnxruntime/core/providers/webgpu/webgpu_context.cc b/onnxruntime/core/providers/webgpu/webgpu_context.cc index 01d7704d2be22..ec8f0cda10ee1 100644 --- a/onnxruntime/core/providers/webgpu/webgpu_context.cc +++ b/onnxruntime/core/providers/webgpu/webgpu_context.cc @@ -34,7 +34,12 @@ std::vector GetEnabledDeviceToggles() { // Enable / disable other toggles that may affect the performance. // Other toggles that may be useful: "dump_shaders", "disable_symbol_renaming" constexpr const char* toggles[] = { +#ifdef NDEBUG + // todo: when skip validation, the process may crash. + // need careful decision to enable this toggle. + // revisit this flag before release. "skip_validation", +#endif "disable_robustness", "disable_workgroup_init", "d3d_disable_ieee_strictness", diff --git a/onnxruntime/core/providers/webgpu/webgpu_kernel.h b/onnxruntime/core/providers/webgpu/webgpu_kernel.h index 6486987501d14..72fea52313f9a 100644 --- a/onnxruntime/core/providers/webgpu/webgpu_kernel.h +++ b/onnxruntime/core/providers/webgpu/webgpu_kernel.h @@ -22,16 +22,14 @@ class WebGpuKernel : public OpKernel { Status Compute(OpKernelContext* p_op_kernel_context) const override { ComputeContext context{*p_op_kernel_context}; - auto s = ComputeInternal(context); - // use this to precisely locate the node where CUDA failure comes from - // if (cudaSuccess != cudaDeviceSynchronize()) - // __debugbreak(); - // if (s.IsOK()) { - // auto err = cudaGetLastError(); - // if (err != cudaSuccess) { - // return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "CUDA error ", cudaGetErrorName(err), ":", cudaGetErrorString(err)); - // } - // } +#ifndef NDEBUG + context.PushErrorScope(); +#endif + Status s = ComputeInternal(context); +#ifndef NDEBUG + ORT_RETURN_IF_ERROR(context.PopErrorScope()); +#endif + return s; } From dce0f181a272668fe459915374ca1cc1424525ff Mon Sep 17 00:00:00 2001 From: Yulong Wang <7679871+fs-eire@users.noreply.github.com> Date: Wed, 11 Sep 2024 18:05:21 -0700 Subject: [PATCH 74/77] placeholder for setting proc table --- .../webgpu/webgpu_provider_factory.cc | 18 ++++++++++++++++-- .../providers/webgpu/webgpu_provider_options.h | 2 ++ 2 files changed, 18 insertions(+), 2 deletions(-) diff --git a/onnxruntime/core/providers/webgpu/webgpu_provider_factory.cc b/onnxruntime/core/providers/webgpu/webgpu_provider_factory.cc index 3848ccfc19f51..b03bddf408b64 100644 --- a/onnxruntime/core/providers/webgpu/webgpu_provider_factory.cc +++ b/onnxruntime/core/providers/webgpu/webgpu_provider_factory.cc @@ -3,6 +3,8 @@ #include +#include + #include "core/framework/error_code_helper.h" #include "core/providers/webgpu/buffer_manager.h" #include "core/providers/webgpu/webgpu_execution_provider.h" @@ -33,7 +35,19 @@ struct WebGpuProviderFactory : IExecutionProviderFactory { std::shared_ptr WebGpuProviderFactoryCreator::Create(const ConfigOptions& config_options) { // - // STEP.1 - prepare WebGpuExecutionProviderInfo + // STEP.1 - set dawn proc table + // + std::string dawn_proc_table_str; + if (config_options.TryGetConfigEntry(kDawnProcTable, dawn_proc_table_str)) { + size_t dawn_proc_table = 0; + ORT_ENFORCE(std::errc{} == + std::from_chars(dawn_proc_table_str.data(), dawn_proc_table_str.data() + dawn_proc_table_str.size(), dawn_proc_table).ec); + // TODO: do this for static link build + // dawnProcSetProcs(reinterpret_cast(dawn_proc_table)); + } + + // + // STEP.2 - prepare WebGpuExecutionProviderInfo // WebGpuExecutionProviderInfo webgpu_ep_info{ // preferred layout is NHWC by default @@ -100,7 +114,7 @@ std::shared_ptr WebGpuProviderFactoryCreator::Create( LOGS_DEFAULT(VERBOSE) << "WebGPU EP default buffer cache mode: " << webgpu_ep_info.default_buffer_cache_mode; // - // STEP.2 - prepare WebGpuContext + // STEP.3 - prepare WebGpuContext // int context_id = 0; std::string context_id_str; diff --git a/onnxruntime/core/providers/webgpu/webgpu_provider_options.h b/onnxruntime/core/providers/webgpu/webgpu_provider_options.h index 65ccbd800b122..334f21c737afe 100644 --- a/onnxruntime/core/providers/webgpu/webgpu_provider_options.h +++ b/onnxruntime/core/providers/webgpu/webgpu_provider_options.h @@ -17,6 +17,8 @@ constexpr const char* kWebGpuInstance = "webgpuInstance"; constexpr const char* kWebGpuAdapter = "webgpuAdapter"; constexpr const char* kWebGpuDevice = "webgpuDevice"; +constexpr const char* kDawnProcTable = "dawnProcTable"; + constexpr const char* kStorageBufferCacheMode = "storageBufferCacheMode"; constexpr const char* kUniformBufferCacheMode = "uniformBufferCacheMode"; constexpr const char* kQueryResolveBufferCacheMode = "queryResolveBufferCacheMode"; From 8978d8954bf343f2d6ed5426b2406812e69fc82e Mon Sep 17 00:00:00 2001 From: Yulong Wang <7679871+fs-eire@users.noreply.github.com> Date: Wed, 11 Sep 2024 20:21:55 -0700 Subject: [PATCH 75/77] Revert "placeholder for setting proc table" This reverts commit dce0f181a272668fe459915374ca1cc1424525ff. --- .../webgpu/webgpu_provider_factory.cc | 18 ++---------------- .../providers/webgpu/webgpu_provider_options.h | 2 -- 2 files changed, 2 insertions(+), 18 deletions(-) diff --git a/onnxruntime/core/providers/webgpu/webgpu_provider_factory.cc b/onnxruntime/core/providers/webgpu/webgpu_provider_factory.cc index b03bddf408b64..3848ccfc19f51 100644 --- a/onnxruntime/core/providers/webgpu/webgpu_provider_factory.cc +++ b/onnxruntime/core/providers/webgpu/webgpu_provider_factory.cc @@ -3,8 +3,6 @@ #include -#include - #include "core/framework/error_code_helper.h" #include "core/providers/webgpu/buffer_manager.h" #include "core/providers/webgpu/webgpu_execution_provider.h" @@ -35,19 +33,7 @@ struct WebGpuProviderFactory : IExecutionProviderFactory { std::shared_ptr WebGpuProviderFactoryCreator::Create(const ConfigOptions& config_options) { // - // STEP.1 - set dawn proc table - // - std::string dawn_proc_table_str; - if (config_options.TryGetConfigEntry(kDawnProcTable, dawn_proc_table_str)) { - size_t dawn_proc_table = 0; - ORT_ENFORCE(std::errc{} == - std::from_chars(dawn_proc_table_str.data(), dawn_proc_table_str.data() + dawn_proc_table_str.size(), dawn_proc_table).ec); - // TODO: do this for static link build - // dawnProcSetProcs(reinterpret_cast(dawn_proc_table)); - } - - // - // STEP.2 - prepare WebGpuExecutionProviderInfo + // STEP.1 - prepare WebGpuExecutionProviderInfo // WebGpuExecutionProviderInfo webgpu_ep_info{ // preferred layout is NHWC by default @@ -114,7 +100,7 @@ std::shared_ptr WebGpuProviderFactoryCreator::Create( LOGS_DEFAULT(VERBOSE) << "WebGPU EP default buffer cache mode: " << webgpu_ep_info.default_buffer_cache_mode; // - // STEP.3 - prepare WebGpuContext + // STEP.2 - prepare WebGpuContext // int context_id = 0; std::string context_id_str; diff --git a/onnxruntime/core/providers/webgpu/webgpu_provider_options.h b/onnxruntime/core/providers/webgpu/webgpu_provider_options.h index 334f21c737afe..65ccbd800b122 100644 --- a/onnxruntime/core/providers/webgpu/webgpu_provider_options.h +++ b/onnxruntime/core/providers/webgpu/webgpu_provider_options.h @@ -17,8 +17,6 @@ constexpr const char* kWebGpuInstance = "webgpuInstance"; constexpr const char* kWebGpuAdapter = "webgpuAdapter"; constexpr const char* kWebGpuDevice = "webgpuDevice"; -constexpr const char* kDawnProcTable = "dawnProcTable"; - constexpr const char* kStorageBufferCacheMode = "storageBufferCacheMode"; constexpr const char* kUniformBufferCacheMode = "uniformBufferCacheMode"; constexpr const char* kQueryResolveBufferCacheMode = "queryResolveBufferCacheMode"; From 43ccaf45b6a791a8acb8b2e323a1cc6a38d33b13 Mon Sep 17 00:00:00 2001 From: Yulong Wang <7679871+fs-eire@users.noreply.github.com> Date: Wed, 11 Sep 2024 22:34:24 -0700 Subject: [PATCH 76/77] allow setting "ValidationMode" --- onnxruntime/core/providers/webgpu/program.h | 7 + .../core/providers/webgpu/webgpu_context.cc | 242 +++++++++--------- .../core/providers/webgpu/webgpu_context.h | 21 +- .../webgpu/webgpu_execution_provider.cc | 3 +- .../webgpu/webgpu_execution_provider.h | 6 +- .../webgpu/webgpu_provider_factory.cc | 25 +- .../webgpu/webgpu_provider_options.h | 7 + 7 files changed, 186 insertions(+), 125 deletions(-) diff --git a/onnxruntime/core/providers/webgpu/program.h b/onnxruntime/core/providers/webgpu/program.h index ae3d82a6371d0..0daf247661362 100644 --- a/onnxruntime/core/providers/webgpu/program.h +++ b/onnxruntime/core/providers/webgpu/program.h @@ -253,6 +253,13 @@ struct ProgramOutput { TensorShape override_shape; }; +enum class ValidationMode { + Disabled = 0, + WGPUOnly, + Basic, + Full +}; + namespace detail { class ProgramWrapper; } diff --git a/onnxruntime/core/providers/webgpu/webgpu_context.cc b/onnxruntime/core/providers/webgpu/webgpu_context.cc index ec8f0cda10ee1..11a337cd3e37e 100644 --- a/onnxruntime/core/providers/webgpu/webgpu_context.cc +++ b/onnxruntime/core/providers/webgpu/webgpu_context.cc @@ -17,79 +17,6 @@ namespace onnxruntime { namespace webgpu { -namespace { - -std::vector GetEnabledAdapterToggles() { - // See the description of all the toggles in toggles.cpp - // "use_dxc" for Shader Model 6+ features (e.g. float16) - // "allow_unsafe_apis" for chromium experimental features - constexpr const char* toggles[] = { - "use_dxc", - "allow_unsafe_apis", - }; - return std::vector(std::begin(toggles), std::end(toggles)); -} - -std::vector GetEnabledDeviceToggles() { - // Enable / disable other toggles that may affect the performance. - // Other toggles that may be useful: "dump_shaders", "disable_symbol_renaming" - constexpr const char* toggles[] = { -#ifdef NDEBUG - // todo: when skip validation, the process may crash. - // need careful decision to enable this toggle. - // revisit this flag before release. - "skip_validation", -#endif - "disable_robustness", - "disable_workgroup_init", - "d3d_disable_ieee_strictness", - }; - return std::vector(std::begin(toggles), std::end(toggles)); -} - -std::vector GetDisabledDeviceToggles() { - constexpr const char* toggles[] = { - "lazy_clear_resource_on_first_use", - }; - return std::vector(std::begin(toggles), std::end(toggles)); -} - -std::vector GetAvailableRequiredFeatures(const wgpu::Adapter& adapter) { - std::vector required_features; - constexpr wgpu::FeatureName features[]{ - wgpu::FeatureName::ChromiumExperimentalTimestampQueryInsidePasses, - wgpu::FeatureName::TimestampQuery, - wgpu::FeatureName::ShaderF16, - wgpu::FeatureName::Subgroups, - wgpu::FeatureName::SubgroupsF16}; - for (auto feature : features) { - if (adapter.HasFeature(feature)) { - required_features.push_back(feature); - } - } - return required_features; -} - -wgpu::RequiredLimits GetRequiredLimits(const wgpu::Adapter& adapter) { - wgpu::RequiredLimits required_limits{}; - wgpu::SupportedLimits adapter_limits; - ORT_ENFORCE(adapter.GetLimits(&adapter_limits)); - - required_limits.limits.maxBindGroups = adapter_limits.limits.maxBindGroups; - required_limits.limits.maxComputeWorkgroupStorageSize = adapter_limits.limits.maxComputeWorkgroupStorageSize; - required_limits.limits.maxComputeWorkgroupsPerDimension = adapter_limits.limits.maxComputeWorkgroupsPerDimension; - required_limits.limits.maxStorageBufferBindingSize = adapter_limits.limits.maxStorageBufferBindingSize; - required_limits.limits.maxBufferSize = adapter_limits.limits.maxBufferSize; - required_limits.limits.maxComputeInvocationsPerWorkgroup = adapter_limits.limits.maxComputeInvocationsPerWorkgroup; - required_limits.limits.maxComputeWorkgroupSizeX = adapter_limits.limits.maxComputeWorkgroupSizeX; - required_limits.limits.maxComputeWorkgroupSizeY = adapter_limits.limits.maxComputeWorkgroupSizeY; - required_limits.limits.maxComputeWorkgroupSizeZ = adapter_limits.limits.maxComputeWorkgroupSizeZ; - - return required_limits; -} - -} // namespace - void WebGpuContext::Initialize(const WebGpuExecutionProviderInfo& webgpu_ep_info) { std::call_once(init_flag_, [this, &webgpu_ep_info]() { // Initialization.Step.1 - Create wgpu::Instance @@ -194,34 +121,34 @@ Status WebGpuContext::Run(const ComputeContext& context, const ProgramBase& prog const auto& inputs = program.Inputs(); const auto& outputs = program.Outputs(); -#ifndef NDEBUG // if debug build - ORT_ENFORCE(std::all_of(inputs.begin(), inputs.end(), [](const ProgramInput& input) { - const auto* tensor = input.tensor; - return tensor != nullptr && - tensor->Location().mem_type == OrtMemType::OrtMemTypeDefault && - tensor->Location().device.Type() == OrtDevice::GPU && - !strcmp(tensor->Location().name, WEBGPU_BUFFER); - }), - "All inputs must be tensors on WebGPU buffers."); - - ORT_ENFORCE(std::all_of(outputs.begin(), outputs.end(), [](const ProgramOutput& output) { - const auto* tensor = output.tensor; - return tensor != nullptr && - tensor->Location().mem_type == OrtMemType::OrtMemTypeDefault && - tensor->Location().device.Type() == OrtDevice::GPU && - !strcmp(tensor->Location().name, WEBGPU_BUFFER); - }), - "All outputs must be tensors on WebGPU buffers."); -#endif - if (outputs.size() == 0) { return Status::OK(); } + if (ValidationMode() >= ValidationMode::Basic) { + ORT_ENFORCE(std::all_of(inputs.begin(), inputs.end(), [](const ProgramInput& input) { + const auto* tensor = input.tensor; + return tensor != nullptr && + tensor->Location().mem_type == OrtMemType::OrtMemTypeDefault && + tensor->Location().device.Type() == OrtDevice::GPU && + !strcmp(tensor->Location().name, WEBGPU_BUFFER); + }), + "All inputs must be tensors on WebGPU buffers."); + + ORT_ENFORCE(std::all_of(outputs.begin(), outputs.end(), [](const ProgramOutput& output) { + const auto* tensor = output.tensor; + return tensor != nullptr && + tensor->Location().mem_type == OrtMemType::OrtMemTypeDefault && + tensor->Location().device.Type() == OrtDevice::GPU && + !strcmp(tensor->Location().name, WEBGPU_BUFFER); + }), + "All outputs must be tensors on WebGPU buffers."); + } + const ProgramMetadata metadata = program.GetMetadata(); // validate program metadata - { + if (ValidationMode() >= ValidationMode::Basic) { const auto& [constants, overridable_constants, uniform_variables] = metadata; // check overridable constants @@ -229,17 +156,20 @@ Status WebGpuContext::Run(const ComputeContext& context, const ProgramBase& prog "Size of overridable constants mismatch in program \"", program.Name(), "\", Expected: ", overridable_constants.size(), ", Actual: ", program.OverridableConstants().size()); - size_t num_overridable_constants = program.OverridableConstants().size(); - for (size_t i = 0; i < num_overridable_constants; ++i) { - const auto& override_value = program.OverridableConstants()[i]; - const auto& definition = overridable_constants[i]; - ORT_RETURN_IF(override_value.has_value && override_value.type != definition.type, - "Overridable override_value[", i, "] (", definition.name, ") data type mismatch in program \"", program.Name(), - "\", Expected: ", definition.type, - ", Actual: ", override_value.type); - ORT_RETURN_IF(!override_value.has_value && !definition.has_default_value, - "Overridable override_value[", i, "] (", definition.name, ") no override_value specified in program \"", program.Name(), - "\""); + + if (ValidationMode() >= ValidationMode::Full) { + size_t num_overridable_constants = program.OverridableConstants().size(); + for (size_t i = 0; i < num_overridable_constants; ++i) { + const auto& override_value = program.OverridableConstants()[i]; + const auto& definition = overridable_constants[i]; + ORT_RETURN_IF(override_value.has_value && override_value.type != definition.type, + "Overridable override_value[", i, "] (", definition.name, ") data type mismatch in program \"", program.Name(), + "\", Expected: ", definition.type, + ", Actual: ", override_value.type); + ORT_RETURN_IF(!override_value.has_value && !definition.has_default_value, + "Overridable override_value[", i, "] (", definition.name, ") no override_value specified in program \"", program.Name(), + "\""); + } } // check uniform variables @@ -247,14 +177,17 @@ Status WebGpuContext::Run(const ComputeContext& context, const ProgramBase& prog "Size of uniform_value variables mismatch in program \"", program.Name(), "\", Expected: ", uniform_variables.size(), ", Actual: ", program.UniformVariables().size()); - size_t num_uniform_variables = program.UniformVariables().size(); - for (size_t i = 0; i < num_uniform_variables; ++i) { - const auto& uniform_value = program.UniformVariables()[i]; - const auto& definition = uniform_variables[i]; - ORT_RETURN_IF(uniform_value.length > 0 && uniform_value.data_type != definition.data_type, - "Uniform variable[", i, "] (", definition.name, ") data type mismatch in program \"", program.Name(), - "\", Expected: ", definition.data_type, - ", Actual: ", uniform_value.data_type); + + if (ValidationMode() >= ValidationMode::Full) { + size_t num_uniform_variables = program.UniformVariables().size(); + for (size_t i = 0; i < num_uniform_variables; ++i) { + const auto& uniform_value = program.UniformVariables()[i]; + const auto& definition = uniform_variables[i]; + ORT_RETURN_IF(uniform_value.length > 0 && uniform_value.data_type != definition.data_type, + "Uniform variable[", i, "] (", definition.name, ") data type mismatch in program \"", program.Name(), + "\", Expected: ", definition.data_type, + ", Actual: ", uniform_value.data_type); + } } } @@ -295,9 +228,11 @@ Status WebGpuContext::Run(const ComputeContext& context, const ProgramBase& prog // prepare shape uniforms for shader variables (if any) and user defined uniforms std::vector shape_uniforms; shape_uniforms.reserve(program_artifact->shape_uniform_ranks.size() * 2); - ORT_RETURN_IF_NOT(program_artifact->shape_uniform_ranks.size() == inputs.size() + outputs.size(), - "Invalid program artifact: variable size (", program_artifact->shape_uniform_ranks.size(), - ") does not match current program (input: ", inputs.size(), ", output: ", outputs.size(), ")"); + if (ValidationMode() >= ValidationMode::Basic) { + ORT_RETURN_IF_NOT(program_artifact->shape_uniform_ranks.size() == inputs.size() + outputs.size(), + "Invalid program artifact: variable size (", program_artifact->shape_uniform_ranks.size(), + ") does not match current program (input: ", inputs.size(), ", output: ", outputs.size(), ")"); + } for (size_t i = 0; i < program_artifact->shape_uniform_ranks.size(); ++i) { SafeInt expected_rank = program_artifact->shape_uniform_ranks[i]; if (expected_rank > 0) { @@ -423,10 +358,81 @@ Status WebGpuContext::Run(const ComputeContext& context, const ProgramBase& prog return Status::OK(); } +std::vector WebGpuContext::GetEnabledAdapterToggles() const { + // See the description of all the toggles in toggles.cpp + // "use_dxc" for Shader Model 6+ features (e.g. float16) + // "allow_unsafe_apis" for chromium experimental features + constexpr const char* toggles[] = { + "use_dxc", + "allow_unsafe_apis", + }; + return std::vector(std::begin(toggles), std::end(toggles)); +} + +std::vector WebGpuContext::GetEnabledDeviceToggles() const { + // Enable / disable other toggles that may affect the performance. + // Other toggles that may be useful: "dump_shaders", "disable_symbol_renaming" + constexpr const char* toggles[] = { + "skip_validation", // only use "skip_validation" when ValidationMode is set to "Disabled" + "disable_robustness", + "disable_workgroup_init", + "d3d_disable_ieee_strictness", + }; + return std::vector(ValidationMode() >= ValidationMode::WGPUOnly + ? std::begin(toggles) + 1 + : std::begin(toggles), + std::end(toggles)); +} + +std::vector WebGpuContext::GetDisabledDeviceToggles() const { + constexpr const char* toggles[] = { + "lazy_clear_resource_on_first_use", + }; + return std::vector(std::begin(toggles), std::end(toggles)); +} + +std::vector WebGpuContext::GetAvailableRequiredFeatures(const wgpu::Adapter& adapter) const { + std::vector required_features; + constexpr wgpu::FeatureName features[]{ + wgpu::FeatureName::ChromiumExperimentalTimestampQueryInsidePasses, + wgpu::FeatureName::TimestampQuery, + wgpu::FeatureName::ShaderF16, + wgpu::FeatureName::Subgroups, + wgpu::FeatureName::SubgroupsF16}; + for (auto feature : features) { + if (adapter.HasFeature(feature)) { + required_features.push_back(feature); + } + } + return required_features; +} + +wgpu::RequiredLimits WebGpuContext::GetRequiredLimits(const wgpu::Adapter& adapter) const { + wgpu::RequiredLimits required_limits{}; + wgpu::SupportedLimits adapter_limits; + ORT_ENFORCE(adapter.GetLimits(&adapter_limits)); + + required_limits.limits.maxBindGroups = adapter_limits.limits.maxBindGroups; + required_limits.limits.maxComputeWorkgroupStorageSize = adapter_limits.limits.maxComputeWorkgroupStorageSize; + required_limits.limits.maxComputeWorkgroupsPerDimension = adapter_limits.limits.maxComputeWorkgroupsPerDimension; + required_limits.limits.maxStorageBufferBindingSize = adapter_limits.limits.maxStorageBufferBindingSize; + required_limits.limits.maxBufferSize = adapter_limits.limits.maxBufferSize; + required_limits.limits.maxComputeInvocationsPerWorkgroup = adapter_limits.limits.maxComputeInvocationsPerWorkgroup; + required_limits.limits.maxComputeWorkgroupSizeX = adapter_limits.limits.maxComputeWorkgroupSizeX; + required_limits.limits.maxComputeWorkgroupSizeY = adapter_limits.limits.maxComputeWorkgroupSizeY; + required_limits.limits.maxComputeWorkgroupSizeZ = adapter_limits.limits.maxComputeWorkgroupSizeZ; + + return required_limits; +} + std::unordered_map> WebGpuContextFactory::contexts_; OrtMutex WebGpuContextFactory::mutex_; -WebGpuContext& WebGpuContextFactory::CreateContext(int context_id, WGPUInstance instance, WGPUAdapter adapter, WGPUDevice device) { +WebGpuContext& WebGpuContextFactory::CreateContext(int context_id, + WGPUInstance instance, + WGPUAdapter adapter, + WGPUDevice device, + ValidationMode validation_mode) { if (context_id == 0) { // context ID is preserved for the default context. User cannot use context ID 0 as a custom context. ORT_ENFORCE(instance == nullptr && adapter == nullptr && device == nullptr, @@ -441,7 +447,7 @@ WebGpuContext& WebGpuContextFactory::CreateContext(int context_id, WGPUInstance auto it = contexts_.find(context_id); if (it == contexts_.end()) { - auto context = std::unique_ptr(new WebGpuContext(instance, adapter, device)); + auto context = std::unique_ptr(new WebGpuContext(instance, adapter, device, validation_mode)); it = contexts_.emplace(context_id, std::move(context)).first; } else if (context_id != 0) { ORT_ENFORCE(it->second->instance_.Get() == instance && it->second->adapter_.Get() == adapter && it->second->device_.Get() == device, diff --git a/onnxruntime/core/providers/webgpu/webgpu_context.h b/onnxruntime/core/providers/webgpu/webgpu_context.h index 2086213e248f3..3251364e85ce3 100644 --- a/onnxruntime/core/providers/webgpu/webgpu_context.h +++ b/onnxruntime/core/providers/webgpu/webgpu_context.h @@ -28,7 +28,11 @@ class ProgramBase; class WebGpuContextFactory { public: - static WebGpuContext& CreateContext(int context_id, WGPUInstance instance, WGPUAdapter adapter, WGPUDevice device); + static WebGpuContext& CreateContext(int context_id, + WGPUInstance instance, + WGPUAdapter adapter, + WGPUDevice device, + ValidationMode validation_mode); static WebGpuContext& GetContext(int context_id); private: @@ -95,18 +99,31 @@ class WebGpuContext final { webgpu::BufferManager& BufferManager() const { return *buffer_mgr_; } + inline webgpu::ValidationMode ValidationMode() const { + return validation_mode_; + } + Status Run(const ComputeContext& context, const ProgramBase& program); private: - WebGpuContext(WGPUInstance instance, WGPUAdapter adapter, WGPUDevice device) : instance_{instance}, adapter_{adapter}, device_{device} {} + WebGpuContext(WGPUInstance instance, WGPUAdapter adapter, WGPUDevice device, webgpu::ValidationMode validation_mode) + : instance_{instance}, adapter_{adapter}, device_{device}, validation_mode_{validation_mode} {} ORT_DISALLOW_COPY_ASSIGNMENT_AND_MOVE(WebGpuContext); + std::vector WebGpuContext::GetEnabledAdapterToggles() const; + std::vector WebGpuContext::GetEnabledDeviceToggles() const; + std::vector WebGpuContext::GetDisabledDeviceToggles() const; + std::vector WebGpuContext::GetAvailableRequiredFeatures(const wgpu::Adapter& adapter) const; + wgpu::RequiredLimits WebGpuContext::GetRequiredLimits(const wgpu::Adapter& adapter) const; + std::once_flag init_flag_; wgpu::Instance instance_; wgpu::Adapter adapter_; wgpu::Device device_; + webgpu::ValidationMode validation_mode_; + wgpu::AdapterInfo adapter_info_; wgpu::Limits device_limits_; diff --git a/onnxruntime/core/providers/webgpu/webgpu_execution_provider.cc b/onnxruntime/core/providers/webgpu/webgpu_execution_provider.cc index ae5b429fb2301..d049cbbf64560 100644 --- a/onnxruntime/core/providers/webgpu/webgpu_execution_provider.cc +++ b/onnxruntime/core/providers/webgpu/webgpu_execution_provider.cc @@ -23,7 +23,8 @@ #include "core/framework/kernel_registry.h" #include "core/graph/function_utils.h" #include "core/graph/indexed_sub_graph.h" -#include "data_transfer.h" + +#include "core/providers/webgpu/data_transfer.h" namespace onnxruntime { diff --git a/onnxruntime/core/providers/webgpu/webgpu_execution_provider.h b/onnxruntime/core/providers/webgpu/webgpu_execution_provider.h index 5f27fad14afc6..db9de9dc4933e 100644 --- a/onnxruntime/core/providers/webgpu/webgpu_execution_provider.h +++ b/onnxruntime/core/providers/webgpu/webgpu_execution_provider.h @@ -22,9 +22,9 @@ enum class BufferCacheMode; } // namespace webgpu struct WebGpuExecutionProviderInfo { - WebGpuExecutionProviderInfo(DataLayout data_layout1, bool enable_graph_capture1) - : data_layout{data_layout1}, - enable_graph_capture{enable_graph_capture1}, + WebGpuExecutionProviderInfo(DataLayout data_layout, bool enable_graph_capture) + : data_layout{data_layout}, + enable_graph_capture{enable_graph_capture}, storage_buffer_cache_mode{}, uniform_buffer_cache_mode{}, query_resolve_buffer_cache_mode{}, diff --git a/onnxruntime/core/providers/webgpu/webgpu_provider_factory.cc b/onnxruntime/core/providers/webgpu/webgpu_provider_factory.cc index 3848ccfc19f51..4ceaa06238590 100644 --- a/onnxruntime/core/providers/webgpu/webgpu_provider_factory.cc +++ b/onnxruntime/core/providers/webgpu/webgpu_provider_factory.cc @@ -99,6 +99,28 @@ std::shared_ptr WebGpuProviderFactoryCreator::Create( webgpu_ep_info.default_buffer_cache_mode = parse_buffer_cache_mode(kDefaultBufferCacheMode, webgpu::BufferCacheMode::Disabled); LOGS_DEFAULT(VERBOSE) << "WebGPU EP default buffer cache mode: " << webgpu_ep_info.default_buffer_cache_mode; + webgpu::ValidationMode validation_mode = +#ifndef NDEBUG + webgpu::ValidationMode::Full // for debug build, enable full validation by default +#else + webgpu::ValidationMode::WGPUOnly // for release build, only enable WGPU validation. +#endif // !NDEBUG + ; + std::string validation_mode_str; + if (config_options.TryGetConfigEntry(kValidationMode, validation_mode_str)) { + if (validation_mode_str == kValidationMode_Disabled) { + validation_mode = webgpu::ValidationMode::Disabled; + } else if (validation_mode_str == kValidationMode_wgpuOnly) { + validation_mode = webgpu::ValidationMode::WGPUOnly; + } else if (validation_mode_str == kValidationMode_basic) { + validation_mode = webgpu::ValidationMode::Basic; + } else if (validation_mode_str == kValidationMode_full) { + validation_mode = webgpu::ValidationMode::Full; + } else { + ORT_THROW("Invalid validation mode: ", validation_mode_str); + } + } + // // STEP.2 - prepare WebGpuContext // @@ -136,7 +158,8 @@ std::shared_ptr WebGpuProviderFactoryCreator::Create( auto& context = webgpu::WebGpuContextFactory::CreateContext(context_id, reinterpret_cast(webgpu_instance), reinterpret_cast(webgpu_adapter), - reinterpret_cast(webgpu_device)); + reinterpret_cast(webgpu_device), + validation_mode); context.Initialize(webgpu_ep_info); return std::make_shared(context_id, context, webgpu_ep_info); diff --git a/onnxruntime/core/providers/webgpu/webgpu_provider_options.h b/onnxruntime/core/providers/webgpu/webgpu_provider_options.h index 65ccbd800b122..ebbca55a8c706 100644 --- a/onnxruntime/core/providers/webgpu/webgpu_provider_options.h +++ b/onnxruntime/core/providers/webgpu/webgpu_provider_options.h @@ -22,6 +22,8 @@ constexpr const char* kUniformBufferCacheMode = "uniformBufferCacheMode"; constexpr const char* kQueryResolveBufferCacheMode = "queryResolveBufferCacheMode"; constexpr const char* kDefaultBufferCacheMode = "defaultBufferCacheMode"; +constexpr const char* kValidationMode = "validationMode"; + // The following are the possible values for the provider options. constexpr const char* kPreferredLayout_NCHW = "NCHW"; @@ -35,6 +37,11 @@ constexpr const char* kBufferCacheMode_LazyRelease = "lazyRelease"; constexpr const char* kBufferCacheMode_Simple = "simple"; constexpr const char* kBufferCacheMode_Bucket = "bucket"; +constexpr const char* kValidationMode_Disabled = "disabled"; +constexpr const char* kValidationMode_wgpuOnly = "wgpuOnly"; +constexpr const char* kValidationMode_basic = "basic"; +constexpr const char* kValidationMode_full = "full"; + } // namespace options } // namespace webgpu } // namespace onnxruntime From 409ac5c9cfa808654a3284aa13331ee036e8b228 Mon Sep 17 00:00:00 2001 From: Cao Date: Thu, 19 Sep 2024 13:56:41 +0800 Subject: [PATCH 77/77] webgpu: support MultiHeadAttention operator --- .../webgpu/bert/multihead_attention.cc | 506 ++++++++++++++++++ .../webgpu/bert/multihead_attention.h | 113 ++++ .../webgpu/webgpu_contrib_kernels.cc | 2 +- .../multihead_attention_op_test.cc | 95 ++-- 4 files changed, 676 insertions(+), 40 deletions(-) create mode 100644 onnxruntime/contrib_ops/webgpu/bert/multihead_attention.cc create mode 100644 onnxruntime/contrib_ops/webgpu/bert/multihead_attention.h diff --git a/onnxruntime/contrib_ops/webgpu/bert/multihead_attention.cc b/onnxruntime/contrib_ops/webgpu/bert/multihead_attention.cc new file mode 100644 index 0000000000000..3ff1140834754 --- /dev/null +++ b/onnxruntime/contrib_ops/webgpu/bert/multihead_attention.cc @@ -0,0 +1,506 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "contrib_ops/cpu/bert/multihead_attention_helper.h" +#include "contrib_ops/webgpu/bert/multihead_attention.h" +#include "contrib_ops/webgpu/webgpu_contrib_kernels.h" + +#include "core/providers/webgpu/webgpu_supported_types.h" + +using namespace onnxruntime::webgpu; +using namespace ::onnxruntime::common; +using namespace ONNX_NAMESPACE; +using namespace onnxruntime::contrib::multihead_attention_helper; + +namespace onnxruntime { +namespace contrib { +namespace webgpu { + +ONNX_OPERATOR_KERNEL_EX( + MultiHeadAttention, + kMSDomain, + 1, + kWebGpuExecutionProvider, + (*KernelDefBuilder::Create()) + .TypeConstraint("T", WebGpuSupportedFloatTypes()), + MultiHeadAttention); + + +Status TransferBSDToBNSHProgram::GenerateShaderCode(ShaderHelper& shader) const { + shader.AddInput("qkv_input", ShaderVariable::UseUniform); + const auto& qkv_output = shader.AddOutput("qkv_output", ShaderVariable::UseUniform | + ShaderVariable::UseOffsetToIndices); + + if (has_bias_) { + shader.AddInput("bias", ShaderVariable::UseUniform); + } + + shader.SetMainFunctionBody(shader.GuardAgainstOutOfBoundsWorkgroupSizes("uniforms.data_size"), + "let output_indices = ", qkv_output.OffsetToIndices("global_idx"), ";\n", + "let input_offset_idx = output_indices[0] * uniforms.batch_offset + output_indices[1] * ", + "uniforms.head_offset + output_indices[2] * uniforms.sequence_offset + output_indices[3];\n", + has_bias_ ? "let bias_offset_idx = (input_offset_idx % uniforms.sequence_offset) + uniforms.bias_offset;\n" : "", + "qkv_output[global_idx] = qkv_input[input_offset_idx]", + has_bias_ ? " + bias[bias_offset_idx];\n" : ";\n"); + + return Status::OK(); +} + +Status TransferBSDToBNSH(onnxruntime::webgpu::ComputeContext& context, int num_heads, int sequence_length, + int head_size, const Tensor* input_tensor, const Tensor* bias, int bias_offset, Tensor* output_tensor) { + assert(input_tensor->Shape().GetDims().size() == 3); + assert(output_tensor->Shape().GetDims().size() == 4); + + uint32_t data_size = SafeInt(output_tensor->Shape().Size()); + const int batch_offset = num_heads * sequence_length * head_size; + const int sequence_offset = num_heads * head_size; + const int head_offset = head_size; + bool has_bias = bias != nullptr; + + TransferBSDToBNSHProgram program{"TransferBSDToBNSH", has_bias}; + program.AddInputs({{input_tensor, ProgramTensorMetadataDependency::TypeAndRank}}) + .AddOutputs({{output_tensor, ProgramTensorMetadataDependency::TypeAndRank}}) + .SetDispatchGroupSize((data_size + WORKGROUP_SIZE - 1) / WORKGROUP_SIZE) + .AddUniformVariables({ + {data_size}, + {static_cast(batch_offset)}, + {static_cast(sequence_offset)}, + {static_cast(head_offset)}, + {static_cast(bias_offset)} + }); + + if (has_bias) { + program.AddInput({bias, ProgramTensorMetadataDependency::TypeAndRank}); + } + + return context.RunProgram(program); +}; + +Status AttentionProbsProgram::GenerateShaderCode(ShaderHelper& shader) const { + shader.AddInput("q", ShaderVariable::UseUniform | ShaderVariable::UseValueTypeAlias); + shader.AddInput("key", ShaderVariable::UseUniform | ShaderVariable::UseValueTypeAlias); + if (feed_past_key_) { + shader.AddInput("past_key", ShaderVariable::UseUniform); + } + if (has_attention_bias_) { + shader.AddInput("attention_bias", ShaderVariable::UseUniform); + } + + shader.AddOutput("output", ShaderVariable::UseUniform | ShaderVariable::UseValueTypeAlias); + if (has_present_key_) { + shader.AddOutput("present_key", ShaderVariable::UseUniform); + } + + shader.AppendImplementation("const TILE_SIZE = ", tile_size_, "u;\n") + .AppendImplementation("var tileQ: array;\n") + .AppendImplementation("var tileK: array;\n"); + + std::string f32_str = components_ == 4 ? "vec4" : (components_ == 2 ? "vec2" : "f32"); + std::ostringstream ss; + ss << "// x holds the N and y holds the M\n" + << "let headIdx = workgroup_id.z;\n" + << "let m = workgroup_id.y * TILE_SIZE;\n" + << "let n = workgroup_id.x * TILE_SIZE;\n" + << "let qOffset = uniforms.M * uniforms.K * headIdx + m * uniforms.K;\n"; + + if (feed_past_key_ && has_present_key_) { + ss << "let kOffset = uniforms.kv_sequence_length * uniforms.K * headIdx;\n" + << "let pastKeyOffset = uniforms.past_sequence_length * uniforms.K * headIdx;\n"; + } else { + ss << "let kOffset = uniforms.N * uniforms.K * headIdx + n * uniforms.K;\n"; + } + + if (has_present_key_) { + ss << "let presentKeyOffset = headIdx * uniforms.N * uniforms.K;\n"; + } + + ss << "var value = " << f32_str << "(0);\n" + << "for (var w: u32 = 0u; w < uniforms.K; w += TILE_SIZE) {\n" + << "if (global_id.y < uniforms.M && w + local_id.x < uniforms.K) {\n" + << "tileQ[TILE_SIZE * local_id.y + local_id.x] = q[qOffset + local_id.y * uniforms.K + w + local_id.x];\n" + << "}\n" + << "if (n + local_id.y < uniforms.N && w + local_id.x < uniforms.K) {\n" + << "var idx = TILE_SIZE * local_id.y + local_id.x;\n"; + + if (feed_past_key_ && has_present_key_) { + ss << "if (n + local_id.y < uniforms.past_sequence_length) {\n" + << "tileK[idx] = past_key[pastKeyOffset + (n + local_id.y) * uniforms.K + w + local_id.x];\n" + << "} else {\n" + << "tileK[idx] =" + << "key[kOffset + (n + local_id.y - uniforms.past_sequence_length) * uniforms.K + w + local_id.x];\n" + << "}\n"; + } else { + ss << "tileK[idx] = key[kOffset + local_id.y * uniforms.K + w + local_id.x];\n"; + } + + if (has_present_key_) { + ss << "present_key[presentKeyOffset + (n + local_id.y) * uniforms.K + w + local_id.x] = tileK[idx];\n"; + } + + ss << "}\n" + << "workgroupBarrier();\n" + << "for (var k: u32 = 0u; k < TILE_SIZE && w+k < uniforms.K; k++) {\n" + << "value += "<< f32_str << "(tileQ[TILE_SIZE * local_id.y + k] * tileK[TILE_SIZE * local_id.x + k]);\n" + << "}\n" + << "workgroupBarrier();\n" + << "}\n"; + + ss << "let headOffset = headIdx * uniforms.M * uniforms.N;\n" + << "if (global_id.y < uniforms.M && global_id.x < uniforms.N) {\n" + << "let outputIdx = headOffset + global_id.y * uniforms.N + global_id.x;\n" + << "var sum: f32 = " << (components_ == 4 ? "value.x + value.y + value.z + value.w" : + (components_ == 2 ? "value.x + value.y" : "value")) << ";\n"; + + ss << "output[outputIdx] = output_value_t(sum * uniforms.alpha) + " + << (has_attention_bias_ ? "attention_bias[outputIdx]" : "0.0") << ";\n" + << "}\n"; + + shader.SetMainFunctionBody(ss.str()); + + return Status::OK(); +} + +Status ComputeAttentionProbs(onnxruntime::webgpu::ComputeContext& context, int output_count, const Tensor* Q, + const Tensor* K, const Tensor* past_key, const Tensor* attention_bias, Tensor* probs, Tensor* present_key, + AttentionParameters& parameters, int past_sequence_length, int total_sequence_length) { + const float alpha = parameters.scale == 0.0f ? 1.f / sqrt(static_cast(parameters.head_size)) + : parameters.scale; + + const bool feed_past_key = present_key != nullptr && past_key != nullptr && past_key->SizeInBytes() > 0; + const bool has_present_key = output_count > 1 && past_key; + const bool has_attention_bias = attention_bias != nullptr; + const int tile_size = 12; + const int components = parameters.head_size % 4 == 0 ? 4 : (parameters.head_size % 2 == 0 ? 2 : 1); + + AttentionProbsProgram program{"AttentionProbs", feed_past_key, has_present_key, has_attention_bias, tile_size, + components}; + program.AddInputs({{Q, ProgramTensorMetadataDependency::TypeAndRank, components}, + {K, ProgramTensorMetadataDependency::TypeAndRank, components}}); + if (feed_past_key) { + program.AddInput({past_key, ProgramTensorMetadataDependency::TypeAndRank, components}); + } + if (has_attention_bias) { + program.AddInput({attention_bias, ProgramTensorMetadataDependency::TypeAndRank}); + } + program.AddOutputs({{probs, ProgramTensorMetadataDependency::Rank}}); + if (has_present_key) { + program.AddOutput({present_key, ProgramTensorMetadataDependency::Rank, components}); + } + + const uint32_t vectorized_head_size = parameters.head_size / components; + program.SetDispatchGroupSize((total_sequence_length + tile_size - 1) / tile_size, + (parameters.sequence_length + tile_size - 1) / tile_size, + parameters.batch_size * parameters.num_heads) + .SetWorkgroupSize(tile_size, tile_size) + .AddUniformVariables({ + {static_cast(parameters.sequence_length)}, + {static_cast(vectorized_head_size)}, + {static_cast(total_sequence_length)}, + {static_cast(parameters.num_heads)}, + {static_cast(alpha)}, + {static_cast(past_sequence_length)}, + {static_cast(parameters.kv_sequence_length)}}); + + return context.RunProgram(program); +} + +Status InPlaceSoftmaxProgram::GenerateShaderCode(ShaderHelper& shader) const { + shader.AddOutput("x", ShaderVariable::UseUniform | ShaderVariable::UseValueTypeAlias | + ShaderVariable::UseElementTypeAlias); + shader.AppendImplementation("var thread_max: array;\n") + .AppendImplementation("var thread_sum: array;\n"); + + std::string f32_str = components_ == 4 ? "vec4" : (components_ == 2 ? "vec2" : "f32"); + std::ostringstream ss; + ss << "let local_offset = local_idx * uniforms.elements_per_thread;\n" + << "let offset = (global_idx / " << work_group_size_ << ") * uniforms.d_comp + local_offset;\n" + << "var thread_max_vector = " << f32_str << "(-3.402823e+38f);\n" + << "for (var i: u32 = 0; i < uniforms.elements_per_thread && i + local_offset < uniforms.d_comp; i++) {\n" + << "thread_max_vector = max(" << f32_str << "(x[offset + i]), thread_max_vector);\n" + << "}\n" + << "thread_max[local_idx] = " << (components_ == 4 ? + "max(max(thread_max_vector.x, thread_max_vector.y), max(thread_max_vector.z, thread_max_vector.w))" : + (components_ == 2 ? "max(thread_max_vector.x, thread_max_vector.y)" : "thread_max_vector")) << ";\n" + << "workgroupBarrier();\n" + << "var max_value = f32(-3.402823e+38f);\n" + << "for (var i = 0u; i < " << work_group_size_ << "; i++) {\n" + << "max_value = max(thread_max[i], max_value);\n" + << "}\n" + << "var sum_vector = " << f32_str << "(0);\n" + << "for (var i: u32 = 0; i < uniforms.elements_per_thread && i + local_offset < uniforms.d_comp; i++) {\n" + << "sum_vector += exp(" << f32_str << "(x[offset + i]) - max_value);\n" + << "}\n" + << "thread_sum[local_idx] = " << (components_ == 4 ? "sum_vector.x + sum_vector.y + sum_vector.z + sum_vector.w" : + (components_ == 2 ? "sum_vector.x + sum_vector.y" : "sum_vector")) << ";\n" + << "workgroupBarrier();\n" + << "var sum: f32 = 0;\n" + << "for (var i = 0u; i < " << work_group_size_ << "; i++) {\n" + << "sum += thread_sum[i]\n;" + << "}\n" + << "if (sum == 0) {\n" + << "for (var i: u32 = 0; i < uniforms.elements_per_thread && i + local_offset < uniforms.d_comp; i++) {\n" + << "x[offset + i] = x_value_t(x_element_t(uniforms.d_inv));\n" + << "}\n" + << "} else {\n" + << "for (var i: u32 = 0; i < uniforms.elements_per_thread && i + local_offset < uniforms.d_comp; i++) {\n" + << "var f32input = " << f32_str << "(x[offset + i]);\n" + << "x[offset + i] = x_value_t(exp(f32input - max_value) / sum);\n" + << "}\n" + << "}\n"; + + shader.SetMainFunctionBody(ss.str()); + + return Status::OK(); +} + +Status ComputeInPlaceSoftmax(onnxruntime::webgpu::ComputeContext& context, Tensor* probs, int n, int d) { + const int components = d % 4 == 0 ? 4 : (d % 2 == 0 ? 2 : 1); + int work_group_size = 64; + const int d_comp = d / components; + if (d_comp < work_group_size) { + work_group_size = 32; + } + const int elementsPerThread = (d_comp + work_group_size - 1) / work_group_size; + + InPlaceSoftmaxProgram program{"InPlaceSoftmax", work_group_size, components}; + program.AddOutputs({{probs, ProgramTensorMetadataDependency::TypeAndRank, components}}) + .SetDispatchGroupSize(n) + .SetWorkgroupSize(work_group_size) + .AddUniformVariables({{static_cast(1.f / static_cast(d))}, + {static_cast(d_comp)}, + {static_cast(elementsPerThread)} + }); + + return context.RunProgram(program); +} + +Status VxAttentionScoreProgram::GenerateShaderCode(ShaderHelper& shader) const { + shader.AddInput("probs", ShaderVariable::UseUniform | ShaderVariable::UseValueTypeAlias | + ShaderVariable::UseElementTypeAlias); + shader.AddInput("v", ShaderVariable::UseUniform | ShaderVariable::UseValueTypeAlias); + if (feed_past_value_) { + shader.AddInput("past_value", ShaderVariable::UseUniform); + } + + shader.AddOutput("output", ShaderVariable::UseUniform); + if (has_present_value_) { + shader.AddOutput("present_value", ShaderVariable::UseUniform); + } + + shader.AppendImplementation("const TILE_SIZE = ", tile_size_, "u;\n") + .AppendImplementation("var tileQ: array;\n") + .AppendImplementation("var tileK: array;\n"); + + std::ostringstream ss; + ss << "let headIdx = workgroup_id.z;\n" + << "let m = global_id.y;\n" + << "let n = global_id.x;\n" + << "let offsetA = headIdx * (uniforms.M * uniforms.K) + m * uniforms.K;\n"; + + if (feed_past_value_ && has_present_value_) { + ss << "let pastValueOffset = headIdx * uniforms.N * uniforms.past_sequence_length + n;\n" + << "let vOffset = headIdx * uniforms.N * uniforms.kv_sequence_length + n;\n"; + } else { + ss << "let offsetB = headIdx * uniforms.N * uniforms.K + n;\n"; + } + + if (has_present_value_) { + ss << "let presentValueOffset = headIdx * uniforms.N * uniforms.K + n;\n"; + } + + ss << "var value = probs_element_t(0);\n" + << "for (var w: u32 = 0u; w < uniforms.K; w += TILE_SIZE) {\n" + << "if (m < uniforms.M && w + local_id.x < uniforms.K) {\n" + << "tileQ[TILE_SIZE * local_id.y + local_id.x] = probs[offsetA + w + local_id.x];\n" + << "}\n" + << "if (n < uniforms.N && w + local_id.y < uniforms.K) {\n" + << "var idx = TILE_SIZE * local_id.y + local_id.x;\n"; + + if (feed_past_value_ && has_present_value_) { + ss << "if (w + local_id.y < uniforms.past_sequence_length) {\n" + << "tileK[idx] = past_value[pastValueOffset + (w + local_id.y) * uniforms.N];\n" + << "} else {\n" + << "tileK[idx] = v[vOffset + (w + local_id.y - uniforms.past_sequence_length) * uniforms.N];\n" + << "}\n"; + } else { + ss << "tileK[idx] = v[offsetB + (w + local_id.y) * uniforms.N];\n"; + } + + if (has_present_value_) { + ss << "present_value[presentValueOffset + (w + local_id.y) * uniforms.N] = tileK[idx];\n"; + } + + ss << "}\n" + << "workgroupBarrier();\n" + << "for (var k: u32 = 0u; k < TILE_SIZE && w+k < uniforms.K; k++) {\n" + << "value += tileQ[TILE_SIZE * local_id.y + k] * tileK[TILE_SIZE * k + local_id.x];\n" + << "}\n" + << "workgroupBarrier();\n" + << "}\n"; + + ss << "// we need to transpose output from BNSH_v to BSND_v\n" + << "let batchIdx = workgroup_id.z / uniforms.num_heads;\n" + << "let currentBatchHeadNumber = workgroup_id.z % uniforms.num_heads;\n" + << "if (m < uniforms.M && n < uniforms.N) {\n" + << "let outputIdx = batchIdx * uniforms.M * uniforms.v_hidden_size + " + << "m * uniforms.v_hidden_size + currentBatchHeadNumber * uniforms.N + n;\n" + << "output[outputIdx] = value;\n" + << "}\n"; + + shader.SetMainFunctionBody(ss.str()); + + return Status::OK(); +} + +Status ComputeVxAttentionScore(onnxruntime::webgpu::ComputeContext& context, int output_count, + const Tensor* probs, + const Tensor* V, + const Tensor* past_value, + Tensor* output, + Tensor* present_value, + AttentionParameters& parameters, + int past_sequence_length, + int total_sequence_length) { + const bool feed_past_value = present_value != nullptr && past_value !=nullptr && past_value->SizeInBytes() > 0; + const bool has_present_value = output_count > 1 && past_value != nullptr; + const int tile_size = 12; + + VxAttentionScoreProgram program{"VxAttentionScore", feed_past_value, has_present_value, tile_size}; + program.AddInputs({{probs, ProgramTensorMetadataDependency::TypeAndRank}, + {V, ProgramTensorMetadataDependency::TypeAndRank}}); + if (feed_past_value) { + program.AddInput({past_value, ProgramTensorMetadataDependency::TypeAndRank}); + } + program.AddOutputs({{output, ProgramTensorMetadataDependency::TypeAndRank}}); + if (has_present_value) { + program.AddOutput({present_value, ProgramTensorMetadataDependency::TypeAndRank}); + } + + program.SetDispatchGroupSize((parameters.v_head_size + tile_size - 1) / tile_size, + (parameters.sequence_length + tile_size - 1) / tile_size, + parameters.batch_size * parameters.num_heads) + .SetWorkgroupSize(tile_size, tile_size) + .AddUniformVariables({ + {static_cast(parameters.sequence_length)}, + {static_cast(total_sequence_length)}, + {static_cast(parameters.v_head_size)}, + {static_cast(parameters.num_heads)}, + {static_cast(parameters.v_hidden_size)}, + {static_cast(past_sequence_length)}, + {static_cast(parameters.kv_sequence_length)}}); + + return context.RunProgram(program); +} + +Status ApplyAttention(const Tensor* Q, const Tensor* K, const Tensor* V, const Tensor* attention_bias, + const Tensor* past_key, const Tensor* past_value, Tensor* output, Tensor* present_key, Tensor* present_value, + AttentionParameters& parameters, onnxruntime::webgpu::ComputeContext& context) { + const int output_count = std::min({context.OutputCount(), 1 + (past_key != nullptr ? 1 : 0) + + (past_value !=nullptr ? 1 : 0)}); + const int past_sequence_length = output_count > 1 ? parameters.past_sequence_length : 0; + const int total_sequence_length = past_sequence_length + parameters.kv_sequence_length; + + const TensorShapeVector probs_dims({parameters.batch_size, parameters.num_heads, + parameters.sequence_length, total_sequence_length}); + const TensorShape probs_shape(probs_dims); + Tensor probs = context.CreateGPUTensor(Q->DataType(), probs_shape); + ORT_RETURN_IF_ERROR(ComputeAttentionProbs(context, output_count, Q, K, past_key, attention_bias, &probs, present_key, + parameters, past_sequence_length, total_sequence_length)); + + ORT_RETURN_IF_ERROR(ComputeInPlaceSoftmax(context, &probs, + parameters.batch_size * parameters.num_heads * parameters.sequence_length, total_sequence_length)); + + ORT_RETURN_IF_ERROR(ComputeVxAttentionScore(context, output_count, &probs, V, past_value, output, present_value, + parameters, past_sequence_length, total_sequence_length)); + + return Status::OK(); +} + +MultiHeadAttention::MultiHeadAttention(const OpKernelInfo& info) + : WebGpuKernel(info) { + int64_t num_heads = 0; + ORT_ENFORCE(info.GetAttr("num_heads", &num_heads).IsOK() && num_heads > 0); + num_heads_ = static_cast(num_heads); + mask_filter_value_ = info.GetAttrOrDefault("mask_filter_value", -10000.0f); + scale_ = info.GetAttrOrDefault("scale", 0.0f); + is_unidirectional_ = info.GetAttrOrDefault("unidirectional", 0) == 1; + ORT_ENFORCE(!is_unidirectional_, "Unidirectional MHA does not support webgpu kernel"); +} + +Status MultiHeadAttention::ComputeInternal(onnxruntime::webgpu::ComputeContext& context) const { + const Tensor* query = context.Input(0); + const Tensor* key = context.Input(1); + const Tensor* value = context.Input(2); + const Tensor* bias = context.Input(3); + const Tensor* key_padding_mask = context.Input(4); + const Tensor* attention_bias = context.Input(5); + const Tensor* past_key = context.Input(6); + const Tensor* past_value = context.Input(7); + + if (query->Shape().GetDims().size() == 5) { + ORT_NOT_IMPLEMENTED("Packed QKV of shape (B, L, N, 3, H) not implemented for webgpu"); + } + if (key != nullptr && key->Shape().GetDims().size() == 5) { + ORT_NOT_IMPLEMENTED("Packed KV not implemented for webgpu"); + } + if (key_padding_mask) { + ORT_NOT_IMPLEMENTED("input `key_padding_mask` not implemented for webgpu"); + } + + AttentionParameters parameters; + ORT_RETURN_IF_ERROR(multihead_attention_helper::CheckInputs(query, key, value, + bias, key_padding_mask, attention_bias, past_key, past_value, nullptr, ¶meters, + num_heads_, mask_filter_value_, scale_, is_unidirectional_, false, kMultiHeadAttention, + context.DeviceLimits().maxComputeInvocationsPerWorkgroup)); + + TensorShapeVector output_shape(3); + output_shape[0] = static_cast(parameters.batch_size); + output_shape[1] = static_cast(parameters.sequence_length); + output_shape[2] = static_cast(parameters.v_hidden_size); + Tensor* output = context.Output(0, output_shape); + + // If optional outputs aren't needed, present_key and present_value will be null + std::vector present_dims{ + parameters.batch_size, + parameters.num_heads, + parameters.total_sequence_length, + parameters.head_size, + }; + TensorShape present_shape(present_dims); + Tensor* present_key = context.Output(1, present_shape); + Tensor* present_value = context.Output(2, present_shape); + + TensorShapeVector q_new_dims({parameters.batch_size, parameters.num_heads, + parameters.sequence_length, parameters.head_size}); + TensorShape q_new_shape(q_new_dims); + Tensor Q = context.CreateGPUTensor(query->DataType(), q_new_shape); + ORT_RETURN_IF_ERROR(TransferBSDToBNSH( + context, parameters.num_heads, parameters.sequence_length, parameters.head_size, query, bias, 0, &Q)); + + if (parameters.qkv_format == Q_K_V_BSNH_BNSH_BNSH) { // key and value in BNSH format + return ApplyAttention(&Q, key, value, attention_bias, past_key, past_value, output, present_key, + present_value, parameters, context); + } + + TensorShapeVector k_new_dims({parameters.batch_size, parameters.num_heads, + parameters.kv_sequence_length, parameters.head_size}); + TensorShape k_new_shape(k_new_dims); + Tensor K = context.CreateGPUTensor(key->DataType(), k_new_shape); + ORT_RETURN_IF_ERROR(TransferBSDToBNSH(context, parameters.num_heads, parameters.kv_sequence_length, + parameters.head_size, key, bias, parameters.hidden_size, &K)); + + TensorShapeVector v_new_dims({parameters.batch_size, parameters.num_heads, + parameters.kv_sequence_length, parameters.v_head_size}); + TensorShape v_new_shape(v_new_dims); + Tensor V = context.CreateGPUTensor(value->DataType(), v_new_shape); + ORT_RETURN_IF_ERROR(TransferBSDToBNSH(context, parameters.num_heads, parameters.kv_sequence_length, + parameters.v_head_size, value, bias, 2 * parameters.hidden_size, &V)); + + // Compute the attention score and apply the score to V + return ApplyAttention(&Q, &K, &V, attention_bias, past_key, past_value, output, present_key, + present_value, parameters, context); +} + +} // namespace webgpu +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/webgpu/bert/multihead_attention.h b/onnxruntime/contrib_ops/webgpu/bert/multihead_attention.h new file mode 100644 index 0000000000000..86d49ab242ccd --- /dev/null +++ b/onnxruntime/contrib_ops/webgpu/bert/multihead_attention.h @@ -0,0 +1,113 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include "core/providers/webgpu/compute_context.h" +#include "core/providers/webgpu/program.h" +#include "core/providers/webgpu/shader_helper.h" +#include "core/providers/webgpu/webgpu_kernel.h" + +namespace onnxruntime { +namespace contrib { +namespace webgpu { + +using namespace onnxruntime::webgpu; + +class TransferBSDToBNSHProgram final : public Program { + public: + TransferBSDToBNSHProgram(const std::string& kernel_name, bool has_bias) : Program{kernel_name}, has_bias_(has_bias) {} + + Status GenerateShaderCode(ShaderHelper& sh) const override; + + WEBGPU_PROGRAM_DEFINE_UNIFORM_VARIABLES({"data_size", ProgramUniformVariableDataType::Uint32}, + {"batch_offset", ProgramUniformVariableDataType::Uint32}, + {"sequence_offset", ProgramUniformVariableDataType::Uint32}, + {"head_offset", ProgramUniformVariableDataType::Uint32}, + {"bias_offset", ProgramUniformVariableDataType::Uint32}); + +private: + bool has_bias_; +}; + +class AttentionProbsProgram final : public Program { + public: + AttentionProbsProgram(const std::string& kernel_name, bool feed_past_key, bool has_present_key, + bool has_attention_bias,int tile_size, int components) + : Program{kernel_name}, feed_past_key_(feed_past_key), has_present_key_(has_present_key), + has_attention_bias_(has_attention_bias), tile_size_(tile_size), components_(components) { + } + + Status GenerateShaderCode(ShaderHelper& sh) const override; + + WEBGPU_PROGRAM_DEFINE_UNIFORM_VARIABLES({"M", ProgramUniformVariableDataType::Uint32}, + {"K", ProgramUniformVariableDataType::Uint32}, + {"N", ProgramUniformVariableDataType::Uint32}, + {"num_heads", ProgramUniformVariableDataType::Uint32}, + {"alpha", ProgramUniformVariableDataType::Float32}, + {"past_sequence_length", ProgramUniformVariableDataType::Uint32}, + {"kv_sequence_length", ProgramUniformVariableDataType::Uint32}); + +private: + bool feed_past_key_; + bool has_present_key_; + bool has_attention_bias_; + int tile_size_; + int components_; +}; + +class InPlaceSoftmaxProgram final : public Program { + public: + InPlaceSoftmaxProgram(const std::string& kernel_name, int work_group_size, int components) + : Program{kernel_name}, work_group_size_(work_group_size), components_(components) { + } + + Status GenerateShaderCode(ShaderHelper& sh) const override; + + WEBGPU_PROGRAM_DEFINE_UNIFORM_VARIABLES({"d_inv", ProgramUniformVariableDataType::Float32}, + {"d_comp", ProgramUniformVariableDataType::Uint32}, + {"elements_per_thread", ProgramUniformVariableDataType::Uint32}); + +private: + int work_group_size_; + int components_; +}; + +class VxAttentionScoreProgram final : public Program { + public: + VxAttentionScoreProgram(const std::string& kernel_name, bool feed_past_value, bool has_present_value,int tile_size) + : Program{kernel_name}, feed_past_value_(feed_past_value), has_present_value_(has_present_value), + tile_size_(tile_size) { + } + + Status GenerateShaderCode(ShaderHelper& sh) const override; + + WEBGPU_PROGRAM_DEFINE_UNIFORM_VARIABLES({"M", ProgramUniformVariableDataType::Uint32}, + {"K", ProgramUniformVariableDataType::Uint32}, + {"N", ProgramUniformVariableDataType::Uint32}, + {"num_heads", ProgramUniformVariableDataType::Uint32}, + {"v_hidden_size", ProgramUniformVariableDataType::Uint32}, + {"past_sequence_length", ProgramUniformVariableDataType::Uint32}, + {"kv_sequence_length", ProgramUniformVariableDataType::Uint32}); + +private: + bool feed_past_value_; + bool has_present_value_; + int tile_size_; +}; + +class MultiHeadAttention final : public WebGpuKernel { + public: + MultiHeadAttention(const OpKernelInfo& info); + Status ComputeInternal(onnxruntime::webgpu::ComputeContext& context) const override; + + protected: + int num_heads_; + float mask_filter_value_; + float scale_; + bool is_unidirectional_{false}; +}; + +} // namespace webgpu +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/webgpu/webgpu_contrib_kernels.cc b/onnxruntime/contrib_ops/webgpu/webgpu_contrib_kernels.cc index def104b6cb108..84ab684c32c09 100644 --- a/onnxruntime/contrib_ops/webgpu/webgpu_contrib_kernels.cc +++ b/onnxruntime/contrib_ops/webgpu/webgpu_contrib_kernels.cc @@ -45,7 +45,7 @@ Status RegisterWebGpuContribKernels(KernelRegistry& kernel_registry) { // // LayerNormalization used to be a contrib op that (incorrectly) used kOnnxDomain so we need to version it // BuildKernelCreateInfo, // BuildKernelCreateInfo, - // BuildKernelCreateInfo, + BuildKernelCreateInfo, // BuildKernelCreateInfo, // BuildKernelCreateInfo, // BuildKernelCreateInfo("num_heads", static_cast(num_heads)); tester.AddAttribute("mask_filter_value", static_cast(-10000.0f)); @@ -266,6 +268,12 @@ static void RunMultiHeadAttentionTest( execution_providers.push_back(DefaultDmlExecutionProvider()); tester.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers); } + + if (enable_webgpu) { + std::vector> execution_providers; + execution_providers.push_back(DefaultWebGpuExecutionProvider()); + tester.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers); + } } } @@ -295,8 +303,10 @@ static void RunMultiHeadAttentionKernel( bool is_static_kv = true, bool disable_cpu = false, // some cases not supported in cpu right now. bool disable_cuda = false, + bool disable_webgpu = false, bool disable_rocm = DISABLE_ROCM, - bool disable_dml = false) { + bool disable_dml = false + ) { if (kernel_type == AttentionKernelType::AttentionKernel_Default) { ScopedEnvironmentVariables scoped_env_vars{ EnvVarMap{ @@ -309,7 +319,8 @@ static void RunMultiHeadAttentionKernel( query_data, key_data, value_data, kv_data, qkv_data, bias_data, attention_bias_data, past_key_data, past_value_data, present_key_data, present_value_data, key_padding_mask_data, mask_type, output_data, num_heads, batch_size, sequence_length, kv_sequence_length, - hidden_size, v_hidden_size, is_static_kv, use_float16, disable_cpu, disable_cuda, disable_rocm, disable_dml); + hidden_size, v_hidden_size, is_static_kv, use_float16, disable_cpu, disable_cuda, disable_rocm, + disable_dml, disable_webgpu); return; } @@ -325,7 +336,8 @@ static void RunMultiHeadAttentionKernel( query_data, key_data, value_data, kv_data, qkv_data, bias_data, attention_bias_data, past_key_data, past_value_data, present_key_data, present_value_data, key_padding_mask_data, mask_type, output_data, num_heads, batch_size, sequence_length, kv_sequence_length, - hidden_size, v_hidden_size, is_static_kv, use_float16, disable_cpu, disable_cuda, disable_rocm, disable_dml); + hidden_size, v_hidden_size, is_static_kv, use_float16, disable_cpu, disable_cuda, disable_rocm, + disable_dml, disable_webgpu); return; } @@ -341,7 +353,8 @@ static void RunMultiHeadAttentionKernel( query_data, key_data, value_data, kv_data, qkv_data, bias_data, attention_bias_data, past_key_data, past_value_data, present_key_data, present_value_data, key_padding_mask_data, mask_type, output_data, num_heads, batch_size, sequence_length, kv_sequence_length, - hidden_size, v_hidden_size, is_static_kv, use_float16, disable_cpu, disable_cuda, disable_rocm, disable_dml); + hidden_size, v_hidden_size, is_static_kv, use_float16, disable_cpu, disable_cuda, disable_rocm, + disable_dml, disable_webgpu); return; } @@ -358,7 +371,8 @@ static void RunMultiHeadAttentionKernel( query_data, key_data, value_data, kv_data, qkv_data, bias_data, attention_bias_data, past_key_data, past_value_data, present_key_data, present_value_data, key_padding_mask_data, mask_type, output_data, num_heads, batch_size, sequence_length, kv_sequence_length, - hidden_size, v_hidden_size, is_static_kv, use_float16, disable_cpu, disable_cuda, disable_rocm, disable_dml); + hidden_size, v_hidden_size, is_static_kv, use_float16, disable_cpu, disable_cuda, disable_rocm, + disable_dml, disable_webgpu); return; } #endif @@ -376,7 +390,8 @@ static void RunMultiHeadAttentionKernel( query_data, key_data, value_data, kv_data, qkv_data, bias_data, attention_bias_data, past_key_data, past_value_data, present_key_data, present_value_data, key_padding_mask_data, mask_type, output_data, num_heads, batch_size, sequence_length, kv_sequence_length, - hidden_size, v_hidden_size, is_static_kv, use_float16, disable_cpu, disable_cuda, disable_rocm, disable_dml); + hidden_size, v_hidden_size, is_static_kv, use_float16, disable_cpu, disable_cuda, disable_rocm, + disable_dml, disable_webgpu); } if (kernel_type == AttentionKernelType::AttentionKernel_CudnnFlashAttention) { @@ -392,11 +407,13 @@ static void RunMultiHeadAttentionKernel( query_data, key_data, value_data, kv_data, qkv_data, bias_data, attention_bias_data, past_key_data, past_value_data, present_key_data, present_value_data, key_padding_mask_data, mask_type, output_data, num_heads, batch_size, sequence_length, kv_sequence_length, - hidden_size, v_hidden_size, is_static_kv, use_float16, disable_cpu, disable_cuda, disable_rocm, disable_dml); + hidden_size, v_hidden_size, is_static_kv, use_float16, disable_cpu, disable_cuda, disable_rocm, + disable_dml, disable_webgpu); } } -static void RunMultiHeadAttentionTests(AttentionTestData& data, bool disable_cpu = false, bool disable_cuda = false) { +static void RunMultiHeadAttentionTests(AttentionTestData& data, bool disable_webgpu = true, bool disable_cpu = false, + bool disable_cuda = false) { if (data.fp32_output_data.size() > 0) { constexpr bool use_float16 = false; @@ -407,7 +424,7 @@ static void RunMultiHeadAttentionTests(AttentionTestData& data, bool disable_cpu data.attention_bias_data, data.past_key_data, data.past_value_data, data.present_key_data, data.present_value_data, data.key_padding_mask_data, data.mask_type, data.fp32_output_data, data.num_heads, data.batch_size, data.sequence_length, data.kv_sequence_length, data.hidden_size, - data.v_hidden_size, kernel_type, use_float16, data.is_static_kv, disable_cpu, disable_cuda); + data.v_hidden_size, kernel_type, use_float16, data.is_static_kv, disable_cpu, disable_cuda, disable_webgpu); } #if USE_MEMORY_EFFICIENT_ATTENTION @@ -420,7 +437,7 @@ static void RunMultiHeadAttentionTests(AttentionTestData& data, bool disable_cpu data.attention_bias_data, data.past_key_data, data.past_value_data, data.present_key_data, data.present_value_data, data.key_padding_mask_data, data.mask_type, data.fp32_output_data, data.num_heads, data.batch_size, data.sequence_length, data.kv_sequence_length, data.hidden_size, - data.v_hidden_size, kernel_type, use_float16, data.is_static_kv, disable_cpu, disable_cuda); + data.v_hidden_size, kernel_type, use_float16, data.is_static_kv, disable_cpu, disable_cuda, disable_webgpu); } } #endif @@ -431,7 +448,7 @@ static void RunMultiHeadAttentionTests(AttentionTestData& data, bool disable_cpu data.attention_bias_data, data.past_key_data, data.past_value_data, data.present_key_data, data.present_value_data, data.key_padding_mask_data, data.mask_type, data.fp32_output_data, data.num_heads, data.batch_size, data.sequence_length, data.kv_sequence_length, data.hidden_size, - data.v_hidden_size, kernel_type, use_float16, data.is_static_kv, disable_cpu, disable_cuda); + data.v_hidden_size, kernel_type, use_float16, data.is_static_kv, disable_cpu, disable_cuda, disable_webgpu); } if (data.fp16_output_data.size() > 0) { @@ -443,7 +460,7 @@ static void RunMultiHeadAttentionTests(AttentionTestData& data, bool disable_cpu data.attention_bias_data, data.past_key_data, data.past_value_data, data.present_key_data, data.present_value_data, data.key_padding_mask_data, data.mask_type, data.fp16_output_data, data.num_heads, data.batch_size, data.sequence_length, data.kv_sequence_length, data.hidden_size, - data.v_hidden_size, kernel_type, use_float16, data.is_static_kv, disable_cpu, disable_cuda); + data.v_hidden_size, kernel_type, use_float16, data.is_static_kv, disable_cpu, disable_cuda, disable_webgpu); } kernel_type = AttentionKernelType::AttentionKernel_TrtFusedAttention; @@ -453,7 +470,7 @@ static void RunMultiHeadAttentionTests(AttentionTestData& data, bool disable_cpu data.attention_bias_data, data.past_key_data, data.past_value_data, data.present_key_data, data.present_value_data, data.key_padding_mask_data, data.mask_type, data.fp16_output_data, data.num_heads, data.batch_size, data.sequence_length, data.kv_sequence_length, data.hidden_size, - data.v_hidden_size, kernel_type, use_float16, data.is_static_kv, disable_cpu, disable_cuda); + data.v_hidden_size, kernel_type, use_float16, data.is_static_kv, disable_cpu, disable_cuda, disable_webgpu); } #if USE_MEMORY_EFFICIENT_ATTENTION @@ -464,7 +481,7 @@ static void RunMultiHeadAttentionTests(AttentionTestData& data, bool disable_cpu data.attention_bias_data, data.past_key_data, data.past_value_data, data.present_key_data, data.present_value_data, data.key_padding_mask_data, data.mask_type, data.fp16_output_data, data.num_heads, data.batch_size, data.sequence_length, data.kv_sequence_length, data.hidden_size, - data.v_hidden_size, kernel_type, use_float16, data.is_static_kv, disable_cpu, disable_cuda); + data.v_hidden_size, kernel_type, use_float16, data.is_static_kv, disable_cpu, disable_cuda, disable_webgpu); } #endif @@ -475,7 +492,7 @@ static void RunMultiHeadAttentionTests(AttentionTestData& data, bool disable_cpu data.attention_bias_data, data.past_key_data, data.past_value_data, data.present_key_data, data.present_value_data, data.key_padding_mask_data, data.mask_type, data.fp16_output_data, data.num_heads, data.batch_size, data.sequence_length, data.kv_sequence_length, data.hidden_size, - data.v_hidden_size, kernel_type, use_float16, data.is_static_kv, disable_cpu, disable_cuda); + data.v_hidden_size, kernel_type, use_float16, data.is_static_kv, disable_cpu, disable_cuda, disable_webgpu); } kernel_type = AttentionKernelType::AttentionKernel_Default; @@ -484,7 +501,7 @@ static void RunMultiHeadAttentionTests(AttentionTestData& data, bool disable_cpu data.attention_bias_data, data.past_key_data, data.past_value_data, data.present_key_data, data.present_value_data, data.key_padding_mask_data, data.mask_type, data.fp16_output_data, data.num_heads, data.batch_size, data.sequence_length, data.kv_sequence_length, data.hidden_size, - data.v_hidden_size, kernel_type, use_float16, data.is_static_kv, disable_cpu, disable_cuda); + data.v_hidden_size, kernel_type, use_float16, data.is_static_kv, disable_cpu, disable_cuda, disable_webgpu); } } @@ -493,75 +510,75 @@ static void RunMultiHeadAttentionTests(AttentionTestData& data, bool disable_cpu TEST(MultiHeadAttentionTest, CrossAttention_Batch2_HeadSize40) { AttentionTestData data; GetCrossAttentionData_HeadSize40(data); - RunMultiHeadAttentionTests(data); + RunMultiHeadAttentionTests(data, /*disable_webgpu=*/false); GetCrossAttentionData_HeadSize40_NoBias(data); - RunMultiHeadAttentionTests(data); + RunMultiHeadAttentionTests(data, /*disable_webgpu=*/false); } TEST(MultiHeadAttentionTest, CrossAttention_Batch2_HeadSize32_RightSidePadding_Mask1D) { ROCM_GTEST_SKIP("ROCm MHA does not support mask type of MASK_1D_KEY_SEQ_LEN"); AttentionTestData data; GetCrossAttentionData_Batch2_HeadSize32_RightSidePadding(data, true); - RunMultiHeadAttentionTests(data, true); + RunMultiHeadAttentionTests(data, /*disable_webgpu=*/true, /*disable_cpu=*/true); GetCrossAttentionData_Batch2_HeadSize32_RightSidePadding_NoBias(data, true); - RunMultiHeadAttentionTests(data, true); + RunMultiHeadAttentionTests(data, /*disable_webgpu=*/true, /*disable_cpu=*/true); } TEST(MultiHeadAttentionTest, CrossAttention_Batch2_HeadSize32_RightSidePadding_Mask2D) { AttentionTestData data; GetCrossAttentionData_Batch2_HeadSize32_RightSidePadding(data, false); - RunMultiHeadAttentionTests(data, true); + RunMultiHeadAttentionTests(data, /*disable_webgpu=*/true, /*disable_cpu=*/true); GetCrossAttentionData_Batch2_HeadSize32_RightSidePadding_NoBias(data, false); - RunMultiHeadAttentionTests(data, true); + RunMultiHeadAttentionTests(data, /*disable_webgpu=*/true, /*disable_cpu=*/true); } TEST(MultiHeadAttentionTest, CrossAttention_Batch1_HeadSize32_LeftSidePadding_Mask2D) { AttentionTestData data; GetCrossAttentionData_Batch1_HeadSize32_LeftSidePadding(data); - RunMultiHeadAttentionTests(data, true); + RunMultiHeadAttentionTests(data, /*disable_webgpu=*/true, /*disable_cpu=*/true); GetCrossAttentionData_Batch1_HeadSize32_LeftSidePadding_NoBias(data); - RunMultiHeadAttentionTests(data, true); + RunMultiHeadAttentionTests(data, /*disable_webgpu=*/true, /*disable_cpu=*/true); } TEST(MultiHeadAttentionTest, CrossAttention_Batch2_HeadSize32_NoBias_NoMask_PackedKV) { AttentionTestData data; GetCrossAttentionData_Batch2_HeadSize32_NoBias_NoMask_PackedKV(data); - RunMultiHeadAttentionTests(data); + RunMultiHeadAttentionTests(data, /*disable_webgpu=*/true); } TEST(MultiHeadAttentionTest, SelfAttention_Batch2_HeadSize32_NoBias_NoMask_PackedQKV) { AttentionTestData data; GetSelfAttentionData_Batch2_HeadSize32_NoBias_NoMask_PackedQKV(data); - RunMultiHeadAttentionTests(data); + RunMultiHeadAttentionTests(data, /*disable_webgpu=*/true); } // This tests qk_head_size != v_head_size TEST(MultiHeadAttentionTest, CrossAttention_Batch2_HeadSize16_8) { AttentionTestData data; GetCrossAttentionData_HeadSize16_8(data); - RunMultiHeadAttentionTests(data); + RunMultiHeadAttentionTests(data, /*disable_webgpu=*/false); GetCrossAttentionData_HeadSize16_8_NoBias(data); - RunMultiHeadAttentionTests(data); + RunMultiHeadAttentionTests(data, /*disable_webgpu=*/false); } TEST(MultiHeadAttentionTest, CrossAttention_Batch1_HeadSize16) { AttentionTestData data; GetCrossAttentionData_HeadSize16(data); - RunMultiHeadAttentionTests(data); + RunMultiHeadAttentionTests(data, /*disable_webgpu=*/false); GetCrossAttentionData_HeadSize16_NoBias(data); - RunMultiHeadAttentionTests(data); + RunMultiHeadAttentionTests(data, /*disable_webgpu=*/false); } TEST(MultiHeadAttentionTest, CrossAttention_Batch1_HeadSize8) { AttentionTestData data; GetCrossAttentionData_HeadSize8_NoBias(data); - RunMultiHeadAttentionTests(data, false, true); + RunMultiHeadAttentionTests(data, /*disable_webgpu=*/false, false, true); } // TODO (pavignol): Fix this regression @@ -579,7 +596,7 @@ TEST(MultiHeadAttentionTest, SelfAttention_WithPast_WithAttnBias_ForT5) { ROCM_GTEST_SKIP("ROCm MHA only support head_size >= 8"); AttentionTestData data; GetSelfAttentionData_WithPast_WithAttnBias_ForT5(data); - RunMultiHeadAttentionTests(data, true); + RunMultiHeadAttentionTests(data, /*disable_webgpu=*/false, true); } TEST(MultiHeadAttentionTest, AttentionCutlassAttnBias) { @@ -596,23 +613,23 @@ TEST(MultiHeadAttentionTest, CrossAttention_DiffSequenceLengths) { RunMultiHeadAttentionTests(data); GetCrossAttentionData_DiffSequenceLengths_HeadSize8(data); - RunMultiHeadAttentionTests(data, /*disable_cpu=*/false, /*disable_cuda=*/true); + RunMultiHeadAttentionTests(data, /*disable_webgpu=*/true, /*disable_cpu=*/false, /*disable_cuda=*/true); GetCrossAttentionData_DiffSequenceLengths_HeadSize8_NoBias(data); - RunMultiHeadAttentionTests(data, /*disable_cpu=*/false, /*disable_cuda=*/true); + RunMultiHeadAttentionTests(data, /*disable_webgpu=*/true, /*disable_cpu=*/false, /*disable_cuda=*/true); } TEST(MultiHeadAttentionTest, SelfAttention_WithPastAndPresent_NoMask_NoAttnBias) { // Whisper decoder self attention with past_kv and present_kv AttentionTestData data; GetSelfAttentionData_WithPastAndPresent_NoMask_NoAttnBias(data); - RunMultiHeadAttentionTests(data); + RunMultiHeadAttentionTests(data, /*disable_webgpu=*/false); GetSelfAttentionData_WithPastAndPresent_HeadSize8_NoMask_NoAttnBias(data); - RunMultiHeadAttentionTests(data, /*disable_cpu=*/false, /*disable_cuda=*/true); + RunMultiHeadAttentionTests(data, /*disable_webgpu=*/false, /*disable_cpu=*/false, /*disable_cuda=*/true); GetSelfAttentionData_WithPastAndPresent_HeadSize8_NoMask_NoAttnBias_NoBias(data); - RunMultiHeadAttentionTests(data, /*disable_cpu=*/false, /*disable_cuda=*/true); + RunMultiHeadAttentionTests(data, /*disable_webgpu=*/false, /*disable_cpu=*/false, /*disable_cuda=*/true); } // This test is disabled since it is not used in Whisper anymore, and it fails in ROCm.