Skip to content

Commit

Permalink
Fix Sparse Attention with Packed QKV inputs (microsoft#20591)
Browse files Browse the repository at this point in the history
### Description
(1) Fix UnpackQKV kernel
(2) Update test_sparse_attention.py with packed QKV option
  • Loading branch information
tianleiwu authored and poweiw committed Jun 25, 2024
1 parent 3d594ba commit b6b3507
Show file tree
Hide file tree
Showing 3 changed files with 226 additions and 386 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -531,19 +531,22 @@ __global__ void UnpackQKV(const T* packed_qkv, T* unpacked_q, T* unpacked_k, T*
int offset = tid % d;
if (output_bnsh) { // output BNSH
int head_count = kv_num_heads;
T* unpacked;
if (offset < q_hidden) {
unpacked = unpacked_q;
head_count = num_heads;
} else if (offset < q_hidden + k_hidden) {
unpacked = unpacked_k;
offset -= q_hidden;
} else {
unpacked = unpacked_v;
offset -= (q_hidden + k_hidden);
}
int n = offset / head_size;
int h = offset % head_size;

int unpacked_i = INDEX_4D(head_count, sequence_length, head_size, b, n, s, h);

unpacked_q[unpacked_i] = packed_qkv[tid];
unpacked[unpacked_i] = packed_qkv[tid];
} else { // output BSNH
if (offset < q_hidden) {
int unpacked_i = b * sequence_length * num_heads * head_size + s * num_heads * head_size + offset;
Expand Down
16 changes: 10 additions & 6 deletions onnxruntime/contrib_ops/cuda/sparse/sparse_attention_impl.cu
Original file line number Diff line number Diff line change
Expand Up @@ -138,6 +138,7 @@ Status QkvToContext(
auto q = reinterpret_cast<T*>(data.unpacked_qkv_buffer);
auto k = reinterpret_cast<T*>(data.unpacked_qkv_buffer + q_size);
auto v = reinterpret_cast<T*>(data.unpacked_qkv_buffer + q_size + k_size);

Status status = LaunchUnpackQKV<T, LAYOUT_BNSH>(data.query, q, k, v, num_heads, kv_num_heads, head_size,
sequence_length, batch_size, stream, max_threads_per_block);
if (status != Status::OK()) {
Expand All @@ -152,15 +153,15 @@ Status QkvToContext(
constexpr bool q_layout = LAYOUT_BNSH;
bool kv_layout = parameters.is_packed_qkv ? LAYOUT_BNSH : LAYOUT_BSNH;

DUMP_TENSOR("query", reinterpret_cast<const T*>(query), batch_size, num_heads, sequence_length, head_size);

#if DUMP_TENSOR_LEVEL > 0
DUMP_TENSOR("query (BNSH)", reinterpret_cast<const T*>(query), batch_size, num_heads, sequence_length, head_size);

if (LAYOUT_BNSH == kv_layout) {
DUMP_TENSOR("key", reinterpret_cast<const T*>(key), batch_size, kv_num_heads, sequence_length, head_size);
DUMP_TENSOR("value", reinterpret_cast<const T*>(value), batch_size, kv_num_heads, sequence_length, head_size);
DUMP_TENSOR("key (BNSH)", reinterpret_cast<const T*>(key), batch_size, kv_num_heads, sequence_length, head_size);
DUMP_TENSOR("value (BNSH)", reinterpret_cast<const T*>(value), batch_size, kv_num_heads, sequence_length, head_size);
} else {
DUMP_TENSOR("key", reinterpret_cast<const T*>(key), batch_size, sequence_length, kv_num_heads, head_size);
DUMP_TENSOR("value", reinterpret_cast<const T*>(value), batch_size, sequence_length, kv_num_heads, head_size);
DUMP_TENSOR("key (BSNH)", reinterpret_cast<const T*>(key), batch_size, sequence_length, kv_num_heads, head_size);
DUMP_TENSOR("value (BSNH)", reinterpret_cast<const T*>(value), batch_size, sequence_length, kv_num_heads, head_size);
}
#endif

Expand Down Expand Up @@ -317,6 +318,9 @@ Status QkvToContext(
ORT_RETURN_IF_ERROR(sparse_attention_v1::run_sparse_attention_fp16(params));
}
}

DUMP_TENSOR("output", reinterpret_cast<const T*>(data.output), batch_size, num_heads, sequence_length, head_size);

return Status::OK();
}

Expand Down
Loading

0 comments on commit b6b3507

Please sign in to comment.