Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[js/WebGPU] webgpu performance gains not showing with transformers.js #17373

Open
guschmue opened this issue Aug 31, 2023 · 19 comments
Open

[js/WebGPU] webgpu performance gains not showing with transformers.js #17373

guschmue opened this issue Aug 31, 2023 · 19 comments
Assignees
Labels
ep:WebGPU ort-web webgpu provider platform:web issues related to ONNX Runtime web; typically submitted using template

Comments

@guschmue
Copy link
Contributor

continued from #17167

Performance gains we are seeing with webgpu in standalone tests are not showing with webgpu.

fyi @dakenf, @fs-eire, @xenova

@guschmue guschmue added the ep:WebGPU ort-web webgpu provider label Aug 31, 2023
@guschmue guschmue self-assigned this Aug 31, 2023
@github-actions github-actions bot added the platform:web issues related to ONNX Runtime web; typically submitted using template label Aug 31, 2023
@guschmue
Copy link
Contributor Author

guschmue commented Sep 5, 2023

I did some testing with all models from the transformers.js demo app. I see some good and some not so good numbers. Trying to understand what I'm seeing.

@xenova
Copy link

xenova commented Sep 5, 2023

I did some testing with all models from the transformers.js demo app. I see some good and some not so good numbers. Trying to understand what I'm seeing.

I assume this testing was outside of transformers.js (i.e., just using standalone onnxruntime-web)? Also, could you share which models you had issues with?

@guschmue
Copy link
Contributor Author

guschmue commented Sep 5, 2023

I used your demo app and added some perf logging, then try to reproduce this with my standalone app trying to clone the input shapes I see in the demo app.
Some diff that I can't explain yet, for example t5-decoder and whisper-decoder. Looking at those 2 first.
distill-bert-squad, detr-resnet50 are not working on webgpu ... pretty sure we can fix those.

Just stepping through the code ... the kv-caches for generative models must be hurting webgpu a lot. @fs-eire is looking at io bindings that should help a lot.

@dakenf
Copy link
Contributor

dakenf commented Sep 6, 2023

I've made a script that summarizes profiler output and here's what i got for SD after some optimizations for Attention and InstanceNorm:

Operation: Conv, Invocations: 64, Total Time: 888.88 ms
Operation: Attention, Invocations: 16, Total Time: 824.56 ms
Operation: MatMul, Invocations: 112, Total Time: 54.34 ms
Operation: MultiHeadAttention, Invocations: 16, Total Time: 17.78 ms
Operation: SkipLayerNormalization, Invocations: 32, Total Time: 7.49 ms
Operation: InstanceNormalization, Invocations: 61, Total Time: 6.51 ms
Operation: Transpose, Invocations: 156, Total Time: 4.59 ms
Operation: Add, Invocations: 153, Total Time: 3.16 ms
Operation: Mul, Invocations: 109, Total Time: 2.45 ms
Operation: Gemm, Invocations: 24, Total Time: 1.91 ms
Operation: LayerNormalization, Invocations: 16, Total Time: 1.83 ms
Operation: BiasSplitGelu, Invocations: 16, Total Time: 1.29 ms
Operation: Sigmoid, Invocations: 47, Total Time: 0.76 ms
Operation: Concat, Invocations: 14, Total Time: 0.39 ms
Operation: BiasAdd, Invocations: 16, Total Time: 0.35 ms
Operation: Resize, Invocations: 3, Total Time: 0.09 ms
Operation: Slice, Invocations: 1, Total Time: 0.01 ms
Operation: Expand, Invocations: 1, Total Time: 0.00 ms
Operation: Cos, Invocations: 1, Total Time: 0.00 ms
Operation: Sin, Invocations: 1, Total Time: 0.00 ms

Summary:
Total Invocations: 859
Total Time Consumed: 1816.40 ms

Will try to add tiling to MatMul/Gemm in Attention. It should drastically improve performance. Maybe it's possible to use existing packed matmul implementation. Typically Attention performance now is:

[profiling] kernel "162594264|[Attention] 162594264" execution time: 7876608 ns // split input into QKV and multiply with weights
[profiling] kernel "162594264|[Attention] 162594264" execution time: 66883584 ns // attention probs
[profiling] kernel "162594264|[Attention] 162594264" execution time: 14159872 ns // softmax
[profiling] kernel "162594264|[Attention] 162594264" execution time: 11609088 ns // attention score
[profiling] kernel "162594264|[Attention] 162594264" execution time: 30720 ns // transpose output

