From 5bb4e9e7b231cb206f964a2de11d7e0ce6afbc6c Mon Sep 17 00:00:00 2001 From: Qin Jiajia Date: Wed, 4 Sep 2024 12:36:02 +0800 Subject: [PATCH] Add FastGelu op --- .../contrib_ops/webgpu/bert/fast_gelu.cc | 78 +++++++++++++++++++ .../contrib_ops/webgpu/bert/fast_gelu.h | 37 +++++++++ .../webgpu/webgpu_contrib_kernels.cc | 42 +++++----- .../webgpu/webgpu_contrib_kernels.h | 5 +- .../test/contrib_ops/fastgelu_op_test.cc | 8 +- 5 files changed, 145 insertions(+), 25 deletions(-) create mode 100644 onnxruntime/contrib_ops/webgpu/bert/fast_gelu.cc create mode 100644 onnxruntime/contrib_ops/webgpu/bert/fast_gelu.h diff --git a/onnxruntime/contrib_ops/webgpu/bert/fast_gelu.cc b/onnxruntime/contrib_ops/webgpu/bert/fast_gelu.cc new file mode 100644 index 0000000000000..bdfe345590447 --- /dev/null +++ b/onnxruntime/contrib_ops/webgpu/bert/fast_gelu.cc @@ -0,0 +1,78 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "fast_gelu.h" +#include "core/providers/webgpu/shader_helper.h" +#include "core/providers/webgpu/webgpu_supported_types.h" +#include "contrib_ops/webgpu/webgpu_contrib_kernels.h" + +namespace onnxruntime { +namespace contrib { +namespace webgpu { + +ONNX_OPERATOR_KERNEL_EX( + FastGelu, + kMSDomain, + 1, + kWebGpuExecutionProvider, + (*KernelDefBuilder::Create()) + .TypeConstraint("T", WebGpuSupportedFloatTypes()), + FastGelu); + +Status FastGeluProgram::GenerateShaderCode(ShaderHelper& shader) const { + const auto& input = shader.AddInput("input", + ToProgramVariableDataType(Inputs()[0].tensor->GetElementType(), 4), + ShaderVariable::UseUniform | ShaderVariable::UseValueTypeAlias); + + const auto& bias = Inputs().size() > 1 ? shader.AddInput("bias", + ToProgramVariableDataType(Inputs()[1].tensor->GetElementType(), bias_components_), + ShaderVariable::UseUniform | ShaderVariable::UseShapeAndStride) + : input; + + const auto& output = shader.AddOutput("output", + ToProgramVariableDataType(Outputs()[0].tensor->GetElementType(), 4), + ShaderVariable::UseUniform); + const std::string& get_bias = bias_components_ == 1 ? "let x_offset = global_idx * 4;\nlet bias = input_value_t(" + bias.GetByOffset("x_offset % uniforms.bias_shape") + ", " + bias.GetByOffset("(x_offset + 1) % uniforms.bias_shape") + ", " + bias.GetByOffset("(x_offset + 2) % uniforms.bias_shape") + ", " + bias.GetByOffset("(x_offset + 3) % uniforms.bias_shape") + ") " : " let bias = " + bias.GetByOffset(" global_idx % uniforms.bias_shape "); + const std::string& add_bias = Inputs().size() > 1 ? get_bias + ";\n x += bias;\n" : ""; + shader.MainFunctionBody(shader.GuardAgainstOutOfBoundsWorkgroupSizes("uniforms.vec_size"), + "var x = ", input.GetByOffset("global_idx"), ";\n", add_bias, + "let y = x * (0.5 + 0.5 * tanh(x * (0.035677408136300125 * x * x + 0.7978845608028654)));\n", + output.SetByOffset("global_idx", "y")); + + return Status::OK(); +} + +Status FastGelu::ComputeInternal(onnxruntime::webgpu::ComputeContext& context) const { + const auto* input = context.Input(0); + const auto* bias = context.Input(1); + auto* output = context.Output(0, input->Shape()); + + uint32_t data_size = SafeInt(output->Shape().Size()); + if (data_size == 0) { + return Status::OK(); + } + + const auto vec_size = (data_size + 3) / 4; + uint32_t bias_size = nullptr == bias ? 0 : SafeInt(bias->Shape().Size()); + int bias_components = 1; + if (bias != nullptr && bias_size % 4 == 0) { + bias_components = 4; + bias_size = bias_size / 4; + } + FastGeluProgram program{"FastGelu", bias_components}; + if (nullptr == bias) { + program.Inputs({{input, ProgramTensorMetadataDependency::Type, {vec_size}}}); + } else { + program.Inputs({{input, ProgramTensorMetadataDependency::Type, {{vec_size}}}, {bias, ProgramTensorMetadataDependency::TypeAndRank, {bias_size}}}); + } + program + .Outputs({{output, ProgramTensorMetadataDependency::None, {vec_size}}}) + .DispatchGroupSize((vec_size + WORKGROUP_SIZE - 1) / WORKGROUP_SIZE) + .UniformVariables({{vec_size}}) + .CacheHint(std::to_string(bias_components)); + return context.RunProgram(program); +} + +} // namespace webgpu +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/webgpu/bert/fast_gelu.h b/onnxruntime/contrib_ops/webgpu/bert/fast_gelu.h new file mode 100644 index 0000000000000..16889c4f70c62 --- /dev/null +++ b/onnxruntime/contrib_ops/webgpu/bert/fast_gelu.h @@ -0,0 +1,37 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include "core/providers/webgpu/program.h" +#include "core/providers/webgpu/webgpu_kernel.h" + +namespace onnxruntime { +namespace contrib { +namespace webgpu { + +using namespace onnxruntime::webgpu; + +class FastGeluProgram final : public Program { + public: + FastGeluProgram(const std::string& kernel_name, int bias_components) : Program{kernel_name}, bias_components_{bias_components} { + } + + Status GenerateShaderCode(ShaderHelper& sh) const override; + + WEBGPU_PROGRAM_DEFINE_UNIFORM_VARIABLES({"vec_size", ProgramUniformVariableDataType::Uint32}); + + private: + int bias_components_; +}; + +class FastGelu final : public WebGpuKernel { + public: + FastGelu(const OpKernelInfo& info) : WebGpuKernel(info) {} + + Status ComputeInternal(onnxruntime::webgpu::ComputeContext& context) const override; +}; + +} // namespace webgpu +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/webgpu/webgpu_contrib_kernels.cc b/onnxruntime/contrib_ops/webgpu/webgpu_contrib_kernels.cc index 91f51df588fca..def104b6cb108 100644 --- a/onnxruntime/contrib_ops/webgpu/webgpu_contrib_kernels.cc +++ b/onnxruntime/contrib_ops/webgpu/webgpu_contrib_kernels.cc @@ -26,11 +26,11 @@ class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kMSDomain, 1, Sk class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 1, SimplifiedLayerNormalization); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kMSDomain, 1, SkipSimplifiedLayerNormalization); -// template <> -// KernelCreateInfo BuildKernelCreateInfo() { -// KernelCreateInfo info; -// return info; -// } +template <> +KernelCreateInfo BuildKernelCreateInfo() { + KernelCreateInfo info; + return info; +} Status RegisterWebGpuContribKernels(KernelRegistry& kernel_registry) { static const BuildKernelCreateInfoFn function_table[] = { @@ -38,22 +38,22 @@ Status RegisterWebGpuContribKernels(KernelRegistry& kernel_registry) { // BuildKernelCreateInfo, // BuildKernelCreateInfo, // BuildKernelCreateInfo, - // BuildKernelCreateInfo, - // BuildKernelCreateInfo, - // BuildKernelCreateInfo, - // BuildKernelCreateInfo, - // // LayerNormalization used to be a contrib op that (incorrectly) used kOnnxDomain so we need to version it - // BuildKernelCreateInfo, - // BuildKernelCreateInfo, - // BuildKernelCreateInfo, - // BuildKernelCreateInfo, - // BuildKernelCreateInfo, - // BuildKernelCreateInfo, - // BuildKernelCreateInfo, - // BuildKernelCreateInfo + BuildKernelCreateInfo, + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + // // LayerNormalization used to be a contrib op that (incorrectly) used kOnnxDomain so we need to version it + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + // BuildKernelCreateInfo }; for (auto& function_table_entry : function_table) { diff --git a/onnxruntime/contrib_ops/webgpu/webgpu_contrib_kernels.h b/onnxruntime/contrib_ops/webgpu/webgpu_contrib_kernels.h index 6cdf7382804f9..d73859de78239 100644 --- a/onnxruntime/contrib_ops/webgpu/webgpu_contrib_kernels.h +++ b/onnxruntime/contrib_ops/webgpu/webgpu_contrib_kernels.h @@ -3,13 +3,16 @@ #pragma once -#include "core/framework/op_kernel.h" #include "core/framework/kernel_registry.h" namespace onnxruntime { namespace contrib { namespace webgpu { +// forward declaration for this EP's namespace. +template +KernelCreateInfo BuildKernelCreateInfo(); + Status RegisterWebGpuContribKernels(KernelRegistry& kernel_registry); } // namespace webgpu diff --git a/onnxruntime/test/contrib_ops/fastgelu_op_test.cc b/onnxruntime/test/contrib_ops/fastgelu_op_test.cc index 5cf749dc4c97c..fd02c84ed1788 100644 --- a/onnxruntime/test/contrib_ops/fastgelu_op_test.cc +++ b/onnxruntime/test/contrib_ops/fastgelu_op_test.cc @@ -41,7 +41,7 @@ const std::vector GetExpectedResult(const std::vector& input_data, return ComputeGelu(add_bias_data); } -#if defined(USE_CUDA) || defined(USE_ROCM) +#if defined(USE_CUDA) || defined(USE_ROCM) || defined(USE_WEBGPU) static void RunFastGeluGpuTest(const std::vector& input_data, const std::vector& bias_data, const std::vector& output_data, const std::vector& input_dims, const std::vector& bias_dims, const std::vector& output_dims, @@ -75,6 +75,8 @@ static void RunFastGeluGpuTest(const std::vector& input_data, const std:: execution_providers.push_back(DefaultCudaExecutionProvider()); #elif USE_ROCM execution_providers.push_back(DefaultRocmExecutionProvider()); +#elif USE_WEBGPU + execution_providers.push_back(DefaultWebGpuExecutionProvider()); #endif tester.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers); } @@ -142,7 +144,7 @@ static void RunFastGeluTest( std::vector input_dims = {batch_size, sequence_length, hidden_size}; std::vector bias_dims = {hidden_size}; std::vector output_dims = input_dims; -#if defined(USE_CUDA) || defined(USE_ROCM) +#if defined(USE_CUDA) || defined(USE_ROCM) || defined(USE_WEBGPU) RunFastGeluGpuTest(input_data, bias_data, output_data, input_dims, bias_dims, output_dims, has_bias); #endif RunFastGeluCpuTest(input_data, bias_data, output_data, input_dims, bias_dims, output_dims, has_bias); @@ -246,7 +248,7 @@ TEST(FastGeluTest, FastGeluWithoutBiasFloat32) { } // CUDA and ROCm only for Float16 and BFloat16 type. -#if defined(USE_CUDA) || defined(USE_ROCM) +#if defined(USE_CUDA) || defined(USE_ROCM) || defined(USE_WEBGPU) TEST(FastGeluTest, FastGeluWithBiasFloat16_2) { int batch_size = 1; int sequence_length = 2;