Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(api): add correlationID to Track Chat requests #3668

Merged
merged 6 commits into from
Sep 28, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions backend/backend.proto
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,7 @@ message PredictOptions {
repeated Message Messages = 44;
repeated string Videos = 45;
repeated string Audios = 46;
string CorrelationId = 47;
}

// The response message containing the result
Expand Down
14 changes: 14 additions & 0 deletions backend/cpp/llama/grpc-server.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2106,6 +2106,9 @@ json parse_options(bool streaming, const backend::PredictOptions* predict, llama
data["ignore_eos"] = predict->ignoreeos();
data["embeddings"] = predict->embeddings();

// Add the correlationid to json data
data["correlation_id"] = predict->correlationid();

// for each image in the request, add the image data
//
for (int i = 0; i < predict->images_size(); i++) {
Expand Down Expand Up @@ -2344,6 +2347,11 @@ class BackendServiceImpl final : public backend::Backend::Service {
int32_t tokens_evaluated = result.result_json.value("tokens_evaluated", 0);
reply.set_prompt_tokens(tokens_evaluated);

// Log Request Correlation Id
LOG_VERBOSE("correlation:", {
{ "id", data["correlation_id"] }
});

// Send the reply
writer->Write(reply);

Expand All @@ -2367,6 +2375,12 @@ class BackendServiceImpl final : public backend::Backend::Service {
std::string completion_text;
task_result result = llama.queue_results.recv(task_id);
if (!result.error && result.stop) {

// Log Request Correlation Id
LOG_VERBOSE("correlation:", {
{ "id", data["correlation_id"] }
});

completion_text = result.result_json.value("content", "");
int32_t tokens_predicted = result.result_json.value("tokens_predicted", 0);
int32_t tokens_evaluated = result.result_json.value("tokens_evaluated", 0);
Expand Down
7 changes: 7 additions & 0 deletions core/http/endpoints/openai/chat.go
Original file line number Diff line number Diff line change
Expand Up @@ -161,6 +161,12 @@ func ChatEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, startup
textContentToReturn = ""
id = uuid.New().String()
created = int(time.Now().Unix())
// Set CorrelationID
correlationID := c.Get("X-Correlation-ID")
if len(strings.TrimSpace(correlationID)) == 0 {
correlationID = id
}
c.Set("X-Correlation-ID", correlationID)

modelFile, input, err := readRequest(c, cl, ml, startupOptions, true)
if err != nil {
Expand Down Expand Up @@ -444,6 +450,7 @@ func ChatEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, startup
c.Set("Cache-Control", "no-cache")
c.Set("Connection", "keep-alive")
c.Set("Transfer-Encoding", "chunked")
c.Set("X-Correlation-ID", id)

responses := make(chan schema.OpenAIResponse)

Expand Down
2 changes: 2 additions & 0 deletions core/http/endpoints/openai/completion.go
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,8 @@ func CompletionEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, a
}

return func(c *fiber.Ctx) error {
// Add Correlation
c.Set("X-Correlation-ID", id)
modelFile, input, err := readRequest(c, cl, ml, appConfig, true)
if err != nil {
return fmt.Errorf("failed reading parameters from request:%w", err)
Expand Down
13 changes: 12 additions & 1 deletion core/http/endpoints/openai/request.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import (
"fmt"

"github.com/gofiber/fiber/v2"
"github.com/google/uuid"
"github.com/mudler/LocalAI/core/config"
fiberContext "github.com/mudler/LocalAI/core/http/ctx"
"github.com/mudler/LocalAI/core/schema"
Expand All @@ -15,6 +16,11 @@ import (
"github.com/rs/zerolog/log"
)

type correlationIDKeyType string

// CorrelationIDKey to track request across process boundary
const CorrelationIDKey correlationIDKeyType = "correlationID"

func readRequest(c *fiber.Ctx, cl *config.BackendConfigLoader, ml *model.ModelLoader, o *config.ApplicationConfig, firstModel bool) (string, *schema.OpenAIRequest, error) {
input := new(schema.OpenAIRequest)

Expand All @@ -24,9 +30,14 @@ func readRequest(c *fiber.Ctx, cl *config.BackendConfigLoader, ml *model.ModelLo
}

received, _ := json.Marshal(input)
// Extract or generate the correlation ID
correlationID := c.Get("X-Correlation-ID", uuid.New().String())

ctx, cancel := context.WithCancel(o.Context)
input.Context = ctx
// Add the correlation ID to the new context
ctxWithCorrelationID := context.WithValue(ctx, CorrelationIDKey, correlationID)

input.Context = ctxWithCorrelationID
input.Cancel = cancel

log.Debug().Msgf("Request received: %s", string(received))
Expand Down
Loading