Then next step would be NhwcConv with vec4 optimizations. So in theory after dealing with these two kernels, i can get it to same performance as these folks have https://websd.mlc.ai/ but for general purpose models, not precompiled ones

Also, fp16 gave about 1.5x performance benefit. But it does not work on windows for some reason and on a mac webgpu does not support profiling. So i'm slightly disappointed but will make a PR with changes anyway since at some point it will work on windows too

@guschmue
Copy link
Contributor Author

guschmue commented Sep 7, 2023

awesome! We had the f16 support on the list but you beat us to it. The windows support is coming soon as far I know and it might be already in canary (will check on it).

For the transformers.js performance:
finally have all the correct shapes in my benchmark app to match what the demo app does.
We are doing ok on encoder and non-generative models (say clip).
What is hurting us is the one token at a time decoder for generative models.
2 issues:

  1. some ops still fall back to cpu. In many cases that is fine but few are not. Working on that one. I'm looking at t5-decoder and whisper-decoder first.
  2. The past*/present* key/values inputs and outputs - they are large and we copy them cross device which is super costly and all of it for 1 token. For python app one would make them iobindings for onnxruntime which ort-web doesn't have yet. We'll add something for that.

Hope we have something soon to address both.

@dakenf
Copy link
Contributor

dakenf commented Sep 7, 2023

I guess after i finish with Attention, it should speed-up LLMs too. However right now my implementation is quite narrow and does not support past/present and mask. But it won't be hard to add them

@xenova
Copy link

xenova commented Sep 7, 2023

@guschmue amazing!!! On that note, do you have a list of models which should be supported fully in transformers.js (bert, vit, sam maybe)? I can start developing some example applications for them.

@guschmue
Copy link
Contributor Author

guschmue commented Sep 8, 2023

I think I can come up with a list of models that work well.

@dakenf
Copy link
Contributor

dakenf commented Sep 10, 2023

Finally got some good results with Attention optimizations. I expected a bit better but 4x is still good. Will revisit it at some point later with some kind of flash attention implementation. Most likely it will be quite challenging on webgpu
But now if i'll manage to optimize Conv, i'll get it to less than 0.5 sec for a step

Operation: Conv, Invocations: 64, Total Time: 874.26 ms
Operation: Attention, Invocations: 16, Total Time: 207.21 ms
Operation: MatMul, Invocations: 111, Total Time: 53.95 ms
Operation: MultiHeadAttention, Invocations: 17, Total Time: 7.74 ms
Operation: SkipLayerNormalization, Invocations: 32, Total Time: 7.00 ms
Operation: InstanceNormalization, Invocations: 61, Total Time: 6.49 ms
Operation: Transpose, Invocations: 156, Total Time: 4.52 ms
Operation: Add, Invocations: 153, Total Time: 3.20 ms
Operation: Mul, Invocations: 109, Total Time: 2.45 ms
Operation: Gemm, Invocations: 24, Total Time: 1.89 ms
Operation: LayerNormalization, Invocations: 16, Total Time: 1.83 ms
Operation: BiasSplitGelu, Invocations: 16, Total Time: 1.28 ms
Operation: Sigmoid, Invocations: 47, Total Time: 0.76 ms
Operation: Concat, Invocations: 14, Total Time: 0.39 ms
Operation: BiasAdd, Invocations: 16, Total Time: 0.35 ms
Operation: Resize, Invocations: 3, Total Time: 0.10 ms
Operation: Slice, Invocations: 1, Total Time: 0.01 ms
Operation: Cos, Invocations: 1, Total Time: 0.00 ms
Operation: Sin, Invocations: 1, Total Time: 0.00 ms
Operation: Expand, Invocations: 1, Total Time: 0.00 ms

Summary:
Total Invocations: 859
Total Time Consumed: 1173.43 ms

@qjia7
Copy link
Contributor

qjia7 commented Sep 11, 2023

Also, fp16 gave about 1.5x performance benefit. But it does not work on windows for some reason and on a mac webgpu does not support profiling.

On windows, to support fp16, you need to add flag --enable-dawn-features=allow_unsafe_apis,use_dxc to chrome canary. Compared with Mac, use_dxc is extra needed. And please also make sure the windows sdk is installed so that they can find the dxc.dll. But I think the dxc will be supported soon by default on windows.

And on Mac, the profiling is not supported now. Because we are using a in-pass timestamp query which is not supported on Mac. We should change it to use the core feature timestamp-query instead of the extension timestamp-query-inside-passes. Reference here. Maybe @gyagp can provide a similar PR to quicky fix it.

