Skip to content

Commit

Permalink
Add FastGelu op (#21991)
Browse files Browse the repository at this point in the history
### Description
<!-- Describe your changes. -->



### Motivation and Context
<!-- - Why is this change required? What problem does it solve?
- If it fixes an open issue, please link to the issue here. -->

---------

Co-authored-by: Yulong Wang <[email protected]>
  • Loading branch information
qjia7 and fs-eire authored Sep 10, 2024
1 parent 969384d commit 6b82486
Show file tree
Hide file tree
Showing 7 changed files with 184 additions and 30 deletions.
84 changes: 84 additions & 0 deletions onnxruntime/contrib_ops/webgpu/bert/fast_gelu.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

#include "fast_gelu.h"

Check warning on line 4 in onnxruntime/contrib_ops/webgpu/bert/fast_gelu.cc

View workflow job for this annotation

GitHub Actions / Optional Lint C++

[cpplint] reported by reviewdog 🐶 Include the directory when naming header files [build/include_subdir] [4] Raw Output: onnxruntime/contrib_ops/webgpu/bert/fast_gelu.cc:4: Include the directory when naming header files [build/include_subdir] [4]
#include "core/providers/webgpu/shader_helper.h"
#include "core/providers/webgpu/webgpu_supported_types.h"
#include "contrib_ops/webgpu/webgpu_contrib_kernels.h"

namespace onnxruntime {
namespace contrib {
namespace webgpu {

ONNX_OPERATOR_KERNEL_EX(
FastGelu,
kMSDomain,
1,
kWebGpuExecutionProvider,
(*KernelDefBuilder::Create())
.TypeConstraint("T", WebGpuSupportedFloatTypes()),
FastGelu);

Status FastGeluProgram::GenerateShaderCode(ShaderHelper& shader) const {
const auto& input = shader.AddInput("input", ShaderVariable::UseUniform | ShaderVariable::UseValueTypeAlias);
const auto& output = shader.AddOutput("output", ShaderVariable::UseUniform);

std::string add_bias = "";

Check warning on line 26 in onnxruntime/contrib_ops/webgpu/bert/fast_gelu.cc

View workflow job for this annotation

GitHub Actions / Optional Lint C++

[cpplint] reported by reviewdog 🐶 Add #include <string> for string [build/include_what_you_use] [4] Raw Output: onnxruntime/contrib_ops/webgpu/bert/fast_gelu.cc:26: Add #include <string> for string [build/include_what_you_use] [4]
if (Inputs().size() > 1) {
const auto& bias = shader.AddInput("bias", ShaderVariable::UseUniform | ShaderVariable::UseShapeAndStride);
add_bias = bias_components_ == 1 ? " let bias_offset = global_idx * 4;\n"
" x += input_value_t(" +
bias.GetByOffset("bias_offset % uniforms.bias_shape") + ", " +
bias.GetByOffset("(bias_offset + 1) % uniforms.bias_shape") + ", " +
bias.GetByOffset("(bias_offset + 2) % uniforms.bias_shape") + ", " +
bias.GetByOffset("(bias_offset + 3) % uniforms.bias_shape") + ");\n"
: " x += " + bias.GetByOffset("global_idx % uniforms.bias_shape") + ";\n";
}

shader.MainFunctionBody(shader.GuardAgainstOutOfBoundsWorkgroupSizes("uniforms.vec_size"),
" var x = ", input.GetByOffset("global_idx"), ";\n",
add_bias,
" let y = x * (0.5 + 0.5 * tanh(x * (0.035677408136300125 * x * x + 0.7978845608028654)));\n ",

Check warning on line 41 in onnxruntime/contrib_ops/webgpu/bert/fast_gelu.cc

View workflow job for this annotation

GitHub Actions / Optional Lint C++

[cpplint] reported by reviewdog 🐶 Lines should be <= 120 characters long [whitespace/line_length] [2] Raw Output: onnxruntime/contrib_ops/webgpu/bert/fast_gelu.cc:41: Lines should be <= 120 characters long [whitespace/line_length] [2]
output.SetByOffset("global_idx", "y"));

return Status::OK();
}

Status FastGelu::ComputeInternal(onnxruntime::webgpu::ComputeContext& context) const {
const auto* input = context.Input(0);
const auto* bias = context.Input(1);
auto* output = context.Output(0, input->Shape());

uint32_t data_size = SafeInt<uint32_t>(output->Shape().Size());
if (data_size == 0) {
return Status::OK();
}

const auto vec_size = (data_size + 3) / 4;
uint32_t bias_size = 0;
int bias_components = 1;

if (bias != nullptr) {
bias_size = SafeInt<uint32_t>(bias->Shape().Size());
if (bias_size % 4 == 0) {
bias_components = 4;
bias_size = bias_size / 4;
}
}

FastGeluProgram program{bias_components};
program.Input({input, ProgramTensorMetadataDependency::Type, {vec_size}, 4})
.Output({output, ProgramTensorMetadataDependency::None, {vec_size}, 4})
.DispatchGroupSize((vec_size + WORKGROUP_SIZE - 1) / WORKGROUP_SIZE)
.UniformVariable({vec_size});

if (bias != nullptr) {
program.Input({bias, ProgramTensorMetadataDependency::TypeAndRank, {bias_size}, bias_components})
.CacheHint(std::to_string(bias_components));
}
return context.RunProgram(program);
}

} // namespace webgpu
} // namespace contrib
} // namespace onnxruntime
38 changes: 38 additions & 0 deletions onnxruntime/contrib_ops/webgpu/bert/fast_gelu.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

#pragma once

#include "core/providers/webgpu/program.h"
#include "core/providers/webgpu/webgpu_kernel.h"

namespace onnxruntime {
namespace contrib {
namespace webgpu {

using namespace onnxruntime::webgpu;

Check warning on line 13 in onnxruntime/contrib_ops/webgpu/bert/fast_gelu.h

View workflow job for this annotation

GitHub Actions / Optional Lint C++

[cpplint] reported by reviewdog 🐶 Do not use namespace using-directives. Use using-declarations instead. [build/namespaces] [5] Raw Output: onnxruntime/contrib_ops/webgpu/bert/fast_gelu.h:13: Do not use namespace using-directives. Use using-declarations instead. [build/namespaces] [5]
using onnxruntime::webgpu::ComputeContext;

class FastGeluProgram final : public Program<FastGeluProgram> {
public:
FastGeluProgram(int bias_components) : Program{"FastGelu"}, bias_components_{bias_components} {
}

Status GenerateShaderCode(ShaderHelper& sh) const override;

WEBGPU_PROGRAM_DEFINE_UNIFORM_VARIABLES({"vec_size", ProgramUniformVariableDataType::Uint32});

private:
int bias_components_;
};

class FastGelu final : public WebGpuKernel {
public:
FastGelu(const OpKernelInfo& info) : WebGpuKernel(info) {}

Status ComputeInternal(ComputeContext& context) const override;
};

} // namespace webgpu
} // namespace contrib
} // namespace onnxruntime
42 changes: 21 additions & 21 deletions onnxruntime/contrib_ops/webgpu/webgpu_contrib_kernels.cc
Original file line number Diff line number Diff line change
Expand Up @@ -26,34 +26,34 @@ class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kMSDomain, 1, Sk
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 1, SimplifiedLayerNormalization);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kMSDomain, 1, SkipSimplifiedLayerNormalization);

