From d29faeed313a05fff47acdf42e38cf4ee6ee3bb0 Mon Sep 17 00:00:00 2001 From: vansangpfiev <vansangpfiev@gmail.com> Date: Wed, 5 Jun 2024 16:55:00 +0700 Subject: [PATCH] feat: add cache_type parameter (#75) * feat: expose cache_type_k and cache_type_v parameters * fix: bugfix * fix: validate cache type --------- Co-authored-by: vansangpfiev <sang@jan.ai> --- src/llama_engine.cc | 29 +++++++++++++++++++++++++++-- 1 file changed, 27 insertions(+), 2 deletions(-) diff --git a/src/llama_engine.cc b/src/llama_engine.cc index eb705d6..4c89379 100644 --- a/src/llama_engine.cc +++ b/src/llama_engine.cc @@ -11,6 +11,17 @@ constexpr const int k400BadRequest = 400; constexpr const int k409Conflict = 409; constexpr const int k500InternalServerError = 500; +constexpr const auto kTypeF16 = "f16"; +constexpr const auto kType_Q8_0 = "q8_0"; +constexpr const auto kType_Q4_0 = "q4_0"; + +bool IsValidCacheType(const std::string& c) { + if(c != kTypeF16 && c != kType_Q8_0 && c!= kType_Q4_0) { + return false; + } + return true; +} + struct InferenceState { int task_id; LlamaServerContext& llama; @@ -341,13 +352,27 @@ bool LlamaEngine::LoadModelImpl(std::shared_ptr<Json::Value> jsonBody) { .asInt(); params.cont_batching = jsonBody->get("cont_batching", false).asBool(); + params.cache_type_k = jsonBody->get("cache_type", kTypeF16).asString(); + if (!IsValidCacheType(params.cache_type_k)) { + LOG_WARN << "Unsupported cache type: " << params.cache_type_k + << ", fallback to f16"; + params.cache_type_k = kTypeF16; + } + params.cache_type_v = params.cache_type_k; + LOG_DEBUG << "cache_type: " << params.cache_type_k; + // Check for backward compatible auto fa0 = jsonBody->get("flash-attn", false).asBool(); auto fa1 = jsonBody->get("flash_attn", false).asBool(); - params.flash_attn = fa0 || fa1; + auto force_enable_fa = params.cache_type_k != kTypeF16; + if(force_enable_fa) { + LOG_DEBUG << "Using KV cache quantization, force enable Flash Attention"; + } + params.flash_attn = fa0 || fa1 || force_enable_fa; if (params.flash_attn) { LOG_DEBUG << "Enabled Flash Attention"; } + server_map_[model_id].caching_enabled = jsonBody->get("caching_enabled", false).asBool(); server_map_[model_id].user_prompt = @@ -556,7 +581,7 @@ void LlamaEngine::HandleInferenceImpl( LOG_INFO << "Request " << request_id << ": " << formatted_output; } - data["prompt"] = formatted_output; + data["prompt"] = formatted_output; for (const auto& sw : stop_words_json) { stopWords.push_back(sw.asString()); }