From 855bf54376751a1376c10fa66bbc4b27c8d27752 Mon Sep 17 00:00:00 2001 From: Ettore Di Giacinto Date: Wed, 17 Apr 2024 11:27:52 +0200 Subject: [PATCH] feat(grpc): return consumed token count and update response accordingly Fixes: #1920 Signed-off-by: Ettore Di Giacinto --- backend/backend.proto | 2 ++ backend/cpp/llama/grpc-server.cpp | 8 ++++++++ core/backend/llm.go | 6 ++++++ 3 files changed, 16 insertions(+) diff --git a/backend/backend.proto b/backend/backend.proto index 56d919efd3b0..62e1a1a64448 100644 --- a/backend/backend.proto +++ b/backend/backend.proto @@ -114,6 +114,8 @@ message PredictOptions { // The response message containing the result message Reply { bytes message = 1; + int32 tokens = 2; + int32 prompt_tokens = 3; } message ModelOptions { diff --git a/backend/cpp/llama/grpc-server.cpp b/backend/cpp/llama/grpc-server.cpp index a2e39a9c5f65..6fb086585f4e 100644 --- a/backend/cpp/llama/grpc-server.cpp +++ b/backend/cpp/llama/grpc-server.cpp @@ -2332,6 +2332,10 @@ class BackendServiceImpl final : public backend::Backend::Service { std::string completion_text = result.result_json.value("content", ""); reply.set_message(completion_text); + int32_t tokens_predicted = result.result_json.value("tokens_predicted", 0); + reply.set_tokens(tokens_predicted); + int32_t tokens_evaluated = result.result_json.value("tokens_evaluated", 0); + reply.set_prompt_tokens(tokens_evaluated); // Send the reply writer->Write(reply); @@ -2357,6 +2361,10 @@ class BackendServiceImpl final : public backend::Backend::Service { task_result result = llama.queue_results.recv(task_id); if (!result.error && result.stop) { completion_text = result.result_json.value("content", ""); + int32_t tokens_predicted = result.result_json.value("tokens_predicted", 0); + int32_t tokens_evaluated = result.result_json.value("tokens_evaluated", 0); + reply->set_prompt_tokens(tokens_evaluated); + reply->set_tokens(tokens_predicted); reply->set_message(completion_text); } else diff --git a/core/backend/llm.go b/core/backend/llm.go index 493dc25cab2b..a4d1e5f35e42 100644 --- a/core/backend/llm.go +++ b/core/backend/llm.go @@ -153,6 +153,12 @@ func ModelInference(ctx context.Context, s string, messages []schema.Message, im if err != nil { return LLMResponse{}, err } + if tokenUsage.Prompt == 0 { + tokenUsage.Prompt = int(reply.PromptTokens) + } + if tokenUsage.Completion == 0 { + tokenUsage.Completion = int(reply.Tokens) + } return LLMResponse{ Response: string(reply.Message), Usage: tokenUsage,