From bd2cbb12e8b72bd02d89f317ff9b0f86cd04cb99 Mon Sep 17 00:00:00 2001 From: Yulong Wang <7679871+fs-eire@users.noreply.github.com> Date: Thu, 26 Sep 2024 22:50:18 -0700 Subject: [PATCH 1/6] use Abseil OStringStream in WebGPU EP string concat --- .../contrib_ops/webgpu/bert/fast_gelu.cc | 27 ++-- .../core/providers/webgpu/docs/Conventions.md | 23 ++- .../webgpu/math/binary_elementwise_ops.cc | 144 +++++++++++------- .../webgpu/math/unary_elementwise_ops.cc | 8 +- onnxruntime/core/providers/webgpu/program.cc | 2 +- onnxruntime/core/providers/webgpu/program.h | 8 +- .../providers/webgpu/program_cache_key.cc | 16 +- .../core/providers/webgpu/shader_helper.cc | 47 +++--- .../core/providers/webgpu/shader_helper.h | 32 ++-- .../core/providers/webgpu/shader_macros.h | 66 -------- .../core/providers/webgpu/shader_variable.cc | 108 +++++++------ .../core/providers/webgpu/shader_variable.h | 4 +- .../core/providers/webgpu/string_macros.h | 18 +++ .../core/providers/webgpu/string_utils.h | 39 +++++ .../core/providers/webgpu/tensor/cast.cc | 6 +- .../core/providers/webgpu/tensor/expand.cc | 8 +- .../core/providers/webgpu/tensor/transpose.cc | 14 +- .../core/providers/webgpu/tensor/where.cc | 4 +- 18 files changed, 300 insertions(+), 274 deletions(-) delete mode 100644 onnxruntime/core/providers/webgpu/shader_macros.h create mode 100644 onnxruntime/core/providers/webgpu/string_macros.h create mode 100644 onnxruntime/core/providers/webgpu/string_utils.h diff --git a/onnxruntime/contrib_ops/webgpu/bert/fast_gelu.cc b/onnxruntime/contrib_ops/webgpu/bert/fast_gelu.cc index 52459b0632d5f..d1e5f53d7f637 100644 --- a/onnxruntime/contrib_ops/webgpu/bert/fast_gelu.cc +++ b/onnxruntime/contrib_ops/webgpu/bert/fast_gelu.cc @@ -24,22 +24,23 @@ Status FastGeluProgram::GenerateShaderCode(ShaderHelper& shader) const { const auto& x = shader.AddInput("x", ShaderUsage::UseUniform | ShaderUsage::UseValueTypeAlias); const auto& y = shader.AddOutput("y", ShaderUsage::UseUniform); - std::string add_bias = ""; + shader.AdditionalImplementation() << TanhImpl; + shader.MainFunctionBody() << shader.GuardAgainstOutOfBoundsWorkgroupSizes("uniforms.vec_size") + << " var a = " << x.GetByOffset("global_idx") << ";\n"; if (Inputs().size() > 1) { const auto& bias = shader.AddInput("bias", ShaderUsage::UseUniform | ShaderUsage::UseShapeAndStride); - add_bias = bias_components_ == 1 ? " let bias_offset = global_idx * 4;\n" - " a += x_value_t(" + - bias.GetByOffset("bias_offset % uniforms.bias_shape") + ", " + - bias.GetByOffset("(bias_offset + 1) % uniforms.bias_shape") + ", " + - bias.GetByOffset("(bias_offset + 2) % uniforms.bias_shape") + ", " + - bias.GetByOffset("(bias_offset + 3) % uniforms.bias_shape") + ");\n" - : " a += " + bias.GetByOffset("global_idx % uniforms.bias_shape") + ";\n"; + if (bias_components_ == 1) { + shader.MainFunctionBody() << " let bias_offset = global_idx * 4;\n" + " a += x_value_t(" + << bias.GetByOffset("bias_offset % uniforms.bias_shape") << ", " + << bias.GetByOffset("(bias_offset + 1) % uniforms.bias_shape") << ", " + << bias.GetByOffset("(bias_offset + 2) % uniforms.bias_shape") << ", " + << bias.GetByOffset("(bias_offset + 3) % uniforms.bias_shape") << ");\n"; + } else { + shader.MainFunctionBody() << " a += " << bias.GetByOffset("global_idx % uniforms.bias_shape") + ";\n"; + } } - shader.AppendImplementation(TanhImpl); - shader.SetMainFunctionBody(shader.GuardAgainstOutOfBoundsWorkgroupSizes("uniforms.vec_size"), - " var a = ", x.GetByOffset("global_idx"), ";\n", - add_bias, - y.SetByOffset("global_idx", onnxruntime::webgpu::FastGeluExpr)); + shader.MainFunctionBody() << y.SetByOffset("global_idx", onnxruntime::webgpu::FastGeluExpr); return Status::OK(); } diff --git a/onnxruntime/core/providers/webgpu/docs/Conventions.md b/onnxruntime/core/providers/webgpu/docs/Conventions.md index 1a86e508cdda8..fecccc76a4db7 100644 --- a/onnxruntime/core/providers/webgpu/docs/Conventions.md +++ b/onnxruntime/core/providers/webgpu/docs/Conventions.md @@ -10,20 +10,27 @@ Let's keep it "webgpu" for this folder for now. I have a very good reason to do And anyway, it's not hard to change it back to "wgpu" if we want to. (but it's harder to change it from "wgpu" to "webgpu") -### Use macros defined in shader_macros.h +### Use `OStringStream` defined in string_utils.h and macros defined in string_macros.h -Take `SS` as example. It's a macro defined in `shader_macros.h` and it's used to concatenate strings. It's just make the `std::ostream::operator<<` to be used in a function call style. +Type `onnxruntime::webgpu::OStringStream` is a type alias of Abseil's OStringStream. It's a lightweight implementation +of `std::ostream`. It's recommended to use `OStringStream` instead of `std::ostringstream` in the code base. -I prefer to use the macro because I feel like it's easier to read. Check the following code: +The macros defined in `string_macros.h` are used to make coding easier: ```cpp -ss << "vec4(" << type << ">(" << value1 << ", " << value2 << ", " << value3 << ", " << value4 << ")"; -``` +std::string MyFunction() { + SS(code /* name of the string stream */, 2048 /* initial capacity */); -vs. + code << "var my_var = "; -```cpp -SS("vec4<", type, ">(", value1, ", ", value2, ", ", value3, ", ", value4, ")"); + // function call style string append. equivalent to: + // + // code << "vec4(" << type << ">(" << value1 << ", " << value2 << ", " << value3 << ", " << value4 << ")"; + // + SS_APPEND(code, "vec4(", type, ">(", value1, ", ", value2, ", ", value3, ", ", value4, ")"); + + return SS_GET(code); // return the string +} ``` ### Use the subfolder for kernel implementation diff --git a/onnxruntime/core/providers/webgpu/math/binary_elementwise_ops.cc b/onnxruntime/core/providers/webgpu/math/binary_elementwise_ops.cc index bae7c6a73c4c7..1da1544252bdb 100644 --- a/onnxruntime/core/providers/webgpu/math/binary_elementwise_ops.cc +++ b/onnxruntime/core/providers/webgpu/math/binary_elementwise_ops.cc @@ -12,6 +12,9 @@ Status BinaryElementwiseProgram::GenerateShaderCode(ShaderHelper& shader) const const auto& a = shader.AddInput("input_a", ShaderUsage::UseUniform | ShaderUsage::UseValueTypeAlias); const auto& b = shader.AddInput("input_b", ShaderUsage::UseUniform | ShaderUsage::UseValueTypeAlias); const auto& c = shader.AddOutput("output", ShaderUsage::UseUniform | ShaderUsage::UseValueTypeAlias); + + shader.MainFunctionBody() << shader.GuardAgainstOutOfBoundsWorkgroupSizes("uniforms.vec_size"); + std::string common; std::string get_a_data = is_lhs_scalar_ ? "let a = input_a_value_t(" + a.GetByOffset("0") + ".x" + ");\n" : "let a = " + a.GetByOffset("global_idx") + ";\n"; @@ -20,7 +23,21 @@ Status BinaryElementwiseProgram::GenerateShaderCode(ShaderHelper& shader) const // check whether can use element-wise mode. // If either A or B is scalar, or A and B have the same shape, element-wise mode can be used. // In element-wise mode, no indices calculation is needed. - if (!is_lhs_scalar_ && !is_rhs_scalar_ && is_broadcast_) { + if (is_lhs_scalar_ || is_rhs_scalar_ || !is_broadcast_) { + // get A data + if (is_lhs_scalar_) { + shader.MainFunctionBody() << "let a = input_a_value_t(" << a.GetByOffset("0") << ".x);\n"; + } else { + shader.MainFunctionBody() << "let a = " << a.GetByOffset("global_idx") << ";\n"; + } + + // get B data + if (is_rhs_scalar_) { + shader.MainFunctionBody() << "let b = input_b_value_t(" << b.GetByOffset("0") << ".x);\n"; + } else { + shader.MainFunctionBody() << "let b = " << b.GetByOffset("global_idx") << ";\n"; + } + } else { const auto& c_indices = shader.AddIndices("bcast_indices"); // check whether can use vectorize mode. // If either last dimension of A or B is divisible by 4, or the shared dimension is divisible by 4, vectorize mode @@ -30,67 +47,80 @@ Status BinaryElementwiseProgram::GenerateShaderCode(ShaderHelper& shader) const if (vectorize_) { const auto& a_indices = shader.AddIndices("a_indices"); const auto& b_indices = shader.AddIndices("b_indices"); - common = "let outputIndices = " + c_indices.OffsetToIndices("global_idx * 4") + - ";\n" - "let offset_a = " + - a_indices.BroadcastedIndicesToOffset("outputIndices", c_indices) + - ";\n" - "let offset_b = " + - b_indices.BroadcastedIndicesToOffset("outputIndices", c_indices) + ";\n"; - get_a_data = a.NumComponents() == 4 ? "let a = " + a.GetByOffset("offset_a / 4") + ";\n" - : "let a = input_b_value_t(" + a.GetByOffset("offset_a") + ");\n"; - get_b_data = b.NumComponents() == 4 ? "let b = " + b.GetByOffset("offset_b / 4") + ";\n" - : "let b = input_a_value_t(" + b.GetByOffset("offset_b") + ");\n"; + + shader.MainFunctionBody() << "let outputIndices = " << c_indices.OffsetToIndices("global_idx * 4") + << ";\n" + "let offset_a = " + << a_indices.BroadcastedIndicesToOffset("outputIndices", c_indices) + << ";\n" + "let offset_b = " + << b_indices.BroadcastedIndicesToOffset("outputIndices", c_indices) << ";\n"; + // get A data + if (a.NumComponents() == 4) { + shader.MainFunctionBody() << "let a = " << a.GetByOffset("offset_a / 4") << ";\n"; + } else { + shader.MainFunctionBody() << "let a = input_a_value_t(" << a.GetByOffset("offset_a") << ");\n"; + } + + // get B data + if (b.NumComponents() == 4) { + shader.MainFunctionBody() << "let b = " << b.GetByOffset("offset_b / 4") << ";\n"; + } else { + shader.MainFunctionBody() << "let b = input_b_value_t(" << b.GetByOffset("offset_b") << ");\n"; + } } else { // In broadcast mode, each element of the vec4 value of A and B will be loaded separately to calculate the output value. - common = "var outputIndices = " + c_indices.OffsetToIndices("global_idx * 4") + - ";\n" - "let offset_a0 = " + - a.BroadcastedIndicesToOffset("outputIndices", c_indices) + - ";\n" - "let offset_b0 = " + - b.BroadcastedIndicesToOffset("outputIndices", c_indices) + - ";\n" - "outputIndices = " + - c_indices.OffsetToIndices("global_idx * 4 + 1") + - ";\n" - "let offset_a1 = " + - a.BroadcastedIndicesToOffset("outputIndices", c_indices) + - ";\n" - "let offset_b1 = " + - b.BroadcastedIndicesToOffset("outputIndices", c_indices) + - ";\n" - "outputIndices = " + - c_indices.OffsetToIndices("global_idx * 4 + 2") + - ";\n" - "let offset_a2 = " + - a.BroadcastedIndicesToOffset("outputIndices", c_indices) + - ";\n" - "let offset_b2 = " + - b.BroadcastedIndicesToOffset("outputIndices", c_indices) + - ";\n" - "outputIndices = " + - c_indices.OffsetToIndices("global_idx * 4 + 3") + - ";\n" - "let offset_a3 = " + - a.BroadcastedIndicesToOffset("outputIndices", c_indices) + - ";\n" - "let offset_b3 = " + - b.BroadcastedIndicesToOffset("outputIndices", c_indices) + ";\n"; - get_a_data = "let a = vec4(" + a.GetByOffset("offset_a0") + ", " + - a.GetByOffset("offset_a1") + ", " + - a.GetByOffset("offset_a2") + ", " + - a.GetByOffset("offset_a3") + ");\n"; - get_b_data = "let b = vec4(" + b.GetByOffset("offset_b0") + ", " + - b.GetByOffset("offset_b1") + ", " + - b.GetByOffset("offset_b2") + ", " + - b.GetByOffset("offset_b3") + ");\n"; + shader.MainFunctionBody() << "var outputIndices = " << c_indices.OffsetToIndices("global_idx * 4") + << ";\n" + "let offset_a0 = " + << a.BroadcastedIndicesToOffset("outputIndices", c_indices) + << ";\n" + "let offset_b0 = " + << b.BroadcastedIndicesToOffset("outputIndices", c_indices) + << ";\n" + "outputIndices = " + << c_indices.OffsetToIndices("global_idx * 4 + 1") + << ";\n" + "let offset_a1 = " + << a.BroadcastedIndicesToOffset("outputIndices", c_indices) + << ";\n" + "let offset_b1 = " + << b.BroadcastedIndicesToOffset("outputIndices", c_indices) + << ";\n" + "outputIndices = " + << c_indices.OffsetToIndices("global_idx * 4 + 2") + << ";\n" + "let offset_a2 = " + << a.BroadcastedIndicesToOffset("outputIndices", c_indices) + << ";\n" + "let offset_b2 = " + << b.BroadcastedIndicesToOffset("outputIndices", c_indices) + << ";\n" + "outputIndices = " + << c_indices.OffsetToIndices("global_idx * 4 + 3") + << ";\n" + "let offset_a3 = " + << a.BroadcastedIndicesToOffset("outputIndices", c_indices) + << ";\n" + "let offset_b3 = " + << b.BroadcastedIndicesToOffset("outputIndices", c_indices) << ";\n"; + + // get A data + shader.MainFunctionBody() << "let a = vec4(" + << a.GetByOffset("offset_a0") << ", " + << a.GetByOffset("offset_a1") << ", " + << a.GetByOffset("offset_a2") << ", " + << a.GetByOffset("offset_a3") << ");\n"; + // get B data + shader.MainFunctionBody() << "let b = vec4(" + << b.GetByOffset("offset_b0") << ", " + << b.GetByOffset("offset_b1") << ", " + << b.GetByOffset("offset_b2") << ", " + << b.GetByOffset("offset_b3") << ");\n"; } } - shader.SetMainFunctionBody(shader.GuardAgainstOutOfBoundsWorkgroupSizes("uniforms.vec_size"), - common, get_a_data, get_b_data, - c.SetByOffset("global_idx", expression_)); + shader.MainFunctionBody() << c.SetByOffset("global_idx", expression_); return Status::OK(); } diff --git a/onnxruntime/core/providers/webgpu/math/unary_elementwise_ops.cc b/onnxruntime/core/providers/webgpu/math/unary_elementwise_ops.cc index 9e8117aa34a92..f6d6b18a3d365 100644 --- a/onnxruntime/core/providers/webgpu/math/unary_elementwise_ops.cc +++ b/onnxruntime/core/providers/webgpu/math/unary_elementwise_ops.cc @@ -12,10 +12,10 @@ namespace webgpu { Status UnaryElementwiseProgram::GenerateShaderCode(ShaderHelper& shader) const { const auto& input = shader.AddInput("x", ShaderUsage::UseUniform | additional_usage_); const auto& output = shader.AddOutput("y", ShaderUsage::UseUniform); - shader.AppendImplementation(additional_impl_); - shader.SetMainFunctionBody(shader.GuardAgainstOutOfBoundsWorkgroupSizes("uniforms.vec_size"), - " let a = ", input.GetByOffset("global_idx"), ";\n ", - output.SetByOffset("global_idx", expression_)); + shader.AdditionalImplementation() << additional_impl_; + shader.MainFunctionBody() << shader.GuardAgainstOutOfBoundsWorkgroupSizes("uniforms.vec_size") + << " let a = " << input.GetByOffset("global_idx") << ";\n " + << output.SetByOffset("global_idx", expression_); return Status::OK(); } diff --git a/onnxruntime/core/providers/webgpu/program.cc b/onnxruntime/core/providers/webgpu/program.cc index 75c3c9ee96081..749c802c55be3 100644 --- a/onnxruntime/core/providers/webgpu/program.cc +++ b/onnxruntime/core/providers/webgpu/program.cc @@ -233,7 +233,7 @@ ProgramOutput::ProgramOutput(Tensor* tensor, ProgramTensorMetadataDependency dep use_override_shape{true}, override_shape{override_shape} {} -ProgramBase::ProgramBase(const std::string& name, ProgramMetadata&& metadata) +ProgramBase::ProgramBase(std::string_view name, ProgramMetadata&& metadata) : name_{name}, metadata_{metadata}, dispatch_group_size_x_{0}, diff --git a/onnxruntime/core/providers/webgpu/program.h b/onnxruntime/core/providers/webgpu/program.h index f05ca9c2bf224..04ef6ad2b9c40 100644 --- a/onnxruntime/core/providers/webgpu/program.h +++ b/onnxruntime/core/providers/webgpu/program.h @@ -243,9 +243,9 @@ class ProgramBase { // // set the cache hint for the program - template - ProgramBase& CacheHint(T&& hint) { - cache_hint_ = std::forward(hint); + template + ProgramBase& CacheHint(T&&... hints) { + cache_hint_ = absl::StrJoin(std::forward_as_tuple(std::forward(hints)...), "|"); return *this; } @@ -324,7 +324,7 @@ class ProgramBase { private: // Make the constructor private to prevent direct instantiation or inheritance from this class // Use the Program template class as base class to create a new program class - explicit ProgramBase(const std::string& name, ProgramMetadata&& metadata); + explicit ProgramBase(std::string_view name, ProgramMetadata&& metadata); std::string name_; ProgramMetadata metadata_; diff --git a/onnxruntime/core/providers/webgpu/program_cache_key.cc b/onnxruntime/core/providers/webgpu/program_cache_key.cc index 6c7ef2bc89c6b..a5c21563dbfcd 100644 --- a/onnxruntime/core/providers/webgpu/program_cache_key.cc +++ b/onnxruntime/core/providers/webgpu/program_cache_key.cc @@ -3,14 +3,21 @@ #include "core/providers/webgpu/program_cache_key.h" -#include "core/providers/webgpu/shader_macros.h" +#include "core/providers/webgpu/string_macros.h" namespace onnxruntime { namespace webgpu { +// macro "D" - append to the ostream only in debug build +#ifndef NDEBUG // if debug build +#define D(str) << str +#else +#define D(str) +#endif + namespace { // append the info of an input or output to the cachekey -void AppendTensorInfo(std::ostringstream& ss, const Tensor& tensor, ProgramVariableDataType var_type, ProgramTensorMetadataDependency dependency, +void AppendTensorInfo(std::ostream& ss, const Tensor& tensor, ProgramVariableDataType var_type, ProgramTensorMetadataDependency dependency, bool& first) { if (first) { first = false; @@ -36,8 +43,7 @@ void AppendTensorInfo(std::ostringstream& ss, const Tensor& tensor, ProgramVaria } // namespace std::string CalculateProgramCacheKey(const ProgramBase& program, bool is_1d_dispatch) { - std::ostringstream ss; - ss.imbue(std::locale::classic()); + SS(ss, kStringInitialSizeCacheKey); // final key format: // =[]:::: @@ -100,7 +106,7 @@ std::string CalculateProgramCacheKey(const ProgramBase& program, bool is_1d_disp AppendTensorInfo(ss, *output.tensor, output.var_type, output.dependency, first); } - return ss.str(); + return SS_GET(ss); } } // namespace webgpu diff --git a/onnxruntime/core/providers/webgpu/shader_helper.cc b/onnxruntime/core/providers/webgpu/shader_helper.cc index c229e821cbf8c..b990d2e348db9 100644 --- a/onnxruntime/core/providers/webgpu/shader_helper.cc +++ b/onnxruntime/core/providers/webgpu/shader_helper.cc @@ -10,6 +10,8 @@ #include "core/providers/webgpu/shader_helper.h" #include "core/providers/webgpu/program.h" +#include "core/providers/webgpu/string_utils.h" +#include "core/providers/webgpu/string_macros.h" namespace onnxruntime { namespace webgpu { @@ -27,7 +29,9 @@ ShaderHelper::ShaderHelper(const ProgramBase& program, dispatch_group_size_y_{dispatch_group_size_y}, dispatch_group_size_z_{dispatch_group_size_z}, program_{program}, - program_metadata_{program_metadata} {} + program_metadata_{program_metadata}, + additional_implementation_ss_{&additional_implementation_}, + body_ss_{&body_} {} Status ShaderHelper::Init() { // dispatch group size is normalized so no need to validate it here @@ -50,31 +54,29 @@ Status ShaderHelper::Init() { // init body string stream bool is_1d_dispatch = dispatch_group_size_y_ == 1 && dispatch_group_size_z_ == 1; - body_.imbue(std::locale::classic()); + body_.reserve(4096); + additional_implementation_.reserve(1024); // append header for main function so it is ready for user to append main function body - body_ << "@compute @workgroup_size(workgroup_size_x, workgroup_size_y, workgroup_size_z)\n" - "fn main(@builtin(global_invocation_id) global_id : vec3,\n" - " @builtin(workgroup_id) workgroup_id : vec3,\n" - " @builtin(local_invocation_id) local_id : vec3"; + body_ss_ << "@compute @workgroup_size(workgroup_size_x, workgroup_size_y, workgroup_size_z)\n" + "fn main(@builtin(global_invocation_id) global_id : vec3,\n" + " @builtin(workgroup_id) workgroup_id : vec3,\n" + " @builtin(local_invocation_id) local_id : vec3"; if (!is_1d_dispatch) { - body_ << ",\n" - " @builtin(local_invocation_index) local_idx : u32,\n" - " @builtin(num_workgroups) num_workgroups : vec3"; + body_ss_ << ",\n" + " @builtin(local_invocation_index) local_idx : u32,\n" + " @builtin(num_workgroups) num_workgroups : vec3"; } - body_ << ") {\n"; + body_ss_ << ") {\n"; if (is_1d_dispatch) { - body_ << " let global_idx = global_id.x;\n" - " let local_idx = local_id.x;\n" - " let workgroup_idx = workgroup_id.x;\n"; + body_ss_ << " let global_idx = global_id.x;\n" + " let local_idx = local_id.x;\n" + " let workgroup_idx = workgroup_id.x;\n"; } else { - body_ << " let workgroup_idx = workgroup_id.z * num_workgroups[0] * num_workgroups[1] + workgroup_id.y * num_workgroups[0] + workgroup_id.x;\n" - " let global_idx = workgroup_idx * (workgroup_size_x * workgroup_size_y * workgroup_size_z) + local_idx;\n"; + body_ss_ << " let workgroup_idx = workgroup_id.z * num_workgroups[0] * num_workgroups[1] + workgroup_id.y * num_workgroups[0] + workgroup_id.x;\n" + " let global_idx = workgroup_idx * (workgroup_size_x * workgroup_size_y * workgroup_size_z) + local_idx;\n"; } - // init additional implementation string stream - additional_implementation_.imbue(std::locale::classic()); - return Status::OK(); } @@ -316,8 +318,7 @@ Status ShaderHelper::ValidateIndices() const { } Status ShaderHelper::GenerateSourceCode(std::string& code, std::vector& shape_uniform_ranks) const { - std::ostringstream ss; - ss.imbue(std::locale::classic()); + SS(ss, kStringInitialSizeShaderSourceCode); // // Section feature enabling @@ -507,16 +508,16 @@ Status ShaderHelper::GenerateSourceCode(std::string& code, std::vector& sha // // Additional Implementation // - ss << additional_implementation_.str(); + ss << additional_implementation_; // // Main Function Body // - ss << body_.str(); + ss << body_; ss << "\n" "}\n"; - code = ss.str(); + code = SS_GET(ss); return Status::OK(); } diff --git a/onnxruntime/core/providers/webgpu/shader_helper.h b/onnxruntime/core/providers/webgpu/shader_helper.h index bdc14669cfb51..5e60c1293acea 100644 --- a/onnxruntime/core/providers/webgpu/shader_helper.h +++ b/onnxruntime/core/providers/webgpu/shader_helper.h @@ -16,6 +16,7 @@ #include "core/providers/webgpu/program.h" #include "core/providers/webgpu/shader_variable.h" +#include "core/providers/webgpu/string_utils.h" namespace onnxruntime { namespace webgpu { @@ -92,23 +93,14 @@ class ShaderHelper final { // Add an indices variable to the shader. const ShaderIndicesHelper& AddIndices(const std::string& name, bool use_uniform = true); - // Append additional implementation code to the shader. - // - // can be called multiple times. - template - inline ShaderHelper& AppendImplementation(Strs&&... impl) { - onnxruntime::detail::MakeStringImpl(additional_implementation_, std::forward(impl)...); - return *this; + // Get the string stream for additional implementation code to the shader. + inline OStringStream& AdditionalImplementation() { + return additional_implementation_ss_; } - // Set the main function body of the shader. - // - // can be called only once. - template - inline void SetMainFunctionBody(const Strs&... body) { - ORT_ENFORCE(!body_set_, "Main function body is already set"); - onnxruntime::detail::MakeStringImpl(body_, std::forward>(body)...); - body_set_ = true; + // Get the string stream for the main function body of the shader. + inline OStringStream& MainFunctionBody() { + return body_ss_; } std::string GuardAgainstOutOfBoundsWorkgroupSizes(std::string_view size) const { @@ -117,7 +109,7 @@ class ShaderHelper final { private: template // ConstantType is one of {ProgramConstant, ProgramOverridableConstantValue, ProgramOverridableConstantDefinition} - void WriteConstantValue(std::ostringstream& ss, const ConstantType& constant) const { + void WriteConstantValue(std::ostream& ss, const ConstantType& constant) const { switch (constant.type) { case ProgramConstantDataType::Float16: ss << constant.f16.ToFloat(); @@ -179,10 +171,10 @@ class ShaderHelper final { std::vector> input_vars_; std::vector> output_vars_; std::vector> indices_vars_; - std::ostringstream additional_implementation_; - std::ostringstream body_; - - bool body_set_ = false; + std::string additional_implementation_; + OStringStream additional_implementation_ss_; + std::string body_; + OStringStream body_ss_; }; } // namespace webgpu diff --git a/onnxruntime/core/providers/webgpu/shader_macros.h b/onnxruntime/core/providers/webgpu/shader_macros.h deleted file mode 100644 index a1c61950e6a10..0000000000000 --- a/onnxruntime/core/providers/webgpu/shader_macros.h +++ /dev/null @@ -1,66 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#pragma once - -// macro "D": append to the ostream only in debug build -// -// Usage example: -// -// ss << "error code: " << err_code D(" (") << D(err_msg) D(")"); -// -// This resolves to: (debug build) -// ss << "error code: " << err_code << " (" << err_msg << ")"; -// -// This resolves to: (release build) -// ss << "error code: " << err_code; - -#ifdef D -#undef D -#endif - -#ifndef NDEBUG // if debug build -#define D(str) << str -#else -#define D(str) -#endif - -// macro "DSS" append to the ostream only in debug build -// (assume variable "ss" is in scope) -// -// Usage example: -// -// DSS << "detail error message: " << err_msg; -// -// This resolves to: (debug build) -// ss << "detail error message: " << err_msg; -// -// This resolves to: (release build) -// if constexpr (false) ss << "detail error message: " << err_msg; // no-op - -#ifdef DSS -#undef DSS -#endif - -#ifndef NDEBUG // if debug build -#define DSS ss -#else -#define DSS \ - if constexpr (false) ss -#endif - -// macro "SS" - use function call style to append to the ostream -// (assume variable "ss" is in scope) -// -// Usage example: -// -// SS("error code: ", err_code, " (", err_msg, ")"); -// -// This resolves to: -// ss << "error code: " << err_code << " (" << err_msg << ")"; - -#ifdef SS -#undef SS -#endif - -#define SS(...) ::onnxruntime::detail::MakeStringImpl(ss, __VA_ARGS__) diff --git a/onnxruntime/core/providers/webgpu/shader_variable.cc b/onnxruntime/core/providers/webgpu/shader_variable.cc index f2a5b049b4777..3d55dc252a248 100644 --- a/onnxruntime/core/providers/webgpu/shader_variable.cc +++ b/onnxruntime/core/providers/webgpu/shader_variable.cc @@ -8,7 +8,7 @@ #include "core/common/safeint.h" #include "core/providers/webgpu/shader_variable.h" -#include "core/providers/webgpu/shader_macros.h" +#include "core/providers/webgpu/string_macros.h" namespace onnxruntime { namespace webgpu { @@ -94,7 +94,7 @@ ShaderVariableHelper::ShaderVariableHelper(std::string_view name, ProgramVariabl ORT_ENFORCE(num_components_ > 0, "Invalid number of components for variable ", name_); } -void ShaderIndicesHelper::Impl(std::ostringstream& ss) const { +void ShaderIndicesHelper::Impl(std::ostream& ss) const { // Start generating code const std::string shape = (usage_ & ShaderUsage::UseUniform) ? "uniforms." + name_ + "_shape" : name_ + "_shape"; @@ -102,18 +102,18 @@ void ShaderIndicesHelper::Impl(std::ostringstream& ss) const { // Types if (usage_ & ShaderUsage::UseValueTypeAlias) { - SS("alias ", value_type_alias_, " = ", VALUE_TYPE[static_cast(type_)], ";\n"); + SS_APPEND(ss, "alias ", value_type_alias_, " = ", VALUE_TYPE[static_cast(type_)], ";\n"); } if (usage_ & ShaderUsage::UseIndicesTypeAlias) { - SS("alias ", indices_type_alias_, " = ", indices_type_, ";\n"); + SS_APPEND(ss, "alias ", indices_type_alias_, " = ", indices_type_, ";\n"); } if (usage_ & ShaderUsage::UseElementTypeAlias) { - SS("alias ", element_type_alias_, " = ", ELEMENT_TYPE[static_cast(type_)], ";\n"); + SS_APPEND(ss, "alias ", element_type_alias_, " = ", ELEMENT_TYPE[static_cast(type_)], ";\n"); } // Need shape and strides when (not use uniform) and (use shape and stride is enabled) if (!(usage_ & ShaderUsage::UseUniform) && (usage_ & ShaderUsage::UseShapeAndStride) && rank_ > 0) { - SS("const ", shape, " = ", IndicesType(), "("); + SS_APPEND(ss, "const ", shape, " = ", IndicesType(), "("); bool first = true; for (auto dim : dims_.GetDims()) { @@ -127,7 +127,7 @@ void ShaderIndicesHelper::Impl(std::ostringstream& ss) const { ss << ");\n"; if (rank_ > 1) { - SS("const ", stride, " = ", GetIndicesType(rank_ - 1), "("); + SS_APPEND(ss, "const ", stride, " = ", GetIndicesType(rank_ - 1), "("); first = true; for (int i = 1; i < rank_; i++) { if (!first) { @@ -143,32 +143,32 @@ void ShaderIndicesHelper::Impl(std::ostringstream& ss) const { // Implementation of "fn o2i_{name}" if (usage_ & ShaderUsage::UseOffsetToIndices) { if (rank_ >= 2) { - SS("fn o2i_", name_, "(offset : u32)->", IndicesType(), " {\n"); - SS(" var indices: ", IndicesType(), ";\n"); - SS(" var current = offset;\n"); + SS_APPEND(ss, "fn o2i_", name_, "(offset : u32)->", IndicesType(), " {\n"); + SS_APPEND(ss, " var indices: ", IndicesType(), ";\n"); + SS_APPEND(ss, " var current = offset;\n"); for (int i = 0; i < rank_ - 1; i++) { auto current_stride = GetElementAt(stride, i, rank_ - 1); - SS(" let dim", i, " = current / ", current_stride, ";\n"); - SS(" let rest", i, " = current % ", current_stride, ";\n"); - SS(" indices[", i, "] = dim", i, ";\n"); - SS(" current = rest", i, ";\n"); + SS_APPEND(ss, " let dim", i, " = current / ", current_stride, ";\n"); + SS_APPEND(ss, " let rest", i, " = current % ", current_stride, ";\n"); + SS_APPEND(ss, " indices[", i, "] = dim", i, ";\n"); + SS_APPEND(ss, " current = rest", i, ";\n"); } - SS(" indices[", rank_ - 1, "] = current;\n"); - SS(" return indices;\n"); - SS("}\n"); + SS_APPEND(ss, " indices[", rank_ - 1, "] = current;\n"); + SS_APPEND(ss, " return indices;\n"); + SS_APPEND(ss, "}\n"); } } // Implementation of "fn i2o_{name}" if (usage_ & ShaderUsage::UseIndicesToOffset) { if (rank_ >= 2) { - SS("fn i2o_", name_, "(indices : ", IndicesType(), ")->u32 {\n"); - SS(" return "); + SS_APPEND(ss, "fn i2o_", name_, "(indices : ", IndicesType(), ")->u32 {\n"); + SS_APPEND(ss, " return "); for (int i = 0; i < rank_ - 1; i++) { - SS("indices[", i, "] * ", GetElementAt(stride, i, rank_ - 1), " + "); + SS_APPEND(ss, "indices[", i, "] * ", GetElementAt(stride, i, rank_ - 1), " + "); } - SS("indices[", rank_ - 1, "];\n"); - SS("}\n"); + SS_APPEND(ss, "indices[", rank_ - 1, "];\n"); + SS_APPEND(ss, "}\n"); } } @@ -177,83 +177,82 @@ void ShaderIndicesHelper::Impl(std::ostringstream& ss) const { if (rank_ > 0) { for (const auto& broadcasted_result_ptr : broadcasted_to_) { const auto& broadcasted_result = *broadcasted_result_ptr; - SS("fn ", broadcasted_result.name_, "_bi2o_", name_, "(indices : ", broadcasted_result.indices_type_, ")->u32 {\n"); + SS_APPEND(ss, "fn ", broadcasted_result.name_, "_bi2o_", name_, "(indices : ", broadcasted_result.indices_type_, ")->u32 {\n"); if (rank_ == 1) { - SS(" return ", broadcasted_result.IndicesGet("indices", broadcasted_result.rank_ - 1), " % ", shape, ";\n"); + SS_APPEND(ss, " return ", broadcasted_result.IndicesGet("indices", broadcasted_result.rank_ - 1), " % ", shape, ";\n"); } else { - SS(" return "); + SS_APPEND(ss, " return "); for (int i = 0; i < rank_ - 1; i++) { auto idx = broadcasted_result.IndicesGet("indices", i + broadcasted_result.rank_ - rank_); std::string current_stride = rank_ == 2 ? stride : GetElementAt(stride, i, rank_ - 1); - SS(current_stride, " * (", idx, " % ", IndicesGet(shape, i), ") + "); + SS_APPEND(ss, current_stride, " * (", idx, " % ", IndicesGet(shape, i), ") + "); } - SS(broadcasted_result.IndicesGet("indices", broadcasted_result.rank_ - 1), " % ", IndicesGet(shape, rank_ - 1), ";\n"); + SS_APPEND(ss, broadcasted_result.IndicesGet("indices", broadcasted_result.rank_ - 1), " % ", IndicesGet(shape, rank_ - 1), ";\n"); } - SS("}\n"); + SS_APPEND(ss, "}\n"); } } } } -void ShaderVariableHelper::Impl(std::ostringstream& ss) const { +void ShaderVariableHelper::Impl(std::ostream& ss) const { ShaderIndicesHelper::Impl(ss); // Implementation of "fn set_{name}" if (usage_ & ShaderUsage::UseSet) { if (rank_ >= 2) { - SS("fn set_", name_, "(d0: u32"); + SS_APPEND(ss, "fn set_", name_, "(d0: u32"); for (int i = 1; i < rank_; i++) { - SS(", d", i, ": u32"); + SS_APPEND(ss, ", d", i, ": u32"); } - SS(", value: ", ValueType(), ") {\n"); - SS(" set_", name_, "_by_indices(d0"); + SS_APPEND(ss, ", value: ", ValueType(), ") {\n"); + SS_APPEND(ss, " set_", name_, "_by_indices(d0"); for (int i = 1; i < rank_; i++) { - SS(", d", i); + SS_APPEND(ss, ", d", i); } - SS(", value);\n"); - SS("}\n"); + SS_APPEND(ss, ", value);\n"); + SS_APPEND(ss, "}\n"); } } // Implementation of "fn set_{name}_by_indices" if (usage_ & ShaderUsage::UseSetByIndices) { if (rank_ >= 2) { - SS("fn set_", name_, "_by_indices(indices: ", IndicesType(), ", value: ", ValueType(), ") {\n"); - SS(" ", SetByOffset("i2o_" + name_ + "(indices)", "value"), "\n"); - SS("}\n"); + SS_APPEND(ss, "fn set_", name_, "_by_indices(indices: ", IndicesType(), ", value: ", ValueType(), ") {\n"); + SS_APPEND(ss, " ", SetByOffset("i2o_" + name_ + "(indices)", "value"), "\n"); + SS_APPEND(ss, "}\n"); } } // Implementation of "fn get_{name}" if (usage_ & ShaderUsage::UseGet) { if (rank_ >= 2) { - SS("fn get_", name_, "(d0: u32"); + SS_APPEND(ss, "fn get_", name_, "(d0: u32"); for (int i = 1; i < rank_; i++) { - SS(", d", i, ": u32"); + SS_APPEND(ss, ", d", i, ": u32"); } - SS(")->", ValueType(), " {\n"); - SS(" return get_", name_, "_by_indices(d0"); + SS_APPEND(ss, ")->", ValueType(), " {\n"); + SS_APPEND(ss, " return get_", name_, "_by_indices(d0"); for (int i = 1; i < rank_; i++) { - SS(", d", i); + SS_APPEND(ss, ", d", i); } - SS(");\n"); - SS("}\n"); + SS_APPEND(ss, ");\n"); + SS_APPEND(ss, "}\n"); } } // Implementation of "fn get_{name}_by_indices" if (usage_ & ShaderUsage::UseGetByIndices) { if (rank_ >= 2) { - SS("fn get_", name_, "_by_indices(indices: ", IndicesType(), ")->", ValueType(), " {\n"); - SS(" return ", GetByOffset("i2o_" + name_ + "(indices)"), ";\n"); - SS("}\n"); + SS_APPEND(ss, "fn get_", name_, "_by_indices(indices: ", IndicesType(), ")->", ValueType(), " {\n"); + SS_APPEND(ss, " return ", GetByOffset("i2o_" + name_ + "(indices)"), ";\n"); + SS_APPEND(ss, "}\n"); } } } std::string ShaderVariableHelper::GetByOffsetImpl(std::string_view offset) const { - std::ostringstream ss; - ss.imbue(std::locale::classic()); + SS(ss, kStringInitialSizeGetByOffsetImpl); switch (type_) { case onnxruntime::webgpu::ProgramVariableDataType::InvalidType: @@ -274,12 +273,11 @@ std::string ShaderVariableHelper::GetByOffsetImpl(std::string_view offset) const ss << name_ << "[" << offset << "]"; } - return ss.str(); + return SS_GET(ss); } std::string ShaderVariableHelper::SetByOffsetImpl(std::string_view offset, std::string_view value) const { - std::ostringstream ss; - ss.imbue(std::locale::classic()); + SS(ss, kStringInitialSizeSetByOffsetImpl); switch (type_) { case onnxruntime::webgpu::ProgramVariableDataType::InvalidType: @@ -298,7 +296,7 @@ std::string ShaderVariableHelper::SetByOffsetImpl(std::string_view offset, std:: ss << name_ << "[" << offset << "]=" << value << ";"; } - return ss.str(); + return SS_GET(ss); } std::string_view ShaderVariableHelper::StorageType() const { diff --git a/onnxruntime/core/providers/webgpu/shader_variable.h b/onnxruntime/core/providers/webgpu/shader_variable.h index 72f38aecb99ce..cad7b0ceb8309 100644 --- a/onnxruntime/core/providers/webgpu/shader_variable.h +++ b/onnxruntime/core/providers/webgpu/shader_variable.h @@ -110,7 +110,7 @@ class ShaderIndicesHelper { protected: ORT_DISALLOW_COPY_AND_ASSIGNMENT(ShaderIndicesHelper); - void Impl(std::ostringstream& ss) const; + void Impl(std::ostream& ss) const; std::string_view IndicesType() const; @@ -175,7 +175,7 @@ class ShaderVariableHelper : public ShaderIndicesHelper { private: ORT_DISALLOW_COPY_AND_ASSIGNMENT(ShaderVariableHelper); - void Impl(std::ostringstream& ss) const; + void Impl(std::ostream& ss) const; std::string GetByOffsetImpl(std::string_view offset) const; std::string SetByOffsetImpl(std::string_view offset, std::string_view value) const; diff --git a/onnxruntime/core/providers/webgpu/string_macros.h b/onnxruntime/core/providers/webgpu/string_macros.h new file mode 100644 index 0000000000000..fd82375ce5589 --- /dev/null +++ b/onnxruntime/core/providers/webgpu/string_macros.h @@ -0,0 +1,18 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include "core/providers/webgpu/string_utils.h" + +// macro "SS" - declare an ostream variable and its string buffer +#define SS(ss, reserve_size) \ + std::string ss##_str; \ + ss##_str.reserve(reserve_size); \ + ::onnxruntime::webgpu::OStringStream ss(&ss##_str) + +// macro "SS_GET" - get the string from the ostream +#define SS_GET(ss) ss##_str + +// macro "SS_APPEND" - use function call style to append to the ostream +#define SS_APPEND(ss, ...) ::onnxruntime::webgpu::detail::OStringStreamAppendImpl(ss, __VA_ARGS__) diff --git a/onnxruntime/core/providers/webgpu/string_utils.h b/onnxruntime/core/providers/webgpu/string_utils.h new file mode 100644 index 0000000000000..6d7e4f40c9d19 --- /dev/null +++ b/onnxruntime/core/providers/webgpu/string_utils.h @@ -0,0 +1,39 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include + +namespace onnxruntime { +namespace webgpu { + +constexpr const size_t kStringInitialSizeSetByOffsetImpl = 128; +constexpr const size_t kStringInitialSizeGetByOffsetImpl = 128; +constexpr const size_t kStringInitialSizeShaderSourceCode = 2048; +#ifndef NDEBUG +constexpr const size_t kStringInitialSizeCacheKey = 512; +#else +constexpr const size_t kStringInitialSizeCacheKey = 256; +#endif + +using OStringStream = absl::strings_internal::OStringStream; + +namespace detail { +inline void OStringStreamAppendImpl(std::ostream& /*ss*/) noexcept { +} + +template +inline void OStringStreamAppendImpl(std::ostream& ss, const T& t) noexcept { + ss << t; +} + +template +inline void OStringStreamAppendImpl(std::ostream& ss, const T& t, const Args&... args) noexcept { + OStringStreamAppendImpl(ss, t); + OStringStreamAppendImpl(ss, args...); +} +} // namespace detail + +} // namespace webgpu +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/webgpu/tensor/cast.cc b/onnxruntime/core/providers/webgpu/tensor/cast.cc index 8d59570de9967..06eae971309c5 100644 --- a/onnxruntime/core/providers/webgpu/tensor/cast.cc +++ b/onnxruntime/core/providers/webgpu/tensor/cast.cc @@ -106,9 +106,9 @@ Status CastProgram::GenerateShaderCode(ShaderHelper& sh) const { default: ORT_NOT_IMPLEMENTED("Cast to type ", to_, " is not supported."); } - sh.SetMainFunctionBody(sh.GuardAgainstOutOfBoundsWorkgroupSizes("uniforms.vec_size"), - " let a = ", input.GetByOffset("global_idx"), ";\n ", - output.SetByOffset("global_idx", expression)); + sh.MainFunctionBody() << sh.GuardAgainstOutOfBoundsWorkgroupSizes("uniforms.vec_size") + << " let a = " << input.GetByOffset("global_idx") << ";\n " + << output.SetByOffset("global_idx", expression); return Status::OK(); } diff --git a/onnxruntime/core/providers/webgpu/tensor/expand.cc b/onnxruntime/core/providers/webgpu/tensor/expand.cc index a106583651885..84cdb35d77f0b 100644 --- a/onnxruntime/core/providers/webgpu/tensor/expand.cc +++ b/onnxruntime/core/providers/webgpu/tensor/expand.cc @@ -14,10 +14,10 @@ Status ExpandProgram::GenerateShaderCode(ShaderHelper& shader) const { const auto& input = shader.AddInput("input", ShaderUsage::UseUniform); const auto& output = shader.AddOutput("output", ShaderUsage::UseUniform); - shader.SetMainFunctionBody(shader.GuardAgainstOutOfBoundsWorkgroupSizes("uniforms.data_size"), - " let output_indices = ", output.OffsetToIndices("global_idx"), ";\n", - " let input_offset = ", input.BroadcastedIndicesToOffset("output_indices", output), ";\n ", - output.SetByOffset("global_idx", input.GetByOffset("input_offset"))); + shader.MainFunctionBody() << shader.GuardAgainstOutOfBoundsWorkgroupSizes("uniforms.data_size") + << " let output_indices = " << output.OffsetToIndices("global_idx") << ";\n" + << " let input_offset = " << input.BroadcastedIndicesToOffset("output_indices", output) << ";\n " + << output.SetByOffset("global_idx", input.GetByOffset("input_offset")); return Status::OK(); } diff --git a/onnxruntime/core/providers/webgpu/tensor/transpose.cc b/onnxruntime/core/providers/webgpu/tensor/transpose.cc index 0962d9191d785..b6272dca473d2 100644 --- a/onnxruntime/core/providers/webgpu/tensor/transpose.cc +++ b/onnxruntime/core/providers/webgpu/tensor/transpose.cc @@ -63,13 +63,13 @@ const std::string AppendPermFunction(gsl::span perm) { Status TransposeProgram::GenerateShaderCode(ShaderHelper& shader) const { const auto& input = shader.AddInput("x", ShaderUsage::UseUniform | ShaderUsage::UseIndicesTypeAlias); const auto& output = shader.AddOutput("y", ShaderUsage::UseUniform | ShaderUsage::UseIndicesTypeAlias); - shader.AppendImplementation(AppendPermFunction(this->perm_)); - shader.SetMainFunctionBody(shader.GuardAgainstOutOfBoundsWorkgroupSizes("uniforms.output_size"), - " let indices = ", output.OffsetToIndices("global_idx"), - ";\n" - " let x_indices = perm(indices); \n" - " ", - output.SetByOffset("global_idx", input.GetByIndices("x_indices"))); + shader.AdditionalImplementation() << AppendPermFunction(this->perm_); + shader.MainFunctionBody() << shader.GuardAgainstOutOfBoundsWorkgroupSizes("uniforms.output_size") + << " let indices = " << output.OffsetToIndices("global_idx") + << ";\n" + " let x_indices = perm(indices); \n" + " " + << output.SetByOffset("global_idx", input.GetByIndices("x_indices")); return Status::OK(); } diff --git a/onnxruntime/core/providers/webgpu/tensor/where.cc b/onnxruntime/core/providers/webgpu/tensor/where.cc index 1d58538a7489c..d8e6d208610ba 100644 --- a/onnxruntime/core/providers/webgpu/tensor/where.cc +++ b/onnxruntime/core/providers/webgpu/tensor/where.cc @@ -114,8 +114,8 @@ Status WhereProgram::GenerateShaderCode(ShaderHelper& shader) const { single_assignment("output_data[global_idx]", "3"); } } - shader.SetMainFunctionBody(shader.GuardAgainstOutOfBoundsWorkgroupSizes("uniforms.vec_size"), - assignment); + shader.MainFunctionBody() << shader.GuardAgainstOutOfBoundsWorkgroupSizes("uniforms.vec_size") + << assignment; return Status::OK(); } From da96b489de94dc9bc41ea8abd1fe3f6126d1203b Mon Sep 17 00:00:00 2001 From: Yulong Wang <7679871+fs-eire@users.noreply.github.com> Date: Fri, 27 Sep 2024 04:39:03 -0700 Subject: [PATCH 2/6] revise operator using latest string lib --- .../webgpu/bert/rotary_embedding.cc | 49 ++-- .../webgpu/quantization/matmul_nbits.cc | 250 +++++++++--------- .../webgpu/math/binary_elementwise_ops.cc | 5 - .../core/providers/webgpu/tensor/concat.cc | 84 +++--- .../core/providers/webgpu/tensor/gather.cc | 27 +- .../core/providers/webgpu/tensor/transpose.cc | 67 +++-- .../core/providers/webgpu/tensor/where.cc | 66 +++-- 7 files changed, 240 insertions(+), 308 deletions(-) diff --git a/onnxruntime/contrib_ops/webgpu/bert/rotary_embedding.cc b/onnxruntime/contrib_ops/webgpu/bert/rotary_embedding.cc index eb5cfad87597f..85ab94706b149 100644 --- a/onnxruntime/contrib_ops/webgpu/bert/rotary_embedding.cc +++ b/onnxruntime/contrib_ops/webgpu/bert/rotary_embedding.cc @@ -29,38 +29,23 @@ Status RotaryEmbeddingProgram::GenerateShaderCode(ShaderHelper& shader) const { // TODO: remove output_indices. const auto& output_indices = shader.AddIndices("output_indices", false); const auto interleaved_str = interleaved_ ? "true" : "false"; - shader.SetMainFunctionBody( - " let half_rotary_emb_dim = uniforms.cos_cache_shape[1];\n" - " let bsnh = global_idx / uniforms.global_stride % uniforms.global_shape;\n" - " let size = uniforms.global_shape[0] * uniforms.global_stride[0];\n", - " if (global_idx >= size) { return; }\n" - " if (bsnh[3] < half_rotary_emb_dim) {\n" - " let position_ids_idx = " + - position_ids.BroadcastedIndicesToOffset("bsnh.xy", output_indices) + ";\n" + - " let position_id = u32(" + - position_ids.GetByOffset("position_ids_idx") + ")" + - " + select(0, bsnh[1], position_ids_idx == 0);\n" - " let i = dot(bsnh, uniforms.input_output_stride) + select(0, bsnh[3], " + - interleaved_str + - ");\n" - " let j = i + select(half_rotary_emb_dim, 1, " + - interleaved_str + - ");\n" - " let re = " + - input.GetByOffset("i") + " * " + cos_cache.GetByIndices("vec2(position_id, bsnh[3])") + "-" + - input.GetByOffset("j") + " * " + sin_cache.GetByIndices("vec2(position_id, bsnh[3])") + ";\n" + - " " + output.SetByOffset("i", "re") + "\n" + - " let im = " + input.GetByOffset("i") + " * " + - sin_cache.GetByIndices("vec2(position_id, bsnh[3])") + - "+ " + input.GetByOffset("j") + - " * " + cos_cache.GetByIndices("vec2(position_id, bsnh[3])") + - ";\n " + output.SetByOffset("j", "im") + - "\n" - " } else { \n" - " let k = dot(bsnh, uniforms.input_output_stride) + half_rotary_emb_dim;\n" + - " " + output.SetByOffset("k", input.GetByOffset("k")) + - "\n" - " }"); + shader.MainFunctionBody() << " let half_rotary_emb_dim = uniforms.cos_cache_shape[1];\n" + " let bsnh = global_idx / uniforms.global_stride % uniforms.global_shape;\n" + " let size = uniforms.global_shape[0] * uniforms.global_stride[0];\n" + " if (global_idx >= size) { return; }\n" + " if (bsnh[3] < half_rotary_emb_dim) {\n" + << " let position_ids_idx = " << position_ids.BroadcastedIndicesToOffset("bsnh.xy", output_indices) << ";\n" + << " let position_id = u32(" << position_ids.GetByOffset("position_ids_idx") << ") + select(0, bsnh[1], position_ids_idx == 0);\n" + << " let i = dot(bsnh, uniforms.input_output_stride) + select(0, bsnh[3], " << interleaved_str << ");\n" + << " let j = i + select(half_rotary_emb_dim, 1, " << interleaved_str << ");\n" + << " let re = " << input.GetByOffset("i") << " * " << cos_cache.GetByIndices("vec2(position_id, bsnh[3])") << " - " << input.GetByOffset("j") << " * " << sin_cache.GetByIndices("vec2(position_id, bsnh[3])") << ";\n" + << " " << output.SetByOffset("i", "re") << "\n" + << " let im = " << input.GetByOffset("i") << " * " << sin_cache.GetByIndices("vec2(position_id, bsnh[3])") << " + " << input.GetByOffset("j") + " * " << cos_cache.GetByIndices("vec2(position_id, bsnh[3])") << ";\n" + << " " << output.SetByOffset("j", "im") << "\n" + << " } else { \n" + " let k = dot(bsnh, uniforms.input_output_stride) + half_rotary_emb_dim;\n" + << " " << output.SetByOffset("k", input.GetByOffset("k")) << "\n" + << " }"; return Status::OK(); } diff --git a/onnxruntime/contrib_ops/webgpu/quantization/matmul_nbits.cc b/onnxruntime/contrib_ops/webgpu/quantization/matmul_nbits.cc index b1f1a3a9ad8d0..2057627c27c20 100644 --- a/onnxruntime/contrib_ops/webgpu/quantization/matmul_nbits.cc +++ b/onnxruntime/contrib_ops/webgpu/quantization/matmul_nbits.cc @@ -59,175 +59,161 @@ Status MatMulNBitsProgram::GenerateShaderCode(ShaderHelper& shader) const { const std::string quantized_data_type = QuantizedDataType(a.NumComponents()); const int output_element_number = y.NumComponents() * SafeInt(output_number_); - std::ostringstream prepare_scale_and_zero_point; - prepare_scale_and_zero_point.imbue(std::locale::classic()); - prepare_scale_and_zero_point << " var col_index = col * " << y.NumComponents() << ";\n"; + + const uint32_t shared_memory_size = output_number_ * WORKGROUP_SIZE; + std::string offset = "workgroup_idx * " + std::to_string(output_number_); + shader.AdditionalImplementation() << "var workgroup_shared : array;\n"; + shader.MainFunctionBody() << " let output_indices = " << y.OffsetToIndices(offset) << ";\n" + << " let col = output_indices[2];\n" + " let row = output_indices[1];\n" + " let batch = output_indices[0];\n" + " let n_blocks_per_col = uniforms.input_b_shape[1];\n" + " let blob_size = uniforms.input_b_shape[2];\n" + " for (var block = local_id.x; block < n_blocks_per_col; block += workgroup_size_x) {\n" + << " var word_offset = block * uniforms.block_size / " << a.NumComponents() << ";\n"; + + // prepare scale and zero point + shader.MainFunctionBody() << " var col_index = col * " << y.NumComponents() << ";\n"; if (has_zero_points_) { const auto& zero_points = shader.AddInput("zero_points", ShaderUsage::UseUniform); - prepare_scale_and_zero_point << " let zero_point_bytes_per_col = (n_blocks_per_col + 1) / 2;\n" - << " var zero_point_byte_count: u32;\n" - << " var zero_point_word_index: u32;\n" - << " var zero_point_byte_offset: u32;\n" - << " let zero_point_nibble_offset: u32 = block & 0x1u;\n" - << " var zero_point_bits_offset: u32;\n" - << " var zero_point_word: u32;\n"; + shader.MainFunctionBody() << " let zero_point_bytes_per_col = (n_blocks_per_col + 1) / 2;\n" + " var zero_point_byte_count: u32;\n" + " var zero_point_word_index: u32;\n" + " var zero_point_byte_offset: u32;\n" + " let zero_point_nibble_offset: u32 = block & 0x1u;\n" + " var zero_point_bits_offset: u32;\n" + " var zero_point_word: u32;\n"; for (int c = 0; c < output_element_number; c++) { - prepare_scale_and_zero_point << " let scale" << c << " = " << scales.GetByOffset("col_index * n_blocks_per_col + block") << ";\n"; - prepare_scale_and_zero_point << " zero_point_byte_count = col_index * zero_point_bytes_per_col + (block >> 0x1u);\n" - << " zero_point_word_index = zero_point_byte_count >> 0x2u;\n" - << " zero_point_byte_offset = zero_point_byte_count & 0x3u;\n" - << " zero_point_bits_offset = (zero_point_byte_offset << 3) + (zero_point_nibble_offset << 2);\n" - << " zero_point_word = " << zero_points.GetByOffset("zero_point_word_index") << " >> zero_point_bits_offset;\n" - << " let zero_point" << c << " = output_element_t((zero_point_word) & 0xFu);\n"; - prepare_scale_and_zero_point << " col_index += 1;\n"; + shader.MainFunctionBody() << " let scale" << c << " = " << scales.GetByOffset("col_index * n_blocks_per_col + block") << ";\n" + << " zero_point_byte_count = col_index * zero_point_bytes_per_col + (block >> 0x1u);\n" + " zero_point_word_index = zero_point_byte_count >> 0x2u;\n" + " zero_point_byte_offset = zero_point_byte_count & 0x3u;\n" + " zero_point_bits_offset = (zero_point_byte_offset << 3) + (zero_point_nibble_offset << 2);\n" + << " zero_point_word = " << zero_points.GetByOffset("zero_point_word_index") << " >> zero_point_bits_offset;\n" + << " let zero_point" << c << " = output_element_t((zero_point_word) & 0xFu);\n" + << " col_index += 1;\n"; } } else { - prepare_scale_and_zero_point << " let zero_point = output_element_t(8.0);\n"; + shader.MainFunctionBody() << " let zero_point = output_element_t(8.0);\n"; for (int c = 0; c < output_element_number; c++) { - prepare_scale_and_zero_point << " let scale" << c << " = " << scales.GetByOffset("col_index * n_blocks_per_col + block") << ";\n"; - prepare_scale_and_zero_point << " col_index += 1;\n"; + shader.MainFunctionBody() << " let scale" << c << " = " << scales.GetByOffset("col_index * n_blocks_per_col + block") << ";\n" + << " col_index += 1;\n"; } } - std::ostringstream prepare_b_data; - prepare_b_data.imbue(std::locale::classic()); - prepare_b_data << " col_index = col * " << y.NumComponents() << ";\n"; + shader.MainFunctionBody() << " for (var word: u32 = 0; word < blob_size; word += 1) {\n"; + + // prepare b data + shader.MainFunctionBody() << " col_index = col * " << y.NumComponents() << ";\n"; for (int c = 0; c < output_element_number; c++) { - prepare_b_data << " let b" << c << "_data = " << b.GetByIndices("input_b_indices_t(col_index, block, word)") << ";\n" - << " col_index += 1;\n"; + shader.MainFunctionBody() << " let b" << c << "_data = " << b.GetByIndices("input_b_indices_t(col_index, block, word)") << ";\n" + << " col_index += 1;\n"; } - prepare_b_data << " var b_value : u32;\n" - << " let b_mask : u32 = 0x0F0F0F0Fu;\n" - << " var b_value_lower : vec4;\n" - << " var b_value_upper : vec4;\n" - << " var b_quantized_values : " << quantized_data_type << ";\n" - << " var b_dequantized_values : " << quantized_data_type << ";\n"; + shader.MainFunctionBody() << " var b_value : u32;\n" + " let b_mask : u32 = 0x0F0F0F0Fu;\n" + " var b_value_lower : vec4;\n" + " var b_value_upper : vec4;\n" + << " var b_quantized_values : " << quantized_data_type << ";\n" + << " var b_dequantized_values : " << quantized_data_type << ";\n"; + + shader.MainFunctionBody() << " for (var i: u32 = 0; i < " << components_b_ << "; i++) {\n"; - std::ostringstream process_one_word; - process_one_word.imbue(std::locale::classic()); - process_one_word << " var input_offset = " << a.IndicesToOffset("input_a_indices_t(batch, row, word_offset)") << ";\n" - << " var a_data: " << quantized_data_type << ";\n" - << " for (var j: u32 = 0; j < " << (8 / a.NumComponents()) << "; j++) {\n" - << " if (word_offset + j < uniforms.input_a_shape[2]) {\n" - << " a_data[j] = " << a.GetByOffset("input_offset") << ";\n" - << " input_offset++;\n" - << " } else {\n" - << " a_data[j] = input_a_value_t(0);\n" - << " }\n" - << " }\n"; + // process one word + shader.MainFunctionBody() << " var input_offset = " << a.IndicesToOffset("input_a_indices_t(batch, row, word_offset)") << ";\n" + << " var a_data: " << quantized_data_type << ";\n" + << " for (var j: u32 = 0; j < " << (8 / a.NumComponents()) << "; j++) {\n" + << " if (word_offset + j < uniforms.input_a_shape[2]) {\n" + << " a_data[j] = " << a.GetByOffset("input_offset") << ";\n" + << " input_offset++;\n" + " } else {\n" + " a_data[j] = input_a_value_t(0);\n" + " }\n" + " }\n"; for (int c = 0; c < output_element_number; c++) { - process_one_word << " b_value = " << "b" << c << "_data"; + shader.MainFunctionBody() << " b_value = b" << c << "_data"; if (components_b_ > 1) { - process_one_word << "[i]"; + shader.MainFunctionBody() << "[i]"; } - process_one_word << ";\n" - << " b_value_lower = unpack4xU8(b_value & b_mask);\n" - << " b_value_upper = unpack4xU8((b_value >> 4) & b_mask);\n" - << " b_quantized_values = " << quantized_data_type << "(output_element_t(b_value_lower[0]), output_element_t(b_value_upper[0]), output_element_t(b_value_lower[1]), output_element_t(b_value_upper[1]), output_element_t(b_value_lower[2]), output_element_t(b_value_upper[2]), output_element_t(b_value_lower[3]), output_element_t(b_value_upper[3]));\n" - << " b_dequantized_values = "; + shader.MainFunctionBody() << ";\n" + " b_value_lower = unpack4xU8(b_value & b_mask);\n" + " b_value_upper = unpack4xU8((b_value >> 4) & b_mask);\n" + << " b_quantized_values = " << quantized_data_type << "(output_element_t(b_value_lower[0]), output_element_t(b_value_upper[0]), output_element_t(b_value_lower[1]), output_element_t(b_value_upper[1]), output_element_t(b_value_lower[2]), output_element_t(b_value_upper[2]), output_element_t(b_value_lower[3]), output_element_t(b_value_upper[3]));\n" + << " b_dequantized_values = "; if (a.NumComponents() == 1) { if (has_zero_points_) { - process_one_word << quantized_data_type << "((b_quantized_values[0] - zero_point" << c << ") * scale" << c << ", " - << "(b_quantized_values[1] - zero_point" << c << ") * scale" << c << ", " - << "(b_quantized_values[2] - zero_point" << c << ") * scale" << c << ", " - << "(b_quantized_values[3] - zero_point" << c << ") * scale" << c << ", " - << "(b_quantized_values[4] - zero_point" << c << ") * scale" << c << ", " - << "(b_quantized_values[5] - zero_point" << c << ") * scale" << c << ", " - << "(b_quantized_values[6] - zero_point" << c << ") * scale" << c << ", " - << "(b_quantized_values[7] - zero_point" << c << ") * scale" << c << ");\n"; + shader.MainFunctionBody() << quantized_data_type << "((b_quantized_values[0] - zero_point" << c << ") * scale" << c << ", " + << "(b_quantized_values[1] - zero_point" << c << ") * scale" << c << ", " + << "(b_quantized_values[2] - zero_point" << c << ") * scale" << c << ", " + << "(b_quantized_values[3] - zero_point" << c << ") * scale" << c << ", " + << "(b_quantized_values[4] - zero_point" << c << ") * scale" << c << ", " + << "(b_quantized_values[5] - zero_point" << c << ") * scale" << c << ", " + << "(b_quantized_values[6] - zero_point" << c << ") * scale" << c << ", " + << "(b_quantized_values[7] - zero_point" << c << ") * scale" << c << ");\n"; } else { - process_one_word << quantized_data_type << "((b_quantized_values[0] - zero_point) * scale" << c << ", " - << "(b_quantized_values[1] - zero_point) * scale" << c << "," - << "(b_quantized_values[2] - zero_point) * scale" << c << "," - << "(b_quantized_values[3] - zero_point) * scale" << c << "," - << "(b_quantized_values[4] - zero_point) * scale" << c << "," - << "(b_quantized_values[5] - zero_point) * scale" << c << "," - << "(b_quantized_values[6] - zero_point) * scale" << c << "," - << "(b_quantized_values[7] - zero_point) * scale" << c << ");\n"; + shader.MainFunctionBody() << quantized_data_type << "((b_quantized_values[0] - zero_point) * scale" << c << ", " + << "(b_quantized_values[1] - zero_point) * scale" << c << "," + << "(b_quantized_values[2] - zero_point) * scale" << c << "," + << "(b_quantized_values[3] - zero_point) * scale" << c << "," + << "(b_quantized_values[4] - zero_point) * scale" << c << "," + << "(b_quantized_values[5] - zero_point) * scale" << c << "," + << "(b_quantized_values[6] - zero_point) * scale" << c << "," + << "(b_quantized_values[7] - zero_point) * scale" << c << ");\n"; } } else { - process_one_word << "(b_quantized_values - " << quantized_data_type << "("; + shader.MainFunctionBody() << "(b_quantized_values - " << quantized_data_type << "("; for (int i = 0; i < 8; i++) { if (has_zero_points_) { - process_one_word << "zero_point" << c; + shader.MainFunctionBody() << "zero_point" << c; } else { - process_one_word << "zero_point"; + shader.MainFunctionBody() << "zero_point"; } if (i < 7) { - process_one_word << ", "; + shader.MainFunctionBody() << ", "; } } - process_one_word << ")) * scale" << c << ";\n"; + shader.MainFunctionBody() << ")) * scale" << c << ";\n"; } - process_one_word << " workgroup_shared[local_id.x * " << output_number_ << " + " << c / y.NumComponents() << "]"; + shader.MainFunctionBody() << " workgroup_shared[local_id.x * " << output_number_ << " + " << c / y.NumComponents() << "]"; if (y.NumComponents() > 1) { - process_one_word << "[" << c % y.NumComponents() << "]"; + shader.MainFunctionBody() << "[" << c % y.NumComponents() << "]"; } - process_one_word << " += "; + shader.MainFunctionBody() << " += "; if (a.NumComponents() == 1) { - process_one_word << "a_data[0] * b_dequantized_values[0] + " - << "a_data[1] * b_dequantized_values[1] + " - << "a_data[2] * b_dequantized_values[2] + " - << "a_data[3] * b_dequantized_values[3] + " - << "a_data[4] * b_dequantized_values[4] + " - << "a_data[5] * b_dequantized_values[5] + " - << "a_data[6] * b_dequantized_values[6] + " - << "a_data[7] * b_dequantized_values[7];\n"; + shader.MainFunctionBody() << "a_data[0] * b_dequantized_values[0] + " + "a_data[1] * b_dequantized_values[1] + " + "a_data[2] * b_dequantized_values[2] + " + "a_data[3] * b_dequantized_values[3] + " + "a_data[4] * b_dequantized_values[4] + " + "a_data[5] * b_dequantized_values[5] + " + "a_data[6] * b_dequantized_values[6] + " + "a_data[7] * b_dequantized_values[7];\n"; } else if (a.NumComponents() == 2) { - process_one_word << "dot(a_data[0], b_dequantized_values[0]) + " - << "dot(a_data[1], b_dequantized_values[1]) + " - << "dot(a_data[2], b_dequantized_values[2]) + " - << "dot(a_data[3], b_dequantized_values[3]);\n"; + shader.MainFunctionBody() << "dot(a_data[0], b_dequantized_values[0]) + " + "dot(a_data[1], b_dequantized_values[1]) + " + "dot(a_data[2], b_dequantized_values[2]) + " + "dot(a_data[3], b_dequantized_values[3]);\n"; } else if (a.NumComponents() == 4) { - process_one_word << "dot(a_data[0], b_dequantized_values[0]) + " - << "dot(a_data[1], b_dequantized_values[1]);\n"; + shader.MainFunctionBody() << "dot(a_data[0], b_dequantized_values[0]) + " + "dot(a_data[1], b_dequantized_values[1]);\n"; } } - const uint32_t shared_memory_size = output_number_ * WORKGROUP_SIZE; - std::string offset = "workgroup_idx * " + std::to_string(output_number_); - shader.AppendImplementation("var workgroup_shared : array;\n"); - shader.SetMainFunctionBody(" let output_indices = ", y.OffsetToIndices(offset), - ";\n" - " let col = output_indices[2];\n" - " let row = output_indices[1];\n" - " let batch = output_indices[0];\n" - " let n_blocks_per_col = uniforms.input_b_shape[1];\n" - " let blob_size = uniforms.input_b_shape[2]" - ";\n" - " for (var block = local_id.x; block < n_blocks_per_col; block += workgroup_size_x) {\n" - " var word_offset = block * uniforms.block_size / ", - a.NumComponents(), ";\n", - prepare_scale_and_zero_point.str(), - " for (var word: u32 = 0; word < blob_size; word += 1) {\n", - prepare_b_data.str(), - " for (var i: u32 = 0; i < ", components_b_, "; i++) {\n", - process_one_word.str(), - " word_offset += ", 8 / a.NumComponents(), - ";\n" - " }\n" - " }\n" - " }\n" - " workgroupBarrier();\n" - " if (local_id.x < ", - output_number_, - ") {\n" - " var output_value = output_value_t(0);\n" - " var workgroup_shared_offset = local_id.x;\n" - " let blocks_num = min(", - shared_memory_size, - ", n_blocks_per_col);\n" - " for (var b = 0u; b < blocks_num; b++) {\n" - " output_value += workgroup_shared[workgroup_shared_offset];\n" - " workgroup_shared_offset += ", - output_number_, - ";\n" - " }\n", - " ", - y.SetByIndices("output_indices_t(batch, row, col + local_id.x)", "output_value"), - "\n" - " }\n"); + shader.MainFunctionBody() << " word_offset += " << 8 / a.NumComponents() << ";\n" + << " }\n" + " }\n" + " }\n" + " workgroupBarrier();\n" + << " if (local_id.x < " << output_number_ << ") {\n" + << " var output_value = output_value_t(0);\n" + " var workgroup_shared_offset = local_id.x;\n" + << " let blocks_num = min(" << shared_memory_size << ", n_blocks_per_col);\n" + << " for (var b = 0u; b < blocks_num; b++) {\n" + " output_value += workgroup_shared[workgroup_shared_offset];\n" + << " workgroup_shared_offset += " << output_number_ << ";\n" + << " }\n" + << " " << y.SetByIndices("output_indices_t(batch, row, col + local_id.x)", "output_value") << "\n" + << " }\n"; return Status::OK(); } diff --git a/onnxruntime/core/providers/webgpu/math/binary_elementwise_ops.cc b/onnxruntime/core/providers/webgpu/math/binary_elementwise_ops.cc index 1da1544252bdb..14e51ce0d3c47 100644 --- a/onnxruntime/core/providers/webgpu/math/binary_elementwise_ops.cc +++ b/onnxruntime/core/providers/webgpu/math/binary_elementwise_ops.cc @@ -15,11 +15,6 @@ Status BinaryElementwiseProgram::GenerateShaderCode(ShaderHelper& shader) const shader.MainFunctionBody() << shader.GuardAgainstOutOfBoundsWorkgroupSizes("uniforms.vec_size"); - std::string common; - std::string get_a_data = is_lhs_scalar_ ? "let a = input_a_value_t(" + a.GetByOffset("0") + ".x" + ");\n" - : "let a = " + a.GetByOffset("global_idx") + ";\n"; - std::string get_b_data = is_rhs_scalar_ ? "let b = input_b_value_t(" + b.GetByOffset("0") + ".x" + ");\n" - : "let b = " + b.GetByOffset("global_idx") + ";\n"; // check whether can use element-wise mode. // If either A or B is scalar, or A and B have the same shape, element-wise mode can be used. // In element-wise mode, no indices calculation is needed. diff --git a/onnxruntime/core/providers/webgpu/tensor/concat.cc b/onnxruntime/core/providers/webgpu/tensor/concat.cc index 671a6a1ed072c..8ef501d3ee413 100644 --- a/onnxruntime/core/providers/webgpu/tensor/concat.cc +++ b/onnxruntime/core/providers/webgpu/tensor/concat.cc @@ -38,48 +38,6 @@ WEBGPU_CONCAT_VERSIONED_KERNEL(4, 10) WEBGPU_CONCAT_VERSIONED_KERNEL(11, 12) WEBGPU_CONCAT_KERNEL(13) -const std::string AppendCalCulateInputIndexFunction(size_t input_count) { - std::ostringstream ss; - ss.imbue(std::locale::classic()); - ss << "fn calculate_input_index(index: u32) -> u32 {" << std::endl - << " for (var i = 0u; i < " << input_count << "; i = i + 1u) {" << std::endl - << " if (index < uniforms.size_in_concat_axis[i]) {" << std::endl - << " return i;" << std::endl - << " }" << std::endl - << " }" << std::endl - << " return " << input_count << ";" << std::endl - << "}" << std::endl; - return ss.str(); -} - -const void AppendAssignOutput(std::ostringstream& ss, const ShaderVariableHelper& input, const ShaderVariableHelper& output) { - ss << output.SetByOffset("global_idx", input.GetByIndices("indices")) << ";" << std::endl; -} - -const std::string AppendAssignOutputDataFunction(gsl::span inputs, const ShaderVariableHelper& output) { - std::ostringstream ss; - size_t input_count = inputs.size(); - ss.imbue(std::locale::classic()); - ss << "fn assign_output_data(global_idx: u32, input_index: u32, indices: output_indices_t) {" << std::endl; - if (input_count == 0) { - AppendAssignOutput(ss, *inputs[0], output); - } else { - for (size_t i = 0; i < input_count; ++i) { - if (i == 0) { - ss << " if (input_index == 0u) {" << std::endl; - } else if (i == input_count - 1) { - ss << " } else {" << std::endl; - } else { - ss << " } else if (input_index == " << i << "u) {" << std::endl; - } - ss << " "; - AppendAssignOutput(ss, *inputs[i], output); - } - ss << " }" << std::endl; - } - ss << "}" << std::endl; - return ss.str(); -} Status ConcatProgram::GenerateShaderCode(ShaderHelper& shader) const { size_t input_count = Inputs().size(); std::vector inputs; @@ -88,16 +46,38 @@ Status ConcatProgram::GenerateShaderCode(ShaderHelper& shader) const { inputs.push_back(&shader.AddInput("input_" + std::to_string(i), ShaderUsage::UseUniform | ShaderUsage::UseIndicesTypeAlias)); } const auto& output = shader.AddOutput("output", ShaderUsage::UseUniform | ShaderUsage::UseIndicesTypeAlias); - shader.AppendImplementation(AppendCalCulateInputIndexFunction(input_count)); - shader.AppendImplementation(AppendAssignOutputDataFunction(gsl::make_span(inputs), output)); - shader.SetMainFunctionBody(shader.GuardAgainstOutOfBoundsWorkgroupSizes("uniforms.output_size"), - " var indices = ", output.OffsetToIndices("global_idx"), ";\n", - " let indices_axis = ", output.IndicesGet("indices", axis_), ";\n", - " let input_index = calculate_input_index(indices_axis);\n", - " if (input_index != 0u) {\n", - " ", output.IndicesSet("indices", axis_, "indices_axis - uniforms.size_in_concat_axis[input_index - 1]"), ";\n", - " }\n", - " assign_output_data(global_idx, input_index, indices);\n"); + + shader.AdditionalImplementation() << "fn calculate_input_index(index: u32) -> u32 {\n" + << " for (var i = 0u; i < " << input_count << "; i = i + 1u) {\n" + << " if (index < uniforms.size_in_concat_axis[i]) {\n" + << " return i;\n" + << " }\n" + << " }\n" + << " return " << input_count << ";\n" + << "}\n"; + + shader.AdditionalImplementation() << "fn assign_output_data(global_idx: u32, input_index: u32, indices: output_indices_t) {\n"; + for (size_t i = 0; i < input_count; ++i) { + if (i == 0) { + shader.AdditionalImplementation() << " if (input_index == 0u) {\n"; + } else if (i == input_count - 1) { + shader.AdditionalImplementation() << " } else {\n"; + } else { + shader.AdditionalImplementation() << " } else if (input_index == " << i << "u) {\n"; + } + shader.AdditionalImplementation() << " " << output.SetByOffset("global_idx", inputs[i]->GetByIndices("indices")) << ";\n"; + } + shader.AdditionalImplementation() << " }\n" + "}\n"; + + shader.MainFunctionBody() << shader.GuardAgainstOutOfBoundsWorkgroupSizes("uniforms.output_size") + << " var indices = " << output.OffsetToIndices("global_idx") << ";\n" + << " let indices_axis = " << output.IndicesGet("indices", axis_) << ";\n" + << " let input_index = calculate_input_index(indices_axis);\n" + " if (input_index != 0u) {\n" + << " " << output.IndicesSet("indices", axis_, "indices_axis - uniforms.size_in_concat_axis[input_index - 1]") << ";\n" + << " }\n" + " assign_output_data(global_idx, input_index, indices);\n"; return Status::OK(); } diff --git a/onnxruntime/core/providers/webgpu/tensor/gather.cc b/onnxruntime/core/providers/webgpu/tensor/gather.cc index 31e0a9e883239..9d5875c4efb41 100644 --- a/onnxruntime/core/providers/webgpu/tensor/gather.cc +++ b/onnxruntime/core/providers/webgpu/tensor/gather.cc @@ -13,31 +13,28 @@ Status GatherProgram::GenerateShaderCode(ShaderHelper& shader) const { const auto& indices = shader.AddInput("input_indices", ShaderUsage::UseUniform | ShaderUsage::UseIndicesTypeAlias | ShaderUsage::UseValueTypeAlias); const auto& output = shader.AddOutput("output", ShaderUsage::UseUniform); - std::ostringstream calc_data_indices; - calc_data_indices.imbue(std::locale::classic()); - calc_data_indices << " var indices_indices = input_indices_indices_t(0);\n"; + shader.MainFunctionBody() << shader.GuardAgainstOutOfBoundsWorkgroupSizes("uniforms.data_size") + << " let output_indices = " << output.OffsetToIndices("global_idx") << ";\n" + << " var indices_indices = input_indices_indices_t(0);\n"; for (int i = 0; i < indices.Rank(); i++) { - calc_data_indices << " " << indices.IndicesSet("indices_indices", i, output.IndicesGet("output_indices", axis_ + i)) << ";\n"; + shader.MainFunctionBody() << " " << indices.IndicesSet("indices_indices", i, output.IndicesGet("output_indices", axis_ + i)) << ";\n"; } - calc_data_indices << " var idx = " << indices.GetByIndices("indices_indices") << ";\n" - << " if (idx < 0) {\n" - << " idx = idx + input_indices_value_t(uniforms.data_shape[" << axis_ << "]);\n" - << " }\n" - << " var data_indices : data_indices_t;\n"; + shader.MainFunctionBody() << " var idx = " << indices.GetByIndices("indices_indices") << ";\n" + << " if (idx < 0) {\n" + << " idx = idx + input_indices_value_t(uniforms.data_shape[" << axis_ << "]);\n" + << " }\n" + << " var data_indices : data_indices_t;\n"; for (int i = 0, j = 0; i < data.Rank(); i++) { if (i == SafeInt(axis_)) { - calc_data_indices << " " << data.IndicesSet("data_indices", i, "u32(idx)") << ";\n"; + shader.MainFunctionBody() << " " << data.IndicesSet("data_indices", i, "u32(idx)") << ";\n"; j += indices.Rank(); } else { - calc_data_indices << " " << data.IndicesSet("data_indices", i, output.IndicesGet("output_indices", j)) << ";\n"; + shader.MainFunctionBody() << " " << data.IndicesSet("data_indices", i, output.IndicesGet("output_indices", j)) << ";\n"; j++; } } - shader.SetMainFunctionBody(shader.GuardAgainstOutOfBoundsWorkgroupSizes("uniforms.data_size"), - " let output_indices = ", output.OffsetToIndices("global_idx"), ";\n", - calc_data_indices.str(), " ", - output.SetByOffset("global_idx", data.GetByIndices("data_indices"))); + shader.MainFunctionBody() << " " << output.SetByOffset("global_idx", data.GetByIndices("data_indices")); return Status::OK(); } diff --git a/onnxruntime/core/providers/webgpu/tensor/transpose.cc b/onnxruntime/core/providers/webgpu/tensor/transpose.cc index e0a0113e13224..d102620877c57 100644 --- a/onnxruntime/core/providers/webgpu/tensor/transpose.cc +++ b/onnxruntime/core/providers/webgpu/tensor/transpose.cc @@ -47,19 +47,6 @@ ONNX_OPERATOR_KERNEL_EX( .TypeConstraint("T", WebGpuSupportedNumberTypes()), Transpose); -const std::string AppendPermFunction(gsl::span perm) { - std::ostringstream ss; - ss.imbue(std::locale::classic()); - ss << "fn perm(i: output_indices_t)->a_indices_t {\n" - " var a: a_indices_t;\n"; - for (size_t i = 0; i < perm.size(); ++i) { - ss << " a[" << perm[i] << "] = i[" << i << "];\n"; - } - ss << " return a;\n" - "}\n"; - return ss.str(); -} - auto SqueezeShape(const gsl::span& shape, const gsl::span& adjusted_perm, InlinedVector& new_shape, InlinedVector& new_perm) { for (auto i = 0; i < shape.size(); ++i) { if (shape[i] != 1) { @@ -76,31 +63,37 @@ Status TransposeProgram::GenerateShaderCode(ShaderHelper& shader) const { const auto& output = shader.AddOutput("output", ShaderUsage::UseUniform | ShaderUsage::UseIndicesTypeAlias | ShaderUsage::UseValueTypeAlias); if (use_shared_) { - shader.AppendImplementation("var tile : array, tile_size>;\n"); - shader.SetMainFunctionBody( - " let stride = (uniforms.output_shape[1] - 1) / tile_size + 1;\n" - " let workgroup_id_x = workgroup_idx % stride;\n" - " let workgroup_id_y = workgroup_idx / stride;\n" - " let input_col = workgroup_id_y * tile_size + local_id.x;\n" - " let input_row = workgroup_id_x * tile_size + local_id.y;\n" - " if (input_row < uniforms.a_shape[0] && input_col < uniforms.a_shape[1]) {\n" - " tile[local_id.y][local_id.x] = " + - input.GetByIndices("a_indices_t(input_row, input_col)") + - ";\n" - " }\n" - " workgroupBarrier();\n" - " let output_col = workgroup_id_x * tile_size + local_id.x;\n" - " let output_row = workgroup_id_y * tile_size + local_id.y;\n" - " if (output_row < uniforms.output_shape[0] && output_col < uniforms.output_shape[1]) {\n " + - output.SetByIndices("output_indices_t(output_row, output_col)", "tile[local_id.x][local_id.y]") + "\n }"); + shader.AdditionalImplementation() << "var tile : array, tile_size>;\n"; + shader.MainFunctionBody() << " let stride = (uniforms.output_shape[1] - 1) / tile_size + 1;\n" + " let workgroup_id_x = workgroup_idx % stride;\n" + " let workgroup_id_y = workgroup_idx / stride;\n" + " let input_col = workgroup_id_y * tile_size + local_id.x;\n" + " let input_row = workgroup_id_x * tile_size + local_id.y;\n" + " if (input_row < uniforms.a_shape[0] && input_col < uniforms.a_shape[1]) {\n" + " tile[local_id.y][local_id.x] = " + << input.GetByIndices("a_indices_t(input_row, input_col)") + << ";\n" + " }\n" + " workgroupBarrier();\n" + " let output_col = workgroup_id_x * tile_size + local_id.x;\n" + " let output_row = workgroup_id_y * tile_size + local_id.y;\n" + " if (output_row < uniforms.output_shape[0] && output_col < uniforms.output_shape[1]) {\n " + << output.SetByIndices("output_indices_t(output_row, output_col)", "tile[local_id.x][local_id.y]") << "\n }"; } else { - shader.AppendImplementation(AppendPermFunction(this->perm_)); - shader.SetMainFunctionBody(shader.GuardAgainstOutOfBoundsWorkgroupSizes("uniforms.output_size"), - " let indices = ", output.OffsetToIndices("global_idx"), - ";\n" - " let x_indices = perm(indices);\n", - " ", - output.SetByOffset("global_idx", input.GetByIndices("x_indices"))); + shader.AdditionalImplementation() << "fn perm(i: output_indices_t)->a_indices_t {\n" + " var a: a_indices_t;\n"; + for (size_t i = 0; i < perm_.size(); ++i) { + shader.AdditionalImplementation() << " a[" << perm_[i] << "] = i[" << i << "];\n"; + } + shader.AdditionalImplementation() << " return a;\n" + "}\n"; + + shader.MainFunctionBody() << shader.GuardAgainstOutOfBoundsWorkgroupSizes("uniforms.output_size") + << " let indices = " << output.OffsetToIndices("global_idx") + << ";\n" + " let x_indices = perm(indices);\n" + " " + << output.SetByOffset("global_idx", input.GetByIndices("x_indices")); } return Status::OK(); } diff --git a/onnxruntime/core/providers/webgpu/tensor/where.cc b/onnxruntime/core/providers/webgpu/tensor/where.cc index d8e6d208610ba..b37014eb05da2 100644 --- a/onnxruntime/core/providers/webgpu/tensor/where.cc +++ b/onnxruntime/core/providers/webgpu/tensor/where.cc @@ -59,12 +59,14 @@ Status WhereProgram::GenerateShaderCode(ShaderHelper& shader) const { const auto& b_input = shader.AddInput("b_data", ShaderUsage::UseUniform); const auto& output = shader.AddOutput("output_data", ShaderUsage::UseUniform); - const auto expression = [](const std::string& a, const std::string& b, const std::string& c) -> auto { - return "select(" + b + ", " + a + ", " + c + ")"; + shader.MainFunctionBody() << shader.GuardAgainstOutOfBoundsWorkgroupSizes("uniforms.vec_size"); + + const auto expression = [](std::string_view a, std::string_view b, std::string_view c) -> auto { + return absl::StrCat("select(", b, ", ", a, ", ", c, ")"); }; - std::string assignment; + if (!is_broadcast_) { - assignment = output.SetByOffset( + shader.MainFunctionBody() << output.SetByOffset( "global_idx", expression(a_input.GetByOffset("global_idx"), b_input.GetByOffset("global_idx"), c_input.GetByOffset("global_idx"))); @@ -75,47 +77,41 @@ Status WhereProgram::GenerateShaderCode(ShaderHelper& shader) const { const auto& output_indices = shader.AddIndices("output_indices"); const auto single_assignment = - [&expression, &output_indices, &a_indices, &b_indices, &c_indices]( - const std::string& rest_str, const std::string& x, const std::string& type_cast = "") - -> auto { + [&expression, &shader, &output_indices, &a_indices, &b_indices, &c_indices]( + std::string_view rest_str, const std::string& x, std::string_view type_cast = "") + -> void { const std::string a_expression = "a_data[index_a" + x + "][component_a" + x + "]"; const std::string b_expression = "b_data[index_b" + x + "][component_b" + x + "]"; const std::string c_expression = "bool(c_data[index_c" + x + "] & (0xffu << (component_c" + x + " * 8)))"; - std::ostringstream ss; - ss.imbue(std::locale::classic()); - ss << "let output_indices" + x + " = " << output_indices.OffsetToIndices("global_idx * 4u + " + x + "u") << ";\n"; - ss << "let offset_a" + x + " = " + a_indices.BroadcastedIndicesToOffset("output_indices" + x, output_indices) + ";\n"; - ss << "let offset_b" + x + " = " + b_indices.BroadcastedIndicesToOffset("output_indices" + x, output_indices) + ";\n"; - ss << "let offset_c" + x + " = " + c_indices.BroadcastedIndicesToOffset("output_indices" + x, output_indices) + ";\n"; - ss << "let index_a" + x + " = offset_a" + x + " / 4u;\n"; - ss << "let index_b" + x + " = offset_b" + x + " / 4u;\n"; - ss << "let index_c" + x + " = offset_c" + x + " / 4u;\n"; - ss << "let component_a" + x + " = offset_a" + x + " % 4u;\n"; - ss << "let component_b" + x + " = offset_b" + x + " % 4u;\n"; - ss << "let component_c" + x + " = offset_c" + x + " % 4u;\n"; - ss << rest_str + "[" + x + "] = " + type_cast + "(" + expression(a_expression, b_expression, c_expression) + ");\n"; - return ss.str(); + shader.MainFunctionBody() << "let output_indices" << x << " = " << output_indices.OffsetToIndices("global_idx * 4u + " + x + "u") << ";\n" + << "let offset_a" << x << " = " << a_indices.BroadcastedIndicesToOffset("output_indices" + x, output_indices) << ";\n" + << "let offset_b" << x << " = " << b_indices.BroadcastedIndicesToOffset("output_indices" + x, output_indices) << ";\n" + << "let offset_c" << x << " = " << c_indices.BroadcastedIndicesToOffset("output_indices" + x, output_indices) << ";\n" + << "let index_a" << x << " = offset_a" << x << " / 4u;\n" + << "let index_b" << x << " = offset_b" << x << " / 4u;\n" + << "let index_c" << x << " = offset_c" << x << " / 4u;\n" + << "let component_a" << x << " = offset_a" << x << " % 4u;\n" + << "let component_b" << x << " = offset_b" << x << " % 4u;\n" + << "let component_c" << x << " = offset_c" << x << " % 4u;\n" + << rest_str << "[" << x << "] = " << type_cast << "(" << expression(a_expression, b_expression, c_expression) << ");\n"; }; if (Outputs()[0].tensor->GetElementType() == ONNX_NAMESPACE::TensorProto_DataType_BOOL) { - assignment = - "var data = vec4(0); \n" + - single_assignment("data", "0", "u32") + - single_assignment("data", "1", "u32") + - single_assignment("data", "2", "u32") + - single_assignment("data", "3", "u32") + - "output_data[global_idx] = dot(vec4(0x1, 0x100, 0x10000, 0x1000000), vec4(data));\n"; + shader.MainFunctionBody() << "var data = vec4(0);\n"; + single_assignment("data", "0", "u32"); + single_assignment("data", "1", "u32"); + single_assignment("data", "2", "u32"); + single_assignment("data", "3", "u32"); + shader.MainFunctionBody() << "output_data[global_idx] = dot(vec4(0x1, 0x100, 0x10000, 0x1000000), vec4(data));\n"; } else { - assignment = - single_assignment("output_data[global_idx]", "0") + - single_assignment("output_data[global_idx]", "1") + - single_assignment("output_data[global_idx]", "2") + - single_assignment("output_data[global_idx]", "3"); + single_assignment("output_data[global_idx]", "0"); + single_assignment("output_data[global_idx]", "1"); + single_assignment("output_data[global_idx]", "2"); + single_assignment("output_data[global_idx]", "3"); } } - shader.MainFunctionBody() << shader.GuardAgainstOutOfBoundsWorkgroupSizes("uniforms.vec_size") - << assignment; + return Status::OK(); } From f9996b94e0825c2efeb76e52c47c976394940f31 Mon Sep 17 00:00:00 2001 From: Yulong Wang <7679871+fs-eire@users.noreply.github.com> Date: Fri, 27 Sep 2024 13:28:18 -0700 Subject: [PATCH 3/6] revise format --- .../webgpu/math/binary_elementwise_ops.cc | 56 +++++-------------- .../core/providers/webgpu/tensor/transpose.cc | 11 ++-- 2 files changed, 20 insertions(+), 47 deletions(-) diff --git a/onnxruntime/core/providers/webgpu/math/binary_elementwise_ops.cc b/onnxruntime/core/providers/webgpu/math/binary_elementwise_ops.cc index 14e51ce0d3c47..6077ef0499069 100644 --- a/onnxruntime/core/providers/webgpu/math/binary_elementwise_ops.cc +++ b/onnxruntime/core/providers/webgpu/math/binary_elementwise_ops.cc @@ -43,13 +43,9 @@ Status BinaryElementwiseProgram::GenerateShaderCode(ShaderHelper& shader) const const auto& a_indices = shader.AddIndices("a_indices"); const auto& b_indices = shader.AddIndices("b_indices"); - shader.MainFunctionBody() << "let outputIndices = " << c_indices.OffsetToIndices("global_idx * 4") - << ";\n" - "let offset_a = " - << a_indices.BroadcastedIndicesToOffset("outputIndices", c_indices) - << ";\n" - "let offset_b = " - << b_indices.BroadcastedIndicesToOffset("outputIndices", c_indices) << ";\n"; + shader.MainFunctionBody() << "let outputIndices = " << c_indices.OffsetToIndices("global_idx * 4") << ";\n" + << "let offset_a = " << a_indices.BroadcastedIndicesToOffset("outputIndices", c_indices) << ";\n" + << "let offset_b = " << b_indices.BroadcastedIndicesToOffset("outputIndices", c_indices) << ";\n"; // get A data if (a.NumComponents() == 4) { shader.MainFunctionBody() << "let a = " << a.GetByOffset("offset_a / 4") << ";\n"; @@ -65,40 +61,18 @@ Status BinaryElementwiseProgram::GenerateShaderCode(ShaderHelper& shader) const } } else { // In broadcast mode, each element of the vec4 value of A and B will be loaded separately to calculate the output value. - shader.MainFunctionBody() << "var outputIndices = " << c_indices.OffsetToIndices("global_idx * 4") - << ";\n" - "let offset_a0 = " - << a.BroadcastedIndicesToOffset("outputIndices", c_indices) - << ";\n" - "let offset_b0 = " - << b.BroadcastedIndicesToOffset("outputIndices", c_indices) - << ";\n" - "outputIndices = " - << c_indices.OffsetToIndices("global_idx * 4 + 1") - << ";\n" - "let offset_a1 = " - << a.BroadcastedIndicesToOffset("outputIndices", c_indices) - << ";\n" - "let offset_b1 = " - << b.BroadcastedIndicesToOffset("outputIndices", c_indices) - << ";\n" - "outputIndices = " - << c_indices.OffsetToIndices("global_idx * 4 + 2") - << ";\n" - "let offset_a2 = " - << a.BroadcastedIndicesToOffset("outputIndices", c_indices) - << ";\n" - "let offset_b2 = " - << b.BroadcastedIndicesToOffset("outputIndices", c_indices) - << ";\n" - "outputIndices = " - << c_indices.OffsetToIndices("global_idx * 4 + 3") - << ";\n" - "let offset_a3 = " - << a.BroadcastedIndicesToOffset("outputIndices", c_indices) - << ";\n" - "let offset_b3 = " - << b.BroadcastedIndicesToOffset("outputIndices", c_indices) << ";\n"; + shader.MainFunctionBody() << "var outputIndices = " << c_indices.OffsetToIndices("global_idx * 4") << ";\n" + << "let offset_a0 = " << a.BroadcastedIndicesToOffset("outputIndices", c_indices) << ";\n" + << "let offset_b0 = " << b.BroadcastedIndicesToOffset("outputIndices", c_indices) << ";\n" + << "outputIndices = " << c_indices.OffsetToIndices("global_idx * 4 + 1") << ";\n" + << "let offset_a1 = " << a.BroadcastedIndicesToOffset("outputIndices", c_indices) << ";\n" + << "let offset_b1 = " << b.BroadcastedIndicesToOffset("outputIndices", c_indices) << ";\n" + << "outputIndices = " << c_indices.OffsetToIndices("global_idx * 4 + 2") << ";\n" + << "let offset_a2 = " << a.BroadcastedIndicesToOffset("outputIndices", c_indices) << ";\n" + << "let offset_b2 = " << b.BroadcastedIndicesToOffset("outputIndices", c_indices) << ";\n" + << "outputIndices = " << c_indices.OffsetToIndices("global_idx * 4 + 3") << ";\n" + << "let offset_a3 = " << a.BroadcastedIndicesToOffset("outputIndices", c_indices) << ";\n" + << "let offset_b3 = " << b.BroadcastedIndicesToOffset("outputIndices", c_indices) << ";\n"; // get A data shader.MainFunctionBody() << "let a = vec4(" diff --git a/onnxruntime/core/providers/webgpu/tensor/transpose.cc b/onnxruntime/core/providers/webgpu/tensor/transpose.cc index d102620877c57..adcee8b64fd80 100644 --- a/onnxruntime/core/providers/webgpu/tensor/transpose.cc +++ b/onnxruntime/core/providers/webgpu/tensor/transpose.cc @@ -70,15 +70,14 @@ Status TransposeProgram::GenerateShaderCode(ShaderHelper& shader) const { " let input_col = workgroup_id_y * tile_size + local_id.x;\n" " let input_row = workgroup_id_x * tile_size + local_id.y;\n" " if (input_row < uniforms.a_shape[0] && input_col < uniforms.a_shape[1]) {\n" - " tile[local_id.y][local_id.x] = " - << input.GetByIndices("a_indices_t(input_row, input_col)") - << ";\n" - " }\n" + << " tile[local_id.y][local_id.x] = " << input.GetByIndices("a_indices_t(input_row, input_col)") << ";\n" + << " }\n" " workgroupBarrier();\n" " let output_col = workgroup_id_x * tile_size + local_id.x;\n" " let output_row = workgroup_id_y * tile_size + local_id.y;\n" - " if (output_row < uniforms.output_shape[0] && output_col < uniforms.output_shape[1]) {\n " - << output.SetByIndices("output_indices_t(output_row, output_col)", "tile[local_id.x][local_id.y]") << "\n }"; + " if (output_row < uniforms.output_shape[0] && output_col < uniforms.output_shape[1]) {\n" + << " " << output.SetByIndices("output_indices_t(output_row, output_col)", "tile[local_id.x][local_id.y]") << "\n" + << " }"; } else { shader.AdditionalImplementation() << "fn perm(i: output_indices_t)->a_indices_t {\n" " var a: a_indices_t;\n"; From 1d2d47681a46f0b7535cbecdf8900d528c730134 Mon Sep 17 00:00:00 2001 From: Yulong Wang <7679871+fs-eire@users.noreply.github.com> Date: Sat, 28 Sep 2024 14:12:25 -0700 Subject: [PATCH 4/6] resolve comments --- .../core/providers/webgpu/tensor/concat.cc | 53 +++++++++++-------- 1 file changed, 31 insertions(+), 22 deletions(-) diff --git a/onnxruntime/core/providers/webgpu/tensor/concat.cc b/onnxruntime/core/providers/webgpu/tensor/concat.cc index 8ef501d3ee413..866f99b587bc6 100644 --- a/onnxruntime/core/providers/webgpu/tensor/concat.cc +++ b/onnxruntime/core/providers/webgpu/tensor/concat.cc @@ -38,6 +38,33 @@ WEBGPU_CONCAT_VERSIONED_KERNEL(4, 10) WEBGPU_CONCAT_VERSIONED_KERNEL(11, 12) WEBGPU_CONCAT_KERNEL(13) +void AppendCalCulateInputIndexFunction(std::ostream& os, size_t input_count) { + os << "fn calculate_input_index(index: u32) -> u32 {\n" + << " for (var i = 0u; i < " << input_count << "; i = i + 1u) {\n" + << " if (index < uniforms.size_in_concat_axis[i]) {\n" + << " return i;\n" + << " }\n" + << " }\n" + << " return " << input_count << ";\n" + << "}\n"; +} + +void AppendAssignOutputDataFunction(std::ostream& os, gsl::span inputs, const ShaderVariableHelper& output) { + os << "fn assign_output_data(global_idx: u32, input_index: u32, indices: output_indices_t) {\n"; + for (size_t i = 0; i < inputs.size(); ++i) { + if (i == 0) { + os << " if (input_index == 0u) {\n"; + } else if (i == inputs.size() - 1) { + os << " } else {\n"; + } else { + os << " } else if (input_index == " << i << "u) {\n"; + } + os << " " << output.SetByOffset("global_idx", inputs[i]->GetByIndices("indices")) << ";\n"; + } + os << " }\n" + "}\n"; +} + Status ConcatProgram::GenerateShaderCode(ShaderHelper& shader) const { size_t input_count = Inputs().size(); std::vector inputs; @@ -47,28 +74,10 @@ Status ConcatProgram::GenerateShaderCode(ShaderHelper& shader) const { } const auto& output = shader.AddOutput("output", ShaderUsage::UseUniform | ShaderUsage::UseIndicesTypeAlias); - shader.AdditionalImplementation() << "fn calculate_input_index(index: u32) -> u32 {\n" - << " for (var i = 0u; i < " << input_count << "; i = i + 1u) {\n" - << " if (index < uniforms.size_in_concat_axis[i]) {\n" - << " return i;\n" - << " }\n" - << " }\n" - << " return " << input_count << ";\n" - << "}\n"; - - shader.AdditionalImplementation() << "fn assign_output_data(global_idx: u32, input_index: u32, indices: output_indices_t) {\n"; - for (size_t i = 0; i < input_count; ++i) { - if (i == 0) { - shader.AdditionalImplementation() << " if (input_index == 0u) {\n"; - } else if (i == input_count - 1) { - shader.AdditionalImplementation() << " } else {\n"; - } else { - shader.AdditionalImplementation() << " } else if (input_index == " << i << "u) {\n"; - } - shader.AdditionalImplementation() << " " << output.SetByOffset("global_idx", inputs[i]->GetByIndices("indices")) << ";\n"; - } - shader.AdditionalImplementation() << " }\n" - "}\n"; + // add implementation of fn calculate_input_index + AppendCalCulateInputIndexFunction(shader.AdditionalImplementation(), input_count); + // add implementation of fn assign_output_data + AppendAssignOutputDataFunction(shader.AdditionalImplementation(), inputs, output); shader.MainFunctionBody() << shader.GuardAgainstOutOfBoundsWorkgroupSizes("uniforms.output_size") << " var indices = " << output.OffsetToIndices("global_idx") << ";\n" From adc550831397b489625c671f860c3618d2b2acd3 Mon Sep 17 00:00:00 2001 From: Yulong Wang <7679871+fs-eire@users.noreply.github.com> Date: Sun, 29 Sep 2024 22:19:57 -0700 Subject: [PATCH 5/6] Tile --- .../core/providers/webgpu/tensor/tile.cc | 22 ++++++++----------- 1 file changed, 9 insertions(+), 13 deletions(-) diff --git a/onnxruntime/core/providers/webgpu/tensor/tile.cc b/onnxruntime/core/providers/webgpu/tensor/tile.cc index 2737b6dafea88..841c36724df30 100644 --- a/onnxruntime/core/providers/webgpu/tensor/tile.cc +++ b/onnxruntime/core/providers/webgpu/tensor/tile.cc @@ -30,22 +30,18 @@ Status TileProgram::GenerateShaderCode(ShaderHelper& shader) const { const ShaderVariableHelper& input = shader.AddInput("input", ShaderUsage::UseUniform | ShaderUsage::UseIndicesTypeAlias); const ShaderVariableHelper& output = shader.AddOutput("output", ShaderUsage::UseUniform | ShaderUsage::UseIndicesTypeAlias); - std::ostringstream ss; - ss.imbue(std::locale::classic()); - - ss << "var input_indices: input_indices_t;\n"; + shader.MainFunctionBody() << shader.GuardAgainstOutOfBoundsWorkgroupSizes("uniforms.output_size") + << "let output_indices = " << output.OffsetToIndices("global_idx") << ";\n" + << "var input_indices: input_indices_t;\n"; for (auto i = 0; i < input.Rank(); i++) { - std::string input_dim_i = "input_dim_" + std::to_string(i); - std::string input_dim_value = "input_dim_" + std::to_string(i) + "_value"; - ss << "let " << input_dim_i << " = " << input.IndicesGet("uniforms.input_shape", i) << ";\n"; - ss << "let " << input_dim_value << " = " << output.IndicesGet("output_indices", i) << " % " << input_dim_i << ";\n"; - ss << input.IndicesSet("input_indices", i, input_dim_value) << ";\n"; + std::string input_dim_i = absl::StrCat("input_dim_", i); + std::string input_dim_value = absl::StrCat("input_dim_", i, "_value"); + shader.MainFunctionBody() << "let " << input_dim_i << " = " << input.IndicesGet("uniforms.input_shape", i) << ";\n" + << "let " << input_dim_value << " = " << output.IndicesGet("output_indices", i) << " % " << input_dim_i << ";\n" + << input.IndicesSet("input_indices", i, input_dim_value) << ";\n"; } - shader.SetMainFunctionBody(shader.GuardAgainstOutOfBoundsWorkgroupSizes("uniforms.output_size"), - "let output_indices = ", output.OffsetToIndices("global_idx"), ";\n", - ss.str(), - output.SetByOffset("global_idx", input.GetByIndices("input_indices"))); + shader.MainFunctionBody() << output.SetByOffset("global_idx", input.GetByIndices("input_indices")); return Status::OK(); } From fd82c395273c30ff6ec488cfc4b5f5b376350848 Mon Sep 17 00:00:00 2001 From: Yulong Wang <7679871+fs-eire@users.noreply.github.com> Date: Mon, 30 Sep 2024 02:20:34 -0700 Subject: [PATCH 6/6] resolve comments --- onnxruntime/core/providers/webgpu/string_macros.h | 2 +- onnxruntime/core/providers/webgpu/string_utils.h | 7 +++++++ 2 files changed, 8 insertions(+), 1 deletion(-) diff --git a/onnxruntime/core/providers/webgpu/string_macros.h b/onnxruntime/core/providers/webgpu/string_macros.h index fd82375ce5589..7821d9c49a171 100644 --- a/onnxruntime/core/providers/webgpu/string_macros.h +++ b/onnxruntime/core/providers/webgpu/string_macros.h @@ -15,4 +15,4 @@ #define SS_GET(ss) ss##_str // macro "SS_APPEND" - use function call style to append to the ostream -#define SS_APPEND(ss, ...) ::onnxruntime::webgpu::detail::OStringStreamAppendImpl(ss, __VA_ARGS__) +#define SS_APPEND(ss, ...) ::onnxruntime::webgpu::detail::OStringStreamAppend(ss, __VA_ARGS__) diff --git a/onnxruntime/core/providers/webgpu/string_utils.h b/onnxruntime/core/providers/webgpu/string_utils.h index 6d7e4f40c9d19..e6d7097ad6182 100644 --- a/onnxruntime/core/providers/webgpu/string_utils.h +++ b/onnxruntime/core/providers/webgpu/string_utils.h @@ -3,6 +3,7 @@ #pragma once +#include "core/common/make_string.h" #include namespace onnxruntime { @@ -33,6 +34,12 @@ inline void OStringStreamAppendImpl(std::ostream& ss, const T& t, const Args&... OStringStreamAppendImpl(ss, t); OStringStreamAppendImpl(ss, args...); } + +template +inline void OStringStreamAppend(std::ostream& ss, const Args&... args) { + return OStringStreamAppendImpl(ss, ::onnxruntime::detail::if_char_array_make_ptr_t(args)...); +} + } // namespace detail } // namespace webgpu