// template <>
// KernelCreateInfo BuildKernelCreateInfo<void>() {
// KernelCreateInfo info;
// return info;
// }
template <>
KernelCreateInfo BuildKernelCreateInfo<void>() {
KernelCreateInfo info;
return info;
}

Status RegisterWebGpuContribKernels(KernelRegistry& kernel_registry) {
static const BuildKernelCreateInfoFn function_table[] = {
BuildKernelCreateInfo<void>, // default entry to avoid the list become empty after ops-reducing
// BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kMSDomain, 1, Attention)>,

Check warning on line 38 in onnxruntime/contrib_ops/webgpu/webgpu_contrib_kernels.cc

View workflow job for this annotation

GitHub Actions / Optional Lint C++

[cpplint] reported by reviewdog 🐶 Lines should be <= 120 characters long [whitespace/line_length] [2] Raw Output: onnxruntime/contrib_ops/webgpu/webgpu_contrib_kernels.cc:38: Lines should be <= 120 characters long [whitespace/line_length] [2]
// BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kMSDomain, 1, BiasAdd)>,

Check warning on line 39 in onnxruntime/contrib_ops/webgpu/webgpu_contrib_kernels.cc

View workflow job for this annotation

GitHub Actions / Optional Lint C++

[cpplint] reported by reviewdog 🐶 Lines should be <= 120 characters long [whitespace/line_length] [2] Raw Output: onnxruntime/contrib_ops/webgpu/webgpu_contrib_kernels.cc:39: Lines should be <= 120 characters long [whitespace/line_length] [2]
// BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kMSDomain, 1, BiasSplitGelu)>,
// BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kMSDomain, 1, FastGelu)>,
// BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kMSDomain, 1, FusedConv)>,
// BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kMSDomain, 1, Gelu)>,
// BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kMSDomain, 1, GroupQueryAttention)>,
// // LayerNormalization used to be a contrib op that (incorrectly) used kOnnxDomain so we need to version it
// BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 1, 16, LayerNormalization)>,
// BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kMSDomain, 1, MatMulNBits)>,
// BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kMSDomain, 1, MultiHeadAttention)>,
// BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kMSDomain, 1, QuickGelu)>,
// BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kMSDomain, 1, RotaryEmbedding)>,
// BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kMSDomain, 1,
// SkipLayerNormalization)>,
// BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 1,
// SimplifiedLayerNormalization)>,
// BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kMSDomain, 1,
// SkipSimplifiedLayerNormalization)>
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kMSDomain, 1, FastGelu)>,
// BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kMSDomain, 1, FusedConv)>,
// BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kMSDomain, 1, Gelu)>,
// BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kMSDomain, 1, GroupQueryAttention)>,
// // LayerNormalization used to be a contrib op that (incorrectly) used kOnnxDomain so we need to version it
// BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 1, 16, LayerNormalization)>,
// BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kMSDomain, 1, MatMulNBits)>,
// BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kMSDomain, 1, MultiHeadAttention)>,
// BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kMSDomain, 1, QuickGelu)>,
// BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kMSDomain, 1, RotaryEmbedding)>,
// BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kMSDomain, 1,
// SkipLayerNormalization)>,
// BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 1,
// SimplifiedLayerNormalization)>,
// BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kMSDomain, 1,
// SkipSimplifiedLayerNormalization)>
};

