From 409ac5c9cfa808654a3284aa13331ee036e8b228 Mon Sep 17 00:00:00 2001 From: Cao Date: Thu, 19 Sep 2024 13:56:41 +0800 Subject: [PATCH 1/4] webgpu: support MultiHeadAttention operator --- .../webgpu/bert/multihead_attention.cc | 506 ++++++++++++++++++ .../webgpu/bert/multihead_attention.h | 113 ++++ .../webgpu/webgpu_contrib_kernels.cc | 2 +- .../multihead_attention_op_test.cc | 95 ++-- 4 files changed, 676 insertions(+), 40 deletions(-) create mode 100644 onnxruntime/contrib_ops/webgpu/bert/multihead_attention.cc create mode 100644 onnxruntime/contrib_ops/webgpu/bert/multihead_attention.h diff --git a/onnxruntime/contrib_ops/webgpu/bert/multihead_attention.cc b/onnxruntime/contrib_ops/webgpu/bert/multihead_attention.cc new file mode 100644 index 0000000000000..3ff1140834754 --- /dev/null +++ b/onnxruntime/contrib_ops/webgpu/bert/multihead_attention.cc @@ -0,0 +1,506 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "contrib_ops/cpu/bert/multihead_attention_helper.h" +#include "contrib_ops/webgpu/bert/multihead_attention.h" +#include "contrib_ops/webgpu/webgpu_contrib_kernels.h" + +#include "core/providers/webgpu/webgpu_supported_types.h" + +using namespace onnxruntime::webgpu; +using namespace ::onnxruntime::common; +using namespace ONNX_NAMESPACE; +using namespace onnxruntime::contrib::multihead_attention_helper; + +namespace onnxruntime { +namespace contrib { +namespace webgpu { + +ONNX_OPERATOR_KERNEL_EX( + MultiHeadAttention, + kMSDomain, + 1, + kWebGpuExecutionProvider, + (*KernelDefBuilder::Create()) + .TypeConstraint("T", WebGpuSupportedFloatTypes()), + MultiHeadAttention); + + +Status TransferBSDToBNSHProgram::GenerateShaderCode(ShaderHelper& shader) const { + shader.AddInput("qkv_input", ShaderVariable::UseUniform); + const auto& qkv_output = shader.AddOutput("qkv_output", ShaderVariable::UseUniform | + ShaderVariable::UseOffsetToIndices); + + if (has_bias_) { + shader.AddInput("bias", ShaderVariable::UseUniform); + } + + shader.SetMainFunctionBody(shader.GuardAgainstOutOfBoundsWorkgroupSizes("uniforms.data_size"), + "let output_indices = ", qkv_output.OffsetToIndices("global_idx"), ";\n", + "let input_offset_idx = output_indices[0] * uniforms.batch_offset + output_indices[1] * ", + "uniforms.head_offset + output_indices[2] * uniforms.sequence_offset + output_indices[3];\n", + has_bias_ ? "let bias_offset_idx = (input_offset_idx % uniforms.sequence_offset) + uniforms.bias_offset;\n" : "", + "qkv_output[global_idx] = qkv_input[input_offset_idx]", + has_bias_ ? " + bias[bias_offset_idx];\n" : ";\n"); + + return Status::OK(); +} + +Status TransferBSDToBNSH(onnxruntime::webgpu::ComputeContext& context, int num_heads, int sequence_length, + int head_size, const Tensor* input_tensor, const Tensor* bias, int bias_offset, Tensor* output_tensor) { + assert(input_tensor->Shape().GetDims().size() == 3); + assert(output_tensor->Shape().GetDims().size() == 4); + + uint32_t data_size = SafeInt(output_tensor->Shape().Size()); + const int batch_offset = num_heads * sequence_length * head_size; + const int sequence_offset = num_heads * head_size; + const int head_offset = head_size; + bool has_bias = bias != nullptr; + + TransferBSDToBNSHProgram program{"TransferBSDToBNSH", has_bias}; + program.AddInputs({{input_tensor, ProgramTensorMetadataDependency::TypeAndRank}}) + .AddOutputs({{output_tensor, ProgramTensorMetadataDependency::TypeAndRank}}) + .SetDispatchGroupSize((data_size + WORKGROUP_SIZE - 1) / WORKGROUP_SIZE) + .AddUniformVariables({ + {data_size}, + {static_cast(batch_offset)}, + {static_cast(sequence_offset)}, + {static_cast(head_offset)}, + {static_cast(bias_offset)} + }); + + if (has_bias) { + program.AddInput({bias, ProgramTensorMetadataDependency::TypeAndRank}); + } + + return context.RunProgram(program); +}; + +Status AttentionProbsProgram::GenerateShaderCode(ShaderHelper& shader) const { + shader.AddInput("q", ShaderVariable::UseUniform | ShaderVariable::UseValueTypeAlias); + shader.AddInput("key", ShaderVariable::UseUniform | ShaderVariable::UseValueTypeAlias); + if (feed_past_key_) { + shader.AddInput("past_key", ShaderVariable::UseUniform); + } + if (has_attention_bias_) { + shader.AddInput("attention_bias", ShaderVariable::UseUniform); + } + + shader.AddOutput("output", ShaderVariable::UseUniform | ShaderVariable::UseValueTypeAlias); + if (has_present_key_) { + shader.AddOutput("present_key", ShaderVariable::UseUniform); + } + + shader.AppendImplementation("const TILE_SIZE = ", tile_size_, "u;\n") + .AppendImplementation("var tileQ: array;\n") + .AppendImplementation("var tileK: array;\n"); + + std::string f32_str = components_ == 4 ? "vec4" : (components_ == 2 ? "vec2" : "f32"); + std::ostringstream ss; + ss << "// x holds the N and y holds the M\n" + << "let headIdx = workgroup_id.z;\n" + << "let m = workgroup_id.y * TILE_SIZE;\n" + << "let n = workgroup_id.x * TILE_SIZE;\n" + << "let qOffset = uniforms.M * uniforms.K * headIdx + m * uniforms.K;\n"; + + if (feed_past_key_ && has_present_key_) { + ss << "let kOffset = uniforms.kv_sequence_length * uniforms.K * headIdx;\n" + << "let pastKeyOffset = uniforms.past_sequence_length * uniforms.K * headIdx;\n"; + } else { + ss << "let kOffset = uniforms.N * uniforms.K * headIdx + n * uniforms.K;\n"; + } + + if (has_present_key_) { + ss << "let presentKeyOffset = headIdx * uniforms.N * uniforms.K;\n"; + } + + ss << "var value = " << f32_str << "(0);\n" + << "for (var w: u32 = 0u; w < uniforms.K; w += TILE_SIZE) {\n" + << "if (global_id.y < uniforms.M && w + local_id.x < uniforms.K) {\n" + << "tileQ[TILE_SIZE * local_id.y + local_id.x] = q[qOffset + local_id.y * uniforms.K + w + local_id.x];\n" + << "}\n" + << "if (n + local_id.y < uniforms.N && w + local_id.x < uniforms.K) {\n" + << "var idx = TILE_SIZE * local_id.y + local_id.x;\n"; + + if (feed_past_key_ && has_present_key_) { + ss << "if (n + local_id.y < uniforms.past_sequence_length) {\n" + << "tileK[idx] = past_key[pastKeyOffset + (n + local_id.y) * uniforms.K + w + local_id.x];\n" + << "} else {\n" + << "tileK[idx] =" + << "key[kOffset + (n + local_id.y - uniforms.past_sequence_length) * uniforms.K + w + local_id.x];\n" + << "}\n"; + } else { + ss << "tileK[idx] = key[kOffset + local_id.y * uniforms.K + w + local_id.x];\n"; + } + + if (has_present_key_) { + ss << "present_key[presentKeyOffset + (n + local_id.y) * uniforms.K + w + local_id.x] = tileK[idx];\n"; + } + + ss << "}\n" + << "workgroupBarrier();\n" + << "for (var k: u32 = 0u; k < TILE_SIZE && w+k < uniforms.K; k++) {\n" + << "value += "<< f32_str << "(tileQ[TILE_SIZE * local_id.y + k] * tileK[TILE_SIZE * local_id.x + k]);\n" + << "}\n" + << "workgroupBarrier();\n" + << "}\n"; + + ss << "let headOffset = headIdx * uniforms.M * uniforms.N;\n" + << "if (global_id.y < uniforms.M && global_id.x < uniforms.N) {\n" + << "let outputIdx = headOffset + global_id.y * uniforms.N + global_id.x;\n" + << "var sum: f32 = " << (components_ == 4 ? "value.x + value.y + value.z + value.w" : + (components_ == 2 ? "value.x + value.y" : "value")) << ";\n"; + + ss << "output[outputIdx] = output_value_t(sum * uniforms.alpha) + " + << (has_attention_bias_ ? "attention_bias[outputIdx]" : "0.0") << ";\n" + << "}\n"; + + shader.SetMainFunctionBody(ss.str()); + + return Status::OK(); +} + +Status ComputeAttentionProbs(onnxruntime::webgpu::ComputeContext& context, int output_count, const Tensor* Q, + const Tensor* K, const Tensor* past_key, const Tensor* attention_bias, Tensor* probs, Tensor* present_key, + AttentionParameters& parameters, int past_sequence_length, int total_sequence_length) { + const float alpha = parameters.scale == 0.0f ? 1.f / sqrt(static_cast(parameters.head_size)) + : parameters.scale; + + const bool feed_past_key = present_key != nullptr && past_key != nullptr && past_key->SizeInBytes() > 0; + const bool has_present_key = output_count > 1 && past_key; + const bool has_attention_bias = attention_bias != nullptr; + const int tile_size = 12; + const int components = parameters.head_size % 4 == 0 ? 4 : (parameters.head_size % 2 == 0 ? 2 : 1); + + AttentionProbsProgram program{"AttentionProbs", feed_past_key, has_present_key, has_attention_bias, tile_size, + components}; + program.AddInputs({{Q, ProgramTensorMetadataDependency::TypeAndRank, components}, + {K, ProgramTensorMetadataDependency::TypeAndRank, components}}); + if (feed_past_key) { + program.AddInput({past_key, ProgramTensorMetadataDependency::TypeAndRank, components}); + } + if (has_attention_bias) { + program.AddInput({attention_bias, ProgramTensorMetadataDependency::TypeAndRank}); + } + program.AddOutputs({{probs, ProgramTensorMetadataDependency::Rank}}); + if (has_present_key) { + program.AddOutput({present_key, ProgramTensorMetadataDependency::Rank, components}); + } + + const uint32_t vectorized_head_size = parameters.head_size / components; + program.SetDispatchGroupSize((total_sequence_length + tile_size - 1) / tile_size, + (parameters.sequence_length + tile_size - 1) / tile_size, + parameters.batch_size * parameters.num_heads) + .SetWorkgroupSize(tile_size, tile_size) + .AddUniformVariables({ + {static_cast(parameters.sequence_length)}, + {static_cast(vectorized_head_size)}, + {static_cast(total_sequence_length)}, + {static_cast(parameters.num_heads)}, + {static_cast(alpha)}, + {static_cast(past_sequence_length)}, + {static_cast(parameters.kv_sequence_length)}}); + + return context.RunProgram(program); +} + +Status InPlaceSoftmaxProgram::GenerateShaderCode(ShaderHelper& shader) const { + shader.AddOutput("x", ShaderVariable::UseUniform | ShaderVariable::UseValueTypeAlias | + ShaderVariable::UseElementTypeAlias); + shader.AppendImplementation("var thread_max: array;\n") + .AppendImplementation("var thread_sum: array;\n"); + + std::string f32_str = components_ == 4 ? "vec4" : (components_ == 2 ? "vec2" : "f32"); + std::ostringstream ss; + ss << "let local_offset = local_idx * uniforms.elements_per_thread;\n" + << "let offset = (global_idx / " << work_group_size_ << ") * uniforms.d_comp + local_offset;\n" + << "var thread_max_vector = " << f32_str << "(-3.402823e+38f);\n" + << "for (var i: u32 = 0; i < uniforms.elements_per_thread && i + local_offset < uniforms.d_comp; i++) {\n" + << "thread_max_vector = max(" << f32_str << "(x[offset + i]), thread_max_vector);\n" + << "}\n" + << "thread_max[local_idx] = " << (components_ == 4 ? + "max(max(thread_max_vector.x, thread_max_vector.y), max(thread_max_vector.z, thread_max_vector.w))" : + (components_ == 2 ? "max(thread_max_vector.x, thread_max_vector.y)" : "thread_max_vector")) << ";\n" + << "workgroupBarrier();\n" + << "var max_value = f32(-3.402823e+38f);\n" + << "for (var i = 0u; i < " << work_group_size_ << "; i++) {\n" + << "max_value = max(thread_max[i], max_value);\n" + << "}\n" + << "var sum_vector = " << f32_str << "(0);\n" + << "for (var i: u32 = 0; i < uniforms.elements_per_thread && i + local_offset < uniforms.d_comp; i++) {\n" + << "sum_vector += exp(" << f32_str << "(x[offset + i]) - max_value);\n" + << "}\n" + << "thread_sum[local_idx] = " << (components_ == 4 ? "sum_vector.x + sum_vector.y + sum_vector.z + sum_vector.w" : + (components_ == 2 ? "sum_vector.x + sum_vector.y" : "sum_vector")) << ";\n" + << "workgroupBarrier();\n" + << "var sum: f32 = 0;\n" + << "for (var i = 0u; i < " << work_group_size_ << "; i++) {\n" + << "sum += thread_sum[i]\n;" + << "}\n" + << "if (sum == 0) {\n" + << "for (var i: u32 = 0; i < uniforms.elements_per_thread && i + local_offset < uniforms.d_comp; i++) {\n" + << "x[offset + i] = x_value_t(x_element_t(uniforms.d_inv));\n" + << "}\n" + << "} else {\n" + << "for (var i: u32 = 0; i < uniforms.elements_per_thread && i + local_offset < uniforms.d_comp; i++) {\n" + << "var f32input = " << f32_str << "(x[offset + i]);\n" + << "x[offset + i] = x_value_t(exp(f32input - max_value) / sum);\n" + << "}\n" + << "}\n"; + + shader.SetMainFunctionBody(ss.str()); + + return Status::OK(); +} + +Status ComputeInPlaceSoftmax(onnxruntime::webgpu::ComputeContext& context, Tensor* probs, int n, int d) { + const int components = d % 4 == 0 ? 4 : (d % 2 == 0 ? 2 : 1); + int work_group_size = 64; + const int d_comp = d / components; + if (d_comp < work_group_size) { + work_group_size = 32; + } + const int elementsPerThread = (d_comp + work_group_size - 1) / work_group_size; + + InPlaceSoftmaxProgram program{"InPlaceSoftmax", work_group_size, components}; + program.AddOutputs({{probs, ProgramTensorMetadataDependency::TypeAndRank, components}}) + .SetDispatchGroupSize(n) + .SetWorkgroupSize(work_group_size) + .AddUniformVariables({{static_cast(1.f / static_cast(d))}, + {static_cast(d_comp)}, + {static_cast(elementsPerThread)} + }); + + return context.RunProgram(program); +} + +Status VxAttentionScoreProgram::GenerateShaderCode(ShaderHelper& shader) const { + shader.AddInput("probs", ShaderVariable::UseUniform | ShaderVariable::UseValueTypeAlias | + ShaderVariable::UseElementTypeAlias); + shader.AddInput("v", ShaderVariable::UseUniform | ShaderVariable::UseValueTypeAlias); + if (feed_past_value_) { + shader.AddInput("past_value", ShaderVariable::UseUniform); + } + + shader.AddOutput("output", ShaderVariable::UseUniform); + if (has_present_value_) { + shader.AddOutput("present_value", ShaderVariable::UseUniform); + } + + shader.AppendImplementation("const TILE_SIZE = ", tile_size_, "u;\n") + .AppendImplementation("var tileQ: array;\n") + .AppendImplementation("var tileK: array;\n"); + + std::ostringstream ss; + ss << "let headIdx = workgroup_id.z;\n" + << "let m = global_id.y;\n" + << "let n = global_id.x;\n" + << "let offsetA = headIdx * (uniforms.M * uniforms.K) + m * uniforms.K;\n"; + + if (feed_past_value_ && has_present_value_) { + ss << "let pastValueOffset = headIdx * uniforms.N * uniforms.past_sequence_length + n;\n" + << "let vOffset = headIdx * uniforms.N * uniforms.kv_sequence_length + n;\n"; + } else { + ss << "let offsetB = headIdx * uniforms.N * uniforms.K + n;\n"; + } + + if (has_present_value_) { + ss << "let presentValueOffset = headIdx * uniforms.N * uniforms.K + n;\n"; + } + + ss << "var value = probs_element_t(0);\n" + << "for (var w: u32 = 0u; w < uniforms.K; w += TILE_SIZE) {\n" + << "if (m < uniforms.M && w + local_id.x < uniforms.K) {\n" + << "tileQ[TILE_SIZE * local_id.y + local_id.x] = probs[offsetA + w + local_id.x];\n" + << "}\n" + << "if (n < uniforms.N && w + local_id.y < uniforms.K) {\n" + << "var idx = TILE_SIZE * local_id.y + local_id.x;\n"; + + if (feed_past_value_ && has_present_value_) { + ss << "if (w + local_id.y < uniforms.past_sequence_length) {\n" + << "tileK[idx] = past_value[pastValueOffset + (w + local_id.y) * uniforms.N];\n" + << "} else {\n" + << "tileK[idx] = v[vOffset + (w + local_id.y - uniforms.past_sequence_length) * uniforms.N];\n" + << "}\n"; + } else { + ss << "tileK[idx] = v[offsetB + (w + local_id.y) * uniforms.N];\n"; + } + + if (has_present_value_) { + ss << "present_value[presentValueOffset + (w + local_id.y) * uniforms.N] = tileK[idx];\n"; + } + + ss << "}\n" + << "workgroupBarrier();\n" + << "for (var k: u32 = 0u; k < TILE_SIZE && w+k < uniforms.K; k++) {\n" + << "value += tileQ[TILE_SIZE * local_id.y + k] * tileK[TILE_SIZE * k + local_id.x];\n" + << "}\n" + << "workgroupBarrier();\n" + << "}\n"; + + ss << "// we need to transpose output from BNSH_v to BSND_v\n" + << "let batchIdx = workgroup_id.z / uniforms.num_heads;\n" + << "let currentBatchHeadNumber = workgroup_id.z % uniforms.num_heads;\n" + << "if (m < uniforms.M && n < uniforms.N) {\n" + << "let outputIdx = batchIdx * uniforms.M * uniforms.v_hidden_size + " + << "m * uniforms.v_hidden_size + currentBatchHeadNumber * uniforms.N + n;\n" + << "output[outputIdx] = value;\n" + << "}\n"; + + shader.SetMainFunctionBody(ss.str()); + + return Status::OK(); +} + +Status ComputeVxAttentionScore(onnxruntime::webgpu::ComputeContext& context, int output_count, + const Tensor* probs, + const Tensor* V, + const Tensor* past_value, + Tensor* output, + Tensor* present_value, + AttentionParameters& parameters, + int past_sequence_length, + int total_sequence_length) { + const bool feed_past_value = present_value != nullptr && past_value !=nullptr && past_value->SizeInBytes() > 0; + const bool has_present_value = output_count > 1 && past_value != nullptr; + const int tile_size = 12; + + VxAttentionScoreProgram program{"VxAttentionScore", feed_past_value, has_present_value, tile_size}; + program.AddInputs({{probs, ProgramTensorMetadataDependency::TypeAndRank}, + {V, ProgramTensorMetadataDependency::TypeAndRank}}); + if (feed_past_value) { + program.AddInput({past_value, ProgramTensorMetadataDependency::TypeAndRank}); + } + program.AddOutputs({{output, ProgramTensorMetadataDependency::TypeAndRank}}); + if (has_present_value) { + program.AddOutput({present_value, ProgramTensorMetadataDependency::TypeAndRank}); + } + + program.SetDispatchGroupSize((parameters.v_head_size + tile_size - 1) / tile_size, + (parameters.sequence_length + tile_size - 1) / tile_size, + parameters.batch_size * parameters.num_heads) + .SetWorkgroupSize(tile_size, tile_size) + .AddUniformVariables({ + {static_cast(parameters.sequence_length)}, + {static_cast(total_sequence_length)}, + {static_cast(parameters.v_head_size)}, + {static_cast(parameters.num_heads)}, + {static_cast(parameters.v_hidden_size)}, + {static_cast(past_sequence_length)}, + {static_cast(parameters.kv_sequence_length)}}); + + return context.RunProgram(program); +} + +Status ApplyAttention(const Tensor* Q, const Tensor* K, const Tensor* V, const Tensor* attention_bias, + const Tensor* past_key, const Tensor* past_value, Tensor* output, Tensor* present_key, Tensor* present_value, + AttentionParameters& parameters, onnxruntime::webgpu::ComputeContext& context) { + const int output_count = std::min({context.OutputCount(), 1 + (past_key != nullptr ? 1 : 0) + + (past_value !=nullptr ? 1 : 0)}); + const int past_sequence_length = output_count > 1 ? parameters.past_sequence_length : 0; + const int total_sequence_length = past_sequence_length + parameters.kv_sequence_length; + + const TensorShapeVector probs_dims({parameters.batch_size, parameters.num_heads, + parameters.sequence_length, total_sequence_length}); + const TensorShape probs_shape(probs_dims); + Tensor probs = context.CreateGPUTensor(Q->DataType(), probs_shape); + ORT_RETURN_IF_ERROR(ComputeAttentionProbs(context, output_count, Q, K, past_key, attention_bias, &probs, present_key, + parameters, past_sequence_length, total_sequence_length)); + + ORT_RETURN_IF_ERROR(ComputeInPlaceSoftmax(context, &probs, + parameters.batch_size * parameters.num_heads * parameters.sequence_length, total_sequence_length)); + + ORT_RETURN_IF_ERROR(ComputeVxAttentionScore(context, output_count, &probs, V, past_value, output, present_value, + parameters, past_sequence_length, total_sequence_length)); + + return Status::OK(); +} + +MultiHeadAttention::MultiHeadAttention(const OpKernelInfo& info) + : WebGpuKernel(info) { + int64_t num_heads = 0; + ORT_ENFORCE(info.GetAttr("num_heads", &num_heads).IsOK() && num_heads > 0); + num_heads_ = static_cast(num_heads); + mask_filter_value_ = info.GetAttrOrDefault("mask_filter_value", -10000.0f); + scale_ = info.GetAttrOrDefault("scale", 0.0f); + is_unidirectional_ = info.GetAttrOrDefault("unidirectional", 0) == 1; + ORT_ENFORCE(!is_unidirectional_, "Unidirectional MHA does not support webgpu kernel"); +} + +Status MultiHeadAttention::ComputeInternal(onnxruntime::webgpu::ComputeContext& context) const { + const Tensor* query = context.Input(0); + const Tensor* key = context.Input(1); + const Tensor* value = context.Input(2); + const Tensor* bias = context.Input(3); + const Tensor* key_padding_mask = context.Input(4); + const Tensor* attention_bias = context.Input(5); + const Tensor* past_key = context.Input(6); + const Tensor* past_value = context.Input(7); + + if (query->Shape().GetDims().size() == 5) { + ORT_NOT_IMPLEMENTED("Packed QKV of shape (B, L, N, 3, H) not implemented for webgpu"); + } + if (key != nullptr && key->Shape().GetDims().size() == 5) { + ORT_NOT_IMPLEMENTED("Packed KV not implemented for webgpu"); + } + if (key_padding_mask) { + ORT_NOT_IMPLEMENTED("input `key_padding_mask` not implemented for webgpu"); + } + + AttentionParameters parameters; + ORT_RETURN_IF_ERROR(multihead_attention_helper::CheckInputs(query, key, value, + bias, key_padding_mask, attention_bias, past_key, past_value, nullptr, ¶meters, + num_heads_, mask_filter_value_, scale_, is_unidirectional_, false, kMultiHeadAttention, + context.DeviceLimits().maxComputeInvocationsPerWorkgroup)); + + TensorShapeVector output_shape(3); + output_shape[0] = static_cast(parameters.batch_size); + output_shape[1] = static_cast(parameters.sequence_length); + output_shape[2] = static_cast(parameters.v_hidden_size); + Tensor* output = context.Output(0, output_shape); + + // If optional outputs aren't needed, present_key and present_value will be null + std::vector present_dims{ + parameters.batch_size, + parameters.num_heads, + parameters.total_sequence_length, + parameters.head_size, + }; + TensorShape present_shape(present_dims); + Tensor* present_key = context.Output(1, present_shape); + Tensor* present_value = context.Output(2, present_shape); + + TensorShapeVector q_new_dims({parameters.batch_size, parameters.num_heads, + parameters.sequence_length, parameters.head_size}); + TensorShape q_new_shape(q_new_dims); + Tensor Q = context.CreateGPUTensor(query->DataType(), q_new_shape); + ORT_RETURN_IF_ERROR(TransferBSDToBNSH( + context, parameters.num_heads, parameters.sequence_length, parameters.head_size, query, bias, 0, &Q)); + + if (parameters.qkv_format == Q_K_V_BSNH_BNSH_BNSH) { // key and value in BNSH format + return ApplyAttention(&Q, key, value, attention_bias, past_key, past_value, output, present_key, + present_value, parameters, context); + } + + TensorShapeVector k_new_dims({parameters.batch_size, parameters.num_heads, + parameters.kv_sequence_length, parameters.head_size}); + TensorShape k_new_shape(k_new_dims); + Tensor K = context.CreateGPUTensor(key->DataType(), k_new_shape); + ORT_RETURN_IF_ERROR(TransferBSDToBNSH(context, parameters.num_heads, parameters.kv_sequence_length, + parameters.head_size, key, bias, parameters.hidden_size, &K)); + + TensorShapeVector v_new_dims({parameters.batch_size, parameters.num_heads, + parameters.kv_sequence_length, parameters.v_head_size}); + TensorShape v_new_shape(v_new_dims); + Tensor V = context.CreateGPUTensor(value->DataType(), v_new_shape); + ORT_RETURN_IF_ERROR(TransferBSDToBNSH(context, parameters.num_heads, parameters.kv_sequence_length, + parameters.v_head_size, value, bias, 2 * parameters.hidden_size, &V)); + + // Compute the attention score and apply the score to V + return ApplyAttention(&Q, &K, &V, attention_bias, past_key, past_value, output, present_key, + present_value, parameters, context); +} + +} // namespace webgpu +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/webgpu/bert/multihead_attention.h b/onnxruntime/contrib_ops/webgpu/bert/multihead_attention.h new file mode 100644 index 0000000000000..86d49ab242ccd --- /dev/null +++ b/onnxruntime/contrib_ops/webgpu/bert/multihead_attention.h @@ -0,0 +1,113 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include "core/providers/webgpu/compute_context.h" +#include "core/providers/webgpu/program.h" +#include "core/providers/webgpu/shader_helper.h" +#include "core/providers/webgpu/webgpu_kernel.h" + +namespace onnxruntime { +namespace contrib { +namespace webgpu { + +using namespace onnxruntime::webgpu; + +class TransferBSDToBNSHProgram final : public Program { + public: + TransferBSDToBNSHProgram(const std::string& kernel_name, bool has_bias) : Program{kernel_name}, has_bias_(has_bias) {} + + Status GenerateShaderCode(ShaderHelper& sh) const override; + + WEBGPU_PROGRAM_DEFINE_UNIFORM_VARIABLES({"data_size", ProgramUniformVariableDataType::Uint32}, + {"batch_offset", ProgramUniformVariableDataType::Uint32}, + {"sequence_offset", ProgramUniformVariableDataType::Uint32}, + {"head_offset", ProgramUniformVariableDataType::Uint32}, + {"bias_offset", ProgramUniformVariableDataType::Uint32}); + +private: + bool has_bias_; +}; + +class AttentionProbsProgram final : public Program { + public: + AttentionProbsProgram(const std::string& kernel_name, bool feed_past_key, bool has_present_key, + bool has_attention_bias,int tile_size, int components) + : Program{kernel_name}, feed_past_key_(feed_past_key), has_present_key_(has_present_key), + has_attention_bias_(has_attention_bias), tile_size_(tile_size), components_(components) { + } + + Status GenerateShaderCode(ShaderHelper& sh) const override; + + WEBGPU_PROGRAM_DEFINE_UNIFORM_VARIABLES({"M", ProgramUniformVariableDataType::Uint32}, + {"K", ProgramUniformVariableDataType::Uint32}, + {"N", ProgramUniformVariableDataType::Uint32}, + {"num_heads", ProgramUniformVariableDataType::Uint32}, + {"alpha", ProgramUniformVariableDataType::Float32}, + {"past_sequence_length", ProgramUniformVariableDataType::Uint32}, + {"kv_sequence_length", ProgramUniformVariableDataType::Uint32}); + +private: + bool feed_past_key_; + bool has_present_key_; + bool has_attention_bias_; + int tile_size_; + int components_; +}; + +class InPlaceSoftmaxProgram final : public Program { + public: + InPlaceSoftmaxProgram(const std::string& kernel_name, int work_group_size, int components) + : Program{kernel_name}, work_group_size_(work_group_size), components_(components) { + } + + Status GenerateShaderCode(ShaderHelper& sh) const override; + + WEBGPU_PROGRAM_DEFINE_UNIFORM_VARIABLES({"d_inv", ProgramUniformVariableDataType::Float32}, + {"d_comp", ProgramUniformVariableDataType::Uint32}, + {"elements_per_thread", ProgramUniformVariableDataType::Uint32}); + +private: + int work_group_size_; + int components_; +}; + +class VxAttentionScoreProgram final : public Program { + public: + VxAttentionScoreProgram(const std::string& kernel_name, bool feed_past_value, bool has_present_value,int tile_size) + : Program{kernel_name}, feed_past_value_(feed_past_value), has_present_value_(has_present_value), + tile_size_(tile_size) { + } + + Status GenerateShaderCode(ShaderHelper& sh) const override; + + WEBGPU_PROGRAM_DEFINE_UNIFORM_VARIABLES({"M", ProgramUniformVariableDataType::Uint32}, + {"K", ProgramUniformVariableDataType::Uint32}, + {"N", ProgramUniformVariableDataType::Uint32}, + {"num_heads", ProgramUniformVariableDataType::Uint32}, + {"v_hidden_size", ProgramUniformVariableDataType::Uint32}, + {"past_sequence_length", ProgramUniformVariableDataType::Uint32}, + {"kv_sequence_length", ProgramUniformVariableDataType::Uint32}); + +private: + bool feed_past_value_; + bool has_present_value_; + int tile_size_; +}; + +class MultiHeadAttention final : public WebGpuKernel { + public: + MultiHeadAttention(const OpKernelInfo& info); + Status ComputeInternal(onnxruntime::webgpu::ComputeContext& context) const override; + + protected: + int num_heads_; + float mask_filter_value_; + float scale_; + bool is_unidirectional_{false}; +}; + +} // namespace webgpu +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/webgpu/webgpu_contrib_kernels.cc b/onnxruntime/contrib_ops/webgpu/webgpu_contrib_kernels.cc index def104b6cb108..84ab684c32c09 100644 --- a/onnxruntime/contrib_ops/webgpu/webgpu_contrib_kernels.cc +++ b/onnxruntime/contrib_ops/webgpu/webgpu_contrib_kernels.cc @@ -45,7 +45,7 @@ Status RegisterWebGpuContribKernels(KernelRegistry& kernel_registry) { // // LayerNormalization used to be a contrib op that (incorrectly) used kOnnxDomain so we need to version it // BuildKernelCreateInfo, // BuildKernelCreateInfo, - // BuildKernelCreateInfo, + BuildKernelCreateInfo, // BuildKernelCreateInfo, // BuildKernelCreateInfo, // BuildKernelCreateInfo("num_heads", static_cast(num_heads)); tester.AddAttribute("mask_filter_value", static_cast(-10000.0f)); @@ -266,6 +268,12 @@ static void RunMultiHeadAttentionTest( execution_providers.push_back(DefaultDmlExecutionProvider()); tester.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers); } + + if (enable_webgpu) { + std::vector> execution_providers; + execution_providers.push_back(DefaultWebGpuExecutionProvider()); + tester.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers); + } } } @@ -295,8 +303,10 @@ static void RunMultiHeadAttentionKernel( bool is_static_kv = true, bool disable_cpu = false, // some cases not supported in cpu right now. bool disable_cuda = false, + bool disable_webgpu = false, bool disable_rocm = DISABLE_ROCM, - bool disable_dml = false) { + bool disable_dml = false + ) { if (kernel_type == AttentionKernelType::AttentionKernel_Default) { ScopedEnvironmentVariables scoped_env_vars{ EnvVarMap{ @@ -309,7 +319,8 @@ static void RunMultiHeadAttentionKernel( query_data, key_data, value_data, kv_data, qkv_data, bias_data, attention_bias_data, past_key_data, past_value_data, present_key_data, present_value_data, key_padding_mask_data, mask_type, output_data, num_heads, batch_size, sequence_length, kv_sequence_length, - hidden_size, v_hidden_size, is_static_kv, use_float16, disable_cpu, disable_cuda, disable_rocm, disable_dml); + hidden_size, v_hidden_size, is_static_kv, use_float16, disable_cpu, disable_cuda, disable_rocm, + disable_dml, disable_webgpu); return; } @@ -325,7 +336,8 @@ static void RunMultiHeadAttentionKernel( query_data, key_data, value_data, kv_data, qkv_data, bias_data, attention_bias_data, past_key_data, past_value_data, present_key_data, present_value_data, key_padding_mask_data, mask_type, output_data, num_heads, batch_size, sequence_length, kv_sequence_length, - hidden_size, v_hidden_size, is_static_kv, use_float16, disable_cpu, disable_cuda, disable_rocm, disable_dml); + hidden_size, v_hidden_size, is_static_kv, use_float16, disable_cpu, disable_cuda, disable_rocm, + disable_dml, disable_webgpu); return; } @@ -341,7 +353,8 @@ static void RunMultiHeadAttentionKernel( query_data, key_data, value_data, kv_data, qkv_data, bias_data, attention_bias_data, past_key_data, past_value_data, present_key_data, present_value_data, key_padding_mask_data, mask_type, output_data, num_heads, batch_size, sequence_length, kv_sequence_length, - hidden_size, v_hidden_size, is_static_kv, use_float16, disable_cpu, disable_cuda, disable_rocm, disable_dml); + hidden_size, v_hidden_size, is_static_kv, use_float16, disable_cpu, disable_cuda, disable_rocm, + disable_dml, disable_webgpu); return; } @@ -358,7 +371,8 @@ static void RunMultiHeadAttentionKernel( query_data, key_data, value_data, kv_data, qkv_data, bias_data, attention_bias_data, past_key_data, past_value_data, present_key_data, present_value_data, key_padding_mask_data, mask_type, output_data, num_heads, batch_size, sequence_length, kv_sequence_length, - hidden_size, v_hidden_size, is_static_kv, use_float16, disable_cpu, disable_cuda, disable_rocm, disable_dml); + hidden_size, v_hidden_size, is_static_kv, use_float16, disable_cpu, disable_cuda, disable_rocm, + disable_dml, disable_webgpu); return; } #endif @@ -376,7 +390,8 @@ static void RunMultiHeadAttentionKernel( query_data, key_data, value_data, kv_data, qkv_data, bias_data, attention_bias_data, past_key_data, past_value_data, present_key_data, present_value_data, key_padding_mask_data, mask_type, output_data, num_heads, batch_size, sequence_length, kv_sequence_length, - hidden_size, v_hidden_size, is_static_kv, use_float16, disable_cpu, disable_cuda, disable_rocm, disable_dml); + hidden_size, v_hidden_size, is_static_kv, use_float16, disable_cpu, disable_cuda, disable_rocm, + disable_dml, disable_webgpu); } if (kernel_type == AttentionKernelType::AttentionKernel_CudnnFlashAttention) { @@ -392,11 +407,13 @@ static void RunMultiHeadAttentionKernel( query_data, key_data, value_data, kv_data, qkv_data, bias_data, attention_bias_data, past_key_data, past_value_data, present_key_data, present_value_data, key_padding_mask_data, mask_type, output_data, num_heads, batch_size, sequence_length, kv_sequence_length, - hidden_size, v_hidden_size, is_static_kv, use_float16, disable_cpu, disable_cuda, disable_rocm, disable_dml); + hidden_size, v_hidden_size, is_static_kv, use_float16, disable_cpu, disable_cuda, disable_rocm, + disable_dml, disable_webgpu); } } -static void RunMultiHeadAttentionTests(AttentionTestData& data, bool disable_cpu = false, bool disable_cuda = false) { +static void RunMultiHeadAttentionTests(AttentionTestData& data, bool disable_webgpu = true, bool disable_cpu = false, + bool disable_cuda = false) { if (data.fp32_output_data.size() > 0) { constexpr bool use_float16 = false; @@ -407,7 +424,7 @@ static void RunMultiHeadAttentionTests(AttentionTestData& data, bool disable_cpu data.attention_bias_data, data.past_key_data, data.past_value_data, data.present_key_data, data.present_value_data, data.key_padding_mask_data, data.mask_type, data.fp32_output_data, data.num_heads, data.batch_size, data.sequence_length, data.kv_sequence_length, data.hidden_size, - data.v_hidden_size, kernel_type, use_float16, data.is_static_kv, disable_cpu, disable_cuda); + data.v_hidden_size, kernel_type, use_float16, data.is_static_kv, disable_cpu, disable_cuda, disable_webgpu); } #if USE_MEMORY_EFFICIENT_ATTENTION @@ -420,7 +437,7 @@ static void RunMultiHeadAttentionTests(AttentionTestData& data, bool disable_cpu data.attention_bias_data, data.past_key_data, data.past_value_data, data.present_key_data, data.present_value_data, data.key_padding_mask_data, data.mask_type, data.fp32_output_data, data.num_heads, data.batch_size, data.sequence_length, data.kv_sequence_length, data.hidden_size, - data.v_hidden_size, kernel_type, use_float16, data.is_static_kv, disable_cpu, disable_cuda); + data.v_hidden_size, kernel_type, use_float16, data.is_static_kv, disable_cpu, disable_cuda, disable_webgpu); } } #endif @@ -431,7 +448,7 @@ static void RunMultiHeadAttentionTests(AttentionTestData& data, bool disable_cpu data.attention_bias_data, data.past_key_data, data.past_value_data, data.present_key_data, data.present_value_data, data.key_padding_mask_data, data.mask_type, data.fp32_output_data, data.num_heads, data.batch_size, data.sequence_length, data.kv_sequence_length, data.hidden_size, - data.v_hidden_size, kernel_type, use_float16, data.is_static_kv, disable_cpu, disable_cuda); + data.v_hidden_size, kernel_type, use_float16, data.is_static_kv, disable_cpu, disable_cuda, disable_webgpu); } if (data.fp16_output_data.size() > 0) { @@ -443,7 +460,7 @@ static void RunMultiHeadAttentionTests(AttentionTestData& data, bool disable_cpu data.attention_bias_data, data.past_key_data, data.past_value_data, data.present_key_data, data.present_value_data, data.key_padding_mask_data, data.mask_type, data.fp16_output_data, data.num_heads, data.batch_size, data.sequence_length, data.kv_sequence_length, data.hidden_size, - data.v_hidden_size, kernel_type, use_float16, data.is_static_kv, disable_cpu, disable_cuda); + data.v_hidden_size, kernel_type, use_float16, data.is_static_kv, disable_cpu, disable_cuda, disable_webgpu); } kernel_type = AttentionKernelType::AttentionKernel_TrtFusedAttention; @@ -453,7 +470,7 @@ static void RunMultiHeadAttentionTests(AttentionTestData& data, bool disable_cpu data.attention_bias_data, data.past_key_data, data.past_value_data, data.present_key_data, data.present_value_data, data.key_padding_mask_data, data.mask_type, data.fp16_output_data, data.num_heads, data.batch_size, data.sequence_length, data.kv_sequence_length, data.hidden_size, - data.v_hidden_size, kernel_type, use_float16, data.is_static_kv, disable_cpu, disable_cuda); + data.v_hidden_size, kernel_type, use_float16, data.is_static_kv, disable_cpu, disable_cuda, disable_webgpu); } #if USE_MEMORY_EFFICIENT_ATTENTION @@ -464,7 +481,7 @@ static void RunMultiHeadAttentionTests(AttentionTestData& data, bool disable_cpu data.attention_bias_data, data.past_key_data, data.past_value_data, data.present_key_data, data.present_value_data, data.key_padding_mask_data, data.mask_type, data.fp16_output_data, data.num_heads, data.batch_size, data.sequence_length, data.kv_sequence_length, data.hidden_size, - data.v_hidden_size, kernel_type, use_float16, data.is_static_kv, disable_cpu, disable_cuda); + data.v_hidden_size, kernel_type, use_float16, data.is_static_kv, disable_cpu, disable_cuda, disable_webgpu); } #endif @@ -475,7 +492,7 @@ static void RunMultiHeadAttentionTests(AttentionTestData& data, bool disable_cpu data.attention_bias_data, data.past_key_data, data.past_value_data, data.present_key_data, data.present_value_data, data.key_padding_mask_data, data.mask_type, data.fp16_output_data, data.num_heads, data.batch_size, data.sequence_length, data.kv_sequence_length, data.hidden_size, - data.v_hidden_size, kernel_type, use_float16, data.is_static_kv, disable_cpu, disable_cuda); + data.v_hidden_size, kernel_type, use_float16, data.is_static_kv, disable_cpu, disable_cuda, disable_webgpu); } kernel_type = AttentionKernelType::AttentionKernel_Default; @@ -484,7 +501,7 @@ static void RunMultiHeadAttentionTests(AttentionTestData& data, bool disable_cpu data.attention_bias_data, data.past_key_data, data.past_value_data, data.present_key_data, data.present_value_data, data.key_padding_mask_data, data.mask_type, data.fp16_output_data, data.num_heads, data.batch_size, data.sequence_length, data.kv_sequence_length, data.hidden_size, - data.v_hidden_size, kernel_type, use_float16, data.is_static_kv, disable_cpu, disable_cuda); + data.v_hidden_size, kernel_type, use_float16, data.is_static_kv, disable_cpu, disable_cuda, disable_webgpu); } } @@ -493,75 +510,75 @@ static void RunMultiHeadAttentionTests(AttentionTestData& data, bool disable_cpu TEST(MultiHeadAttentionTest, CrossAttention_Batch2_HeadSize40) { AttentionTestData data; GetCrossAttentionData_HeadSize40(data); - RunMultiHeadAttentionTests(data); + RunMultiHeadAttentionTests(data, /*disable_webgpu=*/false); GetCrossAttentionData_HeadSize40_NoBias(data); - RunMultiHeadAttentionTests(data); + RunMultiHeadAttentionTests(data, /*disable_webgpu=*/false); } TEST(MultiHeadAttentionTest, CrossAttention_Batch2_HeadSize32_RightSidePadding_Mask1D) { ROCM_GTEST_SKIP("ROCm MHA does not support mask type of MASK_1D_KEY_SEQ_LEN"); AttentionTestData data; GetCrossAttentionData_Batch2_HeadSize32_RightSidePadding(data, true); - RunMultiHeadAttentionTests(data, true); + RunMultiHeadAttentionTests(data, /*disable_webgpu=*/true, /*disable_cpu=*/true); GetCrossAttentionData_Batch2_HeadSize32_RightSidePadding_NoBias(data, true); - RunMultiHeadAttentionTests(data, true); + RunMultiHeadAttentionTests(data, /*disable_webgpu=*/true, /*disable_cpu=*/true); } TEST(MultiHeadAttentionTest, CrossAttention_Batch2_HeadSize32_RightSidePadding_Mask2D) { AttentionTestData data; GetCrossAttentionData_Batch2_HeadSize32_RightSidePadding(data, false); - RunMultiHeadAttentionTests(data, true); + RunMultiHeadAttentionTests(data, /*disable_webgpu=*/true, /*disable_cpu=*/true); GetCrossAttentionData_Batch2_HeadSize32_RightSidePadding_NoBias(data, false); - RunMultiHeadAttentionTests(data, true); + RunMultiHeadAttentionTests(data, /*disable_webgpu=*/true, /*disable_cpu=*/true); } TEST(MultiHeadAttentionTest, CrossAttention_Batch1_HeadSize32_LeftSidePadding_Mask2D) { AttentionTestData data; GetCrossAttentionData_Batch1_HeadSize32_LeftSidePadding(data); - RunMultiHeadAttentionTests(data, true); + RunMultiHeadAttentionTests(data, /*disable_webgpu=*/true, /*disable_cpu=*/true); GetCrossAttentionData_Batch1_HeadSize32_LeftSidePadding_NoBias(data); - RunMultiHeadAttentionTests(data, true); + RunMultiHeadAttentionTests(data, /*disable_webgpu=*/true, /*disable_cpu=*/true); } TEST(MultiHeadAttentionTest, CrossAttention_Batch2_HeadSize32_NoBias_NoMask_PackedKV) { AttentionTestData data; GetCrossAttentionData_Batch2_HeadSize32_NoBias_NoMask_PackedKV(data); - RunMultiHeadAttentionTests(data); + RunMultiHeadAttentionTests(data, /*disable_webgpu=*/true); } TEST(MultiHeadAttentionTest, SelfAttention_Batch2_HeadSize32_NoBias_NoMask_PackedQKV) { AttentionTestData data; GetSelfAttentionData_Batch2_HeadSize32_NoBias_NoMask_PackedQKV(data); - RunMultiHeadAttentionTests(data); + RunMultiHeadAttentionTests(data, /*disable_webgpu=*/true); } // This tests qk_head_size != v_head_size TEST(MultiHeadAttentionTest, CrossAttention_Batch2_HeadSize16_8) { AttentionTestData data; GetCrossAttentionData_HeadSize16_8(data); - RunMultiHeadAttentionTests(data); + RunMultiHeadAttentionTests(data, /*disable_webgpu=*/false); GetCrossAttentionData_HeadSize16_8_NoBias(data); - RunMultiHeadAttentionTests(data); + RunMultiHeadAttentionTests(data, /*disable_webgpu=*/false); } TEST(MultiHeadAttentionTest, CrossAttention_Batch1_HeadSize16) { AttentionTestData data; GetCrossAttentionData_HeadSize16(data); - RunMultiHeadAttentionTests(data); + RunMultiHeadAttentionTests(data, /*disable_webgpu=*/false); GetCrossAttentionData_HeadSize16_NoBias(data); - RunMultiHeadAttentionTests(data); + RunMultiHeadAttentionTests(data, /*disable_webgpu=*/false); } TEST(MultiHeadAttentionTest, CrossAttention_Batch1_HeadSize8) { AttentionTestData data; GetCrossAttentionData_HeadSize8_NoBias(data); - RunMultiHeadAttentionTests(data, false, true); + RunMultiHeadAttentionTests(data, /*disable_webgpu=*/false, false, true); } // TODO (pavignol): Fix this regression @@ -579,7 +596,7 @@ TEST(MultiHeadAttentionTest, SelfAttention_WithPast_WithAttnBias_ForT5) { ROCM_GTEST_SKIP("ROCm MHA only support head_size >= 8"); AttentionTestData data; GetSelfAttentionData_WithPast_WithAttnBias_ForT5(data); - RunMultiHeadAttentionTests(data, true); + RunMultiHeadAttentionTests(data, /*disable_webgpu=*/false, true); } TEST(MultiHeadAttentionTest, AttentionCutlassAttnBias) { @@ -596,23 +613,23 @@ TEST(MultiHeadAttentionTest, CrossAttention_DiffSequenceLengths) { RunMultiHeadAttentionTests(data); GetCrossAttentionData_DiffSequenceLengths_HeadSize8(data); - RunMultiHeadAttentionTests(data, /*disable_cpu=*/false, /*disable_cuda=*/true); + RunMultiHeadAttentionTests(data, /*disable_webgpu=*/true, /*disable_cpu=*/false, /*disable_cuda=*/true); GetCrossAttentionData_DiffSequenceLengths_HeadSize8_NoBias(data); - RunMultiHeadAttentionTests(data, /*disable_cpu=*/false, /*disable_cuda=*/true); + RunMultiHeadAttentionTests(data, /*disable_webgpu=*/true, /*disable_cpu=*/false, /*disable_cuda=*/true); } TEST(MultiHeadAttentionTest, SelfAttention_WithPastAndPresent_NoMask_NoAttnBias) { // Whisper decoder self attention with past_kv and present_kv AttentionTestData data; GetSelfAttentionData_WithPastAndPresent_NoMask_NoAttnBias(data); - RunMultiHeadAttentionTests(data); + RunMultiHeadAttentionTests(data, /*disable_webgpu=*/false); GetSelfAttentionData_WithPastAndPresent_HeadSize8_NoMask_NoAttnBias(data); - RunMultiHeadAttentionTests(data, /*disable_cpu=*/false, /*disable_cuda=*/true); + RunMultiHeadAttentionTests(data, /*disable_webgpu=*/false, /*disable_cpu=*/false, /*disable_cuda=*/true); GetSelfAttentionData_WithPastAndPresent_HeadSize8_NoMask_NoAttnBias_NoBias(data); - RunMultiHeadAttentionTests(data, /*disable_cpu=*/false, /*disable_cuda=*/true); + RunMultiHeadAttentionTests(data, /*disable_webgpu=*/false, /*disable_cpu=*/false, /*disable_cuda=*/true); } // This test is disabled since it is not used in Whisper anymore, and it fails in ROCm. From 53a29ffed7ba4a15123906db8fd132f73914ba34 Mon Sep 17 00:00:00 2001 From: Cao Date: Fri, 20 Sep 2024 16:10:03 +0800 Subject: [PATCH 2/4] Address Yulong's comments --- .../webgpu/bert/multihead_attention.cc | 15 ++++++++------- .../contrib_ops/webgpu/bert/multihead_attention.h | 6 +++++- 2 files changed, 13 insertions(+), 8 deletions(-) diff --git a/onnxruntime/contrib_ops/webgpu/bert/multihead_attention.cc b/onnxruntime/contrib_ops/webgpu/bert/multihead_attention.cc index 3ff1140834754..aa2a557729d2f 100644 --- a/onnxruntime/contrib_ops/webgpu/bert/multihead_attention.cc +++ b/onnxruntime/contrib_ops/webgpu/bert/multihead_attention.cc @@ -57,7 +57,7 @@ Status TransferBSDToBNSH(onnxruntime::webgpu::ComputeContext& context, int num_h const int head_offset = head_size; bool has_bias = bias != nullptr; - TransferBSDToBNSHProgram program{"TransferBSDToBNSH", has_bias}; + TransferBSDToBNSHProgram program{has_bias}; program.AddInputs({{input_tensor, ProgramTensorMetadataDependency::TypeAndRank}}) .AddOutputs({{output_tensor, ProgramTensorMetadataDependency::TypeAndRank}}) .SetDispatchGroupSize((data_size + WORKGROUP_SIZE - 1) / WORKGROUP_SIZE) @@ -91,8 +91,7 @@ Status AttentionProbsProgram::GenerateShaderCode(ShaderHelper& shader) const { shader.AddOutput("present_key", ShaderVariable::UseUniform); } - shader.AppendImplementation("const TILE_SIZE = ", tile_size_, "u;\n") - .AppendImplementation("var tileQ: array;\n") + shader.AppendImplementation("var tileQ: array;\n") .AppendImplementation("var tileK: array;\n"); std::string f32_str = components_ == 4 ? "vec4" : (components_ == 2 ? "vec2" : "f32"); @@ -192,6 +191,7 @@ Status ComputeAttentionProbs(onnxruntime::webgpu::ComputeContext& context, int o (parameters.sequence_length + tile_size - 1) / tile_size, parameters.batch_size * parameters.num_heads) .SetWorkgroupSize(tile_size, tile_size) + .CacheHint(std::to_string(tile_size)) .AddUniformVariables({ {static_cast(parameters.sequence_length)}, {static_cast(vectorized_head_size)}, @@ -199,7 +199,8 @@ Status ComputeAttentionProbs(onnxruntime::webgpu::ComputeContext& context, int o {static_cast(parameters.num_heads)}, {static_cast(alpha)}, {static_cast(past_sequence_length)}, - {static_cast(parameters.kv_sequence_length)}}); + {static_cast(parameters.kv_sequence_length)}}) + .SetOverridableConstants({{static_cast(tile_size)}}); return context.RunProgram(program); } @@ -287,8 +288,7 @@ Status VxAttentionScoreProgram::GenerateShaderCode(ShaderHelper& shader) const { shader.AddOutput("present_value", ShaderVariable::UseUniform); } - shader.AppendImplementation("const TILE_SIZE = ", tile_size_, "u;\n") - .AppendImplementation("var tileQ: array;\n") + shader.AppendImplementation("var tileQ: array;\n") .AppendImplementation("var tileK: array;\n"); std::ostringstream ss; @@ -387,7 +387,8 @@ Status ComputeVxAttentionScore(onnxruntime::webgpu::ComputeContext& context, int {static_cast(parameters.num_heads)}, {static_cast(parameters.v_hidden_size)}, {static_cast(past_sequence_length)}, - {static_cast(parameters.kv_sequence_length)}}); + {static_cast(parameters.kv_sequence_length)}}) + .SetOverridableConstants({{static_cast(tile_size)}});; return context.RunProgram(program); } diff --git a/onnxruntime/contrib_ops/webgpu/bert/multihead_attention.h b/onnxruntime/contrib_ops/webgpu/bert/multihead_attention.h index 86d49ab242ccd..01672298b375f 100644 --- a/onnxruntime/contrib_ops/webgpu/bert/multihead_attention.h +++ b/onnxruntime/contrib_ops/webgpu/bert/multihead_attention.h @@ -16,7 +16,7 @@ using namespace onnxruntime::webgpu; class TransferBSDToBNSHProgram final : public Program { public: - TransferBSDToBNSHProgram(const std::string& kernel_name, bool has_bias) : Program{kernel_name}, has_bias_(has_bias) {} + TransferBSDToBNSHProgram(bool has_bias) : Program{"TransferBSDToBNSH"}, has_bias_(has_bias) {} Status GenerateShaderCode(ShaderHelper& sh) const override; @@ -48,6 +48,8 @@ class AttentionProbsProgram final : public Program { {"past_sequence_length", ProgramUniformVariableDataType::Uint32}, {"kv_sequence_length", ProgramUniformVariableDataType::Uint32}); + WEBGPU_PROGRAM_DEFINE_OVERRIDABLE_CONSTANTS({"TILE_SIZE", ProgramConstantDataType::Uint32}); + private: bool feed_past_key_; bool has_present_key_; @@ -90,6 +92,8 @@ class VxAttentionScoreProgram final : public Program { {"past_sequence_length", ProgramUniformVariableDataType::Uint32}, {"kv_sequence_length", ProgramUniformVariableDataType::Uint32}); + WEBGPU_PROGRAM_DEFINE_OVERRIDABLE_CONSTANTS({"TILE_SIZE", ProgramConstantDataType::Uint32}); + private: bool feed_past_value_; bool has_present_value_; From f5f201fb71bd3889fba775a6ba02cc39214f5e23 Mon Sep 17 00:00:00 2001 From: Yulong Wang <7679871+fs-eire@users.noreply.github.com> Date: Mon, 30 Sep 2024 04:06:58 -0700 Subject: [PATCH 3/4] update impl --- .../webgpu/bert/multihead_attention.cc | 436 +++++++++--------- .../webgpu/bert/multihead_attention.h | 20 +- .../multihead_attention_op_test.cc | 3 +- 3 files changed, 221 insertions(+), 238 deletions(-) diff --git a/onnxruntime/contrib_ops/webgpu/bert/multihead_attention.cc b/onnxruntime/contrib_ops/webgpu/bert/multihead_attention.cc index aa2a557729d2f..d836c1ddf8675 100644 --- a/onnxruntime/contrib_ops/webgpu/bert/multihead_attention.cc +++ b/onnxruntime/contrib_ops/webgpu/bert/multihead_attention.cc @@ -17,37 +17,41 @@ namespace contrib { namespace webgpu { ONNX_OPERATOR_KERNEL_EX( - MultiHeadAttention, - kMSDomain, - 1, - kWebGpuExecutionProvider, - (*KernelDefBuilder::Create()) - .TypeConstraint("T", WebGpuSupportedFloatTypes()), - MultiHeadAttention); - + MultiHeadAttention, + kMSDomain, + 1, + kWebGpuExecutionProvider, + (*KernelDefBuilder::Create()) + .TypeConstraint("T", WebGpuSupportedFloatTypes()), + MultiHeadAttention); Status TransferBSDToBNSHProgram::GenerateShaderCode(ShaderHelper& shader) const { - shader.AddInput("qkv_input", ShaderVariable::UseUniform); - const auto& qkv_output = shader.AddOutput("qkv_output", ShaderVariable::UseUniform | - ShaderVariable::UseOffsetToIndices); + shader.AddInput("qkv_input", ShaderUsage::UseUniform); + const auto& qkv_output = shader.AddOutput("qkv_output", ShaderUsage::UseUniform | ShaderUsage::UseOffsetToIndices); if (has_bias_) { - shader.AddInput("bias", ShaderVariable::UseUniform); + shader.AddInput("bias", ShaderUsage::UseUniform); } - shader.SetMainFunctionBody(shader.GuardAgainstOutOfBoundsWorkgroupSizes("uniforms.data_size"), - "let output_indices = ", qkv_output.OffsetToIndices("global_idx"), ";\n", - "let input_offset_idx = output_indices[0] * uniforms.batch_offset + output_indices[1] * ", - "uniforms.head_offset + output_indices[2] * uniforms.sequence_offset + output_indices[3];\n", - has_bias_ ? "let bias_offset_idx = (input_offset_idx % uniforms.sequence_offset) + uniforms.bias_offset;\n" : "", - "qkv_output[global_idx] = qkv_input[input_offset_idx]", - has_bias_ ? " + bias[bias_offset_idx];\n" : ";\n"); + shader.MainFunctionBody() << shader.GuardAgainstOutOfBoundsWorkgroupSizes("uniforms.data_size") + << "let output_indices = " << qkv_output.OffsetToIndices("global_idx") << ";\n" + << "let input_offset_idx = output_indices[0] * uniforms.batch_offset + output_indices[1] *" + << " uniforms.head_offset + output_indices[2] * uniforms.sequence_offset + output_indices[3];\n"; + if (has_bias_) { + shader.MainFunctionBody() << "let bias_offset_idx = (input_offset_idx % uniforms.sequence_offset) + uniforms.bias_offset;\n"; + } + shader.MainFunctionBody() << "qkv_output[global_idx] = qkv_input[input_offset_idx]"; + if (has_bias_) { + shader.MainFunctionBody() << " + bias[bias_offset_idx];\n"; + } else { + shader.MainFunctionBody() << ";\n"; + } return Status::OK(); } Status TransferBSDToBNSH(onnxruntime::webgpu::ComputeContext& context, int num_heads, int sequence_length, - int head_size, const Tensor* input_tensor, const Tensor* bias, int bias_offset, Tensor* output_tensor) { + int head_size, const Tensor* input_tensor, const Tensor* bias, int bias_offset, Tensor* output_tensor) { assert(input_tensor->Shape().GetDims().size() == 3); assert(output_tensor->Shape().GetDims().size() == 4); @@ -61,13 +65,11 @@ Status TransferBSDToBNSH(onnxruntime::webgpu::ComputeContext& context, int num_h program.AddInputs({{input_tensor, ProgramTensorMetadataDependency::TypeAndRank}}) .AddOutputs({{output_tensor, ProgramTensorMetadataDependency::TypeAndRank}}) .SetDispatchGroupSize((data_size + WORKGROUP_SIZE - 1) / WORKGROUP_SIZE) - .AddUniformVariables({ - {data_size}, - {static_cast(batch_offset)}, - {static_cast(sequence_offset)}, - {static_cast(head_offset)}, - {static_cast(bias_offset)} - }); + .AddUniformVariables({{data_size}, + {static_cast(batch_offset)}, + {static_cast(sequence_offset)}, + {static_cast(head_offset)}, + {static_cast(bias_offset)}}); if (has_bias) { program.AddInput({bias, ProgramTensorMetadataDependency::TypeAndRank}); @@ -77,91 +79,89 @@ Status TransferBSDToBNSH(onnxruntime::webgpu::ComputeContext& context, int num_h }; Status AttentionProbsProgram::GenerateShaderCode(ShaderHelper& shader) const { - shader.AddInput("q", ShaderVariable::UseUniform | ShaderVariable::UseValueTypeAlias); - shader.AddInput("key", ShaderVariable::UseUniform | ShaderVariable::UseValueTypeAlias); + shader.AddInput("q", ShaderUsage::UseUniform | ShaderUsage::UseValueTypeAlias); + shader.AddInput("key", ShaderUsage::UseUniform | ShaderUsage::UseValueTypeAlias); if (feed_past_key_) { - shader.AddInput("past_key", ShaderVariable::UseUniform); + shader.AddInput("past_key", ShaderUsage::UseUniform); } if (has_attention_bias_) { - shader.AddInput("attention_bias", ShaderVariable::UseUniform); + shader.AddInput("attention_bias", ShaderUsage::UseUniform); } - shader.AddOutput("output", ShaderVariable::UseUniform | ShaderVariable::UseValueTypeAlias); + shader.AddOutput("output", ShaderUsage::UseUniform | ShaderUsage::UseValueTypeAlias); if (has_present_key_) { - shader.AddOutput("present_key", ShaderVariable::UseUniform); + shader.AddOutput("present_key", ShaderUsage::UseUniform); } - shader.AppendImplementation("var tileQ: array;\n") - .AppendImplementation("var tileK: array;\n"); + shader.AdditionalImplementation() << "var tileQ: array;\n" + << "var tileK: array;\n" + << "alias f32_val_t = " << (components_ == 4 ? "vec4" : (components_ == 2 ? "vec2" : "f32")) << ";\n"; - std::string f32_str = components_ == 4 ? "vec4" : (components_ == 2 ? "vec2" : "f32"); - std::ostringstream ss; - ss << "// x holds the N and y holds the M\n" - << "let headIdx = workgroup_id.z;\n" - << "let m = workgroup_id.y * TILE_SIZE;\n" - << "let n = workgroup_id.x * TILE_SIZE;\n" - << "let qOffset = uniforms.M * uniforms.K * headIdx + m * uniforms.K;\n"; + shader.MainFunctionBody() << "// x holds the N and y holds the M\n" + "let headIdx = workgroup_id.z;\n" + "let m = workgroup_id.y * TILE_SIZE;\n" + "let n = workgroup_id.x * TILE_SIZE;\n" + "let qOffset = uniforms.M * uniforms.K * headIdx + m * uniforms.K;\n"; if (feed_past_key_ && has_present_key_) { - ss << "let kOffset = uniforms.kv_sequence_length * uniforms.K * headIdx;\n" - << "let pastKeyOffset = uniforms.past_sequence_length * uniforms.K * headIdx;\n"; + shader.MainFunctionBody() << "let kOffset = uniforms.kv_sequence_length * uniforms.K * headIdx;\n" + << "let pastKeyOffset = uniforms.past_sequence_length * uniforms.K * headIdx;\n"; } else { - ss << "let kOffset = uniforms.N * uniforms.K * headIdx + n * uniforms.K;\n"; + shader.MainFunctionBody() << "let kOffset = uniforms.N * uniforms.K * headIdx + n * uniforms.K;\n"; } if (has_present_key_) { - ss << "let presentKeyOffset = headIdx * uniforms.N * uniforms.K;\n"; + shader.MainFunctionBody() << "let presentKeyOffset = headIdx * uniforms.N * uniforms.K;\n"; } - ss << "var value = " << f32_str << "(0);\n" - << "for (var w: u32 = 0u; w < uniforms.K; w += TILE_SIZE) {\n" - << "if (global_id.y < uniforms.M && w + local_id.x < uniforms.K) {\n" - << "tileQ[TILE_SIZE * local_id.y + local_id.x] = q[qOffset + local_id.y * uniforms.K + w + local_id.x];\n" - << "}\n" - << "if (n + local_id.y < uniforms.N && w + local_id.x < uniforms.K) {\n" - << "var idx = TILE_SIZE * local_id.y + local_id.x;\n"; + shader.MainFunctionBody() << "var value = f32_val_t(0);\n" + "for (var w: u32 = 0u; w < uniforms.K; w += TILE_SIZE) {\n" + " if (global_id.y < uniforms.M && w + local_id.x < uniforms.K) {\n" + " tileQ[TILE_SIZE * local_id.y + local_id.x] = q[qOffset + local_id.y * uniforms.K + w + local_id.x];\n" + " }\n" + " if (n + local_id.y < uniforms.N && w + local_id.x < uniforms.K) {\n" + " var idx = TILE_SIZE * local_id.y + local_id.x;\n"; if (feed_past_key_ && has_present_key_) { - ss << "if (n + local_id.y < uniforms.past_sequence_length) {\n" - << "tileK[idx] = past_key[pastKeyOffset + (n + local_id.y) * uniforms.K + w + local_id.x];\n" - << "} else {\n" - << "tileK[idx] =" - << "key[kOffset + (n + local_id.y - uniforms.past_sequence_length) * uniforms.K + w + local_id.x];\n" - << "}\n"; + shader.MainFunctionBody() << " if (n + local_id.y < uniforms.past_sequence_length) {\n" + " tileK[idx] = past_key[pastKeyOffset + (n + local_id.y) * uniforms.K + w + local_id.x];\n" + " } else {\n" + " tileK[idx] = key[kOffset + (n + local_id.y - uniforms.past_sequence_length) * uniforms.K + w + local_id.x];\n" + " }\n"; } else { - ss << "tileK[idx] = key[kOffset + local_id.y * uniforms.K + w + local_id.x];\n"; + shader.MainFunctionBody() << " tileK[idx] = key[kOffset + local_id.y * uniforms.K + w + local_id.x];\n"; } if (has_present_key_) { - ss << "present_key[presentKeyOffset + (n + local_id.y) * uniforms.K + w + local_id.x] = tileK[idx];\n"; + shader.MainFunctionBody() << " present_key[presentKeyOffset + (n + local_id.y) * uniforms.K + w + local_id.x] = tileK[idx];\n"; } - ss << "}\n" - << "workgroupBarrier();\n" - << "for (var k: u32 = 0u; k < TILE_SIZE && w+k < uniforms.K; k++) {\n" - << "value += "<< f32_str << "(tileQ[TILE_SIZE * local_id.y + k] * tileK[TILE_SIZE * local_id.x + k]);\n" - << "}\n" - << "workgroupBarrier();\n" - << "}\n"; - - ss << "let headOffset = headIdx * uniforms.M * uniforms.N;\n" - << "if (global_id.y < uniforms.M && global_id.x < uniforms.N) {\n" - << "let outputIdx = headOffset + global_id.y * uniforms.N + global_id.x;\n" - << "var sum: f32 = " << (components_ == 4 ? "value.x + value.y + value.z + value.w" : - (components_ == 2 ? "value.x + value.y" : "value")) << ";\n"; + shader.MainFunctionBody() << " }\n" + << " workgroupBarrier();\n" + << " for (var k: u32 = 0u; k < TILE_SIZE && w+k < uniforms.K; k++) {\n" + << " value += f32_val_t(tileQ[TILE_SIZE * local_id.y + k] * tileK[TILE_SIZE * local_id.x + k]);\n" + << " }\n" + << " workgroupBarrier();\n" + << "}\n"; - ss << "output[outputIdx] = output_value_t(sum * uniforms.alpha) + " - << (has_attention_bias_ ? "attention_bias[outputIdx]" : "0.0") << ";\n" - << "}\n"; + shader.MainFunctionBody() << "let headOffset = headIdx * uniforms.M * uniforms.N;\n" + << "if (global_id.y < uniforms.M && global_id.x < uniforms.N) {\n" + << " let outputIdx = headOffset + global_id.y * uniforms.N + global_id.x;\n" + << " var sum: f32 = " << (components_ == 4 ? "value.x + value.y + value.z + value.w" : (components_ == 2 ? "value.x + value.y" : "value")) << ";\n"; - shader.SetMainFunctionBody(ss.str()); + shader.MainFunctionBody() << " output[outputIdx] = output_value_t(sum * uniforms.alpha)"; + if (has_attention_bias_) { + shader.MainFunctionBody() << " + attention_bias[outputIdx]"; + } + shader.MainFunctionBody() << ";\n" + << "}\n"; return Status::OK(); } Status ComputeAttentionProbs(onnxruntime::webgpu::ComputeContext& context, int output_count, const Tensor* Q, - const Tensor* K, const Tensor* past_key, const Tensor* attention_bias, Tensor* probs, Tensor* present_key, - AttentionParameters& parameters, int past_sequence_length, int total_sequence_length) { + const Tensor* K, const Tensor* past_key, const Tensor* attention_bias, Tensor* probs, Tensor* present_key, + AttentionParameters& parameters, int past_sequence_length, int total_sequence_length) { const float alpha = parameters.scale == 0.0f ? 1.f / sqrt(static_cast(parameters.head_size)) : parameters.scale; @@ -188,68 +188,60 @@ Status ComputeAttentionProbs(onnxruntime::webgpu::ComputeContext& context, int o const uint32_t vectorized_head_size = parameters.head_size / components; program.SetDispatchGroupSize((total_sequence_length + tile_size - 1) / tile_size, - (parameters.sequence_length + tile_size - 1) / tile_size, - parameters.batch_size * parameters.num_heads) - .SetWorkgroupSize(tile_size, tile_size) - .CacheHint(std::to_string(tile_size)) - .AddUniformVariables({ - {static_cast(parameters.sequence_length)}, - {static_cast(vectorized_head_size)}, - {static_cast(total_sequence_length)}, - {static_cast(parameters.num_heads)}, - {static_cast(alpha)}, - {static_cast(past_sequence_length)}, - {static_cast(parameters.kv_sequence_length)}}) - .SetOverridableConstants({{static_cast(tile_size)}}); + (parameters.sequence_length + tile_size - 1) / tile_size, + parameters.batch_size * parameters.num_heads) + .SetWorkgroupSize(tile_size, tile_size) + .CacheHint(std::to_string(tile_size)) + .AddUniformVariables({{static_cast(parameters.sequence_length)}, + {static_cast(vectorized_head_size)}, + {static_cast(total_sequence_length)}, + {static_cast(parameters.num_heads)}, + {static_cast(alpha)}, + {static_cast(past_sequence_length)}, + {static_cast(parameters.kv_sequence_length)}}) + .SetOverridableConstants({{static_cast(tile_size)}}); return context.RunProgram(program); } Status InPlaceSoftmaxProgram::GenerateShaderCode(ShaderHelper& shader) const { - shader.AddOutput("x", ShaderVariable::UseUniform | ShaderVariable::UseValueTypeAlias | - ShaderVariable::UseElementTypeAlias); - shader.AppendImplementation("var thread_max: array;\n") - .AppendImplementation("var thread_sum: array;\n"); - - std::string f32_str = components_ == 4 ? "vec4" : (components_ == 2 ? "vec2" : "f32"); - std::ostringstream ss; - ss << "let local_offset = local_idx * uniforms.elements_per_thread;\n" - << "let offset = (global_idx / " << work_group_size_ << ") * uniforms.d_comp + local_offset;\n" - << "var thread_max_vector = " << f32_str << "(-3.402823e+38f);\n" - << "for (var i: u32 = 0; i < uniforms.elements_per_thread && i + local_offset < uniforms.d_comp; i++) {\n" - << "thread_max_vector = max(" << f32_str << "(x[offset + i]), thread_max_vector);\n" - << "}\n" - << "thread_max[local_idx] = " << (components_ == 4 ? - "max(max(thread_max_vector.x, thread_max_vector.y), max(thread_max_vector.z, thread_max_vector.w))" : - (components_ == 2 ? "max(thread_max_vector.x, thread_max_vector.y)" : "thread_max_vector")) << ";\n" - << "workgroupBarrier();\n" - << "var max_value = f32(-3.402823e+38f);\n" - << "for (var i = 0u; i < " << work_group_size_ << "; i++) {\n" - << "max_value = max(thread_max[i], max_value);\n" - << "}\n" - << "var sum_vector = " << f32_str << "(0);\n" - << "for (var i: u32 = 0; i < uniforms.elements_per_thread && i + local_offset < uniforms.d_comp; i++) {\n" - << "sum_vector += exp(" << f32_str << "(x[offset + i]) - max_value);\n" - << "}\n" - << "thread_sum[local_idx] = " << (components_ == 4 ? "sum_vector.x + sum_vector.y + sum_vector.z + sum_vector.w" : - (components_ == 2 ? "sum_vector.x + sum_vector.y" : "sum_vector")) << ";\n" - << "workgroupBarrier();\n" - << "var sum: f32 = 0;\n" - << "for (var i = 0u; i < " << work_group_size_ << "; i++) {\n" - << "sum += thread_sum[i]\n;" - << "}\n" - << "if (sum == 0) {\n" - << "for (var i: u32 = 0; i < uniforms.elements_per_thread && i + local_offset < uniforms.d_comp; i++) {\n" - << "x[offset + i] = x_value_t(x_element_t(uniforms.d_inv));\n" - << "}\n" - << "} else {\n" - << "for (var i: u32 = 0; i < uniforms.elements_per_thread && i + local_offset < uniforms.d_comp; i++) {\n" - << "var f32input = " << f32_str << "(x[offset + i]);\n" - << "x[offset + i] = x_value_t(exp(f32input - max_value) / sum);\n" - << "}\n" - << "}\n"; - - shader.SetMainFunctionBody(ss.str()); + shader.AddOutput("x", ShaderUsage::UseUniform | ShaderUsage::UseValueTypeAlias | ShaderUsage::UseElementTypeAlias); + shader.AdditionalImplementation() << "var thread_max: array;\n" + << "var thread_sum: array;\n" + << "alias f32_val_t = " << (components_ == 4 ? "vec4" : (components_ == 2 ? "vec2" : "f32")) << ";\n"; + + shader.MainFunctionBody() << "let local_offset = local_idx * uniforms.elements_per_thread;\n" + << "let offset = (global_idx / " << work_group_size_ << ") * uniforms.d_comp + local_offset;\n" + << "var thread_max_vector = f32_val_t(-3.402823e+38f);\n" + << "for (var i: u32 = 0; i < uniforms.elements_per_thread && i + local_offset < uniforms.d_comp; i++) {\n" + << " thread_max_vector = max(f32_val_t(x[offset + i]), thread_max_vector);\n" + << "}\n" + << "thread_max[local_idx] = " << (components_ == 4 ? "max(max(thread_max_vector.x, thread_max_vector.y), max(thread_max_vector.z, thread_max_vector.w))" : (components_ == 2 ? "max(thread_max_vector.x, thread_max_vector.y)" : "thread_max_vector")) << ";\n" + << "workgroupBarrier();\n" + << "var max_value = f32(-3.402823e+38f);\n" + << "for (var i = 0u; i < " << work_group_size_ << "; i++) {\n" + << " max_value = max(thread_max[i], max_value);\n" + << "}\n" + << "var sum_vector = f32_val_t(0);\n" + << "for (var i: u32 = 0; i < uniforms.elements_per_thread && i + local_offset < uniforms.d_comp; i++) {\n" + << " sum_vector += exp(f32_val_t(x[offset + i]) - max_value);\n" + << "}\n" + << "thread_sum[local_idx] = " << (components_ == 4 ? "sum_vector.x + sum_vector.y + sum_vector.z + sum_vector.w" : (components_ == 2 ? "sum_vector.x + sum_vector.y" : "sum_vector")) << ";\n" + << "workgroupBarrier();\n" + << "var sum: f32 = 0;\n" + << "for (var i = 0u; i < " << work_group_size_ << "; i++) {\n" + << " sum += thread_sum[i]\n;" + << "}\n" + << "if (sum == 0) {\n" + << " for (var i: u32 = 0; i < uniforms.elements_per_thread && i + local_offset < uniforms.d_comp; i++) {\n" + << " x[offset + i] = x_value_t(x_element_t(uniforms.d_inv));\n" + << " }\n" + << "} else {\n" + << " for (var i: u32 = 0; i < uniforms.elements_per_thread && i + local_offset < uniforms.d_comp; i++) {\n" + << " var f32input = f32_val_t(x[offset + i]);\n" + << " x[offset + i] = x_value_t(exp(f32input - max_value) / sum);\n" + << " }\n" + << "}\n"; return Status::OK(); } @@ -265,109 +257,104 @@ Status ComputeInPlaceSoftmax(onnxruntime::webgpu::ComputeContext& context, Tenso InPlaceSoftmaxProgram program{"InPlaceSoftmax", work_group_size, components}; program.AddOutputs({{probs, ProgramTensorMetadataDependency::TypeAndRank, components}}) - .SetDispatchGroupSize(n) - .SetWorkgroupSize(work_group_size) - .AddUniformVariables({{static_cast(1.f / static_cast(d))}, + .SetDispatchGroupSize(n) + .SetWorkgroupSize(work_group_size) + .AddUniformVariables({{static_cast(1.f / static_cast(d))}, {static_cast(d_comp)}, - {static_cast(elementsPerThread)} - }); + {static_cast(elementsPerThread)}}); return context.RunProgram(program); } Status VxAttentionScoreProgram::GenerateShaderCode(ShaderHelper& shader) const { - shader.AddInput("probs", ShaderVariable::UseUniform | ShaderVariable::UseValueTypeAlias | - ShaderVariable::UseElementTypeAlias); - shader.AddInput("v", ShaderVariable::UseUniform | ShaderVariable::UseValueTypeAlias); + shader.AddInput("probs", ShaderUsage::UseUniform | ShaderUsage::UseValueTypeAlias | ShaderUsage::UseElementTypeAlias); + shader.AddInput("v", ShaderUsage::UseUniform | ShaderUsage::UseValueTypeAlias); if (feed_past_value_) { - shader.AddInput("past_value", ShaderVariable::UseUniform); + shader.AddInput("past_value", ShaderUsage::UseUniform); } - shader.AddOutput("output", ShaderVariable::UseUniform); + shader.AddOutput("output", ShaderUsage::UseUniform); if (has_present_value_) { - shader.AddOutput("present_value", ShaderVariable::UseUniform); + shader.AddOutput("present_value", ShaderUsage::UseUniform); } - shader.AppendImplementation("var tileQ: array;\n") - .AppendImplementation("var tileK: array;\n"); + shader.AdditionalImplementation() << "var tileQ: array;\n" + << "var tileK: array;\n"; - std::ostringstream ss; - ss << "let headIdx = workgroup_id.z;\n" - << "let m = global_id.y;\n" - << "let n = global_id.x;\n" - << "let offsetA = headIdx * (uniforms.M * uniforms.K) + m * uniforms.K;\n"; + shader.MainFunctionBody() << "let headIdx = workgroup_id.z;\n" + << "let m = global_id.y;\n" + << "let n = global_id.x;\n" + << "let offsetA = headIdx * (uniforms.M * uniforms.K) + m * uniforms.K;\n"; if (feed_past_value_ && has_present_value_) { - ss << "let pastValueOffset = headIdx * uniforms.N * uniforms.past_sequence_length + n;\n" - << "let vOffset = headIdx * uniforms.N * uniforms.kv_sequence_length + n;\n"; + shader.MainFunctionBody() << "let pastValueOffset = headIdx * uniforms.N * uniforms.past_sequence_length + n;\n" + << "let vOffset = headIdx * uniforms.N * uniforms.kv_sequence_length + n;\n"; } else { - ss << "let offsetB = headIdx * uniforms.N * uniforms.K + n;\n"; + shader.MainFunctionBody() << "let offsetB = headIdx * uniforms.N * uniforms.K + n;\n"; } if (has_present_value_) { - ss << "let presentValueOffset = headIdx * uniforms.N * uniforms.K + n;\n"; + shader.MainFunctionBody() << "let presentValueOffset = headIdx * uniforms.N * uniforms.K + n;\n"; } - ss << "var value = probs_element_t(0);\n" - << "for (var w: u32 = 0u; w < uniforms.K; w += TILE_SIZE) {\n" - << "if (m < uniforms.M && w + local_id.x < uniforms.K) {\n" - << "tileQ[TILE_SIZE * local_id.y + local_id.x] = probs[offsetA + w + local_id.x];\n" - << "}\n" - << "if (n < uniforms.N && w + local_id.y < uniforms.K) {\n" - << "var idx = TILE_SIZE * local_id.y + local_id.x;\n"; + shader.MainFunctionBody() << "var value = probs_element_t(0);\n" + << "for (var w: u32 = 0u; w < uniforms.K; w += TILE_SIZE) {\n" + << " if (m < uniforms.M && w + local_id.x < uniforms.K) {\n" + << " tileQ[TILE_SIZE * local_id.y + local_id.x] = probs[offsetA + w + local_id.x];\n" + << " }\n" + << " if (n < uniforms.N && w + local_id.y < uniforms.K) {\n" + << " var idx = TILE_SIZE * local_id.y + local_id.x;\n"; if (feed_past_value_ && has_present_value_) { - ss << "if (w + local_id.y < uniforms.past_sequence_length) {\n" - << "tileK[idx] = past_value[pastValueOffset + (w + local_id.y) * uniforms.N];\n" - << "} else {\n" - << "tileK[idx] = v[vOffset + (w + local_id.y - uniforms.past_sequence_length) * uniforms.N];\n" - << "}\n"; + shader.MainFunctionBody() << " if (w + local_id.y < uniforms.past_sequence_length) {\n" + << " tileK[idx] = past_value[pastValueOffset + (w + local_id.y) * uniforms.N];\n" + << " } else {\n" + << " tileK[idx] = v[vOffset + (w + local_id.y - uniforms.past_sequence_length) * uniforms.N];\n" + << " }\n"; } else { - ss << "tileK[idx] = v[offsetB + (w + local_id.y) * uniforms.N];\n"; + shader.MainFunctionBody() << " tileK[idx] = v[offsetB + (w + local_id.y) * uniforms.N];\n"; } if (has_present_value_) { - ss << "present_value[presentValueOffset + (w + local_id.y) * uniforms.N] = tileK[idx];\n"; + shader.MainFunctionBody() << " present_value[presentValueOffset + (w + local_id.y) * uniforms.N] = tileK[idx];\n"; } - ss << "}\n" - << "workgroupBarrier();\n" - << "for (var k: u32 = 0u; k < TILE_SIZE && w+k < uniforms.K; k++) {\n" - << "value += tileQ[TILE_SIZE * local_id.y + k] * tileK[TILE_SIZE * k + local_id.x];\n" - << "}\n" - << "workgroupBarrier();\n" - << "}\n"; - - ss << "// we need to transpose output from BNSH_v to BSND_v\n" - << "let batchIdx = workgroup_id.z / uniforms.num_heads;\n" - << "let currentBatchHeadNumber = workgroup_id.z % uniforms.num_heads;\n" - << "if (m < uniforms.M && n < uniforms.N) {\n" - << "let outputIdx = batchIdx * uniforms.M * uniforms.v_hidden_size + " - << "m * uniforms.v_hidden_size + currentBatchHeadNumber * uniforms.N + n;\n" - << "output[outputIdx] = value;\n" - << "}\n"; - - shader.SetMainFunctionBody(ss.str()); + shader.MainFunctionBody() << " }\n" + << " workgroupBarrier();\n" + << " for (var k: u32 = 0u; k < TILE_SIZE && w+k < uniforms.K; k++) {\n" + << " value += tileQ[TILE_SIZE * local_id.y + k] * tileK[TILE_SIZE * k + local_id.x];\n" + << " }\n" + << " workgroupBarrier();\n" + << "}\n"; + + shader.MainFunctionBody() << "// we need to transpose output from BNSH_v to BSND_v\n" + << "let batchIdx = workgroup_id.z / uniforms.num_heads;\n" + << "let currentBatchHeadNumber = workgroup_id.z % uniforms.num_heads;\n" + << "if (m < uniforms.M && n < uniforms.N) {\n" + << " let outputIdx = batchIdx * uniforms.M * uniforms.v_hidden_size + " + << " m * uniforms.v_hidden_size + currentBatchHeadNumber * uniforms.N + n;\n" + << " output[outputIdx] = value;\n" + << "}\n"; return Status::OK(); } Status ComputeVxAttentionScore(onnxruntime::webgpu::ComputeContext& context, int output_count, - const Tensor* probs, - const Tensor* V, - const Tensor* past_value, - Tensor* output, - Tensor* present_value, - AttentionParameters& parameters, - int past_sequence_length, - int total_sequence_length) { - const bool feed_past_value = present_value != nullptr && past_value !=nullptr && past_value->SizeInBytes() > 0; + const Tensor* probs, + const Tensor* V, + const Tensor* past_value, + Tensor* output, + Tensor* present_value, + AttentionParameters& parameters, + int past_sequence_length, + int total_sequence_length) { + const bool feed_past_value = present_value != nullptr && past_value != nullptr && past_value->SizeInBytes() > 0; const bool has_present_value = output_count > 1 && past_value != nullptr; const int tile_size = 12; VxAttentionScoreProgram program{"VxAttentionScore", feed_past_value, has_present_value, tile_size}; program.AddInputs({{probs, ProgramTensorMetadataDependency::TypeAndRank}, - {V, ProgramTensorMetadataDependency::TypeAndRank}}); + {V, ProgramTensorMetadataDependency::TypeAndRank}}); if (feed_past_value) { program.AddInput({past_value, ProgramTensorMetadataDependency::TypeAndRank}); } @@ -377,27 +364,26 @@ Status ComputeVxAttentionScore(onnxruntime::webgpu::ComputeContext& context, int } program.SetDispatchGroupSize((parameters.v_head_size + tile_size - 1) / tile_size, - (parameters.sequence_length + tile_size - 1) / tile_size, - parameters.batch_size * parameters.num_heads) - .SetWorkgroupSize(tile_size, tile_size) - .AddUniformVariables({ - {static_cast(parameters.sequence_length)}, - {static_cast(total_sequence_length)}, - {static_cast(parameters.v_head_size)}, - {static_cast(parameters.num_heads)}, - {static_cast(parameters.v_hidden_size)}, - {static_cast(past_sequence_length)}, - {static_cast(parameters.kv_sequence_length)}}) - .SetOverridableConstants({{static_cast(tile_size)}});; + (parameters.sequence_length + tile_size - 1) / tile_size, + parameters.batch_size * parameters.num_heads) + .SetWorkgroupSize(tile_size, tile_size) + .AddUniformVariables({{static_cast(parameters.sequence_length)}, + {static_cast(total_sequence_length)}, + {static_cast(parameters.v_head_size)}, + {static_cast(parameters.num_heads)}, + {static_cast(parameters.v_hidden_size)}, + {static_cast(past_sequence_length)}, + {static_cast(parameters.kv_sequence_length)}}) + .SetOverridableConstants({{static_cast(tile_size)}}); + ; return context.RunProgram(program); } Status ApplyAttention(const Tensor* Q, const Tensor* K, const Tensor* V, const Tensor* attention_bias, - const Tensor* past_key, const Tensor* past_value, Tensor* output, Tensor* present_key, Tensor* present_value, - AttentionParameters& parameters, onnxruntime::webgpu::ComputeContext& context) { - const int output_count = std::min({context.OutputCount(), 1 + (past_key != nullptr ? 1 : 0) - + (past_value !=nullptr ? 1 : 0)}); + const Tensor* past_key, const Tensor* past_value, Tensor* output, Tensor* present_key, Tensor* present_value, + AttentionParameters& parameters, onnxruntime::webgpu::ComputeContext& context) { + const int output_count = std::min({context.OutputCount(), 1 + (past_key != nullptr ? 1 : 0) + (past_value != nullptr ? 1 : 0)}); const int past_sequence_length = output_count > 1 ? parameters.past_sequence_length : 0; const int total_sequence_length = past_sequence_length + parameters.kv_sequence_length; @@ -406,13 +392,13 @@ Status ApplyAttention(const Tensor* Q, const Tensor* K, const Tensor* V, const T const TensorShape probs_shape(probs_dims); Tensor probs = context.CreateGPUTensor(Q->DataType(), probs_shape); ORT_RETURN_IF_ERROR(ComputeAttentionProbs(context, output_count, Q, K, past_key, attention_bias, &probs, present_key, - parameters, past_sequence_length, total_sequence_length)); + parameters, past_sequence_length, total_sequence_length)); ORT_RETURN_IF_ERROR(ComputeInPlaceSoftmax(context, &probs, - parameters.batch_size * parameters.num_heads * parameters.sequence_length, total_sequence_length)); + parameters.batch_size * parameters.num_heads * parameters.sequence_length, total_sequence_length)); ORT_RETURN_IF_ERROR(ComputeVxAttentionScore(context, output_count, &probs, V, past_value, output, present_value, - parameters, past_sequence_length, total_sequence_length)); + parameters, past_sequence_length, total_sequence_length)); return Status::OK(); } @@ -450,9 +436,9 @@ Status MultiHeadAttention::ComputeInternal(onnxruntime::webgpu::ComputeContext& AttentionParameters parameters; ORT_RETURN_IF_ERROR(multihead_attention_helper::CheckInputs(query, key, value, - bias, key_padding_mask, attention_bias, past_key, past_value, nullptr, ¶meters, - num_heads_, mask_filter_value_, scale_, is_unidirectional_, false, kMultiHeadAttention, - context.DeviceLimits().maxComputeInvocationsPerWorkgroup)); + bias, key_padding_mask, attention_bias, past_key, past_value, nullptr, ¶meters, + num_heads_, mask_filter_value_, scale_, is_unidirectional_, false, kMultiHeadAttention, + context.DeviceLimits().maxComputeInvocationsPerWorkgroup)); TensorShapeVector output_shape(3); output_shape[0] = static_cast(parameters.batch_size); @@ -488,14 +474,14 @@ Status MultiHeadAttention::ComputeInternal(onnxruntime::webgpu::ComputeContext& TensorShape k_new_shape(k_new_dims); Tensor K = context.CreateGPUTensor(key->DataType(), k_new_shape); ORT_RETURN_IF_ERROR(TransferBSDToBNSH(context, parameters.num_heads, parameters.kv_sequence_length, - parameters.head_size, key, bias, parameters.hidden_size, &K)); + parameters.head_size, key, bias, parameters.hidden_size, &K)); TensorShapeVector v_new_dims({parameters.batch_size, parameters.num_heads, parameters.kv_sequence_length, parameters.v_head_size}); TensorShape v_new_shape(v_new_dims); Tensor V = context.CreateGPUTensor(value->DataType(), v_new_shape); ORT_RETURN_IF_ERROR(TransferBSDToBNSH(context, parameters.num_heads, parameters.kv_sequence_length, - parameters.v_head_size, value, bias, 2 * parameters.hidden_size, &V)); + parameters.v_head_size, value, bias, 2 * parameters.hidden_size, &V)); // Compute the attention score and apply the score to V return ApplyAttention(&Q, &K, &V, attention_bias, past_key, past_value, output, present_key, diff --git a/onnxruntime/contrib_ops/webgpu/bert/multihead_attention.h b/onnxruntime/contrib_ops/webgpu/bert/multihead_attention.h index 01672298b375f..36803e3027b4c 100644 --- a/onnxruntime/contrib_ops/webgpu/bert/multihead_attention.h +++ b/onnxruntime/contrib_ops/webgpu/bert/multihead_attention.h @@ -26,16 +26,15 @@ class TransferBSDToBNSHProgram final : public Program {"head_offset", ProgramUniformVariableDataType::Uint32}, {"bias_offset", ProgramUniformVariableDataType::Uint32}); -private: + private: bool has_bias_; }; class AttentionProbsProgram final : public Program { public: AttentionProbsProgram(const std::string& kernel_name, bool feed_past_key, bool has_present_key, - bool has_attention_bias,int tile_size, int components) - : Program{kernel_name}, feed_past_key_(feed_past_key), has_present_key_(has_present_key), - has_attention_bias_(has_attention_bias), tile_size_(tile_size), components_(components) { + bool has_attention_bias, int tile_size, int components) + : Program{kernel_name}, feed_past_key_(feed_past_key), has_present_key_(has_present_key), has_attention_bias_(has_attention_bias), tile_size_(tile_size), components_(components) { } Status GenerateShaderCode(ShaderHelper& sh) const override; @@ -50,7 +49,7 @@ class AttentionProbsProgram final : public Program { WEBGPU_PROGRAM_DEFINE_OVERRIDABLE_CONSTANTS({"TILE_SIZE", ProgramConstantDataType::Uint32}); -private: + private: bool feed_past_key_; bool has_present_key_; bool has_attention_bias_; @@ -61,7 +60,7 @@ class AttentionProbsProgram final : public Program { class InPlaceSoftmaxProgram final : public Program { public: InPlaceSoftmaxProgram(const std::string& kernel_name, int work_group_size, int components) - : Program{kernel_name}, work_group_size_(work_group_size), components_(components) { + : Program{kernel_name}, work_group_size_(work_group_size), components_(components) { } Status GenerateShaderCode(ShaderHelper& sh) const override; @@ -70,16 +69,15 @@ class InPlaceSoftmaxProgram final : public Program { {"d_comp", ProgramUniformVariableDataType::Uint32}, {"elements_per_thread", ProgramUniformVariableDataType::Uint32}); -private: + private: int work_group_size_; int components_; }; class VxAttentionScoreProgram final : public Program { public: - VxAttentionScoreProgram(const std::string& kernel_name, bool feed_past_value, bool has_present_value,int tile_size) - : Program{kernel_name}, feed_past_value_(feed_past_value), has_present_value_(has_present_value), - tile_size_(tile_size) { + VxAttentionScoreProgram(const std::string& kernel_name, bool feed_past_value, bool has_present_value, int tile_size) + : Program{kernel_name}, feed_past_value_(feed_past_value), has_present_value_(has_present_value), tile_size_(tile_size) { } Status GenerateShaderCode(ShaderHelper& sh) const override; @@ -94,7 +92,7 @@ class VxAttentionScoreProgram final : public Program { WEBGPU_PROGRAM_DEFINE_OVERRIDABLE_CONSTANTS({"TILE_SIZE", ProgramConstantDataType::Uint32}); -private: + private: bool feed_past_value_; bool has_present_value_; int tile_size_; diff --git a/onnxruntime/test/contrib_ops/multihead_attention_op_test.cc b/onnxruntime/test/contrib_ops/multihead_attention_op_test.cc index 3dbf3a66b345f..4bcf85ae0b76a 100644 --- a/onnxruntime/test/contrib_ops/multihead_attention_op_test.cc +++ b/onnxruntime/test/contrib_ops/multihead_attention_op_test.cc @@ -305,8 +305,7 @@ static void RunMultiHeadAttentionKernel( bool disable_cuda = false, bool disable_webgpu = false, bool disable_rocm = DISABLE_ROCM, - bool disable_dml = false - ) { + bool disable_dml = false) { if (kernel_type == AttentionKernelType::AttentionKernel_Default) { ScopedEnvironmentVariables scoped_env_vars{ EnvVarMap{ From 4117a8b99c2ceef957ef372117728c3629814c98 Mon Sep 17 00:00:00 2001 From: Yulong Wang <7679871+fs-eire@users.noreply.github.com> Date: Mon, 30 Sep 2024 15:31:10 -0700 Subject: [PATCH 4/4] revise tests --- .../multihead_attention_op_test.cc | 97 +++++++++++-------- 1 file changed, 57 insertions(+), 40 deletions(-) diff --git a/onnxruntime/test/contrib_ops/multihead_attention_op_test.cc b/onnxruntime/test/contrib_ops/multihead_attention_op_test.cc index 4bcf85ae0b76a..6b6799d73fb56 100644 --- a/onnxruntime/test/contrib_ops/multihead_attention_op_test.cc +++ b/onnxruntime/test/contrib_ops/multihead_attention_op_test.cc @@ -49,9 +49,9 @@ static void RunMultiHeadAttentionTest( bool use_float16 = false, bool disable_cpu = false, // some cases not supported in cpu right now. bool disable_cuda = false, + bool disable_webgpu = false, bool disable_rocm = DISABLE_ROCM, // not supported in rocm right now. - bool disable_dml = false, - bool disable_webgpu = false) { + bool disable_dml = false) { kv_sequence_length = (kv_sequence_length == 0 ? sequence_length : kv_sequence_length); int min_cuda_architecture = use_float16 ? 750 : 0; @@ -318,8 +318,8 @@ static void RunMultiHeadAttentionKernel( query_data, key_data, value_data, kv_data, qkv_data, bias_data, attention_bias_data, past_key_data, past_value_data, present_key_data, present_value_data, key_padding_mask_data, mask_type, output_data, num_heads, batch_size, sequence_length, kv_sequence_length, - hidden_size, v_hidden_size, is_static_kv, use_float16, disable_cpu, disable_cuda, disable_rocm, - disable_dml, disable_webgpu); + hidden_size, v_hidden_size, is_static_kv, use_float16, disable_cpu, disable_cuda, disable_webgpu, + disable_rocm, disable_dml); return; } @@ -335,8 +335,8 @@ static void RunMultiHeadAttentionKernel( query_data, key_data, value_data, kv_data, qkv_data, bias_data, attention_bias_data, past_key_data, past_value_data, present_key_data, present_value_data, key_padding_mask_data, mask_type, output_data, num_heads, batch_size, sequence_length, kv_sequence_length, - hidden_size, v_hidden_size, is_static_kv, use_float16, disable_cpu, disable_cuda, disable_rocm, - disable_dml, disable_webgpu); + hidden_size, v_hidden_size, is_static_kv, use_float16, disable_cpu, disable_cuda, disable_webgpu, + disable_rocm, disable_dml); return; } @@ -352,8 +352,8 @@ static void RunMultiHeadAttentionKernel( query_data, key_data, value_data, kv_data, qkv_data, bias_data, attention_bias_data, past_key_data, past_value_data, present_key_data, present_value_data, key_padding_mask_data, mask_type, output_data, num_heads, batch_size, sequence_length, kv_sequence_length, - hidden_size, v_hidden_size, is_static_kv, use_float16, disable_cpu, disable_cuda, disable_rocm, - disable_dml, disable_webgpu); + hidden_size, v_hidden_size, is_static_kv, use_float16, disable_cpu, disable_cuda, disable_webgpu, + disable_rocm, disable_dml); return; } @@ -370,8 +370,8 @@ static void RunMultiHeadAttentionKernel( query_data, key_data, value_data, kv_data, qkv_data, bias_data, attention_bias_data, past_key_data, past_value_data, present_key_data, present_value_data, key_padding_mask_data, mask_type, output_data, num_heads, batch_size, sequence_length, kv_sequence_length, - hidden_size, v_hidden_size, is_static_kv, use_float16, disable_cpu, disable_cuda, disable_rocm, - disable_dml, disable_webgpu); + hidden_size, v_hidden_size, is_static_kv, use_float16, disable_cpu, disable_cuda, disable_webgpu, + disable_rocm, disable_dml); return; } #endif @@ -389,8 +389,8 @@ static void RunMultiHeadAttentionKernel( query_data, key_data, value_data, kv_data, qkv_data, bias_data, attention_bias_data, past_key_data, past_value_data, present_key_data, present_value_data, key_padding_mask_data, mask_type, output_data, num_heads, batch_size, sequence_length, kv_sequence_length, - hidden_size, v_hidden_size, is_static_kv, use_float16, disable_cpu, disable_cuda, disable_rocm, - disable_dml, disable_webgpu); + hidden_size, v_hidden_size, is_static_kv, use_float16, disable_cpu, disable_cuda, disable_webgpu, + disable_rocm, disable_dml); } if (kernel_type == AttentionKernelType::AttentionKernel_CudnnFlashAttention) { @@ -406,13 +406,30 @@ static void RunMultiHeadAttentionKernel( query_data, key_data, value_data, kv_data, qkv_data, bias_data, attention_bias_data, past_key_data, past_value_data, present_key_data, present_value_data, key_padding_mask_data, mask_type, output_data, num_heads, batch_size, sequence_length, kv_sequence_length, - hidden_size, v_hidden_size, is_static_kv, use_float16, disable_cpu, disable_cuda, disable_rocm, - disable_dml, disable_webgpu); + hidden_size, v_hidden_size, is_static_kv, use_float16, disable_cpu, disable_cuda, disable_webgpu, + disable_rocm, disable_dml); } } -static void RunMultiHeadAttentionTests(AttentionTestData& data, bool disable_webgpu = true, bool disable_cpu = false, - bool disable_cuda = false) { +enum RunMultiHeadAttentionTestToggles : uint32_t { + DISABLE_NONE = 0, + DISABLE_CPU = 1 << 0, + DISABLE_CUDA = 1 << 1, + DISABLE_WEBGPU = 1 << 2, +}; +inline RunMultiHeadAttentionTestToggles operator|(RunMultiHeadAttentionTestToggles a, RunMultiHeadAttentionTestToggles b) { + return static_cast(static_cast(a) | static_cast(b)); +} +inline RunMultiHeadAttentionTestToggles operator&(RunMultiHeadAttentionTestToggles a, RunMultiHeadAttentionTestToggles b) { + return static_cast(static_cast(a) & static_cast(b)); +} + +static void RunMultiHeadAttentionTests(AttentionTestData& data, + RunMultiHeadAttentionTestToggles toggles = DISABLE_NONE) { + bool disable_cpu = toggles & DISABLE_CPU; + bool disable_cuda = toggles & DISABLE_CUDA; + bool disable_webgpu = toggles & DISABLE_WEBGPU; + if (data.fp32_output_data.size() > 0) { constexpr bool use_float16 = false; @@ -509,75 +526,75 @@ static void RunMultiHeadAttentionTests(AttentionTestData& data, bool disable_web TEST(MultiHeadAttentionTest, CrossAttention_Batch2_HeadSize40) { AttentionTestData data; GetCrossAttentionData_HeadSize40(data); - RunMultiHeadAttentionTests(data, /*disable_webgpu=*/false); + RunMultiHeadAttentionTests(data); GetCrossAttentionData_HeadSize40_NoBias(data); - RunMultiHeadAttentionTests(data, /*disable_webgpu=*/false); + RunMultiHeadAttentionTests(data); } TEST(MultiHeadAttentionTest, CrossAttention_Batch2_HeadSize32_RightSidePadding_Mask1D) { ROCM_GTEST_SKIP("ROCm MHA does not support mask type of MASK_1D_KEY_SEQ_LEN"); AttentionTestData data; GetCrossAttentionData_Batch2_HeadSize32_RightSidePadding(data, true); - RunMultiHeadAttentionTests(data, /*disable_webgpu=*/true, /*disable_cpu=*/true); + RunMultiHeadAttentionTests(data, DISABLE_CPU | DISABLE_WEBGPU); GetCrossAttentionData_Batch2_HeadSize32_RightSidePadding_NoBias(data, true); - RunMultiHeadAttentionTests(data, /*disable_webgpu=*/true, /*disable_cpu=*/true); + RunMultiHeadAttentionTests(data, DISABLE_CPU | DISABLE_WEBGPU); } TEST(MultiHeadAttentionTest, CrossAttention_Batch2_HeadSize32_RightSidePadding_Mask2D) { AttentionTestData data; GetCrossAttentionData_Batch2_HeadSize32_RightSidePadding(data, false); - RunMultiHeadAttentionTests(data, /*disable_webgpu=*/true, /*disable_cpu=*/true); + RunMultiHeadAttentionTests(data, DISABLE_CPU | DISABLE_WEBGPU); GetCrossAttentionData_Batch2_HeadSize32_RightSidePadding_NoBias(data, false); - RunMultiHeadAttentionTests(data, /*disable_webgpu=*/true, /*disable_cpu=*/true); + RunMultiHeadAttentionTests(data, DISABLE_CPU | DISABLE_WEBGPU); } TEST(MultiHeadAttentionTest, CrossAttention_Batch1_HeadSize32_LeftSidePadding_Mask2D) { AttentionTestData data; GetCrossAttentionData_Batch1_HeadSize32_LeftSidePadding(data); - RunMultiHeadAttentionTests(data, /*disable_webgpu=*/true, /*disable_cpu=*/true); + RunMultiHeadAttentionTests(data, DISABLE_CPU | DISABLE_WEBGPU); GetCrossAttentionData_Batch1_HeadSize32_LeftSidePadding_NoBias(data); - RunMultiHeadAttentionTests(data, /*disable_webgpu=*/true, /*disable_cpu=*/true); + RunMultiHeadAttentionTests(data, DISABLE_CPU | DISABLE_WEBGPU); } TEST(MultiHeadAttentionTest, CrossAttention_Batch2_HeadSize32_NoBias_NoMask_PackedKV) { AttentionTestData data; GetCrossAttentionData_Batch2_HeadSize32_NoBias_NoMask_PackedKV(data); - RunMultiHeadAttentionTests(data, /*disable_webgpu=*/true); + RunMultiHeadAttentionTests(data, DISABLE_WEBGPU); } TEST(MultiHeadAttentionTest, SelfAttention_Batch2_HeadSize32_NoBias_NoMask_PackedQKV) { AttentionTestData data; GetSelfAttentionData_Batch2_HeadSize32_NoBias_NoMask_PackedQKV(data); - RunMultiHeadAttentionTests(data, /*disable_webgpu=*/true); + RunMultiHeadAttentionTests(data, DISABLE_WEBGPU); } // This tests qk_head_size != v_head_size TEST(MultiHeadAttentionTest, CrossAttention_Batch2_HeadSize16_8) { AttentionTestData data; GetCrossAttentionData_HeadSize16_8(data); - RunMultiHeadAttentionTests(data, /*disable_webgpu=*/false); + RunMultiHeadAttentionTests(data); GetCrossAttentionData_HeadSize16_8_NoBias(data); - RunMultiHeadAttentionTests(data, /*disable_webgpu=*/false); + RunMultiHeadAttentionTests(data); } TEST(MultiHeadAttentionTest, CrossAttention_Batch1_HeadSize16) { AttentionTestData data; GetCrossAttentionData_HeadSize16(data); - RunMultiHeadAttentionTests(data, /*disable_webgpu=*/false); + RunMultiHeadAttentionTests(data); GetCrossAttentionData_HeadSize16_NoBias(data); - RunMultiHeadAttentionTests(data, /*disable_webgpu=*/false); + RunMultiHeadAttentionTests(data); } TEST(MultiHeadAttentionTest, CrossAttention_Batch1_HeadSize8) { AttentionTestData data; GetCrossAttentionData_HeadSize8_NoBias(data); - RunMultiHeadAttentionTests(data, /*disable_webgpu=*/false, false, true); + RunMultiHeadAttentionTests(data, DISABLE_CUDA); } // TODO (pavignol): Fix this regression @@ -587,7 +604,7 @@ TEST(MultiHeadAttentionTest, CrossAttentionWithPast) { ROCM_GTEST_SKIP("ROCm MHA only support head_size >= 8"); AttentionTestData data; GetCrossAttentionDataWithPast(data); - RunMultiHeadAttentionTests(data); + RunMultiHeadAttentionTests(data, DISABLE_WEBGPU); } #endif @@ -595,40 +612,40 @@ TEST(MultiHeadAttentionTest, SelfAttention_WithPast_WithAttnBias_ForT5) { ROCM_GTEST_SKIP("ROCm MHA only support head_size >= 8"); AttentionTestData data; GetSelfAttentionData_WithPast_WithAttnBias_ForT5(data); - RunMultiHeadAttentionTests(data, /*disable_webgpu=*/false, true); + RunMultiHeadAttentionTests(data, DISABLE_CPU); } TEST(MultiHeadAttentionTest, AttentionCutlassAttnBias) { // ROCM_GTEST_SKIP("ROCm does not support cutlass"); AttentionTestData data; GetAttentionDataCutlassAttnBias(data); - RunMultiHeadAttentionTests(data); + RunMultiHeadAttentionTests(data, DISABLE_WEBGPU); } TEST(MultiHeadAttentionTest, CrossAttention_DiffSequenceLengths) { // Whisper decoder cross attention without mask and different sequence lengths for Q and K/V AttentionTestData data; GetCrossAttentionData_DiffSequenceLengths(data); - RunMultiHeadAttentionTests(data); + RunMultiHeadAttentionTests(data, DISABLE_WEBGPU); GetCrossAttentionData_DiffSequenceLengths_HeadSize8(data); - RunMultiHeadAttentionTests(data, /*disable_webgpu=*/true, /*disable_cpu=*/false, /*disable_cuda=*/true); + RunMultiHeadAttentionTests(data, DISABLE_CUDA | DISABLE_WEBGPU); GetCrossAttentionData_DiffSequenceLengths_HeadSize8_NoBias(data); - RunMultiHeadAttentionTests(data, /*disable_webgpu=*/true, /*disable_cpu=*/false, /*disable_cuda=*/true); + RunMultiHeadAttentionTests(data, DISABLE_CUDA | DISABLE_WEBGPU); } TEST(MultiHeadAttentionTest, SelfAttention_WithPastAndPresent_NoMask_NoAttnBias) { // Whisper decoder self attention with past_kv and present_kv AttentionTestData data; GetSelfAttentionData_WithPastAndPresent_NoMask_NoAttnBias(data); - RunMultiHeadAttentionTests(data, /*disable_webgpu=*/false); + RunMultiHeadAttentionTests(data); GetSelfAttentionData_WithPastAndPresent_HeadSize8_NoMask_NoAttnBias(data); - RunMultiHeadAttentionTests(data, /*disable_webgpu=*/false, /*disable_cpu=*/false, /*disable_cuda=*/true); + RunMultiHeadAttentionTests(data, DISABLE_CUDA); GetSelfAttentionData_WithPastAndPresent_HeadSize8_NoMask_NoAttnBias_NoBias(data); - RunMultiHeadAttentionTests(data, /*disable_webgpu=*/false, /*disable_cpu=*/false, /*disable_cuda=*/true); + RunMultiHeadAttentionTests(data, DISABLE_CUDA); } // This test is disabled since it is not used in Whisper anymore, and it fails in ROCm.