Skip to content

Commit

Permalink
feat: add cache_type parameter (#75)
Browse files Browse the repository at this point in the history
* feat: expose cache_type_k and cache_type_v parameters

* fix: bugfix

* fix: validate cache type

---------

Co-authored-by: vansangpfiev <[email protected]>
  • Loading branch information
vansangpfiev and sangjanai authored Jun 5, 2024
1 parent 9875632 commit d29faee
Showing 1 changed file with 27 additions and 2 deletions.
29 changes: 27 additions & 2 deletions src/llama_engine.cc
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,17 @@ constexpr const int k400BadRequest = 400;
constexpr const int k409Conflict = 409;
constexpr const int k500InternalServerError = 500;

constexpr const auto kTypeF16 = "f16";
constexpr const auto kType_Q8_0 = "q8_0";
constexpr const auto kType_Q4_0 = "q4_0";

bool IsValidCacheType(const std::string& c) {
if(c != kTypeF16 && c != kType_Q8_0 && c!= kType_Q4_0) {
return false;
}
return true;
}

struct InferenceState {
int task_id;
LlamaServerContext& llama;
Expand Down Expand Up @@ -341,13 +352,27 @@ bool LlamaEngine::LoadModelImpl(std::shared_ptr<Json::Value> jsonBody) {
.asInt();
params.cont_batching = jsonBody->get("cont_batching", false).asBool();

params.cache_type_k = jsonBody->get("cache_type", kTypeF16).asString();
if (!IsValidCacheType(params.cache_type_k)) {
LOG_WARN << "Unsupported cache type: " << params.cache_type_k
<< ", fallback to f16";
params.cache_type_k = kTypeF16;
}
params.cache_type_v = params.cache_type_k;
LOG_DEBUG << "cache_type: " << params.cache_type_k;

// Check for backward compatible
auto fa0 = jsonBody->get("flash-attn", false).asBool();
auto fa1 = jsonBody->get("flash_attn", false).asBool();
params.flash_attn = fa0 || fa1;
auto force_enable_fa = params.cache_type_k != kTypeF16;
if(force_enable_fa) {
LOG_DEBUG << "Using KV cache quantization, force enable Flash Attention";
}
params.flash_attn = fa0 || fa1 || force_enable_fa;
if (params.flash_attn) {
LOG_DEBUG << "Enabled Flash Attention";
}

server_map_[model_id].caching_enabled =
jsonBody->get("caching_enabled", false).asBool();
server_map_[model_id].user_prompt =
Expand Down Expand Up @@ -556,7 +581,7 @@ void LlamaEngine::HandleInferenceImpl(
LOG_INFO << "Request " << request_id << ": " << formatted_output;
}

data["prompt"] = formatted_output;
data["prompt"] = formatted_output;
for (const auto& sw : stop_words_json) {
stopWords.push_back(sw.asString());
}
Expand Down

0 comments on commit d29faee

Please sign in to comment.