Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[webgpu-native] support for webgpu layernorms #22249

Merged
merged 8 commits into from
Oct 1, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 36 additions & 0 deletions onnxruntime/contrib_ops/webgpu/bert/layer_norm.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@

// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

#include "core/providers/webgpu/shader_helper.h"
#include "core/providers/webgpu/webgpu_supported_types.h"
#include "core/providers/webgpu/nn/layer_norm.h"
#include "contrib_ops/webgpu/webgpu_contrib_kernels.h"

namespace onnxruntime {
namespace contrib {
namespace webgpu {

using namespace onnxruntime::webgpu;

Check warning on line 14 in onnxruntime/contrib_ops/webgpu/bert/layer_norm.cc

View workflow job for this annotation

GitHub Actions / Optional Lint C++

[cpplint] reported by reviewdog 🐶 Do not use namespace using-directives. Use using-declarations instead. [build/namespaces] [5] Raw Output: onnxruntime/contrib_ops/webgpu/bert/layer_norm.cc:14: Do not use namespace using-directives. Use using-declarations instead. [build/namespaces] [5]
using onnxruntime::webgpu::ComputeContext;

ONNX_OPERATOR_VERSIONED_KERNEL_EX(
LayerNormalization,
kOnnxDomain,
1,
16,
kWebGpuExecutionProvider,
(*KernelDefBuilder::Create()).TypeConstraint("T", WebGpuSupportedFloatTypes()),
onnxruntime::webgpu::LayerNorm<false>);

ONNX_OPERATOR_KERNEL_EX(
SimplifiedLayerNormalization,
kOnnxDomain,
1,
kWebGpuExecutionProvider,
(*KernelDefBuilder::Create()).TypeConstraint("T", WebGpuSupportedFloatTypes()),
onnxruntime::webgpu::LayerNorm<true>);

} // namespace webgpu
} // namespace contrib
} // namespace onnxruntime
168 changes: 168 additions & 0 deletions onnxruntime/contrib_ops/webgpu/bert/skip_layer_norm.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,168 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

#include "core/providers/webgpu/shader_helper.h"
#include "core/providers/webgpu/webgpu_supported_types.h"
#include "contrib_ops/webgpu/webgpu_contrib_kernels.h"
#include "contrib_ops/webgpu/bert/skip_layer_norm.h"

