From 0eda61130c9a0f67a75be59b2611555e5d70e991 Mon Sep 17 00:00:00 2001 From: zijiren233 Date: Thu, 2 Jan 2025 17:40:26 +0800 Subject: [PATCH] feat: get rpm from redis --- .../common/{rate-limit.go => rpmlimit/mem.go} | 2 +- service/aiproxy/common/rpmlimit/rate-limit.go | 137 ++++++++++++++++++ service/aiproxy/controller/dashboard.go | 26 ++++ service/aiproxy/middleware/distributor.go | 10 +- service/aiproxy/middleware/rate-limit.go | 84 ----------- 5 files changed, 168 insertions(+), 91 deletions(-) rename service/aiproxy/common/{rate-limit.go => rpmlimit/mem.go} (99%) create mode 100644 service/aiproxy/common/rpmlimit/rate-limit.go delete mode 100644 service/aiproxy/middleware/rate-limit.go diff --git a/service/aiproxy/common/rate-limit.go b/service/aiproxy/common/rpmlimit/mem.go similarity index 99% rename from service/aiproxy/common/rate-limit.go rename to service/aiproxy/common/rpmlimit/mem.go index a94b6496fc2..5463d164b45 100644 --- a/service/aiproxy/common/rate-limit.go +++ b/service/aiproxy/common/rpmlimit/mem.go @@ -1,4 +1,4 @@ -package common +package rpmlimit import ( "sync" diff --git a/service/aiproxy/common/rpmlimit/rate-limit.go b/service/aiproxy/common/rpmlimit/rate-limit.go new file mode 100644 index 00000000000..0c6c30da93e --- /dev/null +++ b/service/aiproxy/common/rpmlimit/rate-limit.go @@ -0,0 +1,137 @@ +package rpmlimit + +import ( + "context" + "fmt" + "time" + + "github.com/labring/sealos/service/aiproxy/common" + "github.com/labring/sealos/service/aiproxy/common/config" + log "github.com/sirupsen/logrus" +) + +var inMemoryRateLimiter InMemoryRateLimiter + +const ( + groupModelRPMKey = "group_model_rpm:%s:%s" +) + +// 1. 使用Redis列表存储请求时间戳 +// 2. 列表长度代表当前窗口内的请求数 +// 3. 如果请求数未达到限制,直接添加新请求并返回成功 +// 4. 如果达到限制,则检查最老的请求是否已经过期 +// 5. 如果最老的请求已过期,移除它并添加新请求,否则拒绝新请求 +// 6. 通过EXPIRE命令设置键的过期时间,自动清理过期数据 +var luaScript = ` +local key = KEYS[1] +local max_requests = tonumber(ARGV[1]) +local window = tonumber(ARGV[2]) +local current_time = tonumber(ARGV[3]) + +local count = redis.call('LLEN', key) + +if count < max_requests then + redis.call('LPUSH', key, current_time) + redis.call('PEXPIRE', key, window) + return 1 +else + local oldest = redis.call('LINDEX', key, -1) + if current_time - tonumber(oldest) >= window then + redis.call('LPUSH', key, current_time) + redis.call('LTRIM', key, 0, max_requests - 1) + redis.call('PEXPIRE', key, window) + return 1 + else + return 0 + end +end +` + +var getRPMSumLuaScript = ` +local pattern = ARGV[1] +local window = tonumber(ARGV[2]) +local current_time = tonumber(ARGV[3]) + +local keys = redis.call('KEYS', pattern) +local total = 0 + +for _, key in ipairs(keys) do + local timestamps = redis.call('LRANGE', key, 0, -1) + for _, ts in ipairs(timestamps) do + if current_time - tonumber(ts) < window then + total = total + 1 + end + end +end + +return total +` + +func GetRPM(ctx context.Context, group, model string) (int64, error) { + if !common.RedisEnabled { + return 0, nil + } + + var pattern string + if group == "" && model == "" { + pattern = "group_model_rpm:*:*" + } else if group == "" { + pattern = "group_model_rpm:*:" + model + } else if model == "" { + pattern = fmt.Sprintf("group_model_rpm:%s:*", group) + } else { + pattern = fmt.Sprintf("group_model_rpm:%s:%s", group, model) + } + + rdb := common.RDB + currentTime := time.Now().UnixMilli() + result, err := rdb.Eval(ctx, getRPMSumLuaScript, []string{}, pattern, time.Minute.Milliseconds(), currentTime).Int64() + if err != nil { + return 0, err + } + + return result, nil +} + +func redisRateLimitRequest(ctx context.Context, group, model string, maxRequestNum int64, duration time.Duration) (bool, error) { + rdb := common.RDB + currentTime := time.Now().UnixMilli() + result, err := rdb.Eval(ctx, luaScript, []string{ + fmt.Sprintf(groupModelRPMKey, group, model), + }, maxRequestNum, duration.Milliseconds(), currentTime).Int64() + if err != nil { + return false, err + } + return result == 1, nil +} + +func RateLimit(ctx context.Context, group, model string, maxRequestNum int64, duration time.Duration) (bool, error) { + if maxRequestNum == 0 { + return true, nil + } + if common.RedisEnabled { + return redisRateLimitRequest(ctx, group, model, maxRequestNum, duration) + } + return MemoryRateLimit(ctx, group, model, maxRequestNum, duration), nil +} + +// ignore redis error +func ForceRateLimit(ctx context.Context, group, model string, maxRequestNum int64, duration time.Duration) bool { + if maxRequestNum == 0 { + return true + } + if common.RedisEnabled { + ok, err := redisRateLimitRequest(ctx, group, model, maxRequestNum, duration) + if err == nil { + return ok + } + log.Error("rate limit error: " + err.Error()) + } + return MemoryRateLimit(ctx, group, model, maxRequestNum, duration) +} + +func MemoryRateLimit(_ context.Context, group, model string, maxRequestNum int64, duration time.Duration) bool { + // It's safe to call multi times. + inMemoryRateLimiter.Init(config.RateLimitKeyExpirationDuration) + return inMemoryRateLimiter.Request(fmt.Sprintf(groupModelRPMKey, group, model), int(maxRequestNum), duration) +} diff --git a/service/aiproxy/controller/dashboard.go b/service/aiproxy/controller/dashboard.go index 3bb54f1f11b..0104aaefe7c 100644 --- a/service/aiproxy/controller/dashboard.go +++ b/service/aiproxy/controller/dashboard.go @@ -6,6 +6,8 @@ import ( "time" "github.com/gin-gonic/gin" + "github.com/labring/sealos/service/aiproxy/common" + "github.com/labring/sealos/service/aiproxy/common/rpmlimit" "github.com/labring/sealos/service/aiproxy/middleware" "github.com/labring/sealos/service/aiproxy/model" ) @@ -111,6 +113,8 @@ func getTimeSpanWithDefault(c *gin.Context, defaultTimeSpan time.Duration) time. } func GetDashboard(c *gin.Context) { + log := middleware.GetLogger(c) + start, end, timeSpan := getDashboardTime(c.Query("type")) modelName := c.Query("model") timeSpan = getTimeSpanWithDefault(c, timeSpan) @@ -122,10 +126,22 @@ func GetDashboard(c *gin.Context) { } dashboards.ChartData = fillGaps(dashboards.ChartData, start, end, timeSpan) + + if common.RedisEnabled { + rpm, err := rpmlimit.GetRPM(c.Request.Context(), "", modelName) + if err != nil { + log.Errorf("failed to get rpm: %v", err) + } else { + dashboards.RPM = rpm + } + } + middleware.SuccessResponse(c, dashboards) } func GetGroupDashboard(c *gin.Context) { + log := middleware.GetLogger(c) + group := c.Param("group") if group == "" { middleware.ErrorResponse(c, http.StatusOK, "invalid parameter") @@ -144,5 +160,15 @@ func GetGroupDashboard(c *gin.Context) { } dashboards.ChartData = fillGaps(dashboards.ChartData, start, end, timeSpan) + + if common.RedisEnabled && tokenName == "" { + rpm, err := rpmlimit.GetRPM(c.Request.Context(), group, modelName) + if err != nil { + log.Errorf("failed to get rpm: %v", err) + } else { + dashboards.RPM = rpm + } + } + middleware.SuccessResponse(c, dashboards) } diff --git a/service/aiproxy/middleware/distributor.go b/service/aiproxy/middleware/distributor.go index ac2ce9b8219..911b4df007c 100644 --- a/service/aiproxy/middleware/distributor.go +++ b/service/aiproxy/middleware/distributor.go @@ -9,15 +9,12 @@ import ( "github.com/gin-gonic/gin" "github.com/labring/sealos/service/aiproxy/common/config" "github.com/labring/sealos/service/aiproxy/common/ctxkey" + "github.com/labring/sealos/service/aiproxy/common/rpmlimit" "github.com/labring/sealos/service/aiproxy/model" "github.com/labring/sealos/service/aiproxy/relay/meta" log "github.com/sirupsen/logrus" ) -const ( - groupModelRPMKey = "group_model_rpm:%s:%s" -) - type ModelRequest struct { Model string `form:"model" json:"model"` } @@ -59,9 +56,10 @@ func checkGroupModelRPMAndTPM(c *gin.Context, group *model.GroupCache, requestMo adjustedModelRPM := int64(float64(modelRPM) * groupRPMRatio * groupConsumeLevelRpmRatio) - ok := ForceRateLimit( + ok := rpmlimit.ForceRateLimit( c.Request.Context(), - fmt.Sprintf(groupModelRPMKey, group.ID, requestModel), + group.ID, + requestModel, adjustedModelRPM, time.Minute, ) diff --git a/service/aiproxy/middleware/rate-limit.go b/service/aiproxy/middleware/rate-limit.go deleted file mode 100644 index 103138e41c0..00000000000 --- a/service/aiproxy/middleware/rate-limit.go +++ /dev/null @@ -1,84 +0,0 @@ -package middleware - -import ( - "context" - "time" - - "github.com/labring/sealos/service/aiproxy/common" - "github.com/labring/sealos/service/aiproxy/common/config" - log "github.com/sirupsen/logrus" -) - -var inMemoryRateLimiter common.InMemoryRateLimiter - -// 1. 使用Redis列表存储请求时间戳 -// 2. 列表长度代表当前窗口内的请求数 -// 3. 如果请求数未达到限制,直接添加新请求并返回成功 -// 4. 如果达到限制,则检查最老的请求是否已经过期 -// 5. 如果最老的请求已过期,移除它并添加新请求,否则拒绝新请求 -// 6. 通过EXPIRE命令设置键的过期时间,自动清理过期数据 -var luaScript = ` -local key = KEYS[1] -local max_requests = tonumber(ARGV[1]) -local window = tonumber(ARGV[2]) -local current_time = tonumber(ARGV[3]) - -local count = redis.call('LLEN', key) - -if count < max_requests then - redis.call('LPUSH', key, current_time) - redis.call('PEXPIRE', key, window) - return 1 -else - local oldest = redis.call('LINDEX', key, -1) - if current_time - tonumber(oldest) >= window then - redis.call('LPUSH', key, current_time) - redis.call('LTRIM', key, 0, max_requests - 1) - redis.call('PEXPIRE', key, window) - return 1 - else - return 0 - end -end -` - -func redisRateLimitRequest(ctx context.Context, key string, maxRequestNum int64, duration time.Duration) (bool, error) { - rdb := common.RDB - currentTime := time.Now().UnixMilli() - result, err := rdb.Eval(ctx, luaScript, []string{key}, maxRequestNum, duration.Milliseconds(), currentTime).Int64() - if err != nil { - return false, err - } - return result == 1, nil -} - -func RateLimit(ctx context.Context, key string, maxRequestNum int64, duration time.Duration) (bool, error) { - if maxRequestNum == 0 { - return true, nil - } - if common.RedisEnabled { - return redisRateLimitRequest(ctx, key, maxRequestNum, duration) - } - return MemoryRateLimit(ctx, key, maxRequestNum, duration), nil -} - -// ignore redis error -func ForceRateLimit(ctx context.Context, key string, maxRequestNum int64, duration time.Duration) bool { - if maxRequestNum == 0 { - return true - } - if common.RedisEnabled { - ok, err := redisRateLimitRequest(ctx, key, maxRequestNum, duration) - if err == nil { - return ok - } - log.Error("rate limit error: " + err.Error()) - } - return MemoryRateLimit(ctx, key, maxRequestNum, duration) -} - -func MemoryRateLimit(_ context.Context, key string, maxRequestNum int64, duration time.Duration) bool { - // It's safe to call multi times. - inMemoryRateLimiter.Init(config.RateLimitKeyExpirationDuration) - return inMemoryRateLimiter.Request(key, int(maxRequestNum), duration) -}