From 64819f6f8cad8387b23d7cc8af1a4b4207e2dfbb Mon Sep 17 00:00:00 2001 From: Tianlei Wu Date: Fri, 26 Jul 2024 18:45:14 -0700 Subject: [PATCH] Update benchmark_mha.py to compare with PyTorch SDPA (#21449) ### Description * Update benchmark_mha.py to compare with PyTorch SDPA api. * Write results to csv file. * Use sdpa_kernel cuda provider option instead of environment variables for better control. * Add arguments (`--use_gpu`, `--causal` etc) to allow testing different senarios. * Update benchmark_mha.sh to add cpu benchmarks For Q,K,V format, torch uses BNSH format, while ort uses BSNH format, so the result is not apple-to-apple. However, if the latency difference is large, that could be a warning. #### Example GPU results Example results on A100-SXM4-80GB with settings (use_gpu=TRUE, enable_cuda_graph=FALSE, causal=FALSE, past_sequence_length=0, intra_op_num_threads=0) in Azure Linux. ORT: build from source with CUDA 12.5; PyTorch 2.3.1 for cuda 12.1. format | batch_size | sequence_length | num_heads | head_size | latency (s) | tflops | kernel -- | -- | -- | -- | -- | -- | -- | -- Q,KV | 4 | 2048 | 32 | 128 | 0.0015 | 179.5 | ort:flash Q,KV | 4 | 2048 | 32 | 128 | 0.0015 | 179.0 | ort:default Q,K,V | 4 | 2048 | 32 | 128 | 0.0016 | 170.0 | ort:default Q,K,V | 4 | 2048 | 32 | 128 | 0.0016 | 169.5 | ort:flash QKV | 4 | 2048 | 32 | 128 | 0.0016 | 168.5 | ort:default QKV | 4 | 2048 | 32 | 128 | 0.0016 | 167.4 | ort:flash Q,K,V | 4 | 2048 | 32 | 128 | 0.0017 | 159.4 | torch:default Q,K,V | 4 | 2048 | 32 | 128 | 0.0018 | 155.0 | torch:flash Q,KV | 4 | 2048 | 32 | 128 | 0.0030 | 92.7 | ort:efficient Q,K,V | 4 | 2048 | 32 | 128 | 0.0030 | 90.9 | ort:efficient QKV | 4 | 2048 | 32 | 128 | 0.0031 | 89.9 | ort:efficient Q,K,V | 4 | 2048 | 32 | 128 | 0.0031 | 89.0 | torch:efficient Q,K,V | 4 | 2048 | 32 | 128 | 0.0054 | 51.3 | torch:math Q,KV | 4 | 4096 | 32 | 128 | 0.0058 | 191.0 | ort:default Q,KV | 4 | 4096 | 32 | 128 | 0.0058 | 190.6 | ort:flash Q,K,V | 4 | 4096 | 32 | 128 | 0.0059 | 187.8 | ort:default Q,K,V | 4 | 4096 | 32 | 128 | 0.0059 | 186.7 | ort:flash QKV | 4 | 4096 | 32 | 128 | 0.0059 | 185.9 | ort:flash QKV | 4 | 4096 | 32 | 128 | 0.0059 | 185.8 | ort:default Q,K,V | 4 | 4096 | 32 | 128 | 0.0067 | 163.4 | torch:default Q,K,V | 4 | 4096 | 32 | 128 | 0.0070 | 157.2 | torch:flash Q,KV | 4 | 4096 | 32 | 128 | 0.0113 | 97.6 | ort:efficient Q,K,V | 4 | 4096 | 32 | 128 | 0.0114 | 96.4 | ort:efficient QKV | 4 | 4096 | 32 | 128 | 0.0114 | 96.2 | ort:efficient Q,K,V | 4 | 4096 | 32 | 128 | 0.0127 | 86.3 | torch:efficient Q,KV | 8 | 2048 | 32 | 128 | 0.0031 | 177.8 | ort:flash Q,KV | 8 | 2048 | 32 | 128 | 0.0031 | 177.7 | ort:default Q,K,V | 8 | 2048 | 32 | 128 | 0.0032 | 170.8 | ort:default Q,K,V | 8 | 2048 | 32 | 128 | 0.0032 | 170.3 | ort:flash QKV | 8 | 2048 | 32 | 128 | 0.0032 | 169.2 | ort:default QKV | 8 | 2048 | 32 | 128 | 0.0033 | 169.0 | ort:flash Q,K,V | 8 | 2048 | 32 | 128 | 0.0034 | 161.9 | torch:default Q,K,V | 8 | 2048 | 32 | 128 | 0.0036 | 152.9 | torch:flash Q,KV | 8 | 2048 | 32 | 128 | 0.0059 | 93.5 | ort:efficient Q,K,V | 8 | 2048 | 32 | 128 | 0.0060 | 91.3 | ort:efficient QKV | 8 | 2048 | 32 | 128 | 0.0060 | 91.0 | ort:efficient Q,K,V | 8 | 2048 | 32 | 128 | 0.0064 | 86.0 | torch:efficient Q,KV | 8 | 4096 | 32 | 128 | 0.0115 | 190.8 | ort:flash Q,KV | 8 | 4096 | 32 | 128 | 0.0115 | 190.7 | ort:default Q,K,V | 8 | 4096 | 32 | 128 | 0.0118 | 187.1 | ort:default Q,K,V | 8 | 4096 | 32 | 128 | 0.0118 | 187.0 | ort:flash QKV | 8 | 4096 | 32 | 128 | 0.0118 | 185.6 | ort:default QKV | 8 | 4096 | 32 | 128 | 0.0118 | 185.6 | ort:flash Q,K,V | 8 | 4096 | 32 | 128 | 0.0139 | 158.7 | torch:default Q,K,V | 8 | 4096 | 32 | 128 | 0.0139 | 158.3 | torch:flash Q,KV | 8 | 4096 | 32 | 128 | 0.0225 | 97.7 | ort:efficient Q,K,V | 8 | 4096 | 32 | 128 | 0.0227 | 96.8 | ort:efficient QKV | 8 | 4096 | 32 | 128 | 0.0228 | 96.3 | ort:efficient Q,K,V | 8 | 4096 | 32 | 128 | 0.0260 | 84.5 | torch:efficient #### Example CPU results Dell XPS 8960 with i9-13900 CPU (use_gpu=FALSE, causal=FALSE, past_sequence_length=0) in Windows. ORT: build from source with CUDA 12.5; PyTorch 2.3.1 for cuda 12.1. format | causal | batch_size | seq_len | num_heads | head_size | threads | latency (s) | kernel -- | -- | -- | -- | -- | -- | -- | -- | -- Q,K,V | FALSE | 1 | 128 | 32 | 128 | 8 | 0.0005 | ort:flash Q,K,V | FALSE | 1 | 128 | 32 | 128 | 0 | 0.0009 | ort:flash Q,K,V | FALSE | 1 | 128 | 32 | 128 | 0 | 0.0009 | ort:math Q,K,V | FALSE | 1 | 128 | 32 | 128 | 4 | 0.0009 | ort:flash Q,K,V | FALSE | 1 | 128 | 32 | 128 | 2 | 0.0014 | ort:flash Q,K,V | FALSE | 1 | 128 | 32 | 128 | 1 | 0.0025 | ort:flash Q,K,V | FALSE | 1 | 128 | 32 | 128 | 2 | 0.0045 | torch:default Q,K,V | FALSE | 1 | 128 | 32 | 128 | 24 | 0.0046 | torch:default Q,K,V | FALSE | 1 | 128 | 32 | 128 | 8 | 0.0046 | torch:default Q,K,V | FALSE | 1 | 128 | 32 | 128 | 4 | 0.0046 | torch:default Q,K,V | FALSE | 1 | 128 | 32 | 128 | 1 | 0.0047 | torch:default Q,K,V | FALSE | 1 | 256 | 32 | 128 | 0 | 0.0019 | ort:flash Q,K,V | FALSE | 1 | 256 | 32 | 128 | 8 | 0.0019 | ort:flash Q,K,V | FALSE | 1 | 256 | 32 | 128 | 0 | 0.0022 | ort:math Q,K,V | FALSE | 1 | 256 | 32 | 128 | 4 | 0.0030 | ort:flash Q,K,V | FALSE | 1 | 256 | 32 | 128 | 2 | 0.0047 | ort:flash Q,K,V | FALSE | 1 | 256 | 32 | 128 | 1 | 0.0086 | ort:flash Q,K,V | FALSE | 1 | 256 | 32 | 128 | 2 | 0.0161 | torch:default Q,K,V | FALSE | 1 | 256 | 32 | 128 | 4 | 0.0162 | torch:default Q,K,V | FALSE | 1 | 256 | 32 | 128 | 8 | 0.0162 | torch:default Q,K,V | FALSE | 1 | 256 | 32 | 128 | 24 | 0.0165 | torch:default Q,K,V | FALSE | 1 | 256 | 32 | 128 | 1 | 0.0166 | torch:default Q,K,V | FALSE | 1 | 512 | 32 | 128 | 8 | 0.0077 | ort:flash Q,K,V | FALSE | 1 | 512 | 32 | 128 | 0 | 0.0091 | ort:flash Q,K,V | FALSE | 1 | 512 | 32 | 128 | 0 | 0.0099 | ort:math Q,K,V | FALSE | 1 | 512 | 32 | 128 | 4 | 0.0103 | ort:flash Q,K,V | FALSE | 1 | 512 | 32 | 128 | 2 | 0.0177 | ort:flash Q,K,V | FALSE | 1 | 512 | 32 | 128 | 1 | 0.0328 | ort:flash Q,K,V | FALSE | 1 | 512 | 32 | 128 | 2 | 0.0624 | torch:default Q,K,V | FALSE | 1 | 512 | 32 | 128 | 4 | 0.0624 | torch:default Q,K,V | FALSE | 1 | 512 | 32 | 128 | 8 | 0.0625 | torch:default Q,K,V | FALSE | 1 | 512 | 32 | 128 | 24 | 0.0626 | torch:default Q,K,V | FALSE | 1 | 512 | 32 | 128 | 1 | 0.0640 | torch:default Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 8 | 0.0286 | ort:flash Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 0 | 0.0317 | ort:flash Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 4 | 0.0367 | ort:flash Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 0 | 0.0391 | ort:math Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 2 | 0.0656 | ort:flash Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 1 | 0.1235 | ort:flash Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 24 | 0.2482 | torch:default Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 2 | 0.2483 | torch:default Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 4 | 0.2483 | torch:default Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 8 | 0.2486 | torch:default Q,K,V | FALSE | 1 | 1024 | 32 | 128 | 1 | 0.2538 | torch:default Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 0 | 0.1038 | ort:flash Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 8 | 0.1050 | ort:flash Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 0 | 0.1368 | ort:math Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 4 | 0.1535 | ort:flash Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 2 | 0.2461 | ort:flash Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 1 | 0.4724 | ort:flash Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 8 | 0.9835 | torch:default Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 4 | 0.9841 | torch:default Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 24 | 0.9841 | torch:default Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 2 | 0.9873 | torch:default Q,K,V | FALSE | 1 | 2048 | 32 | 128 | 1 | 0.9985 | torch:default ### Motivation and Context To compare with PyTorch SDPA on CPU and CUDA latency. --- .../python/transformers/benchmark_mha.cmd | 47 ++ .../test/python/transformers/benchmark_mha.py | 690 +++++++++++++----- .../test/python/transformers/benchmark_mha.sh | 48 +- .../test/python/transformers/test_mha.py | 46 +- 4 files changed, 609 insertions(+), 222 deletions(-) create mode 100644 onnxruntime/test/python/transformers/benchmark_mha.cmd diff --git a/onnxruntime/test/python/transformers/benchmark_mha.cmd b/onnxruntime/test/python/transformers/benchmark_mha.cmd new file mode 100644 index 0000000000000..0a6d0c37b4a35 --- /dev/null +++ b/onnxruntime/test/python/transformers/benchmark_mha.cmd @@ -0,0 +1,47 @@ +echo "Benchmark Scaled Dot Product Attention (SDPA) performance on GPU:" + +set CUDA_VISIBLE_DEVICES=0 +python benchmark_mha.py --use_gpu +python benchmark_mha.py --use_gpu --use_cuda_graph +python benchmark_mha.py --use_gpu --torch + +type benchmark_mha_gpu_*.csv > mha_gpu_benchmark_results.csv + +echo "Benchmark performance on CPU with number of threads:" +set MKL_DYNAMIC=FALSE +set OMP_NUM_THREADS=1 +python benchmark_mha.py --torch + +set OMP_NUM_THREADS=2 +python benchmark_mha.py --torch + +set OMP_NUM_THREADS=4 +python benchmark_mha.py --torch + +set OMP_NUM_THREADS=8 +python benchmark_mha.py --torch + +set MKL_DYNAMIC= +set OMP_NUM_THREADS= + +set ORT_DISABLE_FLASH_ATTENTION=0 +python benchmark_mha.py --intra_op_num_threads 1 +python benchmark_mha.py --intra_op_num_threads 2 +python benchmark_mha.py --intra_op_num_threads 4 +python benchmark_mha.py --intra_op_num_threads 8 + +echo "Benchmark performance on CPU with default threads settings:" +python benchmark_mha.py + +python benchmark_mha.py --torch + +python benchmark_mha.py --causal +python benchmark_mha.py --torch --causal + +python benchmark_mha.py --causal --has_past + +set ORT_DISABLE_FLASH_ATTENTION=1 +python benchmark_mha.py +set ORT_DISABLE_FLASH_ATTENTION= + +type benchmark_mha_cpu_*.csv > mha_cpu_benchmark_results.csv diff --git a/onnxruntime/test/python/transformers/benchmark_mha.py b/onnxruntime/test/python/transformers/benchmark_mha.py index 111c417479d20..715a92431e6bf 100644 --- a/onnxruntime/test/python/transformers/benchmark_mha.py +++ b/onnxruntime/test/python/transformers/benchmark_mha.py @@ -4,21 +4,35 @@ # -------------------------------------------------------------------------- """ -Benchmark performance of MultiHeadAttention with Nvidia GPU of Compute Capability 8.0, 8.6 or 8.9 in Linux: -sh benchmark_mha.sh +Benchmark performance of MultiHeadAttention with ORT or PyTorch. + +In Linux, run the the following: + sh benchmark_mha.sh + +In Windows, run the the following: + benchmark_mha.cmd """ +import argparse +import csv import math import os import platform import statistics import time -from typing import List, Optional +from contextlib import nullcontext +from datetime import datetime +from enum import IntEnum +from typing import Callable, Dict, List, Optional, Tuple import torch +import torch.utils.benchmark as benchmark from onnx import TensorProto, helper +from packaging.version import Version +from torch.nn.attention import SDPBackend, sdpa_kernel +from torch.nn.functional import scaled_dot_product_attention -from onnxruntime import InferenceSession, get_available_providers +from onnxruntime import InferenceSession, SessionOptions, get_available_providers from onnxruntime.transformers.io_binding_helper import CudaSession @@ -43,6 +57,20 @@ def get_name_list() -> List[str]: return ["Q,K,V", "QKV", "Q,KV", "Q,K',V'"] +class SdpaKernel(IntEnum): + """Bit flags for sdpa_kernel CUDA provider option""" + + DEFAULT = 0 + FLASH_ATTENTION = 1 + EFFICIENT_ATTENTION = 2 + TRT_FUSED_ATTENTION = 4 + CUDNN_FLASH_ATTENTION = 8 + MATH = 16 + TRT_FLASH_ATTENTION = 32 + TRT_CROSS_ATTENTION = 64 + TRT_CAUSAL_ATTENTION = 128 + + class MultiHeadAttentionConfig: def __init__( self, @@ -62,6 +90,7 @@ def __init__( use_kv_cache: bool = False, share_past_present_buffer: bool = False, input_format: int = InputFormats.Q_K_V_BSNH_BSNH_BSNH, + verbose: bool = False, ): self.operator = "MultiHeadAttention" self.batch_size = batch_size @@ -100,6 +129,7 @@ def __init__( self.input_format = input_format self.is_packed_qkv = input_format == InputFormats.QKV_BSN3H self.is_packed_kv = input_format == InputFormats.Q_KV_BSNH_BSN2H + self.verbose = verbose def __repr__(self): return ( @@ -114,89 +144,93 @@ def __repr__(self): ) def shape_dict(self, input_format=None): + shapes: Dict[str, Tuple] = { + "output": (self.batch_size, self.sequence_length, self.num_heads * self.head_size), + } + input_format = input_format or self.input_format - if input_format == InputFormats.Q_K_V_BSNH_BNSH_BNSH: - # cross attention does not have past state - return { + if input_format == InputFormats.QKV_BSN3H: + shapes = { + **shapes, + "query": (self.batch_size, self.sequence_length, self.num_heads, 3, self.head_size), + } + elif input_format == InputFormats.Q_KV_BSNH_BSN2H: + shapes = { + **shapes, + "query": (self.batch_size, self.sequence_length, self.num_heads * self.head_size), + "key": (self.batch_size, self.sequence_length, self.num_heads, 2, self.head_size), + } + elif input_format == InputFormats.Q_K_V_BSNH_BSNH_BSNH: + shapes = { + **shapes, + "query": (self.batch_size, self.sequence_length, self.num_heads * self.head_size), + "key": (self.batch_size, self.sequence_length, self.num_heads * self.head_size), + "value": (self.batch_size, self.sequence_length, self.num_heads * self.head_size), + } + else: + assert input_format == InputFormats.Q_K_V_BSNH_BNSH_BNSH + shapes = { + **shapes, "query": (self.batch_size, self.sequence_length, self.num_heads * self.head_size), "key": (self.batch_size, self.num_heads, self.sequence_length, self.head_size), "value": (self.batch_size, self.num_heads, self.sequence_length, self.head_size), - "output": (self.batch_size, self.sequence_length, self.num_heads * self.head_size), } if self.use_kv_cache: + assert input_format != InputFormats.Q_K_V_BSNH_BNSH_BNSH, "cross attention shall not have past state" shapes = { + **shapes, "past_key": (self.batch_size, self.num_heads, self.past_buffer_length, self.head_size), "past_value": (self.batch_size, self.num_heads, self.past_buffer_length, self.head_size), - "output": (self.batch_size, self.sequence_length, self.num_heads * self.head_size), "present_key": (self.batch_size, self.num_heads, self.present_buffer_length, self.head_size), "present_value": (self.batch_size, self.num_heads, self.present_buffer_length, self.head_size), } - else: - shapes = { - "output": (self.batch_size, self.sequence_length, self.num_heads * self.head_size), - } - if input_format == InputFormats.QKV_BSN3H: - shapes.update({"query": (self.batch_size, self.sequence_length, self.num_heads, 3, self.head_size)}) - elif input_format == InputFormats.Q_KV_BSNH_BSN2H: - shapes.update( - { - "query": (self.batch_size, self.sequence_length, self.num_heads * self.head_size), - "key": (self.batch_size, self.sequence_length, self.num_heads, 2, self.head_size), - } - ) - else: # input_format == InputFormats.Q_K_V_BSNH_BSNH_BSNH - shapes.update( - { - "query": (self.batch_size, self.sequence_length, self.num_heads * self.head_size), - "key": (self.batch_size, self.sequence_length, self.num_heads * self.head_size), - "value": (self.batch_size, self.sequence_length, self.num_heads * self.head_size), - } - ) return shapes def symbolic_shape_dict(self, input_format=None): + shapes: Dict[str, Tuple] = { + "output": ("batch_size", "sequence_length", self.num_heads * self.head_size), + } + input_format = input_format or self.input_format - if input_format == InputFormats.Q_K_V_BSNH_BNSH_BNSH: - # cross attention does not have past state - return { + if input_format == InputFormats.QKV_BSN3H: + shapes = { + **shapes, + "query": ("batch_size", "sequence_length", self.num_heads, 3, self.head_size), + } + elif input_format == InputFormats.Q_KV_BSNH_BSN2H: + shapes = { + **shapes, + "query": ("batch_size", "sequence_length", self.num_heads * self.head_size), + "key": ("batch_size", "sequence_length", self.num_heads, 2, self.head_size), + } + elif input_format == InputFormats.Q_K_V_BSNH_BSNH_BSNH: + shapes = { + **shapes, + "query": ("batch_size", "sequence_length", self.num_heads * self.head_size), + "key": ("batch_size", "sequence_length", self.num_heads * self.head_size), + "value": ("batch_size", "sequence_length", self.num_heads * self.head_size), + } + else: + assert input_format == InputFormats.Q_K_V_BSNH_BNSH_BNSH + shapes = { + **shapes, "query": ("batch_size", "sequence_length", self.num_heads * self.head_size), "key": ("batch_size", self.num_heads, "sequence_length", self.head_size), "value": ("batch_size", self.num_heads, "sequence_length", self.head_size), - "output": ("batch_size", "sequence_length", self.num_heads * self.head_size), } if self.use_kv_cache: + assert input_format != InputFormats.Q_K_V_BSNH_BNSH_BNSH, "cross attention shall not have past state" shapes = { + **shapes, "past_key": ("batch_size", self.num_heads, "past_buffer_length", self.head_size), "past_value": ("batch_size", self.num_heads, "past_buffer_length", self.head_size), - "output": ("batch_size", "sequence_length", self.num_heads * self.head_size), "present_key": ("batch_size", self.num_heads, "present_buffer_length", self.head_size), "present_value": ("batch_size", self.num_heads, "present_buffer_length", self.head_size), } - else: - shapes = { - "output": ("batch_size", "sequence_length", self.num_heads * self.head_size), - } - if input_format == InputFormats.QKV_BSN3H: - shapes.update({"query": ("batch_size", "sequence_length", self.num_heads, 3, self.head_size)}) - elif input_format == InputFormats.Q_KV_BSNH_BSN2H: - shapes.update( - { - "query": ("batch_size", "sequence_length", self.num_heads * self.head_size), - "key": ("batch_size", "sequence_length", self.num_heads, 2, self.head_size), - } - ) - else: # input_format == InputFormats.Q_K_V_BSNH_BSNH_BSNH - shapes.update( - { - "query": ("batch_size", "sequence_length", self.num_heads * self.head_size), - "key": ("batch_size", "sequence_length", self.num_heads * self.head_size), - "value": ("batch_size", "sequence_length", self.num_heads * self.head_size), - } - ) return shapes def random_inputs(self, seed: int = 123): @@ -215,44 +249,42 @@ def random_inputs(self, seed: int = 123): k_bnsh = k.transpose(1, 2) v_bnsh = v.transpose(1, 2) - if self.input_format == InputFormats.Q_K_V_BSNH_BNSH_BNSH: - return { + if self.input_format == InputFormats.Q_K_V_BSNH_BSNH_BSNH: + feeds = { "query": q.reshape(shape_dict["query"]), - "key": k_bnsh.contiguous(), - "value": v_bnsh.contiguous(), + "key": k.reshape(shape_dict["key"]), + "value": v.reshape(shape_dict["value"]), } - - feeds = {} - if self.use_kv_cache: - feeds.update( - { - "past_key": torch.empty(shape_dict["past_key"], device=device, dtype=dtype).normal_( - mean=0, std=0.1 - ), - "past_value": torch.empty(shape_dict["past_value"], device=device, dtype=dtype).normal_( - mean=0, std=0.1 - ), - } - ) - - if self.input_format == InputFormats.Q_K_V_BSNH_BSNH_BSNH: - feeds.update( - { - "query": q.reshape(shape_dict["query"]), - "key": k.reshape(shape_dict["key"]), - "value": v.reshape(shape_dict["value"]), - } - ) elif self.input_format == InputFormats.QKV_BSN3H: query = q.view(self.batch_size * self.sequence_length, self.num_heads, self.head_size) key = k.view(self.batch_size * self.sequence_length, self.num_heads, self.head_size) value = v.view(self.batch_size * self.sequence_length, self.num_heads, self.head_size) - feeds["query"] = torch.dstack((query, key, value)).reshape(shape_dict["query"]).contiguous() + feeds = { + "query": torch.dstack((query, key, value)).reshape(shape_dict["query"]).contiguous(), + } elif self.input_format == InputFormats.Q_KV_BSNH_BSN2H: key = k.view(self.batch_size * self.sequence_length, self.num_heads, self.head_size) value = v.view(self.batch_size * self.sequence_length, self.num_heads, self.head_size) - feeds["query"] = q.reshape(shape_dict["query"]) - feeds["key"] = torch.dstack((key, value)).reshape(shape_dict["key"]).contiguous() + feeds = { + "query": q.reshape(shape_dict["query"]), + "key": torch.dstack((key, value)).reshape(shape_dict["key"]).contiguous(), + } + else: + assert self.input_format == InputFormats.Q_K_V_BSNH_BNSH_BNSH + feeds = { + "query": q.reshape(shape_dict["query"]), + "key": k_bnsh.contiguous(), + "value": v_bnsh.contiguous(), + } + + if self.use_kv_cache: + feeds = { + **feeds, + "past_key": torch.empty(shape_dict["past_key"], device=device, dtype=dtype).normal_(mean=0, std=0.1), + "past_value": torch.empty(shape_dict["past_value"], device=device, dtype=dtype).normal_( + mean=0, std=0.1 + ), + } return feeds @@ -318,19 +350,32 @@ def create_multi_head_attention_onnx_model(config: MultiHeadAttentionConfig, use return model.SerializeToString() -def create_session( +def create_ort_session( config: MultiHeadAttentionConfig, + session_options=None, + attention_kernel=SdpaKernel.DEFAULT, + use_symbolic_shape: bool = True, ) -> CudaSession: - onnx_model_str = create_multi_head_attention_onnx_model(config) + if config.verbose: + print(f"create session for {vars(config)}") + onnx_model_str = create_multi_head_attention_onnx_model(config, use_symbolic_shape=use_symbolic_shape) if config.provider == "CUDAExecutionProvider": device_id = torch.cuda.current_device() if isinstance(config.device, str) else config.device.index provider_options = CudaSession.get_cuda_provider_options(device_id, config.enable_cuda_graph) + provider_options["sdpa_kernel"] = int(attention_kernel) providers = [(config.provider, provider_options), "CPUExecutionProvider"] else: providers = ["CPUExecutionProvider"] - ort_session = InferenceSession(onnx_model_str, providers=providers) + ort_session = InferenceSession(onnx_model_str, session_options, providers=providers) + return ort_session + + +def create_session( + config: MultiHeadAttentionConfig, session_options=None, attention_kernel=SdpaKernel.DEFAULT +) -> CudaSession: + ort_session = create_ort_session(config, session_options, attention_kernel, use_symbolic_shape=False) cuda_session = CudaSession(ort_session, config.device, config.enable_cuda_graph) shape_dict = config.shape_dict() cuda_session.allocate_buffers(shape_dict) @@ -340,11 +385,8 @@ def create_session( class OrtMultiHeadAttention: """A wrapper of ORT MultiHeadAttention to test relevance and performance.""" - def __init__( - self, - config: MultiHeadAttentionConfig, - ): - self.ort_session = create_session(config) + def __init__(self, config: MultiHeadAttentionConfig, session_options=None): + self.ort_session = create_session(config, session_options) self.feed_dict = config.random_inputs() def infer(self): @@ -363,53 +405,90 @@ def flops(batch, sequence_length, head_size, num_heads, causal): def tflops_per_second(flop, time): - return (flop / time / 10**12) if not math.isnan(time) else 0.0 - - -def get_gpu_kernel_name(config: MultiHeadAttentionConfig) -> str: - # This classification is for Nvidia GPU of Compute Capability 8.* like A100. - # Note that some kernel might not exist in older or newer GPUs. - if os.getenv("ORT_DISABLE_FLASH_ATTENTION") != "1": - if config.input_format == InputFormats.QKV_BSN3H: - min_seq_len = os.getenv("ORT_MIN_SEQ_LEN_FLASH_ATTENTION_PACKED_QKV") - min_length = int(min_seq_len) if min_seq_len is not None else 513 - if config.sequence_length >= min_length: - return "Flash" - else: - return "Flash" + try: + return (flop / time / 10**12) if not math.isnan(time) else 0.0 + except ZeroDivisionError: + return None + + +def get_gpu_kernel_name(attention_kernel: SdpaKernel) -> str: + kernel_names = { + SdpaKernel.DEFAULT: "ort:default", + SdpaKernel.FLASH_ATTENTION: "ort:flash", + SdpaKernel.EFFICIENT_ATTENTION: "ort:efficient", + SdpaKernel.CUDNN_FLASH_ATTENTION: "ort:cudnn", + SdpaKernel.MATH: "ort:math", + } + assert attention_kernel in kernel_names + return kernel_names[attention_kernel] - if (os.getenv("ORT_DISABLE_FUSED_CROSS_ATTENTION") != "1" and config.kv_sequence_length <= 128) or ( - os.getenv("ORT_DISABLE_FUSED_ATTENTION") != "1" - and (config.sequence_length <= 384 or os.getenv("ORT_DISABLE_TRT_FLASH_ATTENTION") != "1") - ): - return "TRT" - if os.getenv("ORT_DISABLE_MEMORY_EFFICIENT_ATTENTION") != "1": - return "MemEff" +def get_cpu_kernel_name(config: MultiHeadAttentionConfig) -> str: + # CPU Flash Attention does not support causal and kv cache etc. + if not (config.causal or config.use_kv_cache or config.past_sequence_length > 0): + if os.getenv("ORT_DISABLE_FLASH_ATTENTION") != "1": + return "ort:flash" - return "Unfused" + return "ort:math" -def get_cpu_kernel_name() -> str: - if os.getenv("ORT_DISABLE_FLASH_ATTENTION") != "1": - return "CPU:Flash" - return "CPU:Unfused" +# ------------------------------------------------------------------ +# Functions for benchmarking PyTorch SDPA +# ------------------------------------------------------------------ +def benchmark_torch_function(func: Callable, *args, **kwargs) -> float: + warmup = 5 + repeats = 100 + for _ in range(warmup): + func(*args, **kwargs) + timer = benchmark.Timer( + stmt="func(*args, **kwargs)", + globals={"args": args, "kwargs": kwargs, "func": func}, + ) + + return timer.timeit(number=repeats).median -def run_tflops_test(use_gpu: bool = True, enable_cuda_graph: bool = False, repeats: int = 100): - if use_gpu: - device_id = torch.cuda.current_device() - device = torch.device("cuda", device_id) - formats = [InputFormats.Q_K_V_BSNH_BSNH_BSNH, InputFormats.Q_KV_BSNH_BSN2H, InputFormats.QKV_BSN3H] - provider = "CUDAExecutionProvider" - print(f"enable_cuda_graph={enable_cuda_graph}") - else: - device_id = 0 - device = torch.device("cpu") - formats = [InputFormats.Q_K_V_BSNH_BSNH_BSNH] - enable_cuda_graph = False - provider = "CPUExecutionProvider" +def run_torch_sdpa( + batch_size: int, + q_seq_len: int, + kv_seq_len: int, + num_heads: int, + head_size: int, + causal: bool, + device, + dtype, + has_mask: bool = False, + mask_dim: int = 2, + mask_dtype=torch.bool, + backend: Optional[int] = None, +): + q_shape = (batch_size, num_heads, q_seq_len, head_size) + kv_shape = (batch_size, num_heads, kv_seq_len, head_size) + q = torch.randn(q_shape, device=device, dtype=dtype) + k = torch.randn(kv_shape, device=device, dtype=dtype) + v = torch.randn(kv_shape, device=device, dtype=dtype) + + attn_mask = None + if has_mask: + mask_shape = (batch_size, num_heads, q_seq_len, kv_seq_len) if mask_dim == 4 else (q_seq_len, kv_seq_len) + attn_mask = torch.ones(mask_shape, dtype=mask_dtype, device=device) + + context = sdpa_kernel(backend) if backend is not None else nullcontext() + + with context: + average_latency = benchmark_torch_function( + scaled_dot_product_attention, + q, + k, + v, + is_causal=causal, + attn_mask=attn_mask, + ) + return average_latency + + +def get_test_configs(use_gpu: bool = True): if use_gpu: # (batch_size, sequence_length, past_sequence_length, num_heads, head_size, run_unfused) configs = [ @@ -450,31 +529,70 @@ def run_tflops_test(use_gpu: bool = True, enable_cuda_graph: bool = False, repea ] else: configs = [ + # TNLGv4 (1, 128, 0, 32, 128, True), (1, 256, 0, 32, 128, True), (1, 512, 0, 32, 128, True), (1, 1024, 0, 32, 128, True), - (1, 2048, 0, 32, 128, True), + # (1, 2048, 0, 32, 128, True), + # bert-base + (1, 128, 0, 12, 64, True), + (1, 384, 0, 12, 64, True), + (1, 512, 0, 12, 64, True), + (4, 128, 0, 12, 64, True), + (4, 384, 0, 12, 64, True), + (4, 512, 0, 12, 64, True), + # bert-large + (1, 128, 0, 16, 64, True), + (1, 384, 0, 16, 64, True), + (1, 512, 0, 16, 64, True), + (4, 128, 0, 16, 64, True), + (4, 384, 0, 16, 64, True), + (4, 512, 0, 16, 64, True), ] + return configs + + +def get_compute_capability(): + assert torch.cuda.is_available() + major, minor = torch.cuda.get_device_capability() + sm = major * 10 + minor + return sm - # List of environment variables to enable/disable attention kernels - print("Environment Variables:") - env_names = [ - "ORT_DISABLE_FLASH_ATTENTION", - "ORT_MIN_SEQ_LEN_FLASH_ATTENTION_PACKED_QKV", - "ORT_DISABLE_FUSED_ATTENTION", - "ORT_DISABLE_TRT_FLASH_ATTENTION", - "ORT_ENABLE_FUSED_CAUSAL_ATTENTION", - "ORT_DISABLE_FUSED_CROSS_ATTENTION", - "ORT_DISABLE_MEMORY_EFFICIENT_ATTENTION", - ] - for name in env_names: - value = os.getenv(name) - if value is not None: - print(f"{name}={value}") - print("\nformat\tcausal\tbatch\tseqlen\theads\th_dim\tms\tTFLOPS\tkernel") - causal = False +def run_tflops_test( + csv_writer: csv.DictWriter, + use_gpu: bool = True, + enable_cuda_graph: bool = False, + causal: bool = False, + has_past: bool = False, + intra_op_num_threads: int = 0, + repeats: int = 100, +): + print(f"run_tflops_test: causal={causal}") + + if use_gpu: + device_id = torch.cuda.current_device() + device = torch.device("cuda", device_id) + formats = [InputFormats.Q_K_V_BSNH_BSNH_BSNH, InputFormats.Q_KV_BSNH_BSN2H, InputFormats.QKV_BSN3H] + provider = "CUDAExecutionProvider" + # flash attention is available for sm >= 80 + sm = get_compute_capability() + if sm >= 80: + backends = [SdpaKernel.DEFAULT, SdpaKernel.FLASH_ATTENTION, SdpaKernel.EFFICIENT_ATTENTION] + else: + backends = [SdpaKernel.DEFAULT, SdpaKernel.EFFICIENT_ATTENTION] + else: + device_id = 0 + device = torch.device("cpu") + formats = [InputFormats.Q_K_V_BSNH_BSNH_BSNH] + enable_cuda_graph = False + provider = "CPUExecutionProvider" + backends = [SdpaKernel.DEFAULT] + + configs = get_test_configs(use_gpu) + + print("\nformat\tcausal\tprompt\tbatch\tseqlen\theads\th_dim\tthreads\tms\tTFLOPS\tkernel") for input_format in formats: for batch_size, sequence_length, past_sequence_length, num_heads, head_size, enable_unfused in configs: @@ -496,21 +614,27 @@ def run_tflops_test(use_gpu: bool = True, enable_cuda_graph: bool = False, repea share_past_present_buffer=False, input_format=input_format, ) - - session = create_session(config) + for attention_kernel in backends: + sess_options = SessionOptions() + sess_options.intra_op_num_threads = intra_op_num_threads + session = create_session(config, sess_options, attention_kernel=attention_kernel) if use_gpu: - kernel = get_gpu_kernel_name(config) + kernel = get_gpu_kernel_name(attention_kernel) else: - kernel = get_cpu_kernel_name() + kernel = get_cpu_kernel_name(config) - if kernel == "Unfused": + if "math" in kernel: # Skip large sequence length for Unfused kernel to avoid OOM. if not enable_unfused: + if config.verbose: + print(f"skip unfused kernel for {vars(config)}") continue # Unfused kernel does not support packed QKV or packed KV formats. if input_format not in [InputFormats.Q_K_V_BSNH_BSNH_BSNH]: + if config.verbose: + print(f"skip input_format for {vars(config)}") continue input_dict = config.random_inputs() @@ -526,19 +650,168 @@ def run_tflops_test(use_gpu: bool = True, enable_cuda_graph: bool = False, repea del session + format_str = InputFormats.input_format_str(input_format) + # compute TFLOPS per second - speed = tflops_per_second( - flops(batch_size, sequence_length, head_size, num_heads, causal), average_latency - ) + speed = None + if past_sequence_length == 0: + speed = tflops_per_second( + flops(batch_size, sequence_length, head_size, num_heads, causal), average_latency + ) + + row = { + "use_gpu": use_gpu, + "enable_cuda_graph": enable_cuda_graph, + "format": format_str, + "causal": causal, + "batch_size": batch_size, + "sequence_length": sequence_length, + "past_sequence_length": past_sequence_length, + "num_heads": num_heads, + "head_size": head_size, + "intra_op_num_threads": intra_op_num_threads, + "average_latency": average_latency, + "tflops": speed, + "kernel": kernel, + } + csv_writer.writerow(row) - format = InputFormats.input_format_str(input_format) + speed = f"{speed:.2f}" if speed is not None else "NA" print( - f"{format}\t{causal}\t{batch_size}\t{sequence_length}\t{num_heads}\t{head_size}\t{average_latency * 1000:.2f}\t{speed:.2f}\t{kernel}" + f"{format_str}\t{causal}\t{not has_past}\t{batch_size}\t{sequence_length}\t{num_heads}\t{head_size}\t" + f"{intra_op_num_threads}\t{average_latency * 1000:.2f}\t{speed}\t{kernel}" ) +def run_torch_test( + csv_writer: csv.DictWriter, + use_gpu: bool = True, + causal: bool = False, +): + configs = get_test_configs(use_gpu) + + if use_gpu: + if not torch.cuda.is_available(): + return + device_id = torch.cuda.current_device() + device = torch.device("cuda", device_id) + dtype = torch.float16 + backends = [ + None, + SDPBackend.FLASH_ATTENTION, + SDPBackend.EFFICIENT_ATTENTION, + SDPBackend.CUDNN_ATTENTION, + SDPBackend.MATH, + ] + else: + device = torch.device("cpu") + dtype = torch.float32 + backends = [None] + + backend_names = { + SDPBackend.FLASH_ATTENTION: "torch:flash", + SDPBackend.EFFICIENT_ATTENTION: "torch:efficient", + SDPBackend.CUDNN_ATTENTION: "torch:cudnn", + SDPBackend.MATH: "torch:math", + None: "torch:default", + } + + # Test PyTorch latency + for batch_size, sequence_length, past_sequence_length, num_heads, head_size, enable_unfused in configs: + for backend in backends: + if backend == SDPBackend.MATH and not enable_unfused: + continue + if backend == SDPBackend.FLASH_ATTENTION and platform.system() != "Linux": + continue + + backend_name = backend_names[backend] + try: + with torch.no_grad(): + torch_latency = run_torch_sdpa( + batch_size, + sequence_length, + sequence_length, + num_heads, + head_size, + causal, + has_mask=False, + mask_dim=2, + mask_dtype=torch.bool, + device=device, + dtype=dtype, + backend=backend, + ) + except RuntimeError: + continue + + speed = tflops_per_second(flops(batch_size, sequence_length, head_size, num_heads, causal), torch_latency) + input_format = "Q,K,V" + print( + f"{input_format}\t{causal}\t{batch_size}\t{sequence_length}\t{num_heads}\t{head_size}\t" + f"{0}\t{torch_latency * 1000:.2f}\t{speed:.2f}\t{backend_name}" + ) + row = { + "use_gpu": use_gpu, + "enable_cuda_graph": False, + "format": input_format, + "causal": causal, + "batch_size": batch_size, + "sequence_length": sequence_length, + "past_sequence_length": past_sequence_length, + "num_heads": num_heads, + "head_size": head_size, + "intra_op_num_threads": torch.get_num_threads(), + "average_latency": torch_latency, + "tflops": speed, + "kernel": backend_name, + } + csv_writer.writerow(row) + + +def run_tflops_tests(args): + features = "gpu" if args.use_gpu else "cpu" + if args.causal: + features += "_causal" + if args.has_past: + features += "_past" + csv_filename = "benchmark_mha_{}_{}_{}.csv".format( + features, + "torch" if args.torch else "ort", + datetime.now().strftime("%Y%m%d-%H%M%S"), + ) + with open(csv_filename, mode="a", newline="") as csv_file: + column_names = [ + "use_gpu", + "enable_cuda_graph", + "format", + "causal", + "batch_size", + "sequence_length", + "past_sequence_length", + "num_heads", + "head_size", + "intra_op_num_threads", + "average_latency", + "tflops", + "kernel", + ] + csv_writer = csv.DictWriter(csv_file, fieldnames=column_names) + csv_writer.writeheader() + + if args.torch: + run_torch_test(csv_writer, args.use_gpu, args.causal) + else: + run_tflops_test( + csv_writer, + use_gpu=args.use_gpu, + enable_cuda_graph=args.use_cuda_graph, + causal=args.causal, + has_past=args.has_past, + intra_op_num_threads=args.intra_op_num_threads, + ) + + def plot_prompt_performance( - sm: int, model_name: str, batch_size: int, num_heads: int, @@ -558,6 +831,7 @@ def plot_prompt_performance( "styles": [("red", "solid"), ("yellow", "dashdot"), ("blue", "dashed"), ("green", "dotted")][0 : len(formats)], } + sm = get_compute_capability() configs = [ triton.testing.Benchmark( x_names=["sequence_length"], @@ -591,13 +865,14 @@ def benchmark( sequence_length=sequence_length, num_heads=num_heads, head_size=head_size, - causal=True, + causal=False, past_sequence_length=0, kv_sequence_length=sequence_length if input_format == InputFormats.get_name_list()[-1] else None, max_cache_sequence_length=max_seq_len, provider="CUDAExecutionProvider", enable_cuda_graph=False, device=device, + dtype=torch.float16, use_kv_cache=False, input_format=InputFormats.convert(input_format), ) @@ -609,14 +884,14 @@ def benchmark( benchmark.run(save_path=".", print_data=True) -def run_performance_test(sm: int): +def run_bert_performance_test(): """ Run performance tests for prompt and token generation. """ configures = [ - (1, 32, 128, 8192, "TNLGv4"), - (4, 32, 128, 8192, "TNLGv4"), + # (1, 32, 128, 8192, "TNLGv4"), + # (4, 32, 128, 8192, "TNLGv4"), (1, 12, 64, 1024, "BertBase"), (16, 12, 64, 1024, "BertBase"), (1, 16, 64, 1024, "BertLarge"), @@ -625,7 +900,6 @@ def run_performance_test(sm: int): for batch_size, num_heads, head_size, max_seq_len, model_name in configures: plot_prompt_performance( - sm=sm, batch_size=batch_size, num_heads=num_heads, head_size=head_size, @@ -634,18 +908,84 @@ def run_performance_test(sm: int): ) +def _parse_arguments(): + parser = argparse.ArgumentParser(description="Benchmark MultiHeadAttention for ONNX Runtime and PyTorch.") + + parser.add_argument( + "--use_gpu", + required=False, + action="store_true", + help="Use GPU for inference.", + ) + parser.set_defaults(use_gpu=False) + + parser.add_argument( + "--use_cuda_graph", + required=False, + action="store_true", + help="Use cuda graph in onnxruntime.", + ) + parser.set_defaults(use_cuda_graph=False) + + parser.add_argument( + "--intra_op_num_threads", + required=False, + type=int, + choices=[0, 1, 2, 4, 8, 16], + default=0, + help="intra_op_num_threads for onnxruntime. ", + ) + + parser.add_argument( + "--has_past", + required=False, + action="store_true", + help="whether past_sequence_length > 0", + ) + parser.set_defaults(has_past=False) + + parser.add_argument( + "--causal", + required=False, + action="store_true", + help="test unidirectional", + ) + parser.set_defaults(causal=False) + + parser.add_argument( + "--torch", + required=False, + action="store_true", + help="test pytorch instead of onnxruntime", + ) + parser.set_defaults(torch=False) + + args = parser.parse_args() + + return args + + if __name__ == "__main__": - if torch.cuda.is_available() and "CUDAExecutionProvider" in get_available_providers(): - # Test CUDA provider - major, minor = torch.cuda.get_device_capability() - sm = major * 10 + minor + args = _parse_arguments() + print(f"arguments:{args}") + + if args.has_past: + assert args.causal, "--has_past need --causal specified" + + if args.use_gpu: + assert args.torch or not args.causal, "no causal cuda kernel in MHA op" + assert torch.cuda.is_available() + if not args.torch: + assert "CUDAExecutionProvider" in get_available_providers() + if args.torch: + assert Version(torch.__version__) >= Version("2.3.0") + assert args.has_past is False + + if args.use_gpu and not args.torch: if platform.system() == "Linux": s = torch.cuda.Stream() with torch.cuda.stream(s), torch.no_grad(): - run_performance_test(sm) - - run_tflops_test(use_gpu=True, enable_cuda_graph=True) + run_bert_performance_test() - # Test CPU provider - run_tflops_test(use_gpu=False, enable_cuda_graph=False) + run_tflops_tests(args) diff --git a/onnxruntime/test/python/transformers/benchmark_mha.sh b/onnxruntime/test/python/transformers/benchmark_mha.sh index 7b21cf1cc1e08..613543d0172dd 100644 --- a/onnxruntime/test/python/transformers/benchmark_mha.sh +++ b/onnxruntime/test/python/transformers/benchmark_mha.sh @@ -1,14 +1,40 @@ -echo "flash attention v2" -ORT_DISABLE_FLASH_ATTENTION=0 ORT_MIN_SEQ_LEN_FLASH_ATTENTION_PACKED_QKV=0 python benchmark_mha.py | tee result.txt +#!/bin/sh -echo "===" -echo "TensorRT attention kernels - cross attention (when kv_seq_len <= 128) or fused attention (when seq_len <= 384) or flash attention (seq_len > 384)" -ORT_DISABLE_FLASH_ATTENTION=1 python benchmark_mha.py | tee -a result.txt +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. +# -------------------------------------------------------------------------- -echo "===" -echo "Memory Efficient attention" -ORT_DISABLE_FLASH_ATTENTION=1 ORT_DISABLE_TRT_FLASH_ATTENTION=1 ORT_DISABLE_FUSED_ATTENTION=1 ORT_DISABLE_FUSED_CROSS_ATTENTION=1 python benchmark_mha.py | tee -a result.txt +echo "Benchmark Scaled Dot Product Attention (SDPA) performance on GPU:" -echo "===" -echo "Unfused Attention (some configurations might fail)" -ORT_DISABLE_FLASH_ATTENTION=1 ORT_DISABLE_TRT_FLASH_ATTENTION=1 ORT_DISABLE_FUSED_ATTENTION=1 ORT_DISABLE_FUSED_CROSS_ATTENTION=1 ORT_DISABLE_MEMORY_EFFICIENT_ATTENTION=1 python benchmark_mha.py | tee -a result.txt +export CUDA_VISIBLE_DEVICES=0 +python benchmark_mha.py --use_gpu +python benchmark_mha.py --use_gpu --use_cuda_graph +python benchmark_mha.py --use_gpu --torch + +cat benchmark_mha_gpu_*.csv > mha_gpu_benchmark_results.csv + +echo "Benchmark performance on CPU with number of threads:" +MKL_DYNAMIC=FALSE OMP_NUM_THREADS=1 python benchmark_mha.py --torch +MKL_DYNAMIC=FALSE OMP_NUM_THREADS=2 python benchmark_mha.py --torch +MKL_DYNAMIC=FALSE OMP_NUM_THREADS=4 python benchmark_mha.py --torch +MKL_DYNAMIC=FALSE OMP_NUM_THREADS=8 python benchmark_mha.py --torch + +python benchmark_mha.py --intra_op_num_threads 1 +python benchmark_mha.py --intra_op_num_threads 2 +python benchmark_mha.py --intra_op_num_threads 4 +python benchmark_mha.py --intra_op_num_threads 8 + + +echo "Benchmark performance on CPU with default threads settings:" +python benchmark_mha.py +ORT_DISABLE_FLASH_ATTENTION=1 python benchmark_mha.py +python benchmark_mha.py --torch + +python benchmark_mha.py --causal +python benchmark_mha.py --torch --causal + +# Pytorch SDPA does not support causal attention with past state, we only test ORT here. +python benchmark_mha.py --causal --has_past + +cat benchmark_mha_cpu_*.csv > mha_cpu_benchmark_results.csv diff --git a/onnxruntime/test/python/transformers/test_mha.py b/onnxruntime/test/python/transformers/test_mha.py index ff473cc2ced92..0fcbd889847e9 100644 --- a/onnxruntime/test/python/transformers/test_mha.py +++ b/onnxruntime/test/python/transformers/test_mha.py @@ -10,36 +10,15 @@ import concurrent.futures import itertools import unittest -from enum import IntEnum from typing import Dict, List, Optional import numpy import torch -from benchmark_mha import ( - InputFormats, - MultiHeadAttentionConfig, - OrtMultiHeadAttention, - create_multi_head_attention_onnx_model, -) +from benchmark_mha import InputFormats, MultiHeadAttentionConfig, OrtMultiHeadAttention, SdpaKernel, create_ort_session from einops import rearrange from parameterized import parameterized import onnxruntime -from onnxruntime import InferenceSession - - -class SdpaKernel(IntEnum): - """Bit flags for sdpa_kernel CUDA provider option""" - - DEFAULT = 0 - FLASH_ATTENTION = 1 - EFFICIENT_ATTENTION = 2 - TRT_FUSED_ATTENTION = 4 - CUDNN_FLASH_ATTENTION = 8 - MATH = 16 - TRT_FLASH_ATTENTION = 32 - TRT_CROSS_ATTENTION = 64 - TRT_CAUSAL_ATTENTION = 128 def attention_reference( @@ -466,7 +445,7 @@ def parity_check_mha_multi_threading( test_inputs: List[Dict], rtol: float = 1e-3, atol: float = 1e-3, - sdpa_kernel: int = SdpaKernel.DEFAULT, + attention_kernel: int = SdpaKernel.DEFAULT, max_threads: int = 5, verbose: bool = False, ): @@ -476,21 +455,14 @@ def parity_check_mha_multi_threading( if config.causal and config.provider == "CUDAExecutionProvider": return None # Some kernel does not support certain input format. - if sdpa_kernel not in [ + if attention_kernel not in [ SdpaKernel.DEFAULT, SdpaKernel.FLASH_ATTENTION, SdpaKernel.EFFICIENT_ATTENTION, ] and config.input_format in [InputFormats.Q_KV_BSNH_BSN2H]: return None - if verbose: - print(f"create a shared session with {vars(config)}") - onnx_model_str = create_multi_head_attention_onnx_model(config, use_symbolic_shape=True) - if config.provider == "CUDAExecutionProvider": - provider_options = {"arena_extend_strategy": "kSameAsRequested", "sdpa_kernel": int(sdpa_kernel)} - providers = [(config.provider, provider_options), "CPUExecutionProvider"] - else: - providers = ["CPUExecutionProvider"] - ort_session = InferenceSession(onnx_model_str, providers=providers) + + ort_session = create_ort_session(config, attention_kernel=attention_kernel, use_symbolic_shape=True) def convert_to_ort_inputs(feed_dict): ort_inputs = {} @@ -613,7 +585,7 @@ def test_mha_cuda(self, config): def test_mha_cpu(self, config): parity_check_mha(config) - def run_mha_cuda_multi_threading(self, spda_kernel): + def run_mha_cuda_multi_threading(self, attention_kernel): for configs in multi_thread_test_cases("CUDAExecutionProvider", comprehensive_mode): test_inputs = [] for config in configs: @@ -626,8 +598,10 @@ def run_mha_cuda_multi_threading(self, spda_kernel): config.input_format = old_format test_inputs.append({"config": config, "ort_inputs": ort_inputs, "ref_inputs": ref_inputs}) - exception = parity_check_mha_multi_threading(test_inputs, sdpa_kernel=spda_kernel, max_threads=len(configs)) - assert exception is None, f"{spda_kernel=}, {vars(configs[0])}, {exception}" + exception = parity_check_mha_multi_threading( + test_inputs, attention_kernel=attention_kernel, max_threads=len(configs) + ) + assert exception is None, f"{attention_kernel=}, {vars(configs[0])}, {exception}" def test_mha_cuda_multi_threading(self): self.run_mha_cuda_multi_threading(SdpaKernel.DEFAULT)