-
Notifications
You must be signed in to change notification settings - Fork 3k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[js/WebGPU] webgpu performance gains not showing with transformers.js #17373
Comments
I did some testing with all models from the transformers.js demo app. I see some good and some not so good numbers. Trying to understand what I'm seeing. |
I assume this testing was outside of transformers.js (i.e., just using standalone onnxruntime-web)? Also, could you share which models you had issues with? |
I used your demo app and added some perf logging, then try to reproduce this with my standalone app trying to clone the input shapes I see in the demo app. Just stepping through the code ... the kv-caches for generative models must be hurting webgpu a lot. @fs-eire is looking at io bindings that should help a lot. |
I've made a script that summarizes profiler output and here's what i got for SD after some optimizations for Attention and InstanceNorm:
Will try to add tiling to MatMul/Gemm in Attention. It should drastically improve performance. Maybe it's possible to use existing packed matmul implementation. Typically Attention performance now is:
Then next step would be NhwcConv with vec4 optimizations. So in theory after dealing with these two kernels, i can get it to same performance as these folks have https://websd.mlc.ai/ but for general purpose models, not precompiled ones Also, fp16 gave about 1.5x performance benefit. But it does not work on windows for some reason and on a mac webgpu does not support profiling. So i'm slightly disappointed but will make a PR with changes anyway since at some point it will work on windows too |
awesome! We had the f16 support on the list but you beat us to it. The windows support is coming soon as far I know and it might be already in canary (will check on it). For the transformers.js performance:
Hope we have something soon to address both. |
I guess after i finish with Attention, it should speed-up LLMs too. However right now my implementation is quite narrow and does not support past/present and mask. But it won't be hard to add them |
@guschmue amazing!!! On that note, do you have a list of models which should be supported fully in transformers.js (bert, vit, sam maybe)? I can start developing some example applications for them. |
I think I can come up with a list of models that work well. |
Finally got some good results with Attention optimizations. I expected a bit better but 4x is still good. Will revisit it at some point later with some kind of flash attention implementation. Most likely it will be quite challenging on webgpu
|
On windows, to support fp16, you need to add flag And on Mac, the profiling is not supported now. Because we are using a in-pass timestamp query which is not supported on Mac. We should change it to use the core feature |
@dakenf Could you please share the onnx model you tested? I would like to look at the performance issue in it. |
It's StableDiffusion 2.1. My changes to support 64bit wasm build and big model loading are not merged yet so it will be hard to reproduce locally |
You can try this one https://huggingface.co/aislamov/stable-diffusion-2-1-base-onnx/tree/main/vae_decoder and feed it with random data in
|
@dakenf I met below errors:
Maybe it's like you said |
Yeah, sorry, did not think about that (there is a PR for that too :). Here is the model version with just one file https://huggingface.co/aislamov/stable-diffusion-2-1-base-onnx/tree/9f697c96d42e5c09437ff14b0a2b287366ce488d/vae_decoder |
Got the results with fp16 on windows. Conv got slightly faster 874.26 ms -> 656.86 ms, Attention 207.21 ms -> 174.33 ms So @qjia7 if you can spend some time checking the Conv i'd be very thankful. Going to focus on refining existing PRs since I've reached my sub one second goal (it was more than 70 seconds on my machine when i've started)
|
Excellent work! @dakenf, fyi, I can get the profiling data from https://huggingface.co/aislamov/stable-diffusion-2-1-base-onnx/tree/9f697c96d42e5c09437ff14b0a2b287366ce488d/vae_decoder. From the profiling result, it seems that from some |
I was thinking to add NhwcConv operator as described here https://github.com/microsoft/onnxruntime/blob/main/onnxruntime/python/tools/transformers/models/stable_diffusion/README.md |
#18108 just got merged, so the profiling should work on macOS now. |
continued from #17167
Performance gains we are seeing with webgpu in standalone tests are not showing with webgpu.
fyi @dakenf, @fs-eire, @xenova
The text was updated successfully, but these errors were encountered: