Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

FA3 kvcache + split kv + gqa parallelization #1236

Merged
merged 106 commits into from
Oct 15, 2024
Merged
Changes from 1 commit
Commits
Show all changes
106 commits
Select commit Hold shift + click to select a range
9dbd114
Adding the flash3 kv cache API. Just compiling for now.
ganeshcolfax Aug 15, 2024
7ee8ee4
start extending seqlen traits for kv cache
jayhshah Aug 15, 2024
3fdf7ee
added cache_batch_idx.
ganeshcolfax Aug 15, 2024
84e31c2
adding python interface.
ganeshcolfax Aug 15, 2024
e16053a
add test_kvcache.py.
ganeshcolfax Aug 15, 2024
be0e36d
enable use of actual seqlen for kv cache
jayhshah Aug 16, 2024
38ad0ac
add new param to handle cache_batch_size
jayhshah Aug 16, 2024
57de4da
add semaphore for kv cache causal
jayhshah Aug 16, 2024
435f86d
add comparision with fa2.
ganeshcolfax Aug 16, 2024
74f160b
change template parameter for SeqLenTraits for ease of further extension
jayhshah Aug 19, 2024
13bad55
modify seqlentraits for gqa parallelism
jayhshah Aug 19, 2024
ccf5b9b
modify Ktraits for decoding QO layouts
jayhshah Aug 20, 2024
fc8f704
decouple types of seqlen traits q and k
jayhshah Aug 20, 2024
d2f049c
change logic of Q loads for gqa parallelization
jayhshah Aug 20, 2024
c6311e4
fix o strides
jayhshah Aug 20, 2024
535b827
complete gqa parallel changes for non-causal
jayhshah Aug 22, 2024
5704a1f
fix some errors
jayhshah Aug 22, 2024
64a9cfb
add causal logic
jayhshah Aug 26, 2024
a06f1f9
add to kv cache api
jayhshah Aug 26, 2024
0c4cea9
add in lse writeout and store zero
jayhshah Sep 4, 2024
0a1a0c2
refactor for split kv
jayhshah Sep 5, 2024
1135dbd
re-enable fp16/bf16 fwd
jayhshah Sep 5, 2024
68ff3f7
add 1 mma warpgroup option, enable splitkv for hdim 256
jayhshah Sep 6, 2024
23bf5b0
fix bug with finalize for split kv
jayhshah Sep 12, 2024
ac19795
delete unused files
jayhshah Sep 17, 2024
1a5e40a
add hid=64.
ganeshcolfax Sep 13, 2024
c75c243
change flash api for rebase
jayhshah Sep 18, 2024
e9db102
avoid redundant compilation with combine kernel by only including nee…
jayhshah Sep 19, 2024
9250969
change Element to OutputType for template param in combine kernel. On…
jayhshah Sep 19, 2024
ec0130f
fix wrong tile size for hdim 64
jayhshah Sep 19, 2024
68b4bb9
revert OutputType change
jayhshah Sep 19, 2024
9c97808
changes for correct lse write out for splits=1 and splits > 1 case.
ganeshcolfax Sep 19, 2024
ecc5c49
added num_split_heuristics.
ganeshcolfax Sep 20, 2024
f3e5bd4
update parameters
jayhshah Sep 20, 2024
78736b4
remove unused code
jayhshah Sep 20, 2024
e8c7b2e
add num_split_heuristics.
ganeshcolfax Sep 20, 2024
75a6ce2
adding block_n and block_m for different headdim.
ganeshcolfax Sep 21, 2024
cf5bd5c
initialize semaphore when num splits != 1
jayhshah Sep 21, 2024
ac96c37
change combine kernel to condition on lse=-INF
jayhshah Sep 23, 2024
099ca28
add gqa decoding logic.
ganeshcolfax Sep 21, 2024
f53703b
add split kv heuristic modifications
jayhshah Sep 25, 2024
a30863f
recent version.
ganeshcolfax Sep 24, 2024
ffa48eb
more refactoring.
ganeshcolfax Sep 25, 2024
4a77193
update test script to use heuristic
jayhshah Sep 25, 2024
3615696
Add some more cases to test script and raise thresholds a bit for max…
jayhshah Sep 25, 2024
24b4b4f
add reference from python.
ganeshcolfax Sep 25, 2024
c516d63
Adding another test case.
ganeshcolfax Sep 25, 2024
9a4941c
add variable seqlen case.
ganeshcolfax Sep 26, 2024
70ff847
all cases passed.
ganeshcolfax Sep 26, 2024
e36e004
change fp8 code path to allow for split kernel and kv cache without p…
jayhshah Sep 27, 2024
cd55fb3
set correct tolerance limit
ganeshcolfax Sep 26, 2024
2472e5e
add 'in principle' fp8 kv cache support
jayhshah Sep 28, 2024
b5cac6d
rebase with Is_local disabled temporarily
jayhshah Sep 30, 2024
6111666
consolidate nblock min max methods
jayhshah Sep 30, 2024
be481ca
add Is_local back in
jayhshah Sep 30, 2024
81d4024
prune unused code
jayhshah Sep 30, 2024
64a0a91
enable Is_local with fp8
jayhshah Sep 30, 2024
0f560b7
update composable kernel
jayhshah Sep 30, 2024
cffef15
separate out fp8 in test_flash_attn
jayhshah Sep 30, 2024
2b840ef
fix the test case and re-factor too.
ganeshcolfax Sep 30, 2024
5df67d2
Merge branch 'fa3-kvcache-gqa' of github.com:Dao-AILab/flash-attentio…
ganeshcolfax Sep 30, 2024
aa45d75
dont write out zero for split kernel, only lse=-inf
jayhshah Sep 30, 2024
33f20a3
Merge branch 'fa3-kvcache-gqa' of github.com:Dao-AILab/flash-attentio…
jayhshah Sep 30, 2024
f77d9f7
fix composable kernel issue again
jayhshah Sep 30, 2024
5e3864f
change default output type of fp8 kernel to bf16
jayhshah Oct 1, 2024
16eb1e5
remove deprecated fp8 code
jayhshah Oct 1, 2024
7940377
correct indent
jayhshah Oct 1, 2024
31c71e0
add log max splits based on num splits to static switch
jayhshah Oct 1, 2024
6bb1092
change seq len class per discussion
jayhshah Oct 1, 2024
0085f04
add fp8 test case.
ganeshcolfax Oct 2, 2024
eaf8898
fix submodule
jayhshah Oct 2, 2024
a44596f
re-commiting.
ganeshcolfax Oct 2, 2024
c0c58ee
Revert "re-commiting."
ganeshcolfax Oct 2, 2024
b8f9dc2
change fp8 tolerances to be smaller
jayhshah Oct 2, 2024
49f1849
lower rtol for fp8 a bit
jayhshah Oct 2, 2024
bb230b8
separate gqa compilation
jayhshah Oct 2, 2024
03200a7
removed old gqa cu files and unified methods
jayhshah Oct 2, 2024
930c8ca
reorg mma code for less redundancy
jayhshah Oct 3, 2024
bc4b872
add crude hdim 64 heuristic
jayhshah Oct 3, 2024
fff4b5c
add split kv benchmark script
jayhshah Oct 4, 2024
aa0e699
move descale tensor declarations outside of conditional
jayhshah Oct 4, 2024
785d978
fix bug with fp8 q layout
jayhshah Oct 7, 2024
8fbefa8
adding rmem to gmem. (Not validating yet).
ganeshcolfax Oct 8, 2024
f0b4946
changes to use tiledcopy (still not passing).
ganeshcolfax Oct 8, 2024
8f45a8c
tests passing now for non-gqa impl
jayhshah Oct 9, 2024
4a4dbd2
move IsRegToGmem
jayhshah Oct 9, 2024
a075e76
handle gqa_parallel with rmem-to-gmem. Not validating yet.
ganeshcolfax Oct 10, 2024
dc2c952
compiles and builes. Not validates.
ganeshcolfax Oct 10, 2024
e49cb5f
passes except for hdim=256.
ganeshcolfax Oct 10, 2024
d437d3d
remove smem usage for when rmem -> gmem epilogue is used
jayhshah Oct 11, 2024
ab5d336
better writeout logic with vectorization
jayhshah Oct 12, 2024
7169b23
unify rmem -> gmem methods
jayhshah Oct 12, 2024
551b91f
uniform notation
jayhshah Oct 12, 2024
eb9c0ee
add rmem -> gmem for fp8
jayhshah Oct 14, 2024
b0f067e
revert epi change for fp8 due to measured perf regression
jayhshah Oct 14, 2024
35f3542
refactor names
jayhshah Oct 14, 2024
8374e1f
remove test code
jayhshah Oct 14, 2024
1ecf821
remove constexpr checks for actual seqlen in mainloop
jayhshah Oct 14, 2024
7c1473e
remove Is_batch_dynamic from seqlen traits and handle fp8 perf regres…
jayhshah Oct 15, 2024
c06cc0b
change cu_seqlens_k to seqused_k for kv cache api
jayhshah Oct 15, 2024
a7cce59
adjust tolerances in test script for kv cache
jayhshah Oct 15, 2024
8efb953
remove commented out code
jayhshah Oct 15, 2024
b3d60fa
prune more dead code
jayhshah Oct 15, 2024
50cb90a
comment out unimplemented kwargs from flash_attn_with_kvcache
jayhshah Oct 15, 2024
dec7dee
fix integer sign compare warning
jayhshah Oct 15, 2024
9b6cba1
remove some debug code
jayhshah Oct 15, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
prune more dead code
  • Loading branch information
