Skip to content

Commit

Permalink
Decode and Prefill support (#3009)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #3009

X-link: facebookresearch/FBGEMM#104

This diff adds support for Triton-splitk kernel. Includes:
1/ prefill_varseq_attn and decode_attn
2/ dequantize kernel
3/ fused quantization in rope functions

TODO:
Dequantize + paged kv cache

Differential Revision: D60747287
  • Loading branch information
Aya-ZIbra authored and facebook-github-bot committed Aug 18, 2024
1 parent 537aeb3 commit 0f4a11f
Show file tree
Hide file tree
Showing 2 changed files with 138 additions and 153 deletions.
140 changes: 19 additions & 121 deletions fbgemm_gpu/experimental/gen_ai/src/kv_cache/kv_cache.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,9 @@ at::Tensor rope_qkv_varseq_prefill(
int64_t old_context_len,
double scaling_factor,
double lo_freq_factor,
double hi_freq_factor);
double hi_freq_factor,
std::optional<at::Tensor> qparam_k,
std::optional<at::Tensor> qparam_v);

at::Tensor rope_qkv_decoding(
at::Tensor XQ,
Expand All @@ -65,7 +67,9 @@ at::Tensor rope_qkv_decoding(
int64_t old_context_len,
double scaling_factor,
double lo_freq_factor,
double hi_freq_factor);
double hi_freq_factor,
std::optional<at::Tensor> qparam_k,
std::optional<at::Tensor> qparam_v);

at::Tensor xpos_qkv_varseq_prefill(
at::Tensor XQ,
Expand All @@ -88,7 +92,9 @@ at::Tensor xpos_qkv_varseq_prefill(
int64_t old_context_len,
double scaling_factor,
double lo_freq_factor,
double hi_freq_factor);
double hi_freq_factor,
std::optional<at::Tensor> qparam_k,
std::optional<at::Tensor> qparam_v);

at::Tensor xpos_qkv_decoding(
at::Tensor XQ,
Expand All @@ -112,7 +118,9 @@ at::Tensor xpos_qkv_decoding(
int64_t old_context_len,
double scaling_factor,
double lo_freq_factor,
double hi_freq_factor);
double hi_freq_factor,
std::optional<at::Tensor> qparam_k,
std::optional<at::Tensor> qparam_v);

