Skip to content

Commit

Permalink
Merge pull request #37 from ROCm/add_paged_kvcache
Browse files Browse the repository at this point in the history
Add support for using Paged-KVCache
  • Loading branch information
qianfengz authored Nov 26, 2024
2 parents bdfffaa + 95460bc commit 56dba6b
Show file tree
Hide file tree
Showing 1,165 changed files with 12,823 additions and 12,426 deletions.
9 changes: 9 additions & 0 deletions tests/test_mem_eff_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -2469,6 +2469,15 @@ def test_paged_attention(
B, MAX_T, num_quant_groups, page_size, op, bench=False, gappy=gappy
)

@cuda_only
@pytest.mark.parametrize("B", [1, 5, 128])
@pytest.mark.parametrize("MAX_T", [64, 128, 2048, 4096, 8192])
@pytest.mark.parametrize("page_size", [128, 256])
@pytest.mark.parametrize("gappy", [False, True], ids=lambda x: "gappy" if x else "")
def test_paged_attention_ck(B, MAX_T: int, page_size: int, gappy: bool):
op = fmha.ck.FwOp
num_quant_groups = 0
paged_attention_run_inner(B, MAX_T, num_quant_groups, page_size, op, bench=False, gappy=gappy)

@sm80_or_better_only
@disable_on_rocm
Expand Down
2 changes: 1 addition & 1 deletion third_party/composable_kernel_tiled
2 changes: 1 addition & 1 deletion xformers/csrc/attention/attention.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ TORCH_LIBRARY_FRAGMENT(xformers, m) {
"xformers::efficient_attention_forward_ck(Tensor query, "
"Tensor key, Tensor value, Tensor? attn_bias, Tensor? seqstart_q, "
"Tensor? seqstart_k, int? max_seqlen_q, float dropout_p, "
"bool compute_logsumexp, int custom_mask_type, float? scale, Tensor? seqlen_k, int? window_size) -> (Tensor, Tensor, int, int)"));
"bool compute_logsumexp, int custom_mask_type, float? scale, Tensor? seqlen_k, int? window_size, Tensor? block_tables, int? page_size) -> (Tensor, Tensor, int, int)"));
m.def(TORCH_SELECTIVE_SCHEMA(
"xformers::efficient_attention_forward_decoder_ck(Tensor query, "
"Tensor key, Tensor value, Tensor? seq_positions, float scale) -> Tensor"));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,9 @@ efficient_attention_forward_ck(
int64_t custom_mask_type,
c10::optional<double> scale,
const c10::optional<at::Tensor>& seqlen_k,
const c10::optional<int64_t> window_size) {
const c10::optional<int64_t> window_size,
const c10::optional<at::Tensor>& block_tables,
const c10::optional<int64_t> page_size) {
TORCH_CHECK(query.dim() == 4);
TORCH_CHECK(key.dim() == 4);
TORCH_CHECK(value.dim() == 4);
Expand All @@ -93,13 +95,21 @@ efficient_attention_forward_ck(
TORCH_CHECK(seqstart_q->scalar_type() == at::ScalarType::Int);
TORCH_CHECK(seqstart_k->scalar_type() == at::ScalarType::Int);
TORCH_CHECK(seqstart_q->dim() == 1 && seqstart_k->dim() == 1);
TORCH_CHECK(seqstart_q->size(0) == seqstart_k->size(0));
TORCH_CHECK(
seqstart_q->size(0) == seqstart_k->size(0) ||
seqstart_q->size(0) == seqstart_k->size(0) + 1);
TORCH_CHECK(query.size(0) == 1, "cu_seqlen only supports batch_size=1");
TORCH_CHECK(max_seqlen_q_.has_value());
CHECK_NOSPARSE_CONTIGUOUS_CUDA((*seqstart_q));
CHECK_NOSPARSE_CONTIGUOUS_CUDA((*seqstart_k));
};

TORCH_CHECK(block_tables.has_value() == page_size.has_value());
TORCH_CHECK(!block_tables.has_value() || block_tables->dim() == 2);

// Currently xformers only use Paged-KVcache in grouped mode
TORCH_CHECK(seqstart_q.has_value() || !block_tables.has_value());

// last dim is contiguous, device is kCUDA
CHECK_NOSPARSE_LASTCONTIGUOUS_CUDA(query);
CHECK_NOSPARSE_LASTCONTIGUOUS_CUDA(key);
Expand Down Expand Up @@ -336,6 +346,22 @@ efficient_attention_forward_ck(
} else
p.seqlen_k_dev_ptr = nullptr;

p.is_gappy = false;
if (block_tables.has_value()) {
p.block_table_ptr = block_tables->data_ptr();
p.page_block_size = *page_size;
p.batch_stride_block_table = block_tables->stride(0);
p.use_paged_kvcache = true;

TORCH_CHECK(seqlen_k.has_value());

// PageBlockDiagonalGappyKeysMask has special way to use seqstart_k,
// somehow ck_tile kernel need know this
if (seqstart_k->size(0) == seqlen_k->size(0))
p.is_gappy = true;
} else
p.use_paged_kvcache = false;

p.philox_seed = philox_seed;
p.philox_offset = philox_offset;
p.compute_logsumexp = compute_logsumexp;
Expand All @@ -361,10 +387,14 @@ efficient_attention_forward_ck(
p.num_kv_splits = get_num_kv_splits_heuristic(
p.num_batches, p.Hq, p.max_seqlen_q, std::max(p.K, p.Kv), 32);

// fmha fwd split-kv kernel does not support dropout
p.use_split_kv = (!use_dropout && (p.num_kv_splits > 1)) ? true : false;
// 1) fmha fwd split-kv kernel does not support dropout
// 2) Paged-KVcache is only available from the split-kv kernel at present
p.use_split_kv =
(p.use_paged_kvcache || (!use_dropout && (p.num_kv_splits > 1)))
? true
: false;

if (p.use_split_kv) {
if (p.use_split_kv && p.num_kv_splits > 1) {
out_acc = at::empty({p.num_kv_splits, M, Hq, Kv}, opts.dtype(at::kFloat));
p.out_acc_ptr = out_acc.data_ptr();
p.out_acc_strides = {
Expand Down Expand Up @@ -454,7 +484,9 @@ efficient_attention_forward_ck_meta(
int64_t custom_mask_type,
c10::optional<double> scale,
const c10::optional<at::Tensor>& seqlen_k,
const c10::optional<int64_t> window_size) {
const c10::optional<int64_t> window_size,
const c10::optional<at::Tensor>& block_tables,
const c10::optional<int64_t> page_size) {
int64_t B = query.size(0);
int64_t M = query.size(1);
int64_t N = key.size(1);
Expand Down
139 changes: 67 additions & 72 deletions xformers/csrc/attention/hip_fmha/ck_tiled_fmha_batched_backward.h
Original file line number Diff line number Diff line change
Expand Up @@ -18,12 +18,12 @@

template <
typename ScalarType,
bool kHasCausalMask,
bool kHasMask,
bool kHasBias,
bool kHasBiasGrad,
bool kHasDropout,
ck_tile::index_t MaxK>
struct batched_backward_causalmask_bias_dropout_dispatch {
struct batched_backward_mask_bias_dropout_dispatch {
using FmhaBlockDropout =
typename FmhaBwdBlockDropoutMaker<kHasDropout, MaxK>::dropout;

Expand Down Expand Up @@ -93,72 +93,67 @@ struct batched_backward_causalmask_bias_dropout_dispatch {
}

{
const bool has_local_attention = (param.window_size > 0) ? true : false;

BOOL_SWITCH(has_local_attention, USE_LOCAL_ATTENTION, [&] {
constexpr ck_tile::index_t occupancy = 1;
constexpr bool has_masking = kHasCausalMask || USE_LOCAL_ATTENTION;

using FmhaMask = ck_tile::SimplifiedGenericAttentionMask<has_masking>;

constexpr auto kBiasEnum = kHasBias
? ck_tile::BlockAttentionBiasEnum::ELEMENTWISE_BIAS
: ck_tile::BlockAttentionBiasEnum::NO_BIAS;

constexpr bool kPadSeqLenQ = true;
constexpr bool kPadSeqLenK = true;

const bool pad_headdim_q =
!(param.K % FmhaBwdShape<MaxK>::kQKHeaddim == 0);
const bool pad_headdim_v =
!(param.Kv % FmhaBwdShape<MaxK>::kVHeaddim == 0);

BOOL_SWITCH_2(
pad_headdim_q, kPadHeadDimQ, pad_headdim_v, kPadHeadDimV, [&] {
using FmhaBwdTraits_ = ck_tile::TileFmhaTraits<
kPadSeqLenQ,
kPadSeqLenK,
kPadHeadDimQ,
kPadHeadDimV,
kBiasEnum,
kHasBiasGrad,
false, // kStoreLSE
false, // place-holder for kHasDropout, not used actually
false, // kDoFp8StaticQuant place-holder
occupancy>;

using FmhaBwdPipelineProblem =
FmhaBwdPipelineProblemTemp<FmhaBwdTraits_, FmhaMask>;

constexpr auto FmhaBwdPipelineEnum_ =
FmhaBwdPipelineEnumSelector<MaxK>::value;

using FmhaBwdPipeline_ = typename FmhaBwdPipelineMaker<
FmhaBwdPipelineEnum_,
FmhaBwdPipelineProblem>::pipeline;

using FmhaBwdKGradEpilogue_ =
ck_tile::Default2DEpilogue<ck_tile::Default2DEpilogueProblem<
typename FmhaBwdTypeConfig<ScalarType>::AccDataType,
typename FmhaBwdTypeConfig<ScalarType>::KGradDataType,
kPadSeqLenK,
kPadHeadDimQ>>;

using FmhaBwdVGradEpilogue_ =
ck_tile::Default2DEpilogue<ck_tile::Default2DEpilogueProblem<
typename FmhaBwdTypeConfig<ScalarType>::AccDataType,
typename FmhaBwdTypeConfig<ScalarType>::VGradDataType,
kPadSeqLenK,
kPadHeadDimV>>;

using FmhaBwdDQDKDVKernel_ = ck_tile::FmhaBwdDQDKDVKernel<
FmhaBwdPipeline_,
FmhaBwdKGradEpilogue_,
FmhaBwdVGradEpilogue_>;

RunWithBwdDQDKDVKernel<FmhaBwdDQDKDVKernel_>(param, stream);
});
});
constexpr ck_tile::index_t occupancy = 1;

using FmhaMask = ck_tile::SimplifiedGenericAttentionMask<kHasMask>;

constexpr auto kBiasEnum = kHasBias
? ck_tile::BlockAttentionBiasEnum::ELEMENTWISE_BIAS
: ck_tile::BlockAttentionBiasEnum::NO_BIAS;

constexpr bool kPadSeqLenQ = true;
constexpr bool kPadSeqLenK = true;

const bool pad_headdim_q =
!(param.K % FmhaBwdShape<MaxK>::kQKHeaddim == 0);
const bool pad_headdim_v =
!(param.Kv % FmhaBwdShape<MaxK>::kVHeaddim == 0);

BOOL_SWITCH_2(
pad_headdim_q, kPadHeadDimQ, pad_headdim_v, kPadHeadDimV, [&] {
using FmhaBwdTraits_ = ck_tile::TileFmhaTraits<
kPadSeqLenQ,
kPadSeqLenK,
kPadHeadDimQ,
kPadHeadDimV,
kBiasEnum,
kHasBiasGrad,
false, // kStoreLSE
false, // place-holder for kHasDropout, not used actually
false, // kDoFp8StaticQuant place-holder
occupancy>;

using FmhaBwdPipelineProblem =
FmhaBwdPipelineProblemTemp<FmhaBwdTraits_, FmhaMask>;

constexpr auto FmhaBwdPipelineEnum_ =
FmhaBwdPipelineEnumSelector<MaxK>::value;

using FmhaBwdPipeline_ = typename FmhaBwdPipelineMaker<
FmhaBwdPipelineEnum_,
FmhaBwdPipelineProblem>::pipeline;

using FmhaBwdKGradEpilogue_ =
ck_tile::Default2DEpilogue<ck_tile::Default2DEpilogueProblem<
typename FmhaBwdTypeConfig<ScalarType>::AccDataType,
typename FmhaBwdTypeConfig<ScalarType>::KGradDataType,
kPadSeqLenK,
kPadHeadDimQ>>;

using FmhaBwdVGradEpilogue_ =
ck_tile::Default2DEpilogue<ck_tile::Default2DEpilogueProblem<
typename FmhaBwdTypeConfig<ScalarType>::AccDataType,
typename FmhaBwdTypeConfig<ScalarType>::VGradDataType,
kPadSeqLenK,
kPadHeadDimV>>;

using FmhaBwdDQDKDVKernel_ = ck_tile::FmhaBwdDQDKDVKernel<
FmhaBwdPipeline_,
FmhaBwdKGradEpilogue_,
FmhaBwdVGradEpilogue_>;

RunWithBwdDQDKDVKernel<FmhaBwdDQDKDVKernel_>(param, stream);
});
};
if constexpr (NeedConvertGradQ) {
constexpr ck_tile::index_t kBlockSize = 256;
Expand Down Expand Up @@ -352,17 +347,17 @@ struct batched_backward_causalmask_bias_dropout_dispatch {

template <
typename ScalarType,
bool kHasCausalMask,
bool kHasMask,
bool kHasBias,
bool kHasBiasGrad,
bool kHasDropout,
ck_tile::index_t MaxK>
void run_batched_backward_causalmask_bias_dropout_dispatch(
void run_batched_backward_mask_bias_dropout_dispatch(
BatchedBackwardParams& param,
hipStream_t stream) {
batched_backward_causalmask_bias_dropout_dispatch<
batched_backward_mask_bias_dropout_dispatch<
ScalarType,
kHasCausalMask,
kHasMask,
kHasBias,
kHasBiasGrad,
kHasDropout,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,16 +25,18 @@ void batched_backward_bf16(BatchedBackwardParams& param, hipStream_t stream) {
[&] {
if constexpr (kHasBias || !kHasBiasGrad) {
FMHA_BWD_HEADDIM_SWITCH(param.K, param.Kv, MaxK, [&] {
if (param.custom_mask_type == 0)
run_batched_backward_causalmask_bias_dropout_dispatch<
if (param.custom_mask_type == 0 && param.window_size <= 0)
run_batched_backward_mask_bias_dropout_dispatch<
ck_tile::bf16_t,
false,
kHasBias,
kHasBiasGrad,
kHasDropout,
MaxK>(param, stream);
else if (param.custom_mask_type == 1 || param.custom_mask_type == 2)
run_batched_backward_causalmask_bias_dropout_dispatch<
else if (
param.custom_mask_type == 1 || param.custom_mask_type == 2 ||
param.window_size > 0)
run_batched_backward_mask_bias_dropout_dispatch<
ck_tile::bf16_t,
true,
kHasBias,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,16 +25,18 @@ void batched_backward_fp16(BatchedBackwardParams& param, hipStream_t stream) {
[&] {
if constexpr (kHasBias || !kHasBiasGrad) {
FMHA_BWD_HEADDIM_SWITCH(param.K, param.Kv, MaxK, [&] {
if (param.custom_mask_type == 0)
run_batched_backward_causalmask_bias_dropout_dispatch<
if (param.custom_mask_type == 0 && param.window_size <= 0)
run_batched_backward_mask_bias_dropout_dispatch<
ck_tile::fp16_t,
false,
kHasBias,
kHasBiasGrad,
kHasDropout,
MaxK>(param, stream);
else if (param.custom_mask_type == 1 || param.custom_mask_type == 2)
run_batched_backward_causalmask_bias_dropout_dispatch<
else if (
param.custom_mask_type == 1 || param.custom_mask_type == 2 ||
param.window_size > 0)
run_batched_backward_mask_bias_dropout_dispatch<
ck_tile::fp16_t,
true,
kHasBias,
Expand Down
16 changes: 8 additions & 8 deletions xformers/csrc/attention/hip_fmha/ck_tiled_fmha_batched_forward.h
Original file line number Diff line number Diff line change
Expand Up @@ -12,37 +12,37 @@

template <
typename ScalarType,
bool kHasCausalMask,
bool kHasMask,
bool kHasBias,
bool kHasDropout,
ck_tile::index_t MaxK>
void run_batched_forward_causalmask_bias_dropout_dispatch(
void run_batched_forward_mask_bias_dropout_dispatch(
BatchedForwardParams& param,
hipStream_t stream) {
// currently split-kv implementation does not support dropout
if constexpr (!kHasDropout) {
#ifndef FMHA_FWD_SPLITKV_NOT_USED
if (param.use_split_kv) {
FMHA_FWD_SEQLEN_Q_SWITCH(param.M, MaxSeqlenQ, [&] {
batched_forward_splitkv_causalmask_bias_dropout_dispatch<
batched_forward_splitkv_mask_bias_dropout_dispatch<
ScalarType,
kHasCausalMask,
kHasMask,
kHasBias,
MaxK,
MaxSeqlenQ>::Run(param, stream);
});
} else
#endif
batched_forward_causalmask_bias_dropout_dispatch<
batched_forward_mask_bias_dropout_dispatch<
ScalarType,
kHasCausalMask,
kHasMask,
kHasBias,
kHasDropout,
MaxK>::Run(param, stream);
} else {
batched_forward_causalmask_bias_dropout_dispatch<
batched_forward_mask_bias_dropout_dispatch<
ScalarType,
kHasCausalMask,
kHasMask,
kHasBias,
kHasDropout,
MaxK>::Run(param, stream);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,15 +17,17 @@ void batched_forward_bf16(BatchedForwardParams& param, hipStream_t stream) {
const bool has_dropout = (param.dropout_prob > 0.0f);
BOOL_SWITCH_2(param.has_attn_bias, kHasBias, has_dropout, kHasDropout, [&] {
FMHA_FWD_HEADDIM_SWITCH(param.K, param.Kv, MaxK, [&] {
if (param.custom_mask_type == 0)
run_batched_forward_causalmask_bias_dropout_dispatch<
if (param.custom_mask_type == 0 && param.window_size <= 0)
run_batched_forward_mask_bias_dropout_dispatch<
ck_tile::bf16_t,
false,
kHasBias,
kHasDropout,
MaxK>(param, stream);
else if (param.custom_mask_type == 1 || param.custom_mask_type == 2)
run_batched_forward_causalmask_bias_dropout_dispatch<
else if (
param.custom_mask_type == 1 || param.custom_mask_type == 2 ||
param.window_size > 0)
run_batched_forward_mask_bias_dropout_dispatch<
ck_tile::bf16_t,
true,
kHasBias,
Expand Down
Loading

0 comments on commit 56dba6b

Please sign in to comment.