diff --git a/pkg/bridge/ai/api_server.go b/pkg/bridge/ai/api_server.go index db110f0a6..52dc3656e 100644 --- a/pkg/bridge/ai/api_server.go +++ b/pkg/bridge/ai/api_server.go @@ -111,7 +111,7 @@ func decorateReqContext(service *Service, logger *slog.Logger) func(handler http caller, err := service.LoadOrCreateCaller(r) if err != nil { - RespondWithError(w, http.StatusBadRequest, err) + RespondWithError(w, http.StatusBadRequest, err, logger) return } ctx = WithCallerContext(ctx, caller) @@ -146,7 +146,7 @@ func (h *Handler) HandleOverview(w http.ResponseWriter, r *http.Request) { tcs, err := register.ListToolCalls(FromCallerContext(r.Context()).Metadata()) if err != nil { - RespondWithError(w, http.StatusInternalServerError, err) + RespondWithError(w, http.StatusInternalServerError, err, h.service.logger) return } @@ -169,7 +169,7 @@ func (h *Handler) HandleInvoke(w http.ResponseWriter, r *http.Request) { ) defer r.Body.Close() - req, err := DecodeRequest[ai.InvokeRequest](r, w) + req, err := DecodeRequest[ai.InvokeRequest](r, w, h.service.logger) if err != nil { return } @@ -181,7 +181,7 @@ func (h *Handler) HandleInvoke(w http.ResponseWriter, r *http.Request) { res, err := h.service.GetInvoke(ctx, req.Prompt, baseSystemMessage, transID, FromCallerContext(ctx), req.IncludeCallStack) if err != nil { - RespondWithError(w, http.StatusInternalServerError, err) + RespondWithError(w, http.StatusInternalServerError, err, h.service.logger) return } @@ -197,7 +197,7 @@ func (h *Handler) HandleChatCompletions(w http.ResponseWriter, r *http.Request) ) defer r.Body.Close() - req, err := DecodeRequest[openai.ChatCompletionRequest](r, w) + req, err := DecodeRequest[openai.ChatCompletionRequest](r, w, h.service.logger) if err != nil { return } @@ -206,18 +206,18 @@ func (h *Handler) HandleChatCompletions(w http.ResponseWriter, r *http.Request) defer cancel() if err := h.service.GetChatCompletions(ctx, req, transID, FromCallerContext(ctx), w); err != nil { - RespondWithError(w, http.StatusBadRequest, err) + RespondWithError(w, http.StatusBadRequest, err, h.service.logger) return } } // DecodeRequest decodes the request body into given type. -func DecodeRequest[T any](r *http.Request, w http.ResponseWriter) (T, error) { +func DecodeRequest[T any](r *http.Request, w http.ResponseWriter, logger *slog.Logger) (T, error) { var req T err := json.NewDecoder(r.Body).Decode(&req) if err != nil { w.Header().Set("Content-Type", "application/json") - RespondWithError(w, http.StatusBadRequest, err) + RespondWithError(w, http.StatusBadRequest, err, logger) return req, err } @@ -225,9 +225,20 @@ func DecodeRequest[T any](r *http.Request, w http.ResponseWriter) (T, error) { } // RespondWithError writes an error to response according to the OpenAI API spec. -func RespondWithError(w http.ResponseWriter, code int, err error) { +func RespondWithError(w http.ResponseWriter, code int, err error, logger *slog.Logger) { + logger.Error("bridge server error", "error", err) + + errString := err.Error() + oerr, ok := err.(*openai.APIError) + if ok { + if oerr.HTTPStatusCode >= 400 { + code = http.StatusInternalServerError + errString = "Internal Server Error, Please Try Again Later." + } + } + w.WriteHeader(code) - w.Write([]byte(fmt.Sprintf(`{"error":{"code":"%d","message":"%s"}}`, code, err.Error()))) + w.Write([]byte(fmt.Sprintf(`{"error":{"code":"%d","message":"%s"}}`, code, errString))) } func getLocalIP() (string, error) { diff --git a/pkg/bridge/ai/service.go b/pkg/bridge/ai/service.go index fd970465e..d5ea787aa 100644 --- a/pkg/bridge/ai/service.go +++ b/pkg/bridge/ai/service.go @@ -255,15 +255,15 @@ func (srv *Service) GetChatCompletions(ctx context.Context, req openai.ChatCompl // 4. request first chat for getting tools if req.Stream { _, firstCallSpan := srv.option.Tracer.Start(reqCtx, "first_call_request") - var ( - flusher = eventFlusher(w) - isFunctionCall = false - ) + resStream, err := srv.provider.GetChatCompletionsStream(reqCtx, req, md) if err != nil { return err } - + var ( + flusher = eventFlusher(w) + isFunctionCall = false + ) var ( i int // number of chunks j int // number of tool call chunks