std::tuple<at::Tensor, at::Tensor> dequantize_int4_cache(
at::Tensor cache_K,
Expand All @@ -121,121 +129,11 @@ std::tuple<at::Tensor, at::Tensor> dequantize_int4_cache(
std::optional<int64_t> num_groups);

std::tuple<at::Tensor, at::Tensor> dequantize_fp8_cache(
at::Tensor cache_K,
at::Tensor cache_V,
at::Tensor kv_seqlen);

at::Tensor mqa_attn(
at::Tensor XQ,
at::Tensor cache_K,
at::Tensor cache_V,
at::Tensor seq_positions,
double qk_scale,
std::optional<int64_t> num_groups,
int64_t cache_logical_dtype_int);

#define DEFAULT_PAGE_SIZE 64
#define STRING_(s) #s
#define STRING(x) STRING_(x)

at::Tensor rope_qkv_varseq_prefill(
at::Tensor XQ,
at::Tensor XK,
at::Tensor XV,
at::Tensor cache_K,
at::Tensor cache_V,
at::Tensor varseq_batch,
at::Tensor varseq_seqpos,
double theta,
std::optional<int64_t> num_groups,
std::optional<at::Tensor> block_tables,
int64_t page_size,
std::optional<at::Tensor> varseq_cache_seqpos,
int64_t cache_logical_dtype_int,
bool rope_scaling,
int64_t old_context_len,
double scaling_factor,
double lo_freq_factor,
double hi_freq_factor);

at::Tensor rope_qkv_decoding(
at::Tensor XQ,
at::Tensor XK,
at::Tensor XV,
at::Tensor cache_K,
at::Tensor cache_V,
at::Tensor seqpos,
double theta,
std::optional<int64_t> num_groups,
std::optional<at::Tensor> block_tables,
int64_t page_size,
std::optional<at::Tensor> actual_batch_size,
std::optional<at::Tensor> batch,
std::optional<at::Tensor> cache_seqpos,
int64_t cache_logical_dtype_int,
bool rope_scaling,
int64_t old_context_len,
double scaling_factor,
double lo_freq_factor,
double hi_freq_factor);

at::Tensor xpos_qkv_varseq_prefill(
at::Tensor XQ,
at::Tensor XK,
at::Tensor XV,
at::Tensor cache_K,
at::Tensor cache_V,
at::Tensor varseq_batch,
at::Tensor varseq_seqpos,
double theta,
double gamma,
double scale_base,
double exponent_offset,
std::optional<int64_t> num_groups,
std::optional<at::Tensor> block_tables,
int64_t page_size,
std::optional<at::Tensor> varseq_cache_seqpos,
int64_t cache_logical_dtype_int,
bool rope_scaling,
int64_t old_context_len,
double scaling_factor,
double lo_freq_factor,
double hi_freq_factor);

at::Tensor xpos_qkv_decoding(
at::Tensor XQ,
at::Tensor XK,
at::Tensor XV,
at::Tensor cache_K,
at::Tensor cache_V,
at::Tensor seqpos,
double theta,
double gamma,
double scale_base,
double exponent_offset,
std::optional<int64_t> num_groups,
std::optional<at::Tensor> block_tables,
int64_t page_size,
std::optional<at::Tensor> actual_batch_size,
std::optional<at::Tensor> batch,
std::optional<at::Tensor> cache_seqpos,
int64_t cache_logical_dtype_int,
bool rope_scaling,
int64_t old_context_len,
double scaling_factor,
double lo_freq_factor,
double hi_freq_factor);

std::tuple<at::Tensor, at::Tensor> dequantize_int4_cache(
at::Tensor cache_K,
at::Tensor cache_V,
at::Tensor kv_seqlen,
std::optional<int64_t> num_groups);

std::tuple<at::Tensor, at::Tensor> dequantize_fp8_cache(
at::Tensor cache_K,
at::Tensor cache_V,
at::Tensor kv_seqlen);
std::optional<at::Tensor> qparam_k,
std::optional<at::Tensor> qparam_v);

