Skip to content

Commit

Permalink
Update benchmark_mha.py to compare with PyTorch SDPA (#21449)
Browse files Browse the repository at this point in the history
### Description
* Update benchmark_mha.py to compare with PyTorch SDPA api.
* Write results to csv file.
* Use sdpa_kernel cuda provider option instead of environment variables
for better control.
* Add arguments (`--use_gpu`, `--causal` etc) to allow testing different
senarios.
* Update benchmark_mha.sh to add cpu benchmarks

For Q,K,V format, torch uses BNSH format, while ort uses BSNH format, so
the result is not apple-to-apple. However, if the latency difference is
large, that could be a warning.

#### Example GPU results

Example results on A100-SXM4-80GB with settings (use_gpu=TRUE,
enable_cuda_graph=FALSE, causal=FALSE, past_sequence_length=0,
intra_op_num_threads=0) in Azure Linux. ORT: build from source with CUDA
12.5; PyTorch 2.3.1 for cuda 12.1.

format | batch_size | sequence_length | num_heads | head_size | latency
(s) | tflops | kernel
-- | -- | -- | -- | -- | -- | -- | --
Q,KV | 4 | 2048 | 32 | 128 | 0.0015 | 179.5 | ort:flash
Q,KV | 4 | 2048 | 32 | 128 | 0.0015 | 179.0 | ort:default
Q,K,V | 4 | 2048 | 32 | 128 | 0.0016 | 170.0 | ort:default
Q,K,V | 4 | 2048 | 32 | 128 | 0.0016 | 169.5 | ort:flash
QKV | 4 | 2048 | 32 | 128 | 0.0016 | 168.5 | ort:default
QKV | 4 | 2048 | 32 | 128 | 0.0016 | 167.4 | ort:flash
Q,K,V | 4 | 2048 | 32 | 128 | 0.0017 | 159.4 | torch:default
Q,K,V | 4 | 2048 | 32 | 128 | 0.0018 | 155.0 | torch:flash
Q,KV | 4 | 2048 | 32 | 128 | 0.0030 | 92.7 | ort:efficient
Q,K,V | 4 | 2048 | 32 | 128 | 0.0030 | 90.9 | ort:efficient
QKV | 4 | 2048 | 32 | 128 | 0.0031 | 89.9 | ort:efficient
Q,K,V | 4 | 2048 | 32 | 128 | 0.0031 | 89.0 | torch:efficient
Q,K,V | 4 | 2048 | 32 | 128 | 0.0054 | 51.3 | torch:math
Q,KV | 4 | 4096 | 32 | 128 | 0.0058 | 191.0 | ort:default
Q,KV | 4 | 4096 | 32 | 128 | 0.0058 | 190.6 | ort:flash
Q,K,V | 4 | 4096 | 32 | 128 | 0.0059 | 187.8 | ort:default
Q,K,V | 4 | 4096 | 32 | 128 | 0.0059 | 186.7 | ort:flash
QKV | 4 | 4096 | 32 | 128 | 0.0059 | 185.9 | ort:flash
QKV | 4 | 4096 | 32 | 128 | 0.0059 | 185.8 | ort:default
Q,K,V | 4 | 4096 | 32 | 128 | 0.0067 | 163.4 | torch:default
Q,K,V | 4 | 4096 | 32 | 128 | 0.0070 | 157.2 | torch:flash
Q,KV | 4 | 4096 | 32 | 128 | 0.0113 | 97.6 | ort:efficient
Q,K,V | 4 | 4096 | 32 | 128 | 0.0114 | 96.4 | ort:efficient
QKV | 4 | 4096 | 32 | 128 | 0.0114 | 96.2 | ort:efficient
Q,K,V | 4 | 4096 | 32 | 128 | 0.0127 | 86.3 | torch:efficient
Q,KV | 8 | 2048 | 32 | 128 | 0.0031 | 177.8 | ort:flash
Q,KV | 8 | 2048 | 32 | 128 | 0.0031 | 177.7 | ort:default
Q,K,V | 8 | 2048 | 32 | 128 | 0.0032 | 170.8 | ort:default
Q,K,V | 8 | 2048 | 32 | 128 | 0.0032 | 170.3 | ort:flash
QKV | 8 | 2048 | 32 | 128 | 0.0032 | 169.2 | ort:default
QKV | 8 | 2048 | 32 | 128 | 0.0033 | 169.0 | ort:flash
Q,K,V | 8 | 2048 | 32 | 128 | 0.0034 | 161.9 | torch:default
Q,K,V | 8 | 2048 | 32 | 128 | 0.0036 | 152.9 | torch:flash
Q,KV | 8 | 2048 | 32 | 128 | 0.0059 | 93.5 | ort:efficient
Q,K,V | 8 | 2048 | 32 | 128 | 0.0060 | 91.3 | ort:efficient
QKV | 8 | 2048 | 32 | 128 | 0.0060 | 91.0 | ort:efficient
Q,K,V | 8 | 2048 | 32 | 128 | 0.0064 | 86.0 | torch:efficient
Q,KV | 8 | 4096 | 32 | 128 | 0.0115 | 190.8 | ort:flash
Q,KV | 8 | 4096 | 32 | 128 | 0.0115 | 190.7 | ort:default
Q,K,V | 8 | 4096 | 32 | 128 | 0.0118 | 187.1 | ort:default
Q,K,V | 8 | 4096 | 32 | 128 | 0.0118 | 187.0 | ort:flash
QKV | 8 | 4096 | 32 | 128 | 0.0118 | 185.6 | ort:default
QKV | 8 | 4096 | 32 | 128 | 0.0118 | 185.6 | ort:flash
Q,K,V | 8 | 4096 | 32 | 128 | 0.0139 | 158.7 | torch:default
Q,K,V | 8 | 4096 | 32 | 128 | 0.0139 | 158.3 | torch:flash
Q,KV | 8 | 4096 | 32 | 128 | 0.0225 | 97.7 | ort:efficient
Q,K,V | 8 | 4096 | 32 | 128 | 0.0227 | 96.8 | ort:efficient
QKV | 8 | 4096 | 32 | 128 | 0.0228 | 96.3 | ort:efficient
Q,K,V | 8 | 4096 | 32 | 128 | 0.0260 | 84.5 | torch:efficient

#### Example CPU results

Dell XPS 8960 with i9-13900 CPU (use_gpu=FALSE, causal=FALSE,
past_sequence_length=0) in Windows. ORT: build from source with CUDA
12.5; PyTorch 2.3.1 for cuda 12.1.

format | causal | batch_size | seq_len | num_heads | head_size | threads
| latency (s) | kernel
-- | -- | -- | -- | -- | -- | -- | -- | --
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 8 | 0.0005 | ort:flash
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 0 | 0.0009 | ort:flash
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 0 | 0.0009 | ort:math
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 4 | 0.0009 | ort:flash
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 2 | 0.0014 | ort:flash
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 1 | 0.0025 | ort:flash
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 2 | 0.0045 | torch:default
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 24 | 0.0046 | torch:default
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 8 | 0.0046 | torch:default
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 4 | 0.0046 | torch:default
Q,K,V | FALSE | 1 | 128 | 32 | 128 | 1 | 0.0047 | torch:default
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 0 | 0.0019 | ort:flash
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 8 | 0.0019 | ort:flash
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 0 | 0.0022 | ort:math
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 4 | 0.0030 | ort:flash
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 2 | 0.0047 | ort:flash
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 1 | 0.0086 | ort:flash
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 2 | 0.0161 | torch:default
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 4 | 0.0162 | torch:default
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 8 | 0.0162 | torch:default
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 24 | 0.0165 | torch:default
Q,K,V | FALSE | 1 | 256 | 32 | 128 | 1 | 0.0166 | torch:default
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 8 | 0.0077 | ort:flash
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 0 | 0.0091 | ort:flash
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 0 | 0.0099 | ort:math
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 4 | 0.0103 | ort:flash
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 2 | 0.0177 | ort:flash
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 1 | 0.0328 | ort:flash
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 2 | 0.0624 | torch:default
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 4 | 0.0624 | torch:default
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 8 | 0.0625 | torch:default
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 24 | 0.0626 | torch:default
Q,K,V | FALSE | 1 | 512 | 32 | 128 | 1 | 0.0640 | torch:default
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 8 | 0.0286 | ort:flash
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 0 | 0.0317 | ort:flash
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 4 | 0.0367 | ort:flash
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 0 | 0.0391 | ort:math
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 2 | 0.0656 | ort:flash
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 1 | 0.1235 | ort:flash
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 24 | 0.2482 | torch:default
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 2 | 0.2483 | torch:default
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 4 | 0.2483 | torch:default
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 8 | 0.2486 | torch:default
Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 1 | 0.2538 | torch:default
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 0 | 0.1038 | ort:flash
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 8 | 0.1050 | ort:flash
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 0 | 0.1368 | ort:math
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 4 | 0.1535 | ort:flash
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 2 | 0.2461 | ort:flash
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 1 | 0.4724 | ort:flash
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 8 | 0.9835 | torch:default
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 4 | 0.9841 | torch:default
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 24 | 0.9841 | torch:default
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 2 | 0.9873 | torch:default
Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 1 | 0.9985 | torch:default


### Motivation and Context
To compare with PyTorch SDPA on CPU and CUDA latency.
  • Loading branch information
tianleiwu authored Jul 27, 2024
1 parent fb61e14 commit 64819f6
Show file tree
Hide file tree
Showing 4 changed files with 609 additions and 222 deletions.
47 changes: 47 additions & 0 deletions onnxruntime/test/python/transformers/benchmark_mha.cmd
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
echo "Benchmark Scaled Dot Product Attention (SDPA) performance on GPU:"

set CUDA_VISIBLE_DEVICES=0
python benchmark_mha.py --use_gpu
python benchmark_mha.py --use_gpu --use_cuda_graph
python benchmark_mha.py --use_gpu --torch

type benchmark_mha_gpu_*.csv > mha_gpu_benchmark_results.csv

echo "Benchmark performance on CPU with number of threads:"
set MKL_DYNAMIC=FALSE
set OMP_NUM_THREADS=1
python benchmark_mha.py --torch

set OMP_NUM_THREADS=2
python benchmark_mha.py --torch

set OMP_NUM_THREADS=4
python benchmark_mha.py --torch

set OMP_NUM_THREADS=8
python benchmark_mha.py --torch

set MKL_DYNAMIC=
set OMP_NUM_THREADS=

set ORT_DISABLE_FLASH_ATTENTION=0
python benchmark_mha.py --intra_op_num_threads 1
python benchmark_mha.py --intra_op_num_threads 2
python benchmark_mha.py --intra_op_num_threads 4
python benchmark_mha.py --intra_op_num_threads 8

echo "Benchmark performance on CPU with default threads settings:"
python benchmark_mha.py

python benchmark_mha.py --torch

python benchmark_mha.py --causal
python benchmark_mha.py --torch --causal

python benchmark_mha.py --causal --has_past

set ORT_DISABLE_FLASH_ATTENTION=1
python benchmark_mha.py
set ORT_DISABLE_FLASH_ATTENTION=

type benchmark_mha_cpu_*.csv > mha_cpu_benchmark_results.csv
Loading

0 comments on commit 64819f6

Please sign in to comment.