Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Multi head attention #22143

Closed
wants to merge 83 commits into from
Closed
Changes from 1 commit
Commits
Show all changes
83 commits
Select commit Hold shift + click to select a range
4037bd4
[WIP] WebGPU EP initial commit
fs-eire Aug 28, 2024
9c36250
update C-API
fs-eire Aug 28, 2024
3a0756d
fix build break
fs-eire Aug 28, 2024
5199e98
add an empty symbols.txt file
fs-eire Aug 28, 2024
1c68dbd
fix an error in doc
fs-eire Aug 29, 2024
7db03de
remove string_join.h in favor of absl::StrJoin
fs-eire Aug 29, 2024
6a373c2
fix DLL copy
fs-eire Aug 29, 2024
ee42bba
update doc: require --skip_tests
fs-eire Aug 29, 2024
5fac202
Merge remote-tracking branch 'origin/main' into fs-eire/webgpu-ep
fs-eire Aug 29, 2024
3f46e5c
update dawn version
fs-eire Aug 29, 2024
9f61279
disable Tint tests
fs-eire Aug 29, 2024
6bb6335
fix one build break in Linux
fs-eire Aug 29, 2024
d839dbc
remove unused variables
fs-eire Aug 30, 2024
b70943d
make webgpu build on linux and known to most tools (#21937)
guschmue Aug 30, 2024
c33ac2e
Merge remote-tracking branch 'origin/main' into fs-eire/webgpu-ep
fs-eire Aug 30, 2024
8437267
revert type of ShaderVariable::rank_ to int
fs-eire Aug 30, 2024
3caf032
output Impl() for variables
fs-eire Aug 30, 2024
84494c4
code formatting
fs-eire Aug 30, 2024
aa70163
better format of Uniform
fs-eire Aug 30, 2024
d772db7
revise document
fs-eire Aug 30, 2024
6ef3dad
more build fix for linux
fs-eire Aug 31, 2024
a56f6c3
apply formatter
fs-eire Aug 31, 2024
12cd79d
simple test runner
fs-eire Aug 31, 2024
14c8966
Program macros update - allow extend
fs-eire Aug 31, 2024
4fff35f
fix BucketCacheManager
fs-eire Sep 1, 2024
4fd8ad1
add a method to get logger from ComputeContext
fs-eire Sep 1, 2024
3bd92ad
add verbose log for cache key
fs-eire Sep 1, 2024
6a1bbfe
revise suite test
fs-eire Sep 1, 2024
947aee1
device lost handler
fs-eire Sep 1, 2024
99b2578
add '-a' and '-t' to test runner
fs-eire Sep 1, 2024
aa7b3f5
atol/rtol 0.0001 -> 0.001
fs-eire Sep 1, 2024
e659acd
Fix uniform
fs-eire Sep 2, 2024
6ad89c5
add some unary ops
fs-eire Sep 2, 2024
8361fc3
various of fixes
fs-eire Sep 2, 2024
c89159d
fix workgroup_size, cache key stringnify and indices type
fs-eire Sep 3, 2024
5ea5936
shape_uniforms preparation
fs-eire Sep 3, 2024
7d83054
allow uniforms of input/output shape/stride being added automatically
fs-eire Sep 3, 2024
7a64cc7
Merge remote-tracking branch 'origin/main' into fs-eire/webgpu-ep
fs-eire Sep 3, 2024
1d53ac8
fix build (linux)
fs-eire Sep 3, 2024
4d52602
fix stride
fs-eire Sep 3, 2024
3761aad
fix "{res_name}_bi2o_{name}"
fs-eire Sep 3, 2024
351da84
Add Expand operator (#21933)
qjia7 Sep 3, 2024
0b7ce77
support onnxruntime_test_all
fs-eire Sep 3, 2024
33726b1
reflect change in WebGpuProviderFactoryCreator::Create signature (#21…
guschmue Sep 3, 2024
50ea9eb
compare the content of WEBGPU_BUFFER, not the address (#21967)
guschmue Sep 3, 2024
d6f6148
fix tanh
fs-eire Sep 3, 2024
626edaf
support size==0 for element wise operators
fs-eire Sep 4, 2024
8913da1
Merge remote-tracking branch 'origin/main' into fs-eire/webgpu-ep
fs-eire Sep 4, 2024
bacc54c
use shared ComputeBroadcastOutputShape()
fs-eire Sep 4, 2024
7ecc5bb
add workgroup_idx
fs-eire Sep 4, 2024
ae836b1
expose name for shader variable
fs-eire Sep 4, 2024
243078b
add uniform for 1D variable
fs-eire Sep 5, 2024
4d48d28
fix GetElementAt with uniform
fs-eire Sep 5, 2024
dbe673b
document update folder
fs-eire Sep 5, 2024
38f182e
fix adapter/device creating: add toggles
fs-eire Sep 5, 2024
eb80f7c
more strict shape&stride usage check
fs-eire Sep 6, 2024
39d5509
fix vector realloc
fs-eire Sep 6, 2024
cd961c3
simplify cache hint interface.
fs-eire Sep 6, 2024
ddc2fbb
revise expand
fs-eire Sep 6, 2024
e8be835
revise unary
fs-eire Sep 6, 2024
bd7d592
Elu/Relu/LeakyRelu/ThresholdedRelu/Gelu
fs-eire Sep 6, 2024
eecac18
Merge remote-tracking branch 'origin/main' into fs-eire/webgpu-ep
fs-eire Sep 6, 2024
601e50f
remove unused field in class Gelu
fs-eire Sep 6, 2024
8f36da2
remove out-of-dated comments
fs-eire Sep 6, 2024
72ebd85
Clip
fs-eire Sep 7, 2024
a3244ae
fix rank in shader helper
fs-eire Sep 7, 2024
5a2ae8c
fix shader variable
fs-eire Sep 9, 2024
aa54ff8
move components number from variable to program
fs-eire Sep 9, 2024
969384d
mark components in cache key
fs-eire Sep 9, 2024
6b82486
Add FastGelu op (#21991)
qjia7 Sep 10, 2024
2b3e7c2
use 'set/add' as prefix for some functions
fs-eire Sep 10, 2024
ef0d53b
remove unnecessary cache hint for FastGelu
fs-eire Sep 10, 2024
c4ca47f
revise unary - expose consts in header
fs-eire Sep 10, 2024
8806d57
use path for header file
fs-eire Sep 10, 2024
0568e2b
a few revises to the code (#22047)
fs-eire Sep 10, 2024
b7a9c0e
use OrtMutex
fs-eire Sep 11, 2024
f65ade9
Merge remote-tracking branch 'origin/main' into fs-eire/webgpu-ep
fs-eire Sep 11, 2024
d4a963d
[webgpu-native] Add transpose op (#21986)
axinging Sep 11, 2024
8b61532
PushErrorScope and PopErrorScope
fs-eire Sep 11, 2024
dce0f18
placeholder for setting proc table
fs-eire Sep 12, 2024
8978d89
Revert "placeholder for setting proc table"
fs-eire Sep 12, 2024
43ccaf4
allow setting "ValidationMode"
fs-eire Sep 12, 2024
409ac5c
webgpu: support MultiHeadAttention operator
xhcao Sep 19, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
Elu/Relu/LeakyRelu/ThresholdedRelu/Gelu
fs-eire committed Sep 6, 2024

Verified

This commit was created on GitHub.com and signed with GitHub’s verified signature. The key has expired.
commit bd7d592386932b5dd55793dd4a44328808114269
87 changes: 82 additions & 5 deletions onnxruntime/core/providers/webgpu/math/unary_elementwise_ops.cc
Original file line number Diff line number Diff line change
@@ -37,6 +37,9 @@ Status UnaryElementwise::ComputeInternal(ComputeContext& context) const {
.UniformVariables({
{static_cast<uint32_t>(vec_size)},
});
if (!cache_hint.empty()) {
program.CacheHint(cache_hint);
}
ORT_RETURN_IF_ERROR(ConfigureProgram(program));
return context.RunProgram(program);
}
@@ -172,7 +175,13 @@ WEBGPU_ELEMENTWISE_KERNEL(Cosh, 9, WebGpuSupportedFloatTypes())

// built-in function tanh() does not work with large input (f32 88.7 or f16 11.09)
// https://github.com/gpuweb/gpuweb/issues/4458
WEBGPU_ELEMENTWISE_IMPL(Tanh, "sign(a) * (1 - exp(-2 * abs(a))) / (1 + exp(-2 * abs(a)))")
constexpr char TanhImpl[] = R"(
fn tanh_v(a: x_value_t) -> x_value_t {
let expr = exp(-2 * abs(a));
return sign(a) * (1 - expr) / (1 + expr);
}
)";
WEBGPU_ELEMENTWISE_IMPL(Tanh, "tanh_v(a)", TanhImpl, ShaderVariable::UseValueTypeAlias)
WEBGPU_ELEMENTWISE_VERSIONED_KERNEL(Tanh, 6, 12, WebGpuSupportedFloatTypes())
WEBGPU_ELEMENTWISE_KERNEL(Tanh, 13, WebGpuSupportedFloatTypes())

@@ -193,10 +202,78 @@ WEBGPU_ELEMENTWISE_KERNEL(Atanh, 9, WebGpuSupportedFloatTypes())

// todo: clip

// constexpr char EluImpl[] = R"(
//)";
//
// WEBGPU_ELEMENTWISE_IMPL(Elu, "elu_v(a)", )
class LinearUnit : public UnaryElementwise {
public:
LinearUnit(const OpKernelInfo& info,
const std::string& kernel_name,
const std::string& expression,
const std::string& additional_impl,
float default_alpha)
: UnaryElementwise{info, kernel_name, expression, additional_impl, ShaderVariable::UseElementTypeAlias} {
info.GetAttrOrDefault("alpha", &alpha_, default_alpha);
}

Status ConfigureProgram(UnaryElementwiseProgram& program) const override {
program.UniformVariables({alpha_, {}});
return Status::OK();
}

protected:
float alpha_;
};

#define WEBGPU_LU_IMPL(OP_TYPE, ...) \
class OP_TYPE final : public LinearUnit { \
public: \
OP_TYPE(const OpKernelInfo& info) : LinearUnit{info, #OP_TYPE, __VA_ARGS__} {} \
};

constexpr char EluImpl[] = R"(
fn elu(a: x_element_t) -> x_element_t {
let alpha = x_element_t(uniforms.f32_attr);
return select((exp(a) - 1.0) * alpha, a, a >= 0.0);
}

fn elu_v(v: vec4<x_element_t>) -> vec4<x_element_t> {
return vec4(elu(v.x), elu(v.y), elu(v.z), elu(v.w));
}
)";

WEBGPU_LU_IMPL(Elu, "elu_v(a)", EluImpl, 1.0)
WEBGPU_ELEMENTWISE_KERNEL(Elu, 6, WebGpuSupportedFloatTypes())

// TODO: support attribute "approximate"
class Gelu : public UnaryElementwise {
public:
Gelu(const OpKernelInfo& info)
: UnaryElementwise{info,
"Gelu",
info.GetAttrOrDefault<std::string>("approximate", "none") == "tanh" ? TanhBasedImpl : DefaultImpl,
info.GetAttrOrDefault<std::string>("approximate", "none") == "tanh" ? TanhImpl : ErfImpl,
ShaderVariable::UseValueTypeAlias} {
cache_hint = info.GetAttrOrDefault<std::string>("approximate", "none");
}

constexpr static const char DefaultImpl[] = "0.5 * a * (1.0 + erf_v(a * 0.7071067811865475))";
constexpr static const char TanhBasedImpl[] = "0.5 * a * (1 + tanh_v(0.7978845608028654 * (a + 0.044715 * a * a * a)))";

protected:
float alpha_;
};

WEBGPU_ELEMENTWISE_KERNEL(Gelu, 20, WebGpuSupportedFloatTypes())

WEBGPU_ELEMENTWISE_IMPL(Relu, "select(x_value_t(0), a, a > x_value_t(0))", "", ShaderVariable::UseValueTypeAlias)
WEBGPU_ELEMENTWISE_VERSIONED_KERNEL(Relu, 6, 12, WebGpuSupportedFloatTypes())
WEBGPU_ELEMENTWISE_VERSIONED_KERNEL(Relu, 13, 13, WebGpuSupportedFloatTypes())
WEBGPU_ELEMENTWISE_KERNEL(Relu, 14, WebGpuSupportedFloatTypes())

WEBGPU_LU_IMPL(LeakyRelu, "select(x_element_t(uniforms.f32_attr) * a, a, a >= vec4<x_element_t>(0))", "", 0.01f)
WEBGPU_ELEMENTWISE_VERSIONED_KERNEL(LeakyRelu, 6, 15, WebGpuSupportedFloatTypes())
WEBGPU_ELEMENTWISE_KERNEL(LeakyRelu, 16, WebGpuSupportedFloatTypes())

WEBGPU_LU_IMPL(ThresholdedRelu, "select(vec4<x_element_t>(0), a, a > vec4<x_element_t>(uniforms.f32_attr))", "", 1.0f)
WEBGPU_ELEMENTWISE_KERNEL(ThresholdedRelu, 10, WebGpuSupportedFloatTypes())

// TODO: add other unary elementwise ops

Original file line number Diff line number Diff line change
@@ -45,6 +45,8 @@ class UnaryElementwise : public WebGpuKernel {
additional_usage_{usage} {}

protected:
std::string cache_hint;

Status ComputeInternal(ComputeContext& context) const final;
virtual Status ConfigureProgram(UnaryElementwiseProgram& program) const {
program.UniformVariables({{}, {}}); // empty for both float and int attribute(s)
19 changes: 10 additions & 9 deletions onnxruntime/core/providers/webgpu/webgpu_execution_provider.cc
Original file line number Diff line number Diff line change
@@ -134,6 +134,8 @@ class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxD
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 14, Relu);
class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 6, 15, LeakyRelu);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 16, LeakyRelu);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 10, ThresholdedRelu);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 20, Gelu);

class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 1, 10, ReduceMax);
class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 11, 11, ReduceMax);
@@ -186,8 +188,6 @@ class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxD
class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 13, 17, ReduceLogSumExp);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 18, ReduceLogSumExp);