namespace onnxruntime {
namespace contrib {
namespace webgpu {

static uint32_t GetMaxComponents(int size) {
if (size % 4 == 0) {
return 4;
} else if (size % 2 == 0) {
return 2;
}
return 1;
}

static std::string SumVector(std::string x, int components) {
switch (components) {
case 1:
return x;
case 2:
return "(" + x + ".x + " + x + ".y" + ")";
case 4:
return "(" + x + ".x + " + x + ".y + " + x + ".w + " + x + ".z" + ")";
default:
ORT_THROW("Unsupported number of components: ", components);
}
}

Status SkipLayerNormProgram::GenerateShaderCode(ShaderHelper& shader) const {
const auto& x = shader.AddInput("x", ShaderUsage::UseUniform | ShaderUsage::UseValueTypeAlias);
shader.AddInput("skip", ShaderUsage::UseUniform);
shader.AddInput("gamma", ShaderUsage::UseUniform);
if (hasBeta_) {
shader.AddInput("beta", ShaderUsage::UseUniform);
}
if (hasBias_) {
shader.AddInput("bias", ShaderUsage::UseUniform);
}
shader.AddOutput("output", ShaderUsage::UseUniform);

int components = x.NumComponents();

std::string bias = (hasBias_) ? " + bias[offset1d + i] " : "";
std::string simpl1 = (simplified_) ? "" : "- mean * mean ";
std::string simpl2 = (simplified_) ? "" : "- element_t(mean) ";
std::string beta = (hasBeta_) ? " + beta[offset1d + i] " : "";

Check warning on line 52 in onnxruntime/contrib_ops/webgpu/bert/skip_layer_norm.cc

View workflow job for this annotation

GitHub Actions / Optional Lint C++

[cpplint] reported by reviewdog 🐶 Add #include <string> for string [build/include_what_you_use] [4] Raw Output: onnxruntime/contrib_ops/webgpu/bert/skip_layer_norm.cc:52: Add #include <string> for string [build/include_what_you_use] [4]

shader.AdditionalImplementation()
<< "alias element_t = " << (is_fp16_ ? "f16;\n" : "f32;\n")
<< "alias f32_val_t = " << (components == 4 ? "vec4<f32>" : (components == 2 ? "vec2<f32>" : "f32")) << ";\n"
<< "var<workgroup> sum_shared : array<f32_val_t, workgroup_size_x>;\n"
<< "var<workgroup> sum_squared_shared : array<f32_val_t, workgroup_size_x>;\n";

shader.MainFunctionBody()
<< "let ix = local_idx;\n"
<< "let iy = global_idx / workgroup_size_x;\n"
<< "let hidden_size_vectorized: u32 = uniforms.hidden_size / uniforms.components;\n"
<< "var stride = hidden_size_vectorized / workgroup_size_x;\n"
<< "let offset = ix * stride + iy * hidden_size_vectorized;\n"
<< "let offset1d = stride * ix;\n"
<< "if (ix == workgroup_size_x - 1) {\n"
<< " stride = hidden_size_vectorized - stride * ix;\n"
<< "}\n"
<< "for (var i: u32 = 0; i < stride; i++) {\n"
<< " let skip_value = skip[offset + i];\n"
<< " let input_value = x[offset + i];\n"
<< " let value = input_value + skip_value" << bias << ";\n"
<< " output[offset + i] = value;\n"
<< " let f32_value = f32_val_t(value);\n"
<< " sum_shared[ix] += f32_value;\n"
<< " sum_squared_shared[ix] += f32_value * f32_value;\n"
<< "}\n"
<< "workgroupBarrier();\n"
<< "var reduce_size : u32 = workgroup_size_x;\n"
<< "for (var curr_size = reduce_size >> 1; curr_size > 0; curr_size = reduce_size >> 1) {\n"
<< " reduce_size = curr_size + (reduce_size & 1);\n"
<< " if (ix < curr_size) {\n"
<< " sum_shared[ix] += sum_shared[ix + reduce_size];\n"
<< " sum_squared_shared[ix] += sum_squared_shared[ix + reduce_size];\n"
<< " }\n"
<< " workgroupBarrier();\n"
<< "}\n"
<< "let sum = sum_shared[0];\n"
<< "let square_sum = sum_squared_shared[0];\n"
<< "let mean = " << SumVector("sum", components) << " / f32(uniforms.hidden_size);\n"
<< "let inv_std_dev = inverseSqrt(" << SumVector("square_sum", components) << " / f32(uniforms.hidden_size) " << simpl1 << "+ uniforms.epsilon);\n"

Check warning on line 92 in onnxruntime/contrib_ops/webgpu/bert/skip_layer_norm.cc

View workflow job for this annotation

GitHub Actions / Optional Lint C++

[cpplint] reported by reviewdog 🐶 Lines should be <= 120 characters long [whitespace/line_length] [2] Raw Output: onnxruntime/contrib_ops/webgpu/bert/skip_layer_norm.cc:92: Lines should be <= 120 characters long [whitespace/line_length] [2]
<< "for (var i: u32 = 0; i < stride; i++) {\n"
<< " output[offset + i] = (output[offset + i] " << simpl2 << ") * element_t(inv_std_dev) * gamma[offset1d + i]" << beta << ";\n"

Check warning on line 94 in onnxruntime/contrib_ops/webgpu/bert/skip_layer_norm.cc

View workflow job for this annotation

GitHub Actions / Optional Lint C++

[cpplint] reported by reviewdog 🐶 Lines should be <= 120 characters long [whitespace/line_length] [2] Raw Output: onnxruntime/contrib_ops/webgpu/bert/skip_layer_norm.cc:94: Lines should be <= 120 characters long [whitespace/line_length] [2]
<< "};\n";

return Status::OK();
}

template <bool simplified>
Status SkipLayerNorm<simplified>::ComputeInternal(onnxruntime::webgpu::ComputeContext& context) const {
const Tensor* x = context.Input(0);
const Tensor* skip = context.Input(1);
const Tensor* gamma = context.Input(2);
// optional
const Tensor* beta = context.Input(3);
const Tensor* bias = context.Input(4);

const auto x_shape = x->Shape();

auto* output = context.Output(0, x_shape);

size_t data_size = x_shape.Size();
if (data_size == 0) {
return Status::OK();
}

const bool is_fp16 = x->GetElementType() == ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT16;
const uint32_t hidden_size = SafeInt<uint32_t>(x_shape[x_shape.NumDimensions() - 1]);
const int components = GetMaxComponents(hidden_size);

SkipLayerNormProgram program{beta != nullptr, bias != nullptr, epsilon_, hidden_size, is_fp16, simplified};
program
.CacheHint(simplified)
.AddInputs({{x, ProgramTensorMetadataDependency::Type, components}})
.AddInputs({{skip, ProgramTensorMetadataDependency::Type, components}})
.AddInputs({{gamma, ProgramTensorMetadataDependency::Type, components}})
.AddOutputs({{output, ProgramTensorMetadataDependency::None, components}})
.SetDispatchGroupSize(SafeInt<uint32_t>(ceil(1.0 * data_size / hidden_size)))
.AddUniformVariables({
{static_cast<uint32_t>(components)},
})
.AddUniformVariables({
{static_cast<uint32_t>(hidden_size)},
})
.AddUniformVariables({
{static_cast<float>(epsilon_)},
});

if (beta != nullptr) {
program.AddInput({beta, ProgramTensorMetadataDependency::Type, components});
}
if (bias != nullptr) {
program.AddInput({bias, ProgramTensorMetadataDependency::Type, components});
}

return context.RunProgram(program);
}

ONNX_OPERATOR_KERNEL_EX(
SkipLayerNormalization,
kMSDomain,
1,
kWebGpuExecutionProvider,
(*KernelDefBuilder::Create()).TypeConstraint("T", WebGpuSupportedFloatTypes()),
SkipLayerNorm<false>);

ONNX_OPERATOR_KERNEL_EX(
SkipSimplifiedLayerNormalization,
kMSDomain,
1,
kWebGpuExecutionProvider,
(*KernelDefBuilder::Create()).TypeConstraint("T", WebGpuSupportedFloatTypes()),
SkipLayerNorm<true>);

} // namespace webgpu
} // namespace contrib
} // namespace onnxruntime
62 changes: 62 additions & 0 deletions onnxruntime/contrib_ops/webgpu/bert/skip_layer_norm.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

