Skip to content

Commit

Permalink
Add FastGelu op
Browse files Browse the repository at this point in the history
  • Loading branch information
qjia7 committed Sep 8, 2024
1 parent a3244ae commit 5bb4e9e
Show file tree
Hide file tree
Showing 5 changed files with 145 additions and 25 deletions.
78 changes: 78 additions & 0 deletions onnxruntime/contrib_ops/webgpu/bert/fast_gelu.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

#include "fast_gelu.h"
#include "core/providers/webgpu/shader_helper.h"
#include "core/providers/webgpu/webgpu_supported_types.h"
#include "contrib_ops/webgpu/webgpu_contrib_kernels.h"

namespace onnxruntime {
namespace contrib {
namespace webgpu {

ONNX_OPERATOR_KERNEL_EX(
FastGelu,
kMSDomain,
1,
kWebGpuExecutionProvider,
(*KernelDefBuilder::Create())
.TypeConstraint("T", WebGpuSupportedFloatTypes()),
FastGelu);

Status FastGeluProgram::GenerateShaderCode(ShaderHelper& shader) const {
const auto& input = shader.AddInput("input",
ToProgramVariableDataType(Inputs()[0].tensor->GetElementType(), 4),
ShaderVariable::UseUniform | ShaderVariable::UseValueTypeAlias);

const auto& bias = Inputs().size() > 1 ? shader.AddInput("bias",
ToProgramVariableDataType(Inputs()[1].tensor->GetElementType(), bias_components_),
ShaderVariable::UseUniform | ShaderVariable::UseShapeAndStride)
: input;

const auto& output = shader.AddOutput("output",
ToProgramVariableDataType(Outputs()[0].tensor->GetElementType(), 4),
ShaderVariable::UseUniform);
const std::string& get_bias = bias_components_ == 1 ? "let x_offset = global_idx * 4;\nlet bias = input_value_t(" + bias.GetByOffset("x_offset % uniforms.bias_shape") + ", " + bias.GetByOffset("(x_offset + 1) % uniforms.bias_shape") + ", " + bias.GetByOffset("(x_offset + 2) % uniforms.bias_shape") + ", " + bias.GetByOffset("(x_offset + 3) % uniforms.bias_shape") + ") " : " let bias = " + bias.GetByOffset(" global_idx % uniforms.bias_shape ");
const std::string& add_bias = Inputs().size() > 1 ? get_bias + ";\n x += bias;\n" : "";
shader.MainFunctionBody(shader.GuardAgainstOutOfBoundsWorkgroupSizes("uniforms.vec_size"),
"var x = ", input.GetByOffset("global_idx"), ";\n", add_bias,
"let y = x * (0.5 + 0.5 * tanh(x * (0.035677408136300125 * x * x + 0.7978845608028654)));\n",
output.SetByOffset("global_idx", "y"));

return Status::OK();
}

Status FastGelu::ComputeInternal(onnxruntime::webgpu::ComputeContext& context) const {
const auto* input = context.Input(0);
const auto* bias = context.Input(1);
auto* output = context.Output(0, input->Shape());

uint32_t data_size = SafeInt<uint32_t>(output->Shape().Size());
if (data_size == 0) {
return Status::OK();
}

const auto vec_size = (data_size + 3) / 4;
uint32_t bias_size = nullptr == bias ? 0 : SafeInt<uint32_t>(bias->Shape().Size());
int bias_components = 1;
if (bias != nullptr && bias_size % 4 == 0) {
bias_components = 4;
bias_size = bias_size / 4;
}
FastGeluProgram program{"FastGelu", bias_components};
if (nullptr == bias) {
program.Inputs({{input, ProgramTensorMetadataDependency::Type, {vec_size}}});
} else {
program.Inputs({{input, ProgramTensorMetadataDependency::Type, {{vec_size}}}, {bias, ProgramTensorMetadataDependency::TypeAndRank, {bias_size}}});
}
program
.Outputs({{output, ProgramTensorMetadataDependency::None, {vec_size}}})
.DispatchGroupSize((vec_size + WORKGROUP_SIZE - 1) / WORKGROUP_SIZE)
.UniformVariables({{vec_size}})
.CacheHint(std::to_string(bias_components));
return context.RunProgram(program);
}

} // namespace webgpu
} // namespace contrib
} // namespace onnxruntime
37 changes: 37 additions & 0 deletions onnxruntime/contrib_ops/webgpu/bert/fast_gelu.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

#pragma once

#include "core/providers/webgpu/program.h"
#include "core/providers/webgpu/webgpu_kernel.h"

namespace onnxruntime {
namespace contrib {
namespace webgpu {

using namespace onnxruntime::webgpu;

class FastGeluProgram final : public Program<FastGeluProgram> {
public:
FastGeluProgram(const std::string& kernel_name, int bias_components) : Program{kernel_name}, bias_components_{bias_components} {
}

Status GenerateShaderCode(ShaderHelper& sh) const override;

WEBGPU_PROGRAM_DEFINE_UNIFORM_VARIABLES({"vec_size", ProgramUniformVariableDataType::Uint32});

private:
int bias_components_;
};

class FastGelu final : public WebGpuKernel {
public:
FastGelu(const OpKernelInfo& info) : WebGpuKernel(info) {}

Status ComputeInternal(onnxruntime::webgpu::ComputeContext& context) const override;
};

} // namespace webgpu
} // namespace contrib
} // namespace onnxruntime
42 changes: 21 additions & 21 deletions onnxruntime/contrib_ops/webgpu/webgpu_contrib_kernels.cc
Original file line number Diff line number Diff line change
Expand Up @@ -26,34 +26,34 @@ class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kMSDomain, 1, Sk
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 1, SimplifiedLayerNormalization);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kMSDomain, 1, SkipSimplifiedLayerNormalization);

