Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[CUDA] Update benchmark_mha.py to capture debug info to identify sdpa kernel #21804

Merged
merged 3 commits into from
Aug 22, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion onnxruntime/contrib_ops/cpu/utils/console_dumper.h
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,11 @@ void PrintTensorByDims(const TConsoleDumper* dumper,
const char* name,
const T* tensor,
gsl::span<const int64_t>& dims) {
if (dumper->IsEnabled() && (tensor == nullptr || dims.size() == 0)) {
if (!dumper->IsEnabled()) {
return;
}

if ((tensor == nullptr || dims.size() == 0)) {
std::cout << std::string(name) << " is None" << std::endl;
return;
}
Expand Down
54 changes: 16 additions & 38 deletions onnxruntime/contrib_ops/cuda/bert/attention_kernel_options.cc
Original file line number Diff line number Diff line change
Expand Up @@ -128,45 +128,23 @@ void AttentionKernelDebugInfo::Print(const char* operator_name,
sstream << " DataType=fp32";
}

sstream << " SdpaKernel=";
if (use_flash_attention.has_value() && use_flash_attention.value()) {
sstream << " FLASH_ATTENTION=" << int(use_flash_attention.value());
}

if (use_efficient_attention.has_value() && use_efficient_attention.value()) {
sstream << " EFFICIENT_ATTENTION=" << int(use_efficient_attention.value());
}

if (use_trt_fused_attention.has_value() && use_trt_fused_attention.value()) {
sstream << " TRT_FUSED_ATTENTION=" << int(use_trt_fused_attention.value());
}

if (use_cudnn_flash_attention.has_value() && use_cudnn_flash_attention.value()) {
sstream << " CUDNN_FLASH_ATTENTION=" << int(use_cudnn_flash_attention.value());
}

if (use_trt_flash_attention.has_value() && use_trt_flash_attention.value()) {
sstream << " TRT_FLASH_ATTENTION=" << int(use_trt_flash_attention.value());
}

if (use_trt_cross_attention.has_value() && use_trt_cross_attention.value()) {
sstream << " TRT_CROSS_ATTENTION=" << int(use_trt_cross_attention.value());
}

if (use_trt_causal_attention.has_value() && use_trt_causal_attention.value()) {
sstream << " TRT_CAUSAL_ATTENTION=" << int(use_trt_causal_attention.value());
}

bool use_fused = (use_flash_attention.has_value() && use_flash_attention.value()) ||
(use_efficient_attention.has_value() && use_efficient_attention.value()) ||
(use_trt_fused_attention.has_value() && use_trt_fused_attention.value()) ||
(use_cudnn_flash_attention.has_value() && use_cudnn_flash_attention.value()) ||
(use_trt_flash_attention.has_value() && use_trt_flash_attention.value()) ||
(use_trt_cross_attention.has_value() && use_trt_cross_attention.value()) ||
(use_trt_causal_attention.has_value() && use_trt_causal_attention.value());

// Fall back to unfused when no fused kernel is enabled.
if (!use_fused) {
sstream << " MATH=1";
sstream << "FLASH_ATTENTION";
} else if (use_efficient_attention.has_value() && use_efficient_attention.value()) {
sstream << "EFFICIENT_ATTENTION";
} else if (use_trt_fused_attention.has_value() && use_trt_fused_attention.value()) {
sstream << "TRT_FUSED_ATTENTION";
} else if (use_cudnn_flash_attention.has_value() && use_cudnn_flash_attention.value()) {
sstream << "CUDNN_FLASH_ATTENTION";
} else if (use_trt_flash_attention.has_value() && use_trt_flash_attention.value()) {
sstream << "TRT_FLASH_ATTENTION";
} else if (use_trt_cross_attention.has_value() && use_trt_cross_attention.value()) {
sstream << "TRT_CROSS_ATTENTION";
} else if (use_trt_causal_attention.has_value() && use_trt_causal_attention.value()) {
sstream << "TRT_CAUSAL_ATTENTION";
} else {
sstream << "MATH";
}

// Output text in Cyan color to make it easier to spot.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -314,7 +314,7 @@ struct BytesHash {
};

// Use thread local caches because cuDNN execution plans are not guaranteed to be thread safe.
// TODO(tianleiwu): since we the key includes sequence lengths, we may want to limit the cache size.
// TODO(tianleiwu): since the key includes sequence lengths, we may want to limit the cache size.
thread_local
std::unordered_map<GraphParams, std::shared_ptr<fe::graph::Graph>, BytesHash<GraphParams> > mha_graph_cache;

Expand Down
1 change: 0 additions & 1 deletion onnxruntime/contrib_ops/cuda/bert/multihead_attention.cc
Original file line number Diff line number Diff line change
Expand Up @@ -233,7 +233,6 @@ Status MultiHeadAttention<T>::ComputeInternal(OpKernelContext* context) const {
bool use_fused_runner =
kernel_type == AttentionKernelType::AttentionKernel_Default &&
!disable_fused_self_attention_ &&
fused_cross_attention_kernel == nullptr &&
nullptr == attention_bias &&
(parameters.qkv_format == Q_K_V_BSNH || parameters.qkv_format == QKV_BSN3H) &&
nullptr == past_key && nullptr == present_key &&
Expand Down
4 changes: 2 additions & 2 deletions onnxruntime/contrib_ops/cuda/bert/packed_attention.cc
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,7 @@ Status PackedAttention<T>::CheckInputs(const TensorShape& input_shape,
// Abbreviation and Meanings:
// T: token_count
// B: batch_size
// S: sequence_length (input sequence length of query)
// S: sequence_length
// N: num_heads
// H: head size for Q and K, aka q_head_size or v_head_size or qk_head_size
// H_v: v_head_size
Expand All @@ -125,7 +125,7 @@ Status PackedAttention<T>::CheckInputs(const TensorShape& input_shape,
// bias (Q/K/V) : (D + D + D_v)
// token_offset : (B, S)
// cu_seq_len_shape : (B + 1)
// attention_bias : (B, N, S, S), (1, N, S, S) or NULL
// attention_bias : (B or 1, N or 1, S, S) or NULL
const auto& input_dims = input_shape.GetDims();
if (input_dims.size() != 2) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ Status PackedMultiHeadAttention<T>::CheckInputs(const TensorShape& query_shape,
// Input 'value': None
// Input 'token_offset': (batch_size, sequence_length)
// Input 'cumulative_sequence_length': (batch_size + 1)
// Input 'attention_bias': (batch_size or 1, num_heads, sequence_length, sequence_length) or None
// Input 'attention_bias': (batch_size or 1, num_heads or 1, sequence_length, sequence_length) or None
// Output 'output': (token_count, v_hidden_size)

const auto& query_dims = query_shape.GetDims();
Expand Down
106 changes: 96 additions & 10 deletions onnxruntime/test/python/transformers/benchmark_mha.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@
import os
import platform
import statistics
import sys
import threading
import time
from contextlib import nullcontext
from datetime import datetime
Expand Down Expand Up @@ -771,6 +773,73 @@
return sm


class CaptureStdout:
def __init__(self):
self.fd = sys.stdout.fileno()
self.chunk_size = 1024
self.output = b""

def _capture(self):
chunks = []
while chunk := os.read(self._pipe_reader, self.chunk_size):
chunks.append(chunk)
self.output = b"".join(chunks)

def __enter__(self):
self._duped_fd = os.dup(self.fd)
self._pipe_reader, pipe_writer = os.pipe()
os.dup2(pipe_writer, self.fd)
os.close(pipe_writer)
self._capture_thread = threading.Thread(target=self._capture)
self._capture_thread.start()
return self

def __exit__(self, exc_type, exc_val, exc_tb):
os.close(self.fd)
self._capture_thread.join()
os.close(self._pipe_reader)
os.dup2(self._duped_fd, self.fd)
os.close(self._duped_fd)


def sdpa_kernel_from_debug_info(
config: MultiHeadAttentionConfig, attention_kernel: SdpaKernel, sess_options: SessionOptions
):
os.environ["ORT_ENABLE_ATTENTION_KERNEL_DEBUG_INFO"] = "1"
try:
with CaptureStdout() as captured:
session = create_session(config, sess_options, attention_kernel=attention_kernel)
input_dict = config.random_inputs()
session.infer(input_dict)
except Exception as e:
print(f"Failed to run {attention_kernel=} for {config=}. Exception: {e}")
finally:
os.environ["ORT_ENABLE_ATTENTION_KERNEL_DEBUG_INFO"] = "0"

captured_text = captured.output.decode()
Fixed Show fixed Hide fixed

import re

m = re.search("SdpaKernel=(?P<kernel>[A-Z_]+)", captured_text)
if m is not None:
name = m.group("kernel")
kernel_names = {
"FLASH_ATTENTION": "ort:flash",
"EFFICIENT_ATTENTION": "ort:efficient",
"CUDNN_FLASH_ATTENTION": "ort:cudnn",
"MATH": "ort:math",
"TRT_FUSED_ATTENTION": "ort:trt_fmha",
"TRT_FLASH_ATTENTION": "ort:trt_flash",
"TRT_CROSS_ATTENTION": "ort:trt_cross",
"TRT_CAUSAL_ATTENTION": "ort:trt_causal",
}
return kernel_names[name]
else:
print("Failed to get sdpa kernel from debug info:", captured_text)

return None


def run_tflops_test(
csv_writer: csv.DictWriter,
args: argparse.Namespace,
Expand Down Expand Up @@ -809,7 +878,9 @@
backends = [SdpaKernel.DEFAULT]

configs = get_test_configs(args)
print("\nformat\tcausal\tattBias\tbatch\tseqlen\tpast\theads\th_dim\tthreads\tms\tTFLOPS\tkernel")
print(
"\nformat\tcausal\tattBias\tbatch\tseqlen\tpast\theads\th_dim\tthreads\tms\tTFLOPS\tsdpa_kernel\trequest_kernel"
)

for input_format in formats:
for batch_size, sequence_length, past_sequence_length, num_heads, head_size, enable_unfused in configs:
Expand All @@ -836,14 +907,13 @@
for attention_kernel in backends:
sess_options = SessionOptions()
sess_options.intra_op_num_threads = intra_op_num_threads
session = create_session(config, sess_options, attention_kernel=attention_kernel)

if use_gpu:
kernel = get_gpu_kernel_name(attention_kernel)
request_kernel = get_gpu_kernel_name(attention_kernel)
else:
kernel = get_cpu_kernel_name(config)
request_kernel = get_cpu_kernel_name(config)

if "math" in kernel:
if "math" in request_kernel:
# Skip large sequence length for Unfused kernel to avoid OOM.
if not enable_unfused:
if config.verbose:
Expand All @@ -856,13 +926,23 @@
print(f"skip input_format for {vars(config)}")
continue

if use_gpu:
actual_kernel = sdpa_kernel_from_debug_info(config, attention_kernel, sess_options)
if actual_kernel is None:
print(f"Warning: skip {config} since kernel from debug info is None")
continue
else:
# CPU has no debug info for now.
actual_kernel = request_kernel

session = create_session(config, sess_options, attention_kernel=attention_kernel)
input_dict = config.random_inputs()

# warm up session
try:
_ = measure_latency(session, input_dict)
except Exception as e:
print(f"Failed to run {kernel=} for {config=}. Exception: {e}")
print(f"Failed to run {request_kernel=} for {config=}. Exception: {e}")
continue

latency_list = []
Expand Down Expand Up @@ -898,15 +978,16 @@
"intra_op_num_threads": intra_op_num_threads,
"average_latency": average_latency,
"tflops": speed,
"kernel": kernel,
"request_kernel": request_kernel,
"kernel": actual_kernel,
}
csv_writer.writerow(row)

speed = f"{speed:.2f}" if speed is not None else "NA"
print(
f"{format_str}\t{causal}\t{args.has_attn_bias}\t{batch_size}\t"
f"{sequence_length}\t{past_sequence_length}\t{num_heads}\t{head_size}\t"
f"{intra_op_num_threads}\t{average_latency * 1000:.2f}\t{speed}\t{kernel}"
f"{intra_op_num_threads}\t{average_latency * 1000:.2f}\t{speed}\t{actual_kernel}\t{request_kernel}"
)


Expand Down Expand Up @@ -979,7 +1060,7 @@
print(
f"{input_format}\t{causal}\t{False}\t{batch_size}\t"
f"{sequence_length}\t{past_sequence_length}\t{num_heads}\t{head_size}\t"
f"{torch.get_num_threads()}\t{torch_latency * 1000:.2f}\t{speed}\t{backend_name}"
f"{torch.get_num_threads()}\t{torch_latency * 1000:.2f}\t{speed}\t{backend_name}\t{backend_name}"
)
row = {
"use_gpu": use_gpu,
Expand All @@ -997,6 +1078,7 @@
"intra_op_num_threads": torch.get_num_threads(),
"average_latency": torch_latency,
"tflops": speed,
"request_kernel": backend_name,
"kernel": backend_name,
}
csv_writer.writerow(row)
Expand Down Expand Up @@ -1030,6 +1112,7 @@
"intra_op_num_threads",
"average_latency",
"tflops",
"request_kernel",
"kernel",
]
csv_writer = csv.DictWriter(csv_file, fieldnames=column_names)
Expand Down Expand Up @@ -1224,7 +1307,7 @@
"--repeats",
required=False,
type=int,
default=100,
default=0,
help="number of repeats for performance test",
)

Expand Down Expand Up @@ -1269,6 +1352,9 @@
args = _parse_arguments()
print(f"arguments:{args}")

if args.repeats == 0:
args.repeats = 10000 if args.use_gpu else 100

if args.use_gpu:
assert args.torch or not args.causal, "no causal cuda kernel in MHA op"
assert torch.cuda.is_available()
Expand Down
Loading