jayhshah committed Oct 15, 2024
commit b3d60fa3a56ff58d3c2b1f27177e990573e3621d
19 changes: 8 additions & 11 deletions hopper/epilogue_fwd_sm90_tma.hpp
Original file line number Diff line number Diff line change
@@ -112,7 +112,6 @@ struct CollectiveEpilogueFwd {
Stride<_4, _32, _1, _0>>;
using ValueLayoutrO = Layout<Shape<_1, _2, Shape<_2, _2>, Int<kHeadDim/16>>,
Stride<_0, _2, Stride<_4, _1>, _8>>;
// using AccessTyperO = std::conditional_t<cutlass::sizeof_bits_v<Element> == 16, uint16_t, uint32_t>;
using TiledCopyrO = decltype(make_tiled_copy(Copy_Atom<UniversalCopy<Element>, Element>{},
ThreadLayoutrO{}, ValueLayoutrO{}));
using TiledCopyShaperO = Shape<_8, Int<kBlockM/8>, _16, Int<kHeadDim/16>>;
@@ -248,22 +247,20 @@ struct CollectiveEpilogueFwd {
}
}
}

int write_warp_idx = kNWarps - 1;
if constexpr(!No_smem_O) {
if (cutlass::canonical_warp_idx_sync() == write_warp_idx) {
cutlass::arch::NamedBarrier::sync(
NumMmaThreads + cutlass::NumThreadsPerWarp,
cutlass::arch::ReservedNamedBarriers::EpilogueBarrier
);
}
}

