diff --git a/python/tutorials/06-fused-attention.py b/python/tutorials/06-fused-attention.py index 09efc06de4b9..9d857c50b0bf 100644 --- a/python/tutorials/06-fused-attention.py +++ b/python/tutorials/06-fused-attention.py @@ -320,7 +320,7 @@ def _attn_bwd(Q, K, V, sm_scale, # BLOCK_M2: tl.constexpr, # BLOCK_N2: tl.constexpr, # BLK_SLICE_FACTOR: tl.constexpr, # - HEAD_DIM: tl.constexpr): + CAUSAL: tl.constexpr, HEAD_DIM: tl.constexpr): LN2: tl.constexpr = 0.6931471824645996 # = ln(2) bhid = tl.program_id(2) @@ -343,7 +343,6 @@ def _attn_bwd(Q, K, V, sm_scale, # offs_k = tl.arange(0, HEAD_DIM) start_n = pid * BLOCK_N1 - start_m = start_n MASK_BLOCK_M1: tl.constexpr = BLOCK_M1 // BLK_SLICE_FACTOR offs_n = start_n + tl.arange(0, BLOCK_N1) @@ -355,21 +354,26 @@ def _attn_bwd(Q, K, V, sm_scale, # k = tl.load(K + offs_n[:, None] * stride_tok + offs_k[None, :] * stride_d) v = tl.load(V + offs_n[:, None] * stride_tok + offs_k[None, :] * stride_d) - num_steps = BLOCK_N1 // MASK_BLOCK_M1 - - dk, dv = _attn_bwd_dkdv(dk, dv, # - Q, k, v, sm_scale, # - DO, # - M, D, # - stride_tok, stride_d, # - H, N_CTX, # - MASK_BLOCK_M1, BLOCK_N1, HEAD_DIM, # - start_n, start_m, num_steps, # - MASK=True # - ) - - start_m += num_steps * MASK_BLOCK_M1 - num_steps = (N_CTX - start_m) // BLOCK_M1 + if CAUSAL: + # compute masked (diagonal) blocks of dk and dv + start_m = start_n + num_steps = BLOCK_N1 // MASK_BLOCK_M1 + dk, dv = _attn_bwd_dkdv(dk, dv, # + Q, k, v, sm_scale, # + DO, # + M, D, # + stride_tok, stride_d, # + H, N_CTX, # + MASK_BLOCK_M1, BLOCK_N1, HEAD_DIM, # + start_n, start_m, num_steps, # + MASK=True # + ) + start_m += num_steps * MASK_BLOCK_M1 + num_steps = (N_CTX - start_m) // BLOCK_M1 + else: + # if non-causal, we compute all of dk, dv + start_m = 0 + num_steps = N_CTX // BLOCK_M1 # Compute dK and dV for non-masked blocks. dk, dv = _attn_bwd_dkdv( # @@ -394,7 +398,6 @@ def _attn_bwd(Q, K, V, sm_scale, # # THIS BLOCK DOES DQ: start_m = pid * BLOCK_M2 - end_n = start_m + BLOCK_M2 MASK_BLOCK_N2: tl.constexpr = BLOCK_N2 // BLK_SLICE_FACTOR offs_m = start_m + tl.arange(0, BLOCK_M2) @@ -406,29 +409,37 @@ def _attn_bwd(Q, K, V, sm_scale, # m = tl.load(M + offs_m) m = m[:, None] - # Compute dQ for masked (diagonal) blocks. - # NOTE: This code scans each row of QK^T backward (from right to left, - # but inside each call to _attn_bwd_dq, from left to right), but that's - # not due to anything important. I just wanted to reuse the loop - # structure for dK & dV above as much as possible. - num_steps = BLOCK_M2 // MASK_BLOCK_N2 - dq = _attn_bwd_dq(dq, q, K, V, # - do, m, D, # - stride_tok, stride_d, # - H, N_CTX, # - BLOCK_M2, MASK_BLOCK_N2, HEAD_DIM, # - start_m, end_n - num_steps * MASK_BLOCK_N2, num_steps, # - MASK=True # - ) - end_n -= num_steps * MASK_BLOCK_N2 - # stage 2 - num_steps = end_n // BLOCK_N2 + if CAUSAL: + # Compute dQ for masked (diagonal) blocks. + # NOTE: This code scans each row of QK^T backward (from right to left, + # but inside each call to _attn_bwd_dq, from left to right), but that's + # not due to anything important. I just wanted to reuse the loop + # structure for dK & dV above as much as possible. + end_n = start_m + BLOCK_M2 + num_steps = BLOCK_M2 // MASK_BLOCK_N2 + dq = _attn_bwd_dq(dq, q, K, V, # + do, m, D, # + stride_tok, stride_d, # + H, N_CTX, # + BLOCK_M2, MASK_BLOCK_N2, HEAD_DIM, # + start_m, end_n - num_steps * MASK_BLOCK_N2, num_steps, # + MASK=True # + ) + end_n -= num_steps * MASK_BLOCK_N2 + num_steps = end_n // BLOCK_N2 + start_n = end_n - num_steps * BLOCK_N2 + else: + # if non-causal, compute all of dq + start_n = 0 + num_steps = N_CTX // BLOCK_N2 + + # compute non-masked blocks of dq dq = _attn_bwd_dq(dq, q, K, V, # do, m, D, # stride_tok, stride_d, # H, N_CTX, # BLOCK_M2, BLOCK_N2, HEAD_DIM, # - start_m, end_n - num_steps * BLOCK_N2, num_steps, # + start_m, start_n, num_steps, # MASK=False # ) # Write back dQ. @@ -512,7 +523,7 @@ def backward(ctx, do): BLOCK_M2=BLOCK_M2, BLOCK_N2=BLOCK_N2, # BLK_SLICE_FACTOR=BLK_SLICE_FACTOR, # HEAD_DIM=ctx.HEAD_DIM, # - num_warps=NUM_WARPS, # + CAUSAL=ctx.causal, num_warps=NUM_WARPS, # num_stages=NUM_STAGES # ) @@ -523,7 +534,7 @@ def backward(ctx, do): @pytest.mark.parametrize("Z, H, N_CTX, HEAD_DIM", [(1, 2, 1024, 64)]) -@pytest.mark.parametrize("causal", [True]) +@pytest.mark.parametrize("causal", [False, True]) def test_op(Z, H, N_CTX, HEAD_DIM, causal, dtype=torch.float16): torch.manual_seed(20) q = (torch.empty((Z, H, N_CTX, HEAD_DIM), dtype=dtype, device="cuda").normal_(mean=0.0, std=0.5).requires_grad_()) @@ -574,8 +585,6 @@ def test_op(Z, H, N_CTX, HEAD_DIM, causal, dtype=torch.float16): configs = [] for mode in ["fwd", "bwd"]: for causal in [True, False]: - if mode == "bwd" and not causal: - continue configs.append( triton.testing.Benchmark( x_names=["N_CTX"],