Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

flashattnvarlen support tree attention #1188

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
58 changes: 56 additions & 2 deletions csrc/flash_attn/flash_api.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -551,7 +551,9 @@ mha_varlen_fwd(at::Tensor &q, // total_q x num_heads x head_size, total_q := \s
int window_size_right,
const float softcap,
const bool return_softmax,
c10::optional<at::Generator> gen_) {
c10::optional<at::Generator> gen_,
c10::optional<const at::Tensor> &tree_end_position_id_k_,
c10::optional<const at::Tensor> &tree_start_position_id_q_) {

auto dprops = at::cuda::getCurrentDeviceProperties();
// bool is_sm75 = dprops->major == 7 && dprops->minor == 5;
Expand Down Expand Up @@ -591,6 +593,22 @@ mha_varlen_fwd(at::Tensor &q, // total_q x num_heads x head_size, total_q := \s
CHECK_CONTIGUOUS(cu_seqlens_q);
CHECK_CONTIGUOUS(cu_seqlens_k);

TORCH_CHECK(tree_start_position_id_q_.has_value() == tree_end_position_id_k_.has_value(), "tree_start_position_id and tree_end_position_id must be passed together");
if (tree_end_position_id_k_.has_value()) {
const at::Tensor tree_end_position_id_k = tree_end_position_id_k_.value(), tree_start_position_id_q = tree_start_position_id_q_.value();
TORCH_CHECK(is_causal, "In tree attention, is_causal must be True");
TORCH_CHECK(window_size_left == -1 && window_size_right == -1, "In tree attention, is_local must be False");
TORCH_CHECK(!alibi_slopes_.has_value(), "tree attention does not support alibi");
TORCH_CHECK(tree_start_position_id_q.dtype() == torch::kInt32, "tree_start_position_id_q must have dtype int32");
TORCH_CHECK(tree_end_position_id_k.dtype() == torch::kInt32, "tree_end_position_id_k must have dtype int32");
TORCH_CHECK(tree_start_position_id_q.sizes().size() == 1, "tree_start_position_id_q must be 1D tensor");
TORCH_CHECK(tree_end_position_id_k.sizes().size() == 1, "tree_end_position_id_k must be 1D tensor");
TORCH_CHECK(tree_start_position_id_q.sizes()[0] == q.sizes()[0], "tree_start_position_id_q and q must have the same length");
TORCH_CHECK(tree_end_position_id_k.sizes()[0] == k.sizes()[0], "tree_end_position_id_k and k must have the same length");
CHECK_DEVICE(tree_start_position_id_q);
CHECK_DEVICE(tree_end_position_id_k);
}

const auto sizes = q.sizes();

const int batch_size = cu_seqlens_q.numel() - 1;
Expand Down Expand Up @@ -770,6 +788,16 @@ mha_varlen_fwd(at::Tensor &q, // total_q x num_heads x head_size, total_q := \s

set_params_alibi(params, alibi_slopes_, batch_size, num_heads);

if (tree_end_position_id_k_.has_value()) {
params.is_tree_attention = true;
params.tree_end_position_id_k = static_cast<int *>(tree_end_position_id_k_.value().data_ptr());
params.tree_start_position_id_q = static_cast<int *>(tree_start_position_id_q_.value().data_ptr());
} else {
params.is_tree_attention = false;
params.tree_end_position_id_k = nullptr;
params.tree_start_position_id_q = nullptr;
}

if (max_seqlen_k > 0) {
auto stream = at::cuda::getCurrentCUDAStream().stream();
run_mha_fwd(params, stream, paged_KV);
Expand Down Expand Up @@ -1062,7 +1090,9 @@ mha_varlen_bwd(const at::Tensor &dout, // total_q x num_heads, x head_size
const float softcap,
const bool deterministic,
c10::optional<at::Generator> gen_,
c10::optional<at::Tensor> &rng_state) {
c10::optional<at::Tensor> &rng_state,
c10::optional<const at::Tensor> &tree_end_position_id_k_,
c10::optional<const at::Tensor> &tree_start_position_id_q_) {

#ifdef FLASHATTENTION_DISABLE_BACKWARD
TORCH_CHECK(false, "This flash attention build does not support backward.");
Expand Down Expand Up @@ -1105,6 +1135,20 @@ mha_varlen_bwd(const at::Tensor &dout, // total_q x num_heads, x head_size
CHECK_CONTIGUOUS(cu_seqlens_q);
CHECK_CONTIGUOUS(cu_seqlens_k);

TORCH_CHECK(tree_start_position_id_q_.has_value() == tree_end_position_id_k_.has_value(), "tree_start_position_id and tree_end_position_id must be passed together");
if (tree_end_position_id_k_.has_value()) {
const at::Tensor tree_end_position_id_k = tree_end_position_id_k_.value(), tree_start_position_id_q = tree_start_position_id_q_.value();
TORCH_CHECK(is_causal, "In tree attention, is_causal must be True");
TORCH_CHECK(tree_start_position_id_q.dtype() == torch::kInt32, "tree_start_position_id_q must have dtype int32");
TORCH_CHECK(tree_end_position_id_k.dtype() == torch::kInt32, "tree_end_position_id_k must have dtype int32");
TORCH_CHECK(tree_start_position_id_q.sizes().size() == 1, "tree_start_position_id_q must be 1D tensor");
TORCH_CHECK(tree_end_position_id_k.sizes().size() == 1, "tree_end_position_id_k must be 1D tensor");
TORCH_CHECK(tree_start_position_id_q.sizes()[0] == q.sizes()[0], "tree_start_position_id_q and q must have the same length");
TORCH_CHECK(tree_end_position_id_k.sizes()[0] == k.sizes()[0], "tree_end_position_id_k and k must have the same length");
CHECK_DEVICE(tree_start_position_id_q);
CHECK_DEVICE(tree_end_position_id_k);
}

const auto sizes = q.sizes();

const int total_q = sizes[0];
Expand Down Expand Up @@ -1270,6 +1314,16 @@ mha_varlen_bwd(const at::Tensor &dout, // total_q x num_heads, x head_size

set_params_alibi(params, alibi_slopes_, batch_size, num_heads);

if (tree_end_position_id_k_.has_value()) {
params.is_tree_attention = true;
params.tree_end_position_id_k = static_cast<int *>(tree_end_position_id_k_.value().data_ptr());
params.tree_start_position_id_q = static_cast<int *>(tree_start_position_id_q_.value().data_ptr());
} else {
params.is_tree_attention = false;
params.tree_end_position_id_k = nullptr;
params.tree_start_position_id_q = nullptr;
}

if (max_seqlen_q > 0) {
launch(params, stream);
} else {
Expand Down
5 changes: 5 additions & 0 deletions csrc/flash_attn/src/flash.h
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,10 @@ struct Flash_fwd_params : public Qkv_params {
int * __restrict__ cu_seqlens_k;
int * __restrict__ leftpad_k;

// tree attention
int * __restrict__ tree_end_position_id_k;
int * __restrict__ tree_start_position_id_q;

// If provided, the actual length of each k sequence.
int * __restrict__ seqused_k;

Expand Down Expand Up @@ -129,6 +133,7 @@ struct Flash_fwd_params : public Qkv_params {

bool is_bf16;
bool is_causal;
bool is_tree_attention;

// If is_seqlens_k_cumulative, then seqlen_k is cu_seqlens_k[bidb + 1] - cu_seqlens_k[bidb].
// Otherwise it's cu_seqlens_k[bidb], i.e., we use cu_seqlens_k to store the sequence lengths of K.
Expand Down
24 changes: 16 additions & 8 deletions csrc/flash_attn/src/flash_bwd_kernel.h
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ make_tiled_copy_C_warpcontiguousN(Copy_Atom<Args...> const& copy_atom,

////////////////////////////////////////////////////////////////////////////////////////////////////

template<typename Kernel_traits, bool Is_dropout, bool Is_causal, bool Is_local, bool Has_alibi, bool Is_even_MN, bool Is_even_K, bool Is_softcap, bool Is_first, bool Is_last, bool Seq_parallel=false, typename Params>
template<typename Kernel_traits, bool Is_dropout, bool Is_causal, bool Is_local, bool Has_alibi, bool Is_even_MN, bool Is_even_K, bool Is_softcap, bool Is_tree_attention, bool Is_first, bool Is_last, bool Seq_parallel=false, typename Params>
inline __device__ void compute_dq_dk_dv_1colblock(const Params &params, const int bidb, const int bidh, const int n_block) {

using Element = typename Kernel_traits::Element;
Expand Down Expand Up @@ -503,18 +503,26 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params &params, const in
flash::apply_mask(scores, binfo.actual_seqlen_k,
n_block * kBlockN + (tidx / 32 / AtomLayoutMS) * MMA_N_SdP * 16);
}
} else if (Is_tree_attention) {
flash::apply_mask_causal_tree_attention(scores, n_block * kBlockN + (tidx / 32 / AtomLayoutMS) * MMA_N_SdP * 16,
binfo.actual_seqlen_k, m_block * kBlockM + get<0>(taccScS_row(0)),
binfo.actual_seqlen_q,
// binfo.actual_seqlen_k, m_block * kBlockM + (tidx / 32) % AtomLayoutMS * 16 + (tidx % 32) / 4,
AtomLayoutMS * 16,
params.tree_end_position_id_k + binfo.sum_s_k,
params.tree_start_position_id_q + binfo.sum_s_q);
} else if (Is_causal) {
// Putting this causal masking right after acc_s is *much* slower for some reason.
// TD [2023-08-16]: We need the 2nd condition because if seqlen_q is long and seqlen_k is short
// (e.g., 256 and 2), the 2nd block of seqlen_q (from 128 to 255), we're not doing causal masking.
// But we still want to mask out elements beyond actual_seqlen_k.
if (m_block * kBlockM < (n_block + 1) * kBlockN + binfo.actual_seqlen_q - binfo.actual_seqlen_k
|| (!Is_even_MN && (n_block + 1) * kBlockN >= binfo.actual_seqlen_k)) {
flash::apply_mask_causal(scores, n_block * kBlockN + (tidx / 32 / AtomLayoutMS) * MMA_N_SdP * 16,
binfo.actual_seqlen_k, m_block * kBlockM + get<0>(taccScS_row(0)),
binfo.actual_seqlen_q,
// binfo.actual_seqlen_k, m_block * kBlockM + (tidx / 32) % AtomLayoutMS * 16 + (tidx % 32) / 4,
AtomLayoutMS * 16);
flash::apply_mask_causal(scores, n_block * kBlockN + (tidx / 32 / AtomLayoutMS) * MMA_N_SdP * 16,
binfo.actual_seqlen_k, m_block * kBlockM + get<0>(taccScS_row(0)),
binfo.actual_seqlen_q,
// binfo.actual_seqlen_k, m_block * kBlockM + (tidx / 32) % AtomLayoutMS * 16 + (tidx % 32) / 4,
AtomLayoutMS * 16);
}
} else if (Is_local) {
if (m_block * kBlockM < (n_block + 1) * kBlockN + binfo.actual_seqlen_q - binfo.actual_seqlen_k - params.window_size_right
Expand Down Expand Up @@ -820,7 +828,7 @@ inline __device__ void compute_dq_dk_dv(const Params &params) {

////////////////////////////////////////////////////////////////////////////////////////////////////

template<typename Kernel_traits, bool Is_dropout, bool Is_causal, bool Is_local, bool Has_alibi, bool Is_even_MN, bool Is_even_K, bool Is_softcap, typename Params>
template<typename Kernel_traits, bool Is_dropout, bool Is_causal, bool Is_local, bool Has_alibi, bool Is_even_MN, bool Is_even_K, bool Is_softcap, bool Is_tree_attention, typename Params>
inline __device__ void compute_dq_dk_dv_seqk_parallel(const Params &params) {

// The block index for the batch.
Expand All @@ -830,7 +838,7 @@ inline __device__ void compute_dq_dk_dv_seqk_parallel(const Params &params) {

// If deterministic, each thread block will do atomicAdd to a different dQ_accum buffer.
for (int n_block = blockIdx.x; n_block < (params.seqlen_k + Kernel_traits::kBlockN - 1) / Kernel_traits::kBlockN; n_block += gridDim.x) {
compute_dq_dk_dv_1colblock<Kernel_traits, Is_dropout, Is_causal, Is_local, Has_alibi, Is_even_MN, Is_even_K, Is_softcap, false, false, /*Seq_parallel=*/true>(params, bidb, bidh, n_block);
compute_dq_dk_dv_1colblock<Kernel_traits, Is_dropout, Is_causal, Is_local, Has_alibi, Is_even_MN, Is_even_K, Is_softcap, Is_tree_attention, false, false, /*Seq_parallel=*/true>(params, bidb, bidh, n_block);
}
}

Expand Down
28 changes: 15 additions & 13 deletions csrc/flash_attn/src/flash_bwd_launch_template.h
Original file line number Diff line number Diff line change
Expand Up @@ -35,10 +35,10 @@ DEFINE_FLASH_BACKWARD_KERNEL(flash_bwd_dq_dk_dv_loop_kernel, bool Is_dropout, bo
#endif
}

DEFINE_FLASH_BACKWARD_KERNEL(flash_bwd_dq_dk_dv_loop_seqk_parallel_kernel, bool Is_dropout, bool Is_causal, bool Is_local, bool Has_alibi, bool Is_even_MN, bool Is_even_K, bool Is_softcap) {
DEFINE_FLASH_BACKWARD_KERNEL(flash_bwd_dq_dk_dv_loop_seqk_parallel_kernel, bool Is_dropout, bool Is_causal, bool Is_local, bool Has_alibi, bool Is_even_MN, bool Is_even_K, bool Is_softcap, bool Is_tree_attention) {
#if defined(ARCH_SUPPORTS_FLASH)
static_assert(!(Is_causal && Is_local)); // If Is_local is true, Is_causal should be false
flash::compute_dq_dk_dv_seqk_parallel<Kernel_traits, Is_dropout, Is_causal, Is_local, Has_alibi, Is_even_MN, Is_even_K, Is_softcap>(params);
flash::compute_dq_dk_dv_seqk_parallel<Kernel_traits, Is_dropout, Is_causal, Is_local, Has_alibi, Is_even_MN, Is_even_K, Is_softcap, Is_tree_attention>(params);
#else
FLASH_UNSUPPORTED_ARCH
#endif
Expand Down Expand Up @@ -95,17 +95,19 @@ void run_flash_bwd_seqk_parallel(Flash_bwd_params &params, cudaStream_t stream)
LOCAL_SWITCH((params.window_size_left >= 0 || params.window_size_right >= 0) && !params.is_causal, Is_local, [&] {
ALIBI_SWITCH(params.alibi_slopes_ptr != nullptr, Has_alibi, [&] {
SOFTCAP_SWITCH(params.softcap > 0.0, Is_softcap, [&] {
// If not IsEvenKConst, we also set IsEvenMNConst to false to reduce number of templates.
// If head dim > 128, set IsEvenMNConst to false to reduce number of templates
// If Is_local, set Is_causal to false
auto kernel = &flash_bwd_dq_dk_dv_loop_seqk_parallel_kernel<Kernel_traits, Is_dropout && !Is_softcap, Is_causal, Is_local && !Is_causal, Has_alibi, IsEvenMNConst && IsEvenKConst && !Is_local && Kernel_traits::kHeadDim <= 128, IsEvenKConst, Is_softcap>;
// auto kernel = &flash_bwd_dq_dk_dv_loop_seqk_parallel_kernel<Kernel_traits, false, Is_causal, false, false, true, true>;
if (smem_size_dq_dk_dv >= 48 * 1024) {
C10_CUDA_CHECK(cudaFuncSetAttribute(
kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size_dq_dk_dv));
}
kernel<<<grid_n, Kernel_traits::kNThreads, smem_size_dq_dk_dv, stream>>>(params);
C10_CUDA_KERNEL_LAUNCH_CHECK();
TREE_ATTENTION_SWITCH(params.is_tree_attention, Is_tree_attention, [&] {
// If not IsEvenKConst, we also set IsEvenMNConst to false to reduce number of templates.
// If head dim > 128, set IsEvenMNConst to false to reduce number of templates
// If Is_local, set Is_causal to false
auto kernel = &flash_bwd_dq_dk_dv_loop_seqk_parallel_kernel<Kernel_traits, Is_dropout && !Is_softcap, Is_causal, Is_local && !Is_causal, Has_alibi, IsEvenMNConst && IsEvenKConst && !Is_local && Kernel_traits::kHeadDim <= 128, IsEvenKConst, Is_softcap, Is_tree_attention>;
// auto kernel = &flash_bwd_dq_dk_dv_loop_seqk_parallel_kernel<Kernel_traits, false, Is_causal, false, false, true, true>;
if (smem_size_dq_dk_dv >= 48 * 1024) {
C10_CUDA_CHECK(cudaFuncSetAttribute(
kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size_dq_dk_dv));
}
kernel<<<grid_n, Kernel_traits::kNThreads, smem_size_dq_dk_dv, stream>>>(params);
C10_CUDA_KERNEL_LAUNCH_CHECK();
});
});
});
});
Expand Down
Loading