From 11a355cc1a68c950bf4fa23258b1d3e36a15eb49 Mon Sep 17 00:00:00 2001 From: ispobock Date: Sat, 17 Aug 2024 23:08:33 +0800 Subject: [PATCH 1/2] add grouped decode for GQA --- python/sglang/srt/layers/decode_attention.py | 368 ++++++++++++++++--- 1 file changed, 318 insertions(+), 50 deletions(-) diff --git a/python/sglang/srt/layers/decode_attention.py b/python/sglang/srt/layers/decode_attention.py index c868299ef4..536af8e52d 100644 --- a/python/sglang/srt/layers/decode_attention.py +++ b/python/sglang/srt/layers/decode_attention.py @@ -39,7 +39,6 @@ def tanh(x): # Tanh is just a scaled sigmoid return 2 * tl.sigmoid(2 * x) - 1 - @triton.jit def _fwd_kernel_stage1( Q, @@ -58,7 +57,6 @@ def _fwd_kernel_stage1( att_stride_h, kv_group_num: tl.constexpr, BLOCK_DMODEL: tl.constexpr, - BLOCK_DPE: tl.constexpr, BLOCK_N: tl.constexpr, logit_cap: tl.constexpr, ): @@ -78,10 +76,6 @@ def _fwd_kernel_stage1( off_q = cur_batch * stride_qbs + cur_head * stride_qh + offs_d - if BLOCK_DPE > 0: - offs_dpe = BLOCK_DMODEL + tl.arange(0, BLOCK_DPE) - off_qpe = cur_batch * stride_qbs + cur_head * stride_qh + offs_dpe - offs_n = start_n * BLOCK_N + tl.arange(0, BLOCK_N) block_stard_index = start_n * BLOCK_N @@ -106,19 +100,6 @@ def _fwd_kernel_stage1( other=0.0, ).to(REDUCE_TRITON_TYPE) att_value = tl.sum(q[None, :] * k, 1) - if BLOCK_DPE > 0: - qpe = tl.load(Q + off_qpe + start_mark).to(REDUCE_TRITON_TYPE) - offs_buf_kpe = ( - k_loc[:, None] * stride_buf_kbs - + cur_kv_head * stride_buf_kh - + offs_dpe[None, :] - ) - kpe = tl.load( - K_Buffer + offs_buf_kpe, - mask=offs_n_new[:, None] < cur_batch_end_index, - other=0.0, - ).to(REDUCE_TRITON_TYPE) - att_value += tl.sum(qpe[None, :] * kpe, 1) att_value *= sm_scale if logit_cap > 0: @@ -214,14 +195,7 @@ def _decode_att_m_fwd( # shape constraints Lq, Lk = q.shape[-1], k_buffer.shape[-1] assert Lq == Lk - assert Lk in {16, 32, 64, 128, 256, 576} - - if Lk == 576: - BLOCK_DMODEL = 512 - BLOCK_DPE = 64 - else: - BLOCK_DMODEL = Lk - BLOCK_DPE = 0 + assert Lk in {16, 32, 64, 128, 256} batch, head_num = B_req_idx.shape[0], q.shape[1] @@ -249,8 +223,7 @@ def _decode_att_m_fwd( k_buffer.stride(1), att_out.stride(0), kv_group_num=kv_group_num, - BLOCK_DMODEL=BLOCK_DMODEL, - BLOCK_DPE=BLOCK_DPE, + BLOCK_DMODEL=Lk, BLOCK_N=BLOCK, logit_cap=logit_cap, num_warps=num_warps, @@ -295,6 +268,274 @@ def _decode_softmax_reducev_fwd( num_stages=3, ) +@triton.jit +def _fwd_grouped_kernel_stage1( + Q, + K_Buffer, + sm_scale, + Req_to_tokens, + B_req_idx, + B_Start_Loc, + B_Seqlen, + Att_Out, + stride_req_to_tokens_b, + stride_qbs, + stride_qh, + stride_buf_kbs, + stride_buf_kh, + att_stride_h, + kv_group_num: tl.constexpr, + q_head_num: tl.constexpr, + BLOCK_DMODEL: tl.constexpr, + BLOCK_DPE: tl.constexpr, + BLOCK_N: tl.constexpr, + BLOCK_H: tl.constexpr, + logit_cap: tl.constexpr, +): + cur_batch = tl.program_id(0) + cur_kv_head = tl.program_id(1) + start_n = tl.program_id(2) + + cur_head = cur_kv_head * kv_group_num + tl.arange(0, BLOCK_H) + mask_h = cur_head < (cur_kv_head+1) * kv_group_num + mask_h = mask_h & (cur_head < q_head_num) + + offs_d = tl.arange(0, BLOCK_DMODEL) + cur_batch_seq_len = tl.load(B_Seqlen + cur_batch) + cur_batch_in_all_start_index = tl.load(B_Start_Loc + cur_batch) + cur_batch_req_idx = tl.load(B_req_idx + cur_batch) + + cur_batch_start_index = 0 + cur_batch_end_index = cur_batch_seq_len + + offs_q = cur_batch * stride_qbs + cur_head[:, None] * stride_qh + offs_d[None, :] + + if BLOCK_DPE > 0: + offs_dpe = BLOCK_DMODEL + tl.arange(0, BLOCK_DPE) + off_qpe = cur_batch * stride_qbs + cur_head[:, None] * stride_qh + offs_dpe[None, :] + + offs_n = start_n * BLOCK_N + tl.arange(0, BLOCK_N) + + block_stard_index = start_n * BLOCK_N + block_mask = tl.where(block_stard_index < cur_batch_seq_len, 1, 0) + + for start_mark in range(0, block_mask, 1): + q = tl.load(Q + offs_q + start_mark, mask=mask_h[:, None]).to(REDUCE_TRITON_TYPE) + offs_n_new = cur_batch_start_index + offs_n + k_loc = tl.load( + Req_to_tokens + stride_req_to_tokens_b * cur_batch_req_idx + offs_n_new, + mask=offs_n_new < cur_batch_end_index, + other=0, + ) + offs_buf_k = ( + k_loc[None, :] * stride_buf_kbs + + cur_kv_head * stride_buf_kh + + offs_d[:, None] + ) + k = tl.load( + K_Buffer + offs_buf_k, + mask=offs_n_new[None, :] < cur_batch_end_index, + other=0.0, + ).to(REDUCE_TRITON_TYPE) + qk = tl.dot(q, k) + if BLOCK_DPE > 0: + qpe = tl.load(Q + off_qpe + start_mark, mask=mask_h[:, None]).to(REDUCE_TRITON_TYPE) + offs_buf_kpe = ( + k_loc[None, :] * stride_buf_kbs + + cur_kv_head * stride_buf_kh + + offs_dpe[:, None] + ) + kpe = tl.load( + K_Buffer + offs_buf_kpe, + mask=offs_n_new[None, :] < cur_batch_end_index, + other=0.0, + ).to(REDUCE_TRITON_TYPE) + qk += tl.dot(qpe, kpe) + qk *= sm_scale + + if logit_cap > 0: + qk = logit_cap * tanh(qk / logit_cap) + + offs_o = cur_head[:, None] * att_stride_h + (cur_batch_in_all_start_index + offs_n[None, :]) + + tl.store(Att_Out + offs_o, qk, mask=mask_h[:, None] & (offs_n_new[None, :] < cur_batch_end_index)) + + +@triton.jit +def _fwd_grouped_kernel_stage2( + Logics, + V_Buffer, + Out, + Req_to_tokens, + B_req_idx, + B_Start_Loc, + B_Seqlen, + stride_logic_h, + stride_buf_vbs, + stride_buf_vh, + stride_obs, + stride_oh, + stride_req_to_token_b, + kv_group_num: tl.constexpr, + q_head_num: tl.constexpr, + BLOCK_DMODEL: tl.constexpr, + BLOCK_N: tl.constexpr, + BLOCK_H: tl.constexpr, +): + cur_batch = tl.program_id(0) + cur_kv_head = tl.program_id(1) + + cur_head = cur_kv_head * kv_group_num + tl.arange(0, BLOCK_H) + mask_h = cur_head < (cur_kv_head+1) * kv_group_num + mask_h = mask_h & (cur_head < q_head_num) + + cur_batch_seq_len = tl.load(B_Seqlen + cur_batch) + cur_batch_start_loc = tl.load(B_Start_Loc + cur_batch) + cur_batch_req_idx = tl.load(B_req_idx + cur_batch) + + offs_n = tl.arange(0, BLOCK_N) + offs_d = tl.arange(0, BLOCK_DMODEL) + + offs_buf_v = cur_kv_head * stride_buf_vh + offs_d[None, :] + v_ptrs = V_Buffer + offs_buf_v + + e_max = tl.zeros([BLOCK_H], dtype=tl.float32) - float("inf") + e_sum = tl.zeros([BLOCK_H], dtype=tl.float32) + acc = tl.zeros([BLOCK_H, BLOCK_DMODEL], dtype=tl.float32) + + for start_n in range(0, cur_batch_seq_len, BLOCK_N): + start_n = tl.multiple_of(start_n, BLOCK_N) + v_index = tl.load( + Req_to_tokens + + cur_batch_req_idx * stride_req_to_token_b + + (start_n + offs_n), + mask=(start_n + offs_n) < cur_batch_seq_len, + other=0, + ) + + offs_qk = cur_head[:, None] * stride_logic_h + (cur_batch_start_loc + start_n + offs_n[None, :]) + + qk = tl.load( + Logics + offs_qk, + mask=mask_h[:, None] & (start_n + offs_n[None, :] < cur_batch_seq_len), + other=float("-inf"), + ) + + n_e_max = tl.maximum(tl.max(qk, 1), e_max) + old_scale = tl.exp(e_max - n_e_max) + p = tl.exp(qk - n_e_max[:, None]) + e_sum = e_sum * old_scale + tl.sum(p, 1) + v = tl.load(v_ptrs + v_index[:, None] * stride_buf_vbs) + p = p.to(v.dtype) + acc = acc * old_scale[:, None] + tl.dot(p, v) + e_max = n_e_max + + acc = acc / e_sum[:, None] + off_o = cur_batch * stride_obs + cur_head[:, None] * stride_oh + offs_d[None, :] + out_ptrs = Out + off_o + tl.store(out_ptrs, acc, mask=mask_h[:, None]) + + +def _decode_grouped_att_m_fwd( + q, + k_buffer, + att_out, + Req_to_tokens, + B_req_idx, + B_Start_Loc, + B_Seqlen, + max_len_in_batch, + sm_scale, + logit_cap, +): + BLOCK = 32 + # shape constraints + Lq, Lk = q.shape[-1], k_buffer.shape[-1] + assert Lq == Lk + assert Lk in {16, 32, 64, 128, 256, 576} + + if Lk == 576: + BLOCK_DMODEL = 512 + BLOCK_DPE = 64 + else: + BLOCK_DMODEL = Lk + BLOCK_DPE = 0 + + batch, head_num = B_req_idx.shape[0], q.shape[1] + kv_group_num = q.shape[1] // k_buffer.shape[1] + + BLOCK_H = max(16, triton.next_power_of_2(kv_group_num)) + grid = (batch, triton.cdiv(head_num, min(BLOCK_H, kv_group_num)), triton.cdiv(max_len_in_batch, BLOCK)) + + num_warps = 4 + + _fwd_grouped_kernel_stage1[grid]( + q, + k_buffer, + sm_scale, + Req_to_tokens, + B_req_idx, + B_Start_Loc, + B_Seqlen, + att_out, + Req_to_tokens.stride(0), + q.stride(0), + q.stride(1), + k_buffer.stride(0), + k_buffer.stride(1), + att_out.stride(0), + kv_group_num=kv_group_num, + q_head_num=head_num, + BLOCK_DMODEL=BLOCK_DMODEL, + BLOCK_DPE=BLOCK_DPE, + BLOCK_N=BLOCK, + BLOCK_H=BLOCK_H, + logit_cap=logit_cap, + num_warps=num_warps, + num_stages=1, + ) + + +def _decode_grouped_softmax_reducev_fwd( + logics, + v_buffer, + o, + req_to_tokens, + b_req_idx, + b_start_loc, + b_seq_len, +): + BLOCK = 128 + batch, head_num = b_seq_len.shape[0], logics.shape[0] + kv_group_num = logics.shape[0] // v_buffer.shape[1] + BLOCK_H = max(16, triton.next_power_of_2(kv_group_num)) + grid = (batch, triton.cdiv(head_num, min(BLOCK_H, kv_group_num)), 1) + + num_warps = 8 + + _fwd_grouped_kernel_stage2[grid]( + logics, + v_buffer, + o, + req_to_tokens, + b_req_idx, + b_start_loc, + b_seq_len, + logics.stride(0), + v_buffer.stride(0), + v_buffer.stride(1), + o.stride(0), + o.stride(1), + req_to_tokens.stride(0), + kv_group_num=kv_group_num, + q_head_num=head_num, + BLOCK_DMODEL=v_buffer.shape[-1], + BLOCK_N=BLOCK, + BLOCK_H=BLOCK_H, + num_warps=num_warps, + num_stages=1, + ) + def decode_attention_fwd( q, @@ -315,25 +556,52 @@ def decode_attention_fwd( att_m = torch.empty( (q.shape[-2], total_num_tokens), dtype=REDUCE_TORCH_TYPE, device="cuda" ) + + kv_group_num = q.shape[1] // v_buffer.shape[1] - _decode_att_m_fwd( - q, - k_buffer, - att_m, - req_to_token, - b_req_idx, - b_start_loc, - b_seq_len, - max_len_in_batch, - sm_scale, - logit_cap, - ) - _decode_softmax_reducev_fwd( - att_m, - v_buffer, - o, - req_to_token, - b_req_idx, - b_start_loc, - b_seq_len, - ) + if kv_group_num == 1: + # MHA + _decode_att_m_fwd( + q, + k_buffer, + att_m, + req_to_token, + b_req_idx, + b_start_loc, + b_seq_len, + max_len_in_batch, + sm_scale, + logit_cap, + ) + _decode_softmax_reducev_fwd( + att_m, + v_buffer, + o, + req_to_token, + b_req_idx, + b_start_loc, + b_seq_len, + ) + else: + # GQA/MQA/MLA + _decode_grouped_att_m_fwd( + q, + k_buffer, + att_m, + req_to_token, + b_req_idx, + b_start_loc, + b_seq_len, + max_len_in_batch, + sm_scale, + logit_cap, + ) + _decode_grouped_softmax_reducev_fwd( + att_m, + v_buffer, + o, + req_to_token, + b_req_idx, + b_start_loc, + b_seq_len, + ) \ No newline at end of file From 448879ae4663038ae8a35fd9377afe740bbda5f3 Mon Sep 17 00:00:00 2001 From: ispobock Date: Sun, 18 Aug 2024 00:01:35 +0800 Subject: [PATCH 2/2] fix format --- python/sglang/srt/layers/decode_attention.py | 42 +++++++++++++++----- 1 file changed, 31 insertions(+), 11 deletions(-) diff --git a/python/sglang/srt/layers/decode_attention.py b/python/sglang/srt/layers/decode_attention.py index 536af8e52d..eef3c00096 100644 --- a/python/sglang/srt/layers/decode_attention.py +++ b/python/sglang/srt/layers/decode_attention.py @@ -39,6 +39,7 @@ def tanh(x): # Tanh is just a scaled sigmoid return 2 * tl.sigmoid(2 * x) - 1 + @triton.jit def _fwd_kernel_stage1( Q, @@ -268,6 +269,7 @@ def _decode_softmax_reducev_fwd( num_stages=3, ) + @triton.jit def _fwd_grouped_kernel_stage1( Q, @@ -297,7 +299,7 @@ def _fwd_grouped_kernel_stage1( start_n = tl.program_id(2) cur_head = cur_kv_head * kv_group_num + tl.arange(0, BLOCK_H) - mask_h = cur_head < (cur_kv_head+1) * kv_group_num + mask_h = cur_head < (cur_kv_head + 1) * kv_group_num mask_h = mask_h & (cur_head < q_head_num) offs_d = tl.arange(0, BLOCK_DMODEL) @@ -312,7 +314,9 @@ def _fwd_grouped_kernel_stage1( if BLOCK_DPE > 0: offs_dpe = BLOCK_DMODEL + tl.arange(0, BLOCK_DPE) - off_qpe = cur_batch * stride_qbs + cur_head[:, None] * stride_qh + offs_dpe[None, :] + off_qpe = ( + cur_batch * stride_qbs + cur_head[:, None] * stride_qh + offs_dpe[None, :] + ) offs_n = start_n * BLOCK_N + tl.arange(0, BLOCK_N) @@ -320,7 +324,9 @@ def _fwd_grouped_kernel_stage1( block_mask = tl.where(block_stard_index < cur_batch_seq_len, 1, 0) for start_mark in range(0, block_mask, 1): - q = tl.load(Q + offs_q + start_mark, mask=mask_h[:, None]).to(REDUCE_TRITON_TYPE) + q = tl.load(Q + offs_q + start_mark, mask=mask_h[:, None]).to( + REDUCE_TRITON_TYPE + ) offs_n_new = cur_batch_start_index + offs_n k_loc = tl.load( Req_to_tokens + stride_req_to_tokens_b * cur_batch_req_idx + offs_n_new, @@ -339,7 +345,9 @@ def _fwd_grouped_kernel_stage1( ).to(REDUCE_TRITON_TYPE) qk = tl.dot(q, k) if BLOCK_DPE > 0: - qpe = tl.load(Q + off_qpe + start_mark, mask=mask_h[:, None]).to(REDUCE_TRITON_TYPE) + qpe = tl.load(Q + off_qpe + start_mark, mask=mask_h[:, None]).to( + REDUCE_TRITON_TYPE + ) offs_buf_kpe = ( k_loc[None, :] * stride_buf_kbs + cur_kv_head * stride_buf_kh @@ -356,9 +364,15 @@ def _fwd_grouped_kernel_stage1( if logit_cap > 0: qk = logit_cap * tanh(qk / logit_cap) - offs_o = cur_head[:, None] * att_stride_h + (cur_batch_in_all_start_index + offs_n[None, :]) + offs_o = cur_head[:, None] * att_stride_h + ( + cur_batch_in_all_start_index + offs_n[None, :] + ) - tl.store(Att_Out + offs_o, qk, mask=mask_h[:, None] & (offs_n_new[None, :] < cur_batch_end_index)) + tl.store( + Att_Out + offs_o, + qk, + mask=mask_h[:, None] & (offs_n_new[None, :] < cur_batch_end_index), + ) @triton.jit @@ -386,7 +400,7 @@ def _fwd_grouped_kernel_stage2( cur_kv_head = tl.program_id(1) cur_head = cur_kv_head * kv_group_num + tl.arange(0, BLOCK_H) - mask_h = cur_head < (cur_kv_head+1) * kv_group_num + mask_h = cur_head < (cur_kv_head + 1) * kv_group_num mask_h = mask_h & (cur_head < q_head_num) cur_batch_seq_len = tl.load(B_Seqlen + cur_batch) @@ -413,7 +427,9 @@ def _fwd_grouped_kernel_stage2( other=0, ) - offs_qk = cur_head[:, None] * stride_logic_h + (cur_batch_start_loc + start_n + offs_n[None, :]) + offs_qk = cur_head[:, None] * stride_logic_h + ( + cur_batch_start_loc + start_n + offs_n[None, :] + ) qk = tl.load( Logics + offs_qk, @@ -465,7 +481,11 @@ def _decode_grouped_att_m_fwd( kv_group_num = q.shape[1] // k_buffer.shape[1] BLOCK_H = max(16, triton.next_power_of_2(kv_group_num)) - grid = (batch, triton.cdiv(head_num, min(BLOCK_H, kv_group_num)), triton.cdiv(max_len_in_batch, BLOCK)) + grid = ( + batch, + triton.cdiv(head_num, min(BLOCK_H, kv_group_num)), + triton.cdiv(max_len_in_batch, BLOCK), + ) num_warps = 4 @@ -556,7 +576,7 @@ def decode_attention_fwd( att_m = torch.empty( (q.shape[-2], total_num_tokens), dtype=REDUCE_TORCH_TYPE, device="cuda" ) - + kv_group_num = q.shape[1] // v_buffer.shape[1] if kv_group_num == 1: @@ -604,4 +624,4 @@ def decode_attention_fwd( b_req_idx, b_start_loc, b_seq_len, - ) \ No newline at end of file + )