diff --git a/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash.h b/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash.h index 0aaf5e5f1ba28..89e2351428d40 100644 --- a/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash.h +++ b/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash.h @@ -18,81 +18,89 @@ constexpr int D_DIM = 2; struct Qkv_params { using index_t = uint32_t; // The QKV matrices. - void* __restrict__ q_ptr; - void* __restrict__ k_ptr; - void* __restrict__ v_ptr; + void* __restrict__ q_ptr = nullptr; + void* __restrict__ k_ptr = nullptr; + void* __restrict__ v_ptr = nullptr; // The stride between rows of the Q, K and V matrices. - index_t q_batch_stride; - index_t k_batch_stride; - index_t v_batch_stride; - index_t q_row_stride; - index_t k_row_stride; - index_t v_row_stride; - index_t q_head_stride; - index_t k_head_stride; - index_t v_head_stride; + index_t q_batch_stride = 0; + index_t k_batch_stride = 0; + index_t v_batch_stride = 0; + index_t q_row_stride = 0; + index_t k_row_stride = 0; + index_t v_row_stride = 0; + index_t q_head_stride = 0; + index_t k_head_stride = 0; + index_t v_head_stride = 0; // The number of heads. - int h, h_k; + int h = 0; + int h_k = 0; // In the case of multi-query and grouped-query attention (MQA/GQA), nheads_k could be // different from nheads (query). - int h_h_k_ratio; // precompute h / h_k, + int h_h_k_ratio = 0; // precompute h / h_k, }; //////////////////////////////////////////////////////////////////////////////////////////////////// struct Flash_fwd_params : public Qkv_params { // The O matrix (output). - void* __restrict__ o_ptr; - void* __restrict__ oaccum_ptr; + void* __restrict__ o_ptr = nullptr; + void* __restrict__ oaccum_ptr = nullptr; // The stride between rows of O. - index_t o_batch_stride; - index_t o_row_stride; - index_t o_head_stride; + index_t o_batch_stride = 0; + index_t o_row_stride = 0; + index_t o_head_stride = 0; // The pointer to the P matrix. - void* __restrict__ p_ptr; + void* __restrict__ p_ptr = nullptr; // The pointer to the softmax sum. - void* __restrict__ softmax_lse_ptr; - void* __restrict__ softmax_lseaccum_ptr; + void* __restrict__ softmax_lse_ptr = nullptr; + void* __restrict__ softmax_lseaccum_ptr = nullptr; // The dimensions. - int b, seqlen_q, seqlen_k, seqlen_knew, d, seqlen_q_rounded, seqlen_k_rounded, d_rounded; + int b = 0; + int seqlen_q = 0; + int seqlen_k = 0; + int seqlen_knew = 0; + int d = 0; + int seqlen_q_rounded = 0; + int seqlen_k_rounded = 0; + int d_rounded = 0; // The scaling factors for the kernel. - float scale_softmax; - float scale_softmax_log2; + float scale_softmax = 0.0; + float scale_softmax_log2 = 0.0; // array of length b+1 holding starting offset of each sequence. - int* __restrict__ cu_seqlens_q; - int* __restrict__ cu_seqlens_k; + int* __restrict__ cu_seqlens_q = nullptr; + int* __restrict__ cu_seqlens_k = nullptr; - int* __restrict__ blockmask; + int* __restrict__ blockmask = nullptr; // The K_new and V_new matrices. - void* __restrict__ knew_ptr; - void* __restrict__ vnew_ptr; + void* __restrict__ knew_ptr = nullptr; + void* __restrict__ vnew_ptr = nullptr; // The stride between rows of the Q, K and V matrices. - index_t knew_batch_stride; - index_t vnew_batch_stride; - index_t knew_row_stride; - index_t vnew_row_stride; - index_t knew_head_stride; - index_t vnew_head_stride; + index_t knew_batch_stride = 0; + index_t vnew_batch_stride = 0; + index_t knew_row_stride = 0; + index_t vnew_row_stride = 0; + index_t knew_head_stride = 0; + index_t vnew_head_stride = 0; bool is_bf16 = false; - bool is_causal; + bool is_causal = false; // If is_seqlens_k_cumulative, then seqlen_k is cu_seqlens_k[bidb + 1] - cu_seqlens_k[bidb]. // Otherwise it's cu_seqlens_k[bidb], i.e., we use cu_seqlens_k to store the sequence lengths of K. - bool is_seqlens_k_cumulative; - int num_splits; // For split-KV version + bool is_seqlens_k_cumulative = true; + int num_splits = 0; // For split-KV version - const cudaDeviceProp* dprops; + const cudaDeviceProp* dprops = nullptr; }; //////////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_api.cc b/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_api.cc index 805a73be96778..ff7a22d253a5b 100644 --- a/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_api.cc +++ b/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_api.cc @@ -215,7 +215,6 @@ Status mha_fwd(const cudaDeviceProp& dprops, const int seqlen_k_rounded = round_multiple(seqlen_k, 128); Flash_fwd_params params; - params.dprops = &dprops; set_params_fprop(params, batch_size, seqlen_q, seqlen_k, @@ -230,7 +229,7 @@ Status mha_fwd(const cudaDeviceProp& dprops, softmax_scale, is_causal, kv_bsnh); - + params.dprops = &dprops; params.knew_ptr = nullptr; params.vnew_ptr = nullptr; params.knew_batch_stride = 0; @@ -276,7 +275,6 @@ Status mha_varlen_fwd(const cudaDeviceProp& dprops, const int seqlen_k_rounded = round_multiple(max_seqlen_k, 128); Flash_fwd_params params; - params.dprops = &dprops; set_params_fprop(params, batch_size, max_seqlen_q, max_seqlen_k, @@ -290,6 +288,12 @@ Status mha_varlen_fwd(const cudaDeviceProp& dprops, softmax_lse, softmax_scale, is_causal); + params.dprops = &dprops; + params.num_splits = 0; + params.softmax_lseaccum_ptr = nullptr; + params.oaccum_ptr = nullptr; + params.knew_ptr = nullptr; + params.vnew_ptr = nullptr; run_mha_fwd(params, stream); return Status::OK(); } @@ -336,7 +340,6 @@ Status mha_fwd_kvcache(const cudaDeviceProp& dprops, const int seqlen_k_rounded = round_multiple(seqlen_k, 128); Flash_fwd_params params; - params.dprops = &dprops; set_params_fprop(params, batch_size, seqlen_q, seqlen_k, @@ -351,6 +354,7 @@ Status mha_fwd_kvcache(const cudaDeviceProp& dprops, softmax_scale, is_causal, past_bsnh); + params.dprops = &dprops; if (k != nullptr && v != nullptr) { params.seqlen_knew = seqlen_k_new;