Skip to content

Commit

Permalink
precommit
Browse files Browse the repository at this point in the history
  • Loading branch information
Chi-Chu319 committed Dec 5, 2024
1 parent abdbc3d commit 21f2177
Showing 1 changed file with 11 additions and 11 deletions.
22 changes: 11 additions & 11 deletions python/perf-kernels/flash-attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -1070,17 +1070,17 @@ def forward(ctx, q, k, v, o, metadata: MetaData):
else:
q_descale = k_descale = p_scale = p_descale = v_descale = None

attn_fwd[grid](q, k, v, metadata.bias, metadata.sm_scale, M, o, *q_strides, *k_strides, *v_strides, *o_strides,
*bias_strides, *alibi_strides, q_descale, k_descale, p_scale, p_descale, v_descale,
metadata.cu_seqlens_q, metadata.cu_seqlens_k, dropout_p=metadata.dropout_p,
philox_seed=philox_seed, philox_offset_base=philox_offset, encoded_softmax=encoded_softmax,
alibi_slopes=metadata.alibi_slopes, HQ=nheads_q, HK=nheads_k, ACTUAL_BLOCK_DMODEL=head_size,
MAX_SEQLENS_Q=metadata.max_seqlens_q, MAX_SEQLENS_K=metadata.max_seqlens_k,
IS_CAUSAL=metadata.causal, VARLEN=metadata.varlen, BLOCK_DMODEL=padded_d_model,
USE_BIAS=False if metadata.bias is None else True,
USE_ALIBI=False if metadata.alibi_slopes is None else True, ENABLE_DROPOUT=metadata.dropout_p
> 0.0, RETURN_ENCODED_SOFTMAX=metadata.return_encoded_softmax, INT8=metadata.int8,
USE_P_SCALE=metadata.int8 and metadata.use_p_scale, USE_KV_SCALE=metadata.int8 and metadata.use_kv_scale)
attn_fwd[grid](
q, k, v, metadata.bias, metadata.sm_scale, M, o, *q_strides, *k_strides, *v_strides, *o_strides,
*bias_strides, *alibi_strides, q_descale, k_descale, p_scale, p_descale, v_descale, metadata.cu_seqlens_q,
metadata.cu_seqlens_k, dropout_p=metadata.dropout_p, philox_seed=philox_seed,
philox_offset_base=philox_offset, encoded_softmax=encoded_softmax, alibi_slopes=metadata.alibi_slopes,
HQ=nheads_q, HK=nheads_k, ACTUAL_BLOCK_DMODEL=head_size, MAX_SEQLENS_Q=metadata.max_seqlens_q,
MAX_SEQLENS_K=metadata.max_seqlens_k, IS_CAUSAL=metadata.causal, VARLEN=metadata.varlen,
BLOCK_DMODEL=padded_d_model, USE_BIAS=False if metadata.bias is None else True,
USE_ALIBI=False if metadata.alibi_slopes is None else True, ENABLE_DROPOUT=metadata.dropout_p > 0.0,
RETURN_ENCODED_SOFTMAX=metadata.return_encoded_softmax, INT8=metadata.int8, USE_P_SCALE=metadata.int8
and metadata.use_p_scale, USE_KV_SCALE=metadata.int8 and metadata.use_kv_scale)

ctx.save_for_backward(q, k, v, o, M)
ctx.grid = grid
Expand Down

0 comments on commit 21f2177

Please sign in to comment.