@qjia7
Copy link
Contributor

qjia7 commented Sep 11, 2023

Finally got some good results with Attention optimizations. I expected a bit better but 4x is still good. Will revisit it at some point later with some kind of flash attention implementation. Most likely it will be quite challenging on webgpu But now if i'll manage to optimize Conv, i'll get it to less than 0.5 sec for a step

Operation: Conv, Invocations: 64, Total Time: 874.26 ms
Operation: Attention, Invocations: 16, Total Time: 207.21 ms
Operation: MatMul, Invocations: 111, Total Time: 53.95 ms
Operation: MultiHeadAttention, Invocations: 17, Total Time: 7.74 ms
Operation: SkipLayerNormalization, Invocations: 32, Total Time: 7.00 ms
Operation: InstanceNormalization, Invocations: 61, Total Time: 6.49 ms
Operation: Transpose, Invocations: 156, Total Time: 4.52 ms
Operation: Add, Invocations: 153, Total Time: 3.20 ms
Operation: Mul, Invocations: 109, Total Time: 2.45 ms
Operation: Gemm, Invocations: 24, Total Time: 1.89 ms
Operation: LayerNormalization, Invocations: 16, Total Time: 1.83 ms
Operation: BiasSplitGelu, Invocations: 16, Total Time: 1.28 ms
Operation: Sigmoid, Invocations: 47, Total Time: 0.76 ms
Operation: Concat, Invocations: 14, Total Time: 0.39 ms
Operation: BiasAdd, Invocations: 16, Total Time: 0.35 ms
Operation: Resize, Invocations: 3, Total Time: 0.10 ms
Operation: Slice, Invocations: 1, Total Time: 0.01 ms
Operation: Cos, Invocations: 1, Total Time: 0.00 ms
Operation: Sin, Invocations: 1, Total Time: 0.00 ms
Operation: Expand, Invocations: 1, Total Time: 0.00 ms

Summary:
Total Invocations: 859
Total Time Consumed: 1173.43 ms

@dakenf Could you please share the onnx model you tested? I would like to look at the performance issue in it.

@dakenf
Copy link
Contributor

dakenf commented Sep 11, 2023

It's StableDiffusion 2.1. My changes to support 64bit wasm build and big model loading are not merged yet so it will be hard to reproduce locally
Thanks for the fp16 flag info, will try it

@dakenf
Copy link
Contributor

dakenf commented Sep 11, 2023

@dakenf Could you please share the onnx model you tested? I would like to look at the performance issue in it.

You can try this one https://huggingface.co/aislamov/stable-diffusion-2-1-base-onnx/tree/main/vae_decoder and feed it with random data in latent_sample input with shape [1, 4, 64, 64]
Apparently this one is even slower than SD Unet

Operation: Conv, Invocations: 33, Total Time: 2908.64 ms
Operation: InstanceNormalization, Invocations: 30, Total Time: 49.30 ms
Operation: Transpose, Invocations: 67, Total Time: 34.44 ms
Operation: Mul, Invocations: 61, Total Time: 16.24 ms
Operation: Add, Invocations: 49, Total Time: 10.64 ms
Operation: MatMul, Invocations: 6, Total Time: 5.16 ms
Operation: Sigmoid, Invocations: 29, Total Time: 4.27 ms
Operation: Resize, Invocations: 3, Total Time: 2.00 ms
Operation: Softmax, Invocations: 1, Total Time: 0.44 ms
Operation: Sqrt, Invocations: 2, Total Time: 0.01 ms
Operation: Div, Invocations: 1, Total Time: 0.00 ms

Summary:
Total Invocations: 282
Total Time Consumed: 3031.13 ms

@qjia7
Copy link
Contributor

qjia7 commented Sep 11, 2023

@dakenf I met below errors:

Error: Can't create a session. ERROR_CODE: 1, ERROR_MESSAGE: Deserialize tensor onnx::Add_959 failed.GetFileLength for ./weights.pb failed:File is too large.

Maybe it's like you said My changes to support 64bit wasm build and big model loading are not merged yet so it will be hard to reproduce locally. Do you have PRs in review to enable them? Expect them to be supported soon:)

@dakenf
Copy link
Contributor

dakenf commented Sep 11, 2023

Yeah, sorry, did not think about that (there is a PR for that too :). Here is the model version with just one file https://huggingface.co/aislamov/stable-diffusion-2-1-base-onnx/tree/9f697c96d42e5c09437ff14b0a2b287366ce488d/vae_decoder
And ignore the commit message there. It is a fp32 version