#pragma once

#include "core/providers/webgpu/program.h"
#include "core/providers/webgpu/webgpu_kernel.h"

namespace onnxruntime {
namespace contrib {
namespace webgpu {

using namespace onnxruntime::webgpu;

Check warning on line 13 in onnxruntime/contrib_ops/webgpu/bert/skip_layer_norm.h

View workflow job for this annotation

GitHub Actions / Optional Lint C++

[cpplint] reported by reviewdog 🐶 Do not use namespace using-directives. Use using-declarations instead. [build/namespaces] [5] Raw Output: onnxruntime/contrib_ops/webgpu/bert/skip_layer_norm.h:13: Do not use namespace using-directives. Use using-declarations instead. [build/namespaces] [5]
using onnxruntime::webgpu::ComputeContext;

class SkipLayerNormProgram final : public Program<SkipLayerNormProgram> {
public:
SkipLayerNormProgram(bool hasBeta, bool hasBias, float epsilon, uint32_t hidden_size, bool is_fp16, bool simplified) : Program{"SkipLayerNorm"} {

Check warning on line 18 in onnxruntime/contrib_ops/webgpu/bert/skip_layer_norm.h

View workflow job for this annotation

GitHub Actions / Optional Lint C++

[cpplint] reported by reviewdog 🐶 Lines should be <= 120 characters long [whitespace/line_length] [2] Raw Output: onnxruntime/contrib_ops/webgpu/bert/skip_layer_norm.h:18: Lines should be <= 120 characters long [whitespace/line_length] [2]
epsilon_ = epsilon;
hasBeta_ = hasBeta;
hasBias_ = hasBias;
epsilon_ = epsilon;
hidden_size_ = hidden_size;
simplified_ = simplified;
is_fp16_ = is_fp16;
}

Status GenerateShaderCode(ShaderHelper& sh) const override;

WEBGPU_PROGRAM_DEFINE_UNIFORM_VARIABLES(
{"components", ProgramUniformVariableDataType::Uint32},
{"hidden_size", ProgramUniformVariableDataType::Uint32},
{"epsilon", ProgramUniformVariableDataType::Float32});

private:
bool hasBeta_;
bool hasBias_;
float epsilon_;
uint32_t hidden_size_;
bool is_fp16_;
bool simplified_;
};

template <bool simplified>
class SkipLayerNorm final : public WebGpuKernel {
public:
SkipLayerNorm(const OpKernelInfo& info) : WebGpuKernel(info) {
info.GetAttrOrDefault<float>("epsilon", &epsilon_, 1e-05f);
}

Status ComputeInternal(ComputeContext& context) const override;

protected:
std::string cache_hint;

Check warning on line 54 in onnxruntime/contrib_ops/webgpu/bert/skip_layer_norm.h

View workflow job for this annotation

GitHub Actions / Optional Lint C++

[cpplint] reported by reviewdog 🐶 Add #include <string> for string [build/include_what_you_use] [4] Raw Output: onnxruntime/contrib_ops/webgpu/bert/skip_layer_norm.h:54: Add #include <string> for string [build/include_what_you_use] [4]

private:
float epsilon_;
};

} // namespace webgpu
} // namespace contrib
} // namespace onnxruntime
17 changes: 7 additions & 10 deletions onnxruntime/contrib_ops/webgpu/webgpu_contrib_kernels.cc
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kMSDomain, 1, MultiHeadAttention);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kMSDomain, 1, QuickGelu);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kMSDomain, 1, RotaryEmbedding);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 1, SimplifiedLayerNormalization);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kMSDomain, 1, SkipLayerNormalization);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 1, SimplifiedLayerNormalization);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kMSDomain, 1, SkipSimplifiedLayerNormalization);
Expand All @@ -42,19 +43,15 @@
// BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kMSDomain, 1, FusedConv)>,
// BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kMSDomain, 1, Gelu)>,
// BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kMSDomain, 1, GroupQueryAttention)>,
// // LayerNormalization used to be a contrib op that (incorrectly) used kOnnxDomain so we need to version it
// BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 1, 16, LayerNormalization)>,
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kMSDomain, 1, MatMulNBits)>,
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kMSDomain, 1, MultiHeadAttention)>,
// BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kMSDomain, 1, QuickGelu)>,
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kMSDomain, 1, RotaryEmbedding)>,
// BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kMSDomain, 1,
// SkipLayerNormalization)>,
// BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 1,
// SimplifiedLayerNormalization)>,
// BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kMSDomain, 1,
// SkipSimplifiedLayerNormalization)>
};
// BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kMSDomain, 1, RotaryEmbedding)>,

