Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add sparse attention kernel for H100 (sm90) #20553

Merged
merged 1 commit into from
May 5, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 5 additions & 5 deletions onnxruntime/contrib_ops/cuda/sparse/sparse_attention.cc
Original file line number Diff line number Diff line change
Expand Up @@ -106,14 +106,15 @@
total_seq_len));
// Some limitations of CUDA kernels
// The v1 and v2 kernels have same coverage, so only check one of them to see whether it is supported.
if (!sparse_attention_v1::is_supported_device(device_prop)) {
int sm = device_prop.major * 10 + device_prop.minor;
if (!sparse_attention_v1::is_supported_device(sm)) {
return ORT_MAKE_STATUS(ONNXRUNTIME, NOT_IMPLEMENTED,
"SparseAttention only support CUDA device with compute capacity 8.*. Got ",
device_prop.major);
"SparseAttention only supports CUDA device with compute capacity 7.5, 8.0, 8.6, 8.9 and 9.0. Got sm=",

Check warning on line 112 in onnxruntime/contrib_ops/cuda/sparse/sparse_attention.cc

View workflow job for this annotation

GitHub Actions / Lint C++

[cpplint] reported by reviewdog 🐶 Lines should be <= 120 characters long [whitespace/line_length] [2] Raw Output: onnxruntime/contrib_ops/cuda/sparse/sparse_attention.cc:112: Lines should be <= 120 characters long [whitespace/line_length] [2]
sm);
}
if (!sparse_attention_v1::is_supported_sparse_attention(parameters.head_size, sparse_block_size_)) {
return ORT_MAKE_STATUS(ONNXRUNTIME, NOT_IMPLEMENTED,
"SparseAttention only support head_size=128 and sparse_block_size=64. Got head_size=",
"SparseAttention only supports head_size=128 and sparse_block_size=64. Got head_size=",
parameters.head_size,
",sparse_block_size=",
sparse_block_size_);
Expand Down Expand Up @@ -149,7 +150,6 @@
}

