-
Notifications
You must be signed in to change notification settings - Fork 1.9k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[webgpu] Tweak WGSL entry point emission so '{' isn't visually swallowed #6764
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM. Thanks for the unification!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM with a comment. But I am ok to merge this PR first and see if we can do further cleanup to address the comment in follow-up PRs.
@builtin(num_workgroups) NumWorkgroups : vec3<u32>) { | ||
localId = LocalId; | ||
globalId = GlobalId; | ||
numWorkgroups = NumWorkgroups; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you check if numWorkgroups = NumWorkgroups;
is still needed for this case? I expect if we use non-linear work group size, we can totally remove the dependency on num_workgroups
and getGlobalIndex()
. I am ok to do it in a follow-up PR. Thanks.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't feel it's safe to assume getGlobalIndex()
is unused elsewhere. E.g. we currently have this dependency chain in depth-wise conv2d: main() -> getOutputCoords() -> getGlobalIndex()
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
getGlobalIndex()
is only called when x.length === outRank
in getOutputCoords
, which means a flat dispatch layout. Let's keep it for now. Maybe it's not necessary to worry about the very small overhead of NumWorkgroups
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I do find the x.length === outRank
branch is taken while running depthwise_conv2d_test, is this expected?
To see the logs from the Cloud Build CI, please join either our discussion or announcement mailing list.
This change is