for (auto& function_table_entry : function_table) {
Expand Down
5 changes: 4 additions & 1 deletion onnxruntime/contrib_ops/webgpu/webgpu_contrib_kernels.h
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,16 @@

#pragma once

#include "core/framework/op_kernel.h"
#include "core/framework/kernel_registry.h"

namespace onnxruntime {
namespace contrib {
namespace webgpu {

// forward declaration for this EP's namespace.
template <typename T>
KernelCreateInfo BuildKernelCreateInfo();

Status RegisterWebGpuContribKernels(KernelRegistry& kernel_registry);

} // namespace webgpu
Expand Down
19 changes: 17 additions & 2 deletions onnxruntime/core/providers/webgpu/program.cc
Original file line number Diff line number Diff line change
Expand Up @@ -192,13 +192,23 @@ ProgramBase::ProgramBase(const std::string& name)
workgroup_size_z_{0} {
}

ProgramBase& ProgramBase::Input(ProgramInput&& input) {
inputs_.emplace_back(input);
return *this;
}

ProgramBase& ProgramBase::Inputs(std::initializer_list<ProgramInput> inputs) {
inputs_.assign(inputs.begin(), inputs.end());
inputs_.insert(inputs_.end(), inputs.begin(), inputs.end());
return *this;
}

ProgramBase& ProgramBase::Output(ProgramOutput&& output) {
outputs_.emplace_back(output);
return *this;
}

ProgramBase& ProgramBase::Outputs(std::initializer_list<ProgramOutput> outputs) {
outputs_.assign(outputs.begin(), outputs.end());
outputs_.insert(outputs_.end(), outputs.begin(), outputs.end());
return *this;
}

Expand Down Expand Up @@ -232,6 +242,11 @@ ProgramBase& ProgramBase::WorkgroupSize(uint32_t x, uint32_t y, uint32_t z) {
return *this;
}

ProgramBase& ProgramBase::UniformVariable(ProgramUniformVariableValue&& variable) {
variables_.emplace_back(variable);
return *this;
}

ProgramBase& ProgramBase::UniformVariables(std::initializer_list<ProgramUniformVariableValue> variables) {
variables_.insert(variables_.end(), variables.begin(), variables.end());
return *this;
Expand Down
13 changes: 11 additions & 2 deletions onnxruntime/core/providers/webgpu/program.h
Original file line number Diff line number Diff line change
Expand Up @@ -272,8 +272,12 @@ class ProgramBase {
return *this;
}

// set one or more program inputs
// append a program input
ProgramBase& Input(ProgramInput&& input);
// append multiple program inputs
ProgramBase& Inputs(std::initializer_list<ProgramInput> inputs);
// append a program output
ProgramBase& Output(ProgramOutput&& output);
// set one or more program outputs
ProgramBase& Outputs(std::initializer_list<ProgramOutput> outputs);

Expand All @@ -291,7 +295,12 @@ class ProgramBase {
// set the size of a workgroup grid.
ProgramBase& WorkgroupSize(uint32_t x, uint32_t y, uint32_t z);

// set the uniform variables.
// append a uniform variable.
//
// the specified uniform variable should match the uniform definition in the class,
// specified by macro WEBGPU_PROGRAM_DEFINE_UNIFORM_VARIABLES.
ProgramBase& UniformVariable(ProgramUniformVariableValue&& variable);
// append multiple uniform variables.
//
// the specified uniform variables should match the uniform definition in the class,
// specified by macro WEBGPU_PROGRAM_DEFINE_UNIFORM_VARIABLES.
Expand Down
13 changes: 9 additions & 4 deletions onnxruntime/test/contrib_ops/fastgelu_op_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ const std::vector<float> GetExpectedResult(const std::vector<float>& input_data,
return ComputeGelu(add_bias_data);
}

#if defined(USE_CUDA) || defined(USE_ROCM)
#if defined(USE_CUDA) || defined(USE_ROCM) || defined(USE_WEBGPU)
static void RunFastGeluGpuTest(const std::vector<float>& input_data, const std::vector<float>& bias_data,
const std::vector<float>& output_data, const std::vector<int64_t>& input_dims,
const std::vector<int64_t>& bias_dims, const std::vector<int64_t>& output_dims,
Expand Down Expand Up @@ -75,6 +75,8 @@ static void RunFastGeluGpuTest(const std::vector<float>& input_data, const std::
execution_providers.push_back(DefaultCudaExecutionProvider());
#elif USE_ROCM
execution_providers.push_back(DefaultRocmExecutionProvider());
#elif USE_WEBGPU
execution_providers.push_back(DefaultWebGpuExecutionProvider());
#endif
tester.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers);
}
Expand Down Expand Up @@ -142,7 +144,7 @@ static void RunFastGeluTest(
std::vector<int64_t> input_dims = {batch_size, sequence_length, hidden_size};
std::vector<int64_t> bias_dims = {hidden_size};
std::vector<int64_t> output_dims = input_dims;
#if defined(USE_CUDA) || defined(USE_ROCM)
#if defined(USE_CUDA) || defined(USE_ROCM) || defined(USE_WEBGPU)
RunFastGeluGpuTest(input_data, bias_data, output_data, input_dims, bias_dims, output_dims, has_bias);
#endif
RunFastGeluCpuTest(input_data, bias_data, output_data, input_dims, bias_dims, output_dims, has_bias);
Expand Down Expand Up @@ -245,8 +247,8 @@ TEST(FastGeluTest, FastGeluWithoutBiasFloat32) {
RunFastGeluTest(input_data, bias_data, batch_size, sequence_length, hidden_size);
}

// CUDA and ROCm only for Float16 and BFloat16 type.
#if defined(USE_CUDA) || defined(USE_ROCM)
// CUDA, ROCm and WebGPU only for Float16 type.
#if defined(USE_CUDA) || defined(USE_ROCM) || defined(USE_WEBGPU)
TEST(FastGeluTest, FastGeluWithBiasFloat16_2) {
int batch_size = 1;
int sequence_length = 2;
Expand Down Expand Up @@ -381,7 +383,10 @@ TEST(FastGeluTest, FastGeluWithoutBiasFloat16_8) {

RunFastGeluGpuTest(input_data, bias_data, output_data, input_dims, bias_dims, output_dims, false, true);
}
#endif

// CUDA and ROCm only for BFloat16 type.
#if defined(USE_CUDA) || defined(USE_ROCM)
TEST(FastGeluTest, FastGeluWithBias_BFloat16) {
#ifdef USE_CUDA
int min_cuda_architecture = 530;
Expand Down

0 comments on commit 6b82486

Please sign in to comment.