Check warning on line 49 in onnxruntime/contrib_ops/webgpu/webgpu_contrib_kernels.cc

View workflow job for this annotation

GitHub Actions / Optional Lint C++

[cpplint] reported by reviewdog 🐶 Lines should be <= 120 characters long [whitespace/line_length] [2] Raw Output: onnxruntime/contrib_ops/webgpu/webgpu_contrib_kernels.cc:49: Lines should be <= 120 characters long [whitespace/line_length] [2]
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kMSDomain, 1, SkipLayerNormalization)>,

Check warning on line 50 in onnxruntime/contrib_ops/webgpu/webgpu_contrib_kernels.cc

View workflow job for this annotation

GitHub Actions / Optional Lint C++

[cpplint] reported by reviewdog 🐶 Lines should be <= 120 characters long [whitespace/line_length] [2] Raw Output: onnxruntime/contrib_ops/webgpu/webgpu_contrib_kernels.cc:50: Lines should be <= 120 characters long [whitespace/line_length] [2]
// LayerNormalization used to be a contrib op that (incorrectly) used kOnnxDomain so we need to version it
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 1, 16, LayerNormalization)>,

Check warning on line 52 in onnxruntime/contrib_ops/webgpu/webgpu_contrib_kernels.cc

View workflow job for this annotation

GitHub Actions / Optional Lint C++

[cpplint] reported by reviewdog 🐶 Lines should be <= 120 characters long [whitespace/line_length] [2] Raw Output: onnxruntime/contrib_ops/webgpu/webgpu_contrib_kernels.cc:52: Lines should be <= 120 characters long [whitespace/line_length] [2]
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 1, SimplifiedLayerNormalization)>,
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kMSDomain, 1, SkipSimplifiedLayerNormalization)>};

for (auto& function_table_entry : function_table) {
KernelCreateInfo info = function_table_entry();
Expand Down
Loading
Loading