if (!kernel_loaded_) {
int sm = device_prop.major * 10 + device_prop.minor;
if constexpr (std::is_same<T, MLFloat16>::value) {
// std::call_once is used in load_sparse_attention_fp16 so no need to use mutex here.
// After kernel is loaded, it will stay in memory until the process exits. We do not unload explicitly.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

# Use triton AoT compiler to convert sparse_attention_triton.py to C source files including cubin and dispatcher.
# Example to use this script (Tested with Python 3.10 and CUDA 12.3 in Ubuntu 20.04):
# python3 -m pip install torch==2.3.0 triton==2.3.0
# python3 -m pip install numpy==1.26.4 torch==2.3.0 triton==2.3.0
# python3 compile_sparse_attention.py | sh
#
# Note that sparse_attention_v1_*.cc and sparse_attention_dispatcher_*.h are the generated files.
Expand Down Expand Up @@ -35,31 +35,35 @@ def generate_triton_compile_shell_script(sm, dtype="fp16"):
print(f"rm -rf {out_dir}")
print(f"mkdir -p {out_dir}")

# Note that block_n * num_block_d is the head_size. We support head_size = 128 for now.
block_n_values = [64]
block_d_values = [64]
num_block_d_values = [2]
even_m_values = [True, False]
even_n_values = [True, False]

# Use triton compiler to compile the kernel of different combinations of constant parameters.
for block_n, block_d, num_blocks_d, even_m, even_n in product(
block_n_values, block_d_values, num_block_d_values, even_m_values, even_n_values
for block_n, block_d, num_blocks_d, even_n in product(
block_n_values, block_d_values, num_block_d_values, even_n_values
):
block_m_values = [16, block_n] if block_n != 16 else [block_n]
for block_m in block_m_values:
scalar_params = "i32,i32,i32,fp32,i32:16,i32:16,i32:16,i32:16,i32:16,i32:16,i32:16,i32:16,i32:16,i32:16,i32:16,i32:16,i32,i32,i32"
sig = f"*{dtype}:16,*{dtype}:16,*{dtype}:16,*{dtype}:16,*i32:16,*i32:16,{scalar_params},{block_m},{int(even_m)},{block_n},{int(even_n)},{block_d},{num_blocks_d}"
prefix = "python compile.py sparse_attention_triton.py"
filename = f"sparse_attention_v1_{dtype}_m{block_m}_{int(even_m)}_n{block_n}_{int(even_n)}_d{block_d}_{num_blocks_d}_sm{sm}"
name = f"sparse_attention_{dtype}_sm{sm}"
num_warps = max(1, 2 ** int(math.log2(min(block_m, block_n, block_d) / 16)))
num_stages = 2
# TODO: use different kernel name (change the name in sparse_attention_triton.py before running compile.py)
print(
f"{prefix} -n block_sparse_attention_kernel -o {out_dir}/{filename} --out-name {name} "
f'-w {num_warps} -ns {num_stages} -s "{sig}" -g "(total_seq_len - past_seq_len + {block_m} - 1) / {block_m}, batch_size * num_heads, 1"'
)
head_size = block_d * num_blocks_d
block_m = block_n
even_m = even_n
scalar_params = "i32,i32,i32,fp32,i32:16,i32:16,i32:16,i32:16,i32:16,i32:16,i32:16,i32:16,i32:16,i32:16,i32:16,i32:16,i32,i32,i32"
sig = f"*{dtype}:16,*{dtype}:16,*{dtype}:16,*{dtype}:16,*i32:16,*i32:16,{scalar_params},{block_m},{int(even_m)},{block_n},{int(even_n)},{block_d},{num_blocks_d}"
prefix = "python compile.py sparse_attention_triton.py"
filename = f"sparse_attention_v1_{dtype}_d{head_size}_n{block_n}_e{int(even_n)}_sm{sm}"
name = f"sparse_attention_{dtype}_sm{sm}"
num_warps = max(1, 2 ** int(math.log2(min(block_m, block_n, block_d) / 16)))

# Shared memory is 96KB for V100 (sm70), 64KB for T4 (sm75), 164KB for A100 (sm80), 228KB for H100 (sm90).
# Adjust stages so that shared memory size is within limit, and choose the one with best performance.
sm_to_stages = {90: 3, 80: 2, 75: 2}

num_stages = sm_to_stages[sm]

# TODO: use different kernel name (change the name in sparse_attention_triton.py before running compile.py)
print(
f"{prefix} -n block_sparse_attention_kernel -o {out_dir}/{filename} --out-name {name} "
f'-w {num_warps} -ns {num_stages} -s "{sig}" -g "(total_seq_len - past_seq_len + {block_m} - 1) / {block_m}, batch_size * num_heads, 1"'
)

# Generate the dispatcher.
dispatcher = f"sparse_attention_dispatcher_{dtype}_sm{sm}"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,82 +11,6 @@ namespace contrib {
namespace cuda {
namespace sparse_attention_v1 {

// launcher for: sparse_attention_bf16_sm80_16x0x64x0x64x2_warps1xstages2
Status sparse_attention_bf16_sm80_ba65ff9c(SparseAttentionParams& params);

Status sparse_attention_bf16_sm80_16x0x64x0x64x2_warps1xstages2(SparseAttentionParams& params) {
return sparse_attention_bf16_sm80_ba65ff9c(params);
}

// load for: sparse_attention_bf16_sm80_16x0x64x0x64x2_warps1xstages2
void load_sparse_attention_bf16_sm80_ba65ff9c();
void load_sparse_attention_bf16_sm80_16x0x64x0x64x2_warps1xstages2() {
load_sparse_attention_bf16_sm80_ba65ff9c();
}

// unload for: sparse_attention_bf16_sm80_16x0x64x0x64x2_warps1xstages2
void unload_sparse_attention_bf16_sm80_ba65ff9c();
void unload_sparse_attention_bf16_sm80_16x0x64x0x64x2_warps1xstages2() {
unload_sparse_attention_bf16_sm80_ba65ff9c();
}

// launcher for: sparse_attention_bf16_sm80_16x0x64x1x64x2_warps1xstages2
Status sparse_attention_bf16_sm80_f951a16d(SparseAttentionParams& params);

Status sparse_attention_bf16_sm80_16x0x64x1x64x2_warps1xstages2(SparseAttentionParams& params) {
return sparse_attention_bf16_sm80_f951a16d(params);
}

// load for: sparse_attention_bf16_sm80_16x0x64x1x64x2_warps1xstages2
void load_sparse_attention_bf16_sm80_f951a16d();
void load_sparse_attention_bf16_sm80_16x0x64x1x64x2_warps1xstages2() {
load_sparse_attention_bf16_sm80_f951a16d();
}

// unload for: sparse_attention_bf16_sm80_16x0x64x1x64x2_warps1xstages2
void unload_sparse_attention_bf16_sm80_f951a16d();
void unload_sparse_attention_bf16_sm80_16x0x64x1x64x2_warps1xstages2() {
unload_sparse_attention_bf16_sm80_f951a16d();
}

// launcher for: sparse_attention_bf16_sm80_16x1x64x0x64x2_warps1xstages2
Status sparse_attention_bf16_sm80_646fefc8(SparseAttentionParams& params);

Status sparse_attention_bf16_sm80_16x1x64x0x64x2_warps1xstages2(SparseAttentionParams& params) {
return sparse_attention_bf16_sm80_646fefc8(params);
}

// load for: sparse_attention_bf16_sm80_16x1x64x0x64x2_warps1xstages2
void load_sparse_attention_bf16_sm80_646fefc8();
void load_sparse_attention_bf16_sm80_16x1x64x0x64x2_warps1xstages2() {
load_sparse_attention_bf16_sm80_646fefc8();
}

// unload for: sparse_attention_bf16_sm80_16x1x64x0x64x2_warps1xstages2
void unload_sparse_attention_bf16_sm80_646fefc8();
void unload_sparse_attention_bf16_sm80_16x1x64x0x64x2_warps1xstages2() {
unload_sparse_attention_bf16_sm80_646fefc8();
}

// launcher for: sparse_attention_bf16_sm80_16x1x64x1x64x2_warps1xstages2
Status sparse_attention_bf16_sm80_21cac990(SparseAttentionParams& params);

Status sparse_attention_bf16_sm80_16x1x64x1x64x2_warps1xstages2(SparseAttentionParams& params) {
return sparse_attention_bf16_sm80_21cac990(params);
}

// load for: sparse_attention_bf16_sm80_16x1x64x1x64x2_warps1xstages2
void load_sparse_attention_bf16_sm80_21cac990();
void load_sparse_attention_bf16_sm80_16x1x64x1x64x2_warps1xstages2() {
load_sparse_attention_bf16_sm80_21cac990();
}

// unload for: sparse_attention_bf16_sm80_16x1x64x1x64x2_warps1xstages2
void unload_sparse_attention_bf16_sm80_21cac990();
void unload_sparse_attention_bf16_sm80_16x1x64x1x64x2_warps1xstages2() {
unload_sparse_attention_bf16_sm80_21cac990();
}

// launcher for: sparse_attention_bf16_sm80_64x0x64x0x64x2_warps4xstages2
Status sparse_attention_bf16_sm80_31acb592(SparseAttentionParams& params);

Expand All @@ -106,44 +30,6 @@ void unload_sparse_attention_bf16_sm80_64x0x64x0x64x2_warps4xstages2() {
unload_sparse_attention_bf16_sm80_31acb592();
}

// launcher for: sparse_attention_bf16_sm80_64x0x64x1x64x2_warps4xstages2
Status sparse_attention_bf16_sm80_d55ab166(SparseAttentionParams& params);

Status sparse_attention_bf16_sm80_64x0x64x1x64x2_warps4xstages2(SparseAttentionParams& params) {
return sparse_attention_bf16_sm80_d55ab166(params);
}

// load for: sparse_attention_bf16_sm80_64x0x64x1x64x2_warps4xstages2
void load_sparse_attention_bf16_sm80_d55ab166();
void load_sparse_attention_bf16_sm80_64x0x64x1x64x2_warps4xstages2() {
load_sparse_attention_bf16_sm80_d55ab166();
}

// unload for: sparse_attention_bf16_sm80_64x0x64x1x64x2_warps4xstages2
void unload_sparse_attention_bf16_sm80_d55ab166();
void unload_sparse_attention_bf16_sm80_64x0x64x1x64x2_warps4xstages2() {
unload_sparse_attention_bf16_sm80_d55ab166();
}

// launcher for: sparse_attention_bf16_sm80_64x1x64x0x64x2_warps4xstages2
Status sparse_attention_bf16_sm80_b0560d11(SparseAttentionParams& params);

Status sparse_attention_bf16_sm80_64x1x64x0x64x2_warps4xstages2(SparseAttentionParams& params) {
return sparse_attention_bf16_sm80_b0560d11(params);
}

// load for: sparse_attention_bf16_sm80_64x1x64x0x64x2_warps4xstages2
void load_sparse_attention_bf16_sm80_b0560d11();
void load_sparse_attention_bf16_sm80_64x1x64x0x64x2_warps4xstages2() {
load_sparse_attention_bf16_sm80_b0560d11();
}

// unload for: sparse_attention_bf16_sm80_64x1x64x0x64x2_warps4xstages2
void unload_sparse_attention_bf16_sm80_b0560d11();
void unload_sparse_attention_bf16_sm80_64x1x64x0x64x2_warps4xstages2() {
unload_sparse_attention_bf16_sm80_b0560d11();
}

// launcher for: sparse_attention_bf16_sm80_64x1x64x1x64x2_warps4xstages2
Status sparse_attention_bf16_sm80_c777f3f5(SparseAttentionParams& params);

Expand All @@ -165,13 +51,7 @@ void unload_sparse_attention_bf16_sm80_64x1x64x1x64x2_warps4xstages2() {

typedef Status (*kernel_func_t)(SparseAttentionParams& params);
kernel_func_t sparse_attention_bf16_sm80_kernels[] = {
sparse_attention_bf16_sm80_16x0x64x0x64x2_warps1xstages2,
sparse_attention_bf16_sm80_16x0x64x1x64x2_warps1xstages2,
sparse_attention_bf16_sm80_16x1x64x0x64x2_warps1xstages2,
sparse_attention_bf16_sm80_16x1x64x1x64x2_warps1xstages2,
sparse_attention_bf16_sm80_64x0x64x0x64x2_warps4xstages2,
sparse_attention_bf16_sm80_64x0x64x1x64x2_warps4xstages2,
sparse_attention_bf16_sm80_64x1x64x0x64x2_warps4xstages2,
sparse_attention_bf16_sm80_64x1x64x1x64x2_warps4xstages2,
};

Expand All @@ -185,24 +65,12 @@ Status sparse_attention_bf16_sm80(SparseAttentionParams& params, int algo_id) {
}

void load_sparse_attention_bf16_sm80(void) {
load_sparse_attention_bf16_sm80_16x0x64x0x64x2_warps1xstages2();
load_sparse_attention_bf16_sm80_16x0x64x1x64x2_warps1xstages2();
load_sparse_attention_bf16_sm80_16x1x64x0x64x2_warps1xstages2();
load_sparse_attention_bf16_sm80_16x1x64x1x64x2_warps1xstages2();
load_sparse_attention_bf16_sm80_64x0x64x0x64x2_warps4xstages2();
load_sparse_attention_bf16_sm80_64x0x64x1x64x2_warps4xstages2();
load_sparse_attention_bf16_sm80_64x1x64x0x64x2_warps4xstages2();
load_sparse_attention_bf16_sm80_64x1x64x1x64x2_warps4xstages2();
}

void unload_sparse_attention_bf16_sm80(void) {
unload_sparse_attention_bf16_sm80_16x0x64x0x64x2_warps1xstages2();
unload_sparse_attention_bf16_sm80_16x0x64x1x64x2_warps1xstages2();
unload_sparse_attention_bf16_sm80_16x1x64x0x64x2_warps1xstages2();
unload_sparse_attention_bf16_sm80_16x1x64x1x64x2_warps1xstages2();
unload_sparse_attention_bf16_sm80_64x0x64x0x64x2_warps4xstages2();
unload_sparse_attention_bf16_sm80_64x0x64x1x64x2_warps4xstages2();
unload_sparse_attention_bf16_sm80_64x1x64x0x64x2_warps4xstages2();
unload_sparse_attention_bf16_sm80_64x1x64x1x64x2_warps4xstages2();
}

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

// This file is generated by compile_sparse_attention.py using triton AoT compiler

#pragma once
#include "contrib_ops/cuda/sparse/sparse_attention_v1/sparse_attention_common.h"

namespace onnxruntime {
namespace contrib {
namespace cuda {
namespace sparse_attention_v1 {

// launcher for: sparse_attention_bf16_sm90_64x0x64x0x64x2_warps4xstages3
Status sparse_attention_bf16_sm90_eb17c351(SparseAttentionParams& params);

Status sparse_attention_bf16_sm90_64x0x64x0x64x2_warps4xstages3(SparseAttentionParams& params) {
return sparse_attention_bf16_sm90_eb17c351(params);
}

// load for: sparse_attention_bf16_sm90_64x0x64x0x64x2_warps4xstages3
void load_sparse_attention_bf16_sm90_eb17c351();
void load_sparse_attention_bf16_sm90_64x0x64x0x64x2_warps4xstages3() {
load_sparse_attention_bf16_sm90_eb17c351();
}

// unload for: sparse_attention_bf16_sm90_64x0x64x0x64x2_warps4xstages3
void unload_sparse_attention_bf16_sm90_eb17c351();
void unload_sparse_attention_bf16_sm90_64x0x64x0x64x2_warps4xstages3() {
unload_sparse_attention_bf16_sm90_eb17c351();
}

// launcher for: sparse_attention_bf16_sm90_64x1x64x1x64x2_warps4xstages3
Status sparse_attention_bf16_sm90_d7dba852(SparseAttentionParams& params);

Status sparse_attention_bf16_sm90_64x1x64x1x64x2_warps4xstages3(SparseAttentionParams& params) {
return sparse_attention_bf16_sm90_d7dba852(params);
}

// load for: sparse_attention_bf16_sm90_64x1x64x1x64x2_warps4xstages3
void load_sparse_attention_bf16_sm90_d7dba852();
void load_sparse_attention_bf16_sm90_64x1x64x1x64x2_warps4xstages3() {
load_sparse_attention_bf16_sm90_d7dba852();
}

// unload for: sparse_attention_bf16_sm90_64x1x64x1x64x2_warps4xstages3
void unload_sparse_attention_bf16_sm90_d7dba852();
void unload_sparse_attention_bf16_sm90_64x1x64x1x64x2_warps4xstages3() {
unload_sparse_attention_bf16_sm90_d7dba852();
}

typedef Status (*kernel_func_t)(SparseAttentionParams& params);
kernel_func_t sparse_attention_bf16_sm90_kernels[] = {
sparse_attention_bf16_sm90_64x0x64x0x64x2_warps4xstages3,
sparse_attention_bf16_sm90_64x1x64x1x64x2_warps4xstages3,
};

int sparse_attention_bf16_sm90_get_num_algos(void) {
return (int)sizeof(sparse_attention_bf16_sm90_kernels);

Check warning on line 59 in onnxruntime/contrib_ops/cuda/sparse/sparse_attention_v1/sparse_attention_dispatcher_bf16_sm90.h

View workflow job for this annotation

GitHub Actions / Lint C++

[cpplint] reported by reviewdog 🐶 Using C-style cast. Use static_cast<int>(...) instead [readability/casting] [4] Raw Output: onnxruntime/contrib_ops/cuda/sparse/sparse_attention_v1/sparse_attention_dispatcher_bf16_sm90.h:59: Using C-style cast. Use static_cast<int>(...) instead [readability/casting] [4]
}

Status sparse_attention_bf16_sm90(SparseAttentionParams& params, int algo_id) {
assert(algo_id < (int)sizeof(sparse_attention_bf16_sm90_kernels));

Check warning on line 63 in onnxruntime/contrib_ops/cuda/sparse/sparse_attention_v1/sparse_attention_dispatcher_bf16_sm90.h

View workflow job for this annotation

GitHub Actions / Lint C++

[cpplint] reported by reviewdog 🐶 Using C-style cast. Use static_cast<int>(...) instead [readability/casting] [4] Raw Output: onnxruntime/contrib_ops/cuda/sparse/sparse_attention_v1/sparse_attention_dispatcher_bf16_sm90.h:63: Using C-style cast. Use static_cast<int>(...) instead [readability/casting] [4]
return sparse_attention_bf16_sm90_kernels[algo_id](params);
}

void load_sparse_attention_bf16_sm90(void) {
load_sparse_attention_bf16_sm90_64x0x64x0x64x2_warps4xstages3();
load_sparse_attention_bf16_sm90_64x1x64x1x64x2_warps4xstages3();
}

void unload_sparse_attention_bf16_sm90(void) {
unload_sparse_attention_bf16_sm90_64x0x64x0x64x2_warps4xstages3();
unload_sparse_attention_bf16_sm90_64x1x64x1x64x2_warps4xstages3();
}

Status sparse_attention_bf16_sm90_default(SparseAttentionParams& params) {
return sparse_attention_bf16_sm90(params, 0);
}

} // namespace sparse_attention_v1
} // namespace cuda
} // namespace contrib
} // namespace onnxruntime
Loading
Loading