From d29faeed313a05fff47acdf42e38cf4ee6ee3bb0 Mon Sep 17 00:00:00 2001
From: vansangpfiev <vansangpfiev@gmail.com>
Date: Wed, 5 Jun 2024 16:55:00 +0700
Subject: [PATCH] feat: add cache_type parameter (#75)

* feat: expose cache_type_k and cache_type_v parameters

* fix: bugfix

* fix: validate cache type

---------

Co-authored-by: vansangpfiev <sang@jan.ai>
---
 src/llama_engine.cc | 29 +++++++++++++++++++++++++++--
 1 file changed, 27 insertions(+), 2 deletions(-)

diff --git a/src/llama_engine.cc b/src/llama_engine.cc
index eb705d6..4c89379 100644
--- a/src/llama_engine.cc
+++ b/src/llama_engine.cc
@@ -11,6 +11,17 @@ constexpr const int k400BadRequest = 400;
 constexpr const int k409Conflict = 409;
 constexpr const int k500InternalServerError = 500;
 
+constexpr const auto kTypeF16 = "f16";
+constexpr const auto kType_Q8_0 = "q8_0";
+constexpr const auto kType_Q4_0 = "q4_0";
+
+bool IsValidCacheType(const std::string& c) {
+  if(c != kTypeF16 && c != kType_Q8_0 && c!= kType_Q4_0) {
+    return false;
+  }
+  return true;
+}
+
 struct InferenceState {
   int task_id;
   LlamaServerContext& llama;
@@ -341,13 +352,27 @@ bool LlamaEngine::LoadModelImpl(std::shared_ptr<Json::Value> jsonBody) {
             .asInt();
     params.cont_batching = jsonBody->get("cont_batching", false).asBool();
 
+    params.cache_type_k = jsonBody->get("cache_type", kTypeF16).asString();
+    if (!IsValidCacheType(params.cache_type_k)) {
+      LOG_WARN << "Unsupported cache type: " << params.cache_type_k
+               << ", fallback to f16";
+      params.cache_type_k = kTypeF16;
+    }
+    params.cache_type_v = params.cache_type_k;
+    LOG_DEBUG << "cache_type: " << params.cache_type_k;
+
     // Check for backward compatible
     auto fa0 = jsonBody->get("flash-attn", false).asBool();
     auto fa1 = jsonBody->get("flash_attn", false).asBool();
-    params.flash_attn = fa0 || fa1;
+    auto force_enable_fa = params.cache_type_k != kTypeF16;
+    if(force_enable_fa) {
+      LOG_DEBUG << "Using KV cache quantization, force enable Flash Attention";
+    }
+    params.flash_attn = fa0 || fa1 || force_enable_fa;
     if (params.flash_attn) {
       LOG_DEBUG << "Enabled Flash Attention";
     }
+
     server_map_[model_id].caching_enabled =
         jsonBody->get("caching_enabled", false).asBool();
     server_map_[model_id].user_prompt =
@@ -556,7 +581,7 @@ void LlamaEngine::HandleInferenceImpl(
     LOG_INFO << "Request " << request_id << ": " << formatted_output;
   }
 
-  data["prompt"] = formatted_output; 
+  data["prompt"] = formatted_output;
   for (const auto& sw : stop_words_json) {
     stopWords.push_back(sw.asString());
   }