From 2ce62db73b7a9d786d6d633232210107ff6ae2d1 Mon Sep 17 00:00:00 2001 From: Guenther Schmuelling Date: Fri, 27 Sep 2024 12:08:19 -0700 Subject: [PATCH 1/7] support for webgpu layernorms --- onnxruntime/contrib_ops/webgpu/layer_norm.cc | 36 ++++ .../contrib_ops/webgpu/skip_layer_norm.cc | 193 ++++++++++++++++++ .../contrib_ops/webgpu/skip_layer_norm.h | 63 ++++++ .../webgpu/webgpu_contrib_kernels.cc | 15 +- .../core/providers/webgpu/nn/layer_norm.cc | 171 ++++++++++++++++ .../core/providers/webgpu/nn/layer_norm.h | 66 ++++++ .../webgpu/webgpu_execution_provider.cc | 2 +- .../test/contrib_ops/layer_norm_test.cc | 2 +- .../test/contrib_ops/skiplayernorm_op_test.cc | 12 +- 9 files changed, 548 insertions(+), 12 deletions(-) create mode 100644 onnxruntime/contrib_ops/webgpu/layer_norm.cc create mode 100644 onnxruntime/contrib_ops/webgpu/skip_layer_norm.cc create mode 100644 onnxruntime/contrib_ops/webgpu/skip_layer_norm.h create mode 100644 onnxruntime/core/providers/webgpu/nn/layer_norm.cc create mode 100644 onnxruntime/core/providers/webgpu/nn/layer_norm.h diff --git a/onnxruntime/contrib_ops/webgpu/layer_norm.cc b/onnxruntime/contrib_ops/webgpu/layer_norm.cc new file mode 100644 index 0000000000000..8997e8698d96d --- /dev/null +++ b/onnxruntime/contrib_ops/webgpu/layer_norm.cc @@ -0,0 +1,36 @@ + +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "core/providers/webgpu/shader_helper.h" +#include "core/providers/webgpu/webgpu_supported_types.h" +#include "core/providers/webgpu/nn/layer_norm.h" +#include "contrib_ops/webgpu/webgpu_contrib_kernels.h" + +namespace onnxruntime { +namespace contrib { +namespace webgpu { + +using namespace onnxruntime::webgpu; +using onnxruntime::webgpu::ComputeContext; + +ONNX_OPERATOR_VERSIONED_KERNEL_EX( + LayerNormalization, + kOnnxDomain, + 1, + 16, + kWebGpuExecutionProvider, + (*KernelDefBuilder::Create()).TypeConstraint("T", WebGpuSupportedFloatTypes()), + onnxruntime::webgpu::LayerNorm); + +ONNX_OPERATOR_KERNEL_EX( + SimplifiedLayerNormalization, + kOnnxDomain, + 1, + kWebGpuExecutionProvider, + (*KernelDefBuilder::Create()).TypeConstraint("T", WebGpuSupportedFloatTypes()), + onnxruntime::webgpu::LayerNorm); + +} // namespace webgpu +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/webgpu/skip_layer_norm.cc b/onnxruntime/contrib_ops/webgpu/skip_layer_norm.cc new file mode 100644 index 0000000000000..10642da7e8763 --- /dev/null +++ b/onnxruntime/contrib_ops/webgpu/skip_layer_norm.cc @@ -0,0 +1,193 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "core/providers/webgpu/shader_helper.h" +#include "core/providers/webgpu/webgpu_supported_types.h" +#include "contrib_ops/webgpu/skip_layer_norm.h" +#include "contrib_ops/webgpu/webgpu_contrib_kernels.h" + +namespace onnxruntime { +namespace contrib { +namespace webgpu { + +static uint32_t getMaxComponents(int size) { + if (size % 4 == 0) { + return 4; + } else if (size % 2 == 0) { + return 2; + } + return 1; +} + +static std::string fillVar(std::string dataType, int components, std::string value) { + if (components == 1) { + return dataType + "(" + value + ")"; + } + return "vec" + std::to_string(components) + "<" + dataType + ">(" + value + ")"; +} + +static std::string sumVector(std::string x, int components) { + switch (components) { + case 1: + return x; + case 2: + return "(" + x + ".x + " + x + ".y" + ")"; + case 4: + return "(" + x + ".x + " + x + ".y + " + x + ".w + " + x + ".z" + ")"; + default: + ORT_THROW("Unsupported number of components: ", components); + } +} + +static std::string castToF32(int components, std::string value) { + if (components == 1) { + return "f32(" + value + ")"; + } + return "vec" + std::to_string(components) + "(" + value + ")"; +}; + +static std::string vecDataType(std::string datatype, int components) { + if (components == 1) { + return datatype; + } + return "vec" + std::to_string(components) + "<" + datatype + ">"; +}; + +Status SkipLayerNormProgram::GenerateShaderCode(ShaderHelper& shader) const { + const auto& x = shader.AddInput("x", ShaderUsage::UseUniform | ShaderUsage::UseValueTypeAlias); + shader.AddInput("skip", ShaderUsage::UseUniform); + shader.AddInput("gamma", ShaderUsage::UseUniform); + if (hasBeta_) { + shader.AddInput("beta", ShaderUsage::UseUniform); + } + if (hasBias_) { + shader.AddInput("bias", ShaderUsage::UseUniform); + } + shader.AddOutput("output", ShaderUsage::UseUniform); + + int components = x.NumComponents(); + + std::string bias = (hasBias_) ? " + bias[offset1d + i] " : ""; + std::string simpl1 = (simplified_) ? "" : "- mean * mean "; + std::string simpl2 = (simplified_) ? "" : "- element_t(mean) "; + std::string fillvec = fillVar("f32", components, "0"); + std::string beta = (hasBeta_) ? " + beta[offset1d + i] " : ""; + std::string element_type = (isFP16_) ? "f16;\n" : "f32;\n"; + + shader.AppendImplementation( + "alias element_t = " + element_type + + "var sum_shared : array<" + vecDataType("f32", components) + + ",workgroup_size_x>;\n" + "var sum_squared_shared : array<" + + vecDataType("f32", components) + ",workgroup_size_x>;\n"); + + std::stringstream ss; + ss << "let ix = local_idx;\n" + << "let iy = global_idx / workgroup_size_x;\n" + << "let hidden_size_vectorized: u32 = uniforms.hidden_size / uniforms.components;\n" + << "var stride = hidden_size_vectorized / workgroup_size_x;\n" + << "let offset = ix * stride + iy * hidden_size_vectorized;\n" + << "let offset1d = stride * ix;\n" + << "if (ix == workgroup_size_x - 1) {\n" + << " stride = hidden_size_vectorized - stride * ix;\n" + << "}\n" + << "for (var i: u32 = 0; i < stride; i++) {\n" + << " let skip_value = skip[offset + i];\n" + << " let input_value = x[offset + i];\n" + << " let value = input_value + skip_value" << bias << ";\n" + << " output[offset + i] = value;\n" + << " let f32_value = " << castToF32(components, "value") << ";\n" + << " sum_shared[ix] += f32_value;\n" + << " sum_squared_shared[ix] += f32_value * f32_value;\n" + << "}\n" + << "workgroupBarrier();\n" + << "var reduce_size : u32 = workgroup_size_x;\n" + << "for (var curr_size = reduce_size >> 1; curr_size > 0; curr_size = reduce_size >> 1) {\n" + << " reduce_size = curr_size + (reduce_size & 1);\n" + << " if (ix < curr_size) {\n" + << " sum_shared[ix] += sum_shared[ix + reduce_size];\n" + << " sum_squared_shared[ix] += sum_squared_shared[ix + reduce_size];\n" + << " }\n" + << " workgroupBarrier();\n" + << "}\n" + << "let sum = sum_shared[0];\n" + << "let square_sum = sum_squared_shared[0];\n" + << "let mean = " << sumVector("sum", components) << " / f32(uniforms.hidden_size);\n" + << "let inv_std_dev = inverseSqrt(" << sumVector("square_sum", components) << " / f32(uniforms.hidden_size) " << simpl1 << "+ uniforms.epsilon);\n" + << "for (var i: u32 = 0; i < stride; i++) {\n" + << " output[offset + i] = (output[offset + i] " << simpl2 << ") * element_t(inv_std_dev) * gamma[offset1d + i]" << beta << ";\n" + << "};\n"; + + shader.SetMainFunctionBody(ss.str()); + return Status::OK(); +} + +template +Status SkipLayerNorm::ComputeInternal(onnxruntime::webgpu::ComputeContext& context) const { + const Tensor* x = context.Input(0); + const Tensor* skip = context.Input(1); + const Tensor* gamma = context.Input(2); + // optional + const Tensor* beta = context.Input(3); + const Tensor* bias = context.Input(4); + + const auto x_shape = x->Shape(); + + auto* output = context.Output(0, x_shape); + + size_t data_size = x_shape.Size(); + if (data_size == 0) { + return Status::OK(); + } + + const bool is_fp16 = x->GetElementType() == ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT16; + const int hidden_size = x_shape[x_shape.NumDimensions() - 1]; + const int components = getMaxComponents(hidden_size); + + SkipLayerNormProgram program{beta != nullptr, bias != nullptr, epsilon_, hidden_size, is_fp16, simplified}; + program + .CacheHint(simplified) + .AddInputs({{x, ProgramTensorMetadataDependency::Type, components}}) + .AddInputs({{skip, ProgramTensorMetadataDependency::Type, components}}) + .AddInputs({{gamma, ProgramTensorMetadataDependency::Type, components}}) + .AddOutputs({{output, ProgramTensorMetadataDependency::None, components}}) + .SetDispatchGroupSize(ceil(data_size / hidden_size)) + .AddUniformVariables({ + {static_cast(components)}, + }) + .AddUniformVariables({ + {static_cast(hidden_size)}, + }) + .AddUniformVariables({ + {static_cast(epsilon_)}, + }); + + if (beta != nullptr) { + program.AddInput({beta, ProgramTensorMetadataDependency::Type, components}); + } + if (bias != nullptr) { + program.AddInput({bias, ProgramTensorMetadataDependency::Type, components}); + } + + return context.RunProgram(program); +} + +ONNX_OPERATOR_KERNEL_EX( + SkipLayerNormalization, + kMSDomain, + 1, + kWebGpuExecutionProvider, + (*KernelDefBuilder::Create()).TypeConstraint("T", WebGpuSupportedFloatTypes()), + SkipLayerNorm); + +ONNX_OPERATOR_KERNEL_EX( + SkipSimplifiedLayerNormalization, + kMSDomain, + 1, + kWebGpuExecutionProvider, + (*KernelDefBuilder::Create()).TypeConstraint("T", WebGpuSupportedFloatTypes()), + SkipLayerNorm); + +} // namespace webgpu +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/webgpu/skip_layer_norm.h b/onnxruntime/contrib_ops/webgpu/skip_layer_norm.h new file mode 100644 index 0000000000000..aabe106a2fbb1 --- /dev/null +++ b/onnxruntime/contrib_ops/webgpu/skip_layer_norm.h @@ -0,0 +1,63 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include "core/providers/webgpu/program.h" +#include "core/providers/webgpu/webgpu_kernel.h" + +namespace onnxruntime { +namespace contrib { +namespace webgpu { + +using namespace onnxruntime::webgpu; +using onnxruntime::webgpu::ComputeContext; + + +class SkipLayerNormProgram final : public Program { + public: + SkipLayerNormProgram(bool hasBeta, bool hasBias, float epsilon, int hiddenSize, bool isFP16, bool simplified) : Program{"SkipLayerNorm"} { + epsilon_ = epsilon; + hasBeta_ = hasBeta; + hasBias_ = hasBias; + epsilon_ = epsilon; + hiddenSize_ = hiddenSize; + simplified_ = simplified; + isFP16_ = isFP16; + } + + Status GenerateShaderCode(ShaderHelper& sh) const override; + + WEBGPU_PROGRAM_DEFINE_UNIFORM_VARIABLES( + {"components", ProgramUniformVariableDataType::Uint32}, + {"hidden_size", ProgramUniformVariableDataType::Uint32}, + {"epsilon", ProgramUniformVariableDataType::Float32}); + + private: + bool hasBeta_; + bool hasBias_; + float epsilon_; + int hiddenSize_; + bool isFP16_; + bool simplified_; +}; + +template +class SkipLayerNorm final : public WebGpuKernel { + public: + SkipLayerNorm(const OpKernelInfo& info) : WebGpuKernel(info) { + info.GetAttrOrDefault("epsilon", &epsilon_, 1e-05); + } + + Status ComputeInternal(ComputeContext& context) const override; + + protected: + std::string cache_hint; + + private: + float epsilon_; +}; + +} // namespace webgpu +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/webgpu/webgpu_contrib_kernels.cc b/onnxruntime/contrib_ops/webgpu/webgpu_contrib_kernels.cc index b5d7a90b9bbfd..5cb73b8f2735e 100644 --- a/onnxruntime/contrib_ops/webgpu/webgpu_contrib_kernels.cc +++ b/onnxruntime/contrib_ops/webgpu/webgpu_contrib_kernels.cc @@ -22,6 +22,7 @@ class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kMSDomain, 1, Ma class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kMSDomain, 1, MultiHeadAttention); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kMSDomain, 1, QuickGelu); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kMSDomain, 1, RotaryEmbedding); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 1, SimplifiedLayerNormalization); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kMSDomain, 1, SkipLayerNormalization); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 1, SimplifiedLayerNormalization); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kMSDomain, 1, SkipSimplifiedLayerNormalization); @@ -44,16 +45,14 @@ Status RegisterWebGpuContribKernels(KernelRegistry& kernel_registry) { // BuildKernelCreateInfo, // // LayerNormalization used to be a contrib op that (incorrectly) used kOnnxDomain so we need to version it // BuildKernelCreateInfo, - BuildKernelCreateInfo, + // BuildKernelCreateInfo, // BuildKernelCreateInfo, // BuildKernelCreateInfo, - BuildKernelCreateInfo, - // BuildKernelCreateInfo, - // BuildKernelCreateInfo, - // BuildKernelCreateInfo + // BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo }; for (auto& function_table_entry : function_table) { diff --git a/onnxruntime/core/providers/webgpu/nn/layer_norm.cc b/onnxruntime/core/providers/webgpu/nn/layer_norm.cc new file mode 100644 index 0000000000000..5c3a1177146f3 --- /dev/null +++ b/onnxruntime/core/providers/webgpu/nn/layer_norm.cc @@ -0,0 +1,171 @@ + +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "core/providers/webgpu/shader_helper.h" +#include "core/providers/webgpu/webgpu_supported_types.h" +#include "core/providers/webgpu/nn/layer_norm.h" + +namespace onnxruntime { +namespace webgpu { + +static uint32_t getMaxComponents(int size) { + if (size % 4 == 0) { + return 4; + } else if (size % 2 == 0) { + return 2; + } + return 1; +} + +static int normalizeAxis(int axis, int tensorRank) { + if (axis < -tensorRank && axis >= tensorRank) { + ORT_THROW("invalid axis: ", axis); + } + return axis < 0 ? axis + tensorRank : axis; +} + +static std::string fillVar(std::string dataType, int components, std::string value) { + if (components == 1) { + return dataType + "(" + value + ")"; + } + return "vec" + std::to_string(components) + "<" + dataType + ">(" + value + ")"; +} + +static std::string castToF32(int components, std::string value) { + if (components == 1) { + return "f32(" + value + ")"; + } + return "vec" + std::to_string(components) + "(" + value + ")"; +}; + +static std::string sumVector(std::string x, int components) { + switch (components) { + case 1: + return x; + case 2: + return "(" + x + ".x + " + x + ".y" + ")"; + case 4: + return "(" + x + ".x + " + x + ".y + " + x + ".w + " + x + ".z" + ")"; + default: + ORT_THROW("Unsupported number of components: ", components); + } +} + +Status LayerNormProgram::GenerateShaderCode(ShaderHelper& shader) const { + const auto& x = shader.AddInput("x", ShaderUsage::UseUniform | ShaderUsage::UseValueTypeAlias); + shader.AddInput("scale", ShaderUsage::UseUniform); + if (has_bias_) { + shader.AddInput("bias", ShaderUsage::UseUniform); + } + shader.AddOutput("output", ShaderUsage::UseUniform); + + int components = x.NumComponents(); + std::string bias = (has_bias_) ? " + bias[j] " : ""; + std::string simpl1 = (simplified_) ? "" : "- mean * mean "; + std::string simpl2 = (simplified_) ? "" : "- mean "; + std::string fillvec = fillVar("f32", components, "0"); + std::string element_type = (isFP16_) ? "f16;\n" : "f32;\n"; + + shader.AppendImplementation("alias element_t = " + element_type); + + std::stringstream ss; + ss << "let offset = global_idx * uniforms.norm_size_vectorized;\n" + << "var mean_vector = " << fillvec << ";\n" + << "var mean_square_vector = " << fillvec << ";\n" + << "for (var h: u32 = 0u; h < uniforms.norm_size_vectorized; h++) {\n" + << " let value = " << castToF32(components, "x[h + offset]") << ";\n" + << " mean_vector += value;\n" + << " mean_square_vector += value * value;\n" + << "}\n" + << "let mean = " << sumVector("mean_vector", components) << " / f32(uniforms.norm_size);\n" + << "let inv_std_dev = inverseSqrt(" << sumVector("mean_square_vector", components) << " / f32(uniforms.norm_size) " << simpl1 << "+ uniforms.epsilon);\n" + << "for (var j: u32 = 0; j < uniforms.norm_size_vectorized; j++) {\n" + << " let f32input = " << castToF32(components, "x[j + offset]") << ";\n" + << " let f32scale = " << castToF32(components, "scale[j]") << ";\n" + << " output[j + offset] = x_value_t((f32input " << simpl2 << ") * inv_std_dev * f32scale)" << bias << ";\n" + << "}\n"; + + shader.SetMainFunctionBody(shader.GuardAgainstOutOfBoundsWorkgroupSizes("uniforms.norm_count"), ss.str()); + return Status::OK(); +} + +template +Status LayerNorm::ComputeInternal(onnxruntime::webgpu::ComputeContext& context) const { + const auto* x = context.Input(0); + const auto* scale = context.Input(1); + const auto* bias = context.Input(2); + + const auto x_shape = x->Shape(); + + auto* output = context.Output(0, x_shape); + + size_t data_size = x_shape.Size(); + if (data_size == 0) { + return Status::OK(); + } + + const bool is_fp16 = x->GetElementType() == ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT16; + + const int axis = normalizeAxis(axis_, x_shape.NumDimensions()); + const int norm_count = x_shape.SizeToDimension(axis); + const int norm_size = x_shape.SizeFromDimension(axis); + const int components = getMaxComponents(norm_size); + const int norm_size_vectorized = (norm_size + components - 1) / components; + + const auto scale_size = scale->Shape().Size(); + const auto bias_size = (bias) ? bias->Shape().Size() : 0; + if (scale_size != norm_size || (bias && bias_size != norm_size)) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "Size of X.shape()[axis:] == ", norm_size, + ". Size of scale and bias (if provided) must match this. Got scale size of ", + scale_size, " and bias size of ", bias_size); + } + + LayerNormProgram program{axis_, epsilon_, stash_type_, bias != nullptr, data_size, is_fp16, simplified}; + + program + .CacheHint(simplified) + .AddInputs({{x, ProgramTensorMetadataDependency::Type, components}}) + .AddInputs({{scale, ProgramTensorMetadataDependency::Type, components}}) + .AddOutputs({{output, ProgramTensorMetadataDependency::None, components}}) + .SetDispatchGroupSize((norm_count + WORKGROUP_SIZE - 1) / WORKGROUP_SIZE) + .AddUniformVariables({ + {static_cast(norm_count)}, + }) + .AddUniformVariables({ + {static_cast(norm_size)}, + }) + .AddUniformVariables({ + {static_cast(norm_size_vectorized)}, + }) + .AddUniformVariables({ + {static_cast(epsilon_)}, + }); + + if (bias != nullptr) { + program.AddInput({bias, ProgramTensorMetadataDependency::Type, components}); + } + return context.RunProgram(program); +} + + +ONNX_OPERATOR_KERNEL_EX( + LayerNormalization, + kOnnxDomain, + 17, + kWebGpuExecutionProvider, + (*KernelDefBuilder::Create()) + .TypeConstraint("T", WebGpuSupportedFloatTypes()), + LayerNorm); + +ONNX_OPERATOR_KERNEL_EX( + SimplifiedLayerNormalization, + kOnnxDomain, + 1, + kWebGpuExecutionProvider, + (*KernelDefBuilder::Create()).TypeConstraint("T", WebGpuSupportedFloatTypes()), + LayerNorm); + +} // namespace webgpu +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/webgpu/nn/layer_norm.h b/onnxruntime/core/providers/webgpu/nn/layer_norm.h new file mode 100644 index 0000000000000..b7cd2c6478901 --- /dev/null +++ b/onnxruntime/core/providers/webgpu/nn/layer_norm.h @@ -0,0 +1,66 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include "core/providers/webgpu/program.h" +#include "core/providers/webgpu/webgpu_kernel.h" + +namespace onnxruntime { +namespace webgpu { + +using namespace onnxruntime::webgpu; +using onnxruntime::webgpu::ComputeContext; + +class LayerNormProgram final : public Program { + public: + LayerNormProgram(int64_t axis, float epsilon, int64_t stash_type, bool has_bias, size_t x_size, bool isFP16, bool simplified) : Program{"LayerNorm"} { + axis_ = axis; + epsilon_ = epsilon; + stash_type_ = stash_type; + has_bias_ = has_bias; + x_size_ = x_size; + isFP16_ = isFP16; + simplified_ = simplified; + } + + Status GenerateShaderCode(ShaderHelper& sh) const override; + + WEBGPU_PROGRAM_DEFINE_UNIFORM_VARIABLES( + {"norm_count", ProgramUniformVariableDataType::Uint32}, + {"norm_size", ProgramUniformVariableDataType::Uint32}, + {"norm_size_vectorized", ProgramUniformVariableDataType::Uint32}, + {"epsilon", ProgramUniformVariableDataType::Float32}); + + private: + int64_t axis_; + float epsilon_; + int64_t stash_type_; + bool has_bias_; + int x_size_; + bool isFP16_; + bool simplified_; +}; + +template +class LayerNorm final : public WebGpuKernel { + public: + LayerNorm(const OpKernelInfo& info) : WebGpuKernel(info) { + info.GetAttrOrDefault("axis", &axis_, -1); + info.GetAttrOrDefault("epsilon", &epsilon_, 1e-05); + info.GetAttrOrDefault("stash_type", &stash_type_, 1); + } + + Status ComputeInternal(ComputeContext& context) const override; + + protected: + std::string cache_hint; + + private: + int64_t axis_; + float epsilon_; + int64_t stash_type_; +}; + +} // namespace webgpu +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/webgpu/webgpu_execution_provider.cc b/onnxruntime/core/providers/webgpu/webgpu_execution_provider.cc index c1f13d652413d..093b1d8ce015f 100644 --- a/onnxruntime/core/providers/webgpu/webgpu_execution_provider.cc +++ b/onnxruntime/core/providers/webgpu/webgpu_execution_provider.cc @@ -671,7 +671,7 @@ std::unique_ptr RegisterKernels() { // BuildKernelCreateInfo, // BuildKernelCreateInfo, - // BuildKernelCreateInfo, + BuildKernelCreateInfo, // BuildKernelCreateInfo, // BuildKernelCreateInfo, diff --git a/onnxruntime/test/contrib_ops/layer_norm_test.cc b/onnxruntime/test/contrib_ops/layer_norm_test.cc index 438a1100ca95c..218c48a37de9b 100644 --- a/onnxruntime/test/contrib_ops/layer_norm_test.cc +++ b/onnxruntime/test/contrib_ops/layer_norm_test.cc @@ -6,7 +6,7 @@ namespace onnxruntime { namespace test { -#if defined(USE_CUDA) || defined(USE_ROCM) || defined(USE_DML) +#if defined(USE_CUDA) || defined(USE_ROCM) || defined(USE_DML) || defined(USE_WEBGPU) constexpr auto k_epsilon_default = 1e-5f; constexpr auto k_random_data_min = -10.0f; constexpr auto k_random_data_max = 10.0f; diff --git a/onnxruntime/test/contrib_ops/skiplayernorm_op_test.cc b/onnxruntime/test/contrib_ops/skiplayernorm_op_test.cc index edf9064bb43c9..b9ca55073d411 100644 --- a/onnxruntime/test/contrib_ops/skiplayernorm_op_test.cc +++ b/onnxruntime/test/contrib_ops/skiplayernorm_op_test.cc @@ -62,6 +62,8 @@ static void RunOneTest( auto rocm_ep = DefaultRocmExecutionProvider(); auto dml_ep = DefaultDmlExecutionProvider(); auto cpu_ep = DefaultCpuExecutionProvider(); + auto webgpu_ep = DefaultWebGpuExecutionProvider(); + std::vector> execution_providers; if (!use_float16) { OpTester test(op_type.c_str(), 1, onnxruntime::kMSDomain); @@ -95,10 +97,14 @@ static void RunOneTest( if (cpu_ep != nullptr) { execution_providers.push_back(DefaultCpuExecutionProvider()); } + if (webgpu_ep != nullptr) { + execution_providers.push_back(DefaultWebGpuExecutionProvider()); + } test.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers); } else if (HasCudaEnvironment(530 /*min_cuda_architecture*/) || dml_ep != nullptr || - rocm_ep != nullptr) { + rocm_ep != nullptr || + webgpu_ep != nullptr) { OpTester test(op_type.c_str(), 1, onnxruntime::kMSDomain); test.AddInput("input", input_dims, ToFloat16(input_data)); test.AddInput("skip", skip_dims, ToFloat16(skip_data)); @@ -132,7 +138,9 @@ static void RunOneTest( ToFloat16(sum_output_data)); } - if (dml_ep != nullptr) { + if (webgpu_ep != nullptr) { + execution_providers.push_back(DefaultWebGpuExecutionProvider()); + } else if (dml_ep != nullptr) { execution_providers.push_back(DefaultDmlExecutionProvider()); } else if (rocm_ep != nullptr) { execution_providers.push_back(DefaultRocmExecutionProvider()); From 6d653801c87e39f6d9cd84ce4b66e6e3c12bec57 Mon Sep 17 00:00:00 2001 From: Guenther Schmuelling Date: Fri, 27 Sep 2024 17:17:41 -0700 Subject: [PATCH 2/7] make lint happy --- onnxruntime/contrib_ops/webgpu/skip_layer_norm.h | 7 +++---- onnxruntime/contrib_ops/webgpu/webgpu_contrib_kernels.cc | 3 +-- onnxruntime/core/providers/webgpu/nn/layer_norm.cc | 3 +-- 3 files changed, 5 insertions(+), 8 deletions(-) diff --git a/onnxruntime/contrib_ops/webgpu/skip_layer_norm.h b/onnxruntime/contrib_ops/webgpu/skip_layer_norm.h index aabe106a2fbb1..a9a227adb5130 100644 --- a/onnxruntime/contrib_ops/webgpu/skip_layer_norm.h +++ b/onnxruntime/contrib_ops/webgpu/skip_layer_norm.h @@ -13,7 +13,6 @@ namespace webgpu { using namespace onnxruntime::webgpu; using onnxruntime::webgpu::ComputeContext; - class SkipLayerNormProgram final : public Program { public: SkipLayerNormProgram(bool hasBeta, bool hasBias, float epsilon, int hiddenSize, bool isFP16, bool simplified) : Program{"SkipLayerNorm"} { @@ -29,9 +28,9 @@ class SkipLayerNormProgram final : public Program { Status GenerateShaderCode(ShaderHelper& sh) const override; WEBGPU_PROGRAM_DEFINE_UNIFORM_VARIABLES( - {"components", ProgramUniformVariableDataType::Uint32}, - {"hidden_size", ProgramUniformVariableDataType::Uint32}, - {"epsilon", ProgramUniformVariableDataType::Float32}); + {"components", ProgramUniformVariableDataType::Uint32}, + {"hidden_size", ProgramUniformVariableDataType::Uint32}, + {"epsilon", ProgramUniformVariableDataType::Float32}); private: bool hasBeta_; diff --git a/onnxruntime/contrib_ops/webgpu/webgpu_contrib_kernels.cc b/onnxruntime/contrib_ops/webgpu/webgpu_contrib_kernels.cc index 5cb73b8f2735e..31d7432c54231 100644 --- a/onnxruntime/contrib_ops/webgpu/webgpu_contrib_kernels.cc +++ b/onnxruntime/contrib_ops/webgpu/webgpu_contrib_kernels.cc @@ -52,8 +52,7 @@ Status RegisterWebGpuContribKernels(KernelRegistry& kernel_registry) { BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, - BuildKernelCreateInfo - }; + BuildKernelCreateInfo}; for (auto& function_table_entry : function_table) { KernelCreateInfo info = function_table_entry(); diff --git a/onnxruntime/core/providers/webgpu/nn/layer_norm.cc b/onnxruntime/core/providers/webgpu/nn/layer_norm.cc index 5c3a1177146f3..526655910d1df 100644 --- a/onnxruntime/core/providers/webgpu/nn/layer_norm.cc +++ b/onnxruntime/core/providers/webgpu/nn/layer_norm.cc @@ -149,14 +149,13 @@ Status LayerNorm::ComputeInternal(onnxruntime::webgpu::ComputeContex return context.RunProgram(program); } - ONNX_OPERATOR_KERNEL_EX( LayerNormalization, kOnnxDomain, 17, kWebGpuExecutionProvider, (*KernelDefBuilder::Create()) - .TypeConstraint("T", WebGpuSupportedFloatTypes()), + .TypeConstraint("T", WebGpuSupportedFloatTypes()), LayerNorm); ONNX_OPERATOR_KERNEL_EX( From ff6ef8ee7e8162f2b20211a09f6b46f65960c354 Mon Sep 17 00:00:00 2001 From: Guenther Schmuelling Date: Mon, 30 Sep 2024 09:09:17 -0700 Subject: [PATCH 3/7] track changes in main branch --- onnxruntime/contrib_ops/webgpu/skip_layer_norm.cc | 12 +++++------- onnxruntime/core/providers/webgpu/nn/layer_norm.cc | 4 ++-- 2 files changed, 7 insertions(+), 9 deletions(-) diff --git a/onnxruntime/contrib_ops/webgpu/skip_layer_norm.cc b/onnxruntime/contrib_ops/webgpu/skip_layer_norm.cc index 10642da7e8763..7c9757b4d48bd 100644 --- a/onnxruntime/contrib_ops/webgpu/skip_layer_norm.cc +++ b/onnxruntime/contrib_ops/webgpu/skip_layer_norm.cc @@ -74,12 +74,10 @@ Status SkipLayerNormProgram::GenerateShaderCode(ShaderHelper& shader) const { std::string beta = (hasBeta_) ? " + beta[offset1d + i] " : ""; std::string element_type = (isFP16_) ? "f16;\n" : "f32;\n"; - shader.AppendImplementation( - "alias element_t = " + element_type + - "var sum_shared : array<" + vecDataType("f32", components) + - ",workgroup_size_x>;\n" - "var sum_squared_shared : array<" + - vecDataType("f32", components) + ",workgroup_size_x>;\n"); + shader.AdditionalImplementation() + << "alias element_t = " << element_type + << "var sum_shared : array<" << vecDataType("f32", components) << ",workgroup_size_x>;\n" + << "var sum_squared_shared : array<" << vecDataType("f32", components) << ",workgroup_size_x>;\n"; std::stringstream ss; ss << "let ix = local_idx;\n" @@ -118,7 +116,7 @@ Status SkipLayerNormProgram::GenerateShaderCode(ShaderHelper& shader) const { << " output[offset + i] = (output[offset + i] " << simpl2 << ") * element_t(inv_std_dev) * gamma[offset1d + i]" << beta << ";\n" << "};\n"; - shader.SetMainFunctionBody(ss.str()); + shader.MainFunctionBody() << ss.str(); return Status::OK(); } diff --git a/onnxruntime/core/providers/webgpu/nn/layer_norm.cc b/onnxruntime/core/providers/webgpu/nn/layer_norm.cc index 526655910d1df..608caafa61f08 100644 --- a/onnxruntime/core/providers/webgpu/nn/layer_norm.cc +++ b/onnxruntime/core/providers/webgpu/nn/layer_norm.cc @@ -67,7 +67,7 @@ Status LayerNormProgram::GenerateShaderCode(ShaderHelper& shader) const { std::string fillvec = fillVar("f32", components, "0"); std::string element_type = (isFP16_) ? "f16;\n" : "f32;\n"; - shader.AppendImplementation("alias element_t = " + element_type); + shader.AdditionalImplementation() << "alias element_t = " << element_type; std::stringstream ss; ss << "let offset = global_idx * uniforms.norm_size_vectorized;\n" @@ -86,7 +86,7 @@ Status LayerNormProgram::GenerateShaderCode(ShaderHelper& shader) const { << " output[j + offset] = x_value_t((f32input " << simpl2 << ") * inv_std_dev * f32scale)" << bias << ";\n" << "}\n"; - shader.SetMainFunctionBody(shader.GuardAgainstOutOfBoundsWorkgroupSizes("uniforms.norm_count"), ss.str()); + shader.MainFunctionBody() << shader.GuardAgainstOutOfBoundsWorkgroupSizes("uniforms.norm_count") << ss.str(); return Status::OK(); } From fa1ec0673da8587228c49b81d8adac379261fecf Mon Sep 17 00:00:00 2001 From: Yulong Wang <7679871+fs-eire@users.noreply.github.com> Date: Mon, 30 Sep 2024 16:55:14 -0700 Subject: [PATCH 4/7] part of the fix --- .../contrib_ops/webgpu/skip_layer_norm.cc | 6 +-- .../contrib_ops/webgpu/skip_layer_norm.h | 12 +++--- .../core/providers/webgpu/nn/layer_norm.cc | 39 ++++++++++--------- .../core/providers/webgpu/nn/layer_norm.h | 32 ++++++++------- .../test/contrib_ops/layer_norm_test.cc | 2 + .../providers/compare_provider_test_utils.cc | 2 + 6 files changed, 50 insertions(+), 43 deletions(-) diff --git a/onnxruntime/contrib_ops/webgpu/skip_layer_norm.cc b/onnxruntime/contrib_ops/webgpu/skip_layer_norm.cc index 7c9757b4d48bd..8e961eba47a2e 100644 --- a/onnxruntime/contrib_ops/webgpu/skip_layer_norm.cc +++ b/onnxruntime/contrib_ops/webgpu/skip_layer_norm.cc @@ -72,7 +72,7 @@ Status SkipLayerNormProgram::GenerateShaderCode(ShaderHelper& shader) const { std::string simpl2 = (simplified_) ? "" : "- element_t(mean) "; std::string fillvec = fillVar("f32", components, "0"); std::string beta = (hasBeta_) ? " + beta[offset1d + i] " : ""; - std::string element_type = (isFP16_) ? "f16;\n" : "f32;\n"; + std::string element_type = (is_fp16_) ? "f16;\n" : "f32;\n"; shader.AdditionalImplementation() << "alias element_t = " << element_type @@ -139,7 +139,7 @@ Status SkipLayerNorm::ComputeInternal(onnxruntime::webgpu::ComputeCo } const bool is_fp16 = x->GetElementType() == ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT16; - const int hidden_size = x_shape[x_shape.NumDimensions() - 1]; + const uint32_t hidden_size = SafeInt(x_shape[x_shape.NumDimensions() - 1]); const int components = getMaxComponents(hidden_size); SkipLayerNormProgram program{beta != nullptr, bias != nullptr, epsilon_, hidden_size, is_fp16, simplified}; @@ -149,7 +149,7 @@ Status SkipLayerNorm::ComputeInternal(onnxruntime::webgpu::ComputeCo .AddInputs({{skip, ProgramTensorMetadataDependency::Type, components}}) .AddInputs({{gamma, ProgramTensorMetadataDependency::Type, components}}) .AddOutputs({{output, ProgramTensorMetadataDependency::None, components}}) - .SetDispatchGroupSize(ceil(data_size / hidden_size)) + .SetDispatchGroupSize(SafeInt(ceil(1.0 * data_size / hidden_size))) .AddUniformVariables({ {static_cast(components)}, }) diff --git a/onnxruntime/contrib_ops/webgpu/skip_layer_norm.h b/onnxruntime/contrib_ops/webgpu/skip_layer_norm.h index a9a227adb5130..d9ef732e28af7 100644 --- a/onnxruntime/contrib_ops/webgpu/skip_layer_norm.h +++ b/onnxruntime/contrib_ops/webgpu/skip_layer_norm.h @@ -15,14 +15,14 @@ using onnxruntime::webgpu::ComputeContext; class SkipLayerNormProgram final : public Program { public: - SkipLayerNormProgram(bool hasBeta, bool hasBias, float epsilon, int hiddenSize, bool isFP16, bool simplified) : Program{"SkipLayerNorm"} { + SkipLayerNormProgram(bool hasBeta, bool hasBias, float epsilon, uint32_t hidden_size, bool is_fp16, bool simplified) : Program{"SkipLayerNorm"} { epsilon_ = epsilon; hasBeta_ = hasBeta; hasBias_ = hasBias; epsilon_ = epsilon; - hiddenSize_ = hiddenSize; + hidden_size_ = hidden_size; simplified_ = simplified; - isFP16_ = isFP16; + is_fp16_ = is_fp16; } Status GenerateShaderCode(ShaderHelper& sh) const override; @@ -36,8 +36,8 @@ class SkipLayerNormProgram final : public Program { bool hasBeta_; bool hasBias_; float epsilon_; - int hiddenSize_; - bool isFP16_; + uint32_t hidden_size_; + bool is_fp16_; bool simplified_; }; @@ -45,7 +45,7 @@ template class SkipLayerNorm final : public WebGpuKernel { public: SkipLayerNorm(const OpKernelInfo& info) : WebGpuKernel(info) { - info.GetAttrOrDefault("epsilon", &epsilon_, 1e-05); + info.GetAttrOrDefault("epsilon", &epsilon_, 1e-05f); } Status ComputeInternal(ComputeContext& context) const override; diff --git a/onnxruntime/core/providers/webgpu/nn/layer_norm.cc b/onnxruntime/core/providers/webgpu/nn/layer_norm.cc index 608caafa61f08..d21067897c06b 100644 --- a/onnxruntime/core/providers/webgpu/nn/layer_norm.cc +++ b/onnxruntime/core/providers/webgpu/nn/layer_norm.cc @@ -9,7 +9,7 @@ namespace onnxruntime { namespace webgpu { -static uint32_t getMaxComponents(int size) { +static int GetMaxComponents(int64_t size) { if (size % 4 == 0) { return 4; } else if (size % 2 == 0) { @@ -18,28 +18,29 @@ static uint32_t getMaxComponents(int size) { return 1; } -static int normalizeAxis(int axis, int tensorRank) { - if (axis < -tensorRank && axis >= tensorRank) { +static size_t NormalizeAxis(int64_t axis, size_t tensor_rank) { + int64_t rank = static_cast(tensor_rank); + if (axis < -rank && axis >= rank) { ORT_THROW("invalid axis: ", axis); } - return axis < 0 ? axis + tensorRank : axis; + return SafeInt(axis < 0 ? axis + rank : axis); } -static std::string fillVar(std::string dataType, int components, std::string value) { +static std::string FillVar(std::string dataType, int components, std::string value) { if (components == 1) { return dataType + "(" + value + ")"; } return "vec" + std::to_string(components) + "<" + dataType + ">(" + value + ")"; } -static std::string castToF32(int components, std::string value) { +static std::string CastToF32(int components, std::string value) { if (components == 1) { return "f32(" + value + ")"; } return "vec" + std::to_string(components) + "(" + value + ")"; }; -static std::string sumVector(std::string x, int components) { +static std::string SumVector(std::string x, int components) { switch (components) { case 1: return x; @@ -64,8 +65,8 @@ Status LayerNormProgram::GenerateShaderCode(ShaderHelper& shader) const { std::string bias = (has_bias_) ? " + bias[j] " : ""; std::string simpl1 = (simplified_) ? "" : "- mean * mean "; std::string simpl2 = (simplified_) ? "" : "- mean "; - std::string fillvec = fillVar("f32", components, "0"); - std::string element_type = (isFP16_) ? "f16;\n" : "f32;\n"; + std::string fillvec = FillVar("f32", components, "0"); + std::string element_type = (is_fp16_) ? "f16;\n" : "f32;\n"; shader.AdditionalImplementation() << "alias element_t = " << element_type; @@ -74,15 +75,15 @@ Status LayerNormProgram::GenerateShaderCode(ShaderHelper& shader) const { << "var mean_vector = " << fillvec << ";\n" << "var mean_square_vector = " << fillvec << ";\n" << "for (var h: u32 = 0u; h < uniforms.norm_size_vectorized; h++) {\n" - << " let value = " << castToF32(components, "x[h + offset]") << ";\n" + << " let value = " << CastToF32(components, "x[h + offset]") << ";\n" << " mean_vector += value;\n" << " mean_square_vector += value * value;\n" << "}\n" - << "let mean = " << sumVector("mean_vector", components) << " / f32(uniforms.norm_size);\n" - << "let inv_std_dev = inverseSqrt(" << sumVector("mean_square_vector", components) << " / f32(uniforms.norm_size) " << simpl1 << "+ uniforms.epsilon);\n" + << "let mean = " << SumVector("mean_vector", components) << " / f32(uniforms.norm_size);\n" + << "let inv_std_dev = inverseSqrt(" << SumVector("mean_square_vector", components) << " / f32(uniforms.norm_size) " << simpl1 << "+ uniforms.epsilon);\n" << "for (var j: u32 = 0; j < uniforms.norm_size_vectorized; j++) {\n" - << " let f32input = " << castToF32(components, "x[j + offset]") << ";\n" - << " let f32scale = " << castToF32(components, "scale[j]") << ";\n" + << " let f32input = " << CastToF32(components, "x[j + offset]") << ";\n" + << " let f32scale = " << CastToF32(components, "scale[j]") << ";\n" << " output[j + offset] = x_value_t((f32input " << simpl2 << ") * inv_std_dev * f32scale)" << bias << ";\n" << "}\n"; @@ -107,11 +108,11 @@ Status LayerNorm::ComputeInternal(onnxruntime::webgpu::ComputeContex const bool is_fp16 = x->GetElementType() == ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT16; - const int axis = normalizeAxis(axis_, x_shape.NumDimensions()); - const int norm_count = x_shape.SizeToDimension(axis); - const int norm_size = x_shape.SizeFromDimension(axis); - const int components = getMaxComponents(norm_size); - const int norm_size_vectorized = (norm_size + components - 1) / components; + const size_t axis = NormalizeAxis(axis_, x_shape.NumDimensions()); + const uint32_t norm_count = SafeInt(x_shape.SizeToDimension(axis)); + const int64_t norm_size = x_shape.SizeFromDimension(axis); + const int components = GetMaxComponents(norm_size); + const uint32_t norm_size_vectorized = SafeInt((norm_size + components - 1) / components); const auto scale_size = scale->Shape().Size(); const auto bias_size = (bias) ? bias->Shape().Size() : 0; diff --git a/onnxruntime/core/providers/webgpu/nn/layer_norm.h b/onnxruntime/core/providers/webgpu/nn/layer_norm.h index b7cd2c6478901..e7014a1b80e20 100644 --- a/onnxruntime/core/providers/webgpu/nn/layer_norm.h +++ b/onnxruntime/core/providers/webgpu/nn/layer_norm.h @@ -9,20 +9,22 @@ namespace onnxruntime { namespace webgpu { -using namespace onnxruntime::webgpu; -using onnxruntime::webgpu::ComputeContext; - class LayerNormProgram final : public Program { public: - LayerNormProgram(int64_t axis, float epsilon, int64_t stash_type, bool has_bias, size_t x_size, bool isFP16, bool simplified) : Program{"LayerNorm"} { - axis_ = axis; - epsilon_ = epsilon; - stash_type_ = stash_type; - has_bias_ = has_bias; - x_size_ = x_size; - isFP16_ = isFP16; - simplified_ = simplified; - } + LayerNormProgram(int64_t axis, + float epsilon, + int64_t stash_type, + bool has_bias, + size_t x_size, + bool is_fp16, + bool simplified) : Program{"LayerNorm"}, + axis_{axis}, + epsilon_{epsilon}, + stash_type_{stash_type}, + has_bias_{has_bias}, + x_size_{x_size}, + is_fp16_{is_fp16}, + simplified_{simplified} {} Status GenerateShaderCode(ShaderHelper& sh) const override; @@ -37,8 +39,8 @@ class LayerNormProgram final : public Program { float epsilon_; int64_t stash_type_; bool has_bias_; - int x_size_; - bool isFP16_; + size_t x_size_; + bool is_fp16_; bool simplified_; }; @@ -47,7 +49,7 @@ class LayerNorm final : public WebGpuKernel { public: LayerNorm(const OpKernelInfo& info) : WebGpuKernel(info) { info.GetAttrOrDefault("axis", &axis_, -1); - info.GetAttrOrDefault("epsilon", &epsilon_, 1e-05); + info.GetAttrOrDefault("epsilon", &epsilon_, 1e-05f); info.GetAttrOrDefault("stash_type", &stash_type_, 1); } diff --git a/onnxruntime/test/contrib_ops/layer_norm_test.cc b/onnxruntime/test/contrib_ops/layer_norm_test.cc index 218c48a37de9b..155db602b2f6f 100644 --- a/onnxruntime/test/contrib_ops/layer_norm_test.cc +++ b/onnxruntime/test/contrib_ops/layer_norm_test.cc @@ -84,6 +84,8 @@ static void TestLayerNorm(const std::vector& x_dims, test.CompareWithCPU(kRocmExecutionProvider); #elif USE_DML test.CompareWithCPU(kDmlExecutionProvider); +#elif USE_WEBGPU + test.CompareWithCPU(kWebGpuExecutionProvider); #endif } diff --git a/onnxruntime/test/providers/compare_provider_test_utils.cc b/onnxruntime/test/providers/compare_provider_test_utils.cc index 3ef74259e27b6..386a5656d8a01 100644 --- a/onnxruntime/test/providers/compare_provider_test_utils.cc +++ b/onnxruntime/test/providers/compare_provider_test_utils.cc @@ -36,6 +36,8 @@ std::unique_ptr GetExecutionProvider(const std::string& prov execution_provider = DefaultRocmExecutionProvider(); else if (provider_type == onnxruntime::kDmlExecutionProvider) execution_provider = DefaultDmlExecutionProvider(); + else if (provider_type == onnxruntime::kWebGpuExecutionProvider) + execution_provider = DefaultWebGpuExecutionProvider(); // skip if execution provider is disabled if (execution_provider == nullptr) { return nullptr; From b1463375d113e81dfa04d55d2d4b0735f9da9ef2 Mon Sep 17 00:00:00 2001 From: Yulong Wang <7679871+fs-eire@users.noreply.github.com> Date: Mon, 30 Sep 2024 17:30:07 -0700 Subject: [PATCH 5/7] fix test --- onnxruntime/test/contrib_ops/layer_norm_test.cc | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/onnxruntime/test/contrib_ops/layer_norm_test.cc b/onnxruntime/test/contrib_ops/layer_norm_test.cc index 155db602b2f6f..46082e1b0cd31 100644 --- a/onnxruntime/test/contrib_ops/layer_norm_test.cc +++ b/onnxruntime/test/contrib_ops/layer_norm_test.cc @@ -65,8 +65,8 @@ static void TestLayerNorm(const std::vector& x_dims, std::vector Y_data = FillZeros(n_x_m_dims); test.AddOutput("output", n_x_m_dims, Y_data); -#ifndef USE_DML - // DML doesn't support more than one output for these ops yet +#if !defined(USE_DML) && !defined(USE_WEBGPU) + // DML and WebGPU don't support more than one output for these ops yet const std::vector& stats_dims = keep_dims ? n_and_ones_dims : n_dims; std::vector mean_data = FillZeros(stats_dims); std::vector var_data = FillZeros(stats_dims); From e2809a677fdb5a331fc722df59d10da335fefa27 Mon Sep 17 00:00:00 2001 From: Guenther Schmuelling Date: Mon, 30 Sep 2024 17:33:33 -0700 Subject: [PATCH 6/7] fix mixed T test cases for webgpu --- onnxruntime/test/contrib_ops/layer_norm_op_test.cc | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/onnxruntime/test/contrib_ops/layer_norm_op_test.cc b/onnxruntime/test/contrib_ops/layer_norm_op_test.cc index 655c4951f262d..7fbaf91824240 100644 --- a/onnxruntime/test/contrib_ops/layer_norm_op_test.cc +++ b/onnxruntime/test/contrib_ops/layer_norm_op_test.cc @@ -120,7 +120,7 @@ TEST(LayerNormTest, LayerNorm_Scale_Float16Input) { // TRT, DNNL, OpenVINO and NNAPI, CoreML don't support this combination of datatypes test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider, kDnnlExecutionProvider, kOpenVINOExecutionProvider, - kNnapiExecutionProvider, kQnnExecutionProvider, kCoreMLExecutionProvider}); + kNnapiExecutionProvider, kQnnExecutionProvider, kCoreMLExecutionProvider, kWebGpuExecutionProvider}); } TEST(LayerNormTest, LayerNorm_Scale_Float16ScaleOutput) { @@ -134,7 +134,7 @@ TEST(LayerNormTest, LayerNorm_Scale_Float16ScaleOutput) { // TRT, DNNL, OpenVINO and NNAPI, CoreML don't support this combination of datatypes test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider, kDnnlExecutionProvider, kOpenVINOExecutionProvider, - kNnapiExecutionProvider, kQnnExecutionProvider, kCoreMLExecutionProvider}); + kNnapiExecutionProvider, kQnnExecutionProvider, kCoreMLExecutionProvider, kWebGpuExecutionProvider}); } TEST(LayerNormTest, LayerNorm_Scale_Float16InputScaleOutput) { @@ -178,7 +178,7 @@ TEST(LayerNormTest, LayerNorm_Scale_Bias_Float16Input) { // TRT, DNNL, OpenVINO and NNAPI, CoreML don't support this combination of datatypes test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider, kDnnlExecutionProvider, kQnnExecutionProvider, - kOpenVINOExecutionProvider, kNnapiExecutionProvider, kCoreMLExecutionProvider}); + kOpenVINOExecutionProvider, kNnapiExecutionProvider, kCoreMLExecutionProvider, kWebGpuExecutionProvider}); } TEST(LayerNormTest, LayerNorm_Scale_Bias_Float16ScaleBiasOutput) { @@ -193,7 +193,7 @@ TEST(LayerNormTest, LayerNorm_Scale_Bias_Float16ScaleBiasOutput) { // TRT, DNNL, OpenVINO and NNAPI, CoreML don't support this combination of datatypes test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider, kDnnlExecutionProvider, kOpenVINOExecutionProvider, - kNnapiExecutionProvider, kQnnExecutionProvider, kCoreMLExecutionProvider}); + kNnapiExecutionProvider, kQnnExecutionProvider, kCoreMLExecutionProvider, kWebGpuExecutionProvider}); } TEST(LayerNormTest, LayerNorm_Scale_Bias_Float16InputScaleBiasOutput) { @@ -208,7 +208,7 @@ TEST(LayerNormTest, LayerNorm_Scale_Bias_Float16InputScaleBiasOutput) { // TRT, DNNL, OpenVINO and NNAPI, CoreML don't support this combination of datatypes test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider, kDnnlExecutionProvider, kOpenVINOExecutionProvider, - kNnapiExecutionProvider, kQnnExecutionProvider, kCoreMLExecutionProvider}); + kNnapiExecutionProvider, kQnnExecutionProvider, kCoreMLExecutionProvider, kWebGpuExecutionProvider}); } // LayerNormalization became an ONNX operator in opset 17. It uses the same implementation so this is a sanity check. From e46af18b3bd7406a902b21371fe315612f794931 Mon Sep 17 00:00:00 2001 From: Yulong Wang <7679871+fs-eire@users.noreply.github.com> Date: Mon, 30 Sep 2024 18:31:03 -0700 Subject: [PATCH 7/7] revise string operation --- .../webgpu/{ => bert}/layer_norm.cc | 0 .../webgpu/{ => bert}/skip_layer_norm.cc | 115 +++++++----------- .../webgpu/{ => bert}/skip_layer_norm.h | 0 .../core/providers/webgpu/nn/layer_norm.cc | 64 ++++------ 4 files changed, 70 insertions(+), 109 deletions(-) rename onnxruntime/contrib_ops/webgpu/{ => bert}/layer_norm.cc (100%) rename onnxruntime/contrib_ops/webgpu/{ => bert}/skip_layer_norm.cc (54%) rename onnxruntime/contrib_ops/webgpu/{ => bert}/skip_layer_norm.h (100%) diff --git a/onnxruntime/contrib_ops/webgpu/layer_norm.cc b/onnxruntime/contrib_ops/webgpu/bert/layer_norm.cc similarity index 100% rename from onnxruntime/contrib_ops/webgpu/layer_norm.cc rename to onnxruntime/contrib_ops/webgpu/bert/layer_norm.cc diff --git a/onnxruntime/contrib_ops/webgpu/skip_layer_norm.cc b/onnxruntime/contrib_ops/webgpu/bert/skip_layer_norm.cc similarity index 54% rename from onnxruntime/contrib_ops/webgpu/skip_layer_norm.cc rename to onnxruntime/contrib_ops/webgpu/bert/skip_layer_norm.cc index 8e961eba47a2e..fb955b45f6948 100644 --- a/onnxruntime/contrib_ops/webgpu/skip_layer_norm.cc +++ b/onnxruntime/contrib_ops/webgpu/bert/skip_layer_norm.cc @@ -3,14 +3,14 @@ #include "core/providers/webgpu/shader_helper.h" #include "core/providers/webgpu/webgpu_supported_types.h" -#include "contrib_ops/webgpu/skip_layer_norm.h" #include "contrib_ops/webgpu/webgpu_contrib_kernels.h" +#include "contrib_ops/webgpu/bert/skip_layer_norm.h" namespace onnxruntime { namespace contrib { namespace webgpu { -static uint32_t getMaxComponents(int size) { +static uint32_t GetMaxComponents(int size) { if (size % 4 == 0) { return 4; } else if (size % 2 == 0) { @@ -19,14 +19,7 @@ static uint32_t getMaxComponents(int size) { return 1; } -static std::string fillVar(std::string dataType, int components, std::string value) { - if (components == 1) { - return dataType + "(" + value + ")"; - } - return "vec" + std::to_string(components) + "<" + dataType + ">(" + value + ")"; -} - -static std::string sumVector(std::string x, int components) { +static std::string SumVector(std::string x, int components) { switch (components) { case 1: return x; @@ -39,20 +32,6 @@ static std::string sumVector(std::string x, int components) { } } -static std::string castToF32(int components, std::string value) { - if (components == 1) { - return "f32(" + value + ")"; - } - return "vec" + std::to_string(components) + "(" + value + ")"; -}; - -static std::string vecDataType(std::string datatype, int components) { - if (components == 1) { - return datatype; - } - return "vec" + std::to_string(components) + "<" + datatype + ">"; -}; - Status SkipLayerNormProgram::GenerateShaderCode(ShaderHelper& shader) const { const auto& x = shader.AddInput("x", ShaderUsage::UseUniform | ShaderUsage::UseValueTypeAlias); shader.AddInput("skip", ShaderUsage::UseUniform); @@ -70,53 +49,51 @@ Status SkipLayerNormProgram::GenerateShaderCode(ShaderHelper& shader) const { std::string bias = (hasBias_) ? " + bias[offset1d + i] " : ""; std::string simpl1 = (simplified_) ? "" : "- mean * mean "; std::string simpl2 = (simplified_) ? "" : "- element_t(mean) "; - std::string fillvec = fillVar("f32", components, "0"); std::string beta = (hasBeta_) ? " + beta[offset1d + i] " : ""; - std::string element_type = (is_fp16_) ? "f16;\n" : "f32;\n"; shader.AdditionalImplementation() - << "alias element_t = " << element_type - << "var sum_shared : array<" << vecDataType("f32", components) << ",workgroup_size_x>;\n" - << "var sum_squared_shared : array<" << vecDataType("f32", components) << ",workgroup_size_x>;\n"; - - std::stringstream ss; - ss << "let ix = local_idx;\n" - << "let iy = global_idx / workgroup_size_x;\n" - << "let hidden_size_vectorized: u32 = uniforms.hidden_size / uniforms.components;\n" - << "var stride = hidden_size_vectorized / workgroup_size_x;\n" - << "let offset = ix * stride + iy * hidden_size_vectorized;\n" - << "let offset1d = stride * ix;\n" - << "if (ix == workgroup_size_x - 1) {\n" - << " stride = hidden_size_vectorized - stride * ix;\n" - << "}\n" - << "for (var i: u32 = 0; i < stride; i++) {\n" - << " let skip_value = skip[offset + i];\n" - << " let input_value = x[offset + i];\n" - << " let value = input_value + skip_value" << bias << ";\n" - << " output[offset + i] = value;\n" - << " let f32_value = " << castToF32(components, "value") << ";\n" - << " sum_shared[ix] += f32_value;\n" - << " sum_squared_shared[ix] += f32_value * f32_value;\n" - << "}\n" - << "workgroupBarrier();\n" - << "var reduce_size : u32 = workgroup_size_x;\n" - << "for (var curr_size = reduce_size >> 1; curr_size > 0; curr_size = reduce_size >> 1) {\n" - << " reduce_size = curr_size + (reduce_size & 1);\n" - << " if (ix < curr_size) {\n" - << " sum_shared[ix] += sum_shared[ix + reduce_size];\n" - << " sum_squared_shared[ix] += sum_squared_shared[ix + reduce_size];\n" - << " }\n" - << " workgroupBarrier();\n" - << "}\n" - << "let sum = sum_shared[0];\n" - << "let square_sum = sum_squared_shared[0];\n" - << "let mean = " << sumVector("sum", components) << " / f32(uniforms.hidden_size);\n" - << "let inv_std_dev = inverseSqrt(" << sumVector("square_sum", components) << " / f32(uniforms.hidden_size) " << simpl1 << "+ uniforms.epsilon);\n" - << "for (var i: u32 = 0; i < stride; i++) {\n" - << " output[offset + i] = (output[offset + i] " << simpl2 << ") * element_t(inv_std_dev) * gamma[offset1d + i]" << beta << ";\n" - << "};\n"; - - shader.MainFunctionBody() << ss.str(); + << "alias element_t = " << (is_fp16_ ? "f16;\n" : "f32;\n") + << "alias f32_val_t = " << (components == 4 ? "vec4" : (components == 2 ? "vec2" : "f32")) << ";\n" + << "var sum_shared : array;\n" + << "var sum_squared_shared : array;\n"; + + shader.MainFunctionBody() + << "let ix = local_idx;\n" + << "let iy = global_idx / workgroup_size_x;\n" + << "let hidden_size_vectorized: u32 = uniforms.hidden_size / uniforms.components;\n" + << "var stride = hidden_size_vectorized / workgroup_size_x;\n" + << "let offset = ix * stride + iy * hidden_size_vectorized;\n" + << "let offset1d = stride * ix;\n" + << "if (ix == workgroup_size_x - 1) {\n" + << " stride = hidden_size_vectorized - stride * ix;\n" + << "}\n" + << "for (var i: u32 = 0; i < stride; i++) {\n" + << " let skip_value = skip[offset + i];\n" + << " let input_value = x[offset + i];\n" + << " let value = input_value + skip_value" << bias << ";\n" + << " output[offset + i] = value;\n" + << " let f32_value = f32_val_t(value);\n" + << " sum_shared[ix] += f32_value;\n" + << " sum_squared_shared[ix] += f32_value * f32_value;\n" + << "}\n" + << "workgroupBarrier();\n" + << "var reduce_size : u32 = workgroup_size_x;\n" + << "for (var curr_size = reduce_size >> 1; curr_size > 0; curr_size = reduce_size >> 1) {\n" + << " reduce_size = curr_size + (reduce_size & 1);\n" + << " if (ix < curr_size) {\n" + << " sum_shared[ix] += sum_shared[ix + reduce_size];\n" + << " sum_squared_shared[ix] += sum_squared_shared[ix + reduce_size];\n" + << " }\n" + << " workgroupBarrier();\n" + << "}\n" + << "let sum = sum_shared[0];\n" + << "let square_sum = sum_squared_shared[0];\n" + << "let mean = " << SumVector("sum", components) << " / f32(uniforms.hidden_size);\n" + << "let inv_std_dev = inverseSqrt(" << SumVector("square_sum", components) << " / f32(uniforms.hidden_size) " << simpl1 << "+ uniforms.epsilon);\n" + << "for (var i: u32 = 0; i < stride; i++) {\n" + << " output[offset + i] = (output[offset + i] " << simpl2 << ") * element_t(inv_std_dev) * gamma[offset1d + i]" << beta << ";\n" + << "};\n"; + return Status::OK(); } @@ -140,7 +117,7 @@ Status SkipLayerNorm::ComputeInternal(onnxruntime::webgpu::ComputeCo const bool is_fp16 = x->GetElementType() == ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT16; const uint32_t hidden_size = SafeInt(x_shape[x_shape.NumDimensions() - 1]); - const int components = getMaxComponents(hidden_size); + const int components = GetMaxComponents(hidden_size); SkipLayerNormProgram program{beta != nullptr, bias != nullptr, epsilon_, hidden_size, is_fp16, simplified}; program diff --git a/onnxruntime/contrib_ops/webgpu/skip_layer_norm.h b/onnxruntime/contrib_ops/webgpu/bert/skip_layer_norm.h similarity index 100% rename from onnxruntime/contrib_ops/webgpu/skip_layer_norm.h rename to onnxruntime/contrib_ops/webgpu/bert/skip_layer_norm.h diff --git a/onnxruntime/core/providers/webgpu/nn/layer_norm.cc b/onnxruntime/core/providers/webgpu/nn/layer_norm.cc index d21067897c06b..5e4d5b7ad10e6 100644 --- a/onnxruntime/core/providers/webgpu/nn/layer_norm.cc +++ b/onnxruntime/core/providers/webgpu/nn/layer_norm.cc @@ -26,20 +26,6 @@ static size_t NormalizeAxis(int64_t axis, size_t tensor_rank) { return SafeInt(axis < 0 ? axis + rank : axis); } -static std::string FillVar(std::string dataType, int components, std::string value) { - if (components == 1) { - return dataType + "(" + value + ")"; - } - return "vec" + std::to_string(components) + "<" + dataType + ">(" + value + ")"; -} - -static std::string CastToF32(int components, std::string value) { - if (components == 1) { - return "f32(" + value + ")"; - } - return "vec" + std::to_string(components) + "(" + value + ")"; -}; - static std::string SumVector(std::string x, int components) { switch (components) { case 1: @@ -62,32 +48,30 @@ Status LayerNormProgram::GenerateShaderCode(ShaderHelper& shader) const { shader.AddOutput("output", ShaderUsage::UseUniform); int components = x.NumComponents(); - std::string bias = (has_bias_) ? " + bias[j] " : ""; - std::string simpl1 = (simplified_) ? "" : "- mean * mean "; - std::string simpl2 = (simplified_) ? "" : "- mean "; - std::string fillvec = FillVar("f32", components, "0"); - std::string element_type = (is_fp16_) ? "f16;\n" : "f32;\n"; - - shader.AdditionalImplementation() << "alias element_t = " << element_type; - - std::stringstream ss; - ss << "let offset = global_idx * uniforms.norm_size_vectorized;\n" - << "var mean_vector = " << fillvec << ";\n" - << "var mean_square_vector = " << fillvec << ";\n" - << "for (var h: u32 = 0u; h < uniforms.norm_size_vectorized; h++) {\n" - << " let value = " << CastToF32(components, "x[h + offset]") << ";\n" - << " mean_vector += value;\n" - << " mean_square_vector += value * value;\n" - << "}\n" - << "let mean = " << SumVector("mean_vector", components) << " / f32(uniforms.norm_size);\n" - << "let inv_std_dev = inverseSqrt(" << SumVector("mean_square_vector", components) << " / f32(uniforms.norm_size) " << simpl1 << "+ uniforms.epsilon);\n" - << "for (var j: u32 = 0; j < uniforms.norm_size_vectorized; j++) {\n" - << " let f32input = " << CastToF32(components, "x[j + offset]") << ";\n" - << " let f32scale = " << CastToF32(components, "scale[j]") << ";\n" - << " output[j + offset] = x_value_t((f32input " << simpl2 << ") * inv_std_dev * f32scale)" << bias << ";\n" - << "}\n"; - - shader.MainFunctionBody() << shader.GuardAgainstOutOfBoundsWorkgroupSizes("uniforms.norm_count") << ss.str(); + std::string bias = (has_bias_) ? " + bias[j]" : ""; + std::string simpl1 = (simplified_) ? "" : " - mean * mean"; + std::string simpl2 = (simplified_) ? "" : " - mean"; + + shader.AdditionalImplementation() << "alias element_t = " << (is_fp16_ ? "f16;\n" : "f32;\n") + << "alias f32_val_t = " << (components == 4 ? "vec4" : (components == 2 ? "vec2" : "f32")) << ";\n"; + + shader.MainFunctionBody() << shader.GuardAgainstOutOfBoundsWorkgroupSizes("uniforms.norm_count") + << "let offset = global_idx * uniforms.norm_size_vectorized;\n" + << "var mean_vector = f32_val_t(0);\n" + << "var mean_square_vector = f32_val_t(0);\n" + << "for (var h: u32 = 0u; h < uniforms.norm_size_vectorized; h++) {\n" + << " let value = f32_val_t(x[h + offset]);\n" + << " mean_vector += value;\n" + << " mean_square_vector += value * value;\n" + << "}\n" + << "let mean = " << SumVector("mean_vector", components) << " / f32(uniforms.norm_size);\n" + << "let inv_std_dev = inverseSqrt(" << SumVector("mean_square_vector", components) << " / f32(uniforms.norm_size)" << simpl1 << " + uniforms.epsilon);\n" + << "for (var j: u32 = 0; j < uniforms.norm_size_vectorized; j++) {\n" + << " let f32input = f32_val_t(x[j + offset]);\n" + << " let f32scale = f32_val_t(scale[j]);\n" + << " output[j + offset] = x_value_t((f32input" << simpl2 << ") * inv_std_dev * f32scale)" << bias << ";\n" + << "}\n"; + return Status::OK(); }