if constexpr (No_smem_O) {
flash::write_rmem_to_gmem<Seqlen_traits::UseGQAPacking, epi_column_permute>(
tOrO_out, epilogue_params.ptr_O, epilogue_params.layout_O, TileShapeOCopy{},
m_block, h_block, bidh, bidh_kv, bidb, n_split_idx,
tiled_mma, seqlen_traits_q, thread_idx);
} else {
int write_warp_idx = kNWarps - 1;
if (cutlass::canonical_warp_idx_sync() == write_warp_idx) {
cutlass::arch::NamedBarrier::sync(
NumMmaThreads + cutlass::NumThreadsPerWarp,
cutlass::arch::ReservedNamedBarriers::EpilogueBarrier
);
}
TiledCopyO gmem_tiled_copy_O;
Tensor sO_out = make_tensor(make_smem_ptr(shared_storage.smem_o.data()), SmemLayoutOCopy{});
if constexpr(!Seqlen_traits::UseGQAPacking) {
7 changes: 1 addition & 6 deletions hopper/flash_fwd_kernel.h
Original file line number Diff line number Diff line change
@@ -33,8 +33,6 @@ __global__ void __launch_bounds__(Ktraits::kNWarps * cutlass::NumThreadsPerWarp,
) {

using Element = typename Ktraits::Element;
using ElementAccum = typename Ktraits::ElementAccum;
using SoftType = ElementAccum;
using TileShape_MNK = typename Ktraits::TileShape_MNK;
using ClusterShape = typename Ktraits::ClusterShape_MNK;

@@ -47,7 +45,7 @@ __global__ void __launch_bounds__(Ktraits::kNWarps * cutlass::NumThreadsPerWarp,
static constexpr int kBlockM = Ktraits::kBlockM;
static constexpr int kBlockH = Ktraits::kBlockH;
// static constexpr int kBlockN = Ktraits::kBlockN;
// constexpr int kHeadDim = Ktraits::kHeadDim;
// static constexpr int kHeadDim = Ktraits::kHeadDim;

using CollectiveMainloop = CollectiveMainloopFwd<Ktraits, Is_causal, Is_local, Seqlen_traits, Seqlen_traits_Q>;
using CollectiveEpilogue = CollectiveEpilogueFwd<Ktraits, Seqlen_traits_Q>;
@@ -222,15 +220,12 @@ __global__ void __launch_bounds__(Ktraits::kNWarps * cutlass::NumThreadsPerWarp,

using Element = typename Ktraits::Element;
static_assert(cutlass::sizeof_bits_v<Element> == 8);
using ElementAccum = typename Ktraits::ElementAccum;
using SoftType = ElementAccum;
using TileShape_MNK = typename Ktraits::TileShape_MNK;
using ClusterShape = typename Ktraits::ClusterShape_MNK;

static_assert(Ktraits::Is_WS);
static constexpr bool Is_WS = Ktraits::Is_WS;
static constexpr bool No_smem_O = Ktraits::No_smem_O;
// static constexpr bool UseVarSeqLen = Seqlen_traits::UseVarSeqLen;

static constexpr int NumMmaThreads = size(typename Ktraits::TiledMma0{});
static constexpr int NumCopyThreads = !Is_WS ? 0 : cutlass::NumThreadsPerWarpGroup;
6 changes: 2 additions & 4 deletions hopper/flash_fwd_launch_template.h
Original file line number Diff line number Diff line change
@@ -108,11 +108,9 @@ void run_flash_fwd(Flash_fwd_params &params, cudaStream_t stream) {

int num_blocks_m = cutlass::ceil_div(params.seqlen_q, Kernel_traits::kBlockM/Kernel_traits::kBlockH);
num_blocks_m = cutlass::ceil_div(num_blocks_m, size<0>(ClusterShape{})) * size<0>(ClusterShape{});
int num_grid_heads = params.h_k * ceil_div(params.h_h_k_ratio, Kernel_traits::kBlockH);

// std::cout << "num blocks m = " << num_blocks_m << " num grid heads" << num_grid_heads << std::endl;
int num_blocks_h = params.h_k * ceil_div(params.h_h_k_ratio, Kernel_traits::kBlockH);
typename Scheduler::Arguments scheduler_args =
{num_blocks_m, Is_split ? params.num_splits : 1, num_grid_heads, params.b, params.tile_count_semaphore};
{num_blocks_m, Is_split ? params.num_splits : 1, num_blocks_h, params.b, params.tile_count_semaphore};
typename Scheduler::Params scheduler_params = Scheduler::to_underlying_arguments(scheduler_args);

// Get the ptr to kernel function.
14 changes: 5 additions & 9 deletions hopper/mainloop_fwd_sm90_tma_gmma_ws.hpp
Original file line number Diff line number Diff line change
@@ -86,11 +86,11 @@ struct CollectiveMainloopFwd {
using TileShape_MNK = typename Ktraits::TileShape_MNK;
using ClusterShape = typename Ktraits::ClusterShape_MNK;

static constexpr int kStages = Ktraits::kStages;
static constexpr int kHeadDim = Ktraits::kHeadDim;
// static constexpr int kBlockM = Ktraits::kBlockM;
// static constexpr int kBlockN = Ktraits::kBlockN;
// static constexpr int kBlockH = Ktraits::kBlockH;
static constexpr int kStages = Ktraits::kStages;
static constexpr int kHeadDim = Ktraits::kHeadDim;
// static constexpr int kBlockM = Ktraits::kBlockM;
// static constexpr int kBlockN = Ktraits::kBlockN;
// static constexpr int kBlockH = Ktraits::kBlockH;
static constexpr bool Is_split = Ktraits::Is_split;
static constexpr bool No_smem_O = Ktraits::No_smem_O;

@@ -250,7 +250,6 @@ struct CollectiveMainloopFwd {
n_block_max = cute::ceil_div(seqlen_k, kBlockN);

if constexpr(Is_split) {
// int const num_n_blocks = ceil_div(seqlen_k, kBlockN);
int const n_blocks_per_split
= mainloop_params.num_splits_divmod.divide(n_block_max + int(mainloop_params.num_splits_divmod) - 1);
n_block_min = n_split_idx * n_blocks_per_split;
@@ -360,7 +359,6 @@ struct CollectiveMainloopFwd {
}
}

// int n_block_max = get_n_block_max(mainloop_params, m_block, seqlen_traits_q, seqlen_traits_k);
int n_block = n_block_max - 1;

int lane_predicate = cute::elect_one_sync();
@@ -498,7 +496,6 @@ struct CollectiveMainloopFwd {
}
}

// int n_block_max = get_n_block_max(mainloop_params, m_block, seqlen_traits_q, seqlen_traits_k);
int n_block = n_block_max - 1;

int lane_predicate = cute::elect_one_sync();
@@ -763,7 +760,6 @@ struct CollectiveMainloopFwd {
Tensor scores_scale = make_fragment_like(softmax.row_max);
clear(scores_scale);

// TODO: modify this for split kv to eliminate superfluous masking steps
constexpr int n_masking_steps = !Is_causal ? 1 : cute::ceil_div(kBlockM_div_H, kBlockN) + 1;
// Only go through these if Is_causal, since n_masking_steps = 1 when !Is_causal
#pragma unroll