Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[wasm] Add BroadcastArgs kernel #7290

Merged
merged 9 commits into from
Jan 23, 2023

Conversation

chunnienc
Copy link
Collaborator

@chunnienc chunnienc commented Jan 19, 2023

Considering the use case of this kernel and there is already an util for it, this kernel is executed completely in JS.

To see the logs from the Cloud Build CI, please join either our discussion or announcement mailing list.


This change is Reviewable

@chunnienc chunnienc marked this pull request as ready for review January 19, 2023 19:17
Copy link
Member

@mattsoulanille mattsoulanille left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM with some minor nits, but see my comment about a more general solution.

tfjs-backend-wasm/src/backend_wasm.ts Outdated Show resolved Hide resolved
const s0Vals = backend.typedArrayFromHeap(s0);
const s1Vals = backend.typedArrayFromHeap(s1);

const broadcastShape = backend_util.assertAndGetBroadcastShape(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe we should have a generic way to fall back to the CPU implementation for ops that aren't implemented in WASM. That would allow us to simply not implement this op, and it would use the fallback by default. IIRC we have something like this for the WebGL backend, but it's mostly used to avoid sending tensors to the GPU for really small ops.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm thinking about the same thing. I may do it in a separate CL later.
This is not the first wasm kernel runs with cpu util/kernel function, and there was no general bridge for wasm-cpu kernel. I will revisit all the cpu function usages in wasm and find a general solution for them after I am familiar with more kernels.

const s1Vals = backend.typedArrayFromHeap(s1);

const broadcastShape = backend_util.assertAndGetBroadcastShape(
Array.from(s0Vals), Array.from(s1Vals));
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: We can probably change the type of assertAndGetBroadcastShape to accept typed arrays.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's not worthy - Only BroadcastArgs kernel (wasm, webgl, and webgpu) passes Int32Array to it, while most of the callers pass the shape array in number[]. Besides, not all the typed array can be the input, where the tensor type is guarded in core.

Copy link
Collaborator

@Linchenn Linchenn left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@chunnienc chunnienc merged commit ceb69e6 into tensorflow:master Jan 23, 2023
@chunnienc chunnienc deleted the wasm-broadcastargs branch January 23, 2023 21:53
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants