-
Notifications
You must be signed in to change notification settings - Fork 3k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
webgpu: support MultiHeadAttention operator #22144
webgpu: support MultiHeadAttention operator #22144
Conversation
|
||
class TransferBSDToBNSHProgram final : public Program<TransferBSDToBNSHProgram> { | ||
public: | ||
TransferBSDToBNSHProgram(const std::string& kernel_name, bool has_bias) : Program{kernel_name}, has_bias_(has_bias) {} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
TransferBSDToBNSHProgram(const std::string& kernel_name, bool has_bias) : Program{kernel_name}, has_bias_(has_bias) {} | |
TransferBSDToBNSHProgram(bool has_bias) : Program{"TransferBSDToBNSH"}, has_bias_(has_bias) {} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done
shader.AddOutput("present_key", ShaderVariable::UseUniform); | ||
} | ||
|
||
shader.AppendImplementation("const TILE_SIZE = ", tile_size_, "u;\n") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
it seems that cache keys for the programs are not set correctly.
for example, here tile_size_
is used as a part of the shader source code, but it is not set in the cache key. Use program.CacheHint()
to set the cache key.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
TILE_SIZE can also be declared in overridable constants.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done
shader.AppendImplementation("var<workgroup> thread_max: array<f32, ", work_group_size_, ">;\n") | ||
.AppendImplementation("var<workgroup> thread_sum: array<f32, ", work_group_size_, ">;\n"); | ||
|
||
std::string f32_str = components_ == 4 ? "vec4<f32>" : (components_ == 2 ? "vec2<f32>" : "f32"); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
can use x_value_t
for the value type of x.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No matter x
's type is f32 or f16, the program only uses f32 to define max and sum values.
I didn't see any call to set program cache key. This may be correct (if necessary information is already in uniform). need to confirm. |
Merged to latest. |
Description
Motivation and Context