Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allowlist with specified rate limits #156

Merged
merged 1 commit into from
Jan 8, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 2 additions & 16 deletions common/ratelimit/limiter.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@ package ratelimit

import (
"context"
"strings"
"time"

"github.com/Layr-Labs/eigenda/common"
Expand All @@ -12,35 +11,23 @@ type BucketStore = common.KVStore[common.RateBucketParams]

type rateLimiter struct {
globalRateParams common.GlobalRateParams

bucketStore BucketStore
allowlist []string
bucketStore BucketStore

logger common.Logger
}

func NewRateLimiter(rateParams common.GlobalRateParams, bucketStore BucketStore, allowlist []string, logger common.Logger) common.RateLimiter {
func NewRateLimiter(rateParams common.GlobalRateParams, bucketStore BucketStore, logger common.Logger) common.RateLimiter {
return &rateLimiter{
globalRateParams: rateParams,
bucketStore: bucketStore,
allowlist: allowlist,
logger: logger,
}
}

// Checks whether a request from the given requesterID is allowed
func (d *rateLimiter) AllowRequest(ctx context.Context, requesterID common.RequesterID, blobSize uint, rate common.RateParam) (bool, error) {
// TODO: temporary allowlist that unconditionally allows request
// for testing purposes only
for _, id := range d.allowlist {
if strings.Contains(requesterID, id) {
return true, nil
}
}

// Retrieve bucket params for the requester ID
// This will be from dynamo for Disperser and from local storage for DA node

bucketParams, err := d.bucketStore.GetItem(ctx, requesterID)
if err != nil {

Expand Down Expand Up @@ -68,7 +55,6 @@ func (d *rateLimiter) AllowRequest(ctx context.Context, requesterID common.Reque

// Update the bucket level
bucketParams.BucketLevels[i] = getBucketLevel(bucketParams.BucketLevels[i], size, interval, deduction)

allowed = allowed && bucketParams.BucketLevels[i] > 0
}

Expand Down
10 changes: 0 additions & 10 deletions common/ratelimit/limiter_cli.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,14 +15,12 @@ const (
BucketMultipliersFlagName = "bucket-multipliers"
CountFailedFlagName = "count-failed"
BucketStoreSizeFlagName = "bucket-store-size"
AllowlistFlagName = "allowlist"
)

type Config struct {
common.GlobalRateParams
BucketStoreSize int
UniformRateParam common.RateParam
Allowlist []string
}

func RatelimiterCLIFlags(envPrefix string, flagPrefix string) []cli.Flag {
Expand Down Expand Up @@ -54,13 +52,6 @@ func RatelimiterCLIFlags(envPrefix string, flagPrefix string) []cli.Flag {
EnvVar: common.PrefixEnvVar(envPrefix, "BUCKET_STORE_SIZE"),
Required: false,
},
cli.StringSliceFlag{
Name: common.PrefixFlag(flagPrefix, AllowlistFlagName),
Usage: "Allowlist of IPs to bypass rate limiting",
EnvVar: common.PrefixEnvVar(envPrefix, "ALLOWLIST"),
Required: false,
Value: &cli.StringSlice{},
},
}
}

Expand Down Expand Up @@ -106,7 +97,6 @@ func ReadCLIConfig(ctx *cli.Context, flagPrefix string) (Config, error) {
cfg.Multipliers = multipliers
cfg.GlobalRateParams.CountFailed = ctx.Bool(common.PrefixFlag(flagPrefix, CountFailedFlagName))
cfg.BucketStoreSize = ctx.Int(common.PrefixFlag(flagPrefix, BucketStoreSizeFlagName))
cfg.Allowlist = ctx.StringSlice(common.PrefixFlag(flagPrefix, AllowlistFlagName))

err := validateConfig(cfg)
if err != nil {
Expand Down
18 changes: 1 addition & 17 deletions common/ratelimit/ratelimit_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ func makeTestRatelimiter() (common.RateLimiter, error) {
return nil, err
}

ratelimiter := ratelimit.NewRateLimiter(globalParams, bucketStore, []string{"testRetriever2"}, &mock.Logger{})
ratelimiter := ratelimit.NewRateLimiter(globalParams, bucketStore, &mock.Logger{})

return ratelimiter, nil

Expand All @@ -50,19 +50,3 @@ func TestRatelimit(t *testing.T) {
assert.NoError(t, err)
assert.Equal(t, false, allow)
}

func TestRatelimitAllowlist(t *testing.T) {
ratelimiter, err := makeTestRatelimiter()
assert.NoError(t, err)

ctx := context.Background()

retreiverID := "testRetriever2"

// 10x more requests allowed for allowlisted IDs
for i := 0; i < 100; i++ {
allow, err := ratelimiter.AllowRequest(ctx, retreiverID, 10, 100)
assert.NoError(t, err)
assert.Equal(t, true, allow)
}
}
51 changes: 51 additions & 0 deletions disperser/apiserver/rate_config.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,9 @@ package apiserver

import (
"fmt"
"log"
"strconv"
"strings"

"github.com/Layr-Labs/eigenda/common"
"github.com/Layr-Labs/eigenda/core"
Expand All @@ -16,6 +18,7 @@ const (
TotalUnauthBlobRateFlagName = "auth.total-unauth-blob-rate"
PerUserUnauthBlobRateFlagName = "auth.per-user-unauth-blob-rate"
ClientIPHeaderFlagName = "auth.client-ip-header"
AllowlistFlagName = "auth.allowlist"

// We allow the user to specify the blob rate in blobs/sec, but internally we use blobs/sec * 1e6 (i.e. blobs/microsec).
// This is because the rate limiter takes an integer rate.
Expand All @@ -29,9 +32,17 @@ type QuorumRateInfo struct {
TotalUnauthBlobRate common.RateParam
}

type PerUserRateInfo struct {
Throughput common.RateParam
BlobRate common.RateParam
}

type Allowlist = map[string]map[core.QuorumID]PerUserRateInfo

type RateConfig struct {
QuorumRateInfos map[core.QuorumID]QuorumRateInfo
ClientIPHeader string
Allowlist Allowlist
}

func CLIFlags(envPrefix string) []cli.Flag {
Expand Down Expand Up @@ -73,6 +84,13 @@ func CLIFlags(envPrefix string) []cli.Flag {
Value: "",
EnvVar: common.PrefixEnvVar(envPrefix, "CLIENT_IP_HEADER"),
},
cli.StringSliceFlag{
Name: AllowlistFlagName,
Usage: "Allowlist of IPs and corresponding blob/byte rates to bypass rate limiting. Format: <IP>:<quorum ID>:<blob rate>:<byte rate>. Example: 127.0.0.1:0:10:10485760",
EnvVar: common.PrefixEnvVar(envPrefix, "ALLOWLIST"),
Required: false,
Value: &cli.StringSlice{},
},
}
}

Expand Down Expand Up @@ -112,8 +130,41 @@ func ReadCLIConfig(c *cli.Context) (RateConfig, error) {
}
}

// Parse allowlist
allowlist := make(Allowlist)
for _, allowlistEntry := range c.StringSlice(AllowlistFlagName) {
allowlistEntrySplit := strings.Split(allowlistEntry, ":")
if len(allowlistEntrySplit) != 4 {
log.Printf("invalid allowlist entry: entry should contain exactly 4 elements: %s", allowlistEntry)
continue
}
ip := allowlistEntrySplit[0]
quorumID, err := strconv.Atoi(allowlistEntrySplit[1])
if err != nil {
log.Printf("invalid allowlist entry: failed to convert quorum ID from string: %s", allowlistEntry)
continue
}
blobRate, err := strconv.ParseFloat(allowlistEntrySplit[2], 64)
if err != nil {
log.Printf("invalid allowlist entry: failed to convert blob rate from string: %s", allowlistEntry)
continue
}
byteRate, err := strconv.ParseFloat(allowlistEntrySplit[3], 64)
if err != nil {
log.Printf("invalid allowlist entry: failed to convert throughput from string: %s", allowlistEntry)
continue
}
allowlist[ip] = map[core.QuorumID]PerUserRateInfo{
core.QuorumID(quorumID): {
Throughput: common.RateParam(byteRate),
BlobRate: common.RateParam(blobRate * blobRateMultiplier),
},
}
}

return RateConfig{
QuorumRateInfos: quorumRateInfos,
ClientIPHeader: c.String(ClientIPHeaderFlagName),
Allowlist: allowlist,
}, nil
}
80 changes: 65 additions & 15 deletions disperser/apiserver/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import (
"errors"
"fmt"
"net"
"strings"
"sync"
"time"

Expand All @@ -18,8 +19,10 @@ import (
"google.golang.org/grpc/reflection"
)

var errSystemRateLimit = fmt.Errorf("request ratelimited: system limit")
var errAccountRateLimit = fmt.Errorf("request ratelimited: account limit")
var errSystemBlobRateLimit = fmt.Errorf("request ratelimited: system blob limit")
var errSystemThroughputRateLimit = fmt.Errorf("request ratelimited: system throughput limit")
var errAccountBlobRateLimit = fmt.Errorf("request ratelimited: account blob limit")
var errAccountThroughputRateLimit = fmt.Errorf("request ratelimited: account throughput limit")

const systemAccountKey = "system"

Expand Down Expand Up @@ -55,6 +58,11 @@ func NewDispersalServer(
ratelimiter common.RateLimiter,
rateConfig RateConfig,
) *DispersalServer {
for ip, rateInfoByQuorum := range rateConfig.Allowlist {
for quorumID, rateInfo := range rateInfoByQuorum {
logger.Info("[Allowlist]", "ip", ip, "quorumID", quorumID, "throughput", rateInfo.Throughput, "blobRate", rateInfo.BlobRate)
}
}
return &DispersalServer{
config: config,
blobStore: store,
Expand Down Expand Up @@ -139,9 +147,9 @@ func (s *DispersalServer) DisperseBlob(ctx context.Context, req *pb.DisperseBlob
if err != nil {
for _, param := range securityParams {
quorumId := string(uint8(param.GetQuorumId()))
if errors.Is(err, errSystemRateLimit) {
if errors.Is(err, errSystemBlobRateLimit) || errors.Is(err, errSystemThroughputRateLimit) {
s.metrics.HandleSystemRateLimitedRequest(quorumId, blobSize, "DisperseBlob")
} else if errors.Is(err, errAccountRateLimit) {
} else if errors.Is(err, errAccountBlobRateLimit) || errors.Is(err, errAccountThroughputRateLimit) {
s.metrics.HandleAccountRateLimitedRequest(quorumId, blobSize, "DisperseBlob")
} else {
s.metrics.HandleFailedRequest(quorumId, blobSize, "DisperseBlob")
Expand Down Expand Up @@ -173,23 +181,65 @@ func (s *DispersalServer) DisperseBlob(ctx context.Context, req *pb.DisperseBlob
}, nil
}

func (s *DispersalServer) getAccountRate(origin string, quorumID core.QuorumID) (*PerUserRateInfo, error) {
unauthRates, ok := s.rateConfig.QuorumRateInfos[quorumID]
if !ok {
return nil, fmt.Errorf("no configured rate exists for quorum %d", quorumID)
}

for ip, rateInfoByQuorum := range s.rateConfig.Allowlist {
if !strings.Contains(origin, ip) {
continue
}

rateInfo, ok := rateInfoByQuorum[quorumID]
if !ok {
continue
}

throughput := unauthRates.PerUserUnauthThroughput
if rateInfo.Throughput > 0 {
throughput = rateInfo.Throughput
}

blobRate := unauthRates.PerUserUnauthBlobRate
if rateInfo.BlobRate > 0 {
blobRate = rateInfo.BlobRate
}

return &PerUserRateInfo{
Throughput: throughput,
BlobRate: blobRate,
}, nil
}

return &PerUserRateInfo{
Throughput: unauthRates.PerUserUnauthThroughput,
BlobRate: unauthRates.PerUserUnauthBlobRate,
}, nil
}

func (s *DispersalServer) checkRateLimitsAndAddRates(ctx context.Context, blob *core.Blob, origin string) error {

// TODO(robert): Remove these locks once we have resolved ratelimiting approach
s.mu.Lock()
defer s.mu.Unlock()

for _, param := range blob.RequestHeader.SecurityParams {
for i, param := range blob.RequestHeader.SecurityParams {

rates, ok := s.rateConfig.QuorumRateInfos[param.QuorumID]
if !ok {
return fmt.Errorf("no configured rate exists for quorum %d", param.QuorumID)
}
accountRates, err := s.getAccountRate(origin, param.QuorumID)
if err != nil {
return err
}

// Get the encoded blob size from the blob header. Calculation is done in a way that nodes can replicate
blobSize := len(blob.Data)
length := core.GetBlobLength(uint(blobSize))
encodedLength := core.GetEncodedBlobLength(length, uint8(blob.RequestHeader.SecurityParams[param.QuorumID].QuorumThreshold), uint8(blob.RequestHeader.SecurityParams[param.QuorumID].AdversaryThreshold))
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

there's no guarantee that the security params are ordered by the quorum ID

encodedLength := core.GetEncodedBlobLength(length, uint8(param.QuorumThreshold), uint8(param.AdversaryThreshold))
encodedSize := core.GetBlobSize(encodedLength)

s.logger.Debug("checking rate limits", "origin", origin, "quorum", param.QuorumID, "encodedSize", encodedSize, "blobSize", blobSize)
Expand All @@ -202,7 +252,7 @@ func (s *DispersalServer) checkRateLimitsAndAddRates(ctx context.Context, blob *
}
if !allowed {
s.logger.Warn("system byte ratelimit exceeded", "systemQuorumKey", systemQuorumKey, "rate", rates.TotalUnauthThroughput)
return errSystemRateLimit
return errSystemThroughputRateLimit
}

systemQuorumKey = fmt.Sprintf("%s:%d-blobrate", systemAccountKey, param.QuorumID)
Expand All @@ -212,35 +262,35 @@ func (s *DispersalServer) checkRateLimitsAndAddRates(ctx context.Context, blob *
}
if !allowed {
s.logger.Warn("system blob ratelimit exceeded", "systemQuorumKey", systemQuorumKey, "rate", float32(rates.TotalUnauthBlobRate)/blobRateMultiplier)
return errSystemRateLimit
return errSystemBlobRateLimit
}

// Check Account Ratelimit

blob.RequestHeader.AccountID = "ip:" + origin

userQuorumKey := fmt.Sprintf("%s:%d", blob.RequestHeader.AccountID, param.QuorumID)
allowed, err = s.ratelimiter.AllowRequest(ctx, userQuorumKey, encodedSize, rates.PerUserUnauthThroughput)
allowed, err = s.ratelimiter.AllowRequest(ctx, userQuorumKey, encodedSize, accountRates.Throughput)
if err != nil {
return fmt.Errorf("ratelimiter error: %v", err)
}
if !allowed {
s.logger.Warn("account byte ratelimit exceeded", "userQuorumKey", userQuorumKey, "rate", rates.PerUserUnauthThroughput)
return errAccountRateLimit
s.logger.Warn("account byte ratelimit exceeded", "userQuorumKey", userQuorumKey, "rate", accountRates.Throughput)
return errAccountThroughputRateLimit
}

userQuorumKey = fmt.Sprintf("%s:%d-blobrate", blob.RequestHeader.AccountID, param.QuorumID)
allowed, err = s.ratelimiter.AllowRequest(ctx, userQuorumKey, blobRateMultiplier, rates.PerUserUnauthBlobRate)
allowed, err = s.ratelimiter.AllowRequest(ctx, userQuorumKey, blobRateMultiplier, accountRates.BlobRate)
if err != nil {
return fmt.Errorf("ratelimiter error: %v", err)
}
if !allowed {
s.logger.Warn("account blob ratelimit exceeded", "userQuorumKey", userQuorumKey, "rate", float32(rates.PerUserUnauthBlobRate)/blobRateMultiplier)
return errAccountRateLimit
s.logger.Warn("account blob ratelimit exceeded", "userQuorumKey", userQuorumKey, "rate", float32(accountRates.BlobRate)/blobRateMultiplier)
return errAccountBlobRateLimit
}

// Update the quorum rate
blob.RequestHeader.SecurityParams[param.QuorumID].QuorumRate = rates.PerUserUnauthThroughput
blob.RequestHeader.SecurityParams[i].QuorumRate = accountRates.Throughput
}
return nil

Expand Down
Loading
Loading