Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Bug] Softmax op is very slow #3132

Open
gesanqiu opened this issue Feb 13, 2025 · 0 comments
Open

[Bug] Softmax op is very slow #3132

gesanqiu opened this issue Feb 13, 2025 · 0 comments
Labels
bug Confirmed bugs

Comments

@gesanqiu
Copy link
Contributor

🐛 Bug

To Reproduce

Steps to reproduce the behavior:

  1. run any model by mlc_llm serve and set --enable-tracing --enable-debug arguments, for example: mlc_llm serve /workdir/Qwen2-1.5B-Instruct-mlc/ --device cuda --model-lib /workdir/Qwen2-1.5B-Instruct-mlc/qwen2-1.5b.so --port 8090 --host 0.0.0.0 --enable-tracing --enable-debug
  2. get the Chrome Trace by curl -X POST http://127.0.0.1:8000/debug/dump_event_trace -H "Content-Type: application/json" -d '{"model": "dist/llama"}'
  3. parse the tracing log, you will find the softmax operator takes over 65% of the total time.
embedding (12) time cost: 0.129 ms
apply logit bias (12) time cost: 0.004 ms
apply penalty (12) time cost: 0.005 ms
apply logit mask (12) time cost: 0.004 ms
update logits (12) time cost: 0.024 ms
softmax (12) time cost: 6.229 ms
renormalization by top p (12) time cost: 0.21 ms
sampling (12) time cost: 0.114 ms
detokenization (12) time cost: 0.052 ms
callback (12) time cost: 0.104 ms
decode (12) time cost: 2.51 ms

Expected behavior

Environment

  • Platform (e.g. WebGPU/Vulkan/IOS/Android/CUDA): CUDA
  • Operating system (e.g. Ubuntu/Windows/MacOS/...): Ubuntu22.04
  • Device (e.g. iPhone 12 Pro, PC+RTX 3090, ...) Jetson AGX Orin 64GB Develop Kit
  • How you installed MLC-LLM (conda, source): source
  • How you installed TVM-Unity (pip, source): source
  • Python version (e.g. 3.10): 3.12
  • GPU driver version (if applicable):
  • CUDA/cuDNN version (if applicable): 12.6
  • TVM Unity Hash Tag (python -c "import tvm; print('\n'.join(f'{k}: {v}' for k, v in tvm.support.libinfo().items()))", applicable if you compile models):
  • Any other relevant information:

Additional context

@gesanqiu gesanqiu added the bug Confirmed bugs label Feb 13, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Confirmed bugs
Projects
None yet
Development

No branches or pull requests

1 participant