Skip to content

Commit

Permalink
feat: consume
Browse files Browse the repository at this point in the history
  • Loading branch information
zijiren233 committed Jan 4, 2025
1 parent 40a2365 commit cc30b33
Show file tree
Hide file tree
Showing 8 changed files with 213 additions and 162 deletions.
162 changes: 162 additions & 0 deletions service/aiproxy/common/consume/consume.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,162 @@
package consume

import (
"context"
"sync"

"github.com/labring/sealos/service/aiproxy/common/balance"
"github.com/labring/sealos/service/aiproxy/model"
"github.com/labring/sealos/service/aiproxy/relay/meta"
relaymodel "github.com/labring/sealos/service/aiproxy/relay/model"
"github.com/shopspring/decimal"
log "github.com/sirupsen/logrus"
)

var consumeWaitGroup sync.WaitGroup

func Wait() {
consumeWaitGroup.Wait()
}

func AsyncConsume(
ctx context.Context,
postGroupConsumer balance.PostGroupConsumer,
code int,
usage *relaymodel.Usage,
meta *meta.Meta,
inputPrice,
outputPrice float64,
content string,
requestDetail *model.RequestDetail,
) {
if meta.IsChannelTest {
return
}

consumeWaitGroup.Add(1)
defer func() {
consumeWaitGroup.Done()
if r := recover(); r != nil {
log.Errorf("panic in consume: %v", r)
}
}()

go Consume(ctx, postGroupConsumer, code, usage, meta, inputPrice, outputPrice, content, requestDetail)
}

func Consume(
ctx context.Context,
postGroupConsumer balance.PostGroupConsumer,
code int,
usage *relaymodel.Usage,
meta *meta.Meta,
inputPrice,
outputPrice float64,
content string,
requestDetail *model.RequestDetail,
) {
if meta.IsChannelTest {
return
}

amount := calculateAmount(ctx, usage, inputPrice, outputPrice, postGroupConsumer, meta)

err := recordConsume(meta, code, usage, inputPrice, outputPrice, content, requestDetail, amount)
if err != nil {
log.Error("error batch record consume: " + err.Error())
}
}

func calculateAmount(
ctx context.Context,
usage *relaymodel.Usage,
inputPrice, outputPrice float64,
postGroupConsumer balance.PostGroupConsumer,
meta *meta.Meta,
) float64 {
if usage == nil {
return 0
}

promptTokens := usage.PromptTokens
completionTokens := usage.CompletionTokens
totalTokens := promptTokens + completionTokens

if totalTokens == 0 {
return 0
}

promptAmount := decimal.NewFromInt(int64(promptTokens)).
Mul(decimal.NewFromFloat(inputPrice)).
Div(decimal.NewFromInt(model.PriceUnit))
completionAmount := decimal.NewFromInt(int64(completionTokens)).
Mul(decimal.NewFromFloat(outputPrice)).
Div(decimal.NewFromInt(model.PriceUnit))
amount := promptAmount.Add(completionAmount).InexactFloat64()

if amount > 0 {
return processGroupConsume(ctx, amount, postGroupConsumer, meta)
}

return 0
}

func processGroupConsume(
ctx context.Context,
amount float64,
postGroupConsumer balance.PostGroupConsumer,
meta *meta.Meta,
) float64 {
consumedAmount, err := postGroupConsumer.PostGroupConsume(ctx, meta.Token.Name, amount)
if err != nil {
log.Error("error consuming token remain amount: " + err.Error())
if err := model.CreateConsumeError(
meta.RequestID,
meta.RequestAt,
meta.Group.ID,
meta.Token.Name,
meta.OriginModel,
err.Error(),
amount,
meta.Token.ID,
); err != nil {
log.Error("failed to create consume error: " + err.Error())
}
return amount
}
return consumedAmount
}