at::Tensor mqa_attn(
at::Tensor XQ,
Expand All @@ -248,23 +146,23 @@ at::Tensor mqa_attn(

TORCH_LIBRARY_FRAGMENT(fbgemm, m) {
m.def("rope_qkv_varseq_prefill(Tensor XQ, Tensor XK, Tensor XV, Tensor(a!) cache_K, Tensor(b!) cache_V, Tensor varseq_batch, Tensor varseq_seqpos, float theta, int? num_groups=1, Tensor? block_tables=None, int page_size=" STRING(
DEFAULT_PAGE_SIZE) ", Tensor? varseq_cache_seqpos=None, int cache_logical_dtype_int=0, bool rope_scaling=False, int old_context_len=8192, float scaling_factor=16, float lo_freq_factor=1, float hi_freq_factor=32) -> Tensor");
DEFAULT_PAGE_SIZE) ", Tensor? varseq_cache_seqpos=None, int cache_logical_dtype_int=0, bool rope_scaling=False, int old_context_len=8192, float scaling_factor=16, float lo_freq_factor=1, float hi_freq_factor=32, Tensor? qparam_k=None, Tensor? qparam_v=None) -> Tensor");
m.impl("rope_qkv_varseq_prefill", rope_qkv_varseq_prefill);
m.def("rope_qkv_decoding(Tensor XQ, Tensor XK, Tensor XV, Tensor(a!) cache_K, Tensor(b!) cache_V, Tensor seqpos, float theta, int? num_groups=1, Tensor? block_tables=None, int page_size=" STRING(
DEFAULT_PAGE_SIZE) ", Tensor? actual_batch_size=None, Tensor? batch=None, Tensor? cache_seqpos=None, int cache_logical_dtype_int=0, bool rope_scaling=False, int old_context_len=8192, float scaling_factor=16, float lo_freq_factor=1, float hi_freq_factor=32) -> Tensor");
DEFAULT_PAGE_SIZE) ", Tensor? actual_batch_size=None, Tensor? batch=None, Tensor? cache_seqpos=None, int cache_logical_dtype_int=0, bool rope_scaling=False, int old_context_len=8192, float scaling_factor=16, float lo_freq_factor=1, float hi_freq_factor=32, Tensor? qparam_k=None, Tensor? qparam_v=None) -> Tensor");
m.impl("rope_qkv_decoding", rope_qkv_decoding);
m.def("xpos_qkv_varseq_prefill(Tensor XQ, Tensor XK, Tensor XV, Tensor(a!) cache_K, Tensor(b!) cache_V, Tensor varseq_batch, Tensor varseq_seqpos, float theta, float gamma, float scale_base, float exponent_offset, int? num_groups=1, Tensor? block_tables=None, int page_size=" STRING(
DEFAULT_PAGE_SIZE) ", Tensor? varseq_cache_seqpos=None, int cache_logical_dtype_int=0, bool rope_scaling=False, int old_context_len=8192, float scaling_factor=16, float lo_freq_factor=1, float hi_freq_factor=32) -> Tensor");
DEFAULT_PAGE_SIZE) ", Tensor? varseq_cache_seqpos=None, int cache_logical_dtype_int=0, bool rope_scaling=False, int old_context_len=8192, float scaling_factor=16, float lo_freq_factor=1, float hi_freq_factor=32, Tensor? qparam_k=None, Tensor? qparam_v=None) -> Tensor");
m.impl("xpos_qkv_varseq_prefill", xpos_qkv_varseq_prefill);
m.def("xpos_qkv_decoding(Tensor XQ, Tensor XK, Tensor XV, Tensor(a!) cache_K, Tensor(b!) cache_V, Tensor seqpos, float theta, float gamma, float scale_base, float exponent_offset, int? num_groups=1, Tensor? block_tables=None, int page_size=" STRING(
DEFAULT_PAGE_SIZE) ", Tensor? actual_batch_size=None, Tensor? batch=None, Tensor? cache_seqpos=None, int cache_logical_dtype_int=0, bool rope_scaling=False, int old_context_len=8192, float scaling_factor=16, float lo_freq_factor=1, float hi_freq_factor=32) -> Tensor");
DEFAULT_PAGE_SIZE) ", Tensor? actual_batch_size=None, Tensor? batch=None, Tensor? cache_seqpos=None, int cache_logical_dtype_int=0, bool rope_scaling=False, int old_context_len=8192, float scaling_factor=16, float lo_freq_factor=1, float hi_freq_factor=32, Tensor? qparam_k=None, Tensor? qparam_v=None) -> Tensor");
m.impl("xpos_qkv_decoding", xpos_qkv_decoding);

m.def(
"dequantize_int4_cache(Tensor cache_K, Tensor cache_V, Tensor kv_seqlen, int? num_groups=1) -> (Tensor, Tensor)");
m.impl("dequantize_int4_cache", dequantize_int4_cache);
m.def(
"dequantize_fp8_cache(Tensor cache_K, Tensor cache_V, Tensor kv_seqlen) -> (Tensor, Tensor)");
"dequantize_fp8_cache(Tensor cache_K, Tensor cache_V, Tensor kv_seqlen, Tensor? qparam_k=None, Tensor? qparam_v=None) -> (Tensor, Tensor)");
m.impl("dequantize_fp8_cache", dequantize_fp8_cache);
}

Expand Down
Loading

0 comments on commit 0f4a11f

Please sign in to comment.