diff --git a/service/aiproxy/common/consume/consume.go b/service/aiproxy/common/consume/consume.go new file mode 100644 index 00000000000..7ba0d3d2ef6 --- /dev/null +++ b/service/aiproxy/common/consume/consume.go @@ -0,0 +1,162 @@ +package consume + +import ( + "context" + "sync" + + "github.com/labring/sealos/service/aiproxy/common/balance" + "github.com/labring/sealos/service/aiproxy/model" + "github.com/labring/sealos/service/aiproxy/relay/meta" + relaymodel "github.com/labring/sealos/service/aiproxy/relay/model" + "github.com/shopspring/decimal" + log "github.com/sirupsen/logrus" +) + +var consumeWaitGroup sync.WaitGroup + +func Wait() { + consumeWaitGroup.Wait() +} + +func AsyncConsume( + ctx context.Context, + postGroupConsumer balance.PostGroupConsumer, + code int, + usage *relaymodel.Usage, + meta *meta.Meta, + inputPrice, + outputPrice float64, + content string, + requestDetail *model.RequestDetail, +) { + if meta.IsChannelTest { + return + } + + consumeWaitGroup.Add(1) + defer func() { + consumeWaitGroup.Done() + if r := recover(); r != nil { + log.Errorf("panic in consume: %v", r) + } + }() + + go Consume(ctx, postGroupConsumer, code, usage, meta, inputPrice, outputPrice, content, requestDetail) +} + +func Consume( + ctx context.Context, + postGroupConsumer balance.PostGroupConsumer, + code int, + usage *relaymodel.Usage, + meta *meta.Meta, + inputPrice, + outputPrice float64, + content string, + requestDetail *model.RequestDetail, +) { + if meta.IsChannelTest { + return + } + + amount := calculateAmount(ctx, usage, inputPrice, outputPrice, postGroupConsumer, meta) + + err := recordConsume(meta, code, usage, inputPrice, outputPrice, content, requestDetail, amount) + if err != nil { + log.Error("error batch record consume: " + err.Error()) + } +} + +func calculateAmount( + ctx context.Context, + usage *relaymodel.Usage, + inputPrice, outputPrice float64, + postGroupConsumer balance.PostGroupConsumer, + meta *meta.Meta, +) float64 { + if usage == nil { + return 0 + } + + promptTokens := usage.PromptTokens + completionTokens := usage.CompletionTokens + totalTokens := promptTokens + completionTokens + + if totalTokens == 0 { + return 0 + } + + promptAmount := decimal.NewFromInt(int64(promptTokens)). + Mul(decimal.NewFromFloat(inputPrice)). + Div(decimal.NewFromInt(model.PriceUnit)) + completionAmount := decimal.NewFromInt(int64(completionTokens)). + Mul(decimal.NewFromFloat(outputPrice)). + Div(decimal.NewFromInt(model.PriceUnit)) + amount := promptAmount.Add(completionAmount).InexactFloat64() + + if amount > 0 { + return processGroupConsume(ctx, amount, postGroupConsumer, meta) + } + + return 0 +} + +func processGroupConsume( + ctx context.Context, + amount float64, + postGroupConsumer balance.PostGroupConsumer, + meta *meta.Meta, +) float64 { + consumedAmount, err := postGroupConsumer.PostGroupConsume(ctx, meta.Token.Name, amount) + if err != nil { + log.Error("error consuming token remain amount: " + err.Error()) + if err := model.CreateConsumeError( + meta.RequestID, + meta.RequestAt, + meta.Group.ID, + meta.Token.Name, + meta.OriginModel, + err.Error(), + amount, + meta.Token.ID, + ); err != nil { + log.Error("failed to create consume error: " + err.Error()) + } + return amount + } + return consumedAmount +} + +func recordConsume(meta *meta.Meta, code int, usage *relaymodel.Usage, inputPrice, outputPrice float64, content string, requestDetail *model.RequestDetail, amount float64) error { + promptTokens := 0 + completionTokens := 0 + if usage != nil { + promptTokens = usage.PromptTokens + completionTokens = usage.CompletionTokens + } + + var channelID int + if meta.Channel != nil { + channelID = meta.Channel.ID + } + + return model.BatchRecordConsume( + meta.RequestID, + meta.RequestAt, + meta.Group.ID, + code, + channelID, + promptTokens, + completionTokens, + meta.OriginModel, + meta.Token.ID, + meta.Token.Name, + amount, + inputPrice, + outputPrice, + meta.Endpoint, + content, + meta.Mode, + requestDetail, + ) +} diff --git a/service/aiproxy/main.go b/service/aiproxy/main.go index 9b12bb947d8..2742f42ce00 100644 --- a/service/aiproxy/main.go +++ b/service/aiproxy/main.go @@ -18,10 +18,10 @@ import ( "github.com/labring/sealos/service/aiproxy/common" "github.com/labring/sealos/service/aiproxy/common/balance" "github.com/labring/sealos/service/aiproxy/common/config" + "github.com/labring/sealos/service/aiproxy/common/consume" "github.com/labring/sealos/service/aiproxy/controller" "github.com/labring/sealos/service/aiproxy/middleware" "github.com/labring/sealos/service/aiproxy/model" - relaycontroller "github.com/labring/sealos/service/aiproxy/relay/controller" "github.com/labring/sealos/service/aiproxy/router" log "github.com/sirupsen/logrus" ) @@ -185,8 +185,8 @@ func main() { log.Info("server shutdown successfully") } - log.Info("shutting down relay consumer...") - relaycontroller.ConsumeWaitGroup.Wait() + log.Info("shutting down consumer...") + consume.Wait() log.Info("shutting down sync services...") wg.Wait() diff --git a/service/aiproxy/middleware/distributor.go b/service/aiproxy/middleware/distributor.go index f4322e2975c..136adeaa5b2 100644 --- a/service/aiproxy/middleware/distributor.go +++ b/service/aiproxy/middleware/distributor.go @@ -46,9 +46,9 @@ func getGroupRPMRatio(group *model.GroupCache) float64 { return groupRPMRatio } -func checkGroupModelRPMAndTPM(c *gin.Context, group *model.GroupCache, requestModel string, modelRPM int64, modelTPM int64) bool { +func checkGroupModelRPMAndTPM(c *gin.Context, group *model.GroupCache, requestModel string, modelRPM int64, modelTPM int64) error { if modelRPM <= 0 { - return true + return nil } groupConsumeLevelRpmRatio := calculateGroupConsumeLevelRpmRatio(group.UsedAmount) @@ -65,10 +65,7 @@ func checkGroupModelRPMAndTPM(c *gin.Context, group *model.GroupCache, requestMo ) if !ok { - abortWithMessage(c, http.StatusTooManyRequests, - group.ID+" is requesting too frequently", - ) - return false + return fmt.Errorf("group (%s) is requesting too frequently", group.ID) } if modelTPM > 0 { @@ -76,17 +73,14 @@ func checkGroupModelRPMAndTPM(c *gin.Context, group *model.GroupCache, requestMo if err != nil { log.Errorf("get group model tpm (%s:%s) error: %s", group.ID, requestModel, err.Error()) // ignore error - return true + return nil } if tpm >= modelTPM { - abortWithMessage(c, http.StatusTooManyRequests, - group.ID+" tpm is too high", - ) - return false + return fmt.Errorf("group (%s) tpm is too high", group.ID) } } - return true + return nil } func Distribute(c *gin.Context) { @@ -111,6 +105,12 @@ func Distribute(c *gin.Context) { SetLogModelFields(log.Data, requestModel) + mc, ok := GetModelCaches(c).ModelConfigMap[requestModel] + if !ok { + abortWithMessage(c, http.StatusServiceUnavailable, requestModel+" is not available") + return + } + token := GetToken(c) if len(token.Models) == 0 || !slices.Contains(token.Models, requestModel) { abortWithMessage(c, @@ -122,13 +122,8 @@ func Distribute(c *gin.Context) { return } - mc, ok := GetModelCaches(c).ModelConfigMap[requestModel] - if !ok { - abortWithMessage(c, http.StatusServiceUnavailable, requestModel+" is not available") - return - } - - if !checkGroupModelRPMAndTPM(c, group, requestModel, mc.RPM, mc.TPM) { + if err := checkGroupModelRPMAndTPM(c, group, requestModel, mc.RPM, mc.TPM); err != nil { + abortWithMessage(c, http.StatusTooManyRequests, err.Error()) return } diff --git a/service/aiproxy/model/modelconfig.go b/service/aiproxy/model/modelconfig.go index a1b51f886be..160db525d56 100644 --- a/service/aiproxy/model/modelconfig.go +++ b/service/aiproxy/model/modelconfig.go @@ -10,6 +10,11 @@ import ( "gorm.io/gorm" ) +const ( + // /1K tokens + PriceUnit = 1000 +) + //nolint:revive type ModelConfig struct { CreatedAt time.Time `gorm:"index;autoCreateTime" json:"created_at"` diff --git a/service/aiproxy/relay/controller/consume.go b/service/aiproxy/relay/controller/consume.go index 9142b200194..b343dbc495f 100644 --- a/service/aiproxy/relay/controller/consume.go +++ b/service/aiproxy/relay/controller/consume.go @@ -2,19 +2,13 @@ package controller import ( "context" - "sync" "github.com/labring/sealos/service/aiproxy/common/balance" - "github.com/labring/sealos/service/aiproxy/middleware" "github.com/labring/sealos/service/aiproxy/model" "github.com/labring/sealos/service/aiproxy/relay/meta" - relaymodel "github.com/labring/sealos/service/aiproxy/relay/model" "github.com/shopspring/decimal" - log "github.com/sirupsen/logrus" ) -var ConsumeWaitGroup sync.WaitGroup - type PreCheckGroupBalanceReq struct { InputTokens int MaxTokens int @@ -33,7 +27,7 @@ func getPreConsumedAmount(req *PreCheckGroupBalanceReq) float64 { return decimal. NewFromInt(preConsumedTokens). Mul(decimal.NewFromFloat(req.InputPrice)). - Div(decimal.NewFromInt(PriceUnit)). + Div(decimal.NewFromInt(model.PriceUnit)). InexactFloat64() } @@ -57,115 +51,3 @@ func getGroupBalance(ctx context.Context, meta *meta.Meta) (float64, balance.Pos return balance.Default.GetGroupRemainBalance(ctx, meta.Group.ID) } - -func postConsumeAmount( - ctx context.Context, - consumeWaitGroup *sync.WaitGroup, - postGroupConsumer balance.PostGroupConsumer, - code int, - usage *relaymodel.Usage, - meta *meta.Meta, - inputPrice, - outputPrice float64, - content string, - requestDetail *model.RequestDetail, -) { - defer func() { - consumeWaitGroup.Done() - if r := recover(); r != nil { - log.Errorf("panic in post consume amount: %v", r) - } - }() - - if meta.IsChannelTest { - return - } - - log := middleware.NewLogger() - middleware.SetLogFieldsFromMeta(meta, log.Data) - - amount := calculateAmount(ctx, usage, inputPrice, outputPrice, postGroupConsumer, meta, log) - - err := recordConsume(meta, code, usage, inputPrice, outputPrice, content, requestDetail, amount) - if err != nil { - log.Error("error batch record consume: " + err.Error()) - } -} - -func calculateAmount(ctx context.Context, usage *relaymodel.Usage, inputPrice, outputPrice float64, postGroupConsumer balance.PostGroupConsumer, meta *meta.Meta, log *log.Entry) float64 { - if usage == nil { - return 0 - } - - promptTokens := usage.PromptTokens - completionTokens := usage.CompletionTokens - totalTokens := promptTokens + completionTokens - - if totalTokens == 0 { - return 0 - } - - promptAmount := decimal.NewFromInt(int64(promptTokens)). - Mul(decimal.NewFromFloat(inputPrice)). - Div(decimal.NewFromInt(PriceUnit)) - completionAmount := decimal.NewFromInt(int64(completionTokens)). - Mul(decimal.NewFromFloat(outputPrice)). - Div(decimal.NewFromInt(PriceUnit)) - amount := promptAmount.Add(completionAmount).InexactFloat64() - - if amount > 0 { - return processGroupConsume(ctx, amount, postGroupConsumer, meta, log) - } - - return 0 -} - -func processGroupConsume(ctx context.Context, amount float64, postGroupConsumer balance.PostGroupConsumer, meta *meta.Meta, log *log.Entry) float64 { - consumedAmount, err := postGroupConsumer.PostGroupConsume(ctx, meta.Token.Name, amount) - if err != nil { - log.Error("error consuming token remain amount: " + err.Error()) - if err := model.CreateConsumeError( - meta.RequestID, - meta.RequestAt, - meta.Group.ID, - meta.Token.Name, - meta.OriginModel, - err.Error(), - amount, - meta.Token.ID, - ); err != nil { - log.Error("failed to create consume error: " + err.Error()) - } - return amount - } - return consumedAmount -} - -func recordConsume(meta *meta.Meta, code int, usage *relaymodel.Usage, inputPrice, outputPrice float64, content string, requestDetail *model.RequestDetail, amount float64) error { - promptTokens := 0 - completionTokens := 0 - if usage != nil { - promptTokens = usage.PromptTokens - completionTokens = usage.CompletionTokens - } - - return model.BatchRecordConsume( - meta.RequestID, - meta.RequestAt, - meta.Group.ID, - code, - meta.Channel.ID, - promptTokens, - completionTokens, - meta.OriginModel, - meta.Token.ID, - meta.Token.Name, - amount, - inputPrice, - outputPrice, - meta.Endpoint, - content, - meta.Mode, - requestDetail, - ) -} diff --git a/service/aiproxy/relay/controller/handle.go b/service/aiproxy/relay/controller/handle.go index 68cb02cbcba..99b90696b0d 100644 --- a/service/aiproxy/relay/controller/handle.go +++ b/service/aiproxy/relay/controller/handle.go @@ -9,6 +9,7 @@ import ( "github.com/gin-gonic/gin" "github.com/labring/sealos/service/aiproxy/common" "github.com/labring/sealos/service/aiproxy/common/config" + "github.com/labring/sealos/service/aiproxy/common/consume" "github.com/labring/sealos/service/aiproxy/common/conv" "github.com/labring/sealos/service/aiproxy/middleware" "github.com/labring/sealos/service/aiproxy/model" @@ -26,15 +27,28 @@ func Handle(meta *meta.Meta, c *gin.Context, preProcess func() (*PreCheckGroupBa adaptor, ok := channeltype.GetAdaptor(meta.Channel.Type) if !ok { log.Errorf("invalid (%s[%d]) channel type: %d", meta.Channel.Name, meta.Channel.ID, meta.Channel.Type) - return openai.ErrorWrapperWithMessage("invalid channel error", "invalid_channel_type", http.StatusInternalServerError) + return openai.ErrorWrapperWithMessage( + "invalid channel error", "invalid_channel_type", http.StatusInternalServerError) } // 2. Get group balance groupRemainBalance, postGroupConsumer, err := getGroupBalance(ctx, meta) if err != nil { log.Errorf("get group (%s) balance failed: %v", meta.Group.ID, err) - return openai.ErrorWrapper( - fmt.Errorf("get group (%s) balance failed", meta.Group.ID), + errMsg := fmt.Sprintf("get group (%s) balance failed", meta.Group.ID) + consume.AsyncConsume( + context.Background(), + nil, + http.StatusInternalServerError, + nil, + meta, + 0, + 0, + errMsg, + nil, + ) + return openai.ErrorWrapperWithMessage( + errMsg, "get_group_quota_failed", http.StatusInternalServerError, ) @@ -53,9 +67,8 @@ func Handle(meta *meta.Meta, c *gin.Context, preProcess func() (*PreCheckGroupBa RequestBody: conv.BytesToString(body), } } - ConsumeWaitGroup.Add(1) - go postConsumeAmount(context.Background(), - &ConsumeWaitGroup, + consume.AsyncConsume( + context.Background(), nil, http.StatusBadRequest, nil, @@ -90,9 +103,8 @@ func Handle(meta *meta.Meta, c *gin.Context, preProcess func() (*PreCheckGroupBa log.Errorf("handle failed: %+v", respErr.Error) } - ConsumeWaitGroup.Add(1) - go postConsumeAmount(context.Background(), - &ConsumeWaitGroup, + consume.AsyncConsume( + context.Background(), postGroupConsumer, respErr.StatusCode, usage, @@ -106,9 +118,8 @@ func Handle(meta *meta.Meta, c *gin.Context, preProcess func() (*PreCheckGroupBa } // 6. Post consume - ConsumeWaitGroup.Add(1) - go postConsumeAmount(context.Background(), - &ConsumeWaitGroup, + consume.AsyncConsume( + context.Background(), postGroupConsumer, http.StatusOK, usage, diff --git a/service/aiproxy/relay/controller/price.go b/service/aiproxy/relay/controller/price.go index aa5652912ee..b36b2a29377 100644 --- a/service/aiproxy/relay/controller/price.go +++ b/service/aiproxy/relay/controller/price.go @@ -5,11 +5,6 @@ import ( "github.com/labring/sealos/service/aiproxy/model" ) -const ( - // /1K tokens - PriceUnit = 1000 -) - func GetModelPrice(modelConfig *model.ModelConfig) (float64, float64, bool) { if !config.GetBillingEnabled() { return 0, 0, true diff --git a/service/aiproxy/router/relay.go b/service/aiproxy/router/relay.go index 32f54850c77..9a55af893f2 100644 --- a/service/aiproxy/router/relay.go +++ b/service/aiproxy/router/relay.go @@ -9,22 +9,23 @@ import ( ) func SetRelayRouter(router *gin.Engine) { - router.Use(middleware.CORS()) + router.Use( + middleware.CORS(), + middleware.TokenAuth, + ) // https://platform.openai.com/docs/api-reference/introduction modelsRouter := router.Group("/v1/models") - modelsRouter.Use(middleware.TokenAuth) { modelsRouter.GET("", controller.ListModels) modelsRouter.GET("/:model", controller.RetrieveModel) } dashboardRouter := router.Group("/v1/dashboard") - dashboardRouter.Use(middleware.TokenAuth) { dashboardRouter.GET("/billing/subscription", controller.GetSubscription) dashboardRouter.GET("/billing/usage", controller.GetUsage) } relayV1Router := router.Group("/v1") - relayV1Router.Use(middleware.TokenAuth, middleware.Distribute) + relayV1Router.Use(middleware.Distribute) { relayV1Router.POST("/completions", controller.NewRelay(relaymode.Completions)) relayV1Router.POST("/chat/completions", controller.NewRelay(relaymode.ChatCompletions))