Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Update benchmark_mha.py to compare with PyTorch SDPA (#21449)
### Description * Update benchmark_mha.py to compare with PyTorch SDPA api. * Write results to csv file. * Use sdpa_kernel cuda provider option instead of environment variables for better control. * Add arguments (`--use_gpu`, `--causal` etc) to allow testing different senarios. * Update benchmark_mha.sh to add cpu benchmarks For Q,K,V format, torch uses BNSH format, while ort uses BSNH format, so the result is not apple-to-apple. However, if the latency difference is large, that could be a warning. #### Example GPU results Example results on A100-SXM4-80GB with settings (use_gpu=TRUE, enable_cuda_graph=FALSE, causal=FALSE, past_sequence_length=0, intra_op_num_threads=0) in Azure Linux. ORT: build from source with CUDA 12.5; PyTorch 2.3.1 for cuda 12.1. format | batch_size | sequence_length | num_heads | head_size | latency (s) | tflops | kernel -- | -- | -- | -- | -- | -- | -- | -- Q,KV | 4 | 2048 | 32 | 128 | 0.0015 | 179.5 | ort:flash Q,KV | 4 | 2048 | 32 | 128 | 0.0015 | 179.0 | ort:default Q,K,V | 4 | 2048 | 32 | 128 | 0.0016 | 170.0 | ort:default Q,K,V | 4 | 2048 | 32 | 128 | 0.0016 | 169.5 | ort:flash QKV | 4 | 2048 | 32 | 128 | 0.0016 | 168.5 | ort:default QKV | 4 | 2048 | 32 | 128 | 0.0016 | 167.4 | ort:flash Q,K,V | 4 | 2048 | 32 | 128 | 0.0017 | 159.4 | torch:default Q,K,V | 4 | 2048 | 32 | 128 | 0.0018 | 155.0 | torch:flash Q,KV | 4 | 2048 | 32 | 128 | 0.0030 | 92.7 | ort:efficient Q,K,V | 4 | 2048 | 32 | 128 | 0.0030 | 90.9 | ort:efficient QKV | 4 | 2048 | 32 | 128 | 0.0031 | 89.9 | ort:efficient Q,K,V | 4 | 2048 | 32 | 128 | 0.0031 | 89.0 | torch:efficient Q,K,V | 4 | 2048 | 32 | 128 | 0.0054 | 51.3 | torch:math Q,KV | 4 | 4096 | 32 | 128 | 0.0058 | 191.0 | ort:default Q,KV | 4 | 4096 | 32 | 128 | 0.0058 | 190.6 | ort:flash Q,K,V | 4 | 4096 | 32 | 128 | 0.0059 | 187.8 | ort:default Q,K,V | 4 | 4096 | 32 | 128 | 0.0059 | 186.7 | ort:flash QKV | 4 | 4096 | 32 | 128 | 0.0059 | 185.9 | ort:flash QKV | 4 | 4096 | 32 | 128 | 0.0059 | 185.8 | ort:default Q,K,V | 4 | 4096 | 32 | 128 | 0.0067 | 163.4 | torch:default Q,K,V | 4 | 4096 | 32 | 128 | 0.0070 | 157.2 | torch:flash Q,KV | 4 | 4096 | 32 | 128 | 0.0113 | 97.6 | ort:efficient Q,K,V | 4 | 4096 | 32 | 128 | 0.0114 | 96.4 | ort:efficient QKV | 4 | 4096 | 32 | 128 | 0.0114 | 96.2 | ort:efficient Q,K,V | 4 | 4096 | 32 | 128 | 0.0127 | 86.3 | torch:efficient Q,KV | 8 | 2048 | 32 | 128 | 0.0031 | 177.8 | ort:flash Q,KV | 8 | 2048 | 32 | 128 | 0.0031 | 177.7 | ort:default Q,K,V | 8 | 2048 | 32 | 128 | 0.0032 | 170.8 | ort:default Q,K,V | 8 | 2048 | 32 | 128 | 0.0032 | 170.3 | ort:flash QKV | 8 | 2048 | 32 | 128 | 0.0032 | 169.2 | ort:default QKV | 8 | 2048 | 32 | 128 | 0.0033 | 169.0 | ort:flash Q,K,V | 8 | 2048 | 32 | 128 | 0.0034 | 161.9 | torch:default Q,K,V | 8 | 2048 | 32 | 128 | 0.0036 | 152.9 | torch:flash Q,KV | 8 | 2048 | 32 | 128 | 0.0059 | 93.5 | ort:efficient Q,K,V | 8 | 2048 | 32 | 128 | 0.0060 | 91.3 | ort:efficient QKV | 8 | 2048 | 32 | 128 | 0.0060 | 91.0 | ort:efficient Q,K,V | 8 | 2048 | 32 | 128 | 0.0064 | 86.0 | torch:efficient Q,KV | 8 | 4096 | 32 | 128 | 0.0115 | 190.8 | ort:flash Q,KV | 8 | 4096 | 32 | 128 | 0.0115 | 190.7 | ort:default Q,K,V | 8 | 4096 | 32 | 128 | 0.0118 | 187.1 | ort:default Q,K,V | 8 | 4096 | 32 | 128 | 0.0118 | 187.0 | ort:flash QKV | 8 | 4096 | 32 | 128 | 0.0118 | 185.6 | ort:default QKV | 8 | 4096 | 32 | 128 | 0.0118 | 185.6 | ort:flash Q,K,V | 8 | 4096 | 32 | 128 | 0.0139 | 158.7 | torch:default Q,K,V | 8 | 4096 | 32 | 128 | 0.0139 | 158.3 | torch:flash Q,KV | 8 | 4096 | 32 | 128 | 0.0225 | 97.7 | ort:efficient Q,K,V | 8 | 4096 | 32 | 128 | 0.0227 | 96.8 | ort:efficient QKV | 8 | 4096 | 32 | 128 | 0.0228 | 96.3 | ort:efficient Q,K,V | 8 | 4096 | 32 | 128 | 0.0260 | 84.5 | torch:efficient #### Example CPU results Dell XPS 8960 with i9-13900 CPU (use_gpu=FALSE, causal=FALSE, past_sequence_length=0) in Windows. ORT: build from source with CUDA 12.5; PyTorch 2.3.1 for cuda 12.1. format | causal | batch_size | seq_len | num_heads | head_size | threads | latency (s) | kernel -- | -- | -- | -- | -- | -- | -- | -- | -- Q,K,V | FALSE | 1 | 128 | 32 | 128 | 8 | 0.0005 | ort:flash Q,K,V | FALSE | 1 | 128 | 32 | 128 | 0 | 0.0009 | ort:flash Q,K,V | FALSE | 1 | 128 | 32 | 128 | 0 | 0.0009 | ort:math Q,K,V | FALSE | 1 | 128 | 32 | 128 | 4 | 0.0009 | ort:flash Q,K,V | FALSE | 1 | 128 | 32 | 128 | 2 | 0.0014 | ort:flash Q,K,V | FALSE | 1 | 128 | 32 | 128 | 1 | 0.0025 | ort:flash Q,K,V | FALSE | 1 | 128 | 32 | 128 | 2 | 0.0045 | torch:default Q,K,V | FALSE | 1 | 128 | 32 | 128 | 24 | 0.0046 | torch:default Q,K,V | FALSE | 1 | 128 | 32 | 128 | 8 | 0.0046 | torch:default Q,K,V | FALSE | 1 | 128 | 32 | 128 | 4 | 0.0046 | torch:default Q,K,V | FALSE | 1 | 128 | 32 | 128 | 1 | 0.0047 | torch:default Q,K,V | FALSE | 1 | 256 | 32 | 128 | 0 | 0.0019 | ort:flash Q,K,V | FALSE | 1 | 256 | 32 | 128 | 8 | 0.0019 | ort:flash Q,K,V | FALSE | 1 | 256 | 32 | 128 | 0 | 0.0022 | ort:math Q,K,V | FALSE | 1 | 256 | 32 | 128 | 4 | 0.0030 | ort:flash Q,K,V | FALSE | 1 | 256 | 32 | 128 | 2 | 0.0047 | ort:flash Q,K,V | FALSE | 1 | 256 | 32 | 128 | 1 | 0.0086 | ort:flash Q,K,V | FALSE | 1 | 256 | 32 | 128 | 2 | 0.0161 | torch:default Q,K,V | FALSE | 1 | 256 | 32 | 128 | 4 | 0.0162 | torch:default Q,K,V | FALSE | 1 | 256 | 32 | 128 | 8 | 0.0162 | torch:default Q,K,V | FALSE | 1 | 256 | 32 | 128 | 24 | 0.0165 | torch:default Q,K,V | FALSE | 1 | 256 | 32 | 128 | 1 | 0.0166 | torch:default Q,K,V | FALSE | 1 | 512 | 32 | 128 | 8 | 0.0077 | ort:flash Q,K,V | FALSE | 1 | 512 | 32 | 128 | 0 | 0.0091 | ort:flash Q,K,V | FALSE | 1 | 512 | 32 | 128 | 0 | 0.0099 | ort:math Q,K,V | FALSE | 1 | 512 | 32 | 128 | 4 | 0.0103 | ort:flash Q,K,V | FALSE | 1 | 512 | 32 | 128 | 2 | 0.0177 | ort:flash Q,K,V | FALSE | 1 | 512 | 32 | 128 | 1 | 0.0328 | ort:flash Q,K,V | FALSE | 1 | 512 | 32 | 128 | 2 | 0.0624 | torch:default Q,K,V | FALSE | 1 | 512 | 32 | 128 | 4 | 0.0624 | torch:default Q,K,V | FALSE | 1 | 512 | 32 | 128 | 8 | 0.0625 | torch:default Q,K,V | FALSE | 1 | 512 | 32 | 128 | 24 | 0.0626 | torch:default Q,K,V | FALSE | 1 | 512 | 32 | 128 | 1 | 0.0640 | torch:default Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 8 | 0.0286 | ort:flash Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 0 | 0.0317 | ort:flash Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 4 | 0.0367 | ort:flash Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 0 | 0.0391 | ort:math Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 2 | 0.0656 | ort:flash Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 1 | 0.1235 | ort:flash Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 24 | 0.2482 | torch:default Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 2 | 0.2483 | torch:default Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 4 | 0.2483 | torch:default Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 8 | 0.2486 | torch:default Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 1 | 0.2538 | torch:default Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 0 | 0.1038 | ort:flash Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 8 | 0.1050 | ort:flash Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 0 | 0.1368 | ort:math Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 4 | 0.1535 | ort:flash Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 2 | 0.2461 | ort:flash Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 1 | 0.4724 | ort:flash Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 8 | 0.9835 | torch:default Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 4 | 0.9841 | torch:default Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 24 | 0.9841 | torch:default Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 2 | 0.9873 | torch:default Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 1 | 0.9985 | torch:default ### Motivation and Context To compare with PyTorch SDPA on CPU and CUDA latency.
- Loading branch information