@dakenf
Copy link
Contributor

dakenf commented Sep 11, 2023

Got the results with fp16 on windows. Conv got slightly faster 874.26 ms -> 656.86 ms, Attention 207.21 ms -> 174.33 ms
But MatMul and SkipLayerNormalization went slightly slower

So @qjia7 if you can spend some time checking the Conv i'd be very thankful. Going to focus on refining existing PRs since I've reached my sub one second goal (it was more than 70 seconds on my machine when i've started)

Operation: Conv, Invocations: 64, Total Time: 656.86 ms
Operation: Attention, Invocations: 16, Total Time: 174.33 ms
Operation: MatMul, Invocations: 113, Total Time: 79.07 ms
Operation: SkipLayerNormalization, Invocations: 32, Total Time: 14.86 ms
Operation: MultiHeadAttention, Invocations: 16, Total Time: 7.26 ms
Operation: InstanceNormalization, Invocations: 61, Total Time: 5.87 ms
Operation: Transpose, Invocations: 156, Total Time: 4.03 ms
Operation: LayerNormalization, Invocations: 16, Total Time: 3.01 ms
Operation: Add, Invocations: 153, Total Time: 2.35 ms
Operation: Gemm, Invocations: 24, Total Time: 1.85 ms
Operation: Mul, Invocations: 109, Total Time: 1.71 ms
Operation: BiasSplitGelu, Invocations: 16, Total Time: 0.67 ms
Operation: Sigmoid, Invocations: 47, Total Time: 0.47 ms
Operation: Concat, Invocations: 14, Total Time: 0.37 ms
Operation: BiasAdd, Invocations: 16, Total Time: 0.20 ms
Operation: Resize, Invocations: 3, Total Time: 0.09 ms
Operation: Cast, Invocations: 4, Total Time: 0.02 ms
Operation: Slice, Invocations: 1, Total Time: 0.01 ms
Operation: Expand, Invocations: 1, Total Time: 0.00 ms
Operation: Cos, Invocations: 1, Total Time: 0.00 ms
Operation: Sin, Invocations: 1, Total Time: 0.00 ms

Summary:
Total Invocations: 864
Total Time Consumed: 953.04 ms

@qjia7
Copy link
Contributor

qjia7 commented Sep 12, 2023

Got the results with fp16 on windows. Conv got slightly faster 874.26 ms -> 656.86 ms, Attention 207.21 ms -> 174.33 ms

Excellent work!

@dakenf, fyi, I can get the profiling data from https://huggingface.co/aislamov/stable-diffusion-2-1-base-onnx/tree/9f697c96d42e5c09437ff14b0a2b287366ce488d/vae_decoder. From the profiling result, it seems that from some conv beginning, it becomes NCHW instead of NHWC so that it can't go to our optimization path. And those conv are executing really slow. I think that's why you see Conv's time consumes a lot. I will take a further look.

@dakenf
Copy link
Contributor

dakenf commented Sep 12, 2023

From the profiling result, it seems that from some conv beginning, it becomes NCHW instead of NHWC so that it can't go to our optimization path

I was thinking to add NhwcConv operator as described here https://github.com/microsoft/onnxruntime/blob/main/onnxruntime/python/tools/transformers/models/stable_diffusion/README.md
I think the only difference between existing Conv implementation in com.microsoft.internal.nhwc is that in NhwcConv weights are in Nhwc format too. And current one takes weights input as NCWH (but maybe there's something else)

@gyagp
Copy link

gyagp commented Oct 27, 2023

Also, fp16 gave about 1.5x performance benefit. But it does not work on windows for some reason and on a mac webgpu does not support profiling.

On windows, to support fp16, you need to add flag --enable-dawn-features=allow_unsafe_apis,use_dxc to chrome canary. Compared with Mac, use_dxc is extra needed. And please also make sure the windows sdk is installed so that they can find the dxc.dll. But I think the dxc will be supported soon by default on windows.

And on Mac, the profiling is not supported now. Because we are using a in-pass timestamp query which is not supported on Mac. We should change it to use the core feature timestamp-query instead of the extension timestamp-query-inside-passes. Reference here. Maybe @gyagp can provide a similar PR to quicky fix it.

#18108 just got merged, so the profiling should work on macOS now.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
ep:WebGPU ort-web webgpu provider platform:web issues related to ONNX Runtime web; typically submitted using template
Projects
None yet
Development

No branches or pull requests

5 participants