Skip to content

Commit

Permalink
save
Browse files Browse the repository at this point in the history
  • Loading branch information
micmelesse committed Oct 16, 2024
1 parent ba7e5d9 commit c6c9559
Show file tree
Hide file tree
Showing 3 changed files with 125 additions and 70 deletions.
101 changes: 75 additions & 26 deletions flash_attn/flash_attn_triton_amd/bwd_prefill.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,37 @@

DEBUG = False

@triton.jit
def _bwd_preprocess_use_o_old(
Out,
DO,
Delta,
BLOCK_M: tl.constexpr,
BLOCK_DMODEL: tl.constexpr,
ACTUAL_BLOCK_DMODEL: tl.constexpr,
N_CTX_Q: tl.constexpr
):
off_m = tl.program_id(0) * BLOCK_M + tl.arange(0, BLOCK_M)
off_d = tl.arange(0, BLOCK_DMODEL)

# create masks
# mask_m = off_m < N_CTX_Q
mask_d = off_d < ACTUAL_BLOCK_DMODEL
# o_mask = None
# o_mask = mask_m[:, None]
o_mask = mask_d[None, :]
# o_mask = mask_m[:, None] & mask_d[None, :]

# load
o = tl.load(Out + off_m[:, None] * ACTUAL_BLOCK_DMODEL + off_d[None, :], mask=o_mask).to(tl.float32)
do = tl.load(DO + off_m[:, None] * ACTUAL_BLOCK_DMODEL + off_d[None, :], mask=o_mask).to(tl.float32)
# compute
delta = tl.sum(o * do, axis=1)
# write-back
tl.store(Delta + off_m, delta)