// template <>
// KernelCreateInfo BuildKernelCreateInfo<void>() {
// KernelCreateInfo info;
// return info;
// }
template <>
KernelCreateInfo BuildKernelCreateInfo<void>() {
KernelCreateInfo info;
return info;
}

Status RegisterWebGpuContribKernels(KernelRegistry& kernel_registry) {
static const BuildKernelCreateInfoFn function_table[] = {
BuildKernelCreateInfo<void>, // default entry to avoid the list become empty after ops-reducing
// BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kMSDomain, 1, Attention)>,
// BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kMSDomain, 1, BiasAdd)>,
// BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kMSDomain, 1, BiasSplitGelu)>,
// BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kMSDomain, 1, FastGelu)>,
// BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kMSDomain, 1, FusedConv)>,
// BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kMSDomain, 1, Gelu)>,
// BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kMSDomain, 1, GroupQueryAttention)>,
// // LayerNormalization used to be a contrib op that (incorrectly) used kOnnxDomain so we need to version it
// BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 1, 16, LayerNormalization)>,
// BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kMSDomain, 1, MatMulNBits)>,
// BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kMSDomain, 1, MultiHeadAttention)>,
// BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kMSDomain, 1, QuickGelu)>,
// BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kMSDomain, 1, RotaryEmbedding)>,
// BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kMSDomain, 1,
// SkipLayerNormalization)>,
// BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 1,
// SimplifiedLayerNormalization)>,
// BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kMSDomain, 1,
// SkipSimplifiedLayerNormalization)>
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kMSDomain, 1, FastGelu)>,
// BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kMSDomain, 1, FusedConv)>,
// BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kMSDomain, 1, Gelu)>,
// BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kMSDomain, 1, GroupQueryAttention)>,
// // LayerNormalization used to be a contrib op that (incorrectly) used kOnnxDomain so we need to version it
// BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 1, 16, LayerNormalization)>,
// BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kMSDomain, 1, MatMulNBits)>,
// BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kMSDomain, 1, MultiHeadAttention)>,
// BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kMSDomain, 1, QuickGelu)>,
// BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kMSDomain, 1, RotaryEmbedding)>,
// BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kMSDomain, 1,
// SkipLayerNormalization)>,
// BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 1,
// SimplifiedLayerNormalization)>,
// BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kMSDomain, 1,
// SkipSimplifiedLayerNormalization)>
};

for (auto& function_table_entry : function_table) {
Expand Down
5 changes: 4 additions & 1 deletion onnxruntime/contrib_ops/webgpu/webgpu_contrib_kernels.h
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,16 @@

#pragma once

#include "core/framework/op_kernel.h"
#include "core/framework/kernel_registry.h"

namespace onnxruntime {
namespace contrib {
namespace webgpu {

// forward declaration for this EP's namespace.
template <typename T>
KernelCreateInfo BuildKernelCreateInfo();

Status RegisterWebGpuContribKernels(KernelRegistry& kernel_registry);

} // namespace webgpu
Expand Down
8 changes: 5 additions & 3 deletions onnxruntime/test/contrib_ops/fastgelu_op_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ const std::vector<float> GetExpectedResult(const std::vector<float>& input_data,
return ComputeGelu(add_bias_data);
}

#if defined(USE_CUDA) || defined(USE_ROCM)
#if defined(USE_CUDA) || defined(USE_ROCM) || defined(USE_WEBGPU)
static void RunFastGeluGpuTest(const std::vector<float>& input_data, const std::vector<float>& bias_data,
const std::vector<float>& output_data, const std::vector<int64_t>& input_dims,
const std::vector<int64_t>& bias_dims, const std::vector<int64_t>& output_dims,
Expand Down Expand Up @@ -75,6 +75,8 @@ static void RunFastGeluGpuTest(const std::vector<float>& input_data, const std::
execution_providers.push_back(DefaultCudaExecutionProvider());
#elif USE_ROCM
execution_providers.push_back(DefaultRocmExecutionProvider());
#elif USE_WEBGPU
execution_providers.push_back(DefaultWebGpuExecutionProvider());
#endif
tester.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers);
}
Expand Down Expand Up @@ -142,7 +144,7 @@ static void RunFastGeluTest(
std::vector<int64_t> input_dims = {batch_size, sequence_length, hidden_size};
std::vector<int64_t> bias_dims = {hidden_size};
std::vector<int64_t> output_dims = input_dims;
#if defined(USE_CUDA) || defined(USE_ROCM)
#if defined(USE_CUDA) || defined(USE_ROCM) || defined(USE_WEBGPU)
RunFastGeluGpuTest(input_data, bias_data, output_data, input_dims, bias_dims, output_dims, has_bias);
#endif
RunFastGeluCpuTest(input_data, bias_data, output_data, input_dims, bias_dims, output_dims, has_bias);
Expand Down Expand Up @@ -246,7 +248,7 @@ TEST(FastGeluTest, FastGeluWithoutBiasFloat32) {
}

// CUDA and ROCm only for Float16 and BFloat16 type.
#if defined(USE_CUDA) || defined(USE_ROCM)
#if defined(USE_CUDA) || defined(USE_ROCM) || defined(USE_WEBGPU)
TEST(FastGeluTest, FastGeluWithBiasFloat16_2) {
int batch_size = 1;
int sequence_length = 2;
Expand Down

0 comments on commit 5bb4e9e

Please sign in to comment.