From c6c95592c43fcbc46861623c1823c9deb6dba586 Mon Sep 17 00:00:00 2001 From: Michael Melesse Date: Wed, 16 Oct 2024 10:29:07 -0500 Subject: [PATCH] save --- .../flash_attn_triton_amd/bwd_prefill.py | 101 +++++++++++++----- flash_attn/flash_attn_triton_amd/test.py | 65 +++++------ flash_attn/flash_attn_triton_amd/utils.py | 29 ++--- 3 files changed, 125 insertions(+), 70 deletions(-) diff --git a/flash_attn/flash_attn_triton_amd/bwd_prefill.py b/flash_attn/flash_attn_triton_amd/bwd_prefill.py index 227550153..0ee34bb10 100644 --- a/flash_attn/flash_attn_triton_amd/bwd_prefill.py +++ b/flash_attn/flash_attn_triton_amd/bwd_prefill.py @@ -7,6 +7,37 @@ DEBUG = False +@triton.jit +def _bwd_preprocess_use_o_old( + Out, + DO, + Delta, + BLOCK_M: tl.constexpr, + BLOCK_DMODEL: tl.constexpr, + ACTUAL_BLOCK_DMODEL: tl.constexpr, + N_CTX_Q: tl.constexpr +): + off_m = tl.program_id(0) * BLOCK_M + tl.arange(0, BLOCK_M) + off_d = tl.arange(0, BLOCK_DMODEL) + + # create masks + # mask_m = off_m < N_CTX_Q + mask_d = off_d < ACTUAL_BLOCK_DMODEL + # o_mask = None + # o_mask = mask_m[:, None] + o_mask = mask_d[None, :] + # o_mask = mask_m[:, None] & mask_d[None, :] + + # load + o = tl.load(Out + off_m[:, None] * ACTUAL_BLOCK_DMODEL + off_d[None, :], mask=o_mask).to(tl.float32) + do = tl.load(DO + off_m[:, None] * ACTUAL_BLOCK_DMODEL + off_d[None, :], mask=o_mask).to(tl.float32) + # compute + delta = tl.sum(o * do, axis=1) + # write-back + tl.store(Delta + off_m, delta) + + + @triton.jit def _bwd_preprocess_use_o( Out, @@ -521,16 +552,16 @@ def attention_prefill_backward_triton_new_impl( dq, dk, dv, - sm_scale, + sm_scale: float, alibi_slopes, causal, - layout, + layout: str, cu_seqlens_q, cu_seqlens_k, - max_seqlen_q, - max_seqlen_k, - use_exp2, - bwd_preprocessing_use_o, + max_seqlen_q: int, + max_seqlen_k: int, + use_exp2: bool, + bwd_preprocessing_use_o: bool, BLOCK_M=64, BLOCK_N=64, ): @@ -561,13 +592,23 @@ def attention_prefill_backward_triton_new_impl( softmax_lse = softmax_lse.contiguous() # get strides and shape - batch, nheads_q, nheads_k, head_size, seqlen_q, seqlen_k = get_shape_from_layout(q, k, layout, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k) - q_strides, k_strides, v_strides, o_strides = get_strides_from_layout(q, k, v, o, layout) - stride_qz, stride_qh, stride_qm, stride_qk = q_strides - stride_kz, stride_kh, stride_kn, stride_kk = k_strides - stride_vz, stride_vh, stride_vn, stride_vk = v_strides - stride_oz, stride_oh, stride_om, stride_ok = o_strides - stride_dq_all = q.numel() + if True: + batch, nheads_q, nheads_k, head_size, seqlen_q, seqlen_k = get_shape_from_layout(q, k, layout, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k) + q_strides, k_strides, v_strides, o_strides = get_strides_from_layout(q, k, v, o, layout) + stride_qz, stride_qh, stride_qm, stride_qk = q_strides + stride_kz, stride_kh, stride_kn, stride_kk = k_strides + stride_vz, stride_vh, stride_vn, stride_vk = v_strides + stride_oz, stride_oh, stride_om, stride_ok = o_strides + stride_dq_all = q.numel() + batch_headsize = batch * nheads_q + else: + batch_q, heads_q, seqlen_q, head_size_q = q.shape + batch_k, heads_k, seqlen_k, head_size_k = k.shape + batch_headsize = batch_q * heads_q + stride_dq_all = dq.numel() + stride_qz, stride_qh, stride_qm, stride_qk = q.stride(0), q.stride(1), q.stride(2), q.stride(3) + stride_kz, stride_kh, stride_kn, stride_kk = k.stride(0), k.stride(1), k.stride(2), k.stride(3) + stride_vz, stride_vh, stride_vn, stride_vk = v.stride(0), v.stride(1), v.stride(2), v.stride(3) sequence_parallel = False causal = False @@ -627,7 +668,6 @@ def attention_prefill_backward_triton_new_impl( assert dk.is_contiguous() assert dv.is_contiguous() - batch_headsize = batch * nheads_q num_warps = 4 # NOTE: originial is 8. changing it to 1 caused issues be careful num_stages = 1 @@ -637,19 +677,28 @@ def attention_prefill_backward_triton_new_impl( delta = torch.empty_like(softmax_lse) if bwd_preprocessing_use_o: - _bwd_preprocess_use_o[(num_blocks_m, batch_headsize)]( + _bwd_preprocess_use_o_old[(batch_headsize * num_blocks_m,)]( o, do, delta, - stride_oz, stride_oh, stride_om, stride_ok, - stride_oz, stride_oh, stride_om, stride_ok, BLOCK_M=BLOCK_M, BLOCK_DMODEL=BLOCK_DMODEL, ACTUAL_BLOCK_DMODEL=ACTUAL_BLOCK_DMODEL, - N_CTX_Q=seqlen_q, - Z=batch, - H=nheads_q, + N_CTX_Q=seqlen_q ) + # _bwd_preprocess_use_o[(num_blocks_m, batch_headsize)]( + # o, + # do, + # delta, + # stride_oz, stride_oh, stride_om, stride_ok, + # stride_oz, stride_oh, stride_om, stride_ok, + # BLOCK_M=BLOCK_M, + # BLOCK_DMODEL=BLOCK_DMODEL, + # ACTUAL_BLOCK_DMODEL=ACTUAL_BLOCK_DMODEL, + # N_CTX_Q=seqlen_q, + # Z=batch, + # H=nheads_q, + # ) else: _bwd_preprocess_use_p[(num_blocks_m, batch_headsize)]( q, @@ -762,16 +811,16 @@ def attention_prefill_backward_triton_impl( dq, dk, dv, - sm_scale, + sm_scale: float, alibi_slopes, causal, - layout, + layout: str, cu_seqlens_q, cu_seqlens_k, - max_seqlen_q, - max_seqlen_k, - use_exp2, - bwd_preprocessing_use_o, + max_seqlen_q: int, + max_seqlen_k: int, + use_exp2: bool, + bwd_preprocessing_use_o: bool, use_new, ): if use_new: diff --git a/flash_attn/flash_attn_triton_amd/test.py b/flash_attn/flash_attn_triton_amd/test.py index e556b8eaa..92fe802f0 100644 --- a/flash_attn/flash_attn_triton_amd/test.py +++ b/flash_attn/flash_attn_triton_amd/test.py @@ -9,7 +9,7 @@ from .bwd_ref import attention_backward_pytorch_ref_impl from .fwd_decode import dequantize_kv_fp16, quantize_kv_int4 -DEBUG=False +DEBUG = True # defailt fp16 tolerance is ATOL, RTOL = 1e-5, 1e-3. See table https://pytorch.org/docs/stable/testing.html ATOL, RTOL = 1e-2, 1e-2 # old standard. maybe to lose. @@ -227,7 +227,7 @@ def test_op_varlen_mqa_fwd(Z, HQ, HK, N_CTX, D_HEAD, causal, dtype=torch.float16 def test_op_bwd(Z, H, N_CTX_Q, N_CTX_K, D_HEAD, causal, torch_sdpa_test, use_alibi, dtype=torch.float16): torch.manual_seed(20) - DEBUG_INPUT = False # if DEBUG is True it fails + DEBUG_INPUT = False # seqlens seqlen_q = N_CTX_Q @@ -341,6 +341,7 @@ def test_op_bwd(Z, H, N_CTX_Q, N_CTX_K, D_HEAD, causal, torch_sdpa_test, use_ali (1, 1, 256, 512, 16), (1, 1, 128, 128, 64), (2, 4, 1024, 1024, 64), + (4, 6, 108, 256, 224), (4, 8, 2048, 2048, 128), (4, 16, 4096, 4096, 64), (2, 4, 8192, 8192, 32), @@ -382,7 +383,6 @@ def test_op_fwd_prefill_impl(Z, H, N_CTX_Q, N_CTX_K, D_HEAD, causal, return_scor v = torch.randn(Z, H, N_CTX_K, D_HEAD, device='cuda', dtype=dtype, requires_grad=True) o = torch.empty_like(q) elif layout == 'bshd': - # Generate random inputs q = torch.randn(Z, N_CTX_Q, H, D_HEAD, device='cuda', dtype=dtype, requires_grad=True) k = torch.randn(Z, N_CTX_K, H, D_HEAD, device='cuda', dtype=dtype, requires_grad=True) v = torch.randn(Z, N_CTX_K, H, D_HEAD, device='cuda', dtype=dtype, requires_grad=True) @@ -469,39 +469,44 @@ def test_op_fwd_prefill_impl(Z, H, N_CTX_Q, N_CTX_K, D_HEAD, causal, return_scor @pytest.mark.parametrize('Z, H, N_CTX_Q, N_CTX_K, D_HEAD', [ - (1, 1, 1, 1, 1), - (1, 1, 4, 4, 4), - (1, 1, 4, 4, 16), - (1, 1, 16, 16, 16), - (1, 1, 32, 32, 16), - (1, 1, 64, 64, 16), # pass # smallest head_size = 16 - (1, 1, 64, 64, 64), # pass # smallest seq len seems to be 64 - (1, 1, 128, 128, 64), - (1, 1, 128, 256, 45), - (1, 1, 256, 256, 64), - (1, 1, 256, 512, 16), - (1, 1, 512, 512, 64), - (1, 1, 1024, 1024, 64), - # old tests that work - (4, 48, 1024, 1024, 73), - (4, 48, 1024, 1024, 64), - (4, 48, 2048, 2048, 64), - (1, 24, 4096, 4096, 64), - (1, 16, 1024, 1024, 64), - (1, 16, 1024, 1024, 128), + # (1, 1, 1, 1, 1), + # (1, 1, 4, 4, 4), + # (1, 1, 4, 4, 16), + # (1, 1, 16, 16, 16), + # (1, 1, 32, 32, 16), + # (1, 1, 64, 64, 16), # pass # smallest head_size = 16 + # (1, 1, 64, 64, 64), # pass # smallest seq len seems to be 64 + # (1, 1, 128, 128, 64), + # (1, 1, 128, 256, 45), + # (1, 1, 256, 256, 64), + # (1, 1, 256, 512, 16), + # (1, 1, 512, 512, 64), + # (1, 1, 1024, 1024, 64), + # fa configs + # (2, 2, 128, 128, 65), + (2, 2, 128, 128, 224), + # (2, 2, 128, 128, 224), + # (2, 2, 108, 256, 224), + # (4, 6, 108, 256, 224), + # (1, 1, 256, 512, 16), + # # old tests that work + # (4, 48, 1024, 1024, 73), + # (4, 48, 1024, 1024, 64), + # (4, 48, 2048, 2048, 64), + # (1, 24, 4096, 4096, 64), + # (1, 16, 1024, 1024, 64), + # (1, 16, 1024, 1024, 128), # # old tests that were commented out # (1, 16, 8192, 8192, 63), # (1, 16, 1022, 1022, 64), - # bad fa configs - # (1, 1, 256, 512, 16), ]) @pytest.mark.parametrize('causal', [False]) -@pytest.mark.parametrize('use_exp2', [True, False]) -@pytest.mark.parametrize('bwd_preprocessing_use_o', [True, False]) -@pytest.mark.parametrize('layout', ["bhsd", "bshd"]) +@pytest.mark.parametrize('use_exp2', [False]) +@pytest.mark.parametrize('bwd_preprocessing_use_o', [False]) +@pytest.mark.parametrize('layout', ["bhsd"]) @pytest.mark.parametrize('use_new', [True]) @pytest.mark.parametrize('DEBUG_INPUT', [False]) # debug output causes nans in both new and old backend -def test_op_bwd_impl(Z, H, N_CTX_Q, N_CTX_K, D_HEAD, causal, use_exp2, bwd_preprocessing_use_o, layout, use_new, DEBUG_INPUT): +def test_op_bwd_prefill_impl(Z, H, N_CTX_Q, N_CTX_K, D_HEAD, causal, use_exp2, bwd_preprocessing_use_o, layout, use_new, DEBUG_INPUT): dtype = torch.float16 torch.manual_seed(20) # seed from test_op_bwd @@ -509,7 +514,6 @@ def test_op_bwd_impl(Z, H, N_CTX_Q, N_CTX_K, D_HEAD, causal, use_exp2, bwd_prepr sm_scale = 1 else: sm_scale = D_HEAD ** -0.5 - head_size = D_HEAD alibi_slopes = None if DEBUG_INPUT: @@ -556,6 +560,7 @@ def test_op_bwd_impl(Z, H, N_CTX_Q, N_CTX_K, D_HEAD, causal, use_exp2, bwd_prepr q_ref, k_ref, v_ref, sm_scale, causal, layout, use_exp2 ) if DEBUG: + print() print("attention_scores_ref:", attention_scores_ref, attention_scores_ref.shape) print("attention_shifted_scaled_scores_ref:", attention_shifted_scaled_scores_ref, attention_shifted_scaled_scores_ref.shape) print("exp_scores_ref:", exp_scores_ref, exp_scores_ref.shape) diff --git a/flash_attn/flash_attn_triton_amd/utils.py b/flash_attn/flash_attn_triton_amd/utils.py index 80defc654..3ac393f23 100644 --- a/flash_attn/flash_attn_triton_amd/utils.py +++ b/flash_attn/flash_attn_triton_amd/utils.py @@ -57,23 +57,24 @@ def varlen_input_helper(Z, HQ, HK, N_CTX_Q, N_CTX_K, D_HEAD, dtype, equal_seqlen def get_shape_from_layout(q, k, layout, cu_seqlens_q = None, cu_seqlens_k = None, max_seqlen_q=None, max_seqlen_k=None): - if layout == 'thd': - nheads_q, nheads_k = q.shape[1], k.shape[1] - head_size = q.shape[-1] - batch = len(cu_seqlens_q) - 1 - elif layout == 'bhsd': - batch, nheads_q, seqlen_q, head_size = q.shape - _, nheads_k, seqlen_k, _ = k.shape + if layout == 'bhsd': + batch_q, nheads_q, seqlen_q, head_size_q = q.shape + batch_k, nheads_k, seqlen_k, head_size_k = k.shape elif layout == 'bshd': - batch, seqlen_q, nheads_q, head_size = q.shape - _, seqlen_k, nheads_k, _ = k.shape + batch_q, seqlen_q, nheads_q, head_size_q = q.shape + batch_k, seqlen_k, nheads_k, head_size_k = k.shape + elif layout == 'thd': + batch_q, seqlen_q, nheads_q, head_size_q = len(cu_seqlens_q) - 1, max_seqlen_q, q.shape[1], q.shape[2] + batch_k, seqlen_k, nheads_k, head_size_k = len(cu_seqlens_k) - 1, max_seqlen_k, k.shape[1], k.shape[2] else: assert False, "Got unsupported layout." - if layout == 'thd': - return batch, nheads_q, nheads_k, head_size, max_seqlen_q, max_seqlen_k - else: - return batch, nheads_q, nheads_k, head_size, seqlen_q, seqlen_k + # assert + assert batch_q == batch_k + assert nheads_q == nheads_k # might not be true in mqa and gqa. Keep for now + assert head_size_q == head_size_k + + return batch_q, nheads_q, nheads_k, head_size_q, seqlen_q, seqlen_k def get_strides_from_layout(q, k, v, o, layout): if layout == 'thd': @@ -206,7 +207,7 @@ def need_dropout(self, dropout_p, return_scores): def check_args(self, q, k, v, o): assert q.dim() == k.dim() and q.dim() == v.dim() - batch, nheads_q, nheads_k, head_size, _, _ = get_shape_from_layout(q, k, self.layout, self.cu_seqlens_q, self.cu_seqlens_k) + batch, nheads_q, nheads_k, head_size, _, _ = get_shape_from_layout(q, k, self.layout, self.cu_seqlens_q, self.cu_seqlens_k, self.max_seqlens_q, self.max_seqlens_k) if self.varlen: assert q.dim() == 3 assert self.cu_seqlens_q is not None