diff --git a/backend/backend.proto b/backend/backend.proto index 31bd63e50867..b2d4518e1333 100644 --- a/backend/backend.proto +++ b/backend/backend.proto @@ -136,6 +136,7 @@ message PredictOptions { repeated Message Messages = 44; repeated string Videos = 45; repeated string Audios = 46; + string CorrelationId = 47; } // The response message containing the result diff --git a/backend/cpp/llama/grpc-server.cpp b/backend/cpp/llama/grpc-server.cpp index 56d59d217a7c..791612dbcc99 100644 --- a/backend/cpp/llama/grpc-server.cpp +++ b/backend/cpp/llama/grpc-server.cpp @@ -2106,6 +2106,9 @@ json parse_options(bool streaming, const backend::PredictOptions* predict, llama data["ignore_eos"] = predict->ignoreeos(); data["embeddings"] = predict->embeddings(); + // Add the correlationid to json data + data["correlation_id"] = predict->correlationid(); + // for each image in the request, add the image data // for (int i = 0; i < predict->images_size(); i++) { @@ -2344,6 +2347,11 @@ class BackendServiceImpl final : public backend::Backend::Service { int32_t tokens_evaluated = result.result_json.value("tokens_evaluated", 0); reply.set_prompt_tokens(tokens_evaluated); + // Log Request Correlation Id + LOG_VERBOSE("correlation:", { + { "id", data["correlation_id"] } + }); + // Send the reply writer->Write(reply); @@ -2367,6 +2375,12 @@ class BackendServiceImpl final : public backend::Backend::Service { std::string completion_text; task_result result = llama.queue_results.recv(task_id); if (!result.error && result.stop) { + + // Log Request Correlation Id + LOG_VERBOSE("correlation:", { + { "id", data["correlation_id"] } + }); + completion_text = result.result_json.value("content", ""); int32_t tokens_predicted = result.result_json.value("tokens_predicted", 0); int32_t tokens_evaluated = result.result_json.value("tokens_evaluated", 0); diff --git a/core/http/endpoints/openai/chat.go b/core/http/endpoints/openai/chat.go index b937120a3331..1ac1387eed3e 100644 --- a/core/http/endpoints/openai/chat.go +++ b/core/http/endpoints/openai/chat.go @@ -161,6 +161,12 @@ func ChatEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, startup textContentToReturn = "" id = uuid.New().String() created = int(time.Now().Unix()) + // Set CorrelationID + correlationID := c.Get("X-Correlation-ID") + if len(strings.TrimSpace(correlationID)) == 0 { + correlationID = id + } + c.Set("X-Correlation-ID", correlationID) modelFile, input, err := readRequest(c, cl, ml, startupOptions, true) if err != nil { @@ -444,6 +450,7 @@ func ChatEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, startup c.Set("Cache-Control", "no-cache") c.Set("Connection", "keep-alive") c.Set("Transfer-Encoding", "chunked") + c.Set("X-Correlation-ID", id) responses := make(chan schema.OpenAIResponse) diff --git a/core/http/endpoints/openai/completion.go b/core/http/endpoints/openai/completion.go index b087cc5f8d35..e5de1b3f0296 100644 --- a/core/http/endpoints/openai/completion.go +++ b/core/http/endpoints/openai/completion.go @@ -57,6 +57,8 @@ func CompletionEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, a } return func(c *fiber.Ctx) error { + // Add Correlation + c.Set("X-Correlation-ID", id) modelFile, input, err := readRequest(c, cl, ml, appConfig, true) if err != nil { return fmt.Errorf("failed reading parameters from request:%w", err) diff --git a/core/http/endpoints/openai/request.go b/core/http/endpoints/openai/request.go index e24dd28f2e4b..d6182a391fe8 100644 --- a/core/http/endpoints/openai/request.go +++ b/core/http/endpoints/openai/request.go @@ -6,6 +6,7 @@ import ( "fmt" "github.com/gofiber/fiber/v2" + "github.com/google/uuid" "github.com/mudler/LocalAI/core/config" fiberContext "github.com/mudler/LocalAI/core/http/ctx" "github.com/mudler/LocalAI/core/schema" @@ -15,6 +16,11 @@ import ( "github.com/rs/zerolog/log" ) +type correlationIDKeyType string + +// CorrelationIDKey to track request across process boundary +const CorrelationIDKey correlationIDKeyType = "correlationID" + func readRequest(c *fiber.Ctx, cl *config.BackendConfigLoader, ml *model.ModelLoader, o *config.ApplicationConfig, firstModel bool) (string, *schema.OpenAIRequest, error) { input := new(schema.OpenAIRequest) @@ -24,9 +30,14 @@ func readRequest(c *fiber.Ctx, cl *config.BackendConfigLoader, ml *model.ModelLo } received, _ := json.Marshal(input) + // Extract or generate the correlation ID + correlationID := c.Get("X-Correlation-ID", uuid.New().String()) ctx, cancel := context.WithCancel(o.Context) - input.Context = ctx + // Add the correlation ID to the new context + ctxWithCorrelationID := context.WithValue(ctx, CorrelationIDKey, correlationID) + + input.Context = ctxWithCorrelationID input.Cancel = cancel log.Debug().Msgf("Request received: %s", string(received))