Skip to content

Commit

Permalink
llama : add flash attention (demo)
Browse files Browse the repository at this point in the history
  • Loading branch information
ggerganov committed Apr 5, 2023
1 parent 986b6ce commit 36ddd12
Showing 1 changed file with 27 additions and 0 deletions.
27 changes: 27 additions & 0 deletions llama.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,8 @@
#define LLAMA_USE_SCRATCH
#define LLAMA_MAX_SCRATCH_BUFFERS 16

#define LLAMA_USE_FLASH_ATTN

#define LLAMA_ASSERT(x) \
do { \
if (!(x)) { \
Expand Down Expand Up @@ -829,6 +831,30 @@ static bool llama_eval_internal(
ggml_build_forward_expand(&gf, ggml_cpy(ctx0, Vcur, v));
}

#ifdef LLAMA_USE_FLASH_ATTN
struct ggml_tensor * Q =
ggml_permute(ctx0,
ggml_cpy(ctx0,
Qcur,
ggml_new_tensor_3d(ctx0, GGML_TYPE_F16, n_embd/n_head, n_head, N)),
0, 2, 1, 3);

struct ggml_tensor * K =
ggml_permute(ctx0,
ggml_reshape_3d(ctx0,
ggml_view_1d(ctx0, kv_self.k, (n_past + N)*n_embd, il*n_ctx*ggml_element_size(kv_self.k)*n_embd),
n_embd/n_head, n_head, n_past + N),
0, 2, 1, 3);

struct ggml_tensor * V =
ggml_view_3d(ctx0, kv_self.v,
n_past + N, n_embd/n_head, n_head,
n_ctx*ggml_element_size(kv_self.v),
n_ctx*ggml_element_size(kv_self.v)*n_embd/n_head,
il*n_ctx*ggml_element_size(kv_self.v)*n_embd);

struct ggml_tensor * KQV = ggml_flash_attn(ctx0, Q, K, V, true);
#else
struct ggml_tensor * Q =
ggml_permute(ctx0,
Qcur,
Expand Down Expand Up @@ -872,6 +898,7 @@ static bool llama_eval_internal(
// is there a better way?
struct ggml_tensor * V_cont = ggml_cpy(ctx0, V, ggml_new_tensor_3d(ctx0, kv_self.v->type, n_past + N, n_embd/n_head, n_head));
struct ggml_tensor * KQV = ggml_mul_mat(ctx0, V_cont, KQ_soft_max);
#endif
#endif

// KQV_merged = KQV.permute(0, 2, 1, 3)
Expand Down

0 comments on commit 36ddd12

Please sign in to comment.