func recordConsume(meta *meta.Meta, code int, usage *relaymodel.Usage, inputPrice, outputPrice float64, content string, requestDetail *model.RequestDetail, amount float64) error {
promptTokens := 0
completionTokens := 0
if usage != nil {
promptTokens = usage.PromptTokens
completionTokens = usage.CompletionTokens
}

var channelID int
if meta.Channel != nil {
channelID = meta.Channel.ID
}

return model.BatchRecordConsume(
meta.RequestID,
meta.RequestAt,
meta.Group.ID,
code,
channelID,
promptTokens,
completionTokens,
meta.OriginModel,
meta.Token.ID,
meta.Token.Name,
amount,
inputPrice,
outputPrice,
meta.Endpoint,
content,
meta.Mode,
requestDetail,
)
}
6 changes: 3 additions & 3 deletions service/aiproxy/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,10 @@ import (
"github.com/labring/sealos/service/aiproxy/common"
"github.com/labring/sealos/service/aiproxy/common/balance"
"github.com/labring/sealos/service/aiproxy/common/config"
"github.com/labring/sealos/service/aiproxy/common/consume"
"github.com/labring/sealos/service/aiproxy/controller"
"github.com/labring/sealos/service/aiproxy/middleware"
"github.com/labring/sealos/service/aiproxy/model"
relaycontroller "github.com/labring/sealos/service/aiproxy/relay/controller"
"github.com/labring/sealos/service/aiproxy/router"
log "github.com/sirupsen/logrus"
)
Expand Down Expand Up @@ -185,8 +185,8 @@ func main() {
log.Info("server shutdown successfully")
}

log.Info("shutting down relay consumer...")
relaycontroller.ConsumeWaitGroup.Wait()
log.Info("shutting down consumer...")
consume.Wait()

log.Info("shutting down sync services...")
wg.Wait()
Expand Down
33 changes: 14 additions & 19 deletions service/aiproxy/middleware/distributor.go
Original file line number Diff line number Diff line change
Expand Up @@ -46,9 +46,9 @@ func getGroupRPMRatio(group *model.GroupCache) float64 {
return groupRPMRatio
}

func checkGroupModelRPMAndTPM(c *gin.Context, group *model.GroupCache, requestModel string, modelRPM int64, modelTPM int64) bool {
func checkGroupModelRPMAndTPM(c *gin.Context, group *model.GroupCache, requestModel string, modelRPM int64, modelTPM int64) error {
if modelRPM <= 0 {
return true
return nil
}

groupConsumeLevelRpmRatio := calculateGroupConsumeLevelRpmRatio(group.UsedAmount)
Expand All @@ -65,28 +65,22 @@ func checkGroupModelRPMAndTPM(c *gin.Context, group *model.GroupCache, requestMo
)

if !ok {
abortWithMessage(c, http.StatusTooManyRequests,
group.ID+" is requesting too frequently",
)
return false
return fmt.Errorf("group (%s) is requesting too frequently", group.ID)
}

if modelTPM > 0 {
tpm, err := model.CacheGetGroupModelTPM(group.ID, requestModel)
if err != nil {
log.Errorf("get group model tpm (%s:%s) error: %s", group.ID, requestModel, err.Error())
// ignore error
return true
return nil
}

if tpm >= modelTPM {
abortWithMessage(c, http.StatusTooManyRequests,
group.ID+" tpm is too high",
)
return false
return fmt.Errorf("group (%s) tpm is too high", group.ID)
}
}
return true
return nil
}

func Distribute(c *gin.Context) {
Expand All @@ -111,6 +105,12 @@ func Distribute(c *gin.Context) {

SetLogModelFields(log.Data, requestModel)

mc, ok := GetModelCaches(c).ModelConfigMap[requestModel]
if !ok {
abortWithMessage(c, http.StatusServiceUnavailable, requestModel+" is not available")
return
}

token := GetToken(c)
if len(token.Models) == 0 || !slices.Contains(token.Models, requestModel) {
abortWithMessage(c,
Expand All @@ -122,13 +122,8 @@ func Distribute(c *gin.Context) {
return
}

mc, ok := GetModelCaches(c).ModelConfigMap[requestModel]
if !ok {
abortWithMessage(c, http.StatusServiceUnavailable, requestModel+" is not available")
return
}

if !checkGroupModelRPMAndTPM(c, group, requestModel, mc.RPM, mc.TPM) {
if err := checkGroupModelRPMAndTPM(c, group, requestModel, mc.RPM, mc.TPM); err != nil {
abortWithMessage(c, http.StatusTooManyRequests, err.Error())
return
}

Expand Down
5 changes: 5 additions & 0 deletions service/aiproxy/model/modelconfig.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,11 @@ import (
"gorm.io/gorm"
)

const (
// /1K tokens
PriceUnit = 1000
)

//nolint:revive
type ModelConfig struct {
CreatedAt time.Time `gorm:"index;autoCreateTime" json:"created_at"`
Expand Down
Loading

0 comments on commit cc30b33

Please sign in to comment.