Skip to content

Commit

Permalink
feat(grpc): return consumed token count and update response accordingly
Browse files Browse the repository at this point in the history
Fixes: #1920

Signed-off-by: Ettore Di Giacinto <[email protected]>
  • Loading branch information
mudler committed Apr 17, 2024
1 parent 70cad33 commit 855bf54
Show file tree
Hide file tree
Showing 3 changed files with 16 additions and 0 deletions.
2 changes: 2 additions & 0 deletions backend/backend.proto
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,8 @@ message PredictOptions {
// The response message containing the result
message Reply {
bytes message = 1;
int32 tokens = 2;
int32 prompt_tokens = 3;
}

message ModelOptions {
Expand Down
8 changes: 8 additions & 0 deletions backend/cpp/llama/grpc-server.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2332,6 +2332,10 @@ class BackendServiceImpl final : public backend::Backend::Service {
std::string completion_text = result.result_json.value("content", "");

reply.set_message(completion_text);
int32_t tokens_predicted = result.result_json.value("tokens_predicted", 0);
reply.set_tokens(tokens_predicted);
int32_t tokens_evaluated = result.result_json.value("tokens_evaluated", 0);
reply.set_prompt_tokens(tokens_evaluated);

// Send the reply
writer->Write(reply);
Expand All @@ -2357,6 +2361,10 @@ class BackendServiceImpl final : public backend::Backend::Service {
task_result result = llama.queue_results.recv(task_id);
if (!result.error && result.stop) {
completion_text = result.result_json.value("content", "");
int32_t tokens_predicted = result.result_json.value("tokens_predicted", 0);
int32_t tokens_evaluated = result.result_json.value("tokens_evaluated", 0);
reply->set_prompt_tokens(tokens_evaluated);
reply->set_tokens(tokens_predicted);
reply->set_message(completion_text);
}
else
Expand Down
6 changes: 6 additions & 0 deletions core/backend/llm.go
Original file line number Diff line number Diff line change
Expand Up @@ -153,6 +153,12 @@ func ModelInference(ctx context.Context, s string, messages []schema.Message, im
if err != nil {
return LLMResponse{}, err
}
if tokenUsage.Prompt == 0 {
tokenUsage.Prompt = int(reply.PromptTokens)
}
if tokenUsage.Completion == 0 {
tokenUsage.Completion = int(reply.Tokens)
}
return LLMResponse{
Response: string(reply.Message),
Usage: tokenUsage,
Expand Down

0 comments on commit 855bf54

Please sign in to comment.