From c1c51722da8139ae200b2bbec24e4ee9c7716448 Mon Sep 17 00:00:00 2001 From: woorui Date: Mon, 26 Aug 2024 17:29:43 +0800 Subject: [PATCH 1/2] feat: ovid provider's errors are sent directly to the end user --- pkg/bridge/ai/api_server.go | 29 ++++++++++++++++++++--------- pkg/bridge/ai/service.go | 10 +++++----- 2 files changed, 25 insertions(+), 14 deletions(-) diff --git a/pkg/bridge/ai/api_server.go b/pkg/bridge/ai/api_server.go index db110f0a6..721632888 100644 --- a/pkg/bridge/ai/api_server.go +++ b/pkg/bridge/ai/api_server.go @@ -111,7 +111,7 @@ func decorateReqContext(service *Service, logger *slog.Logger) func(handler http caller, err := service.LoadOrCreateCaller(r) if err != nil { - RespondWithError(w, http.StatusBadRequest, err) + RespondWithError(w, http.StatusBadRequest, err, logger) return } ctx = WithCallerContext(ctx, caller) @@ -146,7 +146,7 @@ func (h *Handler) HandleOverview(w http.ResponseWriter, r *http.Request) { tcs, err := register.ListToolCalls(FromCallerContext(r.Context()).Metadata()) if err != nil { - RespondWithError(w, http.StatusInternalServerError, err) + RespondWithError(w, http.StatusInternalServerError, err, h.service.logger) return } @@ -169,7 +169,7 @@ func (h *Handler) HandleInvoke(w http.ResponseWriter, r *http.Request) { ) defer r.Body.Close() - req, err := DecodeRequest[ai.InvokeRequest](r, w) + req, err := DecodeRequest[ai.InvokeRequest](r, w, h.service.logger) if err != nil { return } @@ -181,7 +181,7 @@ func (h *Handler) HandleInvoke(w http.ResponseWriter, r *http.Request) { res, err := h.service.GetInvoke(ctx, req.Prompt, baseSystemMessage, transID, FromCallerContext(ctx), req.IncludeCallStack) if err != nil { - RespondWithError(w, http.StatusInternalServerError, err) + RespondWithError(w, http.StatusInternalServerError, err, h.service.logger) return } @@ -197,7 +197,7 @@ func (h *Handler) HandleChatCompletions(w http.ResponseWriter, r *http.Request) ) defer r.Body.Close() - req, err := DecodeRequest[openai.ChatCompletionRequest](r, w) + req, err := DecodeRequest[openai.ChatCompletionRequest](r, w, h.service.logger) if err != nil { return } @@ -206,18 +206,18 @@ func (h *Handler) HandleChatCompletions(w http.ResponseWriter, r *http.Request) defer cancel() if err := h.service.GetChatCompletions(ctx, req, transID, FromCallerContext(ctx), w); err != nil { - RespondWithError(w, http.StatusBadRequest, err) + RespondWithError(w, http.StatusBadRequest, err, h.service.logger) return } } // DecodeRequest decodes the request body into given type. -func DecodeRequest[T any](r *http.Request, w http.ResponseWriter) (T, error) { +func DecodeRequest[T any](r *http.Request, w http.ResponseWriter, logger *slog.Logger) (T, error) { var req T err := json.NewDecoder(r.Body).Decode(&req) if err != nil { w.Header().Set("Content-Type", "application/json") - RespondWithError(w, http.StatusBadRequest, err) + RespondWithError(w, http.StatusBadRequest, err, logger) return req, err } @@ -225,7 +225,18 @@ func DecodeRequest[T any](r *http.Request, w http.ResponseWriter) (T, error) { } // RespondWithError writes an error to response according to the OpenAI API spec. -func RespondWithError(w http.ResponseWriter, code int, err error) { +func RespondWithError(w http.ResponseWriter, code int, err error, logger *slog.Logger) { + logger.Error("bridge server error", "error", err) + + oerr, ok := err.(*openai.APIError) + if ok { + if oerr.HTTPStatusCode >= 400 { + code = http.StatusInternalServerError + w.WriteHeader(code) + w.Write([]byte(fmt.Sprintf(`{"error":{"code":"%d","message":"%s"}}`, code, "Internal Server Error, Please Try Again Later."))) + return + } + } w.WriteHeader(code) w.Write([]byte(fmt.Sprintf(`{"error":{"code":"%d","message":"%s"}}`, code, err.Error()))) } diff --git a/pkg/bridge/ai/service.go b/pkg/bridge/ai/service.go index fd970465e..d5ea787aa 100644 --- a/pkg/bridge/ai/service.go +++ b/pkg/bridge/ai/service.go @@ -255,15 +255,15 @@ func (srv *Service) GetChatCompletions(ctx context.Context, req openai.ChatCompl // 4. request first chat for getting tools if req.Stream { _, firstCallSpan := srv.option.Tracer.Start(reqCtx, "first_call_request") - var ( - flusher = eventFlusher(w) - isFunctionCall = false - ) + resStream, err := srv.provider.GetChatCompletionsStream(reqCtx, req, md) if err != nil { return err } - + var ( + flusher = eventFlusher(w) + isFunctionCall = false + ) var ( i int // number of chunks j int // number of tool call chunks From 7a868720b0986365e9e8c29ab1ea7559bd029d71 Mon Sep 17 00:00:00 2001 From: woorui Date: Mon, 26 Aug 2024 17:33:57 +0800 Subject: [PATCH 2/2] improve code --- pkg/bridge/ai/api_server.go | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/pkg/bridge/ai/api_server.go b/pkg/bridge/ai/api_server.go index 721632888..52dc3656e 100644 --- a/pkg/bridge/ai/api_server.go +++ b/pkg/bridge/ai/api_server.go @@ -228,17 +228,17 @@ func DecodeRequest[T any](r *http.Request, w http.ResponseWriter, logger *slog.L func RespondWithError(w http.ResponseWriter, code int, err error, logger *slog.Logger) { logger.Error("bridge server error", "error", err) + errString := err.Error() oerr, ok := err.(*openai.APIError) if ok { if oerr.HTTPStatusCode >= 400 { code = http.StatusInternalServerError - w.WriteHeader(code) - w.Write([]byte(fmt.Sprintf(`{"error":{"code":"%d","message":"%s"}}`, code, "Internal Server Error, Please Try Again Later."))) - return + errString = "Internal Server Error, Please Try Again Later." } } + w.WriteHeader(code) - w.Write([]byte(fmt.Sprintf(`{"error":{"code":"%d","message":"%s"}}`, code, err.Error()))) + w.Write([]byte(fmt.Sprintf(`{"error":{"code":"%d","message":"%s"}}`, code, errString))) } func getLocalIP() (string, error) {