Skip to content

Commit

Permalink
revise fast gelu
Browse files Browse the repository at this point in the history
  • Loading branch information
fs-eire committed Sep 25, 2024
1 parent 929725e commit c5e5af3
Showing 1 changed file with 8 additions and 8 deletions.
16 changes: 8 additions & 8 deletions onnxruntime/contrib_ops/webgpu/bert/fast_gelu.cc
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

#include "core/providers/webgpu/shader_helper.h"
#include "core/providers/webgpu/webgpu_supported_types.h"
#include "core/providers/webgpu/math/unary_elementwise_ops.h"
#include "contrib_ops/webgpu/bert/fast_gelu.h"
#include "contrib_ops/webgpu/webgpu_contrib_kernels.h"

Expand All @@ -20,26 +21,25 @@ ONNX_OPERATOR_KERNEL_EX(
FastGelu);

Status FastGeluProgram::GenerateShaderCode(ShaderHelper& shader) const {
const auto& input = shader.AddInput("input", ShaderUsage::UseUniform | ShaderUsage::UseValueTypeAlias);
const auto& output = shader.AddOutput("output", ShaderUsage::UseUniform);
const auto& x = shader.AddInput("x", ShaderUsage::UseUniform | ShaderUsage::UseValueTypeAlias);
const auto& y = shader.AddOutput("y", ShaderUsage::UseUniform);

std::string add_bias = "";

Check warning on line 27 in onnxruntime/contrib_ops/webgpu/bert/fast_gelu.cc

View workflow job for this annotation

GitHub Actions / Optional Lint C++

[cpplint] reported by reviewdog 🐶 Add #include <string> for string [build/include_what_you_use] [4] Raw Output: onnxruntime/contrib_ops/webgpu/bert/fast_gelu.cc:27: Add #include <string> for string [build/include_what_you_use] [4]
if (Inputs().size() > 1) {
const auto& bias = shader.AddInput("bias", ShaderUsage::UseUniform | ShaderUsage::UseShapeAndStride);
add_bias = bias_components_ == 1 ? " let bias_offset = global_idx * 4;\n"
" x += input_value_t(" +
" a += x_value_t(" +
bias.GetByOffset("bias_offset % uniforms.bias_shape") + ", " +
bias.GetByOffset("(bias_offset + 1) % uniforms.bias_shape") + ", " +
bias.GetByOffset("(bias_offset + 2) % uniforms.bias_shape") + ", " +
bias.GetByOffset("(bias_offset + 3) % uniforms.bias_shape") + ");\n"
: " x += " + bias.GetByOffset("global_idx % uniforms.bias_shape") + ";\n";
: " a += " + bias.GetByOffset("global_idx % uniforms.bias_shape") + ";\n";
}

shader.AppendImplementation(TanhImpl);
shader.SetMainFunctionBody(shader.GuardAgainstOutOfBoundsWorkgroupSizes("uniforms.vec_size"),
" var x = ", input.GetByOffset("global_idx"), ";\n",
" var a = ", x.GetByOffset("global_idx"), ";\n",
add_bias,
" let y = x * (0.5 + 0.5 * tanh(x * (0.035677408136300125 * x * x + 0.7978845608028654)));\n ",
output.SetByOffset("global_idx", "y"));
y.SetByOffset("global_idx", onnxruntime::webgpu::FastGeluExpr));

return Status::OK();
}
Expand Down

0 comments on commit c5e5af3

Please sign in to comment.