Skip to content

Commit

Permalink
[CUDA] Fix SparseAttention Kernel (microsoft#20716)
Browse files Browse the repository at this point in the history
### Description

Currently, there is one bool flag to indicate whether kernel is loaded.
However, there are v1 and v2 kernels, so the flag will allow only one
version of kernel loaded. We use v1 kernel for prompt and v2 kernel for
token generation, and the flag will cause issue when we want both prompt
and token generation.

This bug is found in integration test. The unit test only test one
kernel at a time so the issue was not found before.

Another possible walkaround without this fix is to set an environment
variable `ORT_DISABLE_SPARSE_ATTENTION_V1=1`
### Motivation and Context
<!-- - Why is this change required? What problem does it solve?
- If it fixes an open issue, please link to the issue here. -->
  • Loading branch information
tianleiwu authored May 18, 2024
1 parent d7f7c3b commit 2e7de54
Show file tree
Hide file tree
Showing 2 changed files with 21 additions and 27 deletions.
31 changes: 13 additions & 18 deletions onnxruntime/contrib_ops/cuda/sparse/sparse_attention.cc
Original file line number Diff line number Diff line change
Expand Up @@ -55,8 +55,6 @@ SparseAttention<T>::SparseAttention(const OpKernelInfo& info)

scale_ = info.GetAttrOrDefault<float>("scale", 0.0f);

kernel_loaded_ = false;

disable_v1_kernel_ = ParseEnvironmentVariableWithDefault<bool>(sparse_attention::kDisableSparseAttentionV1, false);
}

Expand Down Expand Up @@ -150,24 +148,21 @@ Status SparseAttention<T>::ComputeInternal(OpKernelContext* context) const {
CUDA_RETURN_IF_ERROR(cudaEventRecord(isCopyDone, cuda_stream));
}

if (!kernel_loaded_) {
if constexpr (std::is_same<T, MLFloat16>::value) {
// std::call_once is used in load_sparse_attention_fp16 so no need to use mutex here.
// After kernel is loaded, it will stay in memory until the process exits. We do not unload explicitly.
// TODO(tianleiwu): use TSharedCubinKernelFactory to manage kernel loading/unloading.
if (use_v2_kernel) {
sparse_attention_v2::load_sparse_attention_fp16(sm);
} else {
sparse_attention_v1::load_sparse_attention_fp16(sm);
}
if constexpr (std::is_same<T, MLFloat16>::value) {
// std::call_once is used in load_sparse_attention_fp16 so no need to use mutex here.
// After kernel is loaded, it will stay in memory until the process exits. We do not unload explicitly.
// TODO(tianleiwu): use TSharedCubinKernelFactory to manage kernel loading/unloading.
if (use_v2_kernel) {
sparse_attention_v2::load_sparse_attention_fp16(sm);
} else {
if (use_v2_kernel) {
sparse_attention_v2::load_sparse_attention_bf16(sm);
} else {
sparse_attention_v1::load_sparse_attention_bf16(sm);
}
sparse_attention_v1::load_sparse_attention_fp16(sm);
}
} else {
if (use_v2_kernel) {
sparse_attention_v2::load_sparse_attention_bf16(sm);
} else {
sparse_attention_v1::load_sparse_attention_bf16(sm);
}
kernel_loaded_ = true;
}

// Compute output shape and get output tensors.
Expand Down
17 changes: 8 additions & 9 deletions onnxruntime/contrib_ops/cuda/sparse/sparse_attention.h
Original file line number Diff line number Diff line change
Expand Up @@ -18,15 +18,14 @@ class SparseAttention final : public CudaKernel {
Status ComputeInternal(OpKernelContext* context) const override;

protected:
int num_heads_; // number of attention heads for q
int kv_num_heads_; // number of attention heads for k and v
float scale_; // Scaling factor applied prior to softmax.
bool is_causal_; // unidirectional attention or not
int sparse_block_size_; // block size for sparsity
bool do_rotary_; // Has rotary positional embedding
bool rotary_interleaved_; // Interleaved rotary positional embedding
bool disable_v1_kernel_; // Whether disable v1 kernel and use v2 kernel for prompt.
mutable bool kernel_loaded_; // Kernel has been loaded
int num_heads_; // number of attention heads for q
int kv_num_heads_; // number of attention heads for k and v
float scale_; // Scaling factor applied prior to softmax.
bool is_causal_; // unidirectional attention or not
int sparse_block_size_; // block size for sparsity
bool do_rotary_; // Has rotary positional embedding
bool rotary_interleaved_; // Interleaved rotary positional embedding
bool disable_v1_kernel_; // Whether disable v1 kernel and use v2 kernel for prompt.
};

} // namespace cuda
Expand Down

0 comments on commit 2e7de54

Please sign in to comment.