class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 10, ThresholdedRelu);

class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 7, 12, Add);
class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 13, 13, Add);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 14, Add);
@@ -442,13 +442,14 @@ std::unique_ptr<KernelRegistry> RegisterKernels() {
// KERNEL_CREATE_INFO_VERSIONED(11, 11, Clip),
// KERNEL_CREATE_INFO_VERSIONED(12, 12, Clip),
// KERNEL_CREATE_INFO(13, Clip),
// KERNEL_CREATE_INFO(6, Elu),
// KERNEL_CREATE_INFO_VERSIONED(6, 12, Relu),
// KERNEL_CREATE_INFO_VERSIONED(13, 13, Relu),
// KERNEL_CREATE_INFO(14, Relu),
// KERNEL_CREATE_INFO_VERSIONED(6, 15, LeakyRelu),
// KERNEL_CREATE_INFO(16, LeakyRelu),
// KERNEL_CREATE_INFO(10, ThresholdedRelu),
KERNEL_CREATE_INFO(6, Elu),
KERNEL_CREATE_INFO_VERSIONED(6, 12, Relu),
KERNEL_CREATE_INFO_VERSIONED(13, 13, Relu),
KERNEL_CREATE_INFO(14, Relu),
KERNEL_CREATE_INFO_VERSIONED(6, 15, LeakyRelu),
KERNEL_CREATE_INFO(16, LeakyRelu),
KERNEL_CREATE_INFO(10, ThresholdedRelu),
KERNEL_CREATE_INFO(20, Gelu),

// // binary - math
// KERNEL_CREATE_INFO_VERSIONED(7, 12, Add),