-
Notifications
You must be signed in to change notification settings - Fork 1.9k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[wasm] Add BroadcastArgs kernel #7290
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM with some minor nits, but see my comment about a more general solution.
const s0Vals = backend.typedArrayFromHeap(s0); | ||
const s1Vals = backend.typedArrayFromHeap(s1); | ||
|
||
const broadcastShape = backend_util.assertAndGetBroadcastShape( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe we should have a generic way to fall back to the CPU implementation for ops that aren't implemented in WASM. That would allow us to simply not implement this op, and it would use the fallback by default. IIRC we have something like this for the WebGL backend, but it's mostly used to avoid sending tensors to the GPU for really small ops.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm thinking about the same thing. I may do it in a separate CL later.
This is not the first wasm kernel runs with cpu util/kernel function, and there was no general bridge for wasm-cpu kernel. I will revisit all the cpu function usages in wasm and find a general solution for them after I am familiar with more kernels.
const s1Vals = backend.typedArrayFromHeap(s1); | ||
|
||
const broadcastShape = backend_util.assertAndGetBroadcastShape( | ||
Array.from(s0Vals), Array.from(s1Vals)); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nit: We can probably change the type of assertAndGetBroadcastShape
to accept typed arrays.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It's not worthy - Only BroadcastArgs kernel (wasm, webgl, and webgpu) passes Int32Array to it, while most of the callers pass the shape array in number[]. Besides, not all the typed array can be the input, where the tensor type is guarded in core.
Co-authored-by: Matthew Soulanille <[email protected]>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
Considering the use case of this kernel and there is already an util for it, this kernel is executed completely in JS.
To see the logs from the Cloud Build CI, please join either our discussion or announcement mailing list.
This change is