Skip to content

Commit

Permalink
Feat/support logit bias (#270)
Browse files Browse the repository at this point in the history
* feat: support logit bias

* Fix bug building

* Format code

* restore ram and vram usage

* remmove unnecessary include
  • Loading branch information
nguyenhoangthuan99 authored Oct 29, 2024
1 parent 0c8d0d7 commit 185a7cf
Show file tree
Hide file tree
Showing 2 changed files with 59 additions and 4 deletions.
58 changes: 54 additions & 4 deletions src/chat_completion_request.h
Original file line number Diff line number Diff line change
@@ -1,8 +1,35 @@
#pragma once
#include <json.hpp>
#include "json/value.h"
#include "sampling.h"

namespace llama::inferences {

nlohmann::json ConvertJsonCppToNlohmann(const Json::Value& input) {
if (input.isNull()) {
return nullptr;
} else if (input.isBool()) {
return input.asBool();
} else if (input.isInt()) {
return input.asInt();
} else if (input.isDouble()) {
return input.asDouble();
} else if (input.isString()) {
return input.asString();
} else if (input.isArray()) {
nlohmann::json arr = nlohmann::json::array();
for (const auto& elem : input) {
arr.push_back(ConvertJsonCppToNlohmann(elem));
}
return arr;
} else if (input.isObject()) {
nlohmann::json obj = nlohmann::json::object();
for (const auto& key : input.getMemberNames()) {
obj[key] = ConvertJsonCppToNlohmann(input[key]);
}
return obj;
}
return nullptr;
}
struct ChatCompletionRequest {
bool stream = false;
int max_tokens = 500;
Expand Down Expand Up @@ -31,6 +58,21 @@ struct ChatCompletionRequest {
int n_probs = 0;
int min_keep = 0;
std::string grammar;
Json::Value logit_bias = Json::Value(Json::arrayValue);

static Json::Value ConvertLogitBiasToArray(const Json::Value& input) {
Json::Value result(Json::arrayValue);
if (input.isObject()) {
const auto& memberNames = input.getMemberNames();
for (const auto& tokenStr : memberNames) {
Json::Value pair(Json::arrayValue);
pair.append(std::stoi(tokenStr));
pair.append(input[tokenStr].asFloat());
result.append(pair);
}
}
return result;
}
};

inline ChatCompletionRequest fromJson(std::shared_ptr<Json::Value> jsonBody) {
Expand All @@ -50,14 +92,17 @@ inline ChatCompletionRequest fromJson(std::shared_ptr<Json::Value> jsonBody) {
completion.model_id = (*jsonBody).get("model", {}).asString();

completion.seed = (*jsonBody).get("seed", -1).asInt();
completion.dynatemp_range = (*jsonBody).get("dynatemp_range", 0.0f).asFloat();
completion.dynatemp_exponent = (*jsonBody).get("dynatemp_exponent", 0.0f).asFloat();
completion.dynatemp_range =
(*jsonBody).get("dynatemp_range", 0.0f).asFloat();
completion.dynatemp_exponent =
(*jsonBody).get("dynatemp_exponent", 0.0f).asFloat();
completion.top_k = (*jsonBody).get("top_k", 40).asInt();
completion.min_p = (*jsonBody).get("min_p", 0.05f).asFloat();
completion.tfs_z = (*jsonBody).get("tfs_z", 1.0f).asFloat();
completion.typ_p = (*jsonBody).get("typ_p", 1.0f).asFloat();
completion.repeat_last_n = (*jsonBody).get("repeat_last_n", 64).asInt();
completion.penalty_repeat = (*jsonBody).get("repeat_penalty", 1.1f).asFloat();
completion.penalty_repeat =
(*jsonBody).get("repeat_penalty", 1.1f).asFloat();
completion.mirostat = (*jsonBody).get("mirostat", false).asBool();
completion.mirostat_tau = (*jsonBody).get("mirostat_tau", 5.0f).asFloat();
completion.mirostat_eta = (*jsonBody).get("mirostat_eta", 0.1f).asFloat();
Expand All @@ -66,6 +111,11 @@ inline ChatCompletionRequest fromJson(std::shared_ptr<Json::Value> jsonBody) {
completion.n_probs = (*jsonBody).get("n_probs", 0).asInt();
completion.min_keep = (*jsonBody).get("min_keep", 0).asInt();
completion.grammar = (*jsonBody).get("grammar", "").asString();
const Json::Value& input_logit_bias = (*jsonBody)["logit_bias"];
if (!input_logit_bias.isNull()) {
completion.logit_bias =
ChatCompletionRequest::ConvertLogitBiasToArray(input_logit_bias);
}
}
return completion;
}
Expand Down
5 changes: 5 additions & 0 deletions src/llama_engine.cc
Original file line number Diff line number Diff line change
Expand Up @@ -635,6 +635,11 @@ void LlamaEngine::HandleInferenceImpl(
data["n_probs"] = completion.n_probs;
data["min_keep"] = completion.min_keep;
data["grammar"] = completion.grammar;
json arr = json::array();
for (const auto& elem : completion.logit_bias) {
arr.push_back(llama::inferences::ConvertJsonCppToNlohmann(elem));
}
data["logit_bias"] = std::move(arr);
int n_probs = completion.n_probs;
const Json::Value& messages = completion.messages;

Expand Down

0 comments on commit 185a7cf

Please sign in to comment.