@triton.jit
def _bwd_preprocess_use_o(
Out,
Expand Down Expand Up @@ -521,16 +552,16 @@ def attention_prefill_backward_triton_new_impl(
dq,
dk,
dv,
sm_scale,
sm_scale: float,
alibi_slopes,
causal,
layout,
layout: str,
cu_seqlens_q,
cu_seqlens_k,
max_seqlen_q,
max_seqlen_k,
use_exp2,
bwd_preprocessing_use_o,
max_seqlen_q: int,
max_seqlen_k: int,
use_exp2: bool,
bwd_preprocessing_use_o: bool,
BLOCK_M=64,
BLOCK_N=64,
):
Expand Down Expand Up @@ -561,13 +592,23 @@ def attention_prefill_backward_triton_new_impl(
softmax_lse = softmax_lse.contiguous()

# get strides and shape
batch, nheads_q, nheads_k, head_size, seqlen_q, seqlen_k = get_shape_from_layout(q, k, layout, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k)
q_strides, k_strides, v_strides, o_strides = get_strides_from_layout(q, k, v, o, layout)
stride_qz, stride_qh, stride_qm, stride_qk = q_strides
stride_kz, stride_kh, stride_kn, stride_kk = k_strides
stride_vz, stride_vh, stride_vn, stride_vk = v_strides
stride_oz, stride_oh, stride_om, stride_ok = o_strides
stride_dq_all = q.numel()
if True:
batch, nheads_q, nheads_k, head_size, seqlen_q, seqlen_k = get_shape_from_layout(q, k, layout, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k)
q_strides, k_strides, v_strides, o_strides = get_strides_from_layout(q, k, v, o, layout)
stride_qz, stride_qh, stride_qm, stride_qk = q_strides
stride_kz, stride_kh, stride_kn, stride_kk = k_strides
stride_vz, stride_vh, stride_vn, stride_vk = v_strides
stride_oz, stride_oh, stride_om, stride_ok = o_strides
stride_dq_all = q.numel()
batch_headsize = batch * nheads_q
else:
batch_q, heads_q, seqlen_q, head_size_q = q.shape
batch_k, heads_k, seqlen_k, head_size_k = k.shape
batch_headsize = batch_q * heads_q
stride_dq_all = dq.numel()
stride_qz, stride_qh, stride_qm, stride_qk = q.stride(0), q.stride(1), q.stride(2), q.stride(3)
stride_kz, stride_kh, stride_kn, stride_kk = k.stride(0), k.stride(1), k.stride(2), k.stride(3)
stride_vz, stride_vh, stride_vn, stride_vk = v.stride(0), v.stride(1), v.stride(2), v.stride(3)

sequence_parallel = False
causal = False
Expand Down Expand Up @@ -627,7 +668,6 @@ def attention_prefill_backward_triton_new_impl(
assert dk.is_contiguous()
assert dv.is_contiguous()

batch_headsize = batch * nheads_q
num_warps = 4 # NOTE: originial is 8. changing it to 1 caused issues be careful
num_stages = 1

Expand All @@ -637,19 +677,28 @@ def attention_prefill_backward_triton_new_impl(
delta = torch.empty_like(softmax_lse)

if bwd_preprocessing_use_o:
_bwd_preprocess_use_o[(num_blocks_m, batch_headsize)](
_bwd_preprocess_use_o_old[(batch_headsize * num_blocks_m,)](
o,
do,
delta,
stride_oz, stride_oh, stride_om, stride_ok,
stride_oz, stride_oh, stride_om, stride_ok,
BLOCK_M=BLOCK_M,
BLOCK_DMODEL=BLOCK_DMODEL,
ACTUAL_BLOCK_DMODEL=ACTUAL_BLOCK_DMODEL,
N_CTX_Q=seqlen_q,
Z=batch,
H=nheads_q,
N_CTX_Q=seqlen_q
)
# _bwd_preprocess_use_o[(num_blocks_m, batch_headsize)](
# o,
# do,
# delta,
# stride_oz, stride_oh, stride_om, stride_ok,
# stride_oz, stride_oh, stride_om, stride_ok,
# BLOCK_M=BLOCK_M,
# BLOCK_DMODEL=BLOCK_DMODEL,
# ACTUAL_BLOCK_DMODEL=ACTUAL_BLOCK_DMODEL,
# N_CTX_Q=seqlen_q,
# Z=batch,
# H=nheads_q,
# )
else:
_bwd_preprocess_use_p[(num_blocks_m, batch_headsize)](
q,
Expand Down Expand Up @@ -762,16 +811,16 @@ def attention_prefill_backward_triton_impl(
dq,
dk,
dv,
sm_scale,
sm_scale: float,
alibi_slopes,
causal,
layout,
layout: str,
cu_seqlens_q,
cu_seqlens_k,
max_seqlen_q,
max_seqlen_k,
use_exp2,
bwd_preprocessing_use_o,
max_seqlen_q: int,
max_seqlen_k: int,
use_exp2: bool,
bwd_preprocessing_use_o: bool,
use_new,
):
if use_new:
Expand Down
65 changes: 35 additions & 30 deletions flash_attn/flash_attn_triton_amd/test.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from .bwd_ref import attention_backward_pytorch_ref_impl
from .fwd_decode import dequantize_kv_fp16, quantize_kv_int4

DEBUG=False
DEBUG = True

# defailt fp16 tolerance is ATOL, RTOL = 1e-5, 1e-3. See table https://pytorch.org/docs/stable/testing.html
ATOL, RTOL = 1e-2, 1e-2 # old standard. maybe to lose.
Expand Down Expand Up @@ -227,7 +227,7 @@ def test_op_varlen_mqa_fwd(Z, HQ, HK, N_CTX, D_HEAD, causal, dtype=torch.float16
def test_op_bwd(Z, H, N_CTX_Q, N_CTX_K, D_HEAD, causal, torch_sdpa_test, use_alibi, dtype=torch.float16):
torch.manual_seed(20)

DEBUG_INPUT = False # if DEBUG is True it fails
DEBUG_INPUT = False

# seqlens
seqlen_q = N_CTX_Q
Expand Down Expand Up @@ -341,6 +341,7 @@ def test_op_bwd(Z, H, N_CTX_Q, N_CTX_K, D_HEAD, causal, torch_sdpa_test, use_ali
(1, 1, 256, 512, 16),
(1, 1, 128, 128, 64),
(2, 4, 1024, 1024, 64),
(4, 6, 108, 256, 224),
(4, 8, 2048, 2048, 128),
(4, 16, 4096, 4096, 64),
(2, 4, 8192, 8192, 32),
Expand Down Expand Up @@ -382,7 +383,6 @@ def test_op_fwd_prefill_impl(Z, H, N_CTX_Q, N_CTX_K, D_HEAD, causal, return_scor
v = torch.randn(Z, H, N_CTX_K, D_HEAD, device='cuda', dtype=dtype, requires_grad=True)
o = torch.empty_like(q)
elif layout == 'bshd':
# Generate random inputs
q = torch.randn(Z, N_CTX_Q, H, D_HEAD, device='cuda', dtype=dtype, requires_grad=True)
k = torch.randn(Z, N_CTX_K, H, D_HEAD, device='cuda', dtype=dtype, requires_grad=True)
v = torch.randn(Z, N_CTX_K, H, D_HEAD, device='cuda', dtype=dtype, requires_grad=True)
Expand Down Expand Up @@ -469,47 +469,51 @@ def test_op_fwd_prefill_impl(Z, H, N_CTX_Q, N_CTX_K, D_HEAD, causal, return_scor


@pytest.mark.parametrize('Z, H, N_CTX_Q, N_CTX_K, D_HEAD', [
(1, 1, 1, 1, 1),
(1, 1, 4, 4, 4),
(1, 1, 4, 4, 16),
(1, 1, 16, 16, 16),
(1, 1, 32, 32, 16),
(1, 1, 64, 64, 16), # pass # smallest head_size = 16
(1, 1, 64, 64, 64), # pass # smallest seq len seems to be 64
(1, 1, 128, 128, 64),
(1, 1, 128, 256, 45),
(1, 1, 256, 256, 64),
(1, 1, 256, 512, 16),
(1, 1, 512, 512, 64),
(1, 1, 1024, 1024, 64),
# old tests that work
(4, 48, 1024, 1024, 73),
(4, 48, 1024, 1024, 64),
(4, 48, 2048, 2048, 64),
(1, 24, 4096, 4096, 64),
(1, 16, 1024, 1024, 64),
(1, 16, 1024, 1024, 128),
# (1, 1, 1, 1, 1),
# (1, 1, 4, 4, 4),
# (1, 1, 4, 4, 16),
# (1, 1, 16, 16, 16),
# (1, 1, 32, 32, 16),
# (1, 1, 64, 64, 16), # pass # smallest head_size = 16
# (1, 1, 64, 64, 64), # pass # smallest seq len seems to be 64
# (1, 1, 128, 128, 64),
# (1, 1, 128, 256, 45),
# (1, 1, 256, 256, 64),
# (1, 1, 256, 512, 16),
# (1, 1, 512, 512, 64),
# (1, 1, 1024, 1024, 64),
# fa configs
# (2, 2, 128, 128, 65),
(2, 2, 128, 128, 224),
# (2, 2, 128, 128, 224),
# (2, 2, 108, 256, 224),
# (4, 6, 108, 256, 224),
# (1, 1, 256, 512, 16),
# # old tests that work
# (4, 48, 1024, 1024, 73),
# (4, 48, 1024, 1024, 64),
# (4, 48, 2048, 2048, 64),
# (1, 24, 4096, 4096, 64),
# (1, 16, 1024, 1024, 64),
# (1, 16, 1024, 1024, 128),
# # old tests that were commented out
# (1, 16, 8192, 8192, 63),
# (1, 16, 1022, 1022, 64),
# bad fa configs
# (1, 1, 256, 512, 16),
])
@pytest.mark.parametrize('causal', [False])
@pytest.mark.parametrize('use_exp2', [True, False])
@pytest.mark.parametrize('bwd_preprocessing_use_o', [True, False])
@pytest.mark.parametrize('layout', ["bhsd", "bshd"])
@pytest.mark.parametrize('use_exp2', [False])
@pytest.mark.parametrize('bwd_preprocessing_use_o', [False])
@pytest.mark.parametrize('layout', ["bhsd"])
@pytest.mark.parametrize('use_new', [True])
@pytest.mark.parametrize('DEBUG_INPUT', [False]) # debug output causes nans in both new and old backend
def test_op_bwd_impl(Z, H, N_CTX_Q, N_CTX_K, D_HEAD, causal, use_exp2, bwd_preprocessing_use_o, layout, use_new, DEBUG_INPUT):
def test_op_bwd_prefill_impl(Z, H, N_CTX_Q, N_CTX_K, D_HEAD, causal, use_exp2, bwd_preprocessing_use_o, layout, use_new, DEBUG_INPUT):
dtype = torch.float16
torch.manual_seed(20) # seed from test_op_bwd

if DEBUG_INPUT:
sm_scale = 1
else:
sm_scale = D_HEAD ** -0.5
head_size = D_HEAD
alibi_slopes = None

if DEBUG_INPUT:
Expand Down Expand Up @@ -556,6 +560,7 @@ def test_op_bwd_impl(Z, H, N_CTX_Q, N_CTX_K, D_HEAD, causal, use_exp2, bwd_prepr
q_ref, k_ref, v_ref, sm_scale, causal, layout, use_exp2
)
if DEBUG:
print()
print("attention_scores_ref:", attention_scores_ref, attention_scores_ref.shape)
print("attention_shifted_scaled_scores_ref:", attention_shifted_scaled_scores_ref, attention_shifted_scaled_scores_ref.shape)
print("exp_scores_ref:", exp_scores_ref, exp_scores_ref.shape)
Expand Down
29 changes: 15 additions & 14 deletions flash_attn/flash_attn_triton_amd/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,23 +57,24 @@ def varlen_input_helper(Z, HQ, HK, N_CTX_Q, N_CTX_K, D_HEAD, dtype, equal_seqlen


def get_shape_from_layout(q, k, layout, cu_seqlens_q = None, cu_seqlens_k = None, max_seqlen_q=None, max_seqlen_k=None):
if layout == 'thd':
nheads_q, nheads_k = q.shape[1], k.shape[1]
head_size = q.shape[-1]
batch = len(cu_seqlens_q) - 1
elif layout == 'bhsd':
batch, nheads_q, seqlen_q, head_size = q.shape
_, nheads_k, seqlen_k, _ = k.shape
if layout == 'bhsd':
batch_q, nheads_q, seqlen_q, head_size_q = q.shape
batch_k, nheads_k, seqlen_k, head_size_k = k.shape
elif layout == 'bshd':
batch, seqlen_q, nheads_q, head_size = q.shape
_, seqlen_k, nheads_k, _ = k.shape
batch_q, seqlen_q, nheads_q, head_size_q = q.shape
batch_k, seqlen_k, nheads_k, head_size_k = k.shape
elif layout == 'thd':
batch_q, seqlen_q, nheads_q, head_size_q = len(cu_seqlens_q) - 1, max_seqlen_q, q.shape[1], q.shape[2]
batch_k, seqlen_k, nheads_k, head_size_k = len(cu_seqlens_k) - 1, max_seqlen_k, k.shape[1], k.shape[2]
else:
assert False, "Got unsupported layout."

if layout == 'thd':
return batch, nheads_q, nheads_k, head_size, max_seqlen_q, max_seqlen_k
else:
return batch, nheads_q, nheads_k, head_size, seqlen_q, seqlen_k
# assert
assert batch_q == batch_k
assert nheads_q == nheads_k # might not be true in mqa and gqa. Keep for now
assert head_size_q == head_size_k

return batch_q, nheads_q, nheads_k, head_size_q, seqlen_q, seqlen_k

def get_strides_from_layout(q, k, v, o, layout):
if layout == 'thd':
Expand Down Expand Up @@ -206,7 +207,7 @@ def need_dropout(self, dropout_p, return_scores):
def check_args(self, q, k, v, o):
assert q.dim() == k.dim() and q.dim() == v.dim()

batch, nheads_q, nheads_k, head_size, _, _ = get_shape_from_layout(q, k, self.layout, self.cu_seqlens_q, self.cu_seqlens_k)
batch, nheads_q, nheads_k, head_size, _, _ = get_shape_from_layout(q, k, self.layout, self.cu_seqlens_q, self.cu_seqlens_k, self.max_seqlens_q, self.max_seqlens_k)
if self.varlen:
assert q.dim() == 3
assert self.cu_seqlens_q is not None
Expand Down

0 comments on commit c6c9559

Please sign in to comment.