Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[webgpu] Tweak WGSL entry point emission so '{' isn't visually swallowed #6764

Merged
merged 3 commits into from
Aug 22, 2022

Conversation

hujiajie
Copy link
Contributor

@hujiajie hujiajie commented Aug 18, 2022

To see the logs from the Cloud Build CI, please join either our discussion or announcement mailing list.


This change is Reviewable

Copy link

@gyagp gyagp left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM. Thanks for the unification!

@gyagp gyagp requested review from qjia7 and xhcao August 18, 2022 08:54
Copy link
Contributor

@qjia7 qjia7 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM with a comment. But I am ok to merge this PR first and see if we can do further cleanup to address the comment in follow-up PRs.

@builtin(num_workgroups) NumWorkgroups : vec3<u32>) {
localId = LocalId;
globalId = GlobalId;
numWorkgroups = NumWorkgroups;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you check if numWorkgroups = NumWorkgroups; is still needed for this case? I expect if we use non-linear work group size, we can totally remove the dependency on num_workgroups and getGlobalIndex(). I am ok to do it in a follow-up PR. Thanks.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't feel it's safe to assume getGlobalIndex() is unused elsewhere. E.g. we currently have this dependency chain in depth-wise conv2d: main() -> getOutputCoords() -> getGlobalIndex().

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

getGlobalIndex() is only called when x.length === outRank in getOutputCoords, which means a flat dispatch layout. Let's keep it for now. Maybe it's not necessary to worry about the very small overhead of NumWorkgroups .

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I do find the x.length === outRank branch is taken while running depthwise_conv2d_test, is this expected?

@qjia7 qjia7 merged commit 2a25492 into tensorflow:master Aug 22, 2022
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants