-
Notifications
You must be signed in to change notification settings - Fork 3k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[js/webgpu] Optimize matmulnbits with M > 1 #23092
base: main
Are you sure you want to change the base?
Conversation
@guschmue @sushraja-msft FYI. I notice @sushraja-msft already did some very good optimizations on webgpu native ep. I will port this PR to webgpu ep to make some comparisons. |
Thanks JiaJia, jfyi there is this pending change as well #23071 that improves on that previous change. Looking forward to comparing performance. |
This is the webgpu native ep implementation of #23092. I used https://github.com/fs-eire/ort-webgpu-nodejs-chatapp-prototype to test. Meanwhile, applied fs-eire/ort-webgpu-nodejs-chatapp-prototype#2 to print the first token time. The result is like below: The latest main branch: Intel Arc Graphics ``` 659 tokens in 24.8sec, 26.57 tokens/sec Decoding first token with input 449 tokens: 13.0 sec Decoding remaining 210 tokens: 11.8 sec 17.79 tokens/sec ``` NV RTX 2000 ``` 659 tokens in 14.4sec, 45.85 tokens/sec Decoding first token with input 449 tokens: 7.3 sec Decoding remaining 210 tokens: 7.0 sec 29.81 tokens/sec ``` ------------------------------------------------------------------------- With this PR: Intel Arc Graphics ``` 657 tokens in 20.6sec, 31.92 tokens/sec Decoding first token with input 449 tokens: 8.5 sec Decoding remaining 208 tokens: 12.1 sec 17.23 tokens/sec ``` NV RTX 2000 ``` 659 tokens in 11.4sec, 57.93 tokens/sec Decoding first token with input 449 tokens: 4.1 sec Decoding remaining 210 tokens: 7.2 sec 28.98 tokens/sec ``` From above data, you can see that with this PR, both intel (13s -> 8.5s) and NV (7.3s -> 4.1s) GPUs for the first token time are performing better.
/azp run ONNX Runtime Web CI Pipeline,Windows GPU CI Pipeline,Linux Android Emulator QNN CI Pipeline |
Azure Pipelines successfully started running 1 pipeline(s). |
/azp run Linux CPU CI Pipeline,Linux CPU Minimal Build E2E CI Pipeline,Linux GPU CI Pipeline,Linux GPU TensorRT CI Pipeline,Linux OpenVINO CI Pipeline,Linux QNN CI Pipeline,MacOS CI Pipeline,Windows ARM64 QNN CI Pipeline,Windows CPU CI Pipeline |
/azp run Windows GPU TensorRT CI Pipeline,onnxruntime-binary-size-checks-ci-pipeline,orttraining-linux-ci-pipeline,orttraining-linux-gpu-ci-pipeline,orttraining-ortmodule-distributed,Windows x64 QNN CI Pipeline,Big Models |
Azure Pipelines could not run because the pipeline triggers exclude this branch/path. |
Azure Pipelines successfully started running 1 pipeline(s). |
/azp run Windows GPU CUDA CI Pipeline, Windows GPU DML CI Pipeline, Windows GPU Doc Gen CI Pipeline |
Azure Pipelines could not run because the pipeline triggers exclude this branch/path. |
This is the webgpu native ep implementation of #23092. I used https://github.com/fs-eire/ort-webgpu-nodejs-chatapp-prototype to test. Meanwhile, applied fs-eire/ort-webgpu-nodejs-chatapp-prototype#2 to print the first token time. The result is like below: The latest main branch: Intel Arc Graphics ``` 659 tokens in 24.8sec, 26.57 tokens/sec Decoding first token with input 449 tokens: 13.0 sec Decoding remaining 210 tokens: 11.8 sec 17.79 tokens/sec ``` NV RTX 2000 ``` 659 tokens in 14.4sec, 45.85 tokens/sec Decoding first token with input 449 tokens: 7.3 sec Decoding remaining 210 tokens: 7.0 sec 29.81 tokens/sec ``` ------------------------------------------------------------------------- With this PR: Intel Arc Graphics ``` 657 tokens in 20.6sec, 31.92 tokens/sec Decoding first token with input 449 tokens: 8.5 sec Decoding remaining 208 tokens: 12.1 sec 17.23 tokens/sec ``` NV RTX 2000 ``` 659 tokens in 11.4sec, 57.93 tokens/sec Decoding first token with input 449 tokens: 4.1 sec Decoding remaining 210 tokens: 7.2 sec 28.98 tokens/sec ``` From above data, you can see that with this PR, both intel (13s -> 8.5s) and NV (7.3s -> 4.1s) GPUs for the first token time are performing better.
Description
This PR is mainly to optimize decoding first token in phi3 model. For the first token, the matmulnbits is a matrix * matrix, which is very slow when the prompt message is very long, like inputs tokens > 450.
Both intel GPUs and NV GPUs see good improvement with this PR.
My test data is like below:
Decoding first token with input 499 tokens:
It becomes 4.9s from 6.2s on NV RTX 2000.
It becomes 28.8s from 52.6s on Intel UHD 770.